fix(rate-limit): stop double-counting requests in middleware

Multiple RateLimitMiddleware instances were each calling
check_allowed() on every request, halving the effective global
limit (200 req/min became ~100). Added path_prefixes and skip_paths
so each instance only checks the paths it owns.

- Auth middleware scoped to /api/v1/auth/login and /api/v1/setup
- History middleware scoped to /api/v1/history
- Global middleware skips auth and history paths
- Updated tests to match single-count behavior
This commit is contained in:
2026-05-15 23:04:02 +02:00
parent 77df5d5d65
commit 7308ff88d6
3 changed files with 92 additions and 45 deletions

View File

@@ -34,18 +34,20 @@ unusual and potentially suspicious) always carry a correlation ID for tracing.
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING
from app.utils.logging_compat import get_logger
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from app.exceptions import RateLimitError
from app.utils.client_ip import get_client_ip
from app.utils.logging_compat import get_logger
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from starlette.requests import Request
from app.config import Settings
from app.utils.rate_limiter import GlobalRateLimiter
@@ -53,11 +55,15 @@ log = get_logger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Enforce global per-IP request rate limiting on all endpoints.
"""Enforce per-IP request rate limiting on matching endpoints.
Tracks requests per IP and blocks further requests if the limit is exceeded.
Uses the application's GlobalRateLimiter instance and trusted-proxy settings
for consistent IP extraction.
Each middleware instance is scoped to a set of path prefixes (or all paths
if no prefixes are given). This allows multiple instances to coexist
without double-counting requests.
"""
def __init__(
@@ -68,6 +74,8 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
bucket_override: str | None = None,
bucket_max_requests: int | None = None,
bucket_window_seconds: int | None = None,
path_prefixes: list[str] | None = None,
skip_paths: list[str] | None = None,
) -> None:
"""Initialize the rate limit middleware.
@@ -78,6 +86,12 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
bucket_override: Optional named bucket to use instead of the default limiter.
bucket_max_requests: Max requests for the bucket override.
bucket_window_seconds: Window for the bucket override.
path_prefixes: If provided, only apply rate limiting to paths that
start with one of these prefixes. If ``None``, all paths are
matched.
skip_paths: If provided, do not apply rate limiting to paths that
start with one of these prefixes. Evaluated after
``path_prefixes``.
"""
super().__init__(app) # type: ignore[arg-type]
self.rate_limiter: GlobalRateLimiter = rate_limiter
@@ -85,6 +99,23 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
self.bucket_override = bucket_override
self.bucket_max_requests = bucket_max_requests
self.bucket_window_seconds = bucket_window_seconds
self.path_prefixes = path_prefixes or []
self.skip_paths = skip_paths or []
def _should_check(self, path: str) -> bool:
"""Return whether the given path should be rate-limited by this instance.
Args:
path: The request URL path.
Returns:
``True`` if this instance should enforce its limit on the path.
"""
if self.skip_paths and any(path.startswith(p) for p in self.skip_paths):
return False
if self.path_prefixes:
return any(path.startswith(p) for p in self.path_prefixes)
return True
async def dispatch(
self,
@@ -103,37 +134,28 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
Returns:
A response object (either rate limit response or from handler).
"""
client_ip = get_client_ip(request, trusted_proxies=self.settings.trusted_proxies)
# Use higher-rate bucket for specific endpoints.
# Check path to apply the appropriate bucket.
path = request.url.path
if not self._should_check(path):
return await call_next(request)
client_ip = get_client_ip(request, trusted_proxies=self.settings.trusted_proxies)
if self.bucket_override and self.bucket_max_requests and self.bucket_window_seconds:
if path.startswith("/api/v1/history"):
is_allowed, retry_after = self.rate_limiter.check_allowed_for_bucket(
self.bucket_override,
client_ip,
self.bucket_max_requests,
self.bucket_window_seconds,
)
elif path.startswith("/api/v1/login") or path.startswith("/api/v1/setup"):
# Auth endpoints use their own bucket
is_allowed, retry_after = self.rate_limiter.check_allowed_for_bucket(
self.bucket_override,
client_ip,
self.bucket_max_requests,
self.bucket_window_seconds,
)
else:
is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip)
is_allowed, retry_after = self.rate_limiter.check_allowed_for_bucket(
self.bucket_override,
client_ip,
self.bucket_max_requests,
self.bucket_window_seconds,
)
else:
is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip)
if not is_allowed:
log.warning(
"global_rate_limit_exceeded",
client_ip=client_ip,
path=request.url.path,
path=path,
method=request.method,
retry_after=retry_after,
)
@@ -141,7 +163,6 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
"Too many requests. Please try again later.",
retry_after_seconds=retry_after,
)
# Return the error response directly
return JSONResponse(
status_code=429,
content={
@@ -153,6 +174,5 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
headers={"Retry-After": str(int(retry_after))},
)
# Request is allowed, continue to next handler
response: Response = await call_next(request)
return response