316 lines
10 KiB
Python
316 lines
10 KiB
Python
"""Integration tests for error recovery workflows.
|
|
|
|
Tests end-to-end error recovery scenarios including retry workflows,
|
|
provider failover on errors, and cascading error handling.
|
|
"""
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from src.core.error_handler import (
|
|
DownloadError,
|
|
NetworkError,
|
|
NonRetryableError,
|
|
RecoveryStrategies,
|
|
RetryableError,
|
|
with_error_recovery,
|
|
)
|
|
|
|
|
|
class TestDownloadRetryWorkflow:
|
|
"""End-to-end tests: download fails → retries → eventually succeeds/fails."""
|
|
|
|
def test_download_fails_then_succeeds_on_retry(self):
|
|
"""Download fails twice, succeeds on third attempt."""
|
|
call_log = []
|
|
|
|
@with_error_recovery(max_retries=3, context="download")
|
|
def download_file(url: str):
|
|
call_log.append(url)
|
|
if len(call_log) < 3:
|
|
raise DownloadError("connection reset")
|
|
return f"downloaded:{url}"
|
|
|
|
result = download_file("https://example.com/video.mp4")
|
|
assert result == "downloaded:https://example.com/video.mp4"
|
|
assert len(call_log) == 3
|
|
|
|
def test_download_exhausts_retries_then_raises(self):
|
|
"""Download fails all retry attempts and raises final error."""
|
|
|
|
@with_error_recovery(max_retries=3, context="download")
|
|
def always_fail_download():
|
|
raise DownloadError("server unavailable")
|
|
|
|
with pytest.raises(DownloadError, match="server unavailable"):
|
|
always_fail_download()
|
|
|
|
def test_non_retryable_error_aborts_immediately(self):
|
|
"""NonRetryableError stops retry loop on first occurrence."""
|
|
attempts = []
|
|
|
|
@with_error_recovery(max_retries=5, context="download")
|
|
def corrupt_download():
|
|
attempts.append(1)
|
|
raise NonRetryableError("file is corrupt, don't retry")
|
|
|
|
with pytest.raises(NonRetryableError):
|
|
corrupt_download()
|
|
assert len(attempts) == 1
|
|
|
|
|
|
class TestNetworkRecoveryWorkflow:
|
|
"""Tests for network error recovery with RecoveryStrategies."""
|
|
|
|
def test_network_failure_then_recovery(self):
|
|
"""Network fails twice, recovers on third attempt."""
|
|
attempts = []
|
|
|
|
def fetch_data():
|
|
attempts.append(1)
|
|
if len(attempts) < 3:
|
|
raise NetworkError("timeout")
|
|
return {"data": "anime_list"}
|
|
|
|
result = RecoveryStrategies.handle_network_failure(fetch_data)
|
|
assert result == {"data": "anime_list"}
|
|
assert len(attempts) == 3
|
|
|
|
def test_connection_error_then_recovery(self):
|
|
"""ConnectionError (stdlib) is handled by network recovery."""
|
|
attempts = []
|
|
|
|
def connect():
|
|
attempts.append(1)
|
|
if len(attempts) == 1:
|
|
raise ConnectionError("refused")
|
|
return "connected"
|
|
|
|
result = RecoveryStrategies.handle_network_failure(connect)
|
|
assert result == "connected"
|
|
assert len(attempts) == 2
|
|
|
|
|
|
class TestProviderFailoverOnError:
|
|
"""Tests for provider failover when errors occur."""
|
|
|
|
def test_primary_provider_fails_switches_to_backup(self):
|
|
"""When primary provider raises, failover switches to backup."""
|
|
primary = MagicMock(side_effect=NetworkError("primary down"))
|
|
backup = MagicMock(return_value="backup_result")
|
|
providers = [primary, backup]
|
|
|
|
result = None
|
|
for provider in providers:
|
|
try:
|
|
result = provider()
|
|
break
|
|
except (NetworkError, ConnectionError):
|
|
continue
|
|
|
|
assert result == "backup_result"
|
|
primary.assert_called_once()
|
|
backup.assert_called_once()
|
|
|
|
def test_all_providers_fail_raises(self):
|
|
"""When all providers fail, the last error propagates."""
|
|
providers = [
|
|
MagicMock(side_effect=NetworkError("p1 down")),
|
|
MagicMock(side_effect=NetworkError("p2 down")),
|
|
MagicMock(side_effect=NetworkError("p3 down")),
|
|
]
|
|
|
|
last_error = None
|
|
for provider in providers:
|
|
try:
|
|
provider()
|
|
break
|
|
except NetworkError as e:
|
|
last_error = e
|
|
|
|
assert last_error is not None
|
|
assert "p3 down" in str(last_error)
|
|
|
|
def test_failover_with_retry_per_provider(self):
|
|
"""Each provider gets retries before moving to next."""
|
|
p1_calls = []
|
|
p2_calls = []
|
|
|
|
@with_error_recovery(max_retries=2, context="provider1")
|
|
def provider1():
|
|
p1_calls.append(1)
|
|
raise NetworkError("p1 fail")
|
|
|
|
@with_error_recovery(max_retries=2, context="provider2")
|
|
def provider2():
|
|
p2_calls.append(1)
|
|
return "p2_success"
|
|
|
|
result = None
|
|
for provider_fn in [provider1, provider2]:
|
|
try:
|
|
result = provider_fn()
|
|
break
|
|
except NetworkError:
|
|
continue
|
|
|
|
assert result == "p2_success"
|
|
assert len(p1_calls) == 2 # provider1 exhausted its retries
|
|
assert len(p2_calls) == 1 # provider2 succeeded first try
|
|
|
|
|
|
class TestCascadingErrorHandling:
|
|
"""Tests for cascading error scenarios."""
|
|
|
|
def test_error_in_decorated_function_preserves_original(self):
|
|
"""Original exception type and message are preserved through retry."""
|
|
|
|
@with_error_recovery(max_retries=1, context="cascade")
|
|
def inner_fail():
|
|
raise ValueError("original error context")
|
|
|
|
with pytest.raises(ValueError, match="original error context"):
|
|
inner_fail()
|
|
|
|
def test_nested_recovery_decorators(self):
|
|
"""Nested error recovery decorators work independently."""
|
|
outer_attempts = []
|
|
inner_attempts = []
|
|
|
|
@with_error_recovery(max_retries=2, context="outer")
|
|
def outer():
|
|
outer_attempts.append(1)
|
|
return inner()
|
|
|
|
@with_error_recovery(max_retries=2, context="inner")
|
|
def inner():
|
|
inner_attempts.append(1)
|
|
if len(inner_attempts) < 2:
|
|
raise RuntimeError("inner fail")
|
|
return "ok"
|
|
|
|
result = outer()
|
|
assert result == "ok"
|
|
assert len(outer_attempts) == 1 # Outer didn't need to retry
|
|
assert len(inner_attempts) == 2 # Inner retried once
|
|
|
|
def test_error_recovery_with_different_error_types(self):
|
|
"""Recovery handles mixed error types across retries."""
|
|
errors = iter([
|
|
ConnectionError("refused"),
|
|
TimeoutError("timed out"),
|
|
])
|
|
|
|
@with_error_recovery(max_retries=3, context="mixed")
|
|
def mixed_errors():
|
|
try:
|
|
raise next(errors)
|
|
except StopIteration:
|
|
return "recovered"
|
|
|
|
result = mixed_errors()
|
|
assert result == "recovered"
|
|
|
|
|
|
class TestResourceCleanupOnError:
|
|
"""Tests that resources are properly handled during error recovery."""
|
|
|
|
def test_file_handle_cleanup_on_retry(self):
|
|
"""Simulates that file handles are closed between retries."""
|
|
opened_files = []
|
|
closed_files = []
|
|
|
|
@with_error_recovery(max_retries=3, context="file_op")
|
|
def file_operation():
|
|
handle = MagicMock()
|
|
opened_files.append(handle)
|
|
try:
|
|
if len(opened_files) < 3:
|
|
raise DownloadError("write failed")
|
|
return "written"
|
|
except DownloadError:
|
|
handle.close()
|
|
closed_files.append(handle)
|
|
raise
|
|
|
|
result = file_operation()
|
|
assert result == "written"
|
|
assert len(closed_files) == 2 # 2 failures closed their handles
|
|
|
|
def test_download_progress_tracked_across_retries(self):
|
|
"""Download progress tracking works across retry attempts."""
|
|
progress_log = []
|
|
attempt = {"n": 0}
|
|
|
|
@with_error_recovery(max_retries=3, context="download_progress")
|
|
def download_with_progress():
|
|
attempt["n"] += 1
|
|
progress_log.append("started")
|
|
if attempt["n"] < 3:
|
|
progress_log.append("failed")
|
|
raise DownloadError("interrupted")
|
|
progress_log.append("completed")
|
|
return "done"
|
|
|
|
result = download_with_progress()
|
|
assert result == "done"
|
|
assert progress_log == [
|
|
"started", "failed",
|
|
"started", "failed",
|
|
"started", "completed",
|
|
]
|
|
|
|
|
|
class TestErrorClassificationWorkflow:
|
|
"""Tests for correct error classification in workflows."""
|
|
|
|
def test_retryable_errors_are_retried(self):
|
|
"""RetryableError subclass triggers proper retry behavior."""
|
|
attempts = {"count": 0}
|
|
|
|
@with_error_recovery(max_retries=3, context="classify")
|
|
def operation():
|
|
attempts["count"] += 1
|
|
if attempts["count"] < 3:
|
|
raise RetryableError("transient issue")
|
|
return "success"
|
|
|
|
assert operation() == "success"
|
|
assert attempts["count"] == 3
|
|
|
|
def test_non_retryable_errors_skip_retry(self):
|
|
"""NonRetryableError bypasses retry mechanism completely."""
|
|
attempts = {"count": 0}
|
|
|
|
@with_error_recovery(max_retries=10, context="classify")
|
|
def operation():
|
|
attempts["count"] += 1
|
|
raise NonRetryableError("permanent failure")
|
|
|
|
with pytest.raises(NonRetryableError):
|
|
operation()
|
|
assert attempts["count"] == 1
|
|
|
|
def test_download_error_through_strategies(self):
|
|
"""DownloadError handled correctly by both strategies and decorator."""
|
|
# Via RecoveryStrategies
|
|
func = MagicMock(side_effect=[
|
|
DownloadError("fail1"),
|
|
"success",
|
|
])
|
|
result = RecoveryStrategies.handle_download_failure(func)
|
|
assert result == "success"
|
|
|
|
# Via decorator
|
|
counter = {"n": 0}
|
|
|
|
@with_error_recovery(max_retries=3, context="dl")
|
|
def dl():
|
|
counter["n"] += 1
|
|
if counter["n"] < 2:
|
|
raise DownloadError("fail")
|
|
return "downloaded"
|
|
|
|
assert dl() == "downloaded"
|