"""Unit tests for database transaction management. This module tests transaction contexts, savepoints, decorators, and transaction propagation behavior to ensure data consistency. Coverage Target: 90%+ """ import asyncio import logging from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from src.server.database.transaction import ( AsyncSavepointContext, AsyncTransactionContext, SavepointContext, TransactionContext, TransactionError, TransactionPropagation, _extract_session, atomic, atomic_sync, get_transaction_depth, is_in_transaction, transactional, ) class TestTransactionPropagation: """Test TransactionPropagation enum.""" def test_propagation_values(self): """Test enum values are defined correctly.""" assert TransactionPropagation.REQUIRED.value == "required" assert TransactionPropagation.REQUIRES_NEW.value == "requires_new" assert TransactionPropagation.NESTED.value == "nested" def test_propagation_members(self): """Test all expected members exist.""" members = [e.name for e in TransactionPropagation] assert "REQUIRED" in members assert "REQUIRES_NEW" in members assert "NESTED" in members class TestTransactionContext: """Test synchronous TransactionContext.""" @pytest.fixture def mock_session(self): """Create a mock Session.""" session = MagicMock(spec=Session) session.in_transaction.return_value = False return session def test_context_enter_starts_transaction(self, mock_session): """Test entering context starts transaction.""" with TransactionContext(mock_session) as tx: assert tx.session == mock_session mock_session.begin.assert_called_once() def test_context_enter_existing_transaction(self, mock_session): """Test entering context with existing transaction.""" mock_session.in_transaction.return_value = True with TransactionContext(mock_session): mock_session.begin.assert_not_called() def test_context_exit_commits_on_success(self, mock_session): """Test exiting context commits on success.""" with TransactionContext(mock_session): pass mock_session.commit.assert_called_once() mock_session.rollback.assert_not_called() def test_context_exit_rollback_on_exception(self, mock_session): """Test exiting context rolls back on exception.""" with pytest.raises(ValueError): with TransactionContext(mock_session): raise ValueError("Test error") mock_session.rollback.assert_called_once() mock_session.commit.assert_not_called() def test_explicit_commit(self, mock_session): """Test explicit commit within context.""" with TransactionContext(mock_session) as tx: tx.commit() mock_session.commit.assert_called_once() # Should not commit again on exit assert mock_session.commit.call_count == 1 def test_explicit_rollback(self, mock_session): """Test explicit rollback within context.""" with TransactionContext(mock_session) as tx: tx.rollback() mock_session.rollback.assert_called_once() # Should not commit on exit after rollback mock_session.commit.assert_not_called() def test_savepoint_creation(self, mock_session): """Test savepoint creation.""" mock_nested = MagicMock() mock_session.begin_nested.return_value = mock_nested with TransactionContext(mock_session) as tx: with tx.savepoint() as sp: assert isinstance(sp, SavepointContext) mock_session.begin_nested.assert_called_once() def test_savepoint_with_custom_name(self, mock_session): """Test savepoint with custom name.""" mock_nested = MagicMock() mock_session.begin_nested.return_value = mock_nested with TransactionContext(mock_session) as tx: with tx.savepoint("custom_sp") as sp: assert sp._name == "custom_sp" def test_savepoint_auto_naming(self, mock_session): """Test savepoint automatic naming.""" mock_nested = MagicMock() mock_session.begin_nested.return_value = mock_nested with TransactionContext(mock_session) as tx: with tx.savepoint() as sp1: assert sp1._name == "sp_1" with tx.savepoint() as sp2: assert sp2._name == "sp_2" def test_savepoint_rollback_on_exception(self, mock_session): """Test savepoint rolls back on exception.""" mock_nested = MagicMock() mock_session.begin_nested.return_value = mock_nested with pytest.raises(ValueError): with TransactionContext(mock_session) as tx: with tx.savepoint() as sp: raise ValueError("Savepoint error") mock_nested.rollback.assert_called_once() def test_multiple_savepoints(self, mock_session): """Test multiple nested savepoints.""" mock_nested = MagicMock() mock_session.begin_nested.return_value = mock_nested with TransactionContext(mock_session) as tx: with tx.savepoint("outer"): with tx.savepoint("inner"): pass # Both savepoints should be created assert mock_session.begin_nested.call_count == 2 class TestSavepointContext: """Test SavepointContext.""" def test_rollback(self): """Test savepoint rollback.""" mock_nested = MagicMock() sp = SavepointContext(mock_nested, "test_sp") sp.rollback() mock_nested.rollback.assert_called_once() assert sp._rolled_back is True def test_rollback_idempotent(self): """Test rollback can only happen once.""" mock_nested = MagicMock() sp = SavepointContext(mock_nested, "test_sp") sp.rollback() sp.rollback() # Second call should not rollback again mock_nested.rollback.assert_called_once() def test_commit_logs_only(self): """Test commit method logs but doesn't call nested.""" mock_nested = MagicMock() sp = SavepointContext(mock_nested, "test_sp") sp.commit() # Commit doesn't call nested methods (auto-handled by SQLAlchemy) assert not sp._rolled_back class TestAsyncTransactionContext: """Test asynchronous AsyncTransactionContext.""" @pytest.fixture def mock_async_session(self): """Create a mock AsyncSession.""" session = MagicMock(spec=AsyncSession) session.in_transaction = MagicMock(return_value=False) session.begin = AsyncMock() session.commit = AsyncMock() session.rollback = AsyncMock() session.begin_nested = AsyncMock() return session @pytest.mark.asyncio async def test_async_context_enter_starts_transaction(self, mock_async_session): """Test entering async context starts transaction.""" async with AsyncTransactionContext(mock_async_session) as tx: assert tx.session == mock_async_session mock_async_session.begin.assert_called_once() @pytest.mark.asyncio async def test_async_context_enter_existing_transaction(self, mock_async_session): """Test entering async context with existing transaction.""" mock_async_session.in_transaction.return_value = True async with AsyncTransactionContext(mock_async_session): mock_async_session.begin.assert_not_called() @pytest.mark.asyncio async def test_async_context_exit_commits_on_success(self, mock_async_session): """Test exiting async context commits on success.""" async with AsyncTransactionContext(mock_async_session): pass mock_async_session.commit.assert_called_once() mock_async_session.rollback.assert_not_called() @pytest.mark.asyncio async def test_async_context_exit_rollback_on_exception(self, mock_async_session): """Test exiting async context rolls back on exception.""" with pytest.raises(ValueError): async with AsyncTransactionContext(mock_async_session): raise ValueError("Async test error") mock_async_session.rollback.assert_called_once() mock_async_session.commit.assert_not_called() @pytest.mark.asyncio async def test_async_explicit_commit(self, mock_async_session): """Test explicit commit within async context.""" async with AsyncTransactionContext(mock_async_session) as tx: await tx.commit() mock_async_session.commit.assert_called_once() # Should not commit again on exit assert mock_async_session.commit.call_count == 1 @pytest.mark.asyncio async def test_async_explicit_rollback(self, mock_async_session): """Test explicit rollback within async context.""" async with AsyncTransactionContext(mock_async_session) as tx: await tx.rollback() mock_async_session.rollback.assert_called_once() # Should not commit on exit after rollback mock_async_session.commit.assert_not_called() @pytest.mark.asyncio async def test_async_savepoint_creation(self, mock_async_session): """Test async savepoint creation.""" mock_nested = AsyncMock() mock_async_session.begin_nested.return_value = mock_nested async with AsyncTransactionContext(mock_async_session) as tx: async with tx.savepoint() as sp: assert isinstance(sp, AsyncSavepointContext) mock_async_session.begin_nested.assert_called_once() @pytest.mark.asyncio async def test_async_savepoint_with_custom_name(self, mock_async_session): """Test async savepoint with custom name.""" mock_nested = AsyncMock() mock_async_session.begin_nested.return_value = mock_nested async with AsyncTransactionContext(mock_async_session) as tx: async with tx.savepoint("async_custom_sp") as sp: assert sp._name == "async_custom_sp" @pytest.mark.asyncio async def test_async_savepoint_rollback_on_exception(self, mock_async_session): """Test async savepoint rolls back on exception.""" mock_nested = AsyncMock() mock_async_session.begin_nested.return_value = mock_nested with pytest.raises(ValueError): async with AsyncTransactionContext(mock_async_session) as tx: async with tx.savepoint() as sp: raise ValueError("Async savepoint error") mock_nested.rollback.assert_called_once() @pytest.mark.asyncio async def test_async_multiple_savepoints(self, mock_async_session): """Test multiple nested async savepoints.""" mock_nested = AsyncMock() mock_async_session.begin_nested.return_value = mock_nested async with AsyncTransactionContext(mock_async_session) as tx: async with tx.savepoint("outer"): async with tx.savepoint("inner"): pass # Both savepoints should be created assert mock_async_session.begin_nested.call_count == 2 class TestAsyncSavepointContext: """Test AsyncSavepointContext.""" @pytest.mark.asyncio async def test_async_rollback(self): """Test async savepoint rollback.""" mock_nested = AsyncMock() mock_session = AsyncMock(spec=AsyncSession) sp = AsyncSavepointContext(mock_nested, "async_test_sp", mock_session) await sp.rollback() mock_nested.rollback.assert_called_once() assert sp._rolled_back is True @pytest.mark.asyncio async def test_async_rollback_idempotent(self): """Test async rollback can only happen once.""" mock_nested = AsyncMock() mock_session = AsyncMock(spec=AsyncSession) sp = AsyncSavepointContext(mock_nested, "async_test_sp", mock_session) await sp.rollback() await sp.rollback() # Second call should not rollback again mock_nested.rollback.assert_called_once() @pytest.mark.asyncio async def test_async_commit_logs_only(self): """Test async commit method logs but doesn't call nested.""" mock_nested = AsyncMock() mock_session = AsyncMock(spec=AsyncSession) sp = AsyncSavepointContext(mock_nested, "async_test_sp", mock_session) await sp.commit() # Commit doesn't call nested methods (auto-handled by SQLAlchemy) assert not sp._rolled_back class TestAtomicContextManager: """Test atomic() async context manager.""" @pytest.fixture def mock_async_session(self): """Create a mock AsyncSession.""" session = MagicMock(spec=AsyncSession) session.in_transaction = MagicMock(return_value=False) session.begin = AsyncMock() session.commit = AsyncMock() session.rollback = AsyncMock() session.begin_nested = AsyncMock() return session @pytest.mark.asyncio async def test_atomic_required_propagation(self, mock_async_session): """Test atomic with REQUIRED propagation.""" async with atomic(mock_async_session, TransactionPropagation.REQUIRED) as tx: assert isinstance(tx, AsyncTransactionContext) mock_async_session.begin.assert_called_once() @pytest.mark.asyncio async def test_atomic_default_propagation(self, mock_async_session): """Test atomic uses REQUIRED as default.""" async with atomic(mock_async_session) as tx: assert isinstance(tx, AsyncTransactionContext) mock_async_session.begin.assert_called_once() @pytest.mark.asyncio async def test_atomic_nested_with_existing_transaction(self, mock_async_session): """Test atomic with NESTED propagation creates savepoint.""" mock_async_session.in_transaction.return_value = True mock_nested = AsyncMock() mock_async_session.begin_nested.return_value = mock_nested async with atomic(mock_async_session, TransactionPropagation.NESTED) as tx: assert isinstance(tx, AsyncTransactionContext) mock_async_session.begin_nested.assert_called_once() @pytest.mark.asyncio async def test_atomic_nested_without_transaction(self, mock_async_session): """Test atomic with NESTED but no existing transaction starts new one.""" mock_async_session.in_transaction.return_value = False async with atomic(mock_async_session, TransactionPropagation.NESTED) as tx: assert isinstance(tx, AsyncTransactionContext) mock_async_session.begin.assert_called_once() @pytest.mark.asyncio async def test_atomic_commits_on_success(self, mock_async_session): """Test atomic commits on successful completion.""" async with atomic(mock_async_session): pass mock_async_session.commit.assert_called_once() @pytest.mark.asyncio async def test_atomic_rollback_on_exception(self, mock_async_session): """Test atomic rolls back on exception.""" with pytest.raises(RuntimeError): async with atomic(mock_async_session): raise RuntimeError("Atomic error") mock_async_session.rollback.assert_called_once() class TestAtomicSyncContextManager: """Test atomic_sync() context manager.""" @pytest.fixture def mock_session(self): """Create a mock Session.""" session = MagicMock(spec=Session) session.in_transaction.return_value = False return session def test_atomic_sync_required_propagation(self, mock_session): """Test atomic_sync with REQUIRED propagation.""" with atomic_sync(mock_session, TransactionPropagation.REQUIRED) as tx: assert isinstance(tx, TransactionContext) mock_session.begin.assert_called_once() def test_atomic_sync_default_propagation(self, mock_session): """Test atomic_sync uses REQUIRED as default.""" with atomic_sync(mock_session) as tx: assert isinstance(tx, TransactionContext) mock_session.begin.assert_called_once() def test_atomic_sync_nested_with_existing_transaction(self, mock_session): """Test atomic_sync with NESTED propagation creates savepoint.""" mock_session.in_transaction.return_value = True mock_nested = MagicMock() mock_session.begin_nested.return_value = mock_nested with atomic_sync(mock_session, TransactionPropagation.NESTED) as tx: assert isinstance(tx, TransactionContext) mock_session.begin_nested.assert_called_once() def test_atomic_sync_commits_on_success(self, mock_session): """Test atomic_sync commits on successful completion.""" with atomic_sync(mock_session): pass mock_session.commit.assert_called_once() def test_atomic_sync_rollback_on_exception(self, mock_session): """Test atomic_sync rolls back on exception.""" with pytest.raises(RuntimeError): with atomic_sync(mock_session): raise RuntimeError("Sync atomic error") mock_session.rollback.assert_called_once() class TestTransactionalDecorator: """Test @transactional decorator.""" @pytest.fixture def mock_async_session(self): """Create a mock AsyncSession.""" session = MagicMock(spec=AsyncSession) session.in_transaction = MagicMock(return_value=False) session.begin = AsyncMock() session.commit = AsyncMock() session.rollback = AsyncMock() session.begin_nested = AsyncMock() return session @pytest.fixture def mock_session(self): """Create a mock Session.""" session = MagicMock(spec=Session) session.in_transaction.return_value = False return session @pytest.mark.asyncio async def test_transactional_async_function(self, mock_async_session): """Test @transactional decorator on async function.""" @transactional() async def test_func(db: AsyncSession, value: int) -> int: return value * 2 result = await test_func(db=mock_async_session, value=21) assert result == 42 mock_async_session.begin.assert_called_once() mock_async_session.commit.assert_called_once() @pytest.mark.asyncio async def test_transactional_async_with_custom_param_name(self, mock_async_session): """Test @transactional with custom session parameter name.""" @transactional(session_param="session") async def test_func(session: AsyncSession, value: int) -> int: return value * 3 result = await test_func(session=mock_async_session, value=10) assert result == 30 mock_async_session.begin.assert_called_once() @pytest.mark.asyncio async def test_transactional_async_propagation(self, mock_async_session): """Test @transactional with different propagation.""" @transactional(propagation=TransactionPropagation.NESTED) async def test_func(db: AsyncSession) -> str: return "nested" mock_async_session.in_transaction.return_value = False result = await test_func(db=mock_async_session) assert result == "nested" mock_async_session.begin.assert_called_once() @pytest.mark.asyncio async def test_transactional_async_rollback_on_exception(self, mock_async_session): """Test @transactional rolls back on exception.""" @transactional() async def test_func(db: AsyncSession) -> None: raise ValueError("Test exception") with pytest.raises(ValueError): await test_func(db=mock_async_session) mock_async_session.rollback.assert_called_once() mock_async_session.commit.assert_not_called() def test_transactional_sync_function(self, mock_session): """Test @transactional decorator on sync function.""" @transactional() def test_func(db: Session, value: int) -> int: return value * 2 result = test_func(db=mock_session, value=21) assert result == 42 mock_session.begin.assert_called_once() mock_session.commit.assert_called_once() def test_transactional_sync_rollback_on_exception(self, mock_session): """Test @transactional rolls back sync function on exception.""" @transactional() def test_func(db: Session) -> None: raise ValueError("Sync test exception") with pytest.raises(ValueError): test_func(db=mock_session) mock_session.rollback.assert_called_once() mock_session.commit.assert_not_called() @pytest.mark.asyncio async def test_transactional_missing_session_param(self): """Test @transactional raises error when session param not found.""" @transactional(session_param="db") async def test_func(value: int) -> int: return value with pytest.raises(TransactionError, match="Could not find session parameter"): await test_func(value=42) @pytest.mark.asyncio async def test_transactional_positional_session(self, mock_async_session): """Test @transactional with session passed as positional arg.""" @transactional() async def test_func(db: AsyncSession, value: int) -> int: return value * 2 result = await test_func(mock_async_session, 21) assert result == 42 mock_async_session.begin.assert_called_once() class TestExtractSession: """Test _extract_session helper function.""" def test_extract_from_kwargs(self): """Test extracting session from kwargs.""" session = MagicMock() def test_func(db: Session, value: int): pass result = _extract_session(test_func, (), {"db": session, "value": 10}, "db") assert result is session def test_extract_from_positional_args(self): """Test extracting session from positional args.""" session = MagicMock() def test_func(db: Session, value: int): pass result = _extract_session(test_func, (session, 10), {}, "db") assert result is session def test_extract_returns_none_when_not_found(self): """Test returns None when session not found.""" def test_func(value: int): pass result = _extract_session(test_func, (10,), {}, "db") assert result is None def test_extract_with_self_parameter(self): """Test extracting session with self parameter (method).""" session = MagicMock() class TestClass: def test_method(self, db: Session, value: int): pass obj = TestClass() # When calling a bound method, 'self' is not in args result = _extract_session( obj.test_method, (session, 10), # No 'self' in args for bound methods {}, "db" ) assert result is session class TestUtilityFunctions: """Test utility functions.""" def test_is_in_transaction_true(self): """Test is_in_transaction returns True.""" session = MagicMock(spec=Session) session.in_transaction.return_value = True assert is_in_transaction(session) is True def test_is_in_transaction_false(self): """Test is_in_transaction returns False.""" session = MagicMock(spec=Session) session.in_transaction.return_value = False assert is_in_transaction(session) is False def test_get_transaction_depth_zero(self): """Test get_transaction_depth returns 0 when not in transaction.""" session = MagicMock(spec=Session) session.in_transaction.return_value = False assert get_transaction_depth(session) == 0 def test_get_transaction_depth_one(self): """Test get_transaction_depth returns 1 in transaction.""" session = MagicMock(spec=Session) session.in_transaction.return_value = True session._nested_transaction = None assert get_transaction_depth(session) == 1 def test_get_transaction_depth_nested(self): """Test get_transaction_depth returns 2 with savepoint.""" session = MagicMock(spec=Session) session.in_transaction.return_value = True session._nested_transaction = MagicMock() assert get_transaction_depth(session) == 2 class TestTransactionLogging: """Test transaction logging behavior.""" @pytest.fixture def mock_async_session(self): """Create a mock AsyncSession.""" session = MagicMock(spec=AsyncSession) session.in_transaction = MagicMock(return_value=False) session.begin = AsyncMock() session.commit = AsyncMock() session.rollback = AsyncMock() session.begin_nested = AsyncMock() return session @pytest.mark.asyncio async def test_logging_on_commit(self, mock_async_session, caplog): """Test logging when transaction commits.""" with caplog.at_level(logging.DEBUG): async with AsyncTransactionContext(mock_async_session): pass assert "Entering async transaction context" in caplog.text assert "Async transaction committed" in caplog.text @pytest.mark.asyncio async def test_logging_on_rollback(self, mock_async_session, caplog): """Test logging when transaction rolls back.""" with caplog.at_level(logging.WARNING): with pytest.raises(ValueError): async with AsyncTransactionContext(mock_async_session): raise ValueError("Test error") assert "Async transaction rollback due to exception" in caplog.text @pytest.mark.asyncio async def test_logging_savepoint_creation(self, mock_async_session, caplog): """Test logging when savepoint is created.""" mock_nested = AsyncMock() mock_async_session.begin_nested.return_value = mock_nested with caplog.at_level(logging.DEBUG): async with AsyncTransactionContext(mock_async_session) as tx: async with tx.savepoint("test_sp"): pass assert "Creating async savepoint: test_sp" in caplog.text class TestTransactionError: """Test TransactionError exception.""" def test_transaction_error_message(self): """Test TransactionError can be raised with message.""" with pytest.raises(TransactionError, match="Custom error"): raise TransactionError("Custom error") def test_transaction_error_inheritance(self): """Test TransactionError inherits from Exception.""" assert issubclass(TransactionError, Exception) class TestComplexScenarios: """Test complex transaction scenarios.""" @pytest.fixture def mock_async_session(self): """Create a mock AsyncSession.""" session = MagicMock(spec=AsyncSession) session.in_transaction = MagicMock(return_value=False) session.begin = AsyncMock() session.commit = AsyncMock() session.rollback = AsyncMock() session.begin_nested = AsyncMock() return session @pytest.mark.asyncio async def test_nested_transactions_with_partial_rollback(self, mock_async_session): """Test nested transactions with partial rollback.""" mock_async_session.in_transaction.return_value = True mock_nested = AsyncMock() mock_async_session.begin_nested.return_value = mock_nested async with AsyncTransactionContext(mock_async_session) as tx: async with tx.savepoint("outer") as sp_outer: # Simulate outer operation success pass # Outer savepoint should not roll back mock_nested.rollback.assert_not_called() # Transaction should commit mock_async_session.commit.assert_called_once() @pytest.mark.asyncio async def test_multiple_operations_in_transaction(self, mock_async_session): """Test multiple operations within single transaction.""" @transactional() async def operation_a(db: AsyncSession) -> str: return "A" @transactional() async def operation_b(db: AsyncSession) -> str: return "B" async with atomic(mock_async_session): result_a = await operation_a(db=mock_async_session) result_b = await operation_b(db=mock_async_session) assert result_a == "A" assert result_b == "B" @pytest.mark.asyncio async def test_explicit_savepoint_rollback(self, mock_async_session): """Test explicit savepoint rollback.""" mock_nested = AsyncMock() mock_async_session.begin_nested.return_value = mock_nested async with AsyncTransactionContext(mock_async_session) as tx: async with tx.savepoint() as sp: await sp.rollback() # Verify rollback was called mock_nested.rollback.assert_called_once() # Transaction should still commit (only savepoint rolled back) mock_async_session.commit.assert_called_once() @pytest.mark.asyncio async def test_decorator_with_nested_calls(self, mock_async_session): """Test @transactional decorator with nested function calls.""" mock_async_session.in_transaction.return_value = True @transactional() async def inner_operation(db: AsyncSession, value: int) -> int: return value * 2 @transactional() async def outer_operation(db: AsyncSession, value: int) -> int: result = await inner_operation(db, value) return result + 10 result = await outer_operation(db=mock_async_session, value=5) assert result == 20 # (5 * 2) + 10