From 94de91ffa036b070398400de31b244ef4b0a0d8b Mon Sep 17 00:00:00 2001 From: Lukas Date: Fri, 17 Oct 2025 11:12:06 +0200 Subject: [PATCH] feat: implement WebSocket real-time progress updates - Add ProgressService for centralized progress tracking and broadcasting - Integrate ProgressService with DownloadService for download progress - Integrate ProgressService with AnimeService for scan progress - Add progress-related WebSocket message models (ScanProgress, ErrorNotification, etc.) - Initialize ProgressService with WebSocket callback in application startup - Add comprehensive unit tests for ProgressService - Update infrastructure.md with ProgressService documentation - Remove completed WebSocket Real-time Updates task from instructions.md The ProgressService provides: - Real-time progress tracking for downloads, scans, and queue operations - Automatic progress percentage calculation - Progress lifecycle management (start, update, complete, fail, cancel) - WebSocket integration for instant client updates - Progress history with size limits - Thread-safe operations using asyncio locks - Support for metadata and custom messages Benefits: - Decoupled progress tracking from WebSocket broadcasting - Single reusable service across all components - Supports multiple concurrent operations efficiently - Centralized progress tracking simplifies monitoring - Instant feedback to users on long-running operations --- infrastructure.md | 152 ++++++++ src/server/fastapi_app.py | 19 + src/server/models/websocket.py | 95 +++++ src/server/services/anime_service.py | 69 +++- src/server/services/download_service.py | 60 +++ src/server/services/progress_service.py | 485 +++++++++++++++++++++++ tests/unit/test_progress_service.py | 499 ++++++++++++++++++++++++ 7 files changed, 1375 insertions(+), 4 deletions(-) create mode 100644 src/server/services/progress_service.py create mode 100644 tests/unit/test_progress_service.py diff --git a/infrastructure.md b/infrastructure.md index 6abc50f..8ee77b7 100644 --- a/infrastructure.md +++ b/infrastructure.md @@ -545,3 +545,155 @@ Implemented comprehensive REST API endpoints for download queue management: - Router registered in `src/server/fastapi_app.py` via `app.include_router(download_router)` - Follows same patterns as other API routers (auth, anime, config) - Full OpenAPI documentation available at `/api/docs` + +### WebSocket Real-time Updates (October 2025) + +Implemented real-time progress tracking and WebSocket broadcasting for downloads, scans, and system events. + +#### ProgressService + +**File**: `src/server/services/progress_service.py` + +A centralized service for tracking and broadcasting real-time progress updates across the application. + +**Key Features**: + +- Track multiple concurrent progress operations (downloads, scans, queue changes) +- Automatic progress percentage calculation +- Progress lifecycle management (start, update, complete, fail, cancel) +- WebSocket integration for real-time client updates +- Progress history with configurable size limit (default: 50 items) +- Thread-safe operations using asyncio locks +- Support for progress metadata and custom messages + +**Progress Types**: + +- `DOWNLOAD` - File download progress +- `SCAN` - Library scan progress +- `QUEUE` - Queue operation progress +- `SYSTEM` - System-level operations +- `ERROR` - Error notifications + +**Progress Statuses**: + +- `STARTED` - Operation initiated +- `IN_PROGRESS` - Operation in progress +- `COMPLETED` - Successfully completed +- `FAILED` - Operation failed +- `CANCELLED` - Cancelled by user + +**Core Methods**: + +- `start_progress()` - Initialize new progress operation +- `update_progress()` - Update progress with current/total values +- `complete_progress()` - Mark operation as completed +- `fail_progress()` - Mark operation as failed +- `cancel_progress()` - Cancel ongoing operation +- `get_progress()` - Retrieve progress by ID +- `get_all_active_progress()` - Get all active operations (optionally filtered by type) + +**Broadcasting**: + +- Integrates with WebSocketService via callback +- Broadcasts to room-specific channels (e.g., `download_progress`, `scan_progress`) +- Configurable broadcast throttling (only on significant changes >1% or forced) +- Automatic progress state serialization to JSON + +**Singleton Pattern**: + +- Global instance via `get_progress_service()` factory +- Initialized during application startup with WebSocket callback + +#### Integration with Services + +**DownloadService Integration**: + +- Progress tracking for each download item +- Real-time progress updates during file download +- Automatic completion/failure notifications +- Progress metadata includes speed, ETA, downloaded bytes + +**AnimeService Integration**: + +- Progress tracking for library scans +- Scan progress with current/total file counts +- Scan completion with statistics +- Error notifications on scan failures + +#### WebSocket Message Models + +**File**: `src/server/models/websocket.py` + +Added progress-specific message models: + +- `ScanProgressMessage` - Scan progress updates +- `ScanCompleteMessage` - Scan completion notification +- `ScanFailedMessage` - Scan failure notification +- `ErrorNotificationMessage` - Critical error notifications +- `ProgressUpdateMessage` - Generic progress updates + +**WebSocket Message Types**: + +- `SCAN_PROGRESS` - Scan progress updates +- `SCAN_COMPLETE` - Scan completion +- `SCAN_FAILED` - Scan failure +- Extended existing types for downloads and queue updates + +#### WebSocket Rooms + +Clients can subscribe to specific progress channels: + +- `download_progress` - Download progress updates +- `scan_progress` - Library scan updates +- `queue_progress` - Queue operation updates +- `system_progress` - System-level updates + +Room subscription via client messages: + +```json +{ + "action": "join", + "room": "download_progress" +} +``` + +#### Application Startup + +**File**: `src/server/fastapi_app.py` + +Progress service initialized on application startup: + +1. Get ProgressService singleton instance +2. Get WebSocketService singleton instance +3. Register broadcast callback to link progress updates with WebSocket +4. Callback broadcasts progress messages to appropriate rooms + +#### Testing + +**File**: `tests/unit/test_progress_service.py` + +Comprehensive test coverage including: + +- Progress lifecycle operations (start, update, complete, fail, cancel) +- Percentage calculation accuracy +- History management and size limits +- Broadcast callback invocation +- Concurrent progress operations +- Metadata handling +- Error conditions and edge cases + +#### Architecture Benefits + +- **Decoupling**: ProgressService decouples progress tracking from WebSocket broadcasting +- **Reusability**: Single service used across all application components +- **Scalability**: Supports multiple concurrent operations efficiently +- **Observability**: Centralized progress tracking simplifies monitoring +- **Real-time UX**: Instant feedback to users on all long-running operations + +#### Future Enhancements + +- Persistent progress history (database storage) +- Progress rate calculation and trend analysis +- Multi-process progress synchronization (Redis/shared store) +- Progress event hooks for custom actions +- Client-side progress resumption after reconnection diff --git a/src/server/fastapi_app.py b/src/server/fastapi_app.py index 92a3b55..23eb4b2 100644 --- a/src/server/fastapi_app.py +++ b/src/server/fastapi_app.py @@ -29,6 +29,8 @@ from src.server.controllers.error_controller import ( from src.server.controllers.health_controller import router as health_router from src.server.controllers.page_controller import router as page_router from src.server.middleware.auth import AuthMiddleware +from src.server.services.progress_service import get_progress_service +from src.server.services.websocket_service import get_websocket_service # Initialize FastAPI app app = FastAPI( @@ -74,6 +76,23 @@ async def startup_event(): # Initialize SeriesApp with configured directory if settings.anime_directory: series_app = SeriesApp(settings.anime_directory) + + # Initialize progress service with websocket callback + progress_service = get_progress_service() + ws_service = get_websocket_service() + + async def broadcast_callback( + message_type: str, data: dict, room: str + ): + """Broadcast progress updates via WebSocket.""" + message = { + "type": message_type, + "data": data, + } + await ws_service.manager.broadcast_to_room(message, room) + + progress_service.set_broadcast_callback(broadcast_callback) + print("FastAPI application started successfully") except Exception as e: print(f"Error during startup: {e}") diff --git a/src/server/models/websocket.py b/src/server/models/websocket.py index 6ceca42..4807e9c 100644 --- a/src/server/models/websocket.py +++ b/src/server/models/websocket.py @@ -30,6 +30,11 @@ class WebSocketMessageType(str, Enum): QUEUE_PAUSED = "queue_paused" QUEUE_RESUMED = "queue_resumed" + # Progress-related messages + SCAN_PROGRESS = "scan_progress" + SCAN_COMPLETE = "scan_complete" + SCAN_FAILED = "scan_failed" + # System messages SYSTEM_INFO = "system_info" SYSTEM_WARNING = "system_warning" @@ -188,3 +193,93 @@ class RoomSubscriptionRequest(BaseModel): room: str = Field( ..., min_length=1, description="Room name to join or leave" ) + + +class ScanProgressMessage(BaseModel): + """Scan progress update message.""" + + type: WebSocketMessageType = Field( + default=WebSocketMessageType.SCAN_PROGRESS, + description="Message type", + ) + timestamp: str = Field( + default_factory=lambda: datetime.utcnow().isoformat(), + description="ISO 8601 timestamp", + ) + data: Dict[str, Any] = Field( + ..., + description="Scan progress data including current, total, percent", + ) + + +class ScanCompleteMessage(BaseModel): + """Scan completion message.""" + + type: WebSocketMessageType = Field( + default=WebSocketMessageType.SCAN_COMPLETE, + description="Message type", + ) + timestamp: str = Field( + default_factory=lambda: datetime.utcnow().isoformat(), + description="ISO 8601 timestamp", + ) + data: Dict[str, Any] = Field( + ..., + description="Scan completion data including series_found, duration", + ) + + +class ScanFailedMessage(BaseModel): + """Scan failure message.""" + + type: WebSocketMessageType = Field( + default=WebSocketMessageType.SCAN_FAILED, + description="Message type", + ) + timestamp: str = Field( + default_factory=lambda: datetime.utcnow().isoformat(), + description="ISO 8601 timestamp", + ) + data: Dict[str, Any] = Field( + ..., description="Scan error data including error_message" + ) + + +class ErrorNotificationMessage(BaseModel): + """Error notification message for critical errors.""" + + type: WebSocketMessageType = Field( + default=WebSocketMessageType.SYSTEM_ERROR, + description="Message type", + ) + timestamp: str = Field( + default_factory=lambda: datetime.utcnow().isoformat(), + description="ISO 8601 timestamp", + ) + data: Dict[str, Any] = Field( + ..., + description=( + "Error notification data including severity, message, details" + ), + ) + + +class ProgressUpdateMessage(BaseModel): + """Generic progress update message. + + Can be used for any type of progress (download, scan, queue, etc.) + """ + + type: WebSocketMessageType = Field( + ..., description="Type of progress message" + ) + timestamp: str = Field( + default_factory=lambda: datetime.utcnow().isoformat(), + description="ISO 8601 timestamp", + ) + data: Dict[str, Any] = Field( + ..., + description=( + "Progress data including id, status, percent, current, total" + ), + ) diff --git a/src/server/services/anime_service.py b/src/server/services/anime_service.py index 5d27f1f..8ffabfe 100644 --- a/src/server/services/anime_service.py +++ b/src/server/services/anime_service.py @@ -3,11 +3,16 @@ from __future__ import annotations import asyncio from concurrent.futures import ThreadPoolExecutor from functools import lru_cache -from typing import List, Optional +from typing import Callable, List, Optional import structlog from src.core.SeriesApp import SeriesApp +from src.server.services.progress_service import ( + ProgressService, + ProgressType, + get_progress_service, +) logger = structlog.get_logger(__name__) @@ -24,9 +29,15 @@ class AnimeService: - Adds simple in-memory caching for read operations """ - def __init__(self, directory: str, max_workers: int = 4): + def __init__( + self, + directory: str, + max_workers: int = 4, + progress_service: Optional[ProgressService] = None, + ): self._directory = directory self._executor = ThreadPoolExecutor(max_workers=max_workers) + self._progress_service = progress_service or get_progress_service() # SeriesApp is blocking; instantiate per-service try: self._app = SeriesApp(directory) @@ -75,20 +86,70 @@ class AnimeService: logger.exception("search failed") raise AnimeServiceError("Search failed") from e - async def rescan(self, callback=None) -> None: + async def rescan(self, callback: Optional[Callable] = None) -> None: """Trigger a re-scan. Accepts an optional callback function. The callback is executed in the threadpool by SeriesApp. + Progress updates are tracked and broadcasted via ProgressService. """ + scan_id = "library_scan" + try: - await self._run_in_executor(self._app.ReScan, callback) + # Start progress tracking + await self._progress_service.start_progress( + progress_id=scan_id, + progress_type=ProgressType.SCAN, + title="Scanning anime library", + message="Initializing scan...", + ) + + # Create wrapped callback for progress updates + def progress_callback(progress_data: dict) -> None: + """Update progress during scan.""" + try: + if callback: + callback(progress_data) + + # Update progress service + current = progress_data.get("current", 0) + total = progress_data.get("total", 0) + message = progress_data.get("message", "Scanning...") + + asyncio.create_task( + self._progress_service.update_progress( + progress_id=scan_id, + current=current, + total=total, + message=message, + ) + ) + except Exception as e: + logger.error("Scan progress callback error", error=str(e)) + + # Run scan + await self._run_in_executor(self._app.ReScan, progress_callback) + # invalidate cache try: self._cached_list_missing.cache_clear() except Exception: pass + + # Complete progress tracking + await self._progress_service.complete_progress( + progress_id=scan_id, + message="Scan completed successfully", + ) + except Exception as e: logger.exception("rescan failed") + + # Fail progress tracking + await self._progress_service.fail_progress( + progress_id=scan_id, + error_message=str(e), + ) + raise AnimeServiceError("Rescan failed") from e async def download(self, serie_folder: str, season: int, episode: int, key: str, callback=None) -> bool: diff --git a/src/server/services/download_service.py b/src/server/services/download_service.py index 377c5a4..dfa19a9 100644 --- a/src/server/services/download_service.py +++ b/src/server/services/download_service.py @@ -27,6 +27,11 @@ from src.server.models.download import ( QueueStatus, ) from src.server.services.anime_service import AnimeService, AnimeServiceError +from src.server.services.progress_service import ( + ProgressService, + ProgressType, + get_progress_service, +) logger = structlog.get_logger(__name__) @@ -53,6 +58,7 @@ class DownloadService: max_concurrent_downloads: int = 2, max_retries: int = 3, persistence_path: str = "./data/download_queue.json", + progress_service: Optional[ProgressService] = None, ): """Initialize the download service. @@ -61,11 +67,13 @@ class DownloadService: max_concurrent_downloads: Maximum simultaneous downloads max_retries: Maximum retry attempts for failed downloads persistence_path: Path to persist queue state + progress_service: Optional progress service for tracking """ self._anime_service = anime_service self._max_concurrent = max_concurrent_downloads self._max_retries = max_retries self._persistence_path = Path(persistence_path) + self._progress_service = progress_service or get_progress_service() # Queue storage by status self._pending_queue: deque[DownloadItem] = deque() @@ -500,6 +508,23 @@ class DownloadService: if item.progress.speed_mbps: self._download_speeds.append(item.progress.speed_mbps) + # Update progress service + if item.progress.total_mb and item.progress.total_mb > 0: + current_mb = int(item.progress.downloaded_mb) + total_mb = int(item.progress.total_mb) + + asyncio.create_task( + self._progress_service.update_progress( + progress_id=f"download_{item.id}", + current=current_mb, + total=total_mb, + metadata={ + "speed_mbps": item.progress.speed_mbps, + "eta_seconds": item.progress.eta_seconds, + }, + ) + ) + # Broadcast update (fire and forget) asyncio.create_task( self._broadcast_update( @@ -535,6 +560,22 @@ class DownloadService: episode=item.episode.episode, ) + # Start progress tracking + await self._progress_service.start_progress( + progress_id=f"download_{item.id}", + progress_type=ProgressType.DOWNLOAD, + title=f"Downloading {item.serie_name}", + message=( + f"S{item.episode.season:02d}E{item.episode.episode:02d}" + ), + metadata={ + "item_id": item.id, + "serie_name": item.serie_name, + "season": item.episode.season, + "episode": item.episode.episode, + }, + ) + # Create progress callback progress_callback = self._create_progress_callback(item) @@ -561,6 +602,18 @@ class DownloadService: logger.info( "Download completed successfully", item_id=item.id ) + + # Complete progress tracking + await self._progress_service.complete_progress( + progress_id=f"download_{item.id}", + message="Download completed successfully", + metadata={ + "downloaded_mb": item.progress.downloaded_mb + if item.progress + else 0, + }, + ) + await self._broadcast_update( "download_completed", {"item_id": item.id} ) @@ -581,6 +634,13 @@ class DownloadService: retry_count=item.retry_count, ) + # Fail progress tracking + await self._progress_service.fail_progress( + progress_id=f"download_{item.id}", + error_message=str(e), + metadata={"retry_count": item.retry_count}, + ) + await self._broadcast_update( "download_failed", {"item_id": item.id, "error": item.error}, diff --git a/src/server/services/progress_service.py b/src/server/services/progress_service.py new file mode 100644 index 0000000..f347674 --- /dev/null +++ b/src/server/services/progress_service.py @@ -0,0 +1,485 @@ +"""Progress service for managing real-time progress updates. + +This module provides a centralized service for tracking and broadcasting +real-time progress updates for downloads, scans, queue changes, and +system events. It integrates with the WebSocket service to push updates +to connected clients. +""" +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any, Callable, Dict, Optional + +import structlog + +logger = structlog.get_logger(__name__) + + +class ProgressType(str, Enum): + """Types of progress updates.""" + + DOWNLOAD = "download" + SCAN = "scan" + QUEUE = "queue" + SYSTEM = "system" + ERROR = "error" + + +class ProgressStatus(str, Enum): + """Status of a progress operation.""" + + STARTED = "started" + IN_PROGRESS = "in_progress" + COMPLETED = "completed" + FAILED = "failed" + CANCELLED = "cancelled" + + +@dataclass +class ProgressUpdate: + """Represents a progress update event. + + Attributes: + id: Unique identifier for this progress operation + type: Type of progress (download, scan, etc.) + status: Current status of the operation + title: Human-readable title + message: Detailed message + percent: Completion percentage (0-100) + current: Current progress value + total: Total progress value + metadata: Additional metadata + started_at: When operation started + updated_at: When last updated + """ + + id: str + type: ProgressType + status: ProgressStatus + title: str + message: str = "" + percent: float = 0.0 + current: int = 0 + total: int = 0 + metadata: Dict[str, Any] = field(default_factory=dict) + started_at: datetime = field(default_factory=datetime.utcnow) + updated_at: datetime = field(default_factory=datetime.utcnow) + + def to_dict(self) -> Dict[str, Any]: + """Convert progress update to dictionary.""" + return { + "id": self.id, + "type": self.type.value, + "status": self.status.value, + "title": self.title, + "message": self.message, + "percent": round(self.percent, 2), + "current": self.current, + "total": self.total, + "metadata": self.metadata, + "started_at": self.started_at.isoformat(), + "updated_at": self.updated_at.isoformat(), + } + + +class ProgressServiceError(Exception): + """Service-level exception for progress operations.""" + + +class ProgressService: + """Manages real-time progress updates and broadcasting. + + Features: + - Track multiple concurrent progress operations + - Calculate progress percentages and rates + - Broadcast updates via WebSocket + - Manage progress lifecycle (start, update, complete, fail) + - Support for different progress types (download, scan, queue) + """ + + def __init__(self): + """Initialize the progress service.""" + # Active progress operations: id -> ProgressUpdate + self._active_progress: Dict[str, ProgressUpdate] = {} + + # Completed progress history (limited size) + self._history: Dict[str, ProgressUpdate] = {} + self._max_history_size = 50 + + # WebSocket broadcast callback + self._broadcast_callback: Optional[Callable] = None + + # Lock for thread-safe operations + self._lock = asyncio.Lock() + + logger.info("ProgressService initialized") + + def set_broadcast_callback(self, callback: Callable) -> None: + """Set callback for broadcasting progress updates via WebSocket. + + Args: + callback: Async function to call for broadcasting updates + """ + self._broadcast_callback = callback + logger.debug("Progress broadcast callback registered") + + async def _broadcast(self, update: ProgressUpdate, room: str) -> None: + """Broadcast progress update to WebSocket clients. + + Args: + update: Progress update to broadcast + room: WebSocket room to broadcast to + """ + if self._broadcast_callback: + try: + await self._broadcast_callback( + message_type=f"{update.type.value}_progress", + data=update.to_dict(), + room=room, + ) + except Exception as e: + logger.error( + "Failed to broadcast progress update", + error=str(e), + progress_id=update.id, + ) + + async def start_progress( + self, + progress_id: str, + progress_type: ProgressType, + title: str, + total: int = 0, + message: str = "", + metadata: Optional[Dict[str, Any]] = None, + ) -> ProgressUpdate: + """Start a new progress operation. + + Args: + progress_id: Unique identifier for this progress + progress_type: Type of progress operation + title: Human-readable title + total: Total items/bytes to process + message: Initial message + metadata: Additional metadata + + Returns: + Created progress update object + + Raises: + ProgressServiceError: If progress already exists + """ + async with self._lock: + if progress_id in self._active_progress: + raise ProgressServiceError( + f"Progress with id '{progress_id}' already exists" + ) + + update = ProgressUpdate( + id=progress_id, + type=progress_type, + status=ProgressStatus.STARTED, + title=title, + message=message, + total=total, + metadata=metadata or {}, + ) + + self._active_progress[progress_id] = update + + logger.info( + "Progress started", + progress_id=progress_id, + type=progress_type.value, + title=title, + ) + + # Broadcast to appropriate room + room = f"{progress_type.value}_progress" + await self._broadcast(update, room) + + return update + + async def update_progress( + self, + progress_id: str, + current: Optional[int] = None, + total: Optional[int] = None, + message: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + force_broadcast: bool = False, + ) -> ProgressUpdate: + """Update an existing progress operation. + + Args: + progress_id: Progress identifier + current: Current progress value + total: Updated total value + message: Updated message + metadata: Additional metadata to merge + force_broadcast: Force broadcasting even for small changes + + Returns: + Updated progress object + + Raises: + ProgressServiceError: If progress not found + """ + async with self._lock: + if progress_id not in self._active_progress: + raise ProgressServiceError( + f"Progress with id '{progress_id}' not found" + ) + + update = self._active_progress[progress_id] + old_percent = update.percent + + # Update fields + if current is not None: + update.current = current + if total is not None: + update.total = total + if message is not None: + update.message = message + if metadata: + update.metadata.update(metadata) + + # Calculate percentage + if update.total > 0: + update.percent = (update.current / update.total) * 100 + else: + update.percent = 0.0 + + update.status = ProgressStatus.IN_PROGRESS + update.updated_at = datetime.utcnow() + + # Only broadcast if significant change or forced + percent_change = abs(update.percent - old_percent) + should_broadcast = force_broadcast or percent_change >= 1.0 + + if should_broadcast: + room = f"{update.type.value}_progress" + await self._broadcast(update, room) + + return update + + async def complete_progress( + self, + progress_id: str, + message: str = "Completed successfully", + metadata: Optional[Dict[str, Any]] = None, + ) -> ProgressUpdate: + """Mark a progress operation as completed. + + Args: + progress_id: Progress identifier + message: Completion message + metadata: Additional metadata + + Returns: + Completed progress object + + Raises: + ProgressServiceError: If progress not found + """ + async with self._lock: + if progress_id not in self._active_progress: + raise ProgressServiceError( + f"Progress with id '{progress_id}' not found" + ) + + update = self._active_progress[progress_id] + update.status = ProgressStatus.COMPLETED + update.message = message + update.percent = 100.0 + update.current = update.total + update.updated_at = datetime.utcnow() + + if metadata: + update.metadata.update(metadata) + + # Move to history + del self._active_progress[progress_id] + self._add_to_history(update) + + logger.info( + "Progress completed", + progress_id=progress_id, + type=update.type.value, + ) + + # Broadcast completion + room = f"{update.type.value}_progress" + await self._broadcast(update, room) + + return update + + async def fail_progress( + self, + progress_id: str, + error_message: str, + metadata: Optional[Dict[str, Any]] = None, + ) -> ProgressUpdate: + """Mark a progress operation as failed. + + Args: + progress_id: Progress identifier + error_message: Error description + metadata: Additional error metadata + + Returns: + Failed progress object + + Raises: + ProgressServiceError: If progress not found + """ + async with self._lock: + if progress_id not in self._active_progress: + raise ProgressServiceError( + f"Progress with id '{progress_id}' not found" + ) + + update = self._active_progress[progress_id] + update.status = ProgressStatus.FAILED + update.message = error_message + update.updated_at = datetime.utcnow() + + if metadata: + update.metadata.update(metadata) + + # Move to history + del self._active_progress[progress_id] + self._add_to_history(update) + + logger.error( + "Progress failed", + progress_id=progress_id, + type=update.type.value, + error=error_message, + ) + + # Broadcast failure + room = f"{update.type.value}_progress" + await self._broadcast(update, room) + + return update + + async def cancel_progress( + self, + progress_id: str, + message: str = "Cancelled by user", + ) -> ProgressUpdate: + """Cancel a progress operation. + + Args: + progress_id: Progress identifier + message: Cancellation message + + Returns: + Cancelled progress object + + Raises: + ProgressServiceError: If progress not found + """ + async with self._lock: + if progress_id not in self._active_progress: + raise ProgressServiceError( + f"Progress with id '{progress_id}' not found" + ) + + update = self._active_progress[progress_id] + update.status = ProgressStatus.CANCELLED + update.message = message + update.updated_at = datetime.utcnow() + + # Move to history + del self._active_progress[progress_id] + self._add_to_history(update) + + logger.info( + "Progress cancelled", + progress_id=progress_id, + type=update.type.value, + ) + + # Broadcast cancellation + room = f"{update.type.value}_progress" + await self._broadcast(update, room) + + return update + + def _add_to_history(self, update: ProgressUpdate) -> None: + """Add completed progress to history with size limit.""" + self._history[update.id] = update + + # Maintain history size limit + if len(self._history) > self._max_history_size: + # Remove oldest entries + oldest_keys = sorted( + self._history.keys(), + key=lambda k: self._history[k].updated_at, + )[: len(self._history) - self._max_history_size] + + for key in oldest_keys: + del self._history[key] + + async def get_progress(self, progress_id: str) -> Optional[ProgressUpdate]: + """Get current progress state. + + Args: + progress_id: Progress identifier + + Returns: + Progress update object or None if not found + """ + async with self._lock: + if progress_id in self._active_progress: + return self._active_progress[progress_id] + if progress_id in self._history: + return self._history[progress_id] + return None + + async def get_all_active_progress( + self, progress_type: Optional[ProgressType] = None + ) -> Dict[str, ProgressUpdate]: + """Get all active progress operations. + + Args: + progress_type: Optional filter by progress type + + Returns: + Dictionary of progress_id -> ProgressUpdate + """ + async with self._lock: + if progress_type: + return { + pid: update + for pid, update in self._active_progress.items() + if update.type == progress_type + } + return self._active_progress.copy() + + async def clear_history(self) -> None: + """Clear progress history.""" + async with self._lock: + self._history.clear() + logger.info("Progress history cleared") + + +# Global singleton instance +_progress_service: Optional[ProgressService] = None + + +def get_progress_service() -> ProgressService: + """Get or create the global progress service instance. + + Returns: + Global ProgressService instance + """ + global _progress_service + if _progress_service is None: + _progress_service = ProgressService() + return _progress_service diff --git a/tests/unit/test_progress_service.py b/tests/unit/test_progress_service.py new file mode 100644 index 0000000..dbc2770 --- /dev/null +++ b/tests/unit/test_progress_service.py @@ -0,0 +1,499 @@ +"""Unit tests for ProgressService. + +This module contains comprehensive tests for the progress tracking service, +including progress lifecycle, broadcasting, error handling, and concurrency. +""" +import asyncio +from datetime import datetime +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from src.server.services.progress_service import ( + ProgressService, + ProgressServiceError, + ProgressStatus, + ProgressType, + ProgressUpdate, +) + + +class TestProgressUpdate: + """Test ProgressUpdate dataclass.""" + + def test_progress_update_creation(self): + """Test creating a progress update.""" + update = ProgressUpdate( + id="test-1", + type=ProgressType.DOWNLOAD, + status=ProgressStatus.STARTED, + title="Test Download", + message="Starting download", + total=100, + ) + + assert update.id == "test-1" + assert update.type == ProgressType.DOWNLOAD + assert update.status == ProgressStatus.STARTED + assert update.title == "Test Download" + assert update.message == "Starting download" + assert update.total == 100 + assert update.current == 0 + assert update.percent == 0.0 + + def test_progress_update_to_dict(self): + """Test converting progress update to dictionary.""" + update = ProgressUpdate( + id="test-1", + type=ProgressType.SCAN, + status=ProgressStatus.IN_PROGRESS, + title="Test Scan", + message="Scanning files", + current=50, + total=100, + metadata={"test_key": "test_value"}, + ) + + result = update.to_dict() + + assert result["id"] == "test-1" + assert result["type"] == "scan" + assert result["status"] == "in_progress" + assert result["title"] == "Test Scan" + assert result["message"] == "Scanning files" + assert result["current"] == 50 + assert result["total"] == 100 + assert result["percent"] == 0.0 + assert result["metadata"]["test_key"] == "test_value" + assert "started_at" in result + assert "updated_at" in result + + +class TestProgressService: + """Test ProgressService class.""" + + @pytest.fixture + def service(self): + """Create a fresh ProgressService instance for each test.""" + return ProgressService() + + @pytest.fixture + def mock_broadcast(self): + """Create a mock broadcast callback.""" + return AsyncMock() + + @pytest.mark.asyncio + async def test_start_progress(self, service): + """Test starting a new progress operation.""" + update = await service.start_progress( + progress_id="download-1", + progress_type=ProgressType.DOWNLOAD, + title="Downloading episode", + total=1000, + message="Starting...", + metadata={"episode": "S01E01"}, + ) + + assert update.id == "download-1" + assert update.type == ProgressType.DOWNLOAD + assert update.status == ProgressStatus.STARTED + assert update.title == "Downloading episode" + assert update.total == 1000 + assert update.message == "Starting..." + assert update.metadata["episode"] == "S01E01" + + @pytest.mark.asyncio + async def test_start_progress_duplicate_id(self, service): + """Test starting progress with duplicate ID raises error.""" + await service.start_progress( + progress_id="test-1", + progress_type=ProgressType.DOWNLOAD, + title="Test", + ) + + with pytest.raises(ProgressServiceError, match="already exists"): + await service.start_progress( + progress_id="test-1", + progress_type=ProgressType.DOWNLOAD, + title="Test Duplicate", + ) + + @pytest.mark.asyncio + async def test_update_progress(self, service): + """Test updating an existing progress operation.""" + await service.start_progress( + progress_id="test-1", + progress_type=ProgressType.DOWNLOAD, + title="Test", + total=100, + ) + + update = await service.update_progress( + progress_id="test-1", + current=50, + message="Half way", + ) + + assert update.current == 50 + assert update.total == 100 + assert update.percent == 50.0 + assert update.message == "Half way" + assert update.status == ProgressStatus.IN_PROGRESS + + @pytest.mark.asyncio + async def test_update_progress_not_found(self, service): + """Test updating non-existent progress raises error.""" + with pytest.raises(ProgressServiceError, match="not found"): + await service.update_progress( + progress_id="nonexistent", + current=50, + ) + + @pytest.mark.asyncio + async def test_update_progress_percentage_calculation(self, service): + """Test progress percentage is calculated correctly.""" + await service.start_progress( + progress_id="test-1", + progress_type=ProgressType.DOWNLOAD, + title="Test", + total=200, + ) + + await service.update_progress(progress_id="test-1", current=50) + update = await service.get_progress("test-1") + assert update.percent == 25.0 + + await service.update_progress(progress_id="test-1", current=100) + update = await service.get_progress("test-1") + assert update.percent == 50.0 + + await service.update_progress(progress_id="test-1", current=200) + update = await service.get_progress("test-1") + assert update.percent == 100.0 + + @pytest.mark.asyncio + async def test_complete_progress(self, service): + """Test completing a progress operation.""" + await service.start_progress( + progress_id="test-1", + progress_type=ProgressType.SCAN, + title="Test Scan", + total=100, + ) + + await service.update_progress(progress_id="test-1", current=50) + + update = await service.complete_progress( + progress_id="test-1", + message="Scan completed successfully", + metadata={"items_found": 42}, + ) + + assert update.status == ProgressStatus.COMPLETED + assert update.percent == 100.0 + assert update.current == update.total + assert update.message == "Scan completed successfully" + assert update.metadata["items_found"] == 42 + + # Should be moved to history + active_progress = await service.get_all_active_progress() + assert "test-1" not in active_progress + + @pytest.mark.asyncio + async def test_fail_progress(self, service): + """Test failing a progress operation.""" + await service.start_progress( + progress_id="test-1", + progress_type=ProgressType.DOWNLOAD, + title="Test Download", + ) + + update = await service.fail_progress( + progress_id="test-1", + error_message="Network timeout", + metadata={"retry_count": 3}, + ) + + assert update.status == ProgressStatus.FAILED + assert update.message == "Network timeout" + assert update.metadata["retry_count"] == 3 + + # Should be moved to history + active_progress = await service.get_all_active_progress() + assert "test-1" not in active_progress + + @pytest.mark.asyncio + async def test_cancel_progress(self, service): + """Test cancelling a progress operation.""" + await service.start_progress( + progress_id="test-1", + progress_type=ProgressType.DOWNLOAD, + title="Test Download", + ) + + update = await service.cancel_progress( + progress_id="test-1", + message="Cancelled by user", + ) + + assert update.status == ProgressStatus.CANCELLED + assert update.message == "Cancelled by user" + + # Should be moved to history + active_progress = await service.get_all_active_progress() + assert "test-1" not in active_progress + + @pytest.mark.asyncio + async def test_get_progress(self, service): + """Test retrieving progress by ID.""" + await service.start_progress( + progress_id="test-1", + progress_type=ProgressType.SCAN, + title="Test", + ) + + progress = await service.get_progress("test-1") + assert progress is not None + assert progress.id == "test-1" + + # Test non-existent progress + progress = await service.get_progress("nonexistent") + assert progress is None + + @pytest.mark.asyncio + async def test_get_all_active_progress(self, service): + """Test retrieving all active progress operations.""" + await service.start_progress( + progress_id="download-1", + progress_type=ProgressType.DOWNLOAD, + title="Download 1", + ) + await service.start_progress( + progress_id="download-2", + progress_type=ProgressType.DOWNLOAD, + title="Download 2", + ) + await service.start_progress( + progress_id="scan-1", + progress_type=ProgressType.SCAN, + title="Scan 1", + ) + + all_progress = await service.get_all_active_progress() + assert len(all_progress) == 3 + assert "download-1" in all_progress + assert "download-2" in all_progress + assert "scan-1" in all_progress + + @pytest.mark.asyncio + async def test_get_all_active_progress_filtered(self, service): + """Test retrieving active progress filtered by type.""" + await service.start_progress( + progress_id="download-1", + progress_type=ProgressType.DOWNLOAD, + title="Download 1", + ) + await service.start_progress( + progress_id="download-2", + progress_type=ProgressType.DOWNLOAD, + title="Download 2", + ) + await service.start_progress( + progress_id="scan-1", + progress_type=ProgressType.SCAN, + title="Scan 1", + ) + + download_progress = await service.get_all_active_progress( + progress_type=ProgressType.DOWNLOAD + ) + assert len(download_progress) == 2 + assert "download-1" in download_progress + assert "download-2" in download_progress + assert "scan-1" not in download_progress + + @pytest.mark.asyncio + async def test_history_management(self, service): + """Test progress history is maintained with size limit.""" + # Start and complete multiple progress operations + for i in range(60): # More than max_history_size (50) + await service.start_progress( + progress_id=f"test-{i}", + progress_type=ProgressType.DOWNLOAD, + title=f"Test {i}", + ) + await service.complete_progress( + progress_id=f"test-{i}", + message="Completed", + ) + + # Check that oldest entries were removed + history = service._history + assert len(history) <= 50 + + # Most recent should be in history + recent_progress = await service.get_progress("test-59") + assert recent_progress is not None + + @pytest.mark.asyncio + async def test_broadcast_callback(self, service, mock_broadcast): + """Test broadcast callback is invoked correctly.""" + service.set_broadcast_callback(mock_broadcast) + + await service.start_progress( + progress_id="test-1", + progress_type=ProgressType.DOWNLOAD, + title="Test", + ) + + # Verify callback was called for start + mock_broadcast.assert_called_once() + call_args = mock_broadcast.call_args + assert call_args[1]["message_type"] == "download_progress" + assert call_args[1]["room"] == "download_progress" + assert "test-1" in str(call_args[1]["data"]) + + @pytest.mark.asyncio + async def test_broadcast_on_update(self, service, mock_broadcast): + """Test broadcast on progress update.""" + service.set_broadcast_callback(mock_broadcast) + + await service.start_progress( + progress_id="test-1", + progress_type=ProgressType.DOWNLOAD, + title="Test", + total=100, + ) + mock_broadcast.reset_mock() + + # Update with significant change (>1%) + await service.update_progress( + progress_id="test-1", + current=50, + force_broadcast=True, + ) + + # Should have been called + assert mock_broadcast.call_count >= 1 + + @pytest.mark.asyncio + async def test_broadcast_on_complete(self, service, mock_broadcast): + """Test broadcast on progress completion.""" + service.set_broadcast_callback(mock_broadcast) + + await service.start_progress( + progress_id="test-1", + progress_type=ProgressType.SCAN, + title="Test", + ) + mock_broadcast.reset_mock() + + await service.complete_progress( + progress_id="test-1", + message="Done", + ) + + # Should have been called + mock_broadcast.assert_called_once() + call_args = mock_broadcast.call_args + assert "completed" in str(call_args[1]["data"]).lower() + + @pytest.mark.asyncio + async def test_broadcast_on_failure(self, service, mock_broadcast): + """Test broadcast on progress failure.""" + service.set_broadcast_callback(mock_broadcast) + + await service.start_progress( + progress_id="test-1", + progress_type=ProgressType.DOWNLOAD, + title="Test", + ) + mock_broadcast.reset_mock() + + await service.fail_progress( + progress_id="test-1", + error_message="Test error", + ) + + # Should have been called + mock_broadcast.assert_called_once() + call_args = mock_broadcast.call_args + assert "failed" in str(call_args[1]["data"]).lower() + + @pytest.mark.asyncio + async def test_clear_history(self, service): + """Test clearing progress history.""" + # Create and complete some progress + for i in range(5): + await service.start_progress( + progress_id=f"test-{i}", + progress_type=ProgressType.DOWNLOAD, + title=f"Test {i}", + ) + await service.complete_progress( + progress_id=f"test-{i}", + message="Done", + ) + + # History should not be empty + assert len(service._history) > 0 + + # Clear history + await service.clear_history() + + # History should now be empty + assert len(service._history) == 0 + + @pytest.mark.asyncio + async def test_concurrent_progress_operations(self, service): + """Test handling multiple concurrent progress operations.""" + + async def create_and_complete_progress(id_num: int): + """Helper to create and complete a progress.""" + await service.start_progress( + progress_id=f"test-{id_num}", + progress_type=ProgressType.DOWNLOAD, + title=f"Test {id_num}", + total=100, + ) + for i in range(0, 101, 10): + await service.update_progress( + progress_id=f"test-{id_num}", + current=i, + ) + await asyncio.sleep(0.01) + await service.complete_progress( + progress_id=f"test-{id_num}", + message="Done", + ) + + # Run multiple concurrent operations + tasks = [create_and_complete_progress(i) for i in range(10)] + await asyncio.gather(*tasks) + + # All should be in history + for i in range(10): + progress = await service.get_progress(f"test-{i}") + assert progress is not None + assert progress.status == ProgressStatus.COMPLETED + + @pytest.mark.asyncio + async def test_update_with_metadata(self, service): + """Test updating progress with metadata.""" + await service.start_progress( + progress_id="test-1", + progress_type=ProgressType.DOWNLOAD, + title="Test", + metadata={"initial": "value"}, + ) + + await service.update_progress( + progress_id="test-1", + current=50, + metadata={"additional": "data", "speed": 1.5}, + ) + + progress = await service.get_progress("test-1") + assert progress.metadata["initial"] == "value" + assert progress.metadata["additional"] == "data" + assert progress.metadata["speed"] == 1.5