- Create transaction.py with @transactional decorator, atomic() context manager - Add TransactionPropagation modes: REQUIRED, REQUIRES_NEW, NESTED - Add savepoint support for nested transactions with partial rollback - Update connection.py with TransactionManager, get_transactional_session - Update service.py with bulk operations (bulk_mark_downloaded, bulk_delete) - Wrap QueueRepository.save_item() and clear_all() in atomic transactions - Add comprehensive tests (66 transaction tests, 90% coverage) - All 1090 tests passing
547 lines
18 KiB
Python
547 lines
18 KiB
Python
"""Integration tests for database transaction behavior.
|
|
|
|
Tests real database transaction handling including:
|
|
- Transaction isolation
|
|
- Concurrent transaction handling
|
|
- Real commit/rollback behavior
|
|
"""
|
|
import asyncio
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import List
|
|
|
|
import pytest
|
|
from sqlalchemy import text
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
|
|
from src.server.database.base import Base
|
|
from src.server.database.connection import (
|
|
TransactionManager,
|
|
get_session_transaction_depth,
|
|
is_session_in_transaction,
|
|
)
|
|
from src.server.database.models import AnimeSeries, DownloadQueueItem, Episode
|
|
from src.server.database.service import (
|
|
AnimeSeriesService,
|
|
DownloadQueueService,
|
|
EpisodeService,
|
|
)
|
|
from src.server.database.transaction import (
|
|
TransactionPropagation,
|
|
atomic,
|
|
transactional,
|
|
)
|
|
|
|
# ============================================================================
|
|
# Fixtures
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.fixture
|
|
async def db_engine():
|
|
"""Create in-memory database engine for testing."""
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
echo=False,
|
|
)
|
|
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
yield engine
|
|
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest.fixture
|
|
async def session_factory(db_engine):
|
|
"""Create session factory for testing."""
|
|
from sqlalchemy.ext.asyncio import async_sessionmaker
|
|
|
|
return async_sessionmaker(
|
|
db_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
async def db_session(session_factory):
|
|
"""Create database session for testing."""
|
|
async with session_factory() as session:
|
|
yield session
|
|
await session.rollback()
|
|
|
|
|
|
# ============================================================================
|
|
# Real Database Transaction Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestRealDatabaseTransactions:
|
|
"""Tests using real in-memory database."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_commit_persists_data(self, db_session):
|
|
"""Test that committed data is actually persisted."""
|
|
async with atomic(db_session):
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="commit-test",
|
|
name="Commit Test Series",
|
|
site="https://test.com",
|
|
folder="/test/folder",
|
|
)
|
|
|
|
# Data should be retrievable after commit
|
|
retrieved = await AnimeSeriesService.get_by_key(
|
|
db_session, "commit-test"
|
|
)
|
|
assert retrieved is not None
|
|
assert retrieved.name == "Commit Test Series"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rollback_discards_data(self, db_session):
|
|
"""Test that rolled back data is discarded."""
|
|
try:
|
|
async with atomic(db_session):
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="rollback-test",
|
|
name="Rollback Test Series",
|
|
site="https://test.com",
|
|
folder="/test/folder",
|
|
)
|
|
await db_session.flush()
|
|
|
|
raise ValueError("Force rollback")
|
|
except ValueError:
|
|
pass
|
|
|
|
# Data should NOT be retrievable after rollback
|
|
retrieved = await AnimeSeriesService.get_by_key(
|
|
db_session, "rollback-test"
|
|
)
|
|
assert retrieved is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_operations_atomic(self, db_session):
|
|
"""Test multiple operations are committed together."""
|
|
async with atomic(db_session):
|
|
# Create series
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="atomic-multi-test",
|
|
name="Atomic Multi Test",
|
|
site="https://test.com",
|
|
folder="/test/folder",
|
|
)
|
|
|
|
# Create episode
|
|
episode = await EpisodeService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
season=1,
|
|
episode_number=1,
|
|
title="Episode 1",
|
|
)
|
|
|
|
# Create queue item
|
|
item = await DownloadQueueService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
episode_id=episode.id,
|
|
)
|
|
|
|
# All should be persisted
|
|
retrieved_series = await AnimeSeriesService.get_by_key(
|
|
db_session, "atomic-multi-test"
|
|
)
|
|
assert retrieved_series is not None
|
|
|
|
episodes = await EpisodeService.get_by_series(
|
|
db_session, retrieved_series.id
|
|
)
|
|
assert len(episodes) == 1
|
|
|
|
queue_items = await DownloadQueueService.get_all(db_session)
|
|
assert len(queue_items) >= 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_multiple_operations_rollback_all(self, db_session):
|
|
"""Test multiple operations are all rolled back on failure."""
|
|
try:
|
|
async with atomic(db_session):
|
|
# Create series
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="rollback-multi-test",
|
|
name="Rollback Multi Test",
|
|
site="https://test.com",
|
|
folder="/test/folder",
|
|
)
|
|
|
|
# Create episode
|
|
episode = await EpisodeService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
season=1,
|
|
episode_number=1,
|
|
)
|
|
|
|
# Create queue item
|
|
item = await DownloadQueueService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
episode_id=episode.id,
|
|
)
|
|
|
|
await db_session.flush()
|
|
|
|
raise RuntimeError("Force complete rollback")
|
|
except RuntimeError:
|
|
pass
|
|
|
|
# None should be persisted
|
|
retrieved_series = await AnimeSeriesService.get_by_key(
|
|
db_session, "rollback-multi-test"
|
|
)
|
|
assert retrieved_series is None
|
|
|
|
|
|
# ============================================================================
|
|
# Transaction Manager Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestTransactionManager:
|
|
"""Tests for TransactionManager class."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transaction_manager_basic_flow(self, session_factory):
|
|
"""Test basic transaction manager usage."""
|
|
async with TransactionManager(session_factory) as tm:
|
|
session = await tm.get_session()
|
|
await tm.begin()
|
|
|
|
series = AnimeSeries(
|
|
key="tm-test",
|
|
name="TM Test",
|
|
site="https://test.com",
|
|
folder="/test",
|
|
)
|
|
session.add(series)
|
|
|
|
await tm.commit()
|
|
|
|
# Verify data persisted
|
|
async with session_factory() as verify_session:
|
|
from sqlalchemy import select
|
|
result = await verify_session.execute(
|
|
select(AnimeSeries).where(AnimeSeries.key == "tm-test")
|
|
)
|
|
series = result.scalar_one_or_none()
|
|
assert series is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transaction_manager_rollback(self, session_factory):
|
|
"""Test transaction manager rollback."""
|
|
async with TransactionManager(session_factory) as tm:
|
|
session = await tm.get_session()
|
|
await tm.begin()
|
|
|
|
series = AnimeSeries(
|
|
key="tm-rollback-test",
|
|
name="TM Rollback Test",
|
|
site="https://test.com",
|
|
folder="/test",
|
|
)
|
|
session.add(series)
|
|
await session.flush()
|
|
|
|
await tm.rollback()
|
|
|
|
# Verify data NOT persisted
|
|
async with session_factory() as verify_session:
|
|
from sqlalchemy import select
|
|
result = await verify_session.execute(
|
|
select(AnimeSeries).where(AnimeSeries.key == "tm-rollback-test")
|
|
)
|
|
series = result.scalar_one_or_none()
|
|
assert series is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transaction_manager_auto_rollback_on_exception(
|
|
self, session_factory
|
|
):
|
|
"""Test transaction manager auto-rolls back on exception."""
|
|
with pytest.raises(ValueError):
|
|
async with TransactionManager(session_factory) as tm:
|
|
session = await tm.get_session()
|
|
await tm.begin()
|
|
|
|
series = AnimeSeries(
|
|
key="tm-auto-rollback",
|
|
name="TM Auto Rollback",
|
|
site="https://test.com",
|
|
folder="/test",
|
|
)
|
|
session.add(series)
|
|
await session.flush()
|
|
|
|
raise ValueError("Force exception")
|
|
|
|
# Verify data NOT persisted
|
|
async with session_factory() as verify_session:
|
|
from sqlalchemy import select
|
|
result = await verify_session.execute(
|
|
select(AnimeSeries).where(AnimeSeries.key == "tm-auto-rollback")
|
|
)
|
|
series = result.scalar_one_or_none()
|
|
assert series is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_transaction_manager_state_tracking(self, session_factory):
|
|
"""Test transaction manager tracks state correctly."""
|
|
async with TransactionManager(session_factory) as tm:
|
|
assert tm.is_in_transaction() is False
|
|
|
|
await tm.begin()
|
|
assert tm.is_in_transaction() is True
|
|
|
|
await tm.commit()
|
|
assert tm.is_in_transaction() is False
|
|
|
|
|
|
# ============================================================================
|
|
# Helper Function Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestConnectionHelpers:
|
|
"""Tests for connection module helper functions."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_is_session_in_transaction(self, db_session):
|
|
"""Test is_session_in_transaction helper."""
|
|
# Initially not in transaction
|
|
assert is_session_in_transaction(db_session) is False
|
|
|
|
async with atomic(db_session):
|
|
# Now in transaction
|
|
assert is_session_in_transaction(db_session) is True
|
|
|
|
# After exit, depends on session state
|
|
# SQLite behavior may vary
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_session_transaction_depth(self, db_session):
|
|
"""Test get_session_transaction_depth helper."""
|
|
depth = get_session_transaction_depth(db_session)
|
|
assert depth >= 0
|
|
|
|
|
|
# ============================================================================
|
|
# @transactional Decorator Integration Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestTransactionalDecoratorIntegration:
|
|
"""Integration tests for @transactional decorator."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decorated_function_commits(self, db_session):
|
|
"""Test decorated function commits on success."""
|
|
@transactional()
|
|
async def create_series_decorated(db: AsyncSession):
|
|
return await AnimeSeriesService.create(
|
|
db,
|
|
key="decorated-test",
|
|
name="Decorated Test",
|
|
site="https://test.com",
|
|
folder="/test",
|
|
)
|
|
|
|
series = await create_series_decorated(db=db_session)
|
|
|
|
# Verify committed
|
|
retrieved = await AnimeSeriesService.get_by_key(
|
|
db_session, "decorated-test"
|
|
)
|
|
assert retrieved is not None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_decorated_function_rollback(self, db_session):
|
|
"""Test decorated function rolls back on error."""
|
|
@transactional()
|
|
async def create_then_fail(db: AsyncSession):
|
|
await AnimeSeriesService.create(
|
|
db,
|
|
key="decorated-rollback",
|
|
name="Decorated Rollback",
|
|
site="https://test.com",
|
|
folder="/test",
|
|
)
|
|
raise ValueError("Force failure")
|
|
|
|
with pytest.raises(ValueError):
|
|
await create_then_fail(db=db_session)
|
|
|
|
# Verify NOT committed
|
|
retrieved = await AnimeSeriesService.get_by_key(
|
|
db_session, "decorated-rollback"
|
|
)
|
|
assert retrieved is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_nested_decorated_functions(self, db_session):
|
|
"""Test nested decorated functions work correctly."""
|
|
@transactional(propagation=TransactionPropagation.NESTED)
|
|
async def inner_operation(db: AsyncSession, series_id: int):
|
|
return await EpisodeService.create(
|
|
db,
|
|
series_id=series_id,
|
|
season=1,
|
|
episode_number=1,
|
|
)
|
|
|
|
@transactional()
|
|
async def outer_operation(db: AsyncSession):
|
|
series = await AnimeSeriesService.create(
|
|
db,
|
|
key="nested-decorated",
|
|
name="Nested Decorated",
|
|
site="https://test.com",
|
|
folder="/test",
|
|
)
|
|
episode = await inner_operation(db=db, series_id=series.id)
|
|
return series, episode
|
|
|
|
series, episode = await outer_operation(db=db_session)
|
|
|
|
# Both should be committed
|
|
assert series is not None
|
|
assert episode is not None
|
|
|
|
|
|
# ============================================================================
|
|
# Concurrent Transaction Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestConcurrentTransactions:
|
|
"""Tests for concurrent transaction handling."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_concurrent_writes_different_keys(self, session_factory):
|
|
"""Test concurrent writes to different records."""
|
|
async def create_series(key: str):
|
|
async with session_factory() as session:
|
|
async with atomic(session):
|
|
await AnimeSeriesService.create(
|
|
session,
|
|
key=key,
|
|
name=f"Series {key}",
|
|
site="https://test.com",
|
|
folder=f"/test/{key}",
|
|
)
|
|
|
|
# Run concurrent creates
|
|
await asyncio.gather(
|
|
create_series("concurrent-1"),
|
|
create_series("concurrent-2"),
|
|
create_series("concurrent-3"),
|
|
)
|
|
|
|
# Verify all created
|
|
async with session_factory() as verify_session:
|
|
for i in range(1, 4):
|
|
series = await AnimeSeriesService.get_by_key(
|
|
verify_session, f"concurrent-{i}"
|
|
)
|
|
assert series is not None
|
|
|
|
|
|
# ============================================================================
|
|
# Queue Repository Transaction Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestQueueRepositoryTransactions:
|
|
"""Integration tests for QueueRepository transaction handling."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_save_item_atomic(self, session_factory):
|
|
"""Test save_item creates series, episode, and queue item atomically."""
|
|
from src.server.models.download import (
|
|
DownloadItem,
|
|
DownloadStatus,
|
|
EpisodeIdentifier,
|
|
)
|
|
from src.server.services.queue_repository import QueueRepository
|
|
|
|
repo = QueueRepository(session_factory)
|
|
|
|
item = DownloadItem(
|
|
id="temp-id",
|
|
serie_id="repo-atomic-test",
|
|
serie_folder="/test/folder",
|
|
serie_name="Repo Atomic Test",
|
|
episode=EpisodeIdentifier(season=1, episode=1),
|
|
status=DownloadStatus.PENDING,
|
|
)
|
|
|
|
saved_item = await repo.save_item(item)
|
|
|
|
assert saved_item.id != "temp-id" # Should have DB ID
|
|
|
|
# Verify all entities created
|
|
async with session_factory() as verify_session:
|
|
series = await AnimeSeriesService.get_by_key(
|
|
verify_session, "repo-atomic-test"
|
|
)
|
|
assert series is not None
|
|
|
|
episodes = await EpisodeService.get_by_series(
|
|
verify_session, series.id
|
|
)
|
|
assert len(episodes) == 1
|
|
|
|
queue_items = await DownloadQueueService.get_all(verify_session)
|
|
assert len(queue_items) >= 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_clear_all_atomic(self, session_factory):
|
|
"""Test clear_all removes all items atomically."""
|
|
from src.server.models.download import (
|
|
DownloadItem,
|
|
DownloadStatus,
|
|
EpisodeIdentifier,
|
|
)
|
|
from src.server.services.queue_repository import QueueRepository
|
|
|
|
repo = QueueRepository(session_factory)
|
|
|
|
# Add some items
|
|
for i in range(3):
|
|
item = DownloadItem(
|
|
id=f"clear-{i}",
|
|
serie_id=f"clear-series-{i}",
|
|
serie_folder=f"/test/folder/{i}",
|
|
serie_name=f"Clear Series {i}",
|
|
episode=EpisodeIdentifier(season=1, episode=1),
|
|
status=DownloadStatus.PENDING,
|
|
)
|
|
await repo.save_item(item)
|
|
|
|
# Clear all
|
|
count = await repo.clear_all()
|
|
|
|
assert count == 3
|
|
|
|
# Verify all cleared
|
|
async with session_factory() as verify_session:
|
|
queue_items = await DownloadQueueService.get_all(verify_session)
|
|
assert len(queue_items) == 0
|