Implement periodic cleanup of expired rate-limiter entries to prevent unbounded memory growth during long runtimes. Changes: - Create rate_limiter_cleanup task that calls cleanup_expired() every 30 minutes - Register the task in the startup DAG alongside other background jobs - Update rate_limiter module documentation with operational notes about the cleanup lifecycle and memory management strategy The cleanup is conservative and only removes IPs with no recent attempts (all timestamps outside the rate-limit window), so active IPs are preserved. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
239 lines
8.0 KiB
Python
239 lines
8.0 KiB
Python
"""In-memory rate limiter for IP-based request throttling.
|
|
|
|
Tracks login attempts per IP address and enforces a configurable limit.
|
|
Uses a dictionary of deques (per IP) storing timestamps of recent attempts.
|
|
Old entries are cleaned up by a background task to prevent unbounded growth.
|
|
|
|
Process-local implementation — in multi-worker setups, each worker has
|
|
independent counters. This constraint limits the blast radius of brute-force
|
|
attacks to a single worker.
|
|
|
|
The penalty strategy for failed login attempts is also managed here:
|
|
record_failure() records a failure timestamp and returns the penalty delay
|
|
to apply, enabling progressive back-off without exhausting request capacity.
|
|
|
|
Operational Notes
|
|
-----------------
|
|
|
|
**Cleanup Lifecycle**: The rate limiter state (_attempts, _failures, _lock_counts)
|
|
grows as IPs interact with the system. To prevent unbounded memory growth during
|
|
long runtimes, a scheduled background task (rate_limiter_cleanup) calls the
|
|
cleanup_expired() method every 30 minutes. This is safe because:
|
|
|
|
- cleanup_expired() only removes IPs with no recent attempts (all timestamps
|
|
outside the rate-limit window), so active IPs are never disrupted.
|
|
- The cleanup is non-blocking and logged for observability.
|
|
- Individual requests already prune old timestamps from each IP's deque during
|
|
is_allowed() and record_failure(), so cleanup primarily handles dormant IPs.
|
|
|
|
For monitoring, check logs for "rate_limiter_cleanup" events to observe how
|
|
many IPs are being retired from memory each cleanup cycle.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections import deque
|
|
from time import time
|
|
from typing import TYPE_CHECKING
|
|
|
|
import structlog
|
|
|
|
from app.utils.constants import (
|
|
LOGIN_PENALTY_BASE_SECONDS,
|
|
LOGIN_PENALTY_MAX_SECONDS,
|
|
LOGIN_PENALTY_MULTIPLIER,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Mapping
|
|
|
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
|
|
|
# 5 attempts per minute per IP (300 seconds)
|
|
DEFAULT_RATE_LIMIT_ATTEMPTS = 5
|
|
DEFAULT_RATE_LIMIT_WINDOW_SECONDS = 60
|
|
|
|
|
|
class RateLimiter:
|
|
"""Track and enforce request rate limits per IP address.
|
|
|
|
Stores attempt timestamps in per-IP deques, removing old entries
|
|
outside the rate limit window.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
max_attempts: int = DEFAULT_RATE_LIMIT_ATTEMPTS,
|
|
window_seconds: int = DEFAULT_RATE_LIMIT_WINDOW_SECONDS,
|
|
) -> None:
|
|
"""Initialize the rate limiter.
|
|
|
|
Args:
|
|
max_attempts: Maximum attempts allowed within the window.
|
|
window_seconds: Time window (seconds) for rate limit.
|
|
"""
|
|
self.max_attempts: int = max_attempts
|
|
self.window_seconds: int = window_seconds
|
|
self._attempts: dict[str, deque[float]] = {}
|
|
self._failures: dict[str, deque[float]] = {}
|
|
self._lock_counts: dict[str, int] = {}
|
|
|
|
def is_allowed(self, ip_address: str) -> bool:
|
|
"""Check if a request from *ip_address* is allowed.
|
|
|
|
If allowed, the current timestamp is recorded. Old entries (outside
|
|
the window) are removed before checking.
|
|
|
|
Args:
|
|
ip_address: The client IP address to rate-limit.
|
|
|
|
Returns:
|
|
``True`` if the request is allowed, ``False`` if the limit is exceeded.
|
|
"""
|
|
now = time()
|
|
cutoff = now - self.window_seconds
|
|
|
|
if ip_address not in self._attempts:
|
|
self._attempts[ip_address] = deque()
|
|
|
|
attempts = self._attempts[ip_address]
|
|
|
|
# Remove old attempts outside the window
|
|
while attempts and attempts[0] < cutoff:
|
|
attempts.popleft()
|
|
|
|
# Check if the limit is exceeded
|
|
if len(attempts) >= self.max_attempts:
|
|
return False
|
|
|
|
# Record this attempt
|
|
attempts.append(now)
|
|
return True
|
|
|
|
def cleanup_expired(self) -> None:
|
|
"""Remove all IPs with no recent attempts (cleanup task).
|
|
|
|
Called periodically by the background task to prevent unbounded
|
|
growth of the tracking dictionary.
|
|
"""
|
|
now = time()
|
|
cutoff = now - self.window_seconds
|
|
|
|
ips_to_remove = []
|
|
for ip_address, attempts in self._attempts.items():
|
|
# Remove old attempts
|
|
while attempts and attempts[0] < cutoff:
|
|
attempts.popleft()
|
|
# Mark IP for removal if no attempts remain
|
|
if not attempts:
|
|
ips_to_remove.append(ip_address)
|
|
|
|
for ip_address in ips_to_remove:
|
|
del self._attempts[ip_address]
|
|
|
|
if ips_to_remove:
|
|
log.debug("rate_limiter_cleanup", removed_ips=len(ips_to_remove))
|
|
|
|
def get_state(self) -> Mapping[str, int]:
|
|
"""Return a read-only view of current attempt counts per IP.
|
|
|
|
For debugging and monitoring.
|
|
|
|
Returns:
|
|
A mapping of IP addresses to their attempt counts.
|
|
"""
|
|
now = time()
|
|
cutoff = now - self.window_seconds
|
|
result = {}
|
|
for ip_address, attempts in self._attempts.items():
|
|
# Count non-expired attempts
|
|
count = sum(1 for ts in attempts if ts >= cutoff)
|
|
if count > 0:
|
|
result[ip_address] = count
|
|
return result
|
|
|
|
def reset(self) -> None:
|
|
"""Clear all tracked attempts (for testing)."""
|
|
self._attempts.clear()
|
|
self._failures.clear()
|
|
self._lock_counts.clear()
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Penalty strategy for failed login attempts
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def record_failure(self, ip_address: str) -> float:
|
|
"""Record a failed login attempt and return the penalty delay in seconds.
|
|
|
|
Tracks consecutive failures per IP. Penalty grows exponentially with
|
|
each failure, bounded by :data:`~app.utils.constants.LOGIN_PENALTY_MAX_SECONDS`,
|
|
then resets the failure counter. This provides brute-force resistance
|
|
without exhausting request capacity.
|
|
|
|
A concurrency guard (``_lock_counts``) prevents a single IP from
|
|
accumulating many concurrent penalty tasks.
|
|
|
|
Args:
|
|
ip_address: The client IP address whose login attempt failed.
|
|
|
|
Returns:
|
|
The penalty delay in seconds to apply.
|
|
"""
|
|
now = time()
|
|
|
|
if ip_address not in self._failures:
|
|
self._failures[ip_address] = deque()
|
|
if ip_address not in self._lock_counts:
|
|
self._lock_counts[ip_address] = 0
|
|
|
|
failures = self._failures[ip_address]
|
|
lock_count = self._lock_counts[ip_address]
|
|
|
|
# Reset if last failure is outside the window
|
|
cutoff = now - self.window_seconds
|
|
while failures and failures[0] < cutoff:
|
|
failures.popleft()
|
|
|
|
consecutive = len(failures)
|
|
penalty = min(
|
|
LOGIN_PENALTY_BASE_SECONDS * (LOGIN_PENALTY_MULTIPLIER ** consecutive),
|
|
LOGIN_PENALTY_MAX_SECONDS,
|
|
)
|
|
|
|
failures.append(now)
|
|
|
|
# Concurrency protection: if too many concurrent sleeps are already
|
|
# running for this IP, cap the penalty to avoid thread exhaustion.
|
|
if lock_count >= 3:
|
|
penalty = min(penalty, LOGIN_PENALTY_BASE_SECONDS)
|
|
|
|
return penalty
|
|
|
|
def acquire(self, ip_address: str) -> bool:
|
|
"""Acquire a concurrency slot for a penalty task.
|
|
|
|
Args:
|
|
ip_address: The client IP address.
|
|
|
|
Returns:
|
|
``True`` if the slot was acquired, ``False`` if the IP already has
|
|
the maximum number of concurrent penalty tasks running.
|
|
"""
|
|
if ip_address not in self._lock_counts:
|
|
self._lock_counts[ip_address] = 0
|
|
|
|
if self._lock_counts[ip_address] >= 3:
|
|
return False
|
|
|
|
self._lock_counts[ip_address] += 1
|
|
return True
|
|
|
|
def release(self, ip_address: str) -> None:
|
|
"""Release a concurrency slot when a penalty task completes.
|
|
|
|
Args:
|
|
ip_address: The client IP address.
|
|
"""
|
|
if ip_address in self._lock_counts and self._lock_counts[ip_address] > 0:
|
|
self._lock_counts[ip_address] -= 1
|