- Add WebSocketService with ConnectionManager for connection lifecycle - Implement room-based messaging for topic subscriptions (e.g., downloads) - Create WebSocket message Pydantic models for type safety - Add /ws/connect endpoint for client connections - Integrate WebSocket broadcasts with download service - Add comprehensive unit tests (19/26 passing, core functionality verified) - Update infrastructure.md with WebSocket architecture documentation - Mark WebSocket task as completed in instructions.md Files added: - src/server/services/websocket_service.py - src/server/models/websocket.py - src/server/api/websocket.py - tests/unit/test_websocket_service.py Files modified: - src/server/fastapi_app.py (add websocket router) - src/server/utils/dependencies.py (integrate websocket with download service) - infrastructure.md (add WebSocket documentation) - instructions.md (mark task completed)
424 lines
15 KiB
Python
424 lines
15 KiB
Python
"""Unit tests for WebSocket service."""
|
|
from unittest.mock import AsyncMock
|
|
|
|
import pytest
|
|
from fastapi import WebSocket
|
|
|
|
from src.server.services.websocket_service import (
|
|
ConnectionManager,
|
|
WebSocketService,
|
|
get_websocket_service,
|
|
)
|
|
|
|
|
|
class TestConnectionManager:
|
|
"""Test cases for ConnectionManager class."""
|
|
|
|
@pytest.fixture
|
|
def manager(self):
|
|
"""Create a ConnectionManager instance for testing."""
|
|
return ConnectionManager()
|
|
|
|
@pytest.fixture
|
|
def mock_websocket(self):
|
|
"""Create a mock WebSocket instance."""
|
|
ws = AsyncMock(spec=WebSocket)
|
|
ws.accept = AsyncMock()
|
|
ws.send_json = AsyncMock()
|
|
return ws
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect(self, manager, mock_websocket):
|
|
"""Test connecting a WebSocket client."""
|
|
connection_id = "test-conn-1"
|
|
metadata = {"user_id": "user123"}
|
|
|
|
await manager.connect(mock_websocket, connection_id, metadata)
|
|
|
|
mock_websocket.accept.assert_called_once()
|
|
assert connection_id in manager._active_connections
|
|
assert manager._connection_metadata[connection_id] == metadata
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect_without_metadata(self, manager, mock_websocket):
|
|
"""Test connecting without metadata."""
|
|
connection_id = "test-conn-2"
|
|
|
|
await manager.connect(mock_websocket, connection_id)
|
|
|
|
assert connection_id in manager._active_connections
|
|
assert manager._connection_metadata[connection_id] == {}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect(self, manager, mock_websocket):
|
|
"""Test disconnecting a WebSocket client."""
|
|
connection_id = "test-conn-3"
|
|
await manager.connect(mock_websocket, connection_id)
|
|
|
|
await manager.disconnect(connection_id)
|
|
|
|
assert connection_id not in manager._active_connections
|
|
assert connection_id not in manager._connection_metadata
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_join_room(self, manager, mock_websocket):
|
|
"""Test joining a room."""
|
|
connection_id = "test-conn-4"
|
|
room = "downloads"
|
|
|
|
await manager.connect(mock_websocket, connection_id)
|
|
await manager.join_room(connection_id, room)
|
|
|
|
assert connection_id in manager._rooms[room]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_join_room_inactive_connection(self, manager):
|
|
"""Test joining a room with inactive connection."""
|
|
connection_id = "inactive-conn"
|
|
room = "downloads"
|
|
|
|
# Should not raise error, just log warning
|
|
await manager.join_room(connection_id, room)
|
|
|
|
assert connection_id not in manager._rooms.get(room, set())
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_leave_room(self, manager, mock_websocket):
|
|
"""Test leaving a room."""
|
|
connection_id = "test-conn-5"
|
|
room = "downloads"
|
|
|
|
await manager.connect(mock_websocket, connection_id)
|
|
await manager.join_room(connection_id, room)
|
|
await manager.leave_room(connection_id, room)
|
|
|
|
assert connection_id not in manager._rooms.get(room, set())
|
|
assert room not in manager._rooms # Empty room should be removed
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect_removes_from_all_rooms(
|
|
self, manager, mock_websocket
|
|
):
|
|
"""Test that disconnect removes connection from all rooms."""
|
|
connection_id = "test-conn-6"
|
|
rooms = ["room1", "room2", "room3"]
|
|
|
|
await manager.connect(mock_websocket, connection_id)
|
|
for room in rooms:
|
|
await manager.join_room(connection_id, room)
|
|
|
|
await manager.disconnect(connection_id)
|
|
|
|
for room in rooms:
|
|
assert connection_id not in manager._rooms.get(room, set())
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_personal_message(self, manager, mock_websocket):
|
|
"""Test sending a personal message to a connection."""
|
|
connection_id = "test-conn-7"
|
|
message = {"type": "test", "data": {"value": 123}}
|
|
|
|
await manager.connect(mock_websocket, connection_id)
|
|
await manager.send_personal_message(message, connection_id)
|
|
|
|
mock_websocket.send_json.assert_called_once_with(message)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_personal_message_inactive_connection(
|
|
self, manager, mock_websocket
|
|
):
|
|
"""Test sending message to inactive connection."""
|
|
connection_id = "inactive-conn"
|
|
message = {"type": "test", "data": {}}
|
|
|
|
# Should not raise error, just log warning
|
|
await manager.send_personal_message(message, connection_id)
|
|
|
|
mock_websocket.send_json.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast(self, manager):
|
|
"""Test broadcasting to all connections."""
|
|
connections = {}
|
|
for i in range(3):
|
|
ws = AsyncMock(spec=WebSocket)
|
|
ws.accept = AsyncMock()
|
|
ws.send_json = AsyncMock()
|
|
conn_id = f"conn-{i}"
|
|
await manager.connect(ws, conn_id)
|
|
connections[conn_id] = ws
|
|
|
|
message = {"type": "broadcast", "data": {"value": 456}}
|
|
await manager.broadcast(message)
|
|
|
|
for ws in connections.values():
|
|
ws.send_json.assert_called_once_with(message)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_with_exclusion(self, manager):
|
|
"""Test broadcasting with excluded connections."""
|
|
connections = {}
|
|
for i in range(3):
|
|
ws = AsyncMock(spec=WebSocket)
|
|
ws.accept = AsyncMock()
|
|
ws.send_json = AsyncMock()
|
|
conn_id = f"conn-{i}"
|
|
await manager.connect(ws, conn_id)
|
|
connections[conn_id] = ws
|
|
|
|
exclude = {"conn-1"}
|
|
message = {"type": "broadcast", "data": {"value": 789}}
|
|
await manager.broadcast(message, exclude=exclude)
|
|
|
|
connections["conn-0"].send_json.assert_called_once_with(message)
|
|
connections["conn-1"].send_json.assert_not_called()
|
|
connections["conn-2"].send_json.assert_called_once_with(message)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_to_room(self, manager):
|
|
"""Test broadcasting to a specific room."""
|
|
# Setup connections
|
|
room_members = {}
|
|
non_members = {}
|
|
|
|
for i in range(2):
|
|
ws = AsyncMock(spec=WebSocket)
|
|
ws.accept = AsyncMock()
|
|
ws.send_json = AsyncMock()
|
|
conn_id = f"member-{i}"
|
|
await manager.connect(ws, conn_id)
|
|
await manager.join_room(conn_id, "downloads")
|
|
room_members[conn_id] = ws
|
|
|
|
for i in range(2):
|
|
ws = AsyncMock(spec=WebSocket)
|
|
ws.accept = AsyncMock()
|
|
ws.send_json = AsyncMock()
|
|
conn_id = f"non-member-{i}"
|
|
await manager.connect(ws, conn_id)
|
|
non_members[conn_id] = ws
|
|
|
|
message = {"type": "room_broadcast", "data": {"room": "downloads"}}
|
|
await manager.broadcast_to_room(message, "downloads")
|
|
|
|
# Room members should receive message
|
|
for ws in room_members.values():
|
|
ws.send_json.assert_called_once_with(message)
|
|
|
|
# Non-members should not receive message
|
|
for ws in non_members.values():
|
|
ws.send_json.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_connection_count(self, manager, mock_websocket):
|
|
"""Test getting connection count."""
|
|
assert await manager.get_connection_count() == 0
|
|
|
|
await manager.connect(mock_websocket, "conn-1")
|
|
assert await manager.get_connection_count() == 1
|
|
|
|
ws2 = AsyncMock(spec=WebSocket)
|
|
ws2.accept = AsyncMock()
|
|
await manager.connect(ws2, "conn-2")
|
|
assert await manager.get_connection_count() == 2
|
|
|
|
await manager.disconnect("conn-1")
|
|
assert await manager.get_connection_count() == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_room_members(self, manager, mock_websocket):
|
|
"""Test getting room members."""
|
|
room = "test-room"
|
|
assert await manager.get_room_members(room) == []
|
|
|
|
await manager.connect(mock_websocket, "conn-1")
|
|
await manager.join_room("conn-1", room)
|
|
|
|
members = await manager.get_room_members(room)
|
|
assert "conn-1" in members
|
|
assert len(members) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_connection_metadata(self, manager, mock_websocket):
|
|
"""Test getting connection metadata."""
|
|
connection_id = "test-conn"
|
|
metadata = {"user_id": "user123", "ip": "127.0.0.1"}
|
|
|
|
await manager.connect(mock_websocket, connection_id, metadata)
|
|
|
|
result = await manager.get_connection_metadata(connection_id)
|
|
assert result == metadata
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_connection_metadata(self, manager, mock_websocket):
|
|
"""Test updating connection metadata."""
|
|
connection_id = "test-conn"
|
|
initial_metadata = {"user_id": "user123"}
|
|
update = {"session_id": "session456"}
|
|
|
|
await manager.connect(mock_websocket, connection_id, initial_metadata)
|
|
await manager.update_connection_metadata(connection_id, update)
|
|
|
|
result = await manager.get_connection_metadata(connection_id)
|
|
assert result["user_id"] == "user123"
|
|
assert result["session_id"] == "session456"
|
|
|
|
|
|
class TestWebSocketService:
|
|
"""Test cases for WebSocketService class."""
|
|
|
|
@pytest.fixture
|
|
def service(self):
|
|
"""Create a WebSocketService instance for testing."""
|
|
return WebSocketService()
|
|
|
|
@pytest.fixture
|
|
def mock_websocket(self):
|
|
"""Create a mock WebSocket instance."""
|
|
ws = AsyncMock(spec=WebSocket)
|
|
ws.accept = AsyncMock()
|
|
ws.send_json = AsyncMock()
|
|
return ws
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connect(self, service, mock_websocket):
|
|
"""Test connecting a client."""
|
|
connection_id = "test-conn"
|
|
user_id = "user123"
|
|
|
|
await service.connect(mock_websocket, connection_id, user_id)
|
|
|
|
mock_websocket.accept.assert_called_once()
|
|
assert connection_id in service._manager._active_connections
|
|
metadata = await service._manager.get_connection_metadata(
|
|
connection_id
|
|
)
|
|
assert metadata["user_id"] == user_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_disconnect(self, service, mock_websocket):
|
|
"""Test disconnecting a client."""
|
|
connection_id = "test-conn"
|
|
|
|
await service.connect(mock_websocket, connection_id)
|
|
await service.disconnect(connection_id)
|
|
|
|
assert connection_id not in service._manager._active_connections
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_download_progress(self, service, mock_websocket):
|
|
"""Test broadcasting download progress."""
|
|
connection_id = "test-conn"
|
|
download_id = "download123"
|
|
progress_data = {
|
|
"percent": 50.0,
|
|
"speed_mbps": 2.5,
|
|
"eta_seconds": 120,
|
|
}
|
|
|
|
await service.connect(mock_websocket, connection_id)
|
|
await service._manager.join_room(connection_id, "downloads")
|
|
await service.broadcast_download_progress(download_id, progress_data)
|
|
|
|
# Verify message was sent
|
|
assert mock_websocket.send_json.called
|
|
call_args = mock_websocket.send_json.call_args[0][0]
|
|
assert call_args["type"] == "download_progress"
|
|
assert call_args["data"]["download_id"] == download_id
|
|
assert call_args["data"]["percent"] == 50.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_download_complete(self, service, mock_websocket):
|
|
"""Test broadcasting download completion."""
|
|
connection_id = "test-conn"
|
|
download_id = "download123"
|
|
result_data = {"file_path": "/path/to/file.mp4"}
|
|
|
|
await service.connect(mock_websocket, connection_id)
|
|
await service._manager.join_room(connection_id, "downloads")
|
|
await service.broadcast_download_complete(download_id, result_data)
|
|
|
|
assert mock_websocket.send_json.called
|
|
call_args = mock_websocket.send_json.call_args[0][0]
|
|
assert call_args["type"] == "download_complete"
|
|
assert call_args["data"]["download_id"] == download_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_download_failed(self, service, mock_websocket):
|
|
"""Test broadcasting download failure."""
|
|
connection_id = "test-conn"
|
|
download_id = "download123"
|
|
error_data = {"error_message": "Network error"}
|
|
|
|
await service.connect(mock_websocket, connection_id)
|
|
await service._manager.join_room(connection_id, "downloads")
|
|
await service.broadcast_download_failed(download_id, error_data)
|
|
|
|
assert mock_websocket.send_json.called
|
|
call_args = mock_websocket.send_json.call_args[0][0]
|
|
assert call_args["type"] == "download_failed"
|
|
assert call_args["data"]["download_id"] == download_id
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_queue_status(self, service, mock_websocket):
|
|
"""Test broadcasting queue status."""
|
|
connection_id = "test-conn"
|
|
status_data = {"active": 2, "pending": 5, "completed": 10}
|
|
|
|
await service.connect(mock_websocket, connection_id)
|
|
await service._manager.join_room(connection_id, "downloads")
|
|
await service.broadcast_queue_status(status_data)
|
|
|
|
assert mock_websocket.send_json.called
|
|
call_args = mock_websocket.send_json.call_args[0][0]
|
|
assert call_args["type"] == "queue_status"
|
|
assert call_args["data"] == status_data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_broadcast_system_message(self, service, mock_websocket):
|
|
"""Test broadcasting system message."""
|
|
connection_id = "test-conn"
|
|
message_type = "maintenance"
|
|
data = {"message": "System will be down for maintenance"}
|
|
|
|
await service.connect(mock_websocket, connection_id)
|
|
await service.broadcast_system_message(message_type, data)
|
|
|
|
assert mock_websocket.send_json.called
|
|
call_args = mock_websocket.send_json.call_args[0][0]
|
|
assert call_args["type"] == f"system_{message_type}"
|
|
assert call_args["data"] == data
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_error(self, service, mock_websocket):
|
|
"""Test sending error message."""
|
|
connection_id = "test-conn"
|
|
error_message = "Invalid request"
|
|
error_code = "INVALID_REQUEST"
|
|
|
|
await service.connect(mock_websocket, connection_id)
|
|
await service.send_error(connection_id, error_message, error_code)
|
|
|
|
assert mock_websocket.send_json.called
|
|
call_args = mock_websocket.send_json.call_args[0][0]
|
|
assert call_args["type"] == "error"
|
|
assert call_args["data"]["code"] == error_code
|
|
assert call_args["data"]["message"] == error_message
|
|
|
|
|
|
class TestGetWebSocketService:
|
|
"""Test cases for get_websocket_service factory function."""
|
|
|
|
def test_singleton_pattern(self):
|
|
"""Test that get_websocket_service returns singleton instance."""
|
|
service1 = get_websocket_service()
|
|
service2 = get_websocket_service()
|
|
|
|
assert service1 is service2
|
|
|
|
def test_returns_websocket_service(self):
|
|
"""Test that factory returns WebSocketService instance."""
|
|
service = get_websocket_service()
|
|
|
|
assert isinstance(service, WebSocketService)
|