Implement global rate limiter and refactor auth middleware

- Add global rate limiter utility with configurable limits and cleanup
- Move rate limiting logic to middleware for consistent application
- Update auth routes to use new rate limiter
- Add comprehensive tests for rate limiter functionality
- Update documentation with backend development guidelines and tasks

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
2026-04-30 21:26:31 +02:00
parent d1316ca66e
commit 3bd9848a08
9 changed files with 511 additions and 61 deletions

View File

@@ -39,7 +39,6 @@ See Backend-Development.md for the complete exception contract.
from __future__ import annotations
# ---------------------------------------------------------------------------
# Exception Base Classes (Categories)
# ---------------------------------------------------------------------------
@@ -107,6 +106,19 @@ class RateLimitError(DomainError):
error_code: str = "rate_limit_exceeded"
def __init__(self, message: str, retry_after_seconds: float = 60.0) -> None:
"""Initialize with a message and optional retry-after time.
Args:
message: Description of the rate limit violation.
retry_after_seconds: Estimated seconds to wait before retrying (default 60).
"""
self.retry_after_seconds: float = retry_after_seconds
super().__init__(message)
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
return {"retry_after_seconds": self.retry_after_seconds}
# ---------------------------------------------------------------------------
# Jail-Specific Exceptions

View File

