"""Authentication middleware for Aniworld FastAPI app. Responsibilities: - Validate Bearer JWT tokens (optional on public endpoints) - Attach session info to request.state.session when valid - Enforce simple in-memory rate limiting for auth endpoints This middleware is intentionally lightweight and synchronous. For production use consider a distributed rate limiter (Redis) and a proper token revocation store. """ from __future__ import annotations import time from typing import Callable, Dict, Optional from fastapi import HTTPException, Request, status from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from starlette.types import ASGIApp from src.server.services.auth_service import AuthError, auth_service class AuthMiddleware(BaseHTTPMiddleware): """Middleware that decodes JWT Bearer tokens (if present) and provides a small rate limiter for authentication endpoints. How it works - If Authorization: Bearer header is present, attempt to decode and create a session model using the existing auth_service. On success, store session dict on ``request.state.session``. - For POST requests to ``/api/auth/login`` and ``/api/auth/setup`` a simple per-IP rate limiter is applied to mitigate brute-force attempts. """ def __init__(self, app: ASGIApp, *, rate_limit_per_minute: int = 5) -> None: super().__init__(app) # in-memory rate limiter: ip -> {count, window_start} self._rate: Dict[str, Dict[str, float]] = {} self.rate_limit_per_minute = rate_limit_per_minute self.window_seconds = 60 async def dispatch(self, request: Request, call_next: Callable): path = request.url.path or "" # Apply rate limiting to auth endpoints that accept credentials if path in ("/api/auth/login", "/api/auth/setup") and request.method.upper() == "POST": client_host = self._get_client_ip(request) rec = self._rate.setdefault(client_host, {"count": 0, "window_start": time.time()}) now = time.time() if now - rec["window_start"] > self.window_seconds: # reset window rec["window_start"] = now rec["count"] = 0 rec["count"] += 1 if rec["count"] > self.rate_limit_per_minute: # Too many requests in window — return a JSON 429 response return JSONResponse( status_code=status.HTTP_429_TOO_MANY_REQUESTS, content={"detail": "Too many authentication attempts, try again later"}, ) # If Authorization header present try to decode token and attach session auth_header = request.headers.get("authorization") if auth_header and auth_header.lower().startswith("bearer "): token = auth_header.split(" ", 1)[1].strip() try: session = auth_service.create_session_model(token) # attach to request.state for downstream usage request.state.session = session.dict() except AuthError: # Invalid token: if this is a protected API path, reject. # For public/auth endpoints let the dependency system handle # optional auth and return None. if path.startswith("/api/") and not path.startswith("/api/auth"): raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") return await call_next(request) @staticmethod def _get_client_ip(request: Request) -> str: try: client = request.client if client is None: return "unknown" return client.host or "unknown" except Exception: return "unknown"