fix(rate-limit): stop double-counting requests in middleware

Multiple RateLimitMiddleware instances were each calling
check_allowed() on every request, halving the effective global
limit (200 req/min became ~100). Added path_prefixes and skip_paths
so each instance only checks the paths it owns.

- Auth middleware scoped to /api/v1/auth/login and /api/v1/setup
- History middleware scoped to /api/v1/history
- Global middleware skips auth and history paths
- Updated tests to match single-count behavior
This commit is contained in:
2026-05-15 23:04:02 +02:00
parent 77df5d5d65
commit 7308ff88d6
3 changed files with 92 additions and 45 deletions

View File

@@ -1135,9 +1135,7 @@ def create_app(settings: Settings | None = None) -> FastAPI:
app.add_middleware(CsrfMiddleware) app.add_middleware(CsrfMiddleware)
app.add_middleware(DeprecationHeaderMiddleware) app.add_middleware(DeprecationHeaderMiddleware)
# Auth endpoints (login, setup) need a dedicated higher-rate bucket to avoid # Auth endpoints (login, setup) need a dedicated higher-rate bucket to avoid
# rate limiting when running e2e tests sequentially. Auth uses the default # rate limiting when running e2e tests sequentially.
# global rate limiter at 200 req/min per IP.
# Auth endpoints: /api/v1/login, /api/v1/setup
# 1000 req/min per IP — generous for e2e testing. # 1000 req/min per IP — generous for e2e testing.
app.add_middleware( app.add_middleware(
RateLimitMiddleware, RateLimitMiddleware,
@@ -1146,6 +1144,7 @@ def create_app(settings: Settings | None = None) -> FastAPI:
bucket_override="auth:login", bucket_override="auth:login",
bucket_max_requests=1000, bucket_max_requests=1000,
bucket_window_seconds=60, bucket_window_seconds=60,
path_prefixes=["/api/v1/auth/login", "/api/v1/setup"],
) )
# History endpoints get a dedicated higher-rate bucket to avoid # History endpoints get a dedicated higher-rate bucket to avoid
@@ -1159,6 +1158,16 @@ def create_app(settings: Settings | None = None) -> FastAPI:
bucket_override="history:list", bucket_override="history:list",
bucket_max_requests=10000, bucket_max_requests=10000,
bucket_window_seconds=60, bucket_window_seconds=60,
path_prefixes=["/api/v1/history"],
)
# Global rate limiter for all other endpoints.
# 200 req/min per IP — default protection.
app.add_middleware(
RateLimitMiddleware,
rate_limiter=app.state.global_rate_limiter,
settings=resolved_settings,
skip_paths=["/api/v1/auth/login", "/api/v1/setup", "/api/v1/history"],
) )
# Validate middleware order before returning the app. # Validate middleware order before returning the app.

View File

