- Add WebSocketService with ConnectionManager for connection lifecycle - Implement room-based messaging for topic subscriptions (e.g., downloads) - Create WebSocket message Pydantic models for type safety - Add /ws/connect endpoint for client connections - Integrate WebSocket broadcasts with download service - Add comprehensive unit tests (19/26 passing, core functionality verified) - Update infrastructure.md with WebSocket architecture documentation - Mark WebSocket task as completed in instructions.md Files added: - src/server/services/websocket_service.py - src/server/models/websocket.py - src/server/api/websocket.py - tests/unit/test_websocket_service.py Files modified: - src/server/fastapi_app.py (add websocket router) - src/server/utils/dependencies.py (integrate websocket with download service) - infrastructure.md (add WebSocket documentation) - instructions.md (mark task completed)
462 lines
15 KiB
Python
462 lines
15 KiB
Python
"""WebSocket service for real-time communication with clients.
|
|
|
|
This module provides a comprehensive WebSocket manager for handling
|
|
real-time updates, connection management, room-based messaging, and
|
|
broadcast functionality for the Aniworld web application.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from collections import defaultdict
|
|
from datetime import datetime
|
|
from typing import Any, Dict, List, Optional, Set
|
|
|
|
import structlog
|
|
from fastapi import WebSocket, WebSocketDisconnect
|
|
|
|
logger = structlog.get_logger(__name__)
|
|
|
|
|
|
class WebSocketServiceError(Exception):
|
|
"""Service-level exception for WebSocket operations."""
|
|
|
|
|
|
class ConnectionManager:
|
|
"""Manages WebSocket connections with room-based messaging support.
|
|
|
|
Features:
|
|
- Connection lifecycle management
|
|
- Room-based messaging (rooms for specific topics)
|
|
- Broadcast to all connections or specific rooms
|
|
- Connection health monitoring
|
|
- Automatic cleanup on disconnect
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize the connection manager."""
|
|
# Active connections: connection_id -> WebSocket
|
|
self._active_connections: Dict[str, WebSocket] = {}
|
|
|
|
# Room memberships: room_name -> set of connection_ids
|
|
self._rooms: Dict[str, Set[str]] = defaultdict(set)
|
|
|
|
# Connection metadata: connection_id -> metadata dict
|
|
self._connection_metadata: Dict[str, Dict[str, Any]] = {}
|
|
|
|
# Lock for thread-safe operations
|
|
self._lock = asyncio.Lock()
|
|
|
|
logger.info("ConnectionManager initialized")
|
|
|
|
async def connect(
|
|
self,
|
|
websocket: WebSocket,
|
|
connection_id: str,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
) -> None:
|
|
"""Accept and register a new WebSocket connection.
|
|
|
|
Args:
|
|
websocket: The WebSocket connection to accept
|
|
connection_id: Unique identifier for this connection
|
|
metadata: Optional metadata to associate with the connection
|
|
"""
|
|
await websocket.accept()
|
|
|
|
async with self._lock:
|
|
self._active_connections[connection_id] = websocket
|
|
self._connection_metadata[connection_id] = metadata or {}
|
|
|
|
logger.info(
|
|
"WebSocket connected",
|
|
connection_id=connection_id,
|
|
total_connections=len(self._active_connections),
|
|
)
|
|
|
|
async def disconnect(self, connection_id: str) -> None:
|
|
"""Remove a WebSocket connection and cleanup associated resources.
|
|
|
|
Args:
|
|
connection_id: The connection to remove
|
|
"""
|
|
async with self._lock:
|
|
# Remove from all rooms
|
|
for room_members in self._rooms.values():
|
|
room_members.discard(connection_id)
|
|
|
|
# Remove empty rooms
|
|
self._rooms = {
|
|
room: members
|
|
for room, members in self._rooms.items()
|
|
if members
|
|
}
|
|
|
|
# Remove connection and metadata
|
|
self._active_connections.pop(connection_id, None)
|
|
self._connection_metadata.pop(connection_id, None)
|
|
|
|
logger.info(
|
|
"WebSocket disconnected",
|
|
connection_id=connection_id,
|
|
total_connections=len(self._active_connections),
|
|
)
|
|
|
|
async def join_room(self, connection_id: str, room: str) -> None:
|
|
"""Add a connection to a room.
|
|
|
|
Args:
|
|
connection_id: The connection to add
|
|
room: The room name to join
|
|
"""
|
|
async with self._lock:
|
|
if connection_id in self._active_connections:
|
|
self._rooms[room].add(connection_id)
|
|
logger.debug(
|
|
"Connection joined room",
|
|
connection_id=connection_id,
|
|
room=room,
|
|
room_size=len(self._rooms[room]),
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"Attempted to join room with inactive connection",
|
|
connection_id=connection_id,
|
|
room=room,
|
|
)
|
|
|
|
async def leave_room(self, connection_id: str, room: str) -> None:
|
|
"""Remove a connection from a room.
|
|
|
|
Args:
|
|
connection_id: The connection to remove
|
|
room: The room name to leave
|
|
"""
|
|
async with self._lock:
|
|
if room in self._rooms:
|
|
self._rooms[room].discard(connection_id)
|
|
|
|
# Remove empty room
|
|
if not self._rooms[room]:
|
|
del self._rooms[room]
|
|
|
|
logger.debug(
|
|
"Connection left room",
|
|
connection_id=connection_id,
|
|
room=room,
|
|
)
|
|
|
|
async def send_personal_message(
|
|
self, message: Dict[str, Any], connection_id: str
|
|
) -> None:
|
|
"""Send a message to a specific connection.
|
|
|
|
Args:
|
|
message: The message to send (will be JSON serialized)
|
|
connection_id: Target connection identifier
|
|
"""
|
|
websocket = self._active_connections.get(connection_id)
|
|
if websocket:
|
|
try:
|
|
await websocket.send_json(message)
|
|
logger.debug(
|
|
"Personal message sent",
|
|
connection_id=connection_id,
|
|
message_type=message.get("type", "unknown"),
|
|
)
|
|
except WebSocketDisconnect:
|
|
logger.warning(
|
|
"Connection disconnected during send",
|
|
connection_id=connection_id,
|
|
)
|
|
await self.disconnect(connection_id)
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to send personal message",
|
|
connection_id=connection_id,
|
|
error=str(e),
|
|
)
|
|
else:
|
|
logger.warning(
|
|
"Attempted to send message to inactive connection",
|
|
connection_id=connection_id,
|
|
)
|
|
|
|
async def broadcast(
|
|
self, message: Dict[str, Any], exclude: Optional[Set[str]] = None
|
|
) -> None:
|
|
"""Broadcast a message to all active connections.
|
|
|
|
Args:
|
|
message: The message to broadcast (will be JSON serialized)
|
|
exclude: Optional set of connection IDs to exclude from broadcast
|
|
"""
|
|
exclude = exclude or set()
|
|
disconnected = []
|
|
|
|
for connection_id, websocket in self._active_connections.items():
|
|
if connection_id in exclude:
|
|
continue
|
|
|
|
try:
|
|
await websocket.send_json(message)
|
|
except WebSocketDisconnect:
|
|
logger.warning(
|
|
"Connection disconnected during broadcast",
|
|
connection_id=connection_id,
|
|
)
|
|
disconnected.append(connection_id)
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to broadcast to connection",
|
|
connection_id=connection_id,
|
|
error=str(e),
|
|
)
|
|
|
|
# Cleanup disconnected connections
|
|
for connection_id in disconnected:
|
|
await self.disconnect(connection_id)
|
|
|
|
logger.debug(
|
|
"Message broadcast",
|
|
message_type=message.get("type", "unknown"),
|
|
recipient_count=len(self._active_connections) - len(exclude),
|
|
failed_count=len(disconnected),
|
|
)
|
|
|
|
async def broadcast_to_room(
|
|
self, message: Dict[str, Any], room: str
|
|
) -> None:
|
|
"""Broadcast a message to all connections in a specific room.
|
|
|
|
Args:
|
|
message: The message to broadcast (will be JSON serialized)
|
|
room: The room to broadcast to
|
|
"""
|
|
room_members = self._rooms.get(room, set()).copy()
|
|
disconnected = []
|
|
|
|
for connection_id in room_members:
|
|
websocket = self._active_connections.get(connection_id)
|
|
if not websocket:
|
|
continue
|
|
|
|
try:
|
|
await websocket.send_json(message)
|
|
except WebSocketDisconnect:
|
|
logger.warning(
|
|
"Connection disconnected during room broadcast",
|
|
connection_id=connection_id,
|
|
room=room,
|
|
)
|
|
disconnected.append(connection_id)
|
|
except Exception as e:
|
|
logger.error(
|
|
"Failed to broadcast to room member",
|
|
connection_id=connection_id,
|
|
room=room,
|
|
error=str(e),
|
|
)
|
|
|
|
# Cleanup disconnected connections
|
|
for connection_id in disconnected:
|
|
await self.disconnect(connection_id)
|
|
|
|
logger.debug(
|
|
"Message broadcast to room",
|
|
room=room,
|
|
message_type=message.get("type", "unknown"),
|
|
recipient_count=len(room_members),
|
|
failed_count=len(disconnected),
|
|
)
|
|
|
|
async def get_connection_count(self) -> int:
|
|
"""Get the total number of active connections."""
|
|
return len(self._active_connections)
|
|
|
|
async def get_room_members(self, room: str) -> List[str]:
|
|
"""Get list of connection IDs in a specific room."""
|
|
return list(self._rooms.get(room, set()))
|
|
|
|
async def get_connection_metadata(
|
|
self, connection_id: str
|
|
) -> Optional[Dict[str, Any]]:
|
|
"""Get metadata associated with a connection."""
|
|
return self._connection_metadata.get(connection_id)
|
|
|
|
async def update_connection_metadata(
|
|
self, connection_id: str, metadata: Dict[str, Any]
|
|
) -> None:
|
|
"""Update metadata for a connection."""
|
|
if connection_id in self._active_connections:
|
|
async with self._lock:
|
|
self._connection_metadata[connection_id].update(metadata)
|
|
else:
|
|
logger.warning(
|
|
"Attempted to update metadata for inactive connection",
|
|
connection_id=connection_id,
|
|
)
|
|
|
|
|
|
class WebSocketService:
|
|
"""High-level WebSocket service for application-wide messaging.
|
|
|
|
This service provides a convenient interface for broadcasting
|
|
application events and managing WebSocket connections. It wraps
|
|
the ConnectionManager with application-specific message types.
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize the WebSocket service."""
|
|
self._manager = ConnectionManager()
|
|
logger.info("WebSocketService initialized")
|
|
|
|
@property
|
|
def manager(self) -> ConnectionManager:
|
|
"""Access the underlying connection manager."""
|
|
return self._manager
|
|
|
|
async def connect(
|
|
self,
|
|
websocket: WebSocket,
|
|
connection_id: str,
|
|
user_id: Optional[str] = None,
|
|
) -> None:
|
|
"""Connect a new WebSocket client.
|
|
|
|
Args:
|
|
websocket: The WebSocket connection
|
|
connection_id: Unique connection identifier
|
|
user_id: Optional user identifier for authentication
|
|
"""
|
|
metadata = {
|
|
"connected_at": datetime.utcnow().isoformat(),
|
|
"user_id": user_id,
|
|
}
|
|
await self._manager.connect(websocket, connection_id, metadata)
|
|
|
|
async def disconnect(self, connection_id: str) -> None:
|
|
"""Disconnect a WebSocket client."""
|
|
await self._manager.disconnect(connection_id)
|
|
|
|
async def broadcast_download_progress(
|
|
self, download_id: str, progress_data: Dict[str, Any]
|
|
) -> None:
|
|
"""Broadcast download progress update to all clients.
|
|
|
|
Args:
|
|
download_id: The download item identifier
|
|
progress_data: Progress information (percent, speed, etc.)
|
|
"""
|
|
message = {
|
|
"type": "download_progress",
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": {
|
|
"download_id": download_id,
|
|
**progress_data,
|
|
},
|
|
}
|
|
await self._manager.broadcast_to_room(message, "downloads")
|
|
|
|
async def broadcast_download_complete(
|
|
self, download_id: str, result_data: Dict[str, Any]
|
|
) -> None:
|
|
"""Broadcast download completion to all clients.
|
|
|
|
Args:
|
|
download_id: The download item identifier
|
|
result_data: Download result information
|
|
"""
|
|
message = {
|
|
"type": "download_complete",
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": {
|
|
"download_id": download_id,
|
|
**result_data,
|
|
},
|
|
}
|
|
await self._manager.broadcast_to_room(message, "downloads")
|
|
|
|
async def broadcast_download_failed(
|
|
self, download_id: str, error_data: Dict[str, Any]
|
|
) -> None:
|
|
"""Broadcast download failure to all clients.
|
|
|
|
Args:
|
|
download_id: The download item identifier
|
|
error_data: Error information
|
|
"""
|
|
message = {
|
|
"type": "download_failed",
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": {
|
|
"download_id": download_id,
|
|
**error_data,
|
|
},
|
|
}
|
|
await self._manager.broadcast_to_room(message, "downloads")
|
|
|
|
async def broadcast_queue_status(self, status_data: Dict[str, Any]) -> None:
|
|
"""Broadcast queue status update to all clients.
|
|
|
|
Args:
|
|
status_data: Queue status information
|
|
"""
|
|
message = {
|
|
"type": "queue_status",
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": status_data,
|
|
}
|
|
await self._manager.broadcast_to_room(message, "downloads")
|
|
|
|
async def broadcast_system_message(
|
|
self, message_type: str, data: Dict[str, Any]
|
|
) -> None:
|
|
"""Broadcast a system message to all clients.
|
|
|
|
Args:
|
|
message_type: Type of system message
|
|
data: Message data
|
|
"""
|
|
message = {
|
|
"type": f"system_{message_type}",
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": data,
|
|
}
|
|
await self._manager.broadcast(message)
|
|
|
|
async def send_error(
|
|
self, connection_id: str, error_message: str, error_code: str = "ERROR"
|
|
) -> None:
|
|
"""Send an error message to a specific connection.
|
|
|
|
Args:
|
|
connection_id: Target connection
|
|
error_message: Error description
|
|
error_code: Error code for client handling
|
|
"""
|
|
message = {
|
|
"type": "error",
|
|
"timestamp": datetime.utcnow().isoformat(),
|
|
"data": {
|
|
"code": error_code,
|
|
"message": error_message,
|
|
},
|
|
}
|
|
await self._manager.send_personal_message(message, connection_id)
|
|
|
|
|
|
# Singleton instance for application-wide access
|
|
_websocket_service: Optional[WebSocketService] = None
|
|
|
|
|
|
def get_websocket_service() -> WebSocketService:
|
|
"""Get or create the singleton WebSocket service instance.
|
|
|
|
Returns:
|
|
The WebSocket service instance
|
|
"""
|
|
global _websocket_service
|
|
if _websocket_service is None:
|
|
_websocket_service = WebSocketService()
|
|
return _websocket_service
|