Aniworld/src/server/middleware/rate_limit.py
Lukas 17e5a551e1 feat: migrate to Pydantic V2 and implement rate limiting middleware
- Migrate settings.py to Pydantic V2 (SettingsConfigDict, validation_alias)
- Update config models to use @field_validator with @classmethod
- Replace deprecated datetime.utcnow() with datetime.now(timezone.utc)
- Migrate FastAPI app from @app.on_event to lifespan context manager
- Implement comprehensive rate limiting middleware with:
  * Endpoint-specific rate limits (login: 5/min, register: 3/min)
  * IP-based and user-based tracking
  * Authenticated user multiplier (2x limits)
  * Bypass paths for health, docs, static, websocket endpoints
  * Rate limit headers in responses
- Add 13 comprehensive tests for rate limiting (all passing)
- Update instructions.md to mark completed tasks
- Fix asyncio.create_task usage in anime_service.py

All 714 tests passing. No deprecation warnings.
2025-10-23 22:03:15 +02:00

332 lines
10 KiB
Python

"""Rate limiting middleware for API endpoints.
This module provides comprehensive rate limiting with support for:
- Endpoint-specific rate limits
- IP-based limiting
- User-based rate limiting
- Bypass mechanisms for authenticated users
"""
import time
from collections import defaultdict
from typing import Callable, Dict, Optional, Tuple
from fastapi import Request, status
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse
class RateLimitConfig:
"""Configuration for rate limiting rules."""
def __init__(
self,
requests_per_minute: int = 60,
requests_per_hour: int = 1000,
authenticated_multiplier: float = 2.0,
):
"""Initialize rate limit configuration.
Args:
requests_per_minute: Max requests per minute for
unauthenticated users
requests_per_hour: Max requests per hour for
unauthenticated users
authenticated_multiplier: Multiplier for authenticated users
"""
self.requests_per_minute = requests_per_minute
self.requests_per_hour = requests_per_hour
self.authenticated_multiplier = authenticated_multiplier
class RateLimitStore:
"""In-memory store for rate limit tracking."""
def __init__(self):
"""Initialize the rate limit store."""
# Store format: {identifier: [(timestamp, count), ...]}
self._minute_store: Dict[str, list] = defaultdict(list)
self._hour_store: Dict[str, list] = defaultdict(list)
def check_limit(
self,
identifier: str,
max_per_minute: int,
max_per_hour: int,
) -> Tuple[bool, Optional[int]]:
"""Check if the identifier has exceeded rate limits.
Args:
identifier: Unique identifier (IP or user ID)
max_per_minute: Maximum requests allowed per minute
max_per_hour: Maximum requests allowed per hour
Returns:
Tuple of (allowed, retry_after_seconds)
"""
current_time = time.time()
# Clean up old entries
self._cleanup_old_entries(identifier, current_time)
# Check minute limit
minute_count = len(self._minute_store[identifier])
if minute_count >= max_per_minute:
# Calculate retry after time
oldest_entry = self._minute_store[identifier][0]
retry_after = int(60 - (current_time - oldest_entry))
return False, max(retry_after, 1)
# Check hour limit
hour_count = len(self._hour_store[identifier])
if hour_count >= max_per_hour:
# Calculate retry after time
oldest_entry = self._hour_store[identifier][0]
retry_after = int(3600 - (current_time - oldest_entry))
return False, max(retry_after, 1)
return True, None
def record_request(self, identifier: str) -> None:
"""Record a request for the identifier.
Args:
identifier: Unique identifier (IP or user ID)
"""
current_time = time.time()
self._minute_store[identifier].append(current_time)
self._hour_store[identifier].append(current_time)
def get_remaining_requests(
self, identifier: str, max_per_minute: int, max_per_hour: int
) -> Tuple[int, int]:
"""Get remaining requests for the identifier.
Args:
identifier: Unique identifier
max_per_minute: Maximum per minute
max_per_hour: Maximum per hour
Returns:
Tuple of (remaining_per_minute, remaining_per_hour)
"""
minute_used = len(self._minute_store.get(identifier, []))
hour_used = len(self._hour_store.get(identifier, []))
return (
max(0, max_per_minute - minute_used),
max(0, max_per_hour - hour_used)
)
def _cleanup_old_entries(
self, identifier: str, current_time: float
) -> None:
"""Remove entries older than the time windows.
Args:
identifier: Unique identifier
current_time: Current timestamp
"""
# Remove entries older than 1 minute
minute_cutoff = current_time - 60
self._minute_store[identifier] = [
ts for ts in self._minute_store[identifier] if ts > minute_cutoff
]
# Remove entries older than 1 hour
hour_cutoff = current_time - 3600
self._hour_store[identifier] = [
ts for ts in self._hour_store[identifier] if ts > hour_cutoff
]
# Clean up empty entries
if not self._minute_store[identifier]:
del self._minute_store[identifier]
if not self._hour_store[identifier]:
del self._hour_store[identifier]
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Middleware for API rate limiting."""
# Endpoint-specific rate limits (overrides defaults)
ENDPOINT_LIMITS: Dict[str, RateLimitConfig] = {
"/api/auth/login": RateLimitConfig(
requests_per_minute=5,
requests_per_hour=20,
),
"/api/auth/register": RateLimitConfig(
requests_per_minute=3,
requests_per_hour=10,
),
"/api/download": RateLimitConfig(
requests_per_minute=10,
requests_per_hour=100,
authenticated_multiplier=3.0,
),
}
# Paths that bypass rate limiting
BYPASS_PATHS = {
"/health",
"/health/detailed",
"/docs",
"/redoc",
"/openapi.json",
"/static",
"/ws",
}
def __init__(
self,
app,
default_config: Optional[RateLimitConfig] = None,
):
"""Initialize rate limiting middleware.
Args:
app: FastAPI application
default_config: Default rate limit configuration
"""
super().__init__(app)
self.default_config = default_config or RateLimitConfig()
self.store = RateLimitStore()
async def dispatch(self, request: Request, call_next: Callable):
"""Process request and apply rate limiting.
Args:
request: Incoming HTTP request
call_next: Next middleware or endpoint handler
Returns:
HTTP response (either rate limit error or normal response)
"""
# Check if path should bypass rate limiting
if self._should_bypass(request.url.path):
return await call_next(request)
# Get identifier (user ID if authenticated, otherwise IP)
identifier = self._get_identifier(request)
# Get rate limit configuration for this endpoint
config = self._get_endpoint_config(request.url.path)
# Apply authenticated user multiplier if applicable
is_authenticated = self._is_authenticated(request)
max_per_minute = int(
config.requests_per_minute *
(config.authenticated_multiplier if is_authenticated else 1.0)
)
max_per_hour = int(
config.requests_per_hour *
(config.authenticated_multiplier if is_authenticated else 1.0)
)
# Check rate limit
allowed, retry_after = self.store.check_limit(
identifier,
max_per_minute,
max_per_hour,
)
if not allowed:
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"detail": "Rate limit exceeded"},
headers={"Retry-After": str(retry_after)},
)
# Record the request
self.store.record_request(identifier)
# Add rate limit headers to response
response = await call_next(request)
response.headers["X-RateLimit-Limit-Minute"] = str(max_per_minute)
response.headers["X-RateLimit-Limit-Hour"] = str(max_per_hour)
minute_remaining, hour_remaining = self.store.get_remaining_requests(
identifier, max_per_minute, max_per_hour
)
response.headers["X-RateLimit-Remaining-Minute"] = str(
minute_remaining
)
response.headers["X-RateLimit-Remaining-Hour"] = str(
hour_remaining
)
return response
def _should_bypass(self, path: str) -> bool:
"""Check if path should bypass rate limiting.
Args:
path: Request path
Returns:
True if path should bypass rate limiting
"""
for bypass_path in self.BYPASS_PATHS:
if path.startswith(bypass_path):
return True
return False
def _get_identifier(self, request: Request) -> str:
"""Get unique identifier for rate limiting.
Args:
request: HTTP request
Returns:
Unique identifier (user ID or IP address)
"""
# Try to get user ID from request state (set by auth middleware)
user_id = getattr(request.state, "user_id", None)
if user_id:
return f"user:{user_id}"
# Fall back to IP address
# Check for X-Forwarded-For header (proxy/load balancer)
forwarded_for = request.headers.get("X-Forwarded-For")
if forwarded_for:
# Take the first IP in the chain
client_ip = forwarded_for.split(",")[0].strip()
else:
client_ip = request.client.host if request.client else "unknown"
return f"ip:{client_ip}"
def _get_endpoint_config(self, path: str) -> RateLimitConfig:
"""Get rate limit configuration for endpoint.
Args:
path: Request path
Returns:
Rate limit configuration
"""
# Check for exact match
if path in self.ENDPOINT_LIMITS:
return self.ENDPOINT_LIMITS[path]
# Check for prefix match
for endpoint_path, config in self.ENDPOINT_LIMITS.items():
if path.startswith(endpoint_path):
return config
return self.default_config
def _is_authenticated(self, request: Request) -> bool:
"""Check if request is from authenticated user.
Args:
request: HTTP request
Returns:
True if user is authenticated
"""
return (
hasattr(request.state, "user_id") and
request.state.user_id is not None
)