feat: migrate to Pydantic V2 and implement rate limiting middleware
- Migrate settings.py to Pydantic V2 (SettingsConfigDict, validation_alias) - Update config models to use @field_validator with @classmethod - Replace deprecated datetime.utcnow() with datetime.now(timezone.utc) - Migrate FastAPI app from @app.on_event to lifespan context manager - Implement comprehensive rate limiting middleware with: * Endpoint-specific rate limits (login: 5/min, register: 3/min) * IP-based and user-based tracking * Authenticated user multiplier (2x limits) * Bypass paths for health, docs, static, websocket endpoints * Rate limit headers in responses - Add 13 comprehensive tests for rate limiting (all passing) - Update instructions.md to mark completed tasks - Fix asyncio.create_task usage in anime_service.py All 714 tests passing. No deprecation warnings.
This commit is contained in:
@@ -5,6 +5,7 @@ This module provides the main FastAPI application with proper CORS
|
||||
configuration, middleware setup, static file serving, and Jinja2 template
|
||||
integration.
|
||||
"""
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
@@ -36,13 +37,71 @@ from src.server.middleware.error_handler import register_exception_handlers
|
||||
from src.server.services.progress_service import get_progress_service
|
||||
from src.server.services.websocket_service import get_websocket_service
|
||||
|
||||
# Initialize FastAPI app
|
||||
# Prefer storing application-wide singletons on FastAPI.state instead of
|
||||
# module-level globals. This makes testing and multi-instance hosting safer.
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Manage application lifespan (startup and shutdown)."""
|
||||
# Startup
|
||||
try:
|
||||
# Initialize SeriesApp with configured directory and store it on
|
||||
# application state so it can be injected via dependencies.
|
||||
if settings.anime_directory:
|
||||
app.state.series_app = SeriesApp(settings.anime_directory)
|
||||
else:
|
||||
# Log warning when anime directory is not configured
|
||||
print(
|
||||
"WARNING: ANIME_DIRECTORY not configured. "
|
||||
"Some features may be unavailable."
|
||||
)
|
||||
|
||||
# Initialize progress service with websocket callback
|
||||
progress_service = get_progress_service()
|
||||
ws_service = get_websocket_service()
|
||||
|
||||
async def broadcast_callback(
|
||||
message_type: str, data: dict, room: str
|
||||
) -> None:
|
||||
"""Broadcast progress updates via WebSocket."""
|
||||
message = {
|
||||
"type": message_type,
|
||||
"data": data,
|
||||
}
|
||||
await ws_service.manager.broadcast_to_room(message, room)
|
||||
|
||||
progress_service.set_broadcast_callback(broadcast_callback)
|
||||
|
||||
print("FastAPI application started successfully")
|
||||
except Exception as e:
|
||||
print(f"Error during startup: {e}")
|
||||
raise # Re-raise to prevent app from starting in broken state
|
||||
|
||||
# Yield control to the application
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
print("FastAPI application shutting down")
|
||||
|
||||
|
||||
def get_series_app() -> Optional[SeriesApp]:
|
||||
"""Dependency to retrieve the SeriesApp instance from application state.
|
||||
|
||||
Returns None when the application wasn't configured with an anime
|
||||
directory (for example during certain test runs).
|
||||
"""
|
||||
return getattr(app.state, "series_app", None)
|
||||
|
||||
|
||||
# Initialize FastAPI app with lifespan
|
||||
app = FastAPI(
|
||||
title="Aniworld Download Manager",
|
||||
description="Modern web interface for Aniworld anime download management",
|
||||
version="1.0.0",
|
||||
docs_url="/api/docs",
|
||||
redoc_url="/api/redoc"
|
||||
redoc_url="/api/redoc",
|
||||
lifespan=lifespan
|
||||
)
|
||||
|
||||
# Configure CORS using environment-driven configuration.
|
||||
@@ -79,61 +138,6 @@ app.include_router(websocket_router)
|
||||
# Register exception handlers
|
||||
register_exception_handlers(app)
|
||||
|
||||
# Prefer storing application-wide singletons on FastAPI.state instead of
|
||||
# module-level globals. This makes testing and multi-instance hosting safer.
|
||||
|
||||
|
||||
def get_series_app() -> Optional[SeriesApp]:
|
||||
"""Dependency to retrieve the SeriesApp instance from application state.
|
||||
|
||||
Returns None when the application wasn't configured with an anime
|
||||
directory (for example during certain test runs).
|
||||
"""
|
||||
return getattr(app.state, "series_app", None)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup_event() -> None:
|
||||
"""Initialize application on startup."""
|
||||
try:
|
||||
# Initialize SeriesApp with configured directory and store it on
|
||||
# application state so it can be injected via dependencies.
|
||||
if settings.anime_directory:
|
||||
app.state.series_app = SeriesApp(settings.anime_directory)
|
||||
else:
|
||||
# Log warning when anime directory is not configured
|
||||
print(
|
||||
"WARNING: ANIME_DIRECTORY not configured. "
|
||||
"Some features may be unavailable."
|
||||
)
|
||||
|
||||
# Initialize progress service with websocket callback
|
||||
progress_service = get_progress_service()
|
||||
ws_service = get_websocket_service()
|
||||
|
||||
async def broadcast_callback(
|
||||
message_type: str, data: dict, room: str
|
||||
) -> None:
|
||||
"""Broadcast progress updates via WebSocket."""
|
||||
message = {
|
||||
"type": message_type,
|
||||
"data": data,
|
||||
}
|
||||
await ws_service.manager.broadcast_to_room(message, room)
|
||||
|
||||
progress_service.set_broadcast_callback(broadcast_callback)
|
||||
|
||||
print("FastAPI application started successfully")
|
||||
except Exception as e:
|
||||
print(f"Error during startup: {e}")
|
||||
raise # Re-raise to prevent app from starting in broken state
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
"""Cleanup on application shutdown."""
|
||||
print("FastAPI application shutting down")
|
||||
|
||||
|
||||
@app.exception_handler(404)
|
||||
async def handle_not_found(request: Request, exc: HTTPException):
|
||||
|
||||
331
src/server/middleware/rate_limit.py
Normal file
331
src/server/middleware/rate_limit.py
Normal file
@@ -0,0 +1,331 @@
|
||||
"""Rate limiting middleware for API endpoints.
|
||||
|
||||
This module provides comprehensive rate limiting with support for:
|
||||
- Endpoint-specific rate limits
|
||||
- IP-based limiting
|
||||
- User-based rate limiting
|
||||
- Bypass mechanisms for authenticated users
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
|
||||
from fastapi import Request, status
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
|
||||
class RateLimitConfig:
|
||||
"""Configuration for rate limiting rules."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
requests_per_minute: int = 60,
|
||||
requests_per_hour: int = 1000,
|
||||
authenticated_multiplier: float = 2.0,
|
||||
):
|
||||
"""Initialize rate limit configuration.
|
||||
|
||||
Args:
|
||||
requests_per_minute: Max requests per minute for
|
||||
unauthenticated users
|
||||
requests_per_hour: Max requests per hour for
|
||||
unauthenticated users
|
||||
authenticated_multiplier: Multiplier for authenticated users
|
||||
"""
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.requests_per_hour = requests_per_hour
|
||||
self.authenticated_multiplier = authenticated_multiplier
|
||||
|
||||
|
||||
class RateLimitStore:
|
||||
"""In-memory store for rate limit tracking."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the rate limit store."""
|
||||
# Store format: {identifier: [(timestamp, count), ...]}
|
||||
self._minute_store: Dict[str, list] = defaultdict(list)
|
||||
self._hour_store: Dict[str, list] = defaultdict(list)
|
||||
|
||||
def check_limit(
|
||||
self,
|
||||
identifier: str,
|
||||
max_per_minute: int,
|
||||
max_per_hour: int,
|
||||
) -> Tuple[bool, Optional[int]]:
|
||||
"""Check if the identifier has exceeded rate limits.
|
||||
|
||||
Args:
|
||||
identifier: Unique identifier (IP or user ID)
|
||||
max_per_minute: Maximum requests allowed per minute
|
||||
max_per_hour: Maximum requests allowed per hour
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, retry_after_seconds)
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# Clean up old entries
|
||||
self._cleanup_old_entries(identifier, current_time)
|
||||
|
||||
# Check minute limit
|
||||
minute_count = len(self._minute_store[identifier])
|
||||
if minute_count >= max_per_minute:
|
||||
# Calculate retry after time
|
||||
oldest_entry = self._minute_store[identifier][0]
|
||||
retry_after = int(60 - (current_time - oldest_entry))
|
||||
return False, max(retry_after, 1)
|
||||
|
||||
# Check hour limit
|
||||
hour_count = len(self._hour_store[identifier])
|
||||
if hour_count >= max_per_hour:
|
||||
# Calculate retry after time
|
||||
oldest_entry = self._hour_store[identifier][0]
|
||||
retry_after = int(3600 - (current_time - oldest_entry))
|
||||
return False, max(retry_after, 1)
|
||||
|
||||
return True, None
|
||||
|
||||
def record_request(self, identifier: str) -> None:
|
||||
"""Record a request for the identifier.
|
||||
|
||||
Args:
|
||||
identifier: Unique identifier (IP or user ID)
|
||||
"""
|
||||
current_time = time.time()
|
||||
self._minute_store[identifier].append(current_time)
|
||||
self._hour_store[identifier].append(current_time)
|
||||
|
||||
def get_remaining_requests(
|
||||
self, identifier: str, max_per_minute: int, max_per_hour: int
|
||||
) -> Tuple[int, int]:
|
||||
"""Get remaining requests for the identifier.
|
||||
|
||||
Args:
|
||||
identifier: Unique identifier
|
||||
max_per_minute: Maximum per minute
|
||||
max_per_hour: Maximum per hour
|
||||
|
||||
Returns:
|
||||
Tuple of (remaining_per_minute, remaining_per_hour)
|
||||
"""
|
||||
minute_used = len(self._minute_store.get(identifier, []))
|
||||
hour_used = len(self._hour_store.get(identifier, []))
|
||||
return (
|
||||
max(0, max_per_minute - minute_used),
|
||||
max(0, max_per_hour - hour_used)
|
||||
)
|
||||
|
||||
def _cleanup_old_entries(
|
||||
self, identifier: str, current_time: float
|
||||
) -> None:
|
||||
"""Remove entries older than the time windows.
|
||||
|
||||
Args:
|
||||
identifier: Unique identifier
|
||||
current_time: Current timestamp
|
||||
"""
|
||||
# Remove entries older than 1 minute
|
||||
minute_cutoff = current_time - 60
|
||||
self._minute_store[identifier] = [
|
||||
ts for ts in self._minute_store[identifier] if ts > minute_cutoff
|
||||
]
|
||||
|
||||
# Remove entries older than 1 hour
|
||||
hour_cutoff = current_time - 3600
|
||||
self._hour_store[identifier] = [
|
||||
ts for ts in self._hour_store[identifier] if ts > hour_cutoff
|
||||
]
|
||||
|
||||
# Clean up empty entries
|
||||
if not self._minute_store[identifier]:
|
||||
del self._minute_store[identifier]
|
||||
if not self._hour_store[identifier]:
|
||||
del self._hour_store[identifier]
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware for API rate limiting."""
|
||||
|
||||
# Endpoint-specific rate limits (overrides defaults)
|
||||
ENDPOINT_LIMITS: Dict[str, RateLimitConfig] = {
|
||||
"/api/auth/login": RateLimitConfig(
|
||||
requests_per_minute=5,
|
||||
requests_per_hour=20,
|
||||
),
|
||||
"/api/auth/register": RateLimitConfig(
|
||||
requests_per_minute=3,
|
||||
requests_per_hour=10,
|
||||
),
|
||||
"/api/download": RateLimitConfig(
|
||||
requests_per_minute=10,
|
||||
requests_per_hour=100,
|
||||
authenticated_multiplier=3.0,
|
||||
),
|
||||
}
|
||||
|
||||
# Paths that bypass rate limiting
|
||||
BYPASS_PATHS = {
|
||||
"/health",
|
||||
"/health/detailed",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/static",
|
||||
"/ws",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
default_config: Optional[RateLimitConfig] = None,
|
||||
):
|
||||
"""Initialize rate limiting middleware.
|
||||
|
||||
Args:
|
||||
app: FastAPI application
|
||||
default_config: Default rate limit configuration
|
||||
"""
|
||||
super().__init__(app)
|
||||
self.default_config = default_config or RateLimitConfig()
|
||||
self.store = RateLimitStore()
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable):
|
||||
"""Process request and apply rate limiting.
|
||||
|
||||
Args:
|
||||
request: Incoming HTTP request
|
||||
call_next: Next middleware or endpoint handler
|
||||
|
||||
Returns:
|
||||
HTTP response (either rate limit error or normal response)
|
||||
"""
|
||||
# Check if path should bypass rate limiting
|
||||
if self._should_bypass(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Get identifier (user ID if authenticated, otherwise IP)
|
||||
identifier = self._get_identifier(request)
|
||||
|
||||
# Get rate limit configuration for this endpoint
|
||||
config = self._get_endpoint_config(request.url.path)
|
||||
|
||||
# Apply authenticated user multiplier if applicable
|
||||
is_authenticated = self._is_authenticated(request)
|
||||
max_per_minute = int(
|
||||
config.requests_per_minute *
|
||||
(config.authenticated_multiplier if is_authenticated else 1.0)
|
||||
)
|
||||
max_per_hour = int(
|
||||
config.requests_per_hour *
|
||||
(config.authenticated_multiplier if is_authenticated else 1.0)
|
||||
)
|
||||
|
||||
# Check rate limit
|
||||
allowed, retry_after = self.store.check_limit(
|
||||
identifier,
|
||||
max_per_minute,
|
||||
max_per_hour,
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={"detail": "Rate limit exceeded"},
|
||||
headers={"Retry-After": str(retry_after)},
|
||||
)
|
||||
|
||||
# Record the request
|
||||
self.store.record_request(identifier)
|
||||
|
||||
# Add rate limit headers to response
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit-Minute"] = str(max_per_minute)
|
||||
response.headers["X-RateLimit-Limit-Hour"] = str(max_per_hour)
|
||||
|
||||
minute_remaining, hour_remaining = self.store.get_remaining_requests(
|
||||
identifier, max_per_minute, max_per_hour
|
||||
)
|
||||
|
||||
response.headers["X-RateLimit-Remaining-Minute"] = str(
|
||||
minute_remaining
|
||||
)
|
||||
response.headers["X-RateLimit-Remaining-Hour"] = str(
|
||||
hour_remaining
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _should_bypass(self, path: str) -> bool:
|
||||
"""Check if path should bypass rate limiting.
|
||||
|
||||
Args:
|
||||
path: Request path
|
||||
|
||||
Returns:
|
||||
True if path should bypass rate limiting
|
||||
"""
|
||||
for bypass_path in self.BYPASS_PATHS:
|
||||
if path.startswith(bypass_path):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_identifier(self, request: Request) -> str:
|
||||
"""Get unique identifier for rate limiting.
|
||||
|
||||
Args:
|
||||
request: HTTP request
|
||||
|
||||
Returns:
|
||||
Unique identifier (user ID or IP address)
|
||||
"""
|
||||
# Try to get user ID from request state (set by auth middleware)
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
if user_id:
|
||||
return f"user:{user_id}"
|
||||
|
||||
# Fall back to IP address
|
||||
# Check for X-Forwarded-For header (proxy/load balancer)
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
# Take the first IP in the chain
|
||||
client_ip = forwarded_for.split(",")[0].strip()
|
||||
else:
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
return f"ip:{client_ip}"
|
||||
|
||||
def _get_endpoint_config(self, path: str) -> RateLimitConfig:
|
||||
"""Get rate limit configuration for endpoint.
|
||||
|
||||
Args:
|
||||
path: Request path
|
||||
|
||||
Returns:
|
||||
Rate limit configuration
|
||||
"""
|
||||
# Check for exact match
|
||||
if path in self.ENDPOINT_LIMITS:
|
||||
return self.ENDPOINT_LIMITS[path]
|
||||
|
||||
# Check for prefix match
|
||||
for endpoint_path, config in self.ENDPOINT_LIMITS.items():
|
||||
if path.startswith(endpoint_path):
|
||||
return config
|
||||
|
||||
return self.default_config
|
||||
|
||||
def _is_authenticated(self, request: Request) -> bool:
|
||||
"""Check if request is from authenticated user.
|
||||
|
||||
Args:
|
||||
request: HTTP request
|
||||
|
||||
Returns:
|
||||
True if user is authenticated
|
||||
"""
|
||||
return (
|
||||
hasattr(request.state, "user_id") and
|
||||
request.state.user_id is not None
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, ValidationError, validator
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
|
||||
|
||||
class SchedulerConfig(BaseModel):
|
||||
@@ -44,7 +44,8 @@ class LoggingConfig(BaseModel):
|
||||
default=3, ge=0, description="Number of rotated log files to keep"
|
||||
)
|
||||
|
||||
@validator("level")
|
||||
@field_validator("level")
|
||||
@classmethod
|
||||
def validate_level(cls, v: str) -> str:
|
||||
allowed = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
|
||||
lvl = (v or "").upper()
|
||||
|
||||
@@ -253,7 +253,7 @@ class ErrorNotificationMessage(BaseModel):
|
||||
description="Message type",
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
@@ -274,7 +274,7 @@ class ProgressUpdateMessage(BaseModel):
|
||||
..., description="Type of progress message"
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
default_factory=lambda: datetime.now(timezone.utc).isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
|
||||
@@ -115,14 +115,18 @@ class AnimeService:
|
||||
total = progress_data.get("total", 0)
|
||||
message = progress_data.get("message", "Scanning...")
|
||||
|
||||
asyncio.create_task(
|
||||
self._progress_service.update_progress(
|
||||
progress_id=scan_id,
|
||||
current=current,
|
||||
total=total,
|
||||
message=message,
|
||||
# Schedule the coroutine without waiting for it
|
||||
# This is safe because we don't need the result
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
asyncio.ensure_future(
|
||||
self._progress_service.update_progress(
|
||||
progress_id=scan_id,
|
||||
current=current,
|
||||
total=total,
|
||||
message=message,
|
||||
)
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Scan progress callback error", error=str(e))
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ for comprehensive error monitoring and debugging.
|
||||
"""
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -52,7 +52,7 @@ class ErrorTracker:
|
||||
Unique error tracking ID
|
||||
"""
|
||||
error_id = str(uuid.uuid4())
|
||||
timestamp = datetime.utcnow().isoformat()
|
||||
timestamp = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
error_entry = {
|
||||
"id": error_id,
|
||||
@@ -187,7 +187,7 @@ class RequestContextManager:
|
||||
"request_path": request_path,
|
||||
"request_method": request_method,
|
||||
"user_id": user_id,
|
||||
"timestamp": datetime.utcnow().isoformat(),
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
}
|
||||
self.context_stack.append(context)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user