cleanup
This commit is contained in:
@@ -35,6 +35,15 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
attempts.
|
||||
- Rate limit records are periodically cleaned to prevent memory leaks.
|
||||
"""
|
||||
|
||||
# Public endpoints that don't require authentication
|
||||
PUBLIC_PATHS = {
|
||||
"/api/auth/", # All auth endpoints
|
||||
"/api/health", # Health check endpoints
|
||||
"/api/docs", # API documentation
|
||||
"/api/redoc", # ReDoc documentation
|
||||
"/openapi.json", # OpenAPI schema
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self, app: ASGIApp, *, rate_limit_per_minute: int = 5
|
||||
@@ -42,6 +51,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
super().__init__(app)
|
||||
# in-memory rate limiter: ip -> {count, window_start}
|
||||
self._rate: Dict[str, Dict[str, float]] = {}
|
||||
# origin-based rate limiter for CORS: origin -> {count, window_start}
|
||||
self._origin_rate: Dict[str, Dict[str, float]] = {}
|
||||
self.rate_limit_per_minute = rate_limit_per_minute
|
||||
self.window_seconds = 60
|
||||
# Track last cleanup time to prevent memory leaks
|
||||
@@ -51,7 +62,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
def _cleanup_old_entries(self) -> None:
|
||||
"""Remove rate limit entries older than cleanup interval.
|
||||
|
||||
This prevents memory leaks from accumulating old IP addresses.
|
||||
This prevents memory leaks from accumulating old IP addresses and origins.
|
||||
"""
|
||||
now = time.time()
|
||||
if now - self._last_cleanup < self._cleanup_interval:
|
||||
@@ -59,6 +70,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
|
||||
# Remove entries older than 2x window to be safe
|
||||
cutoff = now - (self.window_seconds * 2)
|
||||
|
||||
# Clean IP-based rate limits
|
||||
old_ips = [
|
||||
ip for ip, record in self._rate.items()
|
||||
if record["window_start"] < cutoff
|
||||
@@ -66,14 +79,58 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
for ip in old_ips:
|
||||
del self._rate[ip]
|
||||
|
||||
# Clean origin-based rate limits
|
||||
old_origins = [
|
||||
origin for origin, record in self._origin_rate.items()
|
||||
if record["window_start"] < cutoff
|
||||
]
|
||||
for origin in old_origins:
|
||||
del self._origin_rate[origin]
|
||||
|
||||
self._last_cleanup = now
|
||||
|
||||
def _is_public_path(self, path: str) -> bool:
|
||||
"""Check if a path is public and doesn't require authentication.
|
||||
|
||||
Args:
|
||||
path: The request path to check
|
||||
|
||||
Returns:
|
||||
bool: True if the path is public, False otherwise
|
||||
"""
|
||||
for public_path in self.PUBLIC_PATHS:
|
||||
if path.startswith(public_path):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable):
|
||||
path = request.url.path or ""
|
||||
|
||||
# Periodically clean up old rate limit entries
|
||||
self._cleanup_old_entries()
|
||||
|
||||
# Apply origin-based rate limiting for CORS requests
|
||||
origin = request.headers.get("origin")
|
||||
if origin:
|
||||
origin_rate_record = self._origin_rate.setdefault(
|
||||
origin,
|
||||
{"count": 0, "window_start": time.time()},
|
||||
)
|
||||
now = time.time()
|
||||
if now - origin_rate_record["window_start"] > self.window_seconds:
|
||||
origin_rate_record["window_start"] = now
|
||||
origin_rate_record["count"] = 0
|
||||
|
||||
origin_rate_record["count"] += 1
|
||||
# Allow higher rate limit for origins (e.g., 60 req/min)
|
||||
if origin_rate_record["count"] > self.rate_limit_per_minute * 12:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={
|
||||
"detail": "Rate limit exceeded for this origin"
|
||||
},
|
||||
)
|
||||
|
||||
# Apply rate limiting to auth endpoints that accept credentials
|
||||
if (
|
||||
path in ("/api/auth/login", "/api/auth/setup")
|
||||
@@ -114,19 +171,15 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
||||
# attach to request.state for downstream usage
|
||||
request.state.session = session.model_dump()
|
||||
except AuthError:
|
||||
# Invalid token: if this is a protected API path, reject.
|
||||
# For public/auth endpoints let the dependency system handle
|
||||
# optional auth and return None.
|
||||
is_api = path.startswith("/api/")
|
||||
is_auth = path.startswith("/api/auth")
|
||||
if is_api and not is_auth:
|
||||
# Invalid token: reject if not a public endpoint
|
||||
if not self._is_public_path(path):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"detail": "Invalid token"}
|
||||
content={"detail": "Invalid or expired token"}
|
||||
)
|
||||
else:
|
||||
# No authorization header: check if this is a protected endpoint
|
||||
if path.startswith("/api/") and not path.startswith("/api/auth"):
|
||||
if not self._is_public_path(path):
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={"detail": "Missing authorization credentials"}
|
||||
|
||||
Reference in New Issue
Block a user