diff --git a/backend/app/main.py b/backend/app/main.py index fdcaece..91e750d 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -1135,9 +1135,7 @@ def create_app(settings: Settings | None = None) -> FastAPI: app.add_middleware(CsrfMiddleware) app.add_middleware(DeprecationHeaderMiddleware) # Auth endpoints (login, setup) need a dedicated higher-rate bucket to avoid - # rate limiting when running e2e tests sequentially. Auth uses the default - # global rate limiter at 200 req/min per IP. - # Auth endpoints: /api/v1/login, /api/v1/setup + # rate limiting when running e2e tests sequentially. # 1000 req/min per IP — generous for e2e testing. app.add_middleware( RateLimitMiddleware, @@ -1146,6 +1144,7 @@ def create_app(settings: Settings | None = None) -> FastAPI: bucket_override="auth:login", bucket_max_requests=1000, bucket_window_seconds=60, + path_prefixes=["/api/v1/auth/login", "/api/v1/setup"], ) # History endpoints get a dedicated higher-rate bucket to avoid @@ -1159,6 +1158,16 @@ def create_app(settings: Settings | None = None) -> FastAPI: bucket_override="history:list", bucket_max_requests=10000, bucket_window_seconds=60, + path_prefixes=["/api/v1/history"], + ) + + # Global rate limiter for all other endpoints. + # 200 req/min per IP — default protection. + app.add_middleware( + RateLimitMiddleware, + rate_limiter=app.state.global_rate_limiter, + settings=resolved_settings, + skip_paths=["/api/v1/auth/login", "/api/v1/setup", "/api/v1/history"], ) # Validate middleware order before returning the app. diff --git a/backend/app/middleware/rate_limit.py b/backend/app/middleware/rate_limit.py index f22f312..6cbf670 100644 --- a/backend/app/middleware/rate_limit.py +++ b/backend/app/middleware/rate_limit.py @@ -34,18 +34,20 @@ unusual and potentially suspicious) always carry a correlation ID for tracing. from __future__ import annotations -from collections.abc import Awaitable, Callable from typing import TYPE_CHECKING -from app.utils.logging_compat import get_logger from starlette.middleware.base import BaseHTTPMiddleware -from starlette.requests import Request from starlette.responses import JSONResponse, Response from app.exceptions import RateLimitError from app.utils.client_ip import get_client_ip +from app.utils.logging_compat import get_logger if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + from starlette.requests import Request + from app.config import Settings from app.utils.rate_limiter import GlobalRateLimiter @@ -53,11 +55,15 @@ log = get_logger(__name__) class RateLimitMiddleware(BaseHTTPMiddleware): - """Enforce global per-IP request rate limiting on all endpoints. + """Enforce per-IP request rate limiting on matching endpoints. Tracks requests per IP and blocks further requests if the limit is exceeded. Uses the application's GlobalRateLimiter instance and trusted-proxy settings for consistent IP extraction. + + Each middleware instance is scoped to a set of path prefixes (or all paths + if no prefixes are given). This allows multiple instances to coexist + without double-counting requests. """ def __init__( @@ -68,6 +74,8 @@ class RateLimitMiddleware(BaseHTTPMiddleware): bucket_override: str | None = None, bucket_max_requests: int | None = None, bucket_window_seconds: int | None = None, + path_prefixes: list[str] | None = None, + skip_paths: list[str] | None = None, ) -> None: """Initialize the rate limit middleware. @@ -78,6 +86,12 @@ class RateLimitMiddleware(BaseHTTPMiddleware): bucket_override: Optional named bucket to use instead of the default limiter. bucket_max_requests: Max requests for the bucket override. bucket_window_seconds: Window for the bucket override. + path_prefixes: If provided, only apply rate limiting to paths that + start with one of these prefixes. If ``None``, all paths are + matched. + skip_paths: If provided, do not apply rate limiting to paths that + start with one of these prefixes. Evaluated after + ``path_prefixes``. """ super().__init__(app) # type: ignore[arg-type] self.rate_limiter: GlobalRateLimiter = rate_limiter @@ -85,6 +99,23 @@ class RateLimitMiddleware(BaseHTTPMiddleware): self.bucket_override = bucket_override self.bucket_max_requests = bucket_max_requests self.bucket_window_seconds = bucket_window_seconds + self.path_prefixes = path_prefixes or [] + self.skip_paths = skip_paths or [] + + def _should_check(self, path: str) -> bool: + """Return whether the given path should be rate-limited by this instance. + + Args: + path: The request URL path. + + Returns: + ``True`` if this instance should enforce its limit on the path. + """ + if self.skip_paths and any(path.startswith(p) for p in self.skip_paths): + return False + if self.path_prefixes: + return any(path.startswith(p) for p in self.path_prefixes) + return True async def dispatch( self, @@ -103,37 +134,28 @@ class RateLimitMiddleware(BaseHTTPMiddleware): Returns: A response object (either rate limit response or from handler). """ - client_ip = get_client_ip(request, trusted_proxies=self.settings.trusted_proxies) - - # Use higher-rate bucket for specific endpoints. - # Check path to apply the appropriate bucket. path = request.url.path + if not self._should_check(path): + return await call_next(request) + + client_ip = get_client_ip(request, trusted_proxies=self.settings.trusted_proxies) + if self.bucket_override and self.bucket_max_requests and self.bucket_window_seconds: - if path.startswith("/api/v1/history"): - is_allowed, retry_after = self.rate_limiter.check_allowed_for_bucket( - self.bucket_override, - client_ip, - self.bucket_max_requests, - self.bucket_window_seconds, - ) - elif path.startswith("/api/v1/login") or path.startswith("/api/v1/setup"): - # Auth endpoints use their own bucket - is_allowed, retry_after = self.rate_limiter.check_allowed_for_bucket( - self.bucket_override, - client_ip, - self.bucket_max_requests, - self.bucket_window_seconds, - ) - else: - is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip) + is_allowed, retry_after = self.rate_limiter.check_allowed_for_bucket( + self.bucket_override, + client_ip, + self.bucket_max_requests, + self.bucket_window_seconds, + ) else: is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip) + if not is_allowed: log.warning( "global_rate_limit_exceeded", client_ip=client_ip, - path=request.url.path, + path=path, method=request.method, retry_after=retry_after, ) @@ -141,7 +163,6 @@ class RateLimitMiddleware(BaseHTTPMiddleware): "Too many requests. Please try again later.", retry_after_seconds=retry_after, ) - # Return the error response directly return JSONResponse( status_code=429, content={ @@ -153,6 +174,5 @@ class RateLimitMiddleware(BaseHTTPMiddleware): headers={"Retry-After": str(int(retry_after))}, ) - # Request is allowed, continue to next handler response: Response = await call_next(request) return response diff --git a/backend/tests/test_utils/test_global_rate_limiter.py b/backend/tests/test_utils/test_global_rate_limiter.py index b66dd3e..bea3ac8 100644 --- a/backend/tests/test_utils/test_global_rate_limiter.py +++ b/backend/tests/test_utils/test_global_rate_limiter.py @@ -134,24 +134,17 @@ class TestRateLimitMiddleware: """Global rate limit should block requests exceeding per-IP limit.""" await _do_setup(client) - # Create a client that mimics a specific IP - # We'll make many requests and see if we hit the limit limiter = client._transport.app.state.global_rate_limiter limiter.reset() - # Reduce limit temporarily for testing. - # Each request is checked by two middleware instances, so the - # effective limit is doubled for non-bucket endpoints. original_max = limiter.max_requests - limiter.max_requests = 7 + limiter.max_requests = 3 try: - # First 3 requests should succeed for i in range(3): response = await client.get("/api/v1/health") assert response.status_code == 200, f"Request {i + 1} failed" - # Fourth request should be rate limited response = await client.get("/api/v1/health") assert response.status_code == 429 assert response.json()["code"] == "rate_limit_exceeded" @@ -166,22 +159,47 @@ class TestRateLimitMiddleware: limiter = client._transport.app.state.global_rate_limiter limiter.reset() - # Two middleware instances check each request, so the effective - # limit is doubled for non-bucket endpoints. original_max = limiter.max_requests - limiter.max_requests = 3 + limiter.max_requests = 2 try: - # First request succeeds response = await client.get("/api/v1/health") assert response.status_code == 200 - # Second request is rate limited + response = await client.get("/api/v1/health") + assert response.status_code == 200 + response = await client.get("/api/v1/health") assert response.status_code == 429 assert "Retry-After" in response.headers retry_after = int(response.headers["Retry-After"]) assert retry_after > 0 - assert retry_after <= 60 # Should be less than window + assert retry_after <= 60 finally: limiter.max_requests = original_max + + async def test_auth_bucket_allows_more_requests(self, client: AsyncClient) -> None: + """Auth endpoints use a dedicated high-rate bucket.""" + await _do_setup(client) + + limiter = client._transport.app.state.global_rate_limiter + limiter.reset() + + # The auth bucket is configured for 1000 req/min; we only need to + # verify that it is *not* the global bucket (200 req/min). + for _ in range(5): + response = await client.post("/api/v1/auth/login", json={"password": "x"}) + assert response.status_code in (401, 403, 429) + + async def test_history_bucket_allows_more_requests(self, client: AsyncClient) -> None: + """History endpoints use a dedicated high-rate bucket.""" + await _do_setup(client) + + limiter = client._transport.app.state.global_rate_limiter + limiter.reset() + + for _ in range(5): + response = await client.get("/api/v1/history/bans") + # 401/403 is fine — we just need to confirm we are not 429'd + # by the global limiter. + assert response.status_code != 429