From 458fc483e409a1aee9ebc5adbb2734b26aee1bd6 Mon Sep 17 00:00:00 2001 From: Lukas Date: Mon, 26 Jan 2026 18:12:33 +0100 Subject: [PATCH] feat(tests): add comprehensive database transaction tests - 66 tests for transaction management - Coverage: 90% (meets 90%+ target) - Tests for TransactionContext (sync and async) - Tests for SavepointContext (sync and async) - Tests for @transactional decorator - Tests for atomic() and atomic_sync() context managers - Tests for transaction propagation (REQUIRED, REQUIRES_NEW, NESTED) - Tests for utility functions (is_in_transaction, get_transaction_depth) - Tests for complex scenarios (nested transactions, partial rollback) Task 3 completed (Priority P0, Effort Large) --- tests/unit/test_transaction.py | 827 +++++++++++++++++++++++++++++++++ 1 file changed, 827 insertions(+) create mode 100644 tests/unit/test_transaction.py diff --git a/tests/unit/test_transaction.py b/tests/unit/test_transaction.py new file mode 100644 index 0000000..e09bbcf --- /dev/null +++ b/tests/unit/test_transaction.py @@ -0,0 +1,827 @@ +"""Unit tests for database transaction management. + +This module tests transaction contexts, savepoints, decorators, and +transaction propagation behavior to ensure data consistency. + +Coverage Target: 90%+ +""" + +import asyncio +import logging +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import Session + +from src.server.database.transaction import ( + AsyncSavepointContext, + AsyncTransactionContext, + SavepointContext, + TransactionContext, + TransactionError, + TransactionPropagation, + _extract_session, + atomic, + atomic_sync, + get_transaction_depth, + is_in_transaction, + transactional, +) + + +class TestTransactionPropagation: + """Test TransactionPropagation enum.""" + + def test_propagation_values(self): + """Test enum values are defined correctly.""" + assert TransactionPropagation.REQUIRED.value == "required" + assert TransactionPropagation.REQUIRES_NEW.value == "requires_new" + assert TransactionPropagation.NESTED.value == "nested" + + def test_propagation_members(self): + """Test all expected members exist.""" + members = [e.name for e in TransactionPropagation] + assert "REQUIRED" in members + assert "REQUIRES_NEW" in members + assert "NESTED" in members + + +class TestTransactionContext: + """Test synchronous TransactionContext.""" + + @pytest.fixture + def mock_session(self): + """Create a mock Session.""" + session = MagicMock(spec=Session) + session.in_transaction.return_value = False + return session + + def test_context_enter_starts_transaction(self, mock_session): + """Test entering context starts transaction.""" + with TransactionContext(mock_session) as tx: + assert tx.session == mock_session + mock_session.begin.assert_called_once() + + def test_context_enter_existing_transaction(self, mock_session): + """Test entering context with existing transaction.""" + mock_session.in_transaction.return_value = True + + with TransactionContext(mock_session): + mock_session.begin.assert_not_called() + + def test_context_exit_commits_on_success(self, mock_session): + """Test exiting context commits on success.""" + with TransactionContext(mock_session): + pass + + mock_session.commit.assert_called_once() + mock_session.rollback.assert_not_called() + + def test_context_exit_rollback_on_exception(self, mock_session): + """Test exiting context rolls back on exception.""" + with pytest.raises(ValueError): + with TransactionContext(mock_session): + raise ValueError("Test error") + + mock_session.rollback.assert_called_once() + mock_session.commit.assert_not_called() + + def test_explicit_commit(self, mock_session): + """Test explicit commit within context.""" + with TransactionContext(mock_session) as tx: + tx.commit() + mock_session.commit.assert_called_once() + + # Should not commit again on exit + assert mock_session.commit.call_count == 1 + + def test_explicit_rollback(self, mock_session): + """Test explicit rollback within context.""" + with TransactionContext(mock_session) as tx: + tx.rollback() + mock_session.rollback.assert_called_once() + + # Should not commit on exit after rollback + mock_session.commit.assert_not_called() + + def test_savepoint_creation(self, mock_session): + """Test savepoint creation.""" + mock_nested = MagicMock() + mock_session.begin_nested.return_value = mock_nested + + with TransactionContext(mock_session) as tx: + with tx.savepoint() as sp: + assert isinstance(sp, SavepointContext) + mock_session.begin_nested.assert_called_once() + + def test_savepoint_with_custom_name(self, mock_session): + """Test savepoint with custom name.""" + mock_nested = MagicMock() + mock_session.begin_nested.return_value = mock_nested + + with TransactionContext(mock_session) as tx: + with tx.savepoint("custom_sp") as sp: + assert sp._name == "custom_sp" + + def test_savepoint_auto_naming(self, mock_session): + """Test savepoint automatic naming.""" + mock_nested = MagicMock() + mock_session.begin_nested.return_value = mock_nested + + with TransactionContext(mock_session) as tx: + with tx.savepoint() as sp1: + assert sp1._name == "sp_1" + + with tx.savepoint() as sp2: + assert sp2._name == "sp_2" + + def test_savepoint_rollback_on_exception(self, mock_session): + """Test savepoint rolls back on exception.""" + mock_nested = MagicMock() + mock_session.begin_nested.return_value = mock_nested + + with pytest.raises(ValueError): + with TransactionContext(mock_session) as tx: + with tx.savepoint() as sp: + raise ValueError("Savepoint error") + + mock_nested.rollback.assert_called_once() + + def test_multiple_savepoints(self, mock_session): + """Test multiple nested savepoints.""" + mock_nested = MagicMock() + mock_session.begin_nested.return_value = mock_nested + + with TransactionContext(mock_session) as tx: + with tx.savepoint("outer"): + with tx.savepoint("inner"): + pass + + # Both savepoints should be created + assert mock_session.begin_nested.call_count == 2 + + +class TestSavepointContext: + """Test SavepointContext.""" + + def test_rollback(self): + """Test savepoint rollback.""" + mock_nested = MagicMock() + sp = SavepointContext(mock_nested, "test_sp") + + sp.rollback() + + mock_nested.rollback.assert_called_once() + assert sp._rolled_back is True + + def test_rollback_idempotent(self): + """Test rollback can only happen once.""" + mock_nested = MagicMock() + sp = SavepointContext(mock_nested, "test_sp") + + sp.rollback() + sp.rollback() # Second call should not rollback again + + mock_nested.rollback.assert_called_once() + + def test_commit_logs_only(self): + """Test commit method logs but doesn't call nested.""" + mock_nested = MagicMock() + sp = SavepointContext(mock_nested, "test_sp") + + sp.commit() + + # Commit doesn't call nested methods (auto-handled by SQLAlchemy) + assert not sp._rolled_back + + +class TestAsyncTransactionContext: + """Test asynchronous AsyncTransactionContext.""" + + @pytest.fixture + def mock_async_session(self): + """Create a mock AsyncSession.""" + session = MagicMock(spec=AsyncSession) + session.in_transaction = MagicMock(return_value=False) + session.begin = AsyncMock() + session.commit = AsyncMock() + session.rollback = AsyncMock() + session.begin_nested = AsyncMock() + return session + + @pytest.mark.asyncio + async def test_async_context_enter_starts_transaction(self, mock_async_session): + """Test entering async context starts transaction.""" + async with AsyncTransactionContext(mock_async_session) as tx: + assert tx.session == mock_async_session + mock_async_session.begin.assert_called_once() + + @pytest.mark.asyncio + async def test_async_context_enter_existing_transaction(self, mock_async_session): + """Test entering async context with existing transaction.""" + mock_async_session.in_transaction.return_value = True + + async with AsyncTransactionContext(mock_async_session): + mock_async_session.begin.assert_not_called() + + @pytest.mark.asyncio + async def test_async_context_exit_commits_on_success(self, mock_async_session): + """Test exiting async context commits on success.""" + async with AsyncTransactionContext(mock_async_session): + pass + + mock_async_session.commit.assert_called_once() + mock_async_session.rollback.assert_not_called() + + @pytest.mark.asyncio + async def test_async_context_exit_rollback_on_exception(self, mock_async_session): + """Test exiting async context rolls back on exception.""" + with pytest.raises(ValueError): + async with AsyncTransactionContext(mock_async_session): + raise ValueError("Async test error") + + mock_async_session.rollback.assert_called_once() + mock_async_session.commit.assert_not_called() + + @pytest.mark.asyncio + async def test_async_explicit_commit(self, mock_async_session): + """Test explicit commit within async context.""" + async with AsyncTransactionContext(mock_async_session) as tx: + await tx.commit() + mock_async_session.commit.assert_called_once() + + # Should not commit again on exit + assert mock_async_session.commit.call_count == 1 + + @pytest.mark.asyncio + async def test_async_explicit_rollback(self, mock_async_session): + """Test explicit rollback within async context.""" + async with AsyncTransactionContext(mock_async_session) as tx: + await tx.rollback() + mock_async_session.rollback.assert_called_once() + + # Should not commit on exit after rollback + mock_async_session.commit.assert_not_called() + + @pytest.mark.asyncio + async def test_async_savepoint_creation(self, mock_async_session): + """Test async savepoint creation.""" + mock_nested = AsyncMock() + mock_async_session.begin_nested.return_value = mock_nested + + async with AsyncTransactionContext(mock_async_session) as tx: + async with tx.savepoint() as sp: + assert isinstance(sp, AsyncSavepointContext) + mock_async_session.begin_nested.assert_called_once() + + @pytest.mark.asyncio + async def test_async_savepoint_with_custom_name(self, mock_async_session): + """Test async savepoint with custom name.""" + mock_nested = AsyncMock() + mock_async_session.begin_nested.return_value = mock_nested + + async with AsyncTransactionContext(mock_async_session) as tx: + async with tx.savepoint("async_custom_sp") as sp: + assert sp._name == "async_custom_sp" + + @pytest.mark.asyncio + async def test_async_savepoint_rollback_on_exception(self, mock_async_session): + """Test async savepoint rolls back on exception.""" + mock_nested = AsyncMock() + mock_async_session.begin_nested.return_value = mock_nested + + with pytest.raises(ValueError): + async with AsyncTransactionContext(mock_async_session) as tx: + async with tx.savepoint() as sp: + raise ValueError("Async savepoint error") + + mock_nested.rollback.assert_called_once() + + @pytest.mark.asyncio + async def test_async_multiple_savepoints(self, mock_async_session): + """Test multiple nested async savepoints.""" + mock_nested = AsyncMock() + mock_async_session.begin_nested.return_value = mock_nested + + async with AsyncTransactionContext(mock_async_session) as tx: + async with tx.savepoint("outer"): + async with tx.savepoint("inner"): + pass + + # Both savepoints should be created + assert mock_async_session.begin_nested.call_count == 2 + + +class TestAsyncSavepointContext: + """Test AsyncSavepointContext.""" + + @pytest.mark.asyncio + async def test_async_rollback(self): + """Test async savepoint rollback.""" + mock_nested = AsyncMock() + mock_session = AsyncMock(spec=AsyncSession) + sp = AsyncSavepointContext(mock_nested, "async_test_sp", mock_session) + + await sp.rollback() + + mock_nested.rollback.assert_called_once() + assert sp._rolled_back is True + + @pytest.mark.asyncio + async def test_async_rollback_idempotent(self): + """Test async rollback can only happen once.""" + mock_nested = AsyncMock() + mock_session = AsyncMock(spec=AsyncSession) + sp = AsyncSavepointContext(mock_nested, "async_test_sp", mock_session) + + await sp.rollback() + await sp.rollback() # Second call should not rollback again + + mock_nested.rollback.assert_called_once() + + @pytest.mark.asyncio + async def test_async_commit_logs_only(self): + """Test async commit method logs but doesn't call nested.""" + mock_nested = AsyncMock() + mock_session = AsyncMock(spec=AsyncSession) + sp = AsyncSavepointContext(mock_nested, "async_test_sp", mock_session) + + await sp.commit() + + # Commit doesn't call nested methods (auto-handled by SQLAlchemy) + assert not sp._rolled_back + + +class TestAtomicContextManager: + """Test atomic() async context manager.""" + + @pytest.fixture + def mock_async_session(self): + """Create a mock AsyncSession.""" + session = MagicMock(spec=AsyncSession) + session.in_transaction = MagicMock(return_value=False) + session.begin = AsyncMock() + session.commit = AsyncMock() + session.rollback = AsyncMock() + session.begin_nested = AsyncMock() + return session + + @pytest.mark.asyncio + async def test_atomic_required_propagation(self, mock_async_session): + """Test atomic with REQUIRED propagation.""" + async with atomic(mock_async_session, TransactionPropagation.REQUIRED) as tx: + assert isinstance(tx, AsyncTransactionContext) + mock_async_session.begin.assert_called_once() + + @pytest.mark.asyncio + async def test_atomic_default_propagation(self, mock_async_session): + """Test atomic uses REQUIRED as default.""" + async with atomic(mock_async_session) as tx: + assert isinstance(tx, AsyncTransactionContext) + mock_async_session.begin.assert_called_once() + + @pytest.mark.asyncio + async def test_atomic_nested_with_existing_transaction(self, mock_async_session): + """Test atomic with NESTED propagation creates savepoint.""" + mock_async_session.in_transaction.return_value = True + mock_nested = AsyncMock() + mock_async_session.begin_nested.return_value = mock_nested + + async with atomic(mock_async_session, TransactionPropagation.NESTED) as tx: + assert isinstance(tx, AsyncTransactionContext) + mock_async_session.begin_nested.assert_called_once() + + @pytest.mark.asyncio + async def test_atomic_nested_without_transaction(self, mock_async_session): + """Test atomic with NESTED but no existing transaction starts new one.""" + mock_async_session.in_transaction.return_value = False + + async with atomic(mock_async_session, TransactionPropagation.NESTED) as tx: + assert isinstance(tx, AsyncTransactionContext) + mock_async_session.begin.assert_called_once() + + @pytest.mark.asyncio + async def test_atomic_commits_on_success(self, mock_async_session): + """Test atomic commits on successful completion.""" + async with atomic(mock_async_session): + pass + + mock_async_session.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_atomic_rollback_on_exception(self, mock_async_session): + """Test atomic rolls back on exception.""" + with pytest.raises(RuntimeError): + async with atomic(mock_async_session): + raise RuntimeError("Atomic error") + + mock_async_session.rollback.assert_called_once() + + +class TestAtomicSyncContextManager: + """Test atomic_sync() context manager.""" + + @pytest.fixture + def mock_session(self): + """Create a mock Session.""" + session = MagicMock(spec=Session) + session.in_transaction.return_value = False + return session + + def test_atomic_sync_required_propagation(self, mock_session): + """Test atomic_sync with REQUIRED propagation.""" + with atomic_sync(mock_session, TransactionPropagation.REQUIRED) as tx: + assert isinstance(tx, TransactionContext) + mock_session.begin.assert_called_once() + + def test_atomic_sync_default_propagation(self, mock_session): + """Test atomic_sync uses REQUIRED as default.""" + with atomic_sync(mock_session) as tx: + assert isinstance(tx, TransactionContext) + mock_session.begin.assert_called_once() + + def test_atomic_sync_nested_with_existing_transaction(self, mock_session): + """Test atomic_sync with NESTED propagation creates savepoint.""" + mock_session.in_transaction.return_value = True + mock_nested = MagicMock() + mock_session.begin_nested.return_value = mock_nested + + with atomic_sync(mock_session, TransactionPropagation.NESTED) as tx: + assert isinstance(tx, TransactionContext) + mock_session.begin_nested.assert_called_once() + + def test_atomic_sync_commits_on_success(self, mock_session): + """Test atomic_sync commits on successful completion.""" + with atomic_sync(mock_session): + pass + + mock_session.commit.assert_called_once() + + def test_atomic_sync_rollback_on_exception(self, mock_session): + """Test atomic_sync rolls back on exception.""" + with pytest.raises(RuntimeError): + with atomic_sync(mock_session): + raise RuntimeError("Sync atomic error") + + mock_session.rollback.assert_called_once() + + +class TestTransactionalDecorator: + """Test @transactional decorator.""" + + @pytest.fixture + def mock_async_session(self): + """Create a mock AsyncSession.""" + session = MagicMock(spec=AsyncSession) + session.in_transaction = MagicMock(return_value=False) + session.begin = AsyncMock() + session.commit = AsyncMock() + session.rollback = AsyncMock() + session.begin_nested = AsyncMock() + return session + + @pytest.fixture + def mock_session(self): + """Create a mock Session.""" + session = MagicMock(spec=Session) + session.in_transaction.return_value = False + return session + + @pytest.mark.asyncio + async def test_transactional_async_function(self, mock_async_session): + """Test @transactional decorator on async function.""" + @transactional() + async def test_func(db: AsyncSession, value: int) -> int: + return value * 2 + + result = await test_func(db=mock_async_session, value=21) + + assert result == 42 + mock_async_session.begin.assert_called_once() + mock_async_session.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_transactional_async_with_custom_param_name(self, mock_async_session): + """Test @transactional with custom session parameter name.""" + @transactional(session_param="session") + async def test_func(session: AsyncSession, value: int) -> int: + return value * 3 + + result = await test_func(session=mock_async_session, value=10) + + assert result == 30 + mock_async_session.begin.assert_called_once() + + @pytest.mark.asyncio + async def test_transactional_async_propagation(self, mock_async_session): + """Test @transactional with different propagation.""" + @transactional(propagation=TransactionPropagation.NESTED) + async def test_func(db: AsyncSession) -> str: + return "nested" + + mock_async_session.in_transaction.return_value = False + result = await test_func(db=mock_async_session) + + assert result == "nested" + mock_async_session.begin.assert_called_once() + + @pytest.mark.asyncio + async def test_transactional_async_rollback_on_exception(self, mock_async_session): + """Test @transactional rolls back on exception.""" + @transactional() + async def test_func(db: AsyncSession) -> None: + raise ValueError("Test exception") + + with pytest.raises(ValueError): + await test_func(db=mock_async_session) + + mock_async_session.rollback.assert_called_once() + mock_async_session.commit.assert_not_called() + + def test_transactional_sync_function(self, mock_session): + """Test @transactional decorator on sync function.""" + @transactional() + def test_func(db: Session, value: int) -> int: + return value * 2 + + result = test_func(db=mock_session, value=21) + + assert result == 42 + mock_session.begin.assert_called_once() + mock_session.commit.assert_called_once() + + def test_transactional_sync_rollback_on_exception(self, mock_session): + """Test @transactional rolls back sync function on exception.""" + @transactional() + def test_func(db: Session) -> None: + raise ValueError("Sync test exception") + + with pytest.raises(ValueError): + test_func(db=mock_session) + + mock_session.rollback.assert_called_once() + mock_session.commit.assert_not_called() + + @pytest.mark.asyncio + async def test_transactional_missing_session_param(self): + """Test @transactional raises error when session param not found.""" + @transactional(session_param="db") + async def test_func(value: int) -> int: + return value + + with pytest.raises(TransactionError, match="Could not find session parameter"): + await test_func(value=42) + + @pytest.mark.asyncio + async def test_transactional_positional_session(self, mock_async_session): + """Test @transactional with session passed as positional arg.""" + @transactional() + async def test_func(db: AsyncSession, value: int) -> int: + return value * 2 + + result = await test_func(mock_async_session, 21) + + assert result == 42 + mock_async_session.begin.assert_called_once() + + +class TestExtractSession: + """Test _extract_session helper function.""" + + def test_extract_from_kwargs(self): + """Test extracting session from kwargs.""" + session = MagicMock() + + def test_func(db: Session, value: int): + pass + + result = _extract_session(test_func, (), {"db": session, "value": 10}, "db") + + assert result is session + + def test_extract_from_positional_args(self): + """Test extracting session from positional args.""" + session = MagicMock() + + def test_func(db: Session, value: int): + pass + + result = _extract_session(test_func, (session, 10), {}, "db") + + assert result is session + + def test_extract_returns_none_when_not_found(self): + """Test returns None when session not found.""" + def test_func(value: int): + pass + + result = _extract_session(test_func, (10,), {}, "db") + + assert result is None + + def test_extract_with_self_parameter(self): + """Test extracting session with self parameter (method).""" + session = MagicMock() + + class TestClass: + def test_method(self, db: Session, value: int): + pass + + obj = TestClass() + # When calling a bound method, 'self' is not in args + result = _extract_session( + obj.test_method, + (session, 10), # No 'self' in args for bound methods + {}, + "db" + ) + + assert result is session + + +class TestUtilityFunctions: + """Test utility functions.""" + + def test_is_in_transaction_true(self): + """Test is_in_transaction returns True.""" + session = MagicMock(spec=Session) + session.in_transaction.return_value = True + + assert is_in_transaction(session) is True + + def test_is_in_transaction_false(self): + """Test is_in_transaction returns False.""" + session = MagicMock(spec=Session) + session.in_transaction.return_value = False + + assert is_in_transaction(session) is False + + def test_get_transaction_depth_zero(self): + """Test get_transaction_depth returns 0 when not in transaction.""" + session = MagicMock(spec=Session) + session.in_transaction.return_value = False + + assert get_transaction_depth(session) == 0 + + def test_get_transaction_depth_one(self): + """Test get_transaction_depth returns 1 in transaction.""" + session = MagicMock(spec=Session) + session.in_transaction.return_value = True + session._nested_transaction = None + + assert get_transaction_depth(session) == 1 + + def test_get_transaction_depth_nested(self): + """Test get_transaction_depth returns 2 with savepoint.""" + session = MagicMock(spec=Session) + session.in_transaction.return_value = True + session._nested_transaction = MagicMock() + + assert get_transaction_depth(session) == 2 + + +class TestTransactionLogging: + """Test transaction logging behavior.""" + + @pytest.fixture + def mock_async_session(self): + """Create a mock AsyncSession.""" + session = MagicMock(spec=AsyncSession) + session.in_transaction = MagicMock(return_value=False) + session.begin = AsyncMock() + session.commit = AsyncMock() + session.rollback = AsyncMock() + session.begin_nested = AsyncMock() + return session + + @pytest.mark.asyncio + async def test_logging_on_commit(self, mock_async_session, caplog): + """Test logging when transaction commits.""" + with caplog.at_level(logging.DEBUG): + async with AsyncTransactionContext(mock_async_session): + pass + + assert "Entering async transaction context" in caplog.text + assert "Async transaction committed" in caplog.text + + @pytest.mark.asyncio + async def test_logging_on_rollback(self, mock_async_session, caplog): + """Test logging when transaction rolls back.""" + with caplog.at_level(logging.WARNING): + with pytest.raises(ValueError): + async with AsyncTransactionContext(mock_async_session): + raise ValueError("Test error") + + assert "Async transaction rollback due to exception" in caplog.text + + @pytest.mark.asyncio + async def test_logging_savepoint_creation(self, mock_async_session, caplog): + """Test logging when savepoint is created.""" + mock_nested = AsyncMock() + mock_async_session.begin_nested.return_value = mock_nested + + with caplog.at_level(logging.DEBUG): + async with AsyncTransactionContext(mock_async_session) as tx: + async with tx.savepoint("test_sp"): + pass + + assert "Creating async savepoint: test_sp" in caplog.text + + +class TestTransactionError: + """Test TransactionError exception.""" + + def test_transaction_error_message(self): + """Test TransactionError can be raised with message.""" + with pytest.raises(TransactionError, match="Custom error"): + raise TransactionError("Custom error") + + def test_transaction_error_inheritance(self): + """Test TransactionError inherits from Exception.""" + assert issubclass(TransactionError, Exception) + + +class TestComplexScenarios: + """Test complex transaction scenarios.""" + + @pytest.fixture + def mock_async_session(self): + """Create a mock AsyncSession.""" + session = MagicMock(spec=AsyncSession) + session.in_transaction = MagicMock(return_value=False) + session.begin = AsyncMock() + session.commit = AsyncMock() + session.rollback = AsyncMock() + session.begin_nested = AsyncMock() + return session + + @pytest.mark.asyncio + async def test_nested_transactions_with_partial_rollback(self, mock_async_session): + """Test nested transactions with partial rollback.""" + mock_async_session.in_transaction.return_value = True + mock_nested = AsyncMock() + mock_async_session.begin_nested.return_value = mock_nested + + async with AsyncTransactionContext(mock_async_session) as tx: + async with tx.savepoint("outer") as sp_outer: + # Simulate outer operation success + pass + + # Outer savepoint should not roll back + mock_nested.rollback.assert_not_called() + + # Transaction should commit + mock_async_session.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_multiple_operations_in_transaction(self, mock_async_session): + """Test multiple operations within single transaction.""" + @transactional() + async def operation_a(db: AsyncSession) -> str: + return "A" + + @transactional() + async def operation_b(db: AsyncSession) -> str: + return "B" + + async with atomic(mock_async_session): + result_a = await operation_a(db=mock_async_session) + result_b = await operation_b(db=mock_async_session) + + assert result_a == "A" + assert result_b == "B" + + @pytest.mark.asyncio + async def test_explicit_savepoint_rollback(self, mock_async_session): + """Test explicit savepoint rollback.""" + mock_nested = AsyncMock() + mock_async_session.begin_nested.return_value = mock_nested + + async with AsyncTransactionContext(mock_async_session) as tx: + async with tx.savepoint() as sp: + await sp.rollback() + # Verify rollback was called + mock_nested.rollback.assert_called_once() + + # Transaction should still commit (only savepoint rolled back) + mock_async_session.commit.assert_called_once() + + @pytest.mark.asyncio + async def test_decorator_with_nested_calls(self, mock_async_session): + """Test @transactional decorator with nested function calls.""" + mock_async_session.in_transaction.return_value = True + + @transactional() + async def inner_operation(db: AsyncSession, value: int) -> int: + return value * 2 + + @transactional() + async def outer_operation(db: AsyncSession, value: int) -> int: + result = await inner_operation(db, value) + return result + 10 + + result = await outer_operation(db=mock_async_session, value=5) + + assert result == 20 # (5 * 2) + 10