Fix async context manager usage in BackgroundLoaderService
- Changed 'async for' to 'async with' for get_db_session() - get_db_session() is @asynccontextmanager, requires async with not async for - Created 5 comprehensive unit tests verifying the fix - All tests pass, background loading now works correctly
This commit is contained in:
@@ -121,35 +121,46 @@ For each task completed:
|
||||
|
||||
✅ **Task Completed Successfully**
|
||||
|
||||
### Issue Fixed: RuntimeError: generator didn't stop after athrow()
|
||||
### Issue Fixed: TypeError: 'async for' requires an object with __aiter__ method
|
||||
|
||||
**Problem:**
|
||||
The `/api/anime/add` endpoint was throwing a 500 error with `RuntimeError: generator didn't stop after athrow()` when validation errors (HTTPException) were raised after database session dependencies yielded.
|
||||
The BackgroundLoaderService was crashing when trying to load series data with the error:
|
||||
```
|
||||
TypeError: 'async for' requires an object with __aiter__ method, got _AsyncGeneratorContextManager
|
||||
```
|
||||
|
||||
This error occurred in `_load_series_data` method at line 282 of [background_loader_service.py](src/server/services/background_loader_service.py).
|
||||
|
||||
**Root Cause:**
|
||||
Async generator dependencies (`get_database_session` and `get_optional_database_session`) didn't properly handle exceptions thrown back into them after yielding. When an `HTTPException` was raised in the endpoint for validation errors, Python's async context manager tried to propagate the exception to the generator, but without proper exception handling around the yield statement, it resulted in the "generator didn't stop after athrow()" error.
|
||||
The code was incorrectly using `async for db in get_db_session():` to get a database session. However, `get_db_session()` is decorated with `@asynccontextmanager`, which returns an async context manager (not an async iterator). Async context managers must be used with `async with`, not `async for`.
|
||||
|
||||
**Solution:**
|
||||
Added proper exception handling in both database session dependencies by wrapping the `yield` statement with a try-except block that re-raises any exceptions, allowing FastAPI to handle them correctly.
|
||||
Changed the database session acquisition from:
|
||||
```python
|
||||
async for db in get_db_session():
|
||||
# ... code ...
|
||||
break # Exit loop after first iteration
|
||||
```
|
||||
|
||||
To the correct pattern:
|
||||
```python
|
||||
async with get_db_session() as db:
|
||||
# ... code ...
|
||||
```
|
||||
|
||||
**Files Modified:**
|
||||
1. [src/server/utils/dependencies.py](src/server/utils/dependencies.py) - Fixed exception handling in `get_database_session` and `get_optional_database_session`
|
||||
2. [tests/unit/test_dependency_exception_handling.py](tests/unit/test_dependency_exception_handling.py) - Created comprehensive unit tests for the fix
|
||||
3. [tests/api/test_anime_endpoints.py](tests/api/test_anime_endpoints.py) - Updated to mock BackgroundLoaderService and expect 202 status codes
|
||||
1. [src/server/services/background_loader_service.py](src/server/services/background_loader_service.py) - Fixed async context manager usage
|
||||
2. [tests/unit/test_background_loader_session.py](tests/unit/test_background_loader_session.py) - Created comprehensive unit tests
|
||||
|
||||
**Tests:**
|
||||
- ✅ 5 new unit tests for dependency exception handling (all passing)
|
||||
- ✅ 16 anime endpoint integration tests (all passing)
|
||||
- ✅ Tests verify proper handling of 400, 404, 422 status codes
|
||||
- ✅ Tests verify successful requests still work correctly
|
||||
- ✅ 5 new unit tests for background loader database session handling (all passing)
|
||||
- ✅ Tests verify proper async context manager usage
|
||||
- ✅ Tests verify error handling and progress tracking
|
||||
- ✅ Tests verify episode and NFO loading logic
|
||||
- ✅ Includes test demonstrating the difference between `async with` and `async for`
|
||||
|
||||
**Note for Users:**
|
||||
If you're experiencing this error on a running server, please **restart the server** to load the fixed code:
|
||||
```bash
|
||||
# Stop the server (Ctrl+C)
|
||||
# Then restart:
|
||||
conda run -n AniWorld python -m uvicorn src.server.fastapi_app:app --host 127.0.0.1 --port 8000 --reload
|
||||
```
|
||||
**Verification:**
|
||||
The fix allows the background loader service to properly load series data including episodes, NFO files, logos, and images without crashing.
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -279,7 +279,7 @@ class BackgroundLoaderService:
|
||||
from src.server.database.connection import get_db_session
|
||||
from src.server.database.service import AnimeSeriesService
|
||||
|
||||
async for db in get_db_session():
|
||||
async with get_db_session() as db:
|
||||
try:
|
||||
# Check what data is missing
|
||||
missing = await self.check_missing_data(
|
||||
@@ -337,8 +337,6 @@ class BackgroundLoaderService:
|
||||
# Broadcast error
|
||||
await self._broadcast_status(task)
|
||||
|
||||
break # Exit async for loop after first iteration
|
||||
|
||||
finally:
|
||||
# Remove from active tasks
|
||||
self.active_tasks.pop(task.key, None)
|
||||
|
||||
304
tests/unit/test_background_loader_session.py
Normal file
304
tests/unit/test_background_loader_session.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Unit tests for BackgroundLoaderService database session handling.
|
||||
|
||||
This module tests that the background loader service properly uses async context
|
||||
managers for database sessions, preventing TypeError with async for.
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from src.server.services.background_loader_service import (
|
||||
BackgroundLoaderService,
|
||||
SeriesLoadingTask,
|
||||
LoadingStatus,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_series_data_uses_async_with_not_async_for():
|
||||
"""Test that _load_series_data uses 'async with' for database session.
|
||||
|
||||
This test verifies the fix for the TypeError:
|
||||
'async for' requires an object with __aiter__ method, got _AsyncGeneratorContextManager
|
||||
|
||||
The code should use 'async with get_db_session() as db:' not 'async for db in get_db_session():'
|
||||
"""
|
||||
# Create a fake series app
|
||||
fake_series_app = MagicMock()
|
||||
fake_series_app.directory_to_search = "/fake/path"
|
||||
fake_series_app.loader = MagicMock()
|
||||
|
||||
# Create fake websocket and anime services
|
||||
fake_websocket_service = AsyncMock()
|
||||
fake_anime_service = AsyncMock()
|
||||
|
||||
# Create the service
|
||||
service = BackgroundLoaderService(
|
||||
websocket_service=fake_websocket_service,
|
||||
anime_service=fake_anime_service,
|
||||
series_app=fake_series_app
|
||||
)
|
||||
|
||||
# Create a test task
|
||||
task = SeriesLoadingTask(
|
||||
key="test-anime",
|
||||
folder="Test Anime",
|
||||
name="Test Anime",
|
||||
year=2023,
|
||||
status=LoadingStatus.PENDING,
|
||||
progress={"episodes": False, "nfo": False, "logo": False, "images": False},
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Mock the database session
|
||||
mock_db = AsyncMock()
|
||||
mock_series_db = MagicMock()
|
||||
mock_series_db.loading_status = "pending"
|
||||
|
||||
# Mock database service
|
||||
with patch('src.server.database.connection.get_db_session') as mock_get_db:
|
||||
with patch('src.server.database.service.AnimeSeriesService') as mock_service:
|
||||
# Configure the async context manager
|
||||
mock_get_db.return_value.__aenter__ = AsyncMock(return_value=mock_db)
|
||||
mock_get_db.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Configure service methods
|
||||
mock_service.get_by_key = AsyncMock(return_value=mock_series_db)
|
||||
mock_db.commit = AsyncMock()
|
||||
|
||||
# Mock helper methods
|
||||
service.check_missing_data = AsyncMock(return_value={
|
||||
"episodes": False,
|
||||
"nfo": False,
|
||||
"logo": False,
|
||||
"images": False
|
||||
})
|
||||
service._broadcast_status = AsyncMock()
|
||||
|
||||
# Execute the method - this should not raise TypeError
|
||||
await service._load_series_data(task)
|
||||
|
||||
# Verify the context manager was entered (proves we used 'async with')
|
||||
mock_get_db.return_value.__aenter__.assert_called_once()
|
||||
mock_get_db.return_value.__aexit__.assert_called_once()
|
||||
|
||||
# Verify task was marked as completed
|
||||
assert task.status == LoadingStatus.COMPLETED
|
||||
assert task.completed_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_series_data_handles_database_errors():
|
||||
"""Test that _load_series_data properly handles database errors."""
|
||||
# Create a fake series app
|
||||
fake_series_app = MagicMock()
|
||||
fake_series_app.directory_to_search = "/fake/path"
|
||||
fake_series_app.loader = MagicMock()
|
||||
|
||||
# Create fake websocket and anime services
|
||||
fake_websocket_service = AsyncMock()
|
||||
fake_anime_service = AsyncMock()
|
||||
|
||||
# Create the service
|
||||
service = BackgroundLoaderService(
|
||||
websocket_service=fake_websocket_service,
|
||||
anime_service=fake_anime_service,
|
||||
series_app=fake_series_app
|
||||
)
|
||||
|
||||
# Create a test task
|
||||
task = SeriesLoadingTask(
|
||||
key="test-anime",
|
||||
folder="Test Anime",
|
||||
name="Test Anime",
|
||||
year=2023,
|
||||
status=LoadingStatus.PENDING,
|
||||
progress={"episodes": False, "nfo": False, "logo": False, "images": False},
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Mock the database session to raise an error
|
||||
mock_db = AsyncMock()
|
||||
|
||||
with patch('src.server.database.connection.get_db_session') as mock_get_db:
|
||||
with patch('src.server.database.service.AnimeSeriesService') as mock_service:
|
||||
# Configure the async context manager
|
||||
mock_get_db.return_value.__aenter__ = AsyncMock(return_value=mock_db)
|
||||
mock_get_db.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Make check_missing_data raise an error
|
||||
service.check_missing_data = AsyncMock(side_effect=Exception("Database error"))
|
||||
service._broadcast_status = AsyncMock()
|
||||
mock_service.get_by_key = AsyncMock(return_value=None)
|
||||
|
||||
# Execute - should handle error gracefully
|
||||
await service._load_series_data(task)
|
||||
|
||||
# Verify task was marked as failed
|
||||
assert task.status == LoadingStatus.FAILED
|
||||
assert task.error == "Database error"
|
||||
assert task.completed_at is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_series_data_loads_missing_episodes():
|
||||
"""Test that _load_series_data loads episodes when missing."""
|
||||
# Create a fake series app
|
||||
fake_series_app = MagicMock()
|
||||
fake_series_app.directory_to_search = "/fake/path"
|
||||
fake_series_app.loader = MagicMock()
|
||||
|
||||
# Create fake websocket and anime services
|
||||
fake_websocket_service = AsyncMock()
|
||||
fake_anime_service = AsyncMock()
|
||||
|
||||
# Create the service
|
||||
service = BackgroundLoaderService(
|
||||
websocket_service=fake_websocket_service,
|
||||
anime_service=fake_anime_service,
|
||||
series_app=fake_series_app
|
||||
)
|
||||
|
||||
# Create a test task
|
||||
task = SeriesLoadingTask(
|
||||
key="test-anime",
|
||||
folder="Test Anime",
|
||||
name="Test Anime",
|
||||
year=2023,
|
||||
status=LoadingStatus.PENDING,
|
||||
progress={"episodes": False, "nfo": False, "logo": False, "images": False},
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Mock the database session
|
||||
mock_db = AsyncMock()
|
||||
mock_series_db = MagicMock()
|
||||
|
||||
with patch('src.server.database.connection.get_db_session') as mock_get_db:
|
||||
with patch('src.server.database.service.AnimeSeriesService') as mock_service:
|
||||
# Configure the async context manager
|
||||
mock_get_db.return_value.__aenter__ = AsyncMock(return_value=mock_db)
|
||||
mock_get_db.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Configure service methods
|
||||
mock_service.get_by_key = AsyncMock(return_value=mock_series_db)
|
||||
mock_db.commit = AsyncMock()
|
||||
|
||||
# Mock helper methods - episodes are missing
|
||||
service.check_missing_data = AsyncMock(return_value={
|
||||
"episodes": True, # Episodes are missing
|
||||
"nfo": False,
|
||||
"logo": False,
|
||||
"images": False
|
||||
})
|
||||
service._load_episodes = AsyncMock()
|
||||
service._broadcast_status = AsyncMock()
|
||||
|
||||
# Execute
|
||||
await service._load_series_data(task)
|
||||
|
||||
# Verify _load_episodes was called
|
||||
service._load_episodes.assert_called_once_with(task, mock_db)
|
||||
|
||||
# Verify task completed
|
||||
assert task.status == LoadingStatus.COMPLETED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_series_data_loads_nfo_and_images():
|
||||
"""Test that _load_series_data loads NFO and images when missing."""
|
||||
# Create a fake series app
|
||||
fake_series_app = MagicMock()
|
||||
fake_series_app.directory_to_search = "/fake/path"
|
||||
fake_series_app.loader = MagicMock()
|
||||
|
||||
# Create fake websocket and anime services
|
||||
fake_websocket_service = AsyncMock()
|
||||
fake_anime_service = AsyncMock()
|
||||
|
||||
# Create the service
|
||||
service = BackgroundLoaderService(
|
||||
websocket_service=fake_websocket_service,
|
||||
anime_service=fake_anime_service,
|
||||
series_app=fake_series_app
|
||||
)
|
||||
|
||||
# Create a test task
|
||||
task = SeriesLoadingTask(
|
||||
key="test-anime",
|
||||
folder="Test Anime",
|
||||
name="Test Anime",
|
||||
year=2023,
|
||||
status=LoadingStatus.PENDING,
|
||||
progress={"episodes": False, "nfo": False, "logo": False, "images": False},
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Mock the database session
|
||||
mock_db = AsyncMock()
|
||||
mock_series_db = MagicMock()
|
||||
|
||||
with patch('src.server.database.connection.get_db_session') as mock_get_db:
|
||||
with patch('src.server.database.service.AnimeSeriesService') as mock_service:
|
||||
# Configure the async context manager
|
||||
mock_get_db.return_value.__aenter__ = AsyncMock(return_value=mock_db)
|
||||
mock_get_db.return_value.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Configure service methods
|
||||
mock_service.get_by_key = AsyncMock(return_value=mock_series_db)
|
||||
mock_db.commit = AsyncMock()
|
||||
|
||||
# Mock helper methods - NFO and images are missing
|
||||
service.check_missing_data = AsyncMock(return_value={
|
||||
"episodes": False,
|
||||
"nfo": True, # NFO is missing
|
||||
"logo": True, # Logo is missing
|
||||
"images": True # Images are missing
|
||||
})
|
||||
service._load_nfo_and_images = AsyncMock()
|
||||
service._broadcast_status = AsyncMock()
|
||||
|
||||
# Execute
|
||||
await service._load_series_data(task)
|
||||
|
||||
# Verify _load_nfo_and_images was called
|
||||
service._load_nfo_and_images.assert_called_once_with(task, mock_db)
|
||||
|
||||
# Verify task completed
|
||||
assert task.status == LoadingStatus.COMPLETED
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_context_manager_usage():
|
||||
"""Direct test that verifies async context manager is used correctly.
|
||||
|
||||
This test ensures the code uses 'async with' pattern, not 'async for'.
|
||||
"""
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
# Create a test async context manager
|
||||
call_log = []
|
||||
|
||||
@asynccontextmanager
|
||||
async def test_context_manager() -> AsyncGenerator:
|
||||
call_log.append("enter")
|
||||
yield "test_value"
|
||||
call_log.append("exit")
|
||||
|
||||
# Test that async with works
|
||||
async with test_context_manager() as value:
|
||||
assert value == "test_value"
|
||||
assert "enter" in call_log
|
||||
|
||||
assert "exit" in call_log
|
||||
|
||||
# Verify that async for would fail with this pattern
|
||||
call_log.clear()
|
||||
try:
|
||||
async for item in test_context_manager():
|
||||
pass
|
||||
assert False, "Should have raised TypeError"
|
||||
except TypeError as e:
|
||||
assert "__aiter__" in str(e) or "async for" in str(e)
|
||||
@@ -4,11 +4,12 @@ Unit tests for dependency exception handling in FastAPI dependencies.
|
||||
This module tests that async generator dependencies properly handle exceptions
|
||||
thrown back into them, preventing the "generator didn't stop after athrow()" error.
|
||||
"""
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException, Depends
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
import pytest
|
||||
from fastapi import Depends, FastAPI, HTTPException
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_database_session_handles_http_exception():
|
||||
@@ -18,7 +19,7 @@ async def test_get_optional_database_session_handles_http_exception():
|
||||
that occurred when an HTTPException was raised after yielding a database session.
|
||||
"""
|
||||
from src.server.utils.dependencies import get_optional_database_session
|
||||
|
||||
|
||||
# Create a test app
|
||||
app = FastAPI()
|
||||
|
||||
@@ -48,7 +49,7 @@ async def test_get_database_session_handles_http_exception():
|
||||
that occurred when an HTTPException was raised after yielding a database session.
|
||||
"""
|
||||
from src.server.utils.dependencies import get_database_session
|
||||
|
||||
|
||||
# Create a test app
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user