439 lines
15 KiB
Python
439 lines
15 KiB
Python
"""
|
|
Unit tests for the progress callback system.
|
|
|
|
Tests the callback interfaces, context classes, and callback manager
|
|
functionality.
|
|
"""
|
|
|
|
import unittest
|
|
|
|
from src.core.interfaces.callbacks import (
|
|
CallbackManager,
|
|
CompletionCallback,
|
|
CompletionContext,
|
|
ErrorCallback,
|
|
ErrorContext,
|
|
OperationType,
|
|
ProgressCallback,
|
|
ProgressContext,
|
|
ProgressPhase,
|
|
)
|
|
|
|
|
|
class TestProgressContext(unittest.TestCase):
|
|
"""Test ProgressContext dataclass."""
|
|
|
|
def test_progress_context_creation(self):
|
|
"""Test creating a progress context."""
|
|
context = ProgressContext(
|
|
operation_type=OperationType.DOWNLOAD,
|
|
operation_id="test-123",
|
|
phase=ProgressPhase.IN_PROGRESS,
|
|
current=50,
|
|
total=100,
|
|
percentage=50.0,
|
|
message="Downloading...",
|
|
details="Episode 5",
|
|
key="attack-on-titan",
|
|
folder="Attack on Titan (2013)",
|
|
metadata={"series": "Test"}
|
|
)
|
|
|
|
self.assertEqual(context.operation_type, OperationType.DOWNLOAD)
|
|
self.assertEqual(context.operation_id, "test-123")
|
|
self.assertEqual(context.phase, ProgressPhase.IN_PROGRESS)
|
|
self.assertEqual(context.current, 50)
|
|
self.assertEqual(context.total, 100)
|
|
self.assertEqual(context.percentage, 50.0)
|
|
self.assertEqual(context.message, "Downloading...")
|
|
self.assertEqual(context.details, "Episode 5")
|
|
self.assertEqual(context.key, "attack-on-titan")
|
|
self.assertEqual(context.folder, "Attack on Titan (2013)")
|
|
self.assertEqual(context.metadata, {"series": "Test"})
|
|
|
|
def test_progress_context_to_dict(self):
|
|
"""Test converting progress context to dictionary."""
|
|
context = ProgressContext(
|
|
operation_type=OperationType.SCAN,
|
|
operation_id="scan-456",
|
|
phase=ProgressPhase.COMPLETED,
|
|
current=100,
|
|
total=100,
|
|
percentage=100.0,
|
|
message="Scan complete"
|
|
)
|
|
|
|
result = context.to_dict()
|
|
|
|
self.assertEqual(result["operation_type"], "scan")
|
|
self.assertEqual(result["operation_id"], "scan-456")
|
|
self.assertEqual(result["phase"], "completed")
|
|
self.assertEqual(result["current"], 100)
|
|
self.assertEqual(result["total"], 100)
|
|
self.assertEqual(result["percentage"], 100.0)
|
|
self.assertEqual(result["message"], "Scan complete")
|
|
self.assertIsNone(result["details"])
|
|
self.assertIsNone(result["key"])
|
|
self.assertIsNone(result["folder"])
|
|
self.assertEqual(result["metadata"], {})
|
|
|
|
def test_progress_context_default_metadata(self):
|
|
"""Test that metadata defaults to empty dict."""
|
|
context = ProgressContext(
|
|
operation_type=OperationType.DOWNLOAD,
|
|
operation_id="test",
|
|
phase=ProgressPhase.STARTING,
|
|
current=0,
|
|
total=100,
|
|
percentage=0.0,
|
|
message="Starting"
|
|
)
|
|
|
|
self.assertIsNotNone(context.metadata)
|
|
self.assertEqual(context.metadata, {})
|
|
|
|
|
|
class TestErrorContext(unittest.TestCase):
|
|
"""Test ErrorContext dataclass."""
|
|
|
|
def test_error_context_creation(self):
|
|
"""Test creating an error context."""
|
|
error = ValueError("Test error")
|
|
context = ErrorContext(
|
|
operation_type=OperationType.DOWNLOAD,
|
|
operation_id="test-789",
|
|
error=error,
|
|
message="Download failed",
|
|
recoverable=True,
|
|
retry_count=2,
|
|
key="jujutsu-kaisen",
|
|
folder="Jujutsu Kaisen",
|
|
metadata={"attempt": 3}
|
|
)
|
|
|
|
self.assertEqual(context.operation_type, OperationType.DOWNLOAD)
|
|
self.assertEqual(context.operation_id, "test-789")
|
|
self.assertEqual(context.error, error)
|
|
self.assertEqual(context.message, "Download failed")
|
|
self.assertTrue(context.recoverable)
|
|
self.assertEqual(context.retry_count, 2)
|
|
self.assertEqual(context.key, "jujutsu-kaisen")
|
|
self.assertEqual(context.folder, "Jujutsu Kaisen")
|
|
self.assertEqual(context.metadata, {"attempt": 3})
|
|
|
|
def test_error_context_to_dict(self):
|
|
"""Test converting error context to dictionary."""
|
|
error = RuntimeError("Network error")
|
|
context = ErrorContext(
|
|
operation_type=OperationType.SCAN,
|
|
operation_id="scan-error",
|
|
error=error,
|
|
message="Scan error occurred",
|
|
recoverable=False
|
|
)
|
|
|
|
result = context.to_dict()
|
|
|
|
self.assertEqual(result["operation_type"], "scan")
|
|
self.assertEqual(result["operation_id"], "scan-error")
|
|
self.assertEqual(result["error_type"], "RuntimeError")
|
|
self.assertEqual(result["error_message"], "Network error")
|
|
self.assertEqual(result["message"], "Scan error occurred")
|
|
self.assertFalse(result["recoverable"])
|
|
self.assertEqual(result["retry_count"], 0)
|
|
self.assertIsNone(result["key"])
|
|
self.assertIsNone(result["folder"])
|
|
self.assertEqual(result["metadata"], {})
|
|
|
|
|
|
class TestCompletionContext(unittest.TestCase):
|
|
"""Test CompletionContext dataclass."""
|
|
|
|
def test_completion_context_creation(self):
|
|
"""Test creating a completion context."""
|
|
context = CompletionContext(
|
|
operation_type=OperationType.DOWNLOAD,
|
|
operation_id="download-complete",
|
|
success=True,
|
|
message="Download completed successfully",
|
|
result_data={"file": "episode.mp4"},
|
|
statistics={"size": 1024, "time": 60},
|
|
key="bleach",
|
|
folder="Bleach (2004)",
|
|
metadata={"quality": "HD"}
|
|
)
|
|
|
|
self.assertEqual(context.operation_type, OperationType.DOWNLOAD)
|
|
self.assertEqual(context.operation_id, "download-complete")
|
|
self.assertTrue(context.success)
|
|
self.assertEqual(context.message, "Download completed successfully")
|
|
self.assertEqual(context.result_data, {"file": "episode.mp4"})
|
|
self.assertEqual(context.statistics, {"size": 1024, "time": 60})
|
|
self.assertEqual(context.key, "bleach")
|
|
self.assertEqual(context.folder, "Bleach (2004)")
|
|
self.assertEqual(context.metadata, {"quality": "HD"})
|
|
|
|
def test_completion_context_to_dict(self):
|
|
"""Test converting completion context to dictionary."""
|
|
context = CompletionContext(
|
|
operation_type=OperationType.SCAN,
|
|
operation_id="scan-complete",
|
|
success=False,
|
|
message="Scan failed"
|
|
)
|
|
|
|
result = context.to_dict()
|
|
|
|
self.assertEqual(result["operation_type"], "scan")
|
|
self.assertEqual(result["operation_id"], "scan-complete")
|
|
self.assertFalse(result["success"])
|
|
self.assertEqual(result["message"], "Scan failed")
|
|
self.assertEqual(result["statistics"], {})
|
|
self.assertIsNone(result["key"])
|
|
self.assertIsNone(result["folder"])
|
|
self.assertEqual(result["metadata"], {})
|
|
|
|
|
|
class MockProgressCallback(ProgressCallback):
|
|
"""Mock implementation of ProgressCallback for testing."""
|
|
|
|
def __init__(self):
|
|
self.calls = []
|
|
|
|
def on_progress(self, context: ProgressContext) -> None:
|
|
self.calls.append(context)
|
|
|
|
|
|
class MockErrorCallback(ErrorCallback):
|
|
"""Mock implementation of ErrorCallback for testing."""
|
|
|
|
def __init__(self):
|
|
self.calls = []
|
|
|
|
def on_error(self, context: ErrorContext) -> None:
|
|
self.calls.append(context)
|
|
|
|
|
|
class MockCompletionCallback(CompletionCallback):
|
|
"""Mock implementation of CompletionCallback for testing."""
|
|
|
|
def __init__(self):
|
|
self.calls = []
|
|
|
|
def on_completion(self, context: CompletionContext) -> None:
|
|
self.calls.append(context)
|
|
|
|
|
|
class TestCallbackManager(unittest.TestCase):
|
|
"""Test CallbackManager functionality."""
|
|
|
|
def setUp(self):
|
|
"""Set up test fixtures."""
|
|
self.manager = CallbackManager()
|
|
|
|
def test_register_progress_callback(self):
|
|
"""Test registering a progress callback."""
|
|
callback = MockProgressCallback()
|
|
self.manager.register_progress_callback(callback)
|
|
|
|
# Callback should be registered
|
|
self.assertIn(callback, self.manager._progress_callbacks)
|
|
|
|
def test_register_duplicate_progress_callback(self):
|
|
"""Test that duplicate callbacks are not added."""
|
|
callback = MockProgressCallback()
|
|
self.manager.register_progress_callback(callback)
|
|
self.manager.register_progress_callback(callback)
|
|
|
|
# Should only be registered once
|
|
self.assertEqual(
|
|
self.manager._progress_callbacks.count(callback),
|
|
1
|
|
)
|
|
|
|
def test_register_error_callback(self):
|
|
"""Test registering an error callback."""
|
|
callback = MockErrorCallback()
|
|
self.manager.register_error_callback(callback)
|
|
|
|
self.assertIn(callback, self.manager._error_callbacks)
|
|
|
|
def test_register_completion_callback(self):
|
|
"""Test registering a completion callback."""
|
|
callback = MockCompletionCallback()
|
|
self.manager.register_completion_callback(callback)
|
|
|
|
self.assertIn(callback, self.manager._completion_callbacks)
|
|
|
|
def test_unregister_progress_callback(self):
|
|
"""Test unregistering a progress callback."""
|
|
callback = MockProgressCallback()
|
|
self.manager.register_progress_callback(callback)
|
|
self.manager.unregister_progress_callback(callback)
|
|
|
|
self.assertNotIn(callback, self.manager._progress_callbacks)
|
|
|
|
def test_unregister_error_callback(self):
|
|
"""Test unregistering an error callback."""
|
|
callback = MockErrorCallback()
|
|
self.manager.register_error_callback(callback)
|
|
self.manager.unregister_error_callback(callback)
|
|
|
|
self.assertNotIn(callback, self.manager._error_callbacks)
|
|
|
|
def test_unregister_completion_callback(self):
|
|
"""Test unregistering a completion callback."""
|
|
callback = MockCompletionCallback()
|
|
self.manager.register_completion_callback(callback)
|
|
self.manager.unregister_completion_callback(callback)
|
|
|
|
self.assertNotIn(callback, self.manager._completion_callbacks)
|
|
|
|
def test_notify_progress(self):
|
|
"""Test notifying progress callbacks."""
|
|
callback1 = MockProgressCallback()
|
|
callback2 = MockProgressCallback()
|
|
self.manager.register_progress_callback(callback1)
|
|
self.manager.register_progress_callback(callback2)
|
|
|
|
context = ProgressContext(
|
|
operation_type=OperationType.DOWNLOAD,
|
|
operation_id="test",
|
|
phase=ProgressPhase.IN_PROGRESS,
|
|
current=50,
|
|
total=100,
|
|
percentage=50.0,
|
|
message="Test progress"
|
|
)
|
|
|
|
self.manager.notify_progress(context)
|
|
|
|
# Both callbacks should be called
|
|
self.assertEqual(len(callback1.calls), 1)
|
|
self.assertEqual(len(callback2.calls), 1)
|
|
self.assertEqual(callback1.calls[0], context)
|
|
self.assertEqual(callback2.calls[0], context)
|
|
|
|
def test_notify_error(self):
|
|
"""Test notifying error callbacks."""
|
|
callback = MockErrorCallback()
|
|
self.manager.register_error_callback(callback)
|
|
|
|
error = ValueError("Test error")
|
|
context = ErrorContext(
|
|
operation_type=OperationType.DOWNLOAD,
|
|
operation_id="test",
|
|
error=error,
|
|
message="Error occurred"
|
|
)
|
|
|
|
self.manager.notify_error(context)
|
|
|
|
self.assertEqual(len(callback.calls), 1)
|
|
self.assertEqual(callback.calls[0], context)
|
|
|
|
def test_notify_completion(self):
|
|
"""Test notifying completion callbacks."""
|
|
callback = MockCompletionCallback()
|
|
self.manager.register_completion_callback(callback)
|
|
|
|
context = CompletionContext(
|
|
operation_type=OperationType.SCAN,
|
|
operation_id="test",
|
|
success=True,
|
|
message="Operation completed"
|
|
)
|
|
|
|
self.manager.notify_completion(context)
|
|
|
|
self.assertEqual(len(callback.calls), 1)
|
|
self.assertEqual(callback.calls[0], context)
|
|
|
|
def test_callback_exception_handling(self):
|
|
"""Test that exceptions in callbacks don't break notification."""
|
|
# Create a callback that raises an exception
|
|
class FailingCallback(ProgressCallback):
|
|
def on_progress(self, context: ProgressContext) -> None:
|
|
raise RuntimeError("Callback failed")
|
|
|
|
failing_callback = FailingCallback()
|
|
working_callback = MockProgressCallback()
|
|
|
|
self.manager.register_progress_callback(failing_callback)
|
|
self.manager.register_progress_callback(working_callback)
|
|
|
|
context = ProgressContext(
|
|
operation_type=OperationType.DOWNLOAD,
|
|
operation_id="test",
|
|
phase=ProgressPhase.IN_PROGRESS,
|
|
current=50,
|
|
total=100,
|
|
percentage=50.0,
|
|
message="Test"
|
|
)
|
|
|
|
# Should not raise exception
|
|
self.manager.notify_progress(context)
|
|
|
|
# Working callback should still be called
|
|
self.assertEqual(len(working_callback.calls), 1)
|
|
|
|
def test_clear_all_callbacks(self):
|
|
"""Test clearing all callbacks."""
|
|
self.manager.register_progress_callback(MockProgressCallback())
|
|
self.manager.register_error_callback(MockErrorCallback())
|
|
self.manager.register_completion_callback(MockCompletionCallback())
|
|
|
|
self.manager.clear_all_callbacks()
|
|
|
|
self.assertEqual(len(self.manager._progress_callbacks), 0)
|
|
self.assertEqual(len(self.manager._error_callbacks), 0)
|
|
self.assertEqual(len(self.manager._completion_callbacks), 0)
|
|
|
|
def test_multiple_notifications(self):
|
|
"""Test multiple progress notifications."""
|
|
callback = MockProgressCallback()
|
|
self.manager.register_progress_callback(callback)
|
|
|
|
for i in range(5):
|
|
context = ProgressContext(
|
|
operation_type=OperationType.DOWNLOAD,
|
|
operation_id="test",
|
|
phase=ProgressPhase.IN_PROGRESS,
|
|
current=i * 20,
|
|
total=100,
|
|
percentage=i * 20.0,
|
|
message=f"Progress {i}"
|
|
)
|
|
self.manager.notify_progress(context)
|
|
|
|
self.assertEqual(len(callback.calls), 5)
|
|
|
|
|
|
class TestOperationType(unittest.TestCase):
|
|
"""Test OperationType enum."""
|
|
|
|
def test_operation_types(self):
|
|
"""Test all operation types are defined."""
|
|
self.assertEqual(OperationType.SCAN, "scan")
|
|
self.assertEqual(OperationType.DOWNLOAD, "download")
|
|
self.assertEqual(OperationType.SEARCH, "search")
|
|
self.assertEqual(OperationType.INITIALIZATION, "initialization")
|
|
|
|
|
|
class TestProgressPhase(unittest.TestCase):
|
|
"""Test ProgressPhase enum."""
|
|
|
|
def test_progress_phases(self):
|
|
"""Test all progress phases are defined."""
|
|
self.assertEqual(ProgressPhase.STARTING, "starting")
|
|
self.assertEqual(ProgressPhase.IN_PROGRESS, "in_progress")
|
|
self.assertEqual(ProgressPhase.COMPLETING, "completing")
|
|
self.assertEqual(ProgressPhase.COMPLETED, "completed")
|
|
self.assertEqual(ProgressPhase.FAILED, "failed")
|
|
self.assertEqual(ProgressPhase.CANCELLED, "cancelled")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|