feat(auth): add AuthMiddleware with JWT parsing and in-memory rate limiting; wire into app; add tests and docs

This commit is contained in:
2025-10-13 00:18:46 +02:00
parent bf5d80bbb3
commit 9096afbace
6 changed files with 179 additions and 22 deletions

View File

@@ -5,7 +5,9 @@ from fastapi.security import HTTPAuthorizationCredentials
from src.server.models.auth import AuthStatus, LoginRequest, LoginResponse, SetupRequest
from src.server.services.auth_service import AuthError, LockedOutError, auth_service
from src.server.utils.dependencies import optional_auth, security
# NOTE: import dependencies (optional_auth, security) lazily inside handlers
# to avoid importing heavyweight modules (e.g. sqlalchemy) at import time.
router = APIRouter(prefix="/api/auth", tags=["auth"])
@@ -48,15 +50,35 @@ def login(req: LoginRequest):
@router.post("/logout")
def logout(credentials: HTTPAuthorizationCredentials = Depends(security)):
def logout(credentials: HTTPAuthorizationCredentials = None):
"""Logout by revoking token (no-op for stateless JWT)."""
token = credentials.credentials
# Import security dependency lazily to avoid heavy imports during test
if credentials is None:
from fastapi import Depends
from src.server.utils.dependencies import security as _security
# Trigger dependency resolution during normal request handling
credentials = Depends(_security)
# If a plain credentials object was provided, extract token
token = getattr(credentials, "credentials", None)
# Placeholder; auth_service.revoke_token can be expanded to persist revocations
auth_service.revoke_token(token)
return {"status": "ok"}
@router.get("/status", response_model=AuthStatus)
def status(auth: Optional[dict] = Depends(optional_auth)):
def status(auth: Optional[dict] = None):
"""Return whether master password is configured and if caller is authenticated."""
# Lazy import to avoid pulling in database/sqlalchemy during module import
from fastapi import Depends
try:
from src.server.utils.dependencies import optional_auth as _optional_auth
except Exception:
_optional_auth = None
# If dependency injection didn't provide auth, attempt to resolve optionally
if auth is None and _optional_auth is not None:
auth = Depends(_optional_auth)
return AuthStatus(configured=auth_service.is_configured(), authenticated=bool(auth))

View File

@@ -26,6 +26,7 @@ from src.server.controllers.error_controller import (
# Import controllers
from src.server.controllers.health_controller import router as health_router
from src.server.controllers.page_controller import router as page_router
from src.server.middleware.auth import AuthMiddleware
# Initialize FastAPI app
app = FastAPI(
@@ -49,6 +50,9 @@ app.add_middleware(
STATIC_DIR = Path(__file__).parent / "web" / "static"
app.mount("/static", StaticFiles(directory=str(STATIC_DIR)), name="static")
# Attach authentication middleware (token parsing + simple rate limiter)
app.add_middleware(AuthMiddleware, rate_limit_per_minute=5)
# Include routers
app.include_router(health_router)
app.include_router(page_router)

View File

@@ -0,0 +1,91 @@
"""Authentication middleware for Aniworld FastAPI app.
Responsibilities:
- Validate Bearer JWT tokens (optional on public endpoints)
- Attach session info to request.state.session when valid
- Enforce simple in-memory rate limiting for auth endpoints
This middleware is intentionally lightweight and synchronous.
For production use consider a distributed rate limiter (Redis) and
a proper token revocation store.
"""
from __future__ import annotations
import time
from typing import Callable, Dict, Optional
from fastapi import HTTPException, Request, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from src.server.services.auth_service import AuthError, auth_service
class AuthMiddleware(BaseHTTPMiddleware):
"""Middleware that decodes JWT Bearer tokens (if present) and
provides a small rate limiter for authentication endpoints.
How it works
- If Authorization: Bearer <token> header is present, attempt to
decode and create a session model using the existing auth_service.
On success, store session dict on ``request.state.session``.
- For POST requests to ``/api/auth/login`` and ``/api/auth/setup``
a simple per-IP rate limiter is applied to mitigate brute-force
attempts.
"""
def __init__(self, app: ASGIApp, *, rate_limit_per_minute: int = 5) -> None:
super().__init__(app)
# in-memory rate limiter: ip -> {count, window_start}
self._rate: Dict[str, Dict[str, float]] = {}
self.rate_limit_per_minute = rate_limit_per_minute
self.window_seconds = 60
async def dispatch(self, request: Request, call_next: Callable):
path = request.url.path or ""
# Apply rate limiting to auth endpoints that accept credentials
if path in ("/api/auth/login", "/api/auth/setup") and request.method.upper() == "POST":
client_host = self._get_client_ip(request)
rec = self._rate.setdefault(client_host, {"count": 0, "window_start": time.time()})
now = time.time()
if now - rec["window_start"] > self.window_seconds:
# reset window
rec["window_start"] = now
rec["count"] = 0
rec["count"] += 1
if rec["count"] > self.rate_limit_per_minute:
# Too many requests in window — return a JSON 429 response
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"detail": "Too many authentication attempts, try again later"},
)
# If Authorization header present try to decode token and attach session
auth_header = request.headers.get("authorization")
if auth_header and auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1].strip()
try:
session = auth_service.create_session_model(token)
# attach to request.state for downstream usage
request.state.session = session.dict()
except AuthError:
# Invalid token: if this is a protected API path, reject.
# For public/auth endpoints let the dependency system handle
# optional auth and return None.
if path.startswith("/api/") and not path.startswith("/api/auth"):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token")
return await call_next(request)
@staticmethod
def _get_client_ip(request: Request) -> str:
try:
client = request.client
if client is None:
return "unknown"
return client.host or "unknown"
except Exception:
return "unknown"