Files
BanGUI/backend/app/utils/rate_limiter.py
Lukas 7ec80fdeec refactor(logging): replace structlog with stdlib logging compat layer
- Remove structlog dependency from backend/pyproject.toml
- Add app.utils.logging_compat shim for keyword-arg logging API
- Add app.utils.json_formatter for JSON log output with extra fields
- Update all backend modules to use logging_compat.get_logger()
- Update docstrings in log_sanitizer.py and json_formatter.py
- Update test comment in test_async_utils.py
- Record 406 failing tests in Docs/Tasks.md for tracking
2026-05-10 13:37:54 +02:00

295 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""In-memory global rate limiter for IP-based request throttling.
Implements a sliding-window request counter per IP address. Old entries are
cleaned up by a background task to prevent unbounded growth.
Process-local implementation — in multi-worker setups, each worker has
independent counters. This constraint limits the blast radius of abuse to a
single worker.
**Cleanup Lifecycle**: The rate limiter state grows as IPs interact with the
system. To prevent unbounded memory growth during long runtimes, a scheduled
background task (rate_limiter_cleanup) calls cleanup_expired() every 30 minutes.
This is safe because:
- cleanup_expired() only removes IPs with no recent requests (all timestamps
outside the rate-limit window), so active IPs are never disrupted.
- The cleanup is non-blocking and logged for observability.
- Individual requests already prune old timestamps from each IP's deque during
check_allowed(), so cleanup primarily handles dormant IPs.
For monitoring, check logs for "global_rate_limiter_cleanup" events to observe
how many IPs are being retired from memory each cleanup cycle.
"""
from __future__ import annotations
from collections import deque
from time import time
from typing import TYPE_CHECKING
from app.utils.logging_compat import get_logger
from app.utils.ip_utils import normalise_ip
if TYPE_CHECKING:
from collections.abc import Mapping
log = get_logger(__name__)
class GlobalRateLimiter:
"""Global per-IP request rate limiter using sliding window algorithm.
Tracks total request count within a configurable time window per IP address.
This implements simple
request counting: when an IP exceeds the limit, the next request is blocked
until the oldest request in the window expires.
**Process-local implementation** — Each worker maintains independent counters.
In multi-worker deployments (N workers), an attacker can send up to N × limit
requests before any single worker triggers a block. The single-worker scheduler
lock provides partial protection, but deployments requiring horizontal scaling
should replace this with a Redis-backed store using atomic INCR + EXPIRE.
**Long-term migration path:** The check_allowed() and check_allowed_for_bucket()
interfaces map directly to Redis INCR + EXPIRE. A drop-in RedisRateLimiter
adapter would only need to replace the deque-based in-memory store with Redis
calls, without touching any caller code.
**How It Works:**
1. Each request is recorded as a timestamp in the IP's deque.
2. Old timestamps outside the window are automatically removed.
3. If the number of requests within the window exceeds the limit, the IP is
blocked until the oldest request expires.
4. A background cleanup task removes dormant IPs from memory periodically.
**Per-Bucket Configuration:**
Different endpoints can have different limits via named buckets:
- `bans:ban` — 100/minute per IP (ban operations)
- `bans:unban` — 100/minute per IP (unban operations)
- `blocklist:import` — 10/hour per IP (import operations)
- `config:update` — 50/minute per IP (config write operations)
Each bucket tracks its own requests independently, so hitting the
blocklist:import limit does not affect the bans:ban limit.
"""
def __init__(
self,
max_requests: int = 200,
window_seconds: int = 60,
) -> None:
"""Initialize the global rate limiter.
Args:
max_requests: Maximum requests allowed within the window.
window_seconds: Time window (seconds) for the rate limit.
"""
self.max_requests: int = max_requests
self.window_seconds: int = window_seconds
self._requests: dict[str, deque[float]] = {}
self._buckets: dict[str, dict[str, deque[float]]] = {}
def _get_bucket_deque(
self,
bucket: str,
ip_address: str,
max_requests: int,
window_seconds: int,
) -> deque[float]:
"""Get or create the deque for a specific bucket and IP.
Args:
bucket: Bucket name (e.g., "bans:ban").
ip_address: Client IP address.
max_requests: Maximum requests for this bucket (unused, for future).
window_seconds: Window in seconds (unused, for future).
Returns:
The deque of timestamps for this bucket+IP.
"""
if bucket not in self._buckets:
self._buckets[bucket] = {}
bucket_dict = self._buckets[bucket]
if ip_address not in bucket_dict:
bucket_dict[ip_address] = deque()
return bucket_dict[ip_address]
def check_allowed(self, ip_address: str) -> tuple[bool, float]:
"""Check if a request from *ip_address* is allowed.
Returns both whether the request is allowed and the seconds to wait
if it's not (used for Retry-After header).
Args:
ip_address: The client IP address to rate-limit.
Returns:
A tuple of (is_allowed, retry_after_seconds). If is_allowed is True,
retry_after_seconds is 0. If False, it's the estimated time to wait.
"""
ip_address = normalise_ip(ip_address)
now = time()
if ip_address not in self._requests:
self._requests[ip_address] = deque()
requests = self._requests[ip_address]
cutoff = now - self.window_seconds
# Remove old requests outside the window
while requests and requests[0] < cutoff:
requests.popleft()
# If under the limit, allow the request
if len(requests) < self.max_requests:
requests.append(now)
return True, 0.0
# Over the limit: calculate how long to wait
# The oldest request in the window will expire in (window - age) seconds
oldest_request = requests[0]
age = now - oldest_request
retry_after = self.window_seconds - age
# Ensure retry_after is at least 1 second (avoid 0 values)
retry_after = max(retry_after, 1.0)
return False, retry_after
def check_allowed_for_bucket(
self,
bucket: str,
ip_address: str,
max_requests: int,
window_seconds: int,
) -> tuple[bool, float]:
"""Check if a request for a specific bucket is allowed.
Each bucket has independent rate limiting. This allows different
endpoints to have different limits (e.g., blocklist import is more
restrictive than ban operations).
Args:
bucket: Bucket name (e.g., "bans:ban", "blocklist:import").
ip_address: The client IP address to rate-limit.
max_requests: Maximum requests allowed within the window.
window_seconds: Time window (seconds) for this bucket.
Returns:
A tuple of (is_allowed, retry_after_seconds). If is_allowed is True,
retry_after_seconds is 0. If False, it's the estimated time to wait.
"""
now = time()
ip_address = normalise_ip(ip_address)
requests = self._get_bucket_deque(bucket, ip_address, max_requests, window_seconds)
cutoff = now - window_seconds
# Remove old requests outside the window
while requests and requests[0] < cutoff:
requests.popleft()
# If under the limit, allow the request
if len(requests) < max_requests:
requests.append(now)
return True, 0.0
# Over the limit: calculate how long to wait
oldest_request = requests[0]
age = now - oldest_request
retry_after = window_seconds - age
# Ensure retry_after is at least 1 second
retry_after = max(retry_after, 1.0)
return False, retry_after
def cleanup_expired(self) -> None:
"""Remove all IPs with no recent requests (cleanup task).
Called periodically by the background task to prevent unbounded
growth of the tracking dictionary.
"""
now = time()
cutoff = now - self.window_seconds
ips_to_remove = []
for ip_address, requests in self._requests.items():
# Remove old requests
while requests and requests[0] < cutoff:
requests.popleft()
# Mark IP for removal if no requests remain
if not requests:
ips_to_remove.append(ip_address)
for ip_address in ips_to_remove:
del self._requests[ip_address]
if ips_to_remove:
log.debug("global_rate_limiter_cleanup", removed_ips=len(ips_to_remove))
# Cleanup per-bucket dictionaries
for bucket, bucket_dict in list(self._buckets.items()):
bucket_ips_to_remove = []
bucket_window = 60 # Use a reasonable window for bucket cleanup
bucket_cutoff = now - bucket_window
for ip_address, requests in bucket_dict.items():
while requests and requests[0] < bucket_cutoff:
requests.popleft()
if not requests:
bucket_ips_to_remove.append(ip_address)
for ip_address in bucket_ips_to_remove:
del bucket_dict[ip_address]
if not bucket_dict:
del self._buckets[bucket]
def get_state(self) -> Mapping[str, int]:
"""Return a read-only view of current request counts per IP.
For debugging and monitoring.
Returns:
A mapping of IP addresses to their request counts.
"""
now = time()
cutoff = now - self.window_seconds
result = {}
for ip_address, requests in self._requests.items():
# Count non-expired requests
count = sum(1 for ts in requests if ts >= cutoff)
if count > 0:
result[ip_address] = count
return result
def get_bucket_state(self, bucket: str) -> Mapping[str, int]:
"""Return a read-only view of current request counts per IP for a bucket.
For debugging and monitoring.
Args:
bucket: Bucket name to get state for.
Returns:
A mapping of IP addresses to their request counts in this bucket.
"""
if bucket not in self._buckets:
return {}
now = time()
result = {}
for ip_address, requests in self._buckets[bucket].items():
# Count non-expired requests (use max window of 3600s for hourly buckets)
cutoff = now - 3600
count = sum(1 for ts in requests if ts >= cutoff)
if count > 0:
result[ip_address] = count
return result
def reset(self) -> None:
"""Clear all tracked requests (for testing)."""
self._requests.clear()
self._buckets.clear()