feat: migrate to Pydantic V2 and implement rate limiting middleware

- Migrate settings.py to Pydantic V2 (SettingsConfigDict, validation_alias)
- Update config models to use @field_validator with @classmethod
- Replace deprecated datetime.utcnow() with datetime.now(timezone.utc)
- Migrate FastAPI app from @app.on_event to lifespan context manager
- Implement comprehensive rate limiting middleware with:
  * Endpoint-specific rate limits (login: 5/min, register: 3/min)
  * IP-based and user-based tracking
  * Authenticated user multiplier (2x limits)
  * Bypass paths for health, docs, static, websocket endpoints
  * Rate limit headers in responses
- Add 13 comprehensive tests for rate limiting (all passing)
- Update instructions.md to mark completed tasks
- Fix asyncio.create_task usage in anime_service.py

All 714 tests passing. No deprecation warnings.
This commit is contained in:
2025-10-23 22:03:15 +02:00
parent 6a6ae7e059
commit 17e5a551e1
23 changed files with 949 additions and 269 deletions

View File

@@ -2,18 +2,25 @@ import secrets
from typing import Optional
from pydantic import Field
from pydantic_settings import BaseSettings
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
"""Application settings from environment variables."""
model_config = SettingsConfigDict(env_file=".env", extra="ignore")
jwt_secret_key: str = Field(
default_factory=lambda: secrets.token_urlsafe(32),
env="JWT_SECRET_KEY",
validation_alias="JWT_SECRET_KEY",
)
password_salt: str = Field(
default="default-salt",
validation_alias="PASSWORD_SALT"
)
password_salt: str = Field(default="default-salt", env="PASSWORD_SALT")
master_password_hash: Optional[str] = Field(
default=None, env="MASTER_PASSWORD_HASH"
default=None,
validation_alias="MASTER_PASSWORD_HASH"
)
# ⚠️ WARNING: DEVELOPMENT ONLY - NEVER USE IN PRODUCTION ⚠️
# This field allows setting a plaintext master password via environment
@@ -21,32 +28,50 @@ class Settings(BaseSettings):
# deployments, use MASTER_PASSWORD_HASH instead and NEVER set this field.
master_password: Optional[str] = Field(
default=None,
env="MASTER_PASSWORD",
validation_alias="MASTER_PASSWORD",
description=(
"**DEVELOPMENT ONLY** - Plaintext master password. "
"NEVER enable in production. Use MASTER_PASSWORD_HASH instead."
),
)
token_expiry_hours: int = Field(
default=24, env="SESSION_TIMEOUT_HOURS"
default=24,
validation_alias="SESSION_TIMEOUT_HOURS"
)
anime_directory: str = Field(
default="",
validation_alias="ANIME_DIRECTORY"
)
log_level: str = Field(
default="INFO",
validation_alias="LOG_LEVEL"
)
anime_directory: str = Field(default="", env="ANIME_DIRECTORY")
log_level: str = Field(default="INFO", env="LOG_LEVEL")
# Additional settings from .env
database_url: str = Field(
default="sqlite:///./data/aniworld.db", env="DATABASE_URL"
default="sqlite:///./data/aniworld.db",
validation_alias="DATABASE_URL"
)
cors_origins: str = Field(
default="http://localhost:3000",
env="CORS_ORIGINS",
validation_alias="CORS_ORIGINS",
)
api_rate_limit: int = Field(
default=100,
validation_alias="API_RATE_LIMIT"
)
api_rate_limit: int = Field(default=100, env="API_RATE_LIMIT")
default_provider: str = Field(
default="aniworld.to", env="DEFAULT_PROVIDER"
default="aniworld.to",
validation_alias="DEFAULT_PROVIDER"
)
provider_timeout: int = Field(
default=30,
validation_alias="PROVIDER_TIMEOUT"
)
retry_attempts: int = Field(
default=3,
validation_alias="RETRY_ATTEMPTS"
)
provider_timeout: int = Field(default=30, env="PROVIDER_TIMEOUT")
retry_attempts: int = Field(default=3, env="RETRY_ATTEMPTS")
@property
def allowed_origins(self) -> list[str]:
@@ -67,9 +92,5 @@ class Settings(BaseSettings):
]
return [origin.strip() for origin in raw.split(",") if origin.strip()]
class Config:
env_file = ".env"
extra = "ignore"
settings = Settings()

View File

