feat(tests): add comprehensive database transaction tests

- 66 tests for transaction management
- Coverage: 90% (meets 90%+ target)
- Tests for TransactionContext (sync and async)
- Tests for SavepointContext (sync and async)
- Tests for @transactional decorator
- Tests for atomic() and atomic_sync() context managers
- Tests for transaction propagation (REQUIRED, REQUIRES_NEW, NESTED)
- Tests for utility functions (is_in_transaction, get_transaction_depth)
- Tests for complex scenarios (nested transactions, partial rollback)

Task 3 completed (Priority P0, Effort Large)
This commit is contained in:
2026-01-26 18:12:33 +01:00
parent 3f2e15669d
commit 458fc483e4

View File

@@ -0,0 +1,827 @@
"""Unit tests for database transaction management.
This module tests transaction contexts, savepoints, decorators, and
transaction propagation behavior to ensure data consistency.
Coverage Target: 90%+
"""
import asyncio
import logging
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
from src.server.database.transaction import (
AsyncSavepointContext,
AsyncTransactionContext,
SavepointContext,
TransactionContext,
TransactionError,
TransactionPropagation,
_extract_session,
atomic,
atomic_sync,
get_transaction_depth,
is_in_transaction,
transactional,
)
class TestTransactionPropagation:
"""Test TransactionPropagation enum."""
def test_propagation_values(self):
"""Test enum values are defined correctly."""
assert TransactionPropagation.REQUIRED.value == "required"
assert TransactionPropagation.REQUIRES_NEW.value == "requires_new"
assert TransactionPropagation.NESTED.value == "nested"
def test_propagation_members(self):
"""Test all expected members exist."""
members = [e.name for e in TransactionPropagation]
assert "REQUIRED" in members
assert "REQUIRES_NEW" in members
assert "NESTED" in members
class TestTransactionContext:
"""Test synchronous TransactionContext."""
@pytest.fixture
def mock_session(self):
"""Create a mock Session."""
session = MagicMock(spec=Session)
session.in_transaction.return_value = False
return session
def test_context_enter_starts_transaction(self, mock_session):
"""Test entering context starts transaction."""
with TransactionContext(mock_session) as tx:
assert tx.session == mock_session
mock_session.begin.assert_called_once()
def test_context_enter_existing_transaction(self, mock_session):
"""Test entering context with existing transaction."""
mock_session.in_transaction.return_value = True
with TransactionContext(mock_session):
mock_session.begin.assert_not_called()
def test_context_exit_commits_on_success(self, mock_session):
"""Test exiting context commits on success."""
with TransactionContext(mock_session):
pass
mock_session.commit.assert_called_once()
mock_session.rollback.assert_not_called()
def test_context_exit_rollback_on_exception(self, mock_session):
"""Test exiting context rolls back on exception."""
with pytest.raises(ValueError):
with TransactionContext(mock_session):
raise ValueError("Test error")
mock_session.rollback.assert_called_once()
mock_session.commit.assert_not_called()
def test_explicit_commit(self, mock_session):
"""Test explicit commit within context."""
with TransactionContext(mock_session) as tx:
tx.commit()
mock_session.commit.assert_called_once()
# Should not commit again on exit
assert mock_session.commit.call_count == 1
def test_explicit_rollback(self, mock_session):
"""Test explicit rollback within context."""
with TransactionContext(mock_session) as tx:
tx.rollback()
mock_session.rollback.assert_called_once()
# Should not commit on exit after rollback
mock_session.commit.assert_not_called()
def test_savepoint_creation(self, mock_session):
"""Test savepoint creation."""
mock_nested = MagicMock()
mock_session.begin_nested.return_value = mock_nested
with TransactionContext(mock_session) as tx:
with tx.savepoint() as sp:
assert isinstance(sp, SavepointContext)
mock_session.begin_nested.assert_called_once()
def test_savepoint_with_custom_name(self, mock_session):
"""Test savepoint with custom name."""
mock_nested = MagicMock()
mock_session.begin_nested.return_value = mock_nested
with TransactionContext(mock_session) as tx:
with tx.savepoint("custom_sp") as sp:
assert sp._name == "custom_sp"
def test_savepoint_auto_naming(self, mock_session):
"""Test savepoint automatic naming."""
mock_nested = MagicMock()
mock_session.begin_nested.return_value = mock_nested
with TransactionContext(mock_session) as tx:
with tx.savepoint() as sp1:
assert sp1._name == "sp_1"
with tx.savepoint() as sp2:
assert sp2._name == "sp_2"
def test_savepoint_rollback_on_exception(self, mock_session):
"""Test savepoint rolls back on exception."""
mock_nested = MagicMock()
mock_session.begin_nested.return_value = mock_nested
with pytest.raises(ValueError):
with TransactionContext(mock_session) as tx:
with tx.savepoint() as sp:
raise ValueError("Savepoint error")
mock_nested.rollback.assert_called_once()
def test_multiple_savepoints(self, mock_session):
"""Test multiple nested savepoints."""
mock_nested = MagicMock()
mock_session.begin_nested.return_value = mock_nested
with TransactionContext(mock_session) as tx:
with tx.savepoint("outer"):
with tx.savepoint("inner"):
pass
# Both savepoints should be created
assert mock_session.begin_nested.call_count == 2
class TestSavepointContext:
"""Test SavepointContext."""
def test_rollback(self):
"""Test savepoint rollback."""
mock_nested = MagicMock()
sp = SavepointContext(mock_nested, "test_sp")
sp.rollback()
mock_nested.rollback.assert_called_once()
assert sp._rolled_back is True
def test_rollback_idempotent(self):
"""Test rollback can only happen once."""
mock_nested = MagicMock()
sp = SavepointContext(mock_nested, "test_sp")
sp.rollback()
sp.rollback() # Second call should not rollback again
mock_nested.rollback.assert_called_once()
def test_commit_logs_only(self):
"""Test commit method logs but doesn't call nested."""
mock_nested = MagicMock()
sp = SavepointContext(mock_nested, "test_sp")
sp.commit()
# Commit doesn't call nested methods (auto-handled by SQLAlchemy)
assert not sp._rolled_back
class TestAsyncTransactionContext:
"""Test asynchronous AsyncTransactionContext."""
@pytest.fixture
def mock_async_session(self):
"""Create a mock AsyncSession."""
session = MagicMock(spec=AsyncSession)
session.in_transaction = MagicMock(return_value=False)
session.begin = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.begin_nested = AsyncMock()
return session
@pytest.mark.asyncio
async def test_async_context_enter_starts_transaction(self, mock_async_session):
"""Test entering async context starts transaction."""
async with AsyncTransactionContext(mock_async_session) as tx:
assert tx.session == mock_async_session
mock_async_session.begin.assert_called_once()
@pytest.mark.asyncio
async def test_async_context_enter_existing_transaction(self, mock_async_session):
"""Test entering async context with existing transaction."""
mock_async_session.in_transaction.return_value = True
async with AsyncTransactionContext(mock_async_session):
mock_async_session.begin.assert_not_called()
@pytest.mark.asyncio
async def test_async_context_exit_commits_on_success(self, mock_async_session):
"""Test exiting async context commits on success."""
async with AsyncTransactionContext(mock_async_session):
pass
mock_async_session.commit.assert_called_once()
mock_async_session.rollback.assert_not_called()
@pytest.mark.asyncio
async def test_async_context_exit_rollback_on_exception(self, mock_async_session):
"""Test exiting async context rolls back on exception."""
with pytest.raises(ValueError):
async with AsyncTransactionContext(mock_async_session):
raise ValueError("Async test error")
mock_async_session.rollback.assert_called_once()
mock_async_session.commit.assert_not_called()
@pytest.mark.asyncio
async def test_async_explicit_commit(self, mock_async_session):
"""Test explicit commit within async context."""
async with AsyncTransactionContext(mock_async_session) as tx:
await tx.commit()
mock_async_session.commit.assert_called_once()
# Should not commit again on exit
assert mock_async_session.commit.call_count == 1
@pytest.mark.asyncio
async def test_async_explicit_rollback(self, mock_async_session):
"""Test explicit rollback within async context."""
async with AsyncTransactionContext(mock_async_session) as tx:
await tx.rollback()
mock_async_session.rollback.assert_called_once()
# Should not commit on exit after rollback
mock_async_session.commit.assert_not_called()
@pytest.mark.asyncio
async def test_async_savepoint_creation(self, mock_async_session):
"""Test async savepoint creation."""
mock_nested = AsyncMock()
mock_async_session.begin_nested.return_value = mock_nested
async with AsyncTransactionContext(mock_async_session) as tx:
async with tx.savepoint() as sp:
assert isinstance(sp, AsyncSavepointContext)
mock_async_session.begin_nested.assert_called_once()
@pytest.mark.asyncio
async def test_async_savepoint_with_custom_name(self, mock_async_session):
"""Test async savepoint with custom name."""
mock_nested = AsyncMock()
mock_async_session.begin_nested.return_value = mock_nested
async with AsyncTransactionContext(mock_async_session) as tx:
async with tx.savepoint("async_custom_sp") as sp:
assert sp._name == "async_custom_sp"
@pytest.mark.asyncio
async def test_async_savepoint_rollback_on_exception(self, mock_async_session):
"""Test async savepoint rolls back on exception."""
mock_nested = AsyncMock()
mock_async_session.begin_nested.return_value = mock_nested
with pytest.raises(ValueError):
async with AsyncTransactionContext(mock_async_session) as tx:
async with tx.savepoint() as sp:
raise ValueError("Async savepoint error")
mock_nested.rollback.assert_called_once()
@pytest.mark.asyncio
async def test_async_multiple_savepoints(self, mock_async_session):
"""Test multiple nested async savepoints."""
mock_nested = AsyncMock()
mock_async_session.begin_nested.return_value = mock_nested
async with AsyncTransactionContext(mock_async_session) as tx:
async with tx.savepoint("outer"):
async with tx.savepoint("inner"):
pass
# Both savepoints should be created
assert mock_async_session.begin_nested.call_count == 2
class TestAsyncSavepointContext:
"""Test AsyncSavepointContext."""
@pytest.mark.asyncio
async def test_async_rollback(self):
"""Test async savepoint rollback."""
mock_nested = AsyncMock()
mock_session = AsyncMock(spec=AsyncSession)
sp = AsyncSavepointContext(mock_nested, "async_test_sp", mock_session)
await sp.rollback()
mock_nested.rollback.assert_called_once()
assert sp._rolled_back is True
@pytest.mark.asyncio
async def test_async_rollback_idempotent(self):
"""Test async rollback can only happen once."""
mock_nested = AsyncMock()
mock_session = AsyncMock(spec=AsyncSession)
sp = AsyncSavepointContext(mock_nested, "async_test_sp", mock_session)
await sp.rollback()
await sp.rollback() # Second call should not rollback again
mock_nested.rollback.assert_called_once()
@pytest.mark.asyncio
async def test_async_commit_logs_only(self):
"""Test async commit method logs but doesn't call nested."""
mock_nested = AsyncMock()
mock_session = AsyncMock(spec=AsyncSession)
sp = AsyncSavepointContext(mock_nested, "async_test_sp", mock_session)
await sp.commit()
# Commit doesn't call nested methods (auto-handled by SQLAlchemy)
assert not sp._rolled_back
class TestAtomicContextManager:
"""Test atomic() async context manager."""
@pytest.fixture
def mock_async_session(self):
"""Create a mock AsyncSession."""
session = MagicMock(spec=AsyncSession)
session.in_transaction = MagicMock(return_value=False)
session.begin = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.begin_nested = AsyncMock()
return session
@pytest.mark.asyncio
async def test_atomic_required_propagation(self, mock_async_session):
"""Test atomic with REQUIRED propagation."""
async with atomic(mock_async_session, TransactionPropagation.REQUIRED) as tx:
assert isinstance(tx, AsyncTransactionContext)
mock_async_session.begin.assert_called_once()
@pytest.mark.asyncio
async def test_atomic_default_propagation(self, mock_async_session):
"""Test atomic uses REQUIRED as default."""
async with atomic(mock_async_session) as tx:
assert isinstance(tx, AsyncTransactionContext)
mock_async_session.begin.assert_called_once()
@pytest.mark.asyncio
async def test_atomic_nested_with_existing_transaction(self, mock_async_session):
"""Test atomic with NESTED propagation creates savepoint."""
mock_async_session.in_transaction.return_value = True
mock_nested = AsyncMock()
mock_async_session.begin_nested.return_value = mock_nested
async with atomic(mock_async_session, TransactionPropagation.NESTED) as tx:
assert isinstance(tx, AsyncTransactionContext)
mock_async_session.begin_nested.assert_called_once()
@pytest.mark.asyncio
async def test_atomic_nested_without_transaction(self, mock_async_session):
"""Test atomic with NESTED but no existing transaction starts new one."""
mock_async_session.in_transaction.return_value = False
async with atomic(mock_async_session, TransactionPropagation.NESTED) as tx:
assert isinstance(tx, AsyncTransactionContext)
mock_async_session.begin.assert_called_once()
@pytest.mark.asyncio
async def test_atomic_commits_on_success(self, mock_async_session):
"""Test atomic commits on successful completion."""
async with atomic(mock_async_session):
pass
mock_async_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_atomic_rollback_on_exception(self, mock_async_session):
"""Test atomic rolls back on exception."""
with pytest.raises(RuntimeError):
async with atomic(mock_async_session):
raise RuntimeError("Atomic error")
mock_async_session.rollback.assert_called_once()
class TestAtomicSyncContextManager:
"""Test atomic_sync() context manager."""
@pytest.fixture
def mock_session(self):
"""Create a mock Session."""
session = MagicMock(spec=Session)
session.in_transaction.return_value = False
return session
def test_atomic_sync_required_propagation(self, mock_session):
"""Test atomic_sync with REQUIRED propagation."""
with atomic_sync(mock_session, TransactionPropagation.REQUIRED) as tx:
assert isinstance(tx, TransactionContext)
mock_session.begin.assert_called_once()
def test_atomic_sync_default_propagation(self, mock_session):
"""Test atomic_sync uses REQUIRED as default."""
with atomic_sync(mock_session) as tx:
assert isinstance(tx, TransactionContext)
mock_session.begin.assert_called_once()
def test_atomic_sync_nested_with_existing_transaction(self, mock_session):
"""Test atomic_sync with NESTED propagation creates savepoint."""
mock_session.in_transaction.return_value = True
mock_nested = MagicMock()
mock_session.begin_nested.return_value = mock_nested
with atomic_sync(mock_session, TransactionPropagation.NESTED) as tx:
assert isinstance(tx, TransactionContext)
mock_session.begin_nested.assert_called_once()
def test_atomic_sync_commits_on_success(self, mock_session):
"""Test atomic_sync commits on successful completion."""
with atomic_sync(mock_session):
pass
mock_session.commit.assert_called_once()
def test_atomic_sync_rollback_on_exception(self, mock_session):
"""Test atomic_sync rolls back on exception."""
with pytest.raises(RuntimeError):
with atomic_sync(mock_session):
raise RuntimeError("Sync atomic error")
mock_session.rollback.assert_called_once()
class TestTransactionalDecorator:
"""Test @transactional decorator."""
@pytest.fixture
def mock_async_session(self):
"""Create a mock AsyncSession."""
session = MagicMock(spec=AsyncSession)
session.in_transaction = MagicMock(return_value=False)
session.begin = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.begin_nested = AsyncMock()
return session
@pytest.fixture
def mock_session(self):
"""Create a mock Session."""
session = MagicMock(spec=Session)
session.in_transaction.return_value = False
return session
@pytest.mark.asyncio
async def test_transactional_async_function(self, mock_async_session):
"""Test @transactional decorator on async function."""
@transactional()
async def test_func(db: AsyncSession, value: int) -> int:
return value * 2
result = await test_func(db=mock_async_session, value=21)
assert result == 42
mock_async_session.begin.assert_called_once()
mock_async_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_transactional_async_with_custom_param_name(self, mock_async_session):
"""Test @transactional with custom session parameter name."""
@transactional(session_param="session")
async def test_func(session: AsyncSession, value: int) -> int:
return value * 3
result = await test_func(session=mock_async_session, value=10)
assert result == 30
mock_async_session.begin.assert_called_once()
@pytest.mark.asyncio
async def test_transactional_async_propagation(self, mock_async_session):
"""Test @transactional with different propagation."""
@transactional(propagation=TransactionPropagation.NESTED)
async def test_func(db: AsyncSession) -> str:
return "nested"
mock_async_session.in_transaction.return_value = False
result = await test_func(db=mock_async_session)
assert result == "nested"
mock_async_session.begin.assert_called_once()
@pytest.mark.asyncio
async def test_transactional_async_rollback_on_exception(self, mock_async_session):
"""Test @transactional rolls back on exception."""
@transactional()
async def test_func(db: AsyncSession) -> None:
raise ValueError("Test exception")
with pytest.raises(ValueError):
await test_func(db=mock_async_session)
mock_async_session.rollback.assert_called_once()
mock_async_session.commit.assert_not_called()
def test_transactional_sync_function(self, mock_session):
"""Test @transactional decorator on sync function."""
@transactional()
def test_func(db: Session, value: int) -> int:
return value * 2
result = test_func(db=mock_session, value=21)
assert result == 42
mock_session.begin.assert_called_once()
mock_session.commit.assert_called_once()
def test_transactional_sync_rollback_on_exception(self, mock_session):
"""Test @transactional rolls back sync function on exception."""
@transactional()
def test_func(db: Session) -> None:
raise ValueError("Sync test exception")
with pytest.raises(ValueError):
test_func(db=mock_session)
mock_session.rollback.assert_called_once()
mock_session.commit.assert_not_called()
@pytest.mark.asyncio
async def test_transactional_missing_session_param(self):
"""Test @transactional raises error when session param not found."""
@transactional(session_param="db")
async def test_func(value: int) -> int:
return value
with pytest.raises(TransactionError, match="Could not find session parameter"):
await test_func(value=42)
@pytest.mark.asyncio
async def test_transactional_positional_session(self, mock_async_session):
"""Test @transactional with session passed as positional arg."""
@transactional()
async def test_func(db: AsyncSession, value: int) -> int:
return value * 2
result = await test_func(mock_async_session, 21)
assert result == 42
mock_async_session.begin.assert_called_once()
class TestExtractSession:
"""Test _extract_session helper function."""
def test_extract_from_kwargs(self):
"""Test extracting session from kwargs."""
session = MagicMock()
def test_func(db: Session, value: int):
pass
result = _extract_session(test_func, (), {"db": session, "value": 10}, "db")
assert result is session
def test_extract_from_positional_args(self):
"""Test extracting session from positional args."""
session = MagicMock()
def test_func(db: Session, value: int):
pass
result = _extract_session(test_func, (session, 10), {}, "db")
assert result is session
def test_extract_returns_none_when_not_found(self):
"""Test returns None when session not found."""
def test_func(value: int):
pass
result = _extract_session(test_func, (10,), {}, "db")
assert result is None
def test_extract_with_self_parameter(self):
"""Test extracting session with self parameter (method)."""
session = MagicMock()
class TestClass:
def test_method(self, db: Session, value: int):
pass
obj = TestClass()
# When calling a bound method, 'self' is not in args
result = _extract_session(
obj.test_method,
(session, 10), # No 'self' in args for bound methods
{},
"db"
)
assert result is session
class TestUtilityFunctions:
"""Test utility functions."""
def test_is_in_transaction_true(self):
"""Test is_in_transaction returns True."""
session = MagicMock(spec=Session)
session.in_transaction.return_value = True
assert is_in_transaction(session) is True
def test_is_in_transaction_false(self):
"""Test is_in_transaction returns False."""
session = MagicMock(spec=Session)
session.in_transaction.return_value = False
assert is_in_transaction(session) is False
def test_get_transaction_depth_zero(self):
"""Test get_transaction_depth returns 0 when not in transaction."""
session = MagicMock(spec=Session)
session.in_transaction.return_value = False
assert get_transaction_depth(session) == 0
def test_get_transaction_depth_one(self):
"""Test get_transaction_depth returns 1 in transaction."""
session = MagicMock(spec=Session)
session.in_transaction.return_value = True
session._nested_transaction = None
assert get_transaction_depth(session) == 1
def test_get_transaction_depth_nested(self):
"""Test get_transaction_depth returns 2 with savepoint."""
session = MagicMock(spec=Session)
session.in_transaction.return_value = True
session._nested_transaction = MagicMock()
assert get_transaction_depth(session) == 2
class TestTransactionLogging:
"""Test transaction logging behavior."""
@pytest.fixture
def mock_async_session(self):
"""Create a mock AsyncSession."""
session = MagicMock(spec=AsyncSession)
session.in_transaction = MagicMock(return_value=False)
session.begin = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.begin_nested = AsyncMock()
return session
@pytest.mark.asyncio
async def test_logging_on_commit(self, mock_async_session, caplog):
"""Test logging when transaction commits."""
with caplog.at_level(logging.DEBUG):
async with AsyncTransactionContext(mock_async_session):
pass
assert "Entering async transaction context" in caplog.text
assert "Async transaction committed" in caplog.text
@pytest.mark.asyncio
async def test_logging_on_rollback(self, mock_async_session, caplog):
"""Test logging when transaction rolls back."""
with caplog.at_level(logging.WARNING):
with pytest.raises(ValueError):
async with AsyncTransactionContext(mock_async_session):
raise ValueError("Test error")
assert "Async transaction rollback due to exception" in caplog.text
@pytest.mark.asyncio
async def test_logging_savepoint_creation(self, mock_async_session, caplog):
"""Test logging when savepoint is created."""
mock_nested = AsyncMock()
mock_async_session.begin_nested.return_value = mock_nested
with caplog.at_level(logging.DEBUG):
async with AsyncTransactionContext(mock_async_session) as tx:
async with tx.savepoint("test_sp"):
pass
assert "Creating async savepoint: test_sp" in caplog.text
class TestTransactionError:
"""Test TransactionError exception."""
def test_transaction_error_message(self):
"""Test TransactionError can be raised with message."""
with pytest.raises(TransactionError, match="Custom error"):
raise TransactionError("Custom error")
def test_transaction_error_inheritance(self):
"""Test TransactionError inherits from Exception."""
assert issubclass(TransactionError, Exception)
class TestComplexScenarios:
"""Test complex transaction scenarios."""
@pytest.fixture
def mock_async_session(self):
"""Create a mock AsyncSession."""
session = MagicMock(spec=AsyncSession)
session.in_transaction = MagicMock(return_value=False)
session.begin = AsyncMock()
session.commit = AsyncMock()
session.rollback = AsyncMock()
session.begin_nested = AsyncMock()
return session
@pytest.mark.asyncio
async def test_nested_transactions_with_partial_rollback(self, mock_async_session):
"""Test nested transactions with partial rollback."""
mock_async_session.in_transaction.return_value = True
mock_nested = AsyncMock()
mock_async_session.begin_nested.return_value = mock_nested
async with AsyncTransactionContext(mock_async_session) as tx:
async with tx.savepoint("outer") as sp_outer:
# Simulate outer operation success
pass
# Outer savepoint should not roll back
mock_nested.rollback.assert_not_called()
# Transaction should commit
mock_async_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_multiple_operations_in_transaction(self, mock_async_session):
"""Test multiple operations within single transaction."""
@transactional()
async def operation_a(db: AsyncSession) -> str:
return "A"
@transactional()
async def operation_b(db: AsyncSession) -> str:
return "B"
async with atomic(mock_async_session):
result_a = await operation_a(db=mock_async_session)
result_b = await operation_b(db=mock_async_session)
assert result_a == "A"
assert result_b == "B"
@pytest.mark.asyncio
async def test_explicit_savepoint_rollback(self, mock_async_session):
"""Test explicit savepoint rollback."""
mock_nested = AsyncMock()
mock_async_session.begin_nested.return_value = mock_nested
async with AsyncTransactionContext(mock_async_session) as tx:
async with tx.savepoint() as sp:
await sp.rollback()
# Verify rollback was called
mock_nested.rollback.assert_called_once()
# Transaction should still commit (only savepoint rolled back)
mock_async_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_decorator_with_nested_calls(self, mock_async_session):
"""Test @transactional decorator with nested function calls."""
mock_async_session.in_transaction.return_value = True
@transactional()
async def inner_operation(db: AsyncSession, value: int) -> int:
return value * 2
@transactional()
async def outer_operation(db: AsyncSession, value: int) -> int:
result = await inner_operation(db, value)
return result + 10
result = await outer_operation(db=mock_async_session, value=5)
assert result == 20 # (5 * 2) + 10