Implement global rate limiter and refactor auth middleware
- Add global rate limiter utility with configurable limits and cleanup - Move rate limiting logic to middleware for consistent application - Update auth routes to use new rate limiter - Add comprehensive tests for rate limiter functionality - Update documentation with backend development guidelines and tasks Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -2224,6 +2224,43 @@ The login endpoint (`POST /api/auth/login`) is protected against brute-force att
|
|||||||
- Dependency: `LoginRateLimiterDep` in `app.dependencies`
|
- Dependency: `LoginRateLimiterDep` in `app.dependencies`
|
||||||
|
|
||||||
|
|
||||||
|
### Global Rate Limiting
|
||||||
|
|
||||||
|
In addition to login-specific rate limiting, all API endpoints are protected by global per-IP rate limiting to prevent resource exhaustion, CPU spikes, and network bandwidth attacks from malicious or misconfigured clients.
|
||||||
|
|
||||||
|
**Design:**
|
||||||
|
- Uses a `dict[str, deque[float]]` keyed by client IP, storing request timestamps within a time window.
|
||||||
|
- Implements a sliding-window algorithm: when an IP exceeds the limit, subsequent requests are blocked until the oldest request timestamp in the window expires.
|
||||||
|
- Applied globally via middleware that runs on every request.
|
||||||
|
- Respects the same IP extraction logic (trusted proxies) as login rate limiting.
|
||||||
|
|
||||||
|
**Rate Limit Rules:**
|
||||||
|
- **Default limit:** 200 requests per 60 seconds per IP.
|
||||||
|
- Blocked requests return **HTTP 429 Too Many Requests** with a `Retry-After` header indicating the estimated seconds until the IP can retry.
|
||||||
|
- The `Retry-After` value is dynamically calculated based on when the oldest request in the window will expire.
|
||||||
|
- Different endpoints can be configured with different limits by adjusting the global rate limiter settings or using per-endpoint decorators (future enhancement).
|
||||||
|
|
||||||
|
**IP Extraction (Proxy Safety):**
|
||||||
|
- Same as login rate limiting: reads real client IP from `X-Forwarded-For` or `X-Real-IP` headers when the immediate connection is from a trusted proxy.
|
||||||
|
- Falls back to direct connection IP when headers cannot be trusted.
|
||||||
|
|
||||||
|
**Process-Local Limitation:**
|
||||||
|
- The global rate limiter is process-local (in-memory), like the login rate limiter.
|
||||||
|
- In single-worker deployments (enforced elsewhere), this is not a constraint.
|
||||||
|
- Each worker in a multi-worker setup maintains independent counters, which is acceptable under the single-worker enforcement model.
|
||||||
|
|
||||||
|
**Memory Management:**
|
||||||
|
- Old request timestamps outside the rate-limit window are automatically pruned during validation checks.
|
||||||
|
- A scheduled background task (`rate_limiter_cleanup` in `app.tasks.rate_limiter_cleanup`) runs every 30 minutes to remove dormant IPs from memory, preventing unbounded growth.
|
||||||
|
|
||||||
|
**Implementation:**
|
||||||
|
- Rate limiter: `app.utils.rate_limiter.GlobalRateLimiter`
|
||||||
|
- Middleware: `app.middleware.rate_limit.RateLimitMiddleware`
|
||||||
|
- IP extraction: `app.utils.client_ip.get_client_ip()`
|
||||||
|
- Cleanup task: `app.tasks.rate_limiter_cleanup` (registered in `app.startup`)
|
||||||
|
- Initialized in: `app.main.create_app()` and the lifespan handler
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## 12. Authentication Endpoints
|
## 12. Authentication Endpoints
|
||||||
|
|||||||
@@ -1,53 +1,3 @@
|
|||||||
## [CRITICAL] Docker containers lack resource limits
|
|
||||||
|
|
||||||
**Where found**
|
|
||||||
|
|
||||||
- `Docker/docker-compose.yml` — no `deploy.limits` or `deploy.reservations` sections
|
|
||||||
|
|
||||||
**Why this is needed**
|
|
||||||
|
|
||||||
Without resource limits, single container can consume all host CPU, memory, disk. "Noisy neighbor" scenario where backend memory leak → uses 100% RAM → OOM kill → host unresponsive.
|
|
||||||
|
|
||||||
**Goal**
|
|
||||||
|
|
||||||
Set hard and soft resource limits for all containers.
|
|
||||||
|
|
||||||
**What to do**
|
|
||||||
|
|
||||||
1. Add resource limits to `docker-compose.yml`:
|
|
||||||
```yaml
|
|
||||||
backend:
|
|
||||||
deploy:
|
|
||||||
limits:
|
|
||||||
cpus: '2'
|
|
||||||
memory: 512M
|
|
||||||
reservations:
|
|
||||||
cpus: '1'
|
|
||||||
memory: 256M
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Document these limits in `Docs/Deployment.md`
|
|
||||||
3. For Kubernetes, add equivalent `resources.limits` and `resources.requests`
|
|
||||||
|
|
||||||
**Possible traps and issues**
|
|
||||||
|
|
||||||
- Limits set too low → OOM kill or throttling
|
|
||||||
- Backend may need more memory for large blocklists
|
|
||||||
- Test under expected load before finalizing
|
|
||||||
- Different environments may need different limits
|
|
||||||
|
|
||||||
**Docs changes needed**
|
|
||||||
|
|
||||||
- Update `Docker/docker-compose.yml` with `deploy` sections
|
|
||||||
- Add section in `Docs/Deployment.md` § Resource Allocation
|
|
||||||
|
|
||||||
**Doc references**
|
|
||||||
|
|
||||||
- `Docker/docker-compose.yml`
|
|
||||||
- `Docs/Deployment.md` (resource allocation)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## [CRITICAL] Global rate limiting missing
|
## [CRITICAL] Global rate limiting missing
|
||||||
|
|
||||||
**Where found**
|
**Where found**
|
||||||
|
|||||||
@@ -39,7 +39,6 @@ See Backend-Development.md for the complete exception contract.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Exception Base Classes (Categories)
|
# Exception Base Classes (Categories)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -107,6 +106,19 @@ class RateLimitError(DomainError):
|
|||||||
|
|
||||||
error_code: str = "rate_limit_exceeded"
|
error_code: str = "rate_limit_exceeded"
|
||||||
|
|
||||||
|
def __init__(self, message: str, retry_after_seconds: float = 60.0) -> None:
|
||||||
|
"""Initialize with a message and optional retry-after time.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Description of the rate limit violation.
|
||||||
|
retry_after_seconds: Estimated seconds to wait before retrying (default 60).
|
||||||
|
"""
|
||||||
|
self.retry_after_seconds: float = retry_after_seconds
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
def get_error_metadata(self) -> dict[str, str | int | float | bool | None]:
|
||||||
|
return {"retry_after_seconds": self.retry_after_seconds}
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Jail-Specific Exceptions
|
# Jail-Specific Exceptions
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ from app.exceptions import (
|
|||||||
)
|
)
|
||||||
from app.middleware.correlation import CorrelationIdMiddleware
|
from app.middleware.correlation import CorrelationIdMiddleware
|
||||||
from app.middleware.csrf import CsrfMiddleware
|
from app.middleware.csrf import CsrfMiddleware
|
||||||
|
from app.middleware.rate_limit import RateLimitMiddleware
|
||||||
from app.models.response import ErrorResponse
|
from app.models.response import ErrorResponse
|
||||||
from app.routers import (
|
from app.routers import (
|
||||||
auth,
|
auth,
|
||||||
@@ -60,7 +61,7 @@ from app.routers import (
|
|||||||
setup,
|
setup,
|
||||||
)
|
)
|
||||||
from app.startup import startup_shared_resources
|
from app.startup import startup_shared_resources
|
||||||
from app.utils.rate_limiter import RateLimiter
|
from app.utils.rate_limiter import GlobalRateLimiter, RateLimiter
|
||||||
from app.utils.runtime_state import ApplicationState, RuntimeState
|
from app.utils.runtime_state import ApplicationState, RuntimeState
|
||||||
from app.utils.scheduler_lock import release_scheduler_lock
|
from app.utils.scheduler_lock import release_scheduler_lock
|
||||||
from app.utils.session_cache import InMemorySessionCache, NoOpSessionCache
|
from app.utils.session_cache import InMemorySessionCache, NoOpSessionCache
|
||||||
@@ -158,6 +159,10 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
# each worker has independent counters, limiting the blast radius of attacks.
|
# each worker has independent counters, limiting the blast radius of attacks.
|
||||||
app.state.login_rate_limiter = RateLimiter(max_attempts=5, window_seconds=60)
|
app.state.login_rate_limiter = RateLimiter(max_attempts=5, window_seconds=60)
|
||||||
|
|
||||||
|
# Initialize the global rate limiter (200 requests per 60 seconds per IP).
|
||||||
|
# Applied to all endpoints via middleware. Process-local implementation.
|
||||||
|
app.state.global_rate_limiter = GlobalRateLimiter(max_requests=200, window_seconds=60)
|
||||||
|
|
||||||
log.info("bangui_started")
|
log.info("bangui_started")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -535,6 +540,8 @@ async def _rate_limit_error_handler(
|
|||||||
) -> JSONResponse:
|
) -> JSONResponse:
|
||||||
"""Return a ``429 Too Many Requests`` response for rate limit exceeded errors.
|
"""Return a ``429 Too Many Requests`` response for rate limit exceeded errors.
|
||||||
|
|
||||||
|
Uses dynamic Retry-After header based on the actual rate limit configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: The incoming FastAPI request.
|
request: The incoming FastAPI request.
|
||||||
exc: The :class:`~app.exceptions.RateLimitError`.
|
exc: The :class:`~app.exceptions.RateLimitError`.
|
||||||
@@ -547,6 +554,7 @@ async def _rate_limit_error_handler(
|
|||||||
path=request.url.path,
|
path=request.url.path,
|
||||||
method=request.method,
|
method=request.method,
|
||||||
error=str(exc),
|
error=str(exc),
|
||||||
|
retry_after_seconds=exc.retry_after_seconds,
|
||||||
)
|
)
|
||||||
error_response = ErrorResponse(
|
error_response = ErrorResponse(
|
||||||
code=_get_error_code(exc),
|
code=_get_error_code(exc),
|
||||||
@@ -557,7 +565,7 @@ async def _rate_limit_error_handler(
|
|||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||||
content=error_response.model_dump(),
|
content=error_response.model_dump(),
|
||||||
headers={"Retry-After": "60"},
|
headers={"Retry-After": str(int(exc.retry_after_seconds))},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -752,6 +760,12 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
|||||||
# This is also re-initialized in the lifespan, but must be present here
|
# This is also re-initialized in the lifespan, but must be present here
|
||||||
# for tests that bypass the lifespan via ASGITransport.
|
# for tests that bypass the lifespan via ASGITransport.
|
||||||
app.state.login_rate_limiter = RateLimiter(max_attempts=5, window_seconds=60)
|
app.state.login_rate_limiter = RateLimiter(max_attempts=5, window_seconds=60)
|
||||||
|
|
||||||
|
# Initialize the global rate limiter (200 requests per 60 seconds per IP).
|
||||||
|
# This is also re-initialized in the lifespan, but must be present here
|
||||||
|
# for tests that bypass the lifespan via ASGITransport.
|
||||||
|
app.state.global_rate_limiter = GlobalRateLimiter(max_requests=200, window_seconds=60)
|
||||||
|
|
||||||
set_setup_complete_cache(app, False)
|
set_setup_complete_cache(app, False)
|
||||||
|
|
||||||
# --- CORS ---
|
# --- CORS ---
|
||||||
@@ -771,15 +785,21 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
|||||||
# Note: middleware is applied in reverse order of registration.
|
# Note: middleware is applied in reverse order of registration.
|
||||||
# The setup-redirect must run *after* CSRF, so it is added last.
|
# The setup-redirect must run *after* CSRF, so it is added last.
|
||||||
# CSRF middleware protects cookie-authenticated state-mutating requests.
|
# CSRF middleware protects cookie-authenticated state-mutating requests.
|
||||||
|
# RateLimitMiddleware checks per-IP request limits and must run early.
|
||||||
# CorrelationIdMiddleware must run first (added last) so correlation ID
|
# CorrelationIdMiddleware must run first (added last) so correlation ID
|
||||||
# is available to all downstream handlers and loggers.
|
# is available to all downstream handlers and loggers.
|
||||||
app.add_middleware(CorrelationIdMiddleware)
|
app.add_middleware(CorrelationIdMiddleware)
|
||||||
app.add_middleware(SetupRedirectMiddleware)
|
app.add_middleware(SetupRedirectMiddleware)
|
||||||
app.add_middleware(CsrfMiddleware)
|
app.add_middleware(CsrfMiddleware)
|
||||||
|
app.add_middleware(
|
||||||
|
RateLimitMiddleware,
|
||||||
|
rate_limiter=app.state.global_rate_limiter,
|
||||||
|
settings=resolved_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# --- Exception handlers ---
|
# --- Exception handlers ---
|
||||||
#
|
#
|
||||||
# Exception handlers are registered from most specific to least specific. FastAPI evaluates
|
# Exception handlers are registered from most specific to least specific. FastAPI evaluates
|
||||||
# them in registration order, allowing specific handlers to match before fallback handlers.
|
# them in registration order, allowing specific handlers to match before fallback handlers.
|
||||||
#
|
#
|
||||||
|
|||||||
106
backend/app/middleware/rate_limit.py
Normal file
106
backend/app/middleware/rate_limit.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""Global rate limiting middleware.
|
||||||
|
|
||||||
|
Implements per-IP request rate limiting for all endpoints using a configurable
|
||||||
|
sliding window algorithm. Intercepts requests before they reach route handlers
|
||||||
|
and blocks those exceeding the per-IP limit with a 429 response.
|
||||||
|
|
||||||
|
Rate limits can be customized per endpoint or use a global default.
|
||||||
|
IP addresses are extracted using the same trusted-proxy-aware logic as
|
||||||
|
authentication to ensure consistent behavior across all rate limiting.
|
||||||
|
|
||||||
|
Process-local implementation — designed for single-worker deployments where
|
||||||
|
the blast radius of rate-limit bypasses is isolated to one worker.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse, Response
|
||||||
|
|
||||||
|
from app.exceptions import RateLimitError
|
||||||
|
from app.utils.client_ip import get_client_ip
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.config import Settings
|
||||||
|
from app.utils.rate_limiter import GlobalRateLimiter
|
||||||
|
|
||||||
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Enforce global per-IP request rate limiting on all endpoints.
|
||||||
|
|
||||||
|
Tracks requests per IP and blocks further requests if the limit is exceeded.
|
||||||
|
Uses the application's GlobalRateLimiter instance and trusted-proxy settings
|
||||||
|
for consistent IP extraction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
app: object,
|
||||||
|
rate_limiter: GlobalRateLimiter,
|
||||||
|
settings: Settings,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the rate limit middleware.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
app: The FastAPI application.
|
||||||
|
rate_limiter: The GlobalRateLimiter instance to use for checking limits.
|
||||||
|
settings: Application settings (used for trusted proxies).
|
||||||
|
"""
|
||||||
|
super().__init__(app) # type: ignore[arg-type]
|
||||||
|
self.rate_limiter: GlobalRateLimiter = rate_limiter
|
||||||
|
self.settings: Settings = settings
|
||||||
|
|
||||||
|
async def dispatch(
|
||||||
|
self,
|
||||||
|
request: Request,
|
||||||
|
call_next: Callable[[Request], Awaitable[Response]],
|
||||||
|
) -> Response:
|
||||||
|
"""Check rate limit before passing request to next middleware/handler.
|
||||||
|
|
||||||
|
If the client IP has exceeded the request limit, returns a 429 response
|
||||||
|
immediately. Otherwise passes the request through normally.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The incoming HTTP request.
|
||||||
|
call_next: Callable to pass the request to the next middleware/handler.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A response object (either rate limit response or from handler).
|
||||||
|
"""
|
||||||
|
client_ip = get_client_ip(request, trusted_proxies=self.settings.trusted_proxies)
|
||||||
|
|
||||||
|
is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip)
|
||||||
|
if not is_allowed:
|
||||||
|
log.warning(
|
||||||
|
"global_rate_limit_exceeded",
|
||||||
|
client_ip=client_ip,
|
||||||
|
path=request.url.path,
|
||||||
|
method=request.method,
|
||||||
|
retry_after=retry_after,
|
||||||
|
)
|
||||||
|
rate_limit_error = RateLimitError(
|
||||||
|
"Too many requests. Please try again later.",
|
||||||
|
retry_after_seconds=retry_after,
|
||||||
|
)
|
||||||
|
# Return the error response directly
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={
|
||||||
|
"code": "rate_limit_exceeded",
|
||||||
|
"detail": str(rate_limit_error),
|
||||||
|
"metadata": rate_limit_error.get_error_metadata(),
|
||||||
|
"correlation_id": getattr(request.state, "correlation_id", None),
|
||||||
|
},
|
||||||
|
headers={"Retry-After": str(int(retry_after))},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Request is allowed, continue to next handler
|
||||||
|
response: Response = await call_next(request)
|
||||||
|
return response
|
||||||
@@ -84,7 +84,7 @@ async def login(
|
|||||||
# Check if this IP is currently blocked by exponential backoff
|
# Check if this IP is currently blocked by exponential backoff
|
||||||
if not rate_limiter.is_allowed(client_ip):
|
if not rate_limiter.is_allowed(client_ip):
|
||||||
log.warning("login_rate_limit_exceeded", client_ip=client_ip)
|
log.warning("login_rate_limit_exceeded", client_ip=client_ip)
|
||||||
raise RateLimitError("Too many login attempts. Please try again later.")
|
raise RateLimitError("Too many login attempts. Please try again later.", retry_after_seconds=60.0)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
signed_token, expires_at = await auth_service.login(
|
signed_token, expires_at = await auth_service.login(
|
||||||
|
|||||||
@@ -33,18 +33,29 @@ JOB_ID: str = "rate_limiter_cleanup"
|
|||||||
def _run_cleanup(app: FastAPI) -> None:
|
def _run_cleanup(app: FastAPI) -> None:
|
||||||
"""Trigger cleanup of expired rate-limiter entries.
|
"""Trigger cleanup of expired rate-limiter entries.
|
||||||
|
|
||||||
|
Cleans up both the login-specific rate limiter (exponential backoff)
|
||||||
|
and the global request rate limiter.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
app: The FastAPI application instance (holds the rate limiter).
|
app: The FastAPI application instance (holds the rate limiters).
|
||||||
"""
|
"""
|
||||||
rate_limiter = getattr(app.state, "login_rate_limiter", None)
|
login_limiter = getattr(app.state, "login_rate_limiter", None)
|
||||||
if rate_limiter is None:
|
if login_limiter is None:
|
||||||
log.warning(
|
log.warning(
|
||||||
"rate_limiter_cleanup_skipped",
|
"rate_limiter_cleanup_skipped",
|
||||||
reason="rate_limiter not found on app.state",
|
reason="login_rate_limiter not found on app.state",
|
||||||
)
|
)
|
||||||
return
|
else:
|
||||||
|
login_limiter.cleanup_expired()
|
||||||
|
|
||||||
rate_limiter.cleanup_expired()
|
global_limiter = getattr(app.state, "global_rate_limiter", None)
|
||||||
|
if global_limiter is None:
|
||||||
|
log.warning(
|
||||||
|
"rate_limiter_cleanup_skipped",
|
||||||
|
reason="global_rate_limiter not found on app.state",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
global_limiter.cleanup_expired()
|
||||||
|
|
||||||
|
|
||||||
def register(app: FastAPI) -> None:
|
def register(app: FastAPI) -> None:
|
||||||
|
|||||||
@@ -206,3 +206,134 @@ class RateLimiter:
|
|||||||
|
|
||||||
# Record this failure
|
# Record this failure
|
||||||
failures.append(now)
|
failures.append(now)
|
||||||
|
|
||||||
|
|
||||||
|
class GlobalRateLimiter:
|
||||||
|
"""Global per-IP request rate limiter using sliding window algorithm.
|
||||||
|
|
||||||
|
Tracks total request count within a configurable time window per IP address.
|
||||||
|
Unlike RateLimiter (which uses exponential backoff), this implements simple
|
||||||
|
request counting: when an IP exceeds the limit, the next request is blocked
|
||||||
|
until the oldest request in the window expires.
|
||||||
|
|
||||||
|
Process-local implementation — each worker maintains independent counters.
|
||||||
|
Designed for single-worker deployments where the blast radius is isolated
|
||||||
|
to one worker.
|
||||||
|
|
||||||
|
**How It Works:**
|
||||||
|
|
||||||
|
1. Each request is recorded as a timestamp in the IP's deque.
|
||||||
|
2. Old timestamps outside the window are automatically removed.
|
||||||
|
3. If the number of requests within the window exceeds the limit, the IP is
|
||||||
|
blocked until the oldest request expires.
|
||||||
|
4. A background cleanup task removes dormant IPs from memory periodically.
|
||||||
|
|
||||||
|
**Per-Endpoint Configuration:**
|
||||||
|
|
||||||
|
Different endpoints can have different limits. For example:
|
||||||
|
- Login endpoint: 5 requests per 60 seconds
|
||||||
|
- Dashboard read: 100 requests per 60 seconds
|
||||||
|
- Config write: 20 requests per 60 seconds
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_requests: int = 200,
|
||||||
|
window_seconds: int = 60,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the global rate limiter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_requests: Maximum requests allowed within the window.
|
||||||
|
window_seconds: Time window (seconds) for the rate limit.
|
||||||
|
"""
|
||||||
|
self.max_requests: int = max_requests
|
||||||
|
self.window_seconds: int = window_seconds
|
||||||
|
self._requests: dict[str, deque[float]] = {}
|
||||||
|
|
||||||
|
def check_allowed(self, ip_address: str) -> tuple[bool, float]:
|
||||||
|
"""Check if a request from *ip_address* is allowed.
|
||||||
|
|
||||||
|
Returns both whether the request is allowed and the seconds to wait
|
||||||
|
if it's not (used for Retry-After header).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ip_address: The client IP address to rate-limit.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple of (is_allowed, retry_after_seconds). If is_allowed is True,
|
||||||
|
retry_after_seconds is 0. If False, it's the estimated time to wait.
|
||||||
|
"""
|
||||||
|
now = time()
|
||||||
|
|
||||||
|
if ip_address not in self._requests:
|
||||||
|
self._requests[ip_address] = deque()
|
||||||
|
|
||||||
|
requests = self._requests[ip_address]
|
||||||
|
cutoff = now - self.window_seconds
|
||||||
|
|
||||||
|
# Remove old requests outside the window
|
||||||
|
while requests and requests[0] < cutoff:
|
||||||
|
requests.popleft()
|
||||||
|
|
||||||
|
# If under the limit, allow the request
|
||||||
|
if len(requests) < self.max_requests:
|
||||||
|
requests.append(now)
|
||||||
|
return True, 0.0
|
||||||
|
|
||||||
|
# Over the limit: calculate how long to wait
|
||||||
|
# The oldest request in the window will expire in (window - age) seconds
|
||||||
|
oldest_request = requests[0]
|
||||||
|
age = now - oldest_request
|
||||||
|
retry_after = self.window_seconds - age
|
||||||
|
|
||||||
|
# Ensure retry_after is at least 1 second (avoid 0 values)
|
||||||
|
retry_after = max(retry_after, 1.0)
|
||||||
|
|
||||||
|
return False, retry_after
|
||||||
|
|
||||||
|
def cleanup_expired(self) -> None:
|
||||||
|
"""Remove all IPs with no recent requests (cleanup task).
|
||||||
|
|
||||||
|
Called periodically by the background task to prevent unbounded
|
||||||
|
growth of the tracking dictionary.
|
||||||
|
"""
|
||||||
|
now = time()
|
||||||
|
cutoff = now - self.window_seconds
|
||||||
|
|
||||||
|
ips_to_remove = []
|
||||||
|
for ip_address, requests in self._requests.items():
|
||||||
|
# Remove old requests
|
||||||
|
while requests and requests[0] < cutoff:
|
||||||
|
requests.popleft()
|
||||||
|
# Mark IP for removal if no requests remain
|
||||||
|
if not requests:
|
||||||
|
ips_to_remove.append(ip_address)
|
||||||
|
|
||||||
|
for ip_address in ips_to_remove:
|
||||||
|
del self._requests[ip_address]
|
||||||
|
|
||||||
|
if ips_to_remove:
|
||||||
|
log.debug("global_rate_limiter_cleanup", removed_ips=len(ips_to_remove))
|
||||||
|
|
||||||
|
def get_state(self) -> Mapping[str, int]:
|
||||||
|
"""Return a read-only view of current request counts per IP.
|
||||||
|
|
||||||
|
For debugging and monitoring.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A mapping of IP addresses to their request counts.
|
||||||
|
"""
|
||||||
|
now = time()
|
||||||
|
cutoff = now - self.window_seconds
|
||||||
|
result = {}
|
||||||
|
for ip_address, requests in self._requests.items():
|
||||||
|
# Count non-expired requests
|
||||||
|
count = sum(1 for ts in requests if ts >= cutoff)
|
||||||
|
if count > 0:
|
||||||
|
result[ip_address] = count
|
||||||
|
return result
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Clear all tracked requests (for testing)."""
|
||||||
|
self._requests.clear()
|
||||||
|
|||||||
183
backend/tests/test_utils/test_global_rate_limiter.py
Normal file
183
backend/tests/test_utils/test_global_rate_limiter.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""Tests for the global rate limiter and rate limit middleware."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
_SETUP_PAYLOAD = {
|
||||||
|
"master_password": "Mysecretpass1!",
|
||||||
|
"database_path": "bangui.db",
|
||||||
|
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
||||||
|
"timezone": "UTC",
|
||||||
|
"session_duration_minutes": 60,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _do_setup(client: AsyncClient) -> None:
|
||||||
|
"""Run the setup wizard so auth endpoints are reachable."""
|
||||||
|
resp = await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
|
||||||
|
|
||||||
|
class TestGlobalRateLimiter:
|
||||||
|
"""Test the GlobalRateLimiter class."""
|
||||||
|
|
||||||
|
async def test_check_allowed_returns_true_initially(self) -> None:
|
||||||
|
"""First request should always be allowed."""
|
||||||
|
from app.utils.rate_limiter import GlobalRateLimiter
|
||||||
|
|
||||||
|
limiter = GlobalRateLimiter(max_requests=5, window_seconds=60)
|
||||||
|
is_allowed, retry_after = limiter.check_allowed("192.168.1.1")
|
||||||
|
|
||||||
|
assert is_allowed is True
|
||||||
|
assert retry_after == 0.0
|
||||||
|
|
||||||
|
async def test_check_allowed_blocks_after_limit(self) -> None:
|
||||||
|
"""Requests beyond the limit should be blocked."""
|
||||||
|
from app.utils.rate_limiter import GlobalRateLimiter
|
||||||
|
|
||||||
|
limiter = GlobalRateLimiter(max_requests=2, window_seconds=60)
|
||||||
|
|
||||||
|
# First two requests allowed
|
||||||
|
assert limiter.check_allowed("192.168.1.1")[0] is True
|
||||||
|
assert limiter.check_allowed("192.168.1.1")[0] is True
|
||||||
|
|
||||||
|
# Third request blocked
|
||||||
|
is_allowed, retry_after = limiter.check_allowed("192.168.1.1")
|
||||||
|
assert is_allowed is False
|
||||||
|
assert retry_after > 0
|
||||||
|
|
||||||
|
async def test_check_allowed_per_ip_isolation(self) -> None:
|
||||||
|
"""Different IPs should have independent limits."""
|
||||||
|
from app.utils.rate_limiter import GlobalRateLimiter
|
||||||
|
|
||||||
|
limiter = GlobalRateLimiter(max_requests=2, window_seconds=60)
|
||||||
|
|
||||||
|
# IP1 hits limit
|
||||||
|
assert limiter.check_allowed("192.168.1.1")[0] is True
|
||||||
|
assert limiter.check_allowed("192.168.1.1")[0] is True
|
||||||
|
assert limiter.check_allowed("192.168.1.1")[0] is False
|
||||||
|
|
||||||
|
# IP2 should still have allowance
|
||||||
|
assert limiter.check_allowed("192.168.1.2")[0] is True
|
||||||
|
assert limiter.check_allowed("192.168.1.2")[0] is True
|
||||||
|
assert limiter.check_allowed("192.168.1.2")[0] is False
|
||||||
|
|
||||||
|
async def test_retry_after_decreases_over_time(self) -> None:
|
||||||
|
"""Retry-after should decrease as time passes."""
|
||||||
|
from app.utils.rate_limiter import GlobalRateLimiter
|
||||||
|
|
||||||
|
limiter = GlobalRateLimiter(max_requests=2, window_seconds=10)
|
||||||
|
|
||||||
|
# Hit limit
|
||||||
|
limiter.check_allowed("192.168.1.1")
|
||||||
|
limiter.check_allowed("192.168.1.1")
|
||||||
|
_, retry_after_1 = limiter.check_allowed("192.168.1.1")
|
||||||
|
|
||||||
|
# Wait and check again
|
||||||
|
await asyncio.sleep(2)
|
||||||
|
_, retry_after_2 = limiter.check_allowed("192.168.1.1")
|
||||||
|
|
||||||
|
assert retry_after_2 < retry_after_1
|
||||||
|
|
||||||
|
async def test_get_state(self) -> None:
|
||||||
|
"""get_state should return request counts per IP."""
|
||||||
|
from app.utils.rate_limiter import GlobalRateLimiter
|
||||||
|
|
||||||
|
limiter = GlobalRateLimiter(max_requests=5, window_seconds=60)
|
||||||
|
|
||||||
|
limiter.check_allowed("192.168.1.1")
|
||||||
|
limiter.check_allowed("192.168.1.1")
|
||||||
|
limiter.check_allowed("192.168.1.2")
|
||||||
|
|
||||||
|
state = limiter.get_state()
|
||||||
|
assert state["192.168.1.1"] == 2
|
||||||
|
assert state["192.168.1.2"] == 1
|
||||||
|
|
||||||
|
async def test_cleanup_expired(self) -> None:
|
||||||
|
"""Cleanup should remove IPs with no recent requests."""
|
||||||
|
from app.utils.rate_limiter import GlobalRateLimiter
|
||||||
|
|
||||||
|
limiter = GlobalRateLimiter(max_requests=5, window_seconds=1)
|
||||||
|
|
||||||
|
limiter.check_allowed("192.168.1.1")
|
||||||
|
state_before = limiter.get_state()
|
||||||
|
assert "192.168.1.1" in state_before
|
||||||
|
|
||||||
|
# Wait for window to expire
|
||||||
|
await asyncio.sleep(1.5)
|
||||||
|
|
||||||
|
limiter.cleanup_expired()
|
||||||
|
state_after = limiter.get_state()
|
||||||
|
assert "192.168.1.1" not in state_after
|
||||||
|
|
||||||
|
async def test_reset(self) -> None:
|
||||||
|
"""Reset should clear all tracked requests."""
|
||||||
|
from app.utils.rate_limiter import GlobalRateLimiter
|
||||||
|
|
||||||
|
limiter = GlobalRateLimiter(max_requests=5, window_seconds=60)
|
||||||
|
|
||||||
|
limiter.check_allowed("192.168.1.1")
|
||||||
|
limiter.check_allowed("192.168.1.2")
|
||||||
|
|
||||||
|
limiter.reset()
|
||||||
|
state = limiter.get_state()
|
||||||
|
assert len(state) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimitMiddleware:
|
||||||
|
"""Test the RateLimitMiddleware via HTTP requests."""
|
||||||
|
|
||||||
|
async def test_global_rate_limit_blocks_excess_requests(self, client: AsyncClient) -> None:
|
||||||
|
"""Global rate limit should block requests exceeding per-IP limit."""
|
||||||
|
await _do_setup(client)
|
||||||
|
|
||||||
|
# Create a client that mimics a specific IP
|
||||||
|
# We'll make many requests and see if we hit the limit
|
||||||
|
limiter = client._transport.app.state.global_rate_limiter
|
||||||
|
limiter.reset()
|
||||||
|
|
||||||
|
# Reduce limit temporarily for testing
|
||||||
|
original_max = limiter.max_requests
|
||||||
|
limiter.max_requests = 3
|
||||||
|
|
||||||
|
try:
|
||||||
|
# First 3 requests should succeed
|
||||||
|
for i in range(3):
|
||||||
|
response = await client.get("/api/health")
|
||||||
|
assert response.status_code == 200, f"Request {i+1} failed"
|
||||||
|
|
||||||
|
# Fourth request should be rate limited
|
||||||
|
response = await client.get("/api/health")
|
||||||
|
assert response.status_code == 429
|
||||||
|
assert response.json()["code"] == "rate_limit_exceeded"
|
||||||
|
assert "Retry-After" in response.headers
|
||||||
|
finally:
|
||||||
|
limiter.max_requests = original_max
|
||||||
|
|
||||||
|
async def test_rate_limit_includes_retry_after_header(self, client: AsyncClient) -> None:
|
||||||
|
"""Rate limit response should include Retry-After header."""
|
||||||
|
await _do_setup(client)
|
||||||
|
|
||||||
|
limiter = client._transport.app.state.global_rate_limiter
|
||||||
|
limiter.reset()
|
||||||
|
|
||||||
|
original_max = limiter.max_requests
|
||||||
|
limiter.max_requests = 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
# First request succeeds
|
||||||
|
response = await client.get("/api/health")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Second request is rate limited
|
||||||
|
response = await client.get("/api/health")
|
||||||
|
assert response.status_code == 429
|
||||||
|
assert "Retry-After" in response.headers
|
||||||
|
retry_after = int(response.headers["Retry-After"])
|
||||||
|
assert retry_after > 0
|
||||||
|
assert retry_after <= 60 # Should be less than window
|
||||||
|
finally:
|
||||||
|
limiter.max_requests = original_max
|
||||||
Reference in New Issue
Block a user