Fix async generator exception handling and add comprehensive tests
This commit is contained in:
@@ -16,6 +16,7 @@ from src.server.utils.dependencies import (
|
||||
common_parameters,
|
||||
get_current_user,
|
||||
get_database_session,
|
||||
get_optional_database_session,
|
||||
get_series_app,
|
||||
log_request_dependency,
|
||||
optional_auth,
|
||||
@@ -291,6 +292,156 @@ class TestUtilityDependencies:
|
||||
# Assert - no exception should be raised
|
||||
|
||||
|
||||
class TestOptionalDatabaseSession:
|
||||
"""Test cases for optional database session dependency."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_optional_database_session_success(self):
|
||||
"""Test successful database session creation."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
# Mock the database session
|
||||
mock_session = AsyncMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
mock_get_db = Mock(return_value=mock_session)
|
||||
|
||||
with patch('src.server.database.get_db_session', mock_get_db):
|
||||
# Act
|
||||
gen = get_optional_database_session()
|
||||
session = await gen.__anext__()
|
||||
|
||||
# Assert
|
||||
assert session is mock_session
|
||||
mock_get_db.assert_called_once()
|
||||
|
||||
# Cleanup
|
||||
try:
|
||||
await gen.aclose()
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_optional_database_session_not_available(self):
|
||||
"""Test when database is not available (ImportError)."""
|
||||
# Mock ImportError when trying to import get_db_session
|
||||
with patch(
|
||||
'src.server.database.get_db_session',
|
||||
side_effect=ImportError("No module named 'database'")
|
||||
):
|
||||
# Act
|
||||
gen = get_optional_database_session()
|
||||
session = await gen.__anext__()
|
||||
|
||||
# Assert - should return None when database not available
|
||||
assert session is None
|
||||
|
||||
# Cleanup
|
||||
try:
|
||||
await gen.aclose()
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_optional_database_session_runtime_error(self):
|
||||
"""Test when database raises RuntimeError."""
|
||||
# Mock RuntimeError when trying to get database session
|
||||
with patch(
|
||||
'src.server.database.get_db_session',
|
||||
side_effect=RuntimeError("Database connection failed")
|
||||
):
|
||||
# Act
|
||||
gen = get_optional_database_session()
|
||||
session = await gen.__anext__()
|
||||
|
||||
# Assert - should return None on RuntimeError
|
||||
assert session is None
|
||||
|
||||
# Cleanup
|
||||
try:
|
||||
await gen.aclose()
|
||||
except StopAsyncIteration:
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_optional_database_session_exception_during_use(self):
|
||||
"""
|
||||
Test that exceptions during database operations are properly propagated.
|
||||
|
||||
This test specifically addresses the "generator didn't stop after athrow()"
|
||||
error that occurred when exceptions were caught and re-raised within the
|
||||
yield block of an async context manager.
|
||||
"""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
# Create a mock session that will raise an exception when used
|
||||
mock_session = AsyncMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = AsyncMock(return_value=None)
|
||||
|
||||
# Mock the get_db_session to return our mock session
|
||||
mock_get_db = Mock(return_value=mock_session)
|
||||
|
||||
with patch('src.server.database.get_db_session', mock_get_db):
|
||||
# Act & Assert
|
||||
gen = get_optional_database_session()
|
||||
session = await gen.__anext__()
|
||||
|
||||
assert session is mock_session
|
||||
|
||||
# Simulate an exception being thrown into the generator
|
||||
# This mimics what happens when an endpoint raises an exception
|
||||
# after the dependency has yielded
|
||||
test_exception = ValueError("Database operation failed")
|
||||
|
||||
try:
|
||||
# This should not cause "generator didn't stop after athrow()" error
|
||||
await gen.athrow(test_exception)
|
||||
# If we get here, the exception was swallowed (shouldn't happen)
|
||||
pytest.fail("Exception should have been propagated")
|
||||
except ValueError as e:
|
||||
# Exception should be properly propagated
|
||||
assert str(e) == "Database operation failed"
|
||||
except StopAsyncIteration:
|
||||
# Generator stopped normally after exception
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_optional_database_session_cleanup_on_exception(self):
|
||||
"""Test that database session is properly cleaned up when exception occurs."""
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
# Track cleanup
|
||||
cleanup_called = []
|
||||
|
||||
async def mock_exit(*args):
|
||||
cleanup_called.append(True)
|
||||
return None
|
||||
|
||||
# Create a mock session with tracked cleanup
|
||||
mock_session = AsyncMock()
|
||||
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_session.__aexit__ = mock_exit
|
||||
|
||||
mock_get_db = Mock(return_value=mock_session)
|
||||
|
||||
with patch('src.server.database.get_db_session', mock_get_db):
|
||||
gen = get_optional_database_session()
|
||||
session = await gen.__anext__()
|
||||
|
||||
assert session is mock_session
|
||||
|
||||
# Throw an exception to simulate endpoint failure
|
||||
try:
|
||||
await gen.athrow(RuntimeError("Simulated endpoint error"))
|
||||
except (RuntimeError, StopAsyncIteration):
|
||||
pass
|
||||
|
||||
# Assert cleanup was called
|
||||
assert len(cleanup_called) > 0, "Session cleanup should have been called"
|
||||
|
||||
|
||||
class TestIntegrationScenarios:
|
||||
"""Integration test scenarios for dependency injection."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user