Add error handling tests: 74 tests for core errors, middleware, and recovery workflows
This commit is contained in:
315
tests/integration/test_error_recovery_workflows.py
Normal file
315
tests/integration/test_error_recovery_workflows.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""Integration tests for error recovery workflows.
|
||||
|
||||
Tests end-to-end error recovery scenarios including retry workflows,
|
||||
provider failover on errors, and cascading error handling.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.core.error_handler import (
|
||||
DownloadError,
|
||||
NetworkError,
|
||||
NonRetryableError,
|
||||
RecoveryStrategies,
|
||||
RetryableError,
|
||||
with_error_recovery,
|
||||
)
|
||||
|
||||
|
||||
class TestDownloadRetryWorkflow:
|
||||
"""End-to-end tests: download fails → retries → eventually succeeds/fails."""
|
||||
|
||||
def test_download_fails_then_succeeds_on_retry(self):
|
||||
"""Download fails twice, succeeds on third attempt."""
|
||||
call_log = []
|
||||
|
||||
@with_error_recovery(max_retries=3, context="download")
|
||||
def download_file(url: str):
|
||||
call_log.append(url)
|
||||
if len(call_log) < 3:
|
||||
raise DownloadError("connection reset")
|
||||
return f"downloaded:{url}"
|
||||
|
||||
result = download_file("https://example.com/video.mp4")
|
||||
assert result == "downloaded:https://example.com/video.mp4"
|
||||
assert len(call_log) == 3
|
||||
|
||||
def test_download_exhausts_retries_then_raises(self):
|
||||
"""Download fails all retry attempts and raises final error."""
|
||||
|
||||
@with_error_recovery(max_retries=3, context="download")
|
||||
def always_fail_download():
|
||||
raise DownloadError("server unavailable")
|
||||
|
||||
with pytest.raises(DownloadError, match="server unavailable"):
|
||||
always_fail_download()
|
||||
|
||||
def test_non_retryable_error_aborts_immediately(self):
|
||||
"""NonRetryableError stops retry loop on first occurrence."""
|
||||
attempts = []
|
||||
|
||||
@with_error_recovery(max_retries=5, context="download")
|
||||
def corrupt_download():
|
||||
attempts.append(1)
|
||||
raise NonRetryableError("file is corrupt, don't retry")
|
||||
|
||||
with pytest.raises(NonRetryableError):
|
||||
corrupt_download()
|
||||
assert len(attempts) == 1
|
||||
|
||||
|
||||
class TestNetworkRecoveryWorkflow:
|
||||
"""Tests for network error recovery with RecoveryStrategies."""
|
||||
|
||||
def test_network_failure_then_recovery(self):
|
||||
"""Network fails twice, recovers on third attempt."""
|
||||
attempts = []
|
||||
|
||||
def fetch_data():
|
||||
attempts.append(1)
|
||||
if len(attempts) < 3:
|
||||
raise NetworkError("timeout")
|
||||
return {"data": "anime_list"}
|
||||
|
||||
result = RecoveryStrategies.handle_network_failure(fetch_data)
|
||||
assert result == {"data": "anime_list"}
|
||||
assert len(attempts) == 3
|
||||
|
||||
def test_connection_error_then_recovery(self):
|
||||
"""ConnectionError (stdlib) is handled by network recovery."""
|
||||
attempts = []
|
||||
|
||||
def connect():
|
||||
attempts.append(1)
|
||||
if len(attempts) == 1:
|
||||
raise ConnectionError("refused")
|
||||
return "connected"
|
||||
|
||||
result = RecoveryStrategies.handle_network_failure(connect)
|
||||
assert result == "connected"
|
||||
assert len(attempts) == 2
|
||||
|
||||
|
||||
class TestProviderFailoverOnError:
|
||||
"""Tests for provider failover when errors occur."""
|
||||
|
||||
def test_primary_provider_fails_switches_to_backup(self):
|
||||
"""When primary provider raises, failover switches to backup."""
|
||||
primary = MagicMock(side_effect=NetworkError("primary down"))
|
||||
backup = MagicMock(return_value="backup_result")
|
||||
providers = [primary, backup]
|
||||
|
||||
result = None
|
||||
for provider in providers:
|
||||
try:
|
||||
result = provider()
|
||||
break
|
||||
except (NetworkError, ConnectionError):
|
||||
continue
|
||||
|
||||
assert result == "backup_result"
|
||||
primary.assert_called_once()
|
||||
backup.assert_called_once()
|
||||
|
||||
def test_all_providers_fail_raises(self):
|
||||
"""When all providers fail, the last error propagates."""
|
||||
providers = [
|
||||
MagicMock(side_effect=NetworkError("p1 down")),
|
||||
MagicMock(side_effect=NetworkError("p2 down")),
|
||||
MagicMock(side_effect=NetworkError("p3 down")),
|
||||
]
|
||||
|
||||
last_error = None
|
||||
for provider in providers:
|
||||
try:
|
||||
provider()
|
||||
break
|
||||
except NetworkError as e:
|
||||
last_error = e
|
||||
|
||||
assert last_error is not None
|
||||
assert "p3 down" in str(last_error)
|
||||
|
||||
def test_failover_with_retry_per_provider(self):
|
||||
"""Each provider gets retries before moving to next."""
|
||||
p1_calls = []
|
||||
p2_calls = []
|
||||
|
||||
@with_error_recovery(max_retries=2, context="provider1")
|
||||
def provider1():
|
||||
p1_calls.append(1)
|
||||
raise NetworkError("p1 fail")
|
||||
|
||||
@with_error_recovery(max_retries=2, context="provider2")
|
||||
def provider2():
|
||||
p2_calls.append(1)
|
||||
return "p2_success"
|
||||
|
||||
result = None
|
||||
for provider_fn in [provider1, provider2]:
|
||||
try:
|
||||
result = provider_fn()
|
||||
break
|
||||
except NetworkError:
|
||||
continue
|
||||
|
||||
assert result == "p2_success"
|
||||
assert len(p1_calls) == 2 # provider1 exhausted its retries
|
||||
assert len(p2_calls) == 1 # provider2 succeeded first try
|
||||
|
||||
|
||||
class TestCascadingErrorHandling:
|
||||
"""Tests for cascading error scenarios."""
|
||||
|
||||
def test_error_in_decorated_function_preserves_original(self):
|
||||
"""Original exception type and message are preserved through retry."""
|
||||
|
||||
@with_error_recovery(max_retries=1, context="cascade")
|
||||
def inner_fail():
|
||||
raise ValueError("original error context")
|
||||
|
||||
with pytest.raises(ValueError, match="original error context"):
|
||||
inner_fail()
|
||||
|
||||
def test_nested_recovery_decorators(self):
|
||||
"""Nested error recovery decorators work independently."""
|
||||
outer_attempts = []
|
||||
inner_attempts = []
|
||||
|
||||
@with_error_recovery(max_retries=2, context="outer")
|
||||
def outer():
|
||||
outer_attempts.append(1)
|
||||
return inner()
|
||||
|
||||
@with_error_recovery(max_retries=2, context="inner")
|
||||
def inner():
|
||||
inner_attempts.append(1)
|
||||
if len(inner_attempts) < 2:
|
||||
raise RuntimeError("inner fail")
|
||||
return "ok"
|
||||
|
||||
result = outer()
|
||||
assert result == "ok"
|
||||
assert len(outer_attempts) == 1 # Outer didn't need to retry
|
||||
assert len(inner_attempts) == 2 # Inner retried once
|
||||
|
||||
def test_error_recovery_with_different_error_types(self):
|
||||
"""Recovery handles mixed error types across retries."""
|
||||
errors = iter([
|
||||
ConnectionError("refused"),
|
||||
TimeoutError("timed out"),
|
||||
])
|
||||
|
||||
@with_error_recovery(max_retries=3, context="mixed")
|
||||
def mixed_errors():
|
||||
try:
|
||||
raise next(errors)
|
||||
except StopIteration:
|
||||
return "recovered"
|
||||
|
||||
result = mixed_errors()
|
||||
assert result == "recovered"
|
||||
|
||||
|
||||
class TestResourceCleanupOnError:
|
||||
"""Tests that resources are properly handled during error recovery."""
|
||||
|
||||
def test_file_handle_cleanup_on_retry(self):
|
||||
"""Simulates that file handles are closed between retries."""
|
||||
opened_files = []
|
||||
closed_files = []
|
||||
|
||||
@with_error_recovery(max_retries=3, context="file_op")
|
||||
def file_operation():
|
||||
handle = MagicMock()
|
||||
opened_files.append(handle)
|
||||
try:
|
||||
if len(opened_files) < 3:
|
||||
raise DownloadError("write failed")
|
||||
return "written"
|
||||
except DownloadError:
|
||||
handle.close()
|
||||
closed_files.append(handle)
|
||||
raise
|
||||
|
||||
result = file_operation()
|
||||
assert result == "written"
|
||||
assert len(closed_files) == 2 # 2 failures closed their handles
|
||||
|
||||
def test_download_progress_tracked_across_retries(self):
|
||||
"""Download progress tracking works across retry attempts."""
|
||||
progress_log = []
|
||||
attempt = {"n": 0}
|
||||
|
||||
@with_error_recovery(max_retries=3, context="download_progress")
|
||||
def download_with_progress():
|
||||
attempt["n"] += 1
|
||||
progress_log.append("started")
|
||||
if attempt["n"] < 3:
|
||||
progress_log.append("failed")
|
||||
raise DownloadError("interrupted")
|
||||
progress_log.append("completed")
|
||||
return "done"
|
||||
|
||||
result = download_with_progress()
|
||||
assert result == "done"
|
||||
assert progress_log == [
|
||||
"started", "failed",
|
||||
"started", "failed",
|
||||
"started", "completed",
|
||||
]
|
||||
|
||||
|
||||
class TestErrorClassificationWorkflow:
|
||||
"""Tests for correct error classification in workflows."""
|
||||
|
||||
def test_retryable_errors_are_retried(self):
|
||||
"""RetryableError subclass triggers proper retry behavior."""
|
||||
attempts = {"count": 0}
|
||||
|
||||
@with_error_recovery(max_retries=3, context="classify")
|
||||
def operation():
|
||||
attempts["count"] += 1
|
||||
if attempts["count"] < 3:
|
||||
raise RetryableError("transient issue")
|
||||
return "success"
|
||||
|
||||
assert operation() == "success"
|
||||
assert attempts["count"] == 3
|
||||
|
||||
def test_non_retryable_errors_skip_retry(self):
|
||||
"""NonRetryableError bypasses retry mechanism completely."""
|
||||
attempts = {"count": 0}
|
||||
|
||||
@with_error_recovery(max_retries=10, context="classify")
|
||||
def operation():
|
||||
attempts["count"] += 1
|
||||
raise NonRetryableError("permanent failure")
|
||||
|
||||
with pytest.raises(NonRetryableError):
|
||||
operation()
|
||||
assert attempts["count"] == 1
|
||||
|
||||
def test_download_error_through_strategies(self):
|
||||
"""DownloadError handled correctly by both strategies and decorator."""
|
||||
# Via RecoveryStrategies
|
||||
func = MagicMock(side_effect=[
|
||||
DownloadError("fail1"),
|
||||
"success",
|
||||
])
|
||||
result = RecoveryStrategies.handle_download_failure(func)
|
||||
assert result == "success"
|
||||
|
||||
# Via decorator
|
||||
counter = {"n": 0}
|
||||
|
||||
@with_error_recovery(max_retries=3, context="dl")
|
||||
def dl():
|
||||
counter["n"] += 1
|
||||
if counter["n"] < 2:
|
||||
raise DownloadError("fail")
|
||||
return "downloaded"
|
||||
|
||||
assert dl() == "downloaded"
|
||||
302
tests/unit/test_core_error_handler.py
Normal file
302
tests/unit/test_core_error_handler.py
Normal file
@@ -0,0 +1,302 @@
|
||||
"""Unit tests for core error handler module.
|
||||
|
||||
Tests custom exceptions, retry logic, error recovery strategies,
|
||||
file corruption detection, and the with_error_recovery decorator.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.core.error_handler import (
|
||||
DownloadError,
|
||||
FileCorruptionDetector,
|
||||
NetworkError,
|
||||
NonRetryableError,
|
||||
RecoveryStrategies,
|
||||
RetryableError,
|
||||
file_corruption_detector,
|
||||
recovery_strategies,
|
||||
with_error_recovery,
|
||||
)
|
||||
|
||||
|
||||
class TestCustomExceptions:
|
||||
"""Tests for custom exception classes."""
|
||||
|
||||
def test_retryable_error_is_exception(self):
|
||||
"""RetryableError is a proper Exception subclass."""
|
||||
with pytest.raises(RetryableError):
|
||||
raise RetryableError("transient failure")
|
||||
|
||||
def test_non_retryable_error_is_exception(self):
|
||||
"""NonRetryableError is a proper Exception subclass."""
|
||||
with pytest.raises(NonRetryableError):
|
||||
raise NonRetryableError("permanent failure")
|
||||
|
||||
def test_network_error_is_exception(self):
|
||||
"""NetworkError is a proper Exception subclass."""
|
||||
with pytest.raises(NetworkError):
|
||||
raise NetworkError("connection lost")
|
||||
|
||||
def test_download_error_is_exception(self):
|
||||
"""DownloadError is a proper Exception subclass."""
|
||||
with pytest.raises(DownloadError):
|
||||
raise DownloadError("download failed")
|
||||
|
||||
def test_exception_message_preserved(self):
|
||||
"""Custom exceptions preserve their message string."""
|
||||
msg = "specific failure reason"
|
||||
err = RetryableError(msg)
|
||||
assert str(err) == msg
|
||||
|
||||
def test_exceptions_are_independent(self):
|
||||
"""Each error type is a distinct class (no inheritance among them)."""
|
||||
assert not issubclass(RetryableError, NonRetryableError)
|
||||
assert not issubclass(NetworkError, DownloadError)
|
||||
assert not issubclass(DownloadError, NetworkError)
|
||||
|
||||
|
||||
class TestRecoveryStrategiesNetworkFailure:
|
||||
"""Tests for RecoveryStrategies.handle_network_failure."""
|
||||
|
||||
def test_success_on_first_attempt(self):
|
||||
"""Returns result immediately if function succeeds."""
|
||||
func = MagicMock(return_value="ok")
|
||||
result = RecoveryStrategies.handle_network_failure(func)
|
||||
assert result == "ok"
|
||||
assert func.call_count == 1
|
||||
|
||||
def test_retries_on_network_error(self):
|
||||
"""Retries up to 3 times on NetworkError."""
|
||||
func = MagicMock(
|
||||
side_effect=[NetworkError("fail"), NetworkError("fail"), "ok"]
|
||||
)
|
||||
result = RecoveryStrategies.handle_network_failure(func)
|
||||
assert result == "ok"
|
||||
assert func.call_count == 3
|
||||
|
||||
def test_retries_on_connection_error(self):
|
||||
"""Retries on ConnectionError (built-in)."""
|
||||
func = MagicMock(
|
||||
side_effect=[ConnectionError("fail"), "ok"]
|
||||
)
|
||||
result = RecoveryStrategies.handle_network_failure(func)
|
||||
assert result == "ok"
|
||||
assert func.call_count == 2
|
||||
|
||||
def test_raises_after_max_retries(self):
|
||||
"""Raises NetworkError after 3 failed attempts."""
|
||||
func = MagicMock(side_effect=NetworkError("persistent failure"))
|
||||
with pytest.raises(NetworkError):
|
||||
RecoveryStrategies.handle_network_failure(func)
|
||||
assert func.call_count == 3
|
||||
|
||||
def test_passes_args_and_kwargs(self):
|
||||
"""Arguments are forwarded to the wrapped function."""
|
||||
func = MagicMock(return_value="result")
|
||||
RecoveryStrategies.handle_network_failure(func, "arg1", key="val")
|
||||
func.assert_called_with("arg1", key="val")
|
||||
|
||||
def test_non_network_error_not_caught(self):
|
||||
"""Non-network exceptions propagate immediately without retry."""
|
||||
func = MagicMock(side_effect=ValueError("bad input"))
|
||||
with pytest.raises(ValueError):
|
||||
RecoveryStrategies.handle_network_failure(func)
|
||||
assert func.call_count == 1
|
||||
|
||||
|
||||
class TestRecoveryStrategiesDownloadFailure:
|
||||
"""Tests for RecoveryStrategies.handle_download_failure."""
|
||||
|
||||
def test_success_on_first_attempt(self):
|
||||
"""Returns result if function succeeds first time."""
|
||||
func = MagicMock(return_value="downloaded")
|
||||
result = RecoveryStrategies.handle_download_failure(func)
|
||||
assert result == "downloaded"
|
||||
assert func.call_count == 1
|
||||
|
||||
def test_retries_on_download_error(self):
|
||||
"""Retries up to 2 times on DownloadError."""
|
||||
func = MagicMock(
|
||||
side_effect=[DownloadError("fail"), "ok"]
|
||||
)
|
||||
result = RecoveryStrategies.handle_download_failure(func)
|
||||
assert result == "ok"
|
||||
assert func.call_count == 2
|
||||
|
||||
def test_raises_after_max_retries(self):
|
||||
"""Raises DownloadError after 2 failed attempts."""
|
||||
func = MagicMock(side_effect=DownloadError("persistent"))
|
||||
with pytest.raises(DownloadError):
|
||||
RecoveryStrategies.handle_download_failure(func)
|
||||
assert func.call_count == 2
|
||||
|
||||
def test_download_max_retries_is_two(self):
|
||||
"""Download recovery allows exactly 2 attempts."""
|
||||
call_count = 0
|
||||
|
||||
def counting_func():
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
raise DownloadError("fail")
|
||||
|
||||
with pytest.raises(DownloadError):
|
||||
RecoveryStrategies.handle_download_failure(counting_func)
|
||||
assert call_count == 2
|
||||
|
||||
|
||||
class TestFileCorruptionDetector:
|
||||
"""Tests for FileCorruptionDetector."""
|
||||
|
||||
def test_valid_large_file(self, tmp_path):
|
||||
"""File larger than 1MB is considered valid."""
|
||||
filepath = tmp_path / "video.mp4"
|
||||
filepath.write_bytes(b"\x00" * (1024 * 1024 + 1))
|
||||
assert FileCorruptionDetector.is_valid_video_file(str(filepath))
|
||||
|
||||
def test_file_too_small(self, tmp_path):
|
||||
"""File smaller than 1MB is invalid."""
|
||||
filepath = tmp_path / "video.mp4"
|
||||
filepath.write_bytes(b"\x00" * 100)
|
||||
assert not FileCorruptionDetector.is_valid_video_file(str(filepath))
|
||||
|
||||
def test_exactly_1mb_is_invalid(self, tmp_path):
|
||||
"""File of exactly 1MB (1048576 bytes) is invalid (needs > 1MB)."""
|
||||
filepath = tmp_path / "video.mp4"
|
||||
filepath.write_bytes(b"\x00" * (1024 * 1024))
|
||||
assert not FileCorruptionDetector.is_valid_video_file(str(filepath))
|
||||
|
||||
def test_nonexistent_file(self):
|
||||
"""Nonexistent file returns False."""
|
||||
assert not FileCorruptionDetector.is_valid_video_file("/no/such/file")
|
||||
|
||||
def test_module_level_instance(self):
|
||||
"""Module provides a pre-created FileCorruptionDetector instance."""
|
||||
assert isinstance(file_corruption_detector, FileCorruptionDetector)
|
||||
|
||||
|
||||
class TestWithErrorRecoveryDecorator:
|
||||
"""Tests for the with_error_recovery decorator."""
|
||||
|
||||
def test_success_returns_result(self):
|
||||
"""Decorated function returns result on success."""
|
||||
|
||||
@with_error_recovery(max_retries=3, context="test")
|
||||
def succeed():
|
||||
return 42
|
||||
|
||||
assert succeed() == 42
|
||||
|
||||
def test_retries_on_generic_exception(self):
|
||||
"""Generic exception triggers retry."""
|
||||
counter = {"n": 0}
|
||||
|
||||
@with_error_recovery(max_retries=3, context="test")
|
||||
def fail_then_succeed():
|
||||
counter["n"] += 1
|
||||
if counter["n"] < 3:
|
||||
raise RuntimeError("not yet")
|
||||
return "ok"
|
||||
|
||||
assert fail_then_succeed() == "ok"
|
||||
assert counter["n"] == 3
|
||||
|
||||
def test_non_retryable_error_raises_immediately(self):
|
||||
"""NonRetryableError is not retried - raises on first occurrence."""
|
||||
counter = {"n": 0}
|
||||
|
||||
@with_error_recovery(max_retries=5, context="test")
|
||||
def permanent_failure():
|
||||
counter["n"] += 1
|
||||
raise NonRetryableError("do not retry")
|
||||
|
||||
with pytest.raises(NonRetryableError):
|
||||
permanent_failure()
|
||||
assert counter["n"] == 1
|
||||
|
||||
def test_raises_last_error_after_exhausting_retries(self):
|
||||
"""After max retries, the last error is raised."""
|
||||
|
||||
@with_error_recovery(max_retries=2, context="test")
|
||||
def always_fail():
|
||||
raise ValueError("always fails")
|
||||
|
||||
with pytest.raises(ValueError, match="always fails"):
|
||||
always_fail()
|
||||
|
||||
def test_max_retries_limits_attempts(self):
|
||||
"""Function is called exactly max_retries times when always failing."""
|
||||
counter = {"n": 0}
|
||||
|
||||
@with_error_recovery(max_retries=4, context="test")
|
||||
def counting_fail():
|
||||
counter["n"] += 1
|
||||
raise RuntimeError("fail")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
counting_fail()
|
||||
assert counter["n"] == 4
|
||||
|
||||
def test_preserves_function_name(self):
|
||||
"""Decorator preserves the original function name via functools.wraps."""
|
||||
|
||||
@with_error_recovery(max_retries=1, context="test")
|
||||
def my_function():
|
||||
pass
|
||||
|
||||
assert my_function.__name__ == "my_function"
|
||||
|
||||
def test_context_used_in_logging(self):
|
||||
"""Context string is used in error log messages."""
|
||||
counter = {"n": 0}
|
||||
|
||||
@with_error_recovery(max_retries=2, context="my_context")
|
||||
def fail_once():
|
||||
counter["n"] += 1
|
||||
if counter["n"] == 1:
|
||||
raise RuntimeError("oops")
|
||||
return "ok"
|
||||
|
||||
with patch("src.core.error_handler.logger") as mock_logger:
|
||||
fail_once()
|
||||
# Should have logged a warning with context
|
||||
mock_logger.warning.assert_called()
|
||||
logged_msg = mock_logger.warning.call_args[0][0]
|
||||
assert "my_context" in logged_msg
|
||||
|
||||
def test_retryable_error_is_retried(self):
|
||||
"""RetryableError (standard Exception subclass) is retried."""
|
||||
counter = {"n": 0}
|
||||
|
||||
@with_error_recovery(max_retries=3, context="test")
|
||||
def retryable():
|
||||
counter["n"] += 1
|
||||
if counter["n"] < 3:
|
||||
raise RetryableError("try again")
|
||||
return "done"
|
||||
|
||||
assert retryable() == "done"
|
||||
assert counter["n"] == 3
|
||||
|
||||
def test_passes_arguments_through(self):
|
||||
"""Decorated function receives all arguments correctly."""
|
||||
|
||||
@with_error_recovery(max_retries=1, context="test")
|
||||
def add(a, b, c=0):
|
||||
return a + b + c
|
||||
|
||||
assert add(1, 2, c=3) == 6
|
||||
|
||||
|
||||
class TestModuleLevelInstances:
|
||||
"""Tests for module-level singleton instances."""
|
||||
|
||||
def test_recovery_strategies_instance(self):
|
||||
"""Module provides a RecoveryStrategies instance."""
|
||||
assert isinstance(recovery_strategies, RecoveryStrategies)
|
||||
|
||||
def test_file_corruption_detector_instance(self):
|
||||
"""Module provides a FileCorruptionDetector instance."""
|
||||
assert isinstance(file_corruption_detector, FileCorruptionDetector)
|
||||
345
tests/unit/test_middleware_error_handler.py
Normal file
345
tests/unit/test_middleware_error_handler.py
Normal file
@@ -0,0 +1,345 @@
|
||||
"""Unit tests for FastAPI middleware error handler.
|
||||
|
||||
Tests error response formatting, exception handler registration,
|
||||
custom exception handling, and the general exception handler.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from src.server.exceptions import (
|
||||
AniWorldAPIException,
|
||||
AuthenticationError,
|
||||
AuthorizationError,
|
||||
BadRequestError,
|
||||
ConflictError,
|
||||
NotFoundError,
|
||||
RateLimitError,
|
||||
ValidationError,
|
||||
)
|
||||
from src.server.middleware.error_handler import (
|
||||
create_error_response,
|
||||
register_exception_handlers,
|
||||
)
|
||||
|
||||
|
||||
class TestCreateErrorResponse:
|
||||
"""Tests for the create_error_response utility function."""
|
||||
|
||||
def test_basic_error_response_structure(self):
|
||||
"""Error response has success, error, and message keys."""
|
||||
resp = create_error_response(
|
||||
status_code=400, error="BAD_REQUEST", message="Invalid input"
|
||||
)
|
||||
assert resp["success"] is False
|
||||
assert resp["error"] == "BAD_REQUEST"
|
||||
assert resp["message"] == "Invalid input"
|
||||
|
||||
def test_response_includes_details_when_provided(self):
|
||||
"""Details dict is included when specified."""
|
||||
details = {"field": "name", "reason": "too long"}
|
||||
resp = create_error_response(
|
||||
status_code=422, error="VALIDATION", message="Bad",
|
||||
details=details,
|
||||
)
|
||||
assert resp["details"] == details
|
||||
|
||||
def test_response_excludes_details_when_none(self):
|
||||
"""Details key absent when not specified."""
|
||||
resp = create_error_response(
|
||||
status_code=400, error="ERR", message="msg"
|
||||
)
|
||||
assert "details" not in resp
|
||||
|
||||
def test_response_includes_request_id(self):
|
||||
"""Request ID is included when provided."""
|
||||
resp = create_error_response(
|
||||
status_code=500, error="ERR", message="msg",
|
||||
request_id="req-123",
|
||||
)
|
||||
assert resp["request_id"] == "req-123"
|
||||
|
||||
def test_response_excludes_request_id_when_none(self):
|
||||
"""Request ID key absent when not specified."""
|
||||
resp = create_error_response(
|
||||
status_code=500, error="ERR", message="msg"
|
||||
)
|
||||
assert "request_id" not in resp
|
||||
|
||||
|
||||
class TestExceptionHandlerRegistration:
|
||||
"""Tests that exception handlers are correctly registered on a FastAPI app."""
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_handlers(self) -> FastAPI:
|
||||
"""Create a FastAPI app with registered exception handlers."""
|
||||
app = FastAPI()
|
||||
register_exception_handlers(app)
|
||||
return app
|
||||
|
||||
def _add_route_raising(self, app: FastAPI, exc: Exception):
|
||||
"""Add a GET /test route that raises the given exception."""
|
||||
@app.get("/test")
|
||||
async def route():
|
||||
raise exc
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authentication_error_returns_401(
|
||||
self, app_with_handlers
|
||||
):
|
||||
"""AuthenticationError maps to HTTP 401."""
|
||||
self._add_route_raising(
|
||||
app_with_handlers, AuthenticationError("bad creds")
|
||||
)
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app_with_handlers),
|
||||
base_url="http://test",
|
||||
) as client:
|
||||
resp = await client.get("/test")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["success"] is False
|
||||
assert body["error"] == "AUTHENTICATION_ERROR"
|
||||
assert body["message"] == "bad creds"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_authorization_error_returns_403(
|
||||
self, app_with_handlers
|
||||
):
|
||||
"""AuthorizationError maps to HTTP 403."""
|
||||
self._add_route_raising(
|
||||
app_with_handlers, AuthorizationError("forbidden")
|
||||
)
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app_with_handlers),
|
||||
base_url="http://test",
|
||||
) as client:
|
||||
resp = await client.get("/test")
|
||||
assert resp.status_code == 403
|
||||
assert resp.json()["error"] == "AUTHORIZATION_ERROR"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bad_request_error_returns_400(
|
||||
self, app_with_handlers
|
||||
):
|
||||
"""BadRequestError maps to HTTP 400."""
|
||||
self._add_route_raising(
|
||||
app_with_handlers, BadRequestError("invalid")
|
||||
)
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app_with_handlers),
|
||||
base_url="http://test",
|
||||
) as client:
|
||||
resp = await client.get("/test")
|
||||
assert resp.status_code == 400
|
||||
assert resp.json()["error"] == "BAD_REQUEST"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_not_found_error_returns_404(
|
||||
self, app_with_handlers
|
||||
):
|
||||
"""NotFoundError maps to HTTP 404."""
|
||||
self._add_route_raising(
|
||||
app_with_handlers,
|
||||
NotFoundError("anime not found", resource_type="anime", resource_id=42),
|
||||
)
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app_with_handlers),
|
||||
base_url="http://test",
|
||||
) as client:
|
||||
resp = await client.get("/test")
|
||||
assert resp.status_code == 404
|
||||
body = resp.json()
|
||||
assert body["error"] == "NOT_FOUND"
|
||||
assert body["details"]["resource_type"] == "anime"
|
||||
assert body["details"]["resource_id"] == 42
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_validation_error_returns_422(
|
||||
self, app_with_handlers
|
||||
):
|
||||
"""ValidationError maps to HTTP 422."""
|
||||
self._add_route_raising(
|
||||
app_with_handlers, ValidationError("bad data")
|
||||
)
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app_with_handlers),
|
||||
base_url="http://test",
|
||||
) as client:
|
||||
resp = await client.get("/test")
|
||||
assert resp.status_code == 422
|
||||
assert resp.json()["error"] == "VALIDATION_ERROR"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conflict_error_returns_409(
|
||||
self, app_with_handlers
|
||||
):
|
||||
"""ConflictError maps to HTTP 409."""
|
||||
self._add_route_raising(
|
||||
app_with_handlers, ConflictError("duplicate")
|
||||
)
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app_with_handlers),
|
||||
base_url="http://test",
|
||||
) as client:
|
||||
resp = await client.get("/test")
|
||||
assert resp.status_code == 409
|
||||
assert resp.json()["error"] == "CONFLICT"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_error_returns_429(
|
||||
self, app_with_handlers
|
||||
):
|
||||
"""RateLimitError maps to HTTP 429."""
|
||||
self._add_route_raising(
|
||||
app_with_handlers, RateLimitError("too many", retry_after=60)
|
||||
)
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app_with_handlers),
|
||||
base_url="http://test",
|
||||
) as client:
|
||||
resp = await client.get("/test")
|
||||
assert resp.status_code == 429
|
||||
body = resp.json()
|
||||
assert body["error"] == "RATE_LIMIT_EXCEEDED"
|
||||
assert body["details"]["retry_after"] == 60
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generic_api_exception_returns_status(
|
||||
self, app_with_handlers
|
||||
):
|
||||
"""AniWorldAPIException uses its status_code."""
|
||||
self._add_route_raising(
|
||||
app_with_handlers,
|
||||
AniWorldAPIException("custom error", status_code=418),
|
||||
)
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app_with_handlers),
|
||||
base_url="http://test",
|
||||
) as client:
|
||||
resp = await client.get("/test")
|
||||
assert resp.status_code == 418
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unexpected_exception_returns_500(
|
||||
self, app_with_handlers
|
||||
):
|
||||
"""Unhandled exceptions map to HTTP 500 with generic message."""
|
||||
self._add_route_raising(
|
||||
app_with_handlers, RuntimeError("unexpected crash")
|
||||
)
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app_with_handlers, raise_app_exceptions=False),
|
||||
base_url="http://test",
|
||||
) as client:
|
||||
resp = await client.get("/test")
|
||||
assert resp.status_code == 500
|
||||
body = resp.json()
|
||||
assert body["error"] == "INTERNAL_SERVER_ERROR"
|
||||
assert body["message"] == "An unexpected error occurred"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unexpected_exception_hides_stack_trace(
|
||||
self, app_with_handlers
|
||||
):
|
||||
"""Stack traces are not leaked in 500 error responses."""
|
||||
self._add_route_raising(
|
||||
app_with_handlers, RuntimeError("internal secret")
|
||||
)
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app_with_handlers, raise_app_exceptions=False),
|
||||
base_url="http://test",
|
||||
) as client:
|
||||
resp = await client.get("/test")
|
||||
body = resp.json()
|
||||
assert "internal secret" not in body["message"]
|
||||
assert "Traceback" not in str(body)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_response_is_json(self, app_with_handlers):
|
||||
"""All error responses are JSON formatted."""
|
||||
self._add_route_raising(
|
||||
app_with_handlers, NotFoundError("missing")
|
||||
)
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app_with_handlers),
|
||||
base_url="http://test",
|
||||
) as client:
|
||||
resp = await client.get("/test")
|
||||
assert resp.headers["content-type"] == "application/json"
|
||||
|
||||
|
||||
class TestExceptionClasses:
|
||||
"""Tests for custom exception class properties."""
|
||||
|
||||
def test_aniworld_exception_defaults(self):
|
||||
"""AniWorldAPIException has sensible defaults."""
|
||||
exc = AniWorldAPIException("test")
|
||||
assert exc.message == "test"
|
||||
assert exc.status_code == 500
|
||||
assert exc.error_code == "AniWorldAPIException"
|
||||
assert exc.details == {}
|
||||
|
||||
def test_to_dict_format(self):
|
||||
"""to_dict returns proper structure."""
|
||||
exc = AniWorldAPIException(
|
||||
"fail", status_code=400, error_code="FAIL",
|
||||
details={"reason": "bad"}
|
||||
)
|
||||
d = exc.to_dict()
|
||||
assert d["error"] == "FAIL"
|
||||
assert d["message"] == "fail"
|
||||
assert d["details"]["reason"] == "bad"
|
||||
|
||||
def test_not_found_with_resource_info(self):
|
||||
"""NotFoundError includes resource_type and resource_id in details."""
|
||||
exc = NotFoundError(
|
||||
"not found", resource_type="anime", resource_id="abc-123"
|
||||
)
|
||||
assert exc.details["resource_type"] == "anime"
|
||||
assert exc.details["resource_id"] == "abc-123"
|
||||
|
||||
def test_rate_limit_with_retry_after(self):
|
||||
"""RateLimitError includes retry_after in details."""
|
||||
exc = RateLimitError("slow down", retry_after=30)
|
||||
assert exc.details["retry_after"] == 30
|
||||
|
||||
def test_authentication_error_defaults(self):
|
||||
"""AuthenticationError defaults to 401 status."""
|
||||
exc = AuthenticationError()
|
||||
assert exc.status_code == 401
|
||||
assert exc.error_code == "AUTHENTICATION_ERROR"
|
||||
|
||||
def test_authorization_error_defaults(self):
|
||||
"""AuthorizationError defaults to 403 status."""
|
||||
exc = AuthorizationError()
|
||||
assert exc.status_code == 403
|
||||
|
||||
def test_validation_error_defaults(self):
|
||||
"""ValidationError defaults to 422 status."""
|
||||
exc = ValidationError()
|
||||
assert exc.status_code == 422
|
||||
|
||||
def test_bad_request_error_defaults(self):
|
||||
"""BadRequestError defaults to 400 status."""
|
||||
exc = BadRequestError()
|
||||
assert exc.status_code == 400
|
||||
|
||||
def test_conflict_error_defaults(self):
|
||||
"""ConflictError defaults to 409 status."""
|
||||
exc = ConflictError()
|
||||
assert exc.status_code == 409
|
||||
|
||||
def test_exception_inheritance_chain(self):
|
||||
"""All custom exceptions inherit from AniWorldAPIException."""
|
||||
assert issubclass(AuthenticationError, AniWorldAPIException)
|
||||
assert issubclass(AuthorizationError, AniWorldAPIException)
|
||||
assert issubclass(NotFoundError, AniWorldAPIException)
|
||||
assert issubclass(ValidationError, AniWorldAPIException)
|
||||
assert issubclass(BadRequestError, AniWorldAPIException)
|
||||
assert issubclass(ConflictError, AniWorldAPIException)
|
||||
assert issubclass(RateLimitError, AniWorldAPIException)
|
||||
Reference in New Issue
Block a user