Add database transaction support with atomic operations

- Create transaction.py with @transactional decorator, atomic() context manager
- Add TransactionPropagation modes: REQUIRED, REQUIRES_NEW, NESTED
- Add savepoint support for nested transactions with partial rollback
- Update connection.py with TransactionManager, get_transactional_session
- Update service.py with bulk operations (bulk_mark_downloaded, bulk_delete)
- Wrap QueueRepository.save_item() and clear_all() in atomic transactions
- Add comprehensive tests (66 transaction tests, 90% coverage)
- All 1090 tests passing
This commit is contained in:
2025-12-25 18:05:33 +01:00
parent b2728a7cf4
commit 1ba67357dc
15 changed files with 3385 additions and 202 deletions

View File

@@ -0,0 +1,546 @@
"""Integration tests for database transaction behavior.
Tests real database transaction handling including:
- Transaction isolation
- Concurrent transaction handling
- Real commit/rollback behavior
"""
import asyncio
from datetime import datetime, timedelta, timezone
from typing import List
import pytest
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from src.server.database.base import Base
from src.server.database.connection import (
TransactionManager,
get_session_transaction_depth,
is_session_in_transaction,
)
from src.server.database.models import AnimeSeries, DownloadQueueItem, Episode
from src.server.database.service import (
AnimeSeriesService,
DownloadQueueService,
EpisodeService,
)
from src.server.database.transaction import (
TransactionPropagation,
atomic,
transactional,
)
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
async def db_engine():
"""Create in-memory database engine for testing."""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def session_factory(db_engine):
"""Create session factory for testing."""
from sqlalchemy.ext.asyncio import async_sessionmaker
return async_sessionmaker(
db_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
@pytest.fixture
async def db_session(session_factory):
"""Create database session for testing."""
async with session_factory() as session:
yield session
await session.rollback()
# ============================================================================
# Real Database Transaction Tests
# ============================================================================
class TestRealDatabaseTransactions:
"""Tests using real in-memory database."""
@pytest.mark.asyncio
async def test_commit_persists_data(self, db_session):
"""Test that committed data is actually persisted."""
async with atomic(db_session):
series = await AnimeSeriesService.create(
db_session,
key="commit-test",
name="Commit Test Series",
site="https://test.com",
folder="/test/folder",
)
# Data should be retrievable after commit
retrieved = await AnimeSeriesService.get_by_key(
db_session, "commit-test"
)
assert retrieved is not None
assert retrieved.name == "Commit Test Series"
@pytest.mark.asyncio
async def test_rollback_discards_data(self, db_session):
"""Test that rolled back data is discarded."""
try:
async with atomic(db_session):
series = await AnimeSeriesService.create(
db_session,
key="rollback-test",
name="Rollback Test Series",
site="https://test.com",
folder="/test/folder",
)
await db_session.flush()
raise ValueError("Force rollback")
except ValueError:
pass
# Data should NOT be retrievable after rollback
retrieved = await AnimeSeriesService.get_by_key(
db_session, "rollback-test"
)
assert retrieved is None
@pytest.mark.asyncio
async def test_multiple_operations_atomic(self, db_session):
"""Test multiple operations are committed together."""
async with atomic(db_session):
# Create series
series = await AnimeSeriesService.create(
db_session,
key="atomic-multi-test",
name="Atomic Multi Test",
site="https://test.com",
folder="/test/folder",
)
# Create episode
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
title="Episode 1",
)
# Create queue item
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
# All should be persisted
retrieved_series = await AnimeSeriesService.get_by_key(
db_session, "atomic-multi-test"
)
assert retrieved_series is not None
episodes = await EpisodeService.get_by_series(
db_session, retrieved_series.id
)
assert len(episodes) == 1
queue_items = await DownloadQueueService.get_all(db_session)
assert len(queue_items) >= 1
@pytest.mark.asyncio
async def test_multiple_operations_rollback_all(self, db_session):
"""Test multiple operations are all rolled back on failure."""
try:
async with atomic(db_session):
# Create series
series = await AnimeSeriesService.create(
db_session,
key="rollback-multi-test",
name="Rollback Multi Test",
site="https://test.com",
folder="/test/folder",
)
# Create episode
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
# Create queue item
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
await db_session.flush()
raise RuntimeError("Force complete rollback")
except RuntimeError:
pass
# None should be persisted
retrieved_series = await AnimeSeriesService.get_by_key(
db_session, "rollback-multi-test"
)
assert retrieved_series is None
# ============================================================================
# Transaction Manager Tests
# ============================================================================
class TestTransactionManager:
"""Tests for TransactionManager class."""
@pytest.mark.asyncio
async def test_transaction_manager_basic_flow(self, session_factory):
"""Test basic transaction manager usage."""
async with TransactionManager(session_factory) as tm:
session = await tm.get_session()
await tm.begin()
series = AnimeSeries(
key="tm-test",
name="TM Test",
site="https://test.com",
folder="/test",
)
session.add(series)
await tm.commit()
# Verify data persisted
async with session_factory() as verify_session:
from sqlalchemy import select
result = await verify_session.execute(
select(AnimeSeries).where(AnimeSeries.key == "tm-test")
)
series = result.scalar_one_or_none()
assert series is not None
@pytest.mark.asyncio
async def test_transaction_manager_rollback(self, session_factory):
"""Test transaction manager rollback."""
async with TransactionManager(session_factory) as tm:
session = await tm.get_session()
await tm.begin()
series = AnimeSeries(
key="tm-rollback-test",
name="TM Rollback Test",
site="https://test.com",
folder="/test",
)
session.add(series)
await session.flush()
await tm.rollback()
# Verify data NOT persisted
async with session_factory() as verify_session:
from sqlalchemy import select
result = await verify_session.execute(
select(AnimeSeries).where(AnimeSeries.key == "tm-rollback-test")
)
series = result.scalar_one_or_none()
assert series is None
@pytest.mark.asyncio
async def test_transaction_manager_auto_rollback_on_exception(
self, session_factory
):
"""Test transaction manager auto-rolls back on exception."""
with pytest.raises(ValueError):
async with TransactionManager(session_factory) as tm:
session = await tm.get_session()
await tm.begin()
series = AnimeSeries(
key="tm-auto-rollback",
name="TM Auto Rollback",
site="https://test.com",
folder="/test",
)
session.add(series)
await session.flush()
raise ValueError("Force exception")
# Verify data NOT persisted
async with session_factory() as verify_session:
from sqlalchemy import select
result = await verify_session.execute(
select(AnimeSeries).where(AnimeSeries.key == "tm-auto-rollback")
)
series = result.scalar_one_or_none()
assert series is None
@pytest.mark.asyncio
async def test_transaction_manager_state_tracking(self, session_factory):
"""Test transaction manager tracks state correctly."""
async with TransactionManager(session_factory) as tm:
assert tm.is_in_transaction() is False
await tm.begin()
assert tm.is_in_transaction() is True
await tm.commit()
assert tm.is_in_transaction() is False
# ============================================================================
# Helper Function Tests
# ============================================================================
class TestConnectionHelpers:
"""Tests for connection module helper functions."""
@pytest.mark.asyncio
async def test_is_session_in_transaction(self, db_session):
"""Test is_session_in_transaction helper."""
# Initially not in transaction
assert is_session_in_transaction(db_session) is False
async with atomic(db_session):
# Now in transaction
assert is_session_in_transaction(db_session) is True
# After exit, depends on session state
# SQLite behavior may vary
@pytest.mark.asyncio
async def test_get_session_transaction_depth(self, db_session):
"""Test get_session_transaction_depth helper."""
depth = get_session_transaction_depth(db_session)
assert depth >= 0
# ============================================================================
# @transactional Decorator Integration Tests
# ============================================================================
class TestTransactionalDecoratorIntegration:
"""Integration tests for @transactional decorator."""
@pytest.mark.asyncio
async def test_decorated_function_commits(self, db_session):
"""Test decorated function commits on success."""
@transactional()
async def create_series_decorated(db: AsyncSession):
return await AnimeSeriesService.create(
db,
key="decorated-test",
name="Decorated Test",
site="https://test.com",
folder="/test",
)
series = await create_series_decorated(db=db_session)
# Verify committed
retrieved = await AnimeSeriesService.get_by_key(
db_session, "decorated-test"
)
assert retrieved is not None
@pytest.mark.asyncio
async def test_decorated_function_rollback(self, db_session):
"""Test decorated function rolls back on error."""
@transactional()
async def create_then_fail(db: AsyncSession):
await AnimeSeriesService.create(
db,
key="decorated-rollback",
name="Decorated Rollback",
site="https://test.com",
folder="/test",
)
raise ValueError("Force failure")
with pytest.raises(ValueError):
await create_then_fail(db=db_session)
# Verify NOT committed
retrieved = await AnimeSeriesService.get_by_key(
db_session, "decorated-rollback"
)
assert retrieved is None
@pytest.mark.asyncio
async def test_nested_decorated_functions(self, db_session):
"""Test nested decorated functions work correctly."""
@transactional(propagation=TransactionPropagation.NESTED)
async def inner_operation(db: AsyncSession, series_id: int):
return await EpisodeService.create(
db,
series_id=series_id,
season=1,
episode_number=1,
)
@transactional()
async def outer_operation(db: AsyncSession):
series = await AnimeSeriesService.create(
db,
key="nested-decorated",
name="Nested Decorated",
site="https://test.com",
folder="/test",
)
episode = await inner_operation(db=db, series_id=series.id)
return series, episode
series, episode = await outer_operation(db=db_session)
# Both should be committed
assert series is not None
assert episode is not None
# ============================================================================
# Concurrent Transaction Tests
# ============================================================================
class TestConcurrentTransactions:
"""Tests for concurrent transaction handling."""
@pytest.mark.asyncio
async def test_concurrent_writes_different_keys(self, session_factory):
"""Test concurrent writes to different records."""
async def create_series(key: str):
async with session_factory() as session:
async with atomic(session):
await AnimeSeriesService.create(
session,
key=key,
name=f"Series {key}",
site="https://test.com",
folder=f"/test/{key}",
)
# Run concurrent creates
await asyncio.gather(
create_series("concurrent-1"),
create_series("concurrent-2"),
create_series("concurrent-3"),
)
# Verify all created
async with session_factory() as verify_session:
for i in range(1, 4):
series = await AnimeSeriesService.get_by_key(
verify_session, f"concurrent-{i}"
)
assert series is not None
# ============================================================================
# Queue Repository Transaction Tests
# ============================================================================
class TestQueueRepositoryTransactions:
"""Integration tests for QueueRepository transaction handling."""
@pytest.mark.asyncio
async def test_save_item_atomic(self, session_factory):
"""Test save_item creates series, episode, and queue item atomically."""
from src.server.models.download import (
DownloadItem,
DownloadStatus,
EpisodeIdentifier,
)
from src.server.services.queue_repository import QueueRepository
repo = QueueRepository(session_factory)
item = DownloadItem(
id="temp-id",
serie_id="repo-atomic-test",
serie_folder="/test/folder",
serie_name="Repo Atomic Test",
episode=EpisodeIdentifier(season=1, episode=1),
status=DownloadStatus.PENDING,
)
saved_item = await repo.save_item(item)
assert saved_item.id != "temp-id" # Should have DB ID
# Verify all entities created
async with session_factory() as verify_session:
series = await AnimeSeriesService.get_by_key(
verify_session, "repo-atomic-test"
)
assert series is not None
episodes = await EpisodeService.get_by_series(
verify_session, series.id
)
assert len(episodes) == 1
queue_items = await DownloadQueueService.get_all(verify_session)
assert len(queue_items) >= 1
@pytest.mark.asyncio
async def test_clear_all_atomic(self, session_factory):
"""Test clear_all removes all items atomically."""
from src.server.models.download import (
DownloadItem,
DownloadStatus,
EpisodeIdentifier,
)
from src.server.services.queue_repository import QueueRepository
repo = QueueRepository(session_factory)
# Add some items
for i in range(3):
item = DownloadItem(
id=f"clear-{i}",
serie_id=f"clear-series-{i}",
serie_folder=f"/test/folder/{i}",
serie_name=f"Clear Series {i}",
episode=EpisodeIdentifier(season=1, episode=1),
status=DownloadStatus.PENDING,
)
await repo.save_item(item)
# Clear all
count = await repo.clear_all()
assert count == 3
# Verify all cleared
async with session_factory() as verify_session:
queue_items = await DownloadQueueService.get_all(verify_session)
assert len(queue_items) == 0

