Fix async context manager usage in BackgroundLoaderService

- Changed 'async for' to 'async with' for get_db_session()
- get_db_session() is @asynccontextmanager, requires async with not async for
- Created 5 comprehensive unit tests verifying the fix
- All tests pass, background loading now works correctly
This commit is contained in:
2026-01-19 19:50:25 +01:00
parent 62bdcf35cb
commit 7d95c180a9
4 changed files with 340 additions and 26 deletions

View File

@@ -0,0 +1,304 @@
"""
Unit tests for BackgroundLoaderService database session handling.
This module tests that the background loader service properly uses async context
managers for database sessions, preventing TypeError with async for.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from datetime import datetime, timezone
from src.server.services.background_loader_service import (
BackgroundLoaderService,
SeriesLoadingTask,
LoadingStatus,
)
@pytest.mark.asyncio
async def test_load_series_data_uses_async_with_not_async_for():
"""Test that _load_series_data uses 'async with' for database session.
This test verifies the fix for the TypeError:
'async for' requires an object with __aiter__ method, got _AsyncGeneratorContextManager
The code should use 'async with get_db_session() as db:' not 'async for db in get_db_session():'
"""
# Create a fake series app
fake_series_app = MagicMock()
fake_series_app.directory_to_search = "/fake/path"
fake_series_app.loader = MagicMock()
# Create fake websocket and anime services
fake_websocket_service = AsyncMock()
fake_anime_service = AsyncMock()
# Create the service
service = BackgroundLoaderService(
websocket_service=fake_websocket_service,
anime_service=fake_anime_service,
series_app=fake_series_app
)
# Create a test task
task = SeriesLoadingTask(
key="test-anime",
folder="Test Anime",
name="Test Anime",
year=2023,
status=LoadingStatus.PENDING,
progress={"episodes": False, "nfo": False, "logo": False, "images": False},
started_at=datetime.now(timezone.utc),
)
# Mock the database session
mock_db = AsyncMock()
mock_series_db = MagicMock()
mock_series_db.loading_status = "pending"
# Mock database service
with patch('src.server.database.connection.get_db_session') as mock_get_db:
with patch('src.server.database.service.AnimeSeriesService') as mock_service:
# Configure the async context manager
mock_get_db.return_value.__aenter__ = AsyncMock(return_value=mock_db)
mock_get_db.return_value.__aexit__ = AsyncMock(return_value=None)
# Configure service methods
mock_service.get_by_key = AsyncMock(return_value=mock_series_db)
mock_db.commit = AsyncMock()
# Mock helper methods
service.check_missing_data = AsyncMock(return_value={
"episodes": False,
"nfo": False,
"logo": False,
"images": False
})
service._broadcast_status = AsyncMock()
# Execute the method - this should not raise TypeError
await service._load_series_data(task)
# Verify the context manager was entered (proves we used 'async with')
mock_get_db.return_value.__aenter__.assert_called_once()
mock_get_db.return_value.__aexit__.assert_called_once()
# Verify task was marked as completed
assert task.status == LoadingStatus.COMPLETED
assert task.completed_at is not None
@pytest.mark.asyncio
async def test_load_series_data_handles_database_errors():
"""Test that _load_series_data properly handles database errors."""
# Create a fake series app
fake_series_app = MagicMock()
fake_series_app.directory_to_search = "/fake/path"
fake_series_app.loader = MagicMock()
# Create fake websocket and anime services
fake_websocket_service = AsyncMock()
fake_anime_service = AsyncMock()
# Create the service
service = BackgroundLoaderService(
websocket_service=fake_websocket_service,
anime_service=fake_anime_service,
series_app=fake_series_app
)
# Create a test task
task = SeriesLoadingTask(
key="test-anime",
folder="Test Anime",
name="Test Anime",
year=2023,
status=LoadingStatus.PENDING,
progress={"episodes": False, "nfo": False, "logo": False, "images": False},
started_at=datetime.now(timezone.utc),
)
# Mock the database session to raise an error
mock_db = AsyncMock()
with patch('src.server.database.connection.get_db_session') as mock_get_db:
with patch('src.server.database.service.AnimeSeriesService') as mock_service:
# Configure the async context manager
mock_get_db.return_value.__aenter__ = AsyncMock(return_value=mock_db)
mock_get_db.return_value.__aexit__ = AsyncMock(return_value=None)
# Make check_missing_data raise an error
service.check_missing_data = AsyncMock(side_effect=Exception("Database error"))
service._broadcast_status = AsyncMock()
mock_service.get_by_key = AsyncMock(return_value=None)
# Execute - should handle error gracefully
await service._load_series_data(task)
# Verify task was marked as failed
assert task.status == LoadingStatus.FAILED
assert task.error == "Database error"
assert task.completed_at is not None
@pytest.mark.asyncio
async def test_load_series_data_loads_missing_episodes():
"""Test that _load_series_data loads episodes when missing."""
# Create a fake series app
fake_series_app = MagicMock()
fake_series_app.directory_to_search = "/fake/path"
fake_series_app.loader = MagicMock()
# Create fake websocket and anime services
fake_websocket_service = AsyncMock()
fake_anime_service = AsyncMock()
# Create the service
service = BackgroundLoaderService(
websocket_service=fake_websocket_service,
anime_service=fake_anime_service,
series_app=fake_series_app
)
# Create a test task
task = SeriesLoadingTask(
key="test-anime",
folder="Test Anime",
name="Test Anime",
year=2023,
status=LoadingStatus.PENDING,
progress={"episodes": False, "nfo": False, "logo": False, "images": False},
started_at=datetime.now(timezone.utc),
)
# Mock the database session
mock_db = AsyncMock()
mock_series_db = MagicMock()
with patch('src.server.database.connection.get_db_session') as mock_get_db:
with patch('src.server.database.service.AnimeSeriesService') as mock_service:
# Configure the async context manager
mock_get_db.return_value.__aenter__ = AsyncMock(return_value=mock_db)
mock_get_db.return_value.__aexit__ = AsyncMock(return_value=None)
# Configure service methods
mock_service.get_by_key = AsyncMock(return_value=mock_series_db)
mock_db.commit = AsyncMock()
# Mock helper methods - episodes are missing
service.check_missing_data = AsyncMock(return_value={
"episodes": True, # Episodes are missing
"nfo": False,
"logo": False,
"images": False
})
service._load_episodes = AsyncMock()
service._broadcast_status = AsyncMock()
# Execute
await service._load_series_data(task)
# Verify _load_episodes was called
service._load_episodes.assert_called_once_with(task, mock_db)
# Verify task completed
assert task.status == LoadingStatus.COMPLETED
@pytest.mark.asyncio
async def test_load_series_data_loads_nfo_and_images():
"""Test that _load_series_data loads NFO and images when missing."""
# Create a fake series app
fake_series_app = MagicMock()
fake_series_app.directory_to_search = "/fake/path"
fake_series_app.loader = MagicMock()
# Create fake websocket and anime services
fake_websocket_service = AsyncMock()
fake_anime_service = AsyncMock()
# Create the service
service = BackgroundLoaderService(
websocket_service=fake_websocket_service,
anime_service=fake_anime_service,
series_app=fake_series_app
)
# Create a test task
task = SeriesLoadingTask(
key="test-anime",
folder="Test Anime",
name="Test Anime",
year=2023,
status=LoadingStatus.PENDING,
progress={"episodes": False, "nfo": False, "logo": False, "images": False},
started_at=datetime.now(timezone.utc),
)
# Mock the database session
mock_db = AsyncMock()
mock_series_db = MagicMock()
with patch('src.server.database.connection.get_db_session') as mock_get_db:
with patch('src.server.database.service.AnimeSeriesService') as mock_service:
# Configure the async context manager
mock_get_db.return_value.__aenter__ = AsyncMock(return_value=mock_db)
mock_get_db.return_value.__aexit__ = AsyncMock(return_value=None)
# Configure service methods
mock_service.get_by_key = AsyncMock(return_value=mock_series_db)
mock_db.commit = AsyncMock()
# Mock helper methods - NFO and images are missing
service.check_missing_data = AsyncMock(return_value={
"episodes": False,
"nfo": True, # NFO is missing
"logo": True, # Logo is missing
"images": True # Images are missing
})
service._load_nfo_and_images = AsyncMock()
service._broadcast_status = AsyncMock()
# Execute
await service._load_series_data(task)
# Verify _load_nfo_and_images was called
service._load_nfo_and_images.assert_called_once_with(task, mock_db)
# Verify task completed
assert task.status == LoadingStatus.COMPLETED
@pytest.mark.asyncio
async def test_async_context_manager_usage():
"""Direct test that verifies async context manager is used correctly.
This test ensures the code uses 'async with' pattern, not 'async for'.
"""
from contextlib import asynccontextmanager
from typing import AsyncGenerator
# Create a test async context manager
call_log = []
@asynccontextmanager
async def test_context_manager() -> AsyncGenerator:
call_log.append("enter")
yield "test_value"
call_log.append("exit")
# Test that async with works
async with test_context_manager() as value:
assert value == "test_value"
assert "enter" in call_log
assert "exit" in call_log
# Verify that async for would fail with this pattern
call_log.clear()
try:
async for item in test_context_manager():
pass
assert False, "Should have raised TypeError"
except TypeError as e:
assert "__aiter__" in str(e) or "async for" in str(e)

View File

@@ -4,11 +4,12 @@ Unit tests for dependency exception handling in FastAPI dependencies.
This module tests that async generator dependencies properly handle exceptions
thrown back into them, preventing the "generator didn't stop after athrow()" error.
"""
import pytest
from fastapi import FastAPI, HTTPException, Depends
from httpx import AsyncClient, ASGITransport
from typing import AsyncGenerator, Optional
import pytest
from fastapi import Depends, FastAPI, HTTPException
from httpx import ASGITransport, AsyncClient
@pytest.mark.asyncio
async def test_get_optional_database_session_handles_http_exception():
@@ -18,7 +19,7 @@ async def test_get_optional_database_session_handles_http_exception():
that occurred when an HTTPException was raised after yielding a database session.
"""
from src.server.utils.dependencies import get_optional_database_session
# Create a test app
app = FastAPI()
@@ -48,7 +49,7 @@ async def test_get_database_session_handles_http_exception():
that occurred when an HTTPException was raised after yielding a database session.
"""
from src.server.utils.dependencies import get_database_session
# Create a test app
app = FastAPI()