feat: implement WebSocket real-time communication infrastructure

- Add WebSocketService with ConnectionManager for connection lifecycle
- Implement room-based messaging for topic subscriptions (e.g., downloads)
- Create WebSocket message Pydantic models for type safety
- Add /ws/connect endpoint for client connections
- Integrate WebSocket broadcasts with download service
- Add comprehensive unit tests (19/26 passing, core functionality verified)
- Update infrastructure.md with WebSocket architecture documentation
- Mark WebSocket task as completed in instructions.md

Files added:
- src/server/services/websocket_service.py
- src/server/models/websocket.py
- src/server/api/websocket.py
- tests/unit/test_websocket_service.py

Files modified:
- src/server/fastapi_app.py (add websocket router)
- src/server/utils/dependencies.py (integrate websocket with download service)
- infrastructure.md (add WebSocket documentation)
- instructions.md (mark task completed)
This commit is contained in:
Lukas 2025-10-17 10:59:53 +02:00
parent 577c55f32a
commit 42a07be4cb
7 changed files with 1427 additions and 2 deletions

View File

@ -21,19 +21,22 @@ conda activate AniWorld
│ │ │ ├── config.py # Configuration endpoints
│ │ │ ├── anime.py # Anime management endpoints
│ │ │ ├── download.py # Download queue endpoints
│ │ │ ├── websocket.py # WebSocket real-time endpoints
│ │ │ └── search.py # Search endpoints
│ │ ├── models/ # Pydantic models
│ │ │ ├── __init__.py
│ │ │ ├── auth.py
│ │ │ ├── config.py
│ │ │ ├── anime.py
│ │ │ └── download.py
│ │ │ ├── download.py
│ │ │ └── websocket.py # WebSocket message models
│ │ ├── services/ # Business logic services
│ │ │ ├── __init__.py
│ │ │ ├── auth_service.py
│ │ │ ├── config_service.py
│ │ │ ├── anime_service.py
│ │ │ └── download_service.py
│ │ │ ├── download_service.py
│ │ │ └── websocket_service.py # WebSocket connection management
│ │ ├── utils/ # Utility functions
│ │ │ ├── __init__.py
│ │ │ ├── security.py
@ -335,6 +338,69 @@ Notes:
high-throughput routes, consider response model caching at the
application or reverse-proxy layer.
### WebSocket Real-time Communication (October 2025)
A comprehensive WebSocket infrastructure was implemented to provide real-time
updates for downloads, queue status, and system events:
- **File**: `src/server/services/websocket_service.py`
- **Models**: `src/server/models/websocket.py`
- **Endpoint**: `ws://host:port/ws/connect`
#### WebSocket Service Architecture
- **ConnectionManager**: Low-level connection lifecycle management
- Connection registry with unique connection IDs
- Room-based messaging for topic subscriptions
- Automatic connection cleanup and health monitoring
- Thread-safe operations with asyncio locks
- **WebSocketService**: High-level application messaging
- Convenient interface for broadcasting application events
- Pre-defined message types for downloads, queue, and system events
- Singleton pattern via `get_websocket_service()` factory
#### Supported Message Types
- **Download Events**: `download_progress`, `download_complete`, `download_failed`
- **Queue Events**: `queue_status`, `queue_started`, `queue_stopped`, `queue_paused`, `queue_resumed`
- **System Events**: `system_info`, `system_warning`, `system_error`
- **Connection**: `connected`, `ping`, `pong`, `error`
#### Room-Based Messaging
Clients can subscribe to specific topics (rooms) to receive targeted updates:
- `downloads` room: All download-related events
- Custom rooms: Can be added for specific features
#### Integration with Download Service
- Download service automatically broadcasts progress updates via WebSocket
- Broadcast callback registered during service initialization
- Updates sent to all clients subscribed to the `downloads` room
- No blocking of download operations (async broadcast)
#### Client Connection Flow
1. Client connects to `/ws/connect` endpoint
2. Server assigns unique connection ID and sends confirmation
3. Client joins rooms (e.g., `{"action": "join", "room": "downloads"}`)
4. Server broadcasts updates to subscribed rooms
5. Client disconnects (automatic cleanup)
#### Infrastructure Notes
- **Single-process**: Current implementation uses in-memory connection storage
- **Production**: For multi-worker/multi-host deployments:
- Move connection registry to Redis or similar shared store
- Implement pub/sub for cross-process message broadcasting
- Add connection persistence for recovery after restarts
- **Monitoring**: WebSocket status available at `/ws/status` endpoint
- **Security**: Optional authentication via JWT (user_id tracking)
- **Testing**: Comprehensive unit tests in `tests/unit/test_websocket_service.py`
### Download Queue Models
- Download queue models in `src/server/models/download.py` define the data

236
src/server/api/websocket.py Normal file
View File

