2025-10-23 19:41:24 +02:00

204 lines
7.6 KiB
Python

"""Authentication middleware for Aniworld FastAPI app.
Responsibilities:
- Validate Bearer JWT tokens (optional on public endpoints)
- Attach session info to request.state.session when valid
- Enforce simple in-memory rate limiting for auth endpoints
This middleware is intentionally lightweight and synchronous.
For production use consider a distributed rate limiter (Redis) and
a proper token revocation store.
"""
from __future__ import annotations
import time
from typing import Callable, Dict
from fastapi import Request, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from src.server.services.auth_service import AuthError, auth_service
class AuthMiddleware(BaseHTTPMiddleware):
"""Middleware that decodes JWT Bearer tokens (if present) and
provides a small rate limiter for authentication endpoints.
How it works
- If Authorization: Bearer <token> header is present, attempt to
decode and create a session model using the existing auth_service.
On success, store session dict on ``request.state.session``.
- For POST requests to ``/api/auth/login`` and ``/api/auth/setup``
a simple per-IP rate limiter is applied to mitigate brute-force
attempts.
- Rate limit records are periodically cleaned to prevent memory leaks.
"""
# Public endpoints that don't require authentication
PUBLIC_PATHS = {
"/api/auth/", # All auth endpoints
"/api/health", # Health check endpoints
"/api/docs", # API documentation
"/api/redoc", # ReDoc documentation
"/openapi.json", # OpenAPI schema
}
def __init__(
self,
app: ASGIApp,
*,
rate_limit_per_minute: int = 5,
window_seconds: int = 60
) -> None:
super().__init__(app)
# in-memory rate limiter: ip -> {count, window_start}
self._rate: Dict[str, Dict[str, float]] = {}
# origin-based rate limiter for CORS: origin -> {count, window_start}
self._origin_rate: Dict[str, Dict[str, float]] = {}
self.rate_limit_per_minute = rate_limit_per_minute
self.window_seconds = window_seconds
# Track last cleanup time to prevent memory leaks
self._last_cleanup = time.time()
self._cleanup_interval = 300 # Clean every 5 minutes
def _cleanup_old_entries(self) -> None:
"""Remove rate limit entries older than cleanup interval.
This prevents memory leaks from accumulating old IP addresses
and origins.
"""
now = time.time()
if now - self._last_cleanup < self._cleanup_interval:
return
# Remove entries older than 2x window to be safe
cutoff = now - (self.window_seconds * 2)
# Clean IP-based rate limits
old_ips = [
ip for ip, record in self._rate.items()
if record["window_start"] < cutoff
]
for ip in old_ips:
del self._rate[ip]
# Clean origin-based rate limits
old_origins = [
origin for origin, record in self._origin_rate.items()
if record["window_start"] < cutoff
]
for origin in old_origins:
del self._origin_rate[origin]
self._last_cleanup = now
def _is_public_path(self, path: str) -> bool:
"""Check if a path is public and doesn't require authentication.
Args:
path: The request path to check
Returns:
bool: True if the path is public, False otherwise
"""
for public_path in self.PUBLIC_PATHS:
if path.startswith(public_path):
return True
return False
async def dispatch(self, request: Request, call_next: Callable):
path = request.url.path or ""
# Periodically clean up old rate limit entries
self._cleanup_old_entries()
# Apply origin-based rate limiting for CORS requests
origin = request.headers.get("origin")
if origin:
origin_rate_record = self._origin_rate.setdefault(
origin,
{"count": 0, "window_start": time.time()},
)
now = time.time()
if now - origin_rate_record["window_start"] > self.window_seconds:
origin_rate_record["window_start"] = now
origin_rate_record["count"] = 0
origin_rate_record["count"] += 1
# Allow higher rate limit for origins (e.g., 60 req/min)
if origin_rate_record["count"] > self.rate_limit_per_minute * 12:
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"detail": "Rate limit exceeded for this origin"
},
)
# Apply rate limiting to auth endpoints that accept credentials
if (
path in ("/api/auth/login", "/api/auth/setup")
and request.method.upper() == "POST"
):
client_host = self._get_client_ip(request)
rate_limit_record = self._rate.setdefault(
client_host,
{"count": 0, "window_start": time.time()},
)
now = time.time()
# The limiter uses a fixed window; once the window expires, we
# reset the counter for that client and start measuring again.
if now - rate_limit_record["window_start"] > self.window_seconds:
rate_limit_record["window_start"] = now
rate_limit_record["count"] = 0
rate_limit_record["count"] += 1
if rate_limit_record["count"] > self.rate_limit_per_minute:
# Too many requests in window — return a JSON 429 response
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"detail": (
"Too many authentication attempts, "
"try again later"
)
},
)
# If Authorization header present try to decode token
# and attach session
auth_header = request.headers.get("authorization")
if auth_header and auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1].strip()
try:
session = auth_service.create_session_model(token)
# attach to request.state for downstream usage
request.state.session = session.model_dump()
except AuthError:
# Invalid token: reject if not a public endpoint
if not self._is_public_path(path):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Invalid or expired token"}
)
else:
# No authorization header: check if this is a protected endpoint
if not self._is_public_path(path):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Missing authorization credentials"}
)
return await call_next(request)
@staticmethod
def _get_client_ip(request: Request) -> str:
try:
client = request.client
if client is None:
return "unknown"
return client.host or "unknown"
except Exception:
return "unknown"