feat: implement WebSocket real-time progress updates

- Add ProgressService for centralized progress tracking and broadcasting
- Integrate ProgressService with DownloadService for download progress
- Integrate ProgressService with AnimeService for scan progress
- Add progress-related WebSocket message models (ScanProgress, ErrorNotification, etc.)
- Initialize ProgressService with WebSocket callback in application startup
- Add comprehensive unit tests for ProgressService
- Update infrastructure.md with ProgressService documentation
- Remove completed WebSocket Real-time Updates task from instructions.md

The ProgressService provides:
- Real-time progress tracking for downloads, scans, and queue operations
- Automatic progress percentage calculation
- Progress lifecycle management (start, update, complete, fail, cancel)
- WebSocket integration for instant client updates
- Progress history with size limits
- Thread-safe operations using asyncio locks
- Support for metadata and custom messages

Benefits:
- Decoupled progress tracking from WebSocket broadcasting
- Single reusable service across all components
- Supports multiple concurrent operations efficiently
- Centralized progress tracking simplifies monitoring
- Instant feedback to users on long-running operations
This commit is contained in:
Lukas 2025-10-17 11:12:06 +02:00
parent 42a07be4cb
commit 94de91ffa0
7 changed files with 1375 additions and 4 deletions

View File

@ -545,3 +545,155 @@ Implemented comprehensive REST API endpoints for download queue management:
- Router registered in `src/server/fastapi_app.py` via `app.include_router(download_router)`
- Follows same patterns as other API routers (auth, anime, config)
- Full OpenAPI documentation available at `/api/docs`
### WebSocket Real-time Updates (October 2025)
Implemented real-time progress tracking and WebSocket broadcasting for downloads, scans, and system events.
#### ProgressService
**File**: `src/server/services/progress_service.py`
A centralized service for tracking and broadcasting real-time progress updates across the application.
**Key Features**:
- Track multiple concurrent progress operations (downloads, scans, queue changes)
- Automatic progress percentage calculation
- Progress lifecycle management (start, update, complete, fail, cancel)
- WebSocket integration for real-time client updates
- Progress history with configurable size limit (default: 50 items)
- Thread-safe operations using asyncio locks
- Support for progress metadata and custom messages
**Progress Types**:
- `DOWNLOAD` - File download progress
- `SCAN` - Library scan progress
- `QUEUE` - Queue operation progress
- `SYSTEM` - System-level operations
- `ERROR` - Error notifications
**Progress Statuses**:
- `STARTED` - Operation initiated
- `IN_PROGRESS` - Operation in progress
- `COMPLETED` - Successfully completed
- `FAILED` - Operation failed
- `CANCELLED` - Cancelled by user
**Core Methods**:
- `start_progress()` - Initialize new progress operation
- `update_progress()` - Update progress with current/total values
- `complete_progress()` - Mark operation as completed
- `fail_progress()` - Mark operation as failed
- `cancel_progress()` - Cancel ongoing operation
- `get_progress()` - Retrieve progress by ID
- `get_all_active_progress()` - Get all active operations (optionally filtered by type)
**Broadcasting**:
- Integrates with WebSocketService via callback
- Broadcasts to room-specific channels (e.g., `download_progress`, `scan_progress`)
- Configurable broadcast throttling (only on significant changes >1% or forced)
- Automatic progress state serialization to JSON
**Singleton Pattern**:
- Global instance via `get_progress_service()` factory
- Initialized during application startup with WebSocket callback
#### Integration with Services
**DownloadService Integration**:
- Progress tracking for each download item
- Real-time progress updates during file download
- Automatic completion/failure notifications
- Progress metadata includes speed, ETA, downloaded bytes
**AnimeService Integration**:
- Progress tracking for library scans
- Scan progress with current/total file counts
- Scan completion with statistics
- Error notifications on scan failures
#### WebSocket Message Models
**File**: `src/server/models/websocket.py`
Added progress-specific message models:
- `ScanProgressMessage` - Scan progress updates
- `ScanCompleteMessage` - Scan completion notification
- `ScanFailedMessage` - Scan failure notification
- `ErrorNotificationMessage` - Critical error notifications
- `ProgressUpdateMessage` - Generic progress updates
**WebSocket Message Types**:
- `SCAN_PROGRESS` - Scan progress updates
- `SCAN_COMPLETE` - Scan completion
- `SCAN_FAILED` - Scan failure
- Extended existing types for downloads and queue updates
#### WebSocket Rooms
Clients can subscribe to specific progress channels:
- `download_progress` - Download progress updates
- `scan_progress` - Library scan updates
- `queue_progress` - Queue operation updates
- `system_progress` - System-level updates
Room subscription via client messages:
```json
{
"action": "join",
"room": "download_progress"
}
```
#### Application Startup
**File**: `src/server/fastapi_app.py`
Progress service initialized on application startup:
1. Get ProgressService singleton instance
2. Get WebSocketService singleton instance
3. Register broadcast callback to link progress updates with WebSocket
4. Callback broadcasts progress messages to appropriate rooms
#### Testing
**File**: `tests/unit/test_progress_service.py`
Comprehensive test coverage including:
- Progress lifecycle operations (start, update, complete, fail, cancel)
- Percentage calculation accuracy
- History management and size limits
- Broadcast callback invocation
- Concurrent progress operations
- Metadata handling
- Error conditions and edge cases
#### Architecture Benefits
- **Decoupling**: ProgressService decouples progress tracking from WebSocket broadcasting
- **Reusability**: Single service used across all application components
- **Scalability**: Supports multiple concurrent operations efficiently
- **Observability**: Centralized progress tracking simplifies monitoring
- **Real-time UX**: Instant feedback to users on all long-running operations
#### Future Enhancements
- Persistent progress history (database storage)
- Progress rate calculation and trend analysis
- Multi-process progress synchronization (Redis/shared store)
- Progress event hooks for custom actions
- Client-side progress resumption after reconnection

