"""Unit tests for database transaction utilities. Tests the transaction management utilities including decorators, context managers, and helper functions. """ import asyncio from datetime import datetime, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import Session, sessionmaker from src.server.database.base import Base from src.server.database.transaction import ( AsyncTransactionContext, TransactionContext, TransactionError, TransactionPropagation, atomic, atomic_sync, is_in_transaction, transactional, ) # ============================================================================ # Fixtures # ============================================================================ @pytest.fixture async def async_engine(): """Create in-memory async database engine for testing.""" engine = create_async_engine( "sqlite+aiosqlite:///:memory:", echo=False, ) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield engine await engine.dispose() @pytest.fixture async def async_session(async_engine): """Create async database session for testing.""" from sqlalchemy.ext.asyncio import async_sessionmaker async_session_factory = async_sessionmaker( async_engine, class_=AsyncSession, expire_on_commit=False, ) async with async_session_factory() as session: yield session await session.rollback() # ============================================================================ # TransactionContext Tests (Sync) # ============================================================================ class TestTransactionContext: """Tests for synchronous TransactionContext.""" def test_context_manager_protocol(self): """Test context manager enters and exits properly.""" mock_session = MagicMock(spec=Session) mock_session.in_transaction.return_value = False with TransactionContext(mock_session) as ctx: assert ctx.session == mock_session mock_session.begin.assert_called_once() mock_session.commit.assert_called_once() def test_rollback_on_exception(self): """Test rollback is called when exception occurs.""" mock_session = MagicMock(spec=Session) mock_session.in_transaction.return_value = False with pytest.raises(ValueError): with TransactionContext(mock_session): raise ValueError("Test error") mock_session.rollback.assert_called_once() mock_session.commit.assert_not_called() def test_no_begin_if_already_in_transaction(self): """Test no new transaction started if already in one.""" mock_session = MagicMock(spec=Session) mock_session.in_transaction.return_value = True with TransactionContext(mock_session): pass mock_session.begin.assert_not_called() def test_explicit_commit(self): """Test explicit commit within context.""" mock_session = MagicMock(spec=Session) mock_session.in_transaction.return_value = False with TransactionContext(mock_session) as ctx: ctx.commit() mock_session.commit.assert_called_once() # Should not commit again on exit assert mock_session.commit.call_count == 1 def test_explicit_rollback(self): """Test explicit rollback within context.""" mock_session = MagicMock(spec=Session) mock_session.in_transaction.return_value = False with TransactionContext(mock_session) as ctx: ctx.rollback() mock_session.rollback.assert_called_once() # Should not commit after explicit rollback mock_session.commit.assert_not_called() # ============================================================================ # AsyncTransactionContext Tests # ============================================================================ class TestAsyncTransactionContext: """Tests for asynchronous AsyncTransactionContext.""" @pytest.mark.asyncio async def test_async_context_manager_protocol(self): """Test async context manager enters and exits properly.""" mock_session = AsyncMock(spec=AsyncSession) mock_session.in_transaction.return_value = False mock_session.begin = AsyncMock() mock_session.commit = AsyncMock() mock_session.rollback = AsyncMock() async with AsyncTransactionContext(mock_session) as ctx: assert ctx.session == mock_session mock_session.begin.assert_called_once() mock_session.commit.assert_called_once() @pytest.mark.asyncio async def test_async_rollback_on_exception(self): """Test async rollback is called when exception occurs.""" mock_session = AsyncMock(spec=AsyncSession) mock_session.in_transaction.return_value = False mock_session.begin = AsyncMock() mock_session.commit = AsyncMock() mock_session.rollback = AsyncMock() with pytest.raises(ValueError): async with AsyncTransactionContext(mock_session): raise ValueError("Test error") mock_session.rollback.assert_called_once() mock_session.commit.assert_not_called() @pytest.mark.asyncio async def test_async_explicit_commit(self): """Test async explicit commit within context.""" mock_session = AsyncMock(spec=AsyncSession) mock_session.in_transaction.return_value = False mock_session.begin = AsyncMock() mock_session.commit = AsyncMock() async with AsyncTransactionContext(mock_session) as ctx: await ctx.commit() mock_session.commit.assert_called_once() # Should not commit again on exit assert mock_session.commit.call_count == 1 @pytest.mark.asyncio async def test_async_explicit_rollback(self): """Test async explicit rollback within context.""" mock_session = AsyncMock(spec=AsyncSession) mock_session.in_transaction.return_value = False mock_session.begin = AsyncMock() mock_session.commit = AsyncMock() mock_session.rollback = AsyncMock() async with AsyncTransactionContext(mock_session) as ctx: await ctx.rollback() mock_session.rollback.assert_called_once() # Should not commit after explicit rollback mock_session.commit.assert_not_called() # ============================================================================ # atomic() Context Manager Tests # ============================================================================ class TestAtomicContextManager: """Tests for atomic() async context manager.""" @pytest.mark.asyncio async def test_atomic_commits_on_success(self): """Test atomic commits transaction on success.""" mock_session = AsyncMock(spec=AsyncSession) mock_session.in_transaction.return_value = False mock_session.begin = AsyncMock() mock_session.commit = AsyncMock() mock_session.rollback = AsyncMock() async with atomic(mock_session) as tx: pass mock_session.commit.assert_called_once() @pytest.mark.asyncio async def test_atomic_rollback_on_failure(self): """Test atomic rolls back transaction on failure.""" mock_session = AsyncMock(spec=AsyncSession) mock_session.in_transaction.return_value = False mock_session.begin = AsyncMock() mock_session.commit = AsyncMock() mock_session.rollback = AsyncMock() with pytest.raises(RuntimeError): async with atomic(mock_session): raise RuntimeError("Operation failed") mock_session.rollback.assert_called_once() @pytest.mark.asyncio async def test_atomic_nested_propagation(self): """Test atomic with NESTED propagation creates savepoint.""" mock_session = AsyncMock(spec=AsyncSession) mock_session.in_transaction.return_value = True mock_nested = AsyncMock() mock_session.begin_nested = AsyncMock(return_value=mock_nested) async with atomic( mock_session, propagation=TransactionPropagation.NESTED ): pass mock_session.begin_nested.assert_called_once() @pytest.mark.asyncio async def test_atomic_required_propagation_default(self): """Test atomic uses REQUIRED propagation by default.""" mock_session = AsyncMock(spec=AsyncSession) mock_session.in_transaction.return_value = False mock_session.begin = AsyncMock() mock_session.commit = AsyncMock() async with atomic(mock_session) as tx: # Should start new transaction mock_session.begin.assert_called_once() # ============================================================================ # @transactional Decorator Tests # ============================================================================ class TestTransactionalDecorator: """Tests for @transactional decorator.""" @pytest.mark.asyncio async def test_async_function_wrapped(self): """Test async function is wrapped in transaction.""" mock_session = AsyncMock(spec=AsyncSession) mock_session.in_transaction.return_value = False mock_session.begin = AsyncMock() mock_session.commit = AsyncMock() mock_session.rollback = AsyncMock() @transactional() async def sample_operation(db: AsyncSession): return "result" result = await sample_operation(db=mock_session) assert result == "result" mock_session.commit.assert_called_once() @pytest.mark.asyncio async def test_async_rollback_on_error(self): """Test async function rollback on error.""" mock_session = AsyncMock(spec=AsyncSession) mock_session.in_transaction.return_value = False mock_session.begin = AsyncMock() mock_session.commit = AsyncMock() mock_session.rollback = AsyncMock() @transactional() async def failing_operation(db: AsyncSession): raise ValueError("Operation failed") with pytest.raises(ValueError): await failing_operation(db=mock_session) mock_session.rollback.assert_called_once() @pytest.mark.asyncio async def test_custom_session_param_name(self): """Test decorator with custom session parameter name.""" mock_session = AsyncMock(spec=AsyncSession) mock_session.in_transaction.return_value = False mock_session.begin = AsyncMock() mock_session.commit = AsyncMock() @transactional(session_param="session") async def operation_with_session(session: AsyncSession): return "done" result = await operation_with_session(session=mock_session) assert result == "done" mock_session.commit.assert_called_once() @pytest.mark.asyncio async def test_missing_session_raises_error(self): """Test error raised when session parameter not found.""" @transactional() async def operation_no_session(data: dict): return data with pytest.raises(TransactionError): await operation_no_session(data={"key": "value"}) @pytest.mark.asyncio async def test_propagation_passed_to_atomic(self): """Test propagation mode is passed to atomic.""" mock_session = AsyncMock(spec=AsyncSession) mock_session.in_transaction.return_value = True mock_nested = AsyncMock() mock_session.begin_nested = AsyncMock(return_value=mock_nested) @transactional(propagation=TransactionPropagation.NESTED) async def nested_operation(db: AsyncSession): return "nested" result = await nested_operation(db=mock_session) assert result == "nested" mock_session.begin_nested.assert_called_once() # ============================================================================ # Sync transactional decorator Tests # ============================================================================ class TestSyncTransactionalDecorator: """Tests for @transactional decorator with sync functions.""" def test_sync_function_wrapped(self): """Test sync function is wrapped in transaction.""" mock_session = MagicMock(spec=Session) mock_session.in_transaction.return_value = False @transactional() def sample_sync_operation(db: Session): return "sync_result" result = sample_sync_operation(db=mock_session) assert result == "sync_result" mock_session.commit.assert_called_once() def test_sync_rollback_on_error(self): """Test sync function rollback on error.""" mock_session = MagicMock(spec=Session) mock_session.in_transaction.return_value = False @transactional() def failing_sync_operation(db: Session): raise ValueError("Sync operation failed") with pytest.raises(ValueError): failing_sync_operation(db=mock_session) mock_session.rollback.assert_called_once() # ============================================================================ # Helper Function Tests # ============================================================================ class TestHelperFunctions: """Tests for transaction helper functions.""" def test_is_in_transaction_true(self): """Test is_in_transaction returns True when in transaction.""" mock_session = MagicMock() mock_session.in_transaction.return_value = True assert is_in_transaction(mock_session) is True def test_is_in_transaction_false(self): """Test is_in_transaction returns False when not in transaction.""" mock_session = MagicMock() mock_session.in_transaction.return_value = False assert is_in_transaction(mock_session) is False # ============================================================================ # Integration Tests with Real Database # ============================================================================ class TestTransactionIntegration: """Integration tests using real in-memory database.""" @pytest.mark.asyncio async def test_real_transaction_commit(self, async_session): """Test actual transaction commit with real session.""" from src.server.database.models import AnimeSeries async with atomic(async_session): series = AnimeSeries( key="test-series", name="Test Series", site="https://test.com", folder="/test/folder", ) async_session.add(series) # Verify data persisted from sqlalchemy import select result = await async_session.execute( select(AnimeSeries).where(AnimeSeries.key == "test-series") ) saved_series = result.scalar_one_or_none() assert saved_series is not None assert saved_series.name == "Test Series" @pytest.mark.asyncio async def test_real_transaction_rollback(self, async_session): """Test actual transaction rollback with real session.""" from src.server.database.models import AnimeSeries try: async with atomic(async_session): series = AnimeSeries( key="rollback-series", name="Rollback Series", site="https://test.com", folder="/test/folder", ) async_session.add(series) await async_session.flush() # Force rollback raise ValueError("Simulated error") except ValueError: pass # Verify data was NOT persisted from sqlalchemy import select result = await async_session.execute( select(AnimeSeries).where(AnimeSeries.key == "rollback-series") ) saved_series = result.scalar_one_or_none() assert saved_series is None # ============================================================================ # TransactionPropagation Tests # ============================================================================ class TestTransactionPropagation: """Tests for transaction propagation modes.""" def test_propagation_enum_values(self): """Test propagation enum has correct values.""" assert TransactionPropagation.REQUIRED.value == "required" assert TransactionPropagation.REQUIRES_NEW.value == "requires_new" assert TransactionPropagation.NESTED.value == "nested" # ============================================================================ # Additional Coverage Tests # ============================================================================ class TestSyncSavepointCoverage: """Additional tests for sync savepoint coverage.""" def test_savepoint_exception_rolls_back(self): """Test savepoint rollback when exception occurs within savepoint.""" mock_session = MagicMock(spec=Session) mock_session.in_transaction.return_value = False mock_nested = MagicMock() mock_session.begin_nested.return_value = mock_nested with TransactionContext(mock_session) as ctx: with pytest.raises(ValueError): with ctx.savepoint() as sp: raise ValueError("Error in savepoint") # Nested transaction should have been rolled back mock_nested.rollback.assert_called_once() def test_savepoint_commit_explicit(self): """Test explicit commit on savepoint.""" mock_session = MagicMock(spec=Session) mock_session.in_transaction.return_value = False mock_nested = MagicMock() mock_session.begin_nested.return_value = mock_nested with TransactionContext(mock_session) as ctx: with ctx.savepoint() as sp: sp.commit() # Commit should just log, SQLAlchemy handles actual commit class TestAsyncSavepointCoverage: """Additional tests for async savepoint coverage.""" @pytest.mark.asyncio async def test_async_savepoint_exception_rolls_back(self): """Test async savepoint rollback when exception occurs.""" mock_session = AsyncMock(spec=AsyncSession) mock_session.in_transaction.return_value = False mock_session.begin = AsyncMock() mock_session.commit = AsyncMock() mock_session.rollback = AsyncMock() mock_nested = AsyncMock() mock_nested.rollback = AsyncMock() mock_session.begin_nested = AsyncMock(return_value=mock_nested) async with AsyncTransactionContext(mock_session) as ctx: with pytest.raises(ValueError): async with ctx.savepoint() as sp: raise ValueError("Error in async savepoint") # Nested transaction should have been rolled back mock_nested.rollback.assert_called_once() @pytest.mark.asyncio async def test_async_savepoint_commit_explicit(self): """Test explicit commit on async savepoint.""" mock_session = AsyncMock(spec=AsyncSession) mock_session.in_transaction.return_value = False mock_session.begin = AsyncMock() mock_session.commit = AsyncMock() mock_nested = AsyncMock() mock_session.begin_nested = AsyncMock(return_value=mock_nested) async with AsyncTransactionContext(mock_session) as ctx: async with ctx.savepoint() as sp: await sp.commit() # Commit should just log, SQLAlchemy handles actual commit class TestAtomicNestedPropagationNoTransaction: """Tests for NESTED propagation when not in transaction.""" @pytest.mark.asyncio async def test_async_nested_starts_new_when_not_in_transaction(self): """Test NESTED propagation starts new transaction when none exists.""" mock_session = AsyncMock(spec=AsyncSession) mock_session.in_transaction.return_value = False mock_session.begin = AsyncMock() mock_session.commit = AsyncMock() mock_session.rollback = AsyncMock() async with atomic(mock_session, TransactionPropagation.NESTED) as tx: # Should start new transaction since none exists pass mock_session.begin.assert_called_once() mock_session.commit.assert_called_once() def test_sync_nested_starts_new_when_not_in_transaction(self): """Test sync NESTED propagation starts new transaction when none exists.""" mock_session = MagicMock(spec=Session) mock_session.in_transaction.return_value = False with atomic_sync(mock_session, TransactionPropagation.NESTED) as tx: pass mock_session.begin.assert_called_once() mock_session.commit.assert_called_once() class TestGetTransactionDepth: """Tests for get_transaction_depth helper.""" def test_depth_zero_when_not_in_transaction(self): """Test depth is 0 when not in transaction.""" from src.server.database.transaction import get_transaction_depth mock_session = MagicMock(spec=Session) mock_session.in_transaction.return_value = False depth = get_transaction_depth(mock_session) assert depth == 0 def test_depth_one_in_transaction(self): """Test depth is 1 in basic transaction.""" from src.server.database.transaction import get_transaction_depth mock_session = MagicMock(spec=Session) mock_session.in_transaction.return_value = True mock_session._nested_transaction = None depth = get_transaction_depth(mock_session) assert depth == 1 def test_depth_two_with_nested_transaction(self): """Test depth is 2 with nested transaction.""" from src.server.database.transaction import get_transaction_depth mock_session = MagicMock(spec=Session) mock_session.in_transaction.return_value = True mock_session._nested_transaction = MagicMock() # Has nested depth = get_transaction_depth(mock_session) assert depth == 2 class TestTransactionalDecoratorPositionalArgs: """Tests for transactional decorator with positional arguments.""" @pytest.mark.asyncio async def test_session_from_positional_arg(self): """Test decorator extracts session from positional argument.""" mock_session = AsyncMock(spec=AsyncSession) mock_session.in_transaction.return_value = False mock_session.begin = AsyncMock() mock_session.commit = AsyncMock() mock_session.rollback = AsyncMock() @transactional() async def operation(db: AsyncSession, data: str): return f"processed: {data}" # Pass session as positional argument result = await operation(mock_session, "test") assert result == "processed: test" mock_session.commit.assert_called_once() def test_sync_session_from_positional_arg(self): """Test sync decorator extracts session from positional argument.""" mock_session = MagicMock(spec=Session) mock_session.in_transaction.return_value = False @transactional() def operation(db: Session, data: str): return f"processed: {data}" result = operation(mock_session, "test") assert result == "processed: test" mock_session.commit.assert_called_once()