fix: progress broadcasts now use correct WebSocket room names
- Fixed room name mismatch: ProgressService was broadcasting to 'download_progress' but JS clients join 'downloads' room - Added _get_room_for_progress_type() mapping function - Updated all progress methods to use correct room names - Added 13 new tests for room name mapping and broadcast verification - Updated existing tests to expect correct room names - Fixed JS clients to join valid rooms (downloads, queue, scan)
This commit is contained in:
parent
4c9bf6b982
commit
700f491ef9
@ -133,6 +133,30 @@ class ProgressServiceError(Exception):
|
|||||||
"""Service-level exception for progress operations."""
|
"""Service-level exception for progress operations."""
|
||||||
|
|
||||||
|
|
||||||
|
# Mapping from ProgressType to WebSocket room names
|
||||||
|
# This ensures compatibility with the valid rooms defined in the WebSocket API:
|
||||||
|
# "downloads", "queue", "scan", "system", "errors"
|
||||||
|
_PROGRESS_TYPE_TO_ROOM: Dict[ProgressType, str] = {
|
||||||
|
ProgressType.DOWNLOAD: "downloads",
|
||||||
|
ProgressType.SCAN: "scan",
|
||||||
|
ProgressType.QUEUE: "queue",
|
||||||
|
ProgressType.SYSTEM: "system",
|
||||||
|
ProgressType.ERROR: "errors",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_room_for_progress_type(progress_type: ProgressType) -> str:
|
||||||
|
"""Get the WebSocket room name for a progress type.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
progress_type: The type of progress update
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The WebSocket room name to broadcast to
|
||||||
|
"""
|
||||||
|
return _PROGRESS_TYPE_TO_ROOM.get(progress_type, "system")
|
||||||
|
|
||||||
|
|
||||||
class ProgressService:
|
class ProgressService:
|
||||||
"""Manages real-time progress updates and broadcasting.
|
"""Manages real-time progress updates and broadcasting.
|
||||||
|
|
||||||
@ -293,7 +317,7 @@ class ProgressService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Emit event to subscribers
|
# Emit event to subscribers
|
||||||
room = f"{progress_type.value}_progress"
|
room = _get_room_for_progress_type(progress_type)
|
||||||
event = ProgressEvent(
|
event = ProgressEvent(
|
||||||
event_type=f"{progress_type.value}_progress",
|
event_type=f"{progress_type.value}_progress",
|
||||||
progress_id=progress_id,
|
progress_id=progress_id,
|
||||||
@ -370,7 +394,7 @@ class ProgressService:
|
|||||||
should_broadcast = force_broadcast or percent_change >= 1.0
|
should_broadcast = force_broadcast or percent_change >= 1.0
|
||||||
|
|
||||||
if should_broadcast:
|
if should_broadcast:
|
||||||
room = f"{update.type.value}_progress"
|
room = _get_room_for_progress_type(update.type)
|
||||||
event = ProgressEvent(
|
event = ProgressEvent(
|
||||||
event_type=f"{update.type.value}_progress",
|
event_type=f"{update.type.value}_progress",
|
||||||
progress_id=progress_id,
|
progress_id=progress_id,
|
||||||
@ -427,7 +451,7 @@ class ProgressService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Emit completion event
|
# Emit completion event
|
||||||
room = f"{update.type.value}_progress"
|
room = _get_room_for_progress_type(update.type)
|
||||||
event = ProgressEvent(
|
event = ProgressEvent(
|
||||||
event_type=f"{update.type.value}_progress",
|
event_type=f"{update.type.value}_progress",
|
||||||
progress_id=progress_id,
|
progress_id=progress_id,
|
||||||
@ -483,7 +507,7 @@ class ProgressService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Emit failure event
|
# Emit failure event
|
||||||
room = f"{update.type.value}_progress"
|
room = _get_room_for_progress_type(update.type)
|
||||||
event = ProgressEvent(
|
event = ProgressEvent(
|
||||||
event_type=f"{update.type.value}_progress",
|
event_type=f"{update.type.value}_progress",
|
||||||
progress_id=progress_id,
|
progress_id=progress_id,
|
||||||
@ -533,7 +557,7 @@ class ProgressService:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Emit cancellation event
|
# Emit cancellation event
|
||||||
room = f"{update.type.value}_progress"
|
room = _get_room_for_progress_type(update.type)
|
||||||
event = ProgressEvent(
|
event = ProgressEvent(
|
||||||
event_type=f"{update.type.value}_progress",
|
event_type=f"{update.type.value}_progress",
|
||||||
progress_id=progress_id,
|
progress_id=progress_id,
|
||||||
|
|||||||
@ -186,9 +186,10 @@ class AniWorldApp {
|
|||||||
console.log('Connected to server');
|
console.log('Connected to server');
|
||||||
|
|
||||||
// Subscribe to rooms for targeted updates
|
// Subscribe to rooms for targeted updates
|
||||||
this.socket.join('scan_progress');
|
// Valid rooms: downloads, queue, scan, system, errors
|
||||||
this.socket.join('download_progress');
|
this.socket.join('scan');
|
||||||
this.socket.join('downloads');
|
this.socket.join('downloads');
|
||||||
|
this.socket.join('queue');
|
||||||
|
|
||||||
this.showToast(this.localization.getText('connected-server'), 'success');
|
this.showToast(this.localization.getText('connected-server'), 'success');
|
||||||
this.updateConnectionStatus();
|
this.updateConnectionStatus();
|
||||||
|
|||||||
@ -32,8 +32,9 @@ class QueueManager {
|
|||||||
console.log('Connected to server');
|
console.log('Connected to server');
|
||||||
|
|
||||||
// Subscribe to rooms for targeted updates
|
// Subscribe to rooms for targeted updates
|
||||||
|
// Valid rooms: downloads, queue, scan, system, errors
|
||||||
this.socket.join('downloads');
|
this.socket.join('downloads');
|
||||||
this.socket.join('download_progress');
|
this.socket.join('queue');
|
||||||
|
|
||||||
this.showToast('Connected to server', 'success');
|
this.showToast('Connected to server', 'success');
|
||||||
});
|
});
|
||||||
|
|||||||
@ -180,9 +180,9 @@ class TestDownloadProgressIntegration:
|
|||||||
connection_id = "test_client_1"
|
connection_id = "test_client_1"
|
||||||
await websocket_service.connect(mock_ws, connection_id)
|
await websocket_service.connect(mock_ws, connection_id)
|
||||||
|
|
||||||
# Join the queue_progress room to receive queue updates
|
# Join the queue room to receive queue updates
|
||||||
await websocket_service.manager.join_room(
|
await websocket_service.manager.join_room(
|
||||||
connection_id, "queue_progress"
|
connection_id, "queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Subscribe to progress events and forward to WebSocket
|
# Subscribe to progress events and forward to WebSocket
|
||||||
@ -254,12 +254,12 @@ class TestDownloadProgressIntegration:
|
|||||||
await websocket_service.connect(client1, "client1")
|
await websocket_service.connect(client1, "client1")
|
||||||
await websocket_service.connect(client2, "client2")
|
await websocket_service.connect(client2, "client2")
|
||||||
|
|
||||||
# Join both clients to the queue_progress room
|
# Join both clients to the queue room
|
||||||
await websocket_service.manager.join_room(
|
await websocket_service.manager.join_room(
|
||||||
"client1", "queue_progress"
|
"client1", "queue"
|
||||||
)
|
)
|
||||||
await websocket_service.manager.join_room(
|
await websocket_service.manager.join_room(
|
||||||
"client2", "queue_progress"
|
"client2", "queue"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Subscribe to progress events and forward to WebSocket
|
# Subscribe to progress events and forward to WebSocket
|
||||||
|
|||||||
@ -325,8 +325,9 @@ class TestWebSocketScanIntegration:
|
|||||||
assert len(broadcasts) >= 2 # At least start and complete
|
assert len(broadcasts) >= 2 # At least start and complete
|
||||||
|
|
||||||
# Check for scan progress broadcasts
|
# Check for scan progress broadcasts
|
||||||
|
# Room name is 'scan' for SCAN type progress
|
||||||
scan_broadcasts = [
|
scan_broadcasts = [
|
||||||
b for b in broadcasts if b["room"] == "scan_progress"
|
b for b in broadcasts if b["room"] == "scan"
|
||||||
]
|
]
|
||||||
assert len(scan_broadcasts) >= 2
|
assert len(scan_broadcasts) >= 2
|
||||||
|
|
||||||
@ -379,8 +380,9 @@ class TestWebSocketScanIntegration:
|
|||||||
await anime_service.rescan()
|
await anime_service.rescan()
|
||||||
|
|
||||||
# Verify failure broadcast
|
# Verify failure broadcast
|
||||||
|
# Room name is 'scan' for SCAN type progress
|
||||||
scan_broadcasts = [
|
scan_broadcasts = [
|
||||||
b for b in broadcasts if b["room"] == "scan_progress"
|
b for b in broadcasts if b["room"] == "scan"
|
||||||
]
|
]
|
||||||
assert len(scan_broadcasts) >= 2 # Start and fail
|
assert len(scan_broadcasts) >= 2 # Start and fail
|
||||||
|
|
||||||
@ -438,7 +440,7 @@ class TestWebSocketProgressIntegration:
|
|||||||
|
|
||||||
start_broadcast = broadcasts[0]
|
start_broadcast = broadcasts[0]
|
||||||
assert start_broadcast["data"]["status"] == "started"
|
assert start_broadcast["data"]["status"] == "started"
|
||||||
assert start_broadcast["room"] == "download_progress"
|
assert start_broadcast["room"] == "downloads" # Room name for DOWNLOAD type
|
||||||
|
|
||||||
update_broadcast = broadcasts[1]
|
update_broadcast = broadcasts[1]
|
||||||
assert update_broadcast["data"]["status"] == "in_progress"
|
assert update_broadcast["data"]["status"] == "in_progress"
|
||||||
|
|||||||
@ -352,7 +352,7 @@ class TestProgressService:
|
|||||||
# First positional arg is ProgressEvent
|
# First positional arg is ProgressEvent
|
||||||
call_args = mock_broadcast.call_args[0][0]
|
call_args = mock_broadcast.call_args[0][0]
|
||||||
assert call_args.event_type == "download_progress"
|
assert call_args.event_type == "download_progress"
|
||||||
assert call_args.room == "download_progress"
|
assert call_args.room == "downloads" # Room name for DOWNLOAD type
|
||||||
assert call_args.progress_id == "test-1"
|
assert call_args.progress_id == "test-1"
|
||||||
assert call_args.progress.id == "test-1"
|
assert call_args.progress.id == "test-1"
|
||||||
|
|
||||||
|
|||||||
445
tests/unit/test_queue_progress_broadcast.py
Normal file
445
tests/unit/test_queue_progress_broadcast.py
Normal file
@ -0,0 +1,445 @@
|
|||||||
|
"""Unit tests for queue progress broadcast to correct WebSocket rooms.
|
||||||
|
|
||||||
|
This module tests that download progress events are broadcast to the
|
||||||
|
correct WebSocket rooms ('downloads' for DOWNLOAD type progress).
|
||||||
|
These tests verify the fix for progress not transmitting to clients.
|
||||||
|
|
||||||
|
No real downloads are started - all tests use mocks to verify the
|
||||||
|
event flow from ProgressService through WebSocket broadcasting.
|
||||||
|
"""
|
||||||
|
import asyncio
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from src.server.services.progress_service import (
|
||||||
|
ProgressEvent,
|
||||||
|
ProgressService,
|
||||||
|
ProgressStatus,
|
||||||
|
ProgressType,
|
||||||
|
_get_room_for_progress_type,
|
||||||
|
)
|
||||||
|
from src.server.services.websocket_service import WebSocketService
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoomNameMapping:
|
||||||
|
"""Tests for progress type to room name mapping."""
|
||||||
|
|
||||||
|
def test_download_progress_maps_to_downloads_room(self):
|
||||||
|
"""Test that DOWNLOAD type maps to 'downloads' room."""
|
||||||
|
room = _get_room_for_progress_type(ProgressType.DOWNLOAD)
|
||||||
|
assert room == "downloads"
|
||||||
|
|
||||||
|
def test_scan_progress_maps_to_scan_room(self):
|
||||||
|
"""Test that SCAN type maps to 'scan' room."""
|
||||||
|
room = _get_room_for_progress_type(ProgressType.SCAN)
|
||||||
|
assert room == "scan"
|
||||||
|
|
||||||
|
def test_queue_progress_maps_to_queue_room(self):
|
||||||
|
"""Test that QUEUE type maps to 'queue' room."""
|
||||||
|
room = _get_room_for_progress_type(ProgressType.QUEUE)
|
||||||
|
assert room == "queue"
|
||||||
|
|
||||||
|
def test_system_progress_maps_to_system_room(self):
|
||||||
|
"""Test that SYSTEM type maps to 'system' room."""
|
||||||
|
room = _get_room_for_progress_type(ProgressType.SYSTEM)
|
||||||
|
assert room == "system"
|
||||||
|
|
||||||
|
def test_error_progress_maps_to_errors_room(self):
|
||||||
|
"""Test that ERROR type maps to 'errors' room."""
|
||||||
|
room = _get_room_for_progress_type(ProgressType.ERROR)
|
||||||
|
assert room == "errors"
|
||||||
|
|
||||||
|
|
||||||
|
class TestProgressServiceBroadcastRoom:
|
||||||
|
"""Tests for ProgressService broadcasting to correct rooms."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def progress_service(self):
|
||||||
|
"""Create a fresh ProgressService for each test."""
|
||||||
|
return ProgressService()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_handler(self):
|
||||||
|
"""Create a mock event handler to capture broadcasts."""
|
||||||
|
return AsyncMock()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_start_download_progress_broadcasts_to_downloads_room(
|
||||||
|
self, progress_service, mock_handler
|
||||||
|
):
|
||||||
|
"""Test start_progress with DOWNLOAD type uses 'downloads' room."""
|
||||||
|
# Subscribe to progress events
|
||||||
|
progress_service.subscribe("progress_updated", mock_handler)
|
||||||
|
|
||||||
|
# Start a download progress
|
||||||
|
await progress_service.start_progress(
|
||||||
|
progress_id="test-download-1",
|
||||||
|
progress_type=ProgressType.DOWNLOAD,
|
||||||
|
title="Test Download",
|
||||||
|
message="Downloading episode",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify handler was called with correct room
|
||||||
|
mock_handler.assert_called_once()
|
||||||
|
event: ProgressEvent = mock_handler.call_args[0][0]
|
||||||
|
|
||||||
|
assert event.room == "downloads", (
|
||||||
|
f"Expected room 'downloads' but got '{event.room}'"
|
||||||
|
)
|
||||||
|
assert event.event_type == "download_progress"
|
||||||
|
assert event.progress.status == ProgressStatus.STARTED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_download_progress_broadcasts_to_downloads_room(
|
||||||
|
self, progress_service, mock_handler
|
||||||
|
):
|
||||||
|
"""Test update_progress with DOWNLOAD type uses 'downloads' room."""
|
||||||
|
# Start progress first
|
||||||
|
await progress_service.start_progress(
|
||||||
|
progress_id="test-download-2",
|
||||||
|
progress_type=ProgressType.DOWNLOAD,
|
||||||
|
title="Test Download",
|
||||||
|
total=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Subscribe after start to only capture update event
|
||||||
|
progress_service.subscribe("progress_updated", mock_handler)
|
||||||
|
|
||||||
|
# Update progress with force_broadcast
|
||||||
|
await progress_service.update_progress(
|
||||||
|
progress_id="test-download-2",
|
||||||
|
current=50,
|
||||||
|
message="50% complete",
|
||||||
|
force_broadcast=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify handler was called with correct room
|
||||||
|
mock_handler.assert_called_once()
|
||||||
|
event: ProgressEvent = mock_handler.call_args[0][0]
|
||||||
|
|
||||||
|
assert event.room == "downloads", (
|
||||||
|
f"Expected room 'downloads' but got '{event.room}'"
|
||||||
|
)
|
||||||
|
assert event.event_type == "download_progress"
|
||||||
|
assert event.progress.status == ProgressStatus.IN_PROGRESS
|
||||||
|
assert event.progress.percent == 50.0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_complete_download_progress_broadcasts_to_downloads_room(
|
||||||
|
self, progress_service, mock_handler
|
||||||
|
):
|
||||||
|
"""Test complete_progress with DOWNLOAD uses 'downloads' room."""
|
||||||
|
# Start progress first
|
||||||
|
await progress_service.start_progress(
|
||||||
|
progress_id="test-download-3",
|
||||||
|
progress_type=ProgressType.DOWNLOAD,
|
||||||
|
title="Test Download",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Subscribe after start to only capture complete event
|
||||||
|
progress_service.subscribe("progress_updated", mock_handler)
|
||||||
|
|
||||||
|
# Complete progress
|
||||||
|
await progress_service.complete_progress(
|
||||||
|
progress_id="test-download-3",
|
||||||
|
message="Download completed",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify handler was called with correct room
|
||||||
|
mock_handler.assert_called_once()
|
||||||
|
event: ProgressEvent = mock_handler.call_args[0][0]
|
||||||
|
|
||||||
|
assert event.room == "downloads", (
|
||||||
|
f"Expected room 'downloads' but got '{event.room}'"
|
||||||
|
)
|
||||||
|
assert event.event_type == "download_progress"
|
||||||
|
assert event.progress.status == ProgressStatus.COMPLETED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_fail_download_progress_broadcasts_to_downloads_room(
|
||||||
|
self, progress_service, mock_handler
|
||||||
|
):
|
||||||
|
"""Test that fail_progress with DOWNLOAD type uses 'downloads' room."""
|
||||||
|
# Start progress first
|
||||||
|
await progress_service.start_progress(
|
||||||
|
progress_id="test-download-4",
|
||||||
|
progress_type=ProgressType.DOWNLOAD,
|
||||||
|
title="Test Download",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Subscribe after start to only capture fail event
|
||||||
|
progress_service.subscribe("progress_updated", mock_handler)
|
||||||
|
|
||||||
|
# Fail progress
|
||||||
|
await progress_service.fail_progress(
|
||||||
|
progress_id="test-download-4",
|
||||||
|
error_message="Connection lost",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify handler was called with correct room
|
||||||
|
mock_handler.assert_called_once()
|
||||||
|
event: ProgressEvent = mock_handler.call_args[0][0]
|
||||||
|
|
||||||
|
assert event.room == "downloads", (
|
||||||
|
f"Expected room 'downloads' but got '{event.room}'"
|
||||||
|
)
|
||||||
|
assert event.event_type == "download_progress"
|
||||||
|
assert event.progress.status == ProgressStatus.FAILED
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_queue_progress_broadcasts_to_queue_room(
|
||||||
|
self, progress_service, mock_handler
|
||||||
|
):
|
||||||
|
"""Test that QUEUE type progress uses 'queue' room."""
|
||||||
|
progress_service.subscribe("progress_updated", mock_handler)
|
||||||
|
|
||||||
|
await progress_service.start_progress(
|
||||||
|
progress_id="test-queue-1",
|
||||||
|
progress_type=ProgressType.QUEUE,
|
||||||
|
title="Queue Status",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_handler.assert_called_once()
|
||||||
|
event: ProgressEvent = mock_handler.call_args[0][0]
|
||||||
|
|
||||||
|
assert event.room == "queue", (
|
||||||
|
f"Expected room 'queue' but got '{event.room}'"
|
||||||
|
)
|
||||||
|
assert event.event_type == "queue_progress"
|
||||||
|
|
||||||
|
|
||||||
|
class TestEndToEndProgressBroadcast:
|
||||||
|
"""End-to-end tests for progress broadcast via WebSocket."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def websocket_service(self):
|
||||||
|
"""Create a WebSocketService."""
|
||||||
|
return WebSocketService()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def progress_service(self):
|
||||||
|
"""Create a ProgressService."""
|
||||||
|
return ProgressService()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_progress_broadcast_reaches_downloads_room_clients(
|
||||||
|
self, websocket_service, progress_service
|
||||||
|
):
|
||||||
|
"""Test that download progress reaches clients in 'downloads' room.
|
||||||
|
|
||||||
|
This is the key test verifying the fix: progress updates should
|
||||||
|
be broadcast to the 'downloads' room, not 'download_progress'.
|
||||||
|
"""
|
||||||
|
# Track messages received by mock client
|
||||||
|
received_messages: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
# Create mock WebSocket
|
||||||
|
class MockWebSocket:
|
||||||
|
async def accept(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def send_json(self, data):
|
||||||
|
received_messages.append(data)
|
||||||
|
|
||||||
|
async def receive_json(self):
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
# Connect client to WebSocket service
|
||||||
|
mock_ws = MockWebSocket()
|
||||||
|
connection_id = "test_client"
|
||||||
|
await websocket_service.connect(mock_ws, connection_id)
|
||||||
|
|
||||||
|
# Join the 'downloads' room (this is what the JS client does)
|
||||||
|
await websocket_service.manager.join_room(connection_id, "downloads")
|
||||||
|
|
||||||
|
# Set up the progress event handler (mimics fastapi_app.py)
|
||||||
|
async def progress_event_handler(event: ProgressEvent) -> None:
|
||||||
|
"""Handle progress events and broadcast via WebSocket."""
|
||||||
|
message = {
|
||||||
|
"type": event.event_type,
|
||||||
|
"data": event.progress.to_dict(),
|
||||||
|
}
|
||||||
|
await websocket_service.manager.broadcast_to_room(
|
||||||
|
message, event.room
|
||||||
|
)
|
||||||
|
|
||||||
|
progress_service.subscribe("progress_updated", progress_event_handler)
|
||||||
|
|
||||||
|
# Simulate download progress lifecycle
|
||||||
|
# 1. Start download
|
||||||
|
await progress_service.start_progress(
|
||||||
|
progress_id="real-download-test",
|
||||||
|
progress_type=ProgressType.DOWNLOAD,
|
||||||
|
title="Downloading Anime Episode",
|
||||||
|
total=100,
|
||||||
|
metadata={"item_id": "item-123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. Update progress multiple times
|
||||||
|
for percent in [25, 50, 75]:
|
||||||
|
await progress_service.update_progress(
|
||||||
|
progress_id="real-download-test",
|
||||||
|
current=percent,
|
||||||
|
message=f"{percent}% complete",
|
||||||
|
metadata={"speed_mbps": 2.5},
|
||||||
|
force_broadcast=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. Complete download
|
||||||
|
await progress_service.complete_progress(
|
||||||
|
progress_id="real-download-test",
|
||||||
|
message="Download completed successfully",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify client received all messages
|
||||||
|
# Filter for download_progress type messages
|
||||||
|
download_messages = [
|
||||||
|
m for m in received_messages
|
||||||
|
if m.get("type") == "download_progress"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Should have: start + 3 updates + complete = 5 messages
|
||||||
|
assert len(download_messages) >= 4, (
|
||||||
|
f"Expected at least 4 download_progress messages, "
|
||||||
|
f"got {len(download_messages)}: {download_messages}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify first message is start
|
||||||
|
assert download_messages[0]["data"]["status"] == "started"
|
||||||
|
|
||||||
|
# Verify last message is completed
|
||||||
|
assert download_messages[-1]["data"]["status"] == "completed"
|
||||||
|
assert download_messages[-1]["data"]["percent"] == 100.0
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await websocket_service.disconnect(connection_id)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clients_not_in_downloads_room_dont_receive_progress(
|
||||||
|
self, websocket_service, progress_service
|
||||||
|
):
|
||||||
|
"""Test that clients not in 'downloads' room don't receive progress."""
|
||||||
|
downloads_messages: List[Dict] = []
|
||||||
|
other_messages: List[Dict] = []
|
||||||
|
|
||||||
|
class MockWebSocket:
|
||||||
|
def __init__(self, message_list):
|
||||||
|
self.messages = message_list
|
||||||
|
|
||||||
|
async def accept(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def send_json(self, data):
|
||||||
|
self.messages.append(data)
|
||||||
|
|
||||||
|
async def receive_json(self):
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
# Client in 'downloads' room
|
||||||
|
ws_downloads = MockWebSocket(downloads_messages)
|
||||||
|
await websocket_service.connect(ws_downloads, "client_downloads")
|
||||||
|
await websocket_service.manager.join_room(
|
||||||
|
"client_downloads", "downloads"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Client in 'system' room (different room)
|
||||||
|
ws_other = MockWebSocket(other_messages)
|
||||||
|
await websocket_service.connect(ws_other, "client_other")
|
||||||
|
await websocket_service.manager.join_room("client_other", "system")
|
||||||
|
|
||||||
|
# Set up progress handler
|
||||||
|
async def progress_event_handler(event: ProgressEvent) -> None:
|
||||||
|
message = {
|
||||||
|
"type": event.event_type,
|
||||||
|
"data": event.progress.to_dict(),
|
||||||
|
}
|
||||||
|
await websocket_service.manager.broadcast_to_room(
|
||||||
|
message, event.room
|
||||||
|
)
|
||||||
|
|
||||||
|
progress_service.subscribe("progress_updated", progress_event_handler)
|
||||||
|
|
||||||
|
# Emit download progress
|
||||||
|
await progress_service.start_progress(
|
||||||
|
progress_id="isolation-test",
|
||||||
|
progress_type=ProgressType.DOWNLOAD,
|
||||||
|
title="Test Download",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only 'downloads' room client should receive the message
|
||||||
|
download_progress_in_downloads = [
|
||||||
|
m for m in downloads_messages
|
||||||
|
if m.get("type") == "download_progress"
|
||||||
|
]
|
||||||
|
download_progress_in_other = [
|
||||||
|
m for m in other_messages
|
||||||
|
if m.get("type") == "download_progress"
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(download_progress_in_downloads) == 1, (
|
||||||
|
"Client in 'downloads' room should receive download_progress"
|
||||||
|
)
|
||||||
|
assert len(download_progress_in_other) == 0, (
|
||||||
|
"Client in 'system' room should NOT receive download_progress"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
await websocket_service.disconnect("client_downloads")
|
||||||
|
await websocket_service.disconnect("client_other")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_progress_update_includes_item_id_in_metadata(
|
||||||
|
self, websocket_service, progress_service
|
||||||
|
):
|
||||||
|
"""Test progress updates include item_id for JS client matching."""
|
||||||
|
received_messages: List[Dict] = []
|
||||||
|
|
||||||
|
class MockWebSocket:
|
||||||
|
async def accept(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def send_json(self, data):
|
||||||
|
received_messages.append(data)
|
||||||
|
|
||||||
|
async def receive_json(self):
|
||||||
|
await asyncio.sleep(10)
|
||||||
|
|
||||||
|
mock_ws = MockWebSocket()
|
||||||
|
await websocket_service.connect(mock_ws, "test_client")
|
||||||
|
await websocket_service.manager.join_room("test_client", "downloads")
|
||||||
|
|
||||||
|
async def progress_event_handler(event: ProgressEvent) -> None:
|
||||||
|
message = {
|
||||||
|
"type": event.event_type,
|
||||||
|
"data": event.progress.to_dict(),
|
||||||
|
}
|
||||||
|
await websocket_service.manager.broadcast_to_room(
|
||||||
|
message, event.room
|
||||||
|
)
|
||||||
|
|
||||||
|
progress_service.subscribe("progress_updated", progress_event_handler)
|
||||||
|
|
||||||
|
# Start progress with item_id in metadata
|
||||||
|
item_id = "uuid-12345-67890"
|
||||||
|
await progress_service.start_progress(
|
||||||
|
progress_id=f"download_{item_id}",
|
||||||
|
progress_type=ProgressType.DOWNLOAD,
|
||||||
|
title="Test Download",
|
||||||
|
metadata={"item_id": item_id},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify item_id is present in broadcast
|
||||||
|
download_messages = [
|
||||||
|
m for m in received_messages
|
||||||
|
if m.get("type") == "download_progress"
|
||||||
|
]
|
||||||
|
|
||||||
|
assert len(download_messages) == 1
|
||||||
|
metadata = download_messages[0]["data"].get("metadata", {})
|
||||||
|
assert metadata.get("item_id") == item_id, (
|
||||||
|
f"Expected item_id '{item_id}' in metadata, got: {metadata}"
|
||||||
|
)
|
||||||
|
|
||||||
|
await websocket_service.disconnect("test_client")
|
||||||
Loading…
x
Reference in New Issue
Block a user