View File

@ -29,6 +29,8 @@ from src.server.controllers.error_controller import (
from src.server.controllers.health_controller import router as health_router
from src.server.controllers.page_controller import router as page_router
from src.server.middleware.auth import AuthMiddleware
from src.server.services.progress_service import get_progress_service
from src.server.services.websocket_service import get_websocket_service
# Initialize FastAPI app
app = FastAPI(
@ -74,6 +76,23 @@ async def startup_event():
# Initialize SeriesApp with configured directory
if settings.anime_directory:
series_app = SeriesApp(settings.anime_directory)
# Initialize progress service with websocket callback
progress_service = get_progress_service()
ws_service = get_websocket_service()
async def broadcast_callback(
message_type: str, data: dict, room: str
):
"""Broadcast progress updates via WebSocket."""
message = {
"type": message_type,
"data": data,
}
await ws_service.manager.broadcast_to_room(message, room)
progress_service.set_broadcast_callback(broadcast_callback)
print("FastAPI application started successfully")
except Exception as e:
print(f"Error during startup: {e}")

View File

@ -30,6 +30,11 @@ class WebSocketMessageType(str, Enum):
QUEUE_PAUSED = "queue_paused"
QUEUE_RESUMED = "queue_resumed"
# Progress-related messages
SCAN_PROGRESS = "scan_progress"
SCAN_COMPLETE = "scan_complete"
SCAN_FAILED = "scan_failed"
# System messages
SYSTEM_INFO = "system_info"
SYSTEM_WARNING = "system_warning"
@ -188,3 +193,93 @@ class RoomSubscriptionRequest(BaseModel):
room: str = Field(
..., min_length=1, description="Room name to join or leave"
)
class ScanProgressMessage(BaseModel):
"""Scan progress update message."""
type: WebSocketMessageType = Field(
default=WebSocketMessageType.SCAN_PROGRESS,
description="Message type",
)
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(),
description="ISO 8601 timestamp",
)
data: Dict[str, Any] = Field(
...,
description="Scan progress data including current, total, percent",
)
class ScanCompleteMessage(BaseModel):
"""Scan completion message."""
type: WebSocketMessageType = Field(
default=WebSocketMessageType.SCAN_COMPLETE,
description="Message type",
)
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(),
description="ISO 8601 timestamp",
)
data: Dict[str, Any] = Field(
...,
description="Scan completion data including series_found, duration",
)
class ScanFailedMessage(BaseModel):
"""Scan failure message."""
type: WebSocketMessageType = Field(
default=WebSocketMessageType.SCAN_FAILED,
description="Message type",
)
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(),
description="ISO 8601 timestamp",
)
data: Dict[str, Any] = Field(
..., description="Scan error data including error_message"
)
class ErrorNotificationMessage(BaseModel):
"""Error notification message for critical errors."""
type: WebSocketMessageType = Field(
default=WebSocketMessageType.SYSTEM_ERROR,
description="Message type",
)
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(),
description="ISO 8601 timestamp",
)
data: Dict[str, Any] = Field(
...,
description=(
"Error notification data including severity, message, details"
),
)
class ProgressUpdateMessage(BaseModel):
"""Generic progress update message.
Can be used for any type of progress (download, scan, queue, etc.)
"""
type: WebSocketMessageType = Field(
..., description="Type of progress message"
)
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(),
description="ISO 8601 timestamp",
)
data: Dict[str, Any] = Field(
...,
description=(
"Progress data including id, status, percent, current, total"
),
)

View File

