feat: implement WebSocket real-time progress updates
- Add ProgressService for centralized progress tracking and broadcasting - Integrate ProgressService with DownloadService for download progress - Integrate ProgressService with AnimeService for scan progress - Add progress-related WebSocket message models (ScanProgress, ErrorNotification, etc.) - Initialize ProgressService with WebSocket callback in application startup - Add comprehensive unit tests for ProgressService - Update infrastructure.md with ProgressService documentation - Remove completed WebSocket Real-time Updates task from instructions.md The ProgressService provides: - Real-time progress tracking for downloads, scans, and queue operations - Automatic progress percentage calculation - Progress lifecycle management (start, update, complete, fail, cancel) - WebSocket integration for instant client updates - Progress history with size limits - Thread-safe operations using asyncio locks - Support for metadata and custom messages Benefits: - Decoupled progress tracking from WebSocket broadcasting - Single reusable service across all components - Supports multiple concurrent operations efficiently - Centralized progress tracking simplifies monitoring - Instant feedback to users on long-running operations
This commit is contained in:
parent
42a07be4cb
commit
94de91ffa0
@ -545,3 +545,155 @@ Implemented comprehensive REST API endpoints for download queue management:
|
||||
- Router registered in `src/server/fastapi_app.py` via `app.include_router(download_router)`
|
||||
- Follows same patterns as other API routers (auth, anime, config)
|
||||
- Full OpenAPI documentation available at `/api/docs`
|
||||
|
||||
### WebSocket Real-time Updates (October 2025)
|
||||
|
||||
Implemented real-time progress tracking and WebSocket broadcasting for downloads, scans, and system events.
|
||||
|
||||
#### ProgressService
|
||||
|
||||
**File**: `src/server/services/progress_service.py`
|
||||
|
||||
A centralized service for tracking and broadcasting real-time progress updates across the application.
|
||||
|
||||
**Key Features**:
|
||||
|
||||
- Track multiple concurrent progress operations (downloads, scans, queue changes)
|
||||
- Automatic progress percentage calculation
|
||||
- Progress lifecycle management (start, update, complete, fail, cancel)
|
||||
- WebSocket integration for real-time client updates
|
||||
- Progress history with configurable size limit (default: 50 items)
|
||||
- Thread-safe operations using asyncio locks
|
||||
- Support for progress metadata and custom messages
|
||||
|
||||
**Progress Types**:
|
||||
|
||||
- `DOWNLOAD` - File download progress
|
||||
- `SCAN` - Library scan progress
|
||||
- `QUEUE` - Queue operation progress
|
||||
- `SYSTEM` - System-level operations
|
||||
- `ERROR` - Error notifications
|
||||
|
||||
**Progress Statuses**:
|
||||
|
||||
- `STARTED` - Operation initiated
|
||||
- `IN_PROGRESS` - Operation in progress
|
||||
- `COMPLETED` - Successfully completed
|
||||
- `FAILED` - Operation failed
|
||||
- `CANCELLED` - Cancelled by user
|
||||
|
||||
**Core Methods**:
|
||||
|
||||
- `start_progress()` - Initialize new progress operation
|
||||
- `update_progress()` - Update progress with current/total values
|
||||
- `complete_progress()` - Mark operation as completed
|
||||
- `fail_progress()` - Mark operation as failed
|
||||
- `cancel_progress()` - Cancel ongoing operation
|
||||
- `get_progress()` - Retrieve progress by ID
|
||||
- `get_all_active_progress()` - Get all active operations (optionally filtered by type)
|
||||
|
||||
**Broadcasting**:
|
||||
|
||||
- Integrates with WebSocketService via callback
|
||||
- Broadcasts to room-specific channels (e.g., `download_progress`, `scan_progress`)
|
||||
- Configurable broadcast throttling (only on significant changes >1% or forced)
|
||||
- Automatic progress state serialization to JSON
|
||||
|
||||
**Singleton Pattern**:
|
||||
|
||||
- Global instance via `get_progress_service()` factory
|
||||
- Initialized during application startup with WebSocket callback
|
||||
|
||||
#### Integration with Services
|
||||
|
||||
**DownloadService Integration**:
|
||||
|
||||
- Progress tracking for each download item
|
||||
- Real-time progress updates during file download
|
||||
- Automatic completion/failure notifications
|
||||
- Progress metadata includes speed, ETA, downloaded bytes
|
||||
|
||||
**AnimeService Integration**:
|
||||
|
||||
- Progress tracking for library scans
|
||||
- Scan progress with current/total file counts
|
||||
- Scan completion with statistics
|
||||
- Error notifications on scan failures
|
||||
|
||||
#### WebSocket Message Models
|
||||
|
||||
**File**: `src/server/models/websocket.py`
|
||||
|
||||
Added progress-specific message models:
|
||||
|
||||
- `ScanProgressMessage` - Scan progress updates
|
||||
- `ScanCompleteMessage` - Scan completion notification
|
||||
- `ScanFailedMessage` - Scan failure notification
|
||||
- `ErrorNotificationMessage` - Critical error notifications
|
||||
- `ProgressUpdateMessage` - Generic progress updates
|
||||
|
||||
**WebSocket Message Types**:
|
||||
|
||||
- `SCAN_PROGRESS` - Scan progress updates
|
||||
- `SCAN_COMPLETE` - Scan completion
|
||||
- `SCAN_FAILED` - Scan failure
|
||||
- Extended existing types for downloads and queue updates
|
||||
|
||||
#### WebSocket Rooms
|
||||
|
||||
Clients can subscribe to specific progress channels:
|
||||
|
||||
- `download_progress` - Download progress updates
|
||||
- `scan_progress` - Library scan updates
|
||||
- `queue_progress` - Queue operation updates
|
||||
- `system_progress` - System-level updates
|
||||
|
||||
Room subscription via client messages:
|
||||
|
||||
```json
|
||||
{
|
||||
"action": "join",
|
||||
"room": "download_progress"
|
||||
}
|
||||
```
|
||||
|
||||
#### Application Startup
|
||||
|
||||
**File**: `src/server/fastapi_app.py`
|
||||
|
||||
Progress service initialized on application startup:
|
||||
|
||||
1. Get ProgressService singleton instance
|
||||
2. Get WebSocketService singleton instance
|
||||
3. Register broadcast callback to link progress updates with WebSocket
|
||||
4. Callback broadcasts progress messages to appropriate rooms
|
||||
|
||||
#### Testing
|
||||
|
||||
**File**: `tests/unit/test_progress_service.py`
|
||||
|
||||
Comprehensive test coverage including:
|
||||
|
||||
- Progress lifecycle operations (start, update, complete, fail, cancel)
|
||||
- Percentage calculation accuracy
|
||||
- History management and size limits
|
||||
- Broadcast callback invocation
|
||||
- Concurrent progress operations
|
||||
- Metadata handling
|
||||
- Error conditions and edge cases
|
||||
|
||||
#### Architecture Benefits
|
||||
|
||||
- **Decoupling**: ProgressService decouples progress tracking from WebSocket broadcasting
|
||||
- **Reusability**: Single service used across all application components
|
||||
- **Scalability**: Supports multiple concurrent operations efficiently
|
||||
- **Observability**: Centralized progress tracking simplifies monitoring
|
||||
- **Real-time UX**: Instant feedback to users on all long-running operations
|
||||
|
||||
#### Future Enhancements
|
||||
|
||||
- Persistent progress history (database storage)
|
||||
- Progress rate calculation and trend analysis
|
||||
- Multi-process progress synchronization (Redis/shared store)
|
||||
- Progress event hooks for custom actions
|
||||
- Client-side progress resumption after reconnection
|
||||
|
||||
@ -29,6 +29,8 @@ from src.server.controllers.error_controller import (
|
||||
from src.server.controllers.health_controller import router as health_router
|
||||
from src.server.controllers.page_controller import router as page_router
|
||||
from src.server.middleware.auth import AuthMiddleware
|
||||
from src.server.services.progress_service import get_progress_service
|
||||
from src.server.services.websocket_service import get_websocket_service
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(
|
||||
@ -74,6 +76,23 @@ async def startup_event():
|
||||
# Initialize SeriesApp with configured directory
|
||||
if settings.anime_directory:
|
||||
series_app = SeriesApp(settings.anime_directory)
|
||||
|
||||
# Initialize progress service with websocket callback
|
||||
progress_service = get_progress_service()
|
||||
ws_service = get_websocket_service()
|
||||
|
||||
async def broadcast_callback(
|
||||
message_type: str, data: dict, room: str
|
||||
):
|
||||
"""Broadcast progress updates via WebSocket."""
|
||||
message = {
|
||||
"type": message_type,
|
||||
"data": data,
|
||||
}
|
||||
await ws_service.manager.broadcast_to_room(message, room)
|
||||
|
||||
progress_service.set_broadcast_callback(broadcast_callback)
|
||||
|
||||
print("FastAPI application started successfully")
|
||||
except Exception as e:
|
||||
print(f"Error during startup: {e}")
|
||||
|
||||
@ -30,6 +30,11 @@ class WebSocketMessageType(str, Enum):
|
||||
QUEUE_PAUSED = "queue_paused"
|
||||
QUEUE_RESUMED = "queue_resumed"
|
||||
|
||||
# Progress-related messages
|
||||
SCAN_PROGRESS = "scan_progress"
|
||||
SCAN_COMPLETE = "scan_complete"
|
||||
SCAN_FAILED = "scan_failed"
|
||||
|
||||
# System messages
|
||||
SYSTEM_INFO = "system_info"
|
||||
SYSTEM_WARNING = "system_warning"
|
||||
@ -188,3 +193,93 @@ class RoomSubscriptionRequest(BaseModel):
|
||||
room: str = Field(
|
||||
..., min_length=1, description="Room name to join or leave"
|
||||
)
|
||||
|
||||
|
||||
class ScanProgressMessage(BaseModel):
|
||||
"""Scan progress update message."""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
default=WebSocketMessageType.SCAN_PROGRESS,
|
||||
description="Message type",
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
...,
|
||||
description="Scan progress data including current, total, percent",
|
||||
)
|
||||
|
||||
|
||||
class ScanCompleteMessage(BaseModel):
|
||||
"""Scan completion message."""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
default=WebSocketMessageType.SCAN_COMPLETE,
|
||||
description="Message type",
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
...,
|
||||
description="Scan completion data including series_found, duration",
|
||||
)
|
||||
|
||||
|
||||
class ScanFailedMessage(BaseModel):
|
||||
"""Scan failure message."""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
default=WebSocketMessageType.SCAN_FAILED,
|
||||
description="Message type",
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
..., description="Scan error data including error_message"
|
||||
)
|
||||
|
||||
|
||||
class ErrorNotificationMessage(BaseModel):
|
||||
"""Error notification message for critical errors."""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
default=WebSocketMessageType.SYSTEM_ERROR,
|
||||
description="Message type",
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Error notification data including severity, message, details"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ProgressUpdateMessage(BaseModel):
|
||||
"""Generic progress update message.
|
||||
|
||||
Can be used for any type of progress (download, scan, queue, etc.)
|
||||
"""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
..., description="Type of progress message"
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Progress data including id, status, percent, current, total"
|
||||
),
|
||||
)
|
||||
|
||||
@ -3,11 +3,16 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import structlog
|
||||
|
||||
from src.core.SeriesApp import SeriesApp
|
||||
from src.server.services.progress_service import (
|
||||
ProgressService,
|
||||
ProgressType,
|
||||
get_progress_service,
|
||||
)
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
@ -24,9 +29,15 @@ class AnimeService:
|
||||
- Adds simple in-memory caching for read operations
|
||||
"""
|
||||
|
||||
def __init__(self, directory: str, max_workers: int = 4):
|
||||
def __init__(
|
||||
self,
|
||||
directory: str,
|
||||
max_workers: int = 4,
|
||||
progress_service: Optional[ProgressService] = None,
|
||||
):
|
||||
self._directory = directory
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self._progress_service = progress_service or get_progress_service()
|
||||
# SeriesApp is blocking; instantiate per-service
|
||||
try:
|
||||
self._app = SeriesApp(directory)
|
||||
@ -75,20 +86,70 @@ class AnimeService:
|
||||
logger.exception("search failed")
|
||||
raise AnimeServiceError("Search failed") from e
|
||||
|
||||
async def rescan(self, callback=None) -> None:
|
||||
async def rescan(self, callback: Optional[Callable] = None) -> None:
|
||||
"""Trigger a re-scan. Accepts an optional callback function.
|
||||
|
||||
The callback is executed in the threadpool by SeriesApp.
|
||||
Progress updates are tracked and broadcasted via ProgressService.
|
||||
"""
|
||||
scan_id = "library_scan"
|
||||
|
||||
try:
|
||||
await self._run_in_executor(self._app.ReScan, callback)
|
||||
# Start progress tracking
|
||||
await self._progress_service.start_progress(
|
||||
progress_id=scan_id,
|
||||
progress_type=ProgressType.SCAN,
|
||||
title="Scanning anime library",
|
||||
message="Initializing scan...",
|
||||
)
|
||||
|
||||
# Create wrapped callback for progress updates
|
||||
def progress_callback(progress_data: dict) -> None:
|
||||
"""Update progress during scan."""
|
||||
try:
|
||||
if callback:
|
||||
callback(progress_data)
|
||||
|
||||
# Update progress service
|
||||
current = progress_data.get("current", 0)
|
||||
total = progress_data.get("total", 0)
|
||||
message = progress_data.get("message", "Scanning...")
|
||||
|
||||
asyncio.create_task(
|
||||
self._progress_service.update_progress(
|
||||
progress_id=scan_id,
|
||||
current=current,
|
||||
total=total,
|
||||
message=message,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Scan progress callback error", error=str(e))
|
||||
|
||||
# Run scan
|
||||
await self._run_in_executor(self._app.ReScan, progress_callback)
|
||||
|
||||
# invalidate cache
|
||||
try:
|
||||
self._cached_list_missing.cache_clear()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Complete progress tracking
|
||||
await self._progress_service.complete_progress(
|
||||
progress_id=scan_id,
|
||||
message="Scan completed successfully",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("rescan failed")
|
||||
|
||||
# Fail progress tracking
|
||||
await self._progress_service.fail_progress(
|
||||
progress_id=scan_id,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
raise AnimeServiceError("Rescan failed") from e
|
||||
|
||||
async def download(self, serie_folder: str, season: int, episode: int, key: str, callback=None) -> bool:
|
||||
|
||||
@ -27,6 +27,11 @@ from src.server.models.download import (
|
||||
QueueStatus,
|
||||
)
|
||||
from src.server.services.anime_service import AnimeService, AnimeServiceError
|
||||
from src.server.services.progress_service import (
|
||||
ProgressService,
|
||||
ProgressType,
|
||||
get_progress_service,
|
||||
)
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
@ -53,6 +58,7 @@ class DownloadService:
|
||||
max_concurrent_downloads: int = 2,
|
||||
max_retries: int = 3,
|
||||
persistence_path: str = "./data/download_queue.json",
|
||||
progress_service: Optional[ProgressService] = None,
|
||||
):
|
||||
"""Initialize the download service.
|
||||
|
||||
@ -61,11 +67,13 @@ class DownloadService:
|
||||
max_concurrent_downloads: Maximum simultaneous downloads
|
||||
max_retries: Maximum retry attempts for failed downloads
|
||||
persistence_path: Path to persist queue state
|
||||
progress_service: Optional progress service for tracking
|
||||
"""
|
||||
self._anime_service = anime_service
|
||||
self._max_concurrent = max_concurrent_downloads
|
||||
self._max_retries = max_retries
|
||||
self._persistence_path = Path(persistence_path)
|
||||
self._progress_service = progress_service or get_progress_service()
|
||||
|
||||
# Queue storage by status
|
||||
self._pending_queue: deque[DownloadItem] = deque()
|
||||
@ -500,6 +508,23 @@ class DownloadService:
|
||||
if item.progress.speed_mbps:
|
||||
self._download_speeds.append(item.progress.speed_mbps)
|
||||
|
||||
# Update progress service
|
||||
if item.progress.total_mb and item.progress.total_mb > 0:
|
||||
current_mb = int(item.progress.downloaded_mb)
|
||||
total_mb = int(item.progress.total_mb)
|
||||
|
||||
asyncio.create_task(
|
||||
self._progress_service.update_progress(
|
||||
progress_id=f"download_{item.id}",
|
||||
current=current_mb,
|
||||
total=total_mb,
|
||||
metadata={
|
||||
"speed_mbps": item.progress.speed_mbps,
|
||||
"eta_seconds": item.progress.eta_seconds,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Broadcast update (fire and forget)
|
||||
asyncio.create_task(
|
||||
self._broadcast_update(
|
||||
@ -535,6 +560,22 @@ class DownloadService:
|
||||
episode=item.episode.episode,
|
||||
)
|
||||
|
||||
# Start progress tracking
|
||||
await self._progress_service.start_progress(
|
||||
progress_id=f"download_{item.id}",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title=f"Downloading {item.serie_name}",
|
||||
message=(
|
||||
f"S{item.episode.season:02d}E{item.episode.episode:02d}"
|
||||
),
|
||||
metadata={
|
||||
"item_id": item.id,
|
||||
"serie_name": item.serie_name,
|
||||
"season": item.episode.season,
|
||||
"episode": item.episode.episode,
|
||||
},
|
||||
)
|
||||
|
||||
# Create progress callback
|
||||
progress_callback = self._create_progress_callback(item)
|
||||
|
||||
@ -561,6 +602,18 @@ class DownloadService:
|
||||
logger.info(
|
||||
"Download completed successfully", item_id=item.id
|
||||
)
|
||||
|
||||
# Complete progress tracking
|
||||
await self._progress_service.complete_progress(
|
||||
progress_id=f"download_{item.id}",
|
||||
message="Download completed successfully",
|
||||
metadata={
|
||||
"downloaded_mb": item.progress.downloaded_mb
|
||||
if item.progress
|
||||
else 0,
|
||||
},
|
||||
)
|
||||
|
||||
await self._broadcast_update(
|
||||
"download_completed", {"item_id": item.id}
|
||||
)
|
||||
@ -581,6 +634,13 @@ class DownloadService:
|
||||
retry_count=item.retry_count,
|
||||
)
|
||||
|
||||
# Fail progress tracking
|
||||
await self._progress_service.fail_progress(
|
||||
progress_id=f"download_{item.id}",
|
||||
error_message=str(e),
|
||||
metadata={"retry_count": item.retry_count},
|
||||
)
|
||||
|
||||
await self._broadcast_update(
|
||||
"download_failed",
|
||||
{"item_id": item.id, "error": item.error},
|
||||
|
||||
485
src/server/services/progress_service.py
Normal file
485
src/server/services/progress_service.py
Normal file
@ -0,0 +1,485 @@
|
||||
"""Progress service for managing real-time progress updates.
|
||||
|
||||
This module provides a centralized service for tracking and broadcasting
|
||||
real-time progress updates for downloads, scans, queue changes, and
|
||||
system events. It integrates with the WebSocket service to push updates
|
||||
to connected clients.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class ProgressType(str, Enum):
|
||||
"""Types of progress updates."""
|
||||
|
||||
DOWNLOAD = "download"
|
||||
SCAN = "scan"
|
||||
QUEUE = "queue"
|
||||
SYSTEM = "system"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class ProgressStatus(str, Enum):
|
||||
"""Status of a progress operation."""
|
||||
|
||||
STARTED = "started"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProgressUpdate:
|
||||
"""Represents a progress update event.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for this progress operation
|
||||
type: Type of progress (download, scan, etc.)
|
||||
status: Current status of the operation
|
||||
title: Human-readable title
|
||||
message: Detailed message
|
||||
percent: Completion percentage (0-100)
|
||||
current: Current progress value
|
||||
total: Total progress value
|
||||
metadata: Additional metadata
|
||||
started_at: When operation started
|
||||
updated_at: When last updated
|
||||
"""
|
||||
|
||||
id: str
|
||||
type: ProgressType
|
||||
status: ProgressStatus
|
||||
title: str
|
||||
message: str = ""
|
||||
percent: float = 0.0
|
||||
current: int = 0
|
||||
total: int = 0
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
started_at: datetime = field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert progress update to dictionary."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"type": self.type.value,
|
||||
"status": self.status.value,
|
||||
"title": self.title,
|
||||
"message": self.message,
|
||||
"percent": round(self.percent, 2),
|
||||
"current": self.current,
|
||||
"total": self.total,
|
||||
"metadata": self.metadata,
|
||||
"started_at": self.started_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
class ProgressServiceError(Exception):
|
||||
"""Service-level exception for progress operations."""
|
||||
|
||||
|
||||
class ProgressService:
|
||||
"""Manages real-time progress updates and broadcasting.
|
||||
|
||||
Features:
|
||||
- Track multiple concurrent progress operations
|
||||
- Calculate progress percentages and rates
|
||||
- Broadcast updates via WebSocket
|
||||
- Manage progress lifecycle (start, update, complete, fail)
|
||||
- Support for different progress types (download, scan, queue)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the progress service."""
|
||||
# Active progress operations: id -> ProgressUpdate
|
||||
self._active_progress: Dict[str, ProgressUpdate] = {}
|
||||
|
||||
# Completed progress history (limited size)
|
||||
self._history: Dict[str, ProgressUpdate] = {}
|
||||
self._max_history_size = 50
|
||||
|
||||
# WebSocket broadcast callback
|
||||
self._broadcast_callback: Optional[Callable] = None
|
||||
|
||||
# Lock for thread-safe operations
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
logger.info("ProgressService initialized")
|
||||
|
||||
def set_broadcast_callback(self, callback: Callable) -> None:
|
||||
"""Set callback for broadcasting progress updates via WebSocket.
|
||||
|
||||
Args:
|
||||
callback: Async function to call for broadcasting updates
|
||||
"""
|
||||
self._broadcast_callback = callback
|
||||
logger.debug("Progress broadcast callback registered")
|
||||
|
||||
async def _broadcast(self, update: ProgressUpdate, room: str) -> None:
|
||||
"""Broadcast progress update to WebSocket clients.
|
||||
|
||||
Args:
|
||||
update: Progress update to broadcast
|
||||
room: WebSocket room to broadcast to
|
||||
"""
|
||||
if self._broadcast_callback:
|
||||
try:
|
||||
await self._broadcast_callback(
|
||||
message_type=f"{update.type.value}_progress",
|
||||
data=update.to_dict(),
|
||||
room=room,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to broadcast progress update",
|
||||
error=str(e),
|
||||
progress_id=update.id,
|
||||
)
|
||||
|
||||
async def start_progress(
|
||||
self,
|
||||
progress_id: str,
|
||||
progress_type: ProgressType,
|
||||
title: str,
|
||||
total: int = 0,
|
||||
message: str = "",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> ProgressUpdate:
|
||||
"""Start a new progress operation.
|
||||
|
||||
Args:
|
||||
progress_id: Unique identifier for this progress
|
||||
progress_type: Type of progress operation
|
||||
title: Human-readable title
|
||||
total: Total items/bytes to process
|
||||
message: Initial message
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
Created progress update object
|
||||
|
||||
Raises:
|
||||
ProgressServiceError: If progress already exists
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_id in self._active_progress:
|
||||
raise ProgressServiceError(
|
||||
f"Progress with id '{progress_id}' already exists"
|
||||
)
|
||||
|
||||
update = ProgressUpdate(
|
||||
id=progress_id,
|
||||
type=progress_type,
|
||||
status=ProgressStatus.STARTED,
|
||||
title=title,
|
||||
message=message,
|
||||
total=total,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
self._active_progress[progress_id] = update
|
||||
|
||||
logger.info(
|
||||
"Progress started",
|
||||
progress_id=progress_id,
|
||||
type=progress_type.value,
|
||||
title=title,
|
||||
)
|
||||
|
||||
# Broadcast to appropriate room
|
||||
room = f"{progress_type.value}_progress"
|
||||
await self._broadcast(update, room)
|
||||
|
||||
return update
|
||||
|
||||
async def update_progress(
|
||||
self,
|
||||
progress_id: str,
|
||||
current: Optional[int] = None,
|
||||
total: Optional[int] = None,
|
||||
message: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
force_broadcast: bool = False,
|
||||
) -> ProgressUpdate:
|
||||
"""Update an existing progress operation.
|
||||
|
||||
Args:
|
||||
progress_id: Progress identifier
|
||||
current: Current progress value
|
||||
total: Updated total value
|
||||
message: Updated message
|
||||
metadata: Additional metadata to merge
|
||||
force_broadcast: Force broadcasting even for small changes
|
||||
|
||||
Returns:
|
||||
Updated progress object
|
||||
|
||||
Raises:
|
||||
ProgressServiceError: If progress not found
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_id not in self._active_progress:
|
||||
raise ProgressServiceError(
|
||||
f"Progress with id '{progress_id}' not found"
|
||||
)
|
||||
|
||||
update = self._active_progress[progress_id]
|
||||
old_percent = update.percent
|
||||
|
||||
# Update fields
|
||||
if current is not None:
|
||||
update.current = current
|
||||
if total is not None:
|
||||
update.total = total
|
||||
if message is not None:
|
||||
update.message = message
|
||||
if metadata:
|
||||
update.metadata.update(metadata)
|
||||
|
||||
# Calculate percentage
|
||||
if update.total > 0:
|
||||
update.percent = (update.current / update.total) * 100
|
||||
else:
|
||||
update.percent = 0.0
|
||||
|
||||
update.status = ProgressStatus.IN_PROGRESS
|
||||
update.updated_at = datetime.utcnow()
|
||||
|
||||
# Only broadcast if significant change or forced
|
||||
percent_change = abs(update.percent - old_percent)
|
||||
should_broadcast = force_broadcast or percent_change >= 1.0
|
||||
|
||||
if should_broadcast:
|
||||
room = f"{update.type.value}_progress"
|
||||
await self._broadcast(update, room)
|
||||
|
||||
return update
|
||||
|
||||
async def complete_progress(
|
||||
self,
|
||||
progress_id: str,
|
||||
message: str = "Completed successfully",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> ProgressUpdate:
|
||||
"""Mark a progress operation as completed.
|
||||
|
||||
Args:
|
||||
progress_id: Progress identifier
|
||||
message: Completion message
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
Completed progress object
|
||||
|
||||
Raises:
|
||||
ProgressServiceError: If progress not found
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_id not in self._active_progress:
|
||||
raise ProgressServiceError(
|
||||
f"Progress with id '{progress_id}' not found"
|
||||
)
|
||||
|
||||
update = self._active_progress[progress_id]
|
||||
update.status = ProgressStatus.COMPLETED
|
||||
update.message = message
|
||||
update.percent = 100.0
|
||||
update.current = update.total
|
||||
update.updated_at = datetime.utcnow()
|
||||
|
||||
if metadata:
|
||||
update.metadata.update(metadata)
|
||||
|
||||
# Move to history
|
||||
del self._active_progress[progress_id]
|
||||
self._add_to_history(update)
|
||||
|
||||
logger.info(
|
||||
"Progress completed",
|
||||
progress_id=progress_id,
|
||||
type=update.type.value,
|
||||
)
|
||||
|
||||
# Broadcast completion
|
||||
room = f"{update.type.value}_progress"
|
||||
await self._broadcast(update, room)
|
||||
|
||||
return update
|
||||
|
||||
async def fail_progress(
|
||||
self,
|
||||
progress_id: str,
|
||||
error_message: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> ProgressUpdate:
|
||||
"""Mark a progress operation as failed.
|
||||
|
||||
Args:
|
||||
progress_id: Progress identifier
|
||||
error_message: Error description
|
||||
metadata: Additional error metadata
|
||||
|
||||
Returns:
|
||||
Failed progress object
|
||||
|
||||
Raises:
|
||||
ProgressServiceError: If progress not found
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_id not in self._active_progress:
|
||||
raise ProgressServiceError(
|
||||
f"Progress with id '{progress_id}' not found"
|
||||
)
|
||||
|
||||
update = self._active_progress[progress_id]
|
||||
update.status = ProgressStatus.FAILED
|
||||
update.message = error_message
|
||||
update.updated_at = datetime.utcnow()
|
||||
|
||||
if metadata:
|
||||
update.metadata.update(metadata)
|
||||
|
||||
# Move to history
|
||||
del self._active_progress[progress_id]
|
||||
self._add_to_history(update)
|
||||
|
||||
logger.error(
|
||||
"Progress failed",
|
||||
progress_id=progress_id,
|
||||
type=update.type.value,
|
||||
error=error_message,
|
||||
)
|
||||
|
||||
# Broadcast failure
|
||||
room = f"{update.type.value}_progress"
|
||||
await self._broadcast(update, room)
|
||||
|
||||
return update
|
||||
|
||||
async def cancel_progress(
|
||||
self,
|
||||
progress_id: str,
|
||||
message: str = "Cancelled by user",
|
||||
) -> ProgressUpdate:
|
||||
"""Cancel a progress operation.
|
||||
|
||||
Args:
|
||||
progress_id: Progress identifier
|
||||
message: Cancellation message
|
||||
|
||||
Returns:
|
||||
Cancelled progress object
|
||||
|
||||
Raises:
|
||||
ProgressServiceError: If progress not found
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_id not in self._active_progress:
|
||||
raise ProgressServiceError(
|
||||
f"Progress with id '{progress_id}' not found"
|
||||
)
|
||||
|
||||
update = self._active_progress[progress_id]
|
||||
update.status = ProgressStatus.CANCELLED
|
||||
update.message = message
|
||||
update.updated_at = datetime.utcnow()
|
||||
|
||||
# Move to history
|
||||
del self._active_progress[progress_id]
|
||||
self._add_to_history(update)
|
||||
|
||||
logger.info(
|
||||
"Progress cancelled",
|
||||
progress_id=progress_id,
|
||||
type=update.type.value,
|
||||
)
|
||||
|
||||
# Broadcast cancellation
|
||||
room = f"{update.type.value}_progress"
|
||||
await self._broadcast(update, room)
|
||||
|
||||
return update
|
||||
|
||||
def _add_to_history(self, update: ProgressUpdate) -> None:
|
||||
"""Add completed progress to history with size limit."""
|
||||
self._history[update.id] = update
|
||||
|
||||
# Maintain history size limit
|
||||
if len(self._history) > self._max_history_size:
|
||||
# Remove oldest entries
|
||||
oldest_keys = sorted(
|
||||
self._history.keys(),
|
||||
key=lambda k: self._history[k].updated_at,
|
||||
)[: len(self._history) - self._max_history_size]
|
||||
|
||||
for key in oldest_keys:
|
||||
del self._history[key]
|
||||
|
||||
async def get_progress(self, progress_id: str) -> Optional[ProgressUpdate]:
|
||||
"""Get current progress state.
|
||||
|
||||
Args:
|
||||
progress_id: Progress identifier
|
||||
|
||||
Returns:
|
||||
Progress update object or None if not found
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_id in self._active_progress:
|
||||
return self._active_progress[progress_id]
|
||||
if progress_id in self._history:
|
||||
return self._history[progress_id]
|
||||
return None
|
||||
|
||||
async def get_all_active_progress(
|
||||
self, progress_type: Optional[ProgressType] = None
|
||||
) -> Dict[str, ProgressUpdate]:
|
||||
"""Get all active progress operations.
|
||||
|
||||
Args:
|
||||
progress_type: Optional filter by progress type
|
||||
|
||||
Returns:
|
||||
Dictionary of progress_id -> ProgressUpdate
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_type:
|
||||
return {
|
||||
pid: update
|
||||
for pid, update in self._active_progress.items()
|
||||
if update.type == progress_type
|
||||
}
|
||||
return self._active_progress.copy()
|
||||
|
||||
async def clear_history(self) -> None:
|
||||
"""Clear progress history."""
|
||||
async with self._lock:
|
||||
self._history.clear()
|
||||
logger.info("Progress history cleared")
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_progress_service: Optional[ProgressService] = None
|
||||
|
||||
|
||||
def get_progress_service() -> ProgressService:
|
||||
"""Get or create the global progress service instance.
|
||||
|
||||
Returns:
|
||||
Global ProgressService instance
|
||||
"""
|
||||
global _progress_service
|
||||
if _progress_service is None:
|
||||
_progress_service = ProgressService()
|
||||
return _progress_service
|
||||
499
tests/unit/test_progress_service.py
Normal file
499
tests/unit/test_progress_service.py
Normal file
@ -0,0 +1,499 @@
|
||||
"""Unit tests for ProgressService.
|
||||
|
||||
This module contains comprehensive tests for the progress tracking service,
|
||||
including progress lifecycle, broadcasting, error handling, and concurrency.
|
||||
"""
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.server.services.progress_service import (
|
||||
ProgressService,
|
||||
ProgressServiceError,
|
||||
ProgressStatus,
|
||||
ProgressType,
|
||||
ProgressUpdate,
|
||||
)
|
||||
|
||||
|
||||
class TestProgressUpdate:
|
||||
"""Test ProgressUpdate dataclass."""
|
||||
|
||||
def test_progress_update_creation(self):
|
||||
"""Test creating a progress update."""
|
||||
update = ProgressUpdate(
|
||||
id="test-1",
|
||||
type=ProgressType.DOWNLOAD,
|
||||
status=ProgressStatus.STARTED,
|
||||
title="Test Download",
|
||||
message="Starting download",
|
||||
total=100,
|
||||
)
|
||||
|
||||
assert update.id == "test-1"
|
||||
assert update.type == ProgressType.DOWNLOAD
|
||||
assert update.status == ProgressStatus.STARTED
|
||||
assert update.title == "Test Download"
|
||||
assert update.message == "Starting download"
|
||||
assert update.total == 100
|
||||
assert update.current == 0
|
||||
assert update.percent == 0.0
|
||||
|
||||
def test_progress_update_to_dict(self):
|
||||
"""Test converting progress update to dictionary."""
|
||||
update = ProgressUpdate(
|
||||
id="test-1",
|
||||
type=ProgressType.SCAN,
|
||||
status=ProgressStatus.IN_PROGRESS,
|
||||
title="Test Scan",
|
||||
message="Scanning files",
|
||||
current=50,
|
||||
total=100,
|
||||
metadata={"test_key": "test_value"},
|
||||
)
|
||||
|
||||
result = update.to_dict()
|
||||
|
||||
assert result["id"] == "test-1"
|
||||
assert result["type"] == "scan"
|
||||
assert result["status"] == "in_progress"
|
||||
assert result["title"] == "Test Scan"
|
||||
assert result["message"] == "Scanning files"
|
||||
assert result["current"] == 50
|
||||
assert result["total"] == 100
|
||||
assert result["percent"] == 0.0
|
||||
assert result["metadata"]["test_key"] == "test_value"
|
||||
assert "started_at" in result
|
||||
assert "updated_at" in result
|
||||
|
||||
|
||||
class TestProgressService:
|
||||
"""Test ProgressService class."""
|
||||
|
||||
@pytest.fixture
|
||||
def service(self):
|
||||
"""Create a fresh ProgressService instance for each test."""
|
||||
return ProgressService()
|
||||
|
||||
@pytest.fixture
|
||||
def mock_broadcast(self):
|
||||
"""Create a mock broadcast callback."""
|
||||
return AsyncMock()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_progress(self, service):
|
||||
"""Test starting a new progress operation."""
|
||||
update = await service.start_progress(
|
||||
progress_id="download-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Downloading episode",
|
||||
total=1000,
|
||||
message="Starting...",
|
||||
metadata={"episode": "S01E01"},
|
||||
)
|
||||
|
||||
assert update.id == "download-1"
|
||||
assert update.type == ProgressType.DOWNLOAD
|
||||
assert update.status == ProgressStatus.STARTED
|
||||
assert update.title == "Downloading episode"
|
||||
assert update.total == 1000
|
||||
assert update.message == "Starting..."
|
||||
assert update.metadata["episode"] == "S01E01"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_progress_duplicate_id(self, service):
|
||||
"""Test starting progress with duplicate ID raises error."""
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test",
|
||||
)
|
||||
|
||||
with pytest.raises(ProgressServiceError, match="already exists"):
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test Duplicate",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_progress(self, service):
|
||||
"""Test updating an existing progress operation."""
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test",
|
||||
total=100,
|
||||
)
|
||||
|
||||
update = await service.update_progress(
|
||||
progress_id="test-1",
|
||||
current=50,
|
||||
message="Half way",
|
||||
)
|
||||
|
||||
assert update.current == 50
|
||||
assert update.total == 100
|
||||
assert update.percent == 50.0
|
||||
assert update.message == "Half way"
|
||||
assert update.status == ProgressStatus.IN_PROGRESS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_progress_not_found(self, service):
|
||||
"""Test updating non-existent progress raises error."""
|
||||
with pytest.raises(ProgressServiceError, match="not found"):
|
||||
await service.update_progress(
|
||||
progress_id="nonexistent",
|
||||
current=50,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_progress_percentage_calculation(self, service):
|
||||
"""Test progress percentage is calculated correctly."""
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test",
|
||||
total=200,
|
||||
)
|
||||
|
||||
await service.update_progress(progress_id="test-1", current=50)
|
||||
update = await service.get_progress("test-1")
|
||||
assert update.percent == 25.0
|
||||
|
||||
await service.update_progress(progress_id="test-1", current=100)
|
||||
update = await service.get_progress("test-1")
|
||||
assert update.percent == 50.0
|
||||
|
||||
await service.update_progress(progress_id="test-1", current=200)
|
||||
update = await service.get_progress("test-1")
|
||||
assert update.percent == 100.0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_complete_progress(self, service):
|
||||
"""Test completing a progress operation."""
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.SCAN,
|
||||
title="Test Scan",
|
||||
total=100,
|
||||
)
|
||||
|
||||
await service.update_progress(progress_id="test-1", current=50)
|
||||
|
||||
update = await service.complete_progress(
|
||||
progress_id="test-1",
|
||||
message="Scan completed successfully",
|
||||
metadata={"items_found": 42},
|
||||
)
|
||||
|
||||
assert update.status == ProgressStatus.COMPLETED
|
||||
assert update.percent == 100.0
|
||||
assert update.current == update.total
|
||||
assert update.message == "Scan completed successfully"
|
||||
assert update.metadata["items_found"] == 42
|
||||
|
||||
# Should be moved to history
|
||||
active_progress = await service.get_all_active_progress()
|
||||
assert "test-1" not in active_progress
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fail_progress(self, service):
|
||||
"""Test failing a progress operation."""
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test Download",
|
||||
)
|
||||
|
||||
update = await service.fail_progress(
|
||||
progress_id="test-1",
|
||||
error_message="Network timeout",
|
||||
metadata={"retry_count": 3},
|
||||
)
|
||||
|
||||
assert update.status == ProgressStatus.FAILED
|
||||
assert update.message == "Network timeout"
|
||||
assert update.metadata["retry_count"] == 3
|
||||
|
||||
# Should be moved to history
|
||||
active_progress = await service.get_all_active_progress()
|
||||
assert "test-1" not in active_progress
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_progress(self, service):
|
||||
"""Test cancelling a progress operation."""
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test Download",
|
||||
)
|
||||
|
||||
update = await service.cancel_progress(
|
||||
progress_id="test-1",
|
||||
message="Cancelled by user",
|
||||
)
|
||||
|
||||
assert update.status == ProgressStatus.CANCELLED
|
||||
assert update.message == "Cancelled by user"
|
||||
|
||||
# Should be moved to history
|
||||
active_progress = await service.get_all_active_progress()
|
||||
assert "test-1" not in active_progress
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_progress(self, service):
|
||||
"""Test retrieving progress by ID."""
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.SCAN,
|
||||
title="Test",
|
||||
)
|
||||
|
||||
progress = await service.get_progress("test-1")
|
||||
assert progress is not None
|
||||
assert progress.id == "test-1"
|
||||
|
||||
# Test non-existent progress
|
||||
progress = await service.get_progress("nonexistent")
|
||||
assert progress is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_active_progress(self, service):
|
||||
"""Test retrieving all active progress operations."""
|
||||
await service.start_progress(
|
||||
progress_id="download-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Download 1",
|
||||
)
|
||||
await service.start_progress(
|
||||
progress_id="download-2",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Download 2",
|
||||
)
|
||||
await service.start_progress(
|
||||
progress_id="scan-1",
|
||||
progress_type=ProgressType.SCAN,
|
||||
title="Scan 1",
|
||||
)
|
||||
|
||||
all_progress = await service.get_all_active_progress()
|
||||
assert len(all_progress) == 3
|
||||
assert "download-1" in all_progress
|
||||
assert "download-2" in all_progress
|
||||
assert "scan-1" in all_progress
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_active_progress_filtered(self, service):
|
||||
"""Test retrieving active progress filtered by type."""
|
||||
await service.start_progress(
|
||||
progress_id="download-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Download 1",
|
||||
)
|
||||
await service.start_progress(
|
||||
progress_id="download-2",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Download 2",
|
||||
)
|
||||
await service.start_progress(
|
||||
progress_id="scan-1",
|
||||
progress_type=ProgressType.SCAN,
|
||||
title="Scan 1",
|
||||
)
|
||||
|
||||
download_progress = await service.get_all_active_progress(
|
||||
progress_type=ProgressType.DOWNLOAD
|
||||
)
|
||||
assert len(download_progress) == 2
|
||||
assert "download-1" in download_progress
|
||||
assert "download-2" in download_progress
|
||||
assert "scan-1" not in download_progress
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_history_management(self, service):
|
||||
"""Test progress history is maintained with size limit."""
|
||||
# Start and complete multiple progress operations
|
||||
for i in range(60): # More than max_history_size (50)
|
||||
await service.start_progress(
|
||||
progress_id=f"test-{i}",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title=f"Test {i}",
|
||||
)
|
||||
await service.complete_progress(
|
||||
progress_id=f"test-{i}",
|
||||
message="Completed",
|
||||
)
|
||||
|
||||
# Check that oldest entries were removed
|
||||
history = service._history
|
||||
assert len(history) <= 50
|
||||
|
||||
# Most recent should be in history
|
||||
recent_progress = await service.get_progress("test-59")
|
||||
assert recent_progress is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_callback(self, service, mock_broadcast):
|
||||
"""Test broadcast callback is invoked correctly."""
|
||||
service.set_broadcast_callback(mock_broadcast)
|
||||
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test",
|
||||
)
|
||||
|
||||
# Verify callback was called for start
|
||||
mock_broadcast.assert_called_once()
|
||||
call_args = mock_broadcast.call_args
|
||||
assert call_args[1]["message_type"] == "download_progress"
|
||||
assert call_args[1]["room"] == "download_progress"
|
||||
assert "test-1" in str(call_args[1]["data"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_on_update(self, service, mock_broadcast):
|
||||
"""Test broadcast on progress update."""
|
||||
service.set_broadcast_callback(mock_broadcast)
|
||||
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test",
|
||||
total=100,
|
||||
)
|
||||
mock_broadcast.reset_mock()
|
||||
|
||||
# Update with significant change (>1%)
|
||||
await service.update_progress(
|
||||
progress_id="test-1",
|
||||
current=50,
|
||||
force_broadcast=True,
|
||||
)
|
||||
|
||||
# Should have been called
|
||||
assert mock_broadcast.call_count >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_on_complete(self, service, mock_broadcast):
|
||||
"""Test broadcast on progress completion."""
|
||||
service.set_broadcast_callback(mock_broadcast)
|
||||
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.SCAN,
|
||||
title="Test",
|
||||
)
|
||||
mock_broadcast.reset_mock()
|
||||
|
||||
await service.complete_progress(
|
||||
progress_id="test-1",
|
||||
message="Done",
|
||||
)
|
||||
|
||||
# Should have been called
|
||||
mock_broadcast.assert_called_once()
|
||||
call_args = mock_broadcast.call_args
|
||||
assert "completed" in str(call_args[1]["data"]).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_broadcast_on_failure(self, service, mock_broadcast):
|
||||
"""Test broadcast on progress failure."""
|
||||
service.set_broadcast_callback(mock_broadcast)
|
||||
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test",
|
||||
)
|
||||
mock_broadcast.reset_mock()
|
||||
|
||||
await service.fail_progress(
|
||||
progress_id="test-1",
|
||||
error_message="Test error",
|
||||
)
|
||||
|
||||
# Should have been called
|
||||
mock_broadcast.assert_called_once()
|
||||
call_args = mock_broadcast.call_args
|
||||
assert "failed" in str(call_args[1]["data"]).lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_history(self, service):
|
||||
"""Test clearing progress history."""
|
||||
# Create and complete some progress
|
||||
for i in range(5):
|
||||
await service.start_progress(
|
||||
progress_id=f"test-{i}",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title=f"Test {i}",
|
||||
)
|
||||
await service.complete_progress(
|
||||
progress_id=f"test-{i}",
|
||||
message="Done",
|
||||
)
|
||||
|
||||
# History should not be empty
|
||||
assert len(service._history) > 0
|
||||
|
||||
# Clear history
|
||||
await service.clear_history()
|
||||
|
||||
# History should now be empty
|
||||
assert len(service._history) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_progress_operations(self, service):
|
||||
"""Test handling multiple concurrent progress operations."""
|
||||
|
||||
async def create_and_complete_progress(id_num: int):
|
||||
"""Helper to create and complete a progress."""
|
||||
await service.start_progress(
|
||||
progress_id=f"test-{id_num}",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title=f"Test {id_num}",
|
||||
total=100,
|
||||
)
|
||||
for i in range(0, 101, 10):
|
||||
await service.update_progress(
|
||||
progress_id=f"test-{id_num}",
|
||||
current=i,
|
||||
)
|
||||
await asyncio.sleep(0.01)
|
||||
await service.complete_progress(
|
||||
progress_id=f"test-{id_num}",
|
||||
message="Done",
|
||||
)
|
||||
|
||||
# Run multiple concurrent operations
|
||||
tasks = [create_and_complete_progress(i) for i in range(10)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# All should be in history
|
||||
for i in range(10):
|
||||
progress = await service.get_progress(f"test-{i}")
|
||||
assert progress is not None
|
||||
assert progress.status == ProgressStatus.COMPLETED
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_with_metadata(self, service):
|
||||
"""Test updating progress with metadata."""
|
||||
await service.start_progress(
|
||||
progress_id="test-1",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title="Test",
|
||||
metadata={"initial": "value"},
|
||||
)
|
||||
|
||||
await service.update_progress(
|
||||
progress_id="test-1",
|
||||
current=50,
|
||||
metadata={"additional": "data", "speed": 1.5},
|
||||
)
|
||||
|
||||
progress = await service.get_progress("test-1")
|
||||
assert progress.metadata["initial"] == "value"
|
||||
assert progress.metadata["additional"] == "data"
|
||||
assert progress.metadata["speed"] == 1.5
|
||||
Loading…
x
Reference in New Issue
Block a user