"""Integration tests for provider failover scenarios - End-to-end provider switching.""" import asyncio from unittest.mock import AsyncMock, MagicMock, patch import pytest from src.core.providers.failover import ( ProviderFailover, configure_failover, get_failover, ) from src.core.providers.health_monitor import ProviderHealthMonitor class TestProviderFailoverScenarios: """Test end-to-end failover scenarios with multiple providers.""" @pytest.mark.asyncio async def test_primary_fails_switches_to_backup(self): """When primary provider fails, should switch to backup and succeed.""" call_log = [] async def operation(provider: str) -> str: call_log.append(provider) if provider == "provider1": raise ConnectionError("Provider1 is down") return f"Success from {provider}" failover = ProviderFailover( providers=["provider1", "provider2", "provider3"], max_retries=1, retry_delay=0.01, enable_health_monitoring=False, ) result = await failover.execute_with_failover( operation=operation, operation_name="test_failover", ) assert "Success" in result assert "provider1" in call_log assert len(call_log) >= 2 @pytest.mark.asyncio async def test_first_two_fail_third_succeeds(self): """When first two providers fail, third should be tried.""" attempts = {} async def operation(provider: str) -> str: attempts[provider] = attempts.get(provider, 0) + 1 if provider in ("provider1", "provider2"): raise ConnectionError(f"{provider} is down") return f"Success from {provider}" failover = ProviderFailover( providers=["provider1", "provider2", "provider3"], max_retries=1, retry_delay=0.01, enable_health_monitoring=False, ) result = await failover.execute_with_failover( operation=operation, operation_name="test_two_fail", ) assert "provider3" in result @pytest.mark.asyncio async def test_all_providers_fail_raises(self): """When all providers fail, should raise exception.""" async def operation(provider: str) -> str: raise ConnectionError(f"{provider} is down") failover = ProviderFailover( providers=["provider1", "provider2"], max_retries=1, retry_delay=0.01, enable_health_monitoring=False, ) with pytest.raises(Exception, match="failed with all providers"): await failover.execute_with_failover( operation=operation, operation_name="test_all_fail", ) @pytest.mark.asyncio async def test_retry_within_single_provider(self): """Should retry with same provider before moving to next.""" call_count = 0 async def operation(provider: str) -> str: nonlocal call_count call_count += 1 if call_count <= 2: raise ConnectionError("Temporary failure") return f"Success on attempt {call_count}" failover = ProviderFailover( providers=["provider1"], max_retries=3, retry_delay=0.01, enable_health_monitoring=False, ) result = await failover.execute_with_failover( operation=operation, operation_name="test_retry", ) assert "Success" in result assert call_count == 3 @pytest.mark.asyncio async def test_failover_with_health_monitoring(self): """Failover should integrate with health monitoring.""" monitor = ProviderHealthMonitor(failure_threshold=2) # Pre-record failures for provider1 to make it unavailable for _ in range(3): monitor.record_request( provider_name="provider1", success=False, response_time_ms=100, error_message="Simulated failure", ) # Provider1 should now be unavailable assert "provider1" not in monitor.get_available_providers() with patch( "src.core.providers.failover.get_health_monitor", return_value=monitor, ): failover = ProviderFailover( providers=["provider1", "provider2"], max_retries=1, retry_delay=0.01, enable_health_monitoring=True, ) # Current provider should prefer provider2 current = failover.get_current_provider() # Best provider selection should favor available ones available = monitor.get_available_providers() if available: assert current in available or current == "provider2" class TestProviderFailoverChainManagement: """Test failover chain add/remove/priority operations.""" def test_add_provider_to_chain(self): """Should add new provider to the failover chain.""" failover = ProviderFailover( providers=["p1", "p2"], enable_health_monitoring=False, ) failover.add_provider("p3") assert "p3" in failover.get_providers() def test_add_duplicate_provider_no_effect(self): """Adding existing provider should not create duplicate.""" failover = ProviderFailover( providers=["p1", "p2"], enable_health_monitoring=False, ) failover.add_provider("p1") assert failover.get_providers().count("p1") == 1 def test_remove_provider_from_chain(self): """Should remove provider from the failover chain.""" failover = ProviderFailover( providers=["p1", "p2", "p3"], enable_health_monitoring=False, ) result = failover.remove_provider("p2") assert result is True assert "p2" not in failover.get_providers() def test_remove_nonexistent_provider(self): """Removing non-existent provider should return False.""" failover = ProviderFailover( providers=["p1"], enable_health_monitoring=False, ) assert failover.remove_provider("p99") is False def test_set_provider_priority(self): """Should reorder provider in the chain.""" failover = ProviderFailover( providers=["p1", "p2", "p3"], enable_health_monitoring=False, ) result = failover.set_provider_priority("p3", 0) assert result is True providers = failover.get_providers() assert providers[0] == "p3" def test_set_priority_unknown_provider(self): """Setting priority for unknown provider should return False.""" failover = ProviderFailover( providers=["p1"], enable_health_monitoring=False, ) assert failover.set_provider_priority("unknown", 0) is False class TestFailoverStats: """Test failover statistics reporting.""" def test_get_failover_stats(self): """Should return comprehensive stats.""" failover = ProviderFailover( providers=["p1", "p2"], max_retries=3, retry_delay=1.0, enable_health_monitoring=False, ) stats = failover.get_failover_stats() assert stats["total_providers"] == 2 assert stats["max_retries"] == 3 assert stats["retry_delay"] == 1.0 assert len(stats["providers"]) == 2 def test_stats_with_health_monitoring(self): """Stats should include availability info when monitoring enabled.""" monitor = ProviderHealthMonitor() monitor.record_request("p1", True, 100) monitor.record_request("p2", False, 200, error_message="fail") monitor.record_request("p2", False, 200, error_message="fail") monitor.record_request("p2", False, 200, error_message="fail") with patch( "src.core.providers.failover.get_health_monitor", return_value=monitor, ): failover = ProviderFailover( providers=["p1", "p2"], enable_health_monitoring=True, ) stats = failover.get_failover_stats() assert "available_providers" in stats assert "unavailable_providers" in stats class TestConfigureFailover: """Test the global failover configuration function.""" def test_configure_failover(self): """configure_failover should create a new global instance.""" import src.core.providers.failover as fo fo._failover = None failover = configure_failover( providers=["custom1", "custom2"], max_retries=5, retry_delay=0.5, ) assert isinstance(failover, ProviderFailover) assert failover.get_providers() == ["custom1", "custom2"] assert failover._max_retries == 5 # Cleanup fo._failover = None def test_get_failover_singleton(self): """get_failover should return same instance.""" import src.core.providers.failover as fo fo._failover = None first = get_failover() second = get_failover() assert first is second fo._failover = None class TestNextProviderRotation: """Test provider rotation logic.""" def test_get_next_cycles_through_all(self): """get_next_provider should cycle through all providers.""" failover = ProviderFailover( providers=["p1", "p2", "p3"], enable_health_monitoring=False, ) seen = set() for _ in range(3): provider = failover.get_next_provider() if provider: seen.add(provider) assert len(seen) >= 2 # Should see at least 2 different providers def test_current_provider_is_from_list(self): """get_current_provider should always return from provider list.""" failover = ProviderFailover( providers=["p1", "p2", "p3"], enable_health_monitoring=False, ) for _ in range(10): current = failover.get_current_provider() assert current in ["p1", "p2", "p3"] failover.get_next_provider()