"""WebSocket API endpoints for real-time communication. This module provides WebSocket endpoints for clients to connect and receive real-time updates about downloads, queue status, and system events. Series Identifier Convention: - `key`: Primary identifier for series (provider-assigned, URL-safe) e.g., "attack-on-titan" - `folder`: Display metadata only (e.g., "Attack on Titan (2013)") All series-related WebSocket events include `key` as the primary identifier in their data payload. The `folder` field is optional for display purposes. """ from __future__ import annotations import time import uuid from typing import Dict, Optional, Set import structlog from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status from fastapi.responses import JSONResponse from src.server.models.websocket import ( ClientMessage, RoomSubscriptionRequest, WebSocketMessageType, ) from src.server.services.websocket_service import ( WebSocketService, get_websocket_service, ) logger = structlog.get_logger(__name__) router = APIRouter(prefix="/ws", tags=["websocket"]) # Valid room names - explicit allow-list for security VALID_ROOMS: Set[str] = { "downloads", # Download progress updates "queue", # Queue status changes "scan", # Scan progress updates "system", # System notifications "errors", # Error notifications } # Rate limiting configuration for WebSocket messages WS_RATE_LIMIT_MESSAGES_PER_MINUTE = 60 WS_RATE_LIMIT_WINDOW_SECONDS = 60 # In-memory rate limiting for WebSocket connections # WARNING: This resets on process restart. For production, consider Redis. _ws_rate_limits: Dict[str, Dict[str, float]] = {} def _check_ws_rate_limit(connection_id: str) -> bool: """Check if a WebSocket connection has exceeded its rate limit. Args: connection_id: Unique identifier for the WebSocket connection Returns: bool: True if within rate limit, False if exceeded """ now = time.time() if connection_id not in _ws_rate_limits: _ws_rate_limits[connection_id] = { "count": 0, "window_start": now, } record = _ws_rate_limits[connection_id] # Reset window if expired if now - record["window_start"] > WS_RATE_LIMIT_WINDOW_SECONDS: record["window_start"] = now record["count"] = 0 record["count"] += 1 return record["count"] <= WS_RATE_LIMIT_MESSAGES_PER_MINUTE def _cleanup_ws_rate_limits(connection_id: str) -> None: """Remove rate limit record for a disconnected connection. Args: connection_id: Unique identifier for the WebSocket connection """ _ws_rate_limits.pop(connection_id, None) def _validate_room_name(room: str) -> bool: """Validate that a room name is in the allowed set. Args: room: Room name to validate Returns: bool: True if room is valid, False otherwise """ return room in VALID_ROOMS @router.websocket("/connect") async def websocket_endpoint( websocket: WebSocket, token: Optional[str] = None, ws_service: WebSocketService = Depends(get_websocket_service), ): """WebSocket endpoint for client connections. Clients connect to this endpoint to receive real-time updates. The connection is maintained until the client disconnects or an error occurs. Authentication: - Optional token can be passed as query parameter: /ws/connect?token= - Unauthenticated connections are allowed but may have limited access Message flow: 1. Client connects 2. Server sends "connected" message 3. Client can send subscription requests (join/leave rooms) 4. Server broadcasts updates to subscribed rooms 5. Client disconnects Example client subscription: ```json { "action": "join", "room": "downloads" } ``` Server message format (series-related events include 'key' identifier): ```json { "type": "download_progress", "timestamp": "2025-10-17T10:30:00.000Z", "data": { "download_id": "abc123", "key": "attack-on-titan", "folder": "Attack on Titan (2013)", "percent": 45.2, "speed_mbps": 2.5, "eta_seconds": 180 } } ``` Note: - `key` is the primary series identifier (provider-assigned, URL-safe) - `folder` is optional display metadata """ connection_id = str(uuid.uuid4()) user_id: Optional[str] = None # Optional: Validate token if provided if token: try: from src.server.services.auth_service import auth_service session = auth_service.create_session_model(token) user_id = session.user_id except Exception as e: logger.warning( "Invalid WebSocket authentication token", connection_id=connection_id, error=str(e), ) try: # Accept connection and register with service await ws_service.connect(websocket, connection_id, user_id=user_id) # Send connection confirmation await ws_service.manager.send_personal_message( { "type": WebSocketMessageType.CONNECTED, "data": { "connection_id": connection_id, "message": "Connected to Aniworld WebSocket", }, }, connection_id, ) logger.info( "WebSocket client connected", connection_id=connection_id, user_id=user_id, ) # Handle incoming messages while True: try: # Receive message from client data = await websocket.receive_json() # Check rate limit if not _check_ws_rate_limit(connection_id): logger.warning( "WebSocket rate limit exceeded", connection_id=connection_id, ) await ws_service.send_error( connection_id, "Rate limit exceeded. Please slow down.", "RATE_LIMIT_EXCEEDED", ) continue # Parse client message try: client_msg = ClientMessage(**data) except Exception as e: logger.warning( "Invalid client message format", connection_id=connection_id, error=str(e), ) await ws_service.send_error( connection_id, "Invalid message format", "INVALID_MESSAGE", ) continue # Handle room subscription requests if client_msg.action in ["join", "leave"]: try: room_name = client_msg.data.get("room", "") # Validate room name against allow-list if not _validate_room_name(room_name): logger.warning( "Invalid room name requested", connection_id=connection_id, room=room_name, ) await ws_service.send_error( connection_id, f"Invalid room name: {room_name}. " f"Valid rooms: {', '.join(sorted(VALID_ROOMS))}", "INVALID_ROOM", ) continue room_req = RoomSubscriptionRequest( action=client_msg.action, room=room_name, ) if room_req.action == "join": await ws_service.manager.join_room( connection_id, room_req.room ) await ws_service.manager.send_personal_message( { "type": WebSocketMessageType.SYSTEM_INFO, "data": { "message": ( f"Joined room: {room_req.room}" ) }, }, connection_id, ) elif room_req.action == "leave": await ws_service.manager.leave_room( connection_id, room_req.room ) await ws_service.manager.send_personal_message( { "type": WebSocketMessageType.SYSTEM_INFO, "data": { "message": ( f"Left room: {room_req.room}" ) }, }, connection_id, ) except Exception as e: logger.warning( "Invalid room subscription request", connection_id=connection_id, error=str(e), ) await ws_service.send_error( connection_id, "Invalid room subscription", "INVALID_SUBSCRIPTION", ) # Handle ping/pong for keepalive elif client_msg.action == "ping": await ws_service.manager.send_personal_message( {"type": WebSocketMessageType.PONG, "data": {}}, connection_id, ) else: logger.debug( "Unknown action from client", connection_id=connection_id, action=client_msg.action, ) await ws_service.send_error( connection_id, f"Unknown action: {client_msg.action}", "UNKNOWN_ACTION", ) except WebSocketDisconnect: logger.info( "WebSocket client disconnected", connection_id=connection_id, ) break except Exception as e: logger.error( "Error handling WebSocket message", connection_id=connection_id, error=str(e), ) await ws_service.send_error( connection_id, "Internal server error", "SERVER_ERROR", ) except Exception as e: logger.error( "WebSocket connection error", connection_id=connection_id, error=str(e), ) finally: # Cleanup connection and rate limit record _cleanup_ws_rate_limits(connection_id) await ws_service.disconnect(connection_id) logger.info("WebSocket connection closed", connection_id=connection_id) @router.get("/status") async def websocket_status( ws_service: WebSocketService = Depends(get_websocket_service), ): """Get WebSocket service status and statistics. Returns information about active connections and rooms. Useful for monitoring and debugging. """ connection_count = await ws_service.manager.get_connection_count() return JSONResponse( status_code=status.HTTP_200_OK, content={ "status": "operational", "active_connections": connection_count, "supported_message_types": [t.value for t in WebSocketMessageType], "valid_rooms": sorted(VALID_ROOMS), }, )