"""Integration tests for WebSocket integration with core services. This module tests the integration between WebSocket broadcasting and core services (DownloadService, AnimeService, ProgressService) to ensure real-time updates are properly broadcasted to connected clients. """ import asyncio from typing import Any, Dict, List from unittest.mock import AsyncMock, Mock, patch import pytest from src.server.models.download import ( DownloadPriority, DownloadStatus, EpisodeIdentifier, ) from src.server.services.anime_service import AnimeService from src.server.services.download_service import DownloadService from src.server.services.progress_service import ProgressService, ProgressType from src.server.services.websocket_service import WebSocketService @pytest.fixture def mock_series_app(): """Mock SeriesApp for testing.""" app = Mock() app.series_list = [] async def mock_search(): return [] async def mock_rescan(): pass async def mock_download(*args, **kwargs): return True app.search = mock_search app.rescan = mock_rescan app.download = mock_download return app @pytest.fixture def progress_service(): """Create a ProgressService instance for testing. Each test gets its own instance to avoid state pollution. """ return ProgressService() @pytest.fixture def websocket_service(): """Create a WebSocketService instance for testing.""" return WebSocketService() @pytest.fixture async def anime_service(mock_series_app, progress_service): """Create an AnimeService with mocked dependencies.""" service = AnimeService( series_app=mock_series_app, progress_service=progress_service, ) # Mock database operations that are called during rescan service._save_scan_results_to_db = AsyncMock(return_value=0) service._load_series_from_db = AsyncMock(return_value=None) yield service @pytest.fixture async def download_service(anime_service, progress_service, tmp_path): """Create a DownloadService with mock repository for testing. Uses mock repository to ensure each test has isolated queue storage. """ from tests.unit.test_download_service import MockQueueRepository mock_repo = MockQueueRepository() service = DownloadService( anime_service=anime_service, progress_service=progress_service, queue_repository=mock_repo, ) yield service, progress_service await service.stop() class TestWebSocketDownloadIntegration: """Test WebSocket integration with DownloadService.""" @pytest.mark.asyncio async def test_download_progress_broadcast( self, download_service, websocket_service ): """Test that download progress updates are broadcasted.""" download_svc, progress_svc = download_service broadcasts: List[Dict[str, Any]] = [] async def mock_event_handler(event): """Capture progress events.""" broadcasts.append({ "type": event.event_type, "data": event.progress.to_dict() }) # Subscribe to progress events progress_svc.subscribe("progress_updated", mock_event_handler) # Add item to queue # Note: serie_id uses provider key format (URL-safe, lowercase, hyphenated) item_ids = await download_svc.add_to_queue( serie_id="test-serie-key", serie_folder="Test Anime (2024)", serie_name="Test Anime", episodes=[EpisodeIdentifier(season=1, episode=1)], priority=DownloadPriority.HIGH, ) assert len(item_ids) == 1 # Should have at least one event (queue init + items_added) assert len(broadcasts) >= 1 # Check that queue progress event was emitted items_added_events = [ b for b in broadcasts if b["data"]["metadata"].get("action") == "items_added" ] assert len(items_added_events) >= 1 assert items_added_events[0]["type"] == "queue_progress" @pytest.mark.asyncio async def test_queue_operations_broadcast( self, download_service ): """Test that queue operations emit progress events.""" download_svc, progress_svc = download_service broadcasts: List[Dict[str, Any]] = [] async def mock_event_handler(event): broadcasts.append({ "type": event.event_type, "data": event.progress.to_dict() }) progress_svc.subscribe("progress_updated", mock_event_handler) # Add items # Note: serie_id uses provider key format (URL-safe, lowercase, hyphenated) item_ids = await download_svc.add_to_queue( serie_id="test-queue-ops-key", serie_folder="Test Queue Ops (2024)", serie_name="Test", episodes=[ EpisodeIdentifier(season=1, episode=i) for i in range(1, 4) ], priority=DownloadPriority.NORMAL, ) # Remove items removed = await download_svc.remove_from_queue([item_ids[0]]) assert len(removed) == 1 # Check broadcasts add_broadcast = None remove_broadcast = None for b in broadcasts: if b["data"]["metadata"].get("action") == "items_added": add_broadcast = b if b["data"]["metadata"].get("action") == "items_removed": remove_broadcast = b assert add_broadcast is not None assert add_broadcast["type"] == "queue_progress" assert len(add_broadcast["data"]["metadata"]["added_ids"]) == 3 assert remove_broadcast is not None assert remove_broadcast["type"] == "queue_progress" removed_ids = remove_broadcast["data"]["metadata"]["removed_ids"] assert item_ids[0] in removed_ids @pytest.mark.asyncio async def test_queue_start_stop_broadcast( self, download_service ): """Test that queue operations with items emit progress events.""" download_svc, progress_svc = download_service broadcasts: List[Dict[str, Any]] = [] async def mock_event_handler(event): broadcasts.append({ "type": event.event_type, "data": event.progress.to_dict() }) progress_svc.subscribe("progress_updated", mock_event_handler) # Add an item to initialize the queue progress # Note: serie_id uses provider key format (URL-safe, lowercase, hyphenated) await download_svc.add_to_queue( serie_id="test-start-stop-key", serie_folder="Test Start Stop (2024)", serie_name="Test", episodes=[EpisodeIdentifier(season=1, episode=1)], ) # Find start/stop broadcasts (queue progress events) queue_broadcasts = [ b for b in broadcasts if b["type"] == "queue_progress" ] # Should have at least 2 queue progress updates # (init + items_added) assert len(queue_broadcasts) >= 2 @pytest.mark.asyncio async def test_clear_completed_broadcast( self, download_service ): """Test that clearing completed items emits progress event.""" download_svc, progress_svc = download_service broadcasts: List[Dict[str, Any]] = [] async def mock_event_handler(event): broadcasts.append({ "type": event.event_type, "data": event.progress.to_dict() }) progress_svc.subscribe("progress_updated", mock_event_handler) # Initialize the download queue progress by adding an item # Note: serie_id uses provider key format (URL-safe, lowercase) await download_svc.add_to_queue( serie_id="test-init-key", serie_folder="Test Init (2024)", serie_name="Test Init", episodes=[EpisodeIdentifier(season=1, episode=1)], ) # Manually add a completed item to test from datetime import datetime, timezone from src.server.models.download import DownloadItem completed_item = DownloadItem( id="test_completed", serie_id="test-completed-key", serie_name="Test", serie_folder="Test (2024)", episode=EpisodeIdentifier(season=1, episode=1), status=DownloadStatus.COMPLETED, priority=DownloadPriority.NORMAL, added_at=datetime.now(timezone.utc), ) download_svc._completed_items.append(completed_item) # Clear completed count = await download_svc.clear_completed() assert count == 1 # Find clear broadcast (queue progress event) clear_broadcast = None for b in broadcasts: if b["data"]["metadata"].get("action") == "completed_cleared": clear_broadcast = b break assert clear_broadcast is not None metadata = clear_broadcast["data"]["metadata"] assert metadata["cleared_count"] == 1 class TestWebSocketScanIntegration: """Test WebSocket integration with AnimeService scan operations.""" @pytest.mark.asyncio async def test_scan_progress_broadcast( self, anime_service, progress_service, mock_series_app ): """Test that scan progress updates emit events.""" broadcasts: List[Dict[str, Any]] = [] async def mock_event_handler(event): """Capture progress events.""" broadcasts.append({ "type": event.event_type, "data": event.progress.to_dict(), "room": event.room, }) # Subscribe to progress events progress_service.subscribe("progress_updated", mock_event_handler) # Mock async rescan async def mock_rescan(): """Simulate scan progress.""" # Trigger progress events via progress_service await progress_service.start_progress( progress_id="scan_test", progress_type=ProgressType.SCAN, title="Scanning library", total=10, ) await progress_service.update_progress( progress_id="scan_test", current=5, message="Scanning...", ) await progress_service.complete_progress( progress_id="scan_test", message="Complete", ) mock_series_app.rescan = mock_rescan # Run scan await anime_service.rescan() # Verify broadcasts were made assert len(broadcasts) >= 2 # At least start and complete # Check for scan progress broadcasts scan_broadcasts = [ b for b in broadcasts if b["room"] == "scan_progress" ] assert len(scan_broadcasts) >= 2 # Verify start broadcast start_broadcast = scan_broadcasts[0] assert start_broadcast["data"]["status"] == "started" assert start_broadcast["data"]["type"] == ProgressType.SCAN.value # Verify completion broadcast complete_broadcast = scan_broadcasts[-1] assert complete_broadcast["data"]["status"] == "completed" @pytest.mark.asyncio async def test_scan_failure_broadcast( self, anime_service, progress_service, mock_series_app ): """Test that scan failures are broadcasted.""" broadcasts: List[Dict[str, Any]] = [] async def mock_event_handler(event): """Capture progress events.""" broadcasts.append({ "type": event.event_type, "data": event.progress.to_dict(), "room": event.room, }) progress_service.subscribe("progress_updated", mock_event_handler) # Mock async rescan to emit start event then fail async def mock_scan_error(): # Emit start event await progress_service.start_progress( progress_id="library_scan", progress_type=ProgressType.SCAN, title="Scanning anime library", message="Initializing scan...", ) # Then fail await progress_service.fail_progress( progress_id="library_scan", error_message="Scan failed", ) raise RuntimeError("Scan failed") mock_series_app.rescan = mock_scan_error # Run scan (should fail) with pytest.raises(Exception): await anime_service.rescan() # Verify failure broadcast scan_broadcasts = [ b for b in broadcasts if b["room"] == "scan_progress" ] assert len(scan_broadcasts) >= 2 # Start and fail # Verify failure broadcast fail_broadcast = scan_broadcasts[-1] assert fail_broadcast["data"]["status"] == "failed" # Verify error message or failed status is_error = "error" in fail_broadcast["data"]["message"].lower() is_failed = fail_broadcast["data"]["status"] == "failed" assert is_error or is_failed class TestWebSocketProgressIntegration: """Test WebSocket integration with ProgressService.""" @pytest.mark.asyncio async def test_progress_lifecycle_broadcast( self, progress_service ): """Test that progress lifecycle events emit properly.""" broadcasts: List[Dict[str, Any]] = [] async def mock_event_handler(event): broadcasts.append({ "type": event.event_type, "data": event.progress.to_dict(), "room": event.room, }) progress_service.subscribe("progress_updated", mock_event_handler) # Start progress await progress_service.start_progress( progress_id="test_progress", progress_type=ProgressType.DOWNLOAD, title="Test Download", total=100, ) # Update progress await progress_service.update_progress( progress_id="test_progress", current=50, force_broadcast=True, ) # Complete progress await progress_service.complete_progress( progress_id="test_progress", message="Download complete", ) # Verify broadcasts assert len(broadcasts) == 3 start_broadcast = broadcasts[0] assert start_broadcast["data"]["status"] == "started" assert start_broadcast["room"] == "download_progress" update_broadcast = broadcasts[1] assert update_broadcast["data"]["status"] == "in_progress" assert update_broadcast["data"]["percent"] == 50.0 complete_broadcast = broadcasts[2] assert complete_broadcast["data"]["status"] == "completed" assert complete_broadcast["data"]["percent"] == 100.0 class TestWebSocketEndToEnd: """End-to-end integration tests with all services.""" @pytest.mark.asyncio async def test_complete_download_flow_with_broadcasts( self, download_service, anime_service, progress_service ): """Test complete download flow with all progress events.""" download_svc, _ = download_service all_broadcasts: List[Dict[str, Any]] = [] async def capture_event(event): all_broadcasts.append({ "source": "progress", "type": event.event_type, "data": event.progress.to_dict(), "room": event.room, }) progress_service.subscribe("progress_updated", capture_event) # Add items to queue # Note: serie_id uses provider key format (URL-safe, lowercase) item_ids = await download_svc.add_to_queue( serie_id="test-e2e-key", serie_folder="Test Anime (2024)", serie_name="Test Anime", episodes=[EpisodeIdentifier(season=1, episode=1)], priority=DownloadPriority.HIGH, ) # Start queue await download_svc.start() await asyncio.sleep(0.1) # Stop queue await download_svc.stop() # Verify we received events assert len(all_broadcasts) >= 1 assert len(item_ids) == 1 # Verify queue progress broadcasts queue_events = [ b for b in all_broadcasts if b["type"] == "queue_progress" ] assert len(queue_events) >= 1 if __name__ == "__main__": pytest.main([__file__, "-v"])