"""Authentication middleware for Aniworld FastAPI app. Responsibilities: - Validate Bearer JWT tokens (optional on public endpoints) - Attach session info to request.state.session when valid - Enforce simple in-memory rate limiting for auth endpoints This middleware is intentionally lightweight and synchronous. For production use consider a distributed rate limiter (Redis) and a proper token revocation store. """ from __future__ import annotations import time from typing import Callable, Dict from fastapi import Request, status from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from starlette.types import ASGIApp from src.server.services.auth_service import AuthError, auth_service class AuthMiddleware(BaseHTTPMiddleware): """Middleware that decodes JWT Bearer tokens (if present) and provides a small rate limiter for authentication endpoints. How it works - If Authorization: Bearer header is present, attempt to decode and create a session model using the existing auth_service. On success, store session dict on ``request.state.session``. - For POST requests to ``/api/auth/login`` and ``/api/auth/setup`` a simple per-IP rate limiter is applied to mitigate brute-force attempts. - Rate limit records are periodically cleaned to prevent memory leaks. """ # Public endpoints that don't require authentication PUBLIC_PATHS = { "/api/auth/", # All auth endpoints "/api/health", # Health check endpoints "/api/docs", # API documentation "/api/redoc", # ReDoc documentation "/openapi.json", # OpenAPI schema } def __init__( self, app: ASGIApp, *, rate_limit_per_minute: int = 5, window_seconds: int = 60 ) -> None: super().__init__(app) # in-memory rate limiter: ip -> {count, window_start} self._rate: Dict[str, Dict[str, float]] = {} # origin-based rate limiter for CORS: origin -> {count, window_start} self._origin_rate: Dict[str, Dict[str, float]] = {} self.rate_limit_per_minute = rate_limit_per_minute self.window_seconds = window_seconds # Track last cleanup time to prevent memory leaks self._last_cleanup = time.time() self._cleanup_interval = 300 # Clean every 5 minutes def _cleanup_old_entries(self) -> None: """Remove rate limit entries older than cleanup interval. This prevents memory leaks from accumulating old IP addresses and origins. """ now = time.time() if now - self._last_cleanup < self._cleanup_interval: return # Remove entries older than 2x window to be safe cutoff = now - (self.window_seconds * 2) # Clean IP-based rate limits old_ips = [ ip for ip, record in self._rate.items() if record["window_start"] < cutoff ] for ip in old_ips: del self._rate[ip] # Clean origin-based rate limits old_origins = [ origin for origin, record in self._origin_rate.items() if record["window_start"] < cutoff ] for origin in old_origins: del self._origin_rate[origin] self._last_cleanup = now def _is_public_path(self, path: str) -> bool: """Check if a path is public and doesn't require authentication. Args: path: The request path to check Returns: bool: True if the path is public, False otherwise """ for public_path in self.PUBLIC_PATHS: if path.startswith(public_path): return True return False async def dispatch(self, request: Request, call_next: Callable): path = request.url.path or "" # Periodically clean up old rate limit entries self._cleanup_old_entries() # Apply origin-based rate limiting for CORS requests origin = request.headers.get("origin") if origin: origin_rate_record = self._origin_rate.setdefault( origin, {"count": 0, "window_start": time.time()}, ) now = time.time() if now - origin_rate_record["window_start"] > self.window_seconds: origin_rate_record["window_start"] = now origin_rate_record["count"] = 0 origin_rate_record["count"] += 1 # Allow higher rate limit for origins (e.g., 60 req/min) if origin_rate_record["count"] > self.rate_limit_per_minute * 12: return JSONResponse( status_code=status.HTTP_429_TOO_MANY_REQUESTS, content={ "detail": "Rate limit exceeded for this origin" }, ) # Apply rate limiting to auth endpoints that accept credentials if ( path in ("/api/auth/login", "/api/auth/setup") and request.method.upper() == "POST" ): client_host = self._get_client_ip(request) rate_limit_record = self._rate.setdefault( client_host, {"count": 0, "window_start": time.time()}, ) now = time.time() # The limiter uses a fixed window; once the window expires, we # reset the counter for that client and start measuring again. if now - rate_limit_record["window_start"] > self.window_seconds: rate_limit_record["window_start"] = now rate_limit_record["count"] = 0 rate_limit_record["count"] += 1 if rate_limit_record["count"] > self.rate_limit_per_minute: # Too many requests in window — return a JSON 429 response return JSONResponse( status_code=status.HTTP_429_TOO_MANY_REQUESTS, content={ "detail": ( "Too many authentication attempts, " "try again later" ) }, ) # If Authorization header present try to decode token # and attach session auth_header = request.headers.get("authorization") if auth_header and auth_header.lower().startswith("bearer "): token = auth_header.split(" ", 1)[1].strip() try: session = auth_service.create_session_model(token) # attach to request.state for downstream usage request.state.session = session.model_dump() except AuthError: # Invalid token: reject if not a public endpoint if not self._is_public_path(path): return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Invalid or expired token"} ) else: # No authorization header: check if this is a protected endpoint if not self._is_public_path(path): return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Missing authorization credentials"} ) return await call_next(request) @staticmethod def _get_client_ip(request: Request) -> str: try: client = request.client if client is None: return "unknown" return client.host or "unknown" except Exception: return "unknown"