feat: implement WebSocket real-time progress updates

- Add ProgressService for centralized progress tracking and broadcasting
- Integrate ProgressService with DownloadService for download progress
- Integrate ProgressService with AnimeService for scan progress
- Add progress-related WebSocket message models (ScanProgress, ErrorNotification, etc.)
- Initialize ProgressService with WebSocket callback in application startup
- Add comprehensive unit tests for ProgressService
- Update infrastructure.md with ProgressService documentation
- Remove completed WebSocket Real-time Updates task from instructions.md

The ProgressService provides:
- Real-time progress tracking for downloads, scans, and queue operations
- Automatic progress percentage calculation
- Progress lifecycle management (start, update, complete, fail, cancel)
- WebSocket integration for instant client updates
- Progress history with size limits
- Thread-safe operations using asyncio locks
- Support for metadata and custom messages

Benefits:
- Decoupled progress tracking from WebSocket broadcasting
- Single reusable service across all components
- Supports multiple concurrent operations efficiently
- Centralized progress tracking simplifies monitoring
- Instant feedback to users on long-running operations
This commit is contained in:
2025-10-17 11:12:06 +02:00
parent 42a07be4cb
commit 94de91ffa0
7 changed files with 1375 additions and 4 deletions

View File

@@ -29,6 +29,8 @@ from src.server.controllers.error_controller import (
from src.server.controllers.health_controller import router as health_router
from src.server.controllers.page_controller import router as page_router
from src.server.middleware.auth import AuthMiddleware
from src.server.services.progress_service import get_progress_service
from src.server.services.websocket_service import get_websocket_service
# Initialize FastAPI app
app = FastAPI(
@@ -74,6 +76,23 @@ async def startup_event():
# Initialize SeriesApp with configured directory
if settings.anime_directory:
series_app = SeriesApp(settings.anime_directory)
# Initialize progress service with websocket callback
progress_service = get_progress_service()
ws_service = get_websocket_service()
async def broadcast_callback(
message_type: str, data: dict, room: str
):
"""Broadcast progress updates via WebSocket."""
message = {
"type": message_type,
"data": data,
}
await ws_service.manager.broadcast_to_room(message, room)
progress_service.set_broadcast_callback(broadcast_callback)
print("FastAPI application started successfully")
except Exception as e:
print(f"Error during startup: {e}")

View File

@@ -30,6 +30,11 @@ class WebSocketMessageType(str, Enum):
QUEUE_PAUSED = "queue_paused"
QUEUE_RESUMED = "queue_resumed"
# Progress-related messages
SCAN_PROGRESS = "scan_progress"
SCAN_COMPLETE = "scan_complete"
SCAN_FAILED = "scan_failed"
# System messages
SYSTEM_INFO = "system_info"
SYSTEM_WARNING = "system_warning"
@@ -188,3 +193,93 @@ class RoomSubscriptionRequest(BaseModel):
room: str = Field(
..., min_length=1, description="Room name to join or leave"
)
class ScanProgressMessage(BaseModel):
"""Scan progress update message."""
type: WebSocketMessageType = Field(
default=WebSocketMessageType.SCAN_PROGRESS,
description="Message type",
)
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(),
description="ISO 8601 timestamp",
)
data: Dict[str, Any] = Field(
...,
description="Scan progress data including current, total, percent",
)
class ScanCompleteMessage(BaseModel):
"""Scan completion message."""
type: WebSocketMessageType = Field(
default=WebSocketMessageType.SCAN_COMPLETE,
description="Message type",
)
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(),
description="ISO 8601 timestamp",
)
data: Dict[str, Any] = Field(
...,
description="Scan completion data including series_found, duration",
)
class ScanFailedMessage(BaseModel):
"""Scan failure message."""
type: WebSocketMessageType = Field(
default=WebSocketMessageType.SCAN_FAILED,
description="Message type",
)
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(),
description="ISO 8601 timestamp",
)
data: Dict[str, Any] = Field(
..., description="Scan error data including error_message"
)
class ErrorNotificationMessage(BaseModel):
"""Error notification message for critical errors."""
type: WebSocketMessageType = Field(
default=WebSocketMessageType.SYSTEM_ERROR,
description="Message type",
)
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(),
description="ISO 8601 timestamp",
)
data: Dict[str, Any] = Field(
...,
description=(
"Error notification data including severity, message, details"
),
)
class ProgressUpdateMessage(BaseModel):
"""Generic progress update message.
Can be used for any type of progress (download, scan, queue, etc.)
"""
type: WebSocketMessageType = Field(
..., description="Type of progress message"
)
timestamp: str = Field(
default_factory=lambda: datetime.utcnow().isoformat(),
description="ISO 8601 timestamp",
)
data: Dict[str, Any] = Field(
...,
description=(
"Progress data including id, status, percent, current, total"
),
)