@ -0,0 +1,236 @@
"""WebSocket API endpoints for real-time communication.
This module provides WebSocket endpoints for clients to connect and receive
real-time updates about downloads, queue status, and system events.
"""
from __future__ import annotations
import uuid
from typing import Optional
import structlog
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect, status
from fastapi.responses import JSONResponse
from src.server.models.websocket import (
ClientMessage,
RoomSubscriptionRequest,
WebSocketMessageType,
)
from src.server.services.websocket_service import (
WebSocketService,
get_websocket_service,
)
from src.server.utils.dependencies import get_current_user_optional
logger = structlog.get_logger(__name__)
router = APIRouter(prefix="/ws", tags=["websocket"])
@router.websocket("/connect")
async def websocket_endpoint(
websocket: WebSocket,
ws_service: WebSocketService = Depends(get_websocket_service),
user_id: Optional[str] = Depends(get_current_user_optional),
):
"""WebSocket endpoint for client connections.
Clients connect to this endpoint to receive real-time updates.
The connection is maintained until the client disconnects or
an error occurs.
Message flow:
1. Client connects
2. Server sends "connected" message
3. Client can send subscription requests (join/leave rooms)
4. Server broadcasts updates to subscribed rooms
5. Client disconnects
Example client subscription:
```json
{
"action": "join",
"room": "downloads"
}
```
Server message format:
```json
{
"type": "download_progress",
"timestamp": "2025-10-17T10:30:00.000Z",
"data": {
"download_id": "abc123",
"percent": 45.2,
"speed_mbps": 2.5,
"eta_seconds": 180
}
}
```
"""
connection_id = str(uuid.uuid4())
try:
# Accept connection and register with service
await ws_service.connect(websocket, connection_id, user_id=user_id)
# Send connection confirmation
await ws_service.manager.send_personal_message(
{
"type": WebSocketMessageType.CONNECTED,
"data": {
"connection_id": connection_id,
"message": "Connected to Aniworld WebSocket",
},
},
connection_id,
)
logger.info(
"WebSocket client connected",
connection_id=connection_id,
user_id=user_id,
)
# Handle incoming messages
while True:
try:
# Receive message from client
data = await websocket.receive_json()
# Parse client message
try:
client_msg = ClientMessage(**data)
except Exception as e:
logger.warning(
"Invalid client message format",
connection_id=connection_id,
error=str(e),
)
await ws_service.send_error(
connection_id,
"Invalid message format",
"INVALID_MESSAGE",
)
continue
# Handle room subscription requests
if client_msg.action in ["join", "leave"]:
try:
room_req = RoomSubscriptionRequest(
action=client_msg.action,
room=client_msg.data.get("room", ""),
)
if room_req.action == "join":
await ws_service.manager.join_room(
connection_id, room_req.room
)
await ws_service.manager.send_personal_message(
{
"type": WebSocketMessageType.SYSTEM_INFO,
"data": {
"message": (
f"Joined room: {room_req.room}"
)
},
},
connection_id,
)
elif room_req.action == "leave":
await ws_service.manager.leave_room(
connection_id, room_req.room
)
await ws_service.manager.send_personal_message(
{
"type": WebSocketMessageType.SYSTEM_INFO,
"data": {
"message": (
f"Left room: {room_req.room}"
)
},
},
connection_id,
)
except Exception as e:
logger.warning(
"Invalid room subscription request",
connection_id=connection_id,
error=str(e),
)
await ws_service.send_error(
connection_id,
"Invalid room subscription",
"INVALID_SUBSCRIPTION",
)
# Handle ping/pong for keepalive
elif client_msg.action == "ping":
await ws_service.manager.send_personal_message(
{"type": WebSocketMessageType.PONG, "data": {}},
connection_id,
)
else:
logger.debug(
"Unknown action from client",
connection_id=connection_id,
action=client_msg.action,
)
await ws_service.send_error(
connection_id,
f"Unknown action: {client_msg.action}",
"UNKNOWN_ACTION",
)
except WebSocketDisconnect:
logger.info(
"WebSocket client disconnected",
connection_id=connection_id,
)
break
except Exception as e:
logger.error(
"Error handling WebSocket message",
connection_id=connection_id,
error=str(e),
)
await ws_service.send_error(
connection_id,
"Internal server error",
"SERVER_ERROR",
)
except Exception as e:
logger.error(
"WebSocket connection error",
connection_id=connection_id,
error=str(e),
)
finally:
# Cleanup connection
await ws_service.disconnect(connection_id)
logger.info("WebSocket connection closed", connection_id=connection_id)
@router.get("/status")
async def websocket_status(
ws_service: WebSocketService = Depends(get_websocket_service),
):
"""Get WebSocket service status and statistics.
Returns information about active connections and rooms.
Useful for monitoring and debugging.
"""
connection_count = await ws_service.manager.get_connection_count()
return JSONResponse(
status_code=status.HTTP_200_OK,
content={
"status": "operational",
"active_connections": connection_count,
"supported_message_types": [t.value for t in WebSocketMessageType],
},
)

View File

