This commit is contained in:
2025-10-23 18:28:17 +02:00
parent 9a64ca5b01
commit 3d5c19939c
5 changed files with 277 additions and 96 deletions

View File

@@ -91,6 +91,8 @@ async def init_db() -> None:
db_url,
echo=settings.log_level == "DEBUG",
poolclass=pool.StaticPool if "sqlite" in db_url else pool.QueuePool,
pool_size=5 if "sqlite" not in db_url else None,
max_overflow=10 if "sqlite" not in db_url else None,
pool_pre_ping=True,
future=True,
)

View File

@@ -35,6 +35,15 @@ class AuthMiddleware(BaseHTTPMiddleware):
attempts.
- Rate limit records are periodically cleaned to prevent memory leaks.
"""
# Public endpoints that don't require authentication
PUBLIC_PATHS = {
"/api/auth/", # All auth endpoints
"/api/health", # Health check endpoints
"/api/docs", # API documentation
"/api/redoc", # ReDoc documentation
"/openapi.json", # OpenAPI schema
}
def __init__(
self, app: ASGIApp, *, rate_limit_per_minute: int = 5
@@ -42,6 +51,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
super().__init__(app)
# in-memory rate limiter: ip -> {count, window_start}
self._rate: Dict[str, Dict[str, float]] = {}
# origin-based rate limiter for CORS: origin -> {count, window_start}
self._origin_rate: Dict[str, Dict[str, float]] = {}
self.rate_limit_per_minute = rate_limit_per_minute
self.window_seconds = 60
# Track last cleanup time to prevent memory leaks
@@ -51,7 +62,7 @@ class AuthMiddleware(BaseHTTPMiddleware):
def _cleanup_old_entries(self) -> None:
"""Remove rate limit entries older than cleanup interval.
This prevents memory leaks from accumulating old IP addresses.
This prevents memory leaks from accumulating old IP addresses and origins.
"""
now = time.time()
if now - self._last_cleanup < self._cleanup_interval:
@@ -59,6 +70,8 @@ class AuthMiddleware(BaseHTTPMiddleware):
# Remove entries older than 2x window to be safe
cutoff = now - (self.window_seconds * 2)
# Clean IP-based rate limits
old_ips = [
ip for ip, record in self._rate.items()
if record["window_start"] < cutoff
@@ -66,14 +79,58 @@ class AuthMiddleware(BaseHTTPMiddleware):
for ip in old_ips:
del self._rate[ip]
# Clean origin-based rate limits
old_origins = [
origin for origin, record in self._origin_rate.items()
if record["window_start"] < cutoff
]
for origin in old_origins:
del self._origin_rate[origin]
self._last_cleanup = now
def _is_public_path(self, path: str) -> bool:
"""Check if a path is public and doesn't require authentication.
Args:
path: The request path to check
Returns:
bool: True if the path is public, False otherwise
"""
for public_path in self.PUBLIC_PATHS:
if path.startswith(public_path):
return True
return False
async def dispatch(self, request: Request, call_next: Callable):
path = request.url.path or ""
# Periodically clean up old rate limit entries
self._cleanup_old_entries()
# Apply origin-based rate limiting for CORS requests
origin = request.headers.get("origin")
if origin:
origin_rate_record = self._origin_rate.setdefault(
origin,
{"count": 0, "window_start": time.time()},
)
now = time.time()
if now - origin_rate_record["window_start"] > self.window_seconds:
origin_rate_record["window_start"] = now
origin_rate_record["count"] = 0
origin_rate_record["count"] += 1
# Allow higher rate limit for origins (e.g., 60 req/min)
if origin_rate_record["count"] > self.rate_limit_per_minute * 12:
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={
"detail": "Rate limit exceeded for this origin"
},
)
# Apply rate limiting to auth endpoints that accept credentials
if (
path in ("/api/auth/login", "/api/auth/setup")
@@ -114,19 +171,15 @@ class AuthMiddleware(BaseHTTPMiddleware):
# attach to request.state for downstream usage
request.state.session = session.model_dump()
except AuthError:
# Invalid token: if this is a protected API path, reject.
# For public/auth endpoints let the dependency system handle
# optional auth and return None.
is_api = path.startswith("/api/")
is_auth = path.startswith("/api/auth")
if is_api and not is_auth:
# Invalid token: reject if not a public endpoint
if not self._is_public_path(path):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Invalid token"}
content={"detail": "Invalid or expired token"}
)
else:
# No authorization header: check if this is a protected endpoint
if path.startswith("/api/") and not path.startswith("/api/auth"):
if not self._is_public_path(path):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Missing authorization credentials"}