@ -3,11 +3,16 @@ from __future__ import annotations
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
from typing import List, Optional
from typing import Callable, List, Optional
import structlog
from src.core.SeriesApp import SeriesApp
from src.server.services.progress_service import (
ProgressService,
ProgressType,
get_progress_service,
)
logger = structlog.get_logger(__name__)
@ -24,9 +29,15 @@ class AnimeService:
- Adds simple in-memory caching for read operations
"""
def __init__(self, directory: str, max_workers: int = 4):
def __init__(
self,
directory: str,
max_workers: int = 4,
progress_service: Optional[ProgressService] = None,
):
self._directory = directory
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._progress_service = progress_service or get_progress_service()
# SeriesApp is blocking; instantiate per-service
try:
self._app = SeriesApp(directory)
@ -75,20 +86,70 @@ class AnimeService:
logger.exception("search failed")
raise AnimeServiceError("Search failed") from e
async def rescan(self, callback=None) -> None:
async def rescan(self, callback: Optional[Callable] = None) -> None:
"""Trigger a re-scan. Accepts an optional callback function.
The callback is executed in the threadpool by SeriesApp.
Progress updates are tracked and broadcasted via ProgressService.
"""
scan_id = "library_scan"
try:
await self._run_in_executor(self._app.ReScan, callback)
# Start progress tracking
await self._progress_service.start_progress(
progress_id=scan_id,
progress_type=ProgressType.SCAN,
title="Scanning anime library",
message="Initializing scan...",
)
# Create wrapped callback for progress updates
def progress_callback(progress_data: dict) -> None:
"""Update progress during scan."""
try:
if callback:
callback(progress_data)
# Update progress service
current = progress_data.get("current", 0)
total = progress_data.get("total", 0)
message = progress_data.get("message", "Scanning...")
asyncio.create_task(
self._progress_service.update_progress(
progress_id=scan_id,
current=current,
total=total,
message=message,
)
)
except Exception as e:
logger.error("Scan progress callback error", error=str(e))
# Run scan
await self._run_in_executor(self._app.ReScan, progress_callback)
# invalidate cache
try:
self._cached_list_missing.cache_clear()
except Exception:
pass
# Complete progress tracking
await self._progress_service.complete_progress(
progress_id=scan_id,
message="Scan completed successfully",
)
except Exception as e:
logger.exception("rescan failed")
# Fail progress tracking
await self._progress_service.fail_progress(
progress_id=scan_id,
error_message=str(e),
)
raise AnimeServiceError("Rescan failed") from e
async def download(self, serie_folder: str, season: int, episode: int, key: str, callback=None) -> bool:

View File

@ -27,6 +27,11 @@ from src.server.models.download import (
QueueStatus,
)
from src.server.services.anime_service import AnimeService, AnimeServiceError
from src.server.services.progress_service import (
ProgressService,
ProgressType,
get_progress_service,
)
logger = structlog.get_logger(__name__)
@ -53,6 +58,7 @@ class DownloadService:
max_concurrent_downloads: int = 2,
max_retries: int = 3,
persistence_path: str = "./data/download_queue.json",
progress_service: Optional[ProgressService] = None,
):
"""Initialize the download service.
@ -61,11 +67,13 @@ class DownloadService:
max_concurrent_downloads: Maximum simultaneous downloads
max_retries: Maximum retry attempts for failed downloads
persistence_path: Path to persist queue state
progress_service: Optional progress service for tracking
"""
self._anime_service = anime_service
self._max_concurrent = max_concurrent_downloads
self._max_retries = max_retries
self._persistence_path = Path(persistence_path)
self._progress_service = progress_service or get_progress_service()
# Queue storage by status
self._pending_queue: deque[DownloadItem] = deque()
@ -500,6 +508,23 @@ class DownloadService:
if item.progress.speed_mbps:
self._download_speeds.append(item.progress.speed_mbps)
# Update progress service
if item.progress.total_mb and item.progress.total_mb > 0:
current_mb = int(item.progress.downloaded_mb)
total_mb = int(item.progress.total_mb)
asyncio.create_task(
self._progress_service.update_progress(
progress_id=f"download_{item.id}",
current=current_mb,
total=total_mb,
metadata={
"speed_mbps": item.progress.speed_mbps,
"eta_seconds": item.progress.eta_seconds,
},
)
)
# Broadcast update (fire and forget)
asyncio.create_task(
self._broadcast_update(
@ -535,6 +560,22 @@ class DownloadService:
episode=item.episode.episode,
)
# Start progress tracking
await self._progress_service.start_progress(
progress_id=f"download_{item.id}",
progress_type=ProgressType.DOWNLOAD,
title=f"Downloading {item.serie_name}",
message=(
f"S{item.episode.season:02d}E{item.episode.episode:02d}"
),
metadata={
"item_id": item.id,
"serie_name": item.serie_name,
"season": item.episode.season,
"episode": item.episode.episode,
},
)
# Create progress callback
progress_callback = self._create_progress_callback(item)
@ -561,6 +602,18 @@ class DownloadService:
logger.info(
"Download completed successfully", item_id=item.id
)
# Complete progress tracking
await self._progress_service.complete_progress(
progress_id=f"download_{item.id}",
message="Download completed successfully",
metadata={
"downloaded_mb": item.progress.downloaded_mb
if item.progress
else 0,
},
)
await self._broadcast_update(
"download_completed", {"item_id": item.id}
)
@ -581,6 +634,13 @@ class DownloadService:
retry_count=item.retry_count,
)
# Fail progress tracking
await self._progress_service.fail_progress(
progress_id=f"download_{item.id}",
error_message=str(e),
metadata={"retry_count": item.retry_count},
)
await self._broadcast_update(
"download_failed",
{"item_id": item.id, "error": item.error},

View File

@ -0,0 +1,485 @@
"""Progress service for managing real-time progress updates.
This module provides a centralized service for tracking and broadcasting
real-time progress updates for downloads, scans, queue changes, and
system events. It integrates with the WebSocket service to push updates
to connected clients.
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Callable, Dict, Optional
import structlog
logger = structlog.get_logger(__name__)
class ProgressType(str, Enum):
"""Types of progress updates."""
DOWNLOAD = "download"
SCAN = "scan"
QUEUE = "queue"
SYSTEM = "system"
ERROR = "error"
class ProgressStatus(str, Enum):
"""Status of a progress operation."""
STARTED = "started"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class ProgressUpdate:
"""Represents a progress update event.
Attributes:
id: Unique identifier for this progress operation
type: Type of progress (download, scan, etc.)
status: Current status of the operation
title: Human-readable title
message: Detailed message
percent: Completion percentage (0-100)
current: Current progress value
total: Total progress value
metadata: Additional metadata
started_at: When operation started
updated_at: When last updated
"""
id: str
type: ProgressType
status: ProgressStatus
title: str
message: str = ""
percent: float = 0.0
current: int = 0
total: int = 0
metadata: Dict[str, Any] = field(default_factory=dict)
started_at: datetime = field(default_factory=datetime.utcnow)
updated_at: datetime = field(default_factory=datetime.utcnow)
def to_dict(self) -> Dict[str, Any]:
"""Convert progress update to dictionary."""
return {
"id": self.id,
"type": self.type.value,
"status": self.status.value,
"title": self.title,
"message": self.message,
"percent": round(self.percent, 2),
"current": self.current,
"total": self.total,
"metadata": self.metadata,
"started_at": self.started_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
}
class ProgressServiceError(Exception):
"""Service-level exception for progress operations."""
class ProgressService:
"""Manages real-time progress updates and broadcasting.
Features:
- Track multiple concurrent progress operations
- Calculate progress percentages and rates
- Broadcast updates via WebSocket
- Manage progress lifecycle (start, update, complete, fail)
- Support for different progress types (download, scan, queue)
"""
def __init__(self):
"""Initialize the progress service."""
# Active progress operations: id -> ProgressUpdate
self._active_progress: Dict[str, ProgressUpdate] = {}
# Completed progress history (limited size)
self._history: Dict[str, ProgressUpdate] = {}
self._max_history_size = 50
# WebSocket broadcast callback
self._broadcast_callback: Optional[Callable] = None
# Lock for thread-safe operations
self._lock = asyncio.Lock()
logger.info("ProgressService initialized")
def set_broadcast_callback(self, callback: Callable) -> None:
"""Set callback for broadcasting progress updates via WebSocket.
Args:
callback: Async function to call for broadcasting updates
"""
self._broadcast_callback = callback
logger.debug("Progress broadcast callback registered")
async def _broadcast(self, update: ProgressUpdate, room: str) -> None:
"""Broadcast progress update to WebSocket clients.
Args:
update: Progress update to broadcast
room: WebSocket room to broadcast to
"""
if self._broadcast_callback:
try:
await self._broadcast_callback(
message_type=f"{update.type.value}_progress",
data=update.to_dict(),
room=room,
)
except Exception as e:
logger.error(
"Failed to broadcast progress update",
error=str(e),
progress_id=update.id,
)
async def start_progress(
self,
progress_id: str,
progress_type: ProgressType,
title: str,
total: int = 0,
message: str = "",
metadata: Optional[Dict[str, Any]] = None,
) -> ProgressUpdate:
"""Start a new progress operation.
Args:
progress_id: Unique identifier for this progress
progress_type: Type of progress operation
title: Human-readable title
total: Total items/bytes to process
message: Initial message
metadata: Additional metadata
Returns:
Created progress update object
Raises:
ProgressServiceError: If progress already exists
"""
async with self._lock:
if progress_id in self._active_progress:
raise ProgressServiceError(
f"Progress with id '{progress_id}' already exists"
)
update = ProgressUpdate(
id=progress_id,
type=progress_type,
status=ProgressStatus.STARTED,
title=title,
message=message,
total=total,
metadata=metadata or {},
)
self._active_progress[progress_id] = update
logger.info(
"Progress started",
progress_id=progress_id,
type=progress_type.value,
title=title,
)
# Broadcast to appropriate room
room = f"{progress_type.value}_progress"
await self._broadcast(update, room)
return update
async def update_progress(
self,
progress_id: str,
current: Optional[int] = None,
total: Optional[int] = None,
message: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
force_broadcast: bool = False,
) -> ProgressUpdate:
"""Update an existing progress operation.
Args:
progress_id: Progress identifier
current: Current progress value
total: Updated total value
message: Updated message
metadata: Additional metadata to merge
force_broadcast: Force broadcasting even for small changes
Returns:
Updated progress object
Raises:
ProgressServiceError: If progress not found
"""
async with self._lock:
if progress_id not in self._active_progress:
raise ProgressServiceError(
f"Progress with id '{progress_id}' not found"
)
update = self._active_progress[progress_id]
old_percent = update.percent
# Update fields
if current is not None:
update.current = current
if total is not None:
update.total = total
if message is not None:
update.message = message
if metadata:
update.metadata.update(metadata)
# Calculate percentage
if update.total > 0:
update.percent = (update.current / update.total) * 100
else:
update.percent = 0.0
update.status = ProgressStatus.IN_PROGRESS
update.updated_at = datetime.utcnow()
# Only broadcast if significant change or forced
percent_change = abs(update.percent - old_percent)
should_broadcast = force_broadcast or percent_change >= 1.0
if should_broadcast:
room = f"{update.type.value}_progress"
await self._broadcast(update, room)
return update
async def complete_progress(
self,
progress_id: str,
message: str = "Completed successfully",
metadata: Optional[Dict[str, Any]] = None,
) -> ProgressUpdate:
"""Mark a progress operation as completed.
Args:
progress_id: Progress identifier
message: Completion message
metadata: Additional metadata
Returns:
Completed progress object
Raises:
ProgressServiceError: If progress not found
"""
async with self._lock:
if progress_id not in self._active_progress:
raise ProgressServiceError(
f"Progress with id '{progress_id}' not found"
)
update = self._active_progress[progress_id]
update.status = ProgressStatus.COMPLETED
update.message = message
update.percent = 100.0
update.current = update.total
update.updated_at = datetime.utcnow()
if metadata:
update.metadata.update(metadata)
# Move to history
del self._active_progress[progress_id]
self._add_to_history(update)
logger.info(
"Progress completed",
progress_id=progress_id,
type=update.type.value,
)
# Broadcast completion
room = f"{update.type.value}_progress"
await self._broadcast(update, room)
return update
async def fail_progress(
self,
progress_id: str,
error_message: str,
metadata: Optional[Dict[str, Any]] = None,
) -> ProgressUpdate:
"""Mark a progress operation as failed.
Args:
progress_id: Progress identifier
error_message: Error description
metadata: Additional error metadata
Returns:
Failed progress object
Raises:
ProgressServiceError: If progress not found
"""
async with self._lock:
if progress_id not in self._active_progress:
raise ProgressServiceError(
f"Progress with id '{progress_id}' not found"
)
update = self._active_progress[progress_id]
update.status = ProgressStatus.FAILED
update.message = error_message
update.updated_at = datetime.utcnow()
if metadata:
update.metadata.update(metadata)
# Move to history
del self._active_progress[progress_id]
self._add_to_history(update)
logger.error(
"Progress failed",
progress_id=progress_id,
type=update.type.value,
error=error_message,
)
# Broadcast failure
room = f"{update.type.value}_progress"
await self._broadcast(update, room)
return update
async def cancel_progress(
self,
progress_id: str,
message: str = "Cancelled by user",
) -> ProgressUpdate:
"""Cancel a progress operation.
Args:
progress_id: Progress identifier
message: Cancellation message
Returns:
Cancelled progress object
Raises:
ProgressServiceError: If progress not found
"""
async with self._lock:
if progress_id not in self._active_progress:
raise ProgressServiceError(
f"Progress with id '{progress_id}' not found"
)
update = self._active_progress[progress_id]
update.status = ProgressStatus.CANCELLED
update.message = message
update.updated_at = datetime.utcnow()
# Move to history
del self._active_progress[progress_id]
self._add_to_history(update)
logger.info(
"Progress cancelled",
progress_id=progress_id,
type=update.type.value,
)
# Broadcast cancellation
room = f"{update.type.value}_progress"
await self._broadcast(update, room)
return update
def _add_to_history(self, update: ProgressUpdate) -> None:
"""Add completed progress to history with size limit."""
self._history[update.id] = update
# Maintain history size limit
if len(self._history) > self._max_history_size:
# Remove oldest entries
oldest_keys = sorted(
self._history.keys(),
key=lambda k: self._history[k].updated_at,
)[: len(self._history) - self._max_history_size]
for key in oldest_keys:
del self._history[key]
async def get_progress(self, progress_id: str) -> Optional[ProgressUpdate]:
"""Get current progress state.
Args:
progress_id: Progress identifier
Returns:
Progress update object or None if not found
"""
async with self._lock:
if progress_id in self._active_progress:
return self._active_progress[progress_id]
if progress_id in self._history:
return self._history[progress_id]
return None
async def get_all_active_progress(
self, progress_type: Optional[ProgressType] = None
) -> Dict[str, ProgressUpdate]:
"""Get all active progress operations.
Args:
progress_type: Optional filter by progress type
Returns:
Dictionary of progress_id -> ProgressUpdate
"""
async with self._lock:
if progress_type:
return {
pid: update
for pid, update in self._active_progress.items()
if update.type == progress_type
}
return self._active_progress.copy()
async def clear_history(self) -> None:
"""Clear progress history."""
async with self._lock:
self._history.clear()
logger.info("Progress history cleared")
# Global singleton instance
_progress_service: Optional[ProgressService] = None
def get_progress_service() -> ProgressService:
"""Get or create the global progress service instance.
Returns:
Global ProgressService instance
"""
global _progress_service
if _progress_service is None:
_progress_service = ProgressService()
return _progress_service