@@ -34,18 +34,20 @@ unusual and potentially suspicious) always carry a correlation ID for tracing.
from __future__ import annotations from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from app.utils.logging_compat import get_logger
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response from starlette.responses import JSONResponse, Response
from app.exceptions import RateLimitError from app.exceptions import RateLimitError
from app.utils.client_ip import get_client_ip from app.utils.client_ip import get_client_ip
from app.utils.logging_compat import get_logger
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
from starlette.requests import Request
from app.config import Settings from app.config import Settings
from app.utils.rate_limiter import GlobalRateLimiter from app.utils.rate_limiter import GlobalRateLimiter
@@ -53,11 +55,15 @@ log = get_logger(__name__)
class RateLimitMiddleware(BaseHTTPMiddleware): class RateLimitMiddleware(BaseHTTPMiddleware):
"""Enforce global per-IP request rate limiting on all endpoints. """Enforce per-IP request rate limiting on matching endpoints.
Tracks requests per IP and blocks further requests if the limit is exceeded. Tracks requests per IP and blocks further requests if the limit is exceeded.
Uses the application's GlobalRateLimiter instance and trusted-proxy settings Uses the application's GlobalRateLimiter instance and trusted-proxy settings
for consistent IP extraction. for consistent IP extraction.
Each middleware instance is scoped to a set of path prefixes (or all paths
if no prefixes are given). This allows multiple instances to coexist
without double-counting requests.
""" """
def __init__( def __init__(
@@ -68,6 +74,8 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
bucket_override: str | None = None, bucket_override: str | None = None,
bucket_max_requests: int | None = None, bucket_max_requests: int | None = None,
bucket_window_seconds: int | None = None, bucket_window_seconds: int | None = None,
path_prefixes: list[str] | None = None,
skip_paths: list[str] | None = None,
) -> None: ) -> None:
"""Initialize the rate limit middleware. """Initialize the rate limit middleware.
@@ -78,6 +86,12 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
bucket_override: Optional named bucket to use instead of the default limiter. bucket_override: Optional named bucket to use instead of the default limiter.
bucket_max_requests: Max requests for the bucket override. bucket_max_requests: Max requests for the bucket override.
bucket_window_seconds: Window for the bucket override. bucket_window_seconds: Window for the bucket override.
path_prefixes: If provided, only apply rate limiting to paths that
start with one of these prefixes. If ``None``, all paths are
matched.
skip_paths: If provided, do not apply rate limiting to paths that
start with one of these prefixes. Evaluated after
``path_prefixes``.
""" """
super().__init__(app) # type: ignore[arg-type] super().__init__(app) # type: ignore[arg-type]
self.rate_limiter: GlobalRateLimiter = rate_limiter self.rate_limiter: GlobalRateLimiter = rate_limiter
@@ -85,6 +99,23 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
self.bucket_override = bucket_override self.bucket_override = bucket_override
self.bucket_max_requests = bucket_max_requests self.bucket_max_requests = bucket_max_requests
self.bucket_window_seconds = bucket_window_seconds self.bucket_window_seconds = bucket_window_seconds
self.path_prefixes = path_prefixes or []
self.skip_paths = skip_paths or []
def _should_check(self, path: str) -> bool:
"""Return whether the given path should be rate-limited by this instance.
Args:
path: The request URL path.
Returns:
``True`` if this instance should enforce its limit on the path.
"""
if self.skip_paths and any(path.startswith(p) for p in self.skip_paths):
return False
if self.path_prefixes:
return any(path.startswith(p) for p in self.path_prefixes)
return True
async def dispatch( async def dispatch(
self, self,
@@ -103,22 +134,14 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
Returns: Returns:
A response object (either rate limit response or from handler). A response object (either rate limit response or from handler).
""" """
client_ip = get_client_ip(request, trusted_proxies=self.settings.trusted_proxies)
# Use higher-rate bucket for specific endpoints.
# Check path to apply the appropriate bucket.
path = request.url.path path = request.url.path
if not self._should_check(path):
return await call_next(request)
client_ip = get_client_ip(request, trusted_proxies=self.settings.trusted_proxies)
if self.bucket_override and self.bucket_max_requests and self.bucket_window_seconds: if self.bucket_override and self.bucket_max_requests and self.bucket_window_seconds:
if path.startswith("/api/v1/history"):
is_allowed, retry_after = self.rate_limiter.check_allowed_for_bucket(
self.bucket_override,
client_ip,
self.bucket_max_requests,
self.bucket_window_seconds,
)
elif path.startswith("/api/v1/login") or path.startswith("/api/v1/setup"):
# Auth endpoints use their own bucket
is_allowed, retry_after = self.rate_limiter.check_allowed_for_bucket( is_allowed, retry_after = self.rate_limiter.check_allowed_for_bucket(
self.bucket_override, self.bucket_override,
client_ip, client_ip,
@@ -127,13 +150,12 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
) )
else: else:
is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip) is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip)
else:
is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip)
if not is_allowed: if not is_allowed:
log.warning( log.warning(
"global_rate_limit_exceeded", "global_rate_limit_exceeded",
client_ip=client_ip, client_ip=client_ip,
path=request.url.path, path=path,
method=request.method, method=request.method,
retry_after=retry_after, retry_after=retry_after,
) )
@@ -141,7 +163,6 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
"Too many requests. Please try again later.", "Too many requests. Please try again later.",
retry_after_seconds=retry_after, retry_after_seconds=retry_after,
) )
# Return the error response directly
return JSONResponse( return JSONResponse(
status_code=429, status_code=429,
content={ content={
@@ -153,6 +174,5 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
headers={"Retry-After": str(int(retry_after))}, headers={"Retry-After": str(int(retry_after))},
) )
# Request is allowed, continue to next handler
response: Response = await call_next(request) response: Response = await call_next(request)
return response return response

View File

@@ -134,24 +134,17 @@ class TestRateLimitMiddleware:
"""Global rate limit should block requests exceeding per-IP limit.""" """Global rate limit should block requests exceeding per-IP limit."""
await _do_setup(client) await _do_setup(client)
# Create a client that mimics a specific IP
# We'll make many requests and see if we hit the limit
limiter = client._transport.app.state.global_rate_limiter limiter = client._transport.app.state.global_rate_limiter
limiter.reset() limiter.reset()
# Reduce limit temporarily for testing.
# Each request is checked by two middleware instances, so the
# effective limit is doubled for non-bucket endpoints.
original_max = limiter.max_requests original_max = limiter.max_requests
limiter.max_requests = 7 limiter.max_requests = 3
try: try:
# First 3 requests should succeed
for i in range(3): for i in range(3):
response = await client.get("/api/v1/health") response = await client.get("/api/v1/health")
assert response.status_code == 200, f"Request {i + 1} failed" assert response.status_code == 200, f"Request {i + 1} failed"
# Fourth request should be rate limited
response = await client.get("/api/v1/health") response = await client.get("/api/v1/health")
assert response.status_code == 429 assert response.status_code == 429
assert response.json()["code"] == "rate_limit_exceeded" assert response.json()["code"] == "rate_limit_exceeded"
@@ -166,22 +159,47 @@ class TestRateLimitMiddleware:
limiter = client._transport.app.state.global_rate_limiter limiter = client._transport.app.state.global_rate_limiter
limiter.reset() limiter.reset()
# Two middleware instances check each request, so the effective
# limit is doubled for non-bucket endpoints.
original_max = limiter.max_requests original_max = limiter.max_requests
limiter.max_requests = 3 limiter.max_requests = 2
try: try:
# First request succeeds
response = await client.get("/api/v1/health") response = await client.get("/api/v1/health")
assert response.status_code == 200 assert response.status_code == 200
# Second request is rate limited response = await client.get("/api/v1/health")
assert response.status_code == 200
response = await client.get("/api/v1/health") response = await client.get("/api/v1/health")
assert response.status_code == 429 assert response.status_code == 429
assert "Retry-After" in response.headers assert "Retry-After" in response.headers
retry_after = int(response.headers["Retry-After"]) retry_after = int(response.headers["Retry-After"])
assert retry_after > 0 assert retry_after > 0
assert retry_after <= 60 # Should be less than window assert retry_after <= 60
finally: finally:
limiter.max_requests = original_max limiter.max_requests = original_max
async def test_auth_bucket_allows_more_requests(self, client: AsyncClient) -> None:
"""Auth endpoints use a dedicated high-rate bucket."""
await _do_setup(client)
limiter = client._transport.app.state.global_rate_limiter
limiter.reset()
# The auth bucket is configured for 1000 req/min; we only need to
# verify that it is *not* the global bucket (200 req/min).
for _ in range(5):
response = await client.post("/api/v1/auth/login", json={"password": "x"})
assert response.status_code in (401, 403, 429)
async def test_history_bucket_allows_more_requests(self, client: AsyncClient) -> None:
"""History endpoints use a dedicated high-rate bucket."""
await _do_setup(client)
limiter = client._transport.app.state.global_rate_limiter
limiter.reset()
for _ in range(5):
response = await client.get("/api/v1/history/bans")
# 401/403 is fine — we just need to confirm we are not 429'd
# by the global limiter.
assert response.status_code != 429