Aniworld/tests/integration/test_db_transactions.py
Lukas 1ba67357dc Add database transaction support with atomic operations
- Create transaction.py with @transactional decorator, atomic() context manager
- Add TransactionPropagation modes: REQUIRED, REQUIRES_NEW, NESTED
- Add savepoint support for nested transactions with partial rollback
- Update connection.py with TransactionManager, get_transactional_session
- Update service.py with bulk operations (bulk_mark_downloaded, bulk_delete)
- Wrap QueueRepository.save_item() and clear_all() in atomic transactions
- Add comprehensive tests (66 transaction tests, 90% coverage)
- All 1090 tests passing
2025-12-25 18:05:33 +01:00

547 lines
18 KiB
Python

"""Integration tests for database transaction behavior.
Tests real database transaction handling including:
- Transaction isolation
- Concurrent transaction handling
- Real commit/rollback behavior
"""
import asyncio
from datetime import datetime, timedelta, timezone
from typing import List
import pytest
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from src.server.database.base import Base
from src.server.database.connection import (
TransactionManager,
get_session_transaction_depth,
is_session_in_transaction,
)
from src.server.database.models import AnimeSeries, DownloadQueueItem, Episode
from src.server.database.service import (
AnimeSeriesService,
DownloadQueueService,
EpisodeService,
)
from src.server.database.transaction import (
TransactionPropagation,
atomic,
transactional,
)
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
async def db_engine():
"""Create in-memory database engine for testing."""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def session_factory(db_engine):
"""Create session factory for testing."""
from sqlalchemy.ext.asyncio import async_sessionmaker
return async_sessionmaker(
db_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
@pytest.fixture
async def db_session(session_factory):
"""Create database session for testing."""
async with session_factory() as session:
yield session
await session.rollback()
# ============================================================================
# Real Database Transaction Tests
# ============================================================================
class TestRealDatabaseTransactions:
"""Tests using real in-memory database."""
@pytest.mark.asyncio
async def test_commit_persists_data(self, db_session):
"""Test that committed data is actually persisted."""
async with atomic(db_session):
series = await AnimeSeriesService.create(
db_session,
key="commit-test",
name="Commit Test Series",
site="https://test.com",
folder="/test/folder",
)
# Data should be retrievable after commit
retrieved = await AnimeSeriesService.get_by_key(
db_session, "commit-test"
)
assert retrieved is not None
assert retrieved.name == "Commit Test Series"
@pytest.mark.asyncio
async def test_rollback_discards_data(self, db_session):
"""Test that rolled back data is discarded."""
try:
async with atomic(db_session):
series = await AnimeSeriesService.create(
db_session,
key="rollback-test",
name="Rollback Test Series",
site="https://test.com",
folder="/test/folder",
)
await db_session.flush()
raise ValueError("Force rollback")
except ValueError:
pass
# Data should NOT be retrievable after rollback
retrieved = await AnimeSeriesService.get_by_key(
db_session, "rollback-test"
)
assert retrieved is None
@pytest.mark.asyncio
async def test_multiple_operations_atomic(self, db_session):
"""Test multiple operations are committed together."""
async with atomic(db_session):
# Create series
series = await AnimeSeriesService.create(
db_session,
key="atomic-multi-test",
name="Atomic Multi Test",
site="https://test.com",
folder="/test/folder",
)
# Create episode
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
title="Episode 1",
)
# Create queue item
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
# All should be persisted
retrieved_series = await AnimeSeriesService.get_by_key(
db_session, "atomic-multi-test"
)
assert retrieved_series is not None
episodes = await EpisodeService.get_by_series(
db_session, retrieved_series.id
)
assert len(episodes) == 1
queue_items = await DownloadQueueService.get_all(db_session)
assert len(queue_items) >= 1
@pytest.mark.asyncio
async def test_multiple_operations_rollback_all(self, db_session):
"""Test multiple operations are all rolled back on failure."""
try:
async with atomic(db_session):
# Create series
series = await AnimeSeriesService.create(
db_session,
key="rollback-multi-test",
name="Rollback Multi Test",
site="https://test.com",
folder="/test/folder",
)
# Create episode
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
# Create queue item
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
await db_session.flush()
raise RuntimeError("Force complete rollback")
except RuntimeError:
pass
# None should be persisted
retrieved_series = await AnimeSeriesService.get_by_key(
db_session, "rollback-multi-test"
)
assert retrieved_series is None
# ============================================================================
# Transaction Manager Tests
# ============================================================================
class TestTransactionManager:
"""Tests for TransactionManager class."""
@pytest.mark.asyncio
async def test_transaction_manager_basic_flow(self, session_factory):
"""Test basic transaction manager usage."""
async with TransactionManager(session_factory) as tm:
session = await tm.get_session()
await tm.begin()
series = AnimeSeries(
key="tm-test",
name="TM Test",
site="https://test.com",
folder="/test",
)
session.add(series)
await tm.commit()
# Verify data persisted
async with session_factory() as verify_session:
from sqlalchemy import select
result = await verify_session.execute(
select(AnimeSeries).where(AnimeSeries.key == "tm-test")
)
series = result.scalar_one_or_none()
assert series is not None
@pytest.mark.asyncio
async def test_transaction_manager_rollback(self, session_factory):
"""Test transaction manager rollback."""
async with TransactionManager(session_factory) as tm:
session = await tm.get_session()
await tm.begin()
series = AnimeSeries(
key="tm-rollback-test",
name="TM Rollback Test",
site="https://test.com",
folder="/test",
)
session.add(series)
await session.flush()
await tm.rollback()
# Verify data NOT persisted
async with session_factory() as verify_session:
from sqlalchemy import select
result = await verify_session.execute(
select(AnimeSeries).where(AnimeSeries.key == "tm-rollback-test")
)
series = result.scalar_one_or_none()
assert series is None
@pytest.mark.asyncio
async def test_transaction_manager_auto_rollback_on_exception(
self, session_factory
):
"""Test transaction manager auto-rolls back on exception."""
with pytest.raises(ValueError):
async with TransactionManager(session_factory) as tm:
session = await tm.get_session()
await tm.begin()
series = AnimeSeries(
key="tm-auto-rollback",
name="TM Auto Rollback",
site="https://test.com",
folder="/test",
)
session.add(series)
await session.flush()
raise ValueError("Force exception")
# Verify data NOT persisted
async with session_factory() as verify_session:
from sqlalchemy import select
result = await verify_session.execute(
select(AnimeSeries).where(AnimeSeries.key == "tm-auto-rollback")
)
series = result.scalar_one_or_none()
assert series is None
@pytest.mark.asyncio
async def test_transaction_manager_state_tracking(self, session_factory):
"""Test transaction manager tracks state correctly."""
async with TransactionManager(session_factory) as tm:
assert tm.is_in_transaction() is False
await tm.begin()
assert tm.is_in_transaction() is True
await tm.commit()
assert tm.is_in_transaction() is False
# ============================================================================
# Helper Function Tests
# ============================================================================
class TestConnectionHelpers:
"""Tests for connection module helper functions."""
@pytest.mark.asyncio
async def test_is_session_in_transaction(self, db_session):
"""Test is_session_in_transaction helper."""
# Initially not in transaction
assert is_session_in_transaction(db_session) is False
async with atomic(db_session):
# Now in transaction
assert is_session_in_transaction(db_session) is True
# After exit, depends on session state
# SQLite behavior may vary
@pytest.mark.asyncio
async def test_get_session_transaction_depth(self, db_session):
"""Test get_session_transaction_depth helper."""
depth = get_session_transaction_depth(db_session)
assert depth >= 0
# ============================================================================
# @transactional Decorator Integration Tests
# ============================================================================
class TestTransactionalDecoratorIntegration:
"""Integration tests for @transactional decorator."""
@pytest.mark.asyncio
async def test_decorated_function_commits(self, db_session):
"""Test decorated function commits on success."""
@transactional()
async def create_series_decorated(db: AsyncSession):
return await AnimeSeriesService.create(
db,
key="decorated-test",
name="Decorated Test",
site="https://test.com",
folder="/test",
)
series = await create_series_decorated(db=db_session)
# Verify committed
retrieved = await AnimeSeriesService.get_by_key(
db_session, "decorated-test"
)
assert retrieved is not None
@pytest.mark.asyncio
async def test_decorated_function_rollback(self, db_session):
"""Test decorated function rolls back on error."""
@transactional()
async def create_then_fail(db: AsyncSession):
await AnimeSeriesService.create(
db,
key="decorated-rollback",
name="Decorated Rollback",
site="https://test.com",
folder="/test",
)
raise ValueError("Force failure")
with pytest.raises(ValueError):
await create_then_fail(db=db_session)
# Verify NOT committed
retrieved = await AnimeSeriesService.get_by_key(
db_session, "decorated-rollback"
)
assert retrieved is None
@pytest.mark.asyncio
async def test_nested_decorated_functions(self, db_session):
"""Test nested decorated functions work correctly."""
@transactional(propagation=TransactionPropagation.NESTED)
async def inner_operation(db: AsyncSession, series_id: int):
return await EpisodeService.create(
db,
series_id=series_id,
season=1,
episode_number=1,
)
@transactional()
async def outer_operation(db: AsyncSession):
series = await AnimeSeriesService.create(
db,
key="nested-decorated",
name="Nested Decorated",
site="https://test.com",
folder="/test",
)
episode = await inner_operation(db=db, series_id=series.id)
return series, episode
series, episode = await outer_operation(db=db_session)
# Both should be committed
assert series is not None
assert episode is not None
# ============================================================================
# Concurrent Transaction Tests
# ============================================================================
class TestConcurrentTransactions:
"""Tests for concurrent transaction handling."""
@pytest.mark.asyncio
async def test_concurrent_writes_different_keys(self, session_factory):
"""Test concurrent writes to different records."""
async def create_series(key: str):
async with session_factory() as session:
async with atomic(session):
await AnimeSeriesService.create(
session,
key=key,
name=f"Series {key}",
site="https://test.com",
folder=f"/test/{key}",
)
# Run concurrent creates
await asyncio.gather(
create_series("concurrent-1"),
create_series("concurrent-2"),
create_series("concurrent-3"),
)
# Verify all created
async with session_factory() as verify_session:
for i in range(1, 4):
series = await AnimeSeriesService.get_by_key(
verify_session, f"concurrent-{i}"
)
assert series is not None
# ============================================================================
# Queue Repository Transaction Tests
# ============================================================================
class TestQueueRepositoryTransactions:
"""Integration tests for QueueRepository transaction handling."""
@pytest.mark.asyncio
async def test_save_item_atomic(self, session_factory):
"""Test save_item creates series, episode, and queue item atomically."""
from src.server.models.download import (
DownloadItem,
DownloadStatus,
EpisodeIdentifier,
)
from src.server.services.queue_repository import QueueRepository
repo = QueueRepository(session_factory)
item = DownloadItem(
id="temp-id",
serie_id="repo-atomic-test",
serie_folder="/test/folder",
serie_name="Repo Atomic Test",
episode=EpisodeIdentifier(season=1, episode=1),
status=DownloadStatus.PENDING,
)
saved_item = await repo.save_item(item)
assert saved_item.id != "temp-id" # Should have DB ID
# Verify all entities created
async with session_factory() as verify_session:
series = await AnimeSeriesService.get_by_key(
verify_session, "repo-atomic-test"
)
assert series is not None
episodes = await EpisodeService.get_by_series(
verify_session, series.id
)
assert len(episodes) == 1
queue_items = await DownloadQueueService.get_all(verify_session)
assert len(queue_items) >= 1
@pytest.mark.asyncio
async def test_clear_all_atomic(self, session_factory):
"""Test clear_all removes all items atomically."""
from src.server.models.download import (
DownloadItem,
DownloadStatus,
EpisodeIdentifier,
)
from src.server.services.queue_repository import QueueRepository
repo = QueueRepository(session_factory)
# Add some items
for i in range(3):
item = DownloadItem(
id=f"clear-{i}",
serie_id=f"clear-series-{i}",
serie_folder=f"/test/folder/{i}",
serie_name=f"Clear Series {i}",
episode=EpisodeIdentifier(season=1, episode=1),
status=DownloadStatus.PENDING,
)
await repo.save_item(item)
# Clear all
count = await repo.clear_all()
assert count == 3
# Verify all cleared
async with session_factory() as verify_session:
queue_items = await DownloadQueueService.get_all(verify_session)
assert len(queue_items) == 0