remove part 3
This commit is contained in:
@@ -1,331 +0,0 @@
|
||||
"""Rate limiting middleware for API endpoints.
|
||||
|
||||
This module provides comprehensive rate limiting with support for:
|
||||
- Endpoint-specific rate limits
|
||||
- IP-based limiting
|
||||
- User-based rate limiting
|
||||
- Bypass mechanisms for authenticated users
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
|
||||
from fastapi import Request, status
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
|
||||
|
||||
class RateLimitConfig:
|
||||
"""Configuration for rate limiting rules."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
requests_per_minute: int = 60,
|
||||
requests_per_hour: int = 1000,
|
||||
authenticated_multiplier: float = 2.0,
|
||||
):
|
||||
"""Initialize rate limit configuration.
|
||||
|
||||
Args:
|
||||
requests_per_minute: Max requests per minute for
|
||||
unauthenticated users
|
||||
requests_per_hour: Max requests per hour for
|
||||
unauthenticated users
|
||||
authenticated_multiplier: Multiplier for authenticated users
|
||||
"""
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.requests_per_hour = requests_per_hour
|
||||
self.authenticated_multiplier = authenticated_multiplier
|
||||
|
||||
|
||||
class RateLimitStore:
|
||||
"""In-memory store for rate limit tracking."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the rate limit store."""
|
||||
# Store format: {identifier: [(timestamp, count), ...]}
|
||||
self._minute_store: Dict[str, list] = defaultdict(list)
|
||||
self._hour_store: Dict[str, list] = defaultdict(list)
|
||||
|
||||
def check_limit(
|
||||
self,
|
||||
identifier: str,
|
||||
max_per_minute: int,
|
||||
max_per_hour: int,
|
||||
) -> Tuple[bool, Optional[int]]:
|
||||
"""Check if the identifier has exceeded rate limits.
|
||||
|
||||
Args:
|
||||
identifier: Unique identifier (IP or user ID)
|
||||
max_per_minute: Maximum requests allowed per minute
|
||||
max_per_hour: Maximum requests allowed per hour
|
||||
|
||||
Returns:
|
||||
Tuple of (allowed, retry_after_seconds)
|
||||
"""
|
||||
current_time = time.time()
|
||||
|
||||
# Clean up old entries
|
||||
self._cleanup_old_entries(identifier, current_time)
|
||||
|
||||
# Check minute limit
|
||||
minute_count = len(self._minute_store[identifier])
|
||||
if minute_count >= max_per_minute:
|
||||
# Calculate retry after time
|
||||
oldest_entry = self._minute_store[identifier][0]
|
||||
retry_after = int(60 - (current_time - oldest_entry))
|
||||
return False, max(retry_after, 1)
|
||||
|
||||
# Check hour limit
|
||||
hour_count = len(self._hour_store[identifier])
|
||||
if hour_count >= max_per_hour:
|
||||
# Calculate retry after time
|
||||
oldest_entry = self._hour_store[identifier][0]
|
||||
retry_after = int(3600 - (current_time - oldest_entry))
|
||||
return False, max(retry_after, 1)
|
||||
|
||||
return True, None
|
||||
|
||||
def record_request(self, identifier: str) -> None:
|
||||
"""Record a request for the identifier.
|
||||
|
||||
Args:
|
||||
identifier: Unique identifier (IP or user ID)
|
||||
"""
|
||||
current_time = time.time()
|
||||
self._minute_store[identifier].append(current_time)
|
||||
self._hour_store[identifier].append(current_time)
|
||||
|
||||
def get_remaining_requests(
|
||||
self, identifier: str, max_per_minute: int, max_per_hour: int
|
||||
) -> Tuple[int, int]:
|
||||
"""Get remaining requests for the identifier.
|
||||
|
||||
Args:
|
||||
identifier: Unique identifier
|
||||
max_per_minute: Maximum per minute
|
||||
max_per_hour: Maximum per hour
|
||||
|
||||
Returns:
|
||||
Tuple of (remaining_per_minute, remaining_per_hour)
|
||||
"""
|
||||
minute_used = len(self._minute_store.get(identifier, []))
|
||||
hour_used = len(self._hour_store.get(identifier, []))
|
||||
return (
|
||||
max(0, max_per_minute - minute_used),
|
||||
max(0, max_per_hour - hour_used)
|
||||
)
|
||||
|
||||
def _cleanup_old_entries(
|
||||
self, identifier: str, current_time: float
|
||||
) -> None:
|
||||
"""Remove entries older than the time windows.
|
||||
|
||||
Args:
|
||||
identifier: Unique identifier
|
||||
current_time: Current timestamp
|
||||
"""
|
||||
# Remove entries older than 1 minute
|
||||
minute_cutoff = current_time - 60
|
||||
self._minute_store[identifier] = [
|
||||
ts for ts in self._minute_store[identifier] if ts > minute_cutoff
|
||||
]
|
||||
|
||||
# Remove entries older than 1 hour
|
||||
hour_cutoff = current_time - 3600
|
||||
self._hour_store[identifier] = [
|
||||
ts for ts in self._hour_store[identifier] if ts > hour_cutoff
|
||||
]
|
||||
|
||||
# Clean up empty entries
|
||||
if not self._minute_store[identifier]:
|
||||
del self._minute_store[identifier]
|
||||
if not self._hour_store[identifier]:
|
||||
del self._hour_store[identifier]
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware for API rate limiting."""
|
||||
|
||||
# Endpoint-specific rate limits (overrides defaults)
|
||||
ENDPOINT_LIMITS: Dict[str, RateLimitConfig] = {
|
||||
"/api/auth/login": RateLimitConfig(
|
||||
requests_per_minute=5,
|
||||
requests_per_hour=20,
|
||||
),
|
||||
"/api/auth/register": RateLimitConfig(
|
||||
requests_per_minute=3,
|
||||
requests_per_hour=10,
|
||||
),
|
||||
"/api/download": RateLimitConfig(
|
||||
requests_per_minute=10,
|
||||
requests_per_hour=100,
|
||||
authenticated_multiplier=3.0,
|
||||
),
|
||||
}
|
||||
|
||||
# Paths that bypass rate limiting
|
||||
BYPASS_PATHS = {
|
||||
"/health",
|
||||
"/health/detailed",
|
||||
"/docs",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/static",
|
||||
"/ws",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app,
|
||||
default_config: Optional[RateLimitConfig] = None,
|
||||
):
|
||||
"""Initialize rate limiting middleware.
|
||||
|
||||
Args:
|
||||
app: FastAPI application
|
||||
default_config: Default rate limit configuration
|
||||
"""
|
||||
super().__init__(app)
|
||||
self.default_config = default_config or RateLimitConfig()
|
||||
self.store = RateLimitStore()
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable):
|
||||
"""Process request and apply rate limiting.
|
||||
|
||||
Args:
|
||||
request: Incoming HTTP request
|
||||
call_next: Next middleware or endpoint handler
|
||||
|
||||
Returns:
|
||||
HTTP response (either rate limit error or normal response)
|
||||
"""
|
||||
# Check if path should bypass rate limiting
|
||||
if self._should_bypass(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
# Get identifier (user ID if authenticated, otherwise IP)
|
||||
identifier = self._get_identifier(request)
|
||||
|
||||
# Get rate limit configuration for this endpoint
|
||||
config = self._get_endpoint_config(request.url.path)
|
||||
|
||||
# Apply authenticated user multiplier if applicable
|
||||
is_authenticated = self._is_authenticated(request)
|
||||
max_per_minute = int(
|
||||
config.requests_per_minute *
|
||||
(config.authenticated_multiplier if is_authenticated else 1.0)
|
||||
)
|
||||
max_per_hour = int(
|
||||
config.requests_per_hour *
|
||||
(config.authenticated_multiplier if is_authenticated else 1.0)
|
||||
)
|
||||
|
||||
# Check rate limit
|
||||
allowed, retry_after = self.store.check_limit(
|
||||
identifier,
|
||||
max_per_minute,
|
||||
max_per_hour,
|
||||
)
|
||||
|
||||
if not allowed:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={"detail": "Rate limit exceeded"},
|
||||
headers={"Retry-After": str(retry_after)},
|
||||
)
|
||||
|
||||
# Record the request
|
||||
self.store.record_request(identifier)
|
||||
|
||||
# Add rate limit headers to response
|
||||
response = await call_next(request)
|
||||
response.headers["X-RateLimit-Limit-Minute"] = str(max_per_minute)
|
||||
response.headers["X-RateLimit-Limit-Hour"] = str(max_per_hour)
|
||||
|
||||
minute_remaining, hour_remaining = self.store.get_remaining_requests(
|
||||
identifier, max_per_minute, max_per_hour
|
||||
)
|
||||
|
||||
response.headers["X-RateLimit-Remaining-Minute"] = str(
|
||||
minute_remaining
|
||||
)
|
||||
response.headers["X-RateLimit-Remaining-Hour"] = str(
|
||||
hour_remaining
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
def _should_bypass(self, path: str) -> bool:
|
||||
"""Check if path should bypass rate limiting.
|
||||
|
||||
Args:
|
||||
path: Request path
|
||||
|
||||
Returns:
|
||||
True if path should bypass rate limiting
|
||||
"""
|
||||
for bypass_path in self.BYPASS_PATHS:
|
||||
if path.startswith(bypass_path):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _get_identifier(self, request: Request) -> str:
|
||||
"""Get unique identifier for rate limiting.
|
||||
|
||||
Args:
|
||||
request: HTTP request
|
||||
|
||||
Returns:
|
||||
Unique identifier (user ID or IP address)
|
||||
"""
|
||||
# Try to get user ID from request state (set by auth middleware)
|
||||
user_id = getattr(request.state, "user_id", None)
|
||||
if user_id:
|
||||
return f"user:{user_id}"
|
||||
|
||||
# Fall back to IP address
|
||||
# Check for X-Forwarded-For header (proxy/load balancer)
|
||||
forwarded_for = request.headers.get("X-Forwarded-For")
|
||||
if forwarded_for:
|
||||
# Take the first IP in the chain
|
||||
client_ip = forwarded_for.split(",")[0].strip()
|
||||
else:
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
|
||||
return f"ip:{client_ip}"
|
||||
|
||||
def _get_endpoint_config(self, path: str) -> RateLimitConfig:
|
||||
"""Get rate limit configuration for endpoint.
|
||||
|
||||
Args:
|
||||
path: Request path
|
||||
|
||||
Returns:
|
||||
Rate limit configuration
|
||||
"""
|
||||
# Check for exact match
|
||||
if path in self.ENDPOINT_LIMITS:
|
||||
return self.ENDPOINT_LIMITS[path]
|
||||
|
||||
# Check for prefix match
|
||||
for endpoint_path, config in self.ENDPOINT_LIMITS.items():
|
||||
if path.startswith(endpoint_path):
|
||||
return config
|
||||
|
||||
return self.default_config
|
||||
|
||||
def _is_authenticated(self, request: Request) -> bool:
|
||||
"""Check if request is from authenticated user.
|
||||
|
||||
Args:
|
||||
request: HTTP request
|
||||
|
||||
Returns:
|
||||
True if user is authenticated
|
||||
"""
|
||||
return (
|
||||
hasattr(request.state, "user_id") and
|
||||
request.state.user_id is not None
|
||||
)
|
||||
Reference in New Issue
Block a user