Aniworld/tests/unit/test_callbacks.py

439 lines
15 KiB
Python

"""
Unit tests for the progress callback system.
Tests the callback interfaces, context classes, and callback manager
functionality.
"""
import unittest
from src.core.interfaces.callbacks import (
CallbackManager,
CompletionCallback,
CompletionContext,
ErrorCallback,
ErrorContext,
OperationType,
ProgressCallback,
ProgressContext,
ProgressPhase,
)
class TestProgressContext(unittest.TestCase):
"""Test ProgressContext dataclass."""
def test_progress_context_creation(self):
"""Test creating a progress context."""
context = ProgressContext(
operation_type=OperationType.DOWNLOAD,
operation_id="test-123",
phase=ProgressPhase.IN_PROGRESS,
current=50,
total=100,
percentage=50.0,
message="Downloading...",
details="Episode 5",
key="attack-on-titan",
folder="Attack on Titan (2013)",
metadata={"series": "Test"}
)
self.assertEqual(context.operation_type, OperationType.DOWNLOAD)
self.assertEqual(context.operation_id, "test-123")
self.assertEqual(context.phase, ProgressPhase.IN_PROGRESS)
self.assertEqual(context.current, 50)
self.assertEqual(context.total, 100)
self.assertEqual(context.percentage, 50.0)
self.assertEqual(context.message, "Downloading...")
self.assertEqual(context.details, "Episode 5")
self.assertEqual(context.key, "attack-on-titan")
self.assertEqual(context.folder, "Attack on Titan (2013)")
self.assertEqual(context.metadata, {"series": "Test"})
def test_progress_context_to_dict(self):
"""Test converting progress context to dictionary."""
context = ProgressContext(
operation_type=OperationType.SCAN,
operation_id="scan-456",
phase=ProgressPhase.COMPLETED,
current=100,
total=100,
percentage=100.0,
message="Scan complete"
)
result = context.to_dict()
self.assertEqual(result["operation_type"], "scan")
self.assertEqual(result["operation_id"], "scan-456")
self.assertEqual(result["phase"], "completed")
self.assertEqual(result["current"], 100)
self.assertEqual(result["total"], 100)
self.assertEqual(result["percentage"], 100.0)
self.assertEqual(result["message"], "Scan complete")
self.assertIsNone(result["details"])
self.assertIsNone(result["key"])
self.assertIsNone(result["folder"])
self.assertEqual(result["metadata"], {})
def test_progress_context_default_metadata(self):
"""Test that metadata defaults to empty dict."""
context = ProgressContext(
operation_type=OperationType.DOWNLOAD,
operation_id="test",
phase=ProgressPhase.STARTING,
current=0,
total=100,
percentage=0.0,
message="Starting"
)
self.assertIsNotNone(context.metadata)
self.assertEqual(context.metadata, {})
class TestErrorContext(unittest.TestCase):
"""Test ErrorContext dataclass."""
def test_error_context_creation(self):
"""Test creating an error context."""
error = ValueError("Test error")
context = ErrorContext(
operation_type=OperationType.DOWNLOAD,
operation_id="test-789",
error=error,
message="Download failed",
recoverable=True,
retry_count=2,
key="jujutsu-kaisen",
folder="Jujutsu Kaisen",
metadata={"attempt": 3}
)
self.assertEqual(context.operation_type, OperationType.DOWNLOAD)
self.assertEqual(context.operation_id, "test-789")
self.assertEqual(context.error, error)
self.assertEqual(context.message, "Download failed")
self.assertTrue(context.recoverable)
self.assertEqual(context.retry_count, 2)
self.assertEqual(context.key, "jujutsu-kaisen")
self.assertEqual(context.folder, "Jujutsu Kaisen")
self.assertEqual(context.metadata, {"attempt": 3})
def test_error_context_to_dict(self):
"""Test converting error context to dictionary."""
error = RuntimeError("Network error")
context = ErrorContext(
operation_type=OperationType.SCAN,
operation_id="scan-error",
error=error,
message="Scan error occurred",
recoverable=False
)
result = context.to_dict()
self.assertEqual(result["operation_type"], "scan")
self.assertEqual(result["operation_id"], "scan-error")
self.assertEqual(result["error_type"], "RuntimeError")
self.assertEqual(result["error_message"], "Network error")
self.assertEqual(result["message"], "Scan error occurred")
self.assertFalse(result["recoverable"])
self.assertEqual(result["retry_count"], 0)
self.assertIsNone(result["key"])
self.assertIsNone(result["folder"])
self.assertEqual(result["metadata"], {})
class TestCompletionContext(unittest.TestCase):
"""Test CompletionContext dataclass."""
def test_completion_context_creation(self):
"""Test creating a completion context."""
context = CompletionContext(
operation_type=OperationType.DOWNLOAD,
operation_id="download-complete",
success=True,
message="Download completed successfully",
result_data={"file": "episode.mp4"},
statistics={"size": 1024, "time": 60},
key="bleach",
folder="Bleach (2004)",
metadata={"quality": "HD"}
)
self.assertEqual(context.operation_type, OperationType.DOWNLOAD)
self.assertEqual(context.operation_id, "download-complete")
self.assertTrue(context.success)
self.assertEqual(context.message, "Download completed successfully")
self.assertEqual(context.result_data, {"file": "episode.mp4"})
self.assertEqual(context.statistics, {"size": 1024, "time": 60})
self.assertEqual(context.key, "bleach")
self.assertEqual(context.folder, "Bleach (2004)")
self.assertEqual(context.metadata, {"quality": "HD"})
def test_completion_context_to_dict(self):
"""Test converting completion context to dictionary."""
context = CompletionContext(
operation_type=OperationType.SCAN,
operation_id="scan-complete",
success=False,
message="Scan failed"
)
result = context.to_dict()
self.assertEqual(result["operation_type"], "scan")
self.assertEqual(result["operation_id"], "scan-complete")
self.assertFalse(result["success"])
self.assertEqual(result["message"], "Scan failed")
self.assertEqual(result["statistics"], {})
self.assertIsNone(result["key"])
self.assertIsNone(result["folder"])
self.assertEqual(result["metadata"], {})
class MockProgressCallback(ProgressCallback):
"""Mock implementation of ProgressCallback for testing."""
def __init__(self):
self.calls = []
def on_progress(self, context: ProgressContext) -> None:
self.calls.append(context)
class MockErrorCallback(ErrorCallback):
"""Mock implementation of ErrorCallback for testing."""
def __init__(self):
self.calls = []
def on_error(self, context: ErrorContext) -> None:
self.calls.append(context)
class MockCompletionCallback(CompletionCallback):
"""Mock implementation of CompletionCallback for testing."""
def __init__(self):
self.calls = []
def on_completion(self, context: CompletionContext) -> None:
self.calls.append(context)
class TestCallbackManager(unittest.TestCase):
"""Test CallbackManager functionality."""
def setUp(self):
"""Set up test fixtures."""
self.manager = CallbackManager()
def test_register_progress_callback(self):
"""Test registering a progress callback."""
callback = MockProgressCallback()
self.manager.register_progress_callback(callback)
# Callback should be registered
self.assertIn(callback, self.manager._progress_callbacks)
def test_register_duplicate_progress_callback(self):
"""Test that duplicate callbacks are not added."""
callback = MockProgressCallback()
self.manager.register_progress_callback(callback)
self.manager.register_progress_callback(callback)
# Should only be registered once
self.assertEqual(
self.manager._progress_callbacks.count(callback),
1
)
def test_register_error_callback(self):
"""Test registering an error callback."""
callback = MockErrorCallback()
self.manager.register_error_callback(callback)
self.assertIn(callback, self.manager._error_callbacks)
def test_register_completion_callback(self):
"""Test registering a completion callback."""
callback = MockCompletionCallback()
self.manager.register_completion_callback(callback)
self.assertIn(callback, self.manager._completion_callbacks)
def test_unregister_progress_callback(self):
"""Test unregistering a progress callback."""
callback = MockProgressCallback()
self.manager.register_progress_callback(callback)
self.manager.unregister_progress_callback(callback)
self.assertNotIn(callback, self.manager._progress_callbacks)
def test_unregister_error_callback(self):
"""Test unregistering an error callback."""
callback = MockErrorCallback()
self.manager.register_error_callback(callback)
self.manager.unregister_error_callback(callback)
self.assertNotIn(callback, self.manager._error_callbacks)
def test_unregister_completion_callback(self):
"""Test unregistering a completion callback."""
callback = MockCompletionCallback()
self.manager.register_completion_callback(callback)
self.manager.unregister_completion_callback(callback)
self.assertNotIn(callback, self.manager._completion_callbacks)
def test_notify_progress(self):
"""Test notifying progress callbacks."""
callback1 = MockProgressCallback()
callback2 = MockProgressCallback()
self.manager.register_progress_callback(callback1)
self.manager.register_progress_callback(callback2)
context = ProgressContext(
operation_type=OperationType.DOWNLOAD,
operation_id="test",
phase=ProgressPhase.IN_PROGRESS,
current=50,
total=100,
percentage=50.0,
message="Test progress"
)
self.manager.notify_progress(context)
# Both callbacks should be called
self.assertEqual(len(callback1.calls), 1)
self.assertEqual(len(callback2.calls), 1)
self.assertEqual(callback1.calls[0], context)
self.assertEqual(callback2.calls[0], context)
def test_notify_error(self):
"""Test notifying error callbacks."""
callback = MockErrorCallback()
self.manager.register_error_callback(callback)
error = ValueError("Test error")
context = ErrorContext(
operation_type=OperationType.DOWNLOAD,
operation_id="test",
error=error,
message="Error occurred"
)
self.manager.notify_error(context)
self.assertEqual(len(callback.calls), 1)
self.assertEqual(callback.calls[0], context)
def test_notify_completion(self):
"""Test notifying completion callbacks."""
callback = MockCompletionCallback()
self.manager.register_completion_callback(callback)
context = CompletionContext(
operation_type=OperationType.SCAN,
operation_id="test",
success=True,
message="Operation completed"
)
self.manager.notify_completion(context)
self.assertEqual(len(callback.calls), 1)
self.assertEqual(callback.calls[0], context)
def test_callback_exception_handling(self):
"""Test that exceptions in callbacks don't break notification."""
# Create a callback that raises an exception
class FailingCallback(ProgressCallback):
def on_progress(self, context: ProgressContext) -> None:
raise RuntimeError("Callback failed")
failing_callback = FailingCallback()
working_callback = MockProgressCallback()
self.manager.register_progress_callback(failing_callback)
self.manager.register_progress_callback(working_callback)
context = ProgressContext(
operation_type=OperationType.DOWNLOAD,
operation_id="test",
phase=ProgressPhase.IN_PROGRESS,
current=50,
total=100,
percentage=50.0,
message="Test"
)
# Should not raise exception
self.manager.notify_progress(context)
# Working callback should still be called
self.assertEqual(len(working_callback.calls), 1)
def test_clear_all_callbacks(self):
"""Test clearing all callbacks."""
self.manager.register_progress_callback(MockProgressCallback())
self.manager.register_error_callback(MockErrorCallback())
self.manager.register_completion_callback(MockCompletionCallback())
self.manager.clear_all_callbacks()
self.assertEqual(len(self.manager._progress_callbacks), 0)
self.assertEqual(len(self.manager._error_callbacks), 0)
self.assertEqual(len(self.manager._completion_callbacks), 0)
def test_multiple_notifications(self):
"""Test multiple progress notifications."""
callback = MockProgressCallback()
self.manager.register_progress_callback(callback)
for i in range(5):
context = ProgressContext(
operation_type=OperationType.DOWNLOAD,
operation_id="test",
phase=ProgressPhase.IN_PROGRESS,
current=i * 20,
total=100,
percentage=i * 20.0,
message=f"Progress {i}"
)
self.manager.notify_progress(context)
self.assertEqual(len(callback.calls), 5)
class TestOperationType(unittest.TestCase):
"""Test OperationType enum."""
def test_operation_types(self):
"""Test all operation types are defined."""
self.assertEqual(OperationType.SCAN, "scan")
self.assertEqual(OperationType.DOWNLOAD, "download")
self.assertEqual(OperationType.SEARCH, "search")
self.assertEqual(OperationType.INITIALIZATION, "initialization")
class TestProgressPhase(unittest.TestCase):
"""Test ProgressPhase enum."""
def test_progress_phases(self):
"""Test all progress phases are defined."""
self.assertEqual(ProgressPhase.STARTING, "starting")
self.assertEqual(ProgressPhase.IN_PROGRESS, "in_progress")
self.assertEqual(ProgressPhase.COMPLETING, "completing")
self.assertEqual(ProgressPhase.COMPLETED, "completed")
self.assertEqual(ProgressPhase.FAILED, "failed")
self.assertEqual(ProgressPhase.CANCELLED, "cancelled")
if __name__ == "__main__":
unittest.main()