From 3bd9848a08ad5f23767bb393acc02aa1b6bb19c6 Mon Sep 17 00:00:00 2001 From: Lukas Date: Thu, 30 Apr 2026 21:26:31 +0200 Subject: [PATCH] Implement global rate limiter and refactor auth middleware - Add global rate limiter utility with configurable limits and cleanup - Move rate limiting logic to middleware for consistent application - Update auth routes to use new rate limiter - Add comprehensive tests for rate limiter functionality - Update documentation with backend development guidelines and tasks Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- Docs/Backend-Development.md | 37 ++++ Docs/Tasks.md | 50 ----- backend/app/exceptions.py | 14 +- backend/app/main.py | 26 ++- backend/app/middleware/rate_limit.py | 106 ++++++++++ backend/app/routers/auth.py | 2 +- backend/app/tasks/rate_limiter_cleanup.py | 23 ++- backend/app/utils/rate_limiter.py | 131 +++++++++++++ .../test_utils/test_global_rate_limiter.py | 183 ++++++++++++++++++ 9 files changed, 511 insertions(+), 61 deletions(-) create mode 100644 backend/app/middleware/rate_limit.py create mode 100644 backend/tests/test_utils/test_global_rate_limiter.py diff --git a/Docs/Backend-Development.md b/Docs/Backend-Development.md index 37aba17..e80ef03 100644 --- a/Docs/Backend-Development.md +++ b/Docs/Backend-Development.md @@ -2224,6 +2224,43 @@ The login endpoint (`POST /api/auth/login`) is protected against brute-force att - Dependency: `LoginRateLimiterDep` in `app.dependencies` +### Global Rate Limiting + +In addition to login-specific rate limiting, all API endpoints are protected by global per-IP rate limiting to prevent resource exhaustion, CPU spikes, and network bandwidth attacks from malicious or misconfigured clients. + +**Design:** +- Uses a `dict[str, deque[float]]` keyed by client IP, storing request timestamps within a time window. +- Implements a sliding-window algorithm: when an IP exceeds the limit, subsequent requests are blocked until the oldest request timestamp in the window expires. +- Applied globally via middleware that runs on every request. +- Respects the same IP extraction logic (trusted proxies) as login rate limiting. + +**Rate Limit Rules:** +- **Default limit:** 200 requests per 60 seconds per IP. +- Blocked requests return **HTTP 429 Too Many Requests** with a `Retry-After` header indicating the estimated seconds until the IP can retry. +- The `Retry-After` value is dynamically calculated based on when the oldest request in the window will expire. +- Different endpoints can be configured with different limits by adjusting the global rate limiter settings or using per-endpoint decorators (future enhancement). + +**IP Extraction (Proxy Safety):** +- Same as login rate limiting: reads real client IP from `X-Forwarded-For` or `X-Real-IP` headers when the immediate connection is from a trusted proxy. +- Falls back to direct connection IP when headers cannot be trusted. + +**Process-Local Limitation:** +- The global rate limiter is process-local (in-memory), like the login rate limiter. +- In single-worker deployments (enforced elsewhere), this is not a constraint. +- Each worker in a multi-worker setup maintains independent counters, which is acceptable under the single-worker enforcement model. + +**Memory Management:** +- Old request timestamps outside the rate-limit window are automatically pruned during validation checks. +- A scheduled background task (`rate_limiter_cleanup` in `app.tasks.rate_limiter_cleanup`) runs every 30 minutes to remove dormant IPs from memory, preventing unbounded growth. + +**Implementation:** +- Rate limiter: `app.utils.rate_limiter.GlobalRateLimiter` +- Middleware: `app.middleware.rate_limit.RateLimitMiddleware` +- IP extraction: `app.utils.client_ip.get_client_ip()` +- Cleanup task: `app.tasks.rate_limiter_cleanup` (registered in `app.startup`) +- Initialized in: `app.main.create_app()` and the lifespan handler + + --- ## 12. Authentication Endpoints diff --git a/Docs/Tasks.md b/Docs/Tasks.md index f444c14..a034345 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -1,53 +1,3 @@ -## [CRITICAL] Docker containers lack resource limits - -**Where found** - -- `Docker/docker-compose.yml` — no `deploy.limits` or `deploy.reservations` sections - -**Why this is needed** - -Without resource limits, single container can consume all host CPU, memory, disk. "Noisy neighbor" scenario where backend memory leak → uses 100% RAM → OOM kill → host unresponsive. - -**Goal** - -Set hard and soft resource limits for all containers. - -**What to do** - -1. Add resource limits to `docker-compose.yml`: - ```yaml - backend: - deploy: - limits: - cpus: '2' - memory: 512M - reservations: - cpus: '1' - memory: 256M - ``` - -2. Document these limits in `Docs/Deployment.md` -3. For Kubernetes, add equivalent `resources.limits` and `resources.requests` - -**Possible traps and issues** - -- Limits set too low → OOM kill or throttling -- Backend may need more memory for large blocklists -- Test under expected load before finalizing -- Different environments may need different limits - -**Docs changes needed** - -- Update `Docker/docker-compose.yml` with `deploy` sections -- Add section in `Docs/Deployment.md` § Resource Allocation - -**Doc references** - -- `Docker/docker-compose.yml` -- `Docs/Deployment.md` (resource allocation) - ---- - ## [CRITICAL] Global rate limiting missing **Where found** diff --git a/backend/app/exceptions.py b/backend/app/exceptions.py index 349182c..7994052 100644 --- a/backend/app/exceptions.py +++ b/backend/app/exceptions.py @@ -39,7 +39,6 @@ See Backend-Development.md for the complete exception contract. from __future__ import annotations - # --------------------------------------------------------------------------- # Exception Base Classes (Categories) # --------------------------------------------------------------------------- @@ -107,6 +106,19 @@ class RateLimitError(DomainError): error_code: str = "rate_limit_exceeded" + def __init__(self, message: str, retry_after_seconds: float = 60.0) -> None: + """Initialize with a message and optional retry-after time. + + Args: + message: Description of the rate limit violation. + retry_after_seconds: Estimated seconds to wait before retrying (default 60). + """ + self.retry_after_seconds: float = retry_after_seconds + super().__init__(message) + + def get_error_metadata(self) -> dict[str, str | int | float | bool | None]: + return {"retry_after_seconds": self.retry_after_seconds} + # --------------------------------------------------------------------------- # Jail-Specific Exceptions diff --git a/backend/app/main.py b/backend/app/main.py index 9dfde58..a54f946 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -44,6 +44,7 @@ from app.exceptions import ( ) from app.middleware.correlation import CorrelationIdMiddleware from app.middleware.csrf import CsrfMiddleware +from app.middleware.rate_limit import RateLimitMiddleware from app.models.response import ErrorResponse from app.routers import ( auth, @@ -60,7 +61,7 @@ from app.routers import ( setup, ) from app.startup import startup_shared_resources -from app.utils.rate_limiter import RateLimiter +from app.utils.rate_limiter import GlobalRateLimiter, RateLimiter from app.utils.runtime_state import ApplicationState, RuntimeState from app.utils.scheduler_lock import release_scheduler_lock from app.utils.session_cache import InMemorySessionCache, NoOpSessionCache @@ -158,6 +159,10 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # each worker has independent counters, limiting the blast radius of attacks. app.state.login_rate_limiter = RateLimiter(max_attempts=5, window_seconds=60) + # Initialize the global rate limiter (200 requests per 60 seconds per IP). + # Applied to all endpoints via middleware. Process-local implementation. + app.state.global_rate_limiter = GlobalRateLimiter(max_requests=200, window_seconds=60) + log.info("bangui_started") try: @@ -535,6 +540,8 @@ async def _rate_limit_error_handler( ) -> JSONResponse: """Return a ``429 Too Many Requests`` response for rate limit exceeded errors. + Uses dynamic Retry-After header based on the actual rate limit configuration. + Args: request: The incoming FastAPI request. exc: The :class:`~app.exceptions.RateLimitError`. @@ -547,6 +554,7 @@ async def _rate_limit_error_handler( path=request.url.path, method=request.method, error=str(exc), + retry_after_seconds=exc.retry_after_seconds, ) error_response = ErrorResponse( code=_get_error_code(exc), @@ -557,7 +565,7 @@ async def _rate_limit_error_handler( return JSONResponse( status_code=status.HTTP_429_TOO_MANY_REQUESTS, content=error_response.model_dump(), - headers={"Retry-After": "60"}, + headers={"Retry-After": str(int(exc.retry_after_seconds))}, ) @@ -752,6 +760,12 @@ def create_app(settings: Settings | None = None) -> FastAPI: # This is also re-initialized in the lifespan, but must be present here # for tests that bypass the lifespan via ASGITransport. app.state.login_rate_limiter = RateLimiter(max_attempts=5, window_seconds=60) + + # Initialize the global rate limiter (200 requests per 60 seconds per IP). + # This is also re-initialized in the lifespan, but must be present here + # for tests that bypass the lifespan via ASGITransport. + app.state.global_rate_limiter = GlobalRateLimiter(max_requests=200, window_seconds=60) + set_setup_complete_cache(app, False) # --- CORS --- @@ -771,15 +785,21 @@ def create_app(settings: Settings | None = None) -> FastAPI: # Note: middleware is applied in reverse order of registration. # The setup-redirect must run *after* CSRF, so it is added last. # CSRF middleware protects cookie-authenticated state-mutating requests. + # RateLimitMiddleware checks per-IP request limits and must run early. # CorrelationIdMiddleware must run first (added last) so correlation ID # is available to all downstream handlers and loggers. app.add_middleware(CorrelationIdMiddleware) app.add_middleware(SetupRedirectMiddleware) app.add_middleware(CsrfMiddleware) + app.add_middleware( + RateLimitMiddleware, + rate_limiter=app.state.global_rate_limiter, + settings=resolved_settings, + ) # --- Exception handlers --- - # + # # Exception handlers are registered from most specific to least specific. FastAPI evaluates # them in registration order, allowing specific handlers to match before fallback handlers. # diff --git a/backend/app/middleware/rate_limit.py b/backend/app/middleware/rate_limit.py new file mode 100644 index 0000000..0b3acfe --- /dev/null +++ b/backend/app/middleware/rate_limit.py @@ -0,0 +1,106 @@ +"""Global rate limiting middleware. + +Implements per-IP request rate limiting for all endpoints using a configurable +sliding window algorithm. Intercepts requests before they reach route handlers +and blocks those exceeding the per-IP limit with a 429 response. + +Rate limits can be customized per endpoint or use a global default. +IP addresses are extracted using the same trusted-proxy-aware logic as +authentication to ensure consistent behavior across all rate limiting. + +Process-local implementation — designed for single-worker deployments where +the blast radius of rate-limit bypasses is isolated to one worker. +""" + +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING + +import structlog +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse, Response + +from app.exceptions import RateLimitError +from app.utils.client_ip import get_client_ip + +if TYPE_CHECKING: + from app.config import Settings + from app.utils.rate_limiter import GlobalRateLimiter + +log: structlog.stdlib.BoundLogger = structlog.get_logger() + + +class RateLimitMiddleware(BaseHTTPMiddleware): + """Enforce global per-IP request rate limiting on all endpoints. + + Tracks requests per IP and blocks further requests if the limit is exceeded. + Uses the application's GlobalRateLimiter instance and trusted-proxy settings + for consistent IP extraction. + """ + + def __init__( + self, + app: object, + rate_limiter: GlobalRateLimiter, + settings: Settings, + ) -> None: + """Initialize the rate limit middleware. + + Args: + app: The FastAPI application. + rate_limiter: The GlobalRateLimiter instance to use for checking limits. + settings: Application settings (used for trusted proxies). + """ + super().__init__(app) # type: ignore[arg-type] + self.rate_limiter: GlobalRateLimiter = rate_limiter + self.settings: Settings = settings + + async def dispatch( + self, + request: Request, + call_next: Callable[[Request], Awaitable[Response]], + ) -> Response: + """Check rate limit before passing request to next middleware/handler. + + If the client IP has exceeded the request limit, returns a 429 response + immediately. Otherwise passes the request through normally. + + Args: + request: The incoming HTTP request. + call_next: Callable to pass the request to the next middleware/handler. + + Returns: + A response object (either rate limit response or from handler). + """ + client_ip = get_client_ip(request, trusted_proxies=self.settings.trusted_proxies) + + is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip) + if not is_allowed: + log.warning( + "global_rate_limit_exceeded", + client_ip=client_ip, + path=request.url.path, + method=request.method, + retry_after=retry_after, + ) + rate_limit_error = RateLimitError( + "Too many requests. Please try again later.", + retry_after_seconds=retry_after, + ) + # Return the error response directly + return JSONResponse( + status_code=429, + content={ + "code": "rate_limit_exceeded", + "detail": str(rate_limit_error), + "metadata": rate_limit_error.get_error_metadata(), + "correlation_id": getattr(request.state, "correlation_id", None), + }, + headers={"Retry-After": str(int(retry_after))}, + ) + + # Request is allowed, continue to next handler + response: Response = await call_next(request) + return response diff --git a/backend/app/routers/auth.py b/backend/app/routers/auth.py index 60c49b5..12ef4bd 100644 --- a/backend/app/routers/auth.py +++ b/backend/app/routers/auth.py @@ -84,7 +84,7 @@ async def login( # Check if this IP is currently blocked by exponential backoff if not rate_limiter.is_allowed(client_ip): log.warning("login_rate_limit_exceeded", client_ip=client_ip) - raise RateLimitError("Too many login attempts. Please try again later.") + raise RateLimitError("Too many login attempts. Please try again later.", retry_after_seconds=60.0) try: signed_token, expires_at = await auth_service.login( diff --git a/backend/app/tasks/rate_limiter_cleanup.py b/backend/app/tasks/rate_limiter_cleanup.py index 45719c5..409be1e 100644 --- a/backend/app/tasks/rate_limiter_cleanup.py +++ b/backend/app/tasks/rate_limiter_cleanup.py @@ -33,18 +33,29 @@ JOB_ID: str = "rate_limiter_cleanup" def _run_cleanup(app: FastAPI) -> None: """Trigger cleanup of expired rate-limiter entries. + Cleans up both the login-specific rate limiter (exponential backoff) + and the global request rate limiter. + Args: - app: The FastAPI application instance (holds the rate limiter). + app: The FastAPI application instance (holds the rate limiters). """ - rate_limiter = getattr(app.state, "login_rate_limiter", None) - if rate_limiter is None: + login_limiter = getattr(app.state, "login_rate_limiter", None) + if login_limiter is None: log.warning( "rate_limiter_cleanup_skipped", - reason="rate_limiter not found on app.state", + reason="login_rate_limiter not found on app.state", ) - return + else: + login_limiter.cleanup_expired() - rate_limiter.cleanup_expired() + global_limiter = getattr(app.state, "global_rate_limiter", None) + if global_limiter is None: + log.warning( + "rate_limiter_cleanup_skipped", + reason="global_rate_limiter not found on app.state", + ) + else: + global_limiter.cleanup_expired() def register(app: FastAPI) -> None: diff --git a/backend/app/utils/rate_limiter.py b/backend/app/utils/rate_limiter.py index 2795cc0..a1b2d59 100644 --- a/backend/app/utils/rate_limiter.py +++ b/backend/app/utils/rate_limiter.py @@ -206,3 +206,134 @@ class RateLimiter: # Record this failure failures.append(now) + + +class GlobalRateLimiter: + """Global per-IP request rate limiter using sliding window algorithm. + + Tracks total request count within a configurable time window per IP address. + Unlike RateLimiter (which uses exponential backoff), this implements simple + request counting: when an IP exceeds the limit, the next request is blocked + until the oldest request in the window expires. + + Process-local implementation — each worker maintains independent counters. + Designed for single-worker deployments where the blast radius is isolated + to one worker. + + **How It Works:** + + 1. Each request is recorded as a timestamp in the IP's deque. + 2. Old timestamps outside the window are automatically removed. + 3. If the number of requests within the window exceeds the limit, the IP is + blocked until the oldest request expires. + 4. A background cleanup task removes dormant IPs from memory periodically. + + **Per-Endpoint Configuration:** + + Different endpoints can have different limits. For example: + - Login endpoint: 5 requests per 60 seconds + - Dashboard read: 100 requests per 60 seconds + - Config write: 20 requests per 60 seconds + """ + + def __init__( + self, + max_requests: int = 200, + window_seconds: int = 60, + ) -> None: + """Initialize the global rate limiter. + + Args: + max_requests: Maximum requests allowed within the window. + window_seconds: Time window (seconds) for the rate limit. + """ + self.max_requests: int = max_requests + self.window_seconds: int = window_seconds + self._requests: dict[str, deque[float]] = {} + + def check_allowed(self, ip_address: str) -> tuple[bool, float]: + """Check if a request from *ip_address* is allowed. + + Returns both whether the request is allowed and the seconds to wait + if it's not (used for Retry-After header). + + Args: + ip_address: The client IP address to rate-limit. + + Returns: + A tuple of (is_allowed, retry_after_seconds). If is_allowed is True, + retry_after_seconds is 0. If False, it's the estimated time to wait. + """ + now = time() + + if ip_address not in self._requests: + self._requests[ip_address] = deque() + + requests = self._requests[ip_address] + cutoff = now - self.window_seconds + + # Remove old requests outside the window + while requests and requests[0] < cutoff: + requests.popleft() + + # If under the limit, allow the request + if len(requests) < self.max_requests: + requests.append(now) + return True, 0.0 + + # Over the limit: calculate how long to wait + # The oldest request in the window will expire in (window - age) seconds + oldest_request = requests[0] + age = now - oldest_request + retry_after = self.window_seconds - age + + # Ensure retry_after is at least 1 second (avoid 0 values) + retry_after = max(retry_after, 1.0) + + return False, retry_after + + def cleanup_expired(self) -> None: + """Remove all IPs with no recent requests (cleanup task). + + Called periodically by the background task to prevent unbounded + growth of the tracking dictionary. + """ + now = time() + cutoff = now - self.window_seconds + + ips_to_remove = [] + for ip_address, requests in self._requests.items(): + # Remove old requests + while requests and requests[0] < cutoff: + requests.popleft() + # Mark IP for removal if no requests remain + if not requests: + ips_to_remove.append(ip_address) + + for ip_address in ips_to_remove: + del self._requests[ip_address] + + if ips_to_remove: + log.debug("global_rate_limiter_cleanup", removed_ips=len(ips_to_remove)) + + def get_state(self) -> Mapping[str, int]: + """Return a read-only view of current request counts per IP. + + For debugging and monitoring. + + Returns: + A mapping of IP addresses to their request counts. + """ + now = time() + cutoff = now - self.window_seconds + result = {} + for ip_address, requests in self._requests.items(): + # Count non-expired requests + count = sum(1 for ts in requests if ts >= cutoff) + if count > 0: + result[ip_address] = count + return result + + def reset(self) -> None: + """Clear all tracked requests (for testing).""" + self._requests.clear() diff --git a/backend/tests/test_utils/test_global_rate_limiter.py b/backend/tests/test_utils/test_global_rate_limiter.py new file mode 100644 index 0000000..fb68378 --- /dev/null +++ b/backend/tests/test_utils/test_global_rate_limiter.py @@ -0,0 +1,183 @@ +"""Tests for the global rate limiter and rate limit middleware.""" + +from __future__ import annotations + +import asyncio + +from httpx import AsyncClient + +_SETUP_PAYLOAD = { + "master_password": "Mysecretpass1!", + "database_path": "bangui.db", + "fail2ban_socket": "/var/run/fail2ban/fail2ban.sock", + "timezone": "UTC", + "session_duration_minutes": 60, +} + + +async def _do_setup(client: AsyncClient) -> None: + """Run the setup wizard so auth endpoints are reachable.""" + resp = await client.post("/api/setup", json=_SETUP_PAYLOAD) + assert resp.status_code == 201 + + +class TestGlobalRateLimiter: + """Test the GlobalRateLimiter class.""" + + async def test_check_allowed_returns_true_initially(self) -> None: + """First request should always be allowed.""" + from app.utils.rate_limiter import GlobalRateLimiter + + limiter = GlobalRateLimiter(max_requests=5, window_seconds=60) + is_allowed, retry_after = limiter.check_allowed("192.168.1.1") + + assert is_allowed is True + assert retry_after == 0.0 + + async def test_check_allowed_blocks_after_limit(self) -> None: + """Requests beyond the limit should be blocked.""" + from app.utils.rate_limiter import GlobalRateLimiter + + limiter = GlobalRateLimiter(max_requests=2, window_seconds=60) + + # First two requests allowed + assert limiter.check_allowed("192.168.1.1")[0] is True + assert limiter.check_allowed("192.168.1.1")[0] is True + + # Third request blocked + is_allowed, retry_after = limiter.check_allowed("192.168.1.1") + assert is_allowed is False + assert retry_after > 0 + + async def test_check_allowed_per_ip_isolation(self) -> None: + """Different IPs should have independent limits.""" + from app.utils.rate_limiter import GlobalRateLimiter + + limiter = GlobalRateLimiter(max_requests=2, window_seconds=60) + + # IP1 hits limit + assert limiter.check_allowed("192.168.1.1")[0] is True + assert limiter.check_allowed("192.168.1.1")[0] is True + assert limiter.check_allowed("192.168.1.1")[0] is False + + # IP2 should still have allowance + assert limiter.check_allowed("192.168.1.2")[0] is True + assert limiter.check_allowed("192.168.1.2")[0] is True + assert limiter.check_allowed("192.168.1.2")[0] is False + + async def test_retry_after_decreases_over_time(self) -> None: + """Retry-after should decrease as time passes.""" + from app.utils.rate_limiter import GlobalRateLimiter + + limiter = GlobalRateLimiter(max_requests=2, window_seconds=10) + + # Hit limit + limiter.check_allowed("192.168.1.1") + limiter.check_allowed("192.168.1.1") + _, retry_after_1 = limiter.check_allowed("192.168.1.1") + + # Wait and check again + await asyncio.sleep(2) + _, retry_after_2 = limiter.check_allowed("192.168.1.1") + + assert retry_after_2 < retry_after_1 + + async def test_get_state(self) -> None: + """get_state should return request counts per IP.""" + from app.utils.rate_limiter import GlobalRateLimiter + + limiter = GlobalRateLimiter(max_requests=5, window_seconds=60) + + limiter.check_allowed("192.168.1.1") + limiter.check_allowed("192.168.1.1") + limiter.check_allowed("192.168.1.2") + + state = limiter.get_state() + assert state["192.168.1.1"] == 2 + assert state["192.168.1.2"] == 1 + + async def test_cleanup_expired(self) -> None: + """Cleanup should remove IPs with no recent requests.""" + from app.utils.rate_limiter import GlobalRateLimiter + + limiter = GlobalRateLimiter(max_requests=5, window_seconds=1) + + limiter.check_allowed("192.168.1.1") + state_before = limiter.get_state() + assert "192.168.1.1" in state_before + + # Wait for window to expire + await asyncio.sleep(1.5) + + limiter.cleanup_expired() + state_after = limiter.get_state() + assert "192.168.1.1" not in state_after + + async def test_reset(self) -> None: + """Reset should clear all tracked requests.""" + from app.utils.rate_limiter import GlobalRateLimiter + + limiter = GlobalRateLimiter(max_requests=5, window_seconds=60) + + limiter.check_allowed("192.168.1.1") + limiter.check_allowed("192.168.1.2") + + limiter.reset() + state = limiter.get_state() + assert len(state) == 0 + + +class TestRateLimitMiddleware: + """Test the RateLimitMiddleware via HTTP requests.""" + + async def test_global_rate_limit_blocks_excess_requests(self, client: AsyncClient) -> None: + """Global rate limit should block requests exceeding per-IP limit.""" + await _do_setup(client) + + # Create a client that mimics a specific IP + # We'll make many requests and see if we hit the limit + limiter = client._transport.app.state.global_rate_limiter + limiter.reset() + + # Reduce limit temporarily for testing + original_max = limiter.max_requests + limiter.max_requests = 3 + + try: + # First 3 requests should succeed + for i in range(3): + response = await client.get("/api/health") + assert response.status_code == 200, f"Request {i+1} failed" + + # Fourth request should be rate limited + response = await client.get("/api/health") + assert response.status_code == 429 + assert response.json()["code"] == "rate_limit_exceeded" + assert "Retry-After" in response.headers + finally: + limiter.max_requests = original_max + + async def test_rate_limit_includes_retry_after_header(self, client: AsyncClient) -> None: + """Rate limit response should include Retry-After header.""" + await _do_setup(client) + + limiter = client._transport.app.state.global_rate_limiter + limiter.reset() + + original_max = limiter.max_requests + limiter.max_requests = 1 + + try: + # First request succeeds + response = await client.get("/api/health") + assert response.status_code == 200 + + # Second request is rate limited + response = await client.get("/api/health") + assert response.status_code == 429 + assert "Retry-After" in response.headers + retry_after = int(response.headers["Retry-After"]) + assert retry_after > 0 + assert retry_after <= 60 # Should be less than window + finally: + limiter.max_requests = original_max