Aniworld/tests/unit/test_transactions.py
Lukas 1ba67357dc Add database transaction support with atomic operations
- Create transaction.py with @transactional decorator, atomic() context manager
- Add TransactionPropagation modes: REQUIRED, REQUIRES_NEW, NESTED
- Add savepoint support for nested transactions with partial rollback
- Update connection.py with TransactionManager, get_transactional_session
- Update service.py with bulk operations (bulk_mark_downloaded, bulk_delete)
- Wrap QueueRepository.save_item() and clear_all() in atomic transactions
- Add comprehensive tests (66 transaction tests, 90% coverage)
- All 1090 tests passing
2025-12-25 18:05:33 +01:00

669 lines
24 KiB
Python

"""Unit tests for database transaction utilities.
Tests the transaction management utilities including decorators,
context managers, and helper functions.
"""
import asyncio
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import Session, sessionmaker
from src.server.database.base import Base
from src.server.database.transaction import (
AsyncTransactionContext,
TransactionContext,
TransactionError,
TransactionPropagation,
atomic,
atomic_sync,
is_in_transaction,
transactional,
)
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
async def async_engine():
"""Create in-memory async database engine for testing."""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session(async_engine):
"""Create async database session for testing."""
from sqlalchemy.ext.asyncio import async_sessionmaker
async_session_factory = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session_factory() as session:
yield session
await session.rollback()
# ============================================================================
# TransactionContext Tests (Sync)
# ============================================================================
class TestTransactionContext:
"""Tests for synchronous TransactionContext."""
def test_context_manager_protocol(self):
"""Test context manager enters and exits properly."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
with TransactionContext(mock_session) as ctx:
assert ctx.session == mock_session
mock_session.begin.assert_called_once()
mock_session.commit.assert_called_once()
def test_rollback_on_exception(self):
"""Test rollback is called when exception occurs."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
with pytest.raises(ValueError):
with TransactionContext(mock_session):
raise ValueError("Test error")
mock_session.rollback.assert_called_once()
mock_session.commit.assert_not_called()
def test_no_begin_if_already_in_transaction(self):
"""Test no new transaction started if already in one."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = True
with TransactionContext(mock_session):
pass
mock_session.begin.assert_not_called()
def test_explicit_commit(self):
"""Test explicit commit within context."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
with TransactionContext(mock_session) as ctx:
ctx.commit()
mock_session.commit.assert_called_once()
# Should not commit again on exit
assert mock_session.commit.call_count == 1
def test_explicit_rollback(self):
"""Test explicit rollback within context."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
with TransactionContext(mock_session) as ctx:
ctx.rollback()
mock_session.rollback.assert_called_once()
# Should not commit after explicit rollback
mock_session.commit.assert_not_called()
# ============================================================================
# AsyncTransactionContext Tests
# ============================================================================
class TestAsyncTransactionContext:
"""Tests for asynchronous AsyncTransactionContext."""
@pytest.mark.asyncio
async def test_async_context_manager_protocol(self):
"""Test async context manager enters and exits properly."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
async with AsyncTransactionContext(mock_session) as ctx:
assert ctx.session == mock_session
mock_session.begin.assert_called_once()
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_async_rollback_on_exception(self):
"""Test async rollback is called when exception occurs."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
with pytest.raises(ValueError):
async with AsyncTransactionContext(mock_session):
raise ValueError("Test error")
mock_session.rollback.assert_called_once()
mock_session.commit.assert_not_called()
@pytest.mark.asyncio
async def test_async_explicit_commit(self):
"""Test async explicit commit within context."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
async with AsyncTransactionContext(mock_session) as ctx:
await ctx.commit()
mock_session.commit.assert_called_once()
# Should not commit again on exit
assert mock_session.commit.call_count == 1
@pytest.mark.asyncio
async def test_async_explicit_rollback(self):
"""Test async explicit rollback within context."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
async with AsyncTransactionContext(mock_session) as ctx:
await ctx.rollback()
mock_session.rollback.assert_called_once()
# Should not commit after explicit rollback
mock_session.commit.assert_not_called()
# ============================================================================
# atomic() Context Manager Tests
# ============================================================================
class TestAtomicContextManager:
"""Tests for atomic() async context manager."""
@pytest.mark.asyncio
async def test_atomic_commits_on_success(self):
"""Test atomic commits transaction on success."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
async with atomic(mock_session) as tx:
pass
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_atomic_rollback_on_failure(self):
"""Test atomic rolls back transaction on failure."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
with pytest.raises(RuntimeError):
async with atomic(mock_session):
raise RuntimeError("Operation failed")
mock_session.rollback.assert_called_once()
@pytest.mark.asyncio
async def test_atomic_nested_propagation(self):
"""Test atomic with NESTED propagation creates savepoint."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = True
mock_nested = AsyncMock()
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
async with atomic(
mock_session, propagation=TransactionPropagation.NESTED
):
pass
mock_session.begin_nested.assert_called_once()
@pytest.mark.asyncio
async def test_atomic_required_propagation_default(self):
"""Test atomic uses REQUIRED propagation by default."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
async with atomic(mock_session) as tx:
# Should start new transaction
mock_session.begin.assert_called_once()
# ============================================================================
# @transactional Decorator Tests
# ============================================================================
class TestTransactionalDecorator:
"""Tests for @transactional decorator."""
@pytest.mark.asyncio
async def test_async_function_wrapped(self):
"""Test async function is wrapped in transaction."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
@transactional()
async def sample_operation(db: AsyncSession):
return "result"
result = await sample_operation(db=mock_session)
assert result == "result"
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_async_rollback_on_error(self):
"""Test async function rollback on error."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
@transactional()
async def failing_operation(db: AsyncSession):
raise ValueError("Operation failed")
with pytest.raises(ValueError):
await failing_operation(db=mock_session)
mock_session.rollback.assert_called_once()
@pytest.mark.asyncio
async def test_custom_session_param_name(self):
"""Test decorator with custom session parameter name."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
@transactional(session_param="session")
async def operation_with_session(session: AsyncSession):
return "done"
result = await operation_with_session(session=mock_session)
assert result == "done"
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_missing_session_raises_error(self):
"""Test error raised when session parameter not found."""
@transactional()
async def operation_no_session(data: dict):
return data
with pytest.raises(TransactionError):
await operation_no_session(data={"key": "value"})
@pytest.mark.asyncio
async def test_propagation_passed_to_atomic(self):
"""Test propagation mode is passed to atomic."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = True
mock_nested = AsyncMock()
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
@transactional(propagation=TransactionPropagation.NESTED)
async def nested_operation(db: AsyncSession):
return "nested"
result = await nested_operation(db=mock_session)
assert result == "nested"
mock_session.begin_nested.assert_called_once()
# ============================================================================
# Sync transactional decorator Tests
# ============================================================================
class TestSyncTransactionalDecorator:
"""Tests for @transactional decorator with sync functions."""
def test_sync_function_wrapped(self):
"""Test sync function is wrapped in transaction."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
@transactional()
def sample_sync_operation(db: Session):
return "sync_result"
result = sample_sync_operation(db=mock_session)
assert result == "sync_result"
mock_session.commit.assert_called_once()
def test_sync_rollback_on_error(self):
"""Test sync function rollback on error."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
@transactional()
def failing_sync_operation(db: Session):
raise ValueError("Sync operation failed")
with pytest.raises(ValueError):
failing_sync_operation(db=mock_session)
mock_session.rollback.assert_called_once()
# ============================================================================
# Helper Function Tests
# ============================================================================
class TestHelperFunctions:
"""Tests for transaction helper functions."""
def test_is_in_transaction_true(self):
"""Test is_in_transaction returns True when in transaction."""
mock_session = MagicMock()
mock_session.in_transaction.return_value = True
assert is_in_transaction(mock_session) is True
def test_is_in_transaction_false(self):
"""Test is_in_transaction returns False when not in transaction."""
mock_session = MagicMock()
mock_session.in_transaction.return_value = False
assert is_in_transaction(mock_session) is False
# ============================================================================
# Integration Tests with Real Database
# ============================================================================
class TestTransactionIntegration:
"""Integration tests using real in-memory database."""
@pytest.mark.asyncio
async def test_real_transaction_commit(self, async_session):
"""Test actual transaction commit with real session."""
from src.server.database.models import AnimeSeries
async with atomic(async_session):
series = AnimeSeries(
key="test-series",
name="Test Series",
site="https://test.com",
folder="/test/folder",
)
async_session.add(series)
# Verify data persisted
from sqlalchemy import select
result = await async_session.execute(
select(AnimeSeries).where(AnimeSeries.key == "test-series")
)
saved_series = result.scalar_one_or_none()
assert saved_series is not None
assert saved_series.name == "Test Series"
@pytest.mark.asyncio
async def test_real_transaction_rollback(self, async_session):
"""Test actual transaction rollback with real session."""
from src.server.database.models import AnimeSeries
try:
async with atomic(async_session):
series = AnimeSeries(
key="rollback-series",
name="Rollback Series",
site="https://test.com",
folder="/test/folder",
)
async_session.add(series)
await async_session.flush()
# Force rollback
raise ValueError("Simulated error")
except ValueError:
pass
# Verify data was NOT persisted
from sqlalchemy import select
result = await async_session.execute(
select(AnimeSeries).where(AnimeSeries.key == "rollback-series")
)
saved_series = result.scalar_one_or_none()
assert saved_series is None
# ============================================================================
# TransactionPropagation Tests
# ============================================================================
class TestTransactionPropagation:
"""Tests for transaction propagation modes."""
def test_propagation_enum_values(self):
"""Test propagation enum has correct values."""
assert TransactionPropagation.REQUIRED.value == "required"
assert TransactionPropagation.REQUIRES_NEW.value == "requires_new"
assert TransactionPropagation.NESTED.value == "nested"
# ============================================================================
# Additional Coverage Tests
# ============================================================================
class TestSyncSavepointCoverage:
"""Additional tests for sync savepoint coverage."""
def test_savepoint_exception_rolls_back(self):
"""Test savepoint rollback when exception occurs within savepoint."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
mock_nested = MagicMock()
mock_session.begin_nested.return_value = mock_nested
with TransactionContext(mock_session) as ctx:
with pytest.raises(ValueError):
with ctx.savepoint() as sp:
raise ValueError("Error in savepoint")
# Nested transaction should have been rolled back
mock_nested.rollback.assert_called_once()
def test_savepoint_commit_explicit(self):
"""Test explicit commit on savepoint."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
mock_nested = MagicMock()
mock_session.begin_nested.return_value = mock_nested
with TransactionContext(mock_session) as ctx:
with ctx.savepoint() as sp:
sp.commit()
# Commit should just log, SQLAlchemy handles actual commit
class TestAsyncSavepointCoverage:
"""Additional tests for async savepoint coverage."""
@pytest.mark.asyncio
async def test_async_savepoint_exception_rolls_back(self):
"""Test async savepoint rollback when exception occurs."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
mock_nested = AsyncMock()
mock_nested.rollback = AsyncMock()
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
async with AsyncTransactionContext(mock_session) as ctx:
with pytest.raises(ValueError):
async with ctx.savepoint() as sp:
raise ValueError("Error in async savepoint")
# Nested transaction should have been rolled back
mock_nested.rollback.assert_called_once()
@pytest.mark.asyncio
async def test_async_savepoint_commit_explicit(self):
"""Test explicit commit on async savepoint."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_nested = AsyncMock()
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
async with AsyncTransactionContext(mock_session) as ctx:
async with ctx.savepoint() as sp:
await sp.commit()
# Commit should just log, SQLAlchemy handles actual commit
class TestAtomicNestedPropagationNoTransaction:
"""Tests for NESTED propagation when not in transaction."""
@pytest.mark.asyncio
async def test_async_nested_starts_new_when_not_in_transaction(self):
"""Test NESTED propagation starts new transaction when none exists."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
async with atomic(mock_session, TransactionPropagation.NESTED) as tx:
# Should start new transaction since none exists
pass
mock_session.begin.assert_called_once()
mock_session.commit.assert_called_once()
def test_sync_nested_starts_new_when_not_in_transaction(self):
"""Test sync NESTED propagation starts new transaction when none exists."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
with atomic_sync(mock_session, TransactionPropagation.NESTED) as tx:
pass
mock_session.begin.assert_called_once()
mock_session.commit.assert_called_once()
class TestGetTransactionDepth:
"""Tests for get_transaction_depth helper."""
def test_depth_zero_when_not_in_transaction(self):
"""Test depth is 0 when not in transaction."""
from src.server.database.transaction import get_transaction_depth
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
depth = get_transaction_depth(mock_session)
assert depth == 0
def test_depth_one_in_transaction(self):
"""Test depth is 1 in basic transaction."""
from src.server.database.transaction import get_transaction_depth
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = True
mock_session._nested_transaction = None
depth = get_transaction_depth(mock_session)
assert depth == 1
def test_depth_two_with_nested_transaction(self):
"""Test depth is 2 with nested transaction."""
from src.server.database.transaction import get_transaction_depth
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = True
mock_session._nested_transaction = MagicMock() # Has nested
depth = get_transaction_depth(mock_session)
assert depth == 2
class TestTransactionalDecoratorPositionalArgs:
"""Tests for transactional decorator with positional arguments."""
@pytest.mark.asyncio
async def test_session_from_positional_arg(self):
"""Test decorator extracts session from positional argument."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
@transactional()
async def operation(db: AsyncSession, data: str):
return f"processed: {data}"
# Pass session as positional argument
result = await operation(mock_session, "test")
assert result == "processed: test"
mock_session.commit.assert_called_once()
def test_sync_session_from_positional_arg(self):
"""Test sync decorator extracts session from positional argument."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
@transactional()
def operation(db: Session, data: str):
return f"processed: {data}"
result = operation(mock_session, "test")
assert result == "processed: test"
mock_session.commit.assert_called_once()