"""Integration tests for download progress WebSocket real-time updates. This module tests the end-to-end flow of download progress from the download service through the WebSocket service to connected clients. """ import asyncio from typing import Any, Dict, List from unittest.mock import Mock, patch import pytest from src.server.models.download import EpisodeIdentifier from src.server.services.anime_service import AnimeService from src.server.services.download_service import DownloadService from src.server.services.progress_service import ProgressService from src.server.services.websocket_service import WebSocketService @pytest.fixture def mock_series_app(): """Mock SeriesApp for testing.""" app = Mock() app.series_list = [] app.search = Mock(return_value=[]) app.ReScan = Mock() async def mock_download( serie_folder, season, episode, key, callback=None, **kwargs ): """Simulate download with realistic progress updates.""" if callback: # Simulate yt-dlp progress updates for percent in [10, 25, 50, 75, 90, 100]: callback({ 'percent': float(percent), 'downloaded_mb': percent, 'total_mb': 100.0, 'speed_mbps': 2.5, 'eta_seconds': int((100 - percent) / 2.5), }) result = Mock() result.success = True result.message = "Download completed" return result app.download = mock_download return app @pytest.fixture def progress_service(): """Create a ProgressService instance.""" return ProgressService() @pytest.fixture def websocket_service(): """Create a WebSocketService instance.""" return WebSocketService() @pytest.fixture async def anime_service(mock_series_app, progress_service): """Create an AnimeService.""" service = AnimeService( series_app=mock_series_app, progress_service=progress_service, ) yield service @pytest.fixture async def download_service(anime_service, progress_service): """Create a DownloadService.""" service = DownloadService( anime_service=anime_service, progress_service=progress_service, persistence_path="/tmp/test_integration_progress_queue.json", ) yield service await service.stop() class TestDownloadProgressIntegration: """Integration tests for download progress WebSocket flow.""" @pytest.mark.asyncio async def test_full_progress_flow_with_websocket( self, download_service, websocket_service, progress_service ): """Test complete flow from download to WebSocket broadcast.""" # Track all messages sent via WebSocket sent_messages: List[Dict[str, Any]] = [] # Mock WebSocket broadcast to room method original_broadcast = websocket_service.manager.broadcast_to_room async def mock_broadcast(message: dict, room: str): """Capture broadcast calls.""" sent_messages.append({ 'type': message.get('type'), 'data': message.get('data'), 'room': room, }) # Call original to maintain functionality await original_broadcast(message, room) websocket_service.manager.broadcast_to_room = mock_broadcast # Subscribe to progress events and forward to WebSocket async def progress_event_handler(event): """Handle progress events and broadcast via WebSocket.""" message = { "type": event.event_type, "data": event.progress.to_dict(), } await websocket_service.manager.broadcast_to_room( message, event.room ) progress_service.subscribe( "progress_updated", progress_event_handler ) # Add download to queue # Note: serie_id uses provider key format (URL-safe, lowercase) await download_service.add_to_queue( serie_id="integration-test-key", serie_folder="Integration Test Anime (2024)", serie_name="Integration Test Anime", episodes=[EpisodeIdentifier(season=1, episode=1)], ) # Start processing await download_service.start_queue_processing() # Wait for download to complete await asyncio.sleep(1.0) # Verify progress messages were sent (queue progress) progress_messages = [ m for m in sent_messages if 'queue_progress' in m.get('type', '') ] # Should have queue progress updates # (init + items added + processing started + item processing, etc.) assert len(progress_messages) >= 2 @pytest.mark.asyncio async def test_websocket_client_receives_progress( self, download_service, websocket_service, progress_service ): """Test that WebSocket clients receive progress messages.""" # Track messages received by clients client_messages: List[Dict[str, Any]] = [] # Mock WebSocket client class MockWebSocket: """Mock WebSocket for testing.""" async def accept(self): pass async def send_json(self, data): """Capture sent messages.""" client_messages.append(data) async def receive_json(self): # Keep connection open await asyncio.sleep(10) mock_ws = MockWebSocket() # Connect mock client connection_id = "test_client_1" await websocket_service.connect(mock_ws, connection_id) # Join the queue_progress room to receive queue updates await websocket_service.manager.join_room( connection_id, "queue_progress" ) # Subscribe to progress events and forward to WebSocket async def progress_event_handler(event): """Handle progress events and broadcast via WebSocket.""" message = { "type": event.event_type, "data": event.progress.to_dict(), } await websocket_service.manager.broadcast_to_room( message, event.room ) progress_service.subscribe( "progress_updated", progress_event_handler ) # Add and start download # Note: serie_id uses provider key format (URL-safe, lowercase) await download_service.add_to_queue( serie_id="client-test-key", serie_folder="Client Test Anime (2024)", serie_name="Client Test Anime", episodes=[EpisodeIdentifier(season=1, episode=1)], ) await download_service.start_queue_processing() await asyncio.sleep(1.0) # Verify client received messages (queue progress events) progress_messages = [ m for m in client_messages if 'queue_progress' in m.get('type', '') ] assert len(progress_messages) >= 1 # Cleanup await websocket_service.disconnect(connection_id) @pytest.mark.asyncio async def test_multiple_clients_receive_same_progress( self, download_service, websocket_service, progress_service ): """Test that all connected clients receive progress updates.""" # Track messages for each client client1_messages: List[Dict] = [] client2_messages: List[Dict] = [] class MockWebSocket: """Mock WebSocket for testing.""" def __init__(self, message_list): self.messages = message_list async def accept(self): pass async def send_json(self, data): self.messages.append(data) async def receive_json(self): await asyncio.sleep(10) # Connect two clients client1 = MockWebSocket(client1_messages) client2 = MockWebSocket(client2_messages) await websocket_service.connect(client1, "client1") await websocket_service.connect(client2, "client2") # Join both clients to the queue_progress room await websocket_service.manager.join_room( "client1", "queue_progress" ) await websocket_service.manager.join_room( "client2", "queue_progress" ) # Subscribe to progress events and forward to WebSocket async def progress_event_handler(event): """Handle progress events and broadcast via WebSocket.""" message = { "type": event.event_type, "data": event.progress.to_dict(), } await websocket_service.manager.broadcast_to_room( message, event.room ) progress_service.subscribe( "progress_updated", progress_event_handler ) # Start download # Note: serie_id uses provider key format (URL-safe, lowercase) await download_service.add_to_queue( serie_id="multi-client-test-key", serie_folder="Multi Client Test (2024)", serie_name="Multi Client Test", episodes=[EpisodeIdentifier(season=1, episode=1)], ) await download_service.start_queue_processing() await asyncio.sleep(1.0) # Both clients should receive progress (queue progress events) client1_progress = [ m for m in client1_messages if 'queue_progress' in m.get('type', '') ] client2_progress = [ m for m in client2_messages if 'queue_progress' in m.get('type', '') ] assert len(client1_progress) >= 1 assert len(client2_progress) >= 1 # Cleanup await websocket_service.disconnect("client1") await websocket_service.disconnect("client2") @pytest.mark.asyncio async def test_progress_data_structure_matches_frontend_expectations( self, download_service, websocket_service, progress_service ): """Test that progress data structure matches frontend requirements.""" captured_data: List[Dict] = [] async def capture_broadcast(event): """Capture progress events.""" captured_data.append(event.progress.to_dict()) message = { "type": event.event_type, "data": event.progress.to_dict(), } await websocket_service.manager.broadcast_to_room( message, event.room ) progress_service.subscribe("progress_updated", capture_broadcast) # Note: serie_id uses provider key format (URL-safe, lowercase) await download_service.add_to_queue( serie_id="structure-test-key", serie_folder="Structure Test (2024)", serie_name="Structure Test", episodes=[EpisodeIdentifier(season=2, episode=3)], ) await download_service.start_queue_processing() await asyncio.sleep(1.0) assert len(captured_data) > 0 # Verify data structure - it's now a ProgressUpdate dict for data in captured_data: # Required fields in ProgressUpdate assert 'id' in data assert 'type' in data assert 'status' in data assert 'title' in data assert 'percent' in data assert 'metadata' in data @pytest.mark.asyncio async def test_disconnected_client_doesnt_receive_progress( self, download_service, websocket_service, progress_service ): """Test that disconnected clients don't receive updates.""" client_messages: List[Dict] = [] class MockWebSocket: async def accept(self): pass async def send_json(self, data): client_messages.append(data) async def receive_json(self): await asyncio.sleep(10) mock_ws = MockWebSocket() # Connect and then disconnect connection_id = "temp_client" await websocket_service.connect(mock_ws, connection_id) await websocket_service.disconnect(connection_id) # Subscribe to progress events and forward to WebSocket async def progress_event_handler(event): """Handle progress events and broadcast via WebSocket.""" message = { "type": event.event_type, "data": event.progress.to_dict(), } await websocket_service.manager.broadcast_to_room( message, event.room ) progress_service.subscribe( "progress_updated", progress_event_handler ) # Start download after disconnect # Note: serie_id uses provider key format (URL-safe, lowercase) await download_service.add_to_queue( serie_id="disconnect-test-key", serie_folder="Disconnect Test (2024)", serie_name="Disconnect Test", episodes=[EpisodeIdentifier(season=1, episode=1)], ) initial_message_count = len(client_messages) await download_service.start_queue_processing() await asyncio.sleep(1.0) # Should not receive progress updates after disconnect progress_messages = [ m for m in client_messages[initial_message_count:] if 'queue_progress' in m.get('type', '') ] assert len(progress_messages) == 0