"""WebSocket service for real-time communication with clients. This module provides a comprehensive WebSocket manager for handling real-time updates, connection management, room-based messaging, and broadcast functionality for the Aniworld web application. Series Identifier Convention: - `key`: Primary identifier for series (provider-assigned, URL-safe) e.g., "attack-on-titan" - `folder`: Display metadata only (e.g., "Attack on Titan (2013)") All broadcast methods that handle series-related data should include `key` as the primary identifier in the message payload. The `folder` field is optional and used for display purposes only. """ from __future__ import annotations import asyncio from collections import defaultdict from datetime import datetime, timezone from typing import Any, Dict, List, Optional, Set import structlog from fastapi import WebSocket, WebSocketDisconnect logger = structlog.get_logger(__name__) class WebSocketServiceError(Exception): """Service-level exception for WebSocket operations.""" class ConnectionManager: """Manages WebSocket connections with room-based messaging support. Features: - Connection lifecycle management - Room-based messaging (rooms for specific topics) - Broadcast to all connections or specific rooms - Connection health monitoring - Automatic cleanup on disconnect """ def __init__(self): """Initialize the connection manager.""" # Active connections: connection_id -> WebSocket self._active_connections: Dict[str, WebSocket] = {} # Room memberships: room_name -> set of connection_ids self._rooms: Dict[str, Set[str]] = defaultdict(set) # Connection metadata: connection_id -> metadata dict self._connection_metadata: Dict[str, Dict[str, Any]] = {} # Lock for thread-safe operations self._lock = asyncio.Lock() logger.info("ConnectionManager initialized") async def connect( self, websocket: WebSocket, connection_id: str, metadata: Optional[Dict[str, Any]] = None, ) -> None: """Accept and register a new WebSocket connection. Args: websocket: The WebSocket connection to accept connection_id: Unique identifier for this connection metadata: Optional metadata to associate with the connection """ await websocket.accept() async with self._lock: # If a connection with the same ID already exists, remove it to # prevent stale references during repeated test setups. if connection_id in self._active_connections: try: await self._active_connections[connection_id].close() except Exception: # Ignore errors when closing test mocks pass # cleanup existing data self._active_connections.pop(connection_id, None) self._connection_metadata.pop(connection_id, None) # Remove from any rooms to avoid stale membership for room_members in list(self._rooms.values()): room_members.discard(connection_id) # Remove empty rooms for room in list(self._rooms.keys()): if not self._rooms[room]: del self._rooms[room] self._active_connections[connection_id] = websocket self._connection_metadata[connection_id] = metadata or {} logger.info( "WebSocket connected", connection_id=connection_id, total_connections=len(self._active_connections), ) async def disconnect(self, connection_id: str) -> None: """Remove a WebSocket connection and cleanup associated resources. Args: connection_id: The connection to remove """ async with self._lock: # Remove from all rooms for room_members in self._rooms.values(): room_members.discard(connection_id) # Remove empty rooms (keep as defaultdict) for room in list(self._rooms.keys()): if not self._rooms[room]: del self._rooms[room] # Remove connection and metadata self._active_connections.pop(connection_id, None) self._connection_metadata.pop(connection_id, None) logger.info( "WebSocket disconnected", connection_id=connection_id, total_connections=len(self._active_connections), ) async def join_room(self, connection_id: str, room: str) -> None: """Add a connection to a room. Args: connection_id: The connection to add room: The room name to join """ async with self._lock: if connection_id in self._active_connections: self._rooms[room].add(connection_id) logger.debug( "Connection joined room", connection_id=connection_id, room=room, room_size=len(self._rooms[room]), ) else: logger.warning( "Attempted to join room with inactive connection", connection_id=connection_id, room=room, ) async def leave_room(self, connection_id: str, room: str) -> None: """Remove a connection from a room. Args: connection_id: The connection to remove room: The room name to leave """ async with self._lock: if room in self._rooms: self._rooms[room].discard(connection_id) # Remove empty room if not self._rooms[room]: del self._rooms[room] logger.debug( "Connection left room", connection_id=connection_id, room=room, ) async def send_personal_message( self, message: Dict[str, Any], connection_id: str ) -> None: """Send a message to a specific connection. Args: message: The message to send (will be JSON serialized) connection_id: Target connection identifier """ websocket = self._active_connections.get(connection_id) if websocket is not None: try: await websocket.send_json(message) logger.debug( "Personal message sent", connection_id=connection_id, message_type=message.get("type", "unknown"), ) except WebSocketDisconnect: logger.warning( "Connection disconnected during send", connection_id=connection_id, ) await self.disconnect(connection_id) except Exception as e: logger.error( "Failed to send personal message", connection_id=connection_id, error=str(e), ) else: logger.warning( "Attempted to send message to inactive connection", connection_id=connection_id, ) async def broadcast( self, message: Dict[str, Any], exclude: Optional[Set[str]] = None ) -> None: """Broadcast a message to all active connections. Args: message: The message to broadcast (will be JSON serialized) exclude: Optional set of connection IDs to exclude from broadcast """ exclude = exclude or set() disconnected = [] for connection_id, websocket in self._active_connections.items(): if connection_id in exclude: continue try: await websocket.send_json(message) except WebSocketDisconnect: logger.warning( "Connection disconnected during broadcast", connection_id=connection_id, ) disconnected.append(connection_id) except Exception as e: logger.error( "Failed to broadcast to connection", connection_id=connection_id, error=str(e), ) # Cleanup disconnected connections for connection_id in disconnected: await self.disconnect(connection_id) logger.debug( "Message broadcast", message_type=message.get("type", "unknown"), recipient_count=len(self._active_connections) - len(exclude), failed_count=len(disconnected), ) async def broadcast_to_room( self, message: Dict[str, Any], room: str ) -> None: """Broadcast a message to all connections in a specific room. Args: message: The message to broadcast (will be JSON serialized) room: The room to broadcast to """ room_members = self._rooms.get(room, set()).copy() disconnected = [] for connection_id in room_members: websocket = self._active_connections.get(connection_id) if websocket is None: continue try: await websocket.send_json(message) except WebSocketDisconnect: logger.warning( "Connection disconnected during room broadcast", connection_id=connection_id, room=room, ) disconnected.append(connection_id) except Exception as e: logger.error( "Failed to broadcast to room member", connection_id=connection_id, room=room, error=str(e), ) # Cleanup disconnected connections for connection_id in disconnected: await self.disconnect(connection_id) logger.debug( "Message broadcast to room", room=room, message_type=message.get("type", "unknown"), recipient_count=len(room_members), failed_count=len(disconnected), ) async def get_connection_count(self) -> int: """Get the total number of active connections.""" return len(self._active_connections) async def get_room_members(self, room: str) -> List[str]: """Get list of connection IDs in a specific room.""" return list(self._rooms.get(room, set())) async def get_connection_metadata( self, connection_id: str ) -> Optional[Dict[str, Any]]: """Get metadata associated with a connection.""" return self._connection_metadata.get(connection_id) async def update_connection_metadata( self, connection_id: str, metadata: Dict[str, Any] ) -> None: """Update metadata for a connection.""" if connection_id in self._active_connections: async with self._lock: self._connection_metadata[connection_id].update(metadata) else: logger.warning( "Attempted to update metadata for inactive connection", connection_id=connection_id, ) async def shutdown(self, timeout: float = 5.0) -> None: """Gracefully shutdown all WebSocket connections. Broadcasts a shutdown notification to all clients, then closes each connection with proper close codes. Args: timeout: Maximum time (seconds) to wait for all closes to complete """ logger.info( "Initiating WebSocket shutdown, connections=%d", len(self._active_connections) ) # Broadcast shutdown notification to all clients shutdown_message = { "type": "server_shutdown", "timestamp": datetime.now(timezone.utc).isoformat(), "data": { "message": "Server is shutting down", "reason": "graceful_shutdown", }, } try: await self.broadcast(shutdown_message) except Exception as e: logger.warning("Failed to broadcast shutdown message: %s", e) # Close all connections gracefully async with self._lock: connection_ids = list(self._active_connections.keys()) close_tasks = [] for connection_id in connection_ids: websocket = self._active_connections.get(connection_id) if websocket: close_tasks.append( self._close_connection_gracefully(connection_id, websocket) ) if close_tasks: # Wait for all closes with timeout try: await asyncio.wait_for( asyncio.gather(*close_tasks, return_exceptions=True), timeout=timeout ) except asyncio.TimeoutError: logger.warning( "WebSocket shutdown timed out after %.1f seconds", timeout ) # Clear all data structures async with self._lock: self._active_connections.clear() self._rooms.clear() self._connection_metadata.clear() logger.info("WebSocket shutdown complete") async def _close_connection_gracefully( self, connection_id: str, websocket: WebSocket ) -> None: """Close a single WebSocket connection gracefully. Args: connection_id: The connection identifier websocket: The WebSocket connection to close """ try: # Code 1001 = Going Away (server shutdown) await websocket.close(code=1001, reason="Server shutdown") logger.debug("Closed WebSocket connection: %s", connection_id) except Exception as e: logger.debug( "Error closing WebSocket %s: %s", connection_id, str(e) ) class WebSocketService: """High-level WebSocket service for application-wide messaging. This service provides a convenient interface for broadcasting application events and managing WebSocket connections. It wraps the ConnectionManager with application-specific message types. """ def __init__(self): """Initialize the WebSocket service.""" self._manager = ConnectionManager() logger.info("WebSocketService initialized") @property def manager(self) -> ConnectionManager: """Access the underlying connection manager.""" return self._manager async def connect( self, websocket: WebSocket, connection_id: str, user_id: Optional[str] = None, ) -> None: """Connect a new WebSocket client. Args: websocket: The WebSocket connection connection_id: Unique connection identifier user_id: Optional user identifier for authentication """ metadata = { "connected_at": datetime.now(timezone.utc).isoformat(), "user_id": user_id, } await self._manager.connect(websocket, connection_id, metadata) async def disconnect(self, connection_id: str) -> None: """Disconnect a WebSocket client.""" await self._manager.disconnect(connection_id) async def broadcast_download_progress( self, download_id: str, progress_data: Dict[str, Any] ) -> None: """Broadcast download progress update to all clients. Args: download_id: The download item identifier progress_data: Progress information (percent, speed, etc.) Should include 'key' (series identifier) and optionally 'folder' (display name) Note: The progress_data should include: - key: Series identifier (primary, e.g., 'attack-on-titan') - folder: Series folder name (optional, display only) - percent: Download progress percentage - speed_mbps: Download speed - eta_seconds: Estimated time remaining """ message = { "type": "download_progress", "timestamp": datetime.now(timezone.utc).isoformat(), "data": { "download_id": download_id, **progress_data, }, } await self._manager.broadcast_to_room(message, "downloads") async def broadcast_download_complete( self, download_id: str, result_data: Dict[str, Any] ) -> None: """Broadcast download completion to all clients. Args: download_id: The download item identifier result_data: Download result information Should include 'key' (series identifier) and optionally 'folder' (display name) Note: The result_data should include: - key: Series identifier (primary, e.g., 'attack-on-titan') - folder: Series folder name (optional, display only) - file_path: Path to the downloaded file """ message = { "type": "download_complete", "timestamp": datetime.now(timezone.utc).isoformat(), "data": { "download_id": download_id, **result_data, }, } await self._manager.broadcast_to_room(message, "downloads") async def broadcast_download_failed( self, download_id: str, error_data: Dict[str, Any] ) -> None: """Broadcast download failure to all clients. Args: download_id: The download item identifier error_data: Error information Should include 'key' (series identifier) and optionally 'folder' (display name) Note: The error_data should include: - key: Series identifier (primary, e.g., 'attack-on-titan') - folder: Series folder name (optional, display only) - error_message: Description of the failure """ message = { "type": "download_failed", "timestamp": datetime.now(timezone.utc).isoformat(), "data": { "download_id": download_id, **error_data, }, } await self._manager.broadcast_to_room(message, "downloads") async def broadcast_queue_status( self, status_data: Dict[str, Any] ) -> None: """Broadcast queue status update to all clients. Args: status_data: Queue status information """ message = { "type": "queue_status", "timestamp": datetime.now(timezone.utc).isoformat(), "data": status_data, } await self._manager.broadcast_to_room(message, "downloads") async def broadcast_system_message( self, message_type: str, data: Dict[str, Any] ) -> None: """Broadcast a system message to all clients. Args: message_type: Type of system message data: Message data """ message = { "type": f"system_{message_type}", "timestamp": datetime.now(timezone.utc).isoformat(), "data": data, } await self._manager.broadcast(message) async def send_error( self, connection_id: str, error_message: str, error_code: str = "ERROR" ) -> None: """Send an error message to a specific connection. Args: connection_id: Target connection error_message: Error description error_code: Error code for client handling """ message = { "type": "error", "timestamp": datetime.now(timezone.utc).isoformat(), "data": { "code": error_code, "message": error_message, }, } await self._manager.send_personal_message(message, connection_id) async def broadcast_scan_started( self, directory: str, total_items: int = 0 ) -> None: """Broadcast that a library scan has started. Args: directory: The root directory path being scanned total_items: Total number of items to scan (for progress display) """ message = { "type": "scan_started", "timestamp": datetime.now(timezone.utc).isoformat(), "data": { "directory": directory, "total_items": total_items, }, } await self._manager.broadcast(message) logger.info( "Broadcast scan_started", directory=directory, total_items=total_items, ) async def broadcast_scan_progress( self, directories_scanned: int, files_found: int, current_directory: str, total_items: int = 0, ) -> None: """Broadcast scan progress update to all clients. Args: directories_scanned: Number of directories scanned so far files_found: Number of MP4 files found so far current_directory: Current directory being scanned total_items: Total number of items to scan (for progress display) """ message = { "type": "scan_progress", "timestamp": datetime.now(timezone.utc).isoformat(), "data": { "directories_scanned": directories_scanned, "files_found": files_found, "current_directory": current_directory, "total_items": total_items, }, } await self._manager.broadcast(message) async def broadcast_scan_completed( self, total_directories: int, total_files: int, elapsed_seconds: float, ) -> None: """Broadcast scan completion to all clients. Args: total_directories: Total number of directories scanned total_files: Total number of MP4 files found elapsed_seconds: Time taken for the scan in seconds """ message = { "type": "scan_completed", "timestamp": datetime.now(timezone.utc).isoformat(), "data": { "total_directories": total_directories, "total_files": total_files, "elapsed_seconds": round(elapsed_seconds, 2), }, } await self._manager.broadcast(message) logger.info( "Broadcast scan_completed", total_directories=total_directories, total_files=total_files, elapsed_seconds=round(elapsed_seconds, 2), ) async def shutdown(self, timeout: float = 5.0) -> None: """Gracefully shutdown the WebSocket service. Broadcasts shutdown notification and closes all connections. Args: timeout: Maximum time (seconds) to wait for shutdown """ logger.info("Shutting down WebSocket service...") await self._manager.shutdown(timeout=timeout) logger.info("WebSocket service shutdown complete") # Singleton instance for application-wide access _websocket_service: Optional[WebSocketService] = None def get_websocket_service() -> WebSocketService: """Get or create the singleton WebSocket service instance. Returns: The WebSocket service instance """ global _websocket_service if _websocket_service is None: _websocket_service = WebSocketService() return _websocket_service