92 lines
3.8 KiB
Python
92 lines
3.8 KiB
Python
"""Authentication middleware for Aniworld FastAPI app.
|
|
|
|
Responsibilities:
|
|
- Validate Bearer JWT tokens (optional on public endpoints)
|
|
- Attach session info to request.state.session when valid
|
|
- Enforce simple in-memory rate limiting for auth endpoints
|
|
|
|
This middleware is intentionally lightweight and synchronous.
|
|
For production use consider a distributed rate limiter (Redis) and
|
|
a proper token revocation store.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from typing import Callable, Dict, Optional
|
|
|
|
from fastapi import HTTPException, Request, status
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.types import ASGIApp
|
|
|
|
from src.server.services.auth_service import AuthError, auth_service
|
|
|
|
|
|
class AuthMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware that decodes JWT Bearer tokens (if present) and
|
|
provides a small rate limiter for authentication endpoints.
|
|
|
|
How it works
|
|
- If Authorization: Bearer <token> header is present, attempt to
|
|
decode and create a session model using the existing auth_service.
|
|
On success, store session dict on ``request.state.session``.
|
|
- For POST requests to ``/api/auth/login`` and ``/api/auth/setup``
|
|
a simple per-IP rate limiter is applied to mitigate brute-force
|
|
attempts.
|
|
"""
|
|
|
|
def __init__(self, app: ASGIApp, *, rate_limit_per_minute: int = 5) -> None:
|
|
super().__init__(app)
|
|
# in-memory rate limiter: ip -> {count, window_start}
|
|
self._rate: Dict[str, Dict[str, float]] = {}
|
|
self.rate_limit_per_minute = rate_limit_per_minute
|
|
self.window_seconds = 60
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable):
|
|
path = request.url.path or ""
|
|
|
|
# Apply rate limiting to auth endpoints that accept credentials
|
|
if path in ("/api/auth/login", "/api/auth/setup") and request.method.upper() == "POST":
|
|
client_host = self._get_client_ip(request)
|
|
rec = self._rate.setdefault(client_host, {"count": 0, "window_start": time.time()})
|
|
now = time.time()
|
|
if now - rec["window_start"] > self.window_seconds:
|
|
# reset window
|
|
rec["window_start"] = now
|
|
rec["count"] = 0
|
|
|
|
rec["count"] += 1
|
|
if rec["count"] > self.rate_limit_per_minute:
|
|
# Too many requests in window — return a JSON 429 response
|
|
return JSONResponse(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
content={"detail": "Too many authentication attempts, try again later"},
|
|
)
|
|
|
|
# If Authorization header present try to decode token and attach session
|
|
auth_header = request.headers.get("authorization")
|
|
if auth_header and auth_header.lower().startswith("bearer "):
|
|
token = auth_header.split(" ", 1)[1].strip()
|
|
try:
|
|
session = auth_service.create_session_model(token)
|
|
# attach to request.state for downstream usage
|
|
request.state.session = session.dict()
|
|
except AuthError:
|
|
# Invalid token: if this is a protected API path, reject.
|
|
# For public/auth endpoints let the dependency system handle
|
|
# optional auth and return None.
|
|
if path.startswith("/api/") and not path.startswith("/api/auth"):
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
|
|
|
|
return await call_next(request)
|
|
|
|
@staticmethod
|
|
def _get_client_ip(request: Request) -> str:
|
|
try:
|
|
client = request.client
|
|
if client is None:
|
|
return "unknown"
|
|
return client.host or "unknown"
|
|
except Exception:
|
|
return "unknown"
|