Expand test coverage: ~188 new tests across 6 critical files

- Fix failing test_authenticated_request_succeeds (dependency override)
- Expand test_anime_service.py (+35 tests: status events, DB, broadcasts)
- Create test_queue_repository.py (27 tests: CRUD, model conversion)
- Expand test_enhanced_provider.py (+24 tests: fetch, download, redirect)
- Expand test_serie_scanner.py (+25 tests: events, year extract, mp4 scan)
- Create test_database_connection.py (38 tests: sessions, transactions)
- Expand test_anime_endpoints.py (+39 tests: status, search, loading)
- Clean up docs/instructions.md TODO list
This commit is contained in:
2026-02-15 17:44:27 +01:00
parent d7ab689fe1
commit e84a220f55
8 changed files with 3254 additions and 115 deletions

View File

@@ -1,6 +1,6 @@
"""Tests for anime API endpoints."""
import asyncio
from unittest.mock import AsyncMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
@@ -414,3 +414,536 @@ async def test_add_series_special_characters_in_name(authenticated_client):
invalid_chars = [':', '\\', '?', '*', '<', '>', '|', '"']
for char in invalid_chars:
assert char not in folder_name, f"Found '{char}' in folder name for {name}"
# ---------------------------------------------------------------------------
# New tests: get_anime_status
# ---------------------------------------------------------------------------
class TestGetAnimeStatusEndpoint:
"""Tests for GET /api/anime/status."""
@pytest.mark.asyncio
async def test_status_unauthorized(self):
"""Status endpoint should require authentication."""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/api/anime/status")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_status_authenticated(self, authenticated_client):
"""Authenticated request returns directory and series count."""
response = await authenticated_client.get("/api/anime/status")
assert response.status_code == 200
data = response.json()
assert "directory" in data
assert "series_count" in data
assert isinstance(data["series_count"], int)
def test_status_direct_no_series_app(self):
"""When series_app is None, returns empty directory and 0 count."""
result = asyncio.run(anime_module.get_anime_status(series_app=None))
assert result["directory"] == ""
assert result["series_count"] == 0
# ---------------------------------------------------------------------------
# New tests: list_anime authenticated
# ---------------------------------------------------------------------------
class TestListAnimeAuthenticated:
"""Tests for GET /api/anime/ with authentication."""
@pytest.mark.asyncio
async def test_list_anime_returns_summaries(self, authenticated_client):
"""Authenticated list returns anime summaries."""
response = await authenticated_client.get("/api/anime/")
assert response.status_code == 200
data = response.json()
assert isinstance(data, list)
@pytest.mark.asyncio
async def test_list_anime_invalid_page(self, authenticated_client):
"""Negative page number returns validation error."""
response = await authenticated_client.get(
"/api/anime/", params={"page": -1}
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_list_anime_per_page_too_large(self, authenticated_client):
"""Per page > 1000 returns validation error."""
response = await authenticated_client.get(
"/api/anime/", params={"per_page": 5000}
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_list_anime_invalid_sort_by(self, authenticated_client):
"""Invalid sort_by parameter returns validation error."""
response = await authenticated_client.get(
"/api/anime/", params={"sort_by": "injection_attempt"}
)
assert response.status_code == 422
@pytest.mark.asyncio
async def test_list_anime_valid_sort_by(self, authenticated_client):
"""Valid sort_by parameter is accepted."""
response = await authenticated_client.get(
"/api/anime/", params={"sort_by": "title"}
)
assert response.status_code == 200
@pytest.mark.asyncio
async def test_list_anime_invalid_filter(self, authenticated_client):
"""Invalid filter value returns validation error."""
response = await authenticated_client.get(
"/api/anime/", params={"filter": "hacked"}
)
assert response.status_code == 422
# ---------------------------------------------------------------------------
# New tests: get_scan_status
# ---------------------------------------------------------------------------
class TestGetScanStatusEndpoint:
"""Tests for GET /api/anime/scan/status."""
@pytest.mark.asyncio
async def test_scan_status_unauthorized(self):
"""Scan status endpoint should require authentication."""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/api/anime/scan/status")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_scan_status_authenticated(self, authenticated_client):
"""Authenticated request returns scan status dict."""
response = await authenticated_client.get("/api/anime/scan/status")
assert response.status_code == 200
data = response.json()
assert isinstance(data, dict)
# ---------------------------------------------------------------------------
# New tests: _validate_search_query_extended
# ---------------------------------------------------------------------------
class TestValidateSearchQueryExtended:
"""Tests for the internal _validate_search_query_extended function."""
def test_empty_query_raises(self):
"""Empty string raises 422."""
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
anime_module._validate_search_query_extended("")
assert exc_info.value.status_code == 422
def test_whitespace_only_raises(self):
"""Whitespace-only raises 422."""
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
anime_module._validate_search_query_extended(" ")
assert exc_info.value.status_code == 422
def test_null_bytes_raise(self):
"""Null bytes in query raise 400."""
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
anime_module._validate_search_query_extended("test\x00query")
assert exc_info.value.status_code == 400
def test_too_long_query_raises(self):
"""Query exceeding 200 chars raises 422."""
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
anime_module._validate_search_query_extended("a" * 201)
assert exc_info.value.status_code == 422
def test_valid_query_returns_string(self):
"""Valid query is returned (possibly normalised)."""
result = anime_module._validate_search_query_extended("Naruto")
assert isinstance(result, str)
assert len(result) > 0
# ---------------------------------------------------------------------------
# New tests: search_anime_post
# ---------------------------------------------------------------------------
class TestSearchAnimePost:
"""Tests for POST /api/anime/search."""
@pytest.mark.asyncio
async def test_search_post_returns_results(self):
"""POST search with valid query returns results."""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.post(
"/api/anime/search",
json={"query": "test"},
)
assert response.status_code == 200
assert isinstance(response.json(), list)
@pytest.mark.asyncio
async def test_search_post_empty_query_rejected(self):
"""POST search with empty query returns 422."""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.post(
"/api/anime/search",
json={"query": ""},
)
assert response.status_code == 422
# ---------------------------------------------------------------------------
# New tests: _perform_search
# ---------------------------------------------------------------------------
class TestPerformSearch:
"""Tests for the internal _perform_search function."""
@pytest.mark.asyncio
async def test_search_no_series_app(self):
"""When series_app is None return empty list."""
result = await anime_module._perform_search("test", None)
assert result == []
@pytest.mark.asyncio
async def test_search_dict_results(self):
"""Dict-format results are converted to AnimeSummary."""
mock_app = AsyncMock()
mock_app.search = AsyncMock(
return_value=[
{
"key": "k1",
"title": "Title One",
"site": "aniworld.to",
"folder": "f1",
"link": "https://aniworld.to/anime/stream/k1",
"missing_episodes": {},
}
]
)
results = await anime_module._perform_search("query", mock_app)
assert len(results) == 1
assert results[0].key == "k1"
assert results[0].name == "Title One"
@pytest.mark.asyncio
async def test_search_object_results(self):
"""Object-format results (with attributes) are handled."""
match = MagicMock(spec=[])
match.key = "obj-key"
match.id = ""
match.title = "Object Title"
match.name = "Object Title"
match.site = "aniworld.to"
match.folder = "Object Folder"
match.link = ""
match.url = ""
match.missing_episodes = {}
mock_app = AsyncMock()
mock_app.search = AsyncMock(return_value=[match])
results = await anime_module._perform_search("query", mock_app)
assert len(results) == 1
assert results[0].key == "obj-key"
@pytest.mark.asyncio
async def test_search_key_extracted_from_link(self):
"""When key is empty, extract from link URL."""
mock_app = AsyncMock()
mock_app.search = AsyncMock(
return_value=[
{
"key": "",
"name": "No Key",
"site": "",
"folder": "",
"link": "https://aniworld.to/anime/stream/extracted-key",
"missing_episodes": {},
}
]
)
results = await anime_module._perform_search("q", mock_app)
assert results[0].key == "extracted-key"
@pytest.mark.asyncio
async def test_search_exception_raises_500(self):
"""Non-HTTP exception in search raises 500."""
from fastapi import HTTPException
mock_app = AsyncMock()
mock_app.search = AsyncMock(side_effect=RuntimeError("boom"))
with pytest.raises(HTTPException) as exc_info:
await anime_module._perform_search("q", mock_app)
assert exc_info.value.status_code == 500
# ---------------------------------------------------------------------------
# New tests: get_loading_status
# ---------------------------------------------------------------------------
class TestGetLoadingStatusEndpoint:
"""Tests for GET /api/anime/{key}/loading-status."""
@pytest.mark.asyncio
async def test_loading_status_unauthorized(self):
"""Loading status requires authentication."""
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
response = await client.get("/api/anime/some-key/loading-status")
assert response.status_code == 401
@pytest.mark.asyncio
async def test_loading_status_no_db(self, authenticated_client):
"""Without database session returns 503."""
response = await authenticated_client.get(
"/api/anime/some-key/loading-status"
)
# get_optional_database_session may return None → 503
assert response.status_code in (503, 404, 500)
def test_loading_status_direct_no_db(self):
"""Direct call with db=None raises 503."""
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(anime_module.get_loading_status("key", db=None))
assert exc_info.value.status_code == 503
def test_loading_status_not_found(self):
"""Direct call with unknown key raises 404."""
from fastapi import HTTPException
mock_db = AsyncMock()
async def _run():
with patch(
"src.server.database.service.AnimeSeriesService"
) as mock_svc:
mock_svc.get_by_key = AsyncMock(return_value=None)
return await anime_module.get_loading_status(
"missing-key", db=mock_db
)
with pytest.raises(HTTPException) as exc_info:
asyncio.run(_run())
assert exc_info.value.status_code == 404
def test_loading_status_pending(self):
"""Direct call returns correct pending status payload."""
mock_db = AsyncMock()
series_row = MagicMock()
series_row.key = "test-key"
series_row.loading_status = "pending"
series_row.episodes_loaded = False
series_row.has_nfo = False
series_row.logo_loaded = False
series_row.images_loaded = False
series_row.loading_started_at = None
series_row.loading_completed_at = None
series_row.loading_error = None
async def _run():
with patch(
"src.server.database.service.AnimeSeriesService"
) as mock_svc:
mock_svc.get_by_key = AsyncMock(return_value=series_row)
return await anime_module.get_loading_status(
"test-key", db=mock_db
)
result = asyncio.run(_run())
assert result["key"] == "test-key"
assert result["loading_status"] == "pending"
assert "Queued" in result["message"]
assert result["progress"]["episodes"] is False
def test_loading_status_completed(self):
"""Completed status returns correct message."""
from datetime import datetime
mock_db = AsyncMock()
series_row = MagicMock()
series_row.key = "done-key"
series_row.loading_status = "completed"
series_row.episodes_loaded = True
series_row.has_nfo = True
series_row.logo_loaded = True
series_row.images_loaded = True
series_row.loading_started_at = datetime(2025, 1, 1)
series_row.loading_completed_at = datetime(2025, 1, 1, 0, 5)
series_row.loading_error = None
async def _run():
with patch(
"src.server.database.service.AnimeSeriesService"
) as mock_svc:
mock_svc.get_by_key = AsyncMock(return_value=series_row)
return await anime_module.get_loading_status(
"done-key", db=mock_db
)
result = asyncio.run(_run())
assert result["loading_status"] == "completed"
assert "successfully" in result["message"]
assert result["progress"]["episodes"] is True
assert result["completed_at"] is not None
# ---------------------------------------------------------------------------
# New tests: get_anime detail
# ---------------------------------------------------------------------------
class TestGetAnimeDetail:
"""Tests for GET /api/anime/{anime_id} detail endpoint."""
def test_get_anime_by_key(self):
"""Primary lookup by key returns correct detail."""
fake = FakeSeriesApp()
result = asyncio.run(
anime_module.get_anime("test-show-key", series_app=fake)
)
assert result.key == "test-show-key"
assert result.title == "Test Show"
def test_get_anime_by_folder_fallback(self):
"""Folder-based lookup works as deprecated fallback."""
fake = FakeSeriesApp()
result = asyncio.run(
anime_module.get_anime("Test Show (2023)", series_app=fake)
)
assert result.key == "test-show-key"
def test_get_anime_not_found(self):
"""Unknown anime_id raises 404."""
from fastapi import HTTPException
fake = FakeSeriesApp()
with pytest.raises(HTTPException) as exc_info:
asyncio.run(
anime_module.get_anime("nonexistent", series_app=fake)
)
assert exc_info.value.status_code == 404
def test_get_anime_no_series_app(self):
"""None series_app raises 404."""
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
asyncio.run(
anime_module.get_anime("any-id", series_app=None)
)
assert exc_info.value.status_code == 404
def test_get_anime_episodes_formatted(self):
"""Episode dict is converted to season-episode strings."""
fake = FakeSeriesApp()
result = asyncio.run(
anime_module.get_anime("test-show-key", series_app=fake)
)
assert "1-1" in result.episodes
assert "1-2" in result.episodes
def test_get_anime_complete_show_no_episodes(self):
"""Complete show with empty episodeDict returns empty episodes list."""
fake = FakeSeriesApp()
result = asyncio.run(
anime_module.get_anime("complete-show-key", series_app=fake)
)
assert result.episodes == []
# ---------------------------------------------------------------------------
# New tests: trigger_rescan authenticated
# ---------------------------------------------------------------------------
class TestTriggerRescanAuthenticated:
"""Tests for POST /api/anime/rescan with authentication."""
@pytest.mark.asyncio
async def test_rescan_authenticated(self, authenticated_client):
"""Authenticated rescan returns success."""
from src.server.services.anime_service import AnimeService
from src.server.utils.dependencies import get_anime_service
mock_service = AsyncMock(spec=AnimeService)
mock_service.rescan = AsyncMock()
app.dependency_overrides[get_anime_service] = lambda: mock_service
try:
response = await authenticated_client.post("/api/anime/rescan")
assert response.status_code == 200
data = response.json()
assert data["success"] is True
mock_service.rescan.assert_called_once()
finally:
app.dependency_overrides.pop(get_anime_service, None)
def test_rescan_service_error(self):
"""AnimeServiceError is converted to ServerError."""
from src.server.services.anime_service import AnimeServiceError
mock_service = AsyncMock()
mock_service.rescan = AsyncMock(
side_effect=AnimeServiceError("scan failed")
)
from src.server.exceptions import ServerError
with pytest.raises(ServerError):
asyncio.run(
anime_module.trigger_rescan(anime_service=mock_service)
)
# ---------------------------------------------------------------------------
# New tests: search_anime_get additional
# ---------------------------------------------------------------------------
class TestSearchAnimeGetAdditional:
"""Additional tests for GET /api/anime/search."""
@pytest.mark.asyncio
async def test_search_get_with_query(self):
"""Search GET with valid query returns list."""
transport = ASGITransport(app=app)
async with AsyncClient(
transport=transport, base_url="http://test"
) as client:
response = await client.get(
"/api/anime/search", params={"query": "naruto"}
)
assert response.status_code == 200
assert isinstance(response.json(), list)
@pytest.mark.asyncio
async def test_search_get_null_byte_query(self):
"""Search GET with null byte in query returns 400."""
transport = ASGITransport(app=app)
async with AsyncClient(
transport=transport, base_url="http://test"
) as client:
response = await client.get(
"/api/anime/search", params={"query": "test\x00bad"}
)
assert response.status_code == 400

View File

@@ -158,17 +158,17 @@ class TestFrontendAuthentication:
async def test_authenticated_request_succeeds(self, authenticated_client):
"""Test that requests with valid token succeed."""
with patch("src.server.utils.dependencies.get_series_app") as mock_get_app:
mock_app = AsyncMock()
mock_list = AsyncMock()
mock_list.GetMissingEpisode = AsyncMock(return_value=[])
mock_list.GetList = AsyncMock(return_value=[])
mock_app.List = mock_list
mock_get_app.return_value = mock_app
mock_anime_service = AsyncMock()
mock_anime_service.list_series_with_filters = AsyncMock(return_value=[])
from src.server.utils.dependencies import get_anime_service
app.dependency_overrides[get_anime_service] = lambda: mock_anime_service
try:
response = await authenticated_client.get("/api/anime/")
assert response.status_code == 200
finally:
app.dependency_overrides.pop(get_anime_service, None)
class TestFrontendAnimeAPI:

View File

@@ -1,16 +1,23 @@
"""Unit tests for AnimeService.
Tests cover service initialization, async operations, caching,
error handling, and progress reporting integration.
error handling, progress reporting integration, scan/download status
event handling, database persistence, and WebSocket broadcasting.
"""
from __future__ import annotations
import asyncio
import time
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from src.server.services.anime_service import AnimeService, AnimeServiceError
from src.server.services.anime_service import (
AnimeService,
AnimeServiceError,
sync_series_from_data_files,
)
from src.server.services.progress_service import ProgressService
@@ -472,3 +479,965 @@ class TestFactoryFunction:
assert isinstance(service, AnimeService)
assert service._app is mock_series_app
# =============================================================================
# New coverage tests download / scan status, DB persistence, broadcasting
# =============================================================================
class _FakeDownloadArgs:
"""Minimal stand-in for DownloadStatusEventArgs."""
def __init__(self, **kwargs):
self.status = kwargs.get("status", "started")
self.serie_folder = kwargs.get("serie_folder", "TestFolder")
self.season = kwargs.get("season", 1)
self.episode = kwargs.get("episode", 1)
self.item_id = kwargs.get("item_id", None)
self.progress = kwargs.get("progress", 0)
self.message = kwargs.get("message", None)
self.error = kwargs.get("error", None)
self.mbper_sec = kwargs.get("mbper_sec", None)
self.eta = kwargs.get("eta", None)
class _FakeScanArgs:
"""Minimal stand-in for ScanStatusEventArgs."""
def __init__(self, **kwargs):
self.status = kwargs.get("status", "started")
self.current = kwargs.get("current", 0)
self.total = kwargs.get("total", 10)
self.folder = kwargs.get("folder", "")
self.message = kwargs.get("message", None)
self.error = kwargs.get("error", None)
class TestOnDownloadStatus:
"""Test _on_download_status event handler."""
def test_download_started_schedules_start_progress(
self, anime_service, mock_progress_service
):
"""started event should schedule start_progress."""
loop = asyncio.new_event_loop()
anime_service._event_loop = loop
try:
with patch("asyncio.get_running_loop", side_effect=RuntimeError):
with patch("asyncio.run_coroutine_threadsafe") as mock_run:
args = _FakeDownloadArgs(
status="started", item_id="q-1"
)
anime_service._on_download_status(args)
mock_run.assert_called_once()
coro = mock_run.call_args[0][0]
assert coro is not None
finally:
loop.close()
def test_download_progress_schedules_update(
self, anime_service, mock_progress_service
):
"""progress event should schedule update_progress."""
loop = asyncio.new_event_loop()
anime_service._event_loop = loop
try:
with patch("asyncio.get_running_loop", side_effect=RuntimeError):
with patch("asyncio.run_coroutine_threadsafe") as mock_run:
args = _FakeDownloadArgs(
status="progress",
progress=42,
message="Downloading...",
mbper_sec=5.5,
eta=30,
)
anime_service._on_download_status(args)
mock_run.assert_called_once()
finally:
loop.close()
def test_download_completed_schedules_complete(
self, anime_service, mock_progress_service
):
"""completed event should schedule complete_progress."""
loop = asyncio.new_event_loop()
anime_service._event_loop = loop
try:
with patch("asyncio.get_running_loop", side_effect=RuntimeError):
with patch("asyncio.run_coroutine_threadsafe") as mock_run:
args = _FakeDownloadArgs(status="completed")
anime_service._on_download_status(args)
mock_run.assert_called_once()
finally:
loop.close()
def test_download_failed_schedules_fail(
self, anime_service, mock_progress_service
):
"""failed event should schedule fail_progress."""
loop = asyncio.new_event_loop()
anime_service._event_loop = loop
try:
with patch("asyncio.get_running_loop", side_effect=RuntimeError):
with patch("asyncio.run_coroutine_threadsafe") as mock_run:
args = _FakeDownloadArgs(
status="failed", error=Exception("Err")
)
anime_service._on_download_status(args)
mock_run.assert_called_once()
finally:
loop.close()
def test_progress_id_from_item_id(self, anime_service):
"""item_id should be used as progress_id when available."""
loop = asyncio.new_event_loop()
anime_service._event_loop = loop
try:
with patch("asyncio.get_running_loop", side_effect=RuntimeError):
with patch("asyncio.run_coroutine_threadsafe") as mock_run:
args = _FakeDownloadArgs(
status="started", item_id="queue-42"
)
anime_service._on_download_status(args)
coro = mock_run.call_args[0][0]
# The coroutine was created with progress_id="queue-42"
assert mock_run.called
finally:
loop.close()
def test_progress_id_fallback_without_item_id(self, anime_service):
"""Without item_id, progress_id is built from folder/season/episode."""
loop = asyncio.new_event_loop()
anime_service._event_loop = loop
try:
with patch("asyncio.get_running_loop", side_effect=RuntimeError):
with patch("asyncio.run_coroutine_threadsafe") as mock_run:
args = _FakeDownloadArgs(
status="started",
item_id=None,
serie_folder="FolderX",
season=2,
episode=5,
)
anime_service._on_download_status(args)
assert mock_run.called
finally:
loop.close()
def test_no_event_loop_returns_silently(self, anime_service):
"""No loop available should not raise."""
anime_service._event_loop = None
with patch("asyncio.get_running_loop", side_effect=RuntimeError):
args = _FakeDownloadArgs(status="started")
anime_service._on_download_status(args) # should not raise
class TestOnScanStatus:
"""Test _on_scan_status event handler."""
def test_scan_started_schedules_progress_and_broadcast(
self, anime_service, mock_progress_service
):
"""started scan event should schedule start_progress and broadcast."""
loop = asyncio.new_event_loop()
anime_service._event_loop = loop
try:
with patch("asyncio.get_running_loop", side_effect=RuntimeError):
with patch("asyncio.run_coroutine_threadsafe") as mock_run:
args = _FakeScanArgs(status="started", total=5)
anime_service._on_scan_status(args)
# 2 calls: start_progress + broadcast_scan_started_safe
assert mock_run.call_count == 2
assert anime_service._is_scanning is True
finally:
loop.close()
def test_scan_progress_updates_counters(
self, anime_service, mock_progress_service
):
"""progress scan event should update counters."""
loop = asyncio.new_event_loop()
anime_service._event_loop = loop
try:
with patch("asyncio.get_running_loop", side_effect=RuntimeError):
with patch("asyncio.run_coroutine_threadsafe"):
args = _FakeScanArgs(
status="progress", current=3, total=10,
folder="Naruto"
)
anime_service._on_scan_status(args)
assert anime_service._scan_directories_count == 3
assert anime_service._scan_current_directory == "Naruto"
finally:
loop.close()
def test_scan_completed_marks_done(
self, anime_service, mock_progress_service
):
"""completed scan event should mark scanning as False."""
loop = asyncio.new_event_loop()
anime_service._event_loop = loop
anime_service._is_scanning = True
anime_service._scan_start_time = time.time() - 5
try:
with patch("asyncio.get_running_loop", side_effect=RuntimeError):
with patch("asyncio.run_coroutine_threadsafe"):
args = _FakeScanArgs(status="completed", total=10)
anime_service._on_scan_status(args)
assert anime_service._is_scanning is False
finally:
loop.close()
def test_scan_failed_marks_done(
self, anime_service, mock_progress_service
):
"""failed scan event should reset scanning state."""
loop = asyncio.new_event_loop()
anime_service._event_loop = loop
anime_service._is_scanning = True
try:
with patch("asyncio.get_running_loop", side_effect=RuntimeError):
with patch("asyncio.run_coroutine_threadsafe"):
args = _FakeScanArgs(
status="failed", error=Exception("boom")
)
anime_service._on_scan_status(args)
assert anime_service._is_scanning is False
finally:
loop.close()
def test_scan_cancelled_marks_done(
self, anime_service, mock_progress_service
):
"""cancelled scan event should reset scanning state."""
loop = asyncio.new_event_loop()
anime_service._event_loop = loop
anime_service._is_scanning = True
try:
with patch("asyncio.get_running_loop", side_effect=RuntimeError):
with patch("asyncio.run_coroutine_threadsafe"):
args = _FakeScanArgs(status="cancelled")
anime_service._on_scan_status(args)
assert anime_service._is_scanning is False
finally:
loop.close()
def test_scan_no_loop_returns_silently(self, anime_service):
"""No loop available should not raise for scan events."""
anime_service._event_loop = None
with patch("asyncio.get_running_loop", side_effect=RuntimeError):
args = _FakeScanArgs(status="started")
anime_service._on_scan_status(args) # no error
class TestGetScanStatus:
"""Test get_scan_status method."""
def test_returns_status_dict(self, anime_service):
"""Should return dict with expected keys."""
anime_service._is_scanning = True
anime_service._scan_total_items = 42
anime_service._scan_directories_count = 7
anime_service._scan_current_directory = "Naruto"
result = anime_service.get_scan_status()
assert result["is_scanning"] is True
assert result["total_items"] == 42
assert result["directories_scanned"] == 7
assert result["current_directory"] == "Naruto"
class TestBroadcastHelpers:
"""Test WebSocket broadcast safety wrappers."""
@pytest.mark.asyncio
async def test_broadcast_scan_started_safe(self, anime_service):
"""Should call websocket_service.broadcast_scan_started."""
anime_service._websocket_service.broadcast_scan_started = AsyncMock()
await anime_service._broadcast_scan_started_safe(total_items=5)
anime_service._websocket_service.broadcast_scan_started.assert_called_once()
@pytest.mark.asyncio
async def test_broadcast_scan_started_safe_handles_error(
self, anime_service
):
"""WS failure should be swallowed, not raised."""
anime_service._websocket_service.broadcast_scan_started = AsyncMock(
side_effect=Exception("ws-down")
)
# Should NOT raise
await anime_service._broadcast_scan_started_safe(total_items=5)
@pytest.mark.asyncio
async def test_broadcast_scan_progress_safe(self, anime_service):
"""Should call broadcast_scan_progress."""
anime_service._websocket_service.broadcast_scan_progress = AsyncMock()
await anime_service._broadcast_scan_progress_safe(
directories_scanned=3, files_found=3,
current_directory="AOT", total_items=10,
)
anime_service._websocket_service.broadcast_scan_progress.assert_called_once()
@pytest.mark.asyncio
async def test_broadcast_scan_progress_safe_handles_error(
self, anime_service
):
"""WS failure should be swallowed."""
anime_service._websocket_service.broadcast_scan_progress = AsyncMock(
side_effect=Exception("ws-down")
)
await anime_service._broadcast_scan_progress_safe(
directories_scanned=0, files_found=0,
current_directory="", total_items=0,
)
@pytest.mark.asyncio
async def test_broadcast_scan_completed_safe(self, anime_service):
"""Should call broadcast_scan_completed."""
anime_service._websocket_service.broadcast_scan_completed = AsyncMock()
await anime_service._broadcast_scan_completed_safe(
total_directories=10, total_files=10, elapsed_seconds=5.0,
)
anime_service._websocket_service.broadcast_scan_completed.assert_called_once()
@pytest.mark.asyncio
async def test_broadcast_scan_completed_safe_handles_error(
self, anime_service
):
"""WS failure should be swallowed."""
anime_service._websocket_service.broadcast_scan_completed = AsyncMock(
side_effect=Exception("ws-down")
)
await anime_service._broadcast_scan_completed_safe(
total_directories=0, total_files=0, elapsed_seconds=0,
)
@pytest.mark.asyncio
async def test_broadcast_series_updated(self, anime_service):
"""Should broadcast series_updated over WebSocket."""
anime_service._websocket_service.broadcast = AsyncMock()
await anime_service._broadcast_series_updated("aot")
anime_service._websocket_service.broadcast.assert_called_once()
payload = anime_service._websocket_service.broadcast.call_args[0][0]
assert payload["type"] == "series_updated"
assert payload["key"] == "aot"
@pytest.mark.asyncio
async def test_broadcast_series_updated_no_ws_service(self, anime_service):
"""Should return silently if no websocket service."""
anime_service._websocket_service = None
await anime_service._broadcast_series_updated("aot") # no error
class TestListSeriesWithFilters:
"""Test list_series_with_filters with database enrichment."""
@pytest.mark.asyncio
async def test_returns_enriched_list(
self, anime_service, mock_series_app
):
"""Should merge SeriesApp data with DB metadata."""
mock_serie = MagicMock()
mock_serie.key = "aot"
mock_serie.name = "Attack on Titan"
mock_serie.site = "aniworld.to"
mock_serie.folder = "Attack on Titan (2013)"
mock_serie.episodeDict = {1: [2, 3]}
mock_list = MagicMock()
mock_list.GetList.return_value = [mock_serie]
mock_series_app.list = mock_list
mock_db_series = MagicMock()
mock_db_series.folder = "Attack on Titan (2013)"
mock_db_series.has_nfo = True
mock_db_series.nfo_created_at = None
mock_db_series.nfo_updated_at = None
mock_db_series.tmdb_id = 1234
mock_db_series.tvdb_id = None
mock_db_series.id = 1
mock_session = AsyncMock()
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
with patch(
"src.server.database.connection.get_db_session",
return_value=mock_ctx,
), patch(
"src.server.database.service.AnimeSeriesService"
) as MockASS:
MockASS.get_all = AsyncMock(return_value=[mock_db_series])
result = await anime_service.list_series_with_filters()
assert len(result) == 1
assert result[0]["key"] == "aot"
assert result[0]["has_nfo"] is True
assert result[0]["tmdb_id"] == 1234
@pytest.mark.asyncio
async def test_empty_series_returns_empty(
self, anime_service, mock_series_app
):
"""Should return [] when SeriesApp has no series."""
mock_list = MagicMock()
mock_list.GetList.return_value = []
mock_series_app.list = mock_list
result = await anime_service.list_series_with_filters()
assert result == []
@pytest.mark.asyncio
async def test_no_list_attribute_returns_empty(
self, anime_service, mock_series_app
):
"""Should return [] when SeriesApp has no list attribute."""
del mock_series_app.list
result = await anime_service.list_series_with_filters()
assert result == []
@pytest.mark.asyncio
async def test_db_error_raises_anime_service_error(
self, anime_service, mock_series_app
):
"""DB failure should raise AnimeServiceError."""
mock_serie = MagicMock()
mock_serie.key = "aot"
mock_serie.name = "AOT"
mock_serie.site = "x"
mock_serie.folder = "AOT"
mock_serie.episodeDict = {}
mock_list = MagicMock()
mock_list.GetList.return_value = [mock_serie]
mock_series_app.list = mock_list
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(
side_effect=RuntimeError("DB down")
)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
with patch(
"src.server.database.connection.get_db_session",
return_value=mock_ctx,
):
with pytest.raises(AnimeServiceError):
await anime_service.list_series_with_filters()
class TestSaveAndLoadDB:
"""Test DB persistence helpers."""
@pytest.mark.asyncio
async def test_save_scan_results_creates_new(
self, anime_service
):
"""New series should be created in DB."""
mock_serie = MagicMock()
mock_serie.key = "naruto"
mock_serie.name = "Naruto"
mock_serie.site = "aniworld.to"
mock_serie.folder = "Naruto"
mock_serie.year = 2002
mock_serie.episodeDict = {1: [1, 2]}
mock_session = AsyncMock()
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
with patch(
"src.server.database.connection.get_db_session",
return_value=mock_ctx,
), patch(
"src.server.database.service.AnimeSeriesService.get_by_key",
new_callable=AsyncMock,
return_value=None,
), patch(
"src.server.database.service.AnimeSeriesService.create",
new_callable=AsyncMock,
return_value=MagicMock(id=1),
) as mock_create, patch(
"src.server.database.service.EpisodeService.create",
new_callable=AsyncMock,
) as mock_ep_create:
count = await anime_service._save_scan_results_to_db(
[mock_serie]
)
assert count == 1
mock_create.assert_called_once()
assert mock_ep_create.call_count == 2
@pytest.mark.asyncio
async def test_save_scan_results_updates_existing(
self, anime_service
):
"""Existing series should be updated in DB."""
mock_serie = MagicMock()
mock_serie.key = "naruto"
mock_serie.name = "Naruto"
mock_serie.site = "aniworld.to"
mock_serie.folder = "Naruto"
mock_serie.episodeDict = {1: [3]}
existing = MagicMock()
existing.id = 1
existing.folder = "Naruto"
mock_session = AsyncMock()
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
with patch(
"src.server.database.connection.get_db_session",
return_value=mock_ctx,
), patch(
"src.server.database.service.AnimeSeriesService.get_by_key",
new_callable=AsyncMock,
return_value=existing,
), patch.object(
anime_service,
"_update_series_in_db",
new_callable=AsyncMock,
) as mock_update:
count = await anime_service._save_scan_results_to_db(
[mock_serie]
)
assert count == 1
mock_update.assert_called_once()
@pytest.mark.asyncio
async def test_load_series_from_db(
self, anime_service, mock_series_app
):
"""Should populate SeriesApp from DB records."""
mock_ep = MagicMock()
mock_ep.season = 1
mock_ep.episode_number = 5
mock_db_series = MagicMock()
mock_db_series.key = "naruto"
mock_db_series.name = "Naruto"
mock_db_series.site = "aniworld.to"
mock_db_series.folder = "Naruto"
mock_db_series.episodes = [mock_ep]
mock_session = AsyncMock()
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
with patch(
"src.server.database.connection.get_db_session",
return_value=mock_ctx,
), patch(
"src.server.database.service.AnimeSeriesService.get_all",
new_callable=AsyncMock,
return_value=[mock_db_series],
):
await anime_service._load_series_from_db()
mock_series_app.load_series_from_list.assert_called_once()
loaded = mock_series_app.load_series_from_list.call_args[0][0]
assert len(loaded) == 1
assert loaded[0].key == "naruto"
@pytest.mark.asyncio
async def test_sync_episodes_to_db(
self, anime_service, mock_series_app
):
"""Should sync missing episodes from memory to DB."""
mock_serie = MagicMock()
mock_serie.episodeDict = {1: [4, 5]}
mock_list = MagicMock()
mock_list.keyDict = {"aot": mock_serie}
mock_series_app.list = mock_list
mock_db_series = MagicMock()
mock_db_series.id = 10
mock_session = AsyncMock()
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
anime_service._websocket_service = MagicMock()
anime_service._websocket_service.broadcast = AsyncMock()
with patch(
"src.server.database.connection.get_db_session",
return_value=mock_ctx,
), patch(
"src.server.database.service.AnimeSeriesService.get_by_key",
new_callable=AsyncMock,
return_value=mock_db_series,
), patch(
"src.server.database.service.EpisodeService.get_by_series",
new_callable=AsyncMock,
return_value=[],
), patch(
"src.server.database.service.EpisodeService.create",
new_callable=AsyncMock,
) as mock_ep_create:
count = await anime_service.sync_episodes_to_db("aot")
assert count == 2
assert mock_ep_create.call_count == 2
@pytest.mark.asyncio
async def test_sync_episodes_no_list_returns_zero(
self, anime_service, mock_series_app
):
"""No series list should return 0."""
del mock_series_app.list
count = await anime_service.sync_episodes_to_db("aot")
assert count == 0
class TestAddSeriesToDB:
"""Test add_series_to_db method."""
@pytest.mark.asyncio
async def test_creates_new_series(self, anime_service):
"""New series should be created in DB."""
mock_serie = MagicMock()
mock_serie.key = "x"
mock_serie.name = "X"
mock_serie.site = "y"
mock_serie.folder = "X"
mock_serie.year = 2020
mock_serie.episodeDict = {1: [1]}
mock_db = AsyncMock()
mock_created = MagicMock(id=99)
with patch(
"src.server.database.service.AnimeSeriesService.get_by_key",
new_callable=AsyncMock,
return_value=None,
), patch(
"src.server.database.service.AnimeSeriesService.create",
new_callable=AsyncMock,
return_value=mock_created,
), patch(
"src.server.database.service.EpisodeService.create",
new_callable=AsyncMock,
):
result = await anime_service.add_series_to_db(mock_serie, mock_db)
assert result is mock_created
@pytest.mark.asyncio
async def test_existing_returns_none(self, anime_service):
"""Already-existing series should return None."""
mock_serie = MagicMock()
mock_serie.key = "x"
mock_serie.name = "X"
mock_db = AsyncMock()
with patch(
"src.server.database.service.AnimeSeriesService.get_by_key",
new_callable=AsyncMock,
return_value=MagicMock(),
):
result = await anime_service.add_series_to_db(mock_serie, mock_db)
assert result is None
class TestContainsInDB:
"""Test contains_in_db method."""
@pytest.mark.asyncio
async def test_exists(self, anime_service):
"""Should return True when series exists."""
mock_db = AsyncMock()
with patch(
"src.server.database.service.AnimeSeriesService.get_by_key",
new_callable=AsyncMock,
return_value=MagicMock(),
):
assert await anime_service.contains_in_db("aot", mock_db) is True
@pytest.mark.asyncio
async def test_not_exists(self, anime_service):
"""Should return False when series missing."""
mock_db = AsyncMock()
with patch(
"src.server.database.service.AnimeSeriesService.get_by_key",
new_callable=AsyncMock,
return_value=None,
):
assert await anime_service.contains_in_db("x", mock_db) is False
class TestUpdateNFOStatusWithoutSession:
"""Test update_nfo_status when no db session is passed (self-managed)."""
@pytest.mark.asyncio
async def test_update_creates_session_and_commits(self, anime_service):
"""Should open its own session and commit."""
mock_series = MagicMock()
mock_series.id = 1
mock_series.has_nfo = False
mock_series.nfo_created_at = None
mock_series.nfo_updated_at = None
mock_session = AsyncMock()
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
with patch(
"src.server.database.connection.get_db_session",
return_value=mock_ctx,
), patch(
"src.server.database.service.AnimeSeriesService.get_by_key",
new_callable=AsyncMock,
return_value=mock_series,
), patch(
"src.server.database.service.AnimeSeriesService.update",
new_callable=AsyncMock,
):
await anime_service.update_nfo_status(
key="test", has_nfo=True, tmdb_id=42
)
# commit called by update path
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_update_not_found_skips(self, anime_service):
"""Should return without error if series not in DB."""
mock_session = AsyncMock()
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
with patch(
"src.server.database.connection.get_db_session",
return_value=mock_ctx,
), patch(
"src.server.database.service.AnimeSeriesService.get_by_key",
new_callable=AsyncMock,
return_value=None,
):
await anime_service.update_nfo_status(key="missing", has_nfo=True)
mock_session.commit.assert_not_called()
class TestGetSeriesWithoutNFOSelfManaged:
"""Test get_series_without_nfo when db=None (self-managed session)."""
@pytest.mark.asyncio
async def test_returns_list(self, anime_service):
"""Should return formatted dicts."""
mock_s = MagicMock()
mock_s.key = "test"
mock_s.name = "Test"
mock_s.folder = "Test"
mock_s.tmdb_id = 1
mock_s.tvdb_id = 2
mock_session = AsyncMock()
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
with patch(
"src.server.database.connection.get_db_session",
return_value=mock_ctx,
), patch(
"src.server.database.service.AnimeSeriesService"
".get_series_without_nfo",
new_callable=AsyncMock,
return_value=[mock_s],
):
result = await anime_service.get_series_without_nfo()
assert len(result) == 1
assert result[0]["has_nfo"] is False
class TestGetNFOStatisticsSelfManaged:
"""Test get_nfo_statistics when db=None (self-managed session)."""
@pytest.mark.asyncio
async def test_returns_stats(self, anime_service):
"""Should compute statistics correctly."""
mock_session = AsyncMock()
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
with patch(
"src.server.database.connection.get_db_session",
return_value=mock_ctx,
), patch(
"src.server.database.service.AnimeSeriesService.count_all",
new_callable=AsyncMock,
return_value=50,
), patch(
"src.server.database.service.AnimeSeriesService.count_with_nfo",
new_callable=AsyncMock,
return_value=30,
), patch(
"src.server.database.service.AnimeSeriesService"
".count_with_tmdb_id",
new_callable=AsyncMock,
return_value=40,
), patch(
"src.server.database.service.AnimeSeriesService"
".count_with_tvdb_id",
new_callable=AsyncMock,
return_value=20,
):
result = await anime_service.get_nfo_statistics()
assert result["total"] == 50
assert result["without_nfo"] == 20
assert result["with_tmdb_id"] == 40
class TestSyncSeriesFromDataFiles:
"""Test module-level sync_series_from_data_files function."""
@pytest.mark.asyncio
async def test_sync_adds_new_series(self, tmp_path):
"""Should create series for data files not in DB."""
mock_serie = MagicMock()
mock_serie.key = "new-series"
mock_serie.name = "New Series"
mock_serie.site = "aniworld.to"
mock_serie.folder = "New Series"
mock_serie.episodeDict = {1: [1]}
mock_session = AsyncMock()
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
with patch(
"src.server.services.anime_service.SeriesApp"
) as MockApp, patch(
"src.server.database.connection.get_db_session",
return_value=mock_ctx,
), patch(
"src.server.database.service.AnimeSeriesService.get_by_key",
new_callable=AsyncMock,
return_value=None,
), patch(
"src.server.database.service.AnimeSeriesService.create",
new_callable=AsyncMock,
return_value=MagicMock(id=1),
) as mock_create, patch(
"src.server.database.service.EpisodeService.create",
new_callable=AsyncMock,
):
mock_app_instance = MagicMock()
mock_app_instance.get_all_series_from_data_files.return_value = [
mock_serie
]
MockApp.return_value = mock_app_instance
count = await sync_series_from_data_files(str(tmp_path))
assert count == 1
mock_create.assert_called_once()
@pytest.mark.asyncio
async def test_sync_skips_existing(self, tmp_path):
"""Already-existing series should be skipped."""
mock_serie = MagicMock()
mock_serie.key = "exists"
mock_serie.name = "Exists"
mock_serie.site = "x"
mock_serie.folder = "Exists"
mock_serie.episodeDict = {}
mock_session = AsyncMock()
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
with patch(
"src.server.services.anime_service.SeriesApp"
) as MockApp, patch(
"src.server.database.connection.get_db_session",
return_value=mock_ctx,
), patch(
"src.server.database.service.AnimeSeriesService.get_by_key",
new_callable=AsyncMock,
return_value=MagicMock(),
), patch(
"src.server.database.service.AnimeSeriesService.create",
new_callable=AsyncMock,
) as mock_create:
mock_app_instance = MagicMock()
mock_app_instance.get_all_series_from_data_files.return_value = [
mock_serie
]
MockApp.return_value = mock_app_instance
count = await sync_series_from_data_files(str(tmp_path))
assert count == 0
mock_create.assert_not_called()
@pytest.mark.asyncio
async def test_sync_no_data_files(self, tmp_path):
"""Empty directory should return 0."""
with patch(
"src.server.services.anime_service.SeriesApp"
) as MockApp:
mock_app_instance = MagicMock()
mock_app_instance.get_all_series_from_data_files.return_value = []
MockApp.return_value = mock_app_instance
count = await sync_series_from_data_files(str(tmp_path))
assert count == 0
@pytest.mark.asyncio
async def test_sync_handles_empty_name(self, tmp_path):
"""Series with empty name should use folder as fallback."""
mock_serie = MagicMock()
mock_serie.key = "no-name"
mock_serie.name = ""
mock_serie.site = "x"
mock_serie.folder = "FallbackFolder"
mock_serie.episodeDict = {}
mock_session = AsyncMock()
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_session)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
with patch(
"src.server.services.anime_service.SeriesApp"
) as MockApp, patch(
"src.server.database.connection.get_db_session",
return_value=mock_ctx,
), patch(
"src.server.database.service.AnimeSeriesService.get_by_key",
new_callable=AsyncMock,
return_value=None,
), patch(
"src.server.database.service.AnimeSeriesService.create",
new_callable=AsyncMock,
return_value=MagicMock(id=1),
) as mock_create:
mock_app_instance = MagicMock()
mock_app_instance.get_all_series_from_data_files.return_value = [
mock_serie
]
MockApp.return_value = mock_app_instance
count = await sync_series_from_data_files(str(tmp_path))
assert count == 1
# The name should have been set to folder
assert mock_serie.name == "FallbackFolder"

View File

@@ -0,0 +1,475 @@
"""Unit tests for database connection module.
Tests cover engine/session lifecycle, utility functions,
TransactionManager, SavepointHandle, and various error paths.
"""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import src.server.database.connection as conn_mod
from src.server.database.connection import (
SavepointHandle,
TransactionManager,
_get_database_url,
get_session_transaction_depth,
is_session_in_transaction,
)
# ── Helpers ───────────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _reset_globals():
"""Reset the module-level globals before/after every test."""
old_engine = conn_mod._engine
old_sync = conn_mod._sync_engine
old_sf = conn_mod._session_factory
old_ssf = conn_mod._sync_session_factory
conn_mod._engine = None
conn_mod._sync_engine = None
conn_mod._session_factory = None
conn_mod._sync_session_factory = None
yield
conn_mod._engine = old_engine
conn_mod._sync_engine = old_sync
conn_mod._session_factory = old_sf
conn_mod._sync_session_factory = old_ssf
# ══════════════════════════════════════════════════════════════════════════════
# _get_database_url
# ══════════════════════════════════════════════════════════════════════════════
class TestGetDatabaseURL:
"""Test _get_database_url helper."""
def test_sqlite_url_converted(self):
"""sqlite:/// should be converted to sqlite+aiosqlite:///."""
with patch.object(
conn_mod.settings, "database_url",
"sqlite:///./data/anime.db",
):
result = _get_database_url()
assert "aiosqlite" in result
def test_non_sqlite_url_unchanged(self):
"""Non-SQLite URL should remain unchanged."""
with patch.object(
conn_mod.settings, "database_url",
"postgresql://user:pass@localhost/db",
):
result = _get_database_url()
assert result == "postgresql://user:pass@localhost/db"
# ══════════════════════════════════════════════════════════════════════════════
# get_engine / get_sync_engine
# ══════════════════════════════════════════════════════════════════════════════
class TestGetEngine:
"""Test get_engine and get_sync_engine."""
def test_raises_when_not_initialized(self):
"""get_engine should raise RuntimeError before init_db."""
with pytest.raises(RuntimeError, match="not initialized"):
conn_mod.get_engine()
def test_returns_engine_when_set(self):
"""Should return the engine when initialised."""
fake_engine = MagicMock()
conn_mod._engine = fake_engine
assert conn_mod.get_engine() is fake_engine
def test_get_sync_engine_raises(self):
"""get_sync_engine should raise RuntimeError before init_db."""
with pytest.raises(RuntimeError, match="not initialized"):
conn_mod.get_sync_engine()
def test_get_sync_engine_returns(self):
"""Should return sync engine when set."""
fake = MagicMock()
conn_mod._sync_engine = fake
assert conn_mod.get_sync_engine() is fake
# ══════════════════════════════════════════════════════════════════════════════
# get_db_session
# ══════════════════════════════════════════════════════════════════════════════
class TestGetDBSession:
"""Test get_db_session async context manager."""
@pytest.mark.asyncio
async def test_raises_when_not_initialized(self):
"""Should raise RuntimeError if session factory is None."""
with pytest.raises(RuntimeError, match="not initialized"):
async with conn_mod.get_db_session():
pass
@pytest.mark.asyncio
async def test_commits_on_success(self):
"""Session should be committed on normal exit."""
mock_session = AsyncMock()
factory = MagicMock(return_value=mock_session)
conn_mod._session_factory = factory
async with conn_mod.get_db_session() as session:
assert session is mock_session
mock_session.commit.assert_called_once()
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_rollback_on_exception(self):
"""Session should be rolled back on exception."""
mock_session = AsyncMock()
factory = MagicMock(return_value=mock_session)
conn_mod._session_factory = factory
with pytest.raises(ValueError):
async with conn_mod.get_db_session():
raise ValueError("boom")
mock_session.rollback.assert_called_once()
mock_session.commit.assert_not_called()
mock_session.close.assert_called_once()
# ══════════════════════════════════════════════════════════════════════════════
# get_sync_session / get_async_session_factory
# ══════════════════════════════════════════════════════════════════════════════
class TestGetSyncSession:
"""Test get_sync_session."""
def test_raises_when_not_initialized(self):
"""Should raise RuntimeError."""
with pytest.raises(RuntimeError, match="not initialized"):
conn_mod.get_sync_session()
def test_returns_session(self):
"""Should return a session from the factory."""
mock_session = MagicMock()
conn_mod._sync_session_factory = MagicMock(return_value=mock_session)
assert conn_mod.get_sync_session() is mock_session
class TestGetAsyncSessionFactory:
"""Test get_async_session_factory."""
def test_raises_when_not_initialized(self):
"""Should raise RuntimeError."""
with pytest.raises(RuntimeError, match="not initialized"):
conn_mod.get_async_session_factory()
def test_returns_session(self):
"""Should return a new async session."""
mock_session = AsyncMock()
conn_mod._session_factory = MagicMock(return_value=mock_session)
assert conn_mod.get_async_session_factory() is mock_session
# ══════════════════════════════════════════════════════════════════════════════
# get_transactional_session
# ══════════════════════════════════════════════════════════════════════════════
class TestGetTransactionalSession:
"""Test get_transactional_session."""
@pytest.mark.asyncio
async def test_raises_when_not_initialized(self):
"""Should raise RuntimeError."""
with pytest.raises(RuntimeError, match="not initialized"):
async with conn_mod.get_transactional_session():
pass
@pytest.mark.asyncio
async def test_does_not_auto_commit(self):
"""Session should NOT be committed on normal exit."""
mock_session = AsyncMock()
conn_mod._session_factory = MagicMock(return_value=mock_session)
async with conn_mod.get_transactional_session() as session:
pass
mock_session.commit.assert_not_called()
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_rollback_on_exception(self):
"""Should rollback on exception."""
mock_session = AsyncMock()
conn_mod._session_factory = MagicMock(return_value=mock_session)
with pytest.raises(ValueError):
async with conn_mod.get_transactional_session():
raise ValueError("boom")
mock_session.rollback.assert_called_once()
# ══════════════════════════════════════════════════════════════════════════════
# close_db
# ══════════════════════════════════════════════════════════════════════════════
class TestCloseDB:
"""Test close_db function."""
@pytest.mark.asyncio
async def test_disposes_engines(self):
"""Should dispose both engines."""
mock_engine = AsyncMock()
mock_sync = MagicMock()
mock_sync.url = "sqlite:///test.db"
mock_sync.connect.return_value.__enter__ = MagicMock()
mock_sync.connect.return_value.__exit__ = MagicMock()
conn_ctx = MagicMock()
conn_ctx.__enter__ = MagicMock(return_value=MagicMock())
conn_ctx.__exit__ = MagicMock(return_value=False)
mock_sync.connect.return_value = conn_ctx
conn_mod._engine = mock_engine
conn_mod._sync_engine = mock_sync
conn_mod._session_factory = MagicMock()
conn_mod._sync_session_factory = MagicMock()
await conn_mod.close_db()
mock_engine.dispose.assert_called_once()
mock_sync.dispose.assert_called_once()
assert conn_mod._engine is None
assert conn_mod._sync_engine is None
@pytest.mark.asyncio
async def test_noop_when_not_initialized(self):
"""Should not raise if engines are None."""
await conn_mod.close_db() # should not raise
# ══════════════════════════════════════════════════════════════════════════════
# TransactionManager
# ══════════════════════════════════════════════════════════════════════════════
class TestTransactionManager:
"""Test TransactionManager class."""
def test_init_raises_without_factory(self):
"""Should raise RuntimeError when no session factory."""
with pytest.raises(RuntimeError, match="not initialized"):
TransactionManager()
@pytest.mark.asyncio
async def test_context_manager_creates_and_closes_session(self):
"""Should create session on enter and close on exit."""
mock_session = AsyncMock()
factory = MagicMock(return_value=mock_session)
async with TransactionManager(session_factory=factory) as tm:
session = await tm.get_session()
assert session is mock_session
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_begin_commit(self):
"""begin then commit should work."""
mock_session = AsyncMock()
factory = MagicMock(return_value=mock_session)
async with TransactionManager(session_factory=factory) as tm:
await tm.begin()
assert tm.is_in_transaction() is True
await tm.commit()
assert tm.is_in_transaction() is False
mock_session.begin.assert_called_once()
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_begin_rollback(self):
"""begin then rollback should work."""
mock_session = AsyncMock()
factory = MagicMock(return_value=mock_session)
async with TransactionManager(session_factory=factory) as tm:
await tm.begin()
await tm.rollback()
assert tm.is_in_transaction() is False
mock_session.rollback.assert_called_once()
@pytest.mark.asyncio
async def test_exception_auto_rollback(self):
"""Exception inside context manager should auto rollback."""
mock_session = AsyncMock()
factory = MagicMock(return_value=mock_session)
with pytest.raises(ValueError):
async with TransactionManager(session_factory=factory) as tm:
await tm.begin()
raise ValueError("boom")
mock_session.rollback.assert_called_once()
@pytest.mark.asyncio
async def test_double_begin_raises(self):
"""begin called twice should raise."""
mock_session = AsyncMock()
factory = MagicMock(return_value=mock_session)
async with TransactionManager(session_factory=factory) as tm:
await tm.begin()
with pytest.raises(RuntimeError, match="Already in"):
await tm.begin()
@pytest.mark.asyncio
async def test_commit_without_begin_raises(self):
"""commit without begin should raise."""
mock_session = AsyncMock()
factory = MagicMock(return_value=mock_session)
async with TransactionManager(session_factory=factory) as tm:
with pytest.raises(RuntimeError, match="Not in"):
await tm.commit()
@pytest.mark.asyncio
async def test_get_session_outside_context_raises(self):
"""get_session outside context manager should raise."""
factory = MagicMock()
tm = TransactionManager(session_factory=factory)
with pytest.raises(RuntimeError, match="context manager"):
await tm.get_session()
@pytest.mark.asyncio
async def test_transaction_depth(self):
"""get_transaction_depth should reflect state."""
mock_session = AsyncMock()
factory = MagicMock(return_value=mock_session)
async with TransactionManager(session_factory=factory) as tm:
assert tm.get_transaction_depth() == 0
await tm.begin()
assert tm.get_transaction_depth() == 1
await tm.commit()
assert tm.get_transaction_depth() == 0
@pytest.mark.asyncio
async def test_savepoint_creation(self):
"""savepoint should return SavepointHandle."""
mock_session = AsyncMock()
mock_nested = AsyncMock()
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
factory = MagicMock(return_value=mock_session)
async with TransactionManager(session_factory=factory) as tm:
await tm.begin()
sp = await tm.savepoint("sp1")
assert isinstance(sp, SavepointHandle)
@pytest.mark.asyncio
async def test_savepoint_without_transaction_raises(self):
"""savepoint outside transaction should raise."""
mock_session = AsyncMock()
factory = MagicMock(return_value=mock_session)
async with TransactionManager(session_factory=factory) as tm:
with pytest.raises(RuntimeError, match="Must be in"):
await tm.savepoint()
@pytest.mark.asyncio
async def test_rollback_without_session_raises(self):
"""rollback without active session should raise."""
factory = MagicMock()
tm = TransactionManager(session_factory=factory)
with pytest.raises(RuntimeError, match="No active session"):
await tm.rollback()
# ══════════════════════════════════════════════════════════════════════════════
# SavepointHandle
# ══════════════════════════════════════════════════════════════════════════════
class TestSavepointHandle:
"""Test SavepointHandle class."""
@pytest.mark.asyncio
async def test_rollback(self):
"""Should call nested.rollback()."""
mock_nested = AsyncMock()
sp = SavepointHandle(mock_nested, "sp1")
await sp.rollback()
mock_nested.rollback.assert_called_once()
assert sp._released is True
@pytest.mark.asyncio
async def test_rollback_idempotent(self):
"""Second rollback should be a noop."""
mock_nested = AsyncMock()
sp = SavepointHandle(mock_nested, "sp1")
await sp.rollback()
await sp.rollback()
mock_nested.rollback.assert_called_once()
@pytest.mark.asyncio
async def test_release(self):
"""Should mark as released."""
mock_nested = AsyncMock()
sp = SavepointHandle(mock_nested, "sp1")
await sp.release()
assert sp._released is True
@pytest.mark.asyncio
async def test_release_idempotent(self):
"""Second release should be a noop."""
mock_nested = AsyncMock()
sp = SavepointHandle(mock_nested, "sp1")
await sp.release()
await sp.release()
assert sp._released is True
# ══════════════════════════════════════════════════════════════════════════════
# Utility Functions
# ══════════════════════════════════════════════════════════════════════════════
class TestUtilityFunctions:
"""Test is_session_in_transaction and get_session_transaction_depth."""
def test_in_transaction_true(self):
"""Should return True when session is in transaction."""
session = MagicMock()
session.in_transaction.return_value = True
assert is_session_in_transaction(session) is True
def test_in_transaction_false(self):
"""Should return False when session is not in transaction."""
session = MagicMock()
session.in_transaction.return_value = False
assert is_session_in_transaction(session) is False
def test_transaction_depth_zero(self):
"""Should return 0 when not in transaction."""
session = MagicMock()
session.in_transaction.return_value = False
assert get_session_transaction_depth(session) == 0
def test_transaction_depth_one(self):
"""Should return 1 when in transaction."""
session = MagicMock()
session.in_transaction.return_value = True
assert get_session_transaction_depth(session) == 1

View File

@@ -442,3 +442,479 @@ class TestEnhancedProviderFromHTML:
result = enhanced_loader._get_provider_from_html(1, 1, "test")
assert result == {}
# ══════════════════════════════════════════════════════════════════════════════
# New coverage tests fetch, download flow, redirect, season counts
# ══════════════════════════════════════════════════════════════════════════════
class TestFetchAnimeListWithRecovery:
"""Test _fetch_anime_list_with_recovery."""
def test_successful_fetch(self, enhanced_loader):
"""Should fetch and parse a JSON response."""
mock_response = MagicMock()
mock_response.ok = True
mock_response.text = json.dumps([{"title": "Naruto"}])
with patch(
"src.core.providers.enhanced_provider.recovery_strategies"
) as mock_rs:
mock_rs.handle_network_failure.return_value = mock_response
result = enhanced_loader._fetch_anime_list_with_recovery(
"https://example.com/search"
)
assert len(result) == 1
assert result[0]["title"] == "Naruto"
def test_404_raises_non_retryable(self, enhanced_loader):
"""404 should raise NonRetryableError."""
mock_response = MagicMock()
mock_response.ok = False
mock_response.status_code = 404
with patch(
"src.core.providers.enhanced_provider.recovery_strategies"
) as mock_rs:
mock_rs.handle_network_failure.return_value = mock_response
with pytest.raises(NonRetryableError, match="not found"):
enhanced_loader._fetch_anime_list_with_recovery(
"https://example.com/search"
)
def test_403_raises_non_retryable(self, enhanced_loader):
"""403 should raise NonRetryableError."""
mock_response = MagicMock()
mock_response.ok = False
mock_response.status_code = 403
with patch(
"src.core.providers.enhanced_provider.recovery_strategies"
) as mock_rs:
mock_rs.handle_network_failure.return_value = mock_response
with pytest.raises(NonRetryableError, match="forbidden"):
enhanced_loader._fetch_anime_list_with_recovery(
"https://example.com/search"
)
def test_500_raises_retryable(self, enhanced_loader):
"""500 should raise RetryableError."""
mock_response = MagicMock()
mock_response.ok = False
mock_response.status_code = 500
with patch(
"src.core.providers.enhanced_provider.recovery_strategies"
) as mock_rs:
mock_rs.handle_network_failure.return_value = mock_response
with pytest.raises(RetryableError, match="Server error"):
enhanced_loader._fetch_anime_list_with_recovery(
"https://example.com/search"
)
def test_network_error_raises_network_error(self, enhanced_loader):
"""requests.RequestException should raise NetworkError."""
import requests as req
with patch(
"src.core.providers.enhanced_provider.recovery_strategies"
) as mock_rs:
mock_rs.handle_network_failure.side_effect = (
req.RequestException("timeout")
)
with pytest.raises(NetworkError, match="Network error"):
enhanced_loader._fetch_anime_list_with_recovery(
"https://example.com/search"
)
class TestGetKeyHTML:
"""Test _GetKeyHTML fetching and caching."""
def test_cached_html_returned(self, enhanced_loader):
"""Already-cached key should skip HTTP call."""
mock_resp = MagicMock()
enhanced_loader._KeyHTMLDict["cached-key"] = mock_resp
result = enhanced_loader._GetKeyHTML("cached-key")
assert result is mock_resp
enhanced_loader.session.get.assert_not_called()
def test_fetches_and_caches(self, enhanced_loader):
"""Missing key should be fetched and cached."""
mock_response = MagicMock()
mock_response.ok = True
with patch(
"src.core.providers.enhanced_provider.recovery_strategies"
) as mock_rs:
mock_rs.handle_network_failure.return_value = mock_response
result = enhanced_loader._GetKeyHTML("new-key")
assert result is mock_response
assert enhanced_loader._KeyHTMLDict["new-key"] is mock_response
def test_404_raises_non_retryable(self, enhanced_loader):
"""404 from server should raise NonRetryableError."""
mock_response = MagicMock()
mock_response.ok = False
mock_response.status_code = 404
with patch(
"src.core.providers.enhanced_provider.recovery_strategies"
) as mock_rs:
mock_rs.handle_network_failure.return_value = mock_response
with pytest.raises(NonRetryableError, match="not found"):
enhanced_loader._GetKeyHTML("missing-key")
class TestGetRedirectLink:
"""Test _get_redirect_link method."""
def test_returns_link_and_provider(self, enhanced_loader):
"""Should return (link, provider_name) tuple."""
with patch.object(
enhanced_loader, "IsLanguage", return_value=True
), patch.object(
enhanced_loader,
"_get_provider_from_html",
return_value={
"VOE": {1: "https://aniworld.to/redirect/100"}
},
):
link, provider = enhanced_loader._get_redirect_link(
1, 1, "test", "German Dub"
)
assert link == "https://aniworld.to/redirect/100"
assert provider == "VOE"
def test_language_unavailable_raises(self, enhanced_loader):
"""Should raise NonRetryableError if language not available."""
with patch.object(
enhanced_loader, "IsLanguage", return_value=False
):
with pytest.raises(NonRetryableError, match="not available"):
enhanced_loader._get_redirect_link(
1, 1, "test", "German Dub"
)
def test_no_provider_found_raises(self, enhanced_loader):
"""Should raise when no provider has the language."""
with patch.object(
enhanced_loader, "IsLanguage", return_value=True
), patch.object(
enhanced_loader,
"_get_provider_from_html",
return_value={"VOE": {2: "link"}}, # English Sub only
):
with pytest.raises(NonRetryableError, match="No provider"):
enhanced_loader._get_redirect_link(
1, 1, "test", "German Dub"
)
class TestGetEmbeddedLink:
"""Test _get_embeded_link method."""
def test_returns_final_url(self, enhanced_loader):
"""Should follow redirect and return final URL."""
mock_response = MagicMock()
mock_response.url = "https://voe.sx/e/abc123"
with patch.object(
enhanced_loader,
"_get_redirect_link",
return_value=("https://aniworld.to/redirect/100", "VOE"),
), patch(
"src.core.providers.enhanced_provider.recovery_strategies"
) as mock_rs:
mock_rs.handle_network_failure.return_value = mock_response
result = enhanced_loader._get_embeded_link(
1, 1, "test", "German Dub"
)
assert result == "https://voe.sx/e/abc123"
def test_redirect_failure_raises(self, enhanced_loader):
"""Should propagate error from _get_redirect_link."""
with patch.object(
enhanced_loader,
"_get_redirect_link",
side_effect=NonRetryableError("no link"),
):
with pytest.raises(NonRetryableError):
enhanced_loader._get_embeded_link(
1, 1, "test", "German Dub"
)
class TestGetDirectLinkFromProvider:
"""Test _get_direct_link_from_provider method."""
def test_returns_link_from_voe(self, enhanced_loader):
"""Should use VOE provider to extract direct link."""
mock_provider = MagicMock()
mock_provider.get_link.return_value = (
"https://direct.example.com/video.mp4",
[],
)
enhanced_loader.Providers = MagicMock()
enhanced_loader.Providers.GetProvider.return_value = mock_provider
with patch.object(
enhanced_loader,
"_get_embeded_link",
return_value="https://voe.sx/e/abc123",
):
result = enhanced_loader._get_direct_link_from_provider(
1, 1, "test", "German Dub"
)
assert result == ("https://direct.example.com/video.mp4", [])
def test_no_embedded_link_raises(self, enhanced_loader):
"""Should raise if embedded link is None."""
with patch.object(
enhanced_loader,
"_get_embeded_link",
return_value=None,
):
with pytest.raises(NonRetryableError, match="No embedded link"):
enhanced_loader._get_direct_link_from_provider(
1, 1, "test", "German Dub"
)
def test_no_provider_raises(self, enhanced_loader):
"""Should raise if VOE provider unavailable."""
enhanced_loader.Providers = MagicMock()
enhanced_loader.Providers.GetProvider.return_value = None
with patch.object(
enhanced_loader,
"_get_embeded_link",
return_value="https://voe.sx/e/abc",
):
with pytest.raises(NonRetryableError, match="VOE provider"):
enhanced_loader._get_direct_link_from_provider(
1, 1, "test", "German Dub"
)
class TestDownloadWithRecovery:
"""Test _download_with_recovery method."""
def test_successful_download(self, enhanced_loader, tmp_path):
"""Should download, verify, and move file."""
temp_path = str(tmp_path / "temp.mp4")
output_path = str(tmp_path / "output.mp4")
# Create a fake temp file after "download"
def fake_download(*args, **kwargs):
with open(temp_path, "wb") as f:
f.write(b"fake-video-data")
return True
with patch(
"src.core.providers.enhanced_provider.recovery_strategies"
) as mock_rs, patch(
"src.core.providers.enhanced_provider.file_corruption_detector"
) as mock_fcd, patch(
"src.core.providers.enhanced_provider.get_integrity_manager"
) as mock_im:
mock_rs.handle_network_failure.return_value = (
"https://direct.example.com/v.mp4",
[],
)
mock_rs.handle_download_failure.side_effect = fake_download
mock_fcd.is_valid_video_file.return_value = True
mock_im.return_value.store_checksum.return_value = "abc123"
result = enhanced_loader._download_with_recovery(
1, 1, "test", "German Dub",
temp_path, output_path, None,
)
assert result is True
assert os.path.exists(output_path)
def test_all_providers_fail_returns_false(self, enhanced_loader, tmp_path):
"""Should return False when all providers fail."""
temp_path = str(tmp_path / "temp.mp4")
output_path = str(tmp_path / "output.mp4")
with patch(
"src.core.providers.enhanced_provider.recovery_strategies"
) as mock_rs:
mock_rs.handle_network_failure.side_effect = Exception("fail")
result = enhanced_loader._download_with_recovery(
1, 1, "test", "German Dub",
temp_path, output_path, None,
)
assert result is False
def test_corrupted_download_removed(self, enhanced_loader, tmp_path):
"""Corrupted downloads should be removed and next provider tried."""
temp_path = str(tmp_path / "temp.mp4")
output_path = str(tmp_path / "output.mp4")
# Create a fake temp file after "download"
def fake_download(*args, **kwargs):
with open(temp_path, "wb") as f:
f.write(b"corrupt")
return True
with patch(
"src.core.providers.enhanced_provider.recovery_strategies"
) as mock_rs, patch(
"src.core.providers.enhanced_provider.file_corruption_detector"
) as mock_fcd:
mock_rs.handle_network_failure.return_value = (
"https://direct.example.com/v.mp4",
[],
)
mock_rs.handle_download_failure.side_effect = fake_download
mock_fcd.is_valid_video_file.return_value = False
result = enhanced_loader._download_with_recovery(
1, 1, "test", "German Dub",
temp_path, output_path, None,
)
assert result is False
class TestGetSeasonEpisodeCount:
"""Test get_season_episode_count method."""
def test_returns_episode_counts(self, enhanced_loader):
"""Should return dict of season -> episode count."""
base_html = (
b'<html><meta itemprop="numberOfSeasons" content="2">'
b"</html>"
)
s1_html = (
b'<html><body>'
b'<a href="/anime/stream/test/staffel-1/episode-1">E1</a>'
b'<a href="/anime/stream/test/staffel-1/episode-2">E2</a>'
b'</body></html>'
)
s2_html = (
b'<html><body>'
b'<a href="/anime/stream/test/staffel-2/episode-1">E1</a>'
b'</body></html>'
)
responses = [
MagicMock(content=base_html),
MagicMock(content=s1_html),
MagicMock(content=s2_html),
]
with patch(
"src.core.providers.enhanced_provider.recovery_strategies"
) as mock_rs:
mock_rs.handle_network_failure.side_effect = responses
result = enhanced_loader.get_season_episode_count("test")
assert result == {1: 2, 2: 1}
def test_no_seasons_meta_returns_empty(self, enhanced_loader):
"""Missing numberOfSeasons meta should return empty dict."""
base_html = b"<html><body>No seasons</body></html>"
with patch(
"src.core.providers.enhanced_provider.recovery_strategies"
) as mock_rs:
mock_rs.handle_network_failure.return_value = MagicMock(
content=base_html
)
result = enhanced_loader.get_season_episode_count("test")
assert result == {}
class TestPerformYtdlDownload:
"""Test _perform_ytdl_download method."""
def test_success(self, enhanced_loader):
"""Should return True on successful download."""
with patch(
"src.core.providers.enhanced_provider.YoutubeDL"
) as MockYDL:
mock_ydl = MagicMock()
MockYDL.return_value.__enter__ = MagicMock(return_value=mock_ydl)
MockYDL.return_value.__exit__ = MagicMock(return_value=False)
result = enhanced_loader._perform_ytdl_download(
{}, "https://example.com/video"
)
assert result is True
def test_failure_raises_download_error(self, enhanced_loader):
"""yt-dlp failure should raise DownloadError."""
with patch(
"src.core.providers.enhanced_provider.YoutubeDL"
) as MockYDL:
mock_ydl = MagicMock()
mock_ydl.download.side_effect = Exception("yt-dlp crash")
MockYDL.return_value.__enter__ = MagicMock(return_value=mock_ydl)
MockYDL.return_value.__exit__ = MagicMock(return_value=False)
with pytest.raises(DownloadError, match="Download failed"):
enhanced_loader._perform_ytdl_download(
{}, "https://example.com/video"
)
class TestDownloadFlow:
"""Test full Download method flow."""
@patch("src.core.providers.enhanced_provider.get_integrity_manager")
def test_existing_valid_file_returns_true(
self, mock_integrity, enhanced_loader, tmp_path
):
"""Should return True if file already exists and is valid."""
# Create fake existing file
folder = tmp_path / "Folder" / "Season 1"
folder.mkdir(parents=True)
video = folder / "Test - S01E001 - (German Dub).mp4"
video.write_bytes(b"valid-video")
enhanced_loader._KeyHTMLDict["key"] = MagicMock(
content=b"<html><div class='series-title'><h1><span>Test</span></h1></div></html>"
)
with patch(
"src.core.providers.enhanced_provider.file_corruption_detector"
) as mock_fcd:
mock_fcd.is_valid_video_file.return_value = True
mock_integrity.return_value.has_checksum.return_value = False
result = enhanced_loader.Download(
str(tmp_path), "Folder", 1, 1, "key"
)
assert result is True
assert enhanced_loader.download_stats["successful_downloads"] == 1
@patch("src.core.providers.enhanced_provider.get_integrity_manager")
def test_missing_key_raises_value_error(
self, mock_integrity, enhanced_loader, tmp_path
):
"""Download with empty key should raise."""
with pytest.raises((ValueError, DownloadError)):
enhanced_loader.Download(str(tmp_path), "folder", 1, 1, "")
class TestAniworldLoaderCompat:
"""Test backward compatibility wrapper."""
def test_inherits_from_enhanced(self):
"""AniworldLoader should extend EnhancedAniWorldLoader."""
from src.core.providers.enhanced_provider import AniworldLoader
assert issubclass(AniworldLoader, EnhancedAniWorldLoader)

View File

@@ -0,0 +1,454 @@
"""Unit tests for QueueRepository.
Tests cover model conversion, CRUD operations (save, get, get_all,
set_error, delete, clear), error handling, and the singleton factory.
"""
from __future__ import annotations
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from src.server.models.download import (
DownloadItem,
DownloadPriority,
DownloadStatus,
EpisodeIdentifier,
)
from src.server.services.queue_repository import (
QueueRepository,
QueueRepositoryError,
get_queue_repository,
reset_queue_repository,
)
# ── Fixtures ──────────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _reset_singleton():
"""Ensure singleton is reset before and after every test."""
reset_queue_repository()
yield
reset_queue_repository()
@pytest.fixture
def mock_session():
"""Async session mock."""
session = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.close = AsyncMock()
return session
@pytest.fixture
def session_factory(mock_session):
"""Factory that returns the mock session."""
return MagicMock(return_value=mock_session)
@pytest.fixture
def repo(session_factory):
"""QueueRepository instance backed by mock session."""
return QueueRepository(db_session_factory=session_factory)
def _make_db_item(
*,
db_id: int = 1,
series_key: str = "aot",
series_name: str = "Attack on Titan",
series_folder: str = "Attack on Titan (2013)",
season: int = 1,
episode_number: int = 3,
episode_title: str | None = None,
created_at: datetime | None = None,
started_at: datetime | None = None,
completed_at: datetime | None = None,
error_message: str | None = None,
download_url: str | None = None,
):
"""Build a fake DB DownloadQueueItem."""
episode = MagicMock()
episode.season = season
episode.episode_number = episode_number
episode.title = episode_title
series = MagicMock()
series.key = series_key
series.folder = series_folder
series.name = series_name
db_item = MagicMock()
db_item.id = db_id
db_item.episode = episode
db_item.series = series
db_item.created_at = created_at or datetime(2025, 1, 1, tzinfo=timezone.utc)
db_item.started_at = started_at
db_item.completed_at = completed_at
db_item.error_message = error_message
db_item.download_url = download_url
return db_item
def _make_download_item(**kwargs) -> DownloadItem:
"""Build a DownloadItem for save tests."""
defaults = dict(
id="tmp-1",
serie_id="naruto",
serie_folder="Naruto",
serie_name="Naruto",
episode=EpisodeIdentifier(season=1, episode=5),
status=DownloadStatus.PENDING,
priority=DownloadPriority.NORMAL,
added_at=datetime.now(timezone.utc),
)
defaults.update(kwargs)
return DownloadItem(**defaults)
# ══════════════════════════════════════════════════════════════════════════════
# Model Conversion
# ══════════════════════════════════════════════════════════════════════════════
class TestFromDBModel:
"""Test _from_db_model conversion."""
def test_basic_conversion(self, repo):
"""Should produce a DownloadItem from a DB model."""
db_item = _make_db_item()
result = repo._from_db_model(db_item)
assert isinstance(result, DownloadItem)
assert result.id == "1"
assert result.serie_id == "aot"
assert result.serie_folder == "Attack on Titan (2013)"
assert result.serie_name == "Attack on Titan"
assert result.episode.season == 1
assert result.episode.episode == 3
assert result.status == DownloadStatus.PENDING
assert result.priority == DownloadPriority.NORMAL
def test_custom_item_id(self, repo):
"""item_id kwarg should override the DB ID."""
db_item = _make_db_item(db_id=99)
result = repo._from_db_model(db_item, item_id="custom-42")
assert result.id == "custom-42"
def test_missing_episode(self, repo):
"""If episode is None, defaults should be used."""
db_item = _make_db_item()
db_item.episode = None
result = repo._from_db_model(db_item)
assert result.episode.season == 1
assert result.episode.episode == 1
def test_missing_series(self, repo):
"""If series is None, defaults should be used."""
db_item = _make_db_item()
db_item.series = None
# serie_name has min_length=1 in Pydantic, so empty string
# causes validation error. This test verifies the fallback behavior.
# The _from_db_model method falls back to empty strings for key/folder
# and empty string for name, which will trigger a Pydantic validation error.
with pytest.raises(Exception):
repo._from_db_model(db_item)
def test_error_message_preserved(self, repo):
"""Error message from DB should be carried over."""
db_item = _make_db_item(error_message="timeout")
result = repo._from_db_model(db_item)
assert result.error == "timeout"
def test_download_url_preserved(self, repo):
"""Source URL from DB should be carried over."""
db_item = _make_db_item(download_url="https://example.com/video.mp4")
result = repo._from_db_model(db_item)
assert str(result.source_url) == "https://example.com/video.mp4"
# ══════════════════════════════════════════════════════════════════════════════
# get_item
# ══════════════════════════════════════════════════════════════════════════════
class TestGetItem:
"""Test get_item method."""
@pytest.mark.asyncio
async def test_returns_item(self, repo, mock_session):
"""Should return a DownloadItem when found."""
db_item = _make_db_item(db_id=5)
with patch(
"src.server.services.queue_repository.DownloadQueueService"
) as MockDQS:
MockDQS.get_by_id = AsyncMock(return_value=db_item)
result = await repo.get_item("5")
assert result is not None
assert result.id == "5"
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_returns_none_when_missing(self, repo, mock_session):
"""Should return None when item not found."""
with patch(
"src.server.services.queue_repository.DownloadQueueService"
) as MockDQS:
MockDQS.get_by_id = AsyncMock(return_value=None)
result = await repo.get_item("999")
assert result is None
@pytest.mark.asyncio
async def test_invalid_id_returns_none(self, repo, mock_session):
"""Non-numeric ID should return None."""
result = await repo.get_item("abc")
assert result is None
@pytest.mark.asyncio
async def test_db_error_raises(self, repo, mock_session):
"""DB error should raise QueueRepositoryError."""
with patch(
"src.server.services.queue_repository.DownloadQueueService"
) as MockDQS:
MockDQS.get_by_id = AsyncMock(
side_effect=RuntimeError("DB down")
)
with pytest.raises(QueueRepositoryError, match="Failed to get item"):
await repo.get_item("1")
# ══════════════════════════════════════════════════════════════════════════════
# get_all_items
# ══════════════════════════════════════════════════════════════════════════════
class TestGetAllItems:
"""Test get_all_items method."""
@pytest.mark.asyncio
async def test_returns_list(self, repo, mock_session):
"""Should return list of DownloadItems."""
db_items = [_make_db_item(db_id=i) for i in range(3)]
with patch(
"src.server.services.queue_repository.DownloadQueueService"
) as MockDQS:
MockDQS.get_all = AsyncMock(return_value=db_items)
result = await repo.get_all_items()
assert len(result) == 3
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_empty_returns_empty_list(self, repo, mock_session):
"""Should return [] when no items exist."""
with patch(
"src.server.services.queue_repository.DownloadQueueService"
) as MockDQS:
MockDQS.get_all = AsyncMock(return_value=[])
result = await repo.get_all_items()
assert result == []
@pytest.mark.asyncio
async def test_db_error_raises(self, repo, mock_session):
"""DB error should raise QueueRepositoryError."""
with patch(
"src.server.services.queue_repository.DownloadQueueService"
) as MockDQS:
MockDQS.get_all = AsyncMock(
side_effect=RuntimeError("DB down")
)
with pytest.raises(QueueRepositoryError, match="Failed to get all"):
await repo.get_all_items()
# ══════════════════════════════════════════════════════════════════════════════
# set_error
# ══════════════════════════════════════════════════════════════════════════════
class TestSetError:
"""Test set_error method."""
@pytest.mark.asyncio
async def test_success(self, repo, mock_session):
"""Should return True on success."""
with patch(
"src.server.services.queue_repository.DownloadQueueService"
) as MockDQS:
MockDQS.set_error = AsyncMock(return_value=MagicMock())
result = await repo.set_error("1", "some error")
assert result is True
mock_session.commit.assert_called_once()
mock_session.close.assert_called_once()
@pytest.mark.asyncio
async def test_not_found(self, repo, mock_session):
"""Should return False when item not found."""
with patch(
"src.server.services.queue_repository.DownloadQueueService"
) as MockDQS:
MockDQS.set_error = AsyncMock(return_value=None)
result = await repo.set_error("999", "err")
assert result is False
@pytest.mark.asyncio
async def test_invalid_id_returns_false(self, repo, mock_session):
"""Non-numeric ID should return False."""
result = await repo.set_error("abc", "err")
assert result is False
@pytest.mark.asyncio
async def test_db_error_raises(self, repo, mock_session):
"""DB error should raise QueueRepositoryError."""
with patch(
"src.server.services.queue_repository.DownloadQueueService"
) as MockDQS:
MockDQS.set_error = AsyncMock(
side_effect=RuntimeError("boom")
)
with pytest.raises(QueueRepositoryError, match="Failed to set error"):
await repo.set_error("1", "err")
mock_session.rollback.assert_called_once()
# ══════════════════════════════════════════════════════════════════════════════
# delete_item
# ══════════════════════════════════════════════════════════════════════════════
class TestDeleteItem:
"""Test delete_item method."""
@pytest.mark.asyncio
async def test_success(self, repo, mock_session):
"""Should return True when deleted."""
with patch(
"src.server.services.queue_repository.DownloadQueueService"
) as MockDQS:
MockDQS.delete = AsyncMock(return_value=True)
result = await repo.delete_item("1")
assert result is True
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_not_found(self, repo, mock_session):
"""Should return False when item does not exist."""
with patch(
"src.server.services.queue_repository.DownloadQueueService"
) as MockDQS:
MockDQS.delete = AsyncMock(return_value=False)
result = await repo.delete_item("999")
assert result is False
@pytest.mark.asyncio
async def test_invalid_id_returns_false(self, repo, mock_session):
"""Non-numeric ID should return False."""
result = await repo.delete_item("abc")
assert result is False
@pytest.mark.asyncio
async def test_db_error_raises(self, repo, mock_session):
"""DB error should raise QueueRepositoryError."""
with patch(
"src.server.services.queue_repository.DownloadQueueService"
) as MockDQS:
MockDQS.delete = AsyncMock(side_effect=RuntimeError("boom"))
with pytest.raises(QueueRepositoryError, match="Failed to delete"):
await repo.delete_item("1")
mock_session.rollback.assert_called_once()
# ══════════════════════════════════════════════════════════════════════════════
# clear_all
# ══════════════════════════════════════════════════════════════════════════════
class TestClearAll:
"""Test clear_all method."""
@pytest.mark.asyncio
async def test_returns_count(self, repo, mock_session):
"""Should return number of deleted items."""
with patch(
"src.server.services.queue_repository.DownloadQueueService"
) as MockDQS, patch(
"src.server.services.queue_repository.atomic"
) as mock_atomic:
MockDQS.clear_all = AsyncMock(return_value=5)
# atomic context manager
mock_atomic.return_value.__aenter__ = AsyncMock()
mock_atomic.return_value.__aexit__ = AsyncMock(return_value=False)
result = await repo.clear_all()
assert result == 5
@pytest.mark.asyncio
async def test_empty_queue_returns_zero(self, repo, mock_session):
"""Should return 0 when queue is empty."""
with patch(
"src.server.services.queue_repository.DownloadQueueService"
) as MockDQS, patch(
"src.server.services.queue_repository.atomic"
) as mock_atomic:
MockDQS.clear_all = AsyncMock(return_value=0)
mock_atomic.return_value.__aenter__ = AsyncMock()
mock_atomic.return_value.__aexit__ = AsyncMock(return_value=False)
result = await repo.clear_all()
assert result == 0
@pytest.mark.asyncio
async def test_db_error_raises(self, repo, mock_session):
"""DB error should raise QueueRepositoryError."""
with patch(
"src.server.services.queue_repository.atomic"
) as mock_atomic:
mock_atomic.return_value.__aenter__ = AsyncMock(
side_effect=RuntimeError("boom")
)
mock_atomic.return_value.__aexit__ = AsyncMock(return_value=False)
with pytest.raises(QueueRepositoryError, match="Failed to clear"):
await repo.clear_all()
# ══════════════════════════════════════════════════════════════════════════════
# Singleton Factory
# ══════════════════════════════════════════════════════════════════════════════
class TestSingletonFactory:
"""Test get_queue_repository and reset."""
def test_creates_singleton(self):
"""Should return same instance on repeated calls."""
factory = MagicMock()
instance1 = get_queue_repository(factory)
instance2 = get_queue_repository(factory)
assert instance1 is instance2
def test_reset_clears_instance(self):
"""After reset, a new instance should be created."""
factory = MagicMock()
instance1 = get_queue_repository(factory)
reset_queue_repository()
instance2 = get_queue_repository(factory)
assert instance1 is not instance2
def test_default_factory_used_when_none(self):
"""When no factory passed, should use default from connection."""
with patch(
"src.server.database.connection.get_async_session_factory"
) as mock_factory:
instance = get_queue_repository()
assert instance is not None

View File

@@ -317,3 +317,338 @@ class TestSerieScannerSingleSeries:
# Should only show missing episodes
assert result == {1: [4, 5, 6]}
# ══════════════════════════════════════════════════════════════════════════════
# New coverage tests events, year extraction, find_mp4, read_data
# ══════════════════════════════════════════════════════════════════════════════
class TestEventSubscription:
"""Test subscribe/unsubscribe for all event types."""
def test_subscribe_on_progress(self, temp_directory, mock_loader):
"""Should add handler to on_progress."""
scanner = SerieScanner(temp_directory, mock_loader)
handler = MagicMock()
scanner.subscribe_on_progress(handler)
assert handler in scanner.events.on_progress
def test_unsubscribe_on_progress(self, temp_directory, mock_loader):
"""Should remove handler from on_progress."""
scanner = SerieScanner(temp_directory, mock_loader)
handler = MagicMock()
scanner.subscribe_on_progress(handler)
scanner.unsubscribe_on_progress(handler)
assert handler not in scanner.events.on_progress
def test_subscribe_duplicate_ignored(self, temp_directory, mock_loader):
"""Subscribing same handler twice should not duplicate."""
scanner = SerieScanner(temp_directory, mock_loader)
handler = MagicMock()
scanner.subscribe_on_progress(handler)
scanner.subscribe_on_progress(handler)
assert scanner.events.on_progress.count(handler) == 1
def test_unsubscribe_missing_handler_noop(
self, temp_directory, mock_loader
):
"""Unsubscribing unknown handler should not raise."""
scanner = SerieScanner(temp_directory, mock_loader)
handler = MagicMock()
scanner.unsubscribe_on_progress(handler) # should not raise
def test_subscribe_on_error(self, temp_directory, mock_loader):
"""Should add handler to on_error."""
scanner = SerieScanner(temp_directory, mock_loader)
handler = MagicMock()
scanner.subscribe_on_error(handler)
assert handler in scanner.events.on_error
def test_unsubscribe_on_error(self, temp_directory, mock_loader):
"""Should remove handler from on_error."""
scanner = SerieScanner(temp_directory, mock_loader)
handler = MagicMock()
scanner.subscribe_on_error(handler)
scanner.unsubscribe_on_error(handler)
assert handler not in scanner.events.on_error
def test_subscribe_on_completion(self, temp_directory, mock_loader):
"""Should add handler to on_completion."""
scanner = SerieScanner(temp_directory, mock_loader)
handler = MagicMock()
scanner.subscribe_on_completion(handler)
assert handler in scanner.events.on_completion
def test_unsubscribe_on_completion(self, temp_directory, mock_loader):
"""Should remove handler from on_completion."""
scanner = SerieScanner(temp_directory, mock_loader)
handler = MagicMock()
scanner.subscribe_on_completion(handler)
scanner.unsubscribe_on_completion(handler)
assert handler not in scanner.events.on_completion
class TestExtractYearFromFolderName:
"""Test _extract_year_from_folder_name."""
def test_extracts_year(self, temp_directory, mock_loader):
"""Should extract year from folder like 'Title (2025)'."""
scanner = SerieScanner(temp_directory, mock_loader)
assert scanner._extract_year_from_folder_name("Dororo (2025)") == 2025
def test_no_year_returns_none(self, temp_directory, mock_loader):
"""Folder without year returns None."""
scanner = SerieScanner(temp_directory, mock_loader)
assert scanner._extract_year_from_folder_name("Dororo") is None
def test_empty_string_returns_none(self, temp_directory, mock_loader):
"""Empty string returns None."""
scanner = SerieScanner(temp_directory, mock_loader)
assert scanner._extract_year_from_folder_name("") is None
def test_none_returns_none(self, temp_directory, mock_loader):
"""None input returns None."""
scanner = SerieScanner(temp_directory, mock_loader)
assert scanner._extract_year_from_folder_name(None) is None
def test_year_out_of_range_returns_none(
self, temp_directory, mock_loader
):
"""Year outside 1900-2100 returns None."""
scanner = SerieScanner(temp_directory, mock_loader)
assert scanner._extract_year_from_folder_name("Title (1800)") is None
assert scanner._extract_year_from_folder_name("Title (2200)") is None
def test_year_in_middle(self, temp_directory, mock_loader):
"""Year in the middle of folder name should be extracted."""
scanner = SerieScanner(temp_directory, mock_loader)
assert (
scanner._extract_year_from_folder_name("Title (2020) - Extra")
== 2020
)
class TestSafeCallEvent:
"""Test _safe_call_event method."""
def test_calls_handler(self, temp_directory, mock_loader):
"""Handler should be called with data."""
scanner = SerieScanner(temp_directory, mock_loader)
handler = MagicMock()
scanner.events.on_progress = [handler]
scanner._safe_call_event(scanner.events.on_progress, {"test": True})
handler.assert_called_once_with({"test": True})
def test_handler_error_swallowed(self, temp_directory, mock_loader):
"""Handler exceptions should be swallowed."""
scanner = SerieScanner(temp_directory, mock_loader)
handler = MagicMock(side_effect=Exception("boom"))
scanner.events.on_progress = [handler]
# Should not raise
scanner._safe_call_event(scanner.events.on_progress, {"test": True})
def test_empty_handler_list_noop(self, temp_directory, mock_loader):
"""Empty handler list should not raise."""
scanner = SerieScanner(temp_directory, mock_loader)
scanner.events.on_progress = []
scanner._safe_call_event(scanner.events.on_progress, {"test": True})
class TestFindMp4Files:
"""Test __find_mp4_files method."""
def test_finds_mp4_files(self, temp_directory, mock_loader):
"""Should yield folders with mp4 files."""
scanner = SerieScanner(temp_directory, mock_loader)
result = list(scanner._SerieScanner__find_mp4_files())
# temp_directory has "Attack on Titan (2013)" with one mp4
assert len(result) >= 1
folder, mp4s = result[0]
assert folder == "Attack on Titan (2013)"
assert len(mp4s) == 1
def test_empty_directory(self, mock_loader):
"""Should yield nothing for empty directory."""
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
scanner = SerieScanner(tmpdir, mock_loader)
result = list(scanner._SerieScanner__find_mp4_files())
assert len(result) == 0
def test_nested_mp4_files(self, mock_loader):
"""Should find mp4 files in subdirectories."""
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
# Create nested structure
anime = os.path.join(tmpdir, "Naruto")
season = os.path.join(anime, "Season 1")
os.makedirs(season)
with open(os.path.join(season, "ep1.mp4"), "w") as f:
f.write("dummy")
scanner = SerieScanner(tmpdir, mock_loader)
result = list(scanner._SerieScanner__find_mp4_files())
assert len(result) == 1
assert "Naruto" == result[0][0]
assert len(result[0][1]) == 1
def test_non_mp4_ignored(self, mock_loader):
"""Should ignore non-mp4 files."""
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
anime = os.path.join(tmpdir, "TestAnime")
os.makedirs(anime)
with open(os.path.join(anime, "readme.txt"), "w") as f:
f.write("not a video")
scanner = SerieScanner(tmpdir, mock_loader)
result = list(scanner._SerieScanner__find_mp4_files())
# The folder is yielded but with empty mp4 list
assert len(result) == 1
assert result[0][1] == []
class TestReadDataFromFile:
"""Test __read_data_from_file method."""
def test_reads_key_file(self, mock_loader):
"""Should read key from 'key' file."""
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
anime_folder = os.path.join(tmpdir, "SomeAnime")
os.makedirs(anime_folder)
with open(os.path.join(anime_folder, "key"), "w") as f:
f.write("some-key")
scanner = SerieScanner(tmpdir, mock_loader)
result = scanner._SerieScanner__read_data_from_file("SomeAnime")
assert result is not None
assert result.key == "some-key"
def test_reads_data_file(self, mock_loader):
"""Should read Serie from 'data' file when no 'key' file."""
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
anime_folder = os.path.join(tmpdir, "SomeAnime")
os.makedirs(anime_folder)
# Create a data file
serie = Serie("test-key", "Test", "aniworld.to", "SomeAnime", {})
data_path = os.path.join(anime_folder, "data")
serie.save_to_file(data_path)
scanner = SerieScanner(tmpdir, mock_loader)
result = scanner._SerieScanner__read_data_from_file("SomeAnime")
assert result is not None
assert result.key == "test-key"
def test_no_files_returns_none(self, mock_loader):
"""Should return None when no key or data file exists."""
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
anime_folder = os.path.join(tmpdir, "Empty")
os.makedirs(anime_folder)
scanner = SerieScanner(tmpdir, mock_loader)
result = scanner._SerieScanner__read_data_from_file("Empty")
assert result is None
class TestReinit:
"""Test reinit method."""
def test_clears_keydict(self, temp_directory, mock_loader):
"""reinit should clear the keyDict."""
scanner = SerieScanner(temp_directory, mock_loader)
scanner.keyDict["test"] = MagicMock()
scanner.reinit()
assert scanner.keyDict == {}
class TestGetTotalToScan:
"""Test get_total_to_scan."""
def test_counts_folders(self, temp_directory, mock_loader):
"""Should count number of folders."""
scanner = SerieScanner(temp_directory, mock_loader)
count = scanner.get_total_to_scan()
assert count >= 1
def test_empty_directory(self, mock_loader):
"""Should return 0 for empty directory."""
import tempfile
with tempfile.TemporaryDirectory() as tmpdir:
scanner = SerieScanner(tmpdir, mock_loader)
assert scanner.get_total_to_scan() == 0
class TestScanProgressEvents:
"""Test that scan emits progress and completion events."""
def test_scan_emits_progress(self, temp_directory, mock_loader):
"""Should emit on_progress during scan."""
scanner = SerieScanner(temp_directory, mock_loader)
progress_handler = MagicMock()
scanner.subscribe_on_progress(progress_handler)
with patch.object(scanner, 'get_total_to_scan', return_value=0), \
patch.object(
scanner, '_SerieScanner__find_mp4_files',
return_value=iter([])
):
scanner.scan()
# At minimum, STARTING event should fire
assert progress_handler.call_count >= 1
first_call = progress_handler.call_args_list[0][0][0]
assert first_call["phase"] == "STARTING"
def test_scan_emits_completion(self, temp_directory, mock_loader):
"""Should emit on_completion after scan."""
scanner = SerieScanner(temp_directory, mock_loader)
completion_handler = MagicMock()
scanner.subscribe_on_completion(completion_handler)
with patch.object(scanner, 'get_total_to_scan', return_value=0), \
patch.object(
scanner, '_SerieScanner__find_mp4_files',
return_value=iter([])
):
scanner.scan()
completion_handler.assert_called_once()
call_data = completion_handler.call_args[0][0]
assert call_data["success"] is True
def test_scan_emits_error_on_no_key(
self, temp_directory, mock_loader
):
"""Should emit on_error when NoKeyFoundException occurs."""
from src.core.exceptions.Exceptions import NoKeyFoundException
scanner = SerieScanner(temp_directory, mock_loader)
error_handler = MagicMock()
scanner.subscribe_on_error(error_handler)
with patch.object(scanner, 'get_total_to_scan', return_value=1), \
patch.object(
scanner, '_SerieScanner__find_mp4_files',
return_value=iter([("BadFolder", ["e1.mp4"])])
), \
patch.object(
scanner, '_SerieScanner__read_data_from_file',
side_effect=NoKeyFoundException("no key"),
):
scanner.scan()
error_handler.assert_called_once()
call_data = error_handler.call_args[0][0]
assert call_data["recoverable"] is True