- Add documentation warnings for in-memory rate limiting and failed login attempts - Consolidate duplicate health endpoints into api/health.py - Fix CLI to use correct async rescan method names - Update download.py and anime.py to use custom exception classes - Add WebSocket room validation and rate limiting
220 lines
8.3 KiB
Python
220 lines
8.3 KiB
Python
"""Authentication middleware for Aniworld FastAPI app.
|
|
|
|
Responsibilities:
|
|
- Validate Bearer JWT tokens (optional on public endpoints)
|
|
- Attach session info to request.state.session when valid
|
|
- Enforce simple in-memory rate limiting for auth endpoints
|
|
|
|
This middleware is intentionally lightweight and synchronous.
|
|
For production use consider a distributed rate limiter (Redis) and
|
|
a proper token revocation store.
|
|
|
|
WARNING - SINGLE PROCESS LIMITATION:
|
|
Rate limiting state is stored in memory dictionaries which RESET when
|
|
the process restarts. This means:
|
|
- Attackers can bypass rate limits by triggering a process restart
|
|
- Rate limits are not shared across multiple workers/processes
|
|
|
|
For production deployments, consider:
|
|
- Using Redis-backed rate limiting (e.g., slowapi with Redis)
|
|
- Running behind a reverse proxy with rate limiting (nginx, HAProxy)
|
|
- Using a dedicated rate limiting service
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
from typing import Callable, Dict
|
|
|
|
from fastapi import Request, status
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.types import ASGIApp
|
|
|
|
from src.server.services.auth_service import AuthError, auth_service
|
|
|
|
|
|
class AuthMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware that decodes JWT Bearer tokens (if present) and
|
|
provides a small rate limiter for authentication endpoints.
|
|
|
|
How it works
|
|
- If Authorization: Bearer <token> header is present, attempt to
|
|
decode and create a session model using the existing auth_service.
|
|
On success, store session dict on ``request.state.session``.
|
|
- For POST requests to ``/api/auth/login`` and ``/api/auth/setup``
|
|
a simple per-IP rate limiter is applied to mitigate brute-force
|
|
attempts.
|
|
- Rate limit records are periodically cleaned to prevent memory leaks.
|
|
"""
|
|
|
|
# Public endpoints that don't require authentication
|
|
PUBLIC_PATHS = {
|
|
"/api/auth/", # All auth endpoints
|
|
"/api/health", # Health check endpoints
|
|
"/api/docs", # API documentation
|
|
"/api/redoc", # ReDoc documentation
|
|
"/openapi.json", # OpenAPI schema
|
|
"/static/", # Static files (CSS, JS, images)
|
|
"/", # Landing page
|
|
"/login", # Login page
|
|
"/setup", # Setup page
|
|
"/queue", # Queue page (needs to be accessible for initial load)
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
app: ASGIApp,
|
|
*,
|
|
rate_limit_per_minute: int = 5,
|
|
window_seconds: int = 60
|
|
) -> None:
|
|
super().__init__(app)
|
|
# in-memory rate limiter: ip -> {count, window_start}
|
|
self._rate: Dict[str, Dict[str, float]] = {}
|
|
# origin-based rate limiter for CORS: origin -> {count, window_start}
|
|
self._origin_rate: Dict[str, Dict[str, float]] = {}
|
|
self.rate_limit_per_minute = rate_limit_per_minute
|
|
self.window_seconds = window_seconds
|
|
# Track last cleanup time to prevent memory leaks
|
|
self._last_cleanup = time.time()
|
|
self._cleanup_interval = 300 # Clean every 5 minutes
|
|
|
|
def _cleanup_old_entries(self) -> None:
|
|
"""Remove rate limit entries older than cleanup interval.
|
|
|
|
This prevents memory leaks from accumulating old IP addresses
|
|
and origins.
|
|
"""
|
|
now = time.time()
|
|
if now - self._last_cleanup < self._cleanup_interval:
|
|
return
|
|
|
|
# Remove entries older than 2x window to be safe
|
|
cutoff = now - (self.window_seconds * 2)
|
|
|
|
# Clean IP-based rate limits
|
|
old_ips = [
|
|
ip for ip, record in self._rate.items()
|
|
if record["window_start"] < cutoff
|
|
]
|
|
for ip in old_ips:
|
|
del self._rate[ip]
|
|
|
|
# Clean origin-based rate limits
|
|
old_origins = [
|
|
origin for origin, record in self._origin_rate.items()
|
|
if record["window_start"] < cutoff
|
|
]
|
|
for origin in old_origins:
|
|
del self._origin_rate[origin]
|
|
|
|
self._last_cleanup = now
|
|
|
|
def _is_public_path(self, path: str) -> bool:
|
|
"""Check if a path is public and doesn't require authentication.
|
|
|
|
Args:
|
|
path: The request path to check
|
|
|
|
Returns:
|
|
bool: True if the path is public, False otherwise
|
|
"""
|
|
for public_path in self.PUBLIC_PATHS:
|
|
if path.startswith(public_path):
|
|
return True
|
|
return False
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable):
|
|
path = request.url.path or ""
|
|
|
|
# Periodically clean up old rate limit entries
|
|
self._cleanup_old_entries()
|
|
|
|
# Apply origin-based rate limiting for CORS requests
|
|
origin = request.headers.get("origin")
|
|
if origin:
|
|
origin_rate_record = self._origin_rate.setdefault(
|
|
origin,
|
|
{"count": 0, "window_start": time.time()},
|
|
)
|
|
now = time.time()
|
|
if now - origin_rate_record["window_start"] > self.window_seconds:
|
|
origin_rate_record["window_start"] = now
|
|
origin_rate_record["count"] = 0
|
|
|
|
origin_rate_record["count"] += 1
|
|
# Allow higher rate limit for origins (e.g., 60 req/min)
|
|
if origin_rate_record["count"] > self.rate_limit_per_minute * 12:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
content={
|
|
"detail": "Rate limit exceeded for this origin"
|
|
},
|
|
)
|
|
|
|
# Apply rate limiting to auth endpoints that accept credentials
|
|
if (
|
|
path in ("/api/auth/login", "/api/auth/setup")
|
|
and request.method.upper() == "POST"
|
|
):
|
|
client_host = self._get_client_ip(request)
|
|
rate_limit_record = self._rate.setdefault(
|
|
client_host,
|
|
{"count": 0, "window_start": time.time()},
|
|
)
|
|
now = time.time()
|
|
# The limiter uses a fixed window; once the window expires, we
|
|
# reset the counter for that client and start measuring again.
|
|
if now - rate_limit_record["window_start"] > self.window_seconds:
|
|
rate_limit_record["window_start"] = now
|
|
rate_limit_record["count"] = 0
|
|
|
|
rate_limit_record["count"] += 1
|
|
if rate_limit_record["count"] > self.rate_limit_per_minute:
|
|
# Too many requests in window — return a JSON 429 response
|
|
return JSONResponse(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
content={
|
|
"detail": (
|
|
"Too many authentication attempts, "
|
|
"try again later"
|
|
)
|
|
},
|
|
)
|
|
|
|
# If Authorization header present try to decode token
|
|
# and attach session
|
|
auth_header = request.headers.get("authorization")
|
|
if auth_header and auth_header.lower().startswith("bearer "):
|
|
token = auth_header.split(" ", 1)[1].strip()
|
|
try:
|
|
session = auth_service.create_session_model(token)
|
|
# attach to request.state for downstream usage
|
|
request.state.session = session.model_dump()
|
|
except AuthError:
|
|
# Invalid token: reject if not a public endpoint
|
|
if not self._is_public_path(path):
|
|
return JSONResponse(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
content={"detail": "Invalid or expired token"}
|
|
)
|
|
else:
|
|
# No authorization header: check if this is a protected endpoint
|
|
if not self._is_public_path(path):
|
|
return JSONResponse(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
content={"detail": "Missing authorization credentials"}
|
|
)
|
|
|
|
return await call_next(request)
|
|
|
|
@staticmethod
|
|
def _get_client_ip(request: Request) -> str:
|
|
try:
|
|
client = request.client
|
|
if client is None:
|
|
return "unknown"
|
|
return client.host or "unknown"
|
|
except Exception:
|
|
return "unknown"
|