"""In-memory rate limiter for IP-based request throttling. Tracks login attempts per IP address and enforces a configurable limit. Uses a dictionary of deques (per IP) storing timestamps of recent attempts. Old entries are cleaned up by a background task to prevent unbounded growth. Process-local implementation — in multi-worker setups, each worker has independent counters. This constraint limits the blast radius of brute-force attacks to a single worker. The penalty strategy for failed login attempts is also managed here: record_failure() records a failure timestamp and returns the penalty delay to apply, enabling progressive back-off without exhausting request capacity. """ from __future__ import annotations from collections import deque from time import time from typing import TYPE_CHECKING import structlog from app.utils.constants import ( LOGIN_PENALTY_BASE_SECONDS, LOGIN_PENALTY_MAX_SECONDS, LOGIN_PENALTY_MULTIPLIER, ) if TYPE_CHECKING: from collections.abc import Mapping log: structlog.stdlib.BoundLogger = structlog.get_logger() # 5 attempts per minute per IP (300 seconds) DEFAULT_RATE_LIMIT_ATTEMPTS = 5 DEFAULT_RATE_LIMIT_WINDOW_SECONDS = 60 class RateLimiter: """Track and enforce request rate limits per IP address. Stores attempt timestamps in per-IP deques, removing old entries outside the rate limit window. """ def __init__( self, max_attempts: int = DEFAULT_RATE_LIMIT_ATTEMPTS, window_seconds: int = DEFAULT_RATE_LIMIT_WINDOW_SECONDS, ) -> None: """Initialize the rate limiter. Args: max_attempts: Maximum attempts allowed within the window. window_seconds: Time window (seconds) for rate limit. """ self.max_attempts: int = max_attempts self.window_seconds: int = window_seconds self._attempts: dict[str, deque[float]] = {} self._failures: dict[str, deque[float]] = {} self._lock_counts: dict[str, int] = {} def is_allowed(self, ip_address: str) -> bool: """Check if a request from *ip_address* is allowed. If allowed, the current timestamp is recorded. Old entries (outside the window) are removed before checking. Args: ip_address: The client IP address to rate-limit. Returns: ``True`` if the request is allowed, ``False`` if the limit is exceeded. """ now = time() cutoff = now - self.window_seconds if ip_address not in self._attempts: self._attempts[ip_address] = deque() attempts = self._attempts[ip_address] # Remove old attempts outside the window while attempts and attempts[0] < cutoff: attempts.popleft() # Check if the limit is exceeded if len(attempts) >= self.max_attempts: return False # Record this attempt attempts.append(now) return True def cleanup_expired(self) -> None: """Remove all IPs with no recent attempts (cleanup task). Called periodically by the background task to prevent unbounded growth of the tracking dictionary. """ now = time() cutoff = now - self.window_seconds ips_to_remove = [] for ip_address, attempts in self._attempts.items(): # Remove old attempts while attempts and attempts[0] < cutoff: attempts.popleft() # Mark IP for removal if no attempts remain if not attempts: ips_to_remove.append(ip_address) for ip_address in ips_to_remove: del self._attempts[ip_address] if ips_to_remove: log.debug("rate_limiter_cleanup", removed_ips=len(ips_to_remove)) def get_state(self) -> Mapping[str, int]: """Return a read-only view of current attempt counts per IP. For debugging and monitoring. Returns: A mapping of IP addresses to their attempt counts. """ now = time() cutoff = now - self.window_seconds result = {} for ip_address, attempts in self._attempts.items(): # Count non-expired attempts count = sum(1 for ts in attempts if ts >= cutoff) if count > 0: result[ip_address] = count return result def reset(self) -> None: """Clear all tracked attempts (for testing).""" self._attempts.clear() self._failures.clear() self._lock_counts.clear() # --------------------------------------------------------------------------- # Penalty strategy for failed login attempts # --------------------------------------------------------------------------- def record_failure(self, ip_address: str) -> float: """Record a failed login attempt and return the penalty delay in seconds. Tracks consecutive failures per IP. Penalty grows exponentially with each failure, bounded by :data:`~app.utils.constants.LOGIN_PENALTY_MAX_SECONDS`, then resets the failure counter. This provides brute-force resistance without exhausting request capacity. A concurrency guard (``_lock_counts``) prevents a single IP from accumulating many concurrent penalty tasks. Args: ip_address: The client IP address whose login attempt failed. Returns: The penalty delay in seconds to apply. """ now = time() if ip_address not in self._failures: self._failures[ip_address] = deque() if ip_address not in self._lock_counts: self._lock_counts[ip_address] = 0 failures = self._failures[ip_address] lock_count = self._lock_counts[ip_address] # Reset if last failure is outside the window cutoff = now - self.window_seconds while failures and failures[0] < cutoff: failures.popleft() consecutive = len(failures) penalty = min( LOGIN_PENALTY_BASE_SECONDS * (LOGIN_PENALTY_MULTIPLIER ** consecutive), LOGIN_PENALTY_MAX_SECONDS, ) failures.append(now) # Concurrency protection: if too many concurrent sleeps are already # running for this IP, cap the penalty to avoid thread exhaustion. if lock_count >= 3: penalty = min(penalty, LOGIN_PENALTY_BASE_SECONDS) return penalty def acquire(self, ip_address: str) -> bool: """Acquire a concurrency slot for a penalty task. Args: ip_address: The client IP address. Returns: ``True`` if the slot was acquired, ``False`` if the IP already has the maximum number of concurrent penalty tasks running. """ if ip_address not in self._lock_counts: self._lock_counts[ip_address] = 0 if self._lock_counts[ip_address] >= 3: return False self._lock_counts[ip_address] += 1 return True def release(self, ip_address: str) -> None: """Release a concurrency slot when a penalty task completes. Args: ip_address: The client IP address. """ if ip_address in self._lock_counts and self._lock_counts[ip_address] > 0: self._lock_counts[ip_address] -= 1