Fix architecture issues from todolist
- Add documentation warnings for in-memory rate limiting and failed login attempts - Consolidate duplicate health endpoints into api/health.py - Fix CLI to use correct async rescan method names - Update download.py and anime.py to use custom exception classes - Add WebSocket room validation and rate limiting
This commit is contained in:
@@ -2,12 +2,18 @@ import logging
|
||||
import warnings
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi import APIRouter, Depends, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from src.core.entities.series import Serie
|
||||
from src.server.database.service import AnimeSeriesService
|
||||
from src.server.exceptions import (
|
||||
BadRequestError,
|
||||
NotFoundError,
|
||||
ServerError,
|
||||
ValidationError,
|
||||
)
|
||||
from src.server.services.anime_service import AnimeService, AnimeServiceError
|
||||
from src.server.utils.dependencies import (
|
||||
get_anime_service,
|
||||
@@ -55,9 +61,8 @@ async def get_anime_status(
|
||||
"series_count": series_count
|
||||
}
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to get status: {str(exc)}",
|
||||
raise ServerError(
|
||||
message=f"Failed to get status: {str(exc)}"
|
||||
) from exc
|
||||
|
||||
|
||||
@@ -208,35 +213,30 @@ async def list_anime(
|
||||
try:
|
||||
page_num = int(page)
|
||||
if page_num < 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="Page number must be positive"
|
||||
raise ValidationError(
|
||||
message="Page number must be positive"
|
||||
)
|
||||
page = page_num
|
||||
except (ValueError, TypeError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="Page must be a valid number"
|
||||
raise ValidationError(
|
||||
message="Page must be a valid number"
|
||||
)
|
||||
|
||||
if per_page is not None:
|
||||
try:
|
||||
per_page_num = int(per_page)
|
||||
if per_page_num < 1:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="Per page must be positive"
|
||||
raise ValidationError(
|
||||
message="Per page must be positive"
|
||||
)
|
||||
if per_page_num > 1000:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="Per page cannot exceed 1000"
|
||||
raise ValidationError(
|
||||
message="Per page cannot exceed 1000"
|
||||
)
|
||||
per_page = per_page_num
|
||||
except (ValueError, TypeError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="Per page must be a valid number"
|
||||
raise ValidationError(
|
||||
message="Per page must be a valid number"
|
||||
)
|
||||
|
||||
# Validate sort_by parameter to prevent ORM injection
|
||||
@@ -245,9 +245,8 @@ async def list_anime(
|
||||
allowed_sort_fields = ["title", "id", "missing_episodes", "name"]
|
||||
if sort_by not in allowed_sort_fields:
|
||||
allowed = ", ".join(allowed_sort_fields)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=f"Invalid sort_by parameter. Allowed: {allowed}"
|
||||
raise ValidationError(
|
||||
message=f"Invalid sort_by parameter. Allowed: {allowed}"
|
||||
)
|
||||
|
||||
# Validate filter parameter
|
||||
@@ -260,9 +259,8 @@ async def list_anime(
|
||||
lower_filter = filter.lower()
|
||||
for pattern in dangerous_patterns:
|
||||
if pattern in lower_filter:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail="Invalid filter parameter"
|
||||
raise ValidationError(
|
||||
message="Invalid filter parameter"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -310,12 +308,11 @@ async def list_anime(
|
||||
)
|
||||
|
||||
return summaries
|
||||
except HTTPException:
|
||||
except (ValidationError, BadRequestError, NotFoundError, ServerError):
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to retrieve anime list",
|
||||
raise ServerError(
|
||||
message="Failed to retrieve anime list"
|
||||
) from exc
|
||||
|
||||
|
||||
@@ -346,14 +343,12 @@ async def trigger_rescan(
|
||||
"message": "Rescan started successfully",
|
||||
}
|
||||
except AnimeServiceError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Rescan failed: {str(e)}",
|
||||
raise ServerError(
|
||||
message=f"Rescan failed: {str(e)}"
|
||||
) from e
|
||||
except Exception as exc:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="Failed to start rescan",
|
||||
raise ServerError(
|
||||
message="Failed to start rescan"
|
||||
) from exc
|
||||
|
||||
|
||||
|
||||
@@ -4,9 +4,10 @@ This module provides REST API endpoints for managing the anime download queue,
|
||||
including adding episodes, removing items, controlling queue processing, and
|
||||
retrieving queue status and statistics.
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, status
|
||||
from fastapi import APIRouter, Depends, Path, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from src.server.exceptions import BadRequestError, NotFoundError, ServerError
|
||||
from src.server.models.download import (
|
||||
DownloadRequest,
|
||||
QueueOperationRequest,
|
||||
@@ -52,9 +53,8 @@ async def get_queue_status(
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retrieve queue status: {str(e)}",
|
||||
raise ServerError(
|
||||
message=f"Failed to retrieve queue status: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -91,9 +91,8 @@ async def add_to_queue(
|
||||
try:
|
||||
# Validate request
|
||||
if not request.episodes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="At least one episode must be specified",
|
||||
raise BadRequestError(
|
||||
message="At least one episode must be specified"
|
||||
)
|
||||
|
||||
# Add to queue
|
||||
@@ -122,16 +121,12 @@ async def add_to_queue(
|
||||
)
|
||||
|
||||
except DownloadServiceError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except HTTPException:
|
||||
raise BadRequestError(message=str(e))
|
||||
except (BadRequestError, NotFoundError, ServerError):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to add episodes to queue: {str(e)}",
|
||||
raise ServerError(
|
||||
message=f"Failed to add episodes to queue: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -163,9 +158,8 @@ async def clear_completed(
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to clear completed items: {str(e)}",
|
||||
raise ServerError(
|
||||
message=f"Failed to clear completed items: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -197,9 +191,8 @@ async def clear_failed(
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to clear failed items: {str(e)}",
|
||||
raise ServerError(
|
||||
message=f"Failed to clear failed items: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -231,9 +224,8 @@ async def clear_pending(
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to clear pending items: {str(e)}",
|
||||
raise ServerError(
|
||||
message=f"Failed to clear pending items: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -262,22 +254,19 @@ async def remove_from_queue(
|
||||
removed_ids = await download_service.remove_from_queue([item_id])
|
||||
|
||||
if not removed_ids:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Download item {item_id} not found in queue",
|
||||
raise NotFoundError(
|
||||
message=f"Download item {item_id} not found in queue",
|
||||
resource_type="download_item",
|
||||
resource_id=item_id
|
||||
)
|
||||
|
||||
except DownloadServiceError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except HTTPException:
|
||||
raise BadRequestError(message=str(e))
|
||||
except (BadRequestError, NotFoundError, ServerError):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to remove item from queue: {str(e)}",
|
||||
raise ServerError(
|
||||
message=f"Failed to remove item from queue: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -307,22 +296,18 @@ async def remove_multiple_from_queue(
|
||||
)
|
||||
|
||||
if not removed_ids:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No matching items found in queue",
|
||||
raise NotFoundError(
|
||||
message="No matching items found in queue",
|
||||
resource_type="download_items"
|
||||
)
|
||||
|
||||
except DownloadServiceError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except HTTPException:
|
||||
raise BadRequestError(message=str(e))
|
||||
except (BadRequestError, NotFoundError, ServerError):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to remove items from queue: {str(e)}",
|
||||
raise ServerError(
|
||||
message=f"Failed to remove items from queue: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -354,9 +339,8 @@ async def start_queue(
|
||||
result = await download_service.start_queue_processing()
|
||||
|
||||
if result is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="No pending downloads in queue",
|
||||
raise BadRequestError(
|
||||
message="No pending downloads in queue"
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -365,16 +349,12 @@ async def start_queue(
|
||||
}
|
||||
|
||||
except DownloadServiceError as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=str(e),
|
||||
)
|
||||
except HTTPException:
|
||||
raise BadRequestError(message=str(e))
|
||||
except (BadRequestError, NotFoundError, ServerError):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to start queue processing: {str(e)}",
|
||||
raise ServerError(
|
||||
message=f"Failed to start queue processing: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -408,9 +388,8 @@ async def stop_queue(
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to stop queue processing: {str(e)}",
|
||||
raise ServerError(
|
||||
message=f"Failed to stop queue processing: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -442,9 +421,8 @@ async def pause_queue(
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to pause queue processing: {str(e)}",
|
||||
raise ServerError(
|
||||
message=f"Failed to pause queue processing: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -480,9 +458,8 @@ async def reorder_queue(
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to reorder queue: {str(e)}",
|
||||
raise ServerError(
|
||||
message=f"Failed to reorder queue: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@@ -522,7 +499,6 @@ async def retry_failed(
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Failed to retry downloads: {str(e)}",
|
||||
raise ServerError(
|
||||
message=f"Failed to retry downloads: {str(e)}"
|
||||
)
|
||||
|
||||
@@ -23,6 +23,9 @@ class HealthStatus(BaseModel):
|
||||
status: str
|
||||
timestamp: str
|
||||
version: str = "1.0.0"
|
||||
service: str = "aniworld-api"
|
||||
series_app_initialized: bool = False
|
||||
anime_directory_configured: bool = False
|
||||
|
||||
|
||||
class DatabaseHealth(BaseModel):
|
||||
@@ -170,14 +173,24 @@ def get_system_metrics() -> SystemMetrics:
|
||||
@router.get("", response_model=HealthStatus)
|
||||
async def basic_health_check() -> HealthStatus:
|
||||
"""Basic health check endpoint.
|
||||
|
||||
This endpoint does not depend on anime_directory configuration
|
||||
and should always return 200 OK for basic health monitoring.
|
||||
Includes service information for identification.
|
||||
|
||||
Returns:
|
||||
HealthStatus: Simple health status with timestamp.
|
||||
HealthStatus: Simple health status with timestamp and service info.
|
||||
"""
|
||||
from src.config.settings import settings
|
||||
from src.server.utils.dependencies import _series_app
|
||||
|
||||
logger.debug("Basic health check requested")
|
||||
return HealthStatus(
|
||||
status="healthy",
|
||||
timestamp=datetime.now().isoformat(),
|
||||
service="aniworld-api",
|
||||
series_app_initialized=_series_app is not None,
|
||||
anime_directory_configured=bool(settings.anime_directory),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,8 +13,9 @@ in their data payload. The `folder` field is optional for display purposes.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional, Set
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status
|
||||
@@ -34,6 +35,73 @@ logger = structlog.get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/ws", tags=["websocket"])
|
||||
|
||||
# Valid room names - explicit allow-list for security
|
||||
VALID_ROOMS: Set[str] = {
|
||||
"downloads", # Download progress updates
|
||||
"queue", # Queue status changes
|
||||
"scan", # Scan progress updates
|
||||
"system", # System notifications
|
||||
"errors", # Error notifications
|
||||
}
|
||||
|
||||
# Rate limiting configuration for WebSocket messages
|
||||
WS_RATE_LIMIT_MESSAGES_PER_MINUTE = 60
|
||||
WS_RATE_LIMIT_WINDOW_SECONDS = 60
|
||||
|
||||
# In-memory rate limiting for WebSocket connections
|
||||
# WARNING: This resets on process restart. For production, consider Redis.
|
||||
_ws_rate_limits: Dict[str, Dict[str, float]] = {}
|
||||
|
||||
|
||||
def _check_ws_rate_limit(connection_id: str) -> bool:
|
||||
"""Check if a WebSocket connection has exceeded its rate limit.
|
||||
|
||||
Args:
|
||||
connection_id: Unique identifier for the WebSocket connection
|
||||
|
||||
Returns:
|
||||
bool: True if within rate limit, False if exceeded
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
if connection_id not in _ws_rate_limits:
|
||||
_ws_rate_limits[connection_id] = {
|
||||
"count": 0,
|
||||
"window_start": now,
|
||||
}
|
||||
|
||||
record = _ws_rate_limits[connection_id]
|
||||
|
||||
# Reset window if expired
|
||||
if now - record["window_start"] > WS_RATE_LIMIT_WINDOW_SECONDS:
|
||||
record["window_start"] = now
|
||||
record["count"] = 0
|
||||
|
||||
record["count"] += 1
|
||||
|
||||
return record["count"] <= WS_RATE_LIMIT_MESSAGES_PER_MINUTE
|
||||
|
||||
|
||||
def _cleanup_ws_rate_limits(connection_id: str) -> None:
|
||||
"""Remove rate limit record for a disconnected connection.
|
||||
|
||||
Args:
|
||||
connection_id: Unique identifier for the WebSocket connection
|
||||
"""
|
||||
_ws_rate_limits.pop(connection_id, None)
|
||||
|
||||
|
||||
def _validate_room_name(room: str) -> bool:
|
||||
"""Validate that a room name is in the allowed set.
|
||||
|
||||
Args:
|
||||
room: Room name to validate
|
||||
|
||||
Returns:
|
||||
bool: True if room is valid, False otherwise
|
||||
"""
|
||||
return room in VALID_ROOMS
|
||||
|
||||
|
||||
@router.websocket("/connect")
|
||||
async def websocket_endpoint(
|
||||
@@ -130,6 +198,19 @@ async def websocket_endpoint(
|
||||
# Receive message from client
|
||||
data = await websocket.receive_json()
|
||||
|
||||
# Check rate limit
|
||||
if not _check_ws_rate_limit(connection_id):
|
||||
logger.warning(
|
||||
"WebSocket rate limit exceeded",
|
||||
connection_id=connection_id,
|
||||
)
|
||||
await ws_service.send_error(
|
||||
connection_id,
|
||||
"Rate limit exceeded. Please slow down.",
|
||||
"RATE_LIMIT_EXCEEDED",
|
||||
)
|
||||
continue
|
||||
|
||||
# Parse client message
|
||||
try:
|
||||
client_msg = ClientMessage(**data)
|
||||
@@ -149,9 +230,26 @@ async def websocket_endpoint(
|
||||
# Handle room subscription requests
|
||||
if client_msg.action in ["join", "leave"]:
|
||||
try:
|
||||
room_name = client_msg.data.get("room", "")
|
||||
|
||||
# Validate room name against allow-list
|
||||
if not _validate_room_name(room_name):
|
||||
logger.warning(
|
||||
"Invalid room name requested",
|
||||
connection_id=connection_id,
|
||||
room=room_name,
|
||||
)
|
||||
await ws_service.send_error(
|
||||
connection_id,
|
||||
f"Invalid room name: {room_name}. "
|
||||
f"Valid rooms: {', '.join(sorted(VALID_ROOMS))}",
|
||||
"INVALID_ROOM",
|
||||
)
|
||||
continue
|
||||
|
||||
room_req = RoomSubscriptionRequest(
|
||||
action=client_msg.action,
|
||||
room=client_msg.data.get("room", ""),
|
||||
room=room_name,
|
||||
)
|
||||
|
||||
if room_req.action == "join":
|
||||
@@ -241,7 +339,8 @@ async def websocket_endpoint(
|
||||
error=str(e),
|
||||
)
|
||||
finally:
|
||||
# Cleanup connection
|
||||
# Cleanup connection and rate limit record
|
||||
_cleanup_ws_rate_limits(connection_id)
|
||||
await ws_service.disconnect(connection_id)
|
||||
logger.info("WebSocket connection closed", connection_id=connection_id)
|
||||
|
||||
@@ -263,5 +362,6 @@ async def websocket_status(
|
||||
"status": "operational",
|
||||
"active_connections": connection_count,
|
||||
"supported_message_types": [t.value for t in WebSocketMessageType],
|
||||
"valid_rooms": sorted(VALID_ROOMS),
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user