"""WebSocket service for real-time communication with clients. This module provides a comprehensive WebSocket manager for handling real-time updates, connection management, room-based messaging, and broadcast functionality for the Aniworld web application. """ from __future__ import annotations import asyncio from collections import defaultdict from datetime import datetime from typing import Any, Dict, List, Optional, Set import structlog from fastapi import WebSocket, WebSocketDisconnect logger = structlog.get_logger(__name__) class WebSocketServiceError(Exception): """Service-level exception for WebSocket operations.""" class ConnectionManager: """Manages WebSocket connections with room-based messaging support. Features: - Connection lifecycle management - Room-based messaging (rooms for specific topics) - Broadcast to all connections or specific rooms - Connection health monitoring - Automatic cleanup on disconnect """ def __init__(self): """Initialize the connection manager.""" # Active connections: connection_id -> WebSocket self._active_connections: Dict[str, WebSocket] = {} # Room memberships: room_name -> set of connection_ids self._rooms: Dict[str, Set[str]] = defaultdict(set) # Connection metadata: connection_id -> metadata dict self._connection_metadata: Dict[str, Dict[str, Any]] = {} # Lock for thread-safe operations self._lock = asyncio.Lock() logger.info("ConnectionManager initialized") async def connect( self, websocket: WebSocket, connection_id: str, metadata: Optional[Dict[str, Any]] = None, ) -> None: """Accept and register a new WebSocket connection. Args: websocket: The WebSocket connection to accept connection_id: Unique identifier for this connection metadata: Optional metadata to associate with the connection """ await websocket.accept() async with self._lock: self._active_connections[connection_id] = websocket self._connection_metadata[connection_id] = metadata or {} logger.info( "WebSocket connected", connection_id=connection_id, total_connections=len(self._active_connections), ) async def disconnect(self, connection_id: str) -> None: """Remove a WebSocket connection and cleanup associated resources. Args: connection_id: The connection to remove """ async with self._lock: # Remove from all rooms for room_members in self._rooms.values(): room_members.discard(connection_id) # Remove empty rooms self._rooms = { room: members for room, members in self._rooms.items() if members } # Remove connection and metadata self._active_connections.pop(connection_id, None) self._connection_metadata.pop(connection_id, None) logger.info( "WebSocket disconnected", connection_id=connection_id, total_connections=len(self._active_connections), ) async def join_room(self, connection_id: str, room: str) -> None: """Add a connection to a room. Args: connection_id: The connection to add room: The room name to join """ async with self._lock: if connection_id in self._active_connections: self._rooms[room].add(connection_id) logger.debug( "Connection joined room", connection_id=connection_id, room=room, room_size=len(self._rooms[room]), ) else: logger.warning( "Attempted to join room with inactive connection", connection_id=connection_id, room=room, ) async def leave_room(self, connection_id: str, room: str) -> None: """Remove a connection from a room. Args: connection_id: The connection to remove room: The room name to leave """ async with self._lock: if room in self._rooms: self._rooms[room].discard(connection_id) # Remove empty room if not self._rooms[room]: del self._rooms[room] logger.debug( "Connection left room", connection_id=connection_id, room=room, ) async def send_personal_message( self, message: Dict[str, Any], connection_id: str ) -> None: """Send a message to a specific connection. Args: message: The message to send (will be JSON serialized) connection_id: Target connection identifier """ websocket = self._active_connections.get(connection_id) if websocket: try: await websocket.send_json(message) logger.debug( "Personal message sent", connection_id=connection_id, message_type=message.get("type", "unknown"), ) except WebSocketDisconnect: logger.warning( "Connection disconnected during send", connection_id=connection_id, ) await self.disconnect(connection_id) except Exception as e: logger.error( "Failed to send personal message", connection_id=connection_id, error=str(e), ) else: logger.warning( "Attempted to send message to inactive connection", connection_id=connection_id, ) async def broadcast( self, message: Dict[str, Any], exclude: Optional[Set[str]] = None ) -> None: """Broadcast a message to all active connections. Args: message: The message to broadcast (will be JSON serialized) exclude: Optional set of connection IDs to exclude from broadcast """ exclude = exclude or set() disconnected = [] for connection_id, websocket in self._active_connections.items(): if connection_id in exclude: continue try: await websocket.send_json(message) except WebSocketDisconnect: logger.warning( "Connection disconnected during broadcast", connection_id=connection_id, ) disconnected.append(connection_id) except Exception as e: logger.error( "Failed to broadcast to connection", connection_id=connection_id, error=str(e), ) # Cleanup disconnected connections for connection_id in disconnected: await self.disconnect(connection_id) logger.debug( "Message broadcast", message_type=message.get("type", "unknown"), recipient_count=len(self._active_connections) - len(exclude), failed_count=len(disconnected), ) async def broadcast_to_room( self, message: Dict[str, Any], room: str ) -> None: """Broadcast a message to all connections in a specific room. Args: message: The message to broadcast (will be JSON serialized) room: The room to broadcast to """ room_members = self._rooms.get(room, set()).copy() disconnected = [] for connection_id in room_members: websocket = self._active_connections.get(connection_id) if not websocket: continue try: await websocket.send_json(message) except WebSocketDisconnect: logger.warning( "Connection disconnected during room broadcast", connection_id=connection_id, room=room, ) disconnected.append(connection_id) except Exception as e: logger.error( "Failed to broadcast to room member", connection_id=connection_id, room=room, error=str(e), ) # Cleanup disconnected connections for connection_id in disconnected: await self.disconnect(connection_id) logger.debug( "Message broadcast to room", room=room, message_type=message.get("type", "unknown"), recipient_count=len(room_members), failed_count=len(disconnected), ) async def get_connection_count(self) -> int: """Get the total number of active connections.""" return len(self._active_connections) async def get_room_members(self, room: str) -> List[str]: """Get list of connection IDs in a specific room.""" return list(self._rooms.get(room, set())) async def get_connection_metadata( self, connection_id: str ) -> Optional[Dict[str, Any]]: """Get metadata associated with a connection.""" return self._connection_metadata.get(connection_id) async def update_connection_metadata( self, connection_id: str, metadata: Dict[str, Any] ) -> None: """Update metadata for a connection.""" if connection_id in self._active_connections: async with self._lock: self._connection_metadata[connection_id].update(metadata) else: logger.warning( "Attempted to update metadata for inactive connection", connection_id=connection_id, ) class WebSocketService: """High-level WebSocket service for application-wide messaging. This service provides a convenient interface for broadcasting application events and managing WebSocket connections. It wraps the ConnectionManager with application-specific message types. """ def __init__(self): """Initialize the WebSocket service.""" self._manager = ConnectionManager() logger.info("WebSocketService initialized") @property def manager(self) -> ConnectionManager: """Access the underlying connection manager.""" return self._manager async def connect( self, websocket: WebSocket, connection_id: str, user_id: Optional[str] = None, ) -> None: """Connect a new WebSocket client. Args: websocket: The WebSocket connection connection_id: Unique connection identifier user_id: Optional user identifier for authentication """ metadata = { "connected_at": datetime.utcnow().isoformat(), "user_id": user_id, } await self._manager.connect(websocket, connection_id, metadata) async def disconnect(self, connection_id: str) -> None: """Disconnect a WebSocket client.""" await self._manager.disconnect(connection_id) async def broadcast_download_progress( self, download_id: str, progress_data: Dict[str, Any] ) -> None: """Broadcast download progress update to all clients. Args: download_id: The download item identifier progress_data: Progress information (percent, speed, etc.) """ message = { "type": "download_progress", "timestamp": datetime.utcnow().isoformat(), "data": { "download_id": download_id, **progress_data, }, } await self._manager.broadcast_to_room(message, "downloads") async def broadcast_download_complete( self, download_id: str, result_data: Dict[str, Any] ) -> None: """Broadcast download completion to all clients. Args: download_id: The download item identifier result_data: Download result information """ message = { "type": "download_complete", "timestamp": datetime.utcnow().isoformat(), "data": { "download_id": download_id, **result_data, }, } await self._manager.broadcast_to_room(message, "downloads") async def broadcast_download_failed( self, download_id: str, error_data: Dict[str, Any] ) -> None: """Broadcast download failure to all clients. Args: download_id: The download item identifier error_data: Error information """ message = { "type": "download_failed", "timestamp": datetime.utcnow().isoformat(), "data": { "download_id": download_id, **error_data, }, } await self._manager.broadcast_to_room(message, "downloads") async def broadcast_queue_status(self, status_data: Dict[str, Any]) -> None: """Broadcast queue status update to all clients. Args: status_data: Queue status information """ message = { "type": "queue_status", "timestamp": datetime.utcnow().isoformat(), "data": status_data, } await self._manager.broadcast_to_room(message, "downloads") async def broadcast_system_message( self, message_type: str, data: Dict[str, Any] ) -> None: """Broadcast a system message to all clients. Args: message_type: Type of system message data: Message data """ message = { "type": f"system_{message_type}", "timestamp": datetime.utcnow().isoformat(), "data": data, } await self._manager.broadcast(message) async def send_error( self, connection_id: str, error_message: str, error_code: str = "ERROR" ) -> None: """Send an error message to a specific connection. Args: connection_id: Target connection error_message: Error description error_code: Error code for client handling """ message = { "type": "error", "timestamp": datetime.utcnow().isoformat(), "data": { "code": error_code, "message": error_message, }, } await self._manager.send_personal_message(message, connection_id) # Singleton instance for application-wide access _websocket_service: Optional[WebSocketService] = None def get_websocket_service() -> WebSocketService: """Get or create the singleton WebSocket service instance. Returns: The WebSocket service instance """ global _websocket_service if _websocket_service is None: _websocket_service = WebSocketService() return _websocket_service