Files
BanGUI/backend/app/middleware/rate_limit.py
Lukas 7308ff88d6 fix(rate-limit): stop double-counting requests in middleware
Multiple RateLimitMiddleware instances were each calling
check_allowed() on every request, halving the effective global
limit (200 req/min became ~100). Added path_prefixes and skip_paths
so each instance only checks the paths it owns.

- Auth middleware scoped to /api/v1/auth/login and /api/v1/setup
- History middleware scoped to /api/v1/history
- Global middleware skips auth and history paths
- Updated tests to match single-count behavior
2026-05-15 23:04:02 +02:00

179 lines
6.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Global rate limiting middleware.
Implements per-IP request rate limiting for all endpoints using a configurable
sliding window algorithm. Intercepts requests before they reach route handlers
and blocks those exceeding the per-IP limit with a 429 response.
Rate limits can be customized per endpoint or use a global default.
IP addresses are extracted using the same trusted-proxy-aware logic as
authentication to ensure consistent behavior across all rate limiting.
**Process-local implementation** — Each worker process maintains its own
independent counter store. In multi-worker deployments (N workers), an
attacker can send up to N × limit requests before any single worker triggers
the limit. This is a fundamental limitation of in-process stores.
**Short-term mitigation:** Deploy with a single worker (enforced by the
scheduler lock). The startup warning log documents this constraint.
**Long-term solution:** Replace the in-process GlobalRateLimiter with a
Redis-backed adapter that uses atomic INCR + EXPIRE semantics. The
check_allowed() and check_allowed_for_bucket() interfaces are designed
to make this swap-in without touching middleware or router code.
Processing order
----------------
This middleware must be the innermost in the security-critical chain:
CorrelationIdMiddleware → CsrfMiddleware → RateLimitMiddleware
Rate limiting is last so that requests blocked by CsrfMiddleware do not
consume rate-limit budget, and so that rate-limit log entries (which are
unusual and potentially suspicious) always carry a correlation ID for tracing.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import JSONResponse, Response
from app.exceptions import RateLimitError
from app.utils.client_ip import get_client_ip
from app.utils.logging_compat import get_logger
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from starlette.requests import Request
from app.config import Settings
from app.utils.rate_limiter import GlobalRateLimiter
log = get_logger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Enforce per-IP request rate limiting on matching endpoints.
Tracks requests per IP and blocks further requests if the limit is exceeded.
Uses the application's GlobalRateLimiter instance and trusted-proxy settings
for consistent IP extraction.
Each middleware instance is scoped to a set of path prefixes (or all paths
if no prefixes are given). This allows multiple instances to coexist
without double-counting requests.
"""
def __init__(
self,
app: object,
rate_limiter: GlobalRateLimiter,
settings: Settings,
bucket_override: str | None = None,
bucket_max_requests: int | None = None,
bucket_window_seconds: int | None = None,
path_prefixes: list[str] | None = None,
skip_paths: list[str] | None = None,
) -> None:
"""Initialize the rate limit middleware.
Args:
app: The FastAPI application.
rate_limiter: The GlobalRateLimiter instance to use for checking limits.
settings: Application settings (used for trusted proxies).
bucket_override: Optional named bucket to use instead of the default limiter.
bucket_max_requests: Max requests for the bucket override.
bucket_window_seconds: Window for the bucket override.
path_prefixes: If provided, only apply rate limiting to paths that
start with one of these prefixes. If ``None``, all paths are
matched.
skip_paths: If provided, do not apply rate limiting to paths that
start with one of these prefixes. Evaluated after
``path_prefixes``.
"""
super().__init__(app) # type: ignore[arg-type]
self.rate_limiter: GlobalRateLimiter = rate_limiter
self.settings: Settings = settings
self.bucket_override = bucket_override
self.bucket_max_requests = bucket_max_requests
self.bucket_window_seconds = bucket_window_seconds
self.path_prefixes = path_prefixes or []
self.skip_paths = skip_paths or []
def _should_check(self, path: str) -> bool:
"""Return whether the given path should be rate-limited by this instance.
Args:
path: The request URL path.
Returns:
``True`` if this instance should enforce its limit on the path.
"""
if self.skip_paths and any(path.startswith(p) for p in self.skip_paths):
return False
if self.path_prefixes:
return any(path.startswith(p) for p in self.path_prefixes)
return True
async def dispatch(
self,
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
"""Check rate limit before passing request to next middleware/handler.
If the client IP has exceeded the request limit, returns a 429 response
immediately. Otherwise passes the request through normally.
Args:
request: The incoming HTTP request.
call_next: Callable to pass the request to the next middleware/handler.
Returns:
A response object (either rate limit response or from handler).
"""
path = request.url.path
if not self._should_check(path):
return await call_next(request)
client_ip = get_client_ip(request, trusted_proxies=self.settings.trusted_proxies)
if self.bucket_override and self.bucket_max_requests and self.bucket_window_seconds:
is_allowed, retry_after = self.rate_limiter.check_allowed_for_bucket(
self.bucket_override,
client_ip,
self.bucket_max_requests,
self.bucket_window_seconds,
)
else:
is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip)
if not is_allowed:
log.warning(
"global_rate_limit_exceeded",
client_ip=client_ip,
path=path,
method=request.method,
retry_after=retry_after,
)
rate_limit_error = RateLimitError(
"Too many requests. Please try again later.",
retry_after_seconds=retry_after,
)
return JSONResponse(
status_code=429,
content={
"code": "rate_limit_exceeded",
"detail": str(rate_limit_error),
"metadata": rate_limit_error.get_error_metadata(),
"correlation_id": getattr(request.state, "correlation_id", None),
},
headers={"Retry-After": str(int(retry_after))},
)
response: Response = await call_next(request)
return response