- Fixed room name mismatch: ProgressService was broadcasting to 'download_progress' but JS clients join 'downloads' room - Added _get_room_for_progress_type() mapping function - Updated all progress methods to use correct room names - Added 13 new tests for room name mapping and broadcast verification - Updated existing tests to expect correct room names - Fixed JS clients to join valid rooms (downloads, queue, scan)
446 lines
16 KiB
Python
446 lines
16 KiB
Python
"""Unit tests for queue progress broadcast to correct WebSocket rooms.
|
|
|
|
This module tests that download progress events are broadcast to the
|
|
correct WebSocket rooms ('downloads' for DOWNLOAD type progress).
|
|
These tests verify the fix for progress not transmitting to clients.
|
|
|
|
No real downloads are started - all tests use mocks to verify the
|
|
event flow from ProgressService through WebSocket broadcasting.
|
|
"""
|
|
import asyncio
|
|
from typing import Any, Dict, List
|
|
from unittest.mock import AsyncMock
|
|
|
|
import pytest
|
|
|
|
from src.server.services.progress_service import (
|
|
ProgressEvent,
|
|
ProgressService,
|
|
ProgressStatus,
|
|
ProgressType,
|
|
_get_room_for_progress_type,
|
|
)
|
|
from src.server.services.websocket_service import WebSocketService
|
|
|
|
|
|
class TestRoomNameMapping:
|
|
"""Tests for progress type to room name mapping."""
|
|
|
|
def test_download_progress_maps_to_downloads_room(self):
|
|
"""Test that DOWNLOAD type maps to 'downloads' room."""
|
|
room = _get_room_for_progress_type(ProgressType.DOWNLOAD)
|
|
assert room == "downloads"
|
|
|
|
def test_scan_progress_maps_to_scan_room(self):
|
|
"""Test that SCAN type maps to 'scan' room."""
|
|
room = _get_room_for_progress_type(ProgressType.SCAN)
|
|
assert room == "scan"
|
|
|
|
def test_queue_progress_maps_to_queue_room(self):
|
|
"""Test that QUEUE type maps to 'queue' room."""
|
|
room = _get_room_for_progress_type(ProgressType.QUEUE)
|
|
assert room == "queue"
|
|
|
|
def test_system_progress_maps_to_system_room(self):
|
|
"""Test that SYSTEM type maps to 'system' room."""
|
|
room = _get_room_for_progress_type(ProgressType.SYSTEM)
|
|
assert room == "system"
|
|
|
|
def test_error_progress_maps_to_errors_room(self):
|
|
"""Test that ERROR type maps to 'errors' room."""
|
|
room = _get_room_for_progress_type(ProgressType.ERROR)
|
|
assert room == "errors"
|
|
|
|
|
|
class TestProgressServiceBroadcastRoom:
|
|
"""Tests for ProgressService broadcasting to correct rooms."""
|
|
|
|
@pytest.fixture
|
|
def progress_service(self):
|
|
"""Create a fresh ProgressService for each test."""
|
|
return ProgressService()
|
|
|
|
@pytest.fixture
|
|
def mock_handler(self):
|
|
"""Create a mock event handler to capture broadcasts."""
|
|
return AsyncMock()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_start_download_progress_broadcasts_to_downloads_room(
|
|
self, progress_service, mock_handler
|
|
):
|
|
"""Test start_progress with DOWNLOAD type uses 'downloads' room."""
|
|
# Subscribe to progress events
|
|
progress_service.subscribe("progress_updated", mock_handler)
|
|
|
|
# Start a download progress
|
|
await progress_service.start_progress(
|
|
progress_id="test-download-1",
|
|
progress_type=ProgressType.DOWNLOAD,
|
|
title="Test Download",
|
|
message="Downloading episode",
|
|
)
|
|
|
|
# Verify handler was called with correct room
|
|
mock_handler.assert_called_once()
|
|
event: ProgressEvent = mock_handler.call_args[0][0]
|
|
|
|
assert event.room == "downloads", (
|
|
f"Expected room 'downloads' but got '{event.room}'"
|
|
)
|
|
assert event.event_type == "download_progress"
|
|
assert event.progress.status == ProgressStatus.STARTED
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_download_progress_broadcasts_to_downloads_room(
|
|
self, progress_service, mock_handler
|
|
):
|
|
"""Test update_progress with DOWNLOAD type uses 'downloads' room."""
|
|
# Start progress first
|
|
await progress_service.start_progress(
|
|
progress_id="test-download-2",
|
|
progress_type=ProgressType.DOWNLOAD,
|
|
title="Test Download",
|
|
total=100,
|
|
)
|
|
|
|
# Subscribe after start to only capture update event
|
|
progress_service.subscribe("progress_updated", mock_handler)
|
|
|
|
# Update progress with force_broadcast
|
|
await progress_service.update_progress(
|
|
progress_id="test-download-2",
|
|
current=50,
|
|
message="50% complete",
|
|
force_broadcast=True,
|
|
)
|
|
|
|
# Verify handler was called with correct room
|
|
mock_handler.assert_called_once()
|
|
event: ProgressEvent = mock_handler.call_args[0][0]
|
|
|
|
assert event.room == "downloads", (
|
|
f"Expected room 'downloads' but got '{event.room}'"
|
|
)
|
|
assert event.event_type == "download_progress"
|
|
assert event.progress.status == ProgressStatus.IN_PROGRESS
|
|
assert event.progress.percent == 50.0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_complete_download_progress_broadcasts_to_downloads_room(
|
|
self, progress_service, mock_handler
|
|
):
|
|
"""Test complete_progress with DOWNLOAD uses 'downloads' room."""
|
|
# Start progress first
|
|
await progress_service.start_progress(
|
|
progress_id="test-download-3",
|
|
progress_type=ProgressType.DOWNLOAD,
|
|
title="Test Download",
|
|
)
|
|
|
|
# Subscribe after start to only capture complete event
|
|
progress_service.subscribe("progress_updated", mock_handler)
|
|
|
|
# Complete progress
|
|
await progress_service.complete_progress(
|
|
progress_id="test-download-3",
|
|
message="Download completed",
|
|
)
|
|
|
|
# Verify handler was called with correct room
|
|
mock_handler.assert_called_once()
|
|
event: ProgressEvent = mock_handler.call_args[0][0]
|
|
|
|
assert event.room == "downloads", (
|
|
f"Expected room 'downloads' but got '{event.room}'"
|
|
)
|
|
assert event.event_type == "download_progress"
|
|
assert event.progress.status == ProgressStatus.COMPLETED
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_fail_download_progress_broadcasts_to_downloads_room(
|
|
self, progress_service, mock_handler
|
|
):
|
|
"""Test that fail_progress with DOWNLOAD type uses 'downloads' room."""
|
|
# Start progress first
|
|
await progress_service.start_progress(
|
|
progress_id="test-download-4",
|
|
progress_type=ProgressType.DOWNLOAD,
|
|
title="Test Download",
|
|
)
|
|
|
|
# Subscribe after start to only capture fail event
|
|
progress_service.subscribe("progress_updated", mock_handler)
|
|
|
|
# Fail progress
|
|
await progress_service.fail_progress(
|
|
progress_id="test-download-4",
|
|
error_message="Connection lost",
|
|
)
|
|
|
|
# Verify handler was called with correct room
|
|
mock_handler.assert_called_once()
|
|
event: ProgressEvent = mock_handler.call_args[0][0]
|
|
|
|
assert event.room == "downloads", (
|
|
f"Expected room 'downloads' but got '{event.room}'"
|
|
)
|
|
assert event.event_type == "download_progress"
|
|
assert event.progress.status == ProgressStatus.FAILED
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_queue_progress_broadcasts_to_queue_room(
|
|
self, progress_service, mock_handler
|
|
):
|
|
"""Test that QUEUE type progress uses 'queue' room."""
|
|
progress_service.subscribe("progress_updated", mock_handler)
|
|
|
|
await progress_service.start_progress(
|
|
progress_id="test-queue-1",
|
|
progress_type=ProgressType.QUEUE,
|
|
title="Queue Status",
|
|
)
|
|
|
|
mock_handler.assert_called_once()
|
|
event: ProgressEvent = mock_handler.call_args[0][0]
|
|
|
|
assert event.room == "queue", (
|
|
f"Expected room 'queue' but got '{event.room}'"
|
|
)
|
|
assert event.event_type == "queue_progress"
|
|
|
|
|
|
class TestEndToEndProgressBroadcast:
|
|
"""End-to-end tests for progress broadcast via WebSocket."""
|
|
|
|
@pytest.fixture
|
|
def websocket_service(self):
|
|
"""Create a WebSocketService."""
|
|
return WebSocketService()
|
|
|
|
@pytest.fixture
|
|
def progress_service(self):
|
|
"""Create a ProgressService."""
|
|
return ProgressService()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_progress_broadcast_reaches_downloads_room_clients(
|
|
self, websocket_service, progress_service
|
|
):
|
|
"""Test that download progress reaches clients in 'downloads' room.
|
|
|
|
This is the key test verifying the fix: progress updates should
|
|
be broadcast to the 'downloads' room, not 'download_progress'.
|
|
"""
|
|
# Track messages received by mock client
|
|
received_messages: List[Dict[str, Any]] = []
|
|
|
|
# Create mock WebSocket
|
|
class MockWebSocket:
|
|
async def accept(self):
|
|
pass
|
|
|
|
async def send_json(self, data):
|
|
received_messages.append(data)
|
|
|
|
async def receive_json(self):
|
|
await asyncio.sleep(10)
|
|
|
|
# Connect client to WebSocket service
|
|
mock_ws = MockWebSocket()
|
|
connection_id = "test_client"
|
|
await websocket_service.connect(mock_ws, connection_id)
|
|
|
|
# Join the 'downloads' room (this is what the JS client does)
|
|
await websocket_service.manager.join_room(connection_id, "downloads")
|
|
|
|
# Set up the progress event handler (mimics fastapi_app.py)
|
|
async def progress_event_handler(event: ProgressEvent) -> None:
|
|
"""Handle progress events and broadcast via WebSocket."""
|
|
message = {
|
|
"type": event.event_type,
|
|
"data": event.progress.to_dict(),
|
|
}
|
|
await websocket_service.manager.broadcast_to_room(
|
|
message, event.room
|
|
)
|
|
|
|
progress_service.subscribe("progress_updated", progress_event_handler)
|
|
|
|
# Simulate download progress lifecycle
|
|
# 1. Start download
|
|
await progress_service.start_progress(
|
|
progress_id="real-download-test",
|
|
progress_type=ProgressType.DOWNLOAD,
|
|
title="Downloading Anime Episode",
|
|
total=100,
|
|
metadata={"item_id": "item-123"},
|
|
)
|
|
|
|
# 2. Update progress multiple times
|
|
for percent in [25, 50, 75]:
|
|
await progress_service.update_progress(
|
|
progress_id="real-download-test",
|
|
current=percent,
|
|
message=f"{percent}% complete",
|
|
metadata={"speed_mbps": 2.5},
|
|
force_broadcast=True,
|
|
)
|
|
|
|
# 3. Complete download
|
|
await progress_service.complete_progress(
|
|
progress_id="real-download-test",
|
|
message="Download completed successfully",
|
|
)
|
|
|
|
# Verify client received all messages
|
|
# Filter for download_progress type messages
|
|
download_messages = [
|
|
m for m in received_messages
|
|
if m.get("type") == "download_progress"
|
|
]
|
|
|
|
# Should have: start + 3 updates + complete = 5 messages
|
|
assert len(download_messages) >= 4, (
|
|
f"Expected at least 4 download_progress messages, "
|
|
f"got {len(download_messages)}: {download_messages}"
|
|
)
|
|
|
|
# Verify first message is start
|
|
assert download_messages[0]["data"]["status"] == "started"
|
|
|
|
# Verify last message is completed
|
|
assert download_messages[-1]["data"]["status"] == "completed"
|
|
assert download_messages[-1]["data"]["percent"] == 100.0
|
|
|
|
# Cleanup
|
|
await websocket_service.disconnect(connection_id)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_clients_not_in_downloads_room_dont_receive_progress(
|
|
self, websocket_service, progress_service
|
|
):
|
|
"""Test that clients not in 'downloads' room don't receive progress."""
|
|
downloads_messages: List[Dict] = []
|
|
other_messages: List[Dict] = []
|
|
|
|
class MockWebSocket:
|
|
def __init__(self, message_list):
|
|
self.messages = message_list
|
|
|
|
async def accept(self):
|
|
pass
|
|
|
|
async def send_json(self, data):
|
|
self.messages.append(data)
|
|
|
|
async def receive_json(self):
|
|
await asyncio.sleep(10)
|
|
|
|
# Client in 'downloads' room
|
|
ws_downloads = MockWebSocket(downloads_messages)
|
|
await websocket_service.connect(ws_downloads, "client_downloads")
|
|
await websocket_service.manager.join_room(
|
|
"client_downloads", "downloads"
|
|
)
|
|
|
|
# Client in 'system' room (different room)
|
|
ws_other = MockWebSocket(other_messages)
|
|
await websocket_service.connect(ws_other, "client_other")
|
|
await websocket_service.manager.join_room("client_other", "system")
|
|
|
|
# Set up progress handler
|
|
async def progress_event_handler(event: ProgressEvent) -> None:
|
|
message = {
|
|
"type": event.event_type,
|
|
"data": event.progress.to_dict(),
|
|
}
|
|
await websocket_service.manager.broadcast_to_room(
|
|
message, event.room
|
|
)
|
|
|
|
progress_service.subscribe("progress_updated", progress_event_handler)
|
|
|
|
# Emit download progress
|
|
await progress_service.start_progress(
|
|
progress_id="isolation-test",
|
|
progress_type=ProgressType.DOWNLOAD,
|
|
title="Test Download",
|
|
)
|
|
|
|
# Only 'downloads' room client should receive the message
|
|
download_progress_in_downloads = [
|
|
m for m in downloads_messages
|
|
if m.get("type") == "download_progress"
|
|
]
|
|
download_progress_in_other = [
|
|
m for m in other_messages
|
|
if m.get("type") == "download_progress"
|
|
]
|
|
|
|
assert len(download_progress_in_downloads) == 1, (
|
|
"Client in 'downloads' room should receive download_progress"
|
|
)
|
|
assert len(download_progress_in_other) == 0, (
|
|
"Client in 'system' room should NOT receive download_progress"
|
|
)
|
|
|
|
# Cleanup
|
|
await websocket_service.disconnect("client_downloads")
|
|
await websocket_service.disconnect("client_other")
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_progress_update_includes_item_id_in_metadata(
|
|
self, websocket_service, progress_service
|
|
):
|
|
"""Test progress updates include item_id for JS client matching."""
|
|
received_messages: List[Dict] = []
|
|
|
|
class MockWebSocket:
|
|
async def accept(self):
|
|
pass
|
|
|
|
async def send_json(self, data):
|
|
received_messages.append(data)
|
|
|
|
async def receive_json(self):
|
|
await asyncio.sleep(10)
|
|
|
|
mock_ws = MockWebSocket()
|
|
await websocket_service.connect(mock_ws, "test_client")
|
|
await websocket_service.manager.join_room("test_client", "downloads")
|
|
|
|
async def progress_event_handler(event: ProgressEvent) -> None:
|
|
message = {
|
|
"type": event.event_type,
|
|
"data": event.progress.to_dict(),
|
|
}
|
|
await websocket_service.manager.broadcast_to_room(
|
|
message, event.room
|
|
)
|
|
|
|
progress_service.subscribe("progress_updated", progress_event_handler)
|
|
|
|
# Start progress with item_id in metadata
|
|
item_id = "uuid-12345-67890"
|
|
await progress_service.start_progress(
|
|
progress_id=f"download_{item_id}",
|
|
progress_type=ProgressType.DOWNLOAD,
|
|
title="Test Download",
|
|
metadata={"item_id": item_id},
|
|
)
|
|
|
|
# Verify item_id is present in broadcast
|
|
download_messages = [
|
|
m for m in received_messages
|
|
if m.get("type") == "download_progress"
|
|
]
|
|
|
|
assert len(download_messages) == 1
|
|
metadata = download_messages[0]["data"].get("metadata", {})
|
|
assert metadata.get("item_id") == item_id, (
|
|
f"Expected item_id '{item_id}' in metadata, got: {metadata}"
|
|
)
|
|
|
|
await websocket_service.disconnect("test_client")
|