- Create transaction.py with @transactional decorator, atomic() context manager - Add TransactionPropagation modes: REQUIRED, REQUIRES_NEW, NESTED - Add savepoint support for nested transactions with partial rollback - Update connection.py with TransactionManager, get_transactional_session - Update service.py with bulk operations (bulk_mark_downloaded, bulk_delete) - Wrap QueueRepository.save_item() and clear_all() in atomic transactions - Add comprehensive tests (66 transaction tests, 90% coverage) - All 1090 tests passing
669 lines
24 KiB
Python
669 lines
24 KiB
Python
"""Unit tests for database transaction utilities.
|
|
|
|
Tests the transaction management utilities including decorators,
|
|
context managers, and helper functions.
|
|
"""
|
|
import asyncio
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
from src.server.database.base import Base
|
|
from src.server.database.transaction import (
|
|
AsyncTransactionContext,
|
|
TransactionContext,
|
|
TransactionError,
|
|
TransactionPropagation,
|
|
atomic,
|
|
atomic_sync,
|
|
is_in_transaction,
|
|
transactional,
|
|
)
|
|
|
|
# ============================================================================
|
|
# Fixtures
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.fixture
|
|
async def async_engine():
|
|
"""Create in-memory async database engine for testing."""
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
echo=False,
|
|
)
|
|
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
yield engine
|
|
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest.fixture
|
|
async def async_session(async_engine):
|
|
"""Create async database session for testing."""
|
|
from sqlalchemy.ext.asyncio import async_sessionmaker
|
|
|
|
async_session_factory = async_sessionmaker(
|
|
async_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
)
|
|
|
|
async with async_session_factory() as session:
|
|
yield session
|
|
await session.rollback()
|
|
|
|
|
|
# ============================================================================
|
|
# TransactionContext Tests (Sync)
|
|
# ============================================================================
|
|
|
|
|
|
class TestTransactionContext:
|
|
"""Tests for synchronous TransactionContext."""
|
|
|
|
def test_context_manager_protocol(self):
|
|
"""Test context manager enters and exits properly."""
|
|
mock_session = MagicMock(spec=Session)
|
|
mock_session.in_transaction.return_value = False
|
|
|
|
with TransactionContext(mock_session) as ctx:
|
|
assert ctx.session == mock_session
|
|
mock_session.begin.assert_called_once()
|
|
|
|
mock_session.commit.assert_called_once()
|
|
|
|
def test_rollback_on_exception(self):
|
|
"""Test rollback is called when exception occurs."""
|
|
mock_session = MagicMock(spec=Session)
|
|
mock_session.in_transaction.return_value = False
|
|
|
|
with pytest.raises(ValueError):
|
|
with TransactionContext(mock_session):
|
|
raise ValueError("Test error")
|
|
|
|
mock_session.rollback.assert_called_once()
|
|
mock_session.commit.assert_not_called()
|
|
|
|
def test_no_begin_if_already_in_transaction(self):
|
|
"""Test no new transaction started if already in one."""
|
|
mock_session = MagicMock(spec=Session)
|
|
mock_session.in_transaction.return_value = True
|
|
|
|
with TransactionContext(mock_session):
|
|
pass
|
|
|
|
mock_session.begin.assert_not_called()
|
|
|
|
def test_explicit_commit(self):
|
|
"""Test explicit commit within context."""
|
|
mock_session = MagicMock(spec=Session)
|
|
mock_session.in_transaction.return_value = False
|
|
|
|
with TransactionContext(mock_session) as ctx:
|
|
ctx.commit()
|
|
mock_session.commit.assert_called_once()
|
|
|
|
# Should not commit again on exit
|
|
assert mock_session.commit.call_count == 1
|
|
|
|
def test_explicit_rollback(self):
|
|
"""Test explicit rollback within context."""
|
|
mock_session = MagicMock(spec=Session)
|
|
mock_session.in_transaction.return_value = False
|
|
|
|
with TransactionContext(mock_session) as ctx:
|
|
ctx.rollback()
|
|
mock_session.rollback.assert_called_once()
|
|
|
|
# Should not commit after explicit rollback
|
|
mock_session.commit.assert_not_called()
|
|
|
|
|
|
# ============================================================================
|
|
# AsyncTransactionContext Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestAsyncTransactionContext:
|
|
"""Tests for asynchronous AsyncTransactionContext."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_context_manager_protocol(self):
|
|
"""Test async context manager enters and exits properly."""
|
|
mock_session = AsyncMock(spec=AsyncSession)
|
|
mock_session.in_transaction.return_value = False
|
|
mock_session.begin = AsyncMock()
|
|
mock_session.commit = AsyncMock()
|
|
mock_session.rollback = AsyncMock()
|
|
|
|
async with AsyncTransactionContext(mock_session) as ctx:
|
|
assert ctx.session == mock_session
|
|
mock_session.begin.assert_called_once()
|
|
|
|
mock_session.commit.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_rollback_on_exception(self):
|
|
"""Test async rollback is called when exception occurs."""
|
|
mock_session = AsyncMock(spec=AsyncSession)
|
|
mock_session.in_transaction.return_value = False
|
|
mock_session.begin = AsyncMock()
|
|
mock_session.commit = AsyncMock()
|
|
mock_session.rollback = AsyncMock()
|
|
|
|
with pytest.raises(ValueError):
|
|
async with AsyncTransactionContext(mock_session):
|
|
raise ValueError("Test error")
|
|
|
|
mock_session.rollback.assert_called_once()
|
|
mock_session.commit.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_explicit_commit(self):
|
|
"""Test async explicit commit within context."""
|
|
mock_session = AsyncMock(spec=AsyncSession)
|
|
mock_session.in_transaction.return_value = False
|
|
mock_session.begin = AsyncMock()
|
|
mock_session.commit = AsyncMock()
|
|
|
|
async with AsyncTransactionContext(mock_session) as ctx:
|
|
await ctx.commit()
|
|
mock_session.commit.assert_called_once()
|
|
|
|
# Should not commit again on exit
|
|
assert mock_session.commit.call_count == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_explicit_rollback(self):
|
|
"""Test async explicit rollback within context."""
|
|
mock_session = AsyncMock(spec=AsyncSession)
|
|
mock_session.in_transaction.return_value = False
|
|
mock_session.begin = AsyncMock()
|
|
mock_session.commit = AsyncMock()
|
|
mock_session.rollback = AsyncMock()
|
|
|
|
async with AsyncTransactionContext(mock_session) as ctx:
|
|
await ctx.rollback()
|
|
mock_session.rollback.assert_called_once()
|
|
|
|
# Should not commit after explicit rollback
|
|
mock_session.commit.assert_not_called()
|
|
|
|
|
|
# ============================================================================
|
|
# atomic() Context Manager Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestAtomicContextManager:
|
|
"""Tests for atomic() async context manager."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_atomic_commits_on_success(self):
|
|
"""Test atomic commits transaction on success."""
|
|
mock_session = AsyncMock(spec=AsyncSession)
|
|
mock_session.in_transaction.return_value = False
|
|
mock_session.begin = AsyncMock()
|
|
mock_session.commit = AsyncMock()
|
|
mock_session.rollback = AsyncMock()
|
|
|
|
async with atomic(mock_session) as tx:
|
|
pass
|
|
|
|
mock_session.commit.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_atomic_rollback_on_failure(self):
|
|
"""Test atomic rolls back transaction on failure."""
|
|
mock_session = AsyncMock(spec=AsyncSession)
|
|
mock_session.in_transaction.return_value = False
|
|
mock_session.begin = AsyncMock()
|
|
mock_session.commit = AsyncMock()
|
|
mock_session.rollback = AsyncMock()
|
|
|
|
with pytest.raises(RuntimeError):
|
|
async with atomic(mock_session):
|
|
raise RuntimeError("Operation failed")
|
|
|
|
mock_session.rollback.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_atomic_nested_propagation(self):
|
|
"""Test atomic with NESTED propagation creates savepoint."""
|
|
mock_session = AsyncMock(spec=AsyncSession)
|
|
mock_session.in_transaction.return_value = True
|
|
mock_nested = AsyncMock()
|
|
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
|
|
|
|
async with atomic(
|
|
mock_session, propagation=TransactionPropagation.NESTED
|
|
):
|
|
pass
|
|
|
|
mock_session.begin_nested.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_atomic_required_propagation_default(self):
|
|
"""Test atomic uses REQUIRED propagation by default."""
|
|
mock_session = AsyncMock(spec=AsyncSession)
|
|
mock_session.in_transaction.return_value = False
|
|
mock_session.begin = AsyncMock()
|
|
mock_session.commit = AsyncMock()
|
|
|
|
async with atomic(mock_session) as tx:
|
|
# Should start new transaction
|
|
mock_session.begin.assert_called_once()
|
|
|
|
|
|
# ============================================================================
|
|
# @transactional Decorator Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestTransactionalDecorator:
|
|
"""Tests for @transactional decorator."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_function_wrapped(self):
|
|
"""Test async function is wrapped in transaction."""
|
|
mock_session = AsyncMock(spec=AsyncSession)
|
|
mock_session.in_transaction.return_value = False
|
|
mock_session.begin = AsyncMock()
|
|
mock_session.commit = AsyncMock()
|
|
mock_session.rollback = AsyncMock()
|
|
|
|
@transactional()
|
|
async def sample_operation(db: AsyncSession):
|
|
return "result"
|
|
|
|
result = await sample_operation(db=mock_session)
|
|
|
|
assert result == "result"
|
|
mock_session.commit.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_rollback_on_error(self):
|
|
"""Test async function rollback on error."""
|
|
mock_session = AsyncMock(spec=AsyncSession)
|
|
mock_session.in_transaction.return_value = False
|
|
mock_session.begin = AsyncMock()
|
|
mock_session.commit = AsyncMock()
|
|
mock_session.rollback = AsyncMock()
|
|
|
|
@transactional()
|
|
async def failing_operation(db: AsyncSession):
|
|
raise ValueError("Operation failed")
|
|
|
|
with pytest.raises(ValueError):
|
|
await failing_operation(db=mock_session)
|
|
|
|
mock_session.rollback.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_custom_session_param_name(self):
|
|
"""Test decorator with custom session parameter name."""
|
|
mock_session = AsyncMock(spec=AsyncSession)
|
|
mock_session.in_transaction.return_value = False
|
|
mock_session.begin = AsyncMock()
|
|
mock_session.commit = AsyncMock()
|
|
|
|
@transactional(session_param="session")
|
|
async def operation_with_session(session: AsyncSession):
|
|
return "done"
|
|
|
|
result = await operation_with_session(session=mock_session)
|
|
|
|
assert result == "done"
|
|
mock_session.commit.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_missing_session_raises_error(self):
|
|
"""Test error raised when session parameter not found."""
|
|
@transactional()
|
|
async def operation_no_session(data: dict):
|
|
return data
|
|
|
|
with pytest.raises(TransactionError):
|
|
await operation_no_session(data={"key": "value"})
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_propagation_passed_to_atomic(self):
|
|
"""Test propagation mode is passed to atomic."""
|
|
mock_session = AsyncMock(spec=AsyncSession)
|
|
mock_session.in_transaction.return_value = True
|
|
mock_nested = AsyncMock()
|
|
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
|
|
|
|
@transactional(propagation=TransactionPropagation.NESTED)
|
|
async def nested_operation(db: AsyncSession):
|
|
return "nested"
|
|
|
|
result = await nested_operation(db=mock_session)
|
|
|
|
assert result == "nested"
|
|
mock_session.begin_nested.assert_called_once()
|
|
|
|
|
|
# ============================================================================
|
|
# Sync transactional decorator Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestSyncTransactionalDecorator:
|
|
"""Tests for @transactional decorator with sync functions."""
|
|
|
|
def test_sync_function_wrapped(self):
|
|
"""Test sync function is wrapped in transaction."""
|
|
mock_session = MagicMock(spec=Session)
|
|
mock_session.in_transaction.return_value = False
|
|
|
|
@transactional()
|
|
def sample_sync_operation(db: Session):
|
|
return "sync_result"
|
|
|
|
result = sample_sync_operation(db=mock_session)
|
|
|
|
assert result == "sync_result"
|
|
mock_session.commit.assert_called_once()
|
|
|
|
def test_sync_rollback_on_error(self):
|
|
"""Test sync function rollback on error."""
|
|
mock_session = MagicMock(spec=Session)
|
|
mock_session.in_transaction.return_value = False
|
|
|
|
@transactional()
|
|
def failing_sync_operation(db: Session):
|
|
raise ValueError("Sync operation failed")
|
|
|
|
with pytest.raises(ValueError):
|
|
failing_sync_operation(db=mock_session)
|
|
|
|
mock_session.rollback.assert_called_once()
|
|
|
|
|
|
# ============================================================================
|
|
# Helper Function Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestHelperFunctions:
|
|
"""Tests for transaction helper functions."""
|
|
|
|
def test_is_in_transaction_true(self):
|
|
"""Test is_in_transaction returns True when in transaction."""
|
|
mock_session = MagicMock()
|
|
mock_session.in_transaction.return_value = True
|
|
|
|
assert is_in_transaction(mock_session) is True
|
|
|
|
def test_is_in_transaction_false(self):
|
|
"""Test is_in_transaction returns False when not in transaction."""
|
|
mock_session = MagicMock()
|
|
mock_session.in_transaction.return_value = False
|
|
|
|
assert is_in_transaction(mock_session) is False
|
|
|
|
|
|
# ============================================================================
|
|
# Integration Tests with Real Database
|
|
# ============================================================================
|
|
|
|
|
|
class TestTransactionIntegration:
|
|
"""Integration tests using real in-memory database."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_real_transaction_commit(self, async_session):
|
|
"""Test actual transaction commit with real session."""
|
|
from src.server.database.models import AnimeSeries
|
|
|
|
async with atomic(async_session):
|
|
series = AnimeSeries(
|
|
key="test-series",
|
|
name="Test Series",
|
|
site="https://test.com",
|
|
folder="/test/folder",
|
|
)
|
|
async_session.add(series)
|
|
|
|
# Verify data persisted
|
|
from sqlalchemy import select
|
|
result = await async_session.execute(
|
|
select(AnimeSeries).where(AnimeSeries.key == "test-series")
|
|
)
|
|
saved_series = result.scalar_one_or_none()
|
|
|
|
assert saved_series is not None
|
|
assert saved_series.name == "Test Series"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_real_transaction_rollback(self, async_session):
|
|
"""Test actual transaction rollback with real session."""
|
|
from src.server.database.models import AnimeSeries
|
|
|
|
try:
|
|
async with atomic(async_session):
|
|
series = AnimeSeries(
|
|
key="rollback-series",
|
|
name="Rollback Series",
|
|
site="https://test.com",
|
|
folder="/test/folder",
|
|
)
|
|
async_session.add(series)
|
|
await async_session.flush()
|
|
|
|
# Force rollback
|
|
raise ValueError("Simulated error")
|
|
except ValueError:
|
|
pass
|
|
|
|
# Verify data was NOT persisted
|
|
from sqlalchemy import select
|
|
result = await async_session.execute(
|
|
select(AnimeSeries).where(AnimeSeries.key == "rollback-series")
|
|
)
|
|
saved_series = result.scalar_one_or_none()
|
|
|
|
assert saved_series is None
|
|
|
|
|
|
# ============================================================================
|
|
# TransactionPropagation Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestTransactionPropagation:
|
|
"""Tests for transaction propagation modes."""
|
|
|
|
def test_propagation_enum_values(self):
|
|
"""Test propagation enum has correct values."""
|
|
assert TransactionPropagation.REQUIRED.value == "required"
|
|
assert TransactionPropagation.REQUIRES_NEW.value == "requires_new"
|
|
assert TransactionPropagation.NESTED.value == "nested"
|
|
|
|
|
|
# ============================================================================
|
|
# Additional Coverage Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestSyncSavepointCoverage:
|
|
"""Additional tests for sync savepoint coverage."""
|
|
|
|
def test_savepoint_exception_rolls_back(self):
|
|
"""Test savepoint rollback when exception occurs within savepoint."""
|
|
mock_session = MagicMock(spec=Session)
|
|
mock_session.in_transaction.return_value = False
|
|
mock_nested = MagicMock()
|
|
mock_session.begin_nested.return_value = mock_nested
|
|
|
|
with TransactionContext(mock_session) as ctx:
|
|
with pytest.raises(ValueError):
|
|
with ctx.savepoint() as sp:
|
|
raise ValueError("Error in savepoint")
|
|
|
|
# Nested transaction should have been rolled back
|
|
mock_nested.rollback.assert_called_once()
|
|
|
|
def test_savepoint_commit_explicit(self):
|
|
"""Test explicit commit on savepoint."""
|
|
mock_session = MagicMock(spec=Session)
|
|
mock_session.in_transaction.return_value = False
|
|
mock_nested = MagicMock()
|
|
mock_session.begin_nested.return_value = mock_nested
|
|
|
|
with TransactionContext(mock_session) as ctx:
|
|
with ctx.savepoint() as sp:
|
|
sp.commit()
|
|
# Commit should just log, SQLAlchemy handles actual commit
|
|
|
|
|
|
class TestAsyncSavepointCoverage:
|
|
"""Additional tests for async savepoint coverage."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_savepoint_exception_rolls_back(self):
|
|
"""Test async savepoint rollback when exception occurs."""
|
|
mock_session = AsyncMock(spec=AsyncSession)
|
|
mock_session.in_transaction.return_value = False
|
|
mock_session.begin = AsyncMock()
|
|
mock_session.commit = AsyncMock()
|
|
mock_session.rollback = AsyncMock()
|
|
mock_nested = AsyncMock()
|
|
mock_nested.rollback = AsyncMock()
|
|
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
|
|
|
|
async with AsyncTransactionContext(mock_session) as ctx:
|
|
with pytest.raises(ValueError):
|
|
async with ctx.savepoint() as sp:
|
|
raise ValueError("Error in async savepoint")
|
|
|
|
# Nested transaction should have been rolled back
|
|
mock_nested.rollback.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_savepoint_commit_explicit(self):
|
|
"""Test explicit commit on async savepoint."""
|
|
mock_session = AsyncMock(spec=AsyncSession)
|
|
mock_session.in_transaction.return_value = False
|
|
mock_session.begin = AsyncMock()
|
|
mock_session.commit = AsyncMock()
|
|
mock_nested = AsyncMock()
|
|
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
|
|
|
|
async with AsyncTransactionContext(mock_session) as ctx:
|
|
async with ctx.savepoint() as sp:
|
|
await sp.commit()
|
|
# Commit should just log, SQLAlchemy handles actual commit
|
|
|
|
|
|
class TestAtomicNestedPropagationNoTransaction:
|
|
"""Tests for NESTED propagation when not in transaction."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_nested_starts_new_when_not_in_transaction(self):
|
|
"""Test NESTED propagation starts new transaction when none exists."""
|
|
mock_session = AsyncMock(spec=AsyncSession)
|
|
mock_session.in_transaction.return_value = False
|
|
mock_session.begin = AsyncMock()
|
|
mock_session.commit = AsyncMock()
|
|
mock_session.rollback = AsyncMock()
|
|
|
|
async with atomic(mock_session, TransactionPropagation.NESTED) as tx:
|
|
# Should start new transaction since none exists
|
|
pass
|
|
|
|
mock_session.begin.assert_called_once()
|
|
mock_session.commit.assert_called_once()
|
|
|
|
def test_sync_nested_starts_new_when_not_in_transaction(self):
|
|
"""Test sync NESTED propagation starts new transaction when none exists."""
|
|
mock_session = MagicMock(spec=Session)
|
|
mock_session.in_transaction.return_value = False
|
|
|
|
with atomic_sync(mock_session, TransactionPropagation.NESTED) as tx:
|
|
pass
|
|
|
|
mock_session.begin.assert_called_once()
|
|
mock_session.commit.assert_called_once()
|
|
|
|
|
|
class TestGetTransactionDepth:
|
|
"""Tests for get_transaction_depth helper."""
|
|
|
|
def test_depth_zero_when_not_in_transaction(self):
|
|
"""Test depth is 0 when not in transaction."""
|
|
from src.server.database.transaction import get_transaction_depth
|
|
|
|
mock_session = MagicMock(spec=Session)
|
|
mock_session.in_transaction.return_value = False
|
|
|
|
depth = get_transaction_depth(mock_session)
|
|
assert depth == 0
|
|
|
|
def test_depth_one_in_transaction(self):
|
|
"""Test depth is 1 in basic transaction."""
|
|
from src.server.database.transaction import get_transaction_depth
|
|
|
|
mock_session = MagicMock(spec=Session)
|
|
mock_session.in_transaction.return_value = True
|
|
mock_session._nested_transaction = None
|
|
|
|
depth = get_transaction_depth(mock_session)
|
|
assert depth == 1
|
|
|
|
def test_depth_two_with_nested_transaction(self):
|
|
"""Test depth is 2 with nested transaction."""
|
|
from src.server.database.transaction import get_transaction_depth
|
|
|
|
mock_session = MagicMock(spec=Session)
|
|
mock_session.in_transaction.return_value = True
|
|
mock_session._nested_transaction = MagicMock() # Has nested
|
|
|
|
depth = get_transaction_depth(mock_session)
|
|
assert depth == 2
|
|
|
|
|
|
class TestTransactionalDecoratorPositionalArgs:
|
|
"""Tests for transactional decorator with positional arguments."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_from_positional_arg(self):
|
|
"""Test decorator extracts session from positional argument."""
|
|
mock_session = AsyncMock(spec=AsyncSession)
|
|
mock_session.in_transaction.return_value = False
|
|
mock_session.begin = AsyncMock()
|
|
mock_session.commit = AsyncMock()
|
|
mock_session.rollback = AsyncMock()
|
|
|
|
@transactional()
|
|
async def operation(db: AsyncSession, data: str):
|
|
return f"processed: {data}"
|
|
|
|
# Pass session as positional argument
|
|
result = await operation(mock_session, "test")
|
|
|
|
assert result == "processed: test"
|
|
mock_session.commit.assert_called_once()
|
|
|
|
def test_sync_session_from_positional_arg(self):
|
|
"""Test sync decorator extracts session from positional argument."""
|
|
mock_session = MagicMock(spec=Session)
|
|
mock_session.in_transaction.return_value = False
|
|
|
|
@transactional()
|
|
def operation(db: Session, data: str):
|
|
return f"processed: {data}"
|
|
|
|
result = operation(mock_session, "test")
|
|
|
|
assert result == "processed: test"
|
|
mock_session.commit.assert_called_once()
|