494 lines
16 KiB
Python
494 lines
16 KiB
Python
"""Integration tests for WebSocket integration with core services.
|
|
|
|
This module tests the integration between WebSocket broadcasting and
|
|
core services (DownloadService, AnimeService, ProgressService) to ensure
|
|
real-time updates are properly broadcasted to connected clients.
|
|
"""
|
|
import asyncio
|
|
from typing import Any, Dict, List
|
|
from unittest.mock import Mock, patch
|
|
|
|
import pytest
|
|
|
|
from src.server.models.download import (
|
|
DownloadPriority,
|
|
DownloadStatus,
|
|
EpisodeIdentifier,
|
|
)
|
|
from src.server.services.anime_service import AnimeService
|
|
from src.server.services.download_service import DownloadService
|
|
from src.server.services.progress_service import ProgressService, ProgressType
|
|
from src.server.services.websocket_service import WebSocketService
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_series_app():
|
|
"""Mock SeriesApp for testing."""
|
|
app = Mock()
|
|
app.series_list = []
|
|
|
|
async def mock_search():
|
|
return []
|
|
|
|
async def mock_rescan():
|
|
pass
|
|
|
|
async def mock_download(*args, **kwargs):
|
|
return True
|
|
|
|
app.search = mock_search
|
|
app.rescan = mock_rescan
|
|
app.download = mock_download
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
def progress_service():
|
|
"""Create a ProgressService instance for testing.
|
|
|
|
Each test gets its own instance to avoid state pollution.
|
|
"""
|
|
return ProgressService()
|
|
|
|
|
|
@pytest.fixture
|
|
def websocket_service():
|
|
"""Create a WebSocketService instance for testing."""
|
|
return WebSocketService()
|
|
|
|
|
|
@pytest.fixture
|
|
async def anime_service(mock_series_app, progress_service):
|
|
"""Create an AnimeService with mocked dependencies."""
|
|
service = AnimeService(
|
|
series_app=mock_series_app,
|
|
progress_service=progress_service,
|
|
)
|
|
yield service
|
|
|
|
|
|
@pytest.fixture
|
|
async def download_service(anime_service, progress_service, tmp_path):
|
|
"""Create a DownloadService with dependencies.
|
|
|
|
Uses tmp_path to ensure each test has isolated queue storage.
|
|
"""
|
|
import uuid
|
|
persistence_path = tmp_path / f"test_queue_{uuid.uuid4()}.json"
|
|
service = DownloadService(
|
|
anime_service=anime_service,
|
|
progress_service=progress_service,
|
|
persistence_path=str(persistence_path),
|
|
)
|
|
yield service, progress_service
|
|
await service.stop()
|
|
|
|
|
|
class TestWebSocketDownloadIntegration:
|
|
"""Test WebSocket integration with DownloadService."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_download_progress_broadcast(
|
|
self, download_service, websocket_service
|
|
):
|
|
"""Test that download progress updates are broadcasted."""
|
|
download_svc, progress_svc = download_service
|
|
broadcasts: List[Dict[str, Any]] = []
|
|
|
|
async def mock_event_handler(event):
|
|
"""Capture progress events."""
|
|
broadcasts.append({
|
|
"type": event.event_type,
|
|
"data": event.progress.to_dict()
|
|
})
|
|
|
|
# Subscribe to progress events
|
|
progress_svc.subscribe("progress_updated", mock_event_handler)
|
|
|
|
# Add item to queue
|
|
item_ids = await download_svc.add_to_queue(
|
|
serie_id="test_serie",
|
|
serie_folder="test_serie",
|
|
serie_name="Test Anime",
|
|
episodes=[EpisodeIdentifier(season=1, episode=1)],
|
|
priority=DownloadPriority.HIGH,
|
|
)
|
|
|
|
assert len(item_ids) == 1
|
|
# Should have at least one event (queue init + items_added)
|
|
assert len(broadcasts) >= 1
|
|
# Check that queue progress event was emitted
|
|
items_added_events = [
|
|
b for b in broadcasts
|
|
if b["data"]["metadata"].get("action") == "items_added"
|
|
]
|
|
assert len(items_added_events) >= 1
|
|
assert items_added_events[0]["type"] == "queue_progress"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_queue_operations_broadcast(
|
|
self, download_service
|
|
):
|
|
"""Test that queue operations emit progress events."""
|
|
download_svc, progress_svc = download_service
|
|
broadcasts: List[Dict[str, Any]] = []
|
|
|
|
async def mock_event_handler(event):
|
|
broadcasts.append({
|
|
"type": event.event_type,
|
|
"data": event.progress.to_dict()
|
|
})
|
|
|
|
progress_svc.subscribe("progress_updated", mock_event_handler)
|
|
|
|
# Add items
|
|
item_ids = await download_svc.add_to_queue(
|
|
serie_id="test",
|
|
serie_folder="test",
|
|
serie_name="Test",
|
|
episodes=[
|
|
EpisodeIdentifier(season=1, episode=i)
|
|
for i in range(1, 4)
|
|
],
|
|
priority=DownloadPriority.NORMAL,
|
|
)
|
|
|
|
# Remove items
|
|
removed = await download_svc.remove_from_queue([item_ids[0]])
|
|
assert len(removed) == 1
|
|
|
|
# Check broadcasts
|
|
add_broadcast = None
|
|
remove_broadcast = None
|
|
|
|
for b in broadcasts:
|
|
if b["data"]["metadata"].get("action") == "items_added":
|
|
add_broadcast = b
|
|
if b["data"]["metadata"].get("action") == "items_removed":
|
|
remove_broadcast = b
|
|
|
|
assert add_broadcast is not None
|
|
assert add_broadcast["type"] == "queue_progress"
|
|
assert len(add_broadcast["data"]["metadata"]["added_ids"]) == 3
|
|
|
|
assert remove_broadcast is not None
|
|
assert remove_broadcast["type"] == "queue_progress"
|
|
removed_ids = remove_broadcast["data"]["metadata"]["removed_ids"]
|
|
assert item_ids[0] in removed_ids
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_queue_start_stop_broadcast(
|
|
self, download_service
|
|
):
|
|
"""Test that queue operations with items emit progress events."""
|
|
download_svc, progress_svc = download_service
|
|
broadcasts: List[Dict[str, Any]] = []
|
|
|
|
async def mock_event_handler(event):
|
|
broadcasts.append({
|
|
"type": event.event_type,
|
|
"data": event.progress.to_dict()
|
|
})
|
|
|
|
progress_svc.subscribe("progress_updated", mock_event_handler)
|
|
|
|
# Add an item to initialize the queue progress
|
|
await download_svc.add_to_queue(
|
|
serie_id="test",
|
|
serie_folder="test",
|
|
serie_name="Test",
|
|
episodes=[EpisodeIdentifier(season=1, episode=1)],
|
|
)
|
|
|
|
# Find start/stop broadcasts (queue progress events)
|
|
queue_broadcasts = [
|
|
b for b in broadcasts if b["type"] == "queue_progress"
|
|
]
|
|
|
|
# Should have at least 2 queue progress updates
|
|
# (init + items_added)
|
|
assert len(queue_broadcasts) >= 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_clear_completed_broadcast(
|
|
self, download_service
|
|
):
|
|
"""Test that clearing completed items emits progress event."""
|
|
download_svc, progress_svc = download_service
|
|
broadcasts: List[Dict[str, Any]] = []
|
|
|
|
async def mock_event_handler(event):
|
|
broadcasts.append({
|
|
"type": event.event_type,
|
|
"data": event.progress.to_dict()
|
|
})
|
|
|
|
progress_svc.subscribe("progress_updated", mock_event_handler)
|
|
|
|
# Initialize the download queue progress by adding an item
|
|
await download_svc.add_to_queue(
|
|
serie_id="test",
|
|
serie_folder="test",
|
|
serie_name="Test Init",
|
|
episodes=[EpisodeIdentifier(season=1, episode=1)],
|
|
)
|
|
|
|
# Manually add a completed item to test
|
|
from datetime import datetime, timezone
|
|
|
|
from src.server.models.download import DownloadItem
|
|
|
|
completed_item = DownloadItem(
|
|
id="test_completed",
|
|
serie_id="test",
|
|
serie_name="Test",
|
|
serie_folder="Test",
|
|
episode=EpisodeIdentifier(season=1, episode=1),
|
|
status=DownloadStatus.COMPLETED,
|
|
priority=DownloadPriority.NORMAL,
|
|
added_at=datetime.now(timezone.utc),
|
|
)
|
|
download_svc._completed_items.append(completed_item)
|
|
|
|
# Clear completed
|
|
count = await download_svc.clear_completed()
|
|
|
|
assert count == 1
|
|
|
|
# Find clear broadcast (queue progress event)
|
|
clear_broadcast = None
|
|
for b in broadcasts:
|
|
if b["data"]["metadata"].get("action") == "completed_cleared":
|
|
clear_broadcast = b
|
|
break
|
|
|
|
assert clear_broadcast is not None
|
|
metadata = clear_broadcast["data"]["metadata"]
|
|
assert metadata["cleared_count"] == 1
|
|
|
|
|
|
class TestWebSocketScanIntegration:
|
|
"""Test WebSocket integration with AnimeService scan operations."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_scan_progress_broadcast(
|
|
self, anime_service, progress_service, mock_series_app
|
|
):
|
|
"""Test that scan progress updates emit events."""
|
|
broadcasts: List[Dict[str, Any]] = []
|
|
|
|
async def mock_event_handler(event):
|
|
"""Capture progress events."""
|
|
broadcasts.append({
|
|
"type": event.event_type,
|
|
"data": event.progress.to_dict(),
|
|
"room": event.room,
|
|
})
|
|
|
|
# Subscribe to progress events
|
|
progress_service.subscribe("progress_updated", mock_event_handler)
|
|
|
|
# Mock async rescan
|
|
async def mock_rescan():
|
|
"""Simulate scan progress."""
|
|
# Trigger progress events via progress_service
|
|
await progress_service.start_progress(
|
|
progress_id="scan_test",
|
|
progress_type=ProgressType.SCAN,
|
|
title="Scanning library",
|
|
total=10,
|
|
)
|
|
await progress_service.update_progress(
|
|
progress_id="scan_test",
|
|
current=5,
|
|
message="Scanning...",
|
|
)
|
|
await progress_service.complete_progress(
|
|
progress_id="scan_test",
|
|
message="Complete",
|
|
)
|
|
|
|
mock_series_app.rescan = mock_rescan
|
|
|
|
# Run scan
|
|
await anime_service.rescan()
|
|
|
|
# Verify broadcasts were made
|
|
assert len(broadcasts) >= 2 # At least start and complete
|
|
|
|
# Check for scan progress broadcasts
|
|
scan_broadcasts = [
|
|
b for b in broadcasts if b["room"] == "scan_progress"
|
|
]
|
|
assert len(scan_broadcasts) >= 2
|
|
|
|
# Verify start broadcast
|
|
start_broadcast = scan_broadcasts[0]
|
|
assert start_broadcast["data"]["status"] == "started"
|
|
assert start_broadcast["data"]["type"] == ProgressType.SCAN.value
|
|
|
|
# Verify completion broadcast
|
|
complete_broadcast = scan_broadcasts[-1]
|
|
assert complete_broadcast["data"]["status"] == "completed"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_scan_failure_broadcast(
|
|
self, anime_service, progress_service, mock_series_app
|
|
):
|
|
"""Test that scan failures are broadcasted."""
|
|
broadcasts: List[Dict[str, Any]] = []
|
|
|
|
async def mock_event_handler(event):
|
|
"""Capture progress events."""
|
|
broadcasts.append({
|
|
"type": event.event_type,
|
|
"data": event.progress.to_dict(),
|
|
"room": event.room,
|
|
})
|
|
|
|
progress_service.subscribe("progress_updated", mock_event_handler)
|
|
|
|
# Mock async rescan to emit start event then fail
|
|
async def mock_scan_error():
|
|
# Emit start event
|
|
await progress_service.start_progress(
|
|
progress_id="library_scan",
|
|
progress_type=ProgressType.SCAN,
|
|
title="Scanning anime library",
|
|
message="Initializing scan...",
|
|
)
|
|
# Then fail
|
|
await progress_service.fail_progress(
|
|
progress_id="library_scan",
|
|
error_message="Scan failed",
|
|
)
|
|
raise RuntimeError("Scan failed")
|
|
|
|
mock_series_app.rescan = mock_scan_error
|
|
|
|
# Run scan (should fail)
|
|
with pytest.raises(Exception):
|
|
await anime_service.rescan()
|
|
|
|
# Verify failure broadcast
|
|
scan_broadcasts = [
|
|
b for b in broadcasts if b["room"] == "scan_progress"
|
|
]
|
|
assert len(scan_broadcasts) >= 2 # Start and fail
|
|
|
|
# Verify failure broadcast
|
|
fail_broadcast = scan_broadcasts[-1]
|
|
assert fail_broadcast["data"]["status"] == "failed"
|
|
# Verify error message or failed status
|
|
is_error = "error" in fail_broadcast["data"]["message"].lower()
|
|
is_failed = fail_broadcast["data"]["status"] == "failed"
|
|
assert is_error or is_failed
|
|
|
|
|
|
class TestWebSocketProgressIntegration:
|
|
"""Test WebSocket integration with ProgressService."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_progress_lifecycle_broadcast(
|
|
self, progress_service
|
|
):
|
|
"""Test that progress lifecycle events emit properly."""
|
|
broadcasts: List[Dict[str, Any]] = []
|
|
|
|
async def mock_event_handler(event):
|
|
broadcasts.append({
|
|
"type": event.event_type,
|
|
"data": event.progress.to_dict(),
|
|
"room": event.room,
|
|
})
|
|
|
|
progress_service.subscribe("progress_updated", mock_event_handler)
|
|
|
|
# Start progress
|
|
await progress_service.start_progress(
|
|
progress_id="test_progress",
|
|
progress_type=ProgressType.DOWNLOAD,
|
|
title="Test Download",
|
|
total=100,
|
|
)
|
|
|
|
# Update progress
|
|
await progress_service.update_progress(
|
|
progress_id="test_progress",
|
|
current=50,
|
|
force_broadcast=True,
|
|
)
|
|
|
|
# Complete progress
|
|
await progress_service.complete_progress(
|
|
progress_id="test_progress",
|
|
message="Download complete",
|
|
)
|
|
|
|
# Verify broadcasts
|
|
assert len(broadcasts) == 3
|
|
|
|
start_broadcast = broadcasts[0]
|
|
assert start_broadcast["data"]["status"] == "started"
|
|
assert start_broadcast["room"] == "download_progress"
|
|
|
|
update_broadcast = broadcasts[1]
|
|
assert update_broadcast["data"]["status"] == "in_progress"
|
|
assert update_broadcast["data"]["percent"] == 50.0
|
|
|
|
complete_broadcast = broadcasts[2]
|
|
assert complete_broadcast["data"]["status"] == "completed"
|
|
assert complete_broadcast["data"]["percent"] == 100.0
|
|
|
|
|
|
class TestWebSocketEndToEnd:
|
|
"""End-to-end integration tests with all services."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_download_flow_with_broadcasts(
|
|
self, download_service, anime_service, progress_service
|
|
):
|
|
"""Test complete download flow with all progress events."""
|
|
download_svc, _ = download_service
|
|
all_broadcasts: List[Dict[str, Any]] = []
|
|
|
|
async def capture_event(event):
|
|
all_broadcasts.append({
|
|
"source": "progress",
|
|
"type": event.event_type,
|
|
"data": event.progress.to_dict(),
|
|
"room": event.room,
|
|
})
|
|
|
|
progress_service.subscribe("progress_updated", capture_event)
|
|
|
|
# Add items to queue
|
|
item_ids = await download_svc.add_to_queue(
|
|
serie_id="test",
|
|
serie_folder="test",
|
|
serie_name="Test Anime",
|
|
episodes=[EpisodeIdentifier(season=1, episode=1)],
|
|
priority=DownloadPriority.HIGH,
|
|
)
|
|
|
|
# Start queue
|
|
await download_svc.start()
|
|
await asyncio.sleep(0.1)
|
|
|
|
# Stop queue
|
|
await download_svc.stop()
|
|
|
|
# Verify we received events
|
|
assert len(all_broadcasts) >= 1
|
|
assert len(item_ids) == 1
|
|
|
|
# Verify queue progress broadcasts
|
|
queue_events = [
|
|
b for b in all_broadcasts if b["type"] == "queue_progress"
|
|
]
|
|
assert len(queue_events) >= 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main([__file__, "-v"])
|