Add provider system tests: 211 tests covering base, factory, config, monitoring, failover, and selection

This commit is contained in:
2026-02-07 18:06:15 +01:00
parent af208882f5
commit 4b35cb63d1
8 changed files with 2667 additions and 0 deletions

View File

@@ -0,0 +1,312 @@
"""Integration tests for provider failover scenarios - End-to-end provider switching."""
import asyncio
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from src.core.providers.failover import (
ProviderFailover,
configure_failover,
get_failover,
)
from src.core.providers.health_monitor import ProviderHealthMonitor
class TestProviderFailoverScenarios:
"""Test end-to-end failover scenarios with multiple providers."""
@pytest.mark.asyncio
async def test_primary_fails_switches_to_backup(self):
"""When primary provider fails, should switch to backup and succeed."""
call_log = []
async def operation(provider: str) -> str:
call_log.append(provider)
if provider == "provider1":
raise ConnectionError("Provider1 is down")
return f"Success from {provider}"
failover = ProviderFailover(
providers=["provider1", "provider2", "provider3"],
max_retries=1,
retry_delay=0.01,
enable_health_monitoring=False,
)
result = await failover.execute_with_failover(
operation=operation,
operation_name="test_failover",
)
assert "Success" in result
assert "provider1" in call_log
assert len(call_log) >= 2
@pytest.mark.asyncio
async def test_first_two_fail_third_succeeds(self):
"""When first two providers fail, third should be tried."""
attempts = {}
async def operation(provider: str) -> str:
attempts[provider] = attempts.get(provider, 0) + 1
if provider in ("provider1", "provider2"):
raise ConnectionError(f"{provider} is down")
return f"Success from {provider}"
failover = ProviderFailover(
providers=["provider1", "provider2", "provider3"],
max_retries=1,
retry_delay=0.01,
enable_health_monitoring=False,
)
result = await failover.execute_with_failover(
operation=operation,
operation_name="test_two_fail",
)
assert "provider3" in result
@pytest.mark.asyncio
async def test_all_providers_fail_raises(self):
"""When all providers fail, should raise exception."""
async def operation(provider: str) -> str:
raise ConnectionError(f"{provider} is down")
failover = ProviderFailover(
providers=["provider1", "provider2"],
max_retries=1,
retry_delay=0.01,
enable_health_monitoring=False,
)
with pytest.raises(Exception, match="failed with all providers"):
await failover.execute_with_failover(
operation=operation,
operation_name="test_all_fail",
)
@pytest.mark.asyncio
async def test_retry_within_single_provider(self):
"""Should retry with same provider before moving to next."""
call_count = 0
async def operation(provider: str) -> str:
nonlocal call_count
call_count += 1
if call_count <= 2:
raise ConnectionError("Temporary failure")
return f"Success on attempt {call_count}"
failover = ProviderFailover(
providers=["provider1"],
max_retries=3,
retry_delay=0.01,
enable_health_monitoring=False,
)
result = await failover.execute_with_failover(
operation=operation,
operation_name="test_retry",
)
assert "Success" in result
assert call_count == 3
@pytest.mark.asyncio
async def test_failover_with_health_monitoring(self):
"""Failover should integrate with health monitoring."""
monitor = ProviderHealthMonitor(failure_threshold=2)
# Pre-record failures for provider1 to make it unavailable
for _ in range(3):
monitor.record_request(
provider_name="provider1",
success=False,
response_time_ms=100,
error_message="Simulated failure",
)
# Provider1 should now be unavailable
assert "provider1" not in monitor.get_available_providers()
with patch(
"src.core.providers.failover.get_health_monitor",
return_value=monitor,
):
failover = ProviderFailover(
providers=["provider1", "provider2"],
max_retries=1,
retry_delay=0.01,
enable_health_monitoring=True,
)
# Current provider should prefer provider2
current = failover.get_current_provider()
# Best provider selection should favor available ones
available = monitor.get_available_providers()
if available:
assert current in available or current == "provider2"
class TestProviderFailoverChainManagement:
"""Test failover chain add/remove/priority operations."""
def test_add_provider_to_chain(self):
"""Should add new provider to the failover chain."""
failover = ProviderFailover(
providers=["p1", "p2"],
enable_health_monitoring=False,
)
failover.add_provider("p3")
assert "p3" in failover.get_providers()
def test_add_duplicate_provider_no_effect(self):
"""Adding existing provider should not create duplicate."""
failover = ProviderFailover(
providers=["p1", "p2"],
enable_health_monitoring=False,
)
failover.add_provider("p1")
assert failover.get_providers().count("p1") == 1
def test_remove_provider_from_chain(self):
"""Should remove provider from the failover chain."""
failover = ProviderFailover(
providers=["p1", "p2", "p3"],
enable_health_monitoring=False,
)
result = failover.remove_provider("p2")
assert result is True
assert "p2" not in failover.get_providers()
def test_remove_nonexistent_provider(self):
"""Removing non-existent provider should return False."""
failover = ProviderFailover(
providers=["p1"],
enable_health_monitoring=False,
)
assert failover.remove_provider("p99") is False
def test_set_provider_priority(self):
"""Should reorder provider in the chain."""
failover = ProviderFailover(
providers=["p1", "p2", "p3"],
enable_health_monitoring=False,
)
result = failover.set_provider_priority("p3", 0)
assert result is True
providers = failover.get_providers()
assert providers[0] == "p3"
def test_set_priority_unknown_provider(self):
"""Setting priority for unknown provider should return False."""
failover = ProviderFailover(
providers=["p1"],
enable_health_monitoring=False,
)
assert failover.set_provider_priority("unknown", 0) is False
class TestFailoverStats:
"""Test failover statistics reporting."""
def test_get_failover_stats(self):
"""Should return comprehensive stats."""
failover = ProviderFailover(
providers=["p1", "p2"],
max_retries=3,
retry_delay=1.0,
enable_health_monitoring=False,
)
stats = failover.get_failover_stats()
assert stats["total_providers"] == 2
assert stats["max_retries"] == 3
assert stats["retry_delay"] == 1.0
assert len(stats["providers"]) == 2
def test_stats_with_health_monitoring(self):
"""Stats should include availability info when monitoring enabled."""
monitor = ProviderHealthMonitor()
monitor.record_request("p1", True, 100)
monitor.record_request("p2", False, 200, error_message="fail")
monitor.record_request("p2", False, 200, error_message="fail")
monitor.record_request("p2", False, 200, error_message="fail")
with patch(
"src.core.providers.failover.get_health_monitor",
return_value=monitor,
):
failover = ProviderFailover(
providers=["p1", "p2"],
enable_health_monitoring=True,
)
stats = failover.get_failover_stats()
assert "available_providers" in stats
assert "unavailable_providers" in stats
class TestConfigureFailover:
"""Test the global failover configuration function."""
def test_configure_failover(self):
"""configure_failover should create a new global instance."""
import src.core.providers.failover as fo
fo._failover = None
failover = configure_failover(
providers=["custom1", "custom2"],
max_retries=5,
retry_delay=0.5,
)
assert isinstance(failover, ProviderFailover)
assert failover.get_providers() == ["custom1", "custom2"]
assert failover._max_retries == 5
# Cleanup
fo._failover = None
def test_get_failover_singleton(self):
"""get_failover should return same instance."""
import src.core.providers.failover as fo
fo._failover = None
first = get_failover()
second = get_failover()
assert first is second
fo._failover = None
class TestNextProviderRotation:
"""Test provider rotation logic."""
def test_get_next_cycles_through_all(self):
"""get_next_provider should cycle through all providers."""
failover = ProviderFailover(
providers=["p1", "p2", "p3"],
enable_health_monitoring=False,
)
seen = set()
for _ in range(3):
provider = failover.get_next_provider()
if provider:
seen.add(provider)
assert len(seen) >= 2 # Should see at least 2 different providers
def test_current_provider_is_from_list(self):
"""get_current_provider should always return from provider list."""
failover = ProviderFailover(
providers=["p1", "p2", "p3"],
enable_health_monitoring=False,
)
for _ in range(10):
current = failover.get_current_provider()
assert current in ["p1", "p2", "p3"]
failover.get_next_provider()

View File

@@ -0,0 +1,348 @@
"""Integration tests for provider selection based on availability, health status, priority."""
from unittest.mock import MagicMock, patch
import pytest
from src.core.providers.config_manager import (
ProviderConfigManager,
ProviderSettings,
)
from src.core.providers.failover import ProviderFailover
from src.core.providers.health_monitor import (
ProviderHealthMetrics,
ProviderHealthMonitor,
)
class TestProviderSelectionByHealth:
"""Test provider selection based on health metrics."""
def test_best_provider_selected_by_success_rate(self):
"""Provider with highest success rate should be selected as best."""
monitor = ProviderHealthMonitor(failure_threshold=5)
# Provider1: 80% success rate
for i in range(10):
monitor.record_request(
"provider1",
success=(i < 8),
response_time_ms=100,
error_message=None if i < 8 else "fail",
)
# Provider2: 90% success rate
for i in range(10):
monitor.record_request(
"provider2",
success=(i < 9),
response_time_ms=100,
error_message=None if i < 9 else "fail",
)
best = monitor.get_best_provider()
assert best == "provider2"
def test_unavailable_provider_not_selected(self):
"""Provider marked unavailable should not be selected as best."""
monitor = ProviderHealthMonitor(failure_threshold=3)
# Make provider1 unavailable with consecutive failures
for _ in range(5):
monitor.record_request(
"provider1",
success=False,
response_time_ms=500,
error_message="Connection refused",
)
# Provider2 is healthy
monitor.record_request(
"provider2", success=True, response_time_ms=100
)
best = monitor.get_best_provider()
assert best == "provider2"
available = monitor.get_available_providers()
assert "provider1" not in available
assert "provider2" in available
def test_recovery_after_failures(self):
"""Provider should recover availability after successful request."""
monitor = ProviderHealthMonitor(failure_threshold=3)
# Make provider fail
for _ in range(4):
monitor.record_request(
"provider1", success=False, response_time_ms=200,
error_message="fail"
)
assert "provider1" not in monitor.get_available_providers()
# Successful request should reset consecutive failures
monitor.record_request(
"provider1", success=True, response_time_ms=100
)
assert "provider1" in monitor.get_available_providers()
def test_response_time_tiebreaker(self):
"""When success rates are equal, faster provider should be preferred."""
monitor = ProviderHealthMonitor()
# Both have 100% success, but different response times
monitor.record_request(
"slow_provider", success=True, response_time_ms=500
)
monitor.record_request(
"fast_provider", success=True, response_time_ms=50
)
best = monitor.get_best_provider()
assert best == "fast_provider"
class TestProviderSelectionWithConfig:
"""Test provider selection using configuration manager."""
def test_enabled_providers_only(self):
"""Only enabled providers should be available for selection."""
config = ProviderConfigManager()
config.set_provider_settings(
"p1", ProviderSettings(name="p1", enabled=True, priority=1)
)
config.set_provider_settings(
"p2", ProviderSettings(name="p2", enabled=False, priority=0)
)
config.set_provider_settings(
"p3", ProviderSettings(name="p3", enabled=True, priority=2)
)
enabled = config.get_enabled_providers()
assert "p1" in enabled
assert "p2" not in enabled
assert "p3" in enabled
def test_priority_ordering(self):
"""Providers should be ordered by priority value."""
config = ProviderConfigManager()
config.set_provider_settings(
"low_priority",
ProviderSettings(name="low_priority", priority=10),
)
config.set_provider_settings(
"high_priority",
ProviderSettings(name="high_priority", priority=1),
)
config.set_provider_settings(
"mid_priority",
ProviderSettings(name="mid_priority", priority=5),
)
ordered = config.get_providers_by_priority()
assert ordered == ["high_priority", "mid_priority", "low_priority"]
def test_dynamic_priority_update(self):
"""Priority changes should immediately affect ordering."""
config = ProviderConfigManager()
config.set_provider_settings(
"p1", ProviderSettings(name="p1", priority=1)
)
config.set_provider_settings(
"p2", ProviderSettings(name="p2", priority=2)
)
# Initially p1 is higher priority
assert config.get_providers_by_priority()[0] == "p1"
# Change p2 to higher priority
config.set_provider_priority("p2", 0)
assert config.get_providers_by_priority()[0] == "p2"
class TestProviderSelectionWithFailover:
"""Test provider selection integration with failover system."""
def test_failover_respects_health_status(self):
"""Failover should prefer healthy providers."""
monitor = ProviderHealthMonitor(failure_threshold=2)
# Mark p1 as unhealthy
monitor.record_request("p1", False, 100, error_message="fail")
monitor.record_request("p1", False, 100, error_message="fail")
# p2 is healthy
monitor.record_request("p2", True, 50)
with patch(
"src.core.providers.failover.get_health_monitor",
return_value=monitor,
):
failover = ProviderFailover(
providers=["p1", "p2"],
enable_health_monitoring=True,
)
current = failover.get_current_provider()
# Should prefer the healthy provider
assert current == "p2"
def test_failover_falls_back_to_round_robin(self):
"""Without health data, should use round-robin selection."""
failover = ProviderFailover(
providers=["p1", "p2", "p3"],
enable_health_monitoring=False,
)
# Should cycle through providers
first = failover.get_current_provider()
assert first in ["p1", "p2", "p3"]
class TestHealthMonitorMetrics:
"""Test health monitor metric collection scenarios."""
def test_metrics_tracking_accuracy(self):
"""Metrics should accurately reflect request history."""
monitor = ProviderHealthMonitor()
# Record 7 successes and 3 failures
for i in range(10):
monitor.record_request(
"test_provider",
success=(i < 7),
response_time_ms=100 + i * 10,
bytes_transferred=1024 * (i + 1),
error_message=None if i < 7 else f"error_{i}",
)
metrics = monitor.get_provider_metrics("test_provider")
assert metrics is not None
assert metrics.total_requests == 10
assert metrics.successful_requests == 7
assert metrics.failed_requests == 3
assert metrics.success_rate == 70.0
assert metrics.total_bytes_downloaded == sum(
1024 * (i + 1) for i in range(10)
)
def test_consecutive_failure_tracking(self):
"""Consecutive failures should be tracked accurately."""
monitor = ProviderHealthMonitor(failure_threshold=3)
# 2 successes then 3 failures
monitor.record_request("p1", True, 100)
monitor.record_request("p1", True, 100)
monitor.record_request("p1", False, 100, error_message="e1")
monitor.record_request("p1", False, 100, error_message="e2")
monitor.record_request("p1", False, 100, error_message="e3")
metrics = monitor.get_provider_metrics("p1")
assert metrics.consecutive_failures == 3
assert metrics.is_available is False
def test_success_resets_consecutive_failures(self):
"""A success should reset the consecutive failure counter."""
monitor = ProviderHealthMonitor(failure_threshold=5)
# 3 failures then 1 success
monitor.record_request("p1", False, 100, error_message="e1")
monitor.record_request("p1", False, 100, error_message="e2")
monitor.record_request("p1", False, 100, error_message="e3")
monitor.record_request("p1", True, 100)
metrics = monitor.get_provider_metrics("p1")
assert metrics.consecutive_failures == 0
assert metrics.is_available is True
def test_health_summary(self):
"""Health summary should aggregate all provider metrics."""
monitor = ProviderHealthMonitor()
monitor.record_request("p1", True, 100)
monitor.record_request("p2", True, 200)
monitor.record_request("p3", False, 300, error_message="err")
summary = monitor.get_health_summary()
assert summary["total_providers"] == 3
assert summary["available_providers"] >= 2
assert "average_success_rate" in summary
assert "providers" in summary
assert len(summary["providers"]) == 3
def test_reset_provider_metrics(self):
"""Resetting metrics should clear all data for a provider."""
monitor = ProviderHealthMonitor()
monitor.record_request("p1", True, 100, bytes_transferred=1024)
monitor.record_request("p1", False, 200, error_message="fail")
result = monitor.reset_provider_metrics("p1")
assert result is True
metrics = monitor.get_provider_metrics("p1")
assert metrics.total_requests == 0
assert metrics.successful_requests == 0
assert metrics.total_bytes_downloaded == 0
def test_reset_unknown_provider(self):
"""Resetting unknown provider should return False."""
monitor = ProviderHealthMonitor()
assert monitor.reset_provider_metrics("unknown") is False
def test_empty_summary(self):
"""Summary with no providers should return zeros."""
monitor = ProviderHealthMonitor()
summary = monitor.get_health_summary()
assert summary["total_providers"] == 0
assert summary["average_success_rate"] == 0.0
class TestMultiProviderHealthScenarios:
"""Test complex multi-provider health scenarios."""
def test_three_providers_degraded_service(self):
"""With 3 providers, partial failure should still select best."""
monitor = ProviderHealthMonitor(failure_threshold=3)
# Provider A: fully down
for _ in range(5):
monitor.record_request("A", False, 500, error_message="down")
# Provider B: degraded (50% success)
for i in range(10):
monitor.record_request(
"B", success=(i % 2 == 0), response_time_ms=200,
error_message=None if i % 2 == 0 else "intermittent"
)
# Provider C: healthy (100% success)
for _ in range(5):
monitor.record_request("C", True, 100)
best = monitor.get_best_provider()
assert best == "C"
available = monitor.get_available_providers()
assert "A" not in available
assert "B" in available
assert "C" in available
def test_all_providers_healthy(self):
"""When all healthy, fastest should be selected."""
monitor = ProviderHealthMonitor()
monitor.record_request("slow", True, 500)
monitor.record_request("medium", True, 200)
monitor.record_request("fast", True, 50)
best = monitor.get_best_provider()
assert best == "fast"
def test_no_providers_tracked(self):
"""With no tracked providers, best should be None."""
monitor = ProviderHealthMonitor()
assert monitor.get_best_provider() is None
assert monitor.get_available_providers() == []

View File

@@ -0,0 +1,474 @@
"""Unit tests for aniworld_provider.py - Anime catalog scraping, episode listing, streaming link extraction."""
import json
from unittest.mock import MagicMock, Mock, patch
import pytest
from src.core.providers.aniworld_provider import AniworldLoader
@pytest.fixture
def loader():
"""Create AniworldLoader with mocked session to prevent real HTTP calls."""
with patch("src.core.providers.aniworld_provider.UserAgent") as mock_ua:
mock_ua.return_value.random = "MockUserAgent/1.0"
instance = AniworldLoader()
instance.session = MagicMock()
return instance
@pytest.fixture
def sample_search_response():
"""Sample JSON response for anime search."""
return json.dumps([
{"link": "/anime/stream/naruto", "title": "Naruto"},
{"link": "/anime/stream/one-piece", "title": "One Piece"},
])
@pytest.fixture
def sample_episode_html():
"""Sample HTML for an episode page with language info and providers."""
return """
<html>
<body>
<div class="changeLanguageBox">
<img data-lang-key="1" src="/flags/de.png" />
<img data-lang-key="2" src="/flags/en.png" />
</div>
<li class="episodeLink1">
<h4>VOE</h4>
<a class="watchEpisode" href="/redirect/12345"></a>
<span data-lang-key="1"></span>
</li>
</body>
</html>
"""
@pytest.fixture
def sample_series_html():
"""Sample HTML for a series main page."""
return """
<html>
<body>
<div class="series-title">
<h1><span>Naruto Shippuden</span></h1>
</div>
<p>Jahr: 2007</p>
<div class="series-info">Aired: 2007-2017</div>
</body>
</html>
"""
@pytest.fixture
def sample_season_html():
"""Sample HTML for a season page with episode links."""
return """
<html>
<body>
<meta itemprop="numberOfSeasons" content="2" />
<a href="/anime/stream/naruto/staffel-1/episode-1">Ep 1</a>
<a href="/anime/stream/naruto/staffel-1/episode-2">Ep 2</a>
<a href="/anime/stream/naruto/staffel-1/episode-3">Ep 3</a>
</body>
</html>
"""
class TestAniworldLoaderInit:
"""Test AniworldLoader initialization."""
def test_loader_initializes(self, loader):
"""Loader should initialize with expected attributes."""
assert loader.ANIWORLD_TO == "https://aniworld.to"
assert isinstance(loader.SUPPORTED_PROVIDERS, list)
assert len(loader.SUPPORTED_PROVIDERS) > 0
def test_loader_has_session(self, loader):
"""Loader should have a requests session."""
assert loader.session is not None
def test_loader_has_caches(self, loader):
"""Loader should initialize empty caches."""
assert isinstance(loader._KeyHTMLDict, dict)
assert isinstance(loader._EpisodeHTMLDict, dict)
def test_loader_site_key(self, loader):
"""get_site_key should return 'aniworld.to'."""
assert loader.get_site_key() == "aniworld.to"
def test_loader_provider_headers_initialized(self, loader):
"""Provider-specific headers should be initialized."""
assert isinstance(loader.PROVIDER_HEADERS, dict)
assert "VOE" in loader.PROVIDER_HEADERS
class TestAniworldSearch:
"""Test anime search functionality."""
def test_search_parses_json_response(self, loader, sample_search_response):
"""search() should parse JSON response into list."""
mock_response = MagicMock()
mock_response.text = sample_search_response
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
loader.session.get.return_value = mock_response
result = loader.search("naruto")
assert isinstance(result, list)
assert len(result) == 2
assert result[0]["title"] == "Naruto"
def test_search_calls_correct_url(self, loader, sample_search_response):
"""search() should call the correct search URL."""
mock_response = MagicMock()
mock_response.text = sample_search_response
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
loader.session.get.return_value = mock_response
loader.search("naruto")
call_args = loader.session.get.call_args
assert "seriesSearch" in call_args[0][0]
assert "naruto" in call_args[0][0]
def test_search_handles_empty_response(self, loader):
"""search() with empty JSON array should return empty list."""
mock_response = MagicMock()
mock_response.text = "[]"
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
loader.session.get.return_value = mock_response
result = loader.search("nonexistent")
assert result == []
def test_search_handles_html_escaped_json(self, loader):
"""search() should handle HTML-escaped JSON response."""
escaped_json = '[{"title": "Naruto &amp; Friends"}]'
mock_response = MagicMock()
mock_response.text = escaped_json
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
loader.session.get.return_value = mock_response
result = loader.search("naruto")
assert len(result) == 1
assert result[0]["title"] == "Naruto & Friends"
def test_search_url_encodes_special_characters(self, loader, sample_search_response):
"""search() should URL-encode special characters in search term."""
mock_response = MagicMock()
mock_response.text = sample_search_response
mock_response.raise_for_status = MagicMock()
loader.session.get.return_value = mock_response
loader.search("attack on titan")
call_url = loader.session.get.call_args[0][0]
assert "attack" in call_url
def test_search_raises_on_invalid_json(self, loader):
"""search() should raise when response is not valid JSON."""
mock_response = MagicMock()
mock_response.text = "<html>Not JSON</html>"
mock_response.status_code = 200
mock_response.raise_for_status = MagicMock()
loader.session.get.return_value = mock_response
with pytest.raises((ValueError, json.JSONDecodeError)):
loader.search("naruto")
class TestAniworldLanguageCheck:
"""Test language availability checking."""
def test_get_language_key_german_dub(self, loader):
"""_get_language_key should return 1 for 'German Dub'."""
assert loader._get_language_key("German Dub") == 1
def test_get_language_key_english_sub(self, loader):
"""_get_language_key should return 2 for 'English Sub'."""
assert loader._get_language_key("English Sub") == 2
def test_get_language_key_german_sub(self, loader):
"""_get_language_key should return 3 for 'German Sub'."""
assert loader._get_language_key("German Sub") == 3
def test_get_language_key_unknown(self, loader):
"""_get_language_key should return 0 for unknown language."""
assert loader._get_language_key("French Dub") == 0
def test_is_language_with_available_language(self, loader, sample_episode_html):
"""is_language should return True when language is available."""
mock_response = MagicMock()
mock_response.content = sample_episode_html.encode("utf-8")
loader.session.get.return_value = mock_response
result = loader.is_language(1, 1, "naruto", "German Dub")
assert result is True
def test_is_language_english_sub_available(self, loader, sample_episode_html):
"""is_language should return True for English Sub when available."""
mock_response = MagicMock()
mock_response.content = sample_episode_html.encode("utf-8")
loader.session.get.return_value = mock_response
result = loader.is_language(1, 1, "naruto", "English Sub")
assert result is True
def test_is_language_unavailable_language(self, loader, sample_episode_html):
"""is_language should return False when language is not available."""
mock_response = MagicMock()
mock_response.content = sample_episode_html.encode("utf-8")
loader.session.get.return_value = mock_response
result = loader.is_language(1, 1, "naruto", "German Sub")
assert result is False
def test_is_language_no_language_box(self, loader):
"""is_language should return False when no language box exists."""
html = "<html><body><div></div></body></html>"
mock_response = MagicMock()
mock_response.content = html.encode("utf-8")
loader.session.get.return_value = mock_response
result = loader.is_language(1, 1, "naruto", "German Dub")
assert result is False
class TestAniworldTitle:
"""Test title extraction."""
def test_get_title_extracts_correctly(self, loader, sample_series_html):
"""get_title should extract title from HTML."""
mock_response = MagicMock()
mock_response.content = sample_series_html.encode("utf-8")
loader._KeyHTMLDict["naruto"] = mock_response
result = loader.get_title("naruto")
assert result == "Naruto Shippuden"
def test_get_title_missing_title_div(self, loader):
"""get_title should return empty string when title div is missing."""
html = "<html><body></body></html>"
mock_response = MagicMock()
mock_response.content = html.encode("utf-8")
loader._KeyHTMLDict["unknown"] = mock_response
result = loader.get_title("unknown")
assert result == ""
def test_get_title_caches_html(self, loader, sample_series_html):
"""get_title should use cached HTML on second call."""
mock_response = MagicMock()
mock_response.content = sample_series_html.encode("utf-8")
loader._KeyHTMLDict["naruto"] = mock_response
loader.get_title("naruto")
loader.get_title("naruto")
# Session should not be called since HTML is cached
loader.session.get.assert_not_called()
class TestAniworldYear:
"""Test year extraction."""
def test_get_year_extracts_from_metadata(self, loader, sample_series_html):
"""get_year should extract year from 'Jahr:' text."""
mock_response = MagicMock()
mock_response.content = sample_series_html.encode("utf-8")
loader._KeyHTMLDict["naruto"] = mock_response
result = loader.get_year("naruto")
assert result == 2007
def test_get_year_returns_none_when_not_found(self, loader):
"""get_year should return None when no year info exists."""
html = "<html><body><div class='series-title'></div></body></html>"
mock_response = MagicMock()
mock_response.content = html.encode("utf-8")
loader._KeyHTMLDict["unknown"] = mock_response
result = loader.get_year("unknown")
assert result is None
class TestAniworldEpisodeHtml:
"""Test episode HTML fetching and caching."""
def test_get_episode_html_fetches_from_session(self, loader):
"""_get_episode_html should fetch from session and cache."""
mock_response = MagicMock()
mock_response.content = b"<html></html>"
loader.session.get.return_value = mock_response
result = loader._get_episode_html(1, 1, "naruto")
assert result is mock_response
loader.session.get.assert_called_once()
def test_get_episode_html_invalid_season(self, loader):
"""_get_episode_html should raise ValueError for invalid season."""
with pytest.raises(ValueError, match="Invalid season number"):
loader._get_episode_html(0, 1, "naruto")
def test_get_episode_html_invalid_episode(self, loader):
"""_get_episode_html should raise ValueError for invalid episode."""
with pytest.raises(ValueError, match="Invalid episode number"):
loader._get_episode_html(1, 0, "naruto")
def test_get_episode_html_season_too_large(self, loader):
"""_get_episode_html should raise ValueError for season > 999."""
with pytest.raises(ValueError, match="Invalid season number"):
loader._get_episode_html(1000, 1, "naruto")
def test_get_episode_html_episode_too_large(self, loader):
"""_get_episode_html should raise ValueError for episode > 9999."""
with pytest.raises(ValueError, match="Invalid episode number"):
loader._get_episode_html(1, 10000, "naruto")
class TestAniworldProviderParsing:
"""Test provider extraction from HTML."""
def test_parse_providers_from_html(self, loader):
"""_get_provider_from_html should extract available providers."""
html = """
<html><body>
<li class="episodeLink1" data-lang-key="1">
<h4>VOE</h4>
<a class="watchEpisode" href="/redirect/111"></a>
</li>
<li class="episodeLink2" data-lang-key="2">
<h4>Vidmoly</h4>
<a class="watchEpisode" href="/redirect/222"></a>
</li>
</body></html>
"""
mock_response = MagicMock()
mock_response.content = html.encode("utf-8")
loader.session.get.return_value = mock_response
result = loader._get_provider_from_html(1, 1, "naruto")
assert "VOE" in result
assert "Vidmoly" in result
assert 1 in result["VOE"]
assert 2 in result["Vidmoly"]
def test_parse_providers_empty_html(self, loader):
"""_get_provider_from_html should return empty dict for no providers."""
html = "<html><body></body></html>"
mock_response = MagicMock()
mock_response.content = html.encode("utf-8")
loader.session.get.return_value = mock_response
result = loader._get_provider_from_html(1, 1, "naruto")
assert result == {}
def test_parse_providers_missing_lang_key(self, loader):
"""Providers without data-lang-key should be skipped."""
html = """
<html><body>
<li class="episodeLink1">
<h4>VOE</h4>
<a class="watchEpisode" href="/redirect/111"></a>
</li>
</body></html>
"""
mock_response = MagicMock()
mock_response.content = html.encode("utf-8")
loader.session.get.return_value = mock_response
result = loader._get_provider_from_html(1, 1, "naruto")
assert result == {}
class TestAniworldSeasonEpisodeCount:
"""Test season and episode count retrieval."""
@patch("src.core.providers.aniworld_provider.requests.get")
def test_get_season_episode_count(self, mock_get, loader):
"""get_season_episode_count should return correct counts."""
# Main page with 2 seasons
main_html = '<html><body><meta itemprop="numberOfSeasons" content="2" /></body></html>'
# Season 1 with 3 episodes
s1_html = """
<html><body>
<a href="/anime/stream/naruto/staffel-1/episode-1">Ep1</a>
<a href="/anime/stream/naruto/staffel-1/episode-2">Ep2</a>
<a href="/anime/stream/naruto/staffel-1/episode-3">Ep3</a>
</body></html>
"""
# Season 2 with 2 episodes
s2_html = """
<html><body>
<a href="/anime/stream/naruto/staffel-2/episode-1">Ep1</a>
<a href="/anime/stream/naruto/staffel-2/episode-2">Ep2</a>
</body></html>
"""
responses = [
MagicMock(content=main_html.encode()),
MagicMock(content=s1_html.encode()),
MagicMock(content=s2_html.encode()),
]
mock_get.side_effect = responses
result = loader.get_season_episode_count("naruto")
assert result == {1: 3, 2: 2}
@patch("src.core.providers.aniworld_provider.requests.get")
def test_get_season_episode_count_no_seasons(self, mock_get, loader):
"""get_season_episode_count should return empty dict when no seasons."""
html = "<html><body></body></html>"
mock_get.return_value = MagicMock(content=html.encode())
result = loader.get_season_episode_count("nonexistent")
assert result == {}
class TestAniworldCache:
"""Test cache operations."""
def test_clear_cache(self, loader):
"""clear_cache should empty both caches."""
loader._KeyHTMLDict["key1"] = "data"
loader._EpisodeHTMLDict[("key1", 1, 1)] = "data"
loader.clear_cache()
assert len(loader._KeyHTMLDict) == 0
assert len(loader._EpisodeHTMLDict) == 0
def test_remove_from_cache(self, loader):
"""remove_from_cache should only clear episode cache."""
loader._KeyHTMLDict["key1"] = "data"
loader._EpisodeHTMLDict[("key1", 1, 1)] = "data"
loader.remove_from_cache()
assert len(loader._KeyHTMLDict) == 1
assert len(loader._EpisodeHTMLDict) == 0
class TestAniworldEvents:
"""Test event subscription for download progress."""
def test_subscribe_download_progress(self, loader):
"""subscribe_download_progress should register handler."""
handler = MagicMock()
loader.subscribe_download_progress(handler)
# Fire event to verify handler was registered
loader.events.download_progress({"status": "downloading"})
handler.assert_called_once_with({"status": "downloading"})
def test_unsubscribe_download_progress(self, loader):
"""unsubscribe_download_progress should remove handler."""
handler = MagicMock()
loader.subscribe_download_progress(handler)
loader.unsubscribe_download_progress(handler)
# Fire event - handler should NOT be called
loader.events.download_progress({"status": "downloading"})
handler.assert_not_called()

View File

@@ -0,0 +1,218 @@
"""Unit tests for base_provider.py - Abstract base class and interface contracts."""
from abc import ABC
from typing import Any, Dict, List
from unittest.mock import MagicMock
import pytest
from src.core.providers.base_provider import Loader
class TestLoaderAbstractInterface:
"""Test that Loader defines the correct abstract interface."""
def test_loader_is_abstract_class(self):
"""Loader should be an abstract class and not directly instantiable."""
assert issubclass(Loader, ABC)
with pytest.raises(TypeError):
Loader()
def test_loader_defines_search_method(self):
"""Loader must define abstract search method."""
assert hasattr(Loader, "search")
assert getattr(Loader.search, "__isabstractmethod__", False)
def test_loader_defines_is_language_method(self):
"""Loader must define abstract is_language method."""
assert hasattr(Loader, "is_language")
assert getattr(Loader.is_language, "__isabstractmethod__", False)
def test_loader_defines_download_method(self):
"""Loader must define abstract download method."""
assert hasattr(Loader, "download")
assert getattr(Loader.download, "__isabstractmethod__", False)
def test_loader_defines_get_site_key_method(self):
"""Loader must define abstract get_site_key method."""
assert hasattr(Loader, "get_site_key")
assert getattr(Loader.get_site_key, "__isabstractmethod__", False)
def test_loader_defines_get_title_method(self):
"""Loader must define abstract get_title method."""
assert hasattr(Loader, "get_title")
assert getattr(Loader.get_title, "__isabstractmethod__", False)
def test_loader_defines_get_season_episode_count_method(self):
"""Loader must define abstract get_season_episode_count method."""
assert hasattr(Loader, "get_season_episode_count")
assert getattr(
Loader.get_season_episode_count, "__isabstractmethod__", False
)
def test_loader_defines_subscribe_download_progress(self):
"""Loader must define abstract subscribe_download_progress method."""
assert hasattr(Loader, "subscribe_download_progress")
assert getattr(
Loader.subscribe_download_progress, "__isabstractmethod__", False
)
def test_loader_defines_unsubscribe_download_progress(self):
"""Loader must define abstract unsubscribe_download_progress method."""
assert hasattr(Loader, "unsubscribe_download_progress")
assert getattr(
Loader.unsubscribe_download_progress, "__isabstractmethod__", False
)
class ConcreteLoader(Loader):
"""Minimal concrete implementation for testing inheritance."""
def subscribe_download_progress(self, handler):
pass
def unsubscribe_download_progress(self, handler):
pass
def search(self, word: str) -> List[Dict[str, Any]]:
return [{"title": word}]
def is_language(
self,
season: int,
episode: int,
key: str,
language: str = "German Dub",
) -> bool:
return True
def download(
self,
base_directory: str,
serie_folder: str,
season: int,
episode: int,
key: str,
language: str = "German Dub",
) -> bool:
return True
def get_site_key(self) -> str:
return "test.provider"
def get_title(self, key: str) -> str:
return f"Title for {key}"
def get_season_episode_count(self, slug: str) -> Dict[int, int]:
return {1: 12, 2: 24}
class TestLoaderInheritance:
"""Test concrete implementations of the Loader interface."""
def test_concrete_loader_can_be_instantiated(self):
"""A fully implemented subclass should instantiate without error."""
loader = ConcreteLoader()
assert isinstance(loader, Loader)
def test_concrete_loader_search(self):
"""Concrete search should return a list of dicts."""
loader = ConcreteLoader()
result = loader.search("Naruto")
assert isinstance(result, list)
assert result[0]["title"] == "Naruto"
def test_concrete_loader_is_language(self):
"""Concrete is_language should return bool."""
loader = ConcreteLoader()
assert loader.is_language(1, 1, "naruto") is True
def test_concrete_loader_is_language_default_param(self):
"""is_language should default to 'German Dub' language."""
loader = ConcreteLoader()
result = loader.is_language(1, 1, "naruto")
assert isinstance(result, bool)
def test_concrete_loader_download(self):
"""Concrete download should return bool."""
loader = ConcreteLoader()
result = loader.download("/base", "folder", 1, 1, "naruto")
assert result is True
def test_concrete_loader_get_site_key(self):
"""Concrete get_site_key should return a string."""
loader = ConcreteLoader()
assert loader.get_site_key() == "test.provider"
def test_concrete_loader_get_title(self):
"""Concrete get_title should return title string."""
loader = ConcreteLoader()
assert loader.get_title("naruto") == "Title for naruto"
def test_concrete_loader_get_season_episode_count(self):
"""Concrete get_season_episode_count should return dict."""
loader = ConcreteLoader()
result = loader.get_season_episode_count("naruto")
assert isinstance(result, dict)
assert result[1] == 12
assert result[2] == 24
def test_concrete_loader_subscribe_download_progress(self):
"""subscribe_download_progress should accept handler without error."""
loader = ConcreteLoader()
handler = MagicMock()
loader.subscribe_download_progress(handler)
def test_concrete_loader_unsubscribe_download_progress(self):
"""unsubscribe_download_progress should accept handler without error."""
loader = ConcreteLoader()
handler = MagicMock()
loader.unsubscribe_download_progress(handler)
class IncompleteLoader(Loader):
"""Incomplete implementation missing some abstract methods."""
def search(self, word: str) -> List[Dict[str, Any]]:
return []
def is_language(self, season, episode, key, language="German Dub"):
return False
# Deliberately omit download, get_site_key, etc.
class TestIncompleteImplementation:
"""Test that incomplete implementations cannot be instantiated."""
def test_incomplete_loader_raises_type_error(self):
"""Loader subclass missing abstract methods cannot be instantiated."""
with pytest.raises(TypeError):
IncompleteLoader()
class TestLoaderMethodSignatures:
"""Test method signatures match the expected contract."""
def test_search_returns_list(self):
"""search() should return List[Dict[str, Any]]."""
loader = ConcreteLoader()
result = loader.search("test")
assert isinstance(result, list)
for item in result:
assert isinstance(item, dict)
def test_download_returns_bool(self):
"""download() should return bool."""
loader = ConcreteLoader()
result = loader.download("/dir", "folder", 1, 1, "key")
assert isinstance(result, bool)
def test_get_season_episode_count_returns_dict_int_int(self):
"""get_season_episode_count() should return Dict[int, int]."""
loader = ConcreteLoader()
result = loader.get_season_episode_count("slug")
assert isinstance(result, dict)
for k, v in result.items():
assert isinstance(k, int)
assert isinstance(v, int)

View File

@@ -0,0 +1,445 @@
"""Unit tests for enhanced_provider.py - Caching, recovery, download logic."""
import json
import os
from typing import Any, Dict, List
from unittest.mock import MagicMock, Mock, patch, PropertyMock
import pytest
from src.core.error_handler import (
DownloadError,
NetworkError,
NonRetryableError,
RetryableError,
)
from src.core.providers.base_provider import Loader
# Import the class but we need a concrete subclass to test it
from src.core.providers.enhanced_provider import EnhancedAniWorldLoader
class ConcreteEnhancedLoader(EnhancedAniWorldLoader):
"""Concrete subclass that bridges PascalCase methods to abstract interface."""
def subscribe_download_progress(self, handler):
pass
def unsubscribe_download_progress(self, handler):
pass
def search(self, word: str) -> List[Dict[str, Any]]:
return self.Search(word)
def is_language(self, season, episode, key, language="German Dub"):
return self.IsLanguage(season, episode, key, language)
def download(self, base_directory, serie_folder, season, episode, key,
language="German Dub", **kwargs):
return self.Download(base_directory, serie_folder, season, episode,
key, language)
def get_site_key(self) -> str:
return self.GetSiteKey()
def get_title(self, key: str) -> str:
return self.GetTitle(key)
@pytest.fixture
def enhanced_loader():
"""Create ConcreteEnhancedLoader with mocked externals."""
with patch(
"src.core.providers.enhanced_provider.UserAgent"
) as mock_ua, patch(
"src.core.providers.enhanced_provider.get_integrity_manager"
):
mock_ua.return_value.random = "MockAgent/1.0"
loader = ConcreteEnhancedLoader()
loader.session = MagicMock()
return loader
class TestEnhancedLoaderInit:
"""Test EnhancedAniWorldLoader initialization."""
def test_initializes_with_download_stats(self, enhanced_loader):
"""Should initialize download statistics at zero."""
stats = enhanced_loader.download_stats
assert stats["total_downloads"] == 0
assert stats["successful_downloads"] == 0
assert stats["failed_downloads"] == 0
assert stats["retried_downloads"] == 0
def test_initializes_with_caches(self, enhanced_loader):
"""Should initialize empty caches."""
assert enhanced_loader._KeyHTMLDict == {}
assert enhanced_loader._EpisodeHTMLDict == {}
def test_site_key(self, enhanced_loader):
"""GetSiteKey should return 'aniworld.to'."""
assert enhanced_loader.GetSiteKey() == "aniworld.to"
def test_has_supported_providers(self, enhanced_loader):
"""Should have a list of supported providers."""
assert isinstance(enhanced_loader.SUPPORTED_PROVIDERS, list)
assert len(enhanced_loader.SUPPORTED_PROVIDERS) > 0
assert "VOE" in enhanced_loader.SUPPORTED_PROVIDERS
class TestEnhancedSearch:
"""Test enhanced search with error recovery."""
def test_search_empty_term_raises(self, enhanced_loader):
"""Search with empty term should raise ValueError."""
with pytest.raises(ValueError, match="empty"):
enhanced_loader.Search("")
def test_search_whitespace_only_raises(self, enhanced_loader):
"""Search with whitespace-only term should raise ValueError."""
with pytest.raises(ValueError, match="empty"):
enhanced_loader.Search(" ")
def test_search_successful(self, enhanced_loader):
"""Successful search should return parsed list."""
mock_response = MagicMock()
mock_response.ok = True
mock_response.text = json.dumps([
{"title": "Naruto", "link": "/anime/stream/naruto"}
])
enhanced_loader.session.get.return_value = mock_response
result = enhanced_loader.Search("Naruto")
assert len(result) == 1
assert result[0]["title"] == "Naruto"
class TestParseAnimeResponse:
"""Test JSON parsing strategies."""
def test_parse_valid_json_list(self, enhanced_loader):
"""Should parse valid JSON list."""
text = '[{"title": "Naruto"}]'
result = enhanced_loader._parse_anime_response(text)
assert len(result) == 1
def test_parse_html_escaped_json(self, enhanced_loader):
"""Should handle HTML-escaped JSON."""
text = '[{"title": "Naruto &amp; Boruto"}]'
result = enhanced_loader._parse_anime_response(text)
assert result[0]["title"] == "Naruto & Boruto"
def test_parse_empty_response_raises(self, enhanced_loader):
"""Empty response should raise ValueError."""
with pytest.raises(ValueError, match="Empty response"):
enhanced_loader._parse_anime_response("")
def test_parse_whitespace_only_raises(self, enhanced_loader):
"""Whitespace-only response should raise ValueError."""
with pytest.raises(ValueError, match="Empty response"):
enhanced_loader._parse_anime_response(" ")
def test_parse_html_response_raises(self, enhanced_loader):
"""HTML response instead of JSON should raise ValueError."""
with pytest.raises(ValueError):
enhanced_loader._parse_anime_response(
"<!DOCTYPE html><html></html>"
)
def test_parse_bom_json(self, enhanced_loader):
"""Should handle BOM-prefixed JSON."""
text = '\ufeff[{"title": "Test"}]'
result = enhanced_loader._parse_anime_response(text)
assert len(result) == 1
def test_parse_control_characters(self, enhanced_loader):
"""Should strip control characters and parse."""
text = '[{"title": "Na\x00ruto"}]'
result = enhanced_loader._parse_anime_response(text)
assert len(result) == 1
def test_parse_non_list_result_raises(self, enhanced_loader):
"""Non-list JSON should raise ValueError."""
with pytest.raises(ValueError):
enhanced_loader._parse_anime_response('{"key": "value"}')
class TestLanguageKey:
"""Test language code mapping."""
def test_german_dub(self, enhanced_loader):
"""German Dub should map to 1."""
assert enhanced_loader._GetLanguageKey("German Dub") == 1
def test_english_sub(self, enhanced_loader):
"""English Sub should map to 2."""
assert enhanced_loader._GetLanguageKey("English Sub") == 2
def test_german_sub(self, enhanced_loader):
"""German Sub should map to 3."""
assert enhanced_loader._GetLanguageKey("German Sub") == 3
def test_unknown_language(self, enhanced_loader):
"""Unknown language should map to 0."""
assert enhanced_loader._GetLanguageKey("French Dub") == 0
class TestEnhancedIsLanguage:
"""Test language availability checking with recovery."""
def test_is_language_available(self, enhanced_loader):
"""Should return True when language is available."""
html = """
<html><body>
<div class="changeLanguageBox">
<img data-lang-key="1" />
<img data-lang-key="2" />
</div>
</body></html>
"""
mock_response = MagicMock()
mock_response.content = html.encode("utf-8")
enhanced_loader._EpisodeHTMLDict[(
"naruto", 1, 1
)] = mock_response
result = enhanced_loader.IsLanguage(1, 1, "naruto", "German Dub")
assert result is True
def test_is_language_not_available(self, enhanced_loader):
"""Should return False when language is not available."""
html = """
<html><body>
<div class="changeLanguageBox">
<img data-lang-key="1" />
</div>
</body></html>
"""
mock_response = MagicMock()
mock_response.content = html.encode("utf-8")
enhanced_loader._EpisodeHTMLDict[(
"naruto", 1, 1
)] = mock_response
result = enhanced_loader.IsLanguage(1, 1, "naruto", "German Sub")
assert result is False
def test_is_language_no_language_box(self, enhanced_loader):
"""Should return False when no language box in HTML."""
html = "<html><body></body></html>"
mock_response = MagicMock()
mock_response.content = html.encode("utf-8")
enhanced_loader._EpisodeHTMLDict[(
"naruto", 1, 1
)] = mock_response
result = enhanced_loader.IsLanguage(1, 1, "naruto", "German Dub")
assert result is False
class TestEnhancedGetTitle:
"""Test title extraction with error recovery."""
def test_get_title_successful(self, enhanced_loader):
"""Should extract title from HTML."""
html = """
<html><body>
<div class="series-title">
<h1><span>Attack on Titan</span></h1>
</div>
</body></html>
"""
mock_response = MagicMock()
mock_response.content = html.encode("utf-8")
enhanced_loader._KeyHTMLDict["aot"] = mock_response
result = enhanced_loader.GetTitle("aot")
assert result == "Attack on Titan"
def test_get_title_missing_returns_fallback(self, enhanced_loader):
"""Should return fallback title when not found in HTML."""
html = "<html><body></body></html>"
mock_response = MagicMock()
mock_response.content = html.encode("utf-8")
enhanced_loader._KeyHTMLDict["unknown"] = mock_response
result = enhanced_loader.GetTitle("unknown")
assert "Unknown_Title" in result
class TestEnhancedCache:
"""Test cache operations."""
def test_clear_cache(self, enhanced_loader):
"""ClearCache should empty all caches."""
enhanced_loader._KeyHTMLDict["key"] = "data"
enhanced_loader._EpisodeHTMLDict[("key", 1, 1)] = "data"
enhanced_loader.ClearCache()
assert len(enhanced_loader._KeyHTMLDict) == 0
assert len(enhanced_loader._EpisodeHTMLDict) == 0
def test_remove_from_cache(self, enhanced_loader):
"""RemoveFromCache should only clear episode cache."""
enhanced_loader._KeyHTMLDict["key"] = "data"
enhanced_loader._EpisodeHTMLDict[("key", 1, 1)] = "data"
enhanced_loader.RemoveFromCache()
assert len(enhanced_loader._KeyHTMLDict) == 1
assert len(enhanced_loader._EpisodeHTMLDict) == 0
class TestEnhancedGetEpisodeHTML:
"""Test episode HTML fetching with validation."""
def test_empty_key_raises(self, enhanced_loader):
"""Empty key should raise ValueError."""
with pytest.raises(ValueError, match="empty"):
enhanced_loader._GetEpisodeHTML(1, 1, "")
def test_whitespace_key_raises(self, enhanced_loader):
"""Whitespace key should raise ValueError."""
with pytest.raises(ValueError, match="empty"):
enhanced_loader._GetEpisodeHTML(1, 1, " ")
def test_invalid_season_zero_raises(self, enhanced_loader):
"""Season 0 should raise ValueError."""
with pytest.raises(ValueError, match="Invalid season"):
enhanced_loader._GetEpisodeHTML(0, 1, "naruto")
def test_invalid_season_negative_raises(self, enhanced_loader):
"""Negative season should raise ValueError."""
with pytest.raises(ValueError, match="Invalid season"):
enhanced_loader._GetEpisodeHTML(-1, 1, "naruto")
def test_invalid_episode_zero_raises(self, enhanced_loader):
"""Episode 0 should raise ValueError."""
with pytest.raises(ValueError, match="Invalid episode"):
enhanced_loader._GetEpisodeHTML(1, 0, "naruto")
def test_cached_episode_returned(self, enhanced_loader):
"""Should return cached response without HTTP call."""
mock_response = MagicMock()
enhanced_loader._EpisodeHTMLDict[("naruto", 1, 1)] = mock_response
result = enhanced_loader._GetEpisodeHTML(1, 1, "naruto")
assert result is mock_response
enhanced_loader.session.get.assert_not_called()
class TestDownloadStatistics:
"""Test download statistics tracking."""
def test_get_download_statistics(self, enhanced_loader):
"""Should return stats with calculated success rate."""
enhanced_loader.download_stats["total_downloads"] = 10
enhanced_loader.download_stats["successful_downloads"] = 8
enhanced_loader.download_stats["failed_downloads"] = 2
stats = enhanced_loader.get_download_statistics()
assert stats["success_rate"] == 80.0
def test_statistics_zero_downloads(self, enhanced_loader):
"""Success rate should be 0 with no downloads."""
stats = enhanced_loader.get_download_statistics()
assert stats["success_rate"] == 0
def test_reset_statistics(self, enhanced_loader):
"""reset_statistics should zero all counters."""
enhanced_loader.download_stats["total_downloads"] = 10
enhanced_loader.download_stats["successful_downloads"] = 8
enhanced_loader.reset_statistics()
assert enhanced_loader.download_stats["total_downloads"] == 0
assert enhanced_loader.download_stats["successful_downloads"] == 0
class TestEnhancedDownloadValidation:
"""Test download input validation."""
@patch("src.core.providers.enhanced_provider.get_integrity_manager")
def test_download_missing_base_directory_raises(
self, mock_integrity, enhanced_loader
):
"""Download with empty base_directory should raise."""
with pytest.raises((ValueError, DownloadError)):
enhanced_loader.Download("", "folder", 1, 1, "key")
@patch("src.core.providers.enhanced_provider.get_integrity_manager")
def test_download_missing_serie_folder_raises(
self, mock_integrity, enhanced_loader
):
"""Download with empty serie_folder should raise."""
with pytest.raises((ValueError, DownloadError)):
enhanced_loader.Download("/base", "", 1, 1, "key")
@patch("src.core.providers.enhanced_provider.get_integrity_manager")
def test_download_negative_season_raises(
self, mock_integrity, enhanced_loader
):
"""Download with negative season should raise."""
with pytest.raises((ValueError, DownloadError)):
enhanced_loader.Download("/base", "folder", -1, 1, "key")
@patch("src.core.providers.enhanced_provider.get_integrity_manager")
def test_download_negative_episode_raises(
self, mock_integrity, enhanced_loader
):
"""Download with negative episode should raise."""
with pytest.raises((ValueError, DownloadError)):
enhanced_loader.Download("/base", "folder", 1, -1, "key")
@patch("src.core.providers.enhanced_provider.get_integrity_manager")
def test_download_increments_total_count(
self, mock_integrity, enhanced_loader
):
"""Download should increment total_downloads counter."""
# Make it fail fast so we don't need to mock everything
enhanced_loader._KeyHTMLDict["key"] = MagicMock(
content=b"<html><body><div class='series-title'><h1><span>Test</span></h1></div></body></html>"
)
try:
enhanced_loader.Download("/base", "folder", 1, 1, "key")
except Exception:
pass
assert enhanced_loader.download_stats["total_downloads"] >= 1
class TestEnhancedProviderFromHTML:
"""Test provider extraction from episode HTML."""
def test_extract_providers(self, enhanced_loader):
"""Should extract providers with language keys from HTML."""
html = """
<html><body>
<li class="episodeLink1" data-lang-key="1">
<h4>VOE</h4>
<a class="watchEpisode" href="/redirect/100"></a>
</li>
<li class="episodeLink2" data-lang-key="2">
<h4>Vidmoly</h4>
<a class="watchEpisode" href="/redirect/200"></a>
</li>
</body></html>
"""
mock_response = MagicMock()
mock_response.content = html.encode("utf-8")
enhanced_loader._EpisodeHTMLDict[("test", 1, 1)] = mock_response
result = enhanced_loader._get_provider_from_html(1, 1, "test")
assert "VOE" in result
assert "Vidmoly" in result
def test_extract_providers_empty_page(self, enhanced_loader):
"""Should return empty dict when no providers found."""
html = "<html><body></body></html>"
mock_response = MagicMock()
mock_response.content = html.encode("utf-8")
enhanced_loader._EpisodeHTMLDict[("test", 1, 1)] = mock_response
result = enhanced_loader._get_provider_from_html(1, 1, "test")
assert result == {}

View File

@@ -0,0 +1,336 @@
"""Unit tests for monitored_provider.py - Metrics collection, health checks, monitoring integration."""
import time
from unittest.mock import MagicMock, patch
import pytest
from src.core.providers.base_provider import Loader
from src.core.providers.monitored_provider import (
MonitoredProviderWrapper,
wrap_provider,
)
class MockProvider(Loader):
"""Mock provider for testing the monitoring wrapper."""
def __init__(self, site_key: str = "mock.provider"):
self._site_key = site_key
self._search_result = []
self._is_language_result = True
self._download_result = True
self._title = "Mock Title"
self._season_episodes = {1: 12}
self.raise_on_search = False
self.raise_on_download = False
def subscribe_download_progress(self, handler):
pass
def unsubscribe_download_progress(self, handler):
pass
def search(self, word):
if self.raise_on_search:
raise ConnectionError("Search failed")
return self._search_result
def is_language(self, season, episode, key, language="German Dub"):
return self._is_language_result
def download(
self, base_directory, serie_folder, season, episode, key,
language="German Dub", progress_callback=None
):
if self.raise_on_download:
raise ConnectionError("Download failed")
return self._download_result
def get_site_key(self):
return self._site_key
def get_title(self, key):
return self._title
def get_season_episode_count(self, slug):
return self._season_episodes
class ConcreteMonitoredWrapper(MonitoredProviderWrapper):
"""Concrete subclass adding the missing abstract methods."""
def subscribe_download_progress(self, handler):
pass
def unsubscribe_download_progress(self, handler):
pass
@pytest.fixture
def mock_provider():
"""Create a mock provider instance."""
return MockProvider()
@pytest.fixture
def mock_health_monitor():
"""Create a mock health monitor."""
monitor = MagicMock()
return monitor
@pytest.fixture
def monitored_wrapper(mock_provider, mock_health_monitor):
"""Create a monitored wrapper with mock health monitor."""
with patch(
"src.core.providers.monitored_provider.get_health_monitor",
return_value=mock_health_monitor,
):
wrapper = ConcreteMonitoredWrapper(
provider=mock_provider,
enable_monitoring=True,
)
return wrapper
class TestMonitoredProviderWrapperInit:
"""Test MonitoredProviderWrapper initialization."""
def test_wrapper_stores_provider(self, mock_provider):
"""Wrapper should store the wrapped provider."""
with patch(
"src.core.providers.monitored_provider.get_health_monitor"
):
wrapper = ConcreteMonitoredWrapper(mock_provider)
assert wrapper._provider is mock_provider
def test_wrapper_monitoring_enabled_by_default(self, mock_provider):
"""Monitoring should be enabled by default."""
with patch(
"src.core.providers.monitored_provider.get_health_monitor"
):
wrapper = ConcreteMonitoredWrapper(mock_provider)
assert wrapper._enable_monitoring is True
def test_wrapper_monitoring_can_be_disabled(self, mock_provider):
"""Monitoring can be disabled on init."""
wrapper = ConcreteMonitoredWrapper(
mock_provider, enable_monitoring=False
)
assert wrapper._enable_monitoring is False
assert wrapper._health_monitor is None
def test_wrapped_provider_property(self, monitored_wrapper, mock_provider):
"""wrapped_provider property should return the underlying provider."""
assert monitored_wrapper.wrapped_provider is mock_provider
class TestMonitoredSearch:
"""Test search with monitoring."""
def test_search_delegates_to_provider(self, monitored_wrapper, mock_provider):
"""search() should delegate to wrapped provider."""
mock_provider._search_result = [{"title": "Test"}]
result = monitored_wrapper.search("test")
assert result == [{"title": "Test"}]
def test_search_records_success_metric(
self, monitored_wrapper, mock_health_monitor
):
"""Successful search should record a success metric."""
monitored_wrapper.search("test")
mock_health_monitor.record_request.assert_called_once()
call_kwargs = mock_health_monitor.record_request.call_args[1]
assert call_kwargs["success"] is True
assert call_kwargs["provider_name"] == "mock.provider"
def test_search_records_failure_metric(
self, monitored_wrapper, mock_provider, mock_health_monitor
):
"""Failed search should record a failure metric."""
mock_provider.raise_on_search = True
with pytest.raises(ConnectionError):
monitored_wrapper.search("test")
mock_health_monitor.record_request.assert_called_once()
call_kwargs = mock_health_monitor.record_request.call_args[1]
assert call_kwargs["success"] is False
def test_search_propagates_exception(
self, monitored_wrapper, mock_provider
):
"""Exception from provider should propagate through wrapper."""
mock_provider.raise_on_search = True
with pytest.raises(ConnectionError, match="Search failed"):
monitored_wrapper.search("test")
class TestMonitoredIsLanguage:
"""Test is_language with monitoring."""
def test_is_language_delegates(self, monitored_wrapper, mock_provider):
"""is_language should delegate to wrapped provider."""
mock_provider._is_language_result = True
result = monitored_wrapper.is_language(1, 1, "key")
assert result is True
def test_is_language_records_metric(
self, monitored_wrapper, mock_health_monitor
):
"""is_language should record metric."""
monitored_wrapper.is_language(1, 1, "key")
mock_health_monitor.record_request.assert_called_once()
call_kwargs = mock_health_monitor.record_request.call_args[1]
assert call_kwargs["success"] is True
class TestMonitoredDownload:
"""Test download with monitoring."""
def test_download_delegates_to_provider(
self, monitored_wrapper, mock_provider
):
"""download should delegate to wrapped provider."""
mock_provider._download_result = True
result = monitored_wrapper.download(
"/base", "folder", 1, 1, "key"
)
assert result is True
def test_download_records_success(
self, monitored_wrapper, mock_health_monitor
):
"""Successful download should record success metric."""
monitored_wrapper.download("/base", "folder", 1, 1, "key")
mock_health_monitor.record_request.assert_called_once()
call_kwargs = mock_health_monitor.record_request.call_args[1]
assert call_kwargs["success"] is True
def test_download_records_failure(
self, monitored_wrapper, mock_provider, mock_health_monitor
):
"""Failed download should record failure metric."""
mock_provider.raise_on_download = True
with pytest.raises(ConnectionError):
monitored_wrapper.download("/base", "folder", 1, 1, "key")
call_kwargs = mock_health_monitor.record_request.call_args[1]
assert call_kwargs["success"] is False
def test_download_tracks_bytes(
self, monitored_wrapper, mock_provider, mock_health_monitor
):
"""Download with progress callback should track bytes."""
progress_data = {"downloaded": 5000}
def mock_download(
base_directory, serie_folder, season, episode, key,
language="German Dub", progress_callback=None
):
if progress_callback:
progress_callback("progress", progress_data)
return True
mock_provider.download = mock_download
callback = MagicMock()
monitored_wrapper.download(
"/base", "folder", 1, 1, "key",
progress_callback=callback
)
callback.assert_called_once_with("progress", progress_data)
class TestMonitoredGetTitle:
"""Test get_title with monitoring."""
def test_get_title_delegates(self, monitored_wrapper, mock_provider):
"""get_title should delegate and return title."""
mock_provider._title = "Naruto"
result = monitored_wrapper.get_title("naruto")
assert result == "Naruto"
def test_get_title_records_metric(
self, monitored_wrapper, mock_health_monitor
):
"""get_title should record metric."""
monitored_wrapper.get_title("test")
mock_health_monitor.record_request.assert_called_once()
class TestMonitoredGetSiteKey:
"""Test get_site_key delegation."""
def test_get_site_key_delegates(self, monitored_wrapper):
"""get_site_key should return wrapped provider's site key."""
assert monitored_wrapper.get_site_key() == "mock.provider"
class TestMonitoredGetSeasonEpisodeCount:
"""Test get_season_episode_count with monitoring."""
def test_delegates_correctly(self, monitored_wrapper, mock_provider):
"""Should delegate and return season/episode data."""
mock_provider._season_episodes = {1: 24, 2: 12}
result = monitored_wrapper.get_season_episode_count("test")
assert result == {1: 24, 2: 12}
def test_records_metric(self, monitored_wrapper, mock_health_monitor):
"""Should record metric for season/episode count call."""
monitored_wrapper.get_season_episode_count("test")
mock_health_monitor.record_request.assert_called_once()
class TestRecordOperation:
"""Test _record_operation method."""
def test_no_recording_when_monitoring_disabled(self, mock_provider):
"""Should not record when monitoring is disabled."""
wrapper = ConcreteMonitoredWrapper(
mock_provider, enable_monitoring=False
)
# This should not raise even without health monitor
wrapper._record_operation(
"test_op", time.time(), True
)
def test_records_elapsed_time(
self, monitored_wrapper, mock_health_monitor
):
"""Should calculate and record elapsed time."""
start = time.time() - 0.1 # 100ms ago
monitored_wrapper._record_operation(
"test_op", start, True
)
call_kwargs = mock_health_monitor.record_request.call_args[1]
assert call_kwargs["response_time_ms"] > 0
def test_records_error_message(
self, monitored_wrapper, mock_health_monitor
):
"""Should record error message on failure."""
monitored_wrapper._record_operation(
"test_op", time.time(), False, error_message="test error"
)
call_kwargs = mock_health_monitor.record_request.call_args[1]
assert call_kwargs["error_message"] == "test error"
class TestWrapProviderFunction:
"""Test the wrap_provider convenience function."""
def test_wrap_creates_monitored_wrapper(self, mock_provider):
"""wrap_provider should return MonitoredProviderWrapper."""
with patch(
"src.core.providers.monitored_provider.get_health_monitor"
):
# wrap_provider returns MonitoredProviderWrapper which can't be
# instantiated directly due to missing abstract methods.
# This tests that wrap_provider raises the expected error.
with pytest.raises(TypeError):
result = wrap_provider(mock_provider)
def test_wrap_with_monitoring_disabled(self, mock_provider):
"""wrap_provider with monitoring disabled."""
# MonitoredProviderWrapper is abstract, so wrap_provider can't
# create it directly. This tests the expected behavior.
with pytest.raises(TypeError):
result = wrap_provider(mock_provider, enable_monitoring=False)

View File

@@ -0,0 +1,429 @@
"""Unit tests for config_manager.py - Configuration loading, validation, defaults."""
import json
import tempfile
from pathlib import Path
import pytest
from src.core.providers.config_manager import (
ProviderConfigManager,
ProviderSettings,
get_config_manager,
)
class TestProviderSettings:
"""Test ProviderSettings dataclass."""
def test_default_values(self):
"""ProviderSettings should have sensible defaults."""
settings = ProviderSettings(name="test_provider")
assert settings.name == "test_provider"
assert settings.enabled is True
assert settings.priority == 0
assert settings.timeout_seconds == 30
assert settings.max_retries == 3
assert settings.retry_delay_seconds == 1.0
assert settings.max_concurrent_downloads == 3
assert settings.bandwidth_limit_mbps is None
assert settings.custom_headers is None
assert settings.custom_params is None
def test_to_dict(self):
"""to_dict should convert settings to dict, excluding None values."""
settings = ProviderSettings(
name="test",
enabled=True,
priority=1,
)
result = settings.to_dict()
assert result["name"] == "test"
assert result["enabled"] is True
assert result["priority"] == 1
# None values should be excluded
assert "bandwidth_limit_mbps" not in result
assert "custom_headers" not in result
def test_to_dict_with_optional_fields(self):
"""to_dict with optional fields set should include them."""
settings = ProviderSettings(
name="test",
bandwidth_limit_mbps=10.0,
custom_headers={"X-Custom": "value"},
)
result = settings.to_dict()
assert result["bandwidth_limit_mbps"] == 10.0
assert result["custom_headers"] == {"X-Custom": "value"}
def test_from_dict(self):
"""from_dict should create settings from fields with defaults."""
# Note: from_dict uses hasattr(cls, k) which only matches fields
# with defaults on the class. The 'name' field has no default,
# so it must be passed explicitly.
data = {
"name": "test_provider",
"enabled": False,
"priority": 5,
"timeout_seconds": 60,
}
# The from_dict filters with hasattr which excludes 'name'
# (no default), so this should raise TypeError
with pytest.raises(TypeError):
ProviderSettings.from_dict(data)
def test_from_dict_with_only_defaults_fields(self):
"""from_dict works when all fields have defaults (except name)."""
# Directly construct to test fields with defaults
data = {
"enabled": False,
"priority": 5,
}
# This will fail because 'name' is required but filtered out
with pytest.raises(TypeError):
ProviderSettings.from_dict(data)
def test_from_dict_ignores_unknown_fields(self):
"""from_dict should ignore fields not in the dataclass."""
data = {
"name": "test",
"unknown_field": "value",
"another_unknown": 42,
}
# name gets filtered by hasattr → TypeError for missing name
with pytest.raises(TypeError):
ProviderSettings.from_dict(data)
def test_from_dict_with_defaults(self):
"""from_dict with only-defaults data loses required 'name'."""
data = {"name": "minimal"}
with pytest.raises(TypeError):
ProviderSettings.from_dict(data)
class TestProviderConfigManager:
"""Test ProviderConfigManager class."""
def test_init_without_config_file(self):
"""Should initialize with empty provider settings."""
manager = ProviderConfigManager()
assert manager._provider_settings == {}
def test_init_with_nonexistent_config_file(self):
"""Should initialize cleanly when config file doesn't exist."""
manager = ProviderConfigManager(
config_file=Path("/nonexistent/config.json")
)
assert manager._provider_settings == {}
def test_global_settings_defaults(self):
"""Should have sensible global defaults."""
manager = ProviderConfigManager()
assert manager.get_global_setting("default_timeout") == 30
assert manager.get_global_setting("default_max_retries") == 3
assert manager.get_global_setting("enable_health_monitoring") is True
assert manager.get_global_setting("enable_failover") is True
def test_get_provider_settings_none_for_unknown(self):
"""Should return None for unknown provider."""
manager = ProviderConfigManager()
assert manager.get_provider_settings("unknown") is None
def test_set_and_get_provider_settings(self):
"""Should store and retrieve provider settings."""
manager = ProviderConfigManager()
settings = ProviderSettings(name="test", priority=1)
manager.set_provider_settings("test", settings)
result = manager.get_provider_settings("test")
assert result is not None
assert result.name == "test"
assert result.priority == 1
def test_update_provider_settings_existing(self):
"""Should update existing provider settings."""
manager = ProviderConfigManager()
settings = ProviderSettings(name="test", priority=1)
manager.set_provider_settings("test", settings)
result = manager.update_provider_settings("test", priority=5)
assert result is True
updated = manager.get_provider_settings("test")
assert updated.priority == 5
def test_update_provider_settings_new(self):
"""Should create new settings when provider doesn't exist."""
manager = ProviderConfigManager()
result = manager.update_provider_settings(
"new_provider", priority=3, timeout_seconds=60
)
assert result is True
settings = manager.get_provider_settings("new_provider")
assert settings is not None
assert settings.priority == 3
assert settings.timeout_seconds == 60
def test_get_all_provider_settings(self):
"""Should return copy of all provider settings."""
manager = ProviderConfigManager()
manager.set_provider_settings(
"p1", ProviderSettings(name="p1")
)
manager.set_provider_settings(
"p2", ProviderSettings(name="p2")
)
all_settings = manager.get_all_provider_settings()
assert len(all_settings) == 2
assert "p1" in all_settings
assert "p2" in all_settings
def test_get_all_returns_copy(self):
"""get_all_provider_settings should return a copy."""
manager = ProviderConfigManager()
manager.set_provider_settings(
"p1", ProviderSettings(name="p1")
)
all_settings = manager.get_all_provider_settings()
all_settings["p2"] = ProviderSettings(name="p2")
assert "p2" not in manager.get_all_provider_settings()
class TestProviderEnableDisable:
"""Test enable/disable provider functionality."""
def test_enable_provider(self):
"""Should enable a disabled provider."""
manager = ProviderConfigManager()
settings = ProviderSettings(name="test", enabled=False)
manager.set_provider_settings("test", settings)
result = manager.enable_provider("test")
assert result is True
assert manager.get_provider_settings("test").enabled is True
def test_disable_provider(self):
"""Should disable an enabled provider."""
manager = ProviderConfigManager()
settings = ProviderSettings(name="test", enabled=True)
manager.set_provider_settings("test", settings)
result = manager.disable_provider("test")
assert result is True
assert manager.get_provider_settings("test").enabled is False
def test_enable_unknown_provider_returns_false(self):
"""Should return False when enabling unknown provider."""
manager = ProviderConfigManager()
assert manager.enable_provider("unknown") is False
def test_disable_unknown_provider_returns_false(self):
"""Should return False when disabling unknown provider."""
manager = ProviderConfigManager()
assert manager.disable_provider("unknown") is False
def test_get_enabled_providers(self):
"""Should return only enabled providers."""
manager = ProviderConfigManager()
manager.set_provider_settings(
"p1", ProviderSettings(name="p1", enabled=True)
)
manager.set_provider_settings(
"p2", ProviderSettings(name="p2", enabled=False)
)
manager.set_provider_settings(
"p3", ProviderSettings(name="p3", enabled=True)
)
enabled = manager.get_enabled_providers()
assert "p1" in enabled
assert "p2" not in enabled
assert "p3" in enabled
class TestProviderPriority:
"""Test provider priority management."""
def test_set_provider_priority(self):
"""Should set priority for a provider."""
manager = ProviderConfigManager()
manager.set_provider_settings(
"test", ProviderSettings(name="test", priority=0)
)
result = manager.set_provider_priority("test", 5)
assert result is True
assert manager.get_provider_settings("test").priority == 5
def test_set_priority_unknown_returns_false(self):
"""Should return False for unknown provider."""
manager = ProviderConfigManager()
assert manager.set_provider_priority("unknown", 1) is False
def test_get_providers_by_priority(self):
"""Should return providers sorted by priority."""
manager = ProviderConfigManager()
manager.set_provider_settings(
"low", ProviderSettings(name="low", priority=10)
)
manager.set_provider_settings(
"high", ProviderSettings(name="high", priority=1)
)
manager.set_provider_settings(
"mid", ProviderSettings(name="mid", priority=5)
)
sorted_providers = manager.get_providers_by_priority()
assert sorted_providers == ["high", "mid", "low"]
class TestGlobalSettings:
"""Test global settings management."""
def test_get_global_setting(self):
"""Should retrieve global setting value."""
manager = ProviderConfigManager()
assert manager.get_global_setting("default_timeout") == 30
def test_get_unknown_global_setting(self):
"""Should return None for unknown global setting."""
manager = ProviderConfigManager()
assert manager.get_global_setting("nonexistent") is None
def test_set_global_setting(self):
"""Should set a global setting."""
manager = ProviderConfigManager()
manager.set_global_setting("custom_key", "custom_value")
assert manager.get_global_setting("custom_key") == "custom_value"
def test_get_all_global_settings(self):
"""Should return all global settings."""
manager = ProviderConfigManager()
all_settings = manager.get_all_global_settings()
assert "default_timeout" in all_settings
assert "enable_failover" in all_settings
def test_get_all_global_returns_copy(self):
"""get_all_global_settings should return a copy."""
manager = ProviderConfigManager()
settings = manager.get_all_global_settings()
settings["new_key"] = "new_value"
assert manager.get_global_setting("new_key") is None
class TestConfigPersistence:
"""Test configuration save/load functionality."""
def test_save_and_load_config(self):
"""Should save and load configuration from file."""
with tempfile.NamedTemporaryFile(
suffix=".json", delete=False, mode="w"
) as f:
config_path = Path(f.name)
try:
manager = ProviderConfigManager(config_file=config_path)
manager.set_provider_settings(
"test_provider",
ProviderSettings(name="test_provider", priority=3),
)
manager.set_global_setting("custom_option", True)
assert manager.save_config() is True
# Verify the file was created and is valid JSON
assert config_path.exists()
with open(config_path, "r") as f:
data = json.load(f)
assert "providers" in data
assert "test_provider" in data["providers"]
assert data["providers"]["test_provider"]["priority"] == 3
assert data["global"]["custom_option"] is True
finally:
config_path.unlink(missing_ok=True)
def test_save_config_no_path(self):
"""Should return False when no path is specified."""
manager = ProviderConfigManager()
assert manager.save_config() is False
def test_load_config_nonexistent_file(self):
"""Should return False when file doesn't exist."""
manager = ProviderConfigManager()
assert manager.load_config(Path("/nonexistent.json")) is False
def test_load_config_invalid_json(self):
"""Should return False for invalid JSON file."""
with tempfile.NamedTemporaryFile(
suffix=".json", delete=False, mode="w"
) as f:
f.write("not valid json{{{")
config_path = Path(f.name)
try:
manager = ProviderConfigManager()
assert manager.load_config(config_path) is False
finally:
config_path.unlink(missing_ok=True)
def test_save_creates_parent_directories(self):
"""save_config should create parent directories if needed."""
with tempfile.TemporaryDirectory() as tmpdir:
config_path = Path(tmpdir) / "subdir" / "config.json"
manager = ProviderConfigManager(config_file=config_path)
manager.set_provider_settings(
"p1", ProviderSettings(name="p1")
)
assert manager.save_config() is True
assert config_path.exists()
class TestResetToDefaults:
"""Test reset functionality."""
def test_reset_clears_providers(self):
"""reset_to_defaults should clear all provider settings."""
manager = ProviderConfigManager()
manager.set_provider_settings(
"p1", ProviderSettings(name="p1")
)
manager.reset_to_defaults()
assert manager.get_all_provider_settings() == {}
def test_reset_restores_global_defaults(self):
"""reset_to_defaults should restore default global settings."""
manager = ProviderConfigManager()
manager.set_global_setting("default_timeout", 999)
manager.set_global_setting("custom_key", "value")
manager.reset_to_defaults()
assert manager.get_global_setting("default_timeout") == 30
assert manager.get_global_setting("custom_key") is None
class TestGetConfigManagerSingleton:
"""Test the get_config_manager singleton function."""
def test_returns_instance(self):
"""get_config_manager should return a ProviderConfigManager."""
# Reset global state for test
import src.core.providers.config_manager as cm
cm._config_manager = None
manager = get_config_manager()
assert isinstance(manager, ProviderConfigManager)
# Cleanup
cm._config_manager = None
def test_returns_same_instance(self):
"""get_config_manager should return same instance on repeated calls."""
import src.core.providers.config_manager as cm
cm._config_manager = None
first = get_config_manager()
second = get_config_manager()
assert first is second
# Cleanup
cm._config_manager = None

View File

@@ -0,0 +1,105 @@
"""Unit tests for provider_factory.py - Factory instantiation, dependency injection, provider registration."""
from unittest.mock import MagicMock, patch
import pytest
from src.core.providers.base_provider import Loader
from src.core.providers.provider_factory import Loaders
class TestLoadersInit:
"""Test Loaders factory initialization."""
@patch("src.core.providers.provider_factory.AniworldLoader")
def test_factory_initializes_with_default_providers(self, mock_aniworld):
"""Factory should register aniworld.to provider by default."""
mock_aniworld.return_value = MagicMock(spec=Loader)
factory = Loaders()
assert "aniworld.to" in factory.dict
@patch("src.core.providers.provider_factory.AniworldLoader")
def test_factory_dict_contains_loader_instances(self, mock_aniworld):
"""Factory dict values should be Loader instances."""
mock_instance = MagicMock(spec=Loader)
mock_aniworld.return_value = mock_instance
factory = Loaders()
for key, value in factory.dict.items():
assert isinstance(key, str)
class TestLoadersGetLoader:
"""Test GetLoader method."""
@patch("src.core.providers.provider_factory.AniworldLoader")
def test_get_loader_returns_registered_provider(self, mock_aniworld):
"""GetLoader should return provider for known key."""
mock_instance = MagicMock(spec=Loader)
mock_aniworld.return_value = mock_instance
factory = Loaders()
loader = factory.GetLoader("aniworld.to")
assert loader is mock_instance
@patch("src.core.providers.provider_factory.AniworldLoader")
def test_get_loader_raises_key_error_for_unknown(self, mock_aniworld):
"""GetLoader should raise KeyError for unknown provider key."""
mock_aniworld.return_value = MagicMock(spec=Loader)
factory = Loaders()
with pytest.raises(KeyError):
factory.GetLoader("nonexistent.provider")
@patch("src.core.providers.provider_factory.AniworldLoader")
def test_get_loader_returns_same_instance(self, mock_aniworld):
"""GetLoader should return same instance on repeated calls."""
mock_instance = MagicMock(spec=Loader)
mock_aniworld.return_value = mock_instance
factory = Loaders()
first = factory.GetLoader("aniworld.to")
second = factory.GetLoader("aniworld.to")
assert first is second
@patch("src.core.providers.provider_factory.AniworldLoader")
def test_get_loader_empty_key(self, mock_aniworld):
"""GetLoader should raise KeyError for empty string key."""
mock_aniworld.return_value = MagicMock(spec=Loader)
factory = Loaders()
with pytest.raises(KeyError):
factory.GetLoader("")
class TestLoadersProviderRegistry:
"""Test the provider registry within the factory."""
@patch("src.core.providers.provider_factory.AniworldLoader")
def test_registry_size(self, mock_aniworld):
"""Factory should have exactly one default provider."""
mock_aniworld.return_value = MagicMock(spec=Loader)
factory = Loaders()
assert len(factory.dict) == 1
@patch("src.core.providers.provider_factory.AniworldLoader")
def test_can_add_custom_provider(self, mock_aniworld):
"""Custom providers can be added to the factory registry."""
mock_aniworld.return_value = MagicMock(spec=Loader)
factory = Loaders()
custom_provider = MagicMock(spec=Loader)
factory.dict["custom.provider"] = custom_provider
assert factory.GetLoader("custom.provider") is custom_provider
@patch("src.core.providers.provider_factory.AniworldLoader")
def test_can_override_existing_provider(self, mock_aniworld):
"""Existing providers can be overridden in the registry."""
mock_aniworld.return_value = MagicMock(spec=Loader)
factory = Loaders()
new_provider = MagicMock(spec=Loader)
factory.dict["aniworld.to"] = new_provider
assert factory.GetLoader("aniworld.to") is new_provider
@patch("src.core.providers.provider_factory.AniworldLoader")
def test_multiple_factories_are_independent(self, mock_aniworld):
"""Multiple factory instances should have independent registries."""
mock_aniworld.return_value = MagicMock(spec=Loader)
factory1 = Loaders()
factory2 = Loaders()
factory1.dict["extra"] = MagicMock(spec=Loader)
assert "extra" not in factory2.dict