fix progress events
This commit is contained in:
@@ -11,7 +11,7 @@ import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import structlog
|
||||
|
||||
@@ -85,6 +85,30 @@ class ProgressUpdate:
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProgressEvent:
|
||||
"""Represents a progress event for subscribers.
|
||||
|
||||
Attributes:
|
||||
event_type: Type of event (e.g., 'download_progress')
|
||||
progress_id: Unique identifier for the progress operation
|
||||
progress: The progress update data
|
||||
room: WebSocket room to broadcast to (default: 'progress')
|
||||
"""
|
||||
|
||||
event_type: str
|
||||
progress_id: str
|
||||
progress: ProgressUpdate
|
||||
room: str = "progress"
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert event to dictionary for broadcasting."""
|
||||
return {
|
||||
"type": self.event_type,
|
||||
"data": self.progress.to_dict(),
|
||||
}
|
||||
|
||||
|
||||
class ProgressServiceError(Exception):
|
||||
"""Service-level exception for progress operations."""
|
||||
|
||||
@@ -109,44 +133,82 @@ class ProgressService:
|
||||
self._history: Dict[str, ProgressUpdate] = {}
|
||||
self._max_history_size = 50
|
||||
|
||||
# WebSocket broadcast callback
|
||||
self._broadcast_callback: Optional[Callable] = None
|
||||
# Event subscribers: event_name -> list of handlers
|
||||
self._event_handlers: Dict[
|
||||
str, List[Callable[[ProgressEvent], None]]
|
||||
] = {}
|
||||
|
||||
# Lock for thread-safe operations
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
logger.info("ProgressService initialized")
|
||||
|
||||
def set_broadcast_callback(self, callback: Callable) -> None:
|
||||
"""Set callback for broadcasting progress updates via WebSocket.
|
||||
def subscribe(
|
||||
self, event_name: str, handler: Callable[[ProgressEvent], None]
|
||||
) -> None:
|
||||
"""Subscribe to progress events.
|
||||
|
||||
Args:
|
||||
callback: Async function to call for broadcasting updates
|
||||
event_name: Name of event to subscribe to
|
||||
(e.g., 'progress_updated')
|
||||
handler: Async function to call when event occurs
|
||||
"""
|
||||
self._broadcast_callback = callback
|
||||
logger.debug("Progress broadcast callback registered")
|
||||
if event_name not in self._event_handlers:
|
||||
self._event_handlers[event_name] = []
|
||||
|
||||
async def _broadcast(self, update: ProgressUpdate, room: str) -> None:
|
||||
"""Broadcast progress update to WebSocket clients.
|
||||
self._event_handlers[event_name].append(handler)
|
||||
logger.debug("Event handler subscribed", event=event_name)
|
||||
|
||||
def unsubscribe(
|
||||
self, event_name: str, handler: Callable[[ProgressEvent], None]
|
||||
) -> None:
|
||||
"""Unsubscribe from progress events.
|
||||
|
||||
Args:
|
||||
update: Progress update to broadcast
|
||||
room: WebSocket room to broadcast to
|
||||
event_name: Name of event to unsubscribe from
|
||||
handler: Handler function to remove
|
||||
"""
|
||||
if self._broadcast_callback:
|
||||
if event_name in self._event_handlers:
|
||||
try:
|
||||
await self._broadcast_callback(
|
||||
message_type=f"{update.type.value}_progress",
|
||||
data=update.to_dict(),
|
||||
room=room,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Failed to broadcast progress update",
|
||||
error=str(e),
|
||||
progress_id=update.id,
|
||||
self._event_handlers[event_name].remove(handler)
|
||||
logger.debug("Event handler unsubscribed", event=event_name)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
"Handler not found for unsubscribe", event=event_name
|
||||
)
|
||||
|
||||
async def _emit_event(self, event: ProgressEvent) -> None:
|
||||
"""Emit event to all subscribers.
|
||||
|
||||
Args:
|
||||
event: Progress event to emit
|
||||
|
||||
Note:
|
||||
Errors in individual handlers are logged but do not
|
||||
prevent other handlers from executing.
|
||||
"""
|
||||
event_name = "progress_updated"
|
||||
|
||||
if event_name in self._event_handlers:
|
||||
handlers = self._event_handlers[event_name]
|
||||
if handlers:
|
||||
# Execute all handlers, capturing exceptions
|
||||
tasks = [handler(event) for handler in handlers]
|
||||
# Ignore type error - tasks will be coroutines at runtime
|
||||
results = await asyncio.gather(
|
||||
*tasks, return_exceptions=True
|
||||
) # type: ignore[arg-type]
|
||||
|
||||
# Log any exceptions that occurred
|
||||
for idx, result in enumerate(results):
|
||||
if isinstance(result, Exception):
|
||||
logger.error(
|
||||
"Event handler raised exception",
|
||||
event=event_name,
|
||||
error=str(result),
|
||||
handler_index=idx,
|
||||
)
|
||||
|
||||
async def start_progress(
|
||||
self,
|
||||
progress_id: str,
|
||||
@@ -197,9 +259,15 @@ class ProgressService:
|
||||
title=title,
|
||||
)
|
||||
|
||||
# Broadcast to appropriate room
|
||||
# Emit event to subscribers
|
||||
room = f"{progress_type.value}_progress"
|
||||
await self._broadcast(update, room)
|
||||
event = ProgressEvent(
|
||||
event_type=f"{progress_type.value}_progress",
|
||||
progress_id=progress_id,
|
||||
progress=update,
|
||||
room=room,
|
||||
)
|
||||
await self._emit_event(event)
|
||||
|
||||
return update
|
||||
|
||||
@@ -262,7 +330,13 @@ class ProgressService:
|
||||
|
||||
if should_broadcast:
|
||||
room = f"{update.type.value}_progress"
|
||||
await self._broadcast(update, room)
|
||||
event = ProgressEvent(
|
||||
event_type=f"{update.type.value}_progress",
|
||||
progress_id=progress_id,
|
||||
progress=update,
|
||||
room=room,
|
||||
)
|
||||
await self._emit_event(event)
|
||||
|
||||
return update
|
||||
|
||||
@@ -311,9 +385,15 @@ class ProgressService:
|
||||
type=update.type.value,
|
||||
)
|
||||
|
||||
# Broadcast completion
|
||||
# Emit completion event
|
||||
room = f"{update.type.value}_progress"
|
||||
await self._broadcast(update, room)
|
||||
event = ProgressEvent(
|
||||
event_type=f"{update.type.value}_progress",
|
||||
progress_id=progress_id,
|
||||
progress=update,
|
||||
room=room,
|
||||
)
|
||||
await self._emit_event(event)
|
||||
|
||||
return update
|
||||
|
||||
@@ -361,9 +441,15 @@ class ProgressService:
|
||||
error=error_message,
|
||||
)
|
||||
|
||||
# Broadcast failure
|
||||
# Emit failure event
|
||||
room = f"{update.type.value}_progress"
|
||||
await self._broadcast(update, room)
|
||||
event = ProgressEvent(
|
||||
event_type=f"{update.type.value}_progress",
|
||||
progress_id=progress_id,
|
||||
progress=update,
|
||||
room=room,
|
||||
)
|
||||
await self._emit_event(event)
|
||||
|
||||
return update
|
||||
|
||||
@@ -405,9 +491,15 @@ class ProgressService:
|
||||
type=update.type.value,
|
||||
)
|
||||
|
||||
# Broadcast cancellation
|
||||
# Emit cancellation event
|
||||
room = f"{update.type.value}_progress"
|
||||
await self._broadcast(update, room)
|
||||
event = ProgressEvent(
|
||||
event_type=f"{update.type.value}_progress",
|
||||
progress_id=progress_id,
|
||||
progress=update,
|
||||
room=room,
|
||||
)
|
||||
await self._emit_event(event)
|
||||
|
||||
return update
|
||||
|
||||
|
||||
Reference in New Issue
Block a user