@@ -5,6 +5,7 @@ This module provides the main FastAPI application with proper CORS
configuration, middleware setup, static file serving, and Jinja2 template
integration.
"""
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional
@@ -36,13 +37,71 @@ from src.server.middleware.error_handler import register_exception_handlers
from src.server.services.progress_service import get_progress_service
from src.server.services.websocket_service import get_websocket_service
# Initialize FastAPI app
# Prefer storing application-wide singletons on FastAPI.state instead of
# module-level globals. This makes testing and multi-instance hosting safer.
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Manage application lifespan (startup and shutdown)."""
# Startup
try:
# Initialize SeriesApp with configured directory and store it on
# application state so it can be injected via dependencies.
if settings.anime_directory:
app.state.series_app = SeriesApp(settings.anime_directory)
else:
# Log warning when anime directory is not configured
print(
"WARNING: ANIME_DIRECTORY not configured. "
"Some features may be unavailable."
)
# Initialize progress service with websocket callback
progress_service = get_progress_service()
ws_service = get_websocket_service()
async def broadcast_callback(
message_type: str, data: dict, room: str
) -> None:
"""Broadcast progress updates via WebSocket."""
message = {
"type": message_type,
"data": data,
}
await ws_service.manager.broadcast_to_room(message, room)
progress_service.set_broadcast_callback(broadcast_callback)
print("FastAPI application started successfully")
except Exception as e:
print(f"Error during startup: {e}")
raise # Re-raise to prevent app from starting in broken state
# Yield control to the application
yield
# Shutdown
print("FastAPI application shutting down")
def get_series_app() -> Optional[SeriesApp]:
"""Dependency to retrieve the SeriesApp instance from application state.
Returns None when the application wasn't configured with an anime
directory (for example during certain test runs).
"""
return getattr(app.state, "series_app", None)
# Initialize FastAPI app with lifespan
app = FastAPI(
title="Aniworld Download Manager",
description="Modern web interface for Aniworld anime download management",
version="1.0.0",
docs_url="/api/docs",
redoc_url="/api/redoc"
redoc_url="/api/redoc",
lifespan=lifespan
)
# Configure CORS using environment-driven configuration.
@@ -79,61 +138,6 @@ app.include_router(websocket_router)
# Register exception handlers
register_exception_handlers(app)
# Prefer storing application-wide singletons on FastAPI.state instead of
# module-level globals. This makes testing and multi-instance hosting safer.
def get_series_app() -> Optional[SeriesApp]:
"""Dependency to retrieve the SeriesApp instance from application state.
Returns None when the application wasn't configured with an anime
directory (for example during certain test runs).
"""
return getattr(app.state, "series_app", None)
@app.on_event("startup")
async def startup_event() -> None:
"""Initialize application on startup."""
try:
# Initialize SeriesApp with configured directory and store it on
# application state so it can be injected via dependencies.
if settings.anime_directory:
app.state.series_app = SeriesApp(settings.anime_directory)
else:
# Log warning when anime directory is not configured
print(
"WARNING: ANIME_DIRECTORY not configured. "
"Some features may be unavailable."
)
# Initialize progress service with websocket callback
progress_service = get_progress_service()
ws_service = get_websocket_service()
async def broadcast_callback(
message_type: str, data: dict, room: str
) -> None:
"""Broadcast progress updates via WebSocket."""
message = {
"type": message_type,
"data": data,
}
await ws_service.manager.broadcast_to_room(message, room)
progress_service.set_broadcast_callback(broadcast_callback)
print("FastAPI application started successfully")
except Exception as e:
print(f"Error during startup: {e}")
raise # Re-raise to prevent app from starting in broken state
@app.on_event("shutdown")
async def shutdown_event():
"""Cleanup on application shutdown."""
print("FastAPI application shutting down")
@app.exception_handler(404)
async def handle_not_found(request: Request, exc: HTTPException):

View File