View File

@@ -0,0 +1,546 @@
"""Unit tests for service layer transaction behavior.
Tests that service operations correctly handle transactions,
especially compound operations that require atomicity.
"""
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from src.server.database.base import Base
from src.server.database.models import (
AnimeSeries,
DownloadQueueItem,
Episode,
UserSession,
)
from src.server.database.service import (
AnimeSeriesService,
DownloadQueueService,
EpisodeService,
UserSessionService,
)
from src.server.database.transaction import atomic
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
async def db_engine():
"""Create in-memory database engine for testing."""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def db_session(db_engine):
"""Create database session for testing."""
from sqlalchemy.ext.asyncio import async_sessionmaker
async_session = async_sessionmaker(
db_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session() as session:
yield session
await session.rollback()
# ============================================================================
# AnimeSeriesService Transaction Tests
# ============================================================================
class TestAnimeSeriesServiceTransactions:
"""Tests for AnimeSeriesService transaction behavior."""
@pytest.mark.asyncio
async def test_create_uses_flush_not_commit(self, db_session):
"""Test create uses flush for transaction compatibility."""
series = await AnimeSeriesService.create(
db_session,
key="test-key",
name="Test Series",
site="https://test.com",
folder="/test/folder",
)
# Series should exist in session
assert series.id is not None
# But not committed yet (we're in an uncommitted transaction)
# We can verify by checking the session's uncommitted state
assert series in db_session
@pytest.mark.asyncio
async def test_update_uses_flush_not_commit(self, db_session):
"""Test update uses flush for transaction compatibility."""
# Create series
series = await AnimeSeriesService.create(
db_session,
key="update-test",
name="Original Name",
site="https://test.com",
folder="/test/folder",
)
# Update series
updated = await AnimeSeriesService.update(
db_session,
series.id,
name="Updated Name",
)
assert updated.name == "Updated Name"
assert updated in db_session
# ============================================================================
# EpisodeService Transaction Tests
# ============================================================================
class TestEpisodeServiceTransactions:
"""Tests for EpisodeService transaction behavior."""
@pytest.mark.asyncio
async def test_bulk_mark_downloaded_atomicity(self, db_session):
"""Test bulk_mark_downloaded updates all or none."""
# Create series and episodes
series = await AnimeSeriesService.create(
db_session,
key="bulk-test-series",
name="Bulk Test",
site="https://test.com",
folder="/test/folder",
)
episodes = []
for i in range(1, 4):
ep = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=i,
title=f"Episode {i}",
)
episodes.append(ep)
episode_ids = [ep.id for ep in episodes]
file_paths = [f"/path/ep{i}.mp4" for i in range(1, 4)]
# Bulk update within atomic context
async with atomic(db_session):
count = await EpisodeService.bulk_mark_downloaded(
db_session,
episode_ids,
file_paths,
)
assert count == 3
# Verify all episodes were marked
for i, ep_id in enumerate(episode_ids):
episode = await EpisodeService.get_by_id(db_session, ep_id)
assert episode.is_downloaded is True
assert episode.file_path == file_paths[i]
@pytest.mark.asyncio
async def test_bulk_mark_downloaded_empty_list(self, db_session):
"""Test bulk_mark_downloaded handles empty list."""
count = await EpisodeService.bulk_mark_downloaded(
db_session,
episode_ids=[],
)
assert count == 0
@pytest.mark.asyncio
async def test_delete_by_series_and_episode_transaction(self, db_session):
"""Test delete_by_series_and_episode in transaction."""
# Create series and episode
series = await AnimeSeriesService.create(
db_session,
key="delete-test-series",
name="Delete Test",
site="https://test.com",
folder="/test/folder",
)
await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
title="Episode 1",
)
await db_session.commit()
# Delete episode within transaction
async with atomic(db_session):
deleted = await EpisodeService.delete_by_series_and_episode(
db_session,
series_key="delete-test-series",
season=1,
episode_number=1,
)
assert deleted is True
# Verify episode is gone
episode = await EpisodeService.get_by_episode(
db_session,
series.id,
season=1,
episode_number=1,
)
assert episode is None
# ============================================================================
# DownloadQueueService Transaction Tests
# ============================================================================
class TestDownloadQueueServiceTransactions:
"""Tests for DownloadQueueService transaction behavior."""
@pytest.mark.asyncio
async def test_bulk_delete_atomicity(self, db_session):
"""Test bulk_delete removes all or none."""
# Create series and episodes
series = await AnimeSeriesService.create(
db_session,
key="queue-bulk-test",
name="Queue Bulk Test",
site="https://test.com",
folder="/test/folder",
)
item_ids = []
for i in range(1, 4):
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=i,
)
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
item_ids.append(item.id)
# Bulk delete within atomic context
async with atomic(db_session):
count = await DownloadQueueService.bulk_delete(
db_session,
item_ids,
)
assert count == 3
# Verify all items deleted
all_items = await DownloadQueueService.get_all(db_session)
assert len(all_items) == 0
@pytest.mark.asyncio
async def test_bulk_delete_empty_list(self, db_session):
"""Test bulk_delete handles empty list."""
count = await DownloadQueueService.bulk_delete(
db_session,
item_ids=[],
)
assert count == 0
@pytest.mark.asyncio
async def test_clear_all_atomicity(self, db_session):
"""Test clear_all removes all items atomically."""
# Create series and queue items
series = await AnimeSeriesService.create(
db_session,
key="clear-all-test",
name="Clear All Test",
site="https://test.com",
folder="/test/folder",
)
for i in range(1, 4):
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=i,
)
await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
# Clear all within atomic context
async with atomic(db_session):
count = await DownloadQueueService.clear_all(db_session)
assert count == 3
# Verify all items cleared
all_items = await DownloadQueueService.get_all(db_session)
assert len(all_items) == 0
# ============================================================================
# UserSessionService Transaction Tests
# ============================================================================
class TestUserSessionServiceTransactions:
"""Tests for UserSessionService transaction behavior."""
@pytest.mark.asyncio
async def test_rotate_session_atomicity(self, db_session):
"""Test rotate_session is atomic (revoke + create)."""
# Create old session
old_session = await UserSessionService.create(
db_session,
session_id="old-session-123",
token_hash="old_hash",
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
)
await db_session.commit()
# Rotate session within atomic context
async with atomic(db_session):
new_session = await UserSessionService.rotate_session(
db_session,
old_session_id="old-session-123",
new_session_id="new-session-456",
new_token_hash="new_hash",
new_expires_at=datetime.now(timezone.utc) + timedelta(hours=2),
)
assert new_session is not None
assert new_session.session_id == "new-session-456"
# Verify old session is revoked
old = await UserSessionService.get_by_session_id(
db_session, "old-session-123"
)
assert old.is_active is False
@pytest.mark.asyncio
async def test_rotate_session_old_not_found(self, db_session):
"""Test rotate_session returns None if old session not found."""
result = await UserSessionService.rotate_session(
db_session,
old_session_id="nonexistent-session",
new_session_id="new-session",
new_token_hash="hash",
new_expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
)
assert result is None
@pytest.mark.asyncio
async def test_cleanup_expired_bulk_delete(self, db_session):
"""Test cleanup_expired removes all expired sessions."""
# Create expired sessions
for i in range(3):
await UserSessionService.create(
db_session,
session_id=f"expired-{i}",
token_hash=f"hash-{i}",
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
)
# Create active session
await UserSessionService.create(
db_session,
session_id="active-session",
token_hash="active_hash",
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
)
await db_session.commit()
# Cleanup expired within atomic context
async with atomic(db_session):
count = await UserSessionService.cleanup_expired(db_session)
assert count == 3
# Verify active session still exists
active = await UserSessionService.get_by_session_id(
db_session, "active-session"
)
assert active is not None
# ============================================================================
# Compound Operation Rollback Tests
# ============================================================================
class TestCompoundOperationRollback:
"""Tests for rollback behavior in compound operations."""
@pytest.mark.asyncio
async def test_rollback_on_partial_failure(self, db_session):
"""Test rollback when compound operation fails mid-way."""
# Create initial series
series = await AnimeSeriesService.create(
db_session,
key="rollback-test-series",
name="Rollback Test",
site="https://test.com",
folder="/test/folder",
)
await db_session.commit()
# Store the id before starting the transaction to avoid expired state access
series_id = series.id
try:
async with atomic(db_session):
# Create episode
episode = await EpisodeService.create(
db_session,
series_id=series_id,
season=1,
episode_number=1,
)
# Force flush to persist episode in transaction
await db_session.flush()
# Simulate failure mid-operation
raise ValueError("Simulated failure")
except ValueError:
pass
# Verify episode was NOT persisted
episode = await EpisodeService.get_by_episode(
db_session,
series_id,
season=1,
episode_number=1,
)
assert episode is None
@pytest.mark.asyncio
async def test_no_orphan_data_on_failure(self, db_session):
"""Test no orphaned data when multi-service operation fails."""
try:
async with atomic(db_session):
# Create series
series = await AnimeSeriesService.create(
db_session,
key="orphan-test-series",
name="Orphan Test",
site="https://test.com",
folder="/test/folder",
)
# Create episode
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
# Create queue item
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
await db_session.flush()
# Fail after all creates
raise RuntimeError("Critical failure")
except RuntimeError:
pass
# Verify nothing was persisted
all_series = await AnimeSeriesService.get_all(db_session)
series_keys = [s.key for s in all_series]
assert "orphan-test-series" not in series_keys
# ============================================================================
# Nested Transaction Tests
# ============================================================================
class TestNestedTransactions:
"""Tests for nested transaction (savepoint) behavior."""
@pytest.mark.asyncio
async def test_savepoint_partial_rollback(self, db_session):
"""Test savepoint allows partial rollback."""
# Create series
series = await AnimeSeriesService.create(
db_session,
key="savepoint-test",
name="Savepoint Test",
site="https://test.com",
folder="/test/folder",
)
async with atomic(db_session) as tx:
# Create first episode (should persist)
await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
# Nested transaction for second episode
async with tx.savepoint() as sp:
await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=2,
)
# Rollback only the savepoint
await sp.rollback()
# Create third episode (should persist)
await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=3,
)
# Verify first and third episodes exist, second doesn't
episodes = await EpisodeService.get_by_series(db_session, series.id)
episode_numbers = [ep.episode_number for ep in episodes]
assert 1 in episode_numbers
assert 2 not in episode_numbers # Rolled back
assert 3 in episode_numbers

