Add provider system tests: 211 tests covering base, factory, config, monitoring, failover, and selection
This commit is contained in:
312
tests/integration/test_provider_failover_scenarios.py
Normal file
312
tests/integration/test_provider_failover_scenarios.py
Normal file
@@ -0,0 +1,312 @@
|
||||
"""Integration tests for provider failover scenarios - End-to-end provider switching."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.core.providers.failover import (
|
||||
ProviderFailover,
|
||||
configure_failover,
|
||||
get_failover,
|
||||
)
|
||||
from src.core.providers.health_monitor import ProviderHealthMonitor
|
||||
|
||||
|
||||
class TestProviderFailoverScenarios:
|
||||
"""Test end-to-end failover scenarios with multiple providers."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_primary_fails_switches_to_backup(self):
|
||||
"""When primary provider fails, should switch to backup and succeed."""
|
||||
call_log = []
|
||||
|
||||
async def operation(provider: str) -> str:
|
||||
call_log.append(provider)
|
||||
if provider == "provider1":
|
||||
raise ConnectionError("Provider1 is down")
|
||||
return f"Success from {provider}"
|
||||
|
||||
failover = ProviderFailover(
|
||||
providers=["provider1", "provider2", "provider3"],
|
||||
max_retries=1,
|
||||
retry_delay=0.01,
|
||||
enable_health_monitoring=False,
|
||||
)
|
||||
|
||||
result = await failover.execute_with_failover(
|
||||
operation=operation,
|
||||
operation_name="test_failover",
|
||||
)
|
||||
|
||||
assert "Success" in result
|
||||
assert "provider1" in call_log
|
||||
assert len(call_log) >= 2
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_two_fail_third_succeeds(self):
|
||||
"""When first two providers fail, third should be tried."""
|
||||
attempts = {}
|
||||
|
||||
async def operation(provider: str) -> str:
|
||||
attempts[provider] = attempts.get(provider, 0) + 1
|
||||
if provider in ("provider1", "provider2"):
|
||||
raise ConnectionError(f"{provider} is down")
|
||||
return f"Success from {provider}"
|
||||
|
||||
failover = ProviderFailover(
|
||||
providers=["provider1", "provider2", "provider3"],
|
||||
max_retries=1,
|
||||
retry_delay=0.01,
|
||||
enable_health_monitoring=False,
|
||||
)
|
||||
|
||||
result = await failover.execute_with_failover(
|
||||
operation=operation,
|
||||
operation_name="test_two_fail",
|
||||
)
|
||||
|
||||
assert "provider3" in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_providers_fail_raises(self):
|
||||
"""When all providers fail, should raise exception."""
|
||||
async def operation(provider: str) -> str:
|
||||
raise ConnectionError(f"{provider} is down")
|
||||
|
||||
failover = ProviderFailover(
|
||||
providers=["provider1", "provider2"],
|
||||
max_retries=1,
|
||||
retry_delay=0.01,
|
||||
enable_health_monitoring=False,
|
||||
)
|
||||
|
||||
with pytest.raises(Exception, match="failed with all providers"):
|
||||
await failover.execute_with_failover(
|
||||
operation=operation,
|
||||
operation_name="test_all_fail",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_retry_within_single_provider(self):
|
||||
"""Should retry with same provider before moving to next."""
|
||||
call_count = 0
|
||||
|
||||
async def operation(provider: str) -> str:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count <= 2:
|
||||
raise ConnectionError("Temporary failure")
|
||||
return f"Success on attempt {call_count}"
|
||||
|
||||
failover = ProviderFailover(
|
||||
providers=["provider1"],
|
||||
max_retries=3,
|
||||
retry_delay=0.01,
|
||||
enable_health_monitoring=False,
|
||||
)
|
||||
|
||||
result = await failover.execute_with_failover(
|
||||
operation=operation,
|
||||
operation_name="test_retry",
|
||||
)
|
||||
|
||||
assert "Success" in result
|
||||
assert call_count == 3
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failover_with_health_monitoring(self):
|
||||
"""Failover should integrate with health monitoring."""
|
||||
monitor = ProviderHealthMonitor(failure_threshold=2)
|
||||
|
||||
# Pre-record failures for provider1 to make it unavailable
|
||||
for _ in range(3):
|
||||
monitor.record_request(
|
||||
provider_name="provider1",
|
||||
success=False,
|
||||
response_time_ms=100,
|
||||
error_message="Simulated failure",
|
||||
)
|
||||
|
||||
# Provider1 should now be unavailable
|
||||
assert "provider1" not in monitor.get_available_providers()
|
||||
|
||||
with patch(
|
||||
"src.core.providers.failover.get_health_monitor",
|
||||
return_value=monitor,
|
||||
):
|
||||
failover = ProviderFailover(
|
||||
providers=["provider1", "provider2"],
|
||||
max_retries=1,
|
||||
retry_delay=0.01,
|
||||
enable_health_monitoring=True,
|
||||
)
|
||||
|
||||
# Current provider should prefer provider2
|
||||
current = failover.get_current_provider()
|
||||
# Best provider selection should favor available ones
|
||||
available = monitor.get_available_providers()
|
||||
if available:
|
||||
assert current in available or current == "provider2"
|
||||
|
||||
|
||||
class TestProviderFailoverChainManagement:
|
||||
"""Test failover chain add/remove/priority operations."""
|
||||
|
||||
def test_add_provider_to_chain(self):
|
||||
"""Should add new provider to the failover chain."""
|
||||
failover = ProviderFailover(
|
||||
providers=["p1", "p2"],
|
||||
enable_health_monitoring=False,
|
||||
)
|
||||
failover.add_provider("p3")
|
||||
assert "p3" in failover.get_providers()
|
||||
|
||||
def test_add_duplicate_provider_no_effect(self):
|
||||
"""Adding existing provider should not create duplicate."""
|
||||
failover = ProviderFailover(
|
||||
providers=["p1", "p2"],
|
||||
enable_health_monitoring=False,
|
||||
)
|
||||
failover.add_provider("p1")
|
||||
assert failover.get_providers().count("p1") == 1
|
||||
|
||||
def test_remove_provider_from_chain(self):
|
||||
"""Should remove provider from the failover chain."""
|
||||
failover = ProviderFailover(
|
||||
providers=["p1", "p2", "p3"],
|
||||
enable_health_monitoring=False,
|
||||
)
|
||||
result = failover.remove_provider("p2")
|
||||
assert result is True
|
||||
assert "p2" not in failover.get_providers()
|
||||
|
||||
def test_remove_nonexistent_provider(self):
|
||||
"""Removing non-existent provider should return False."""
|
||||
failover = ProviderFailover(
|
||||
providers=["p1"],
|
||||
enable_health_monitoring=False,
|
||||
)
|
||||
assert failover.remove_provider("p99") is False
|
||||
|
||||
def test_set_provider_priority(self):
|
||||
"""Should reorder provider in the chain."""
|
||||
failover = ProviderFailover(
|
||||
providers=["p1", "p2", "p3"],
|
||||
enable_health_monitoring=False,
|
||||
)
|
||||
result = failover.set_provider_priority("p3", 0)
|
||||
assert result is True
|
||||
providers = failover.get_providers()
|
||||
assert providers[0] == "p3"
|
||||
|
||||
def test_set_priority_unknown_provider(self):
|
||||
"""Setting priority for unknown provider should return False."""
|
||||
failover = ProviderFailover(
|
||||
providers=["p1"],
|
||||
enable_health_monitoring=False,
|
||||
)
|
||||
assert failover.set_provider_priority("unknown", 0) is False
|
||||
|
||||
|
||||
class TestFailoverStats:
|
||||
"""Test failover statistics reporting."""
|
||||
|
||||
def test_get_failover_stats(self):
|
||||
"""Should return comprehensive stats."""
|
||||
failover = ProviderFailover(
|
||||
providers=["p1", "p2"],
|
||||
max_retries=3,
|
||||
retry_delay=1.0,
|
||||
enable_health_monitoring=False,
|
||||
)
|
||||
|
||||
stats = failover.get_failover_stats()
|
||||
assert stats["total_providers"] == 2
|
||||
assert stats["max_retries"] == 3
|
||||
assert stats["retry_delay"] == 1.0
|
||||
assert len(stats["providers"]) == 2
|
||||
|
||||
def test_stats_with_health_monitoring(self):
|
||||
"""Stats should include availability info when monitoring enabled."""
|
||||
monitor = ProviderHealthMonitor()
|
||||
monitor.record_request("p1", True, 100)
|
||||
monitor.record_request("p2", False, 200, error_message="fail")
|
||||
monitor.record_request("p2", False, 200, error_message="fail")
|
||||
monitor.record_request("p2", False, 200, error_message="fail")
|
||||
|
||||
with patch(
|
||||
"src.core.providers.failover.get_health_monitor",
|
||||
return_value=monitor,
|
||||
):
|
||||
failover = ProviderFailover(
|
||||
providers=["p1", "p2"],
|
||||
enable_health_monitoring=True,
|
||||
)
|
||||
stats = failover.get_failover_stats()
|
||||
assert "available_providers" in stats
|
||||
assert "unavailable_providers" in stats
|
||||
|
||||
|
||||
class TestConfigureFailover:
|
||||
"""Test the global failover configuration function."""
|
||||
|
||||
def test_configure_failover(self):
|
||||
"""configure_failover should create a new global instance."""
|
||||
import src.core.providers.failover as fo
|
||||
fo._failover = None
|
||||
|
||||
failover = configure_failover(
|
||||
providers=["custom1", "custom2"],
|
||||
max_retries=5,
|
||||
retry_delay=0.5,
|
||||
)
|
||||
|
||||
assert isinstance(failover, ProviderFailover)
|
||||
assert failover.get_providers() == ["custom1", "custom2"]
|
||||
assert failover._max_retries == 5
|
||||
|
||||
# Cleanup
|
||||
fo._failover = None
|
||||
|
||||
def test_get_failover_singleton(self):
|
||||
"""get_failover should return same instance."""
|
||||
import src.core.providers.failover as fo
|
||||
fo._failover = None
|
||||
|
||||
first = get_failover()
|
||||
second = get_failover()
|
||||
assert first is second
|
||||
|
||||
fo._failover = None
|
||||
|
||||
|
||||
class TestNextProviderRotation:
|
||||
"""Test provider rotation logic."""
|
||||
|
||||
def test_get_next_cycles_through_all(self):
|
||||
"""get_next_provider should cycle through all providers."""
|
||||
failover = ProviderFailover(
|
||||
providers=["p1", "p2", "p3"],
|
||||
enable_health_monitoring=False,
|
||||
)
|
||||
|
||||
seen = set()
|
||||
for _ in range(3):
|
||||
provider = failover.get_next_provider()
|
||||
if provider:
|
||||
seen.add(provider)
|
||||
|
||||
assert len(seen) >= 2 # Should see at least 2 different providers
|
||||
|
||||
def test_current_provider_is_from_list(self):
|
||||
"""get_current_provider should always return from provider list."""
|
||||
failover = ProviderFailover(
|
||||
providers=["p1", "p2", "p3"],
|
||||
enable_health_monitoring=False,
|
||||
)
|
||||
|
||||
for _ in range(10):
|
||||
current = failover.get_current_provider()
|
||||
assert current in ["p1", "p2", "p3"]
|
||||
failover.get_next_provider()
|
||||
348
tests/integration/test_provider_selection.py
Normal file
348
tests/integration/test_provider_selection.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""Integration tests for provider selection based on availability, health status, priority."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.core.providers.config_manager import (
|
||||
ProviderConfigManager,
|
||||
ProviderSettings,
|
||||
)
|
||||
from src.core.providers.failover import ProviderFailover
|
||||
from src.core.providers.health_monitor import (
|
||||
ProviderHealthMetrics,
|
||||
ProviderHealthMonitor,
|
||||
)
|
||||
|
||||
|
||||
class TestProviderSelectionByHealth:
|
||||
"""Test provider selection based on health metrics."""
|
||||
|
||||
def test_best_provider_selected_by_success_rate(self):
|
||||
"""Provider with highest success rate should be selected as best."""
|
||||
monitor = ProviderHealthMonitor(failure_threshold=5)
|
||||
|
||||
# Provider1: 80% success rate
|
||||
for i in range(10):
|
||||
monitor.record_request(
|
||||
"provider1",
|
||||
success=(i < 8),
|
||||
response_time_ms=100,
|
||||
error_message=None if i < 8 else "fail",
|
||||
)
|
||||
|
||||
# Provider2: 90% success rate
|
||||
for i in range(10):
|
||||
monitor.record_request(
|
||||
"provider2",
|
||||
success=(i < 9),
|
||||
response_time_ms=100,
|
||||
error_message=None if i < 9 else "fail",
|
||||
)
|
||||
|
||||
best = monitor.get_best_provider()
|
||||
assert best == "provider2"
|
||||
|
||||
def test_unavailable_provider_not_selected(self):
|
||||
"""Provider marked unavailable should not be selected as best."""
|
||||
monitor = ProviderHealthMonitor(failure_threshold=3)
|
||||
|
||||
# Make provider1 unavailable with consecutive failures
|
||||
for _ in range(5):
|
||||
monitor.record_request(
|
||||
"provider1",
|
||||
success=False,
|
||||
response_time_ms=500,
|
||||
error_message="Connection refused",
|
||||
)
|
||||
|
||||
# Provider2 is healthy
|
||||
monitor.record_request(
|
||||
"provider2", success=True, response_time_ms=100
|
||||
)
|
||||
|
||||
best = monitor.get_best_provider()
|
||||
assert best == "provider2"
|
||||
|
||||
available = monitor.get_available_providers()
|
||||
assert "provider1" not in available
|
||||
assert "provider2" in available
|
||||
|
||||
def test_recovery_after_failures(self):
|
||||
"""Provider should recover availability after successful request."""
|
||||
monitor = ProviderHealthMonitor(failure_threshold=3)
|
||||
|
||||
# Make provider fail
|
||||
for _ in range(4):
|
||||
monitor.record_request(
|
||||
"provider1", success=False, response_time_ms=200,
|
||||
error_message="fail"
|
||||
)
|
||||
|
||||
assert "provider1" not in monitor.get_available_providers()
|
||||
|
||||
# Successful request should reset consecutive failures
|
||||
monitor.record_request(
|
||||
"provider1", success=True, response_time_ms=100
|
||||
)
|
||||
|
||||
assert "provider1" in monitor.get_available_providers()
|
||||
|
||||
def test_response_time_tiebreaker(self):
|
||||
"""When success rates are equal, faster provider should be preferred."""
|
||||
monitor = ProviderHealthMonitor()
|
||||
|
||||
# Both have 100% success, but different response times
|
||||
monitor.record_request(
|
||||
"slow_provider", success=True, response_time_ms=500
|
||||
)
|
||||
monitor.record_request(
|
||||
"fast_provider", success=True, response_time_ms=50
|
||||
)
|
||||
|
||||
best = monitor.get_best_provider()
|
||||
assert best == "fast_provider"
|
||||
|
||||
|
||||
class TestProviderSelectionWithConfig:
|
||||
"""Test provider selection using configuration manager."""
|
||||
|
||||
def test_enabled_providers_only(self):
|
||||
"""Only enabled providers should be available for selection."""
|
||||
config = ProviderConfigManager()
|
||||
config.set_provider_settings(
|
||||
"p1", ProviderSettings(name="p1", enabled=True, priority=1)
|
||||
)
|
||||
config.set_provider_settings(
|
||||
"p2", ProviderSettings(name="p2", enabled=False, priority=0)
|
||||
)
|
||||
config.set_provider_settings(
|
||||
"p3", ProviderSettings(name="p3", enabled=True, priority=2)
|
||||
)
|
||||
|
||||
enabled = config.get_enabled_providers()
|
||||
assert "p1" in enabled
|
||||
assert "p2" not in enabled
|
||||
assert "p3" in enabled
|
||||
|
||||
def test_priority_ordering(self):
|
||||
"""Providers should be ordered by priority value."""
|
||||
config = ProviderConfigManager()
|
||||
config.set_provider_settings(
|
||||
"low_priority",
|
||||
ProviderSettings(name="low_priority", priority=10),
|
||||
)
|
||||
config.set_provider_settings(
|
||||
"high_priority",
|
||||
ProviderSettings(name="high_priority", priority=1),
|
||||
)
|
||||
config.set_provider_settings(
|
||||
"mid_priority",
|
||||
ProviderSettings(name="mid_priority", priority=5),
|
||||
)
|
||||
|
||||
ordered = config.get_providers_by_priority()
|
||||
assert ordered == ["high_priority", "mid_priority", "low_priority"]
|
||||
|
||||
def test_dynamic_priority_update(self):
|
||||
"""Priority changes should immediately affect ordering."""
|
||||
config = ProviderConfigManager()
|
||||
config.set_provider_settings(
|
||||
"p1", ProviderSettings(name="p1", priority=1)
|
||||
)
|
||||
config.set_provider_settings(
|
||||
"p2", ProviderSettings(name="p2", priority=2)
|
||||
)
|
||||
|
||||
# Initially p1 is higher priority
|
||||
assert config.get_providers_by_priority()[0] == "p1"
|
||||
|
||||
# Change p2 to higher priority
|
||||
config.set_provider_priority("p2", 0)
|
||||
assert config.get_providers_by_priority()[0] == "p2"
|
||||
|
||||
|
||||
class TestProviderSelectionWithFailover:
|
||||
"""Test provider selection integration with failover system."""
|
||||
|
||||
def test_failover_respects_health_status(self):
|
||||
"""Failover should prefer healthy providers."""
|
||||
monitor = ProviderHealthMonitor(failure_threshold=2)
|
||||
|
||||
# Mark p1 as unhealthy
|
||||
monitor.record_request("p1", False, 100, error_message="fail")
|
||||
monitor.record_request("p1", False, 100, error_message="fail")
|
||||
|
||||
# p2 is healthy
|
||||
monitor.record_request("p2", True, 50)
|
||||
|
||||
with patch(
|
||||
"src.core.providers.failover.get_health_monitor",
|
||||
return_value=monitor,
|
||||
):
|
||||
failover = ProviderFailover(
|
||||
providers=["p1", "p2"],
|
||||
enable_health_monitoring=True,
|
||||
)
|
||||
|
||||
current = failover.get_current_provider()
|
||||
# Should prefer the healthy provider
|
||||
assert current == "p2"
|
||||
|
||||
def test_failover_falls_back_to_round_robin(self):
|
||||
"""Without health data, should use round-robin selection."""
|
||||
failover = ProviderFailover(
|
||||
providers=["p1", "p2", "p3"],
|
||||
enable_health_monitoring=False,
|
||||
)
|
||||
|
||||
# Should cycle through providers
|
||||
first = failover.get_current_provider()
|
||||
assert first in ["p1", "p2", "p3"]
|
||||
|
||||
|
||||
class TestHealthMonitorMetrics:
|
||||
"""Test health monitor metric collection scenarios."""
|
||||
|
||||
def test_metrics_tracking_accuracy(self):
|
||||
"""Metrics should accurately reflect request history."""
|
||||
monitor = ProviderHealthMonitor()
|
||||
|
||||
# Record 7 successes and 3 failures
|
||||
for i in range(10):
|
||||
monitor.record_request(
|
||||
"test_provider",
|
||||
success=(i < 7),
|
||||
response_time_ms=100 + i * 10,
|
||||
bytes_transferred=1024 * (i + 1),
|
||||
error_message=None if i < 7 else f"error_{i}",
|
||||
)
|
||||
|
||||
metrics = monitor.get_provider_metrics("test_provider")
|
||||
assert metrics is not None
|
||||
assert metrics.total_requests == 10
|
||||
assert metrics.successful_requests == 7
|
||||
assert metrics.failed_requests == 3
|
||||
assert metrics.success_rate == 70.0
|
||||
assert metrics.total_bytes_downloaded == sum(
|
||||
1024 * (i + 1) for i in range(10)
|
||||
)
|
||||
|
||||
def test_consecutive_failure_tracking(self):
|
||||
"""Consecutive failures should be tracked accurately."""
|
||||
monitor = ProviderHealthMonitor(failure_threshold=3)
|
||||
|
||||
# 2 successes then 3 failures
|
||||
monitor.record_request("p1", True, 100)
|
||||
monitor.record_request("p1", True, 100)
|
||||
monitor.record_request("p1", False, 100, error_message="e1")
|
||||
monitor.record_request("p1", False, 100, error_message="e2")
|
||||
monitor.record_request("p1", False, 100, error_message="e3")
|
||||
|
||||
metrics = monitor.get_provider_metrics("p1")
|
||||
assert metrics.consecutive_failures == 3
|
||||
assert metrics.is_available is False
|
||||
|
||||
def test_success_resets_consecutive_failures(self):
|
||||
"""A success should reset the consecutive failure counter."""
|
||||
monitor = ProviderHealthMonitor(failure_threshold=5)
|
||||
|
||||
# 3 failures then 1 success
|
||||
monitor.record_request("p1", False, 100, error_message="e1")
|
||||
monitor.record_request("p1", False, 100, error_message="e2")
|
||||
monitor.record_request("p1", False, 100, error_message="e3")
|
||||
monitor.record_request("p1", True, 100)
|
||||
|
||||
metrics = monitor.get_provider_metrics("p1")
|
||||
assert metrics.consecutive_failures == 0
|
||||
assert metrics.is_available is True
|
||||
|
||||
def test_health_summary(self):
|
||||
"""Health summary should aggregate all provider metrics."""
|
||||
monitor = ProviderHealthMonitor()
|
||||
|
||||
monitor.record_request("p1", True, 100)
|
||||
monitor.record_request("p2", True, 200)
|
||||
monitor.record_request("p3", False, 300, error_message="err")
|
||||
|
||||
summary = monitor.get_health_summary()
|
||||
assert summary["total_providers"] == 3
|
||||
assert summary["available_providers"] >= 2
|
||||
assert "average_success_rate" in summary
|
||||
assert "providers" in summary
|
||||
assert len(summary["providers"]) == 3
|
||||
|
||||
def test_reset_provider_metrics(self):
|
||||
"""Resetting metrics should clear all data for a provider."""
|
||||
monitor = ProviderHealthMonitor()
|
||||
|
||||
monitor.record_request("p1", True, 100, bytes_transferred=1024)
|
||||
monitor.record_request("p1", False, 200, error_message="fail")
|
||||
|
||||
result = monitor.reset_provider_metrics("p1")
|
||||
assert result is True
|
||||
|
||||
metrics = monitor.get_provider_metrics("p1")
|
||||
assert metrics.total_requests == 0
|
||||
assert metrics.successful_requests == 0
|
||||
assert metrics.total_bytes_downloaded == 0
|
||||
|
||||
def test_reset_unknown_provider(self):
|
||||
"""Resetting unknown provider should return False."""
|
||||
monitor = ProviderHealthMonitor()
|
||||
assert monitor.reset_provider_metrics("unknown") is False
|
||||
|
||||
def test_empty_summary(self):
|
||||
"""Summary with no providers should return zeros."""
|
||||
monitor = ProviderHealthMonitor()
|
||||
summary = monitor.get_health_summary()
|
||||
assert summary["total_providers"] == 0
|
||||
assert summary["average_success_rate"] == 0.0
|
||||
|
||||
|
||||
class TestMultiProviderHealthScenarios:
|
||||
"""Test complex multi-provider health scenarios."""
|
||||
|
||||
def test_three_providers_degraded_service(self):
|
||||
"""With 3 providers, partial failure should still select best."""
|
||||
monitor = ProviderHealthMonitor(failure_threshold=3)
|
||||
|
||||
# Provider A: fully down
|
||||
for _ in range(5):
|
||||
monitor.record_request("A", False, 500, error_message="down")
|
||||
|
||||
# Provider B: degraded (50% success)
|
||||
for i in range(10):
|
||||
monitor.record_request(
|
||||
"B", success=(i % 2 == 0), response_time_ms=200,
|
||||
error_message=None if i % 2 == 0 else "intermittent"
|
||||
)
|
||||
|
||||
# Provider C: healthy (100% success)
|
||||
for _ in range(5):
|
||||
monitor.record_request("C", True, 100)
|
||||
|
||||
best = monitor.get_best_provider()
|
||||
assert best == "C"
|
||||
|
||||
available = monitor.get_available_providers()
|
||||
assert "A" not in available
|
||||
assert "B" in available
|
||||
assert "C" in available
|
||||
|
||||
def test_all_providers_healthy(self):
|
||||
"""When all healthy, fastest should be selected."""
|
||||
monitor = ProviderHealthMonitor()
|
||||
|
||||
monitor.record_request("slow", True, 500)
|
||||
monitor.record_request("medium", True, 200)
|
||||
monitor.record_request("fast", True, 50)
|
||||
|
||||
best = monitor.get_best_provider()
|
||||
assert best == "fast"
|
||||
|
||||
def test_no_providers_tracked(self):
|
||||
"""With no tracked providers, best should be None."""
|
||||
monitor = ProviderHealthMonitor()
|
||||
assert monitor.get_best_provider() is None
|
||||
assert monitor.get_available_providers() == []
|
||||
474
tests/unit/test_aniworld_provider.py
Normal file
474
tests/unit/test_aniworld_provider.py
Normal file
@@ -0,0 +1,474 @@
|
||||
"""Unit tests for aniworld_provider.py - Anime catalog scraping, episode listing, streaming link extraction."""
|
||||
|
||||
import json
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.core.providers.aniworld_provider import AniworldLoader
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def loader():
|
||||
"""Create AniworldLoader with mocked session to prevent real HTTP calls."""
|
||||
with patch("src.core.providers.aniworld_provider.UserAgent") as mock_ua:
|
||||
mock_ua.return_value.random = "MockUserAgent/1.0"
|
||||
instance = AniworldLoader()
|
||||
instance.session = MagicMock()
|
||||
return instance
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_search_response():
|
||||
"""Sample JSON response for anime search."""
|
||||
return json.dumps([
|
||||
{"link": "/anime/stream/naruto", "title": "Naruto"},
|
||||
{"link": "/anime/stream/one-piece", "title": "One Piece"},
|
||||
])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_episode_html():
|
||||
"""Sample HTML for an episode page with language info and providers."""
|
||||
return """
|
||||
<html>
|
||||
<body>
|
||||
<div class="changeLanguageBox">
|
||||
<img data-lang-key="1" src="/flags/de.png" />
|
||||
<img data-lang-key="2" src="/flags/en.png" />
|
||||
</div>
|
||||
<li class="episodeLink1">
|
||||
<h4>VOE</h4>
|
||||
<a class="watchEpisode" href="/redirect/12345"></a>
|
||||
<span data-lang-key="1"></span>
|
||||
</li>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_series_html():
|
||||
"""Sample HTML for a series main page."""
|
||||
return """
|
||||
<html>
|
||||
<body>
|
||||
<div class="series-title">
|
||||
<h1><span>Naruto Shippuden</span></h1>
|
||||
</div>
|
||||
<p>Jahr: 2007</p>
|
||||
<div class="series-info">Aired: 2007-2017</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_season_html():
|
||||
"""Sample HTML for a season page with episode links."""
|
||||
return """
|
||||
<html>
|
||||
<body>
|
||||
<meta itemprop="numberOfSeasons" content="2" />
|
||||
<a href="/anime/stream/naruto/staffel-1/episode-1">Ep 1</a>
|
||||
<a href="/anime/stream/naruto/staffel-1/episode-2">Ep 2</a>
|
||||
<a href="/anime/stream/naruto/staffel-1/episode-3">Ep 3</a>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
class TestAniworldLoaderInit:
|
||||
"""Test AniworldLoader initialization."""
|
||||
|
||||
def test_loader_initializes(self, loader):
|
||||
"""Loader should initialize with expected attributes."""
|
||||
assert loader.ANIWORLD_TO == "https://aniworld.to"
|
||||
assert isinstance(loader.SUPPORTED_PROVIDERS, list)
|
||||
assert len(loader.SUPPORTED_PROVIDERS) > 0
|
||||
|
||||
def test_loader_has_session(self, loader):
|
||||
"""Loader should have a requests session."""
|
||||
assert loader.session is not None
|
||||
|
||||
def test_loader_has_caches(self, loader):
|
||||
"""Loader should initialize empty caches."""
|
||||
assert isinstance(loader._KeyHTMLDict, dict)
|
||||
assert isinstance(loader._EpisodeHTMLDict, dict)
|
||||
|
||||
def test_loader_site_key(self, loader):
|
||||
"""get_site_key should return 'aniworld.to'."""
|
||||
assert loader.get_site_key() == "aniworld.to"
|
||||
|
||||
def test_loader_provider_headers_initialized(self, loader):
|
||||
"""Provider-specific headers should be initialized."""
|
||||
assert isinstance(loader.PROVIDER_HEADERS, dict)
|
||||
assert "VOE" in loader.PROVIDER_HEADERS
|
||||
|
||||
|
||||
class TestAniworldSearch:
|
||||
"""Test anime search functionality."""
|
||||
|
||||
def test_search_parses_json_response(self, loader, sample_search_response):
|
||||
"""search() should parse JSON response into list."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = sample_search_response
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
loader.session.get.return_value = mock_response
|
||||
|
||||
result = loader.search("naruto")
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
assert result[0]["title"] == "Naruto"
|
||||
|
||||
def test_search_calls_correct_url(self, loader, sample_search_response):
|
||||
"""search() should call the correct search URL."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = sample_search_response
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
loader.session.get.return_value = mock_response
|
||||
|
||||
loader.search("naruto")
|
||||
call_args = loader.session.get.call_args
|
||||
assert "seriesSearch" in call_args[0][0]
|
||||
assert "naruto" in call_args[0][0]
|
||||
|
||||
def test_search_handles_empty_response(self, loader):
|
||||
"""search() with empty JSON array should return empty list."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "[]"
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
loader.session.get.return_value = mock_response
|
||||
|
||||
result = loader.search("nonexistent")
|
||||
assert result == []
|
||||
|
||||
def test_search_handles_html_escaped_json(self, loader):
|
||||
"""search() should handle HTML-escaped JSON response."""
|
||||
escaped_json = '[{"title": "Naruto & Friends"}]'
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = escaped_json
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
loader.session.get.return_value = mock_response
|
||||
|
||||
result = loader.search("naruto")
|
||||
assert len(result) == 1
|
||||
assert result[0]["title"] == "Naruto & Friends"
|
||||
|
||||
def test_search_url_encodes_special_characters(self, loader, sample_search_response):
|
||||
"""search() should URL-encode special characters in search term."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = sample_search_response
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
loader.session.get.return_value = mock_response
|
||||
|
||||
loader.search("attack on titan")
|
||||
call_url = loader.session.get.call_args[0][0]
|
||||
assert "attack" in call_url
|
||||
|
||||
def test_search_raises_on_invalid_json(self, loader):
|
||||
"""search() should raise when response is not valid JSON."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "<html>Not JSON</html>"
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
loader.session.get.return_value = mock_response
|
||||
|
||||
with pytest.raises((ValueError, json.JSONDecodeError)):
|
||||
loader.search("naruto")
|
||||
|
||||
|
||||
class TestAniworldLanguageCheck:
|
||||
"""Test language availability checking."""
|
||||
|
||||
def test_get_language_key_german_dub(self, loader):
|
||||
"""_get_language_key should return 1 for 'German Dub'."""
|
||||
assert loader._get_language_key("German Dub") == 1
|
||||
|
||||
def test_get_language_key_english_sub(self, loader):
|
||||
"""_get_language_key should return 2 for 'English Sub'."""
|
||||
assert loader._get_language_key("English Sub") == 2
|
||||
|
||||
def test_get_language_key_german_sub(self, loader):
|
||||
"""_get_language_key should return 3 for 'German Sub'."""
|
||||
assert loader._get_language_key("German Sub") == 3
|
||||
|
||||
def test_get_language_key_unknown(self, loader):
|
||||
"""_get_language_key should return 0 for unknown language."""
|
||||
assert loader._get_language_key("French Dub") == 0
|
||||
|
||||
def test_is_language_with_available_language(self, loader, sample_episode_html):
|
||||
"""is_language should return True when language is available."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = sample_episode_html.encode("utf-8")
|
||||
loader.session.get.return_value = mock_response
|
||||
|
||||
result = loader.is_language(1, 1, "naruto", "German Dub")
|
||||
assert result is True
|
||||
|
||||
def test_is_language_english_sub_available(self, loader, sample_episode_html):
|
||||
"""is_language should return True for English Sub when available."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = sample_episode_html.encode("utf-8")
|
||||
loader.session.get.return_value = mock_response
|
||||
|
||||
result = loader.is_language(1, 1, "naruto", "English Sub")
|
||||
assert result is True
|
||||
|
||||
def test_is_language_unavailable_language(self, loader, sample_episode_html):
|
||||
"""is_language should return False when language is not available."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = sample_episode_html.encode("utf-8")
|
||||
loader.session.get.return_value = mock_response
|
||||
|
||||
result = loader.is_language(1, 1, "naruto", "German Sub")
|
||||
assert result is False
|
||||
|
||||
def test_is_language_no_language_box(self, loader):
|
||||
"""is_language should return False when no language box exists."""
|
||||
html = "<html><body><div></div></body></html>"
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = html.encode("utf-8")
|
||||
loader.session.get.return_value = mock_response
|
||||
|
||||
result = loader.is_language(1, 1, "naruto", "German Dub")
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestAniworldTitle:
|
||||
"""Test title extraction."""
|
||||
|
||||
def test_get_title_extracts_correctly(self, loader, sample_series_html):
|
||||
"""get_title should extract title from HTML."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = sample_series_html.encode("utf-8")
|
||||
loader._KeyHTMLDict["naruto"] = mock_response
|
||||
|
||||
result = loader.get_title("naruto")
|
||||
assert result == "Naruto Shippuden"
|
||||
|
||||
def test_get_title_missing_title_div(self, loader):
|
||||
"""get_title should return empty string when title div is missing."""
|
||||
html = "<html><body></body></html>"
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = html.encode("utf-8")
|
||||
loader._KeyHTMLDict["unknown"] = mock_response
|
||||
|
||||
result = loader.get_title("unknown")
|
||||
assert result == ""
|
||||
|
||||
def test_get_title_caches_html(self, loader, sample_series_html):
|
||||
"""get_title should use cached HTML on second call."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = sample_series_html.encode("utf-8")
|
||||
loader._KeyHTMLDict["naruto"] = mock_response
|
||||
|
||||
loader.get_title("naruto")
|
||||
loader.get_title("naruto")
|
||||
# Session should not be called since HTML is cached
|
||||
loader.session.get.assert_not_called()
|
||||
|
||||
|
||||
class TestAniworldYear:
|
||||
"""Test year extraction."""
|
||||
|
||||
def test_get_year_extracts_from_metadata(self, loader, sample_series_html):
|
||||
"""get_year should extract year from 'Jahr:' text."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = sample_series_html.encode("utf-8")
|
||||
loader._KeyHTMLDict["naruto"] = mock_response
|
||||
|
||||
result = loader.get_year("naruto")
|
||||
assert result == 2007
|
||||
|
||||
def test_get_year_returns_none_when_not_found(self, loader):
|
||||
"""get_year should return None when no year info exists."""
|
||||
html = "<html><body><div class='series-title'></div></body></html>"
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = html.encode("utf-8")
|
||||
loader._KeyHTMLDict["unknown"] = mock_response
|
||||
|
||||
result = loader.get_year("unknown")
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestAniworldEpisodeHtml:
|
||||
"""Test episode HTML fetching and caching."""
|
||||
|
||||
def test_get_episode_html_fetches_from_session(self, loader):
|
||||
"""_get_episode_html should fetch from session and cache."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"<html></html>"
|
||||
loader.session.get.return_value = mock_response
|
||||
|
||||
result = loader._get_episode_html(1, 1, "naruto")
|
||||
assert result is mock_response
|
||||
loader.session.get.assert_called_once()
|
||||
|
||||
def test_get_episode_html_invalid_season(self, loader):
|
||||
"""_get_episode_html should raise ValueError for invalid season."""
|
||||
with pytest.raises(ValueError, match="Invalid season number"):
|
||||
loader._get_episode_html(0, 1, "naruto")
|
||||
|
||||
def test_get_episode_html_invalid_episode(self, loader):
|
||||
"""_get_episode_html should raise ValueError for invalid episode."""
|
||||
with pytest.raises(ValueError, match="Invalid episode number"):
|
||||
loader._get_episode_html(1, 0, "naruto")
|
||||
|
||||
def test_get_episode_html_season_too_large(self, loader):
|
||||
"""_get_episode_html should raise ValueError for season > 999."""
|
||||
with pytest.raises(ValueError, match="Invalid season number"):
|
||||
loader._get_episode_html(1000, 1, "naruto")
|
||||
|
||||
def test_get_episode_html_episode_too_large(self, loader):
|
||||
"""_get_episode_html should raise ValueError for episode > 9999."""
|
||||
with pytest.raises(ValueError, match="Invalid episode number"):
|
||||
loader._get_episode_html(1, 10000, "naruto")
|
||||
|
||||
|
||||
class TestAniworldProviderParsing:
|
||||
"""Test provider extraction from HTML."""
|
||||
|
||||
def test_parse_providers_from_html(self, loader):
|
||||
"""_get_provider_from_html should extract available providers."""
|
||||
html = """
|
||||
<html><body>
|
||||
<li class="episodeLink1" data-lang-key="1">
|
||||
<h4>VOE</h4>
|
||||
<a class="watchEpisode" href="/redirect/111"></a>
|
||||
</li>
|
||||
<li class="episodeLink2" data-lang-key="2">
|
||||
<h4>Vidmoly</h4>
|
||||
<a class="watchEpisode" href="/redirect/222"></a>
|
||||
</li>
|
||||
</body></html>
|
||||
"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = html.encode("utf-8")
|
||||
loader.session.get.return_value = mock_response
|
||||
|
||||
result = loader._get_provider_from_html(1, 1, "naruto")
|
||||
assert "VOE" in result
|
||||
assert "Vidmoly" in result
|
||||
assert 1 in result["VOE"]
|
||||
assert 2 in result["Vidmoly"]
|
||||
|
||||
def test_parse_providers_empty_html(self, loader):
|
||||
"""_get_provider_from_html should return empty dict for no providers."""
|
||||
html = "<html><body></body></html>"
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = html.encode("utf-8")
|
||||
loader.session.get.return_value = mock_response
|
||||
|
||||
result = loader._get_provider_from_html(1, 1, "naruto")
|
||||
assert result == {}
|
||||
|
||||
def test_parse_providers_missing_lang_key(self, loader):
|
||||
"""Providers without data-lang-key should be skipped."""
|
||||
html = """
|
||||
<html><body>
|
||||
<li class="episodeLink1">
|
||||
<h4>VOE</h4>
|
||||
<a class="watchEpisode" href="/redirect/111"></a>
|
||||
</li>
|
||||
</body></html>
|
||||
"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = html.encode("utf-8")
|
||||
loader.session.get.return_value = mock_response
|
||||
|
||||
result = loader._get_provider_from_html(1, 1, "naruto")
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestAniworldSeasonEpisodeCount:
|
||||
"""Test season and episode count retrieval."""
|
||||
|
||||
@patch("src.core.providers.aniworld_provider.requests.get")
|
||||
def test_get_season_episode_count(self, mock_get, loader):
|
||||
"""get_season_episode_count should return correct counts."""
|
||||
# Main page with 2 seasons
|
||||
main_html = '<html><body><meta itemprop="numberOfSeasons" content="2" /></body></html>'
|
||||
# Season 1 with 3 episodes
|
||||
s1_html = """
|
||||
<html><body>
|
||||
<a href="/anime/stream/naruto/staffel-1/episode-1">Ep1</a>
|
||||
<a href="/anime/stream/naruto/staffel-1/episode-2">Ep2</a>
|
||||
<a href="/anime/stream/naruto/staffel-1/episode-3">Ep3</a>
|
||||
</body></html>
|
||||
"""
|
||||
# Season 2 with 2 episodes
|
||||
s2_html = """
|
||||
<html><body>
|
||||
<a href="/anime/stream/naruto/staffel-2/episode-1">Ep1</a>
|
||||
<a href="/anime/stream/naruto/staffel-2/episode-2">Ep2</a>
|
||||
</body></html>
|
||||
"""
|
||||
|
||||
responses = [
|
||||
MagicMock(content=main_html.encode()),
|
||||
MagicMock(content=s1_html.encode()),
|
||||
MagicMock(content=s2_html.encode()),
|
||||
]
|
||||
mock_get.side_effect = responses
|
||||
|
||||
result = loader.get_season_episode_count("naruto")
|
||||
assert result == {1: 3, 2: 2}
|
||||
|
||||
@patch("src.core.providers.aniworld_provider.requests.get")
|
||||
def test_get_season_episode_count_no_seasons(self, mock_get, loader):
|
||||
"""get_season_episode_count should return empty dict when no seasons."""
|
||||
html = "<html><body></body></html>"
|
||||
mock_get.return_value = MagicMock(content=html.encode())
|
||||
|
||||
result = loader.get_season_episode_count("nonexistent")
|
||||
assert result == {}
|
||||
|
||||
|
||||
class TestAniworldCache:
|
||||
"""Test cache operations."""
|
||||
|
||||
def test_clear_cache(self, loader):
|
||||
"""clear_cache should empty both caches."""
|
||||
loader._KeyHTMLDict["key1"] = "data"
|
||||
loader._EpisodeHTMLDict[("key1", 1, 1)] = "data"
|
||||
|
||||
loader.clear_cache()
|
||||
|
||||
assert len(loader._KeyHTMLDict) == 0
|
||||
assert len(loader._EpisodeHTMLDict) == 0
|
||||
|
||||
def test_remove_from_cache(self, loader):
|
||||
"""remove_from_cache should only clear episode cache."""
|
||||
loader._KeyHTMLDict["key1"] = "data"
|
||||
loader._EpisodeHTMLDict[("key1", 1, 1)] = "data"
|
||||
|
||||
loader.remove_from_cache()
|
||||
|
||||
assert len(loader._KeyHTMLDict) == 1
|
||||
assert len(loader._EpisodeHTMLDict) == 0
|
||||
|
||||
|
||||
class TestAniworldEvents:
|
||||
"""Test event subscription for download progress."""
|
||||
|
||||
def test_subscribe_download_progress(self, loader):
|
||||
"""subscribe_download_progress should register handler."""
|
||||
handler = MagicMock()
|
||||
loader.subscribe_download_progress(handler)
|
||||
# Fire event to verify handler was registered
|
||||
loader.events.download_progress({"status": "downloading"})
|
||||
handler.assert_called_once_with({"status": "downloading"})
|
||||
|
||||
def test_unsubscribe_download_progress(self, loader):
|
||||
"""unsubscribe_download_progress should remove handler."""
|
||||
handler = MagicMock()
|
||||
loader.subscribe_download_progress(handler)
|
||||
loader.unsubscribe_download_progress(handler)
|
||||
# Fire event - handler should NOT be called
|
||||
loader.events.download_progress({"status": "downloading"})
|
||||
handler.assert_not_called()
|
||||
218
tests/unit/test_base_provider.py
Normal file
218
tests/unit/test_base_provider.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""Unit tests for base_provider.py - Abstract base class and interface contracts."""
|
||||
|
||||
from abc import ABC
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.core.providers.base_provider import Loader
|
||||
|
||||
|
||||
class TestLoaderAbstractInterface:
|
||||
"""Test that Loader defines the correct abstract interface."""
|
||||
|
||||
def test_loader_is_abstract_class(self):
|
||||
"""Loader should be an abstract class and not directly instantiable."""
|
||||
assert issubclass(Loader, ABC)
|
||||
with pytest.raises(TypeError):
|
||||
Loader()
|
||||
|
||||
def test_loader_defines_search_method(self):
|
||||
"""Loader must define abstract search method."""
|
||||
assert hasattr(Loader, "search")
|
||||
assert getattr(Loader.search, "__isabstractmethod__", False)
|
||||
|
||||
def test_loader_defines_is_language_method(self):
|
||||
"""Loader must define abstract is_language method."""
|
||||
assert hasattr(Loader, "is_language")
|
||||
assert getattr(Loader.is_language, "__isabstractmethod__", False)
|
||||
|
||||
def test_loader_defines_download_method(self):
|
||||
"""Loader must define abstract download method."""
|
||||
assert hasattr(Loader, "download")
|
||||
assert getattr(Loader.download, "__isabstractmethod__", False)
|
||||
|
||||
def test_loader_defines_get_site_key_method(self):
|
||||
"""Loader must define abstract get_site_key method."""
|
||||
assert hasattr(Loader, "get_site_key")
|
||||
assert getattr(Loader.get_site_key, "__isabstractmethod__", False)
|
||||
|
||||
def test_loader_defines_get_title_method(self):
|
||||
"""Loader must define abstract get_title method."""
|
||||
assert hasattr(Loader, "get_title")
|
||||
assert getattr(Loader.get_title, "__isabstractmethod__", False)
|
||||
|
||||
def test_loader_defines_get_season_episode_count_method(self):
|
||||
"""Loader must define abstract get_season_episode_count method."""
|
||||
assert hasattr(Loader, "get_season_episode_count")
|
||||
assert getattr(
|
||||
Loader.get_season_episode_count, "__isabstractmethod__", False
|
||||
)
|
||||
|
||||
def test_loader_defines_subscribe_download_progress(self):
|
||||
"""Loader must define abstract subscribe_download_progress method."""
|
||||
assert hasattr(Loader, "subscribe_download_progress")
|
||||
assert getattr(
|
||||
Loader.subscribe_download_progress, "__isabstractmethod__", False
|
||||
)
|
||||
|
||||
def test_loader_defines_unsubscribe_download_progress(self):
|
||||
"""Loader must define abstract unsubscribe_download_progress method."""
|
||||
assert hasattr(Loader, "unsubscribe_download_progress")
|
||||
assert getattr(
|
||||
Loader.unsubscribe_download_progress, "__isabstractmethod__", False
|
||||
)
|
||||
|
||||
|
||||
class ConcreteLoader(Loader):
|
||||
"""Minimal concrete implementation for testing inheritance."""
|
||||
|
||||
def subscribe_download_progress(self, handler):
|
||||
pass
|
||||
|
||||
def unsubscribe_download_progress(self, handler):
|
||||
pass
|
||||
|
||||
def search(self, word: str) -> List[Dict[str, Any]]:
|
||||
return [{"title": word}]
|
||||
|
||||
def is_language(
|
||||
self,
|
||||
season: int,
|
||||
episode: int,
|
||||
key: str,
|
||||
language: str = "German Dub",
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def download(
|
||||
self,
|
||||
base_directory: str,
|
||||
serie_folder: str,
|
||||
season: int,
|
||||
episode: int,
|
||||
key: str,
|
||||
language: str = "German Dub",
|
||||
) -> bool:
|
||||
return True
|
||||
|
||||
def get_site_key(self) -> str:
|
||||
return "test.provider"
|
||||
|
||||
def get_title(self, key: str) -> str:
|
||||
return f"Title for {key}"
|
||||
|
||||
def get_season_episode_count(self, slug: str) -> Dict[int, int]:
|
||||
return {1: 12, 2: 24}
|
||||
|
||||
|
||||
class TestLoaderInheritance:
|
||||
"""Test concrete implementations of the Loader interface."""
|
||||
|
||||
def test_concrete_loader_can_be_instantiated(self):
|
||||
"""A fully implemented subclass should instantiate without error."""
|
||||
loader = ConcreteLoader()
|
||||
assert isinstance(loader, Loader)
|
||||
|
||||
def test_concrete_loader_search(self):
|
||||
"""Concrete search should return a list of dicts."""
|
||||
loader = ConcreteLoader()
|
||||
result = loader.search("Naruto")
|
||||
assert isinstance(result, list)
|
||||
assert result[0]["title"] == "Naruto"
|
||||
|
||||
def test_concrete_loader_is_language(self):
|
||||
"""Concrete is_language should return bool."""
|
||||
loader = ConcreteLoader()
|
||||
assert loader.is_language(1, 1, "naruto") is True
|
||||
|
||||
def test_concrete_loader_is_language_default_param(self):
|
||||
"""is_language should default to 'German Dub' language."""
|
||||
loader = ConcreteLoader()
|
||||
result = loader.is_language(1, 1, "naruto")
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_concrete_loader_download(self):
|
||||
"""Concrete download should return bool."""
|
||||
loader = ConcreteLoader()
|
||||
result = loader.download("/base", "folder", 1, 1, "naruto")
|
||||
assert result is True
|
||||
|
||||
def test_concrete_loader_get_site_key(self):
|
||||
"""Concrete get_site_key should return a string."""
|
||||
loader = ConcreteLoader()
|
||||
assert loader.get_site_key() == "test.provider"
|
||||
|
||||
def test_concrete_loader_get_title(self):
|
||||
"""Concrete get_title should return title string."""
|
||||
loader = ConcreteLoader()
|
||||
assert loader.get_title("naruto") == "Title for naruto"
|
||||
|
||||
def test_concrete_loader_get_season_episode_count(self):
|
||||
"""Concrete get_season_episode_count should return dict."""
|
||||
loader = ConcreteLoader()
|
||||
result = loader.get_season_episode_count("naruto")
|
||||
assert isinstance(result, dict)
|
||||
assert result[1] == 12
|
||||
assert result[2] == 24
|
||||
|
||||
def test_concrete_loader_subscribe_download_progress(self):
|
||||
"""subscribe_download_progress should accept handler without error."""
|
||||
loader = ConcreteLoader()
|
||||
handler = MagicMock()
|
||||
loader.subscribe_download_progress(handler)
|
||||
|
||||
def test_concrete_loader_unsubscribe_download_progress(self):
|
||||
"""unsubscribe_download_progress should accept handler without error."""
|
||||
loader = ConcreteLoader()
|
||||
handler = MagicMock()
|
||||
loader.unsubscribe_download_progress(handler)
|
||||
|
||||
|
||||
class IncompleteLoader(Loader):
|
||||
"""Incomplete implementation missing some abstract methods."""
|
||||
|
||||
def search(self, word: str) -> List[Dict[str, Any]]:
|
||||
return []
|
||||
|
||||
def is_language(self, season, episode, key, language="German Dub"):
|
||||
return False
|
||||
|
||||
# Deliberately omit download, get_site_key, etc.
|
||||
|
||||
|
||||
class TestIncompleteImplementation:
|
||||
"""Test that incomplete implementations cannot be instantiated."""
|
||||
|
||||
def test_incomplete_loader_raises_type_error(self):
|
||||
"""Loader subclass missing abstract methods cannot be instantiated."""
|
||||
with pytest.raises(TypeError):
|
||||
IncompleteLoader()
|
||||
|
||||
|
||||
class TestLoaderMethodSignatures:
|
||||
"""Test method signatures match the expected contract."""
|
||||
|
||||
def test_search_returns_list(self):
|
||||
"""search() should return List[Dict[str, Any]]."""
|
||||
loader = ConcreteLoader()
|
||||
result = loader.search("test")
|
||||
assert isinstance(result, list)
|
||||
for item in result:
|
||||
assert isinstance(item, dict)
|
||||
|
||||
def test_download_returns_bool(self):
|
||||
"""download() should return bool."""
|
||||
loader = ConcreteLoader()
|
||||
result = loader.download("/dir", "folder", 1, 1, "key")
|
||||
assert isinstance(result, bool)
|
||||
|
||||
def test_get_season_episode_count_returns_dict_int_int(self):
|
||||
"""get_season_episode_count() should return Dict[int, int]."""
|
||||
loader = ConcreteLoader()
|
||||
result = loader.get_season_episode_count("slug")
|
||||
assert isinstance(result, dict)
|
||||
for k, v in result.items():
|
||||
assert isinstance(k, int)
|
||||
assert isinstance(v, int)
|
||||
445
tests/unit/test_enhanced_provider.py
Normal file
445
tests/unit/test_enhanced_provider.py
Normal file
@@ -0,0 +1,445 @@
|
||||
"""Unit tests for enhanced_provider.py - Caching, recovery, download logic."""
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, Mock, patch, PropertyMock
|
||||
|
||||
import pytest
|
||||
|
||||
from src.core.error_handler import (
|
||||
DownloadError,
|
||||
NetworkError,
|
||||
NonRetryableError,
|
||||
RetryableError,
|
||||
)
|
||||
from src.core.providers.base_provider import Loader
|
||||
|
||||
|
||||
# Import the class but we need a concrete subclass to test it
|
||||
from src.core.providers.enhanced_provider import EnhancedAniWorldLoader
|
||||
|
||||
|
||||
class ConcreteEnhancedLoader(EnhancedAniWorldLoader):
|
||||
"""Concrete subclass that bridges PascalCase methods to abstract interface."""
|
||||
|
||||
def subscribe_download_progress(self, handler):
|
||||
pass
|
||||
|
||||
def unsubscribe_download_progress(self, handler):
|
||||
pass
|
||||
|
||||
def search(self, word: str) -> List[Dict[str, Any]]:
|
||||
return self.Search(word)
|
||||
|
||||
def is_language(self, season, episode, key, language="German Dub"):
|
||||
return self.IsLanguage(season, episode, key, language)
|
||||
|
||||
def download(self, base_directory, serie_folder, season, episode, key,
|
||||
language="German Dub", **kwargs):
|
||||
return self.Download(base_directory, serie_folder, season, episode,
|
||||
key, language)
|
||||
|
||||
def get_site_key(self) -> str:
|
||||
return self.GetSiteKey()
|
||||
|
||||
def get_title(self, key: str) -> str:
|
||||
return self.GetTitle(key)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enhanced_loader():
|
||||
"""Create ConcreteEnhancedLoader with mocked externals."""
|
||||
with patch(
|
||||
"src.core.providers.enhanced_provider.UserAgent"
|
||||
) as mock_ua, patch(
|
||||
"src.core.providers.enhanced_provider.get_integrity_manager"
|
||||
):
|
||||
mock_ua.return_value.random = "MockAgent/1.0"
|
||||
loader = ConcreteEnhancedLoader()
|
||||
loader.session = MagicMock()
|
||||
return loader
|
||||
|
||||
|
||||
class TestEnhancedLoaderInit:
|
||||
"""Test EnhancedAniWorldLoader initialization."""
|
||||
|
||||
def test_initializes_with_download_stats(self, enhanced_loader):
|
||||
"""Should initialize download statistics at zero."""
|
||||
stats = enhanced_loader.download_stats
|
||||
assert stats["total_downloads"] == 0
|
||||
assert stats["successful_downloads"] == 0
|
||||
assert stats["failed_downloads"] == 0
|
||||
assert stats["retried_downloads"] == 0
|
||||
|
||||
def test_initializes_with_caches(self, enhanced_loader):
|
||||
"""Should initialize empty caches."""
|
||||
assert enhanced_loader._KeyHTMLDict == {}
|
||||
assert enhanced_loader._EpisodeHTMLDict == {}
|
||||
|
||||
def test_site_key(self, enhanced_loader):
|
||||
"""GetSiteKey should return 'aniworld.to'."""
|
||||
assert enhanced_loader.GetSiteKey() == "aniworld.to"
|
||||
|
||||
def test_has_supported_providers(self, enhanced_loader):
|
||||
"""Should have a list of supported providers."""
|
||||
assert isinstance(enhanced_loader.SUPPORTED_PROVIDERS, list)
|
||||
assert len(enhanced_loader.SUPPORTED_PROVIDERS) > 0
|
||||
assert "VOE" in enhanced_loader.SUPPORTED_PROVIDERS
|
||||
|
||||
|
||||
class TestEnhancedSearch:
|
||||
"""Test enhanced search with error recovery."""
|
||||
|
||||
def test_search_empty_term_raises(self, enhanced_loader):
|
||||
"""Search with empty term should raise ValueError."""
|
||||
with pytest.raises(ValueError, match="empty"):
|
||||
enhanced_loader.Search("")
|
||||
|
||||
def test_search_whitespace_only_raises(self, enhanced_loader):
|
||||
"""Search with whitespace-only term should raise ValueError."""
|
||||
with pytest.raises(ValueError, match="empty"):
|
||||
enhanced_loader.Search(" ")
|
||||
|
||||
def test_search_successful(self, enhanced_loader):
|
||||
"""Successful search should return parsed list."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.ok = True
|
||||
mock_response.text = json.dumps([
|
||||
{"title": "Naruto", "link": "/anime/stream/naruto"}
|
||||
])
|
||||
enhanced_loader.session.get.return_value = mock_response
|
||||
|
||||
result = enhanced_loader.Search("Naruto")
|
||||
assert len(result) == 1
|
||||
assert result[0]["title"] == "Naruto"
|
||||
|
||||
|
||||
class TestParseAnimeResponse:
|
||||
"""Test JSON parsing strategies."""
|
||||
|
||||
def test_parse_valid_json_list(self, enhanced_loader):
|
||||
"""Should parse valid JSON list."""
|
||||
text = '[{"title": "Naruto"}]'
|
||||
result = enhanced_loader._parse_anime_response(text)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_parse_html_escaped_json(self, enhanced_loader):
|
||||
"""Should handle HTML-escaped JSON."""
|
||||
text = '[{"title": "Naruto & Boruto"}]'
|
||||
result = enhanced_loader._parse_anime_response(text)
|
||||
assert result[0]["title"] == "Naruto & Boruto"
|
||||
|
||||
def test_parse_empty_response_raises(self, enhanced_loader):
|
||||
"""Empty response should raise ValueError."""
|
||||
with pytest.raises(ValueError, match="Empty response"):
|
||||
enhanced_loader._parse_anime_response("")
|
||||
|
||||
def test_parse_whitespace_only_raises(self, enhanced_loader):
|
||||
"""Whitespace-only response should raise ValueError."""
|
||||
with pytest.raises(ValueError, match="Empty response"):
|
||||
enhanced_loader._parse_anime_response(" ")
|
||||
|
||||
def test_parse_html_response_raises(self, enhanced_loader):
|
||||
"""HTML response instead of JSON should raise ValueError."""
|
||||
with pytest.raises(ValueError):
|
||||
enhanced_loader._parse_anime_response(
|
||||
"<!DOCTYPE html><html></html>"
|
||||
)
|
||||
|
||||
def test_parse_bom_json(self, enhanced_loader):
|
||||
"""Should handle BOM-prefixed JSON."""
|
||||
text = '\ufeff[{"title": "Test"}]'
|
||||
result = enhanced_loader._parse_anime_response(text)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_parse_control_characters(self, enhanced_loader):
|
||||
"""Should strip control characters and parse."""
|
||||
text = '[{"title": "Na\x00ruto"}]'
|
||||
result = enhanced_loader._parse_anime_response(text)
|
||||
assert len(result) == 1
|
||||
|
||||
def test_parse_non_list_result_raises(self, enhanced_loader):
|
||||
"""Non-list JSON should raise ValueError."""
|
||||
with pytest.raises(ValueError):
|
||||
enhanced_loader._parse_anime_response('{"key": "value"}')
|
||||
|
||||
|
||||
class TestLanguageKey:
|
||||
"""Test language code mapping."""
|
||||
|
||||
def test_german_dub(self, enhanced_loader):
|
||||
"""German Dub should map to 1."""
|
||||
assert enhanced_loader._GetLanguageKey("German Dub") == 1
|
||||
|
||||
def test_english_sub(self, enhanced_loader):
|
||||
"""English Sub should map to 2."""
|
||||
assert enhanced_loader._GetLanguageKey("English Sub") == 2
|
||||
|
||||
def test_german_sub(self, enhanced_loader):
|
||||
"""German Sub should map to 3."""
|
||||
assert enhanced_loader._GetLanguageKey("German Sub") == 3
|
||||
|
||||
def test_unknown_language(self, enhanced_loader):
|
||||
"""Unknown language should map to 0."""
|
||||
assert enhanced_loader._GetLanguageKey("French Dub") == 0
|
||||
|
||||
|
||||
class TestEnhancedIsLanguage:
|
||||
"""Test language availability checking with recovery."""
|
||||
|
||||
def test_is_language_available(self, enhanced_loader):
|
||||
"""Should return True when language is available."""
|
||||
html = """
|
||||
<html><body>
|
||||
<div class="changeLanguageBox">
|
||||
<img data-lang-key="1" />
|
||||
<img data-lang-key="2" />
|
||||
</div>
|
||||
</body></html>
|
||||
"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = html.encode("utf-8")
|
||||
enhanced_loader._EpisodeHTMLDict[(
|
||||
"naruto", 1, 1
|
||||
)] = mock_response
|
||||
|
||||
result = enhanced_loader.IsLanguage(1, 1, "naruto", "German Dub")
|
||||
assert result is True
|
||||
|
||||
def test_is_language_not_available(self, enhanced_loader):
|
||||
"""Should return False when language is not available."""
|
||||
html = """
|
||||
<html><body>
|
||||
<div class="changeLanguageBox">
|
||||
<img data-lang-key="1" />
|
||||
</div>
|
||||
</body></html>
|
||||
"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = html.encode("utf-8")
|
||||
enhanced_loader._EpisodeHTMLDict[(
|
||||
"naruto", 1, 1
|
||||
)] = mock_response
|
||||
|
||||
result = enhanced_loader.IsLanguage(1, 1, "naruto", "German Sub")
|
||||
assert result is False
|
||||
|
||||
def test_is_language_no_language_box(self, enhanced_loader):
|
||||
"""Should return False when no language box in HTML."""
|
||||
html = "<html><body></body></html>"
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = html.encode("utf-8")
|
||||
enhanced_loader._EpisodeHTMLDict[(
|
||||
"naruto", 1, 1
|
||||
)] = mock_response
|
||||
|
||||
result = enhanced_loader.IsLanguage(1, 1, "naruto", "German Dub")
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestEnhancedGetTitle:
|
||||
"""Test title extraction with error recovery."""
|
||||
|
||||
def test_get_title_successful(self, enhanced_loader):
|
||||
"""Should extract title from HTML."""
|
||||
html = """
|
||||
<html><body>
|
||||
<div class="series-title">
|
||||
<h1><span>Attack on Titan</span></h1>
|
||||
</div>
|
||||
</body></html>
|
||||
"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = html.encode("utf-8")
|
||||
enhanced_loader._KeyHTMLDict["aot"] = mock_response
|
||||
|
||||
result = enhanced_loader.GetTitle("aot")
|
||||
assert result == "Attack on Titan"
|
||||
|
||||
def test_get_title_missing_returns_fallback(self, enhanced_loader):
|
||||
"""Should return fallback title when not found in HTML."""
|
||||
html = "<html><body></body></html>"
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = html.encode("utf-8")
|
||||
enhanced_loader._KeyHTMLDict["unknown"] = mock_response
|
||||
|
||||
result = enhanced_loader.GetTitle("unknown")
|
||||
assert "Unknown_Title" in result
|
||||
|
||||
|
||||
class TestEnhancedCache:
|
||||
"""Test cache operations."""
|
||||
|
||||
def test_clear_cache(self, enhanced_loader):
|
||||
"""ClearCache should empty all caches."""
|
||||
enhanced_loader._KeyHTMLDict["key"] = "data"
|
||||
enhanced_loader._EpisodeHTMLDict[("key", 1, 1)] = "data"
|
||||
|
||||
enhanced_loader.ClearCache()
|
||||
|
||||
assert len(enhanced_loader._KeyHTMLDict) == 0
|
||||
assert len(enhanced_loader._EpisodeHTMLDict) == 0
|
||||
|
||||
def test_remove_from_cache(self, enhanced_loader):
|
||||
"""RemoveFromCache should only clear episode cache."""
|
||||
enhanced_loader._KeyHTMLDict["key"] = "data"
|
||||
enhanced_loader._EpisodeHTMLDict[("key", 1, 1)] = "data"
|
||||
|
||||
enhanced_loader.RemoveFromCache()
|
||||
|
||||
assert len(enhanced_loader._KeyHTMLDict) == 1
|
||||
assert len(enhanced_loader._EpisodeHTMLDict) == 0
|
||||
|
||||
|
||||
class TestEnhancedGetEpisodeHTML:
|
||||
"""Test episode HTML fetching with validation."""
|
||||
|
||||
def test_empty_key_raises(self, enhanced_loader):
|
||||
"""Empty key should raise ValueError."""
|
||||
with pytest.raises(ValueError, match="empty"):
|
||||
enhanced_loader._GetEpisodeHTML(1, 1, "")
|
||||
|
||||
def test_whitespace_key_raises(self, enhanced_loader):
|
||||
"""Whitespace key should raise ValueError."""
|
||||
with pytest.raises(ValueError, match="empty"):
|
||||
enhanced_loader._GetEpisodeHTML(1, 1, " ")
|
||||
|
||||
def test_invalid_season_zero_raises(self, enhanced_loader):
|
||||
"""Season 0 should raise ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid season"):
|
||||
enhanced_loader._GetEpisodeHTML(0, 1, "naruto")
|
||||
|
||||
def test_invalid_season_negative_raises(self, enhanced_loader):
|
||||
"""Negative season should raise ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid season"):
|
||||
enhanced_loader._GetEpisodeHTML(-1, 1, "naruto")
|
||||
|
||||
def test_invalid_episode_zero_raises(self, enhanced_loader):
|
||||
"""Episode 0 should raise ValueError."""
|
||||
with pytest.raises(ValueError, match="Invalid episode"):
|
||||
enhanced_loader._GetEpisodeHTML(1, 0, "naruto")
|
||||
|
||||
def test_cached_episode_returned(self, enhanced_loader):
|
||||
"""Should return cached response without HTTP call."""
|
||||
mock_response = MagicMock()
|
||||
enhanced_loader._EpisodeHTMLDict[("naruto", 1, 1)] = mock_response
|
||||
|
||||
result = enhanced_loader._GetEpisodeHTML(1, 1, "naruto")
|
||||
assert result is mock_response
|
||||
enhanced_loader.session.get.assert_not_called()
|
||||
|
||||
|
||||
class TestDownloadStatistics:
|
||||
"""Test download statistics tracking."""
|
||||
|
||||
def test_get_download_statistics(self, enhanced_loader):
|
||||
"""Should return stats with calculated success rate."""
|
||||
enhanced_loader.download_stats["total_downloads"] = 10
|
||||
enhanced_loader.download_stats["successful_downloads"] = 8
|
||||
enhanced_loader.download_stats["failed_downloads"] = 2
|
||||
|
||||
stats = enhanced_loader.get_download_statistics()
|
||||
assert stats["success_rate"] == 80.0
|
||||
|
||||
def test_statistics_zero_downloads(self, enhanced_loader):
|
||||
"""Success rate should be 0 with no downloads."""
|
||||
stats = enhanced_loader.get_download_statistics()
|
||||
assert stats["success_rate"] == 0
|
||||
|
||||
def test_reset_statistics(self, enhanced_loader):
|
||||
"""reset_statistics should zero all counters."""
|
||||
enhanced_loader.download_stats["total_downloads"] = 10
|
||||
enhanced_loader.download_stats["successful_downloads"] = 8
|
||||
|
||||
enhanced_loader.reset_statistics()
|
||||
|
||||
assert enhanced_loader.download_stats["total_downloads"] == 0
|
||||
assert enhanced_loader.download_stats["successful_downloads"] == 0
|
||||
|
||||
|
||||
class TestEnhancedDownloadValidation:
|
||||
"""Test download input validation."""
|
||||
|
||||
@patch("src.core.providers.enhanced_provider.get_integrity_manager")
|
||||
def test_download_missing_base_directory_raises(
|
||||
self, mock_integrity, enhanced_loader
|
||||
):
|
||||
"""Download with empty base_directory should raise."""
|
||||
with pytest.raises((ValueError, DownloadError)):
|
||||
enhanced_loader.Download("", "folder", 1, 1, "key")
|
||||
|
||||
@patch("src.core.providers.enhanced_provider.get_integrity_manager")
|
||||
def test_download_missing_serie_folder_raises(
|
||||
self, mock_integrity, enhanced_loader
|
||||
):
|
||||
"""Download with empty serie_folder should raise."""
|
||||
with pytest.raises((ValueError, DownloadError)):
|
||||
enhanced_loader.Download("/base", "", 1, 1, "key")
|
||||
|
||||
@patch("src.core.providers.enhanced_provider.get_integrity_manager")
|
||||
def test_download_negative_season_raises(
|
||||
self, mock_integrity, enhanced_loader
|
||||
):
|
||||
"""Download with negative season should raise."""
|
||||
with pytest.raises((ValueError, DownloadError)):
|
||||
enhanced_loader.Download("/base", "folder", -1, 1, "key")
|
||||
|
||||
@patch("src.core.providers.enhanced_provider.get_integrity_manager")
|
||||
def test_download_negative_episode_raises(
|
||||
self, mock_integrity, enhanced_loader
|
||||
):
|
||||
"""Download with negative episode should raise."""
|
||||
with pytest.raises((ValueError, DownloadError)):
|
||||
enhanced_loader.Download("/base", "folder", 1, -1, "key")
|
||||
|
||||
@patch("src.core.providers.enhanced_provider.get_integrity_manager")
|
||||
def test_download_increments_total_count(
|
||||
self, mock_integrity, enhanced_loader
|
||||
):
|
||||
"""Download should increment total_downloads counter."""
|
||||
# Make it fail fast so we don't need to mock everything
|
||||
enhanced_loader._KeyHTMLDict["key"] = MagicMock(
|
||||
content=b"<html><body><div class='series-title'><h1><span>Test</span></h1></div></body></html>"
|
||||
)
|
||||
try:
|
||||
enhanced_loader.Download("/base", "folder", 1, 1, "key")
|
||||
except Exception:
|
||||
pass
|
||||
assert enhanced_loader.download_stats["total_downloads"] >= 1
|
||||
|
||||
|
||||
class TestEnhancedProviderFromHTML:
|
||||
"""Test provider extraction from episode HTML."""
|
||||
|
||||
def test_extract_providers(self, enhanced_loader):
|
||||
"""Should extract providers with language keys from HTML."""
|
||||
html = """
|
||||
<html><body>
|
||||
<li class="episodeLink1" data-lang-key="1">
|
||||
<h4>VOE</h4>
|
||||
<a class="watchEpisode" href="/redirect/100"></a>
|
||||
</li>
|
||||
<li class="episodeLink2" data-lang-key="2">
|
||||
<h4>Vidmoly</h4>
|
||||
<a class="watchEpisode" href="/redirect/200"></a>
|
||||
</li>
|
||||
</body></html>
|
||||
"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = html.encode("utf-8")
|
||||
enhanced_loader._EpisodeHTMLDict[("test", 1, 1)] = mock_response
|
||||
|
||||
result = enhanced_loader._get_provider_from_html(1, 1, "test")
|
||||
assert "VOE" in result
|
||||
assert "Vidmoly" in result
|
||||
|
||||
def test_extract_providers_empty_page(self, enhanced_loader):
|
||||
"""Should return empty dict when no providers found."""
|
||||
html = "<html><body></body></html>"
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = html.encode("utf-8")
|
||||
enhanced_loader._EpisodeHTMLDict[("test", 1, 1)] = mock_response
|
||||
|
||||
result = enhanced_loader._get_provider_from_html(1, 1, "test")
|
||||
assert result == {}
|
||||
336
tests/unit/test_monitored_provider.py
Normal file
336
tests/unit/test_monitored_provider.py
Normal file
@@ -0,0 +1,336 @@
|
||||
"""Unit tests for monitored_provider.py - Metrics collection, health checks, monitoring integration."""
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.core.providers.base_provider import Loader
|
||||
from src.core.providers.monitored_provider import (
|
||||
MonitoredProviderWrapper,
|
||||
wrap_provider,
|
||||
)
|
||||
|
||||
|
||||
class MockProvider(Loader):
|
||||
"""Mock provider for testing the monitoring wrapper."""
|
||||
|
||||
def __init__(self, site_key: str = "mock.provider"):
|
||||
self._site_key = site_key
|
||||
self._search_result = []
|
||||
self._is_language_result = True
|
||||
self._download_result = True
|
||||
self._title = "Mock Title"
|
||||
self._season_episodes = {1: 12}
|
||||
self.raise_on_search = False
|
||||
self.raise_on_download = False
|
||||
|
||||
def subscribe_download_progress(self, handler):
|
||||
pass
|
||||
|
||||
def unsubscribe_download_progress(self, handler):
|
||||
pass
|
||||
|
||||
def search(self, word):
|
||||
if self.raise_on_search:
|
||||
raise ConnectionError("Search failed")
|
||||
return self._search_result
|
||||
|
||||
def is_language(self, season, episode, key, language="German Dub"):
|
||||
return self._is_language_result
|
||||
|
||||
def download(
|
||||
self, base_directory, serie_folder, season, episode, key,
|
||||
language="German Dub", progress_callback=None
|
||||
):
|
||||
if self.raise_on_download:
|
||||
raise ConnectionError("Download failed")
|
||||
return self._download_result
|
||||
|
||||
def get_site_key(self):
|
||||
return self._site_key
|
||||
|
||||
def get_title(self, key):
|
||||
return self._title
|
||||
|
||||
def get_season_episode_count(self, slug):
|
||||
return self._season_episodes
|
||||
|
||||
|
||||
class ConcreteMonitoredWrapper(MonitoredProviderWrapper):
|
||||
"""Concrete subclass adding the missing abstract methods."""
|
||||
|
||||
def subscribe_download_progress(self, handler):
|
||||
pass
|
||||
|
||||
def unsubscribe_download_progress(self, handler):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_provider():
|
||||
"""Create a mock provider instance."""
|
||||
return MockProvider()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_health_monitor():
|
||||
"""Create a mock health monitor."""
|
||||
monitor = MagicMock()
|
||||
return monitor
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def monitored_wrapper(mock_provider, mock_health_monitor):
|
||||
"""Create a monitored wrapper with mock health monitor."""
|
||||
with patch(
|
||||
"src.core.providers.monitored_provider.get_health_monitor",
|
||||
return_value=mock_health_monitor,
|
||||
):
|
||||
wrapper = ConcreteMonitoredWrapper(
|
||||
provider=mock_provider,
|
||||
enable_monitoring=True,
|
||||
)
|
||||
return wrapper
|
||||
|
||||
|
||||
class TestMonitoredProviderWrapperInit:
|
||||
"""Test MonitoredProviderWrapper initialization."""
|
||||
|
||||
def test_wrapper_stores_provider(self, mock_provider):
|
||||
"""Wrapper should store the wrapped provider."""
|
||||
with patch(
|
||||
"src.core.providers.monitored_provider.get_health_monitor"
|
||||
):
|
||||
wrapper = ConcreteMonitoredWrapper(mock_provider)
|
||||
assert wrapper._provider is mock_provider
|
||||
|
||||
def test_wrapper_monitoring_enabled_by_default(self, mock_provider):
|
||||
"""Monitoring should be enabled by default."""
|
||||
with patch(
|
||||
"src.core.providers.monitored_provider.get_health_monitor"
|
||||
):
|
||||
wrapper = ConcreteMonitoredWrapper(mock_provider)
|
||||
assert wrapper._enable_monitoring is True
|
||||
|
||||
def test_wrapper_monitoring_can_be_disabled(self, mock_provider):
|
||||
"""Monitoring can be disabled on init."""
|
||||
wrapper = ConcreteMonitoredWrapper(
|
||||
mock_provider, enable_monitoring=False
|
||||
)
|
||||
assert wrapper._enable_monitoring is False
|
||||
assert wrapper._health_monitor is None
|
||||
|
||||
def test_wrapped_provider_property(self, monitored_wrapper, mock_provider):
|
||||
"""wrapped_provider property should return the underlying provider."""
|
||||
assert monitored_wrapper.wrapped_provider is mock_provider
|
||||
|
||||
|
||||
class TestMonitoredSearch:
|
||||
"""Test search with monitoring."""
|
||||
|
||||
def test_search_delegates_to_provider(self, monitored_wrapper, mock_provider):
|
||||
"""search() should delegate to wrapped provider."""
|
||||
mock_provider._search_result = [{"title": "Test"}]
|
||||
result = monitored_wrapper.search("test")
|
||||
assert result == [{"title": "Test"}]
|
||||
|
||||
def test_search_records_success_metric(
|
||||
self, monitored_wrapper, mock_health_monitor
|
||||
):
|
||||
"""Successful search should record a success metric."""
|
||||
monitored_wrapper.search("test")
|
||||
mock_health_monitor.record_request.assert_called_once()
|
||||
call_kwargs = mock_health_monitor.record_request.call_args[1]
|
||||
assert call_kwargs["success"] is True
|
||||
assert call_kwargs["provider_name"] == "mock.provider"
|
||||
|
||||
def test_search_records_failure_metric(
|
||||
self, monitored_wrapper, mock_provider, mock_health_monitor
|
||||
):
|
||||
"""Failed search should record a failure metric."""
|
||||
mock_provider.raise_on_search = True
|
||||
with pytest.raises(ConnectionError):
|
||||
monitored_wrapper.search("test")
|
||||
mock_health_monitor.record_request.assert_called_once()
|
||||
call_kwargs = mock_health_monitor.record_request.call_args[1]
|
||||
assert call_kwargs["success"] is False
|
||||
|
||||
def test_search_propagates_exception(
|
||||
self, monitored_wrapper, mock_provider
|
||||
):
|
||||
"""Exception from provider should propagate through wrapper."""
|
||||
mock_provider.raise_on_search = True
|
||||
with pytest.raises(ConnectionError, match="Search failed"):
|
||||
monitored_wrapper.search("test")
|
||||
|
||||
|
||||
class TestMonitoredIsLanguage:
|
||||
"""Test is_language with monitoring."""
|
||||
|
||||
def test_is_language_delegates(self, monitored_wrapper, mock_provider):
|
||||
"""is_language should delegate to wrapped provider."""
|
||||
mock_provider._is_language_result = True
|
||||
result = monitored_wrapper.is_language(1, 1, "key")
|
||||
assert result is True
|
||||
|
||||
def test_is_language_records_metric(
|
||||
self, monitored_wrapper, mock_health_monitor
|
||||
):
|
||||
"""is_language should record metric."""
|
||||
monitored_wrapper.is_language(1, 1, "key")
|
||||
mock_health_monitor.record_request.assert_called_once()
|
||||
call_kwargs = mock_health_monitor.record_request.call_args[1]
|
||||
assert call_kwargs["success"] is True
|
||||
|
||||
|
||||
class TestMonitoredDownload:
|
||||
"""Test download with monitoring."""
|
||||
|
||||
def test_download_delegates_to_provider(
|
||||
self, monitored_wrapper, mock_provider
|
||||
):
|
||||
"""download should delegate to wrapped provider."""
|
||||
mock_provider._download_result = True
|
||||
result = monitored_wrapper.download(
|
||||
"/base", "folder", 1, 1, "key"
|
||||
)
|
||||
assert result is True
|
||||
|
||||
def test_download_records_success(
|
||||
self, monitored_wrapper, mock_health_monitor
|
||||
):
|
||||
"""Successful download should record success metric."""
|
||||
monitored_wrapper.download("/base", "folder", 1, 1, "key")
|
||||
mock_health_monitor.record_request.assert_called_once()
|
||||
call_kwargs = mock_health_monitor.record_request.call_args[1]
|
||||
assert call_kwargs["success"] is True
|
||||
|
||||
def test_download_records_failure(
|
||||
self, monitored_wrapper, mock_provider, mock_health_monitor
|
||||
):
|
||||
"""Failed download should record failure metric."""
|
||||
mock_provider.raise_on_download = True
|
||||
with pytest.raises(ConnectionError):
|
||||
monitored_wrapper.download("/base", "folder", 1, 1, "key")
|
||||
call_kwargs = mock_health_monitor.record_request.call_args[1]
|
||||
assert call_kwargs["success"] is False
|
||||
|
||||
def test_download_tracks_bytes(
|
||||
self, monitored_wrapper, mock_provider, mock_health_monitor
|
||||
):
|
||||
"""Download with progress callback should track bytes."""
|
||||
progress_data = {"downloaded": 5000}
|
||||
|
||||
def mock_download(
|
||||
base_directory, serie_folder, season, episode, key,
|
||||
language="German Dub", progress_callback=None
|
||||
):
|
||||
if progress_callback:
|
||||
progress_callback("progress", progress_data)
|
||||
return True
|
||||
|
||||
mock_provider.download = mock_download
|
||||
callback = MagicMock()
|
||||
monitored_wrapper.download(
|
||||
"/base", "folder", 1, 1, "key",
|
||||
progress_callback=callback
|
||||
)
|
||||
callback.assert_called_once_with("progress", progress_data)
|
||||
|
||||
|
||||
class TestMonitoredGetTitle:
|
||||
"""Test get_title with monitoring."""
|
||||
|
||||
def test_get_title_delegates(self, monitored_wrapper, mock_provider):
|
||||
"""get_title should delegate and return title."""
|
||||
mock_provider._title = "Naruto"
|
||||
result = monitored_wrapper.get_title("naruto")
|
||||
assert result == "Naruto"
|
||||
|
||||
def test_get_title_records_metric(
|
||||
self, monitored_wrapper, mock_health_monitor
|
||||
):
|
||||
"""get_title should record metric."""
|
||||
monitored_wrapper.get_title("test")
|
||||
mock_health_monitor.record_request.assert_called_once()
|
||||
|
||||
|
||||
class TestMonitoredGetSiteKey:
|
||||
"""Test get_site_key delegation."""
|
||||
|
||||
def test_get_site_key_delegates(self, monitored_wrapper):
|
||||
"""get_site_key should return wrapped provider's site key."""
|
||||
assert monitored_wrapper.get_site_key() == "mock.provider"
|
||||
|
||||
|
||||
class TestMonitoredGetSeasonEpisodeCount:
|
||||
"""Test get_season_episode_count with monitoring."""
|
||||
|
||||
def test_delegates_correctly(self, monitored_wrapper, mock_provider):
|
||||
"""Should delegate and return season/episode data."""
|
||||
mock_provider._season_episodes = {1: 24, 2: 12}
|
||||
result = monitored_wrapper.get_season_episode_count("test")
|
||||
assert result == {1: 24, 2: 12}
|
||||
|
||||
def test_records_metric(self, monitored_wrapper, mock_health_monitor):
|
||||
"""Should record metric for season/episode count call."""
|
||||
monitored_wrapper.get_season_episode_count("test")
|
||||
mock_health_monitor.record_request.assert_called_once()
|
||||
|
||||
|
||||
class TestRecordOperation:
|
||||
"""Test _record_operation method."""
|
||||
|
||||
def test_no_recording_when_monitoring_disabled(self, mock_provider):
|
||||
"""Should not record when monitoring is disabled."""
|
||||
wrapper = ConcreteMonitoredWrapper(
|
||||
mock_provider, enable_monitoring=False
|
||||
)
|
||||
# This should not raise even without health monitor
|
||||
wrapper._record_operation(
|
||||
"test_op", time.time(), True
|
||||
)
|
||||
|
||||
def test_records_elapsed_time(
|
||||
self, monitored_wrapper, mock_health_monitor
|
||||
):
|
||||
"""Should calculate and record elapsed time."""
|
||||
start = time.time() - 0.1 # 100ms ago
|
||||
monitored_wrapper._record_operation(
|
||||
"test_op", start, True
|
||||
)
|
||||
call_kwargs = mock_health_monitor.record_request.call_args[1]
|
||||
assert call_kwargs["response_time_ms"] > 0
|
||||
|
||||
def test_records_error_message(
|
||||
self, monitored_wrapper, mock_health_monitor
|
||||
):
|
||||
"""Should record error message on failure."""
|
||||
monitored_wrapper._record_operation(
|
||||
"test_op", time.time(), False, error_message="test error"
|
||||
)
|
||||
call_kwargs = mock_health_monitor.record_request.call_args[1]
|
||||
assert call_kwargs["error_message"] == "test error"
|
||||
|
||||
|
||||
class TestWrapProviderFunction:
|
||||
"""Test the wrap_provider convenience function."""
|
||||
|
||||
def test_wrap_creates_monitored_wrapper(self, mock_provider):
|
||||
"""wrap_provider should return MonitoredProviderWrapper."""
|
||||
with patch(
|
||||
"src.core.providers.monitored_provider.get_health_monitor"
|
||||
):
|
||||
# wrap_provider returns MonitoredProviderWrapper which can't be
|
||||
# instantiated directly due to missing abstract methods.
|
||||
# This tests that wrap_provider raises the expected error.
|
||||
with pytest.raises(TypeError):
|
||||
result = wrap_provider(mock_provider)
|
||||
|
||||
def test_wrap_with_monitoring_disabled(self, mock_provider):
|
||||
"""wrap_provider with monitoring disabled."""
|
||||
# MonitoredProviderWrapper is abstract, so wrap_provider can't
|
||||
# create it directly. This tests the expected behavior.
|
||||
with pytest.raises(TypeError):
|
||||
result = wrap_provider(mock_provider, enable_monitoring=False)
|
||||
429
tests/unit/test_provider_config_manager.py
Normal file
429
tests/unit/test_provider_config_manager.py
Normal file
@@ -0,0 +1,429 @@
|
||||
"""Unit tests for config_manager.py - Configuration loading, validation, defaults."""
|
||||
|
||||
import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from src.core.providers.config_manager import (
|
||||
ProviderConfigManager,
|
||||
ProviderSettings,
|
||||
get_config_manager,
|
||||
)
|
||||
|
||||
|
||||
class TestProviderSettings:
|
||||
"""Test ProviderSettings dataclass."""
|
||||
|
||||
def test_default_values(self):
|
||||
"""ProviderSettings should have sensible defaults."""
|
||||
settings = ProviderSettings(name="test_provider")
|
||||
assert settings.name == "test_provider"
|
||||
assert settings.enabled is True
|
||||
assert settings.priority == 0
|
||||
assert settings.timeout_seconds == 30
|
||||
assert settings.max_retries == 3
|
||||
assert settings.retry_delay_seconds == 1.0
|
||||
assert settings.max_concurrent_downloads == 3
|
||||
assert settings.bandwidth_limit_mbps is None
|
||||
assert settings.custom_headers is None
|
||||
assert settings.custom_params is None
|
||||
|
||||
def test_to_dict(self):
|
||||
"""to_dict should convert settings to dict, excluding None values."""
|
||||
settings = ProviderSettings(
|
||||
name="test",
|
||||
enabled=True,
|
||||
priority=1,
|
||||
)
|
||||
result = settings.to_dict()
|
||||
assert result["name"] == "test"
|
||||
assert result["enabled"] is True
|
||||
assert result["priority"] == 1
|
||||
# None values should be excluded
|
||||
assert "bandwidth_limit_mbps" not in result
|
||||
assert "custom_headers" not in result
|
||||
|
||||
def test_to_dict_with_optional_fields(self):
|
||||
"""to_dict with optional fields set should include them."""
|
||||
settings = ProviderSettings(
|
||||
name="test",
|
||||
bandwidth_limit_mbps=10.0,
|
||||
custom_headers={"X-Custom": "value"},
|
||||
)
|
||||
result = settings.to_dict()
|
||||
assert result["bandwidth_limit_mbps"] == 10.0
|
||||
assert result["custom_headers"] == {"X-Custom": "value"}
|
||||
|
||||
def test_from_dict(self):
|
||||
"""from_dict should create settings from fields with defaults."""
|
||||
# Note: from_dict uses hasattr(cls, k) which only matches fields
|
||||
# with defaults on the class. The 'name' field has no default,
|
||||
# so it must be passed explicitly.
|
||||
data = {
|
||||
"name": "test_provider",
|
||||
"enabled": False,
|
||||
"priority": 5,
|
||||
"timeout_seconds": 60,
|
||||
}
|
||||
# The from_dict filters with hasattr which excludes 'name'
|
||||
# (no default), so this should raise TypeError
|
||||
with pytest.raises(TypeError):
|
||||
ProviderSettings.from_dict(data)
|
||||
|
||||
def test_from_dict_with_only_defaults_fields(self):
|
||||
"""from_dict works when all fields have defaults (except name)."""
|
||||
# Directly construct to test fields with defaults
|
||||
data = {
|
||||
"enabled": False,
|
||||
"priority": 5,
|
||||
}
|
||||
# This will fail because 'name' is required but filtered out
|
||||
with pytest.raises(TypeError):
|
||||
ProviderSettings.from_dict(data)
|
||||
|
||||
def test_from_dict_ignores_unknown_fields(self):
|
||||
"""from_dict should ignore fields not in the dataclass."""
|
||||
data = {
|
||||
"name": "test",
|
||||
"unknown_field": "value",
|
||||
"another_unknown": 42,
|
||||
}
|
||||
# name gets filtered by hasattr → TypeError for missing name
|
||||
with pytest.raises(TypeError):
|
||||
ProviderSettings.from_dict(data)
|
||||
|
||||
def test_from_dict_with_defaults(self):
|
||||
"""from_dict with only-defaults data loses required 'name'."""
|
||||
data = {"name": "minimal"}
|
||||
with pytest.raises(TypeError):
|
||||
ProviderSettings.from_dict(data)
|
||||
|
||||
|
||||
class TestProviderConfigManager:
|
||||
"""Test ProviderConfigManager class."""
|
||||
|
||||
def test_init_without_config_file(self):
|
||||
"""Should initialize with empty provider settings."""
|
||||
manager = ProviderConfigManager()
|
||||
assert manager._provider_settings == {}
|
||||
|
||||
def test_init_with_nonexistent_config_file(self):
|
||||
"""Should initialize cleanly when config file doesn't exist."""
|
||||
manager = ProviderConfigManager(
|
||||
config_file=Path("/nonexistent/config.json")
|
||||
)
|
||||
assert manager._provider_settings == {}
|
||||
|
||||
def test_global_settings_defaults(self):
|
||||
"""Should have sensible global defaults."""
|
||||
manager = ProviderConfigManager()
|
||||
assert manager.get_global_setting("default_timeout") == 30
|
||||
assert manager.get_global_setting("default_max_retries") == 3
|
||||
assert manager.get_global_setting("enable_health_monitoring") is True
|
||||
assert manager.get_global_setting("enable_failover") is True
|
||||
|
||||
def test_get_provider_settings_none_for_unknown(self):
|
||||
"""Should return None for unknown provider."""
|
||||
manager = ProviderConfigManager()
|
||||
assert manager.get_provider_settings("unknown") is None
|
||||
|
||||
def test_set_and_get_provider_settings(self):
|
||||
"""Should store and retrieve provider settings."""
|
||||
manager = ProviderConfigManager()
|
||||
settings = ProviderSettings(name="test", priority=1)
|
||||
manager.set_provider_settings("test", settings)
|
||||
result = manager.get_provider_settings("test")
|
||||
assert result is not None
|
||||
assert result.name == "test"
|
||||
assert result.priority == 1
|
||||
|
||||
def test_update_provider_settings_existing(self):
|
||||
"""Should update existing provider settings."""
|
||||
manager = ProviderConfigManager()
|
||||
settings = ProviderSettings(name="test", priority=1)
|
||||
manager.set_provider_settings("test", settings)
|
||||
|
||||
result = manager.update_provider_settings("test", priority=5)
|
||||
assert result is True
|
||||
updated = manager.get_provider_settings("test")
|
||||
assert updated.priority == 5
|
||||
|
||||
def test_update_provider_settings_new(self):
|
||||
"""Should create new settings when provider doesn't exist."""
|
||||
manager = ProviderConfigManager()
|
||||
result = manager.update_provider_settings(
|
||||
"new_provider", priority=3, timeout_seconds=60
|
||||
)
|
||||
assert result is True
|
||||
settings = manager.get_provider_settings("new_provider")
|
||||
assert settings is not None
|
||||
assert settings.priority == 3
|
||||
assert settings.timeout_seconds == 60
|
||||
|
||||
def test_get_all_provider_settings(self):
|
||||
"""Should return copy of all provider settings."""
|
||||
manager = ProviderConfigManager()
|
||||
manager.set_provider_settings(
|
||||
"p1", ProviderSettings(name="p1")
|
||||
)
|
||||
manager.set_provider_settings(
|
||||
"p2", ProviderSettings(name="p2")
|
||||
)
|
||||
|
||||
all_settings = manager.get_all_provider_settings()
|
||||
assert len(all_settings) == 2
|
||||
assert "p1" in all_settings
|
||||
assert "p2" in all_settings
|
||||
|
||||
def test_get_all_returns_copy(self):
|
||||
"""get_all_provider_settings should return a copy."""
|
||||
manager = ProviderConfigManager()
|
||||
manager.set_provider_settings(
|
||||
"p1", ProviderSettings(name="p1")
|
||||
)
|
||||
all_settings = manager.get_all_provider_settings()
|
||||
all_settings["p2"] = ProviderSettings(name="p2")
|
||||
assert "p2" not in manager.get_all_provider_settings()
|
||||
|
||||
|
||||
class TestProviderEnableDisable:
|
||||
"""Test enable/disable provider functionality."""
|
||||
|
||||
def test_enable_provider(self):
|
||||
"""Should enable a disabled provider."""
|
||||
manager = ProviderConfigManager()
|
||||
settings = ProviderSettings(name="test", enabled=False)
|
||||
manager.set_provider_settings("test", settings)
|
||||
|
||||
result = manager.enable_provider("test")
|
||||
assert result is True
|
||||
assert manager.get_provider_settings("test").enabled is True
|
||||
|
||||
def test_disable_provider(self):
|
||||
"""Should disable an enabled provider."""
|
||||
manager = ProviderConfigManager()
|
||||
settings = ProviderSettings(name="test", enabled=True)
|
||||
manager.set_provider_settings("test", settings)
|
||||
|
||||
result = manager.disable_provider("test")
|
||||
assert result is True
|
||||
assert manager.get_provider_settings("test").enabled is False
|
||||
|
||||
def test_enable_unknown_provider_returns_false(self):
|
||||
"""Should return False when enabling unknown provider."""
|
||||
manager = ProviderConfigManager()
|
||||
assert manager.enable_provider("unknown") is False
|
||||
|
||||
def test_disable_unknown_provider_returns_false(self):
|
||||
"""Should return False when disabling unknown provider."""
|
||||
manager = ProviderConfigManager()
|
||||
assert manager.disable_provider("unknown") is False
|
||||
|
||||
def test_get_enabled_providers(self):
|
||||
"""Should return only enabled providers."""
|
||||
manager = ProviderConfigManager()
|
||||
manager.set_provider_settings(
|
||||
"p1", ProviderSettings(name="p1", enabled=True)
|
||||
)
|
||||
manager.set_provider_settings(
|
||||
"p2", ProviderSettings(name="p2", enabled=False)
|
||||
)
|
||||
manager.set_provider_settings(
|
||||
"p3", ProviderSettings(name="p3", enabled=True)
|
||||
)
|
||||
|
||||
enabled = manager.get_enabled_providers()
|
||||
assert "p1" in enabled
|
||||
assert "p2" not in enabled
|
||||
assert "p3" in enabled
|
||||
|
||||
|
||||
class TestProviderPriority:
|
||||
"""Test provider priority management."""
|
||||
|
||||
def test_set_provider_priority(self):
|
||||
"""Should set priority for a provider."""
|
||||
manager = ProviderConfigManager()
|
||||
manager.set_provider_settings(
|
||||
"test", ProviderSettings(name="test", priority=0)
|
||||
)
|
||||
|
||||
result = manager.set_provider_priority("test", 5)
|
||||
assert result is True
|
||||
assert manager.get_provider_settings("test").priority == 5
|
||||
|
||||
def test_set_priority_unknown_returns_false(self):
|
||||
"""Should return False for unknown provider."""
|
||||
manager = ProviderConfigManager()
|
||||
assert manager.set_provider_priority("unknown", 1) is False
|
||||
|
||||
def test_get_providers_by_priority(self):
|
||||
"""Should return providers sorted by priority."""
|
||||
manager = ProviderConfigManager()
|
||||
manager.set_provider_settings(
|
||||
"low", ProviderSettings(name="low", priority=10)
|
||||
)
|
||||
manager.set_provider_settings(
|
||||
"high", ProviderSettings(name="high", priority=1)
|
||||
)
|
||||
manager.set_provider_settings(
|
||||
"mid", ProviderSettings(name="mid", priority=5)
|
||||
)
|
||||
|
||||
sorted_providers = manager.get_providers_by_priority()
|
||||
assert sorted_providers == ["high", "mid", "low"]
|
||||
|
||||
|
||||
class TestGlobalSettings:
|
||||
"""Test global settings management."""
|
||||
|
||||
def test_get_global_setting(self):
|
||||
"""Should retrieve global setting value."""
|
||||
manager = ProviderConfigManager()
|
||||
assert manager.get_global_setting("default_timeout") == 30
|
||||
|
||||
def test_get_unknown_global_setting(self):
|
||||
"""Should return None for unknown global setting."""
|
||||
manager = ProviderConfigManager()
|
||||
assert manager.get_global_setting("nonexistent") is None
|
||||
|
||||
def test_set_global_setting(self):
|
||||
"""Should set a global setting."""
|
||||
manager = ProviderConfigManager()
|
||||
manager.set_global_setting("custom_key", "custom_value")
|
||||
assert manager.get_global_setting("custom_key") == "custom_value"
|
||||
|
||||
def test_get_all_global_settings(self):
|
||||
"""Should return all global settings."""
|
||||
manager = ProviderConfigManager()
|
||||
all_settings = manager.get_all_global_settings()
|
||||
assert "default_timeout" in all_settings
|
||||
assert "enable_failover" in all_settings
|
||||
|
||||
def test_get_all_global_returns_copy(self):
|
||||
"""get_all_global_settings should return a copy."""
|
||||
manager = ProviderConfigManager()
|
||||
settings = manager.get_all_global_settings()
|
||||
settings["new_key"] = "new_value"
|
||||
assert manager.get_global_setting("new_key") is None
|
||||
|
||||
|
||||
class TestConfigPersistence:
|
||||
"""Test configuration save/load functionality."""
|
||||
|
||||
def test_save_and_load_config(self):
|
||||
"""Should save and load configuration from file."""
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=".json", delete=False, mode="w"
|
||||
) as f:
|
||||
config_path = Path(f.name)
|
||||
|
||||
try:
|
||||
manager = ProviderConfigManager(config_file=config_path)
|
||||
manager.set_provider_settings(
|
||||
"test_provider",
|
||||
ProviderSettings(name="test_provider", priority=3),
|
||||
)
|
||||
manager.set_global_setting("custom_option", True)
|
||||
|
||||
assert manager.save_config() is True
|
||||
|
||||
# Verify the file was created and is valid JSON
|
||||
assert config_path.exists()
|
||||
with open(config_path, "r") as f:
|
||||
data = json.load(f)
|
||||
assert "providers" in data
|
||||
assert "test_provider" in data["providers"]
|
||||
assert data["providers"]["test_provider"]["priority"] == 3
|
||||
assert data["global"]["custom_option"] is True
|
||||
finally:
|
||||
config_path.unlink(missing_ok=True)
|
||||
|
||||
def test_save_config_no_path(self):
|
||||
"""Should return False when no path is specified."""
|
||||
manager = ProviderConfigManager()
|
||||
assert manager.save_config() is False
|
||||
|
||||
def test_load_config_nonexistent_file(self):
|
||||
"""Should return False when file doesn't exist."""
|
||||
manager = ProviderConfigManager()
|
||||
assert manager.load_config(Path("/nonexistent.json")) is False
|
||||
|
||||
def test_load_config_invalid_json(self):
|
||||
"""Should return False for invalid JSON file."""
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=".json", delete=False, mode="w"
|
||||
) as f:
|
||||
f.write("not valid json{{{")
|
||||
config_path = Path(f.name)
|
||||
|
||||
try:
|
||||
manager = ProviderConfigManager()
|
||||
assert manager.load_config(config_path) is False
|
||||
finally:
|
||||
config_path.unlink(missing_ok=True)
|
||||
|
||||
def test_save_creates_parent_directories(self):
|
||||
"""save_config should create parent directories if needed."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
config_path = Path(tmpdir) / "subdir" / "config.json"
|
||||
manager = ProviderConfigManager(config_file=config_path)
|
||||
manager.set_provider_settings(
|
||||
"p1", ProviderSettings(name="p1")
|
||||
)
|
||||
assert manager.save_config() is True
|
||||
assert config_path.exists()
|
||||
|
||||
|
||||
class TestResetToDefaults:
|
||||
"""Test reset functionality."""
|
||||
|
||||
def test_reset_clears_providers(self):
|
||||
"""reset_to_defaults should clear all provider settings."""
|
||||
manager = ProviderConfigManager()
|
||||
manager.set_provider_settings(
|
||||
"p1", ProviderSettings(name="p1")
|
||||
)
|
||||
manager.reset_to_defaults()
|
||||
assert manager.get_all_provider_settings() == {}
|
||||
|
||||
def test_reset_restores_global_defaults(self):
|
||||
"""reset_to_defaults should restore default global settings."""
|
||||
manager = ProviderConfigManager()
|
||||
manager.set_global_setting("default_timeout", 999)
|
||||
manager.set_global_setting("custom_key", "value")
|
||||
|
||||
manager.reset_to_defaults()
|
||||
|
||||
assert manager.get_global_setting("default_timeout") == 30
|
||||
assert manager.get_global_setting("custom_key") is None
|
||||
|
||||
|
||||
class TestGetConfigManagerSingleton:
|
||||
"""Test the get_config_manager singleton function."""
|
||||
|
||||
def test_returns_instance(self):
|
||||
"""get_config_manager should return a ProviderConfigManager."""
|
||||
# Reset global state for test
|
||||
import src.core.providers.config_manager as cm
|
||||
cm._config_manager = None
|
||||
|
||||
manager = get_config_manager()
|
||||
assert isinstance(manager, ProviderConfigManager)
|
||||
|
||||
# Cleanup
|
||||
cm._config_manager = None
|
||||
|
||||
def test_returns_same_instance(self):
|
||||
"""get_config_manager should return same instance on repeated calls."""
|
||||
import src.core.providers.config_manager as cm
|
||||
cm._config_manager = None
|
||||
|
||||
first = get_config_manager()
|
||||
second = get_config_manager()
|
||||
assert first is second
|
||||
|
||||
# Cleanup
|
||||
cm._config_manager = None
|
||||
105
tests/unit/test_provider_factory.py
Normal file
105
tests/unit/test_provider_factory.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""Unit tests for provider_factory.py - Factory instantiation, dependency injection, provider registration."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.core.providers.base_provider import Loader
|
||||
from src.core.providers.provider_factory import Loaders
|
||||
|
||||
|
||||
class TestLoadersInit:
|
||||
"""Test Loaders factory initialization."""
|
||||
|
||||
@patch("src.core.providers.provider_factory.AniworldLoader")
|
||||
def test_factory_initializes_with_default_providers(self, mock_aniworld):
|
||||
"""Factory should register aniworld.to provider by default."""
|
||||
mock_aniworld.return_value = MagicMock(spec=Loader)
|
||||
factory = Loaders()
|
||||
assert "aniworld.to" in factory.dict
|
||||
|
||||
@patch("src.core.providers.provider_factory.AniworldLoader")
|
||||
def test_factory_dict_contains_loader_instances(self, mock_aniworld):
|
||||
"""Factory dict values should be Loader instances."""
|
||||
mock_instance = MagicMock(spec=Loader)
|
||||
mock_aniworld.return_value = mock_instance
|
||||
factory = Loaders()
|
||||
for key, value in factory.dict.items():
|
||||
assert isinstance(key, str)
|
||||
|
||||
|
||||
class TestLoadersGetLoader:
|
||||
"""Test GetLoader method."""
|
||||
|
||||
@patch("src.core.providers.provider_factory.AniworldLoader")
|
||||
def test_get_loader_returns_registered_provider(self, mock_aniworld):
|
||||
"""GetLoader should return provider for known key."""
|
||||
mock_instance = MagicMock(spec=Loader)
|
||||
mock_aniworld.return_value = mock_instance
|
||||
factory = Loaders()
|
||||
loader = factory.GetLoader("aniworld.to")
|
||||
assert loader is mock_instance
|
||||
|
||||
@patch("src.core.providers.provider_factory.AniworldLoader")
|
||||
def test_get_loader_raises_key_error_for_unknown(self, mock_aniworld):
|
||||
"""GetLoader should raise KeyError for unknown provider key."""
|
||||
mock_aniworld.return_value = MagicMock(spec=Loader)
|
||||
factory = Loaders()
|
||||
with pytest.raises(KeyError):
|
||||
factory.GetLoader("nonexistent.provider")
|
||||
|
||||
@patch("src.core.providers.provider_factory.AniworldLoader")
|
||||
def test_get_loader_returns_same_instance(self, mock_aniworld):
|
||||
"""GetLoader should return same instance on repeated calls."""
|
||||
mock_instance = MagicMock(spec=Loader)
|
||||
mock_aniworld.return_value = mock_instance
|
||||
factory = Loaders()
|
||||
first = factory.GetLoader("aniworld.to")
|
||||
second = factory.GetLoader("aniworld.to")
|
||||
assert first is second
|
||||
|
||||
@patch("src.core.providers.provider_factory.AniworldLoader")
|
||||
def test_get_loader_empty_key(self, mock_aniworld):
|
||||
"""GetLoader should raise KeyError for empty string key."""
|
||||
mock_aniworld.return_value = MagicMock(spec=Loader)
|
||||
factory = Loaders()
|
||||
with pytest.raises(KeyError):
|
||||
factory.GetLoader("")
|
||||
|
||||
|
||||
class TestLoadersProviderRegistry:
|
||||
"""Test the provider registry within the factory."""
|
||||
|
||||
@patch("src.core.providers.provider_factory.AniworldLoader")
|
||||
def test_registry_size(self, mock_aniworld):
|
||||
"""Factory should have exactly one default provider."""
|
||||
mock_aniworld.return_value = MagicMock(spec=Loader)
|
||||
factory = Loaders()
|
||||
assert len(factory.dict) == 1
|
||||
|
||||
@patch("src.core.providers.provider_factory.AniworldLoader")
|
||||
def test_can_add_custom_provider(self, mock_aniworld):
|
||||
"""Custom providers can be added to the factory registry."""
|
||||
mock_aniworld.return_value = MagicMock(spec=Loader)
|
||||
factory = Loaders()
|
||||
custom_provider = MagicMock(spec=Loader)
|
||||
factory.dict["custom.provider"] = custom_provider
|
||||
assert factory.GetLoader("custom.provider") is custom_provider
|
||||
|
||||
@patch("src.core.providers.provider_factory.AniworldLoader")
|
||||
def test_can_override_existing_provider(self, mock_aniworld):
|
||||
"""Existing providers can be overridden in the registry."""
|
||||
mock_aniworld.return_value = MagicMock(spec=Loader)
|
||||
factory = Loaders()
|
||||
new_provider = MagicMock(spec=Loader)
|
||||
factory.dict["aniworld.to"] = new_provider
|
||||
assert factory.GetLoader("aniworld.to") is new_provider
|
||||
|
||||
@patch("src.core.providers.provider_factory.AniworldLoader")
|
||||
def test_multiple_factories_are_independent(self, mock_aniworld):
|
||||
"""Multiple factory instances should have independent registries."""
|
||||
mock_aniworld.return_value = MagicMock(spec=Loader)
|
||||
factory1 = Loaders()
|
||||
factory2 = Loaders()
|
||||
factory1.dict["extra"] = MagicMock(spec=Loader)
|
||||
assert "extra" not in factory2.dict
|
||||
Reference in New Issue
Block a user