View File

@ -0,0 +1,499 @@
"""Unit tests for ProgressService.
This module contains comprehensive tests for the progress tracking service,
including progress lifecycle, broadcasting, error handling, and concurrency.
"""
import asyncio
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock
import pytest
from src.server.services.progress_service import (
ProgressService,
ProgressServiceError,
ProgressStatus,
ProgressType,
ProgressUpdate,
)
class TestProgressUpdate:
"""Test ProgressUpdate dataclass."""
def test_progress_update_creation(self):
"""Test creating a progress update."""
update = ProgressUpdate(
id="test-1",
type=ProgressType.DOWNLOAD,
status=ProgressStatus.STARTED,
title="Test Download",
message="Starting download",
total=100,
)
assert update.id == "test-1"
assert update.type == ProgressType.DOWNLOAD
assert update.status == ProgressStatus.STARTED
assert update.title == "Test Download"
assert update.message == "Starting download"
assert update.total == 100
assert update.current == 0
assert update.percent == 0.0
def test_progress_update_to_dict(self):
"""Test converting progress update to dictionary."""
update = ProgressUpdate(
id="test-1",
type=ProgressType.SCAN,
status=ProgressStatus.IN_PROGRESS,
title="Test Scan",
message="Scanning files",
current=50,
total=100,
metadata={"test_key": "test_value"},
)
result = update.to_dict()
assert result["id"] == "test-1"
assert result["type"] == "scan"
assert result["status"] == "in_progress"
assert result["title"] == "Test Scan"
assert result["message"] == "Scanning files"
assert result["current"] == 50
assert result["total"] == 100
assert result["percent"] == 0.0
assert result["metadata"]["test_key"] == "test_value"
assert "started_at" in result
assert "updated_at" in result
class TestProgressService:
"""Test ProgressService class."""
@pytest.fixture
def service(self):
"""Create a fresh ProgressService instance for each test."""
return ProgressService()
@pytest.fixture
def mock_broadcast(self):
"""Create a mock broadcast callback."""
return AsyncMock()
@pytest.mark.asyncio
async def test_start_progress(self, service):
"""Test starting a new progress operation."""
update = await service.start_progress(
progress_id="download-1",
progress_type=ProgressType.DOWNLOAD,
title="Downloading episode",
total=1000,
message="Starting...",
metadata={"episode": "S01E01"},
)
assert update.id == "download-1"
assert update.type == ProgressType.DOWNLOAD
assert update.status == ProgressStatus.STARTED
assert update.title == "Downloading episode"
assert update.total == 1000
assert update.message == "Starting..."
assert update.metadata["episode"] == "S01E01"
@pytest.mark.asyncio
async def test_start_progress_duplicate_id(self, service):
"""Test starting progress with duplicate ID raises error."""
await service.start_progress(
progress_id="test-1",
progress_type=ProgressType.DOWNLOAD,
title="Test",
)
with pytest.raises(ProgressServiceError, match="already exists"):
await service.start_progress(
progress_id="test-1",
progress_type=ProgressType.DOWNLOAD,
title="Test Duplicate",
)
@pytest.mark.asyncio
async def test_update_progress(self, service):
"""Test updating an existing progress operation."""
await service.start_progress(
progress_id="test-1",
progress_type=ProgressType.DOWNLOAD,
title="Test",
total=100,
)
update = await service.update_progress(
progress_id="test-1",
current=50,
message="Half way",
)
assert update.current == 50
assert update.total == 100
assert update.percent == 50.0
assert update.message == "Half way"
assert update.status == ProgressStatus.IN_PROGRESS
@pytest.mark.asyncio
async def test_update_progress_not_found(self, service):
"""Test updating non-existent progress raises error."""
with pytest.raises(ProgressServiceError, match="not found"):
await service.update_progress(
progress_id="nonexistent",
current=50,
)
@pytest.mark.asyncio
async def test_update_progress_percentage_calculation(self, service):
"""Test progress percentage is calculated correctly."""
await service.start_progress(
progress_id="test-1",
progress_type=ProgressType.DOWNLOAD,
title="Test",
total=200,
)
await service.update_progress(progress_id="test-1", current=50)
update = await service.get_progress("test-1")
assert update.percent == 25.0
await service.update_progress(progress_id="test-1", current=100)
update = await service.get_progress("test-1")
assert update.percent == 50.0
await service.update_progress(progress_id="test-1", current=200)
update = await service.get_progress("test-1")
assert update.percent == 100.0
@pytest.mark.asyncio
async def test_complete_progress(self, service):
"""Test completing a progress operation."""
await service.start_progress(
progress_id="test-1",
progress_type=ProgressType.SCAN,
title="Test Scan",
total=100,
)
await service.update_progress(progress_id="test-1", current=50)
update = await service.complete_progress(
progress_id="test-1",
message="Scan completed successfully",
metadata={"items_found": 42},
)
assert update.status == ProgressStatus.COMPLETED
assert update.percent == 100.0
assert update.current == update.total
assert update.message == "Scan completed successfully"
assert update.metadata["items_found"] == 42
# Should be moved to history
active_progress = await service.get_all_active_progress()
assert "test-1" not in active_progress
@pytest.mark.asyncio
async def test_fail_progress(self, service):
"""Test failing a progress operation."""
await service.start_progress(
progress_id="test-1",
progress_type=ProgressType.DOWNLOAD,
title="Test Download",
)
update = await service.fail_progress(
progress_id="test-1",
error_message="Network timeout",
metadata={"retry_count": 3},
)
assert update.status == ProgressStatus.FAILED
assert update.message == "Network timeout"
assert update.metadata["retry_count"] == 3
# Should be moved to history
active_progress = await service.get_all_active_progress()
assert "test-1" not in active_progress
@pytest.mark.asyncio
async def test_cancel_progress(self, service):
"""Test cancelling a progress operation."""
await service.start_progress(
progress_id="test-1",
progress_type=ProgressType.DOWNLOAD,
title="Test Download",
)
update = await service.cancel_progress(
progress_id="test-1",
message="Cancelled by user",
)
assert update.status == ProgressStatus.CANCELLED
assert update.message == "Cancelled by user"
# Should be moved to history
active_progress = await service.get_all_active_progress()
assert "test-1" not in active_progress
@pytest.mark.asyncio
async def test_get_progress(self, service):
"""Test retrieving progress by ID."""
await service.start_progress(
progress_id="test-1",
progress_type=ProgressType.SCAN,
title="Test",
)
progress = await service.get_progress("test-1")
assert progress is not None
assert progress.id == "test-1"
# Test non-existent progress
progress = await service.get_progress("nonexistent")
assert progress is None
@pytest.mark.asyncio
async def test_get_all_active_progress(self, service):
"""Test retrieving all active progress operations."""
await service.start_progress(
progress_id="download-1",
progress_type=ProgressType.DOWNLOAD,
title="Download 1",
)
await service.start_progress(
progress_id="download-2",
progress_type=ProgressType.DOWNLOAD,
title="Download 2",
)
await service.start_progress(
progress_id="scan-1",
progress_type=ProgressType.SCAN,
title="Scan 1",
)
all_progress = await service.get_all_active_progress()
assert len(all_progress) == 3
assert "download-1" in all_progress
assert "download-2" in all_progress
assert "scan-1" in all_progress
@pytest.mark.asyncio
async def test_get_all_active_progress_filtered(self, service):
"""Test retrieving active progress filtered by type."""
await service.start_progress(
progress_id="download-1",
progress_type=ProgressType.DOWNLOAD,
title="Download 1",
)
await service.start_progress(
progress_id="download-2",
progress_type=ProgressType.DOWNLOAD,
title="Download 2",
)
await service.start_progress(
progress_id="scan-1",
progress_type=ProgressType.SCAN,
title="Scan 1",
)
download_progress = await service.get_all_active_progress(
progress_type=ProgressType.DOWNLOAD
)
assert len(download_progress) == 2
assert "download-1" in download_progress
assert "download-2" in download_progress
assert "scan-1" not in download_progress
@pytest.mark.asyncio
async def test_history_management(self, service):
"""Test progress history is maintained with size limit."""
# Start and complete multiple progress operations
for i in range(60): # More than max_history_size (50)
await service.start_progress(
progress_id=f"test-{i}",
progress_type=ProgressType.DOWNLOAD,
title=f"Test {i}",
)
await service.complete_progress(
progress_id=f"test-{i}",
message="Completed",
)
# Check that oldest entries were removed
history = service._history
assert len(history) <= 50
# Most recent should be in history
recent_progress = await service.get_progress("test-59")
assert recent_progress is not None
@pytest.mark.asyncio
async def test_broadcast_callback(self, service, mock_broadcast):
"""Test broadcast callback is invoked correctly."""
service.set_broadcast_callback(mock_broadcast)
await service.start_progress(
progress_id="test-1",
progress_type=ProgressType.DOWNLOAD,
title="Test",
)
# Verify callback was called for start
mock_broadcast.assert_called_once()
call_args = mock_broadcast.call_args
assert call_args[1]["message_type"] == "download_progress"
assert call_args[1]["room"] == "download_progress"
assert "test-1" in str(call_args[1]["data"])
@pytest.mark.asyncio
async def test_broadcast_on_update(self, service, mock_broadcast):
"""Test broadcast on progress update."""
service.set_broadcast_callback(mock_broadcast)
await service.start_progress(
progress_id="test-1",
progress_type=ProgressType.DOWNLOAD,
title="Test",
total=100,
)
mock_broadcast.reset_mock()
# Update with significant change (>1%)
await service.update_progress(
progress_id="test-1",
current=50,
force_broadcast=True,
)
# Should have been called
assert mock_broadcast.call_count >= 1
@pytest.mark.asyncio
async def test_broadcast_on_complete(self, service, mock_broadcast):
"""Test broadcast on progress completion."""
service.set_broadcast_callback(mock_broadcast)
await service.start_progress(
progress_id="test-1",
progress_type=ProgressType.SCAN,
title="Test",
)
mock_broadcast.reset_mock()
await service.complete_progress(
progress_id="test-1",
message="Done",
)
# Should have been called
mock_broadcast.assert_called_once()
call_args = mock_broadcast.call_args
assert "completed" in str(call_args[1]["data"]).lower()
@pytest.mark.asyncio
async def test_broadcast_on_failure(self, service, mock_broadcast):
"""Test broadcast on progress failure."""
service.set_broadcast_callback(mock_broadcast)
await service.start_progress(
progress_id="test-1",
progress_type=ProgressType.DOWNLOAD,
title="Test",
)
mock_broadcast.reset_mock()
await service.fail_progress(
progress_id="test-1",
error_message="Test error",
)
# Should have been called
mock_broadcast.assert_called_once()
call_args = mock_broadcast.call_args
assert "failed" in str(call_args[1]["data"]).lower()
@pytest.mark.asyncio
async def test_clear_history(self, service):
"""Test clearing progress history."""
# Create and complete some progress
for i in range(5):
await service.start_progress(
progress_id=f"test-{i}",
progress_type=ProgressType.DOWNLOAD,
title=f"Test {i}",
)
await service.complete_progress(
progress_id=f"test-{i}",
message="Done",
)
# History should not be empty
assert len(service._history) > 0
# Clear history
await service.clear_history()
# History should now be empty
assert len(service._history) == 0
@pytest.mark.asyncio
async def test_concurrent_progress_operations(self, service):
"""Test handling multiple concurrent progress operations."""
async def create_and_complete_progress(id_num: int):
"""Helper to create and complete a progress."""
await service.start_progress(
progress_id=f"test-{id_num}",
progress_type=ProgressType.DOWNLOAD,
title=f"Test {id_num}",
total=100,
)
for i in range(0, 101, 10):
await service.update_progress(
progress_id=f"test-{id_num}",
current=i,
)
await asyncio.sleep(0.01)
await service.complete_progress(
progress_id=f"test-{id_num}",
message="Done",
)
# Run multiple concurrent operations
tasks = [create_and_complete_progress(i) for i in range(10)]
await asyncio.gather(*tasks)
# All should be in history
for i in range(10):
progress = await service.get_progress(f"test-{i}")
assert progress is not None
assert progress.status == ProgressStatus.COMPLETED
@pytest.mark.asyncio
async def test_update_with_metadata(self, service):
"""Test updating progress with metadata."""
await service.start_progress(
progress_id="test-1",
progress_type=ProgressType.DOWNLOAD,
title="Test",
metadata={"initial": "value"},
)
await service.update_progress(
progress_id="test-1",
current=50,
metadata={"additional": "data", "speed": 1.5},
)
progress = await service.get_progress("test-1")
assert progress.metadata["initial"] == "value"
assert progress.metadata["additional"] == "data"
assert progress.metadata["speed"] == 1.5