"""Provider failover system for automatic fallback on failures. This module implements automatic failover between multiple providers, ensuring high availability by switching to backup providers when the primary fails. """ import asyncio import logging from typing import Any, Callable, Dict, List, Optional, TypeVar from src.core.providers.health_monitor import get_health_monitor from src.core.providers.provider_config import DEFAULT_PROVIDERS logger = logging.getLogger(__name__) T = TypeVar("T") class ProviderFailover: """Manages automatic failover between multiple providers.""" def __init__( self, providers: Optional[List[str]] = None, max_retries: int = 3, retry_delay: float = 1.0, enable_health_monitoring: bool = True, ): """Initialize provider failover manager. Args: providers: List of provider names to use (default: all). max_retries: Maximum retry attempts per provider. retry_delay: Delay between retries in seconds. enable_health_monitoring: Whether to use health monitoring. """ self._providers = providers or DEFAULT_PROVIDERS.copy() self._max_retries = max_retries self._retry_delay = retry_delay self._enable_health_monitoring = enable_health_monitoring # Current provider index self._current_index = 0 # Health monitor self._health_monitor = ( get_health_monitor() if enable_health_monitoring else None ) logger.info( f"Provider failover initialized with " f"{len(self._providers)} providers" ) def get_current_provider(self) -> str: """Get the current active provider. Returns: Name of current provider. """ if self._enable_health_monitoring and self._health_monitor: # Try to get best available provider best = self._health_monitor.get_best_provider() if best and best in self._providers: return best # Fall back to round-robin selection return self._providers[self._current_index % len(self._providers)] def get_next_provider(self) -> Optional[str]: """Get the next provider in the failover chain. Returns: Name of next provider or None if none available. """ if self._enable_health_monitoring and self._health_monitor: # Get available providers available = [ p for p in self._providers if p in self._health_monitor.get_available_providers() ] if not available: logger.warning("No available providers for failover") return None # Find next available provider current = self.get_current_provider() try: current_idx = available.index(current) next_idx = (current_idx + 1) % len(available) return available[next_idx] except ValueError: # Current provider not in available list return available[0] # Fall back to simple rotation self._current_index = (self._current_index + 1) % len( self._providers ) return self._providers[self._current_index] async def execute_with_failover( self, operation: Callable[[str], Any], operation_name: str = "operation", **kwargs, ) -> Any: """Execute an operation with automatic failover. Args: operation: Async callable that takes provider name. operation_name: Name for logging purposes. **kwargs: Additional arguments to pass to operation. Returns: Result from successful operation. Raises: Exception: If all providers fail. """ providers_tried = [] last_error = None # Try each provider for attempt in range(len(self._providers)): provider = self.get_current_provider() # Skip if already tried if provider in providers_tried: self.get_next_provider() continue providers_tried.append(provider) # Try operation with retries for retry in range(self._max_retries): try: logger.info( f"Executing {operation_name} with provider " f"{provider} (attempt {retry + 1}/{self._max_retries})" # noqa: E501 ) # Execute operation import time start_time = time.time() result = await operation(provider, **kwargs) elapsed_ms = (time.time() - start_time) * 1000 # Record success if self._health_monitor: self._health_monitor.record_request( provider_name=provider, success=True, response_time_ms=elapsed_ms, ) logger.info( f"{operation_name} succeeded with provider " f"{provider} in {elapsed_ms:.2f}ms" ) return result except Exception as e: last_error = e logger.warning( f"{operation_name} failed with provider " f"{provider} (attempt {retry + 1}): {e}" ) # Record failure if self._health_monitor: import time elapsed_ms = (time.time() - start_time) * 1000 self._health_monitor.record_request( provider_name=provider, success=False, response_time_ms=elapsed_ms, error_message=str(e), ) # Retry with delay if retry < self._max_retries - 1: await asyncio.sleep(self._retry_delay) # Try next provider next_provider = self.get_next_provider() if next_provider is None: break # All providers failed error_msg = ( f"{operation_name} failed with all providers. " f"Tried: {', '.join(providers_tried)}" ) logger.error(error_msg) raise Exception(error_msg) from last_error def add_provider(self, provider_name: str) -> None: """Add a provider to the failover chain. Args: provider_name: Name of provider to add. """ if provider_name not in self._providers: self._providers.append(provider_name) logger.info(f"Added provider to failover chain: {provider_name}") def remove_provider(self, provider_name: str) -> bool: """Remove a provider from the failover chain. Args: provider_name: Name of provider to remove. Returns: True if removed, False if not found. """ if provider_name in self._providers: self._providers.remove(provider_name) logger.info( f"Removed provider from failover chain: {provider_name}" ) return True return False def get_providers(self) -> List[str]: """Get list of all providers in failover chain. Returns: List of provider names. """ return self._providers.copy() def set_provider_priority( self, provider_name: str, priority_index: int ) -> bool: """Set priority of a provider by moving it in the chain. Args: provider_name: Name of provider to prioritize. priority_index: New index position (0 = highest priority). Returns: True if updated, False if provider not found. """ if provider_name not in self._providers: return False self._providers.remove(provider_name) self._providers.insert( min(priority_index, len(self._providers)), provider_name ) logger.info( f"Set provider {provider_name} priority to index {priority_index}" ) return True def get_failover_stats(self) -> Dict[str, Any]: """Get failover statistics and configuration. Returns: Dictionary with failover stats. """ stats = { "total_providers": len(self._providers), "providers": self._providers.copy(), "current_provider": self.get_current_provider(), "max_retries": self._max_retries, "retry_delay": self._retry_delay, "health_monitoring_enabled": self._enable_health_monitoring, } if self._health_monitor: available = self._health_monitor.get_available_providers() stats["available_providers"] = [ p for p in self._providers if p in available ] stats["unavailable_providers"] = [ p for p in self._providers if p not in available ] return stats # Global failover instance _failover: Optional[ProviderFailover] = None def get_failover() -> ProviderFailover: """Get or create global provider failover instance. Returns: Global ProviderFailover instance. """ global _failover if _failover is None: _failover = ProviderFailover() return _failover def configure_failover( providers: Optional[List[str]] = None, max_retries: int = 3, retry_delay: float = 1.0, ) -> ProviderFailover: """Configure global provider failover instance. Args: providers: List of provider names to use. max_retries: Maximum retry attempts per provider. retry_delay: Delay between retries in seconds. Returns: Configured ProviderFailover instance. """ global _failover _failover = ProviderFailover( providers=providers, max_retries=max_retries, retry_delay=retry_delay, ) return _failover