Implement global rate limiter and refactor auth middleware
- Add global rate limiter utility with configurable limits and cleanup - Move rate limiting logic to middleware for consistent application - Update auth routes to use new rate limiter - Add comprehensive tests for rate limiter functionality - Update documentation with backend development guidelines and tasks Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -206,3 +206,134 @@ class RateLimiter:
|
||||
|
||||
# Record this failure
|
||||
failures.append(now)
|
||||
|
||||
|
||||
class GlobalRateLimiter:
|
||||
"""Global per-IP request rate limiter using sliding window algorithm.
|
||||
|
||||
Tracks total request count within a configurable time window per IP address.
|
||||
Unlike RateLimiter (which uses exponential backoff), this implements simple
|
||||
request counting: when an IP exceeds the limit, the next request is blocked
|
||||
until the oldest request in the window expires.
|
||||
|
||||
Process-local implementation — each worker maintains independent counters.
|
||||
Designed for single-worker deployments where the blast radius is isolated
|
||||
to one worker.
|
||||
|
||||
**How It Works:**
|
||||
|
||||
1. Each request is recorded as a timestamp in the IP's deque.
|
||||
2. Old timestamps outside the window are automatically removed.
|
||||
3. If the number of requests within the window exceeds the limit, the IP is
|
||||
blocked until the oldest request expires.
|
||||
4. A background cleanup task removes dormant IPs from memory periodically.
|
||||
|
||||
**Per-Endpoint Configuration:**
|
||||
|
||||
Different endpoints can have different limits. For example:
|
||||
- Login endpoint: 5 requests per 60 seconds
|
||||
- Dashboard read: 100 requests per 60 seconds
|
||||
- Config write: 20 requests per 60 seconds
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_requests: int = 200,
|
||||
window_seconds: int = 60,
|
||||
) -> None:
|
||||
"""Initialize the global rate limiter.
|
||||
|
||||
Args:
|
||||
max_requests: Maximum requests allowed within the window.
|
||||
window_seconds: Time window (seconds) for the rate limit.
|
||||
"""
|
||||
self.max_requests: int = max_requests
|
||||
self.window_seconds: int = window_seconds
|
||||
self._requests: dict[str, deque[float]] = {}
|
||||
|
||||
def check_allowed(self, ip_address: str) -> tuple[bool, float]:
|
||||
"""Check if a request from *ip_address* is allowed.
|
||||
|
||||
Returns both whether the request is allowed and the seconds to wait
|
||||
if it's not (used for Retry-After header).
|
||||
|
||||
Args:
|
||||
ip_address: The client IP address to rate-limit.
|
||||
|
||||
Returns:
|
||||
A tuple of (is_allowed, retry_after_seconds). If is_allowed is True,
|
||||
retry_after_seconds is 0. If False, it's the estimated time to wait.
|
||||
"""
|
||||
now = time()
|
||||
|
||||
if ip_address not in self._requests:
|
||||
self._requests[ip_address] = deque()
|
||||
|
||||
requests = self._requests[ip_address]
|
||||
cutoff = now - self.window_seconds
|
||||
|
||||
# Remove old requests outside the window
|
||||
while requests and requests[0] < cutoff:
|
||||
requests.popleft()
|
||||
|
||||
# If under the limit, allow the request
|
||||
if len(requests) < self.max_requests:
|
||||
requests.append(now)
|
||||
return True, 0.0
|
||||
|
||||
# Over the limit: calculate how long to wait
|
||||
# The oldest request in the window will expire in (window - age) seconds
|
||||
oldest_request = requests[0]
|
||||
age = now - oldest_request
|
||||
retry_after = self.window_seconds - age
|
||||
|
||||
# Ensure retry_after is at least 1 second (avoid 0 values)
|
||||
retry_after = max(retry_after, 1.0)
|
||||
|
||||
return False, retry_after
|
||||
|
||||
def cleanup_expired(self) -> None:
|
||||
"""Remove all IPs with no recent requests (cleanup task).
|
||||
|
||||
Called periodically by the background task to prevent unbounded
|
||||
growth of the tracking dictionary.
|
||||
"""
|
||||
now = time()
|
||||
cutoff = now - self.window_seconds
|
||||
|
||||
ips_to_remove = []
|
||||
for ip_address, requests in self._requests.items():
|
||||
# Remove old requests
|
||||
while requests and requests[0] < cutoff:
|
||||
requests.popleft()
|
||||
# Mark IP for removal if no requests remain
|
||||
if not requests:
|
||||
ips_to_remove.append(ip_address)
|
||||
|
||||
for ip_address in ips_to_remove:
|
||||
del self._requests[ip_address]
|
||||
|
||||
if ips_to_remove:
|
||||
log.debug("global_rate_limiter_cleanup", removed_ips=len(ips_to_remove))
|
||||
|
||||
def get_state(self) -> Mapping[str, int]:
|
||||
"""Return a read-only view of current request counts per IP.
|
||||
|
||||
For debugging and monitoring.
|
||||
|
||||
Returns:
|
||||
A mapping of IP addresses to their request counts.
|
||||
"""
|
||||
now = time()
|
||||
cutoff = now - self.window_seconds
|
||||
result = {}
|
||||
for ip_address, requests in self._requests.items():
|
||||
# Count non-expired requests
|
||||
count = sum(1 for ts in requests if ts >= cutoff)
|
||||
if count > 0:
|
||||
result[ip_address] = count
|
||||
return result
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear all tracked requests (for testing)."""
|
||||
self._requests.clear()
|
||||
|
||||
Reference in New Issue
Block a user