View File

@@ -3,11 +3,16 @@ from __future__ import annotations
import asyncio
from concurrent.futures import ThreadPoolExecutor
from functools import lru_cache
from typing import List, Optional
from typing import Callable, List, Optional
import structlog
from src.core.SeriesApp import SeriesApp
from src.server.services.progress_service import (
ProgressService,
ProgressType,
get_progress_service,
)
logger = structlog.get_logger(__name__)
@@ -24,9 +29,15 @@ class AnimeService:
- Adds simple in-memory caching for read operations
"""
def __init__(self, directory: str, max_workers: int = 4):
def __init__(
self,
directory: str,
max_workers: int = 4,
progress_service: Optional[ProgressService] = None,
):
self._directory = directory
self._executor = ThreadPoolExecutor(max_workers=max_workers)
self._progress_service = progress_service or get_progress_service()
# SeriesApp is blocking; instantiate per-service
try:
self._app = SeriesApp(directory)
@@ -75,20 +86,70 @@ class AnimeService:
logger.exception("search failed")
raise AnimeServiceError("Search failed") from e
async def rescan(self, callback=None) -> None:
async def rescan(self, callback: Optional[Callable] = None) -> None:
"""Trigger a re-scan. Accepts an optional callback function.
The callback is executed in the threadpool by SeriesApp.
Progress updates are tracked and broadcasted via ProgressService.
"""
scan_id = "library_scan"
try:
await self._run_in_executor(self._app.ReScan, callback)
# Start progress tracking
await self._progress_service.start_progress(
progress_id=scan_id,
progress_type=ProgressType.SCAN,
title="Scanning anime library",
message="Initializing scan...",
)
# Create wrapped callback for progress updates
def progress_callback(progress_data: dict) -> None:
"""Update progress during scan."""
try:
if callback:
callback(progress_data)
# Update progress service
current = progress_data.get("current", 0)
total = progress_data.get("total", 0)
message = progress_data.get("message", "Scanning...")
asyncio.create_task(
self._progress_service.update_progress(
progress_id=scan_id,
current=current,
total=total,
message=message,
)
)
except Exception as e:
logger.error("Scan progress callback error", error=str(e))
# Run scan
await self._run_in_executor(self._app.ReScan, progress_callback)
# invalidate cache
try:
self._cached_list_missing.cache_clear()
except Exception:
pass
# Complete progress tracking
await self._progress_service.complete_progress(
progress_id=scan_id,
message="Scan completed successfully",
)
except Exception as e:
logger.exception("rescan failed")
# Fail progress tracking
await self._progress_service.fail_progress(
progress_id=scan_id,
error_message=str(e),
)
raise AnimeServiceError("Rescan failed") from e
async def download(self, serie_folder: str, season: int, episode: int, key: str, callback=None) -> bool:

View File

@@ -27,6 +27,11 @@ from src.server.models.download import (
QueueStatus,
)
from src.server.services.anime_service import AnimeService, AnimeServiceError
from src.server.services.progress_service import (
ProgressService,
ProgressType,
get_progress_service,
)
logger = structlog.get_logger(__name__)
@@ -53,6 +58,7 @@ class DownloadService:
max_concurrent_downloads: int = 2,
max_retries: int = 3,
persistence_path: str = "./data/download_queue.json",
progress_service: Optional[ProgressService] = None,
):
"""Initialize the download service.
@@ -61,11 +67,13 @@ class DownloadService:
max_concurrent_downloads: Maximum simultaneous downloads
max_retries: Maximum retry attempts for failed downloads
persistence_path: Path to persist queue state
progress_service: Optional progress service for tracking
"""
self._anime_service = anime_service
self._max_concurrent = max_concurrent_downloads
self._max_retries = max_retries
self._persistence_path = Path(persistence_path)
self._progress_service = progress_service or get_progress_service()
# Queue storage by status
self._pending_queue: deque[DownloadItem] = deque()
@@ -500,6 +508,23 @@ class DownloadService:
if item.progress.speed_mbps:
self._download_speeds.append(item.progress.speed_mbps)
# Update progress service
if item.progress.total_mb and item.progress.total_mb > 0:
current_mb = int(item.progress.downloaded_mb)
total_mb = int(item.progress.total_mb)
asyncio.create_task(
self._progress_service.update_progress(
progress_id=f"download_{item.id}",
current=current_mb,
total=total_mb,
metadata={
"speed_mbps": item.progress.speed_mbps,
"eta_seconds": item.progress.eta_seconds,
},
)
)
# Broadcast update (fire and forget)
asyncio.create_task(
self._broadcast_update(
@@ -535,6 +560,22 @@ class DownloadService:
episode=item.episode.episode,
)
# Start progress tracking
await self._progress_service.start_progress(
progress_id=f"download_{item.id}",
progress_type=ProgressType.DOWNLOAD,
title=f"Downloading {item.serie_name}",
message=(
f"S{item.episode.season:02d}E{item.episode.episode:02d}"
),
metadata={
"item_id": item.id,
"serie_name": item.serie_name,
"season": item.episode.season,
"episode": item.episode.episode,
},
)
# Create progress callback
progress_callback = self._create_progress_callback(item)
@@ -561,6 +602,18 @@ class DownloadService:
logger.info(
"Download completed successfully", item_id=item.id
)
# Complete progress tracking
await self._progress_service.complete_progress(
progress_id=f"download_{item.id}",
message="Download completed successfully",
metadata={
"downloaded_mb": item.progress.downloaded_mb
if item.progress
else 0,
},
)
await self._broadcast_update(
"download_completed", {"item_id": item.id}
)
@@ -581,6 +634,13 @@ class DownloadService:
retry_count=item.retry_count,
)
# Fail progress tracking
await self._progress_service.fail_progress(
progress_id=f"download_{item.id}",
error_message=str(e),
metadata={"retry_count": item.retry_count},
)
await self._broadcast_update(
"download_failed",
{"item_id": item.id, "error": item.error},

View File

@@ -0,0 +1,485 @@
"""Progress service for managing real-time progress updates.
This module provides a centralized service for tracking and broadcasting
real-time progress updates for downloads, scans, queue changes, and
system events. It integrates with the WebSocket service to push updates
to connected clients.
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any, Callable, Dict, Optional
import structlog
logger = structlog.get_logger(__name__)
class ProgressType(str, Enum):
"""Types of progress updates."""
DOWNLOAD = "download"
SCAN = "scan"
QUEUE = "queue"
SYSTEM = "system"
ERROR = "error"
class ProgressStatus(str, Enum):
"""Status of a progress operation."""
STARTED = "started"
IN_PROGRESS = "in_progress"
COMPLETED = "completed"
FAILED = "failed"
CANCELLED = "cancelled"
@dataclass
class ProgressUpdate:
"""Represents a progress update event.
Attributes:
id: Unique identifier for this progress operation
type: Type of progress (download, scan, etc.)
status: Current status of the operation
title: Human-readable title
message: Detailed message
percent: Completion percentage (0-100)
current: Current progress value
total: Total progress value
metadata: Additional metadata
started_at: When operation started
updated_at: When last updated
"""
id: str
type: ProgressType
status: ProgressStatus
title: str
message: str = ""
percent: float = 0.0
current: int = 0
total: int = 0
metadata: Dict[str, Any] = field(default_factory=dict)
started_at: datetime = field(default_factory=datetime.utcnow)
updated_at: datetime = field(default_factory=datetime.utcnow)
def to_dict(self) -> Dict[str, Any]:
"""Convert progress update to dictionary."""
return {
"id": self.id,
"type": self.type.value,
"status": self.status.value,
"title": self.title,
"message": self.message,
"percent": round(self.percent, 2),
"current": self.current,
"total": self.total,
"metadata": self.metadata,
"started_at": self.started_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
}
class ProgressServiceError(Exception):
"""Service-level exception for progress operations."""
class ProgressService:
"""Manages real-time progress updates and broadcasting.
Features:
- Track multiple concurrent progress operations
- Calculate progress percentages and rates
- Broadcast updates via WebSocket
- Manage progress lifecycle (start, update, complete, fail)
- Support for different progress types (download, scan, queue)
"""
def __init__(self):
"""Initialize the progress service."""
# Active progress operations: id -> ProgressUpdate
self._active_progress: Dict[str, ProgressUpdate] = {}
# Completed progress history (limited size)
self._history: Dict[str, ProgressUpdate] = {}
self._max_history_size = 50
# WebSocket broadcast callback
self._broadcast_callback: Optional[Callable] = None
# Lock for thread-safe operations
self._lock = asyncio.Lock()
logger.info("ProgressService initialized")
def set_broadcast_callback(self, callback: Callable) -> None:
"""Set callback for broadcasting progress updates via WebSocket.
Args:
callback: Async function to call for broadcasting updates
"""
self._broadcast_callback = callback
logger.debug("Progress broadcast callback registered")
async def _broadcast(self, update: ProgressUpdate, room: str) -> None:
"""Broadcast progress update to WebSocket clients.
Args:
update: Progress update to broadcast
room: WebSocket room to broadcast to
"""
if self._broadcast_callback:
try:
await self._broadcast_callback(
message_type=f"{update.type.value}_progress",
data=update.to_dict(),
room=room,
)
except Exception as e:
logger.error(
"Failed to broadcast progress update",
error=str(e),
progress_id=update.id,
)
async def start_progress(
self,
progress_id: str,
progress_type: ProgressType,
title: str,
total: int = 0,
message: str = "",
metadata: Optional[Dict[str, Any]] = None,
) -> ProgressUpdate:
"""Start a new progress operation.
Args:
progress_id: Unique identifier for this progress
progress_type: Type of progress operation
title: Human-readable title
total: Total items/bytes to process
message: Initial message
metadata: Additional metadata
Returns:
Created progress update object
Raises:
ProgressServiceError: If progress already exists
"""
async with self._lock:
if progress_id in self._active_progress:
raise ProgressServiceError(
f"Progress with id '{progress_id}' already exists"
)
update = ProgressUpdate(
id=progress_id,
type=progress_type,
status=ProgressStatus.STARTED,
title=title,
message=message,
total=total,
metadata=metadata or {},
)
self._active_progress[progress_id] = update
logger.info(
"Progress started",
progress_id=progress_id,
type=progress_type.value,
title=title,
)
# Broadcast to appropriate room
room = f"{progress_type.value}_progress"
await self._broadcast(update, room)
return update
async def update_progress(
self,
progress_id: str,
current: Optional[int] = None,
total: Optional[int] = None,
message: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
force_broadcast: bool = False,
) -> ProgressUpdate:
"""Update an existing progress operation.
Args:
progress_id: Progress identifier
current: Current progress value
total: Updated total value
message: Updated message
metadata: Additional metadata to merge
force_broadcast: Force broadcasting even for small changes
Returns:
Updated progress object
Raises:
ProgressServiceError: If progress not found
"""
async with self._lock:
if progress_id not in self._active_progress:
raise ProgressServiceError(
f"Progress with id '{progress_id}' not found"
)
update = self._active_progress[progress_id]
old_percent = update.percent
# Update fields
if current is not None:
update.current = current
if total is not None:
update.total = total
if message is not None:
update.message = message
if metadata:
update.metadata.update(metadata)
# Calculate percentage
if update.total > 0:
update.percent = (update.current / update.total) * 100
else:
update.percent = 0.0
update.status = ProgressStatus.IN_PROGRESS
update.updated_at = datetime.utcnow()
# Only broadcast if significant change or forced
percent_change = abs(update.percent - old_percent)
should_broadcast = force_broadcast or percent_change >= 1.0
if should_broadcast:
room = f"{update.type.value}_progress"
await self._broadcast(update, room)
return update
async def complete_progress(
self,
progress_id: str,
message: str = "Completed successfully",
metadata: Optional[Dict[str, Any]] = None,
) -> ProgressUpdate:
"""Mark a progress operation as completed.
Args:
progress_id: Progress identifier
message: Completion message
metadata: Additional metadata
Returns:
Completed progress object
Raises:
ProgressServiceError: If progress not found
"""
async with self._lock:
if progress_id not in self._active_progress:
raise ProgressServiceError(
f"Progress with id '{progress_id}' not found"
)
update = self._active_progress[progress_id]
update.status = ProgressStatus.COMPLETED
update.message = message
update.percent = 100.0
update.current = update.total
update.updated_at = datetime.utcnow()
if metadata:
update.metadata.update(metadata)
# Move to history
del self._active_progress[progress_id]
self._add_to_history(update)
logger.info(
"Progress completed",
progress_id=progress_id,
type=update.type.value,
)
# Broadcast completion
room = f"{update.type.value}_progress"
await self._broadcast(update, room)
return update
async def fail_progress(
self,
progress_id: str,
error_message: str,
metadata: Optional[Dict[str, Any]] = None,
) -> ProgressUpdate:
"""Mark a progress operation as failed.
Args:
progress_id: Progress identifier
error_message: Error description
metadata: Additional error metadata
Returns:
Failed progress object
Raises:
ProgressServiceError: If progress not found
"""
async with self._lock:
if progress_id not in self._active_progress:
raise ProgressServiceError(
f"Progress with id '{progress_id}' not found"
)
update = self._active_progress[progress_id]
update.status = ProgressStatus.FAILED
update.message = error_message
update.updated_at = datetime.utcnow()
if metadata:
update.metadata.update(metadata)
# Move to history
del self._active_progress[progress_id]
self._add_to_history(update)
logger.error(
"Progress failed",
progress_id=progress_id,
type=update.type.value,
error=error_message,
)
# Broadcast failure
room = f"{update.type.value}_progress"
await self._broadcast(update, room)
return update
async def cancel_progress(
self,
progress_id: str,
message: str = "Cancelled by user",
) -> ProgressUpdate:
"""Cancel a progress operation.
Args:
progress_id: Progress identifier
message: Cancellation message
Returns:
Cancelled progress object
Raises:
ProgressServiceError: If progress not found
"""
async with self._lock:
if progress_id not in self._active_progress:
raise ProgressServiceError(
f"Progress with id '{progress_id}' not found"
)
update = self._active_progress[progress_id]
update.status = ProgressStatus.CANCELLED
update.message = message
update.updated_at = datetime.utcnow()
# Move to history
del self._active_progress[progress_id]
self._add_to_history(update)
logger.info(
"Progress cancelled",
progress_id=progress_id,
type=update.type.value,
)
# Broadcast cancellation
room = f"{update.type.value}_progress"
await self._broadcast(update, room)
return update
def _add_to_history(self, update: ProgressUpdate) -> None:
"""Add completed progress to history with size limit."""
self._history[update.id] = update
# Maintain history size limit
if len(self._history) > self._max_history_size:
# Remove oldest entries
oldest_keys = sorted(
self._history.keys(),
key=lambda k: self._history[k].updated_at,
)[: len(self._history) - self._max_history_size]
for key in oldest_keys:
del self._history[key]
async def get_progress(self, progress_id: str) -> Optional[ProgressUpdate]:
"""Get current progress state.
Args:
progress_id: Progress identifier
Returns:
Progress update object or None if not found
"""
async with self._lock:
if progress_id in self._active_progress:
return self._active_progress[progress_id]
if progress_id in self._history:
return self._history[progress_id]
return None
async def get_all_active_progress(
self, progress_type: Optional[ProgressType] = None
) -> Dict[str, ProgressUpdate]:
"""Get all active progress operations.
Args:
progress_type: Optional filter by progress type
Returns:
Dictionary of progress_id -> ProgressUpdate
"""
async with self._lock:
if progress_type:
return {
pid: update
for pid, update in self._active_progress.items()
if update.type == progress_type
}
return self._active_progress.copy()
async def clear_history(self) -> None:
"""Clear progress history."""
async with self._lock:
self._history.clear()
logger.info("Progress history cleared")
# Global singleton instance
_progress_service: Optional[ProgressService] = None
def get_progress_service() -> ProgressService:
"""Get or create the global progress service instance.
Returns:
Global ProgressService instance
"""
global _progress_service
if _progress_service is None:
_progress_service = ProgressService()
return _progress_service