"""Unit tests for queue progress broadcast to correct WebSocket rooms. This module tests that download progress events are broadcast to the correct WebSocket rooms ('downloads' for DOWNLOAD type progress). These tests verify the fix for progress not transmitting to clients. No real downloads are started - all tests use mocks to verify the event flow from ProgressService through WebSocket broadcasting. """ import asyncio from typing import Any, Dict, List from unittest.mock import AsyncMock import pytest from src.server.services.progress_service import ( ProgressEvent, ProgressService, ProgressStatus, ProgressType, _get_room_for_progress_type, ) from src.server.services.websocket_service import WebSocketService class TestRoomNameMapping: """Tests for progress type to room name mapping.""" def test_download_progress_maps_to_downloads_room(self): """Test that DOWNLOAD type maps to 'downloads' room.""" room = _get_room_for_progress_type(ProgressType.DOWNLOAD) assert room == "downloads" def test_scan_progress_maps_to_scan_room(self): """Test that SCAN type maps to 'scan' room.""" room = _get_room_for_progress_type(ProgressType.SCAN) assert room == "scan" def test_queue_progress_maps_to_queue_room(self): """Test that QUEUE type maps to 'queue' room.""" room = _get_room_for_progress_type(ProgressType.QUEUE) assert room == "queue" def test_system_progress_maps_to_system_room(self): """Test that SYSTEM type maps to 'system' room.""" room = _get_room_for_progress_type(ProgressType.SYSTEM) assert room == "system" def test_error_progress_maps_to_errors_room(self): """Test that ERROR type maps to 'errors' room.""" room = _get_room_for_progress_type(ProgressType.ERROR) assert room == "errors" class TestProgressServiceBroadcastRoom: """Tests for ProgressService broadcasting to correct rooms.""" @pytest.fixture def progress_service(self): """Create a fresh ProgressService for each test.""" return ProgressService() @pytest.fixture def mock_handler(self): """Create a mock event handler to capture broadcasts.""" return AsyncMock() @pytest.mark.asyncio async def test_start_download_progress_broadcasts_to_downloads_room( self, progress_service, mock_handler ): """Test start_progress with DOWNLOAD type uses 'downloads' room.""" # Subscribe to progress events progress_service.subscribe("progress_updated", mock_handler) # Start a download progress await progress_service.start_progress( progress_id="test-download-1", progress_type=ProgressType.DOWNLOAD, title="Test Download", message="Downloading episode", ) # Verify handler was called with correct room mock_handler.assert_called_once() event: ProgressEvent = mock_handler.call_args[0][0] assert event.room == "downloads", ( f"Expected room 'downloads' but got '{event.room}'" ) assert event.event_type == "download_progress" assert event.progress.status == ProgressStatus.STARTED @pytest.mark.asyncio async def test_update_download_progress_broadcasts_to_downloads_room( self, progress_service, mock_handler ): """Test update_progress with DOWNLOAD type uses 'downloads' room.""" # Start progress first await progress_service.start_progress( progress_id="test-download-2", progress_type=ProgressType.DOWNLOAD, title="Test Download", total=100, ) # Subscribe after start to only capture update event progress_service.subscribe("progress_updated", mock_handler) # Update progress with force_broadcast await progress_service.update_progress( progress_id="test-download-2", current=50, message="50% complete", force_broadcast=True, ) # Verify handler was called with correct room mock_handler.assert_called_once() event: ProgressEvent = mock_handler.call_args[0][0] assert event.room == "downloads", ( f"Expected room 'downloads' but got '{event.room}'" ) assert event.event_type == "download_progress" assert event.progress.status == ProgressStatus.IN_PROGRESS assert event.progress.percent == 50.0 @pytest.mark.asyncio async def test_complete_download_progress_broadcasts_to_downloads_room( self, progress_service, mock_handler ): """Test complete_progress with DOWNLOAD uses 'downloads' room.""" # Start progress first await progress_service.start_progress( progress_id="test-download-3", progress_type=ProgressType.DOWNLOAD, title="Test Download", ) # Subscribe after start to only capture complete event progress_service.subscribe("progress_updated", mock_handler) # Complete progress await progress_service.complete_progress( progress_id="test-download-3", message="Download completed", ) # Verify handler was called with correct room mock_handler.assert_called_once() event: ProgressEvent = mock_handler.call_args[0][0] assert event.room == "downloads", ( f"Expected room 'downloads' but got '{event.room}'" ) assert event.event_type == "download_progress" assert event.progress.status == ProgressStatus.COMPLETED @pytest.mark.asyncio async def test_fail_download_progress_broadcasts_to_downloads_room( self, progress_service, mock_handler ): """Test that fail_progress with DOWNLOAD type uses 'downloads' room.""" # Start progress first await progress_service.start_progress( progress_id="test-download-4", progress_type=ProgressType.DOWNLOAD, title="Test Download", ) # Subscribe after start to only capture fail event progress_service.subscribe("progress_updated", mock_handler) # Fail progress await progress_service.fail_progress( progress_id="test-download-4", error_message="Connection lost", ) # Verify handler was called with correct room mock_handler.assert_called_once() event: ProgressEvent = mock_handler.call_args[0][0] assert event.room == "downloads", ( f"Expected room 'downloads' but got '{event.room}'" ) assert event.event_type == "download_progress" assert event.progress.status == ProgressStatus.FAILED @pytest.mark.asyncio async def test_queue_progress_broadcasts_to_queue_room( self, progress_service, mock_handler ): """Test that QUEUE type progress uses 'queue' room.""" progress_service.subscribe("progress_updated", mock_handler) await progress_service.start_progress( progress_id="test-queue-1", progress_type=ProgressType.QUEUE, title="Queue Status", ) mock_handler.assert_called_once() event: ProgressEvent = mock_handler.call_args[0][0] assert event.room == "queue", ( f"Expected room 'queue' but got '{event.room}'" ) assert event.event_type == "queue_progress" class TestEndToEndProgressBroadcast: """End-to-end tests for progress broadcast via WebSocket.""" @pytest.fixture def websocket_service(self): """Create a WebSocketService.""" return WebSocketService() @pytest.fixture def progress_service(self): """Create a ProgressService.""" return ProgressService() @pytest.mark.asyncio async def test_progress_broadcast_reaches_downloads_room_clients( self, websocket_service, progress_service ): """Test that download progress reaches clients in 'downloads' room. This is the key test verifying the fix: progress updates should be broadcast to the 'downloads' room, not 'download_progress'. """ # Track messages received by mock client received_messages: List[Dict[str, Any]] = [] # Create mock WebSocket class MockWebSocket: async def accept(self): pass async def send_json(self, data): received_messages.append(data) async def receive_json(self): await asyncio.sleep(10) # Connect client to WebSocket service mock_ws = MockWebSocket() connection_id = "test_client" await websocket_service.connect(mock_ws, connection_id) # Join the 'downloads' room (this is what the JS client does) await websocket_service.manager.join_room(connection_id, "downloads") # Set up the progress event handler (mimics fastapi_app.py) async def progress_event_handler(event: ProgressEvent) -> None: """Handle progress events and broadcast via WebSocket.""" message = { "type": event.event_type, "data": event.progress.to_dict(), } await websocket_service.manager.broadcast_to_room( message, event.room ) progress_service.subscribe("progress_updated", progress_event_handler) # Simulate download progress lifecycle # 1. Start download await progress_service.start_progress( progress_id="real-download-test", progress_type=ProgressType.DOWNLOAD, title="Downloading Anime Episode", total=100, metadata={"item_id": "item-123"}, ) # 2. Update progress multiple times for percent in [25, 50, 75]: await progress_service.update_progress( progress_id="real-download-test", current=percent, message=f"{percent}% complete", metadata={"speed_mbps": 2.5}, force_broadcast=True, ) # 3. Complete download await progress_service.complete_progress( progress_id="real-download-test", message="Download completed successfully", ) # Verify client received all messages # Filter for download_progress type messages download_messages = [ m for m in received_messages if m.get("type") == "download_progress" ] # Should have: start + 3 updates + complete = 5 messages assert len(download_messages) >= 4, ( f"Expected at least 4 download_progress messages, " f"got {len(download_messages)}: {download_messages}" ) # Verify first message is start assert download_messages[0]["data"]["status"] == "started" # Verify last message is completed assert download_messages[-1]["data"]["status"] == "completed" assert download_messages[-1]["data"]["percent"] == 100.0 # Cleanup await websocket_service.disconnect(connection_id) @pytest.mark.asyncio async def test_clients_not_in_downloads_room_dont_receive_progress( self, websocket_service, progress_service ): """Test that clients not in 'downloads' room don't receive progress.""" downloads_messages: List[Dict] = [] other_messages: List[Dict] = [] class MockWebSocket: def __init__(self, message_list): self.messages = message_list async def accept(self): pass async def send_json(self, data): self.messages.append(data) async def receive_json(self): await asyncio.sleep(10) # Client in 'downloads' room ws_downloads = MockWebSocket(downloads_messages) await websocket_service.connect(ws_downloads, "client_downloads") await websocket_service.manager.join_room( "client_downloads", "downloads" ) # Client in 'system' room (different room) ws_other = MockWebSocket(other_messages) await websocket_service.connect(ws_other, "client_other") await websocket_service.manager.join_room("client_other", "system") # Set up progress handler async def progress_event_handler(event: ProgressEvent) -> None: message = { "type": event.event_type, "data": event.progress.to_dict(), } await websocket_service.manager.broadcast_to_room( message, event.room ) progress_service.subscribe("progress_updated", progress_event_handler) # Emit download progress await progress_service.start_progress( progress_id="isolation-test", progress_type=ProgressType.DOWNLOAD, title="Test Download", ) # Only 'downloads' room client should receive the message download_progress_in_downloads = [ m for m in downloads_messages if m.get("type") == "download_progress" ] download_progress_in_other = [ m for m in other_messages if m.get("type") == "download_progress" ] assert len(download_progress_in_downloads) == 1, ( "Client in 'downloads' room should receive download_progress" ) assert len(download_progress_in_other) == 0, ( "Client in 'system' room should NOT receive download_progress" ) # Cleanup await websocket_service.disconnect("client_downloads") await websocket_service.disconnect("client_other") @pytest.mark.asyncio async def test_progress_update_includes_item_id_in_metadata( self, websocket_service, progress_service ): """Test progress updates include item_id for JS client matching.""" received_messages: List[Dict] = [] class MockWebSocket: async def accept(self): pass async def send_json(self, data): received_messages.append(data) async def receive_json(self): await asyncio.sleep(10) mock_ws = MockWebSocket() await websocket_service.connect(mock_ws, "test_client") await websocket_service.manager.join_room("test_client", "downloads") async def progress_event_handler(event: ProgressEvent) -> None: message = { "type": event.event_type, "data": event.progress.to_dict(), } await websocket_service.manager.broadcast_to_room( message, event.room ) progress_service.subscribe("progress_updated", progress_event_handler) # Start progress with item_id in metadata item_id = "uuid-12345-67890" await progress_service.start_progress( progress_id=f"download_{item_id}", progress_type=ProgressType.DOWNLOAD, title="Test Download", metadata={"item_id": item_id}, ) # Verify item_id is present in broadcast download_messages = [ m for m in received_messages if m.get("type") == "download_progress" ] assert len(download_messages) == 1 metadata = download_messages[0]["data"].get("metadata", {}) assert metadata.get("item_id") == item_id, ( f"Expected item_id '{item_id}' in metadata, got: {metadata}" ) await websocket_service.disconnect("test_client")