"""In-memory global rate limiter for IP-based request throttling. Implements a sliding-window request counter per IP address. Old entries are cleaned up by a background task to prevent unbounded growth. Process-local implementation — in multi-worker setups, each worker has independent counters. This constraint limits the blast radius of abuse to a single worker. **Cleanup Lifecycle**: The rate limiter state grows as IPs interact with the system. To prevent unbounded memory growth during long runtimes, a scheduled background task (rate_limiter_cleanup) calls cleanup_expired() every 30 minutes. This is safe because: - cleanup_expired() only removes IPs with no recent requests (all timestamps outside the rate-limit window), so active IPs are never disrupted. - The cleanup is non-blocking and logged for observability. - Individual requests already prune old timestamps from each IP's deque during check_allowed(), so cleanup primarily handles dormant IPs. For monitoring, check logs for "global_rate_limiter_cleanup" events to observe how many IPs are being retired from memory each cleanup cycle. """ from __future__ import annotations from collections import deque from time import time from typing import TYPE_CHECKING from app.utils.logging_compat import get_logger from app.utils.ip_utils import normalise_ip if TYPE_CHECKING: from collections.abc import Mapping log = get_logger(__name__) class GlobalRateLimiter: """Global per-IP request rate limiter using sliding window algorithm. Tracks total request count within a configurable time window per IP address. This implements simple request counting: when an IP exceeds the limit, the next request is blocked until the oldest request in the window expires. **Process-local implementation** — Each worker maintains independent counters. In multi-worker deployments (N workers), an attacker can send up to N × limit requests before any single worker triggers a block. The single-worker scheduler lock provides partial protection, but deployments requiring horizontal scaling should replace this with a Redis-backed store using atomic INCR + EXPIRE. **Long-term migration path:** The check_allowed() and check_allowed_for_bucket() interfaces map directly to Redis INCR + EXPIRE. A drop-in RedisRateLimiter adapter would only need to replace the deque-based in-memory store with Redis calls, without touching any caller code. **How It Works:** 1. Each request is recorded as a timestamp in the IP's deque. 2. Old timestamps outside the window are automatically removed. 3. If the number of requests within the window exceeds the limit, the IP is blocked until the oldest request expires. 4. A background cleanup task removes dormant IPs from memory periodically. **Per-Bucket Configuration:** Different endpoints can have different limits via named buckets: - `bans:ban` — 100/minute per IP (ban operations) - `bans:unban` — 100/minute per IP (unban operations) - `blocklist:import` — 10/hour per IP (import operations) - `config:update` — 50/minute per IP (config write operations) Each bucket tracks its own requests independently, so hitting the blocklist:import limit does not affect the bans:ban limit. """ def __init__( self, max_requests: int = 200, window_seconds: int = 60, ) -> None: """Initialize the global rate limiter. Args: max_requests: Maximum requests allowed within the window. window_seconds: Time window (seconds) for the rate limit. """ self.max_requests: int = max_requests self.window_seconds: int = window_seconds self._requests: dict[str, deque[float]] = {} self._buckets: dict[str, dict[str, deque[float]]] = {} def _get_bucket_deque( self, bucket: str, ip_address: str, max_requests: int, window_seconds: int, ) -> deque[float]: """Get or create the deque for a specific bucket and IP. Args: bucket: Bucket name (e.g., "bans:ban"). ip_address: Client IP address. max_requests: Maximum requests for this bucket (unused, for future). window_seconds: Window in seconds (unused, for future). Returns: The deque of timestamps for this bucket+IP. """ if bucket not in self._buckets: self._buckets[bucket] = {} bucket_dict = self._buckets[bucket] if ip_address not in bucket_dict: bucket_dict[ip_address] = deque() return bucket_dict[ip_address] def check_allowed(self, ip_address: str) -> tuple[bool, float]: """Check if a request from *ip_address* is allowed. Returns both whether the request is allowed and the seconds to wait if it's not (used for Retry-After header). Args: ip_address: The client IP address to rate-limit. Returns: A tuple of (is_allowed, retry_after_seconds). If is_allowed is True, retry_after_seconds is 0. If False, it's the estimated time to wait. """ ip_address = normalise_ip(ip_address) now = time() if ip_address not in self._requests: self._requests[ip_address] = deque() requests = self._requests[ip_address] cutoff = now - self.window_seconds # Remove old requests outside the window while requests and requests[0] < cutoff: requests.popleft() # If under the limit, allow the request if len(requests) < self.max_requests: requests.append(now) return True, 0.0 # Over the limit: calculate how long to wait # The oldest request in the window will expire in (window - age) seconds oldest_request = requests[0] age = now - oldest_request retry_after = self.window_seconds - age # Ensure retry_after is at least 1 second (avoid 0 values) retry_after = max(retry_after, 1.0) return False, retry_after def check_allowed_for_bucket( self, bucket: str, ip_address: str, max_requests: int, window_seconds: int, ) -> tuple[bool, float]: """Check if a request for a specific bucket is allowed. Each bucket has independent rate limiting. This allows different endpoints to have different limits (e.g., blocklist import is more restrictive than ban operations). Args: bucket: Bucket name (e.g., "bans:ban", "blocklist:import"). ip_address: The client IP address to rate-limit. max_requests: Maximum requests allowed within the window. window_seconds: Time window (seconds) for this bucket. Returns: A tuple of (is_allowed, retry_after_seconds). If is_allowed is True, retry_after_seconds is 0. If False, it's the estimated time to wait. """ now = time() ip_address = normalise_ip(ip_address) requests = self._get_bucket_deque(bucket, ip_address, max_requests, window_seconds) cutoff = now - window_seconds # Remove old requests outside the window while requests and requests[0] < cutoff: requests.popleft() # If under the limit, allow the request if len(requests) < max_requests: requests.append(now) return True, 0.0 # Over the limit: calculate how long to wait oldest_request = requests[0] age = now - oldest_request retry_after = window_seconds - age # Ensure retry_after is at least 1 second retry_after = max(retry_after, 1.0) return False, retry_after def cleanup_expired(self) -> None: """Remove all IPs with no recent requests (cleanup task). Called periodically by the background task to prevent unbounded growth of the tracking dictionary. """ now = time() cutoff = now - self.window_seconds ips_to_remove = [] for ip_address, requests in self._requests.items(): # Remove old requests while requests and requests[0] < cutoff: requests.popleft() # Mark IP for removal if no requests remain if not requests: ips_to_remove.append(ip_address) for ip_address in ips_to_remove: del self._requests[ip_address] if ips_to_remove: log.debug("global_rate_limiter_cleanup", removed_ips=len(ips_to_remove)) # Cleanup per-bucket dictionaries for bucket, bucket_dict in list(self._buckets.items()): bucket_ips_to_remove = [] bucket_window = 60 # Use a reasonable window for bucket cleanup bucket_cutoff = now - bucket_window for ip_address, requests in bucket_dict.items(): while requests and requests[0] < bucket_cutoff: requests.popleft() if not requests: bucket_ips_to_remove.append(ip_address) for ip_address in bucket_ips_to_remove: del bucket_dict[ip_address] if not bucket_dict: del self._buckets[bucket] def get_state(self) -> Mapping[str, int]: """Return a read-only view of current request counts per IP. For debugging and monitoring. Returns: A mapping of IP addresses to their request counts. """ now = time() cutoff = now - self.window_seconds result = {} for ip_address, requests in self._requests.items(): # Count non-expired requests count = sum(1 for ts in requests if ts >= cutoff) if count > 0: result[ip_address] = count return result def get_bucket_state(self, bucket: str) -> Mapping[str, int]: """Return a read-only view of current request counts per IP for a bucket. For debugging and monitoring. Args: bucket: Bucket name to get state for. Returns: A mapping of IP addresses to their request counts in this bucket. """ if bucket not in self._buckets: return {} now = time() result = {} for ip_address, requests in self._buckets[bucket].items(): # Count non-expired requests (use max window of 3600s for hourly buckets) cutoff = now - 3600 count = sum(1 for ts in requests if ts >= cutoff) if count > 0: result[ip_address] = count return result def reset(self) -> None: """Clear all tracked requests (for testing).""" self._requests.clear() self._buckets.clear()