- 66 tests for transaction management - Coverage: 90% (meets 90%+ target) - Tests for TransactionContext (sync and async) - Tests for SavepointContext (sync and async) - Tests for @transactional decorator - Tests for atomic() and atomic_sync() context managers - Tests for transaction propagation (REQUIRED, REQUIRES_NEW, NESTED) - Tests for utility functions (is_in_transaction, get_transaction_depth) - Tests for complex scenarios (nested transactions, partial rollback) Task 3 completed (Priority P0, Effort Large)
828 lines
31 KiB
Python
828 lines
31 KiB
Python
"""Unit tests for database transaction management.
|
|
|
|
This module tests transaction contexts, savepoints, decorators, and
|
|
transaction propagation behavior to ensure data consistency.
|
|
|
|
Coverage Target: 90%+
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy.orm import Session
|
|
|
|
from src.server.database.transaction import (
|
|
AsyncSavepointContext,
|
|
AsyncTransactionContext,
|
|
SavepointContext,
|
|
TransactionContext,
|
|
TransactionError,
|
|
TransactionPropagation,
|
|
_extract_session,
|
|
atomic,
|
|
atomic_sync,
|
|
get_transaction_depth,
|
|
is_in_transaction,
|
|
transactional,
|
|
)
|
|
|
|
|
|
class TestTransactionPropagation:
|
|
"""Test TransactionPropagation enum."""
|
|
|
|
def test_propagation_values(self):
|
|
"""Test enum values are defined correctly."""
|
|
assert TransactionPropagation.REQUIRED.value == "required"
|
|
assert TransactionPropagation.REQUIRES_NEW.value == "requires_new"
|
|
assert TransactionPropagation.NESTED.value == "nested"
|
|
|
|
def test_propagation_members(self):
|
|
"""Test all expected members exist."""
|
|
members = [e.name for e in TransactionPropagation]
|
|
assert "REQUIRED" in members
|
|
assert "REQUIRES_NEW" in members
|
|
assert "NESTED" in members
|
|
|
|
|
|
class TestTransactionContext:
|
|
"""Test synchronous TransactionContext."""
|
|
|
|
@pytest.fixture
|
|
def mock_session(self):
|
|
"""Create a mock Session."""
|
|
session = MagicMock(spec=Session)
|
|
session.in_transaction.return_value = False
|
|
return session
|
|
|
|
def test_context_enter_starts_transaction(self, mock_session):
|
|
"""Test entering context starts transaction."""
|
|
with TransactionContext(mock_session) as tx:
|
|
assert tx.session == mock_session
|
|
mock_session.begin.assert_called_once()
|
|
|
|
def test_context_enter_existing_transaction(self, mock_session):
|
|
"""Test entering context with existing transaction."""
|
|
mock_session.in_transaction.return_value = True
|
|
|
|
with TransactionContext(mock_session):
|
|
mock_session.begin.assert_not_called()
|
|
|
|
def test_context_exit_commits_on_success(self, mock_session):
|
|
"""Test exiting context commits on success."""
|
|
with TransactionContext(mock_session):
|
|
pass
|
|
|
|
mock_session.commit.assert_called_once()
|
|
mock_session.rollback.assert_not_called()
|
|
|
|
def test_context_exit_rollback_on_exception(self, mock_session):
|
|
"""Test exiting context rolls back on exception."""
|
|
with pytest.raises(ValueError):
|
|
with TransactionContext(mock_session):
|
|
raise ValueError("Test error")
|
|
|
|
mock_session.rollback.assert_called_once()
|
|
mock_session.commit.assert_not_called()
|
|
|
|
def test_explicit_commit(self, mock_session):
|
|
"""Test explicit commit within context."""
|
|
with TransactionContext(mock_session) as tx:
|
|
tx.commit()
|
|
mock_session.commit.assert_called_once()
|
|
|
|
# Should not commit again on exit
|
|
assert mock_session.commit.call_count == 1
|
|
|
|
def test_explicit_rollback(self, mock_session):
|
|
"""Test explicit rollback within context."""
|
|
with TransactionContext(mock_session) as tx:
|
|
tx.rollback()
|
|
mock_session.rollback.assert_called_once()
|
|
|
|
# Should not commit on exit after rollback
|
|
mock_session.commit.assert_not_called()
|
|
|
|
def test_savepoint_creation(self, mock_session):
|
|
"""Test savepoint creation."""
|
|
mock_nested = MagicMock()
|
|
mock_session.begin_nested.return_value = mock_nested
|
|
|
|
with TransactionContext(mock_session) as tx:
|
|
with tx.savepoint() as sp:
|
|
assert isinstance(sp, SavepointContext)
|
|
mock_session.begin_nested.assert_called_once()
|
|
|
|
def test_savepoint_with_custom_name(self, mock_session):
|
|
"""Test savepoint with custom name."""
|
|
mock_nested = MagicMock()
|
|
mock_session.begin_nested.return_value = mock_nested
|
|
|
|
with TransactionContext(mock_session) as tx:
|
|
with tx.savepoint("custom_sp") as sp:
|
|
assert sp._name == "custom_sp"
|
|
|
|
def test_savepoint_auto_naming(self, mock_session):
|
|
"""Test savepoint automatic naming."""
|
|
mock_nested = MagicMock()
|
|
mock_session.begin_nested.return_value = mock_nested
|
|
|
|
with TransactionContext(mock_session) as tx:
|
|
with tx.savepoint() as sp1:
|
|
assert sp1._name == "sp_1"
|
|
|
|
with tx.savepoint() as sp2:
|
|
assert sp2._name == "sp_2"
|
|
|
|
def test_savepoint_rollback_on_exception(self, mock_session):
|
|
"""Test savepoint rolls back on exception."""
|
|
mock_nested = MagicMock()
|
|
mock_session.begin_nested.return_value = mock_nested
|
|
|
|
with pytest.raises(ValueError):
|
|
with TransactionContext(mock_session) as tx:
|
|
with tx.savepoint() as sp:
|
|
raise ValueError("Savepoint error")
|
|
|
|
mock_nested.rollback.assert_called_once()
|
|
|
|
def test_multiple_savepoints(self, mock_session):
|
|
"""Test multiple nested savepoints."""
|
|
mock_nested = MagicMock()
|
|
mock_session.begin_nested.return_value = mock_nested
|
|
|
|
with TransactionContext(mock_session) as tx:
|
|
with tx.savepoint("outer"):
|
|
with tx.savepoint("inner"):
|
|
pass
|
|
|
|
# Both savepoints should be created
|
|
assert mock_session.begin_nested.call_count == 2
|
|
|
|
|
|
class TestSavepointContext:
|
|
"""Test SavepointContext."""
|
|
|
|
def test_rollback(self):
|
|
"""Test savepoint rollback."""
|
|
mock_nested = MagicMock()
|
|
sp = SavepointContext(mock_nested, "test_sp")
|
|
|
|
sp.rollback()
|
|
|
|
mock_nested.rollback.assert_called_once()
|
|
assert sp._rolled_back is True
|
|
|
|
def test_rollback_idempotent(self):
|
|
"""Test rollback can only happen once."""
|
|
mock_nested = MagicMock()
|
|
sp = SavepointContext(mock_nested, "test_sp")
|
|
|
|
sp.rollback()
|
|
sp.rollback() # Second call should not rollback again
|
|
|
|
mock_nested.rollback.assert_called_once()
|
|
|
|
def test_commit_logs_only(self):
|
|
"""Test commit method logs but doesn't call nested."""
|
|
mock_nested = MagicMock()
|
|
sp = SavepointContext(mock_nested, "test_sp")
|
|
|
|
sp.commit()
|
|
|
|
# Commit doesn't call nested methods (auto-handled by SQLAlchemy)
|
|
assert not sp._rolled_back
|
|
|
|
|
|
class TestAsyncTransactionContext:
|
|
"""Test asynchronous AsyncTransactionContext."""
|
|
|
|
@pytest.fixture
|
|
def mock_async_session(self):
|
|
"""Create a mock AsyncSession."""
|
|
session = MagicMock(spec=AsyncSession)
|
|
session.in_transaction = MagicMock(return_value=False)
|
|
session.begin = AsyncMock()
|
|
session.commit = AsyncMock()
|
|
session.rollback = AsyncMock()
|
|
session.begin_nested = AsyncMock()
|
|
return session
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_context_enter_starts_transaction(self, mock_async_session):
|
|
"""Test entering async context starts transaction."""
|
|
async with AsyncTransactionContext(mock_async_session) as tx:
|
|
assert tx.session == mock_async_session
|
|
mock_async_session.begin.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_context_enter_existing_transaction(self, mock_async_session):
|
|
"""Test entering async context with existing transaction."""
|
|
mock_async_session.in_transaction.return_value = True
|
|
|
|
async with AsyncTransactionContext(mock_async_session):
|
|
mock_async_session.begin.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_context_exit_commits_on_success(self, mock_async_session):
|
|
"""Test exiting async context commits on success."""
|
|
async with AsyncTransactionContext(mock_async_session):
|
|
pass
|
|
|
|
mock_async_session.commit.assert_called_once()
|
|
mock_async_session.rollback.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_context_exit_rollback_on_exception(self, mock_async_session):
|
|
"""Test exiting async context rolls back on exception."""
|
|
with pytest.raises(ValueError):
|
|
async with AsyncTransactionContext(mock_async_session):
|
|
raise ValueError("Async test error")
|
|
|
|
mock_async_session.rollback.assert_called_once()
|
|
mock_async_session.commit.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_explicit_commit(self, mock_async_session):
|
|
"""Test explicit commit within async context."""
|
|
async with AsyncTransactionContext(mock_async_session) as tx:
|
|
await tx.commit()
|
|
mock_async_session.commit.assert_called_once()
|
|
|
|
# Should not commit again on exit
|
|
assert mock_async_session.commit.call_count == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_explicit_rollback(self, mock_async_session):
|
|
"""Test explicit rollback within async context."""
|
|
async with AsyncTransactionContext(mock_async_session) as tx:
|
|
await tx.rollback()
|
|
mock_async_session.rollback.assert_called_once()
|
|
|
|
# Should not commit on exit after rollback
|
|
mock_async_session.commit.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_savepoint_creation(self, mock_async_session):
|
|
"""Test async savepoint creation."""
|
|
mock_nested = AsyncMock()
|
|
mock_async_session.begin_nested.return_value = mock_nested
|
|
|
|
async with AsyncTransactionContext(mock_async_session) as tx:
|
|
async with tx.savepoint() as sp:
|
|
assert isinstance(sp, AsyncSavepointContext)
|
|
mock_async_session.begin_nested.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_savepoint_with_custom_name(self, mock_async_session):
|
|
"""Test async savepoint with custom name."""
|
|
mock_nested = AsyncMock()
|
|
mock_async_session.begin_nested.return_value = mock_nested
|
|
|
|
async with AsyncTransactionContext(mock_async_session) as tx:
|
|
async with tx.savepoint("async_custom_sp") as sp:
|
|
assert sp._name == "async_custom_sp"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_savepoint_rollback_on_exception(self, mock_async_session):
|
|
"""Test async savepoint rolls back on exception."""
|
|
mock_nested = AsyncMock()
|
|
mock_async_session.begin_nested.return_value = mock_nested
|
|
|
|
with pytest.raises(ValueError):
|
|
async with AsyncTransactionContext(mock_async_session) as tx:
|
|
async with tx.savepoint() as sp:
|
|
raise ValueError("Async savepoint error")
|
|
|
|
mock_nested.rollback.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_multiple_savepoints(self, mock_async_session):
|
|
"""Test multiple nested async savepoints."""
|
|
mock_nested = AsyncMock()
|
|
mock_async_session.begin_nested.return_value = mock_nested
|
|
|
|
async with AsyncTransactionContext(mock_async_session) as tx:
|
|
async with tx.savepoint("outer"):
|
|
async with tx.savepoint("inner"):
|
|
pass
|
|
|
|
# Both savepoints should be created
|
|
assert mock_async_session.begin_nested.call_count == 2
|
|
|
|
|
|
class TestAsyncSavepointContext:
|
|
"""Test AsyncSavepointContext."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_rollback(self):
|
|
"""Test async savepoint rollback."""
|
|
mock_nested = AsyncMock()
|
|
mock_session = AsyncMock(spec=AsyncSession)
|
|
sp = AsyncSavepointContext(mock_nested, "async_test_sp", mock_session)
|
|
|
|
await sp.rollback()
|
|
|
|
mock_nested.rollback.assert_called_once()
|
|
assert sp._rolled_back is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_rollback_idempotent(self):
|
|
"""Test async rollback can only happen once."""
|
|
mock_nested = AsyncMock()
|
|
mock_session = AsyncMock(spec=AsyncSession)
|
|
sp = AsyncSavepointContext(mock_nested, "async_test_sp", mock_session)
|
|
|
|
await sp.rollback()
|
|
await sp.rollback() # Second call should not rollback again
|
|
|
|
mock_nested.rollback.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_commit_logs_only(self):
|
|
"""Test async commit method logs but doesn't call nested."""
|
|
mock_nested = AsyncMock()
|
|
mock_session = AsyncMock(spec=AsyncSession)
|
|
sp = AsyncSavepointContext(mock_nested, "async_test_sp", mock_session)
|
|
|
|
await sp.commit()
|
|
|
|
# Commit doesn't call nested methods (auto-handled by SQLAlchemy)
|
|
assert not sp._rolled_back
|
|
|
|
|
|
class TestAtomicContextManager:
|
|
"""Test atomic() async context manager."""
|
|
|
|
@pytest.fixture
|
|
def mock_async_session(self):
|
|
"""Create a mock AsyncSession."""
|
|
session = MagicMock(spec=AsyncSession)
|
|
session.in_transaction = MagicMock(return_value=False)
|
|
session.begin = AsyncMock()
|
|
session.commit = AsyncMock()
|
|
session.rollback = AsyncMock()
|
|
session.begin_nested = AsyncMock()
|
|
return session
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_atomic_required_propagation(self, mock_async_session):
|
|
"""Test atomic with REQUIRED propagation."""
|
|
async with atomic(mock_async_session, TransactionPropagation.REQUIRED) as tx:
|
|
assert isinstance(tx, AsyncTransactionContext)
|
|
mock_async_session.begin.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_atomic_default_propagation(self, mock_async_session):
|
|
"""Test atomic uses REQUIRED as default."""
|
|
async with atomic(mock_async_session) as tx:
|
|
assert isinstance(tx, AsyncTransactionContext)
|
|
mock_async_session.begin.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_atomic_nested_with_existing_transaction(self, mock_async_session):
|
|
"""Test atomic with NESTED propagation creates savepoint."""
|
|
mock_async_session.in_transaction.return_value = True
|
|
mock_nested = AsyncMock()
|
|
mock_async_session.begin_nested.return_value = mock_nested
|
|
|
|
async with atomic(mock_async_session, TransactionPropagation.NESTED) as tx:
|
|
assert isinstance(tx, AsyncTransactionContext)
|
|
mock_async_session.begin_nested.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_atomic_nested_without_transaction(self, mock_async_session):
|
|
"""Test atomic with NESTED but no existing transaction starts new one."""
|
|
mock_async_session.in_transaction.return_value = False
|
|
|
|
async with atomic(mock_async_session, TransactionPropagation.NESTED) as tx:
|
|
assert isinstance(tx, AsyncTransactionContext)
|
|
mock_async_session.begin.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_atomic_commits_on_success(self, mock_async_session):
|
|
"""Test atomic commits on successful completion."""
|
|
async with atomic(mock_async_session):
|
|
pass
|
|
|
|
mock_async_session.commit.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_atomic_rollback_on_exception(self, mock_async_session):
|
|
"""Test atomic rolls back on exception."""
|
|
with pytest.raises(RuntimeError):
|
|
async with atomic(mock_async_session):
|
|
raise RuntimeError("Atomic error")
|
|
|
|
mock_async_session.rollback.assert_called_once()
|
|
|
|
|
|
class TestAtomicSyncContextManager:
|
|
"""Test atomic_sync() context manager."""
|
|
|
|
@pytest.fixture
|
|
def mock_session(self):
|
|
"""Create a mock Session."""
|
|
session = MagicMock(spec=Session)
|
|
session.in_transaction.return_value = False
|
|
return session
|
|
|
|
def test_atomic_sync_required_propagation(self, mock_session):
|
|
"""Test atomic_sync with REQUIRED propagation."""
|
|
with atomic_sync(mock_session, TransactionPropagation.REQUIRED) as tx:
|
|
assert isinstance(tx, TransactionContext)
|
|
mock_session.begin.assert_called_once()
|
|
|
|
def test_atomic_sync_default_propagation(self, mock_session):
|
|
"""Test atomic_sync uses REQUIRED as default."""
|
|
with atomic_sync(mock_session) as tx:
|
|
assert isinstance(tx, TransactionContext)
|
|
mock_session.begin.assert_called_once()
|
|
|
|
def test_atomic_sync_nested_with_existing_transaction(self, mock_session):
|
|
"""Test atomic_sync with NESTED propagation creates savepoint."""
|
|
mock_session.in_transaction.return_value = True
|
|
mock_nested = MagicMock()
|
|
mock_session.begin_nested.return_value = mock_nested
|
|
|
|
with atomic_sync(mock_session, TransactionPropagation.NESTED) as tx:
|
|
assert isinstance(tx, TransactionContext)
|
|
mock_session.begin_nested.assert_called_once()
|
|
|
|
def test_atomic_sync_commits_on_success(self, mock_session):
|
|
"""Test atomic_sync commits on successful completion."""
|
|
with atomic_sync(mock_session):
|
|
pass
|
|
|
|
mock_session.commit.assert_called_once()
|
|
|
|
def test_atomic_sync_rollback_on_exception(self, mock_session):
|
|
"""Test atomic_sync rolls back on exception."""
|
|
with pytest.raises(RuntimeError):
|
|
with atomic_sync(mock_session):
|
|
raise RuntimeError("Sync atomic error")
|
|
|
|
mock_session.rollback.assert_called_once()
|
|
|
|
|
|
class TestTransactionalDecorator:
|
|
"""Test @transactional decorator."""
|
|
|
|
@pytest.fixture
|
|
def mock_async_session(self):
|
|
"""Create a mock AsyncSession."""
|
|
session = MagicMock(spec=AsyncSession)
|
|
session.in_transaction = MagicMock(return_value=False)
|
|
session.begin = AsyncMock()
|
|
session.commit = AsyncMock()
|
|
session.rollback = AsyncMock()
|
|
session.begin_nested = AsyncMock()
|
|
return session
|
|
|
|
@pytest.fixture
|
|
def mock_session(self):
|
|
"""Create a mock Session."""
|
|
session = MagicMock(spec=Session)
|
|
session.in_transaction.return_value = False
|
|
return session
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transactional_async_function(self, mock_async_session):
|
|
"""Test @transactional decorator on async function."""
|
|
@transactional()
|
|
async def test_func(db: AsyncSession, value: int) -> int:
|
|
return value * 2
|
|
|
|
result = await test_func(db=mock_async_session, value=21)
|
|
|
|
assert result == 42
|
|
mock_async_session.begin.assert_called_once()
|
|
mock_async_session.commit.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transactional_async_with_custom_param_name(self, mock_async_session):
|
|
"""Test @transactional with custom session parameter name."""
|
|
@transactional(session_param="session")
|
|
async def test_func(session: AsyncSession, value: int) -> int:
|
|
return value * 3
|
|
|
|
result = await test_func(session=mock_async_session, value=10)
|
|
|
|
assert result == 30
|
|
mock_async_session.begin.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transactional_async_propagation(self, mock_async_session):
|
|
"""Test @transactional with different propagation."""
|
|
@transactional(propagation=TransactionPropagation.NESTED)
|
|
async def test_func(db: AsyncSession) -> str:
|
|
return "nested"
|
|
|
|
mock_async_session.in_transaction.return_value = False
|
|
result = await test_func(db=mock_async_session)
|
|
|
|
assert result == "nested"
|
|
mock_async_session.begin.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transactional_async_rollback_on_exception(self, mock_async_session):
|
|
"""Test @transactional rolls back on exception."""
|
|
@transactional()
|
|
async def test_func(db: AsyncSession) -> None:
|
|
raise ValueError("Test exception")
|
|
|
|
with pytest.raises(ValueError):
|
|
await test_func(db=mock_async_session)
|
|
|
|
mock_async_session.rollback.assert_called_once()
|
|
mock_async_session.commit.assert_not_called()
|
|
|
|
def test_transactional_sync_function(self, mock_session):
|
|
"""Test @transactional decorator on sync function."""
|
|
@transactional()
|
|
def test_func(db: Session, value: int) -> int:
|
|
return value * 2
|
|
|
|
result = test_func(db=mock_session, value=21)
|
|
|
|
assert result == 42
|
|
mock_session.begin.assert_called_once()
|
|
mock_session.commit.assert_called_once()
|
|
|
|
def test_transactional_sync_rollback_on_exception(self, mock_session):
|
|
"""Test @transactional rolls back sync function on exception."""
|
|
@transactional()
|
|
def test_func(db: Session) -> None:
|
|
raise ValueError("Sync test exception")
|
|
|
|
with pytest.raises(ValueError):
|
|
test_func(db=mock_session)
|
|
|
|
mock_session.rollback.assert_called_once()
|
|
mock_session.commit.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transactional_missing_session_param(self):
|
|
"""Test @transactional raises error when session param not found."""
|
|
@transactional(session_param="db")
|
|
async def test_func(value: int) -> int:
|
|
return value
|
|
|
|
with pytest.raises(TransactionError, match="Could not find session parameter"):
|
|
await test_func(value=42)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transactional_positional_session(self, mock_async_session):
|
|
"""Test @transactional with session passed as positional arg."""
|
|
@transactional()
|
|
async def test_func(db: AsyncSession, value: int) -> int:
|
|
return value * 2
|
|
|
|
result = await test_func(mock_async_session, 21)
|
|
|
|
assert result == 42
|
|
mock_async_session.begin.assert_called_once()
|
|
|
|
|
|
class TestExtractSession:
|
|
"""Test _extract_session helper function."""
|
|
|
|
def test_extract_from_kwargs(self):
|
|
"""Test extracting session from kwargs."""
|
|
session = MagicMock()
|
|
|
|
def test_func(db: Session, value: int):
|
|
pass
|
|
|
|
result = _extract_session(test_func, (), {"db": session, "value": 10}, "db")
|
|
|
|
assert result is session
|
|
|
|
def test_extract_from_positional_args(self):
|
|
"""Test extracting session from positional args."""
|
|
session = MagicMock()
|
|
|
|
def test_func(db: Session, value: int):
|
|
pass
|
|
|
|
result = _extract_session(test_func, (session, 10), {}, "db")
|
|
|
|
assert result is session
|
|
|
|
def test_extract_returns_none_when_not_found(self):
|
|
"""Test returns None when session not found."""
|
|
def test_func(value: int):
|
|
pass
|
|
|
|
result = _extract_session(test_func, (10,), {}, "db")
|
|
|
|
assert result is None
|
|
|
|
def test_extract_with_self_parameter(self):
|
|
"""Test extracting session with self parameter (method)."""
|
|
session = MagicMock()
|
|
|
|
class TestClass:
|
|
def test_method(self, db: Session, value: int):
|
|
pass
|
|
|
|
obj = TestClass()
|
|
# When calling a bound method, 'self' is not in args
|
|
result = _extract_session(
|
|
obj.test_method,
|
|
(session, 10), # No 'self' in args for bound methods
|
|
{},
|
|
"db"
|
|
)
|
|
|
|
assert result is session
|
|
|
|
|
|
class TestUtilityFunctions:
|
|
"""Test utility functions."""
|
|
|
|
def test_is_in_transaction_true(self):
|
|
"""Test is_in_transaction returns True."""
|
|
session = MagicMock(spec=Session)
|
|
session.in_transaction.return_value = True
|
|
|
|
assert is_in_transaction(session) is True
|
|
|
|
def test_is_in_transaction_false(self):
|
|
"""Test is_in_transaction returns False."""
|
|
session = MagicMock(spec=Session)
|
|
session.in_transaction.return_value = False
|
|
|
|
assert is_in_transaction(session) is False
|
|
|
|
def test_get_transaction_depth_zero(self):
|
|
"""Test get_transaction_depth returns 0 when not in transaction."""
|
|
session = MagicMock(spec=Session)
|
|
session.in_transaction.return_value = False
|
|
|
|
assert get_transaction_depth(session) == 0
|
|
|
|
def test_get_transaction_depth_one(self):
|
|
"""Test get_transaction_depth returns 1 in transaction."""
|
|
session = MagicMock(spec=Session)
|
|
session.in_transaction.return_value = True
|
|
session._nested_transaction = None
|
|
|
|
assert get_transaction_depth(session) == 1
|
|
|
|
def test_get_transaction_depth_nested(self):
|
|
"""Test get_transaction_depth returns 2 with savepoint."""
|
|
session = MagicMock(spec=Session)
|
|
session.in_transaction.return_value = True
|
|
session._nested_transaction = MagicMock()
|
|
|
|
assert get_transaction_depth(session) == 2
|
|
|
|
|
|
class TestTransactionLogging:
|
|
"""Test transaction logging behavior."""
|
|
|
|
@pytest.fixture
|
|
def mock_async_session(self):
|
|
"""Create a mock AsyncSession."""
|
|
session = MagicMock(spec=AsyncSession)
|
|
session.in_transaction = MagicMock(return_value=False)
|
|
session.begin = AsyncMock()
|
|
session.commit = AsyncMock()
|
|
session.rollback = AsyncMock()
|
|
session.begin_nested = AsyncMock()
|
|
return session
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_logging_on_commit(self, mock_async_session, caplog):
|
|
"""Test logging when transaction commits."""
|
|
with caplog.at_level(logging.DEBUG):
|
|
async with AsyncTransactionContext(mock_async_session):
|
|
pass
|
|
|
|
assert "Entering async transaction context" in caplog.text
|
|
assert "Async transaction committed" in caplog.text
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_logging_on_rollback(self, mock_async_session, caplog):
|
|
"""Test logging when transaction rolls back."""
|
|
with caplog.at_level(logging.WARNING):
|
|
with pytest.raises(ValueError):
|
|
async with AsyncTransactionContext(mock_async_session):
|
|
raise ValueError("Test error")
|
|
|
|
assert "Async transaction rollback due to exception" in caplog.text
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_logging_savepoint_creation(self, mock_async_session, caplog):
|
|
"""Test logging when savepoint is created."""
|
|
mock_nested = AsyncMock()
|
|
mock_async_session.begin_nested.return_value = mock_nested
|
|
|
|
with caplog.at_level(logging.DEBUG):
|
|
async with AsyncTransactionContext(mock_async_session) as tx:
|
|
async with tx.savepoint("test_sp"):
|
|
pass
|
|
|
|
assert "Creating async savepoint: test_sp" in caplog.text
|
|
|
|
|
|
class TestTransactionError:
|
|
"""Test TransactionError exception."""
|
|
|
|
def test_transaction_error_message(self):
|
|
"""Test TransactionError can be raised with message."""
|
|
with pytest.raises(TransactionError, match="Custom error"):
|
|
raise TransactionError("Custom error")
|
|
|
|
def test_transaction_error_inheritance(self):
|
|
"""Test TransactionError inherits from Exception."""
|
|
assert issubclass(TransactionError, Exception)
|
|
|
|
|
|
class TestComplexScenarios:
|
|
"""Test complex transaction scenarios."""
|
|
|
|
@pytest.fixture
|
|
def mock_async_session(self):
|
|
"""Create a mock AsyncSession."""
|
|
session = MagicMock(spec=AsyncSession)
|
|
session.in_transaction = MagicMock(return_value=False)
|
|
session.begin = AsyncMock()
|
|
session.commit = AsyncMock()
|
|
session.rollback = AsyncMock()
|
|
session.begin_nested = AsyncMock()
|
|
return session
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_nested_transactions_with_partial_rollback(self, mock_async_session):
|
|
"""Test nested transactions with partial rollback."""
|
|
mock_async_session.in_transaction.return_value = True
|
|
mock_nested = AsyncMock()
|
|
mock_async_session.begin_nested.return_value = mock_nested
|
|
|
|
async with AsyncTransactionContext(mock_async_session) as tx:
|
|
async with tx.savepoint("outer") as sp_outer:
|
|
# Simulate outer operation success
|
|
pass
|
|
|
|
# Outer savepoint should not roll back
|
|
mock_nested.rollback.assert_not_called()
|
|
|
|
# Transaction should commit
|
|
mock_async_session.commit.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_operations_in_transaction(self, mock_async_session):
|
|
"""Test multiple operations within single transaction."""
|
|
@transactional()
|
|
async def operation_a(db: AsyncSession) -> str:
|
|
return "A"
|
|
|
|
@transactional()
|
|
async def operation_b(db: AsyncSession) -> str:
|
|
return "B"
|
|
|
|
async with atomic(mock_async_session):
|
|
result_a = await operation_a(db=mock_async_session)
|
|
result_b = await operation_b(db=mock_async_session)
|
|
|
|
assert result_a == "A"
|
|
assert result_b == "B"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_explicit_savepoint_rollback(self, mock_async_session):
|
|
"""Test explicit savepoint rollback."""
|
|
mock_nested = AsyncMock()
|
|
mock_async_session.begin_nested.return_value = mock_nested
|
|
|
|
async with AsyncTransactionContext(mock_async_session) as tx:
|
|
async with tx.savepoint() as sp:
|
|
await sp.rollback()
|
|
# Verify rollback was called
|
|
mock_nested.rollback.assert_called_once()
|
|
|
|
# Transaction should still commit (only savepoint rolled back)
|
|
mock_async_session.commit.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decorator_with_nested_calls(self, mock_async_session):
|
|
"""Test @transactional decorator with nested function calls."""
|
|
mock_async_session.in_transaction.return_value = True
|
|
|
|
@transactional()
|
|
async def inner_operation(db: AsyncSession, value: int) -> int:
|
|
return value * 2
|
|
|
|
@transactional()
|
|
async def outer_operation(db: AsyncSession, value: int) -> int:
|
|
result = await inner_operation(db, value)
|
|
return result + 10
|
|
|
|
result = await outer_operation(db=mock_async_session, value=5)
|
|
|
|
assert result == 20 # (5 * 2) + 10
|