581 lines
18 KiB
Python
581 lines
18 KiB
Python
"""Progress service for managing real-time progress updates.
|
|
|
|
This module provides a centralized service for tracking and broadcasting
|
|
real-time progress updates for downloads, scans, queue changes, and
|
|
system events. It integrates with the WebSocket service to push updates
|
|
to connected clients.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from enum import Enum
|
|
from typing import Any, Callable, Dict, List, Optional
|
|
|
|
import structlog
|
|
|
|
logger = structlog.get_logger(__name__)
|
|
|
|
|
|
class ProgressType(str, Enum):
|
|
"""Types of progress updates."""
|
|
|
|
DOWNLOAD = "download"
|
|
SCAN = "scan"
|
|
QUEUE = "queue"
|
|
SYSTEM = "system"
|
|
ERROR = "error"
|
|
|
|
|
|
class ProgressStatus(str, Enum):
|
|
"""Status of a progress operation."""
|
|
|
|
STARTED = "started"
|
|
IN_PROGRESS = "in_progress"
|
|
COMPLETED = "completed"
|
|
FAILED = "failed"
|
|
CANCELLED = "cancelled"
|
|
|
|
|
|
@dataclass
|
|
class ProgressUpdate:
|
|
"""Represents a progress update event.
|
|
|
|
Attributes:
|
|
id: Unique identifier for this progress operation
|
|
type: Type of progress (download, scan, etc.)
|
|
status: Current status of the operation
|
|
title: Human-readable title
|
|
message: Detailed message
|
|
percent: Completion percentage (0-100)
|
|
current: Current progress value
|
|
total: Total progress value
|
|
metadata: Additional metadata
|
|
started_at: When operation started
|
|
updated_at: When last updated
|
|
"""
|
|
|
|
id: str
|
|
type: ProgressType
|
|
status: ProgressStatus
|
|
title: str
|
|
message: str = ""
|
|
percent: float = 0.0
|
|
current: int = 0
|
|
total: int = 0
|
|
metadata: Dict[str, Any] = field(default_factory=dict)
|
|
started_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
updated_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert progress update to dictionary."""
|
|
return {
|
|
"id": self.id,
|
|
"type": self.type.value,
|
|
"status": self.status.value,
|
|
"title": self.title,
|
|
"message": self.message,
|
|
"percent": round(self.percent, 2),
|
|
"current": self.current,
|
|
"total": self.total,
|
|
"metadata": self.metadata,
|
|
"started_at": self.started_at.isoformat(),
|
|
"updated_at": self.updated_at.isoformat(),
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class ProgressEvent:
|
|
"""Represents a progress event for subscribers.
|
|
|
|
Attributes:
|
|
event_type: Type of event (e.g., 'download_progress')
|
|
progress_id: Unique identifier for the progress operation
|
|
progress: The progress update data
|
|
room: WebSocket room to broadcast to (default: 'progress')
|
|
"""
|
|
|
|
event_type: str
|
|
progress_id: str
|
|
progress: ProgressUpdate
|
|
room: str = "progress"
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Convert event to dictionary for broadcasting."""
|
|
return {
|
|
"type": self.event_type,
|
|
"data": self.progress.to_dict(),
|
|
}
|
|
|
|
|
|
class ProgressServiceError(Exception):
|
|
"""Service-level exception for progress operations."""
|
|
|
|
|
|
class ProgressService:
|
|
"""Manages real-time progress updates and broadcasting.
|
|
|
|
Features:
|
|
- Track multiple concurrent progress operations
|
|
- Calculate progress percentages and rates
|
|
- Broadcast updates via WebSocket
|
|
- Manage progress lifecycle (start, update, complete, fail)
|
|
- Support for different progress types (download, scan, queue)
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""Initialize the progress service."""
|
|
# Active progress operations: id -> ProgressUpdate
|
|
self._active_progress: Dict[str, ProgressUpdate] = {}
|
|
|
|
# Completed progress history (limited size)
|
|
self._history: Dict[str, ProgressUpdate] = {}
|
|
self._max_history_size = 50
|
|
|
|
# Event subscribers: event_name -> list of handlers
|
|
self._event_handlers: Dict[
|
|
str, List[Callable[[ProgressEvent], None]]
|
|
] = {}
|
|
|
|
# Lock for thread-safe operations
|
|
self._lock = asyncio.Lock()
|
|
|
|
logger.info("ProgressService initialized")
|
|
|
|
def subscribe(
|
|
self, event_name: str, handler: Callable[[ProgressEvent], None]
|
|
) -> None:
|
|
"""Subscribe to progress events.
|
|
|
|
Args:
|
|
event_name: Name of event to subscribe to
|
|
(e.g., 'progress_updated')
|
|
handler: Async function to call when event occurs
|
|
"""
|
|
if event_name not in self._event_handlers:
|
|
self._event_handlers[event_name] = []
|
|
|
|
self._event_handlers[event_name].append(handler)
|
|
logger.debug("Event handler subscribed", event_type=event_name)
|
|
|
|
def unsubscribe(
|
|
self, event_name: str, handler: Callable[[ProgressEvent], None]
|
|
) -> None:
|
|
"""Unsubscribe from progress events.
|
|
|
|
Args:
|
|
event_name: Name of event to unsubscribe from
|
|
handler: Handler function to remove
|
|
"""
|
|
if event_name in self._event_handlers:
|
|
try:
|
|
self._event_handlers[event_name].remove(handler)
|
|
logger.debug(
|
|
"Event handler unsubscribed", event_type=event_name
|
|
)
|
|
except ValueError:
|
|
logger.warning(
|
|
"Handler not found for unsubscribe",
|
|
event_type=event_name,
|
|
)
|
|
|
|
async def _emit_event(self, event: ProgressEvent) -> None:
|
|
"""Emit event to all subscribers.
|
|
|
|
Args:
|
|
event: Progress event to emit
|
|
|
|
Note:
|
|
Errors in individual handlers are logged but do not
|
|
prevent other handlers from executing.
|
|
"""
|
|
event_name = "progress_updated"
|
|
|
|
if event_name in self._event_handlers:
|
|
handlers = self._event_handlers[event_name]
|
|
if handlers:
|
|
# Execute all handlers, capturing exceptions
|
|
tasks = [handler(event) for handler in handlers]
|
|
# Ignore type error - tasks will be coroutines at runtime
|
|
results = await asyncio.gather(
|
|
*tasks, return_exceptions=True
|
|
) # type: ignore[arg-type]
|
|
|
|
# Log any exceptions that occurred
|
|
for idx, result in enumerate(results):
|
|
if isinstance(result, Exception):
|
|
logger.error(
|
|
"Event handler raised exception",
|
|
event_type=event_name,
|
|
error=str(result),
|
|
handler_index=idx,
|
|
)
|
|
|
|
async def start_progress(
|
|
self,
|
|
progress_id: str,
|
|
progress_type: ProgressType,
|
|
title: str,
|
|
total: int = 0,
|
|
message: str = "",
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
) -> ProgressUpdate:
|
|
"""Start a new progress operation.
|
|
|
|
Args:
|
|
progress_id: Unique identifier for this progress
|
|
progress_type: Type of progress operation
|
|
title: Human-readable title
|
|
total: Total items/bytes to process
|
|
message: Initial message
|
|
metadata: Additional metadata
|
|
|
|
Returns:
|
|
Created progress update object
|
|
|
|
Raises:
|
|
ProgressServiceError: If progress already exists
|
|
"""
|
|
async with self._lock:
|
|
if progress_id in self._active_progress:
|
|
raise ProgressServiceError(
|
|
f"Progress with id '{progress_id}' already exists"
|
|
)
|
|
|
|
update = ProgressUpdate(
|
|
id=progress_id,
|
|
type=progress_type,
|
|
status=ProgressStatus.STARTED,
|
|
title=title,
|
|
message=message,
|
|
total=total,
|
|
metadata=metadata or {},
|
|
)
|
|
|
|
self._active_progress[progress_id] = update
|
|
|
|
logger.info(
|
|
"Progress started",
|
|
progress_id=progress_id,
|
|
type=progress_type.value,
|
|
title=title,
|
|
)
|
|
|
|
# Emit event to subscribers
|
|
room = f"{progress_type.value}_progress"
|
|
event = ProgressEvent(
|
|
event_type=f"{progress_type.value}_progress",
|
|
progress_id=progress_id,
|
|
progress=update,
|
|
room=room,
|
|
)
|
|
await self._emit_event(event)
|
|
|
|
return update
|
|
|
|
async def update_progress(
|
|
self,
|
|
progress_id: str,
|
|
current: Optional[int] = None,
|
|
total: Optional[int] = None,
|
|
message: Optional[str] = None,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
force_broadcast: bool = False,
|
|
) -> ProgressUpdate:
|
|
"""Update an existing progress operation.
|
|
|
|
Args:
|
|
progress_id: Progress identifier
|
|
current: Current progress value
|
|
total: Updated total value
|
|
message: Updated message
|
|
metadata: Additional metadata to merge
|
|
force_broadcast: Force broadcasting even for small changes
|
|
|
|
Returns:
|
|
Updated progress object
|
|
|
|
Raises:
|
|
ProgressServiceError: If progress not found
|
|
"""
|
|
async with self._lock:
|
|
if progress_id not in self._active_progress:
|
|
raise ProgressServiceError(
|
|
f"Progress with id '{progress_id}' not found"
|
|
)
|
|
|
|
update = self._active_progress[progress_id]
|
|
old_percent = update.percent
|
|
|
|
# Update fields
|
|
if current is not None:
|
|
update.current = current
|
|
if total is not None:
|
|
update.total = total
|
|
if message is not None:
|
|
update.message = message
|
|
if metadata:
|
|
update.metadata.update(metadata)
|
|
|
|
# Calculate percentage
|
|
if update.total > 0:
|
|
update.percent = (update.current / update.total) * 100
|
|
else:
|
|
update.percent = 0.0
|
|
|
|
update.status = ProgressStatus.IN_PROGRESS
|
|
update.updated_at = datetime.now(timezone.utc)
|
|
|
|
# Only broadcast if significant change or forced
|
|
percent_change = abs(update.percent - old_percent)
|
|
should_broadcast = force_broadcast or percent_change >= 1.0
|
|
|
|
if should_broadcast:
|
|
room = f"{update.type.value}_progress"
|
|
event = ProgressEvent(
|
|
event_type=f"{update.type.value}_progress",
|
|
progress_id=progress_id,
|
|
progress=update,
|
|
room=room,
|
|
)
|
|
await self._emit_event(event)
|
|
|
|
return update
|
|
|
|
async def complete_progress(
|
|
self,
|
|
progress_id: str,
|
|
message: str = "Completed successfully",
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
) -> ProgressUpdate:
|
|
"""Mark a progress operation as completed.
|
|
|
|
Args:
|
|
progress_id: Progress identifier
|
|
message: Completion message
|
|
metadata: Additional metadata
|
|
|
|
Returns:
|
|
Completed progress object
|
|
|
|
Raises:
|
|
ProgressServiceError: If progress not found
|
|
"""
|
|
async with self._lock:
|
|
if progress_id not in self._active_progress:
|
|
raise ProgressServiceError(
|
|
f"Progress with id '{progress_id}' not found"
|
|
)
|
|
|
|
update = self._active_progress[progress_id]
|
|
update.status = ProgressStatus.COMPLETED
|
|
update.message = message
|
|
update.percent = 100.0
|
|
update.current = update.total
|
|
update.updated_at = datetime.now(timezone.utc)
|
|
|
|
if metadata:
|
|
update.metadata.update(metadata)
|
|
|
|
# Move to history
|
|
del self._active_progress[progress_id]
|
|
self._add_to_history(update)
|
|
|
|
logger.info(
|
|
"Progress completed",
|
|
progress_id=progress_id,
|
|
type=update.type.value,
|
|
)
|
|
|
|
# Emit completion event
|
|
room = f"{update.type.value}_progress"
|
|
event = ProgressEvent(
|
|
event_type=f"{update.type.value}_progress",
|
|
progress_id=progress_id,
|
|
progress=update,
|
|
room=room,
|
|
)
|
|
await self._emit_event(event)
|
|
|
|
return update
|
|
|
|
async def fail_progress(
|
|
self,
|
|
progress_id: str,
|
|
error_message: str,
|
|
metadata: Optional[Dict[str, Any]] = None,
|
|
) -> ProgressUpdate:
|
|
"""Mark a progress operation as failed.
|
|
|
|
Args:
|
|
progress_id: Progress identifier
|
|
error_message: Error description
|
|
metadata: Additional error metadata
|
|
|
|
Returns:
|
|
Failed progress object
|
|
|
|
Raises:
|
|
ProgressServiceError: If progress not found
|
|
"""
|
|
async with self._lock:
|
|
if progress_id not in self._active_progress:
|
|
raise ProgressServiceError(
|
|
f"Progress with id '{progress_id}' not found"
|
|
)
|
|
|
|
update = self._active_progress[progress_id]
|
|
update.status = ProgressStatus.FAILED
|
|
update.message = error_message
|
|
update.updated_at = datetime.now(timezone.utc)
|
|
|
|
if metadata:
|
|
update.metadata.update(metadata)
|
|
|
|
# Move to history
|
|
del self._active_progress[progress_id]
|
|
self._add_to_history(update)
|
|
|
|
logger.error(
|
|
"Progress failed",
|
|
progress_id=progress_id,
|
|
type=update.type.value,
|
|
error=error_message,
|
|
)
|
|
|
|
# Emit failure event
|
|
room = f"{update.type.value}_progress"
|
|
event = ProgressEvent(
|
|
event_type=f"{update.type.value}_progress",
|
|
progress_id=progress_id,
|
|
progress=update,
|
|
room=room,
|
|
)
|
|
await self._emit_event(event)
|
|
|
|
return update
|
|
|
|
async def cancel_progress(
|
|
self,
|
|
progress_id: str,
|
|
message: str = "Cancelled by user",
|
|
) -> ProgressUpdate:
|
|
"""Cancel a progress operation.
|
|
|
|
Args:
|
|
progress_id: Progress identifier
|
|
message: Cancellation message
|
|
|
|
Returns:
|
|
Cancelled progress object
|
|
|
|
Raises:
|
|
ProgressServiceError: If progress not found
|
|
"""
|
|
async with self._lock:
|
|
if progress_id not in self._active_progress:
|
|
raise ProgressServiceError(
|
|
f"Progress with id '{progress_id}' not found"
|
|
)
|
|
|
|
update = self._active_progress[progress_id]
|
|
update.status = ProgressStatus.CANCELLED
|
|
update.message = message
|
|
update.updated_at = datetime.now(timezone.utc)
|
|
|
|
# Move to history
|
|
del self._active_progress[progress_id]
|
|
self._add_to_history(update)
|
|
|
|
logger.info(
|
|
"Progress cancelled",
|
|
progress_id=progress_id,
|
|
type=update.type.value,
|
|
)
|
|
|
|
# Emit cancellation event
|
|
room = f"{update.type.value}_progress"
|
|
event = ProgressEvent(
|
|
event_type=f"{update.type.value}_progress",
|
|
progress_id=progress_id,
|
|
progress=update,
|
|
room=room,
|
|
)
|
|
await self._emit_event(event)
|
|
|
|
return update
|
|
|
|
def _add_to_history(self, update: ProgressUpdate) -> None:
|
|
"""Add completed progress to history with size limit."""
|
|
self._history[update.id] = update
|
|
|
|
# Maintain history size limit
|
|
if len(self._history) > self._max_history_size:
|
|
# Remove oldest entries
|
|
oldest_keys = sorted(
|
|
self._history.keys(),
|
|
key=lambda k: self._history[k].updated_at,
|
|
)[: len(self._history) - self._max_history_size]
|
|
|
|
for key in oldest_keys:
|
|
del self._history[key]
|
|
|
|
async def get_progress(self, progress_id: str) -> Optional[ProgressUpdate]:
|
|
"""Get current progress state.
|
|
|
|
Args:
|
|
progress_id: Progress identifier
|
|
|
|
Returns:
|
|
Progress update object or None if not found
|
|
"""
|
|
async with self._lock:
|
|
if progress_id in self._active_progress:
|
|
return self._active_progress[progress_id]
|
|
if progress_id in self._history:
|
|
return self._history[progress_id]
|
|
return None
|
|
|
|
async def get_all_active_progress(
|
|
self, progress_type: Optional[ProgressType] = None
|
|
) -> Dict[str, ProgressUpdate]:
|
|
"""Get all active progress operations.
|
|
|
|
Args:
|
|
progress_type: Optional filter by progress type
|
|
|
|
Returns:
|
|
Dictionary of progress_id -> ProgressUpdate
|
|
"""
|
|
async with self._lock:
|
|
if progress_type:
|
|
return {
|
|
pid: update
|
|
for pid, update in self._active_progress.items()
|
|
if update.type == progress_type
|
|
}
|
|
return self._active_progress.copy()
|
|
|
|
async def clear_history(self) -> None:
|
|
"""Clear progress history."""
|
|
async with self._lock:
|
|
self._history.clear()
|
|
logger.info("Progress history cleared")
|
|
|
|
|
|
# Global singleton instance
|
|
_progress_service: Optional[ProgressService] = None
|
|
|
|
|
|
def get_progress_service() -> ProgressService:
|
|
"""Get or create the global progress service instance.
|
|
|
|
Returns:
|
|
Global ProgressService instance
|
|
"""
|
|
global _progress_service
|
|
if _progress_service is None:
|
|
_progress_service = ProgressService()
|
|
return _progress_service
|