- Add documentation warnings for in-memory rate limiting and failed login attempts - Consolidate duplicate health endpoints into api/health.py - Fix CLI to use correct async rescan method names - Update download.py and anime.py to use custom exception classes - Add WebSocket room validation and rate limiting
368 lines
13 KiB
Python
368 lines
13 KiB
Python
"""WebSocket API endpoints for real-time communication.
|
|
|
|
This module provides WebSocket endpoints for clients to connect and receive
|
|
real-time updates about downloads, queue status, and system events.
|
|
|
|
Series Identifier Convention:
|
|
- `key`: Primary identifier for series (provider-assigned, URL-safe)
|
|
e.g., "attack-on-titan"
|
|
- `folder`: Display metadata only (e.g., "Attack on Titan (2013)")
|
|
|
|
All series-related WebSocket events include `key` as the primary identifier
|
|
in their data payload. The `folder` field is optional for display purposes.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import time
|
|
import uuid
|
|
from typing import Dict, Optional, Set
|
|
|
|
import structlog
|
|
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status
|
|
from fastapi.responses import JSONResponse
|
|
|
|
from src.server.models.websocket import (
|
|
ClientMessage,
|
|
RoomSubscriptionRequest,
|
|
WebSocketMessageType,
|
|
)
|
|
from src.server.services.websocket_service import (
|
|
WebSocketService,
|
|
get_websocket_service,
|
|
)
|
|
|
|
logger = structlog.get_logger(__name__)
|
|
|
|
router = APIRouter(prefix="/ws", tags=["websocket"])
|
|
|
|
# Valid room names - explicit allow-list for security
|
|
VALID_ROOMS: Set[str] = {
|
|
"downloads", # Download progress updates
|
|
"queue", # Queue status changes
|
|
"scan", # Scan progress updates
|
|
"system", # System notifications
|
|
"errors", # Error notifications
|
|
}
|
|
|
|
# Rate limiting configuration for WebSocket messages
|
|
WS_RATE_LIMIT_MESSAGES_PER_MINUTE = 60
|
|
WS_RATE_LIMIT_WINDOW_SECONDS = 60
|
|
|
|
# In-memory rate limiting for WebSocket connections
|
|
# WARNING: This resets on process restart. For production, consider Redis.
|
|
_ws_rate_limits: Dict[str, Dict[str, float]] = {}
|
|
|
|
|
|
def _check_ws_rate_limit(connection_id: str) -> bool:
|
|
"""Check if a WebSocket connection has exceeded its rate limit.
|
|
|
|
Args:
|
|
connection_id: Unique identifier for the WebSocket connection
|
|
|
|
Returns:
|
|
bool: True if within rate limit, False if exceeded
|
|
"""
|
|
now = time.time()
|
|
|
|
if connection_id not in _ws_rate_limits:
|
|
_ws_rate_limits[connection_id] = {
|
|
"count": 0,
|
|
"window_start": now,
|
|
}
|
|
|
|
record = _ws_rate_limits[connection_id]
|
|
|
|
# Reset window if expired
|
|
if now - record["window_start"] > WS_RATE_LIMIT_WINDOW_SECONDS:
|
|
record["window_start"] = now
|
|
record["count"] = 0
|
|
|
|
record["count"] += 1
|
|
|
|
return record["count"] <= WS_RATE_LIMIT_MESSAGES_PER_MINUTE
|
|
|
|
|
|
def _cleanup_ws_rate_limits(connection_id: str) -> None:
|
|
"""Remove rate limit record for a disconnected connection.
|
|
|
|
Args:
|
|
connection_id: Unique identifier for the WebSocket connection
|
|
"""
|
|
_ws_rate_limits.pop(connection_id, None)
|
|
|
|
|
|
def _validate_room_name(room: str) -> bool:
|
|
"""Validate that a room name is in the allowed set.
|
|
|
|
Args:
|
|
room: Room name to validate
|
|
|
|
Returns:
|
|
bool: True if room is valid, False otherwise
|
|
"""
|
|
return room in VALID_ROOMS
|
|
|
|
|
|
@router.websocket("/connect")
|
|
async def websocket_endpoint(
|
|
websocket: WebSocket,
|
|
token: Optional[str] = None,
|
|
ws_service: WebSocketService = Depends(get_websocket_service),
|
|
):
|
|
"""WebSocket endpoint for client connections.
|
|
|
|
Clients connect to this endpoint to receive real-time updates.
|
|
The connection is maintained until the client disconnects or
|
|
an error occurs.
|
|
|
|
Authentication:
|
|
- Optional token can be passed as query parameter: /ws/connect?token=<jwt>
|
|
- Unauthenticated connections are allowed but may have limited access
|
|
|
|
Message flow:
|
|
1. Client connects
|
|
2. Server sends "connected" message
|
|
3. Client can send subscription requests (join/leave rooms)
|
|
4. Server broadcasts updates to subscribed rooms
|
|
5. Client disconnects
|
|
|
|
Example client subscription:
|
|
```json
|
|
{
|
|
"action": "join",
|
|
"room": "downloads"
|
|
}
|
|
```
|
|
|
|
Server message format (series-related events include 'key' identifier):
|
|
```json
|
|
{
|
|
"type": "download_progress",
|
|
"timestamp": "2025-10-17T10:30:00.000Z",
|
|
"data": {
|
|
"download_id": "abc123",
|
|
"key": "attack-on-titan",
|
|
"folder": "Attack on Titan (2013)",
|
|
"percent": 45.2,
|
|
"speed_mbps": 2.5,
|
|
"eta_seconds": 180
|
|
}
|
|
}
|
|
```
|
|
|
|
Note:
|
|
- `key` is the primary series identifier (provider-assigned, URL-safe)
|
|
- `folder` is optional display metadata
|
|
"""
|
|
connection_id = str(uuid.uuid4())
|
|
user_id: Optional[str] = None
|
|
|
|
# Optional: Validate token if provided
|
|
if token:
|
|
try:
|
|
from src.server.services.auth_service import auth_service
|
|
session = auth_service.create_session_model(token)
|
|
user_id = session.user_id
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Invalid WebSocket authentication token",
|
|
connection_id=connection_id,
|
|
error=str(e),
|
|
)
|
|
|
|
try:
|
|
# Accept connection and register with service
|
|
await ws_service.connect(websocket, connection_id, user_id=user_id)
|
|
|
|
# Send connection confirmation
|
|
await ws_service.manager.send_personal_message(
|
|
{
|
|
"type": WebSocketMessageType.CONNECTED,
|
|
"data": {
|
|
"connection_id": connection_id,
|
|
"message": "Connected to Aniworld WebSocket",
|
|
},
|
|
},
|
|
connection_id,
|
|
)
|
|
|
|
logger.info(
|
|
"WebSocket client connected",
|
|
connection_id=connection_id,
|
|
user_id=user_id,
|
|
)
|
|
|
|
# Handle incoming messages
|
|
while True:
|
|
try:
|
|
# Receive message from client
|
|
data = await websocket.receive_json()
|
|
|
|
# Check rate limit
|
|
if not _check_ws_rate_limit(connection_id):
|
|
logger.warning(
|
|
"WebSocket rate limit exceeded",
|
|
connection_id=connection_id,
|
|
)
|
|
await ws_service.send_error(
|
|
connection_id,
|
|
"Rate limit exceeded. Please slow down.",
|
|
"RATE_LIMIT_EXCEEDED",
|
|
)
|
|
continue
|
|
|
|
# Parse client message
|
|
try:
|
|
client_msg = ClientMessage(**data)
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Invalid client message format",
|
|
connection_id=connection_id,
|
|
error=str(e),
|
|
)
|
|
await ws_service.send_error(
|
|
connection_id,
|
|
"Invalid message format",
|
|
"INVALID_MESSAGE",
|
|
)
|
|
continue
|
|
|
|
# Handle room subscription requests
|
|
if client_msg.action in ["join", "leave"]:
|
|
try:
|
|
room_name = client_msg.data.get("room", "")
|
|
|
|
# Validate room name against allow-list
|
|
if not _validate_room_name(room_name):
|
|
logger.warning(
|
|
"Invalid room name requested",
|
|
connection_id=connection_id,
|
|
room=room_name,
|
|
)
|
|
await ws_service.send_error(
|
|
connection_id,
|
|
f"Invalid room name: {room_name}. "
|
|
f"Valid rooms: {', '.join(sorted(VALID_ROOMS))}",
|
|
"INVALID_ROOM",
|
|
)
|
|
continue
|
|
|
|
room_req = RoomSubscriptionRequest(
|
|
action=client_msg.action,
|
|
room=room_name,
|
|
)
|
|
|
|
if room_req.action == "join":
|
|
await ws_service.manager.join_room(
|
|
connection_id, room_req.room
|
|
)
|
|
await ws_service.manager.send_personal_message(
|
|
{
|
|
"type": WebSocketMessageType.SYSTEM_INFO,
|
|
"data": {
|
|
"message": (
|
|
f"Joined room: {room_req.room}"
|
|
)
|
|
},
|
|
},
|
|
connection_id,
|
|
)
|
|
elif room_req.action == "leave":
|
|
await ws_service.manager.leave_room(
|
|
connection_id, room_req.room
|
|
)
|
|
await ws_service.manager.send_personal_message(
|
|
{
|
|
"type": WebSocketMessageType.SYSTEM_INFO,
|
|
"data": {
|
|
"message": (
|
|
f"Left room: {room_req.room}"
|
|
)
|
|
},
|
|
},
|
|
connection_id,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.warning(
|
|
"Invalid room subscription request",
|
|
connection_id=connection_id,
|
|
error=str(e),
|
|
)
|
|
await ws_service.send_error(
|
|
connection_id,
|
|
"Invalid room subscription",
|
|
"INVALID_SUBSCRIPTION",
|
|
)
|
|
|
|
# Handle ping/pong for keepalive
|
|
elif client_msg.action == "ping":
|
|
await ws_service.manager.send_personal_message(
|
|
{"type": WebSocketMessageType.PONG, "data": {}},
|
|
connection_id,
|
|
)
|
|
|
|
else:
|
|
logger.debug(
|
|
"Unknown action from client",
|
|
connection_id=connection_id,
|
|
action=client_msg.action,
|
|
)
|
|
await ws_service.send_error(
|
|
connection_id,
|
|
f"Unknown action: {client_msg.action}",
|
|
"UNKNOWN_ACTION",
|
|
)
|
|
|
|
except WebSocketDisconnect:
|
|
logger.info(
|
|
"WebSocket client disconnected",
|
|
connection_id=connection_id,
|
|
)
|
|
break
|
|
except Exception as e:
|
|
logger.error(
|
|
"Error handling WebSocket message",
|
|
connection_id=connection_id,
|
|
error=str(e),
|
|
)
|
|
await ws_service.send_error(
|
|
connection_id,
|
|
"Internal server error",
|
|
"SERVER_ERROR",
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(
|
|
"WebSocket connection error",
|
|
connection_id=connection_id,
|
|
error=str(e),
|
|
)
|
|
finally:
|
|
# Cleanup connection and rate limit record
|
|
_cleanup_ws_rate_limits(connection_id)
|
|
await ws_service.disconnect(connection_id)
|
|
logger.info("WebSocket connection closed", connection_id=connection_id)
|
|
|
|
|
|
@router.get("/status")
|
|
async def websocket_status(
|
|
ws_service: WebSocketService = Depends(get_websocket_service),
|
|
):
|
|
"""Get WebSocket service status and statistics.
|
|
|
|
Returns information about active connections and rooms.
|
|
Useful for monitoring and debugging.
|
|
"""
|
|
connection_count = await ws_service.manager.get_connection_count()
|
|
|
|
return JSONResponse(
|
|
status_code=status.HTTP_200_OK,
|
|
content={
|
|
"status": "operational",
|
|
"active_connections": connection_count,
|
|
"supported_message_types": [t.value for t in WebSocketMessageType],
|
|
"valid_rooms": sorted(VALID_ROOMS),
|
|
},
|
|
)
|