fix: progress broadcasts now use correct WebSocket room names

- Fixed room name mismatch: ProgressService was broadcasting to
  'download_progress' but JS clients join 'downloads' room
- Added _get_room_for_progress_type() mapping function
- Updated all progress methods to use correct room names
- Added 13 new tests for room name mapping and broadcast verification
- Updated existing tests to expect correct room names
- Fixed JS clients to join valid rooms (downloads, queue, scan)
This commit is contained in:
Lukas 2025-12-16 19:21:30 +01:00
parent 4c9bf6b982
commit 700f491ef9
7 changed files with 490 additions and 17 deletions

View File

@ -133,6 +133,30 @@ class ProgressServiceError(Exception):
"""Service-level exception for progress operations.""" """Service-level exception for progress operations."""
# Mapping from ProgressType to WebSocket room names
# This ensures compatibility with the valid rooms defined in the WebSocket API:
# "downloads", "queue", "scan", "system", "errors"
_PROGRESS_TYPE_TO_ROOM: Dict[ProgressType, str] = {
ProgressType.DOWNLOAD: "downloads",
ProgressType.SCAN: "scan",
ProgressType.QUEUE: "queue",
ProgressType.SYSTEM: "system",
ProgressType.ERROR: "errors",
}
def _get_room_for_progress_type(progress_type: ProgressType) -> str:
"""Get the WebSocket room name for a progress type.
Args:
progress_type: The type of progress update
Returns:
The WebSocket room name to broadcast to
"""
return _PROGRESS_TYPE_TO_ROOM.get(progress_type, "system")
class ProgressService: class ProgressService:
"""Manages real-time progress updates and broadcasting. """Manages real-time progress updates and broadcasting.
@ -293,7 +317,7 @@ class ProgressService:
) )
# Emit event to subscribers # Emit event to subscribers
room = f"{progress_type.value}_progress" room = _get_room_for_progress_type(progress_type)
event = ProgressEvent( event = ProgressEvent(
event_type=f"{progress_type.value}_progress", event_type=f"{progress_type.value}_progress",
progress_id=progress_id, progress_id=progress_id,
@ -370,7 +394,7 @@ class ProgressService:
should_broadcast = force_broadcast or percent_change >= 1.0 should_broadcast = force_broadcast or percent_change >= 1.0
if should_broadcast: if should_broadcast:
room = f"{update.type.value}_progress" room = _get_room_for_progress_type(update.type)
event = ProgressEvent( event = ProgressEvent(
event_type=f"{update.type.value}_progress", event_type=f"{update.type.value}_progress",
progress_id=progress_id, progress_id=progress_id,
@ -427,7 +451,7 @@ class ProgressService:
) )
# Emit completion event # Emit completion event
room = f"{update.type.value}_progress" room = _get_room_for_progress_type(update.type)
event = ProgressEvent( event = ProgressEvent(
event_type=f"{update.type.value}_progress", event_type=f"{update.type.value}_progress",
progress_id=progress_id, progress_id=progress_id,
@ -483,7 +507,7 @@ class ProgressService:
) )
# Emit failure event # Emit failure event
room = f"{update.type.value}_progress" room = _get_room_for_progress_type(update.type)
event = ProgressEvent( event = ProgressEvent(
event_type=f"{update.type.value}_progress", event_type=f"{update.type.value}_progress",
progress_id=progress_id, progress_id=progress_id,
@ -533,7 +557,7 @@ class ProgressService:
) )
# Emit cancellation event # Emit cancellation event
room = f"{update.type.value}_progress" room = _get_room_for_progress_type(update.type)
event = ProgressEvent( event = ProgressEvent(
event_type=f"{update.type.value}_progress", event_type=f"{update.type.value}_progress",
progress_id=progress_id, progress_id=progress_id,

View File

@ -186,9 +186,10 @@ class AniWorldApp {
console.log('Connected to server'); console.log('Connected to server');
// Subscribe to rooms for targeted updates // Subscribe to rooms for targeted updates
this.socket.join('scan_progress'); // Valid rooms: downloads, queue, scan, system, errors
this.socket.join('download_progress'); this.socket.join('scan');
this.socket.join('downloads'); this.socket.join('downloads');
this.socket.join('queue');
this.showToast(this.localization.getText('connected-server'), 'success'); this.showToast(this.localization.getText('connected-server'), 'success');
this.updateConnectionStatus(); this.updateConnectionStatus();

View File

@ -32,8 +32,9 @@ class QueueManager {
console.log('Connected to server'); console.log('Connected to server');
// Subscribe to rooms for targeted updates // Subscribe to rooms for targeted updates
// Valid rooms: downloads, queue, scan, system, errors
this.socket.join('downloads'); this.socket.join('downloads');
this.socket.join('download_progress'); this.socket.join('queue');
this.showToast('Connected to server', 'success'); this.showToast('Connected to server', 'success');
}); });

View File

@ -180,9 +180,9 @@ class TestDownloadProgressIntegration:
connection_id = "test_client_1" connection_id = "test_client_1"
await websocket_service.connect(mock_ws, connection_id) await websocket_service.connect(mock_ws, connection_id)
# Join the queue_progress room to receive queue updates # Join the queue room to receive queue updates
await websocket_service.manager.join_room( await websocket_service.manager.join_room(
connection_id, "queue_progress" connection_id, "queue"
) )
# Subscribe to progress events and forward to WebSocket # Subscribe to progress events and forward to WebSocket
@ -254,12 +254,12 @@ class TestDownloadProgressIntegration:
await websocket_service.connect(client1, "client1") await websocket_service.connect(client1, "client1")
await websocket_service.connect(client2, "client2") await websocket_service.connect(client2, "client2")
# Join both clients to the queue_progress room # Join both clients to the queue room
await websocket_service.manager.join_room( await websocket_service.manager.join_room(
"client1", "queue_progress" "client1", "queue"
) )
await websocket_service.manager.join_room( await websocket_service.manager.join_room(
"client2", "queue_progress" "client2", "queue"
) )
# Subscribe to progress events and forward to WebSocket # Subscribe to progress events and forward to WebSocket

View File

@ -325,8 +325,9 @@ class TestWebSocketScanIntegration:
assert len(broadcasts) >= 2 # At least start and complete assert len(broadcasts) >= 2 # At least start and complete
# Check for scan progress broadcasts # Check for scan progress broadcasts
# Room name is 'scan' for SCAN type progress
scan_broadcasts = [ scan_broadcasts = [
b for b in broadcasts if b["room"] == "scan_progress" b for b in broadcasts if b["room"] == "scan"
] ]
assert len(scan_broadcasts) >= 2 assert len(scan_broadcasts) >= 2
@ -379,8 +380,9 @@ class TestWebSocketScanIntegration:
await anime_service.rescan() await anime_service.rescan()
# Verify failure broadcast # Verify failure broadcast
# Room name is 'scan' for SCAN type progress
scan_broadcasts = [ scan_broadcasts = [
b for b in broadcasts if b["room"] == "scan_progress" b for b in broadcasts if b["room"] == "scan"
] ]
assert len(scan_broadcasts) >= 2 # Start and fail assert len(scan_broadcasts) >= 2 # Start and fail
@ -438,7 +440,7 @@ class TestWebSocketProgressIntegration:
start_broadcast = broadcasts[0] start_broadcast = broadcasts[0]
assert start_broadcast["data"]["status"] == "started" assert start_broadcast["data"]["status"] == "started"
assert start_broadcast["room"] == "download_progress" assert start_broadcast["room"] == "downloads" # Room name for DOWNLOAD type
update_broadcast = broadcasts[1] update_broadcast = broadcasts[1]
assert update_broadcast["data"]["status"] == "in_progress" assert update_broadcast["data"]["status"] == "in_progress"

View File

@ -352,7 +352,7 @@ class TestProgressService:
# First positional arg is ProgressEvent # First positional arg is ProgressEvent
call_args = mock_broadcast.call_args[0][0] call_args = mock_broadcast.call_args[0][0]
assert call_args.event_type == "download_progress" assert call_args.event_type == "download_progress"
assert call_args.room == "download_progress" assert call_args.room == "downloads" # Room name for DOWNLOAD type
assert call_args.progress_id == "test-1" assert call_args.progress_id == "test-1"
assert call_args.progress.id == "test-1" assert call_args.progress.id == "test-1"

View File

@ -0,0 +1,445 @@
"""Unit tests for queue progress broadcast to correct WebSocket rooms.
This module tests that download progress events are broadcast to the
correct WebSocket rooms ('downloads' for DOWNLOAD type progress).
These tests verify the fix for progress not transmitting to clients.
No real downloads are started - all tests use mocks to verify the
event flow from ProgressService through WebSocket broadcasting.
"""
import asyncio
from typing import Any, Dict, List
from unittest.mock import AsyncMock
import pytest
from src.server.services.progress_service import (
ProgressEvent,
ProgressService,
ProgressStatus,
ProgressType,
_get_room_for_progress_type,
)
from src.server.services.websocket_service import WebSocketService
class TestRoomNameMapping:
"""Tests for progress type to room name mapping."""
def test_download_progress_maps_to_downloads_room(self):
"""Test that DOWNLOAD type maps to 'downloads' room."""
room = _get_room_for_progress_type(ProgressType.DOWNLOAD)
assert room == "downloads"
def test_scan_progress_maps_to_scan_room(self):
"""Test that SCAN type maps to 'scan' room."""
room = _get_room_for_progress_type(ProgressType.SCAN)
assert room == "scan"
def test_queue_progress_maps_to_queue_room(self):
"""Test that QUEUE type maps to 'queue' room."""
room = _get_room_for_progress_type(ProgressType.QUEUE)
assert room == "queue"
def test_system_progress_maps_to_system_room(self):
"""Test that SYSTEM type maps to 'system' room."""
room = _get_room_for_progress_type(ProgressType.SYSTEM)
assert room == "system"
def test_error_progress_maps_to_errors_room(self):
"""Test that ERROR type maps to 'errors' room."""
room = _get_room_for_progress_type(ProgressType.ERROR)
assert room == "errors"
class TestProgressServiceBroadcastRoom:
"""Tests for ProgressService broadcasting to correct rooms."""
@pytest.fixture
def progress_service(self):
"""Create a fresh ProgressService for each test."""
return ProgressService()
@pytest.fixture
def mock_handler(self):
"""Create a mock event handler to capture broadcasts."""
return AsyncMock()
@pytest.mark.asyncio
async def test_start_download_progress_broadcasts_to_downloads_room(
self, progress_service, mock_handler
):
"""Test start_progress with DOWNLOAD type uses 'downloads' room."""
# Subscribe to progress events
progress_service.subscribe("progress_updated", mock_handler)
# Start a download progress
await progress_service.start_progress(
progress_id="test-download-1",
progress_type=ProgressType.DOWNLOAD,
title="Test Download",
message="Downloading episode",
)
# Verify handler was called with correct room
mock_handler.assert_called_once()
event: ProgressEvent = mock_handler.call_args[0][0]
assert event.room == "downloads", (
f"Expected room 'downloads' but got '{event.room}'"
)
assert event.event_type == "download_progress"
assert event.progress.status == ProgressStatus.STARTED
@pytest.mark.asyncio
async def test_update_download_progress_broadcasts_to_downloads_room(
self, progress_service, mock_handler
):
"""Test update_progress with DOWNLOAD type uses 'downloads' room."""
# Start progress first
await progress_service.start_progress(
progress_id="test-download-2",
progress_type=ProgressType.DOWNLOAD,
title="Test Download",
total=100,
)
# Subscribe after start to only capture update event
progress_service.subscribe("progress_updated", mock_handler)
# Update progress with force_broadcast
await progress_service.update_progress(
progress_id="test-download-2",
current=50,
message="50% complete",
force_broadcast=True,
)
# Verify handler was called with correct room
mock_handler.assert_called_once()
event: ProgressEvent = mock_handler.call_args[0][0]
assert event.room == "downloads", (
f"Expected room 'downloads' but got '{event.room}'"
)
assert event.event_type == "download_progress"
assert event.progress.status == ProgressStatus.IN_PROGRESS
assert event.progress.percent == 50.0
@pytest.mark.asyncio
async def test_complete_download_progress_broadcasts_to_downloads_room(
self, progress_service, mock_handler
):
"""Test complete_progress with DOWNLOAD uses 'downloads' room."""
# Start progress first
await progress_service.start_progress(
progress_id="test-download-3",
progress_type=ProgressType.DOWNLOAD,
title="Test Download",
)
# Subscribe after start to only capture complete event
progress_service.subscribe("progress_updated", mock_handler)
# Complete progress
await progress_service.complete_progress(
progress_id="test-download-3",
message="Download completed",
)
# Verify handler was called with correct room
mock_handler.assert_called_once()
event: ProgressEvent = mock_handler.call_args[0][0]
assert event.room == "downloads", (
f"Expected room 'downloads' but got '{event.room}'"
)
assert event.event_type == "download_progress"
assert event.progress.status == ProgressStatus.COMPLETED
@pytest.mark.asyncio
async def test_fail_download_progress_broadcasts_to_downloads_room(
self, progress_service, mock_handler
):
"""Test that fail_progress with DOWNLOAD type uses 'downloads' room."""
# Start progress first
await progress_service.start_progress(
progress_id="test-download-4",
progress_type=ProgressType.DOWNLOAD,
title="Test Download",
)
# Subscribe after start to only capture fail event
progress_service.subscribe("progress_updated", mock_handler)
# Fail progress
await progress_service.fail_progress(
progress_id="test-download-4",
error_message="Connection lost",
)
# Verify handler was called with correct room
mock_handler.assert_called_once()
event: ProgressEvent = mock_handler.call_args[0][0]
assert event.room == "downloads", (
f"Expected room 'downloads' but got '{event.room}'"
)
assert event.event_type == "download_progress"
assert event.progress.status == ProgressStatus.FAILED
@pytest.mark.asyncio
async def test_queue_progress_broadcasts_to_queue_room(
self, progress_service, mock_handler
):
"""Test that QUEUE type progress uses 'queue' room."""
progress_service.subscribe("progress_updated", mock_handler)
await progress_service.start_progress(
progress_id="test-queue-1",
progress_type=ProgressType.QUEUE,
title="Queue Status",
)
mock_handler.assert_called_once()
event: ProgressEvent = mock_handler.call_args[0][0]
assert event.room == "queue", (
f"Expected room 'queue' but got '{event.room}'"
)
assert event.event_type == "queue_progress"
class TestEndToEndProgressBroadcast:
"""End-to-end tests for progress broadcast via WebSocket."""
@pytest.fixture
def websocket_service(self):
"""Create a WebSocketService."""
return WebSocketService()
@pytest.fixture
def progress_service(self):
"""Create a ProgressService."""
return ProgressService()
@pytest.mark.asyncio
async def test_progress_broadcast_reaches_downloads_room_clients(
self, websocket_service, progress_service
):
"""Test that download progress reaches clients in 'downloads' room.
This is the key test verifying the fix: progress updates should
be broadcast to the 'downloads' room, not 'download_progress'.
"""
# Track messages received by mock client
received_messages: List[Dict[str, Any]] = []
# Create mock WebSocket
class MockWebSocket:
async def accept(self):
pass
async def send_json(self, data):
received_messages.append(data)
async def receive_json(self):
await asyncio.sleep(10)
# Connect client to WebSocket service
mock_ws = MockWebSocket()
connection_id = "test_client"
await websocket_service.connect(mock_ws, connection_id)
# Join the 'downloads' room (this is what the JS client does)
await websocket_service.manager.join_room(connection_id, "downloads")
# Set up the progress event handler (mimics fastapi_app.py)
async def progress_event_handler(event: ProgressEvent) -> None:
"""Handle progress events and broadcast via WebSocket."""
message = {
"type": event.event_type,
"data": event.progress.to_dict(),
}
await websocket_service.manager.broadcast_to_room(
message, event.room
)
progress_service.subscribe("progress_updated", progress_event_handler)
# Simulate download progress lifecycle
# 1. Start download
await progress_service.start_progress(
progress_id="real-download-test",
progress_type=ProgressType.DOWNLOAD,
title="Downloading Anime Episode",
total=100,
metadata={"item_id": "item-123"},
)
# 2. Update progress multiple times
for percent in [25, 50, 75]:
await progress_service.update_progress(
progress_id="real-download-test",
current=percent,
message=f"{percent}% complete",
metadata={"speed_mbps": 2.5},
force_broadcast=True,
)
# 3. Complete download
await progress_service.complete_progress(
progress_id="real-download-test",
message="Download completed successfully",
)
# Verify client received all messages
# Filter for download_progress type messages
download_messages = [
m for m in received_messages
if m.get("type") == "download_progress"
]
# Should have: start + 3 updates + complete = 5 messages
assert len(download_messages) >= 4, (
f"Expected at least 4 download_progress messages, "
f"got {len(download_messages)}: {download_messages}"
)
# Verify first message is start
assert download_messages[0]["data"]["status"] == "started"
# Verify last message is completed
assert download_messages[-1]["data"]["status"] == "completed"
assert download_messages[-1]["data"]["percent"] == 100.0
# Cleanup
await websocket_service.disconnect(connection_id)
@pytest.mark.asyncio
async def test_clients_not_in_downloads_room_dont_receive_progress(
self, websocket_service, progress_service
):
"""Test that clients not in 'downloads' room don't receive progress."""
downloads_messages: List[Dict] = []
other_messages: List[Dict] = []
class MockWebSocket:
def __init__(self, message_list):
self.messages = message_list
async def accept(self):
pass
async def send_json(self, data):
self.messages.append(data)
async def receive_json(self):
await asyncio.sleep(10)
# Client in 'downloads' room
ws_downloads = MockWebSocket(downloads_messages)
await websocket_service.connect(ws_downloads, "client_downloads")
await websocket_service.manager.join_room(
"client_downloads", "downloads"
)
# Client in 'system' room (different room)
ws_other = MockWebSocket(other_messages)
await websocket_service.connect(ws_other, "client_other")
await websocket_service.manager.join_room("client_other", "system")
# Set up progress handler
async def progress_event_handler(event: ProgressEvent) -> None:
message = {
"type": event.event_type,
"data": event.progress.to_dict(),
}
await websocket_service.manager.broadcast_to_room(
message, event.room
)
progress_service.subscribe("progress_updated", progress_event_handler)
# Emit download progress
await progress_service.start_progress(
progress_id="isolation-test",
progress_type=ProgressType.DOWNLOAD,
title="Test Download",
)
# Only 'downloads' room client should receive the message
download_progress_in_downloads = [
m for m in downloads_messages
if m.get("type") == "download_progress"
]
download_progress_in_other = [
m for m in other_messages
if m.get("type") == "download_progress"
]
assert len(download_progress_in_downloads) == 1, (
"Client in 'downloads' room should receive download_progress"
)
assert len(download_progress_in_other) == 0, (
"Client in 'system' room should NOT receive download_progress"
)
# Cleanup
await websocket_service.disconnect("client_downloads")
await websocket_service.disconnect("client_other")
@pytest.mark.asyncio
async def test_progress_update_includes_item_id_in_metadata(
self, websocket_service, progress_service
):
"""Test progress updates include item_id for JS client matching."""
received_messages: List[Dict] = []
class MockWebSocket:
async def accept(self):
pass
async def send_json(self, data):
received_messages.append(data)
async def receive_json(self):
await asyncio.sleep(10)
mock_ws = MockWebSocket()
await websocket_service.connect(mock_ws, "test_client")
await websocket_service.manager.join_room("test_client", "downloads")
async def progress_event_handler(event: ProgressEvent) -> None:
message = {
"type": event.event_type,
"data": event.progress.to_dict(),
}
await websocket_service.manager.broadcast_to_room(
message, event.room
)
progress_service.subscribe("progress_updated", progress_event_handler)
# Start progress with item_id in metadata
item_id = "uuid-12345-67890"
await progress_service.start_progress(
progress_id=f"download_{item_id}",
progress_type=ProgressType.DOWNLOAD,
title="Test Download",
metadata={"item_id": item_id},
)
# Verify item_id is present in broadcast
download_messages = [
m for m in received_messages
if m.get("type") == "download_progress"
]
assert len(download_messages) == 1
metadata = download_messages[0]["data"].get("metadata", {})
assert metadata.get("item_id") == item_id, (
f"Expected item_id '{item_id}' in metadata, got: {metadata}"
)
await websocket_service.disconnect("test_client")