Aniworld/src/server/services/websocket_service.py
Lukas d70d70e193 feat: implement graceful shutdown with SIGINT/SIGTERM support
- Add WebSocket shutdown() with client notification and graceful close
- Enhance download service stop() with pending state persistence
- Expand FastAPI lifespan shutdown with proper cleanup sequence
- Add SQLite WAL checkpoint before database close
- Update stop_server.sh to use SIGTERM with timeout fallback
- Configure uvicorn timeout_graceful_shutdown=30s
- Update ARCHITECTURE.md with shutdown documentation
2025-12-25 18:59:07 +01:00

688 lines
24 KiB
Python

"""WebSocket service for real-time communication with clients.
This module provides a comprehensive WebSocket manager for handling
real-time updates, connection management, room-based messaging, and
broadcast functionality for the Aniworld web application.
Series Identifier Convention:
- `key`: Primary identifier for series (provider-assigned, URL-safe)
e.g., "attack-on-titan"
- `folder`: Display metadata only (e.g., "Attack on Titan (2013)")
All broadcast methods that handle series-related data should include `key`
as the primary identifier in the message payload. The `folder` field is
optional and used for display purposes only.
"""
from __future__ import annotations
import asyncio
from collections import defaultdict
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, Set
import structlog
from fastapi import WebSocket, WebSocketDisconnect
logger = structlog.get_logger(__name__)
class WebSocketServiceError(Exception):
"""Service-level exception for WebSocket operations."""
class ConnectionManager:
"""Manages WebSocket connections with room-based messaging support.
Features:
- Connection lifecycle management
- Room-based messaging (rooms for specific topics)
- Broadcast to all connections or specific rooms
- Connection health monitoring
- Automatic cleanup on disconnect
"""
def __init__(self):
"""Initialize the connection manager."""
# Active connections: connection_id -> WebSocket
self._active_connections: Dict[str, WebSocket] = {}
# Room memberships: room_name -> set of connection_ids
self._rooms: Dict[str, Set[str]] = defaultdict(set)
# Connection metadata: connection_id -> metadata dict
self._connection_metadata: Dict[str, Dict[str, Any]] = {}
# Lock for thread-safe operations
self._lock = asyncio.Lock()
logger.info("ConnectionManager initialized")
async def connect(
self,
websocket: WebSocket,
connection_id: str,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Accept and register a new WebSocket connection.
Args:
websocket: The WebSocket connection to accept
connection_id: Unique identifier for this connection
metadata: Optional metadata to associate with the connection
"""
await websocket.accept()
async with self._lock:
# If a connection with the same ID already exists, remove it to
# prevent stale references during repeated test setups.
if connection_id in self._active_connections:
try:
await self._active_connections[connection_id].close()
except Exception:
# Ignore errors when closing test mocks
pass
# cleanup existing data
self._active_connections.pop(connection_id, None)
self._connection_metadata.pop(connection_id, None)
# Remove from any rooms to avoid stale membership
for room_members in list(self._rooms.values()):
room_members.discard(connection_id)
# Remove empty rooms
for room in list(self._rooms.keys()):
if not self._rooms[room]:
del self._rooms[room]
self._active_connections[connection_id] = websocket
self._connection_metadata[connection_id] = metadata or {}
logger.info(
"WebSocket connected",
connection_id=connection_id,
total_connections=len(self._active_connections),
)
async def disconnect(self, connection_id: str) -> None:
"""Remove a WebSocket connection and cleanup associated resources.
Args:
connection_id: The connection to remove
"""
async with self._lock:
# Remove from all rooms
for room_members in self._rooms.values():
room_members.discard(connection_id)
# Remove empty rooms (keep as defaultdict)
for room in list(self._rooms.keys()):
if not self._rooms[room]:
del self._rooms[room]
# Remove connection and metadata
self._active_connections.pop(connection_id, None)
self._connection_metadata.pop(connection_id, None)
logger.info(
"WebSocket disconnected",
connection_id=connection_id,
total_connections=len(self._active_connections),
)
async def join_room(self, connection_id: str, room: str) -> None:
"""Add a connection to a room.
Args:
connection_id: The connection to add
room: The room name to join
"""
async with self._lock:
if connection_id in self._active_connections:
self._rooms[room].add(connection_id)
logger.debug(
"Connection joined room",
connection_id=connection_id,
room=room,
room_size=len(self._rooms[room]),
)
else:
logger.warning(
"Attempted to join room with inactive connection",
connection_id=connection_id,
room=room,
)
async def leave_room(self, connection_id: str, room: str) -> None:
"""Remove a connection from a room.
Args:
connection_id: The connection to remove
room: The room name to leave
"""
async with self._lock:
if room in self._rooms:
self._rooms[room].discard(connection_id)
# Remove empty room
if not self._rooms[room]:
del self._rooms[room]
logger.debug(
"Connection left room",
connection_id=connection_id,
room=room,
)
async def send_personal_message(
self, message: Dict[str, Any], connection_id: str
) -> None:
"""Send a message to a specific connection.
Args:
message: The message to send (will be JSON serialized)
connection_id: Target connection identifier
"""
websocket = self._active_connections.get(connection_id)
if websocket is not None:
try:
await websocket.send_json(message)
logger.debug(
"Personal message sent",
connection_id=connection_id,
message_type=message.get("type", "unknown"),
)
except WebSocketDisconnect:
logger.warning(
"Connection disconnected during send",
connection_id=connection_id,
)
await self.disconnect(connection_id)
except Exception as e:
logger.error(
"Failed to send personal message",
connection_id=connection_id,
error=str(e),
)
else:
logger.warning(
"Attempted to send message to inactive connection",
connection_id=connection_id,
)
async def broadcast(
self, message: Dict[str, Any], exclude: Optional[Set[str]] = None
) -> None:
"""Broadcast a message to all active connections.
Args:
message: The message to broadcast (will be JSON serialized)
exclude: Optional set of connection IDs to exclude from broadcast
"""
exclude = exclude or set()
disconnected = []
for connection_id, websocket in self._active_connections.items():
if connection_id in exclude:
continue
try:
await websocket.send_json(message)
except WebSocketDisconnect:
logger.warning(
"Connection disconnected during broadcast",
connection_id=connection_id,
)
disconnected.append(connection_id)
except Exception as e:
logger.error(
"Failed to broadcast to connection",
connection_id=connection_id,
error=str(e),
)
# Cleanup disconnected connections
for connection_id in disconnected:
await self.disconnect(connection_id)
logger.debug(
"Message broadcast",
message_type=message.get("type", "unknown"),
recipient_count=len(self._active_connections) - len(exclude),
failed_count=len(disconnected),
)
async def broadcast_to_room(
self, message: Dict[str, Any], room: str
) -> None:
"""Broadcast a message to all connections in a specific room.
Args:
message: The message to broadcast (will be JSON serialized)
room: The room to broadcast to
"""
room_members = self._rooms.get(room, set()).copy()
disconnected = []
for connection_id in room_members:
websocket = self._active_connections.get(connection_id)
if websocket is None:
continue
try:
await websocket.send_json(message)
except WebSocketDisconnect:
logger.warning(
"Connection disconnected during room broadcast",
connection_id=connection_id,
room=room,
)
disconnected.append(connection_id)
except Exception as e:
logger.error(
"Failed to broadcast to room member",
connection_id=connection_id,
room=room,
error=str(e),
)
# Cleanup disconnected connections
for connection_id in disconnected:
await self.disconnect(connection_id)
logger.debug(
"Message broadcast to room",
room=room,
message_type=message.get("type", "unknown"),
recipient_count=len(room_members),
failed_count=len(disconnected),
)
async def get_connection_count(self) -> int:
"""Get the total number of active connections."""
return len(self._active_connections)
async def get_room_members(self, room: str) -> List[str]:
"""Get list of connection IDs in a specific room."""
return list(self._rooms.get(room, set()))
async def get_connection_metadata(
self, connection_id: str
) -> Optional[Dict[str, Any]]:
"""Get metadata associated with a connection."""
return self._connection_metadata.get(connection_id)
async def update_connection_metadata(
self, connection_id: str, metadata: Dict[str, Any]
) -> None:
"""Update metadata for a connection."""
if connection_id in self._active_connections:
async with self._lock:
self._connection_metadata[connection_id].update(metadata)
else:
logger.warning(
"Attempted to update metadata for inactive connection",
connection_id=connection_id,
)
async def shutdown(self, timeout: float = 5.0) -> None:
"""Gracefully shutdown all WebSocket connections.
Broadcasts a shutdown notification to all clients, then closes
each connection with proper close codes.
Args:
timeout: Maximum time (seconds) to wait for all closes to complete
"""
logger.info(
"Initiating WebSocket shutdown, connections=%d",
len(self._active_connections)
)
# Broadcast shutdown notification to all clients
shutdown_message = {
"type": "server_shutdown",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"message": "Server is shutting down",
"reason": "graceful_shutdown",
},
}
try:
await self.broadcast(shutdown_message)
except Exception as e:
logger.warning("Failed to broadcast shutdown message: %s", e)
# Close all connections gracefully
async with self._lock:
connection_ids = list(self._active_connections.keys())
close_tasks = []
for connection_id in connection_ids:
websocket = self._active_connections.get(connection_id)
if websocket:
close_tasks.append(
self._close_connection_gracefully(connection_id, websocket)
)
if close_tasks:
# Wait for all closes with timeout
try:
await asyncio.wait_for(
asyncio.gather(*close_tasks, return_exceptions=True),
timeout=timeout
)
except asyncio.TimeoutError:
logger.warning(
"WebSocket shutdown timed out after %.1f seconds", timeout
)
# Clear all data structures
async with self._lock:
self._active_connections.clear()
self._rooms.clear()
self._connection_metadata.clear()
logger.info("WebSocket shutdown complete")
async def _close_connection_gracefully(
self, connection_id: str, websocket: WebSocket
) -> None:
"""Close a single WebSocket connection gracefully.
Args:
connection_id: The connection identifier
websocket: The WebSocket connection to close
"""
try:
# Code 1001 = Going Away (server shutdown)
await websocket.close(code=1001, reason="Server shutdown")
logger.debug("Closed WebSocket connection: %s", connection_id)
except Exception as e:
logger.debug(
"Error closing WebSocket %s: %s", connection_id, str(e)
)
class WebSocketService:
"""High-level WebSocket service for application-wide messaging.
This service provides a convenient interface for broadcasting
application events and managing WebSocket connections. It wraps
the ConnectionManager with application-specific message types.
"""
def __init__(self):
"""Initialize the WebSocket service."""
self._manager = ConnectionManager()
logger.info("WebSocketService initialized")
@property
def manager(self) -> ConnectionManager:
"""Access the underlying connection manager."""
return self._manager
async def connect(
self,
websocket: WebSocket,
connection_id: str,
user_id: Optional[str] = None,
) -> None:
"""Connect a new WebSocket client.
Args:
websocket: The WebSocket connection
connection_id: Unique connection identifier
user_id: Optional user identifier for authentication
"""
metadata = {
"connected_at": datetime.now(timezone.utc).isoformat(),
"user_id": user_id,
}
await self._manager.connect(websocket, connection_id, metadata)
async def disconnect(self, connection_id: str) -> None:
"""Disconnect a WebSocket client."""
await self._manager.disconnect(connection_id)
async def broadcast_download_progress(
self, download_id: str, progress_data: Dict[str, Any]
) -> None:
"""Broadcast download progress update to all clients.
Args:
download_id: The download item identifier
progress_data: Progress information (percent, speed, etc.)
Should include 'key' (series identifier) and
optionally 'folder' (display name)
Note:
The progress_data should include:
- key: Series identifier (primary, e.g., 'attack-on-titan')
- folder: Series folder name (optional, display only)
- percent: Download progress percentage
- speed_mbps: Download speed
- eta_seconds: Estimated time remaining
"""
message = {
"type": "download_progress",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"download_id": download_id,
**progress_data,
},
}
await self._manager.broadcast_to_room(message, "downloads")
async def broadcast_download_complete(
self, download_id: str, result_data: Dict[str, Any]
) -> None:
"""Broadcast download completion to all clients.
Args:
download_id: The download item identifier
result_data: Download result information
Should include 'key' (series identifier) and
optionally 'folder' (display name)
Note:
The result_data should include:
- key: Series identifier (primary, e.g., 'attack-on-titan')
- folder: Series folder name (optional, display only)
- file_path: Path to the downloaded file
"""
message = {
"type": "download_complete",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"download_id": download_id,
**result_data,
},
}
await self._manager.broadcast_to_room(message, "downloads")
async def broadcast_download_failed(
self, download_id: str, error_data: Dict[str, Any]
) -> None:
"""Broadcast download failure to all clients.
Args:
download_id: The download item identifier
error_data: Error information
Should include 'key' (series identifier) and
optionally 'folder' (display name)
Note:
The error_data should include:
- key: Series identifier (primary, e.g., 'attack-on-titan')
- folder: Series folder name (optional, display only)
- error_message: Description of the failure
"""
message = {
"type": "download_failed",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"download_id": download_id,
**error_data,
},
}
await self._manager.broadcast_to_room(message, "downloads")
async def broadcast_queue_status(
self, status_data: Dict[str, Any]
) -> None:
"""Broadcast queue status update to all clients.
Args:
status_data: Queue status information
"""
message = {
"type": "queue_status",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": status_data,
}
await self._manager.broadcast_to_room(message, "downloads")
async def broadcast_system_message(
self, message_type: str, data: Dict[str, Any]
) -> None:
"""Broadcast a system message to all clients.
Args:
message_type: Type of system message
data: Message data
"""
message = {
"type": f"system_{message_type}",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": data,
}
await self._manager.broadcast(message)
async def send_error(
self, connection_id: str, error_message: str, error_code: str = "ERROR"
) -> None:
"""Send an error message to a specific connection.
Args:
connection_id: Target connection
error_message: Error description
error_code: Error code for client handling
"""
message = {
"type": "error",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"code": error_code,
"message": error_message,
},
}
await self._manager.send_personal_message(message, connection_id)
async def broadcast_scan_started(
self, directory: str, total_items: int = 0
) -> None:
"""Broadcast that a library scan has started.
Args:
directory: The root directory path being scanned
total_items: Total number of items to scan (for progress display)
"""
message = {
"type": "scan_started",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"directory": directory,
"total_items": total_items,
},
}
await self._manager.broadcast(message)
logger.info(
"Broadcast scan_started",
directory=directory,
total_items=total_items,
)
async def broadcast_scan_progress(
self,
directories_scanned: int,
files_found: int,
current_directory: str,
total_items: int = 0,
) -> None:
"""Broadcast scan progress update to all clients.
Args:
directories_scanned: Number of directories scanned so far
files_found: Number of MP4 files found so far
current_directory: Current directory being scanned
total_items: Total number of items to scan (for progress display)
"""
message = {
"type": "scan_progress",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"directories_scanned": directories_scanned,
"files_found": files_found,
"current_directory": current_directory,
"total_items": total_items,
},
}
await self._manager.broadcast(message)
async def broadcast_scan_completed(
self,
total_directories: int,
total_files: int,
elapsed_seconds: float,
) -> None:
"""Broadcast scan completion to all clients.
Args:
total_directories: Total number of directories scanned
total_files: Total number of MP4 files found
elapsed_seconds: Time taken for the scan in seconds
"""
message = {
"type": "scan_completed",
"timestamp": datetime.now(timezone.utc).isoformat(),
"data": {
"total_directories": total_directories,
"total_files": total_files,
"elapsed_seconds": round(elapsed_seconds, 2),
},
}
await self._manager.broadcast(message)
logger.info(
"Broadcast scan_completed",
total_directories=total_directories,
total_files=total_files,
elapsed_seconds=round(elapsed_seconds, 2),
)
async def shutdown(self, timeout: float = 5.0) -> None:
"""Gracefully shutdown the WebSocket service.
Broadcasts shutdown notification and closes all connections.
Args:
timeout: Maximum time (seconds) to wait for shutdown
"""
logger.info("Shutting down WebSocket service...")
await self._manager.shutdown(timeout=timeout)
logger.info("WebSocket service shutdown complete")
# Singleton instance for application-wide access
_websocket_service: Optional[WebSocketService] = None
def get_websocket_service() -> WebSocketService:
"""Get or create the singleton WebSocket service instance.
Returns:
The WebSocket service instance
"""
global _websocket_service
if _websocket_service is None:
_websocket_service = WebSocketService()
return _websocket_service