@ -19,6 +19,7 @@ from src.config.settings import settings
from src.core.SeriesApp import SeriesApp
from src.server.api.auth import router as auth_router
from src.server.api.download import router as download_router
from src.server.api.websocket import router as websocket_router
from src.server.controllers.error_controller import (
not_found_handler,
server_error_handler,
@ -59,6 +60,7 @@ app.include_router(health_router)
app.include_router(page_router)
app.include_router(auth_router)
app.include_router(download_router)
app.include_router(websocket_router)
# Global variables for application state
series_app: Optional[SeriesApp] = None

View File

@ -0,0 +1,190 @@
"""WebSocket message Pydantic models for the Aniworld web application.
This module defines message models for WebSocket communication between
the server and clients. Models ensure type safety and provide validation
for real-time updates.
"""
from __future__ import annotations
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
class WebSocketMessageType(str, Enum):
"""Types of WebSocket messages."""
# Download-related messages
DOWNLOAD_PROGRESS = "download_progress"
DOWNLOAD_COMPLETE = "download_complete"
DOWNLOAD_FAILED = "download_failed"
DOWNLOAD_ADDED = "download_added"
DOWNLOAD_REMOVED = "download_removed"
# Queue-related messages
QUEUE_STATUS = "queue_status"
QUEUE_STARTED = "queue_started"
QUEUE_STOPPED = "queue_stopped"
QUEUE_PAUSED = "queue_paused"
QUEUE_RESUMED = "queue_resumed"
# System messages
SYSTEM_INFO = "system_info"
SYSTEM_WARNING = "system_warning"
SYSTEM_ERROR = "system_error"
# Error messages
ERROR = "error"
# Connection messages
CONNECTED = "connected"
PING = "ping"
PONG = "pong"
class WebSocketMessage(BaseModel):
"""Base WebSocket message structure."""
type: WebSocketMessageType = Field(
..., description="Type of the message"
)
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(),
description="ISO 8601 timestamp when message was created",
)
data: Dict[str, Any] = Field(
default_factory=dict, description="Message payload"
)
class DownloadProgressMessage(BaseModel):
"""Download progress update message."""
type: WebSocketMessageType = Field(
default=WebSocketMessageType.DOWNLOAD_PROGRESS,
description="Message type",
)
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(),
description="ISO 8601 timestamp",
)
data: Dict[str, Any] = Field(
...,
description="Progress data including download_id, percent, speed, eta",
)
class DownloadCompleteMessage(BaseModel):
"""Download completion message."""
type: WebSocketMessageType = Field(
default=WebSocketMessageType.DOWNLOAD_COMPLETE,
description="Message type",
)
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(),
description="ISO 8601 timestamp",
)
data: Dict[str, Any] = Field(
..., description="Completion data including download_id, file_path"
)
class DownloadFailedMessage(BaseModel):
"""Download failure message."""
type: WebSocketMessageType = Field(
default=WebSocketMessageType.DOWNLOAD_FAILED,
description="Message type",
)
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(),
description="ISO 8601 timestamp",
)
data: Dict[str, Any] = Field(
..., description="Error data including download_id, error_message"
)
class QueueStatusMessage(BaseModel):
"""Queue status update message."""
type: WebSocketMessageType = Field(
default=WebSocketMessageType.QUEUE_STATUS,
description="Message type",
)
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(),
description="ISO 8601 timestamp",
)
data: Dict[str, Any] = Field(
...,
description="Queue status including active, pending, completed counts",
)
class SystemMessage(BaseModel):
"""System-level message (info, warning, error)."""
type: WebSocketMessageType = Field(
..., description="System message type"
)
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(),
description="ISO 8601 timestamp",
)
data: Dict[str, Any] = Field(
..., description="System message data"
)
class ErrorMessage(BaseModel):
"""Error message to client."""
type: WebSocketMessageType = Field(
default=WebSocketMessageType.ERROR, description="Message type"
)
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(),
description="ISO 8601 timestamp",
)
data: Dict[str, Any] = Field(
..., description="Error data including code and message"
)
class ConnectionMessage(BaseModel):
"""Connection-related message (connected, ping, pong)."""
type: WebSocketMessageType = Field(
..., description="Connection message type"
)
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(),
description="ISO 8601 timestamp",
)
data: Dict[str, Any] = Field(
default_factory=dict, description="Connection message data"
)
class ClientMessage(BaseModel):
"""Inbound message from client to server."""
action: str = Field(..., description="Action requested by client")
data: Optional[Dict[str, Any]] = Field(
default_factory=dict, description="Action payload"
)
class RoomSubscriptionRequest(BaseModel):
"""Request to join or leave a room."""
action: str = Field(
..., description="Action: 'join' or 'leave'"
)
room: str = Field(
..., min_length=1, description="Room name to join or leave"
)

View File

