- Implemented ProviderHealthMonitor for real-time tracking - Monitors availability, response times, success rates - Automatic marking unavailable after failures - Background health check loop - Added ProviderFailover for automatic provider switching - Configurable retry attempts with exponential backoff - Integration with health monitoring - Smart provider selection - Created MonitoredProviderWrapper for performance tracking - Transparent monitoring for any provider - Automatic metric recording - No changes needed to existing providers - Implemented ProviderConfigManager for dynamic configuration - Runtime updates without restart - Per-provider settings (timeout, retries, bandwidth) - JSON-based persistence - Added Provider Management API (15+ endpoints) - Health monitoring endpoints - Configuration management - Failover control - Comprehensive testing (34 tests, 100% pass rate) - Health monitoring tests - Failover scenario tests - Configuration management tests - Documentation updates - Updated infrastructure.md - Updated instructions.md - Created PROVIDER_ENHANCEMENT_SUMMARY.md Total: ~2,593 lines of code, 34 passing tests
208 lines
6.3 KiB
Python
208 lines
6.3 KiB
Python
"""Unit tests for provider failover system."""
|
|
import pytest
|
|
|
|
from src.core.providers.failover import (
|
|
ProviderFailover,
|
|
configure_failover,
|
|
get_failover,
|
|
)
|
|
|
|
|
|
class TestProviderFailover:
|
|
"""Test ProviderFailover class."""
|
|
|
|
def test_failover_initialization(self):
|
|
"""Test failover initialization."""
|
|
providers = ["provider1", "provider2", "provider3"]
|
|
failover = ProviderFailover(
|
|
providers=providers,
|
|
max_retries=5,
|
|
retry_delay=2.0,
|
|
)
|
|
|
|
assert failover._providers == providers
|
|
assert failover._max_retries == 5
|
|
assert failover._retry_delay == 2.0
|
|
|
|
def test_get_current_provider(self):
|
|
"""Test getting current provider."""
|
|
providers = ["provider1", "provider2"]
|
|
failover = ProviderFailover(
|
|
providers=providers,
|
|
enable_health_monitoring=False,
|
|
)
|
|
|
|
current = failover.get_current_provider()
|
|
assert current in providers
|
|
|
|
def test_get_next_provider(self):
|
|
"""Test getting next provider."""
|
|
providers = ["provider1", "provider2", "provider3"]
|
|
failover = ProviderFailover(
|
|
providers=providers,
|
|
enable_health_monitoring=False,
|
|
)
|
|
|
|
first = failover.get_current_provider()
|
|
next_provider = failover.get_next_provider()
|
|
|
|
assert next_provider in providers
|
|
assert next_provider != first
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_with_failover_success(self):
|
|
"""Test successful execution with failover."""
|
|
|
|
async def mock_operation(provider: str) -> str:
|
|
return f"Success with {provider}"
|
|
|
|
failover = ProviderFailover(
|
|
providers=["provider1"],
|
|
enable_health_monitoring=False,
|
|
)
|
|
|
|
result = await failover.execute_with_failover(
|
|
operation=mock_operation,
|
|
operation_name="test_op",
|
|
)
|
|
|
|
assert "Success" in result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_with_failover_retry(self):
|
|
"""Test failover with retry on first failure."""
|
|
call_count = 0
|
|
|
|
async def mock_operation(provider: str) -> str:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
raise Exception("First attempt failed")
|
|
return f"Success with {provider}"
|
|
|
|
failover = ProviderFailover(
|
|
providers=["provider1"],
|
|
max_retries=2,
|
|
retry_delay=0.1,
|
|
enable_health_monitoring=False,
|
|
)
|
|
|
|
result = await failover.execute_with_failover(
|
|
operation=mock_operation,
|
|
operation_name="test_op",
|
|
)
|
|
|
|
assert "Success" in result
|
|
assert call_count == 2
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_with_failover_all_fail(self):
|
|
"""Test failover when all providers fail."""
|
|
|
|
async def mock_operation(provider: str) -> str:
|
|
raise Exception(f"Failed with {provider}")
|
|
|
|
failover = ProviderFailover(
|
|
providers=["provider1", "provider2"],
|
|
max_retries=1,
|
|
retry_delay=0.1,
|
|
enable_health_monitoring=False,
|
|
)
|
|
|
|
with pytest.raises(Exception) as exc_info:
|
|
await failover.execute_with_failover(
|
|
operation=mock_operation,
|
|
operation_name="test_op",
|
|
)
|
|
|
|
assert "failed with all providers" in str(exc_info.value)
|
|
|
|
def test_add_provider(self):
|
|
"""Test adding provider to failover chain."""
|
|
failover = ProviderFailover(providers=["provider1"])
|
|
|
|
failover.add_provider("provider2")
|
|
|
|
assert "provider2" in failover.get_providers()
|
|
assert len(failover.get_providers()) == 2
|
|
|
|
def test_remove_provider(self):
|
|
"""Test removing provider from failover chain."""
|
|
failover = ProviderFailover(providers=["provider1", "provider2"])
|
|
|
|
success = failover.remove_provider("provider1")
|
|
|
|
assert success is True
|
|
assert "provider1" not in failover.get_providers()
|
|
assert len(failover.get_providers()) == 1
|
|
|
|
def test_remove_nonexistent_provider(self):
|
|
"""Test removing provider that doesn't exist."""
|
|
failover = ProviderFailover(providers=["provider1"])
|
|
|
|
success = failover.remove_provider("nonexistent")
|
|
|
|
assert success is False
|
|
|
|
def test_set_provider_priority(self):
|
|
"""Test setting provider priority."""
|
|
failover = ProviderFailover(
|
|
providers=["provider1", "provider2", "provider3"]
|
|
)
|
|
|
|
success = failover.set_provider_priority("provider3", 0)
|
|
|
|
assert success is True
|
|
providers = failover.get_providers()
|
|
assert providers[0] == "provider3"
|
|
|
|
def test_set_priority_nonexistent_provider(self):
|
|
"""Test setting priority for nonexistent provider."""
|
|
failover = ProviderFailover(providers=["provider1"])
|
|
|
|
success = failover.set_provider_priority("nonexistent", 0)
|
|
|
|
assert success is False
|
|
|
|
def test_get_failover_stats(self):
|
|
"""Test getting failover statistics."""
|
|
providers = ["provider1", "provider2"]
|
|
failover = ProviderFailover(
|
|
providers=providers,
|
|
max_retries=3,
|
|
retry_delay=1.5,
|
|
enable_health_monitoring=False,
|
|
)
|
|
|
|
stats = failover.get_failover_stats()
|
|
|
|
assert stats["total_providers"] == 2
|
|
assert stats["providers"] == providers
|
|
assert stats["max_retries"] == 3
|
|
assert stats["retry_delay"] == 1.5
|
|
assert stats["health_monitoring_enabled"] is False
|
|
|
|
|
|
class TestFailoverSingleton:
|
|
"""Test global failover singleton."""
|
|
|
|
def test_get_failover_singleton(self):
|
|
"""Test that get_failover returns singleton."""
|
|
failover1 = get_failover()
|
|
failover2 = get_failover()
|
|
|
|
assert failover1 is failover2
|
|
|
|
def test_configure_failover(self):
|
|
"""Test configuring global failover instance."""
|
|
providers = ["custom1", "custom2"]
|
|
failover = configure_failover(
|
|
providers=providers,
|
|
max_retries=10,
|
|
retry_delay=3.0,
|
|
)
|
|
|
|
assert failover._providers == providers
|
|
assert failover._max_retries == 10
|
|
assert failover._retry_delay == 3.0
|