View File

@@ -0,0 +1,668 @@
"""Unit tests for database transaction utilities.
Tests the transaction management utilities including decorators,
context managers, and helper functions.
"""
import asyncio
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import Session, sessionmaker
from src.server.database.base import Base
from src.server.database.transaction import (
AsyncTransactionContext,
TransactionContext,
TransactionError,
TransactionPropagation,
atomic,
atomic_sync,
is_in_transaction,
transactional,
)
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
async def async_engine():
"""Create in-memory async database engine for testing."""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def async_session(async_engine):
"""Create async database session for testing."""
from sqlalchemy.ext.asyncio import async_sessionmaker
async_session_factory = async_sessionmaker(
async_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session_factory() as session:
yield session
await session.rollback()
# ============================================================================
# TransactionContext Tests (Sync)
# ============================================================================
class TestTransactionContext:
"""Tests for synchronous TransactionContext."""
def test_context_manager_protocol(self):
"""Test context manager enters and exits properly."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
with TransactionContext(mock_session) as ctx:
assert ctx.session == mock_session
mock_session.begin.assert_called_once()
mock_session.commit.assert_called_once()
def test_rollback_on_exception(self):
"""Test rollback is called when exception occurs."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
with pytest.raises(ValueError):
with TransactionContext(mock_session):
raise ValueError("Test error")
mock_session.rollback.assert_called_once()
mock_session.commit.assert_not_called()
def test_no_begin_if_already_in_transaction(self):
"""Test no new transaction started if already in one."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = True
with TransactionContext(mock_session):
pass
mock_session.begin.assert_not_called()
def test_explicit_commit(self):
"""Test explicit commit within context."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
with TransactionContext(mock_session) as ctx:
ctx.commit()
mock_session.commit.assert_called_once()
# Should not commit again on exit
assert mock_session.commit.call_count == 1
def test_explicit_rollback(self):
"""Test explicit rollback within context."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
with TransactionContext(mock_session) as ctx:
ctx.rollback()
mock_session.rollback.assert_called_once()
# Should not commit after explicit rollback
mock_session.commit.assert_not_called()
# ============================================================================
# AsyncTransactionContext Tests
# ============================================================================
class TestAsyncTransactionContext:
"""Tests for asynchronous AsyncTransactionContext."""
@pytest.mark.asyncio
async def test_async_context_manager_protocol(self):
"""Test async context manager enters and exits properly."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
async with AsyncTransactionContext(mock_session) as ctx:
assert ctx.session == mock_session
mock_session.begin.assert_called_once()
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_async_rollback_on_exception(self):
"""Test async rollback is called when exception occurs."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
with pytest.raises(ValueError):
async with AsyncTransactionContext(mock_session):
raise ValueError("Test error")
mock_session.rollback.assert_called_once()
mock_session.commit.assert_not_called()
@pytest.mark.asyncio
async def test_async_explicit_commit(self):
"""Test async explicit commit within context."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
async with AsyncTransactionContext(mock_session) as ctx:
await ctx.commit()
mock_session.commit.assert_called_once()
# Should not commit again on exit
assert mock_session.commit.call_count == 1
@pytest.mark.asyncio
async def test_async_explicit_rollback(self):
"""Test async explicit rollback within context."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
async with AsyncTransactionContext(mock_session) as ctx:
await ctx.rollback()
mock_session.rollback.assert_called_once()
# Should not commit after explicit rollback
mock_session.commit.assert_not_called()
# ============================================================================
# atomic() Context Manager Tests
# ============================================================================
class TestAtomicContextManager:
"""Tests for atomic() async context manager."""
@pytest.mark.asyncio
async def test_atomic_commits_on_success(self):
"""Test atomic commits transaction on success."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
async with atomic(mock_session) as tx:
pass
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_atomic_rollback_on_failure(self):
"""Test atomic rolls back transaction on failure."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
with pytest.raises(RuntimeError):
async with atomic(mock_session):
raise RuntimeError("Operation failed")
mock_session.rollback.assert_called_once()
@pytest.mark.asyncio
async def test_atomic_nested_propagation(self):
"""Test atomic with NESTED propagation creates savepoint."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = True
mock_nested = AsyncMock()
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
async with atomic(
mock_session, propagation=TransactionPropagation.NESTED
):
pass
mock_session.begin_nested.assert_called_once()
@pytest.mark.asyncio
async def test_atomic_required_propagation_default(self):
"""Test atomic uses REQUIRED propagation by default."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
async with atomic(mock_session) as tx:
# Should start new transaction
mock_session.begin.assert_called_once()
# ============================================================================
# @transactional Decorator Tests
# ============================================================================
class TestTransactionalDecorator:
"""Tests for @transactional decorator."""
@pytest.mark.asyncio
async def test_async_function_wrapped(self):
"""Test async function is wrapped in transaction."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
@transactional()
async def sample_operation(db: AsyncSession):
return "result"
result = await sample_operation(db=mock_session)
assert result == "result"
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_async_rollback_on_error(self):
"""Test async function rollback on error."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
@transactional()
async def failing_operation(db: AsyncSession):
raise ValueError("Operation failed")
with pytest.raises(ValueError):
await failing_operation(db=mock_session)
mock_session.rollback.assert_called_once()
@pytest.mark.asyncio
async def test_custom_session_param_name(self):
"""Test decorator with custom session parameter name."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
@transactional(session_param="session")
async def operation_with_session(session: AsyncSession):
return "done"
result = await operation_with_session(session=mock_session)
assert result == "done"
mock_session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_missing_session_raises_error(self):
"""Test error raised when session parameter not found."""
@transactional()
async def operation_no_session(data: dict):
return data
with pytest.raises(TransactionError):
await operation_no_session(data={"key": "value"})
@pytest.mark.asyncio
async def test_propagation_passed_to_atomic(self):
"""Test propagation mode is passed to atomic."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = True
mock_nested = AsyncMock()
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
@transactional(propagation=TransactionPropagation.NESTED)
async def nested_operation(db: AsyncSession):
return "nested"
result = await nested_operation(db=mock_session)
assert result == "nested"
mock_session.begin_nested.assert_called_once()
# ============================================================================
# Sync transactional decorator Tests
# ============================================================================
class TestSyncTransactionalDecorator:
"""Tests for @transactional decorator with sync functions."""
def test_sync_function_wrapped(self):
"""Test sync function is wrapped in transaction."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
@transactional()
def sample_sync_operation(db: Session):
return "sync_result"
result = sample_sync_operation(db=mock_session)
assert result == "sync_result"
mock_session.commit.assert_called_once()
def test_sync_rollback_on_error(self):
"""Test sync function rollback on error."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
@transactional()
def failing_sync_operation(db: Session):
raise ValueError("Sync operation failed")
with pytest.raises(ValueError):
failing_sync_operation(db=mock_session)
mock_session.rollback.assert_called_once()
# ============================================================================
# Helper Function Tests
# ============================================================================
class TestHelperFunctions:
"""Tests for transaction helper functions."""
def test_is_in_transaction_true(self):
"""Test is_in_transaction returns True when in transaction."""
mock_session = MagicMock()
mock_session.in_transaction.return_value = True
assert is_in_transaction(mock_session) is True
def test_is_in_transaction_false(self):
"""Test is_in_transaction returns False when not in transaction."""
mock_session = MagicMock()
mock_session.in_transaction.return_value = False
assert is_in_transaction(mock_session) is False
# ============================================================================
# Integration Tests with Real Database
# ============================================================================
class TestTransactionIntegration:
"""Integration tests using real in-memory database."""
@pytest.mark.asyncio
async def test_real_transaction_commit(self, async_session):
"""Test actual transaction commit with real session."""
from src.server.database.models import AnimeSeries
async with atomic(async_session):
series = AnimeSeries(
key="test-series",
name="Test Series",
site="https://test.com",
folder="/test/folder",
)
async_session.add(series)
# Verify data persisted
from sqlalchemy import select
result = await async_session.execute(
select(AnimeSeries).where(AnimeSeries.key == "test-series")
)
saved_series = result.scalar_one_or_none()
assert saved_series is not None
assert saved_series.name == "Test Series"
@pytest.mark.asyncio
async def test_real_transaction_rollback(self, async_session):
"""Test actual transaction rollback with real session."""
from src.server.database.models import AnimeSeries
try:
async with atomic(async_session):
series = AnimeSeries(
key="rollback-series",
name="Rollback Series",
site="https://test.com",
folder="/test/folder",
)
async_session.add(series)
await async_session.flush()
# Force rollback
raise ValueError("Simulated error")
except ValueError:
pass
# Verify data was NOT persisted
from sqlalchemy import select
result = await async_session.execute(
select(AnimeSeries).where(AnimeSeries.key == "rollback-series")
)
saved_series = result.scalar_one_or_none()
assert saved_series is None
# ============================================================================
# TransactionPropagation Tests
# ============================================================================
class TestTransactionPropagation:
"""Tests for transaction propagation modes."""
def test_propagation_enum_values(self):
"""Test propagation enum has correct values."""
assert TransactionPropagation.REQUIRED.value == "required"
assert TransactionPropagation.REQUIRES_NEW.value == "requires_new"
assert TransactionPropagation.NESTED.value == "nested"
# ============================================================================
# Additional Coverage Tests
# ============================================================================
class TestSyncSavepointCoverage:
"""Additional tests for sync savepoint coverage."""
def test_savepoint_exception_rolls_back(self):
"""Test savepoint rollback when exception occurs within savepoint."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
mock_nested = MagicMock()
mock_session.begin_nested.return_value = mock_nested
with TransactionContext(mock_session) as ctx:
with pytest.raises(ValueError):
with ctx.savepoint() as sp:
raise ValueError("Error in savepoint")
# Nested transaction should have been rolled back
mock_nested.rollback.assert_called_once()
def test_savepoint_commit_explicit(self):
"""Test explicit commit on savepoint."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
mock_nested = MagicMock()
mock_session.begin_nested.return_value = mock_nested
with TransactionContext(mock_session) as ctx:
with ctx.savepoint() as sp:
sp.commit()
# Commit should just log, SQLAlchemy handles actual commit
class TestAsyncSavepointCoverage:
"""Additional tests for async savepoint coverage."""
@pytest.mark.asyncio
async def test_async_savepoint_exception_rolls_back(self):
"""Test async savepoint rollback when exception occurs."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
mock_nested = AsyncMock()
mock_nested.rollback = AsyncMock()
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
async with AsyncTransactionContext(mock_session) as ctx:
with pytest.raises(ValueError):
async with ctx.savepoint() as sp:
raise ValueError("Error in async savepoint")
# Nested transaction should have been rolled back
mock_nested.rollback.assert_called_once()
@pytest.mark.asyncio
async def test_async_savepoint_commit_explicit(self):
"""Test explicit commit on async savepoint."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_nested = AsyncMock()
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
async with AsyncTransactionContext(mock_session) as ctx:
async with ctx.savepoint() as sp:
await sp.commit()
# Commit should just log, SQLAlchemy handles actual commit
class TestAtomicNestedPropagationNoTransaction:
"""Tests for NESTED propagation when not in transaction."""
@pytest.mark.asyncio
async def test_async_nested_starts_new_when_not_in_transaction(self):
"""Test NESTED propagation starts new transaction when none exists."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
async with atomic(mock_session, TransactionPropagation.NESTED) as tx:
# Should start new transaction since none exists
pass
mock_session.begin.assert_called_once()
mock_session.commit.assert_called_once()
def test_sync_nested_starts_new_when_not_in_transaction(self):
"""Test sync NESTED propagation starts new transaction when none exists."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
with atomic_sync(mock_session, TransactionPropagation.NESTED) as tx:
pass
mock_session.begin.assert_called_once()
mock_session.commit.assert_called_once()
class TestGetTransactionDepth:
"""Tests for get_transaction_depth helper."""
def test_depth_zero_when_not_in_transaction(self):
"""Test depth is 0 when not in transaction."""
from src.server.database.transaction import get_transaction_depth
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
depth = get_transaction_depth(mock_session)
assert depth == 0
def test_depth_one_in_transaction(self):
"""Test depth is 1 in basic transaction."""
from src.server.database.transaction import get_transaction_depth
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = True
mock_session._nested_transaction = None
depth = get_transaction_depth(mock_session)
assert depth == 1
def test_depth_two_with_nested_transaction(self):
"""Test depth is 2 with nested transaction."""
from src.server.database.transaction import get_transaction_depth
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = True
mock_session._nested_transaction = MagicMock() # Has nested
depth = get_transaction_depth(mock_session)
assert depth == 2
class TestTransactionalDecoratorPositionalArgs:
"""Tests for transactional decorator with positional arguments."""
@pytest.mark.asyncio
async def test_session_from_positional_arg(self):
"""Test decorator extracts session from positional argument."""
mock_session = AsyncMock(spec=AsyncSession)
mock_session.in_transaction.return_value = False
mock_session.begin = AsyncMock()
mock_session.commit = AsyncMock()
mock_session.rollback = AsyncMock()
@transactional()
async def operation(db: AsyncSession, data: str):
return f"processed: {data}"
# Pass session as positional argument
result = await operation(mock_session, "test")
assert result == "processed: test"
mock_session.commit.assert_called_once()
def test_sync_session_from_positional_arg(self):
"""Test sync decorator extracts session from positional argument."""
mock_session = MagicMock(spec=Session)
mock_session.in_transaction.return_value = False
@transactional()
def operation(db: Session, data: str):
return f"processed: {data}"
result = operation(mock_session, "test")
assert result == "processed: test"
mock_session.commit.assert_called_once()