feat: implement WebSocket real-time progress updates
- Add ProgressService for centralized progress tracking and broadcasting - Integrate ProgressService with DownloadService for download progress - Integrate ProgressService with AnimeService for scan progress - Add progress-related WebSocket message models (ScanProgress, ErrorNotification, etc.) - Initialize ProgressService with WebSocket callback in application startup - Add comprehensive unit tests for ProgressService - Update infrastructure.md with ProgressService documentation - Remove completed WebSocket Real-time Updates task from instructions.md The ProgressService provides: - Real-time progress tracking for downloads, scans, and queue operations - Automatic progress percentage calculation - Progress lifecycle management (start, update, complete, fail, cancel) - WebSocket integration for instant client updates - Progress history with size limits - Thread-safe operations using asyncio locks - Support for metadata and custom messages Benefits: - Decoupled progress tracking from WebSocket broadcasting - Single reusable service across all components - Supports multiple concurrent operations efficiently - Centralized progress tracking simplifies monitoring - Instant feedback to users on long-running operations
This commit is contained in:
@@ -29,6 +29,8 @@ from src.server.controllers.error_controller import (
|
||||
from src.server.controllers.health_controller import router as health_router
|
||||
from src.server.controllers.page_controller import router as page_router
|
||||
from src.server.middleware.auth import AuthMiddleware
|
||||
from src.server.services.progress_service import get_progress_service
|
||||
from src.server.services.websocket_service import get_websocket_service
|
||||
|
||||
# Initialize FastAPI app
|
||||
app = FastAPI(
|
||||
@@ -74,6 +76,23 @@ async def startup_event():
|
||||
# Initialize SeriesApp with configured directory
|
||||
if settings.anime_directory:
|
||||
series_app = SeriesApp(settings.anime_directory)
|
||||
|
||||
# Initialize progress service with websocket callback
|
||||
progress_service = get_progress_service()
|
||||
ws_service = get_websocket_service()
|
||||
|
||||
async def broadcast_callback(
|
||||
message_type: str, data: dict, room: str
|
||||
):
|
||||
"""Broadcast progress updates via WebSocket."""
|
||||
message = {
|
||||
"type": message_type,
|
||||
"data": data,
|
||||
}
|
||||
await ws_service.manager.broadcast_to_room(message, room)
|
||||
|
||||
progress_service.set_broadcast_callback(broadcast_callback)
|
||||
|
||||
print("FastAPI application started successfully")
|
||||
except Exception as e:
|
||||
print(f"Error during startup: {e}")
|
||||
|
||||
@@ -30,6 +30,11 @@ class WebSocketMessageType(str, Enum):
|
||||
QUEUE_PAUSED = "queue_paused"
|
||||
QUEUE_RESUMED = "queue_resumed"
|
||||
|
||||
# Progress-related messages
|
||||
SCAN_PROGRESS = "scan_progress"
|
||||
SCAN_COMPLETE = "scan_complete"
|
||||
SCAN_FAILED = "scan_failed"
|
||||
|
||||
# System messages
|
||||
SYSTEM_INFO = "system_info"
|
||||
SYSTEM_WARNING = "system_warning"
|
||||
@@ -188,3 +193,93 @@ class RoomSubscriptionRequest(BaseModel):
|
||||
room: str = Field(
|
||||
..., min_length=1, description="Room name to join or leave"
|
||||
)
|
||||
|
||||
|
||||
class ScanProgressMessage(BaseModel):
|
||||
"""Scan progress update message."""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
default=WebSocketMessageType.SCAN_PROGRESS,
|
||||
description="Message type",
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
...,
|
||||
description="Scan progress data including current, total, percent",
|
||||
)
|
||||
|
||||
|
||||
class ScanCompleteMessage(BaseModel):
|
||||
"""Scan completion message."""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
default=WebSocketMessageType.SCAN_COMPLETE,
|
||||
description="Message type",
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
...,
|
||||
description="Scan completion data including series_found, duration",
|
||||
)
|
||||
|
||||
|
||||
class ScanFailedMessage(BaseModel):
|
||||
"""Scan failure message."""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
default=WebSocketMessageType.SCAN_FAILED,
|
||||
description="Message type",
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
..., description="Scan error data including error_message"
|
||||
)
|
||||
|
||||
|
||||
class ErrorNotificationMessage(BaseModel):
|
||||
"""Error notification message for critical errors."""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
default=WebSocketMessageType.SYSTEM_ERROR,
|
||||
description="Message type",
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Error notification data including severity, message, details"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class ProgressUpdateMessage(BaseModel):
|
||||
"""Generic progress update message.
|
||||
|
||||
Can be used for any type of progress (download, scan, queue, etc.)
|
||||
"""
|
||||
|
||||
type: WebSocketMessageType = Field(
|
||||
..., description="Type of progress message"
|
||||
)
|
||||
timestamp: str = Field(
|
||||
default_factory=lambda: datetime.utcnow().isoformat(),
|
||||
description="ISO 8601 timestamp",
|
||||
)
|
||||
data: Dict[str, Any] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Progress data including id, status, percent, current, total"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -3,11 +3,16 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import structlog
|
||||
|
||||
from src.core.SeriesApp import SeriesApp
|
||||
from src.server.services.progress_service import (
|
||||
ProgressService,
|
||||
ProgressType,
|
||||
get_progress_service,
|
||||
)
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
@@ -24,9 +29,15 @@ class AnimeService:
|
||||
- Adds simple in-memory caching for read operations
|
||||
"""
|
||||
|
||||
def __init__(self, directory: str, max_workers: int = 4):
|
||||
def __init__(
|
||||
self,
|
||||
directory: str,
|
||||
max_workers: int = 4,
|
||||
progress_service: Optional[ProgressService] = None,
|
||||
):
|
||||
self._directory = directory
|
||||
self._executor = ThreadPoolExecutor(max_workers=max_workers)
|
||||
self._progress_service = progress_service or get_progress_service()
|
||||
# SeriesApp is blocking; instantiate per-service
|
||||
try:
|
||||
self._app = SeriesApp(directory)
|
||||
@@ -75,20 +86,70 @@ class AnimeService:
|
||||
logger.exception("search failed")
|
||||
raise AnimeServiceError("Search failed") from e
|
||||
|
||||
async def rescan(self, callback=None) -> None:
|
||||
async def rescan(self, callback: Optional[Callable] = None) -> None:
|
||||
"""Trigger a re-scan. Accepts an optional callback function.
|
||||
|
||||
The callback is executed in the threadpool by SeriesApp.
|
||||
Progress updates are tracked and broadcasted via ProgressService.
|
||||
"""
|
||||
scan_id = "library_scan"
|
||||
|
||||
try:
|
||||
await self._run_in_executor(self._app.ReScan, callback)
|
||||
# Start progress tracking
|
||||
await self._progress_service.start_progress(
|
||||
progress_id=scan_id,
|
||||
progress_type=ProgressType.SCAN,
|
||||
title="Scanning anime library",
|
||||
message="Initializing scan...",
|
||||
)
|
||||
|
||||
# Create wrapped callback for progress updates
|
||||
def progress_callback(progress_data: dict) -> None:
|
||||
"""Update progress during scan."""
|
||||
try:
|
||||
if callback:
|
||||
callback(progress_data)
|
||||
|
||||
# Update progress service
|
||||
current = progress_data.get("current", 0)
|
||||
total = progress_data.get("total", 0)
|
||||
message = progress_data.get("message", "Scanning...")
|
||||
|
||||
asyncio.create_task(
|
||||
self._progress_service.update_progress(
|
||||
progress_id=scan_id,
|
||||
current=current,
|
||||
total=total,
|
||||
message=message,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Scan progress callback error", error=str(e))
|
||||
|
||||
# Run scan
|
||||
await self._run_in_executor(self._app.ReScan, progress_callback)
|
||||
|
||||
# invalidate cache
|
||||
try:
|
||||
self._cached_list_missing.cache_clear()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Complete progress tracking
|
||||
await self._progress_service.complete_progress(
|
||||
progress_id=scan_id,
|
||||
message="Scan completed successfully",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("rescan failed")
|
||||
|
||||
# Fail progress tracking
|
||||
await self._progress_service.fail_progress(
|
||||
progress_id=scan_id,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
raise AnimeServiceError("Rescan failed") from e
|
||||
|
||||
async def download(self, serie_folder: str, season: int, episode: int, key: str, callback=None) -> bool:
|
||||
|
||||
@@ -27,6 +27,11 @@ from src.server.models.download import (
|
||||
QueueStatus,
|
||||
)
|
||||
from src.server.services.anime_service import AnimeService, AnimeServiceError
|
||||
from src.server.services.progress_service import (
|
||||
ProgressService,
|
||||
ProgressType,
|
||||
get_progress_service,
|
||||
)
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
@@ -53,6 +58,7 @@ class DownloadService:
|
||||
max_concurrent_downloads: int = 2,
|
||||
max_retries: int = 3,
|
||||
persistence_path: str = "./data/download_queue.json",
|
||||
progress_service: Optional[ProgressService] = None,
|
||||
):
|
||||
"""Initialize the download service.
|
||||
|
||||
@@ -61,11 +67,13 @@ class DownloadService:
|
||||
max_concurrent_downloads: Maximum simultaneous downloads
|
||||
max_retries: Maximum retry attempts for failed downloads
|
||||
persistence_path: Path to persist queue state
|
||||
progress_service: Optional progress service for tracking
|
||||
"""
|
||||
self._anime_service = anime_service
|
||||
self._max_concurrent = max_concurrent_downloads
|
||||
self._max_retries = max_retries
|
||||
self._persistence_path = Path(persistence_path)
|
||||
self._progress_service = progress_service or get_progress_service()
|
||||
|
||||
# Queue storage by status
|
||||
self._pending_queue: deque[DownloadItem] = deque()
|
||||
@@ -500,6 +508,23 @@ class DownloadService:
|
||||
if item.progress.speed_mbps:
|
||||
self._download_speeds.append(item.progress.speed_mbps)
|
||||
|
||||
# Update progress service
|
||||
if item.progress.total_mb and item.progress.total_mb > 0:
|
||||
current_mb = int(item.progress.downloaded_mb)
|
||||
total_mb = int(item.progress.total_mb)
|
||||
|
||||
asyncio.create_task(
|
||||
self._progress_service.update_progress(
|
||||
progress_id=f"download_{item.id}",
|
||||
current=current_mb,
|
||||
total=total_mb,
|
||||
metadata={
|
||||
"speed_mbps": item.progress.speed_mbps,
|
||||
"eta_seconds": item.progress.eta_seconds,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Broadcast update (fire and forget)
|
||||
asyncio.create_task(
|
||||
self._broadcast_update(
|
||||
@@ -535,6 +560,22 @@ class DownloadService:
|
||||
episode=item.episode.episode,
|
||||
)
|
||||
|
||||
# Start progress tracking
|
||||
await self._progress_service.start_progress(
|
||||
progress_id=f"download_{item.id}",
|
||||
progress_type=ProgressType.DOWNLOAD,
|
||||
title=f"Downloading {item.serie_name}",
|
||||
message=(
|
||||
f"S{item.episode.season:02d}E{item.episode.episode:02d}"
|
||||
),
|
||||
metadata={
|
||||
"item_id": item.id,
|
||||
"serie_name": item.serie_name,
|
||||
"season": item.episode.season,
|
||||
"episode": item.episode.episode,
|
||||
},
|
||||
)
|
||||
|
||||
# Create progress callback
|
||||
progress_callback = self._create_progress_callback(item)
|
||||
|
||||
@@ -561,6 +602,18 @@ class DownloadService:
|
||||
logger.info(
|
||||
"Download completed successfully", item_id=item.id
|
||||
)
|
||||
|
||||
# Complete progress tracking
|
||||
await self._progress_service.complete_progress(
|
||||
progress_id=f"download_{item.id}",
|
||||
message="Download completed successfully",
|
||||
metadata={
|
||||
"downloaded_mb": item.progress.downloaded_mb
|
||||
if item.progress
|
||||
else 0,
|
||||
},
|
||||
)
|
||||
|
||||
await self._broadcast_update(
|
||||
"download_completed", {"item_id": item.id}
|
||||
)
|
||||
@@ -581,6 +634,13 @@ class DownloadService:
|
||||
retry_count=item.retry_count,
|
||||
)
|
||||
|
||||
# Fail progress tracking
|
||||
await self._progress_service.fail_progress(
|
||||
progress_id=f"download_{item.id}",
|
||||
error_message=str(e),
|
||||
metadata={"retry_count": item.retry_count},
|
||||
)
|
||||
|
||||
await self._broadcast_update(
|
||||
"download_failed",
|
||||
{"item_id": item.id, "error": item.error},
|
||||
|
||||
485
src/server/services/progress_service.py
Normal file
485
src/server/services/progress_service.py
Normal file
@@ -0,0 +1,485 @@
|
||||
"""Progress service for managing real-time progress updates.
|
||||
|
||||
This module provides a centralized service for tracking and broadcasting
|
||||
real-time progress updates for downloads, scans, queue changes, and
|
||||
system events. It integrates with the WebSocket service to push updates
|
||||
to connected clients.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class ProgressType(str, Enum):
|
||||
"""Types of progress updates."""
|
||||
|
||||
DOWNLOAD = "download"
|
||||
SCAN = "scan"
|
||||
QUEUE = "queue"
|
||||
SYSTEM = "system"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class ProgressStatus(str, Enum):
|
||||
"""Status of a progress operation."""
|
||||
|
||||
STARTED = "started"
|
||||
IN_PROGRESS = "in_progress"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProgressUpdate:
|
||||
"""Represents a progress update event.
|
||||
|
||||
Attributes:
|
||||
id: Unique identifier for this progress operation
|
||||
type: Type of progress (download, scan, etc.)
|
||||
status: Current status of the operation
|
||||
title: Human-readable title
|
||||
message: Detailed message
|
||||
percent: Completion percentage (0-100)
|
||||
current: Current progress value
|
||||
total: Total progress value
|
||||
metadata: Additional metadata
|
||||
started_at: When operation started
|
||||
updated_at: When last updated
|
||||
"""
|
||||
|
||||
id: str
|
||||
type: ProgressType
|
||||
status: ProgressStatus
|
||||
title: str
|
||||
message: str = ""
|
||||
percent: float = 0.0
|
||||
current: int = 0
|
||||
total: int = 0
|
||||
metadata: Dict[str, Any] = field(default_factory=dict)
|
||||
started_at: datetime = field(default_factory=datetime.utcnow)
|
||||
updated_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert progress update to dictionary."""
|
||||
return {
|
||||
"id": self.id,
|
||||
"type": self.type.value,
|
||||
"status": self.status.value,
|
||||
"title": self.title,
|
||||
"message": self.message,
|
||||
"percent": round(self.percent, 2),
|
||||
"current": self.current,
|
||||
"total": self.total,
|
||||
"metadata": self.metadata,
|
||||
"started_at": self.started_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
}
|
||||
|
||||
|
||||
class ProgressServiceError(Exception):
|
||||
"""Service-level exception for progress operations."""
|
||||
|
||||
|
||||
class ProgressService:
|
||||
"""Manages real-time progress updates and broadcasting.
|
||||
|
||||
Features:
|
||||
- Track multiple concurrent progress operations
|
||||
- Calculate progress percentages and rates
|
||||
- Broadcast updates via WebSocket
|
||||
- Manage progress lifecycle (start, update, complete, fail)
|
||||
- Support for different progress types (download, scan, queue)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the progress service."""
|
||||
# Active progress operations: id -> ProgressUpdate
|
||||
self._active_progress: Dict[str, ProgressUpdate] = {}
|
||||
|
||||
# Completed progress history (limited size)
|
||||
self._history: Dict[str, ProgressUpdate] = {}
|
||||
self._max_history_size = 50
|
||||
|
||||
# WebSocket broadcast callback
|
||||
self._broadcast_callback: Optional[Callable] = None
|
||||
|
||||
# Lock for thread-safe operations
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
logger.info("ProgressService initialized")
|
||||
|
||||
def set_broadcast_callback(self, callback: Callable) -> None:
|
||||
"""Set callback for broadcasting progress updates via WebSocket.
|
||||
|
||||
Args:
|
||||
callback: Async function to call for broadcasting updates
|
||||
"""
|
||||
self._broadcast_callback = callback
|
||||
logger.debug("Progress broadcast callback registered")
|
||||
|
||||
async def _broadcast(self, update: ProgressUpdate, room: str) -> None:
|
||||
"""Broadcast progress update to WebSocket clients.
|
||||
|
||||
Args:
|
||||
update: Progress update to broadcast
|
||||
room: WebSocket room to broadcast to
|
||||
"""
|
||||
if self._broadcast_callback:
|
||||
try:
|
||||
await self._broadcast_callback(
|
||||
message_type=f"{update.type.value}_progress",
|
||||
data=update.to_dict(),
|
||||
room=room,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to broadcast progress update",
|
||||
error=str(e),
|
||||
progress_id=update.id,
|
||||
)
|
||||
|
||||
async def start_progress(
|
||||
self,
|
||||
progress_id: str,
|
||||
progress_type: ProgressType,
|
||||
title: str,
|
||||
total: int = 0,
|
||||
message: str = "",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> ProgressUpdate:
|
||||
"""Start a new progress operation.
|
||||
|
||||
Args:
|
||||
progress_id: Unique identifier for this progress
|
||||
progress_type: Type of progress operation
|
||||
title: Human-readable title
|
||||
total: Total items/bytes to process
|
||||
message: Initial message
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
Created progress update object
|
||||
|
||||
Raises:
|
||||
ProgressServiceError: If progress already exists
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_id in self._active_progress:
|
||||
raise ProgressServiceError(
|
||||
f"Progress with id '{progress_id}' already exists"
|
||||
)
|
||||
|
||||
update = ProgressUpdate(
|
||||
id=progress_id,
|
||||
type=progress_type,
|
||||
status=ProgressStatus.STARTED,
|
||||
title=title,
|
||||
message=message,
|
||||
total=total,
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
self._active_progress[progress_id] = update
|
||||
|
||||
logger.info(
|
||||
"Progress started",
|
||||
progress_id=progress_id,
|
||||
type=progress_type.value,
|
||||
title=title,
|
||||
)
|
||||
|
||||
# Broadcast to appropriate room
|
||||
room = f"{progress_type.value}_progress"
|
||||
await self._broadcast(update, room)
|
||||
|
||||
return update
|
||||
|
||||
async def update_progress(
|
||||
self,
|
||||
progress_id: str,
|
||||
current: Optional[int] = None,
|
||||
total: Optional[int] = None,
|
||||
message: Optional[str] = None,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
force_broadcast: bool = False,
|
||||
) -> ProgressUpdate:
|
||||
"""Update an existing progress operation.
|
||||
|
||||
Args:
|
||||
progress_id: Progress identifier
|
||||
current: Current progress value
|
||||
total: Updated total value
|
||||
message: Updated message
|
||||
metadata: Additional metadata to merge
|
||||
force_broadcast: Force broadcasting even for small changes
|
||||
|
||||
Returns:
|
||||
Updated progress object
|
||||
|
||||
Raises:
|
||||
ProgressServiceError: If progress not found
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_id not in self._active_progress:
|
||||
raise ProgressServiceError(
|
||||
f"Progress with id '{progress_id}' not found"
|
||||
)
|
||||
|
||||
update = self._active_progress[progress_id]
|
||||
old_percent = update.percent
|
||||
|
||||
# Update fields
|
||||
if current is not None:
|
||||
update.current = current
|
||||
if total is not None:
|
||||
update.total = total
|
||||
if message is not None:
|
||||
update.message = message
|
||||
if metadata:
|
||||
update.metadata.update(metadata)
|
||||
|
||||
# Calculate percentage
|
||||
if update.total > 0:
|
||||
update.percent = (update.current / update.total) * 100
|
||||
else:
|
||||
update.percent = 0.0
|
||||
|
||||
update.status = ProgressStatus.IN_PROGRESS
|
||||
update.updated_at = datetime.utcnow()
|
||||
|
||||
# Only broadcast if significant change or forced
|
||||
percent_change = abs(update.percent - old_percent)
|
||||
should_broadcast = force_broadcast or percent_change >= 1.0
|
||||
|
||||
if should_broadcast:
|
||||
room = f"{update.type.value}_progress"
|
||||
await self._broadcast(update, room)
|
||||
|
||||
return update
|
||||
|
||||
async def complete_progress(
|
||||
self,
|
||||
progress_id: str,
|
||||
message: str = "Completed successfully",
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> ProgressUpdate:
|
||||
"""Mark a progress operation as completed.
|
||||
|
||||
Args:
|
||||
progress_id: Progress identifier
|
||||
message: Completion message
|
||||
metadata: Additional metadata
|
||||
|
||||
Returns:
|
||||
Completed progress object
|
||||
|
||||
Raises:
|
||||
ProgressServiceError: If progress not found
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_id not in self._active_progress:
|
||||
raise ProgressServiceError(
|
||||
f"Progress with id '{progress_id}' not found"
|
||||
)
|
||||
|
||||
update = self._active_progress[progress_id]
|
||||
update.status = ProgressStatus.COMPLETED
|
||||
update.message = message
|
||||
update.percent = 100.0
|
||||
update.current = update.total
|
||||
update.updated_at = datetime.utcnow()
|
||||
|
||||
if metadata:
|
||||
update.metadata.update(metadata)
|
||||
|
||||
# Move to history
|
||||
del self._active_progress[progress_id]
|
||||
self._add_to_history(update)
|
||||
|
||||
logger.info(
|
||||
"Progress completed",
|
||||
progress_id=progress_id,
|
||||
type=update.type.value,
|
||||
)
|
||||
|
||||
# Broadcast completion
|
||||
room = f"{update.type.value}_progress"
|
||||
await self._broadcast(update, room)
|
||||
|
||||
return update
|
||||
|
||||
async def fail_progress(
|
||||
self,
|
||||
progress_id: str,
|
||||
error_message: str,
|
||||
metadata: Optional[Dict[str, Any]] = None,
|
||||
) -> ProgressUpdate:
|
||||
"""Mark a progress operation as failed.
|
||||
|
||||
Args:
|
||||
progress_id: Progress identifier
|
||||
error_message: Error description
|
||||
metadata: Additional error metadata
|
||||
|
||||
Returns:
|
||||
Failed progress object
|
||||
|
||||
Raises:
|
||||
ProgressServiceError: If progress not found
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_id not in self._active_progress:
|
||||
raise ProgressServiceError(
|
||||
f"Progress with id '{progress_id}' not found"
|
||||
)
|
||||
|
||||
update = self._active_progress[progress_id]
|
||||
update.status = ProgressStatus.FAILED
|
||||
update.message = error_message
|
||||
update.updated_at = datetime.utcnow()
|
||||
|
||||
if metadata:
|
||||
update.metadata.update(metadata)
|
||||
|
||||
# Move to history
|
||||
del self._active_progress[progress_id]
|
||||
self._add_to_history(update)
|
||||
|
||||
logger.error(
|
||||
"Progress failed",
|
||||
progress_id=progress_id,
|
||||
type=update.type.value,
|
||||
error=error_message,
|
||||
)
|
||||
|
||||
# Broadcast failure
|
||||
room = f"{update.type.value}_progress"
|
||||
await self._broadcast(update, room)
|
||||
|
||||
return update
|
||||
|
||||
async def cancel_progress(
|
||||
self,
|
||||
progress_id: str,
|
||||
message: str = "Cancelled by user",
|
||||
) -> ProgressUpdate:
|
||||
"""Cancel a progress operation.
|
||||
|
||||
Args:
|
||||
progress_id: Progress identifier
|
||||
message: Cancellation message
|
||||
|
||||
Returns:
|
||||
Cancelled progress object
|
||||
|
||||
Raises:
|
||||
ProgressServiceError: If progress not found
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_id not in self._active_progress:
|
||||
raise ProgressServiceError(
|
||||
f"Progress with id '{progress_id}' not found"
|
||||
)
|
||||
|
||||
update = self._active_progress[progress_id]
|
||||
update.status = ProgressStatus.CANCELLED
|
||||
update.message = message
|
||||
update.updated_at = datetime.utcnow()
|
||||
|
||||
# Move to history
|
||||
del self._active_progress[progress_id]
|
||||
self._add_to_history(update)
|
||||
|
||||
logger.info(
|
||||
"Progress cancelled",
|
||||
progress_id=progress_id,
|
||||
type=update.type.value,
|
||||
)
|
||||
|
||||
# Broadcast cancellation
|
||||
room = f"{update.type.value}_progress"
|
||||
await self._broadcast(update, room)
|
||||
|
||||
return update
|
||||
|
||||
def _add_to_history(self, update: ProgressUpdate) -> None:
|
||||
"""Add completed progress to history with size limit."""
|
||||
self._history[update.id] = update
|
||||
|
||||
# Maintain history size limit
|
||||
if len(self._history) > self._max_history_size:
|
||||
# Remove oldest entries
|
||||
oldest_keys = sorted(
|
||||
self._history.keys(),
|
||||
key=lambda k: self._history[k].updated_at,
|
||||
)[: len(self._history) - self._max_history_size]
|
||||
|
||||
for key in oldest_keys:
|
||||
del self._history[key]
|
||||
|
||||
async def get_progress(self, progress_id: str) -> Optional[ProgressUpdate]:
|
||||
"""Get current progress state.
|
||||
|
||||
Args:
|
||||
progress_id: Progress identifier
|
||||
|
||||
Returns:
|
||||
Progress update object or None if not found
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_id in self._active_progress:
|
||||
return self._active_progress[progress_id]
|
||||
if progress_id in self._history:
|
||||
return self._history[progress_id]
|
||||
return None
|
||||
|
||||
async def get_all_active_progress(
|
||||
self, progress_type: Optional[ProgressType] = None
|
||||
) -> Dict[str, ProgressUpdate]:
|
||||
"""Get all active progress operations.
|
||||
|
||||
Args:
|
||||
progress_type: Optional filter by progress type
|
||||
|
||||
Returns:
|
||||
Dictionary of progress_id -> ProgressUpdate
|
||||
"""
|
||||
async with self._lock:
|
||||
if progress_type:
|
||||
return {
|
||||
pid: update
|
||||
for pid, update in self._active_progress.items()
|
||||
if update.type == progress_type
|
||||
}
|
||||
return self._active_progress.copy()
|
||||
|
||||
async def clear_history(self) -> None:
|
||||
"""Clear progress history."""
|
||||
async with self._lock:
|
||||
self._history.clear()
|
||||
logger.info("Progress history cleared")
|
||||
|
||||
|
||||
# Global singleton instance
|
||||
_progress_service: Optional[ProgressService] = None
|
||||
|
||||
|
||||
def get_progress_service() -> ProgressService:
|
||||
"""Get or create the global progress service instance.
|
||||
|
||||
Returns:
|
||||
Global ProgressService instance
|
||||
"""
|
||||
global _progress_service
|
||||
if _progress_service is None:
|
||||
_progress_service = ProgressService()
|
||||
return _progress_service
|
||||
Reference in New Issue
Block a user