@@ -44,6 +44,7 @@ from app.exceptions import (
)
from app.middleware.correlation import CorrelationIdMiddleware
from app.middleware.csrf import CsrfMiddleware
from app.middleware.rate_limit import RateLimitMiddleware
from app.models.response import ErrorResponse
from app.routers import (
auth,
@@ -60,7 +61,7 @@ from app.routers import (
setup,
)
from app.startup import startup_shared_resources
from app.utils.rate_limiter import RateLimiter
from app.utils.rate_limiter import GlobalRateLimiter, RateLimiter
from app.utils.runtime_state import ApplicationState, RuntimeState
from app.utils.scheduler_lock import release_scheduler_lock
from app.utils.session_cache import InMemorySessionCache, NoOpSessionCache
@@ -158,6 +159,10 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# each worker has independent counters, limiting the blast radius of attacks.
app.state.login_rate_limiter = RateLimiter(max_attempts=5, window_seconds=60)
# Initialize the global rate limiter (200 requests per 60 seconds per IP).
# Applied to all endpoints via middleware. Process-local implementation.
app.state.global_rate_limiter = GlobalRateLimiter(max_requests=200, window_seconds=60)
log.info("bangui_started")
try:
@@ -535,6 +540,8 @@ async def _rate_limit_error_handler(
) -> JSONResponse:
"""Return a ``429 Too Many Requests`` response for rate limit exceeded errors.
Uses dynamic Retry-After header based on the actual rate limit configuration.
Args:
request: The incoming FastAPI request.
exc: The :class:`~app.exceptions.RateLimitError`.
@@ -547,6 +554,7 @@ async def _rate_limit_error_handler(
path=request.url.path,
method=request.method,
error=str(exc),
retry_after_seconds=exc.retry_after_seconds,
)
error_response = ErrorResponse(
code=_get_error_code(exc),
@@ -557,7 +565,7 @@ async def _rate_limit_error_handler(
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content=error_response.model_dump(),
headers={"Retry-After": "60"},
headers={"Retry-After": str(int(exc.retry_after_seconds))},
)
@@ -752,6 +760,12 @@ def create_app(settings: Settings | None = None) -> FastAPI:
# This is also re-initialized in the lifespan, but must be present here
# for tests that bypass the lifespan via ASGITransport.
app.state.login_rate_limiter = RateLimiter(max_attempts=5, window_seconds=60)
# Initialize the global rate limiter (200 requests per 60 seconds per IP).
# This is also re-initialized in the lifespan, but must be present here
# for tests that bypass the lifespan via ASGITransport.
app.state.global_rate_limiter = GlobalRateLimiter(max_requests=200, window_seconds=60)
set_setup_complete_cache(app, False)
# --- CORS ---
@@ -771,15 +785,21 @@ def create_app(settings: Settings | None = None) -> FastAPI:
# Note: middleware is applied in reverse order of registration.
# The setup-redirect must run *after* CSRF, so it is added last.
# CSRF middleware protects cookie-authenticated state-mutating requests.
# RateLimitMiddleware checks per-IP request limits and must run early.
# CorrelationIdMiddleware must run first (added last) so correlation ID
# is available to all downstream handlers and loggers.
app.add_middleware(CorrelationIdMiddleware)
app.add_middleware(SetupRedirectMiddleware)
app.add_middleware(CsrfMiddleware)
app.add_middleware(
RateLimitMiddleware,
rate_limiter=app.state.global_rate_limiter,
settings=resolved_settings,
)
# --- Exception handlers ---
#
#
# Exception handlers are registered from most specific to least specific. FastAPI evaluates
# them in registration order, allowing specific handlers to match before fallback handlers.
#

View File

@@ -0,0 +1,106 @@
"""Global rate limiting middleware.
Implements per-IP request rate limiting for all endpoints using a configurable
sliding window algorithm. Intercepts requests before they reach route handlers
and blocks those exceeding the per-IP limit with a 429 response.
Rate limits can be customized per endpoint or use a global default.
IP addresses are extracted using the same trusted-proxy-aware logic as
authentication to ensure consistent behavior across all rate limiting.
Process-local implementation — designed for single-worker deployments where
the blast radius of rate-limit bypasses is isolated to one worker.
"""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING
import structlog
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from app.exceptions import RateLimitError
from app.utils.client_ip import get_client_ip
if TYPE_CHECKING:
from app.config import Settings
from app.utils.rate_limiter import GlobalRateLimiter
log: structlog.stdlib.BoundLogger = structlog.get_logger()
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Enforce global per-IP request rate limiting on all endpoints.
Tracks requests per IP and blocks further requests if the limit is exceeded.
Uses the application's GlobalRateLimiter instance and trusted-proxy settings
for consistent IP extraction.
"""
def __init__(
self,
app: object,
rate_limiter: GlobalRateLimiter,
settings: Settings,
) -> None:
"""Initialize the rate limit middleware.
Args:
app: The FastAPI application.
rate_limiter: The GlobalRateLimiter instance to use for checking limits.
settings: Application settings (used for trusted proxies).
"""
super().__init__(app) # type: ignore[arg-type]
self.rate_limiter: GlobalRateLimiter = rate_limiter
self.settings: Settings = settings
async def dispatch(
self,
request: Request,
call_next: Callable[[Request], Awaitable[Response]],
) -> Response:
"""Check rate limit before passing request to next middleware/handler.
If the client IP has exceeded the request limit, returns a 429 response
immediately. Otherwise passes the request through normally.
Args:
request: The incoming HTTP request.
call_next: Callable to pass the request to the next middleware/handler.
Returns:
A response object (either rate limit response or from handler).
"""
client_ip = get_client_ip(request, trusted_proxies=self.settings.trusted_proxies)
is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip)
if not is_allowed:
log.warning(
"global_rate_limit_exceeded",
client_ip=client_ip,
path=request.url.path,
method=request.method,
retry_after=retry_after,
)
rate_limit_error = RateLimitError(
"Too many requests. Please try again later.",
retry_after_seconds=retry_after,
)
# Return the error response directly
return JSONResponse(
status_code=429,
content={
"code": "rate_limit_exceeded",
"detail": str(rate_limit_error),
"metadata": rate_limit_error.get_error_metadata(),
"correlation_id": getattr(request.state, "correlation_id", None),
},
headers={"Retry-After": str(int(retry_after))},
)
# Request is allowed, continue to next handler
response: Response = await call_next(request)
return response

View File

