"""Unit tests for WebSocket service.""" from unittest.mock import AsyncMock import pytest from fastapi import WebSocket from src.server.services.websocket_service import ( ConnectionManager, WebSocketService, get_websocket_service, ) class TestConnectionManager: """Test cases for ConnectionManager class.""" @pytest.fixture def manager(self): """Create a ConnectionManager instance for testing.""" return ConnectionManager() @pytest.fixture def mock_websocket(self): """Create a mock WebSocket instance.""" ws = AsyncMock(spec=WebSocket) ws.accept = AsyncMock() ws.send_json = AsyncMock() return ws @pytest.mark.asyncio async def test_connect(self, manager, mock_websocket): """Test connecting a WebSocket client.""" connection_id = "test-conn-1" metadata = {"user_id": "user123"} await manager.connect(mock_websocket, connection_id, metadata) mock_websocket.accept.assert_called_once() assert connection_id in manager._active_connections assert manager._connection_metadata[connection_id] == metadata @pytest.mark.asyncio async def test_connect_without_metadata(self, manager, mock_websocket): """Test connecting without metadata.""" connection_id = "test-conn-2" await manager.connect(mock_websocket, connection_id) assert connection_id in manager._active_connections assert manager._connection_metadata[connection_id] == {} @pytest.mark.asyncio async def test_disconnect(self, manager, mock_websocket): """Test disconnecting a WebSocket client.""" connection_id = "test-conn-3" await manager.connect(mock_websocket, connection_id) await manager.disconnect(connection_id) assert connection_id not in manager._active_connections assert connection_id not in manager._connection_metadata @pytest.mark.asyncio async def test_join_room(self, manager, mock_websocket): """Test joining a room.""" connection_id = "test-conn-4" room = "downloads" await manager.connect(mock_websocket, connection_id) await manager.join_room(connection_id, room) assert connection_id in manager._rooms[room] @pytest.mark.asyncio async def test_join_room_inactive_connection(self, manager): """Test joining a room with inactive connection.""" connection_id = "inactive-conn" room = "downloads" # Should not raise error, just log warning await manager.join_room(connection_id, room) assert connection_id not in manager._rooms.get(room, set()) @pytest.mark.asyncio async def test_leave_room(self, manager, mock_websocket): """Test leaving a room.""" connection_id = "test-conn-5" room = "downloads" await manager.connect(mock_websocket, connection_id) await manager.join_room(connection_id, room) await manager.leave_room(connection_id, room) assert connection_id not in manager._rooms.get(room, set()) assert room not in manager._rooms # Empty room should be removed @pytest.mark.asyncio async def test_disconnect_removes_from_all_rooms( self, manager, mock_websocket ): """Test that disconnect removes connection from all rooms.""" connection_id = "test-conn-6" rooms = ["room1", "room2", "room3"] await manager.connect(mock_websocket, connection_id) for room in rooms: await manager.join_room(connection_id, room) await manager.disconnect(connection_id) for room in rooms: assert connection_id not in manager._rooms.get(room, set()) @pytest.mark.asyncio async def test_send_personal_message(self, manager, mock_websocket): """Test sending a personal message to a connection.""" connection_id = "test-conn-7" message = {"type": "test", "data": {"value": 123}} await manager.connect(mock_websocket, connection_id) await manager.send_personal_message(message, connection_id) mock_websocket.send_json.assert_called_once_with(message) @pytest.mark.asyncio async def test_send_personal_message_inactive_connection( self, manager, mock_websocket ): """Test sending message to inactive connection.""" connection_id = "inactive-conn" message = {"type": "test", "data": {}} # Should not raise error, just log warning await manager.send_personal_message(message, connection_id) mock_websocket.send_json.assert_not_called() @pytest.mark.asyncio async def test_broadcast(self, manager): """Test broadcasting to all connections.""" connections = {} for i in range(3): ws = AsyncMock(spec=WebSocket) ws.accept = AsyncMock() ws.send_json = AsyncMock() conn_id = f"conn-{i}" await manager.connect(ws, conn_id) connections[conn_id] = ws message = {"type": "broadcast", "data": {"value": 456}} await manager.broadcast(message) for ws in connections.values(): ws.send_json.assert_called_once_with(message) @pytest.mark.asyncio async def test_broadcast_with_exclusion(self, manager): """Test broadcasting with excluded connections.""" connections = {} for i in range(3): ws = AsyncMock(spec=WebSocket) ws.accept = AsyncMock() ws.send_json = AsyncMock() conn_id = f"conn-{i}" await manager.connect(ws, conn_id) connections[conn_id] = ws exclude = {"conn-1"} message = {"type": "broadcast", "data": {"value": 789}} await manager.broadcast(message, exclude=exclude) connections["conn-0"].send_json.assert_called_once_with(message) connections["conn-1"].send_json.assert_not_called() connections["conn-2"].send_json.assert_called_once_with(message) @pytest.mark.asyncio async def test_broadcast_to_room(self, manager): """Test broadcasting to a specific room.""" # Setup connections room_members = {} non_members = {} for i in range(2): ws = AsyncMock(spec=WebSocket) ws.accept = AsyncMock() ws.send_json = AsyncMock() conn_id = f"member-{i}" await manager.connect(ws, conn_id) await manager.join_room(conn_id, "downloads") room_members[conn_id] = ws for i in range(2): ws = AsyncMock(spec=WebSocket) ws.accept = AsyncMock() ws.send_json = AsyncMock() conn_id = f"non-member-{i}" await manager.connect(ws, conn_id) non_members[conn_id] = ws message = {"type": "room_broadcast", "data": {"room": "downloads"}} await manager.broadcast_to_room(message, "downloads") # Room members should receive message for ws in room_members.values(): ws.send_json.assert_called_once_with(message) # Non-members should not receive message for ws in non_members.values(): ws.send_json.assert_not_called() @pytest.mark.asyncio async def test_get_connection_count(self, manager, mock_websocket): """Test getting connection count.""" assert await manager.get_connection_count() == 0 await manager.connect(mock_websocket, "conn-1") assert await manager.get_connection_count() == 1 ws2 = AsyncMock(spec=WebSocket) ws2.accept = AsyncMock() await manager.connect(ws2, "conn-2") assert await manager.get_connection_count() == 2 await manager.disconnect("conn-1") assert await manager.get_connection_count() == 1 @pytest.mark.asyncio async def test_get_room_members(self, manager, mock_websocket): """Test getting room members.""" room = "test-room" assert await manager.get_room_members(room) == [] await manager.connect(mock_websocket, "conn-1") await manager.join_room("conn-1", room) members = await manager.get_room_members(room) assert "conn-1" in members assert len(members) == 1 @pytest.mark.asyncio async def test_get_connection_metadata(self, manager, mock_websocket): """Test getting connection metadata.""" connection_id = "test-conn" metadata = {"user_id": "user123", "ip": "127.0.0.1"} await manager.connect(mock_websocket, connection_id, metadata) result = await manager.get_connection_metadata(connection_id) assert result == metadata @pytest.mark.asyncio async def test_update_connection_metadata(self, manager, mock_websocket): """Test updating connection metadata.""" connection_id = "test-conn" initial_metadata = {"user_id": "user123"} update = {"session_id": "session456"} await manager.connect(mock_websocket, connection_id, initial_metadata) await manager.update_connection_metadata(connection_id, update) result = await manager.get_connection_metadata(connection_id) assert result["user_id"] == "user123" assert result["session_id"] == "session456" class TestWebSocketService: """Test cases for WebSocketService class.""" @pytest.fixture def service(self): """Create a WebSocketService instance for testing.""" return WebSocketService() @pytest.fixture def mock_websocket(self): """Create a mock WebSocket instance.""" ws = AsyncMock(spec=WebSocket) ws.accept = AsyncMock() ws.send_json = AsyncMock() return ws @pytest.mark.asyncio async def test_connect(self, service, mock_websocket): """Test connecting a client.""" connection_id = "test-conn" user_id = "user123" await service.connect(mock_websocket, connection_id, user_id) mock_websocket.accept.assert_called_once() assert connection_id in service._manager._active_connections metadata = await service._manager.get_connection_metadata( connection_id ) assert metadata["user_id"] == user_id @pytest.mark.asyncio async def test_disconnect(self, service, mock_websocket): """Test disconnecting a client.""" connection_id = "test-conn" await service.connect(mock_websocket, connection_id) await service.disconnect(connection_id) assert connection_id not in service._manager._active_connections @pytest.mark.asyncio async def test_broadcast_download_progress(self, service, mock_websocket): """Test broadcasting download progress. Verifies that progress data includes 'key' as the primary series identifier and 'folder' for display purposes only. """ connection_id = "test-conn" download_id = "download123" progress_data = { "key": "attack-on-titan", "folder": "Attack on Titan (2013)", "percent": 50.0, "speed_mbps": 2.5, "eta_seconds": 120, } await service.connect(mock_websocket, connection_id) await service._manager.join_room(connection_id, "downloads") await service.broadcast_download_progress(download_id, progress_data) # Verify message was sent assert mock_websocket.send_json.called call_args = mock_websocket.send_json.call_args[0][0] assert call_args["type"] == "download_progress" assert call_args["data"]["download_id"] == download_id assert call_args["data"]["key"] == "attack-on-titan" assert call_args["data"]["folder"] == "Attack on Titan (2013)" assert call_args["data"]["percent"] == 50.0 @pytest.mark.asyncio async def test_broadcast_download_complete(self, service, mock_websocket): """Test broadcasting download completion. Verifies that result data includes 'key' as the primary series identifier and 'folder' for display purposes only. """ connection_id = "test-conn" download_id = "download123" result_data = { "key": "attack-on-titan", "folder": "Attack on Titan (2013)", "file_path": "/path/to/file.mp4" } await service.connect(mock_websocket, connection_id) await service._manager.join_room(connection_id, "downloads") await service.broadcast_download_complete(download_id, result_data) assert mock_websocket.send_json.called call_args = mock_websocket.send_json.call_args[0][0] assert call_args["type"] == "download_complete" assert call_args["data"]["download_id"] == download_id assert call_args["data"]["key"] == "attack-on-titan" assert call_args["data"]["folder"] == "Attack on Titan (2013)" @pytest.mark.asyncio async def test_broadcast_download_failed(self, service, mock_websocket): """Test broadcasting download failure. Verifies that error data includes 'key' as the primary series identifier and 'folder' for display purposes only. """ connection_id = "test-conn" download_id = "download123" error_data = { "key": "attack-on-titan", "folder": "Attack on Titan (2013)", "error_message": "Network error" } await service.connect(mock_websocket, connection_id) await service._manager.join_room(connection_id, "downloads") await service.broadcast_download_failed(download_id, error_data) assert mock_websocket.send_json.called call_args = mock_websocket.send_json.call_args[0][0] assert call_args["type"] == "download_failed" assert call_args["data"]["download_id"] == download_id assert call_args["data"]["key"] == "attack-on-titan" assert call_args["data"]["folder"] == "Attack on Titan (2013)" @pytest.mark.asyncio async def test_broadcast_queue_status(self, service, mock_websocket): """Test broadcasting queue status.""" connection_id = "test-conn" status_data = {"active": 2, "pending": 5, "completed": 10} await service.connect(mock_websocket, connection_id) await service._manager.join_room(connection_id, "downloads") await service.broadcast_queue_status(status_data) assert mock_websocket.send_json.called call_args = mock_websocket.send_json.call_args[0][0] assert call_args["type"] == "queue_status" assert call_args["data"] == status_data @pytest.mark.asyncio async def test_broadcast_system_message(self, service, mock_websocket): """Test broadcasting system message.""" connection_id = "test-conn" message_type = "maintenance" data = {"message": "System will be down for maintenance"} await service.connect(mock_websocket, connection_id) await service.broadcast_system_message(message_type, data) assert mock_websocket.send_json.called call_args = mock_websocket.send_json.call_args[0][0] assert call_args["type"] == f"system_{message_type}" assert call_args["data"] == data @pytest.mark.asyncio async def test_send_error(self, service, mock_websocket): """Test sending error message.""" connection_id = "test-conn" error_message = "Invalid request" error_code = "INVALID_REQUEST" await service.connect(mock_websocket, connection_id) await service.send_error(connection_id, error_message, error_code) assert mock_websocket.send_json.called call_args = mock_websocket.send_json.call_args[0][0] assert call_args["type"] == "error" assert call_args["data"]["code"] == error_code assert call_args["data"]["message"] == error_message @pytest.mark.asyncio async def test_broadcast_scan_started(self, service, mock_websocket): """Test broadcasting scan started event.""" connection_id = "test-conn" directory = "/home/user/anime" await service.connect(mock_websocket, connection_id) await service.broadcast_scan_started(directory) assert mock_websocket.send_json.called call_args = mock_websocket.send_json.call_args[0][0] assert call_args["type"] == "scan_started" assert call_args["data"]["directory"] == directory assert "timestamp" in call_args @pytest.mark.asyncio async def test_broadcast_scan_progress(self, service, mock_websocket): """Test broadcasting scan progress event.""" connection_id = "test-conn" directories_scanned = 25 files_found = 150 current_directory = "/home/user/anime/Attack on Titan" await service.connect(mock_websocket, connection_id) await service.broadcast_scan_progress( directories_scanned, files_found, current_directory ) assert mock_websocket.send_json.called call_args = mock_websocket.send_json.call_args[0][0] assert call_args["type"] == "scan_progress" assert call_args["data"]["directories_scanned"] == directories_scanned assert call_args["data"]["files_found"] == files_found assert call_args["data"]["current_directory"] == current_directory assert "timestamp" in call_args @pytest.mark.asyncio async def test_broadcast_scan_completed(self, service, mock_websocket): """Test broadcasting scan completed event.""" connection_id = "test-conn" total_directories = 100 total_files = 500 elapsed_seconds = 12.5 await service.connect(mock_websocket, connection_id) await service.broadcast_scan_completed( total_directories, total_files, elapsed_seconds ) assert mock_websocket.send_json.called call_args = mock_websocket.send_json.call_args[0][0] assert call_args["type"] == "scan_completed" assert call_args["data"]["total_directories"] == total_directories assert call_args["data"]["total_files"] == total_files assert call_args["data"]["elapsed_seconds"] == elapsed_seconds assert "timestamp" in call_args class TestGetWebSocketService: """Test cases for get_websocket_service factory function.""" def test_singleton_pattern(self): """Test that get_websocket_service returns singleton instance.""" service1 = get_websocket_service() service2 = get_websocket_service() assert service1 is service2 def test_returns_websocket_service(self): """Test that factory returns WebSocketService instance.""" service = get_websocket_service() assert isinstance(service, WebSocketService)