""" Unit tests for the progress callback system. Tests the callback interfaces, context classes, and callback manager functionality. """ import unittest from src.core.interfaces.callbacks import ( CallbackManager, CompletionCallback, CompletionContext, ErrorCallback, ErrorContext, OperationType, ProgressCallback, ProgressContext, ProgressPhase, ) class TestProgressContext(unittest.TestCase): """Test ProgressContext dataclass.""" def test_progress_context_creation(self): """Test creating a progress context.""" context = ProgressContext( operation_type=OperationType.DOWNLOAD, operation_id="test-123", phase=ProgressPhase.IN_PROGRESS, current=50, total=100, percentage=50.0, message="Downloading...", details="Episode 5", metadata={"series": "Test"} ) self.assertEqual(context.operation_type, OperationType.DOWNLOAD) self.assertEqual(context.operation_id, "test-123") self.assertEqual(context.phase, ProgressPhase.IN_PROGRESS) self.assertEqual(context.current, 50) self.assertEqual(context.total, 100) self.assertEqual(context.percentage, 50.0) self.assertEqual(context.message, "Downloading...") self.assertEqual(context.details, "Episode 5") self.assertEqual(context.metadata, {"series": "Test"}) def test_progress_context_to_dict(self): """Test converting progress context to dictionary.""" context = ProgressContext( operation_type=OperationType.SCAN, operation_id="scan-456", phase=ProgressPhase.COMPLETED, current=100, total=100, percentage=100.0, message="Scan complete" ) result = context.to_dict() self.assertEqual(result["operation_type"], "scan") self.assertEqual(result["operation_id"], "scan-456") self.assertEqual(result["phase"], "completed") self.assertEqual(result["current"], 100) self.assertEqual(result["total"], 100) self.assertEqual(result["percentage"], 100.0) self.assertEqual(result["message"], "Scan complete") self.assertIsNone(result["details"]) self.assertEqual(result["metadata"], {}) def test_progress_context_default_metadata(self): """Test that metadata defaults to empty dict.""" context = ProgressContext( operation_type=OperationType.DOWNLOAD, operation_id="test", phase=ProgressPhase.STARTING, current=0, total=100, percentage=0.0, message="Starting" ) self.assertIsNotNone(context.metadata) self.assertEqual(context.metadata, {}) class TestErrorContext(unittest.TestCase): """Test ErrorContext dataclass.""" def test_error_context_creation(self): """Test creating an error context.""" error = ValueError("Test error") context = ErrorContext( operation_type=OperationType.DOWNLOAD, operation_id="test-789", error=error, message="Download failed", recoverable=True, retry_count=2, metadata={"attempt": 3} ) self.assertEqual(context.operation_type, OperationType.DOWNLOAD) self.assertEqual(context.operation_id, "test-789") self.assertEqual(context.error, error) self.assertEqual(context.message, "Download failed") self.assertTrue(context.recoverable) self.assertEqual(context.retry_count, 2) self.assertEqual(context.metadata, {"attempt": 3}) def test_error_context_to_dict(self): """Test converting error context to dictionary.""" error = RuntimeError("Network error") context = ErrorContext( operation_type=OperationType.SCAN, operation_id="scan-error", error=error, message="Scan error occurred", recoverable=False ) result = context.to_dict() self.assertEqual(result["operation_type"], "scan") self.assertEqual(result["operation_id"], "scan-error") self.assertEqual(result["error_type"], "RuntimeError") self.assertEqual(result["error_message"], "Network error") self.assertEqual(result["message"], "Scan error occurred") self.assertFalse(result["recoverable"]) self.assertEqual(result["retry_count"], 0) self.assertEqual(result["metadata"], {}) class TestCompletionContext(unittest.TestCase): """Test CompletionContext dataclass.""" def test_completion_context_creation(self): """Test creating a completion context.""" context = CompletionContext( operation_type=OperationType.DOWNLOAD, operation_id="download-complete", success=True, message="Download completed successfully", result_data={"file": "episode.mp4"}, statistics={"size": 1024, "time": 60}, metadata={"quality": "HD"} ) self.assertEqual(context.operation_type, OperationType.DOWNLOAD) self.assertEqual(context.operation_id, "download-complete") self.assertTrue(context.success) self.assertEqual(context.message, "Download completed successfully") self.assertEqual(context.result_data, {"file": "episode.mp4"}) self.assertEqual(context.statistics, {"size": 1024, "time": 60}) self.assertEqual(context.metadata, {"quality": "HD"}) def test_completion_context_to_dict(self): """Test converting completion context to dictionary.""" context = CompletionContext( operation_type=OperationType.SCAN, operation_id="scan-complete", success=False, message="Scan failed" ) result = context.to_dict() self.assertEqual(result["operation_type"], "scan") self.assertEqual(result["operation_id"], "scan-complete") self.assertFalse(result["success"]) self.assertEqual(result["message"], "Scan failed") self.assertEqual(result["statistics"], {}) self.assertEqual(result["metadata"], {}) class MockProgressCallback(ProgressCallback): """Mock implementation of ProgressCallback for testing.""" def __init__(self): self.calls = [] def on_progress(self, context: ProgressContext) -> None: self.calls.append(context) class MockErrorCallback(ErrorCallback): """Mock implementation of ErrorCallback for testing.""" def __init__(self): self.calls = [] def on_error(self, context: ErrorContext) -> None: self.calls.append(context) class MockCompletionCallback(CompletionCallback): """Mock implementation of CompletionCallback for testing.""" def __init__(self): self.calls = [] def on_completion(self, context: CompletionContext) -> None: self.calls.append(context) class TestCallbackManager(unittest.TestCase): """Test CallbackManager functionality.""" def setUp(self): """Set up test fixtures.""" self.manager = CallbackManager() def test_register_progress_callback(self): """Test registering a progress callback.""" callback = MockProgressCallback() self.manager.register_progress_callback(callback) # Callback should be registered self.assertIn(callback, self.manager._progress_callbacks) def test_register_duplicate_progress_callback(self): """Test that duplicate callbacks are not added.""" callback = MockProgressCallback() self.manager.register_progress_callback(callback) self.manager.register_progress_callback(callback) # Should only be registered once self.assertEqual( self.manager._progress_callbacks.count(callback), 1 ) def test_register_error_callback(self): """Test registering an error callback.""" callback = MockErrorCallback() self.manager.register_error_callback(callback) self.assertIn(callback, self.manager._error_callbacks) def test_register_completion_callback(self): """Test registering a completion callback.""" callback = MockCompletionCallback() self.manager.register_completion_callback(callback) self.assertIn(callback, self.manager._completion_callbacks) def test_unregister_progress_callback(self): """Test unregistering a progress callback.""" callback = MockProgressCallback() self.manager.register_progress_callback(callback) self.manager.unregister_progress_callback(callback) self.assertNotIn(callback, self.manager._progress_callbacks) def test_unregister_error_callback(self): """Test unregistering an error callback.""" callback = MockErrorCallback() self.manager.register_error_callback(callback) self.manager.unregister_error_callback(callback) self.assertNotIn(callback, self.manager._error_callbacks) def test_unregister_completion_callback(self): """Test unregistering a completion callback.""" callback = MockCompletionCallback() self.manager.register_completion_callback(callback) self.manager.unregister_completion_callback(callback) self.assertNotIn(callback, self.manager._completion_callbacks) def test_notify_progress(self): """Test notifying progress callbacks.""" callback1 = MockProgressCallback() callback2 = MockProgressCallback() self.manager.register_progress_callback(callback1) self.manager.register_progress_callback(callback2) context = ProgressContext( operation_type=OperationType.DOWNLOAD, operation_id="test", phase=ProgressPhase.IN_PROGRESS, current=50, total=100, percentage=50.0, message="Test progress" ) self.manager.notify_progress(context) # Both callbacks should be called self.assertEqual(len(callback1.calls), 1) self.assertEqual(len(callback2.calls), 1) self.assertEqual(callback1.calls[0], context) self.assertEqual(callback2.calls[0], context) def test_notify_error(self): """Test notifying error callbacks.""" callback = MockErrorCallback() self.manager.register_error_callback(callback) error = ValueError("Test error") context = ErrorContext( operation_type=OperationType.DOWNLOAD, operation_id="test", error=error, message="Error occurred" ) self.manager.notify_error(context) self.assertEqual(len(callback.calls), 1) self.assertEqual(callback.calls[0], context) def test_notify_completion(self): """Test notifying completion callbacks.""" callback = MockCompletionCallback() self.manager.register_completion_callback(callback) context = CompletionContext( operation_type=OperationType.SCAN, operation_id="test", success=True, message="Operation completed" ) self.manager.notify_completion(context) self.assertEqual(len(callback.calls), 1) self.assertEqual(callback.calls[0], context) def test_callback_exception_handling(self): """Test that exceptions in callbacks don't break notification.""" # Create a callback that raises an exception class FailingCallback(ProgressCallback): def on_progress(self, context: ProgressContext) -> None: raise RuntimeError("Callback failed") failing_callback = FailingCallback() working_callback = MockProgressCallback() self.manager.register_progress_callback(failing_callback) self.manager.register_progress_callback(working_callback) context = ProgressContext( operation_type=OperationType.DOWNLOAD, operation_id="test", phase=ProgressPhase.IN_PROGRESS, current=50, total=100, percentage=50.0, message="Test" ) # Should not raise exception self.manager.notify_progress(context) # Working callback should still be called self.assertEqual(len(working_callback.calls), 1) def test_clear_all_callbacks(self): """Test clearing all callbacks.""" self.manager.register_progress_callback(MockProgressCallback()) self.manager.register_error_callback(MockErrorCallback()) self.manager.register_completion_callback(MockCompletionCallback()) self.manager.clear_all_callbacks() self.assertEqual(len(self.manager._progress_callbacks), 0) self.assertEqual(len(self.manager._error_callbacks), 0) self.assertEqual(len(self.manager._completion_callbacks), 0) def test_multiple_notifications(self): """Test multiple progress notifications.""" callback = MockProgressCallback() self.manager.register_progress_callback(callback) for i in range(5): context = ProgressContext( operation_type=OperationType.DOWNLOAD, operation_id="test", phase=ProgressPhase.IN_PROGRESS, current=i * 20, total=100, percentage=i * 20.0, message=f"Progress {i}" ) self.manager.notify_progress(context) self.assertEqual(len(callback.calls), 5) class TestOperationType(unittest.TestCase): """Test OperationType enum.""" def test_operation_types(self): """Test all operation types are defined.""" self.assertEqual(OperationType.SCAN, "scan") self.assertEqual(OperationType.DOWNLOAD, "download") self.assertEqual(OperationType.SEARCH, "search") self.assertEqual(OperationType.INITIALIZATION, "initialization") class TestProgressPhase(unittest.TestCase): """Test ProgressPhase enum.""" def test_progress_phases(self): """Test all progress phases are defined.""" self.assertEqual(ProgressPhase.STARTING, "starting") self.assertEqual(ProgressPhase.IN_PROGRESS, "in_progress") self.assertEqual(ProgressPhase.COMPLETING, "completing") self.assertEqual(ProgressPhase.COMPLETED, "completed") self.assertEqual(ProgressPhase.FAILED, "failed") self.assertEqual(ProgressPhase.CANCELLED, "cancelled") if __name__ == "__main__": unittest.main()