"""Rate limiting middleware for API endpoints. This module provides comprehensive rate limiting with support for: - Endpoint-specific rate limits - IP-based limiting - User-based rate limiting - Bypass mechanisms for authenticated users """ import time from collections import defaultdict from typing import Callable, Dict, Optional, Tuple from fastapi import Request, status from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse class RateLimitConfig: """Configuration for rate limiting rules.""" def __init__( self, requests_per_minute: int = 60, requests_per_hour: int = 1000, authenticated_multiplier: float = 2.0, ): """Initialize rate limit configuration. Args: requests_per_minute: Max requests per minute for unauthenticated users requests_per_hour: Max requests per hour for unauthenticated users authenticated_multiplier: Multiplier for authenticated users """ self.requests_per_minute = requests_per_minute self.requests_per_hour = requests_per_hour self.authenticated_multiplier = authenticated_multiplier class RateLimitStore: """In-memory store for rate limit tracking.""" def __init__(self): """Initialize the rate limit store.""" # Store format: {identifier: [(timestamp, count), ...]} self._minute_store: Dict[str, list] = defaultdict(list) self._hour_store: Dict[str, list] = defaultdict(list) def check_limit( self, identifier: str, max_per_minute: int, max_per_hour: int, ) -> Tuple[bool, Optional[int]]: """Check if the identifier has exceeded rate limits. Args: identifier: Unique identifier (IP or user ID) max_per_minute: Maximum requests allowed per minute max_per_hour: Maximum requests allowed per hour Returns: Tuple of (allowed, retry_after_seconds) """ current_time = time.time() # Clean up old entries self._cleanup_old_entries(identifier, current_time) # Check minute limit minute_count = len(self._minute_store[identifier]) if minute_count >= max_per_minute: # Calculate retry after time oldest_entry = self._minute_store[identifier][0] retry_after = int(60 - (current_time - oldest_entry)) return False, max(retry_after, 1) # Check hour limit hour_count = len(self._hour_store[identifier]) if hour_count >= max_per_hour: # Calculate retry after time oldest_entry = self._hour_store[identifier][0] retry_after = int(3600 - (current_time - oldest_entry)) return False, max(retry_after, 1) return True, None def record_request(self, identifier: str) -> None: """Record a request for the identifier. Args: identifier: Unique identifier (IP or user ID) """ current_time = time.time() self._minute_store[identifier].append(current_time) self._hour_store[identifier].append(current_time) def get_remaining_requests( self, identifier: str, max_per_minute: int, max_per_hour: int ) -> Tuple[int, int]: """Get remaining requests for the identifier. Args: identifier: Unique identifier max_per_minute: Maximum per minute max_per_hour: Maximum per hour Returns: Tuple of (remaining_per_minute, remaining_per_hour) """ minute_used = len(self._minute_store.get(identifier, [])) hour_used = len(self._hour_store.get(identifier, [])) return ( max(0, max_per_minute - minute_used), max(0, max_per_hour - hour_used) ) def _cleanup_old_entries( self, identifier: str, current_time: float ) -> None: """Remove entries older than the time windows. Args: identifier: Unique identifier current_time: Current timestamp """ # Remove entries older than 1 minute minute_cutoff = current_time - 60 self._minute_store[identifier] = [ ts for ts in self._minute_store[identifier] if ts > minute_cutoff ] # Remove entries older than 1 hour hour_cutoff = current_time - 3600 self._hour_store[identifier] = [ ts for ts in self._hour_store[identifier] if ts > hour_cutoff ] # Clean up empty entries if not self._minute_store[identifier]: del self._minute_store[identifier] if not self._hour_store[identifier]: del self._hour_store[identifier] class RateLimitMiddleware(BaseHTTPMiddleware): """Middleware for API rate limiting.""" # Endpoint-specific rate limits (overrides defaults) ENDPOINT_LIMITS: Dict[str, RateLimitConfig] = { "/api/auth/login": RateLimitConfig( requests_per_minute=5, requests_per_hour=20, ), "/api/auth/register": RateLimitConfig( requests_per_minute=3, requests_per_hour=10, ), "/api/download": RateLimitConfig( requests_per_minute=10, requests_per_hour=100, authenticated_multiplier=3.0, ), } # Paths that bypass rate limiting BYPASS_PATHS = { "/health", "/health/detailed", "/docs", "/redoc", "/openapi.json", "/static", "/ws", } def __init__( self, app, default_config: Optional[RateLimitConfig] = None, ): """Initialize rate limiting middleware. Args: app: FastAPI application default_config: Default rate limit configuration """ super().__init__(app) self.default_config = default_config or RateLimitConfig() self.store = RateLimitStore() async def dispatch(self, request: Request, call_next: Callable): """Process request and apply rate limiting. Args: request: Incoming HTTP request call_next: Next middleware or endpoint handler Returns: HTTP response (either rate limit error or normal response) """ # Check if path should bypass rate limiting if self._should_bypass(request.url.path): return await call_next(request) # Get identifier (user ID if authenticated, otherwise IP) identifier = self._get_identifier(request) # Get rate limit configuration for this endpoint config = self._get_endpoint_config(request.url.path) # Apply authenticated user multiplier if applicable is_authenticated = self._is_authenticated(request) max_per_minute = int( config.requests_per_minute * (config.authenticated_multiplier if is_authenticated else 1.0) ) max_per_hour = int( config.requests_per_hour * (config.authenticated_multiplier if is_authenticated else 1.0) ) # Check rate limit allowed, retry_after = self.store.check_limit( identifier, max_per_minute, max_per_hour, ) if not allowed: return JSONResponse( status_code=status.HTTP_429_TOO_MANY_REQUESTS, content={"detail": "Rate limit exceeded"}, headers={"Retry-After": str(retry_after)}, ) # Record the request self.store.record_request(identifier) # Add rate limit headers to response response = await call_next(request) response.headers["X-RateLimit-Limit-Minute"] = str(max_per_minute) response.headers["X-RateLimit-Limit-Hour"] = str(max_per_hour) minute_remaining, hour_remaining = self.store.get_remaining_requests( identifier, max_per_minute, max_per_hour ) response.headers["X-RateLimit-Remaining-Minute"] = str( minute_remaining ) response.headers["X-RateLimit-Remaining-Hour"] = str( hour_remaining ) return response def _should_bypass(self, path: str) -> bool: """Check if path should bypass rate limiting. Args: path: Request path Returns: True if path should bypass rate limiting """ for bypass_path in self.BYPASS_PATHS: if path.startswith(bypass_path): return True return False def _get_identifier(self, request: Request) -> str: """Get unique identifier for rate limiting. Args: request: HTTP request Returns: Unique identifier (user ID or IP address) """ # Try to get user ID from request state (set by auth middleware) user_id = getattr(request.state, "user_id", None) if user_id: return f"user:{user_id}" # Fall back to IP address # Check for X-Forwarded-For header (proxy/load balancer) forwarded_for = request.headers.get("X-Forwarded-For") if forwarded_for: # Take the first IP in the chain client_ip = forwarded_for.split(",")[0].strip() else: client_ip = request.client.host if request.client else "unknown" return f"ip:{client_ip}" def _get_endpoint_config(self, path: str) -> RateLimitConfig: """Get rate limit configuration for endpoint. Args: path: Request path Returns: Rate limit configuration """ # Check for exact match if path in self.ENDPOINT_LIMITS: return self.ENDPOINT_LIMITS[path] # Check for prefix match for endpoint_path, config in self.ENDPOINT_LIMITS.items(): if path.startswith(endpoint_path): return config return self.default_config def _is_authenticated(self, request: Request) -> bool: """Check if request is from authenticated user. Args: request: HTTP request Returns: True if user is authenticated """ return ( hasattr(request.state, "user_id") and request.state.user_id is not None )