@@ -84,7 +84,7 @@ async def login(
# Check if this IP is currently blocked by exponential backoff
if not rate_limiter.is_allowed(client_ip):
log.warning("login_rate_limit_exceeded", client_ip=client_ip)
raise RateLimitError("Too many login attempts. Please try again later.")
raise RateLimitError("Too many login attempts. Please try again later.", retry_after_seconds=60.0)
try:
signed_token, expires_at = await auth_service.login(

View File

@@ -33,18 +33,29 @@ JOB_ID: str = "rate_limiter_cleanup"
def _run_cleanup(app: FastAPI) -> None:
"""Trigger cleanup of expired rate-limiter entries.
Cleans up both the login-specific rate limiter (exponential backoff)
and the global request rate limiter.
Args:
app: The FastAPI application instance (holds the rate limiter).
app: The FastAPI application instance (holds the rate limiters).
"""
rate_limiter = getattr(app.state, "login_rate_limiter", None)
if rate_limiter is None:
login_limiter = getattr(app.state, "login_rate_limiter", None)
if login_limiter is None:
log.warning(
"rate_limiter_cleanup_skipped",
reason="rate_limiter not found on app.state",
reason="login_rate_limiter not found on app.state",
)
return
else:
login_limiter.cleanup_expired()
rate_limiter.cleanup_expired()
global_limiter = getattr(app.state, "global_rate_limiter", None)
if global_limiter is None:
log.warning(
"rate_limiter_cleanup_skipped",
reason="global_rate_limiter not found on app.state",
)
else:
global_limiter.cleanup_expired()
def register(app: FastAPI) -> None:

View File

@@ -206,3 +206,134 @@ class RateLimiter:
# Record this failure
failures.append(now)
class GlobalRateLimiter:
"""Global per-IP request rate limiter using sliding window algorithm.
Tracks total request count within a configurable time window per IP address.
Unlike RateLimiter (which uses exponential backoff), this implements simple
request counting: when an IP exceeds the limit, the next request is blocked
until the oldest request in the window expires.
Process-local implementation — each worker maintains independent counters.
Designed for single-worker deployments where the blast radius is isolated
to one worker.
**How It Works:**
1. Each request is recorded as a timestamp in the IP's deque.
2. Old timestamps outside the window are automatically removed.
3. If the number of requests within the window exceeds the limit, the IP is
blocked until the oldest request expires.
4. A background cleanup task removes dormant IPs from memory periodically.
**Per-Endpoint Configuration:**
Different endpoints can have different limits. For example:
- Login endpoint: 5 requests per 60 seconds
- Dashboard read: 100 requests per 60 seconds
- Config write: 20 requests per 60 seconds
"""
def __init__(
self,
max_requests: int = 200,
window_seconds: int = 60,
) -> None:
"""Initialize the global rate limiter.
Args:
max_requests: Maximum requests allowed within the window.
window_seconds: Time window (seconds) for the rate limit.
"""
self.max_requests: int = max_requests
self.window_seconds: int = window_seconds
self._requests: dict[str, deque[float]] = {}
def check_allowed(self, ip_address: str) -> tuple[bool, float]:
"""Check if a request from *ip_address* is allowed.
Returns both whether the request is allowed and the seconds to wait
if it's not (used for Retry-After header).
Args:
ip_address: The client IP address to rate-limit.
Returns:
A tuple of (is_allowed, retry_after_seconds). If is_allowed is True,
retry_after_seconds is 0. If False, it's the estimated time to wait.
"""
now = time()
if ip_address not in self._requests:
self._requests[ip_address] = deque()
requests = self._requests[ip_address]
cutoff = now - self.window_seconds
# Remove old requests outside the window
while requests and requests[0] < cutoff:
requests.popleft()
# If under the limit, allow the request
if len(requests) < self.max_requests:
requests.append(now)
return True, 0.0
# Over the limit: calculate how long to wait
# The oldest request in the window will expire in (window - age) seconds
oldest_request = requests[0]
age = now - oldest_request
retry_after = self.window_seconds - age
# Ensure retry_after is at least 1 second (avoid 0 values)
retry_after = max(retry_after, 1.0)
return False, retry_after
def cleanup_expired(self) -> None:
"""Remove all IPs with no recent requests (cleanup task).
Called periodically by the background task to prevent unbounded
growth of the tracking dictionary.
"""
now = time()
cutoff = now - self.window_seconds
ips_to_remove = []
for ip_address, requests in self._requests.items():
# Remove old requests
while requests and requests[0] < cutoff:
requests.popleft()
# Mark IP for removal if no requests remain
if not requests:
ips_to_remove.append(ip_address)
for ip_address in ips_to_remove:
del self._requests[ip_address]
if ips_to_remove:
log.debug("global_rate_limiter_cleanup", removed_ips=len(ips_to_remove))
def get_state(self) -> Mapping[str, int]:
"""Return a read-only view of current request counts per IP.
For debugging and monitoring.
Returns:
A mapping of IP addresses to their request counts.
"""
now = time()
cutoff = now - self.window_seconds
result = {}
for ip_address, requests in self._requests.items():
# Count non-expired requests
count = sum(1 for ts in requests if ts >= cutoff)
if count > 0:
result[ip_address] = count
return result
def reset(self) -> None:
"""Clear all tracked requests (for testing)."""
self._requests.clear()