Add database transaction support with atomic operations
- Create transaction.py with @transactional decorator, atomic() context manager - Add TransactionPropagation modes: REQUIRED, REQUIRES_NEW, NESTED - Add savepoint support for nested transactions with partial rollback - Update connection.py with TransactionManager, get_transactional_session - Update service.py with bulk operations (bulk_mark_downloaded, bulk_delete) - Wrap QueueRepository.save_item() and clear_all() in atomic transactions - Add comprehensive tests (66 transaction tests, 90% coverage) - All 1090 tests passing
This commit is contained in:
546
tests/integration/test_db_transactions.py
Normal file
546
tests/integration/test_db_transactions.py
Normal file
@@ -0,0 +1,546 @@
|
||||
"""Integration tests for database transaction behavior.
|
||||
|
||||
Tests real database transaction handling including:
|
||||
- Transaction isolation
|
||||
- Concurrent transaction handling
|
||||
- Real commit/rollback behavior
|
||||
"""
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
|
||||
from src.server.database.base import Base
|
||||
from src.server.database.connection import (
|
||||
TransactionManager,
|
||||
get_session_transaction_depth,
|
||||
is_session_in_transaction,
|
||||
)
|
||||
from src.server.database.models import AnimeSeries, DownloadQueueItem, Episode
|
||||
from src.server.database.service import (
|
||||
AnimeSeriesService,
|
||||
DownloadQueueService,
|
||||
EpisodeService,
|
||||
)
|
||||
from src.server.database.transaction import (
|
||||
TransactionPropagation,
|
||||
atomic,
|
||||
transactional,
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
"""Create in-memory database engine for testing."""
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
echo=False,
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def session_factory(db_engine):
|
||||
"""Create session factory for testing."""
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
return async_sessionmaker(
|
||||
db_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(session_factory):
|
||||
"""Create database session for testing."""
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Real Database Transaction Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestRealDatabaseTransactions:
|
||||
"""Tests using real in-memory database."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commit_persists_data(self, db_session):
|
||||
"""Test that committed data is actually persisted."""
|
||||
async with atomic(db_session):
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
key="commit-test",
|
||||
name="Commit Test Series",
|
||||
site="https://test.com",
|
||||
folder="/test/folder",
|
||||
)
|
||||
|
||||
# Data should be retrievable after commit
|
||||
retrieved = await AnimeSeriesService.get_by_key(
|
||||
db_session, "commit-test"
|
||||
)
|
||||
assert retrieved is not None
|
||||
assert retrieved.name == "Commit Test Series"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rollback_discards_data(self, db_session):
|
||||
"""Test that rolled back data is discarded."""
|
||||
try:
|
||||
async with atomic(db_session):
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
key="rollback-test",
|
||||
name="Rollback Test Series",
|
||||
site="https://test.com",
|
||||
folder="/test/folder",
|
||||
)
|
||||
await db_session.flush()
|
||||
|
||||
raise ValueError("Force rollback")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Data should NOT be retrievable after rollback
|
||||
retrieved = await AnimeSeriesService.get_by_key(
|
||||
db_session, "rollback-test"
|
||||
)
|
||||
assert retrieved is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_operations_atomic(self, db_session):
|
||||
"""Test multiple operations are committed together."""
|
||||
async with atomic(db_session):
|
||||
# Create series
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
key="atomic-multi-test",
|
||||
name="Atomic Multi Test",
|
||||
site="https://test.com",
|
||||
folder="/test/folder",
|
||||
)
|
||||
|
||||
# Create episode
|
||||
episode = await EpisodeService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
title="Episode 1",
|
||||
)
|
||||
|
||||
# Create queue item
|
||||
item = await DownloadQueueService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
episode_id=episode.id,
|
||||
)
|
||||
|
||||
# All should be persisted
|
||||
retrieved_series = await AnimeSeriesService.get_by_key(
|
||||
db_session, "atomic-multi-test"
|
||||
)
|
||||
assert retrieved_series is not None
|
||||
|
||||
episodes = await EpisodeService.get_by_series(
|
||||
db_session, retrieved_series.id
|
||||
)
|
||||
assert len(episodes) == 1
|
||||
|
||||
queue_items = await DownloadQueueService.get_all(db_session)
|
||||
assert len(queue_items) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_operations_rollback_all(self, db_session):
|
||||
"""Test multiple operations are all rolled back on failure."""
|
||||
try:
|
||||
async with atomic(db_session):
|
||||
# Create series
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
key="rollback-multi-test",
|
||||
name="Rollback Multi Test",
|
||||
site="https://test.com",
|
||||
folder="/test/folder",
|
||||
)
|
||||
|
||||
# Create episode
|
||||
episode = await EpisodeService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
)
|
||||
|
||||
# Create queue item
|
||||
item = await DownloadQueueService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
episode_id=episode.id,
|
||||
)
|
||||
|
||||
await db_session.flush()
|
||||
|
||||
raise RuntimeError("Force complete rollback")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# None should be persisted
|
||||
retrieved_series = await AnimeSeriesService.get_by_key(
|
||||
db_session, "rollback-multi-test"
|
||||
)
|
||||
assert retrieved_series is None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Transaction Manager Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestTransactionManager:
|
||||
"""Tests for TransactionManager class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transaction_manager_basic_flow(self, session_factory):
|
||||
"""Test basic transaction manager usage."""
|
||||
async with TransactionManager(session_factory) as tm:
|
||||
session = await tm.get_session()
|
||||
await tm.begin()
|
||||
|
||||
series = AnimeSeries(
|
||||
key="tm-test",
|
||||
name="TM Test",
|
||||
site="https://test.com",
|
||||
folder="/test",
|
||||
)
|
||||
session.add(series)
|
||||
|
||||
await tm.commit()
|
||||
|
||||
# Verify data persisted
|
||||
async with session_factory() as verify_session:
|
||||
from sqlalchemy import select
|
||||
result = await verify_session.execute(
|
||||
select(AnimeSeries).where(AnimeSeries.key == "tm-test")
|
||||
)
|
||||
series = result.scalar_one_or_none()
|
||||
assert series is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transaction_manager_rollback(self, session_factory):
|
||||
"""Test transaction manager rollback."""
|
||||
async with TransactionManager(session_factory) as tm:
|
||||
session = await tm.get_session()
|
||||
await tm.begin()
|
||||
|
||||
series = AnimeSeries(
|
||||
key="tm-rollback-test",
|
||||
name="TM Rollback Test",
|
||||
site="https://test.com",
|
||||
folder="/test",
|
||||
)
|
||||
session.add(series)
|
||||
await session.flush()
|
||||
|
||||
await tm.rollback()
|
||||
|
||||
# Verify data NOT persisted
|
||||
async with session_factory() as verify_session:
|
||||
from sqlalchemy import select
|
||||
result = await verify_session.execute(
|
||||
select(AnimeSeries).where(AnimeSeries.key == "tm-rollback-test")
|
||||
)
|
||||
series = result.scalar_one_or_none()
|
||||
assert series is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transaction_manager_auto_rollback_on_exception(
|
||||
self, session_factory
|
||||
):
|
||||
"""Test transaction manager auto-rolls back on exception."""
|
||||
with pytest.raises(ValueError):
|
||||
async with TransactionManager(session_factory) as tm:
|
||||
session = await tm.get_session()
|
||||
await tm.begin()
|
||||
|
||||
series = AnimeSeries(
|
||||
key="tm-auto-rollback",
|
||||
name="TM Auto Rollback",
|
||||
site="https://test.com",
|
||||
folder="/test",
|
||||
)
|
||||
session.add(series)
|
||||
await session.flush()
|
||||
|
||||
raise ValueError("Force exception")
|
||||
|
||||
# Verify data NOT persisted
|
||||
async with session_factory() as verify_session:
|
||||
from sqlalchemy import select
|
||||
result = await verify_session.execute(
|
||||
select(AnimeSeries).where(AnimeSeries.key == "tm-auto-rollback")
|
||||
)
|
||||
series = result.scalar_one_or_none()
|
||||
assert series is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transaction_manager_state_tracking(self, session_factory):
|
||||
"""Test transaction manager tracks state correctly."""
|
||||
async with TransactionManager(session_factory) as tm:
|
||||
assert tm.is_in_transaction() is False
|
||||
|
||||
await tm.begin()
|
||||
assert tm.is_in_transaction() is True
|
||||
|
||||
await tm.commit()
|
||||
assert tm.is_in_transaction() is False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Helper Function Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestConnectionHelpers:
|
||||
"""Tests for connection module helper functions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_session_in_transaction(self, db_session):
|
||||
"""Test is_session_in_transaction helper."""
|
||||
# Initially not in transaction
|
||||
assert is_session_in_transaction(db_session) is False
|
||||
|
||||
async with atomic(db_session):
|
||||
# Now in transaction
|
||||
assert is_session_in_transaction(db_session) is True
|
||||
|
||||
# After exit, depends on session state
|
||||
# SQLite behavior may vary
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_transaction_depth(self, db_session):
|
||||
"""Test get_session_transaction_depth helper."""
|
||||
depth = get_session_transaction_depth(db_session)
|
||||
assert depth >= 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# @transactional Decorator Integration Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestTransactionalDecoratorIntegration:
|
||||
"""Integration tests for @transactional decorator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorated_function_commits(self, db_session):
|
||||
"""Test decorated function commits on success."""
|
||||
@transactional()
|
||||
async def create_series_decorated(db: AsyncSession):
|
||||
return await AnimeSeriesService.create(
|
||||
db,
|
||||
key="decorated-test",
|
||||
name="Decorated Test",
|
||||
site="https://test.com",
|
||||
folder="/test",
|
||||
)
|
||||
|
||||
series = await create_series_decorated(db=db_session)
|
||||
|
||||
# Verify committed
|
||||
retrieved = await AnimeSeriesService.get_by_key(
|
||||
db_session, "decorated-test"
|
||||
)
|
||||
assert retrieved is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorated_function_rollback(self, db_session):
|
||||
"""Test decorated function rolls back on error."""
|
||||
@transactional()
|
||||
async def create_then_fail(db: AsyncSession):
|
||||
await AnimeSeriesService.create(
|
||||
db,
|
||||
key="decorated-rollback",
|
||||
name="Decorated Rollback",
|
||||
site="https://test.com",
|
||||
folder="/test",
|
||||
)
|
||||
raise ValueError("Force failure")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await create_then_fail(db=db_session)
|
||||
|
||||
# Verify NOT committed
|
||||
retrieved = await AnimeSeriesService.get_by_key(
|
||||
db_session, "decorated-rollback"
|
||||
)
|
||||
assert retrieved is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nested_decorated_functions(self, db_session):
|
||||
"""Test nested decorated functions work correctly."""
|
||||
@transactional(propagation=TransactionPropagation.NESTED)
|
||||
async def inner_operation(db: AsyncSession, series_id: int):
|
||||
return await EpisodeService.create(
|
||||
db,
|
||||
series_id=series_id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
)
|
||||
|
||||
@transactional()
|
||||
async def outer_operation(db: AsyncSession):
|
||||
series = await AnimeSeriesService.create(
|
||||
db,
|
||||
key="nested-decorated",
|
||||
name="Nested Decorated",
|
||||
site="https://test.com",
|
||||
folder="/test",
|
||||
)
|
||||
episode = await inner_operation(db=db, series_id=series.id)
|
||||
return series, episode
|
||||
|
||||
series, episode = await outer_operation(db=db_session)
|
||||
|
||||
# Both should be committed
|
||||
assert series is not None
|
||||
assert episode is not None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Concurrent Transaction Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestConcurrentTransactions:
|
||||
"""Tests for concurrent transaction handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_writes_different_keys(self, session_factory):
|
||||
"""Test concurrent writes to different records."""
|
||||
async def create_series(key: str):
|
||||
async with session_factory() as session:
|
||||
async with atomic(session):
|
||||
await AnimeSeriesService.create(
|
||||
session,
|
||||
key=key,
|
||||
name=f"Series {key}",
|
||||
site="https://test.com",
|
||||
folder=f"/test/{key}",
|
||||
)
|
||||
|
||||
# Run concurrent creates
|
||||
await asyncio.gather(
|
||||
create_series("concurrent-1"),
|
||||
create_series("concurrent-2"),
|
||||
create_series("concurrent-3"),
|
||||
)
|
||||
|
||||
# Verify all created
|
||||
async with session_factory() as verify_session:
|
||||
for i in range(1, 4):
|
||||
series = await AnimeSeriesService.get_by_key(
|
||||
verify_session, f"concurrent-{i}"
|
||||
)
|
||||
assert series is not None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Queue Repository Transaction Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestQueueRepositoryTransactions:
|
||||
"""Integration tests for QueueRepository transaction handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_item_atomic(self, session_factory):
|
||||
"""Test save_item creates series, episode, and queue item atomically."""
|
||||
from src.server.models.download import (
|
||||
DownloadItem,
|
||||
DownloadStatus,
|
||||
EpisodeIdentifier,
|
||||
)
|
||||
from src.server.services.queue_repository import QueueRepository
|
||||
|
||||
repo = QueueRepository(session_factory)
|
||||
|
||||
item = DownloadItem(
|
||||
id="temp-id",
|
||||
serie_id="repo-atomic-test",
|
||||
serie_folder="/test/folder",
|
||||
serie_name="Repo Atomic Test",
|
||||
episode=EpisodeIdentifier(season=1, episode=1),
|
||||
status=DownloadStatus.PENDING,
|
||||
)
|
||||
|
||||
saved_item = await repo.save_item(item)
|
||||
|
||||
assert saved_item.id != "temp-id" # Should have DB ID
|
||||
|
||||
# Verify all entities created
|
||||
async with session_factory() as verify_session:
|
||||
series = await AnimeSeriesService.get_by_key(
|
||||
verify_session, "repo-atomic-test"
|
||||
)
|
||||
assert series is not None
|
||||
|
||||
episodes = await EpisodeService.get_by_series(
|
||||
verify_session, series.id
|
||||
)
|
||||
assert len(episodes) == 1
|
||||
|
||||
queue_items = await DownloadQueueService.get_all(verify_session)
|
||||
assert len(queue_items) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_all_atomic(self, session_factory):
|
||||
"""Test clear_all removes all items atomically."""
|
||||
from src.server.models.download import (
|
||||
DownloadItem,
|
||||
DownloadStatus,
|
||||
EpisodeIdentifier,
|
||||
)
|
||||
from src.server.services.queue_repository import QueueRepository
|
||||
|
||||
repo = QueueRepository(session_factory)
|
||||
|
||||
# Add some items
|
||||
for i in range(3):
|
||||
item = DownloadItem(
|
||||
id=f"clear-{i}",
|
||||
serie_id=f"clear-series-{i}",
|
||||
serie_folder=f"/test/folder/{i}",
|
||||
serie_name=f"Clear Series {i}",
|
||||
episode=EpisodeIdentifier(season=1, episode=1),
|
||||
status=DownloadStatus.PENDING,
|
||||
)
|
||||
await repo.save_item(item)
|
||||
|
||||
# Clear all
|
||||
count = await repo.clear_all()
|
||||
|
||||
assert count == 3
|
||||
|
||||
# Verify all cleared
|
||||
async with session_factory() as verify_session:
|
||||
queue_items = await DownloadQueueService.get_all(verify_session)
|
||||
assert len(queue_items) == 0
|
||||
Reference in New Issue
Block a user