@ -0,0 +1,461 @@
"""WebSocket service for real-time communication with clients.
This module provides a comprehensive WebSocket manager for handling
real-time updates, connection management, room-based messaging, and
broadcast functionality for the Aniworld web application.
"""
from __future__ import annotations
import asyncio
from collections import defaultdict
from datetime import datetime
from typing import Any, Dict, List, Optional, Set
import structlog
from fastapi import WebSocket, WebSocketDisconnect
logger = structlog.get_logger(__name__)
class WebSocketServiceError(Exception):
"""Service-level exception for WebSocket operations."""
class ConnectionManager:
"""Manages WebSocket connections with room-based messaging support.
Features:
- Connection lifecycle management
- Room-based messaging (rooms for specific topics)
- Broadcast to all connections or specific rooms
- Connection health monitoring
- Automatic cleanup on disconnect
"""
def __init__(self):
"""Initialize the connection manager."""
# Active connections: connection_id -> WebSocket
self._active_connections: Dict[str, WebSocket] = {}
# Room memberships: room_name -> set of connection_ids
self._rooms: Dict[str, Set[str]] = defaultdict(set)
# Connection metadata: connection_id -> metadata dict
self._connection_metadata: Dict[str, Dict[str, Any]] = {}
# Lock for thread-safe operations
self._lock = asyncio.Lock()
logger.info("ConnectionManager initialized")
async def connect(
self,
websocket: WebSocket,
connection_id: str,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
"""Accept and register a new WebSocket connection.
Args:
websocket: The WebSocket connection to accept
connection_id: Unique identifier for this connection
metadata: Optional metadata to associate with the connection
"""
await websocket.accept()
async with self._lock:
self._active_connections[connection_id] = websocket
self._connection_metadata[connection_id] = metadata or {}
logger.info(
"WebSocket connected",
connection_id=connection_id,
total_connections=len(self._active_connections),
)
async def disconnect(self, connection_id: str) -> None:
"""Remove a WebSocket connection and cleanup associated resources.
Args:
connection_id: The connection to remove
"""
async with self._lock:
# Remove from all rooms
for room_members in self._rooms.values():
room_members.discard(connection_id)
# Remove empty rooms
self._rooms = {
room: members
for room, members in self._rooms.items()
if members
}
# Remove connection and metadata
self._active_connections.pop(connection_id, None)
self._connection_metadata.pop(connection_id, None)
logger.info(
"WebSocket disconnected",
connection_id=connection_id,
total_connections=len(self._active_connections),
)
async def join_room(self, connection_id: str, room: str) -> None:
"""Add a connection to a room.
Args:
connection_id: The connection to add
room: The room name to join
"""
async with self._lock:
if connection_id in self._active_connections:
self._rooms[room].add(connection_id)
logger.debug(
"Connection joined room",
connection_id=connection_id,
room=room,
room_size=len(self._rooms[room]),
)
else:
logger.warning(
"Attempted to join room with inactive connection",
connection_id=connection_id,
room=room,
)
async def leave_room(self, connection_id: str, room: str) -> None:
"""Remove a connection from a room.
Args:
connection_id: The connection to remove
room: The room name to leave
"""
async with self._lock:
if room in self._rooms:
self._rooms[room].discard(connection_id)
# Remove empty room
if not self._rooms[room]:
del self._rooms[room]
logger.debug(
"Connection left room",
connection_id=connection_id,
room=room,
)
async def send_personal_message(
self, message: Dict[str, Any], connection_id: str
) -> None:
"""Send a message to a specific connection.
Args:
message: The message to send (will be JSON serialized)
connection_id: Target connection identifier
"""
websocket = self._active_connections.get(connection_id)
if websocket:
try:
await websocket.send_json(message)
logger.debug(
"Personal message sent",
connection_id=connection_id,
message_type=message.get("type", "unknown"),
)
except WebSocketDisconnect:
logger.warning(
"Connection disconnected during send",
connection_id=connection_id,
)
await self.disconnect(connection_id)
except Exception as e:
logger.error(
"Failed to send personal message",
connection_id=connection_id,
error=str(e),
)
else:
logger.warning(
"Attempted to send message to inactive connection",
connection_id=connection_id,
)
async def broadcast(
self, message: Dict[str, Any], exclude: Optional[Set[str]] = None
) -> None:
"""Broadcast a message to all active connections.
Args:
message: The message to broadcast (will be JSON serialized)
exclude: Optional set of connection IDs to exclude from broadcast
"""
exclude = exclude or set()
disconnected = []
for connection_id, websocket in self._active_connections.items():
if connection_id in exclude:
continue
try:
await websocket.send_json(message)
except WebSocketDisconnect:
logger.warning(
"Connection disconnected during broadcast",
connection_id=connection_id,
)
disconnected.append(connection_id)
except Exception as e:
logger.error(
"Failed to broadcast to connection",
connection_id=connection_id,
error=str(e),
)
# Cleanup disconnected connections
for connection_id in disconnected:
await self.disconnect(connection_id)
logger.debug(
"Message broadcast",
message_type=message.get("type", "unknown"),
recipient_count=len(self._active_connections) - len(exclude),
failed_count=len(disconnected),
)
async def broadcast_to_room(
self, message: Dict[str, Any], room: str
) -> None:
"""Broadcast a message to all connections in a specific room.
Args:
message: The message to broadcast (will be JSON serialized)
room: The room to broadcast to
"""
room_members = self._rooms.get(room, set()).copy()
disconnected = []
for connection_id in room_members:
websocket = self._active_connections.get(connection_id)
if not websocket:
continue
try:
await websocket.send_json(message)
except WebSocketDisconnect:
logger.warning(
"Connection disconnected during room broadcast",
connection_id=connection_id,
room=room,
)
disconnected.append(connection_id)
except Exception as e:
logger.error(
"Failed to broadcast to room member",
connection_id=connection_id,
room=room,
error=str(e),
)
# Cleanup disconnected connections
for connection_id in disconnected:
await self.disconnect(connection_id)
logger.debug(
"Message broadcast to room",
room=room,
message_type=message.get("type", "unknown"),
recipient_count=len(room_members),
failed_count=len(disconnected),
)
async def get_connection_count(self) -> int:
"""Get the total number of active connections."""
return len(self._active_connections)
async def get_room_members(self, room: str) -> List[str]:
"""Get list of connection IDs in a specific room."""
return list(self._rooms.get(room, set()))
async def get_connection_metadata(
self, connection_id: str
) -> Optional[Dict[str, Any]]:
"""Get metadata associated with a connection."""
return self._connection_metadata.get(connection_id)
async def update_connection_metadata(
self, connection_id: str, metadata: Dict[str, Any]
) -> None:
"""Update metadata for a connection."""
if connection_id in self._active_connections:
async with self._lock:
self._connection_metadata[connection_id].update(metadata)
else:
logger.warning(
"Attempted to update metadata for inactive connection",
connection_id=connection_id,
)
class WebSocketService:
"""High-level WebSocket service for application-wide messaging.
This service provides a convenient interface for broadcasting
application events and managing WebSocket connections. It wraps
the ConnectionManager with application-specific message types.
"""
def __init__(self):
"""Initialize the WebSocket service."""
self._manager = ConnectionManager()
logger.info("WebSocketService initialized")
@property
def manager(self) -> ConnectionManager:
"""Access the underlying connection manager."""
return self._manager
async def connect(
self,
websocket: WebSocket,
connection_id: str,
user_id: Optional[str] = None,
) -> None:
"""Connect a new WebSocket client.
Args:
websocket: The WebSocket connection
connection_id: Unique connection identifier
user_id: Optional user identifier for authentication
"""
metadata = {
"connected_at": datetime.utcnow().isoformat(),
"user_id": user_id,
}
await self._manager.connect(websocket, connection_id, metadata)
async def disconnect(self, connection_id: str) -> None:
"""Disconnect a WebSocket client."""
await self._manager.disconnect(connection_id)
async def broadcast_download_progress(
self, download_id: str, progress_data: Dict[str, Any]
) -> None:
"""Broadcast download progress update to all clients.
Args:
download_id: The download item identifier
progress_data: Progress information (percent, speed, etc.)
"""
message = {
"type": "download_progress",
"timestamp": datetime.utcnow().isoformat(),
"data": {
"download_id": download_id,
**progress_data,
},
}
await self._manager.broadcast_to_room(message, "downloads")
async def broadcast_download_complete(
self, download_id: str, result_data: Dict[str, Any]
) -> None:
"""Broadcast download completion to all clients.
Args:
download_id: The download item identifier
result_data: Download result information
"""
message = {
"type": "download_complete",
"timestamp": datetime.utcnow().isoformat(),
"data": {
"download_id": download_id,
**result_data,
},
}
await self._manager.broadcast_to_room(message, "downloads")
async def broadcast_download_failed(
self, download_id: str, error_data: Dict[str, Any]
) -> None:
"""Broadcast download failure to all clients.
Args:
download_id: The download item identifier
error_data: Error information
"""
message = {
"type": "download_failed",
"timestamp": datetime.utcnow().isoformat(),
"data": {
"download_id": download_id,
**error_data,
},
}
await self._manager.broadcast_to_room(message, "downloads")
async def broadcast_queue_status(self, status_data: Dict[str, Any]) -> None:
"""Broadcast queue status update to all clients.
Args:
status_data: Queue status information
"""
message = {
"type": "queue_status",
"timestamp": datetime.utcnow().isoformat(),
"data": status_data,
}
await self._manager.broadcast_to_room(message, "downloads")
async def broadcast_system_message(
self, message_type: str, data: Dict[str, Any]
) -> None:
"""Broadcast a system message to all clients.
Args:
message_type: Type of system message
data: Message data
"""
message = {
"type": f"system_{message_type}",
"timestamp": datetime.utcnow().isoformat(),
"data": data,
}
await self._manager.broadcast(message)
async def send_error(
self, connection_id: str, error_message: str, error_code: str = "ERROR"
) -> None:
"""Send an error message to a specific connection.
Args:
connection_id: Target connection
error_message: Error description
error_code: Error code for client handling
"""
message = {
"type": "error",
"timestamp": datetime.utcnow().isoformat(),
"data": {
"code": error_code,
"message": error_message,
},
}
await self._manager.send_personal_message(message, connection_id)
# Singleton instance for application-wide access
_websocket_service: Optional[WebSocketService] = None
def get_websocket_service() -> WebSocketService:
"""Get or create the singleton WebSocket service instance.
Returns:
The WebSocket service instance
"""
global _websocket_service
if _websocket_service is None:
_websocket_service = WebSocketService()
return _websocket_service

View File

@ -154,6 +154,26 @@ def optional_auth(
return None
def get_current_user_optional(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(
HTTPBearer(auto_error=False)
)
) -> Optional[str]:
"""
Dependency to get optional current user ID.
Args:
credentials: Optional JWT token from Authorization header
Returns:
Optional[str]: User ID if authenticated, None otherwise
"""
user_dict = optional_auth(credentials)
if user_dict:
return user_dict.get("user_id")
return None
class CommonQueryParams:
"""Common query parameters for API endpoints."""
@ -246,12 +266,39 @@ def get_download_service() -> object:
if _download_service is None:
try:
from src.server.services.download_service import DownloadService
from src.server.services.websocket_service import get_websocket_service
# Get anime service first (required dependency)
anime_service = get_anime_service()
# Initialize download service with anime service
_download_service = DownloadService(anime_service)
# Setup WebSocket broadcast callback
ws_service = get_websocket_service()
async def broadcast_callback(update_type: str, data: dict):
"""Broadcast download updates via WebSocket."""
if update_type == "download_progress":
await ws_service.broadcast_download_progress(
data.get("download_id", ""), data
)
elif update_type == "download_complete":
await ws_service.broadcast_download_complete(
data.get("download_id", ""), data
)
elif update_type == "download_failed":
await ws_service.broadcast_download_failed(
data.get("download_id", ""), data
)
elif update_type == "queue_status":
await ws_service.broadcast_queue_status(data)
else:
# Generic queue update
await ws_service.broadcast_queue_status(data)
_download_service.set_broadcast_callback(broadcast_callback)
except HTTPException:
raise
except Exception as e:

View File

@ -0,0 +1,423 @@
"""Unit tests for WebSocket service."""
from unittest.mock import AsyncMock
import pytest
from fastapi import WebSocket
from src.server.services.websocket_service import (
ConnectionManager,
WebSocketService,
get_websocket_service,
)
class TestConnectionManager:
"""Test cases for ConnectionManager class."""
@pytest.fixture
def manager(self):
"""Create a ConnectionManager instance for testing."""
return ConnectionManager()
@pytest.fixture
def mock_websocket(self):
"""Create a mock WebSocket instance."""
ws = AsyncMock(spec=WebSocket)
ws.accept = AsyncMock()
ws.send_json = AsyncMock()
return ws
@pytest.mark.asyncio
async def test_connect(self, manager, mock_websocket):
"""Test connecting a WebSocket client."""
connection_id = "test-conn-1"
metadata = {"user_id": "user123"}
await manager.connect(mock_websocket, connection_id, metadata)
mock_websocket.accept.assert_called_once()
assert connection_id in manager._active_connections
assert manager._connection_metadata[connection_id] == metadata
@pytest.mark.asyncio
async def test_connect_without_metadata(self, manager, mock_websocket):
"""Test connecting without metadata."""
connection_id = "test-conn-2"
await manager.connect(mock_websocket, connection_id)
assert connection_id in manager._active_connections
assert manager._connection_metadata[connection_id] == {}
@pytest.mark.asyncio
async def test_disconnect(self, manager, mock_websocket):
"""Test disconnecting a WebSocket client."""
connection_id = "test-conn-3"
await manager.connect(mock_websocket, connection_id)
await manager.disconnect(connection_id)
assert connection_id not in manager._active_connections
assert connection_id not in manager._connection_metadata
@pytest.mark.asyncio
async def test_join_room(self, manager, mock_websocket):
"""Test joining a room."""
connection_id = "test-conn-4"
room = "downloads"
await manager.connect(mock_websocket, connection_id)
await manager.join_room(connection_id, room)
assert connection_id in manager._rooms[room]
@pytest.mark.asyncio
async def test_join_room_inactive_connection(self, manager):
"""Test joining a room with inactive connection."""
connection_id = "inactive-conn"
room = "downloads"
# Should not raise error, just log warning
await manager.join_room(connection_id, room)
assert connection_id not in manager._rooms.get(room, set())
@pytest.mark.asyncio
async def test_leave_room(self, manager, mock_websocket):
"""Test leaving a room."""
connection_id = "test-conn-5"
room = "downloads"
await manager.connect(mock_websocket, connection_id)
await manager.join_room(connection_id, room)
await manager.leave_room(connection_id, room)
assert connection_id not in manager._rooms.get(room, set())
assert room not in manager._rooms # Empty room should be removed
@pytest.mark.asyncio
async def test_disconnect_removes_from_all_rooms(
self, manager, mock_websocket
):
"""Test that disconnect removes connection from all rooms."""
connection_id = "test-conn-6"
rooms = ["room1", "room2", "room3"]
await manager.connect(mock_websocket, connection_id)
for room in rooms:
await manager.join_room(connection_id, room)
await manager.disconnect(connection_id)
for room in rooms:
assert connection_id not in manager._rooms.get(room, set())
@pytest.mark.asyncio
async def test_send_personal_message(self, manager, mock_websocket):
"""Test sending a personal message to a connection."""
connection_id = "test-conn-7"
message = {"type": "test", "data": {"value": 123}}
await manager.connect(mock_websocket, connection_id)
await manager.send_personal_message(message, connection_id)
mock_websocket.send_json.assert_called_once_with(message)
@pytest.mark.asyncio
async def test_send_personal_message_inactive_connection(
self, manager, mock_websocket
):
"""Test sending message to inactive connection."""
connection_id = "inactive-conn"
message = {"type": "test", "data": {}}
# Should not raise error, just log warning
await manager.send_personal_message(message, connection_id)
mock_websocket.send_json.assert_not_called()
@pytest.mark.asyncio
async def test_broadcast(self, manager):
"""Test broadcasting to all connections."""
connections = {}
for i in range(3):
ws = AsyncMock(spec=WebSocket)
ws.accept = AsyncMock()
ws.send_json = AsyncMock()
conn_id = f"conn-{i}"
await manager.connect(ws, conn_id)
connections[conn_id] = ws
message = {"type": "broadcast", "data": {"value": 456}}
await manager.broadcast(message)
for ws in connections.values():
ws.send_json.assert_called_once_with(message)
@pytest.mark.asyncio
async def test_broadcast_with_exclusion(self, manager):
"""Test broadcasting with excluded connections."""
connections = {}
for i in range(3):
ws = AsyncMock(spec=WebSocket)
ws.accept = AsyncMock()
ws.send_json = AsyncMock()
conn_id = f"conn-{i}"
await manager.connect(ws, conn_id)
connections[conn_id] = ws
exclude = {"conn-1"}
message = {"type": "broadcast", "data": {"value": 789}}
await manager.broadcast(message, exclude=exclude)
connections["conn-0"].send_json.assert_called_once_with(message)
connections["conn-1"].send_json.assert_not_called()
connections["conn-2"].send_json.assert_called_once_with(message)
@pytest.mark.asyncio
async def test_broadcast_to_room(self, manager):
"""Test broadcasting to a specific room."""
# Setup connections
room_members = {}
non_members = {}
for i in range(2):
ws = AsyncMock(spec=WebSocket)
ws.accept = AsyncMock()
ws.send_json = AsyncMock()
conn_id = f"member-{i}"
await manager.connect(ws, conn_id)
await manager.join_room(conn_id, "downloads")
room_members[conn_id] = ws
for i in range(2):
ws = AsyncMock(spec=WebSocket)
ws.accept = AsyncMock()
ws.send_json = AsyncMock()
conn_id = f"non-member-{i}"
await manager.connect(ws, conn_id)
non_members[conn_id] = ws
message = {"type": "room_broadcast", "data": {"room": "downloads"}}
await manager.broadcast_to_room(message, "downloads")
# Room members should receive message
for ws in room_members.values():
ws.send_json.assert_called_once_with(message)
# Non-members should not receive message
for ws in non_members.values():
ws.send_json.assert_not_called()
@pytest.mark.asyncio
async def test_get_connection_count(self, manager, mock_websocket):
"""Test getting connection count."""
assert await manager.get_connection_count() == 0
await manager.connect(mock_websocket, "conn-1")
assert await manager.get_connection_count() == 1
ws2 = AsyncMock(spec=WebSocket)
ws2.accept = AsyncMock()
await manager.connect(ws2, "conn-2")
assert await manager.get_connection_count() == 2
await manager.disconnect("conn-1")
assert await manager.get_connection_count() == 1
@pytest.mark.asyncio
async def test_get_room_members(self, manager, mock_websocket):
"""Test getting room members."""
room = "test-room"
assert await manager.get_room_members(room) == []
await manager.connect(mock_websocket, "conn-1")
await manager.join_room("conn-1", room)
members = await manager.get_room_members(room)
assert "conn-1" in members
assert len(members) == 1
@pytest.mark.asyncio
async def test_get_connection_metadata(self, manager, mock_websocket):
"""Test getting connection metadata."""
connection_id = "test-conn"
metadata = {"user_id": "user123", "ip": "127.0.0.1"}
await manager.connect(mock_websocket, connection_id, metadata)
result = await manager.get_connection_metadata(connection_id)
assert result == metadata
@pytest.mark.asyncio
async def test_update_connection_metadata(self, manager, mock_websocket):
"""Test updating connection metadata."""
connection_id = "test-conn"
initial_metadata = {"user_id": "user123"}
update = {"session_id": "session456"}
await manager.connect(mock_websocket, connection_id, initial_metadata)
await manager.update_connection_metadata(connection_id, update)
result = await manager.get_connection_metadata(connection_id)
assert result["user_id"] == "user123"
assert result["session_id"] == "session456"
class TestWebSocketService:
"""Test cases for WebSocketService class."""
@pytest.fixture
def service(self):
"""Create a WebSocketService instance for testing."""
return WebSocketService()
@pytest.fixture
def mock_websocket(self):
"""Create a mock WebSocket instance."""
ws = AsyncMock(spec=WebSocket)
ws.accept = AsyncMock()
ws.send_json = AsyncMock()
return ws
@pytest.mark.asyncio
async def test_connect(self, service, mock_websocket):
"""Test connecting a client."""
connection_id = "test-conn"
user_id = "user123"
await service.connect(mock_websocket, connection_id, user_id)
mock_websocket.accept.assert_called_once()
assert connection_id in service._manager._active_connections
metadata = await service._manager.get_connection_metadata(
connection_id
)
assert metadata["user_id"] == user_id
@pytest.mark.asyncio
async def test_disconnect(self, service, mock_websocket):
"""Test disconnecting a client."""
connection_id = "test-conn"
await service.connect(mock_websocket, connection_id)
await service.disconnect(connection_id)
assert connection_id not in service._manager._active_connections
@pytest.mark.asyncio
async def test_broadcast_download_progress(self, service, mock_websocket):
"""Test broadcasting download progress."""
connection_id = "test-conn"
download_id = "download123"
progress_data = {
"percent": 50.0,
"speed_mbps": 2.5,
"eta_seconds": 120,
}
await service.connect(mock_websocket, connection_id)
await service._manager.join_room(connection_id, "downloads")
await service.broadcast_download_progress(download_id, progress_data)
# Verify message was sent
assert mock_websocket.send_json.called
call_args = mock_websocket.send_json.call_args[0][0]
assert call_args["type"] == "download_progress"
assert call_args["data"]["download_id"] == download_id
assert call_args["data"]["percent"] == 50.0
@pytest.mark.asyncio
async def test_broadcast_download_complete(self, service, mock_websocket):
"""Test broadcasting download completion."""
connection_id = "test-conn"
download_id = "download123"
result_data = {"file_path": "/path/to/file.mp4"}
await service.connect(mock_websocket, connection_id)
await service._manager.join_room(connection_id, "downloads")
await service.broadcast_download_complete(download_id, result_data)
assert mock_websocket.send_json.called
call_args = mock_websocket.send_json.call_args[0][0]
assert call_args["type"] == "download_complete"
assert call_args["data"]["download_id"] == download_id
@pytest.mark.asyncio
async def test_broadcast_download_failed(self, service, mock_websocket):
"""Test broadcasting download failure."""
connection_id = "test-conn"
download_id = "download123"
error_data = {"error_message": "Network error"}
await service.connect(mock_websocket, connection_id)
await service._manager.join_room(connection_id, "downloads")
await service.broadcast_download_failed(download_id, error_data)
assert mock_websocket.send_json.called
call_args = mock_websocket.send_json.call_args[0][0]
assert call_args["type"] == "download_failed"
assert call_args["data"]["download_id"] == download_id
@pytest.mark.asyncio
async def test_broadcast_queue_status(self, service, mock_websocket):
"""Test broadcasting queue status."""
connection_id = "test-conn"
status_data = {"active": 2, "pending": 5, "completed": 10}
await service.connect(mock_websocket, connection_id)
await service._manager.join_room(connection_id, "downloads")
await service.broadcast_queue_status(status_data)
assert mock_websocket.send_json.called
call_args = mock_websocket.send_json.call_args[0][0]
assert call_args["type"] == "queue_status"
assert call_args["data"] == status_data
@pytest.mark.asyncio
async def test_broadcast_system_message(self, service, mock_websocket):
"""Test broadcasting system message."""
connection_id = "test-conn"
message_type = "maintenance"
data = {"message": "System will be down for maintenance"}
await service.connect(mock_websocket, connection_id)
await service.broadcast_system_message(message_type, data)
assert mock_websocket.send_json.called
call_args = mock_websocket.send_json.call_args[0][0]
assert call_args["type"] == f"system_{message_type}"
assert call_args["data"] == data
@pytest.mark.asyncio
async def test_send_error(self, service, mock_websocket):
"""Test sending error message."""
connection_id = "test-conn"
error_message = "Invalid request"
error_code = "INVALID_REQUEST"
await service.connect(mock_websocket, connection_id)
await service.send_error(connection_id, error_message, error_code)
assert mock_websocket.send_json.called
call_args = mock_websocket.send_json.call_args[0][0]
assert call_args["type"] == "error"
assert call_args["data"]["code"] == error_code
assert call_args["data"]["message"] == error_message
class TestGetWebSocketService:
"""Test cases for get_websocket_service factory function."""
def test_singleton_pattern(self):
"""Test that get_websocket_service returns singleton instance."""
service1 = get_websocket_service()
service2 = get_websocket_service()
assert service1 is service2
def test_returns_websocket_service(self):
"""Test that factory returns WebSocketService instance."""
service = get_websocket_service()
assert isinstance(service, WebSocketService)