Implement global rate limiter and refactor auth middleware
- Add global rate limiter utility with configurable limits and cleanup - Move rate limiting logic to middleware for consistent application - Update auth routes to use new rate limiter - Add comprehensive tests for rate limiter functionality - Update documentation with backend development guidelines and tasks Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -39,7 +39,6 @@ See Backend-Development.md for the complete exception contract.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Exception Base Classes (Categories)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -107,6 +106,19 @@ class RateLimitError(DomainError):
|
||||
|
||||
error_code: str = "rate_limit_exceeded"
|
||||
|
||||
def __init__(self, message: str, retry_after_seconds: float = 60.0) -> None:
|
||||
"""Initialize with a message and optional retry-after time.
|
||||
|
||||
Args:
|
||||
message: Description of the rate limit violation.
|
||||
retry_after_seconds: Estimated seconds to wait before retrying (default 60).
|
||||
"""
|
||||
self.retry_after_seconds: float = retry_after_seconds
|
||||
super().__init__(message)
|
||||
|
||||
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||
return {"retry_after_seconds": self.retry_after_seconds}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Jail-Specific Exceptions
|
||||
|
||||
@@ -44,6 +44,7 @@ from app.exceptions import (
|
||||
)
|
||||
from app.middleware.correlation import CorrelationIdMiddleware
|
||||
from app.middleware.csrf import CsrfMiddleware
|
||||
from app.middleware.rate_limit import RateLimitMiddleware
|
||||
from app.models.response import ErrorResponse
|
||||
from app.routers import (
|
||||
auth,
|
||||
@@ -60,7 +61,7 @@ from app.routers import (
|
||||
setup,
|
||||
)
|
||||
from app.startup import startup_shared_resources
|
||||
from app.utils.rate_limiter import RateLimiter
|
||||
from app.utils.rate_limiter import GlobalRateLimiter, RateLimiter
|
||||
from app.utils.runtime_state import ApplicationState, RuntimeState
|
||||
from app.utils.scheduler_lock import release_scheduler_lock
|
||||
from app.utils.session_cache import InMemorySessionCache, NoOpSessionCache
|
||||
@@ -158,6 +159,10 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
# each worker has independent counters, limiting the blast radius of attacks.
|
||||
app.state.login_rate_limiter = RateLimiter(max_attempts=5, window_seconds=60)
|
||||
|
||||
# Initialize the global rate limiter (200 requests per 60 seconds per IP).
|
||||
# Applied to all endpoints via middleware. Process-local implementation.
|
||||
app.state.global_rate_limiter = GlobalRateLimiter(max_requests=200, window_seconds=60)
|
||||
|
||||
log.info("bangui_started")
|
||||
|
||||
try:
|
||||
@@ -535,6 +540,8 @@ async def _rate_limit_error_handler(
|
||||
) -> JSONResponse:
|
||||
"""Return a ``429 Too Many Requests`` response for rate limit exceeded errors.
|
||||
|
||||
Uses dynamic Retry-After header based on the actual rate limit configuration.
|
||||
|
||||
Args:
|
||||
request: The incoming FastAPI request.
|
||||
exc: The :class:`~app.exceptions.RateLimitError`.
|
||||
@@ -547,6 +554,7 @@ async def _rate_limit_error_handler(
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
error=str(exc),
|
||||
retry_after_seconds=exc.retry_after_seconds,
|
||||
)
|
||||
error_response = ErrorResponse(
|
||||
code=_get_error_code(exc),
|
||||
@@ -557,7 +565,7 @@ async def _rate_limit_error_handler(
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content=error_response.model_dump(),
|
||||
headers={"Retry-After": "60"},
|
||||
headers={"Retry-After": str(int(exc.retry_after_seconds))},
|
||||
)
|
||||
|
||||
|
||||
@@ -752,6 +760,12 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
||||
# This is also re-initialized in the lifespan, but must be present here
|
||||
# for tests that bypass the lifespan via ASGITransport.
|
||||
app.state.login_rate_limiter = RateLimiter(max_attempts=5, window_seconds=60)
|
||||
|
||||
# Initialize the global rate limiter (200 requests per 60 seconds per IP).
|
||||
# This is also re-initialized in the lifespan, but must be present here
|
||||
# for tests that bypass the lifespan via ASGITransport.
|
||||
app.state.global_rate_limiter = GlobalRateLimiter(max_requests=200, window_seconds=60)
|
||||
|
||||
set_setup_complete_cache(app, False)
|
||||
|
||||
# --- CORS ---
|
||||
@@ -771,15 +785,21 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
||||
# Note: middleware is applied in reverse order of registration.
|
||||
# The setup-redirect must run *after* CSRF, so it is added last.
|
||||
# CSRF middleware protects cookie-authenticated state-mutating requests.
|
||||
# RateLimitMiddleware checks per-IP request limits and must run early.
|
||||
# CorrelationIdMiddleware must run first (added last) so correlation ID
|
||||
# is available to all downstream handlers and loggers.
|
||||
app.add_middleware(CorrelationIdMiddleware)
|
||||
app.add_middleware(SetupRedirectMiddleware)
|
||||
app.add_middleware(CsrfMiddleware)
|
||||
app.add_middleware(
|
||||
RateLimitMiddleware,
|
||||
rate_limiter=app.state.global_rate_limiter,
|
||||
settings=resolved_settings,
|
||||
)
|
||||
|
||||
|
||||
# --- Exception handlers ---
|
||||
#
|
||||
#
|
||||
# Exception handlers are registered from most specific to least specific. FastAPI evaluates
|
||||
# them in registration order, allowing specific handlers to match before fallback handlers.
|
||||
#
|
||||
|
||||
106
backend/app/middleware/rate_limit.py
Normal file
106
backend/app/middleware/rate_limit.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Global rate limiting middleware.
|
||||
|
||||
Implements per-IP request rate limiting for all endpoints using a configurable
|
||||
sliding window algorithm. Intercepts requests before they reach route handlers
|
||||
and blocks those exceeding the per-IP limit with a 429 response.
|
||||
|
||||
Rate limits can be customized per endpoint or use a global default.
|
||||
IP addresses are extracted using the same trusted-proxy-aware logic as
|
||||
authentication to ensure consistent behavior across all rate limiting.
|
||||
|
||||
Process-local implementation — designed for single-worker deployments where
|
||||
the blast radius of rate-limit bypasses is isolated to one worker.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
from app.exceptions import RateLimitError
|
||||
from app.utils.client_ip import get_client_ip
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.config import Settings
|
||||
from app.utils.rate_limiter import GlobalRateLimiter
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Enforce global per-IP request rate limiting on all endpoints.
|
||||
|
||||
Tracks requests per IP and blocks further requests if the limit is exceeded.
|
||||
Uses the application's GlobalRateLimiter instance and trusted-proxy settings
|
||||
for consistent IP extraction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: object,
|
||||
rate_limiter: GlobalRateLimiter,
|
||||
settings: Settings,
|
||||
) -> None:
|
||||
"""Initialize the rate limit middleware.
|
||||
|
||||
Args:
|
||||
app: The FastAPI application.
|
||||
rate_limiter: The GlobalRateLimiter instance to use for checking limits.
|
||||
settings: Application settings (used for trusted proxies).
|
||||
"""
|
||||
super().__init__(app) # type: ignore[arg-type]
|
||||
self.rate_limiter: GlobalRateLimiter = rate_limiter
|
||||
self.settings: Settings = settings
|
||||
|
||||
async def dispatch(
|
||||
self,
|
||||
request: Request,
|
||||
call_next: Callable[[Request], Awaitable[Response]],
|
||||
) -> Response:
|
||||
"""Check rate limit before passing request to next middleware/handler.
|
||||
|
||||
If the client IP has exceeded the request limit, returns a 429 response
|
||||
immediately. Otherwise passes the request through normally.
|
||||
|
||||
Args:
|
||||
request: The incoming HTTP request.
|
||||
call_next: Callable to pass the request to the next middleware/handler.
|
||||
|
||||
Returns:
|
||||
A response object (either rate limit response or from handler).
|
||||
"""
|
||||
client_ip = get_client_ip(request, trusted_proxies=self.settings.trusted_proxies)
|
||||
|
||||
is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip)
|
||||
if not is_allowed:
|
||||
log.warning(
|
||||
"global_rate_limit_exceeded",
|
||||
client_ip=client_ip,
|
||||
path=request.url.path,
|
||||
method=request.method,
|
||||
retry_after=retry_after,
|
||||
)
|
||||
rate_limit_error = RateLimitError(
|
||||
"Too many requests. Please try again later.",
|
||||
retry_after_seconds=retry_after,
|
||||
)
|
||||
# Return the error response directly
|
||||
return JSONResponse(
|
||||
status_code=429,
|
||||
content={
|
||||
"code": "rate_limit_exceeded",
|
||||
"detail": str(rate_limit_error),
|
||||
"metadata": rate_limit_error.get_error_metadata(),
|
||||
"correlation_id": getattr(request.state, "correlation_id", None),
|
||||
},
|
||||
headers={"Retry-After": str(int(retry_after))},
|
||||
)
|
||||
|
||||
# Request is allowed, continue to next handler
|
||||
response: Response = await call_next(request)
|
||||
return response
|
||||
@@ -84,7 +84,7 @@ async def login(
|
||||
# Check if this IP is currently blocked by exponential backoff
|
||||
if not rate_limiter.is_allowed(client_ip):
|
||||
log.warning("login_rate_limit_exceeded", client_ip=client_ip)
|
||||
raise RateLimitError("Too many login attempts. Please try again later.")
|
||||
raise RateLimitError("Too many login attempts. Please try again later.", retry_after_seconds=60.0)
|
||||
|
||||
try:
|
||||
signed_token, expires_at = await auth_service.login(
|
||||
|
||||
@@ -33,18 +33,29 @@ JOB_ID: str = "rate_limiter_cleanup"
|
||||
def _run_cleanup(app: FastAPI) -> None:
|
||||
"""Trigger cleanup of expired rate-limiter entries.
|
||||
|
||||
Cleans up both the login-specific rate limiter (exponential backoff)
|
||||
and the global request rate limiter.
|
||||
|
||||
Args:
|
||||
app: The FastAPI application instance (holds the rate limiter).
|
||||
app: The FastAPI application instance (holds the rate limiters).
|
||||
"""
|
||||
rate_limiter = getattr(app.state, "login_rate_limiter", None)
|
||||
if rate_limiter is None:
|
||||
login_limiter = getattr(app.state, "login_rate_limiter", None)
|
||||
if login_limiter is None:
|
||||
log.warning(
|
||||
"rate_limiter_cleanup_skipped",
|
||||
reason="rate_limiter not found on app.state",
|
||||
reason="login_rate_limiter not found on app.state",
|
||||
)
|
||||
return
|
||||
else:
|
||||
login_limiter.cleanup_expired()
|
||||
|
||||
rate_limiter.cleanup_expired()
|
||||
global_limiter = getattr(app.state, "global_rate_limiter", None)
|
||||
if global_limiter is None:
|
||||
log.warning(
|
||||
"rate_limiter_cleanup_skipped",
|
||||
reason="global_rate_limiter not found on app.state",
|
||||
)
|
||||
else:
|
||||
global_limiter.cleanup_expired()
|
||||
|
||||
|
||||
def register(app: FastAPI) -> None:
|
||||
|
||||
@@ -206,3 +206,134 @@ class RateLimiter:
|
||||
|
||||
# Record this failure
|
||||
failures.append(now)
|
||||
|
||||
|
||||
class GlobalRateLimiter:
|
||||
"""Global per-IP request rate limiter using sliding window algorithm.
|
||||
|
||||
Tracks total request count within a configurable time window per IP address.
|
||||
Unlike RateLimiter (which uses exponential backoff), this implements simple
|
||||
request counting: when an IP exceeds the limit, the next request is blocked
|
||||
until the oldest request in the window expires.
|
||||
|
||||
Process-local implementation — each worker maintains independent counters.
|
||||
Designed for single-worker deployments where the blast radius is isolated
|
||||
to one worker.
|
||||
|
||||
**How It Works:**
|
||||
|
||||
1. Each request is recorded as a timestamp in the IP's deque.
|
||||
2. Old timestamps outside the window are automatically removed.
|
||||
3. If the number of requests within the window exceeds the limit, the IP is
|
||||
blocked until the oldest request expires.
|
||||
4. A background cleanup task removes dormant IPs from memory periodically.
|
||||
|
||||
**Per-Endpoint Configuration:**
|
||||
|
||||
Different endpoints can have different limits. For example:
|
||||
- Login endpoint: 5 requests per 60 seconds
|
||||
- Dashboard read: 100 requests per 60 seconds
|
||||
- Config write: 20 requests per 60 seconds
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_requests: int = 200,
|
||||
window_seconds: int = 60,
|
||||
) -> None:
|
||||
"""Initialize the global rate limiter.
|
||||
|
||||
Args:
|
||||
max_requests: Maximum requests allowed within the window.
|
||||
window_seconds: Time window (seconds) for the rate limit.
|
||||
"""
|
||||
self.max_requests: int = max_requests
|
||||
self.window_seconds: int = window_seconds
|
||||
self._requests: dict[str, deque[float]] = {}
|
||||
|
||||
def check_allowed(self, ip_address: str) -> tuple[bool, float]:
|
||||
"""Check if a request from *ip_address* is allowed.
|
||||
|
||||
Returns both whether the request is allowed and the seconds to wait
|
||||
if it's not (used for Retry-After header).
|
||||
|
||||
Args:
|
||||
ip_address: The client IP address to rate-limit.
|
||||
|
||||
Returns:
|
||||
A tuple of (is_allowed, retry_after_seconds). If is_allowed is True,
|
||||
retry_after_seconds is 0. If False, it's the estimated time to wait.
|
||||
"""
|
||||
now = time()
|
||||
|
||||
if ip_address not in self._requests:
|
||||
self._requests[ip_address] = deque()
|
||||
|
||||
requests = self._requests[ip_address]
|
||||
cutoff = now - self.window_seconds
|
||||
|
||||
# Remove old requests outside the window
|
||||
while requests and requests[0] < cutoff:
|
||||
requests.popleft()
|
||||
|
||||
# If under the limit, allow the request
|
||||
if len(requests) < self.max_requests:
|
||||
requests.append(now)
|
||||
return True, 0.0
|
||||
|
||||
# Over the limit: calculate how long to wait
|
||||
# The oldest request in the window will expire in (window - age) seconds
|
||||
oldest_request = requests[0]
|
||||
age = now - oldest_request
|
||||
retry_after = self.window_seconds - age
|
||||
|
||||
# Ensure retry_after is at least 1 second (avoid 0 values)
|
||||
retry_after = max(retry_after, 1.0)
|
||||
|
||||
return False, retry_after
|
||||
|
||||
def cleanup_expired(self) -> None:
|
||||
"""Remove all IPs with no recent requests (cleanup task).
|
||||
|
||||
Called periodically by the background task to prevent unbounded
|
||||
growth of the tracking dictionary.
|
||||
"""
|
||||
now = time()
|
||||
cutoff = now - self.window_seconds
|
||||
|
||||
ips_to_remove = []
|
||||
for ip_address, requests in self._requests.items():
|
||||
# Remove old requests
|
||||
while requests and requests[0] < cutoff:
|
||||
requests.popleft()
|
||||
# Mark IP for removal if no requests remain
|
||||
if not requests:
|
||||
ips_to_remove.append(ip_address)
|
||||
|
||||
for ip_address in ips_to_remove:
|
||||
del self._requests[ip_address]
|
||||
|
||||
if ips_to_remove:
|
||||
log.debug("global_rate_limiter_cleanup", removed_ips=len(ips_to_remove))
|
||||
|
||||
def get_state(self) -> Mapping[str, int]:
|
||||
"""Return a read-only view of current request counts per IP.
|
||||
|
||||
For debugging and monitoring.
|
||||
|
||||
Returns:
|
||||
A mapping of IP addresses to their request counts.
|
||||
"""
|
||||
now = time()
|
||||
cutoff = now - self.window_seconds
|
||||
result = {}
|
||||
for ip_address, requests in self._requests.items():
|
||||
# Count non-expired requests
|
||||
count = sum(1 for ts in requests if ts >= cutoff)
|
||||
if count > 0:
|
||||
result[ip_address] = count
|
||||
return result
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear all tracked requests (for testing)."""
|
||||
self._requests.clear()
|
||||
|
||||
Reference in New Issue
Block a user