fix progress events

This commit is contained in:
2025-11-07 18:40:36 +01:00
parent 5c4bd3d7e8
commit 2441730862
5 changed files with 673 additions and 249 deletions

View File

@@ -11,7 +11,7 @@ import asyncio
from dataclasses import dataclass, field
from datetime import datetime, timezone
from enum import Enum
from typing import Any, Callable, Dict, Optional
from typing import Any, Callable, Dict, List, Optional
import structlog
@@ -85,6 +85,30 @@ class ProgressUpdate:
}
@dataclass
class ProgressEvent:
"""Represents a progress event for subscribers.
Attributes:
event_type: Type of event (e.g., 'download_progress')
progress_id: Unique identifier for the progress operation
progress: The progress update data
room: WebSocket room to broadcast to (default: 'progress')
"""
event_type: str
progress_id: str
progress: ProgressUpdate
room: str = "progress"
def to_dict(self) -> Dict[str, Any]:
"""Convert event to dictionary for broadcasting."""
return {
"type": self.event_type,
"data": self.progress.to_dict(),
}
class ProgressServiceError(Exception):
"""Service-level exception for progress operations."""
@@ -109,44 +133,82 @@ class ProgressService:
self._history: Dict[str, ProgressUpdate] = {}
self._max_history_size = 50
# WebSocket broadcast callback
self._broadcast_callback: Optional[Callable] = None
# Event subscribers: event_name -> list of handlers
self._event_handlers: Dict[
str, List[Callable[[ProgressEvent], None]]
] = {}
# Lock for thread-safe operations
self._lock = asyncio.Lock()
logger.info("ProgressService initialized")
def set_broadcast_callback(self, callback: Callable) -> None:
"""Set callback for broadcasting progress updates via WebSocket.
def subscribe(
self, event_name: str, handler: Callable[[ProgressEvent], None]
) -> None:
"""Subscribe to progress events.
Args:
callback: Async function to call for broadcasting updates
event_name: Name of event to subscribe to
(e.g., 'progress_updated')
handler: Async function to call when event occurs
"""
self._broadcast_callback = callback
logger.debug("Progress broadcast callback registered")
if event_name not in self._event_handlers:
self._event_handlers[event_name] = []
async def _broadcast(self, update: ProgressUpdate, room: str) -> None:
"""Broadcast progress update to WebSocket clients.
self._event_handlers[event_name].append(handler)
logger.debug("Event handler subscribed", event=event_name)
def unsubscribe(
self, event_name: str, handler: Callable[[ProgressEvent], None]
) -> None:
"""Unsubscribe from progress events.
Args:
update: Progress update to broadcast
room: WebSocket room to broadcast to
event_name: Name of event to unsubscribe from
handler: Handler function to remove
"""
if self._broadcast_callback:
if event_name in self._event_handlers:
try:
await self._broadcast_callback(
message_type=f"{update.type.value}_progress",
data=update.to_dict(),
room=room,
)
except Exception as e:
logger.error(
"Failed to broadcast progress update",
error=str(e),
progress_id=update.id,
self._event_handlers[event_name].remove(handler)
logger.debug("Event handler unsubscribed", event=event_name)
except ValueError:
logger.warning(
"Handler not found for unsubscribe", event=event_name
)
async def _emit_event(self, event: ProgressEvent) -> None:
"""Emit event to all subscribers.
Args:
event: Progress event to emit
Note:
Errors in individual handlers are logged but do not
prevent other handlers from executing.
"""
event_name = "progress_updated"
if event_name in self._event_handlers:
handlers = self._event_handlers[event_name]
if handlers:
# Execute all handlers, capturing exceptions
tasks = [handler(event) for handler in handlers]
# Ignore type error - tasks will be coroutines at runtime
results = await asyncio.gather(
*tasks, return_exceptions=True
) # type: ignore[arg-type]
# Log any exceptions that occurred
for idx, result in enumerate(results):
if isinstance(result, Exception):
logger.error(
"Event handler raised exception",
event=event_name,
error=str(result),
handler_index=idx,
)
async def start_progress(
self,
progress_id: str,
@@ -197,9 +259,15 @@ class ProgressService:
title=title,
)
# Broadcast to appropriate room
# Emit event to subscribers
room = f"{progress_type.value}_progress"
await self._broadcast(update, room)
event = ProgressEvent(
event_type=f"{progress_type.value}_progress",
progress_id=progress_id,
progress=update,
room=room,
)
await self._emit_event(event)
return update
@@ -262,7 +330,13 @@ class ProgressService:
if should_broadcast:
room = f"{update.type.value}_progress"
await self._broadcast(update, room)
event = ProgressEvent(
event_type=f"{update.type.value}_progress",
progress_id=progress_id,
progress=update,
room=room,
)
await self._emit_event(event)
return update
@@ -311,9 +385,15 @@ class ProgressService:
type=update.type.value,
)
# Broadcast completion
# Emit completion event
room = f"{update.type.value}_progress"
await self._broadcast(update, room)
event = ProgressEvent(
event_type=f"{update.type.value}_progress",
progress_id=progress_id,
progress=update,
room=room,
)
await self._emit_event(event)
return update
@@ -361,9 +441,15 @@ class ProgressService:
error=error_message,
)
# Broadcast failure
# Emit failure event
room = f"{update.type.value}_progress"
await self._broadcast(update, room)
event = ProgressEvent(
event_type=f"{update.type.value}_progress",
progress_id=progress_id,
progress=update,
room=room,
)
await self._emit_event(event)
return update
@@ -405,9 +491,15 @@ class ProgressService:
type=update.type.value,
)
# Broadcast cancellation
# Emit cancellation event
room = f"{update.type.value}_progress"
await self._broadcast(update, room)
event = ProgressEvent(
event_type=f"{update.type.value}_progress",
progress_id=progress_id,
progress=update,
room=room,
)
await self._emit_event(event)
return update