"""Global rate limiting middleware. Implements per-IP request rate limiting for all endpoints using a configurable sliding window algorithm. Intercepts requests before they reach route handlers and blocks those exceeding the per-IP limit with a 429 response. Rate limits can be customized per endpoint or use a global default. IP addresses are extracted using the same trusted-proxy-aware logic as authentication to ensure consistent behavior across all rate limiting. **Process-local implementation** — Each worker process maintains its own independent counter store. In multi-worker deployments (N workers), an attacker can send up to N × limit requests before any single worker triggers the limit. This is a fundamental limitation of in-process stores. **Short-term mitigation:** Deploy with a single worker (enforced by the scheduler lock). The startup warning log documents this constraint. **Long-term solution:** Replace the in-process GlobalRateLimiter with a Redis-backed adapter that uses atomic INCR + EXPIRE semantics. The check_allowed() and check_allowed_for_bucket() interfaces are designed to make this swap-in without touching middleware or router code. Processing order ---------------- This middleware must be the innermost in the security-critical chain: CorrelationIdMiddleware → CsrfMiddleware → RateLimitMiddleware Rate limiting is last so that requests blocked by CsrfMiddleware do not consume rate-limit budget, and so that rate-limit log entries (which are unusual and potentially suspicious) always carry a correlation ID for tracing. """ from __future__ import annotations from typing import TYPE_CHECKING from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse, Response from app.exceptions import RateLimitError from app.utils.client_ip import get_client_ip from app.utils.logging_compat import get_logger if TYPE_CHECKING: from collections.abc import Awaitable, Callable from starlette.requests import Request from app.config import Settings from app.utils.rate_limiter import GlobalRateLimiter log = get_logger(__name__) class RateLimitMiddleware(BaseHTTPMiddleware): """Enforce per-IP request rate limiting on matching endpoints. Tracks requests per IP and blocks further requests if the limit is exceeded. Uses the application's GlobalRateLimiter instance and trusted-proxy settings for consistent IP extraction. Each middleware instance is scoped to a set of path prefixes (or all paths if no prefixes are given). This allows multiple instances to coexist without double-counting requests. """ def __init__( self, app: object, rate_limiter: GlobalRateLimiter, settings: Settings, bucket_override: str | None = None, bucket_max_requests: int | None = None, bucket_window_seconds: int | None = None, path_prefixes: list[str] | None = None, skip_paths: list[str] | None = None, ) -> None: """Initialize the rate limit middleware. Args: app: The FastAPI application. rate_limiter: The GlobalRateLimiter instance to use for checking limits. settings: Application settings (used for trusted proxies). bucket_override: Optional named bucket to use instead of the default limiter. bucket_max_requests: Max requests for the bucket override. bucket_window_seconds: Window for the bucket override. path_prefixes: If provided, only apply rate limiting to paths that start with one of these prefixes. If ``None``, all paths are matched. skip_paths: If provided, do not apply rate limiting to paths that start with one of these prefixes. Evaluated after ``path_prefixes``. """ super().__init__(app) # type: ignore[arg-type] self.rate_limiter: GlobalRateLimiter = rate_limiter self.settings: Settings = settings self.bucket_override = bucket_override self.bucket_max_requests = bucket_max_requests self.bucket_window_seconds = bucket_window_seconds self.path_prefixes = path_prefixes or [] self.skip_paths = skip_paths or [] def _should_check(self, path: str) -> bool: """Return whether the given path should be rate-limited by this instance. Args: path: The request URL path. Returns: ``True`` if this instance should enforce its limit on the path. """ if self.skip_paths and any(path.startswith(p) for p in self.skip_paths): return False if self.path_prefixes: return any(path.startswith(p) for p in self.path_prefixes) return True async def dispatch( self, request: Request, call_next: Callable[[Request], Awaitable[Response]], ) -> Response: """Check rate limit before passing request to next middleware/handler. If the client IP has exceeded the request limit, returns a 429 response immediately. Otherwise passes the request through normally. Args: request: The incoming HTTP request. call_next: Callable to pass the request to the next middleware/handler. Returns: A response object (either rate limit response or from handler). """ path = request.url.path if not self._should_check(path): return await call_next(request) client_ip = get_client_ip(request, trusted_proxies=self.settings.trusted_proxies) if self.bucket_override and self.bucket_max_requests and self.bucket_window_seconds: is_allowed, retry_after = self.rate_limiter.check_allowed_for_bucket( self.bucket_override, client_ip, self.bucket_max_requests, self.bucket_window_seconds, ) else: is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip) if not is_allowed: log.warning( "global_rate_limit_exceeded", client_ip=client_ip, path=path, method=request.method, retry_after=retry_after, ) rate_limit_error = RateLimitError( "Too many requests. Please try again later.", retry_after_seconds=retry_after, ) return JSONResponse( status_code=429, content={ "code": "rate_limit_exceeded", "detail": str(rate_limit_error), "metadata": rate_limit_error.get_error_metadata(), "correlation_id": getattr(request.state, "correlation_id", None), }, headers={"Retry-After": str(int(retry_after))}, ) response: Response = await call_next(request) return response