313 lines
10 KiB
Python
313 lines
10 KiB
Python
"""Integration tests for provider failover scenarios - End-to-end provider switching."""
|
|
|
|
import asyncio
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from src.core.providers.failover import (
|
|
ProviderFailover,
|
|
configure_failover,
|
|
get_failover,
|
|
)
|
|
from src.core.providers.health_monitor import ProviderHealthMonitor
|
|
|
|
|
|
class TestProviderFailoverScenarios:
|
|
"""Test end-to-end failover scenarios with multiple providers."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_primary_fails_switches_to_backup(self):
|
|
"""When primary provider fails, should switch to backup and succeed."""
|
|
call_log = []
|
|
|
|
async def operation(provider: str) -> str:
|
|
call_log.append(provider)
|
|
if provider == "provider1":
|
|
raise ConnectionError("Provider1 is down")
|
|
return f"Success from {provider}"
|
|
|
|
failover = ProviderFailover(
|
|
providers=["provider1", "provider2", "provider3"],
|
|
max_retries=1,
|
|
retry_delay=0.01,
|
|
enable_health_monitoring=False,
|
|
)
|
|
|
|
result = await failover.execute_with_failover(
|
|
operation=operation,
|
|
operation_name="test_failover",
|
|
)
|
|
|
|
assert "Success" in result
|
|
assert "provider1" in call_log
|
|
assert len(call_log) >= 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_first_two_fail_third_succeeds(self):
|
|
"""When first two providers fail, third should be tried."""
|
|
attempts = {}
|
|
|
|
async def operation(provider: str) -> str:
|
|
attempts[provider] = attempts.get(provider, 0) + 1
|
|
if provider in ("provider1", "provider2"):
|
|
raise ConnectionError(f"{provider} is down")
|
|
return f"Success from {provider}"
|
|
|
|
failover = ProviderFailover(
|
|
providers=["provider1", "provider2", "provider3"],
|
|
max_retries=1,
|
|
retry_delay=0.01,
|
|
enable_health_monitoring=False,
|
|
)
|
|
|
|
result = await failover.execute_with_failover(
|
|
operation=operation,
|
|
operation_name="test_two_fail",
|
|
)
|
|
|
|
assert "provider3" in result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_all_providers_fail_raises(self):
|
|
"""When all providers fail, should raise exception."""
|
|
async def operation(provider: str) -> str:
|
|
raise ConnectionError(f"{provider} is down")
|
|
|
|
failover = ProviderFailover(
|
|
providers=["provider1", "provider2"],
|
|
max_retries=1,
|
|
retry_delay=0.01,
|
|
enable_health_monitoring=False,
|
|
)
|
|
|
|
with pytest.raises(Exception, match="failed with all providers"):
|
|
await failover.execute_with_failover(
|
|
operation=operation,
|
|
operation_name="test_all_fail",
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retry_within_single_provider(self):
|
|
"""Should retry with same provider before moving to next."""
|
|
call_count = 0
|
|
|
|
async def operation(provider: str) -> str:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count <= 2:
|
|
raise ConnectionError("Temporary failure")
|
|
return f"Success on attempt {call_count}"
|
|
|
|
failover = ProviderFailover(
|
|
providers=["provider1"],
|
|
max_retries=3,
|
|
retry_delay=0.01,
|
|
enable_health_monitoring=False,
|
|
)
|
|
|
|
result = await failover.execute_with_failover(
|
|
operation=operation,
|
|
operation_name="test_retry",
|
|
)
|
|
|
|
assert "Success" in result
|
|
assert call_count == 3
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_failover_with_health_monitoring(self):
|
|
"""Failover should integrate with health monitoring."""
|
|
monitor = ProviderHealthMonitor(failure_threshold=2)
|
|
|
|
# Pre-record failures for provider1 to make it unavailable
|
|
for _ in range(3):
|
|
monitor.record_request(
|
|
provider_name="provider1",
|
|
success=False,
|
|
response_time_ms=100,
|
|
error_message="Simulated failure",
|
|
)
|
|
|
|
# Provider1 should now be unavailable
|
|
assert "provider1" not in monitor.get_available_providers()
|
|
|
|
with patch(
|
|
"src.core.providers.failover.get_health_monitor",
|
|
return_value=monitor,
|
|
):
|
|
failover = ProviderFailover(
|
|
providers=["provider1", "provider2"],
|
|
max_retries=1,
|
|
retry_delay=0.01,
|
|
enable_health_monitoring=True,
|
|
)
|
|
|
|
# Current provider should prefer provider2
|
|
current = failover.get_current_provider()
|
|
# Best provider selection should favor available ones
|
|
available = monitor.get_available_providers()
|
|
if available:
|
|
assert current in available or current == "provider2"
|
|
|
|
|
|
class TestProviderFailoverChainManagement:
|
|
"""Test failover chain add/remove/priority operations."""
|
|
|
|
def test_add_provider_to_chain(self):
|
|
"""Should add new provider to the failover chain."""
|
|
failover = ProviderFailover(
|
|
providers=["p1", "p2"],
|
|
enable_health_monitoring=False,
|
|
)
|
|
failover.add_provider("p3")
|
|
assert "p3" in failover.get_providers()
|
|
|
|
def test_add_duplicate_provider_no_effect(self):
|
|
"""Adding existing provider should not create duplicate."""
|
|
failover = ProviderFailover(
|
|
providers=["p1", "p2"],
|
|
enable_health_monitoring=False,
|
|
)
|
|
failover.add_provider("p1")
|
|
assert failover.get_providers().count("p1") == 1
|
|
|
|
def test_remove_provider_from_chain(self):
|
|
"""Should remove provider from the failover chain."""
|
|
failover = ProviderFailover(
|
|
providers=["p1", "p2", "p3"],
|
|
enable_health_monitoring=False,
|
|
)
|
|
result = failover.remove_provider("p2")
|
|
assert result is True
|
|
assert "p2" not in failover.get_providers()
|
|
|
|
def test_remove_nonexistent_provider(self):
|
|
"""Removing non-existent provider should return False."""
|
|
failover = ProviderFailover(
|
|
providers=["p1"],
|
|
enable_health_monitoring=False,
|
|
)
|
|
assert failover.remove_provider("p99") is False
|
|
|
|
def test_set_provider_priority(self):
|
|
"""Should reorder provider in the chain."""
|
|
failover = ProviderFailover(
|
|
providers=["p1", "p2", "p3"],
|
|
enable_health_monitoring=False,
|
|
)
|
|
result = failover.set_provider_priority("p3", 0)
|
|
assert result is True
|
|
providers = failover.get_providers()
|
|
assert providers[0] == "p3"
|
|
|
|
def test_set_priority_unknown_provider(self):
|
|
"""Setting priority for unknown provider should return False."""
|
|
failover = ProviderFailover(
|
|
providers=["p1"],
|
|
enable_health_monitoring=False,
|
|
)
|
|
assert failover.set_provider_priority("unknown", 0) is False
|
|
|
|
|
|
class TestFailoverStats:
|
|
"""Test failover statistics reporting."""
|
|
|
|
def test_get_failover_stats(self):
|
|
"""Should return comprehensive stats."""
|
|
failover = ProviderFailover(
|
|
providers=["p1", "p2"],
|
|
max_retries=3,
|
|
retry_delay=1.0,
|
|
enable_health_monitoring=False,
|
|
)
|
|
|
|
stats = failover.get_failover_stats()
|
|
assert stats["total_providers"] == 2
|
|
assert stats["max_retries"] == 3
|
|
assert stats["retry_delay"] == 1.0
|
|
assert len(stats["providers"]) == 2
|
|
|
|
def test_stats_with_health_monitoring(self):
|
|
"""Stats should include availability info when monitoring enabled."""
|
|
monitor = ProviderHealthMonitor()
|
|
monitor.record_request("p1", True, 100)
|
|
monitor.record_request("p2", False, 200, error_message="fail")
|
|
monitor.record_request("p2", False, 200, error_message="fail")
|
|
monitor.record_request("p2", False, 200, error_message="fail")
|
|
|
|
with patch(
|
|
"src.core.providers.failover.get_health_monitor",
|
|
return_value=monitor,
|
|
):
|
|
failover = ProviderFailover(
|
|
providers=["p1", "p2"],
|
|
enable_health_monitoring=True,
|
|
)
|
|
stats = failover.get_failover_stats()
|
|
assert "available_providers" in stats
|
|
assert "unavailable_providers" in stats
|
|
|
|
|
|
class TestConfigureFailover:
|
|
"""Test the global failover configuration function."""
|
|
|
|
def test_configure_failover(self):
|
|
"""configure_failover should create a new global instance."""
|
|
import src.core.providers.failover as fo
|
|
fo._failover = None
|
|
|
|
failover = configure_failover(
|
|
providers=["custom1", "custom2"],
|
|
max_retries=5,
|
|
retry_delay=0.5,
|
|
)
|
|
|
|
assert isinstance(failover, ProviderFailover)
|
|
assert failover.get_providers() == ["custom1", "custom2"]
|
|
assert failover._max_retries == 5
|
|
|
|
# Cleanup
|
|
fo._failover = None
|
|
|
|
def test_get_failover_singleton(self):
|
|
"""get_failover should return same instance."""
|
|
import src.core.providers.failover as fo
|
|
fo._failover = None
|
|
|
|
first = get_failover()
|
|
second = get_failover()
|
|
assert first is second
|
|
|
|
fo._failover = None
|
|
|
|
|
|
class TestNextProviderRotation:
|
|
"""Test provider rotation logic."""
|
|
|
|
def test_get_next_cycles_through_all(self):
|
|
"""get_next_provider should cycle through all providers."""
|
|
failover = ProviderFailover(
|
|
providers=["p1", "p2", "p3"],
|
|
enable_health_monitoring=False,
|
|
)
|
|
|
|
seen = set()
|
|
for _ in range(3):
|
|
provider = failover.get_next_provider()
|
|
if provider:
|
|
seen.add(provider)
|
|
|
|
assert len(seen) >= 2 # Should see at least 2 different providers
|
|
|
|
def test_current_provider_is_from_list(self):
|
|
"""get_current_provider should always return from provider list."""
|
|
failover = ProviderFailover(
|
|
providers=["p1", "p2", "p3"],
|
|
enable_health_monitoring=False,
|
|
)
|
|
|
|
for _ in range(10):
|
|
current = failover.get_current_provider()
|
|
assert current in ["p1", "p2", "p3"]
|
|
failover.get_next_provider()
|