92 lines
3.8 KiB
Python

"""Authentication middleware for Aniworld FastAPI app.
Responsibilities:
- Validate Bearer JWT tokens (optional on public endpoints)
- Attach session info to request.state.session when valid
- Enforce simple in-memory rate limiting for auth endpoints
This middleware is intentionally lightweight and synchronous.
For production use consider a distributed rate limiter (Redis) and
a proper token revocation store.
"""
from __future__ import annotations
import time
from typing import Callable, Dict, Optional
from fastapi import HTTPException, Request, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from src.server.services.auth_service import AuthError, auth_service
class AuthMiddleware(BaseHTTPMiddleware):
"""Middleware that decodes JWT Bearer tokens (if present) and
provides a small rate limiter for authentication endpoints.
How it works
- If Authorization: Bearer <token> header is present, attempt to
decode and create a session model using the existing auth_service.
On success, store session dict on ``request.state.session``.
- For POST requests to ``/api/auth/login`` and ``/api/auth/setup``
a simple per-IP rate limiter is applied to mitigate brute-force
attempts.
"""
def __init__(self, app: ASGIApp, *, rate_limit_per_minute: int = 5) -> None:
super().__init__(app)
# in-memory rate limiter: ip -> {count, window_start}
self._rate: Dict[str, Dict[str, float]] = {}
self.rate_limit_per_minute = rate_limit_per_minute
self.window_seconds = 60
async def dispatch(self, request: Request, call_next: Callable):
path = request.url.path or ""
# Apply rate limiting to auth endpoints that accept credentials
if path in ("/api/auth/login", "/api/auth/setup") and request.method.upper() == "POST":
client_host = self._get_client_ip(request)
rec = self._rate.setdefault(client_host, {"count": 0, "window_start": time.time()})
now = time.time()
if now - rec["window_start"] > self.window_seconds:
# reset window
rec["window_start"] = now
rec["count"] = 0
rec["count"] += 1
if rec["count"] > self.rate_limit_per_minute:
# Too many requests in window — return a JSON 429 response
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"detail": "Too many authentication attempts, try again later"},
)
# If Authorization header present try to decode token and attach session
auth_header = request.headers.get("authorization")
if auth_header and auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1].strip()
try:
session = auth_service.create_session_model(token)
# attach to request.state for downstream usage
request.state.session = session.dict()
except AuthError:
# Invalid token: if this is a protected API path, reject.
# For public/auth endpoints let the dependency system handle
# optional auth and return None.
if path.startswith("/api/") and not path.startswith("/api/auth"):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
return await call_next(request)
@staticmethod
def _get_client_ip(request: Request) -> str:
try:
client = request.client
if client is None:
return "unknown"
return client.host or "unknown"
except Exception:
return "unknown"