fix(rate-limit): stop double-counting requests in middleware
Multiple RateLimitMiddleware instances were each calling check_allowed() on every request, halving the effective global limit (200 req/min became ~100). Added path_prefixes and skip_paths so each instance only checks the paths it owns. - Auth middleware scoped to /api/v1/auth/login and /api/v1/setup - History middleware scoped to /api/v1/history - Global middleware skips auth and history paths - Updated tests to match single-count behavior
This commit is contained in:
@@ -1135,9 +1135,7 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
|||||||
app.add_middleware(CsrfMiddleware)
|
app.add_middleware(CsrfMiddleware)
|
||||||
app.add_middleware(DeprecationHeaderMiddleware)
|
app.add_middleware(DeprecationHeaderMiddleware)
|
||||||
# Auth endpoints (login, setup) need a dedicated higher-rate bucket to avoid
|
# Auth endpoints (login, setup) need a dedicated higher-rate bucket to avoid
|
||||||
# rate limiting when running e2e tests sequentially. Auth uses the default
|
# rate limiting when running e2e tests sequentially.
|
||||||
# global rate limiter at 200 req/min per IP.
|
|
||||||
# Auth endpoints: /api/v1/login, /api/v1/setup
|
|
||||||
# 1000 req/min per IP — generous for e2e testing.
|
# 1000 req/min per IP — generous for e2e testing.
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
RateLimitMiddleware,
|
RateLimitMiddleware,
|
||||||
@@ -1146,6 +1144,7 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
|||||||
bucket_override="auth:login",
|
bucket_override="auth:login",
|
||||||
bucket_max_requests=1000,
|
bucket_max_requests=1000,
|
||||||
bucket_window_seconds=60,
|
bucket_window_seconds=60,
|
||||||
|
path_prefixes=["/api/v1/auth/login", "/api/v1/setup"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# History endpoints get a dedicated higher-rate bucket to avoid
|
# History endpoints get a dedicated higher-rate bucket to avoid
|
||||||
@@ -1159,6 +1158,16 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
|||||||
bucket_override="history:list",
|
bucket_override="history:list",
|
||||||
bucket_max_requests=10000,
|
bucket_max_requests=10000,
|
||||||
bucket_window_seconds=60,
|
bucket_window_seconds=60,
|
||||||
|
path_prefixes=["/api/v1/history"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Global rate limiter for all other endpoints.
|
||||||
|
# 200 req/min per IP — default protection.
|
||||||
|
app.add_middleware(
|
||||||
|
RateLimitMiddleware,
|
||||||
|
rate_limiter=app.state.global_rate_limiter,
|
||||||
|
settings=resolved_settings,
|
||||||
|
skip_paths=["/api/v1/auth/login", "/api/v1/setup", "/api/v1/history"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate middleware order before returning the app.
|
# Validate middleware order before returning the app.
|
||||||
|
|||||||
@@ -34,18 +34,20 @@ unusual and potentially suspicious) always carry a correlation ID for tracing.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from collections.abc import Awaitable, Callable
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from app.utils.logging_compat import get_logger
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from starlette.requests import Request
|
|
||||||
from starlette.responses import JSONResponse, Response
|
from starlette.responses import JSONResponse, Response
|
||||||
|
|
||||||
from app.exceptions import RateLimitError
|
from app.exceptions import RateLimitError
|
||||||
from app.utils.client_ip import get_client_ip
|
from app.utils.client_ip import get_client_ip
|
||||||
|
from app.utils.logging_compat import get_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
from app.config import Settings
|
from app.config import Settings
|
||||||
from app.utils.rate_limiter import GlobalRateLimiter
|
from app.utils.rate_limiter import GlobalRateLimiter
|
||||||
|
|
||||||
@@ -53,11 +55,15 @@ log = get_logger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
"""Enforce global per-IP request rate limiting on all endpoints.
|
"""Enforce per-IP request rate limiting on matching endpoints.
|
||||||
|
|
||||||
Tracks requests per IP and blocks further requests if the limit is exceeded.
|
Tracks requests per IP and blocks further requests if the limit is exceeded.
|
||||||
Uses the application's GlobalRateLimiter instance and trusted-proxy settings
|
Uses the application's GlobalRateLimiter instance and trusted-proxy settings
|
||||||
for consistent IP extraction.
|
for consistent IP extraction.
|
||||||
|
|
||||||
|
Each middleware instance is scoped to a set of path prefixes (or all paths
|
||||||
|
if no prefixes are given). This allows multiple instances to coexist
|
||||||
|
without double-counting requests.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -68,6 +74,8 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
|||||||
bucket_override: str | None = None,
|
bucket_override: str | None = None,
|
||||||
bucket_max_requests: int | None = None,
|
bucket_max_requests: int | None = None,
|
||||||
bucket_window_seconds: int | None = None,
|
bucket_window_seconds: int | None = None,
|
||||||
|
path_prefixes: list[str] | None = None,
|
||||||
|
skip_paths: list[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the rate limit middleware.
|
"""Initialize the rate limit middleware.
|
||||||
|
|
||||||
@@ -78,6 +86,12 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
|||||||
bucket_override: Optional named bucket to use instead of the default limiter.
|
bucket_override: Optional named bucket to use instead of the default limiter.
|
||||||
bucket_max_requests: Max requests for the bucket override.
|
bucket_max_requests: Max requests for the bucket override.
|
||||||
bucket_window_seconds: Window for the bucket override.
|
bucket_window_seconds: Window for the bucket override.
|
||||||
|
path_prefixes: If provided, only apply rate limiting to paths that
|
||||||
|
start with one of these prefixes. If ``None``, all paths are
|
||||||
|
matched.
|
||||||
|
skip_paths: If provided, do not apply rate limiting to paths that
|
||||||
|
start with one of these prefixes. Evaluated after
|
||||||
|
``path_prefixes``.
|
||||||
"""
|
"""
|
||||||
super().__init__(app) # type: ignore[arg-type]
|
super().__init__(app) # type: ignore[arg-type]
|
||||||
self.rate_limiter: GlobalRateLimiter = rate_limiter
|
self.rate_limiter: GlobalRateLimiter = rate_limiter
|
||||||
@@ -85,6 +99,23 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
|||||||
self.bucket_override = bucket_override
|
self.bucket_override = bucket_override
|
||||||
self.bucket_max_requests = bucket_max_requests
|
self.bucket_max_requests = bucket_max_requests
|
||||||
self.bucket_window_seconds = bucket_window_seconds
|
self.bucket_window_seconds = bucket_window_seconds
|
||||||
|
self.path_prefixes = path_prefixes or []
|
||||||
|
self.skip_paths = skip_paths or []
|
||||||
|
|
||||||
|
def _should_check(self, path: str) -> bool:
|
||||||
|
"""Return whether the given path should be rate-limited by this instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: The request URL path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``True`` if this instance should enforce its limit on the path.
|
||||||
|
"""
|
||||||
|
if self.skip_paths and any(path.startswith(p) for p in self.skip_paths):
|
||||||
|
return False
|
||||||
|
if self.path_prefixes:
|
||||||
|
return any(path.startswith(p) for p in self.path_prefixes)
|
||||||
|
return True
|
||||||
|
|
||||||
async def dispatch(
|
async def dispatch(
|
||||||
self,
|
self,
|
||||||
@@ -103,37 +134,28 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
|||||||
Returns:
|
Returns:
|
||||||
A response object (either rate limit response or from handler).
|
A response object (either rate limit response or from handler).
|
||||||
"""
|
"""
|
||||||
client_ip = get_client_ip(request, trusted_proxies=self.settings.trusted_proxies)
|
|
||||||
|
|
||||||
# Use higher-rate bucket for specific endpoints.
|
|
||||||
# Check path to apply the appropriate bucket.
|
|
||||||
path = request.url.path
|
path = request.url.path
|
||||||
|
|
||||||
|
if not self._should_check(path):
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
client_ip = get_client_ip(request, trusted_proxies=self.settings.trusted_proxies)
|
||||||
|
|
||||||
if self.bucket_override and self.bucket_max_requests and self.bucket_window_seconds:
|
if self.bucket_override and self.bucket_max_requests and self.bucket_window_seconds:
|
||||||
if path.startswith("/api/v1/history"):
|
is_allowed, retry_after = self.rate_limiter.check_allowed_for_bucket(
|
||||||
is_allowed, retry_after = self.rate_limiter.check_allowed_for_bucket(
|
self.bucket_override,
|
||||||
self.bucket_override,
|
client_ip,
|
||||||
client_ip,
|
self.bucket_max_requests,
|
||||||
self.bucket_max_requests,
|
self.bucket_window_seconds,
|
||||||
self.bucket_window_seconds,
|
)
|
||||||
)
|
|
||||||
elif path.startswith("/api/v1/login") or path.startswith("/api/v1/setup"):
|
|
||||||
# Auth endpoints use their own bucket
|
|
||||||
is_allowed, retry_after = self.rate_limiter.check_allowed_for_bucket(
|
|
||||||
self.bucket_override,
|
|
||||||
client_ip,
|
|
||||||
self.bucket_max_requests,
|
|
||||||
self.bucket_window_seconds,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip)
|
|
||||||
else:
|
else:
|
||||||
is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip)
|
is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip)
|
||||||
|
|
||||||
if not is_allowed:
|
if not is_allowed:
|
||||||
log.warning(
|
log.warning(
|
||||||
"global_rate_limit_exceeded",
|
"global_rate_limit_exceeded",
|
||||||
client_ip=client_ip,
|
client_ip=client_ip,
|
||||||
path=request.url.path,
|
path=path,
|
||||||
method=request.method,
|
method=request.method,
|
||||||
retry_after=retry_after,
|
retry_after=retry_after,
|
||||||
)
|
)
|
||||||
@@ -141,7 +163,6 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
|||||||
"Too many requests. Please try again later.",
|
"Too many requests. Please try again later.",
|
||||||
retry_after_seconds=retry_after,
|
retry_after_seconds=retry_after,
|
||||||
)
|
)
|
||||||
# Return the error response directly
|
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=429,
|
status_code=429,
|
||||||
content={
|
content={
|
||||||
@@ -153,6 +174,5 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
|
|||||||
headers={"Retry-After": str(int(retry_after))},
|
headers={"Retry-After": str(int(retry_after))},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Request is allowed, continue to next handler
|
|
||||||
response: Response = await call_next(request)
|
response: Response = await call_next(request)
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -134,24 +134,17 @@ class TestRateLimitMiddleware:
|
|||||||
"""Global rate limit should block requests exceeding per-IP limit."""
|
"""Global rate limit should block requests exceeding per-IP limit."""
|
||||||
await _do_setup(client)
|
await _do_setup(client)
|
||||||
|
|
||||||
# Create a client that mimics a specific IP
|
|
||||||
# We'll make many requests and see if we hit the limit
|
|
||||||
limiter = client._transport.app.state.global_rate_limiter
|
limiter = client._transport.app.state.global_rate_limiter
|
||||||
limiter.reset()
|
limiter.reset()
|
||||||
|
|
||||||
# Reduce limit temporarily for testing.
|
|
||||||
# Each request is checked by two middleware instances, so the
|
|
||||||
# effective limit is doubled for non-bucket endpoints.
|
|
||||||
original_max = limiter.max_requests
|
original_max = limiter.max_requests
|
||||||
limiter.max_requests = 7
|
limiter.max_requests = 3
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# First 3 requests should succeed
|
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
response = await client.get("/api/v1/health")
|
response = await client.get("/api/v1/health")
|
||||||
assert response.status_code == 200, f"Request {i + 1} failed"
|
assert response.status_code == 200, f"Request {i + 1} failed"
|
||||||
|
|
||||||
# Fourth request should be rate limited
|
|
||||||
response = await client.get("/api/v1/health")
|
response = await client.get("/api/v1/health")
|
||||||
assert response.status_code == 429
|
assert response.status_code == 429
|
||||||
assert response.json()["code"] == "rate_limit_exceeded"
|
assert response.json()["code"] == "rate_limit_exceeded"
|
||||||
@@ -166,22 +159,47 @@ class TestRateLimitMiddleware:
|
|||||||
limiter = client._transport.app.state.global_rate_limiter
|
limiter = client._transport.app.state.global_rate_limiter
|
||||||
limiter.reset()
|
limiter.reset()
|
||||||
|
|
||||||
# Two middleware instances check each request, so the effective
|
|
||||||
# limit is doubled for non-bucket endpoints.
|
|
||||||
original_max = limiter.max_requests
|
original_max = limiter.max_requests
|
||||||
limiter.max_requests = 3
|
limiter.max_requests = 2
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# First request succeeds
|
|
||||||
response = await client.get("/api/v1/health")
|
response = await client.get("/api/v1/health")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
# Second request is rate limited
|
response = await client.get("/api/v1/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
response = await client.get("/api/v1/health")
|
response = await client.get("/api/v1/health")
|
||||||
assert response.status_code == 429
|
assert response.status_code == 429
|
||||||
assert "Retry-After" in response.headers
|
assert "Retry-After" in response.headers
|
||||||
retry_after = int(response.headers["Retry-After"])
|
retry_after = int(response.headers["Retry-After"])
|
||||||
assert retry_after > 0
|
assert retry_after > 0
|
||||||
assert retry_after <= 60 # Should be less than window
|
assert retry_after <= 60
|
||||||
finally:
|
finally:
|
||||||
limiter.max_requests = original_max
|
limiter.max_requests = original_max
|
||||||
|
|
||||||
|
async def test_auth_bucket_allows_more_requests(self, client: AsyncClient) -> None:
|
||||||
|
"""Auth endpoints use a dedicated high-rate bucket."""
|
||||||
|
await _do_setup(client)
|
||||||
|
|
||||||
|
limiter = client._transport.app.state.global_rate_limiter
|
||||||
|
limiter.reset()
|
||||||
|
|
||||||
|
# The auth bucket is configured for 1000 req/min; we only need to
|
||||||
|
# verify that it is *not* the global bucket (200 req/min).
|
||||||
|
for _ in range(5):
|
||||||
|
response = await client.post("/api/v1/auth/login", json={"password": "x"})
|
||||||
|
assert response.status_code in (401, 403, 429)
|
||||||
|
|
||||||
|
async def test_history_bucket_allows_more_requests(self, client: AsyncClient) -> None:
|
||||||
|
"""History endpoints use a dedicated high-rate bucket."""
|
||||||
|
await _do_setup(client)
|
||||||
|
|
||||||
|
limiter = client._transport.app.state.global_rate_limiter
|
||||||
|
limiter.reset()
|
||||||
|
|
||||||
|
for _ in range(5):
|
||||||
|
response = await client.get("/api/v1/history/bans")
|
||||||
|
# 401/403 is fine — we just need to confirm we are not 429'd
|
||||||
|
# by the global limiter.
|
||||||
|
assert response.status_code != 429
|
||||||
|
|||||||
Reference in New Issue
Block a user