Fix architecture issues from todolist
- Add documentation warnings for in-memory rate limiting and failed login attempts - Consolidate duplicate health endpoints into api/health.py - Fix CLI to use correct async rescan method names - Update download.py and anime.py to use custom exception classes - Add WebSocket room validation and rate limiting
This commit is contained in:
@@ -13,8 +13,9 @@ in their data payload. The `folder` field is optional for display purposes.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from typing import Dict, Optional, Set
|
||||
|
||||
import structlog
|
||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status
|
||||
@@ -34,6 +35,73 @@ logger = structlog.get_logger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/ws", tags=["websocket"])
|
||||
|
||||
# Valid room names - explicit allow-list for security
|
||||
VALID_ROOMS: Set[str] = {
|
||||
"downloads", # Download progress updates
|
||||
"queue", # Queue status changes
|
||||
"scan", # Scan progress updates
|
||||
"system", # System notifications
|
||||
"errors", # Error notifications
|
||||
}
|
||||
|
||||
# Rate limiting configuration for WebSocket messages
|
||||
WS_RATE_LIMIT_MESSAGES_PER_MINUTE = 60
|
||||
WS_RATE_LIMIT_WINDOW_SECONDS = 60
|
||||
|
||||
# In-memory rate limiting for WebSocket connections
|
||||
# WARNING: This resets on process restart. For production, consider Redis.
|
||||
_ws_rate_limits: Dict[str, Dict[str, float]] = {}
|
||||
|
||||
|
||||
def _check_ws_rate_limit(connection_id: str) -> bool:
|
||||
"""Check if a WebSocket connection has exceeded its rate limit.
|
||||
|
||||
Args:
|
||||
connection_id: Unique identifier for the WebSocket connection
|
||||
|
||||
Returns:
|
||||
bool: True if within rate limit, False if exceeded
|
||||
"""
|
||||
now = time.time()
|
||||
|
||||
if connection_id not in _ws_rate_limits:
|
||||
_ws_rate_limits[connection_id] = {
|
||||
"count": 0,
|
||||
"window_start": now,
|
||||
}
|
||||
|
||||
record = _ws_rate_limits[connection_id]
|
||||
|
||||
# Reset window if expired
|
||||
if now - record["window_start"] > WS_RATE_LIMIT_WINDOW_SECONDS:
|
||||
record["window_start"] = now
|
||||
record["count"] = 0
|
||||
|
||||
record["count"] += 1
|
||||
|
||||
return record["count"] <= WS_RATE_LIMIT_MESSAGES_PER_MINUTE
|
||||
|
||||
|
||||
def _cleanup_ws_rate_limits(connection_id: str) -> None:
|
||||
"""Remove rate limit record for a disconnected connection.
|
||||
|
||||
Args:
|
||||
connection_id: Unique identifier for the WebSocket connection
|
||||
"""
|
||||
_ws_rate_limits.pop(connection_id, None)
|
||||
|
||||
|
||||
def _validate_room_name(room: str) -> bool:
|
||||
"""Validate that a room name is in the allowed set.
|
||||
|
||||
Args:
|
||||
room: Room name to validate
|
||||
|
||||
Returns:
|
||||
bool: True if room is valid, False otherwise
|
||||
"""
|
||||
return room in VALID_ROOMS
|
||||
|
||||
|
||||
@router.websocket("/connect")
|
||||
async def websocket_endpoint(
|
||||
@@ -130,6 +198,19 @@ async def websocket_endpoint(
|
||||
# Receive message from client
|
||||
data = await websocket.receive_json()
|
||||
|
||||
# Check rate limit
|
||||
if not _check_ws_rate_limit(connection_id):
|
||||
logger.warning(
|
||||
"WebSocket rate limit exceeded",
|
||||
connection_id=connection_id,
|
||||
)
|
||||
await ws_service.send_error(
|
||||
connection_id,
|
||||
"Rate limit exceeded. Please slow down.",
|
||||
"RATE_LIMIT_EXCEEDED",
|
||||
)
|
||||
continue
|
||||
|
||||
# Parse client message
|
||||
try:
|
||||
client_msg = ClientMessage(**data)
|
||||
@@ -149,9 +230,26 @@ async def websocket_endpoint(
|
||||
# Handle room subscription requests
|
||||
if client_msg.action in ["join", "leave"]:
|
||||
try:
|
||||
room_name = client_msg.data.get("room", "")
|
||||
|
||||
# Validate room name against allow-list
|
||||
if not _validate_room_name(room_name):
|
||||
logger.warning(
|
||||
"Invalid room name requested",
|
||||
connection_id=connection_id,
|
||||
room=room_name,
|
||||
)
|
||||
await ws_service.send_error(
|
||||
connection_id,
|
||||
f"Invalid room name: {room_name}. "
|
||||
f"Valid rooms: {', '.join(sorted(VALID_ROOMS))}",
|
||||
"INVALID_ROOM",
|
||||
)
|
||||
continue
|
||||
|
||||
room_req = RoomSubscriptionRequest(
|
||||
action=client_msg.action,
|
||||
room=client_msg.data.get("room", ""),
|
||||
room=room_name,
|
||||
)
|
||||
|
||||
if room_req.action == "join":
|
||||
@@ -241,7 +339,8 @@ async def websocket_endpoint(
|
||||
error=str(e),
|
||||
)
|
||||
finally:
|
||||
# Cleanup connection
|
||||
# Cleanup connection and rate limit record
|
||||
_cleanup_ws_rate_limits(connection_id)
|
||||
await ws_service.disconnect(connection_id)
|
||||
logger.info("WebSocket connection closed", connection_id=connection_id)
|
||||
|
||||
@@ -263,5 +362,6 @@ async def websocket_status(
|
||||
"status": "operational",
|
||||
"active_connections": connection_count,
|
||||
"supported_message_types": [t.value for t in WebSocketMessageType],
|
||||
"valid_rooms": sorted(VALID_ROOMS),
|
||||
},
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user