- Migrate settings.py to Pydantic V2 (SettingsConfigDict, validation_alias) - Update config models to use @field_validator with @classmethod - Replace deprecated datetime.utcnow() with datetime.now(timezone.utc) - Migrate FastAPI app from @app.on_event to lifespan context manager - Implement comprehensive rate limiting middleware with: * Endpoint-specific rate limits (login: 5/min, register: 3/min) * IP-based and user-based tracking * Authenticated user multiplier (2x limits) * Bypass paths for health, docs, static, websocket endpoints * Rate limit headers in responses - Add 13 comprehensive tests for rate limiting (all passing) - Update instructions.md to mark completed tasks - Fix asyncio.create_task usage in anime_service.py All 714 tests passing. No deprecation warnings.
332 lines
10 KiB
Python
332 lines
10 KiB
Python
"""Rate limiting middleware for API endpoints.
|
|
|
|
This module provides comprehensive rate limiting with support for:
|
|
- Endpoint-specific rate limits
|
|
- IP-based limiting
|
|
- User-based rate limiting
|
|
- Bypass mechanisms for authenticated users
|
|
"""
|
|
|
|
import time
|
|
from collections import defaultdict
|
|
from typing import Callable, Dict, Optional, Tuple
|
|
|
|
from fastapi import Request, status
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.responses import JSONResponse
|
|
|
|
|
|
class RateLimitConfig:
|
|
"""Configuration for rate limiting rules."""
|
|
|
|
def __init__(
|
|
self,
|
|
requests_per_minute: int = 60,
|
|
requests_per_hour: int = 1000,
|
|
authenticated_multiplier: float = 2.0,
|
|
):
|
|
"""Initialize rate limit configuration.
|
|
|
|
Args:
|
|
requests_per_minute: Max requests per minute for
|
|
unauthenticated users
|
|
requests_per_hour: Max requests per hour for
|
|
unauthenticated users
|
|
authenticated_multiplier: Multiplier for authenticated users
|
|
"""
|
|
self.requests_per_minute = requests_per_minute
|
|
self.requests_per_hour = requests_per_hour
|
|
self.authenticated_multiplier = authenticated_multiplier
|
|
|
|
|
|
class RateLimitStore:
|
|
"""In-memory store for rate limit tracking."""
|
|
|
|
def __init__(self):
|
|
"""Initialize the rate limit store."""
|
|
# Store format: {identifier: [(timestamp, count), ...]}
|
|
self._minute_store: Dict[str, list] = defaultdict(list)
|
|
self._hour_store: Dict[str, list] = defaultdict(list)
|
|
|
|
def check_limit(
|
|
self,
|
|
identifier: str,
|
|
max_per_minute: int,
|
|
max_per_hour: int,
|
|
) -> Tuple[bool, Optional[int]]:
|
|
"""Check if the identifier has exceeded rate limits.
|
|
|
|
Args:
|
|
identifier: Unique identifier (IP or user ID)
|
|
max_per_minute: Maximum requests allowed per minute
|
|
max_per_hour: Maximum requests allowed per hour
|
|
|
|
Returns:
|
|
Tuple of (allowed, retry_after_seconds)
|
|
"""
|
|
current_time = time.time()
|
|
|
|
# Clean up old entries
|
|
self._cleanup_old_entries(identifier, current_time)
|
|
|
|
# Check minute limit
|
|
minute_count = len(self._minute_store[identifier])
|
|
if minute_count >= max_per_minute:
|
|
# Calculate retry after time
|
|
oldest_entry = self._minute_store[identifier][0]
|
|
retry_after = int(60 - (current_time - oldest_entry))
|
|
return False, max(retry_after, 1)
|
|
|
|
# Check hour limit
|
|
hour_count = len(self._hour_store[identifier])
|
|
if hour_count >= max_per_hour:
|
|
# Calculate retry after time
|
|
oldest_entry = self._hour_store[identifier][0]
|
|
retry_after = int(3600 - (current_time - oldest_entry))
|
|
return False, max(retry_after, 1)
|
|
|
|
return True, None
|
|
|
|
def record_request(self, identifier: str) -> None:
|
|
"""Record a request for the identifier.
|
|
|
|
Args:
|
|
identifier: Unique identifier (IP or user ID)
|
|
"""
|
|
current_time = time.time()
|
|
self._minute_store[identifier].append(current_time)
|
|
self._hour_store[identifier].append(current_time)
|
|
|
|
def get_remaining_requests(
|
|
self, identifier: str, max_per_minute: int, max_per_hour: int
|
|
) -> Tuple[int, int]:
|
|
"""Get remaining requests for the identifier.
|
|
|
|
Args:
|
|
identifier: Unique identifier
|
|
max_per_minute: Maximum per minute
|
|
max_per_hour: Maximum per hour
|
|
|
|
Returns:
|
|
Tuple of (remaining_per_minute, remaining_per_hour)
|
|
"""
|
|
minute_used = len(self._minute_store.get(identifier, []))
|
|
hour_used = len(self._hour_store.get(identifier, []))
|
|
return (
|
|
max(0, max_per_minute - minute_used),
|
|
max(0, max_per_hour - hour_used)
|
|
)
|
|
|
|
def _cleanup_old_entries(
|
|
self, identifier: str, current_time: float
|
|
) -> None:
|
|
"""Remove entries older than the time windows.
|
|
|
|
Args:
|
|
identifier: Unique identifier
|
|
current_time: Current timestamp
|
|
"""
|
|
# Remove entries older than 1 minute
|
|
minute_cutoff = current_time - 60
|
|
self._minute_store[identifier] = [
|
|
ts for ts in self._minute_store[identifier] if ts > minute_cutoff
|
|
]
|
|
|
|
# Remove entries older than 1 hour
|
|
hour_cutoff = current_time - 3600
|
|
self._hour_store[identifier] = [
|
|
ts for ts in self._hour_store[identifier] if ts > hour_cutoff
|
|
]
|
|
|
|
# Clean up empty entries
|
|
if not self._minute_store[identifier]:
|
|
del self._minute_store[identifier]
|
|
if not self._hour_store[identifier]:
|
|
del self._hour_store[identifier]
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
"""Middleware for API rate limiting."""
|
|
|
|
# Endpoint-specific rate limits (overrides defaults)
|
|
ENDPOINT_LIMITS: Dict[str, RateLimitConfig] = {
|
|
"/api/auth/login": RateLimitConfig(
|
|
requests_per_minute=5,
|
|
requests_per_hour=20,
|
|
),
|
|
"/api/auth/register": RateLimitConfig(
|
|
requests_per_minute=3,
|
|
requests_per_hour=10,
|
|
),
|
|
"/api/download": RateLimitConfig(
|
|
requests_per_minute=10,
|
|
requests_per_hour=100,
|
|
authenticated_multiplier=3.0,
|
|
),
|
|
}
|
|
|
|
# Paths that bypass rate limiting
|
|
BYPASS_PATHS = {
|
|
"/health",
|
|
"/health/detailed",
|
|
"/docs",
|
|
"/redoc",
|
|
"/openapi.json",
|
|
"/static",
|
|
"/ws",
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
app,
|
|
default_config: Optional[RateLimitConfig] = None,
|
|
):
|
|
"""Initialize rate limiting middleware.
|
|
|
|
Args:
|
|
app: FastAPI application
|
|
default_config: Default rate limit configuration
|
|
"""
|
|
super().__init__(app)
|
|
self.default_config = default_config or RateLimitConfig()
|
|
self.store = RateLimitStore()
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable):
|
|
"""Process request and apply rate limiting.
|
|
|
|
Args:
|
|
request: Incoming HTTP request
|
|
call_next: Next middleware or endpoint handler
|
|
|
|
Returns:
|
|
HTTP response (either rate limit error or normal response)
|
|
"""
|
|
# Check if path should bypass rate limiting
|
|
if self._should_bypass(request.url.path):
|
|
return await call_next(request)
|
|
|
|
# Get identifier (user ID if authenticated, otherwise IP)
|
|
identifier = self._get_identifier(request)
|
|
|
|
# Get rate limit configuration for this endpoint
|
|
config = self._get_endpoint_config(request.url.path)
|
|
|
|
# Apply authenticated user multiplier if applicable
|
|
is_authenticated = self._is_authenticated(request)
|
|
max_per_minute = int(
|
|
config.requests_per_minute *
|
|
(config.authenticated_multiplier if is_authenticated else 1.0)
|
|
)
|
|
max_per_hour = int(
|
|
config.requests_per_hour *
|
|
(config.authenticated_multiplier if is_authenticated else 1.0)
|
|
)
|
|
|
|
# Check rate limit
|
|
allowed, retry_after = self.store.check_limit(
|
|
identifier,
|
|
max_per_minute,
|
|
max_per_hour,
|
|
)
|
|
|
|
if not allowed:
|
|
return JSONResponse(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
content={"detail": "Rate limit exceeded"},
|
|
headers={"Retry-After": str(retry_after)},
|
|
)
|
|
|
|
# Record the request
|
|
self.store.record_request(identifier)
|
|
|
|
# Add rate limit headers to response
|
|
response = await call_next(request)
|
|
response.headers["X-RateLimit-Limit-Minute"] = str(max_per_minute)
|
|
response.headers["X-RateLimit-Limit-Hour"] = str(max_per_hour)
|
|
|
|
minute_remaining, hour_remaining = self.store.get_remaining_requests(
|
|
identifier, max_per_minute, max_per_hour
|
|
)
|
|
|
|
response.headers["X-RateLimit-Remaining-Minute"] = str(
|
|
minute_remaining
|
|
)
|
|
response.headers["X-RateLimit-Remaining-Hour"] = str(
|
|
hour_remaining
|
|
)
|
|
|
|
return response
|
|
|
|
def _should_bypass(self, path: str) -> bool:
|
|
"""Check if path should bypass rate limiting.
|
|
|
|
Args:
|
|
path: Request path
|
|
|
|
Returns:
|
|
True if path should bypass rate limiting
|
|
"""
|
|
for bypass_path in self.BYPASS_PATHS:
|
|
if path.startswith(bypass_path):
|
|
return True
|
|
return False
|
|
|
|
def _get_identifier(self, request: Request) -> str:
|
|
"""Get unique identifier for rate limiting.
|
|
|
|
Args:
|
|
request: HTTP request
|
|
|
|
Returns:
|
|
Unique identifier (user ID or IP address)
|
|
"""
|
|
# Try to get user ID from request state (set by auth middleware)
|
|
user_id = getattr(request.state, "user_id", None)
|
|
if user_id:
|
|
return f"user:{user_id}"
|
|
|
|
# Fall back to IP address
|
|
# Check for X-Forwarded-For header (proxy/load balancer)
|
|
forwarded_for = request.headers.get("X-Forwarded-For")
|
|
if forwarded_for:
|
|
# Take the first IP in the chain
|
|
client_ip = forwarded_for.split(",")[0].strip()
|
|
else:
|
|
client_ip = request.client.host if request.client else "unknown"
|
|
|
|
return f"ip:{client_ip}"
|
|
|
|
def _get_endpoint_config(self, path: str) -> RateLimitConfig:
|
|
"""Get rate limit configuration for endpoint.
|
|
|
|
Args:
|
|
path: Request path
|
|
|
|
Returns:
|
|
Rate limit configuration
|
|
"""
|
|
# Check for exact match
|
|
if path in self.ENDPOINT_LIMITS:
|
|
return self.ENDPOINT_LIMITS[path]
|
|
|
|
# Check for prefix match
|
|
for endpoint_path, config in self.ENDPOINT_LIMITS.items():
|
|
if path.startswith(endpoint_path):
|
|
return config
|
|
|
|
return self.default_config
|
|
|
|
def _is_authenticated(self, request: Request) -> bool:
|
|
"""Check if request is from authenticated user.
|
|
|
|
Args:
|
|
request: HTTP request
|
|
|
|
Returns:
|
|
True if user is authenticated
|
|
"""
|
|
return (
|
|
hasattr(request.state, "user_id") and
|
|
request.state.user_id is not None
|
|
)
|