- Add per-bucket rate limit config (ban, unban, import, config, jail, filter, action) - Add process-local warning at startup for multi-worker deployments - Document Redis migration path for shared state across workers - Remove Issue #42 from Tasks.md (resolved)
468 lines
17 KiB
Python
468 lines
17 KiB
Python
"""In-memory rate limiter for IP-based request throttling.
|
||
|
||
Implements exponential backoff for failed login attempts using failure tracking.
|
||
Each wrong password attempt increments the failure count for that IP, and subsequent
|
||
attempts are blocked for a duration that grows exponentially up to a maximum.
|
||
|
||
Uses a dictionary of deques (per IP) storing timestamps of recent failures.
|
||
Old entries are cleaned up by a background task to prevent unbounded growth.
|
||
|
||
Process-local implementation — in multi-worker setups, each worker has
|
||
independent counters. This constraint limits the blast radius of brute-force
|
||
attacks to a single worker.
|
||
|
||
**How It Works:**
|
||
|
||
1. A successful login resets the failure counter for that IP.
|
||
2. Each failed login (wrong password) calls record_failure() and increments the counter.
|
||
3. is_allowed() checks if enough time has passed since the last failure based on
|
||
the current failure count. The delay grows exponentially with each consecutive failure:
|
||
|
||
- 1st failure: 0.5 second penalty
|
||
- 2nd failure: 1 second penalty (0.5 * 2^1)
|
||
- 3rd failure: 2 seconds penalty (0.5 * 2^2)
|
||
- 4th failure: 4 seconds penalty (0.5 * 2^3)
|
||
- ... up to the configured maximum (default 5 seconds)
|
||
|
||
4. Penalties are cumulative within the window: if an attacker makes 5 failed
|
||
attempts, they must wait the full 5 seconds before trying again (not 5 seconds
|
||
per attempt).
|
||
|
||
**Cleanup Lifecycle**: The rate limiter state (_failures) grows as IPs interact
|
||
with the system. To prevent unbounded memory growth during long runtimes, a
|
||
scheduled background task (rate_limiter_cleanup) calls cleanup_expired() every
|
||
30 minutes. This is safe because:
|
||
|
||
- cleanup_expired() only removes IPs with no recent failures (all timestamps
|
||
outside the rate-limit window), so active IPs are never disrupted.
|
||
- The cleanup is non-blocking and logged for observability.
|
||
- Individual requests already prune old timestamps from each IP's deque during
|
||
is_allowed() and record_failure(), so cleanup primarily handles dormant IPs.
|
||
|
||
For monitoring, check logs for "rate_limiter_cleanup" events to observe how
|
||
many IPs are being retired from memory each cleanup cycle.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from collections import deque
|
||
from time import time
|
||
from typing import TYPE_CHECKING
|
||
|
||
import structlog
|
||
|
||
from app.utils.constants import (
|
||
LOGIN_PENALTY_BASE_SECONDS,
|
||
LOGIN_PENALTY_MAX_SECONDS,
|
||
LOGIN_PENALTY_MULTIPLIER,
|
||
)
|
||
from app.utils.ip_utils import normalise_ip
|
||
|
||
if TYPE_CHECKING:
|
||
from collections.abc import Mapping
|
||
|
||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||
|
||
# 5 attempts per minute per IP (300 seconds)
|
||
DEFAULT_RATE_LIMIT_ATTEMPTS = 5
|
||
DEFAULT_RATE_LIMIT_WINDOW_SECONDS = 60
|
||
|
||
|
||
class RateLimiter:
|
||
"""Track and enforce request rate limits per IP address.
|
||
|
||
Stores attempt timestamps in per-IP deques, removing old entries
|
||
outside the rate limit window.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
max_attempts: int = DEFAULT_RATE_LIMIT_ATTEMPTS,
|
||
window_seconds: int = DEFAULT_RATE_LIMIT_WINDOW_SECONDS,
|
||
) -> None:
|
||
"""Initialize the rate limiter.
|
||
|
||
Args:
|
||
max_attempts: Maximum attempts allowed within the window.
|
||
(Deprecated: now only used for cleanup window size)
|
||
window_seconds: Time window (seconds) for rate limit.
|
||
"""
|
||
self.max_attempts: int = max_attempts
|
||
self.window_seconds: int = window_seconds
|
||
self._failures: dict[str, deque[float]] = {}
|
||
|
||
def is_allowed(self, ip_address: str) -> bool:
|
||
"""Check if a request from *ip_address* is allowed.
|
||
|
||
Checks if the IP has accumulated failures that would currently block
|
||
the attempt due to penalty backoff. Does NOT record a new attempt —
|
||
that happens only on successful password verification.
|
||
|
||
Args:
|
||
ip_address: The client IP address to rate-limit.
|
||
|
||
Returns:
|
||
``True`` if the request is allowed (past penalty period), ``False``
|
||
if currently blocked by exponential backoff.
|
||
"""
|
||
ip_address = normalise_ip(ip_address)
|
||
now = time()
|
||
|
||
if ip_address not in self._failures:
|
||
self._failures[ip_address] = deque()
|
||
|
||
failures = self._failures[ip_address]
|
||
cutoff = now - self.window_seconds
|
||
|
||
# Remove old failures outside the window
|
||
while failures and failures[0] < cutoff:
|
||
failures.popleft()
|
||
|
||
# If no recent failures, request is allowed
|
||
if not failures:
|
||
return True
|
||
|
||
# Calculate accumulated penalty: how much time must pass before
|
||
# the next attempt is allowed, based on failure count
|
||
failure_count = len(failures)
|
||
penalty = min(
|
||
LOGIN_PENALTY_BASE_SECONDS * (LOGIN_PENALTY_MULTIPLIER ** failure_count),
|
||
LOGIN_PENALTY_MAX_SECONDS,
|
||
)
|
||
|
||
# Check if enough time has passed since the last failure
|
||
time_since_last_failure = now - failures[-1]
|
||
return time_since_last_failure >= penalty
|
||
|
||
def cleanup_expired(self) -> None:
|
||
"""Remove all IPs with no recent failures (cleanup task).
|
||
|
||
Called periodically by the background task to prevent unbounded
|
||
growth of the tracking dictionary.
|
||
"""
|
||
now = time()
|
||
cutoff = now - self.window_seconds
|
||
|
||
ips_to_remove = []
|
||
for ip_address, failures in self._failures.items():
|
||
# Remove old failures
|
||
while failures and failures[0] < cutoff:
|
||
failures.popleft()
|
||
# Mark IP for removal if no failures remain
|
||
if not failures:
|
||
ips_to_remove.append(ip_address)
|
||
|
||
for ip_address in ips_to_remove:
|
||
del self._failures[ip_address]
|
||
|
||
if ips_to_remove:
|
||
log.debug("rate_limiter_cleanup", removed_ips=len(ips_to_remove))
|
||
|
||
def get_state(self) -> Mapping[str, int]:
|
||
"""Return a read-only view of current failure counts per IP.
|
||
|
||
For debugging and monitoring.
|
||
|
||
Returns:
|
||
A mapping of IP addresses to their failure counts.
|
||
"""
|
||
now = time()
|
||
cutoff = now - self.window_seconds
|
||
result = {}
|
||
for ip_address, failures in self._failures.items():
|
||
# Count non-expired failures
|
||
count = sum(1 for ts in failures if ts >= cutoff)
|
||
if count > 0:
|
||
result[ip_address] = count
|
||
return result
|
||
|
||
def reset(self) -> None:
|
||
"""Clear all tracked failures (for testing)."""
|
||
self._failures.clear()
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Penalty strategy for failed login attempts
|
||
# ---------------------------------------------------------------------------
|
||
|
||
def record_failure(self, ip_address: str) -> None:
|
||
"""Record a failed login attempt.
|
||
|
||
Tracks failures per IP to enable exponential backoff in is_allowed().
|
||
The penalty delay is automatically calculated in is_allowed() based on
|
||
the failure count, providing transparent brute-force resistance.
|
||
|
||
Args:
|
||
ip_address: The client IP address whose login attempt failed.
|
||
"""
|
||
ip_address = normalise_ip(ip_address)
|
||
now = time()
|
||
|
||
if ip_address not in self._failures:
|
||
self._failures[ip_address] = deque()
|
||
|
||
failures = self._failures[ip_address]
|
||
cutoff = now - self.window_seconds
|
||
|
||
# Remove old failures outside the window
|
||
while failures and failures[0] < cutoff:
|
||
failures.popleft()
|
||
|
||
# Record this failure
|
||
failures.append(now)
|
||
|
||
|
||
class GlobalRateLimiter:
|
||
"""Global per-IP request rate limiter using sliding window algorithm.
|
||
|
||
Tracks total request count within a configurable time window per IP address.
|
||
Unlike RateLimiter (which uses exponential backoff), this implements simple
|
||
request counting: when an IP exceeds the limit, the next request is blocked
|
||
until the oldest request in the window expires.
|
||
|
||
**Process-local implementation** — Each worker maintains independent counters.
|
||
In multi-worker deployments (N workers), an attacker can send up to N × limit
|
||
requests before any single worker triggers a block. The single-worker scheduler
|
||
lock provides partial protection, but deployments requiring horizontal scaling
|
||
should replace this with a Redis-backed store using atomic INCR + EXPIRE.
|
||
|
||
**Long-term migration path:** The check_allowed() and check_allowed_for_bucket()
|
||
interfaces map directly to Redis INCR + EXPIRE. A drop-in RedisRateLimiter
|
||
adapter would only need to replace the deque-based in-memory store with Redis
|
||
calls, without touching any caller code.
|
||
|
||
**How It Works:**
|
||
|
||
1. Each request is recorded as a timestamp in the IP's deque.
|
||
2. Old timestamps outside the window are automatically removed.
|
||
3. If the number of requests within the window exceeds the limit, the IP is
|
||
blocked until the oldest request expires.
|
||
4. A background cleanup task removes dormant IPs from memory periodically.
|
||
|
||
**Per-Bucket Configuration:**
|
||
|
||
Different endpoints can have different limits via named buckets:
|
||
- `bans:ban` — 100/minute per IP (ban operations)
|
||
- `bans:unban` — 100/minute per IP (unban operations)
|
||
- `blocklist:import` — 10/hour per IP (import operations)
|
||
- `config:update` — 50/minute per IP (config write operations)
|
||
|
||
Each bucket tracks its own requests independently, so hitting the
|
||
blocklist:import limit does not affect the bans:ban limit.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
max_requests: int = 200,
|
||
window_seconds: int = 60,
|
||
) -> None:
|
||
"""Initialize the global rate limiter.
|
||
|
||
Args:
|
||
max_requests: Maximum requests allowed within the window.
|
||
window_seconds: Time window (seconds) for the rate limit.
|
||
"""
|
||
self.max_requests: int = max_requests
|
||
self.window_seconds: int = window_seconds
|
||
self._requests: dict[str, deque[float]] = {}
|
||
self._buckets: dict[str, dict[str, deque[float]]] = {}
|
||
|
||
def _get_bucket_deque(
|
||
self,
|
||
bucket: str,
|
||
ip_address: str,
|
||
max_requests: int,
|
||
window_seconds: int,
|
||
) -> deque[float]:
|
||
"""Get or create the deque for a specific bucket and IP.
|
||
|
||
Args:
|
||
bucket: Bucket name (e.g., "bans:ban").
|
||
ip_address: Client IP address.
|
||
max_requests: Maximum requests for this bucket (unused, for future).
|
||
window_seconds: Window in seconds (unused, for future).
|
||
|
||
Returns:
|
||
The deque of timestamps for this bucket+IP.
|
||
"""
|
||
if bucket not in self._buckets:
|
||
self._buckets[bucket] = {}
|
||
bucket_dict = self._buckets[bucket]
|
||
if ip_address not in bucket_dict:
|
||
bucket_dict[ip_address] = deque()
|
||
return bucket_dict[ip_address]
|
||
|
||
def check_allowed(self, ip_address: str) -> tuple[bool, float]:
|
||
"""Check if a request from *ip_address* is allowed.
|
||
|
||
Returns both whether the request is allowed and the seconds to wait
|
||
if it's not (used for Retry-After header).
|
||
|
||
Args:
|
||
ip_address: The client IP address to rate-limit.
|
||
|
||
Returns:
|
||
A tuple of (is_allowed, retry_after_seconds). If is_allowed is True,
|
||
retry_after_seconds is 0. If False, it's the estimated time to wait.
|
||
"""
|
||
ip_address = normalise_ip(ip_address)
|
||
now = time()
|
||
|
||
if ip_address not in self._requests:
|
||
self._requests[ip_address] = deque()
|
||
|
||
requests = self._requests[ip_address]
|
||
cutoff = now - self.window_seconds
|
||
|
||
# Remove old requests outside the window
|
||
while requests and requests[0] < cutoff:
|
||
requests.popleft()
|
||
|
||
# If under the limit, allow the request
|
||
if len(requests) < self.max_requests:
|
||
requests.append(now)
|
||
return True, 0.0
|
||
|
||
# Over the limit: calculate how long to wait
|
||
# The oldest request in the window will expire in (window - age) seconds
|
||
oldest_request = requests[0]
|
||
age = now - oldest_request
|
||
retry_after = self.window_seconds - age
|
||
|
||
# Ensure retry_after is at least 1 second (avoid 0 values)
|
||
retry_after = max(retry_after, 1.0)
|
||
|
||
return False, retry_after
|
||
|
||
def check_allowed_for_bucket(
|
||
self,
|
||
bucket: str,
|
||
ip_address: str,
|
||
max_requests: int,
|
||
window_seconds: int,
|
||
) -> tuple[bool, float]:
|
||
"""Check if a request for a specific bucket is allowed.
|
||
|
||
Each bucket has independent rate limiting. This allows different
|
||
endpoints to have different limits (e.g., blocklist import is more
|
||
restrictive than ban operations).
|
||
|
||
Args:
|
||
bucket: Bucket name (e.g., "bans:ban", "blocklist:import").
|
||
ip_address: The client IP address to rate-limit.
|
||
max_requests: Maximum requests allowed within the window.
|
||
window_seconds: Time window (seconds) for this bucket.
|
||
|
||
Returns:
|
||
A tuple of (is_allowed, retry_after_seconds). If is_allowed is True,
|
||
retry_after_seconds is 0. If False, it's the estimated time to wait.
|
||
"""
|
||
now = time()
|
||
|
||
ip_address = normalise_ip(ip_address)
|
||
requests = self._get_bucket_deque(bucket, ip_address, max_requests, window_seconds)
|
||
cutoff = now - window_seconds
|
||
|
||
# Remove old requests outside the window
|
||
while requests and requests[0] < cutoff:
|
||
requests.popleft()
|
||
|
||
# If under the limit, allow the request
|
||
if len(requests) < max_requests:
|
||
requests.append(now)
|
||
return True, 0.0
|
||
|
||
# Over the limit: calculate how long to wait
|
||
oldest_request = requests[0]
|
||
age = now - oldest_request
|
||
retry_after = window_seconds - age
|
||
|
||
# Ensure retry_after is at least 1 second
|
||
retry_after = max(retry_after, 1.0)
|
||
|
||
return False, retry_after
|
||
|
||
def cleanup_expired(self) -> None:
|
||
"""Remove all IPs with no recent requests (cleanup task).
|
||
|
||
Called periodically by the background task to prevent unbounded
|
||
growth of the tracking dictionary.
|
||
"""
|
||
now = time()
|
||
cutoff = now - self.window_seconds
|
||
|
||
ips_to_remove = []
|
||
for ip_address, requests in self._requests.items():
|
||
# Remove old requests
|
||
while requests and requests[0] < cutoff:
|
||
requests.popleft()
|
||
# Mark IP for removal if no requests remain
|
||
if not requests:
|
||
ips_to_remove.append(ip_address)
|
||
|
||
for ip_address in ips_to_remove:
|
||
del self._requests[ip_address]
|
||
|
||
if ips_to_remove:
|
||
log.debug("global_rate_limiter_cleanup", removed_ips=len(ips_to_remove))
|
||
|
||
# Cleanup per-bucket dictionaries
|
||
for bucket, bucket_dict in list(self._buckets.items()):
|
||
bucket_ips_to_remove = []
|
||
bucket_window = 60 # Use a reasonable window for bucket cleanup
|
||
bucket_cutoff = now - bucket_window
|
||
for ip_address, requests in bucket_dict.items():
|
||
while requests and requests[0] < bucket_cutoff:
|
||
requests.popleft()
|
||
if not requests:
|
||
bucket_ips_to_remove.append(ip_address)
|
||
for ip_address in bucket_ips_to_remove:
|
||
del bucket_dict[ip_address]
|
||
if not bucket_dict:
|
||
del self._buckets[bucket]
|
||
|
||
def get_state(self) -> Mapping[str, int]:
|
||
"""Return a read-only view of current request counts per IP.
|
||
|
||
For debugging and monitoring.
|
||
|
||
Returns:
|
||
A mapping of IP addresses to their request counts.
|
||
"""
|
||
now = time()
|
||
cutoff = now - self.window_seconds
|
||
result = {}
|
||
for ip_address, requests in self._requests.items():
|
||
# Count non-expired requests
|
||
count = sum(1 for ts in requests if ts >= cutoff)
|
||
if count > 0:
|
||
result[ip_address] = count
|
||
return result
|
||
|
||
def get_bucket_state(self, bucket: str) -> Mapping[str, int]:
|
||
"""Return a read-only view of current request counts per IP for a bucket.
|
||
|
||
For debugging and monitoring.
|
||
|
||
Args:
|
||
bucket: Bucket name to get state for.
|
||
|
||
Returns:
|
||
A mapping of IP addresses to their request counts in this bucket.
|
||
"""
|
||
if bucket not in self._buckets:
|
||
return {}
|
||
now = time()
|
||
result = {}
|
||
for ip_address, requests in self._buckets[bucket].items():
|
||
# Count non-expired requests (use max window of 3600s for hourly buckets)
|
||
cutoff = now - 3600
|
||
count = sum(1 for ts in requests if ts >= cutoff)
|
||
if count > 0:
|
||
result[ip_address] = count
|
||
return result
|
||
|
||
def reset(self) -> None:
|
||
"""Clear all tracked requests (for testing)."""
|
||
self._requests.clear()
|
||
self._buckets.clear()
|