Fix architecture issues from todolist

- Add documentation warnings for in-memory rate limiting and failed login attempts
- Consolidate duplicate health endpoints into api/health.py
- Fix CLI to use correct async rescan method names
- Update download.py and anime.py to use custom exception classes
- Add WebSocket room validation and rate limiting
This commit is contained in:
2025-12-15 14:23:41 +01:00
parent 54790a7ebb
commit 27108aacda
13 changed files with 303 additions and 255 deletions

View File

@@ -1,5 +1,6 @@
"""Command-line interface for the Aniworld anime download manager."""
import asyncio
import logging
import os
from typing import Optional, Sequence
@@ -179,8 +180,11 @@ class SeriesCLI:
# Rescan logic
# ------------------------------------------------------------------
def rescan(self) -> None:
"""Trigger a rescan of the anime directory using the core app."""
total_to_scan = self.series_app.SerieScanner.get_total_to_scan()
"""Trigger a rescan of the anime directory using the core app.
Uses the legacy file-based scan mode for CLI compatibility.
"""
total_to_scan = self.series_app.serie_scanner.get_total_to_scan()
total_to_scan = max(total_to_scan, 1)
self._progress = Progress()
@@ -190,17 +194,16 @@ class SeriesCLI:
total=total_to_scan,
)
result = self.series_app.ReScan(
callback=self._wrap_scan_callback(total_to_scan)
# Run async rescan in sync context with file-based mode
asyncio.run(
self.series_app.rescan(use_database=False)
)
self._progress = None
self._scan_task_id = None
if result.success:
print(result.message)
else:
print(f"Scan failed: {result.message}")
series_count = len(self.series_app.series_list)
print(f"Scan completed. Found {series_count} series with missing episodes.")
def _wrap_scan_callback(self, total: int):
"""Create a callback that updates the scan progress bar."""

View File

@@ -2,12 +2,18 @@ import logging
import warnings
from typing import Any, List, Optional
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi import APIRouter, Depends, status
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from src.core.entities.series import Serie
from src.server.database.service import AnimeSeriesService
from src.server.exceptions import (
BadRequestError,
NotFoundError,
ServerError,
ValidationError,
)
from src.server.services.anime_service import AnimeService, AnimeServiceError
from src.server.utils.dependencies import (
get_anime_service,
@@ -55,9 +61,8 @@ async def get_anime_status(
"series_count": series_count
}
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to get status: {str(exc)}",
raise ServerError(
message=f"Failed to get status: {str(exc)}"
) from exc
@@ -208,35 +213,30 @@ async def list_anime(
try:
page_num = int(page)
if page_num < 1:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Page number must be positive"
raise ValidationError(
message="Page number must be positive"
)
page = page_num
except (ValueError, TypeError):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Page must be a valid number"
raise ValidationError(
message="Page must be a valid number"
)
if per_page is not None:
try:
per_page_num = int(per_page)
if per_page_num < 1:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Per page must be positive"
raise ValidationError(
message="Per page must be positive"
)
if per_page_num > 1000:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Per page cannot exceed 1000"
raise ValidationError(
message="Per page cannot exceed 1000"
)
per_page = per_page_num
except (ValueError, TypeError):
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Per page must be a valid number"
raise ValidationError(
message="Per page must be a valid number"
)
# Validate sort_by parameter to prevent ORM injection
@@ -245,9 +245,8 @@ async def list_anime(
allowed_sort_fields = ["title", "id", "missing_episodes", "name"]
if sort_by not in allowed_sort_fields:
allowed = ", ".join(allowed_sort_fields)
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"Invalid sort_by parameter. Allowed: {allowed}"
raise ValidationError(
message=f"Invalid sort_by parameter. Allowed: {allowed}"
)
# Validate filter parameter
@@ -260,9 +259,8 @@ async def list_anime(
lower_filter = filter.lower()
for pattern in dangerous_patterns:
if pattern in lower_filter:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail="Invalid filter parameter"
raise ValidationError(
message="Invalid filter parameter"
)
try:
@@ -310,12 +308,11 @@ async def list_anime(
)
return summaries
except HTTPException:
except (ValidationError, BadRequestError, NotFoundError, ServerError):
raise
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to retrieve anime list",
raise ServerError(
message="Failed to retrieve anime list"
) from exc
@@ -346,14 +343,12 @@ async def trigger_rescan(
"message": "Rescan started successfully",
}
except AnimeServiceError as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Rescan failed: {str(e)}",
raise ServerError(
message=f"Rescan failed: {str(e)}"
) from e
except Exception as exc:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to start rescan",
raise ServerError(
message="Failed to start rescan"
) from exc

View File

@@ -4,9 +4,10 @@ This module provides REST API endpoints for managing the anime download queue,
including adding episodes, removing items, controlling queue processing, and
retrieving queue status and statistics.
"""
from fastapi import APIRouter, Depends, HTTPException, Path, status
from fastapi import APIRouter, Depends, Path, status
from fastapi.responses import JSONResponse
from src.server.exceptions import BadRequestError, NotFoundError, ServerError
from src.server.models.download import (
DownloadRequest,
QueueOperationRequest,
@@ -52,9 +53,8 @@ async def get_queue_status(
return response
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retrieve queue status: {str(e)}",
raise ServerError(
message=f"Failed to retrieve queue status: {str(e)}"
)
@@ -91,9 +91,8 @@ async def add_to_queue(
try:
# Validate request
if not request.episodes:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="At least one episode must be specified",
raise BadRequestError(
message="At least one episode must be specified"
)
# Add to queue
@@ -122,16 +121,12 @@ async def add_to_queue(
)
except DownloadServiceError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
except HTTPException:
raise BadRequestError(message=str(e))
except (BadRequestError, NotFoundError, ServerError):
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to add episodes to queue: {str(e)}",
raise ServerError(
message=f"Failed to add episodes to queue: {str(e)}"
)
@@ -163,9 +158,8 @@ async def clear_completed(
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to clear completed items: {str(e)}",
raise ServerError(
message=f"Failed to clear completed items: {str(e)}"
)
@@ -197,9 +191,8 @@ async def clear_failed(
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to clear failed items: {str(e)}",
raise ServerError(
message=f"Failed to clear failed items: {str(e)}"
)
@@ -231,9 +224,8 @@ async def clear_pending(
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to clear pending items: {str(e)}",
raise ServerError(
message=f"Failed to clear pending items: {str(e)}"
)
@@ -262,22 +254,19 @@ async def remove_from_queue(
removed_ids = await download_service.remove_from_queue([item_id])
if not removed_ids:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Download item {item_id} not found in queue",
raise NotFoundError(
message=f"Download item {item_id} not found in queue",
resource_type="download_item",
resource_id=item_id
)
except DownloadServiceError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
except HTTPException:
raise BadRequestError(message=str(e))
except (BadRequestError, NotFoundError, ServerError):
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to remove item from queue: {str(e)}",
raise ServerError(
message=f"Failed to remove item from queue: {str(e)}"
)
@@ -307,22 +296,18 @@ async def remove_multiple_from_queue(
)
if not removed_ids:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No matching items found in queue",
raise NotFoundError(
message="No matching items found in queue",
resource_type="download_items"
)
except DownloadServiceError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
except HTTPException:
raise BadRequestError(message=str(e))
except (BadRequestError, NotFoundError, ServerError):
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to remove items from queue: {str(e)}",
raise ServerError(
message=f"Failed to remove items from queue: {str(e)}"
)
@@ -354,9 +339,8 @@ async def start_queue(
result = await download_service.start_queue_processing()
if result is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No pending downloads in queue",
raise BadRequestError(
message="No pending downloads in queue"
)
return {
@@ -365,16 +349,12 @@ async def start_queue(
}
except DownloadServiceError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=str(e),
)
except HTTPException:
raise BadRequestError(message=str(e))
except (BadRequestError, NotFoundError, ServerError):
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to start queue processing: {str(e)}",
raise ServerError(
message=f"Failed to start queue processing: {str(e)}"
)
@@ -408,9 +388,8 @@ async def stop_queue(
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to stop queue processing: {str(e)}",
raise ServerError(
message=f"Failed to stop queue processing: {str(e)}"
)
@@ -442,9 +421,8 @@ async def pause_queue(
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to pause queue processing: {str(e)}",
raise ServerError(
message=f"Failed to pause queue processing: {str(e)}"
)
@@ -480,9 +458,8 @@ async def reorder_queue(
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to reorder queue: {str(e)}",
raise ServerError(
message=f"Failed to reorder queue: {str(e)}"
)
@@ -522,7 +499,6 @@ async def retry_failed(
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to retry downloads: {str(e)}",
raise ServerError(
message=f"Failed to retry downloads: {str(e)}"
)

View File

@@ -23,6 +23,9 @@ class HealthStatus(BaseModel):
status: str
timestamp: str
version: str = "1.0.0"
service: str = "aniworld-api"
series_app_initialized: bool = False
anime_directory_configured: bool = False
class DatabaseHealth(BaseModel):
@@ -170,14 +173,24 @@ def get_system_metrics() -> SystemMetrics:
@router.get("", response_model=HealthStatus)
async def basic_health_check() -> HealthStatus:
"""Basic health check endpoint.
This endpoint does not depend on anime_directory configuration
and should always return 200 OK for basic health monitoring.
Includes service information for identification.
Returns:
HealthStatus: Simple health status with timestamp.
HealthStatus: Simple health status with timestamp and service info.
"""
from src.config.settings import settings
from src.server.utils.dependencies import _series_app
logger.debug("Basic health check requested")
return HealthStatus(
status="healthy",
timestamp=datetime.now().isoformat(),
service="aniworld-api",
series_app_initialized=_series_app is not None,
anime_directory_configured=bool(settings.anime_directory),
)

View File

@@ -13,8 +13,9 @@ in their data payload. The `folder` field is optional for display purposes.
"""
from __future__ import annotations
import time
import uuid
from typing import Optional
from typing import Dict, Optional, Set
import structlog
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status
@@ -34,6 +35,73 @@ logger = structlog.get_logger(__name__)
router = APIRouter(prefix="/ws", tags=["websocket"])
# Valid room names - explicit allow-list for security
VALID_ROOMS: Set[str] = {
"downloads", # Download progress updates
"queue", # Queue status changes
"scan", # Scan progress updates
"system", # System notifications
"errors", # Error notifications
}
# Rate limiting configuration for WebSocket messages
WS_RATE_LIMIT_MESSAGES_PER_MINUTE = 60
WS_RATE_LIMIT_WINDOW_SECONDS = 60
# In-memory rate limiting for WebSocket connections
# WARNING: This resets on process restart. For production, consider Redis.
_ws_rate_limits: Dict[str, Dict[str, float]] = {}
def _check_ws_rate_limit(connection_id: str) -> bool:
"""Check if a WebSocket connection has exceeded its rate limit.
Args:
connection_id: Unique identifier for the WebSocket connection
Returns:
bool: True if within rate limit, False if exceeded
"""
now = time.time()
if connection_id not in _ws_rate_limits:
_ws_rate_limits[connection_id] = {
"count": 0,
"window_start": now,
}
record = _ws_rate_limits[connection_id]
# Reset window if expired
if now - record["window_start"] > WS_RATE_LIMIT_WINDOW_SECONDS:
record["window_start"] = now
record["count"] = 0
record["count"] += 1
return record["count"] <= WS_RATE_LIMIT_MESSAGES_PER_MINUTE
def _cleanup_ws_rate_limits(connection_id: str) -> None:
"""Remove rate limit record for a disconnected connection.
Args:
connection_id: Unique identifier for the WebSocket connection
"""
_ws_rate_limits.pop(connection_id, None)
def _validate_room_name(room: str) -> bool:
"""Validate that a room name is in the allowed set.
Args:
room: Room name to validate
Returns:
bool: True if room is valid, False otherwise
"""
return room in VALID_ROOMS
@router.websocket("/connect")
async def websocket_endpoint(
@@ -130,6 +198,19 @@ async def websocket_endpoint(
# Receive message from client
data = await websocket.receive_json()
# Check rate limit
if not _check_ws_rate_limit(connection_id):
logger.warning(
"WebSocket rate limit exceeded",
connection_id=connection_id,
)
await ws_service.send_error(
connection_id,
"Rate limit exceeded. Please slow down.",
"RATE_LIMIT_EXCEEDED",
)
continue
# Parse client message
try:
client_msg = ClientMessage(**data)
@@ -149,9 +230,26 @@ async def websocket_endpoint(
# Handle room subscription requests
if client_msg.action in ["join", "leave"]:
try:
room_name = client_msg.data.get("room", "")
# Validate room name against allow-list
if not _validate_room_name(room_name):
logger.warning(
"Invalid room name requested",
connection_id=connection_id,
room=room_name,
)
await ws_service.send_error(
connection_id,
f"Invalid room name: {room_name}. "
f"Valid rooms: {', '.join(sorted(VALID_ROOMS))}",
"INVALID_ROOM",
)
continue
room_req = RoomSubscriptionRequest(
action=client_msg.action,
room=client_msg.data.get("room", ""),
room=room_name,
)
if room_req.action == "join":
@@ -241,7 +339,8 @@ async def websocket_endpoint(
error=str(e),
)
finally:
# Cleanup connection
# Cleanup connection and rate limit record
_cleanup_ws_rate_limits(connection_id)
await ws_service.disconnect(connection_id)
logger.info("WebSocket connection closed", connection_id=connection_id)
@@ -263,5 +362,6 @@ async def websocket_status(
"status": "operational",
"active_connections": connection_count,
"supported_message_types": [t.value for t in WebSocketMessageType],
"valid_rooms": sorted(VALID_ROOMS),
},
)

View File

@@ -1,27 +0,0 @@
"""
Health check controller for monitoring and status endpoints.
This module provides health check endpoints for application monitoring.
"""
from fastapi import APIRouter
from src.config.settings import settings
from src.server.utils.dependencies import _series_app
router = APIRouter(prefix="/health", tags=["health"])
@router.get("")
async def health_check():
"""Health check endpoint for monitoring.
This endpoint does not depend on anime_directory configuration
and should always return 200 OK for basic health monitoring.
"""
return {
"status": "healthy",
"service": "aniworld-api",
"version": "1.0.0",
"series_app_initialized": _series_app is not None,
"anime_directory_configured": bool(settings.anime_directory)
}

View File

@@ -144,6 +144,23 @@ class ConflictError(AniWorldAPIException):
)
class BadRequestError(AniWorldAPIException):
"""Exception raised for bad request (400) errors."""
def __init__(
self,
message: str = "Bad request",
details: Optional[Dict[str, Any]] = None,
):
"""Initialize bad request error."""
super().__init__(
message=message,
status_code=400,
error_code="BAD_REQUEST",
details=details,
)
class RateLimitError(AniWorldAPIException):
"""Exception raised when rate limit is exceeded."""

View File

@@ -21,6 +21,7 @@ from src.server.api.anime import router as anime_router
from src.server.api.auth import router as auth_router
from src.server.api.config import router as config_router
from src.server.api.download import router as download_router
from src.server.api.health import router as health_router
from src.server.api.scheduler import router as scheduler_router
from src.server.api.websocket import router as websocket_router
from src.server.controllers.error_controller import (
@@ -29,7 +30,6 @@ from src.server.controllers.error_controller import (
)
# Import controllers
from src.server.controllers.health_controller import router as health_router
from src.server.controllers.page_controller import router as page_router
from src.server.middleware.auth import AuthMiddleware
from src.server.middleware.error_handler import register_exception_handlers

View File

@@ -8,6 +8,17 @@ Responsibilities:
This middleware is intentionally lightweight and synchronous.
For production use consider a distributed rate limiter (Redis) and
a proper token revocation store.
WARNING - SINGLE PROCESS LIMITATION:
Rate limiting state is stored in memory dictionaries which RESET when
the process restarts. This means:
- Attackers can bypass rate limits by triggering a process restart
- Rate limits are not shared across multiple workers/processes
For production deployments, consider:
- Using Redis-backed rate limiting (e.g., slowapi with Redis)
- Running behind a reverse proxy with rate limiting (nginx, HAProxy)
- Using a dedicated rate limiting service
"""
from __future__ import annotations

View File

@@ -15,6 +15,7 @@ from src.server.exceptions import (
AniWorldAPIException,
AuthenticationError,
AuthorizationError,
BadRequestError,
ConflictError,
NotFoundError,
RateLimitError,
@@ -127,6 +128,26 @@ def register_exception_handlers(app: FastAPI) -> None:
),
)
@app.exception_handler(BadRequestError)
async def bad_request_error_handler(
request: Request, exc: BadRequestError
) -> JSONResponse:
"""Handle bad request errors (400)."""
logger.info(
f"Bad request error: {exc.message}",
extra={"details": exc.details, "path": str(request.url.path)},
)
return JSONResponse(
status_code=exc.status_code,
content=create_error_response(
status_code=exc.status_code,
error=exc.error_code,
message=exc.message,
details=exc.details,
request_id=getattr(request.state, "request_id", None),
),
)
@app.exception_handler(NotFoundError)
async def not_found_error_handler(
request: Request, exc: NotFoundError

View File

@@ -42,6 +42,17 @@ class AuthService:
config persistence should be used (not implemented here).
- Lockout policy is kept in-memory and will reset when the process
restarts. This is acceptable for single-process deployments.
WARNING - SINGLE PROCESS LIMITATION:
Failed login attempts are stored in memory dictionaries which RESET
when the process restarts. This means:
- Attackers can bypass lockouts by triggering a process restart
- Lockout state is not shared across multiple workers/processes
For production deployments, consider:
- Storing failed attempts in database with TTL-based expiration
- Using Redis for distributed lockout state
- Implementing account-based (not just IP-based) lockout tracking
"""
def __init__(self) -> None: