diff --git a/docs/instructions.md b/docs/instructions.md index 05e424a..3de7c14 100644 --- a/docs/instructions.md +++ b/docs/instructions.md @@ -118,106 +118,3 @@ For each task completed: --- ## TODO List: - -### High Priority - Test Failures (136 total) - -#### 1. TMDB API Resilience Tests (26 failures) - -**Location**: `tests/integration/test_tmdb_resilience.py`, `tests/unit/test_tmdb_rate_limiting.py` -**Issue**: `TypeError: 'coroutine' object does not support the asynchronous context manager protocol` -**Root cause**: Mock session.get() returns coroutine instead of async context manager -**Impact**: All TMDB API resilience and timeout tests failing - -- [ ] Fix mock setup in TMDB resilience tests -- [ ] Fix mock setup in TMDB rate limiting tests -- [ ] Ensure AsyncMock context managers are properly configured - -#### 2. Config Backup/Restore Tests (18 failures) - -**Location**: `tests/integration/test_config_backup_restore.py` -**Issue**: Authentication failures (401 Unauthorized) -**Root cause**: authenticated_client fixture not properly authenticating -**Affected tests**: - -- [ ] test_create_backup_with_default_name -- [ ] test_multiple_backups_can_be_created -- [ ] test_list_backups_returns_array -- [ ] test_list_backups_contains_metadata -- [ ] test_list_backups_shows_recently_created -- [ ] test_restore_nonexistent_backup_fails -- [ ] test_restore_backup_with_valid_backup -- [ ] test_restore_creates_backup_before_restoring -- [ ] test_restored_config_matches_backup -- [ ] test_delete_existing_backup -- [ ] test_delete_removes_backup_from_list -- [ ] test_delete_removes_backup_file -- [ ] test_delete_nonexistent_backup_fails -- [ ] test_full_backup_restore_workflow -- [ ] test_restore_with_invalid_backup_name -- [ ] test_concurrent_backup_operations -- [ ] test_backup_with_very_long_custom_name -- [ ] test_backup_preserves_all_configuration_sections - -#### 3. Background Loader Service Tests (10 failures) - -**Location**: `tests/integration/test_async_series_loading.py`, `tests/unit/test_background_loader_session.py`, `tests/integration/test_anime_add_nfo_isolation.py` -**Issues**: Service initialization, task processing, NFO loading - -- [ ] test_loader_start_stop - Fix worker_task vs worker_tasks attribute -- [ ] test_add_series_loading_task - Tasks not being added to active_tasks -- [ ] test_multiple_tasks_concurrent - Active tasks not being tracked -- [ ] test_no_duplicate_tasks - No tasks registered -- [ ] test_adding_tasks_is_fast - Active tasks empty -- [ ] test_load_series_data_loads_missing_episodes - \_load_episodes not called -- [ ] test_add_anime_loads_nfo_only_for_new_anime - NFO service not called -- [ ] test_add_anime_has_nfo_check_is_isolated - has_nfo check not called -- [ ] test_multiple_anime_added_each_loads_independently - NFO service call count wrong -- [ ] test_nfo_service_receives_correct_parameters - Call args is None - -#### 4. Performance Tests (4 failures) - -**Location**: `tests/performance/test_large_library.py`, `tests/performance/test_api_load.py` -**Issues**: Missing attributes, database not initialized, service not initialized - -- [ ] test_scanner_progress_reporting_1000_series - AttributeError: '\_SerieClass' missing -- [ ] test_database_query_performance_1000_series - Database not initialized -- [ ] test_concurrent_scan_prevention - get_anime_service() missing required argument -- [ ] test_health_endpoint_load - RPS too low (37.27 < 50 expected) - -#### 5. NFO Tracking Tests (4 failures) - -**Location**: `tests/unit/test_anime_service.py` -**Issue**: `TypeError: object MagicMock can't be used in 'await' expression` -**Root cause**: Database mocks not properly configured for async - -- [ ] test_update_nfo_status_success -- [ ] test_update_nfo_status_not_found -- [ ] test_get_series_without_nfo -- [ ] test_get_nfo_statistics - -#### 6. Concurrent Anime Add Tests (2 failures) - -**Location**: `tests/api/test_concurrent_anime_add.py` -**Issue**: `RuntimeError: BackgroundLoaderService not initialized` -**Root cause**: Service not initialized in test setup - -- [ ] test_concurrent_anime_add_requests -- [ ] test_same_anime_concurrent_add - -#### 7. Other Test Failures (3 failures) - -- [ ] test_get_database_session_handles_http_exception - Database not initialized -- [ ] test_anime_endpoint_returns_series_after_loading - Empty response (expects 2, got 0) - -### Summary - -- **Total failures**: 136 out of 2503 tests -- **Pass rate**: 94.6% -- **Main issues**: - 1. AsyncMock configuration for TMDB tests - 2. Authentication in backup/restore tests - 3. Background loader service lifecycle - 4. Database mock configuration for async operations - 5. Service initialization in tests - ---- diff --git a/tests/api/test_anime_endpoints.py b/tests/api/test_anime_endpoints.py index 8c96e4c..ffba2d1 100644 --- a/tests/api/test_anime_endpoints.py +++ b/tests/api/test_anime_endpoints.py @@ -1,6 +1,6 @@ """Tests for anime API endpoints.""" import asyncio -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock, patch import pytest from httpx import ASGITransport, AsyncClient @@ -414,3 +414,536 @@ async def test_add_series_special_characters_in_name(authenticated_client): invalid_chars = [':', '\\', '?', '*', '<', '>', '|', '"'] for char in invalid_chars: assert char not in folder_name, f"Found '{char}' in folder name for {name}" + + +# --------------------------------------------------------------------------- +# New tests: get_anime_status +# --------------------------------------------------------------------------- + + +class TestGetAnimeStatusEndpoint: + """Tests for GET /api/anime/status.""" + + @pytest.mark.asyncio + async def test_status_unauthorized(self): + """Status endpoint should require authentication.""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get("/api/anime/status") + assert response.status_code == 401 + + @pytest.mark.asyncio + async def test_status_authenticated(self, authenticated_client): + """Authenticated request returns directory and series count.""" + response = await authenticated_client.get("/api/anime/status") + assert response.status_code == 200 + data = response.json() + assert "directory" in data + assert "series_count" in data + assert isinstance(data["series_count"], int) + + def test_status_direct_no_series_app(self): + """When series_app is None, returns empty directory and 0 count.""" + result = asyncio.run(anime_module.get_anime_status(series_app=None)) + assert result["directory"] == "" + assert result["series_count"] == 0 + + +# --------------------------------------------------------------------------- +# New tests: list_anime authenticated +# --------------------------------------------------------------------------- + + +class TestListAnimeAuthenticated: + """Tests for GET /api/anime/ with authentication.""" + + @pytest.mark.asyncio + async def test_list_anime_returns_summaries(self, authenticated_client): + """Authenticated list returns anime summaries.""" + response = await authenticated_client.get("/api/anime/") + assert response.status_code == 200 + data = response.json() + assert isinstance(data, list) + + @pytest.mark.asyncio + async def test_list_anime_invalid_page(self, authenticated_client): + """Negative page number returns validation error.""" + response = await authenticated_client.get( + "/api/anime/", params={"page": -1} + ) + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_list_anime_per_page_too_large(self, authenticated_client): + """Per page > 1000 returns validation error.""" + response = await authenticated_client.get( + "/api/anime/", params={"per_page": 5000} + ) + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_list_anime_invalid_sort_by(self, authenticated_client): + """Invalid sort_by parameter returns validation error.""" + response = await authenticated_client.get( + "/api/anime/", params={"sort_by": "injection_attempt"} + ) + assert response.status_code == 422 + + @pytest.mark.asyncio + async def test_list_anime_valid_sort_by(self, authenticated_client): + """Valid sort_by parameter is accepted.""" + response = await authenticated_client.get( + "/api/anime/", params={"sort_by": "title"} + ) + assert response.status_code == 200 + + @pytest.mark.asyncio + async def test_list_anime_invalid_filter(self, authenticated_client): + """Invalid filter value returns validation error.""" + response = await authenticated_client.get( + "/api/anime/", params={"filter": "hacked"} + ) + assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# New tests: get_scan_status +# --------------------------------------------------------------------------- + + +class TestGetScanStatusEndpoint: + """Tests for GET /api/anime/scan/status.""" + + @pytest.mark.asyncio + async def test_scan_status_unauthorized(self): + """Scan status endpoint should require authentication.""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get("/api/anime/scan/status") + assert response.status_code == 401 + + @pytest.mark.asyncio + async def test_scan_status_authenticated(self, authenticated_client): + """Authenticated request returns scan status dict.""" + response = await authenticated_client.get("/api/anime/scan/status") + assert response.status_code == 200 + data = response.json() + assert isinstance(data, dict) + + +# --------------------------------------------------------------------------- +# New tests: _validate_search_query_extended +# --------------------------------------------------------------------------- + + +class TestValidateSearchQueryExtended: + """Tests for the internal _validate_search_query_extended function.""" + + def test_empty_query_raises(self): + """Empty string raises 422.""" + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + anime_module._validate_search_query_extended("") + assert exc_info.value.status_code == 422 + + def test_whitespace_only_raises(self): + """Whitespace-only raises 422.""" + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + anime_module._validate_search_query_extended(" ") + assert exc_info.value.status_code == 422 + + def test_null_bytes_raise(self): + """Null bytes in query raise 400.""" + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + anime_module._validate_search_query_extended("test\x00query") + assert exc_info.value.status_code == 400 + + def test_too_long_query_raises(self): + """Query exceeding 200 chars raises 422.""" + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + anime_module._validate_search_query_extended("a" * 201) + assert exc_info.value.status_code == 422 + + def test_valid_query_returns_string(self): + """Valid query is returned (possibly normalised).""" + result = anime_module._validate_search_query_extended("Naruto") + assert isinstance(result, str) + assert len(result) > 0 + + +# --------------------------------------------------------------------------- +# New tests: search_anime_post +# --------------------------------------------------------------------------- + + +class TestSearchAnimePost: + """Tests for POST /api/anime/search.""" + + @pytest.mark.asyncio + async def test_search_post_returns_results(self): + """POST search with valid query returns results.""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/api/anime/search", + json={"query": "test"}, + ) + assert response.status_code == 200 + assert isinstance(response.json(), list) + + @pytest.mark.asyncio + async def test_search_post_empty_query_rejected(self): + """POST search with empty query returns 422.""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.post( + "/api/anime/search", + json={"query": ""}, + ) + assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# New tests: _perform_search +# --------------------------------------------------------------------------- + + +class TestPerformSearch: + """Tests for the internal _perform_search function.""" + + @pytest.mark.asyncio + async def test_search_no_series_app(self): + """When series_app is None return empty list.""" + result = await anime_module._perform_search("test", None) + assert result == [] + + @pytest.mark.asyncio + async def test_search_dict_results(self): + """Dict-format results are converted to AnimeSummary.""" + mock_app = AsyncMock() + mock_app.search = AsyncMock( + return_value=[ + { + "key": "k1", + "title": "Title One", + "site": "aniworld.to", + "folder": "f1", + "link": "https://aniworld.to/anime/stream/k1", + "missing_episodes": {}, + } + ] + ) + results = await anime_module._perform_search("query", mock_app) + assert len(results) == 1 + assert results[0].key == "k1" + assert results[0].name == "Title One" + + @pytest.mark.asyncio + async def test_search_object_results(self): + """Object-format results (with attributes) are handled.""" + match = MagicMock(spec=[]) + match.key = "obj-key" + match.id = "" + match.title = "Object Title" + match.name = "Object Title" + match.site = "aniworld.to" + match.folder = "Object Folder" + match.link = "" + match.url = "" + match.missing_episodes = {} + mock_app = AsyncMock() + mock_app.search = AsyncMock(return_value=[match]) + results = await anime_module._perform_search("query", mock_app) + assert len(results) == 1 + assert results[0].key == "obj-key" + + @pytest.mark.asyncio + async def test_search_key_extracted_from_link(self): + """When key is empty, extract from link URL.""" + mock_app = AsyncMock() + mock_app.search = AsyncMock( + return_value=[ + { + "key": "", + "name": "No Key", + "site": "", + "folder": "", + "link": "https://aniworld.to/anime/stream/extracted-key", + "missing_episodes": {}, + } + ] + ) + results = await anime_module._perform_search("q", mock_app) + assert results[0].key == "extracted-key" + + @pytest.mark.asyncio + async def test_search_exception_raises_500(self): + """Non-HTTP exception in search raises 500.""" + from fastapi import HTTPException + + mock_app = AsyncMock() + mock_app.search = AsyncMock(side_effect=RuntimeError("boom")) + with pytest.raises(HTTPException) as exc_info: + await anime_module._perform_search("q", mock_app) + assert exc_info.value.status_code == 500 + + +# --------------------------------------------------------------------------- +# New tests: get_loading_status +# --------------------------------------------------------------------------- + + +class TestGetLoadingStatusEndpoint: + """Tests for GET /api/anime/{key}/loading-status.""" + + @pytest.mark.asyncio + async def test_loading_status_unauthorized(self): + """Loading status requires authentication.""" + transport = ASGITransport(app=app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + response = await client.get("/api/anime/some-key/loading-status") + assert response.status_code == 401 + + @pytest.mark.asyncio + async def test_loading_status_no_db(self, authenticated_client): + """Without database session returns 503.""" + response = await authenticated_client.get( + "/api/anime/some-key/loading-status" + ) + # get_optional_database_session may return None → 503 + assert response.status_code in (503, 404, 500) + + def test_loading_status_direct_no_db(self): + """Direct call with db=None raises 503.""" + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + asyncio.run(anime_module.get_loading_status("key", db=None)) + assert exc_info.value.status_code == 503 + + def test_loading_status_not_found(self): + """Direct call with unknown key raises 404.""" + from fastapi import HTTPException + + mock_db = AsyncMock() + + async def _run(): + with patch( + "src.server.database.service.AnimeSeriesService" + ) as mock_svc: + mock_svc.get_by_key = AsyncMock(return_value=None) + return await anime_module.get_loading_status( + "missing-key", db=mock_db + ) + + with pytest.raises(HTTPException) as exc_info: + asyncio.run(_run()) + assert exc_info.value.status_code == 404 + + def test_loading_status_pending(self): + """Direct call returns correct pending status payload.""" + mock_db = AsyncMock() + series_row = MagicMock() + series_row.key = "test-key" + series_row.loading_status = "pending" + series_row.episodes_loaded = False + series_row.has_nfo = False + series_row.logo_loaded = False + series_row.images_loaded = False + series_row.loading_started_at = None + series_row.loading_completed_at = None + series_row.loading_error = None + + async def _run(): + with patch( + "src.server.database.service.AnimeSeriesService" + ) as mock_svc: + mock_svc.get_by_key = AsyncMock(return_value=series_row) + return await anime_module.get_loading_status( + "test-key", db=mock_db + ) + + result = asyncio.run(_run()) + assert result["key"] == "test-key" + assert result["loading_status"] == "pending" + assert "Queued" in result["message"] + assert result["progress"]["episodes"] is False + + def test_loading_status_completed(self): + """Completed status returns correct message.""" + from datetime import datetime + + mock_db = AsyncMock() + series_row = MagicMock() + series_row.key = "done-key" + series_row.loading_status = "completed" + series_row.episodes_loaded = True + series_row.has_nfo = True + series_row.logo_loaded = True + series_row.images_loaded = True + series_row.loading_started_at = datetime(2025, 1, 1) + series_row.loading_completed_at = datetime(2025, 1, 1, 0, 5) + series_row.loading_error = None + + async def _run(): + with patch( + "src.server.database.service.AnimeSeriesService" + ) as mock_svc: + mock_svc.get_by_key = AsyncMock(return_value=series_row) + return await anime_module.get_loading_status( + "done-key", db=mock_db + ) + + result = asyncio.run(_run()) + assert result["loading_status"] == "completed" + assert "successfully" in result["message"] + assert result["progress"]["episodes"] is True + assert result["completed_at"] is not None + + +# --------------------------------------------------------------------------- +# New tests: get_anime detail +# --------------------------------------------------------------------------- + + +class TestGetAnimeDetail: + """Tests for GET /api/anime/{anime_id} detail endpoint.""" + + def test_get_anime_by_key(self): + """Primary lookup by key returns correct detail.""" + fake = FakeSeriesApp() + result = asyncio.run( + anime_module.get_anime("test-show-key", series_app=fake) + ) + assert result.key == "test-show-key" + assert result.title == "Test Show" + + def test_get_anime_by_folder_fallback(self): + """Folder-based lookup works as deprecated fallback.""" + fake = FakeSeriesApp() + result = asyncio.run( + anime_module.get_anime("Test Show (2023)", series_app=fake) + ) + assert result.key == "test-show-key" + + def test_get_anime_not_found(self): + """Unknown anime_id raises 404.""" + from fastapi import HTTPException + + fake = FakeSeriesApp() + with pytest.raises(HTTPException) as exc_info: + asyncio.run( + anime_module.get_anime("nonexistent", series_app=fake) + ) + assert exc_info.value.status_code == 404 + + def test_get_anime_no_series_app(self): + """None series_app raises 404.""" + from fastapi import HTTPException + + with pytest.raises(HTTPException) as exc_info: + asyncio.run( + anime_module.get_anime("any-id", series_app=None) + ) + assert exc_info.value.status_code == 404 + + def test_get_anime_episodes_formatted(self): + """Episode dict is converted to season-episode strings.""" + fake = FakeSeriesApp() + result = asyncio.run( + anime_module.get_anime("test-show-key", series_app=fake) + ) + assert "1-1" in result.episodes + assert "1-2" in result.episodes + + def test_get_anime_complete_show_no_episodes(self): + """Complete show with empty episodeDict returns empty episodes list.""" + fake = FakeSeriesApp() + result = asyncio.run( + anime_module.get_anime("complete-show-key", series_app=fake) + ) + assert result.episodes == [] + + +# --------------------------------------------------------------------------- +# New tests: trigger_rescan authenticated +# --------------------------------------------------------------------------- + + +class TestTriggerRescanAuthenticated: + """Tests for POST /api/anime/rescan with authentication.""" + + @pytest.mark.asyncio + async def test_rescan_authenticated(self, authenticated_client): + """Authenticated rescan returns success.""" + from src.server.services.anime_service import AnimeService + from src.server.utils.dependencies import get_anime_service + + mock_service = AsyncMock(spec=AnimeService) + mock_service.rescan = AsyncMock() + app.dependency_overrides[get_anime_service] = lambda: mock_service + try: + response = await authenticated_client.post("/api/anime/rescan") + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + mock_service.rescan.assert_called_once() + finally: + app.dependency_overrides.pop(get_anime_service, None) + + def test_rescan_service_error(self): + """AnimeServiceError is converted to ServerError.""" + from src.server.services.anime_service import AnimeServiceError + + mock_service = AsyncMock() + mock_service.rescan = AsyncMock( + side_effect=AnimeServiceError("scan failed") + ) + + from src.server.exceptions import ServerError + + with pytest.raises(ServerError): + asyncio.run( + anime_module.trigger_rescan(anime_service=mock_service) + ) + + +# --------------------------------------------------------------------------- +# New tests: search_anime_get additional +# --------------------------------------------------------------------------- + + +class TestSearchAnimeGetAdditional: + """Additional tests for GET /api/anime/search.""" + + @pytest.mark.asyncio + async def test_search_get_with_query(self): + """Search GET with valid query returns list.""" + transport = ASGITransport(app=app) + async with AsyncClient( + transport=transport, base_url="http://test" + ) as client: + response = await client.get( + "/api/anime/search", params={"query": "naruto"} + ) + assert response.status_code == 200 + assert isinstance(response.json(), list) + + @pytest.mark.asyncio + async def test_search_get_null_byte_query(self): + """Search GET with null byte in query returns 400.""" + transport = ASGITransport(app=app) + async with AsyncClient( + transport=transport, base_url="http://test" + ) as client: + response = await client.get( + "/api/anime/search", params={"query": "test\x00bad"} + ) + assert response.status_code == 400 diff --git a/tests/frontend/test_existing_ui_integration.py b/tests/frontend/test_existing_ui_integration.py index 1b8a39b..74c56f7 100644 --- a/tests/frontend/test_existing_ui_integration.py +++ b/tests/frontend/test_existing_ui_integration.py @@ -158,17 +158,17 @@ class TestFrontendAuthentication: async def test_authenticated_request_succeeds(self, authenticated_client): """Test that requests with valid token succeed.""" - with patch("src.server.utils.dependencies.get_series_app") as mock_get_app: - mock_app = AsyncMock() - mock_list = AsyncMock() - mock_list.GetMissingEpisode = AsyncMock(return_value=[]) - mock_list.GetList = AsyncMock(return_value=[]) - mock_app.List = mock_list - mock_get_app.return_value = mock_app - + mock_anime_service = AsyncMock() + mock_anime_service.list_series_with_filters = AsyncMock(return_value=[]) + + from src.server.utils.dependencies import get_anime_service + app.dependency_overrides[get_anime_service] = lambda: mock_anime_service + + try: response = await authenticated_client.get("/api/anime/") - assert response.status_code == 200 + finally: + app.dependency_overrides.pop(get_anime_service, None) class TestFrontendAnimeAPI: diff --git a/tests/unit/test_anime_service.py b/tests/unit/test_anime_service.py index 9f68328..acd89a0 100644 --- a/tests/unit/test_anime_service.py +++ b/tests/unit/test_anime_service.py @@ -1,16 +1,23 @@ """Unit tests for AnimeService. Tests cover service initialization, async operations, caching, -error handling, and progress reporting integration. +error handling, progress reporting integration, scan/download status +event handling, database persistence, and WebSocket broadcasting. """ from __future__ import annotations import asyncio +import time +from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest -from src.server.services.anime_service import AnimeService, AnimeServiceError +from src.server.services.anime_service import ( + AnimeService, + AnimeServiceError, + sync_series_from_data_files, +) from src.server.services.progress_service import ProgressService @@ -472,3 +479,965 @@ class TestFactoryFunction: assert isinstance(service, AnimeService) assert service._app is mock_series_app + + +# ============================================================================= +# New coverage tests – download / scan status, DB persistence, broadcasting +# ============================================================================= + + +class _FakeDownloadArgs: + """Minimal stand-in for DownloadStatusEventArgs.""" + + def __init__(self, **kwargs): + self.status = kwargs.get("status", "started") + self.serie_folder = kwargs.get("serie_folder", "TestFolder") + self.season = kwargs.get("season", 1) + self.episode = kwargs.get("episode", 1) + self.item_id = kwargs.get("item_id", None) + self.progress = kwargs.get("progress", 0) + self.message = kwargs.get("message", None) + self.error = kwargs.get("error", None) + self.mbper_sec = kwargs.get("mbper_sec", None) + self.eta = kwargs.get("eta", None) + + +class _FakeScanArgs: + """Minimal stand-in for ScanStatusEventArgs.""" + + def __init__(self, **kwargs): + self.status = kwargs.get("status", "started") + self.current = kwargs.get("current", 0) + self.total = kwargs.get("total", 10) + self.folder = kwargs.get("folder", "") + self.message = kwargs.get("message", None) + self.error = kwargs.get("error", None) + + +class TestOnDownloadStatus: + """Test _on_download_status event handler.""" + + def test_download_started_schedules_start_progress( + self, anime_service, mock_progress_service + ): + """started event should schedule start_progress.""" + loop = asyncio.new_event_loop() + anime_service._event_loop = loop + try: + with patch("asyncio.get_running_loop", side_effect=RuntimeError): + with patch("asyncio.run_coroutine_threadsafe") as mock_run: + args = _FakeDownloadArgs( + status="started", item_id="q-1" + ) + anime_service._on_download_status(args) + mock_run.assert_called_once() + coro = mock_run.call_args[0][0] + assert coro is not None + finally: + loop.close() + + def test_download_progress_schedules_update( + self, anime_service, mock_progress_service + ): + """progress event should schedule update_progress.""" + loop = asyncio.new_event_loop() + anime_service._event_loop = loop + try: + with patch("asyncio.get_running_loop", side_effect=RuntimeError): + with patch("asyncio.run_coroutine_threadsafe") as mock_run: + args = _FakeDownloadArgs( + status="progress", + progress=42, + message="Downloading...", + mbper_sec=5.5, + eta=30, + ) + anime_service._on_download_status(args) + mock_run.assert_called_once() + finally: + loop.close() + + def test_download_completed_schedules_complete( + self, anime_service, mock_progress_service + ): + """completed event should schedule complete_progress.""" + loop = asyncio.new_event_loop() + anime_service._event_loop = loop + try: + with patch("asyncio.get_running_loop", side_effect=RuntimeError): + with patch("asyncio.run_coroutine_threadsafe") as mock_run: + args = _FakeDownloadArgs(status="completed") + anime_service._on_download_status(args) + mock_run.assert_called_once() + finally: + loop.close() + + def test_download_failed_schedules_fail( + self, anime_service, mock_progress_service + ): + """failed event should schedule fail_progress.""" + loop = asyncio.new_event_loop() + anime_service._event_loop = loop + try: + with patch("asyncio.get_running_loop", side_effect=RuntimeError): + with patch("asyncio.run_coroutine_threadsafe") as mock_run: + args = _FakeDownloadArgs( + status="failed", error=Exception("Err") + ) + anime_service._on_download_status(args) + mock_run.assert_called_once() + finally: + loop.close() + + def test_progress_id_from_item_id(self, anime_service): + """item_id should be used as progress_id when available.""" + loop = asyncio.new_event_loop() + anime_service._event_loop = loop + try: + with patch("asyncio.get_running_loop", side_effect=RuntimeError): + with patch("asyncio.run_coroutine_threadsafe") as mock_run: + args = _FakeDownloadArgs( + status="started", item_id="queue-42" + ) + anime_service._on_download_status(args) + coro = mock_run.call_args[0][0] + # The coroutine was created with progress_id="queue-42" + assert mock_run.called + finally: + loop.close() + + def test_progress_id_fallback_without_item_id(self, anime_service): + """Without item_id, progress_id is built from folder/season/episode.""" + loop = asyncio.new_event_loop() + anime_service._event_loop = loop + try: + with patch("asyncio.get_running_loop", side_effect=RuntimeError): + with patch("asyncio.run_coroutine_threadsafe") as mock_run: + args = _FakeDownloadArgs( + status="started", + item_id=None, + serie_folder="FolderX", + season=2, + episode=5, + ) + anime_service._on_download_status(args) + assert mock_run.called + finally: + loop.close() + + def test_no_event_loop_returns_silently(self, anime_service): + """No loop available should not raise.""" + anime_service._event_loop = None + with patch("asyncio.get_running_loop", side_effect=RuntimeError): + args = _FakeDownloadArgs(status="started") + anime_service._on_download_status(args) # should not raise + + +class TestOnScanStatus: + """Test _on_scan_status event handler.""" + + def test_scan_started_schedules_progress_and_broadcast( + self, anime_service, mock_progress_service + ): + """started scan event should schedule start_progress and broadcast.""" + loop = asyncio.new_event_loop() + anime_service._event_loop = loop + try: + with patch("asyncio.get_running_loop", side_effect=RuntimeError): + with patch("asyncio.run_coroutine_threadsafe") as mock_run: + args = _FakeScanArgs(status="started", total=5) + anime_service._on_scan_status(args) + # 2 calls: start_progress + broadcast_scan_started_safe + assert mock_run.call_count == 2 + assert anime_service._is_scanning is True + finally: + loop.close() + + def test_scan_progress_updates_counters( + self, anime_service, mock_progress_service + ): + """progress scan event should update counters.""" + loop = asyncio.new_event_loop() + anime_service._event_loop = loop + try: + with patch("asyncio.get_running_loop", side_effect=RuntimeError): + with patch("asyncio.run_coroutine_threadsafe"): + args = _FakeScanArgs( + status="progress", current=3, total=10, + folder="Naruto" + ) + anime_service._on_scan_status(args) + assert anime_service._scan_directories_count == 3 + assert anime_service._scan_current_directory == "Naruto" + finally: + loop.close() + + def test_scan_completed_marks_done( + self, anime_service, mock_progress_service + ): + """completed scan event should mark scanning as False.""" + loop = asyncio.new_event_loop() + anime_service._event_loop = loop + anime_service._is_scanning = True + anime_service._scan_start_time = time.time() - 5 + try: + with patch("asyncio.get_running_loop", side_effect=RuntimeError): + with patch("asyncio.run_coroutine_threadsafe"): + args = _FakeScanArgs(status="completed", total=10) + anime_service._on_scan_status(args) + assert anime_service._is_scanning is False + finally: + loop.close() + + def test_scan_failed_marks_done( + self, anime_service, mock_progress_service + ): + """failed scan event should reset scanning state.""" + loop = asyncio.new_event_loop() + anime_service._event_loop = loop + anime_service._is_scanning = True + try: + with patch("asyncio.get_running_loop", side_effect=RuntimeError): + with patch("asyncio.run_coroutine_threadsafe"): + args = _FakeScanArgs( + status="failed", error=Exception("boom") + ) + anime_service._on_scan_status(args) + assert anime_service._is_scanning is False + finally: + loop.close() + + def test_scan_cancelled_marks_done( + self, anime_service, mock_progress_service + ): + """cancelled scan event should reset scanning state.""" + loop = asyncio.new_event_loop() + anime_service._event_loop = loop + anime_service._is_scanning = True + try: + with patch("asyncio.get_running_loop", side_effect=RuntimeError): + with patch("asyncio.run_coroutine_threadsafe"): + args = _FakeScanArgs(status="cancelled") + anime_service._on_scan_status(args) + assert anime_service._is_scanning is False + finally: + loop.close() + + def test_scan_no_loop_returns_silently(self, anime_service): + """No loop available should not raise for scan events.""" + anime_service._event_loop = None + with patch("asyncio.get_running_loop", side_effect=RuntimeError): + args = _FakeScanArgs(status="started") + anime_service._on_scan_status(args) # no error + + +class TestGetScanStatus: + """Test get_scan_status method.""" + + def test_returns_status_dict(self, anime_service): + """Should return dict with expected keys.""" + anime_service._is_scanning = True + anime_service._scan_total_items = 42 + anime_service._scan_directories_count = 7 + anime_service._scan_current_directory = "Naruto" + result = anime_service.get_scan_status() + assert result["is_scanning"] is True + assert result["total_items"] == 42 + assert result["directories_scanned"] == 7 + assert result["current_directory"] == "Naruto" + + +class TestBroadcastHelpers: + """Test WebSocket broadcast safety wrappers.""" + + @pytest.mark.asyncio + async def test_broadcast_scan_started_safe(self, anime_service): + """Should call websocket_service.broadcast_scan_started.""" + anime_service._websocket_service.broadcast_scan_started = AsyncMock() + await anime_service._broadcast_scan_started_safe(total_items=5) + anime_service._websocket_service.broadcast_scan_started.assert_called_once() + + @pytest.mark.asyncio + async def test_broadcast_scan_started_safe_handles_error( + self, anime_service + ): + """WS failure should be swallowed, not raised.""" + anime_service._websocket_service.broadcast_scan_started = AsyncMock( + side_effect=Exception("ws-down") + ) + # Should NOT raise + await anime_service._broadcast_scan_started_safe(total_items=5) + + @pytest.mark.asyncio + async def test_broadcast_scan_progress_safe(self, anime_service): + """Should call broadcast_scan_progress.""" + anime_service._websocket_service.broadcast_scan_progress = AsyncMock() + await anime_service._broadcast_scan_progress_safe( + directories_scanned=3, files_found=3, + current_directory="AOT", total_items=10, + ) + anime_service._websocket_service.broadcast_scan_progress.assert_called_once() + + @pytest.mark.asyncio + async def test_broadcast_scan_progress_safe_handles_error( + self, anime_service + ): + """WS failure should be swallowed.""" + anime_service._websocket_service.broadcast_scan_progress = AsyncMock( + side_effect=Exception("ws-down") + ) + await anime_service._broadcast_scan_progress_safe( + directories_scanned=0, files_found=0, + current_directory="", total_items=0, + ) + + @pytest.mark.asyncio + async def test_broadcast_scan_completed_safe(self, anime_service): + """Should call broadcast_scan_completed.""" + anime_service._websocket_service.broadcast_scan_completed = AsyncMock() + await anime_service._broadcast_scan_completed_safe( + total_directories=10, total_files=10, elapsed_seconds=5.0, + ) + anime_service._websocket_service.broadcast_scan_completed.assert_called_once() + + @pytest.mark.asyncio + async def test_broadcast_scan_completed_safe_handles_error( + self, anime_service + ): + """WS failure should be swallowed.""" + anime_service._websocket_service.broadcast_scan_completed = AsyncMock( + side_effect=Exception("ws-down") + ) + await anime_service._broadcast_scan_completed_safe( + total_directories=0, total_files=0, elapsed_seconds=0, + ) + + @pytest.mark.asyncio + async def test_broadcast_series_updated(self, anime_service): + """Should broadcast series_updated over WebSocket.""" + anime_service._websocket_service.broadcast = AsyncMock() + await anime_service._broadcast_series_updated("aot") + anime_service._websocket_service.broadcast.assert_called_once() + payload = anime_service._websocket_service.broadcast.call_args[0][0] + assert payload["type"] == "series_updated" + assert payload["key"] == "aot" + + @pytest.mark.asyncio + async def test_broadcast_series_updated_no_ws_service(self, anime_service): + """Should return silently if no websocket service.""" + anime_service._websocket_service = None + await anime_service._broadcast_series_updated("aot") # no error + + +class TestListSeriesWithFilters: + """Test list_series_with_filters with database enrichment.""" + + @pytest.mark.asyncio + async def test_returns_enriched_list( + self, anime_service, mock_series_app + ): + """Should merge SeriesApp data with DB metadata.""" + mock_serie = MagicMock() + mock_serie.key = "aot" + mock_serie.name = "Attack on Titan" + mock_serie.site = "aniworld.to" + mock_serie.folder = "Attack on Titan (2013)" + mock_serie.episodeDict = {1: [2, 3]} + + mock_list = MagicMock() + mock_list.GetList.return_value = [mock_serie] + mock_series_app.list = mock_list + + mock_db_series = MagicMock() + mock_db_series.folder = "Attack on Titan (2013)" + mock_db_series.has_nfo = True + mock_db_series.nfo_created_at = None + mock_db_series.nfo_updated_at = None + mock_db_series.tmdb_id = 1234 + mock_db_series.tvdb_id = None + mock_db_series.id = 1 + + mock_session = AsyncMock() + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_session) + mock_ctx.__aexit__ = AsyncMock(return_value=False) + + with patch( + "src.server.database.connection.get_db_session", + return_value=mock_ctx, + ), patch( + "src.server.database.service.AnimeSeriesService" + ) as MockASS: + MockASS.get_all = AsyncMock(return_value=[mock_db_series]) + result = await anime_service.list_series_with_filters() + + assert len(result) == 1 + assert result[0]["key"] == "aot" + assert result[0]["has_nfo"] is True + assert result[0]["tmdb_id"] == 1234 + + @pytest.mark.asyncio + async def test_empty_series_returns_empty( + self, anime_service, mock_series_app + ): + """Should return [] when SeriesApp has no series.""" + mock_list = MagicMock() + mock_list.GetList.return_value = [] + mock_series_app.list = mock_list + + result = await anime_service.list_series_with_filters() + assert result == [] + + @pytest.mark.asyncio + async def test_no_list_attribute_returns_empty( + self, anime_service, mock_series_app + ): + """Should return [] when SeriesApp has no list attribute.""" + del mock_series_app.list + result = await anime_service.list_series_with_filters() + assert result == [] + + @pytest.mark.asyncio + async def test_db_error_raises_anime_service_error( + self, anime_service, mock_series_app + ): + """DB failure should raise AnimeServiceError.""" + mock_serie = MagicMock() + mock_serie.key = "aot" + mock_serie.name = "AOT" + mock_serie.site = "x" + mock_serie.folder = "AOT" + mock_serie.episodeDict = {} + + mock_list = MagicMock() + mock_list.GetList.return_value = [mock_serie] + mock_series_app.list = mock_list + + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock( + side_effect=RuntimeError("DB down") + ) + mock_ctx.__aexit__ = AsyncMock(return_value=False) + + with patch( + "src.server.database.connection.get_db_session", + return_value=mock_ctx, + ): + with pytest.raises(AnimeServiceError): + await anime_service.list_series_with_filters() + + +class TestSaveAndLoadDB: + """Test DB persistence helpers.""" + + @pytest.mark.asyncio + async def test_save_scan_results_creates_new( + self, anime_service + ): + """New series should be created in DB.""" + mock_serie = MagicMock() + mock_serie.key = "naruto" + mock_serie.name = "Naruto" + mock_serie.site = "aniworld.to" + mock_serie.folder = "Naruto" + mock_serie.year = 2002 + mock_serie.episodeDict = {1: [1, 2]} + + mock_session = AsyncMock() + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_session) + mock_ctx.__aexit__ = AsyncMock(return_value=False) + + with patch( + "src.server.database.connection.get_db_session", + return_value=mock_ctx, + ), patch( + "src.server.database.service.AnimeSeriesService.get_by_key", + new_callable=AsyncMock, + return_value=None, + ), patch( + "src.server.database.service.AnimeSeriesService.create", + new_callable=AsyncMock, + return_value=MagicMock(id=1), + ) as mock_create, patch( + "src.server.database.service.EpisodeService.create", + new_callable=AsyncMock, + ) as mock_ep_create: + count = await anime_service._save_scan_results_to_db( + [mock_serie] + ) + assert count == 1 + mock_create.assert_called_once() + assert mock_ep_create.call_count == 2 + + @pytest.mark.asyncio + async def test_save_scan_results_updates_existing( + self, anime_service + ): + """Existing series should be updated in DB.""" + mock_serie = MagicMock() + mock_serie.key = "naruto" + mock_serie.name = "Naruto" + mock_serie.site = "aniworld.to" + mock_serie.folder = "Naruto" + mock_serie.episodeDict = {1: [3]} + + existing = MagicMock() + existing.id = 1 + existing.folder = "Naruto" + + mock_session = AsyncMock() + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_session) + mock_ctx.__aexit__ = AsyncMock(return_value=False) + + with patch( + "src.server.database.connection.get_db_session", + return_value=mock_ctx, + ), patch( + "src.server.database.service.AnimeSeriesService.get_by_key", + new_callable=AsyncMock, + return_value=existing, + ), patch.object( + anime_service, + "_update_series_in_db", + new_callable=AsyncMock, + ) as mock_update: + count = await anime_service._save_scan_results_to_db( + [mock_serie] + ) + assert count == 1 + mock_update.assert_called_once() + + @pytest.mark.asyncio + async def test_load_series_from_db( + self, anime_service, mock_series_app + ): + """Should populate SeriesApp from DB records.""" + mock_ep = MagicMock() + mock_ep.season = 1 + mock_ep.episode_number = 5 + + mock_db_series = MagicMock() + mock_db_series.key = "naruto" + mock_db_series.name = "Naruto" + mock_db_series.site = "aniworld.to" + mock_db_series.folder = "Naruto" + mock_db_series.episodes = [mock_ep] + + mock_session = AsyncMock() + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_session) + mock_ctx.__aexit__ = AsyncMock(return_value=False) + + with patch( + "src.server.database.connection.get_db_session", + return_value=mock_ctx, + ), patch( + "src.server.database.service.AnimeSeriesService.get_all", + new_callable=AsyncMock, + return_value=[mock_db_series], + ): + await anime_service._load_series_from_db() + + mock_series_app.load_series_from_list.assert_called_once() + loaded = mock_series_app.load_series_from_list.call_args[0][0] + assert len(loaded) == 1 + assert loaded[0].key == "naruto" + + @pytest.mark.asyncio + async def test_sync_episodes_to_db( + self, anime_service, mock_series_app + ): + """Should sync missing episodes from memory to DB.""" + mock_serie = MagicMock() + mock_serie.episodeDict = {1: [4, 5]} + + mock_list = MagicMock() + mock_list.keyDict = {"aot": mock_serie} + mock_series_app.list = mock_list + + mock_db_series = MagicMock() + mock_db_series.id = 10 + + mock_session = AsyncMock() + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_session) + mock_ctx.__aexit__ = AsyncMock(return_value=False) + + anime_service._websocket_service = MagicMock() + anime_service._websocket_service.broadcast = AsyncMock() + + with patch( + "src.server.database.connection.get_db_session", + return_value=mock_ctx, + ), patch( + "src.server.database.service.AnimeSeriesService.get_by_key", + new_callable=AsyncMock, + return_value=mock_db_series, + ), patch( + "src.server.database.service.EpisodeService.get_by_series", + new_callable=AsyncMock, + return_value=[], + ), patch( + "src.server.database.service.EpisodeService.create", + new_callable=AsyncMock, + ) as mock_ep_create: + count = await anime_service.sync_episodes_to_db("aot") + + assert count == 2 + assert mock_ep_create.call_count == 2 + + @pytest.mark.asyncio + async def test_sync_episodes_no_list_returns_zero( + self, anime_service, mock_series_app + ): + """No series list should return 0.""" + del mock_series_app.list + count = await anime_service.sync_episodes_to_db("aot") + assert count == 0 + + +class TestAddSeriesToDB: + """Test add_series_to_db method.""" + + @pytest.mark.asyncio + async def test_creates_new_series(self, anime_service): + """New series should be created in DB.""" + mock_serie = MagicMock() + mock_serie.key = "x" + mock_serie.name = "X" + mock_serie.site = "y" + mock_serie.folder = "X" + mock_serie.year = 2020 + mock_serie.episodeDict = {1: [1]} + + mock_db = AsyncMock() + mock_created = MagicMock(id=99) + + with patch( + "src.server.database.service.AnimeSeriesService.get_by_key", + new_callable=AsyncMock, + return_value=None, + ), patch( + "src.server.database.service.AnimeSeriesService.create", + new_callable=AsyncMock, + return_value=mock_created, + ), patch( + "src.server.database.service.EpisodeService.create", + new_callable=AsyncMock, + ): + result = await anime_service.add_series_to_db(mock_serie, mock_db) + assert result is mock_created + + @pytest.mark.asyncio + async def test_existing_returns_none(self, anime_service): + """Already-existing series should return None.""" + mock_serie = MagicMock() + mock_serie.key = "x" + mock_serie.name = "X" + mock_db = AsyncMock() + + with patch( + "src.server.database.service.AnimeSeriesService.get_by_key", + new_callable=AsyncMock, + return_value=MagicMock(), + ): + result = await anime_service.add_series_to_db(mock_serie, mock_db) + assert result is None + + +class TestContainsInDB: + """Test contains_in_db method.""" + + @pytest.mark.asyncio + async def test_exists(self, anime_service): + """Should return True when series exists.""" + mock_db = AsyncMock() + with patch( + "src.server.database.service.AnimeSeriesService.get_by_key", + new_callable=AsyncMock, + return_value=MagicMock(), + ): + assert await anime_service.contains_in_db("aot", mock_db) is True + + @pytest.mark.asyncio + async def test_not_exists(self, anime_service): + """Should return False when series missing.""" + mock_db = AsyncMock() + with patch( + "src.server.database.service.AnimeSeriesService.get_by_key", + new_callable=AsyncMock, + return_value=None, + ): + assert await anime_service.contains_in_db("x", mock_db) is False + + +class TestUpdateNFOStatusWithoutSession: + """Test update_nfo_status when no db session is passed (self-managed).""" + + @pytest.mark.asyncio + async def test_update_creates_session_and_commits(self, anime_service): + """Should open its own session and commit.""" + mock_series = MagicMock() + mock_series.id = 1 + mock_series.has_nfo = False + mock_series.nfo_created_at = None + mock_series.nfo_updated_at = None + + mock_session = AsyncMock() + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_session) + mock_ctx.__aexit__ = AsyncMock(return_value=False) + + with patch( + "src.server.database.connection.get_db_session", + return_value=mock_ctx, + ), patch( + "src.server.database.service.AnimeSeriesService.get_by_key", + new_callable=AsyncMock, + return_value=mock_series, + ), patch( + "src.server.database.service.AnimeSeriesService.update", + new_callable=AsyncMock, + ): + await anime_service.update_nfo_status( + key="test", has_nfo=True, tmdb_id=42 + ) + + # commit called by update path + mock_session.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_update_not_found_skips(self, anime_service): + """Should return without error if series not in DB.""" + mock_session = AsyncMock() + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_session) + mock_ctx.__aexit__ = AsyncMock(return_value=False) + + with patch( + "src.server.database.connection.get_db_session", + return_value=mock_ctx, + ), patch( + "src.server.database.service.AnimeSeriesService.get_by_key", + new_callable=AsyncMock, + return_value=None, + ): + await anime_service.update_nfo_status(key="missing", has_nfo=True) + + mock_session.commit.assert_not_called() + + +class TestGetSeriesWithoutNFOSelfManaged: + """Test get_series_without_nfo when db=None (self-managed session).""" + + @pytest.mark.asyncio + async def test_returns_list(self, anime_service): + """Should return formatted dicts.""" + mock_s = MagicMock() + mock_s.key = "test" + mock_s.name = "Test" + mock_s.folder = "Test" + mock_s.tmdb_id = 1 + mock_s.tvdb_id = 2 + + mock_session = AsyncMock() + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_session) + mock_ctx.__aexit__ = AsyncMock(return_value=False) + + with patch( + "src.server.database.connection.get_db_session", + return_value=mock_ctx, + ), patch( + "src.server.database.service.AnimeSeriesService" + ".get_series_without_nfo", + new_callable=AsyncMock, + return_value=[mock_s], + ): + result = await anime_service.get_series_without_nfo() + + assert len(result) == 1 + assert result[0]["has_nfo"] is False + + +class TestGetNFOStatisticsSelfManaged: + """Test get_nfo_statistics when db=None (self-managed session).""" + + @pytest.mark.asyncio + async def test_returns_stats(self, anime_service): + """Should compute statistics correctly.""" + mock_session = AsyncMock() + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_session) + mock_ctx.__aexit__ = AsyncMock(return_value=False) + + with patch( + "src.server.database.connection.get_db_session", + return_value=mock_ctx, + ), patch( + "src.server.database.service.AnimeSeriesService.count_all", + new_callable=AsyncMock, + return_value=50, + ), patch( + "src.server.database.service.AnimeSeriesService.count_with_nfo", + new_callable=AsyncMock, + return_value=30, + ), patch( + "src.server.database.service.AnimeSeriesService" + ".count_with_tmdb_id", + new_callable=AsyncMock, + return_value=40, + ), patch( + "src.server.database.service.AnimeSeriesService" + ".count_with_tvdb_id", + new_callable=AsyncMock, + return_value=20, + ): + result = await anime_service.get_nfo_statistics() + + assert result["total"] == 50 + assert result["without_nfo"] == 20 + assert result["with_tmdb_id"] == 40 + + +class TestSyncSeriesFromDataFiles: + """Test module-level sync_series_from_data_files function.""" + + @pytest.mark.asyncio + async def test_sync_adds_new_series(self, tmp_path): + """Should create series for data files not in DB.""" + mock_serie = MagicMock() + mock_serie.key = "new-series" + mock_serie.name = "New Series" + mock_serie.site = "aniworld.to" + mock_serie.folder = "New Series" + mock_serie.episodeDict = {1: [1]} + + mock_session = AsyncMock() + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_session) + mock_ctx.__aexit__ = AsyncMock(return_value=False) + + with patch( + "src.server.services.anime_service.SeriesApp" + ) as MockApp, patch( + "src.server.database.connection.get_db_session", + return_value=mock_ctx, + ), patch( + "src.server.database.service.AnimeSeriesService.get_by_key", + new_callable=AsyncMock, + return_value=None, + ), patch( + "src.server.database.service.AnimeSeriesService.create", + new_callable=AsyncMock, + return_value=MagicMock(id=1), + ) as mock_create, patch( + "src.server.database.service.EpisodeService.create", + new_callable=AsyncMock, + ): + mock_app_instance = MagicMock() + mock_app_instance.get_all_series_from_data_files.return_value = [ + mock_serie + ] + MockApp.return_value = mock_app_instance + + count = await sync_series_from_data_files(str(tmp_path)) + + assert count == 1 + mock_create.assert_called_once() + + @pytest.mark.asyncio + async def test_sync_skips_existing(self, tmp_path): + """Already-existing series should be skipped.""" + mock_serie = MagicMock() + mock_serie.key = "exists" + mock_serie.name = "Exists" + mock_serie.site = "x" + mock_serie.folder = "Exists" + mock_serie.episodeDict = {} + + mock_session = AsyncMock() + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_session) + mock_ctx.__aexit__ = AsyncMock(return_value=False) + + with patch( + "src.server.services.anime_service.SeriesApp" + ) as MockApp, patch( + "src.server.database.connection.get_db_session", + return_value=mock_ctx, + ), patch( + "src.server.database.service.AnimeSeriesService.get_by_key", + new_callable=AsyncMock, + return_value=MagicMock(), + ), patch( + "src.server.database.service.AnimeSeriesService.create", + new_callable=AsyncMock, + ) as mock_create: + mock_app_instance = MagicMock() + mock_app_instance.get_all_series_from_data_files.return_value = [ + mock_serie + ] + MockApp.return_value = mock_app_instance + + count = await sync_series_from_data_files(str(tmp_path)) + + assert count == 0 + mock_create.assert_not_called() + + @pytest.mark.asyncio + async def test_sync_no_data_files(self, tmp_path): + """Empty directory should return 0.""" + with patch( + "src.server.services.anime_service.SeriesApp" + ) as MockApp: + mock_app_instance = MagicMock() + mock_app_instance.get_all_series_from_data_files.return_value = [] + MockApp.return_value = mock_app_instance + + count = await sync_series_from_data_files(str(tmp_path)) + + assert count == 0 + + @pytest.mark.asyncio + async def test_sync_handles_empty_name(self, tmp_path): + """Series with empty name should use folder as fallback.""" + mock_serie = MagicMock() + mock_serie.key = "no-name" + mock_serie.name = "" + mock_serie.site = "x" + mock_serie.folder = "FallbackFolder" + mock_serie.episodeDict = {} + + mock_session = AsyncMock() + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_session) + mock_ctx.__aexit__ = AsyncMock(return_value=False) + + with patch( + "src.server.services.anime_service.SeriesApp" + ) as MockApp, patch( + "src.server.database.connection.get_db_session", + return_value=mock_ctx, + ), patch( + "src.server.database.service.AnimeSeriesService.get_by_key", + new_callable=AsyncMock, + return_value=None, + ), patch( + "src.server.database.service.AnimeSeriesService.create", + new_callable=AsyncMock, + return_value=MagicMock(id=1), + ) as mock_create: + mock_app_instance = MagicMock() + mock_app_instance.get_all_series_from_data_files.return_value = [ + mock_serie + ] + MockApp.return_value = mock_app_instance + + count = await sync_series_from_data_files(str(tmp_path)) + + assert count == 1 + # The name should have been set to folder + assert mock_serie.name == "FallbackFolder" diff --git a/tests/unit/test_database_connection.py b/tests/unit/test_database_connection.py new file mode 100644 index 0000000..c752289 --- /dev/null +++ b/tests/unit/test_database_connection.py @@ -0,0 +1,475 @@ +"""Unit tests for database connection module. + +Tests cover engine/session lifecycle, utility functions, +TransactionManager, SavepointHandle, and various error paths. +""" +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +import src.server.database.connection as conn_mod +from src.server.database.connection import ( + SavepointHandle, + TransactionManager, + _get_database_url, + get_session_transaction_depth, + is_session_in_transaction, +) + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +@pytest.fixture(autouse=True) +def _reset_globals(): + """Reset the module-level globals before/after every test.""" + old_engine = conn_mod._engine + old_sync = conn_mod._sync_engine + old_sf = conn_mod._session_factory + old_ssf = conn_mod._sync_session_factory + + conn_mod._engine = None + conn_mod._sync_engine = None + conn_mod._session_factory = None + conn_mod._sync_session_factory = None + + yield + + conn_mod._engine = old_engine + conn_mod._sync_engine = old_sync + conn_mod._session_factory = old_sf + conn_mod._sync_session_factory = old_ssf + + +# ══════════════════════════════════════════════════════════════════════════════ +# _get_database_url +# ══════════════════════════════════════════════════════════════════════════════ + + +class TestGetDatabaseURL: + """Test _get_database_url helper.""" + + def test_sqlite_url_converted(self): + """sqlite:/// should be converted to sqlite+aiosqlite:///.""" + with patch.object( + conn_mod.settings, "database_url", + "sqlite:///./data/anime.db", + ): + result = _get_database_url() + assert "aiosqlite" in result + + def test_non_sqlite_url_unchanged(self): + """Non-SQLite URL should remain unchanged.""" + with patch.object( + conn_mod.settings, "database_url", + "postgresql://user:pass@localhost/db", + ): + result = _get_database_url() + assert result == "postgresql://user:pass@localhost/db" + + +# ══════════════════════════════════════════════════════════════════════════════ +# get_engine / get_sync_engine +# ══════════════════════════════════════════════════════════════════════════════ + + +class TestGetEngine: + """Test get_engine and get_sync_engine.""" + + def test_raises_when_not_initialized(self): + """get_engine should raise RuntimeError before init_db.""" + with pytest.raises(RuntimeError, match="not initialized"): + conn_mod.get_engine() + + def test_returns_engine_when_set(self): + """Should return the engine when initialised.""" + fake_engine = MagicMock() + conn_mod._engine = fake_engine + assert conn_mod.get_engine() is fake_engine + + def test_get_sync_engine_raises(self): + """get_sync_engine should raise RuntimeError before init_db.""" + with pytest.raises(RuntimeError, match="not initialized"): + conn_mod.get_sync_engine() + + def test_get_sync_engine_returns(self): + """Should return sync engine when set.""" + fake = MagicMock() + conn_mod._sync_engine = fake + assert conn_mod.get_sync_engine() is fake + + +# ══════════════════════════════════════════════════════════════════════════════ +# get_db_session +# ══════════════════════════════════════════════════════════════════════════════ + + +class TestGetDBSession: + """Test get_db_session async context manager.""" + + @pytest.mark.asyncio + async def test_raises_when_not_initialized(self): + """Should raise RuntimeError if session factory is None.""" + with pytest.raises(RuntimeError, match="not initialized"): + async with conn_mod.get_db_session(): + pass + + @pytest.mark.asyncio + async def test_commits_on_success(self): + """Session should be committed on normal exit.""" + mock_session = AsyncMock() + factory = MagicMock(return_value=mock_session) + conn_mod._session_factory = factory + + async with conn_mod.get_db_session() as session: + assert session is mock_session + + mock_session.commit.assert_called_once() + mock_session.close.assert_called_once() + + @pytest.mark.asyncio + async def test_rollback_on_exception(self): + """Session should be rolled back on exception.""" + mock_session = AsyncMock() + factory = MagicMock(return_value=mock_session) + conn_mod._session_factory = factory + + with pytest.raises(ValueError): + async with conn_mod.get_db_session(): + raise ValueError("boom") + + mock_session.rollback.assert_called_once() + mock_session.commit.assert_not_called() + mock_session.close.assert_called_once() + + +# ══════════════════════════════════════════════════════════════════════════════ +# get_sync_session / get_async_session_factory +# ══════════════════════════════════════════════════════════════════════════════ + + +class TestGetSyncSession: + """Test get_sync_session.""" + + def test_raises_when_not_initialized(self): + """Should raise RuntimeError.""" + with pytest.raises(RuntimeError, match="not initialized"): + conn_mod.get_sync_session() + + def test_returns_session(self): + """Should return a session from the factory.""" + mock_session = MagicMock() + conn_mod._sync_session_factory = MagicMock(return_value=mock_session) + assert conn_mod.get_sync_session() is mock_session + + +class TestGetAsyncSessionFactory: + """Test get_async_session_factory.""" + + def test_raises_when_not_initialized(self): + """Should raise RuntimeError.""" + with pytest.raises(RuntimeError, match="not initialized"): + conn_mod.get_async_session_factory() + + def test_returns_session(self): + """Should return a new async session.""" + mock_session = AsyncMock() + conn_mod._session_factory = MagicMock(return_value=mock_session) + assert conn_mod.get_async_session_factory() is mock_session + + +# ══════════════════════════════════════════════════════════════════════════════ +# get_transactional_session +# ══════════════════════════════════════════════════════════════════════════════ + + +class TestGetTransactionalSession: + """Test get_transactional_session.""" + + @pytest.mark.asyncio + async def test_raises_when_not_initialized(self): + """Should raise RuntimeError.""" + with pytest.raises(RuntimeError, match="not initialized"): + async with conn_mod.get_transactional_session(): + pass + + @pytest.mark.asyncio + async def test_does_not_auto_commit(self): + """Session should NOT be committed on normal exit.""" + mock_session = AsyncMock() + conn_mod._session_factory = MagicMock(return_value=mock_session) + + async with conn_mod.get_transactional_session() as session: + pass + + mock_session.commit.assert_not_called() + mock_session.close.assert_called_once() + + @pytest.mark.asyncio + async def test_rollback_on_exception(self): + """Should rollback on exception.""" + mock_session = AsyncMock() + conn_mod._session_factory = MagicMock(return_value=mock_session) + + with pytest.raises(ValueError): + async with conn_mod.get_transactional_session(): + raise ValueError("boom") + + mock_session.rollback.assert_called_once() + + +# ══════════════════════════════════════════════════════════════════════════════ +# close_db +# ══════════════════════════════════════════════════════════════════════════════ + + +class TestCloseDB: + """Test close_db function.""" + + @pytest.mark.asyncio + async def test_disposes_engines(self): + """Should dispose both engines.""" + mock_engine = AsyncMock() + mock_sync = MagicMock() + mock_sync.url = "sqlite:///test.db" + mock_sync.connect.return_value.__enter__ = MagicMock() + mock_sync.connect.return_value.__exit__ = MagicMock() + + conn_ctx = MagicMock() + conn_ctx.__enter__ = MagicMock(return_value=MagicMock()) + conn_ctx.__exit__ = MagicMock(return_value=False) + mock_sync.connect.return_value = conn_ctx + + conn_mod._engine = mock_engine + conn_mod._sync_engine = mock_sync + conn_mod._session_factory = MagicMock() + conn_mod._sync_session_factory = MagicMock() + + await conn_mod.close_db() + + mock_engine.dispose.assert_called_once() + mock_sync.dispose.assert_called_once() + assert conn_mod._engine is None + assert conn_mod._sync_engine is None + + @pytest.mark.asyncio + async def test_noop_when_not_initialized(self): + """Should not raise if engines are None.""" + await conn_mod.close_db() # should not raise + + +# ══════════════════════════════════════════════════════════════════════════════ +# TransactionManager +# ══════════════════════════════════════════════════════════════════════════════ + + +class TestTransactionManager: + """Test TransactionManager class.""" + + def test_init_raises_without_factory(self): + """Should raise RuntimeError when no session factory.""" + with pytest.raises(RuntimeError, match="not initialized"): + TransactionManager() + + @pytest.mark.asyncio + async def test_context_manager_creates_and_closes_session(self): + """Should create session on enter and close on exit.""" + mock_session = AsyncMock() + factory = MagicMock(return_value=mock_session) + + async with TransactionManager(session_factory=factory) as tm: + session = await tm.get_session() + assert session is mock_session + + mock_session.close.assert_called_once() + + @pytest.mark.asyncio + async def test_begin_commit(self): + """begin then commit should work.""" + mock_session = AsyncMock() + factory = MagicMock(return_value=mock_session) + + async with TransactionManager(session_factory=factory) as tm: + await tm.begin() + assert tm.is_in_transaction() is True + await tm.commit() + assert tm.is_in_transaction() is False + + mock_session.begin.assert_called_once() + mock_session.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_begin_rollback(self): + """begin then rollback should work.""" + mock_session = AsyncMock() + factory = MagicMock(return_value=mock_session) + + async with TransactionManager(session_factory=factory) as tm: + await tm.begin() + await tm.rollback() + assert tm.is_in_transaction() is False + + mock_session.rollback.assert_called_once() + + @pytest.mark.asyncio + async def test_exception_auto_rollback(self): + """Exception inside context manager should auto rollback.""" + mock_session = AsyncMock() + factory = MagicMock(return_value=mock_session) + + with pytest.raises(ValueError): + async with TransactionManager(session_factory=factory) as tm: + await tm.begin() + raise ValueError("boom") + + mock_session.rollback.assert_called_once() + + @pytest.mark.asyncio + async def test_double_begin_raises(self): + """begin called twice should raise.""" + mock_session = AsyncMock() + factory = MagicMock(return_value=mock_session) + + async with TransactionManager(session_factory=factory) as tm: + await tm.begin() + with pytest.raises(RuntimeError, match="Already in"): + await tm.begin() + + @pytest.mark.asyncio + async def test_commit_without_begin_raises(self): + """commit without begin should raise.""" + mock_session = AsyncMock() + factory = MagicMock(return_value=mock_session) + + async with TransactionManager(session_factory=factory) as tm: + with pytest.raises(RuntimeError, match="Not in"): + await tm.commit() + + @pytest.mark.asyncio + async def test_get_session_outside_context_raises(self): + """get_session outside context manager should raise.""" + factory = MagicMock() + tm = TransactionManager(session_factory=factory) + with pytest.raises(RuntimeError, match="context manager"): + await tm.get_session() + + @pytest.mark.asyncio + async def test_transaction_depth(self): + """get_transaction_depth should reflect state.""" + mock_session = AsyncMock() + factory = MagicMock(return_value=mock_session) + + async with TransactionManager(session_factory=factory) as tm: + assert tm.get_transaction_depth() == 0 + await tm.begin() + assert tm.get_transaction_depth() == 1 + await tm.commit() + assert tm.get_transaction_depth() == 0 + + @pytest.mark.asyncio + async def test_savepoint_creation(self): + """savepoint should return SavepointHandle.""" + mock_session = AsyncMock() + mock_nested = AsyncMock() + mock_session.begin_nested = AsyncMock(return_value=mock_nested) + factory = MagicMock(return_value=mock_session) + + async with TransactionManager(session_factory=factory) as tm: + await tm.begin() + sp = await tm.savepoint("sp1") + assert isinstance(sp, SavepointHandle) + + @pytest.mark.asyncio + async def test_savepoint_without_transaction_raises(self): + """savepoint outside transaction should raise.""" + mock_session = AsyncMock() + factory = MagicMock(return_value=mock_session) + + async with TransactionManager(session_factory=factory) as tm: + with pytest.raises(RuntimeError, match="Must be in"): + await tm.savepoint() + + @pytest.mark.asyncio + async def test_rollback_without_session_raises(self): + """rollback without active session should raise.""" + factory = MagicMock() + tm = TransactionManager(session_factory=factory) + with pytest.raises(RuntimeError, match="No active session"): + await tm.rollback() + + +# ══════════════════════════════════════════════════════════════════════════════ +# SavepointHandle +# ══════════════════════════════════════════════════════════════════════════════ + + +class TestSavepointHandle: + """Test SavepointHandle class.""" + + @pytest.mark.asyncio + async def test_rollback(self): + """Should call nested.rollback().""" + mock_nested = AsyncMock() + sp = SavepointHandle(mock_nested, "sp1") + await sp.rollback() + mock_nested.rollback.assert_called_once() + assert sp._released is True + + @pytest.mark.asyncio + async def test_rollback_idempotent(self): + """Second rollback should be a noop.""" + mock_nested = AsyncMock() + sp = SavepointHandle(mock_nested, "sp1") + await sp.rollback() + await sp.rollback() + mock_nested.rollback.assert_called_once() + + @pytest.mark.asyncio + async def test_release(self): + """Should mark as released.""" + mock_nested = AsyncMock() + sp = SavepointHandle(mock_nested, "sp1") + await sp.release() + assert sp._released is True + + @pytest.mark.asyncio + async def test_release_idempotent(self): + """Second release should be a noop.""" + mock_nested = AsyncMock() + sp = SavepointHandle(mock_nested, "sp1") + await sp.release() + await sp.release() + assert sp._released is True + + +# ══════════════════════════════════════════════════════════════════════════════ +# Utility Functions +# ══════════════════════════════════════════════════════════════════════════════ + + +class TestUtilityFunctions: + """Test is_session_in_transaction and get_session_transaction_depth.""" + + def test_in_transaction_true(self): + """Should return True when session is in transaction.""" + session = MagicMock() + session.in_transaction.return_value = True + assert is_session_in_transaction(session) is True + + def test_in_transaction_false(self): + """Should return False when session is not in transaction.""" + session = MagicMock() + session.in_transaction.return_value = False + assert is_session_in_transaction(session) is False + + def test_transaction_depth_zero(self): + """Should return 0 when not in transaction.""" + session = MagicMock() + session.in_transaction.return_value = False + assert get_session_transaction_depth(session) == 0 + + def test_transaction_depth_one(self): + """Should return 1 when in transaction.""" + session = MagicMock() + session.in_transaction.return_value = True + assert get_session_transaction_depth(session) == 1 diff --git a/tests/unit/test_enhanced_provider.py b/tests/unit/test_enhanced_provider.py index 6c546ff..d7d34ce 100644 --- a/tests/unit/test_enhanced_provider.py +++ b/tests/unit/test_enhanced_provider.py @@ -442,3 +442,479 @@ class TestEnhancedProviderFromHTML: result = enhanced_loader._get_provider_from_html(1, 1, "test") assert result == {} + + +# ══════════════════════════════════════════════════════════════════════════════ +# New coverage tests – fetch, download flow, redirect, season counts +# ══════════════════════════════════════════════════════════════════════════════ + + +class TestFetchAnimeListWithRecovery: + """Test _fetch_anime_list_with_recovery.""" + + def test_successful_fetch(self, enhanced_loader): + """Should fetch and parse a JSON response.""" + mock_response = MagicMock() + mock_response.ok = True + mock_response.text = json.dumps([{"title": "Naruto"}]) + + with patch( + "src.core.providers.enhanced_provider.recovery_strategies" + ) as mock_rs: + mock_rs.handle_network_failure.return_value = mock_response + result = enhanced_loader._fetch_anime_list_with_recovery( + "https://example.com/search" + ) + + assert len(result) == 1 + assert result[0]["title"] == "Naruto" + + def test_404_raises_non_retryable(self, enhanced_loader): + """404 should raise NonRetryableError.""" + mock_response = MagicMock() + mock_response.ok = False + mock_response.status_code = 404 + + with patch( + "src.core.providers.enhanced_provider.recovery_strategies" + ) as mock_rs: + mock_rs.handle_network_failure.return_value = mock_response + with pytest.raises(NonRetryableError, match="not found"): + enhanced_loader._fetch_anime_list_with_recovery( + "https://example.com/search" + ) + + def test_403_raises_non_retryable(self, enhanced_loader): + """403 should raise NonRetryableError.""" + mock_response = MagicMock() + mock_response.ok = False + mock_response.status_code = 403 + + with patch( + "src.core.providers.enhanced_provider.recovery_strategies" + ) as mock_rs: + mock_rs.handle_network_failure.return_value = mock_response + with pytest.raises(NonRetryableError, match="forbidden"): + enhanced_loader._fetch_anime_list_with_recovery( + "https://example.com/search" + ) + + def test_500_raises_retryable(self, enhanced_loader): + """500 should raise RetryableError.""" + mock_response = MagicMock() + mock_response.ok = False + mock_response.status_code = 500 + + with patch( + "src.core.providers.enhanced_provider.recovery_strategies" + ) as mock_rs: + mock_rs.handle_network_failure.return_value = mock_response + with pytest.raises(RetryableError, match="Server error"): + enhanced_loader._fetch_anime_list_with_recovery( + "https://example.com/search" + ) + + def test_network_error_raises_network_error(self, enhanced_loader): + """requests.RequestException should raise NetworkError.""" + import requests as req + + with patch( + "src.core.providers.enhanced_provider.recovery_strategies" + ) as mock_rs: + mock_rs.handle_network_failure.side_effect = ( + req.RequestException("timeout") + ) + with pytest.raises(NetworkError, match="Network error"): + enhanced_loader._fetch_anime_list_with_recovery( + "https://example.com/search" + ) + + +class TestGetKeyHTML: + """Test _GetKeyHTML fetching and caching.""" + + def test_cached_html_returned(self, enhanced_loader): + """Already-cached key should skip HTTP call.""" + mock_resp = MagicMock() + enhanced_loader._KeyHTMLDict["cached-key"] = mock_resp + + result = enhanced_loader._GetKeyHTML("cached-key") + assert result is mock_resp + enhanced_loader.session.get.assert_not_called() + + def test_fetches_and_caches(self, enhanced_loader): + """Missing key should be fetched and cached.""" + mock_response = MagicMock() + mock_response.ok = True + + with patch( + "src.core.providers.enhanced_provider.recovery_strategies" + ) as mock_rs: + mock_rs.handle_network_failure.return_value = mock_response + result = enhanced_loader._GetKeyHTML("new-key") + + assert result is mock_response + assert enhanced_loader._KeyHTMLDict["new-key"] is mock_response + + def test_404_raises_non_retryable(self, enhanced_loader): + """404 from server should raise NonRetryableError.""" + mock_response = MagicMock() + mock_response.ok = False + mock_response.status_code = 404 + + with patch( + "src.core.providers.enhanced_provider.recovery_strategies" + ) as mock_rs: + mock_rs.handle_network_failure.return_value = mock_response + with pytest.raises(NonRetryableError, match="not found"): + enhanced_loader._GetKeyHTML("missing-key") + + +class TestGetRedirectLink: + """Test _get_redirect_link method.""" + + def test_returns_link_and_provider(self, enhanced_loader): + """Should return (link, provider_name) tuple.""" + with patch.object( + enhanced_loader, "IsLanguage", return_value=True + ), patch.object( + enhanced_loader, + "_get_provider_from_html", + return_value={ + "VOE": {1: "https://aniworld.to/redirect/100"} + }, + ): + link, provider = enhanced_loader._get_redirect_link( + 1, 1, "test", "German Dub" + ) + assert link == "https://aniworld.to/redirect/100" + assert provider == "VOE" + + def test_language_unavailable_raises(self, enhanced_loader): + """Should raise NonRetryableError if language not available.""" + with patch.object( + enhanced_loader, "IsLanguage", return_value=False + ): + with pytest.raises(NonRetryableError, match="not available"): + enhanced_loader._get_redirect_link( + 1, 1, "test", "German Dub" + ) + + def test_no_provider_found_raises(self, enhanced_loader): + """Should raise when no provider has the language.""" + with patch.object( + enhanced_loader, "IsLanguage", return_value=True + ), patch.object( + enhanced_loader, + "_get_provider_from_html", + return_value={"VOE": {2: "link"}}, # English Sub only + ): + with pytest.raises(NonRetryableError, match="No provider"): + enhanced_loader._get_redirect_link( + 1, 1, "test", "German Dub" + ) + + +class TestGetEmbeddedLink: + """Test _get_embeded_link method.""" + + def test_returns_final_url(self, enhanced_loader): + """Should follow redirect and return final URL.""" + mock_response = MagicMock() + mock_response.url = "https://voe.sx/e/abc123" + + with patch.object( + enhanced_loader, + "_get_redirect_link", + return_value=("https://aniworld.to/redirect/100", "VOE"), + ), patch( + "src.core.providers.enhanced_provider.recovery_strategies" + ) as mock_rs: + mock_rs.handle_network_failure.return_value = mock_response + result = enhanced_loader._get_embeded_link( + 1, 1, "test", "German Dub" + ) + + assert result == "https://voe.sx/e/abc123" + + def test_redirect_failure_raises(self, enhanced_loader): + """Should propagate error from _get_redirect_link.""" + with patch.object( + enhanced_loader, + "_get_redirect_link", + side_effect=NonRetryableError("no link"), + ): + with pytest.raises(NonRetryableError): + enhanced_loader._get_embeded_link( + 1, 1, "test", "German Dub" + ) + + +class TestGetDirectLinkFromProvider: + """Test _get_direct_link_from_provider method.""" + + def test_returns_link_from_voe(self, enhanced_loader): + """Should use VOE provider to extract direct link.""" + mock_provider = MagicMock() + mock_provider.get_link.return_value = ( + "https://direct.example.com/video.mp4", + [], + ) + + enhanced_loader.Providers = MagicMock() + enhanced_loader.Providers.GetProvider.return_value = mock_provider + + with patch.object( + enhanced_loader, + "_get_embeded_link", + return_value="https://voe.sx/e/abc123", + ): + result = enhanced_loader._get_direct_link_from_provider( + 1, 1, "test", "German Dub" + ) + + assert result == ("https://direct.example.com/video.mp4", []) + + def test_no_embedded_link_raises(self, enhanced_loader): + """Should raise if embedded link is None.""" + with patch.object( + enhanced_loader, + "_get_embeded_link", + return_value=None, + ): + with pytest.raises(NonRetryableError, match="No embedded link"): + enhanced_loader._get_direct_link_from_provider( + 1, 1, "test", "German Dub" + ) + + def test_no_provider_raises(self, enhanced_loader): + """Should raise if VOE provider unavailable.""" + enhanced_loader.Providers = MagicMock() + enhanced_loader.Providers.GetProvider.return_value = None + + with patch.object( + enhanced_loader, + "_get_embeded_link", + return_value="https://voe.sx/e/abc", + ): + with pytest.raises(NonRetryableError, match="VOE provider"): + enhanced_loader._get_direct_link_from_provider( + 1, 1, "test", "German Dub" + ) + + +class TestDownloadWithRecovery: + """Test _download_with_recovery method.""" + + def test_successful_download(self, enhanced_loader, tmp_path): + """Should download, verify, and move file.""" + temp_path = str(tmp_path / "temp.mp4") + output_path = str(tmp_path / "output.mp4") + + # Create a fake temp file after "download" + def fake_download(*args, **kwargs): + with open(temp_path, "wb") as f: + f.write(b"fake-video-data") + return True + + with patch( + "src.core.providers.enhanced_provider.recovery_strategies" + ) as mock_rs, patch( + "src.core.providers.enhanced_provider.file_corruption_detector" + ) as mock_fcd, patch( + "src.core.providers.enhanced_provider.get_integrity_manager" + ) as mock_im: + mock_rs.handle_network_failure.return_value = ( + "https://direct.example.com/v.mp4", + [], + ) + mock_rs.handle_download_failure.side_effect = fake_download + mock_fcd.is_valid_video_file.return_value = True + mock_im.return_value.store_checksum.return_value = "abc123" + + result = enhanced_loader._download_with_recovery( + 1, 1, "test", "German Dub", + temp_path, output_path, None, + ) + + assert result is True + assert os.path.exists(output_path) + + def test_all_providers_fail_returns_false(self, enhanced_loader, tmp_path): + """Should return False when all providers fail.""" + temp_path = str(tmp_path / "temp.mp4") + output_path = str(tmp_path / "output.mp4") + + with patch( + "src.core.providers.enhanced_provider.recovery_strategies" + ) as mock_rs: + mock_rs.handle_network_failure.side_effect = Exception("fail") + + result = enhanced_loader._download_with_recovery( + 1, 1, "test", "German Dub", + temp_path, output_path, None, + ) + + assert result is False + + def test_corrupted_download_removed(self, enhanced_loader, tmp_path): + """Corrupted downloads should be removed and next provider tried.""" + temp_path = str(tmp_path / "temp.mp4") + output_path = str(tmp_path / "output.mp4") + + # Create a fake temp file after "download" + def fake_download(*args, **kwargs): + with open(temp_path, "wb") as f: + f.write(b"corrupt") + return True + + with patch( + "src.core.providers.enhanced_provider.recovery_strategies" + ) as mock_rs, patch( + "src.core.providers.enhanced_provider.file_corruption_detector" + ) as mock_fcd: + mock_rs.handle_network_failure.return_value = ( + "https://direct.example.com/v.mp4", + [], + ) + mock_rs.handle_download_failure.side_effect = fake_download + mock_fcd.is_valid_video_file.return_value = False + + result = enhanced_loader._download_with_recovery( + 1, 1, "test", "German Dub", + temp_path, output_path, None, + ) + + assert result is False + + +class TestGetSeasonEpisodeCount: + """Test get_season_episode_count method.""" + + def test_returns_episode_counts(self, enhanced_loader): + """Should return dict of season -> episode count.""" + base_html = ( + b'' + b"" + ) + s1_html = ( + b'
' + b'E1' + b'E2' + b'' + ) + s2_html = ( + b'' + b'E1' + b'' + ) + + responses = [ + MagicMock(content=base_html), + MagicMock(content=s1_html), + MagicMock(content=s2_html), + ] + + with patch( + "src.core.providers.enhanced_provider.recovery_strategies" + ) as mock_rs: + mock_rs.handle_network_failure.side_effect = responses + result = enhanced_loader.get_season_episode_count("test") + + assert result == {1: 2, 2: 1} + + def test_no_seasons_meta_returns_empty(self, enhanced_loader): + """Missing numberOfSeasons meta should return empty dict.""" + base_html = b"No seasons" + + with patch( + "src.core.providers.enhanced_provider.recovery_strategies" + ) as mock_rs: + mock_rs.handle_network_failure.return_value = MagicMock( + content=base_html + ) + result = enhanced_loader.get_season_episode_count("test") + + assert result == {} + + +class TestPerformYtdlDownload: + """Test _perform_ytdl_download method.""" + + def test_success(self, enhanced_loader): + """Should return True on successful download.""" + with patch( + "src.core.providers.enhanced_provider.YoutubeDL" + ) as MockYDL: + mock_ydl = MagicMock() + MockYDL.return_value.__enter__ = MagicMock(return_value=mock_ydl) + MockYDL.return_value.__exit__ = MagicMock(return_value=False) + result = enhanced_loader._perform_ytdl_download( + {}, "https://example.com/video" + ) + + assert result is True + + def test_failure_raises_download_error(self, enhanced_loader): + """yt-dlp failure should raise DownloadError.""" + with patch( + "src.core.providers.enhanced_provider.YoutubeDL" + ) as MockYDL: + mock_ydl = MagicMock() + mock_ydl.download.side_effect = Exception("yt-dlp crash") + MockYDL.return_value.__enter__ = MagicMock(return_value=mock_ydl) + MockYDL.return_value.__exit__ = MagicMock(return_value=False) + with pytest.raises(DownloadError, match="Download failed"): + enhanced_loader._perform_ytdl_download( + {}, "https://example.com/video" + ) + + +class TestDownloadFlow: + """Test full Download method flow.""" + + @patch("src.core.providers.enhanced_provider.get_integrity_manager") + def test_existing_valid_file_returns_true( + self, mock_integrity, enhanced_loader, tmp_path + ): + """Should return True if file already exists and is valid.""" + # Create fake existing file + folder = tmp_path / "Folder" / "Season 1" + folder.mkdir(parents=True) + video = folder / "Test - S01E001 - (German Dub).mp4" + video.write_bytes(b"valid-video") + + enhanced_loader._KeyHTMLDict["key"] = MagicMock( + content=b"