@@ -0,0 +1,331 @@
"""Rate limiting middleware for API endpoints.
This module provides comprehensive rate limiting with support for:
- Endpoint-specific rate limits
- IP-based limiting
- User-based rate limiting
- Bypass mechanisms for authenticated users
"""
import time
from collections import defaultdict
from typing import Callable, Dict, Optional, Tuple
from fastapi import Request, status
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
class RateLimitConfig:
"""Configuration for rate limiting rules."""
def __init__(
self,
requests_per_minute: int = 60,
requests_per_hour: int = 1000,
authenticated_multiplier: float = 2.0,
):
"""Initialize rate limit configuration.
Args:
requests_per_minute: Max requests per minute for
unauthenticated users
requests_per_hour: Max requests per hour for
unauthenticated users
authenticated_multiplier: Multiplier for authenticated users
"""
self.requests_per_minute = requests_per_minute
self.requests_per_hour = requests_per_hour
self.authenticated_multiplier = authenticated_multiplier
class RateLimitStore:
"""In-memory store for rate limit tracking."""
def __init__(self):
"""Initialize the rate limit store."""
# Store format: {identifier: [(timestamp, count), ...]}
self._minute_store: Dict[str, list] = defaultdict(list)
self._hour_store: Dict[str, list] = defaultdict(list)
def check_limit(
self,
identifier: str,
max_per_minute: int,
max_per_hour: int,
) -> Tuple[bool, Optional[int]]:
"""Check if the identifier has exceeded rate limits.
Args:
identifier: Unique identifier (IP or user ID)
max_per_minute: Maximum requests allowed per minute
max_per_hour: Maximum requests allowed per hour
Returns:
Tuple of (allowed, retry_after_seconds)
"""
current_time = time.time()
# Clean up old entries
self._cleanup_old_entries(identifier, current_time)
# Check minute limit
minute_count = len(self._minute_store[identifier])
if minute_count >= max_per_minute:
# Calculate retry after time
oldest_entry = self._minute_store[identifier][0]
retry_after = int(60 - (current_time - oldest_entry))
return False, max(retry_after, 1)
# Check hour limit
hour_count = len(self._hour_store[identifier])
if hour_count >= max_per_hour:
# Calculate retry after time
oldest_entry = self._hour_store[identifier][0]
retry_after = int(3600 - (current_time - oldest_entry))
return False, max(retry_after, 1)
return True, None
def record_request(self, identifier: str) -> None:
"""Record a request for the identifier.
Args:
identifier: Unique identifier (IP or user ID)
"""
current_time = time.time()
self._minute_store[identifier].append(current_time)
self._hour_store[identifier].append(current_time)
def get_remaining_requests(
self, identifier: str, max_per_minute: int, max_per_hour: int
) -> Tuple[int, int]:
"""Get remaining requests for the identifier.
Args:
identifier: Unique identifier
max_per_minute: Maximum per minute
max_per_hour: Maximum per hour
Returns:
Tuple of (remaining_per_minute, remaining_per_hour)
"""
minute_used = len(self._minute_store.get(identifier, []))
hour_used = len(self._hour_store.get(identifier, []))
return (
max(0, max_per_minute - minute_used),
max(0, max_per_hour - hour_used)
)
def _cleanup_old_entries(
self, identifier: str, current_time: float
) -> None:
"""Remove entries older than the time windows.
Args:
identifier: Unique identifier
current_time: Current timestamp
"""
# Remove entries older than 1 minute
minute_cutoff = current_time - 60
self._minute_store[identifier] = [
ts for ts in self._minute_store[identifier] if ts > minute_cutoff
]
# Remove entries older than 1 hour
hour_cutoff = current_time - 3600
self._hour_store[identifier] = [
ts for ts in self._hour_store[identifier] if ts > hour_cutoff
]
# Clean up empty entries
if not self._minute_store[identifier]:
del self._minute_store[identifier]
if not self._hour_store[identifier]:
del self._hour_store[identifier]
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Middleware for API rate limiting."""
# Endpoint-specific rate limits (overrides defaults)
ENDPOINT_LIMITS: Dict[str, RateLimitConfig] = {
"/api/auth/login": RateLimitConfig(
requests_per_minute=5,
requests_per_hour=20,
),
"/api/auth/register": RateLimitConfig(
requests_per_minute=3,
requests_per_hour=10,
),
"/api/download": RateLimitConfig(
requests_per_minute=10,
requests_per_hour=100,
authenticated_multiplier=3.0,
),
}
# Paths that bypass rate limiting
BYPASS_PATHS = {
"/health",
"/health/detailed",
"/docs",
"/redoc",
"/openapi.json",
"/static",
"/ws",
}
def __init__(
self,
app,
default_config: Optional[RateLimitConfig] = None,
):
"""Initialize rate limiting middleware.
Args:
app: FastAPI application
default_config: Default rate limit configuration
"""
super().__init__(app)
self.default_config = default_config or RateLimitConfig()
self.store = RateLimitStore()
async def dispatch(self, request: Request, call_next: Callable):
"""Process request and apply rate limiting.
Args:
request: Incoming HTTP request
call_next: Next middleware or endpoint handler
Returns:
HTTP response (either rate limit error or normal response)
"""
# Check if path should bypass rate limiting
if self._should_bypass(request.url.path):
return await call_next(request)
# Get identifier (user ID if authenticated, otherwise IP)
identifier = self._get_identifier(request)
# Get rate limit configuration for this endpoint
config = self._get_endpoint_config(request.url.path)
# Apply authenticated user multiplier if applicable
is_authenticated = self._is_authenticated(request)
max_per_minute = int(
config.requests_per_minute *
(config.authenticated_multiplier if is_authenticated else 1.0)
)
max_per_hour = int(
config.requests_per_hour *
(config.authenticated_multiplier if is_authenticated else 1.0)
)
# Check rate limit
allowed, retry_after = self.store.check_limit(
identifier,
max_per_minute,
max_per_hour,
)
if not allowed:
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"detail": "Rate limit exceeded"},
headers={"Retry-After": str(retry_after)},
)
# Record the request
self.store.record_request(identifier)
# Add rate limit headers to response
response = await call_next(request)
response.headers["X-RateLimit-Limit-Minute"] = str(max_per_minute)
response.headers["X-RateLimit-Limit-Hour"] = str(max_per_hour)
minute_remaining, hour_remaining = self.store.get_remaining_requests(
identifier, max_per_minute, max_per_hour
)
response.headers["X-RateLimit-Remaining-Minute"] = str(
minute_remaining
)
response.headers["X-RateLimit-Remaining-Hour"] = str(
hour_remaining
)
return response
def _should_bypass(self, path: str) -> bool:
"""Check if path should bypass rate limiting.
Args:
path: Request path
Returns:
True if path should bypass rate limiting
"""
for bypass_path in self.BYPASS_PATHS:
if path.startswith(bypass_path):
return True
return False
def _get_identifier(self, request: Request) -> str:
"""Get unique identifier for rate limiting.
Args:
request: HTTP request
Returns:
Unique identifier (user ID or IP address)
"""
# Try to get user ID from request state (set by auth middleware)
user_id = getattr(request.state, "user_id", None)
if user_id:
return f"user:{user_id}"
# Fall back to IP address
# Check for X-Forwarded-For header (proxy/load balancer)
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# Take the first IP in the chain
client_ip = forwarded_for.split(",")[0].strip()
else:
client_ip = request.client.host if request.client else "unknown"
return f"ip:{client_ip}"
def _get_endpoint_config(self, path: str) -> RateLimitConfig:
"""Get rate limit configuration for endpoint.
Args:
path: Request path
Returns:
Rate limit configuration
"""
# Check for exact match
if path in self.ENDPOINT_LIMITS:
return self.ENDPOINT_LIMITS[path]
# Check for prefix match
for endpoint_path, config in self.ENDPOINT_LIMITS.items():
if path.startswith(endpoint_path):
return config
return self.default_config
def _is_authenticated(self, request: Request) -> bool:
"""Check if request is from authenticated user.
Args:
request: HTTP request
Returns:
True if user is authenticated
"""
return (
hasattr(request.state, "user_id") and
request.state.user_id is not None
)

View File

@@ -1,6 +1,6 @@
from typing import Dict, List, Optional
from pydantic import BaseModel, Field, ValidationError, validator
from pydantic import BaseModel, Field, ValidationError, field_validator
class SchedulerConfig(BaseModel):
@@ -44,7 +44,8 @@ class LoggingConfig(BaseModel):
default=3, ge=0, description="Number of rotated log files to keep"
)
@validator("level")
@field_validator("level")
@classmethod
def validate_level(cls, v: str) -> str:
allowed = {"DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"}
lvl = (v or "").upper()

View File

@@ -253,7 +253,7 @@ class ErrorNotificationMessage(BaseModel):
description="Message type",
)
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(),
default_factory=lambda: datetime.now(timezone.utc).isoformat(),
description="ISO 8601 timestamp",
)
data: Dict[str, Any] = Field(
@@ -274,7 +274,7 @@ class ProgressUpdateMessage(BaseModel):
..., description="Type of progress message"
)
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(),
default_factory=lambda: datetime.now(timezone.utc).isoformat(),
description="ISO 8601 timestamp",
)
data: Dict[str, Any] = Field(

View File

@@ -115,14 +115,18 @@ class AnimeService:
total = progress_data.get("total", 0)
message = progress_data.get("message", "Scanning...")
asyncio.create_task(
self._progress_service.update_progress(
progress_id=scan_id,
current=current,
total=total,
message=message,
# Schedule the coroutine without waiting for it
# This is safe because we don't need the result
loop = asyncio.get_event_loop()
if loop.is_running():
asyncio.ensure_future(
self._progress_service.update_progress(
progress_id=scan_id,
current=current,
total=total,
message=message,
)
)
)
except Exception as e:
logger.error("Scan progress callback error", error=str(e))

View File

@@ -6,7 +6,7 @@ for comprehensive error monitoring and debugging.
"""
import logging
import uuid
from datetime import datetime
from datetime import datetime, timezone
from typing import Any, Dict, Optional
logger = logging.getLogger(__name__)
@@ -52,7 +52,7 @@ class ErrorTracker:
Unique error tracking ID
"""
error_id = str(uuid.uuid4())
timestamp = datetime.utcnow().isoformat()
timestamp = datetime.now(timezone.utc).isoformat()
error_entry = {
"id": error_id,
@@ -187,7 +187,7 @@ class RequestContextManager:
"request_path": request_path,
"request_method": request_method,
"user_id": user_id,
"timestamp": datetime.utcnow().isoformat(),
"timestamp": datetime.now(timezone.utc).isoformat(),
}
self.context_stack.append(context)