"""Unit tests for provider failover system.""" import pytest from src.core.providers.failover import ( ProviderFailover, configure_failover, get_failover, ) class TestProviderFailover: """Test ProviderFailover class.""" def test_failover_initialization(self): """Test failover initialization.""" providers = ["provider1", "provider2", "provider3"] failover = ProviderFailover( providers=providers, max_retries=5, retry_delay=2.0, ) assert failover._providers == providers assert failover._max_retries == 5 assert failover._retry_delay == 2.0 def test_get_current_provider(self): """Test getting current provider.""" providers = ["provider1", "provider2"] failover = ProviderFailover( providers=providers, enable_health_monitoring=False, ) current = failover.get_current_provider() assert current in providers def test_get_next_provider(self): """Test getting next provider.""" providers = ["provider1", "provider2", "provider3"] failover = ProviderFailover( providers=providers, enable_health_monitoring=False, ) first = failover.get_current_provider() next_provider = failover.get_next_provider() assert next_provider in providers assert next_provider != first @pytest.mark.asyncio async def test_execute_with_failover_success(self): """Test successful execution with failover.""" async def mock_operation(provider: str) -> str: return f"Success with {provider}" failover = ProviderFailover( providers=["provider1"], enable_health_monitoring=False, ) result = await failover.execute_with_failover( operation=mock_operation, operation_name="test_op", ) assert "Success" in result @pytest.mark.asyncio async def test_execute_with_failover_retry(self): """Test failover with retry on first failure.""" call_count = 0 async def mock_operation(provider: str) -> str: nonlocal call_count call_count += 1 if call_count == 1: raise Exception("First attempt failed") return f"Success with {provider}" failover = ProviderFailover( providers=["provider1"], max_retries=2, retry_delay=0.1, enable_health_monitoring=False, ) result = await failover.execute_with_failover( operation=mock_operation, operation_name="test_op", ) assert "Success" in result assert call_count == 2 @pytest.mark.asyncio async def test_execute_with_failover_all_fail(self): """Test failover when all providers fail.""" async def mock_operation(provider: str) -> str: raise Exception(f"Failed with {provider}") failover = ProviderFailover( providers=["provider1", "provider2"], max_retries=1, retry_delay=0.1, enable_health_monitoring=False, ) with pytest.raises(Exception) as exc_info: await failover.execute_with_failover( operation=mock_operation, operation_name="test_op", ) assert "failed with all providers" in str(exc_info.value) def test_add_provider(self): """Test adding provider to failover chain.""" failover = ProviderFailover(providers=["provider1"]) failover.add_provider("provider2") assert "provider2" in failover.get_providers() assert len(failover.get_providers()) == 2 def test_remove_provider(self): """Test removing provider from failover chain.""" failover = ProviderFailover(providers=["provider1", "provider2"]) success = failover.remove_provider("provider1") assert success is True assert "provider1" not in failover.get_providers() assert len(failover.get_providers()) == 1 def test_remove_nonexistent_provider(self): """Test removing provider that doesn't exist.""" failover = ProviderFailover(providers=["provider1"]) success = failover.remove_provider("nonexistent") assert success is False def test_set_provider_priority(self): """Test setting provider priority.""" failover = ProviderFailover( providers=["provider1", "provider2", "provider3"] ) success = failover.set_provider_priority("provider3", 0) assert success is True providers = failover.get_providers() assert providers[0] == "provider3" def test_set_priority_nonexistent_provider(self): """Test setting priority for nonexistent provider.""" failover = ProviderFailover(providers=["provider1"]) success = failover.set_provider_priority("nonexistent", 0) assert success is False def test_get_failover_stats(self): """Test getting failover statistics.""" providers = ["provider1", "provider2"] failover = ProviderFailover( providers=providers, max_retries=3, retry_delay=1.5, enable_health_monitoring=False, ) stats = failover.get_failover_stats() assert stats["total_providers"] == 2 assert stats["providers"] == providers assert stats["max_retries"] == 3 assert stats["retry_delay"] == 1.5 assert stats["health_monitoring_enabled"] is False class TestFailoverSingleton: """Test global failover singleton.""" def test_get_failover_singleton(self): """Test that get_failover returns singleton.""" failover1 = get_failover() failover2 = get_failover() assert failover1 is failover2 def test_configure_failover(self): """Test configuring global failover instance.""" providers = ["custom1", "custom2"] failover = configure_failover( providers=providers, max_retries=10, retry_delay=3.0, ) assert failover._providers == providers assert failover._max_retries == 10 assert failover._retry_delay == 3.0