Add database transaction support with atomic operations
- Create transaction.py with @transactional decorator, atomic() context manager - Add TransactionPropagation modes: REQUIRED, REQUIRES_NEW, NESTED - Add savepoint support for nested transactions with partial rollback - Update connection.py with TransactionManager, get_transactional_session - Update service.py with bulk operations (bulk_mark_downloaded, bulk_delete) - Wrap QueueRepository.save_item() and clear_all() in atomic transactions - Add comprehensive tests (66 transaction tests, 90% coverage) - All 1090 tests passing
This commit is contained in:
546
tests/integration/test_db_transactions.py
Normal file
546
tests/integration/test_db_transactions.py
Normal file
@@ -0,0 +1,546 @@
|
||||
"""Integration tests for database transaction behavior.
|
||||
|
||||
Tests real database transaction handling including:
|
||||
- Transaction isolation
|
||||
- Concurrent transaction handling
|
||||
- Real commit/rollback behavior
|
||||
"""
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
|
||||
from src.server.database.base import Base
|
||||
from src.server.database.connection import (
|
||||
TransactionManager,
|
||||
get_session_transaction_depth,
|
||||
is_session_in_transaction,
|
||||
)
|
||||
from src.server.database.models import AnimeSeries, DownloadQueueItem, Episode
|
||||
from src.server.database.service import (
|
||||
AnimeSeriesService,
|
||||
DownloadQueueService,
|
||||
EpisodeService,
|
||||
)
|
||||
from src.server.database.transaction import (
|
||||
TransactionPropagation,
|
||||
atomic,
|
||||
transactional,
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
"""Create in-memory database engine for testing."""
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
echo=False,
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def session_factory(db_engine):
|
||||
"""Create session factory for testing."""
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
return async_sessionmaker(
|
||||
db_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
autocommit=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(session_factory):
|
||||
"""Create database session for testing."""
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Real Database Transaction Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestRealDatabaseTransactions:
|
||||
"""Tests using real in-memory database."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_commit_persists_data(self, db_session):
|
||||
"""Test that committed data is actually persisted."""
|
||||
async with atomic(db_session):
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
key="commit-test",
|
||||
name="Commit Test Series",
|
||||
site="https://test.com",
|
||||
folder="/test/folder",
|
||||
)
|
||||
|
||||
# Data should be retrievable after commit
|
||||
retrieved = await AnimeSeriesService.get_by_key(
|
||||
db_session, "commit-test"
|
||||
)
|
||||
assert retrieved is not None
|
||||
assert retrieved.name == "Commit Test Series"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rollback_discards_data(self, db_session):
|
||||
"""Test that rolled back data is discarded."""
|
||||
try:
|
||||
async with atomic(db_session):
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
key="rollback-test",
|
||||
name="Rollback Test Series",
|
||||
site="https://test.com",
|
||||
folder="/test/folder",
|
||||
)
|
||||
await db_session.flush()
|
||||
|
||||
raise ValueError("Force rollback")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Data should NOT be retrievable after rollback
|
||||
retrieved = await AnimeSeriesService.get_by_key(
|
||||
db_session, "rollback-test"
|
||||
)
|
||||
assert retrieved is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_operations_atomic(self, db_session):
|
||||
"""Test multiple operations are committed together."""
|
||||
async with atomic(db_session):
|
||||
# Create series
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
key="atomic-multi-test",
|
||||
name="Atomic Multi Test",
|
||||
site="https://test.com",
|
||||
folder="/test/folder",
|
||||
)
|
||||
|
||||
# Create episode
|
||||
episode = await EpisodeService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
title="Episode 1",
|
||||
)
|
||||
|
||||
# Create queue item
|
||||
item = await DownloadQueueService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
episode_id=episode.id,
|
||||
)
|
||||
|
||||
# All should be persisted
|
||||
retrieved_series = await AnimeSeriesService.get_by_key(
|
||||
db_session, "atomic-multi-test"
|
||||
)
|
||||
assert retrieved_series is not None
|
||||
|
||||
episodes = await EpisodeService.get_by_series(
|
||||
db_session, retrieved_series.id
|
||||
)
|
||||
assert len(episodes) == 1
|
||||
|
||||
queue_items = await DownloadQueueService.get_all(db_session)
|
||||
assert len(queue_items) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_operations_rollback_all(self, db_session):
|
||||
"""Test multiple operations are all rolled back on failure."""
|
||||
try:
|
||||
async with atomic(db_session):
|
||||
# Create series
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
key="rollback-multi-test",
|
||||
name="Rollback Multi Test",
|
||||
site="https://test.com",
|
||||
folder="/test/folder",
|
||||
)
|
||||
|
||||
# Create episode
|
||||
episode = await EpisodeService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
)
|
||||
|
||||
# Create queue item
|
||||
item = await DownloadQueueService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
episode_id=episode.id,
|
||||
)
|
||||
|
||||
await db_session.flush()
|
||||
|
||||
raise RuntimeError("Force complete rollback")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# None should be persisted
|
||||
retrieved_series = await AnimeSeriesService.get_by_key(
|
||||
db_session, "rollback-multi-test"
|
||||
)
|
||||
assert retrieved_series is None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Transaction Manager Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestTransactionManager:
|
||||
"""Tests for TransactionManager class."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transaction_manager_basic_flow(self, session_factory):
|
||||
"""Test basic transaction manager usage."""
|
||||
async with TransactionManager(session_factory) as tm:
|
||||
session = await tm.get_session()
|
||||
await tm.begin()
|
||||
|
||||
series = AnimeSeries(
|
||||
key="tm-test",
|
||||
name="TM Test",
|
||||
site="https://test.com",
|
||||
folder="/test",
|
||||
)
|
||||
session.add(series)
|
||||
|
||||
await tm.commit()
|
||||
|
||||
# Verify data persisted
|
||||
async with session_factory() as verify_session:
|
||||
from sqlalchemy import select
|
||||
result = await verify_session.execute(
|
||||
select(AnimeSeries).where(AnimeSeries.key == "tm-test")
|
||||
)
|
||||
series = result.scalar_one_or_none()
|
||||
assert series is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transaction_manager_rollback(self, session_factory):
|
||||
"""Test transaction manager rollback."""
|
||||
async with TransactionManager(session_factory) as tm:
|
||||
session = await tm.get_session()
|
||||
await tm.begin()
|
||||
|
||||
series = AnimeSeries(
|
||||
key="tm-rollback-test",
|
||||
name="TM Rollback Test",
|
||||
site="https://test.com",
|
||||
folder="/test",
|
||||
)
|
||||
session.add(series)
|
||||
await session.flush()
|
||||
|
||||
await tm.rollback()
|
||||
|
||||
# Verify data NOT persisted
|
||||
async with session_factory() as verify_session:
|
||||
from sqlalchemy import select
|
||||
result = await verify_session.execute(
|
||||
select(AnimeSeries).where(AnimeSeries.key == "tm-rollback-test")
|
||||
)
|
||||
series = result.scalar_one_or_none()
|
||||
assert series is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transaction_manager_auto_rollback_on_exception(
|
||||
self, session_factory
|
||||
):
|
||||
"""Test transaction manager auto-rolls back on exception."""
|
||||
with pytest.raises(ValueError):
|
||||
async with TransactionManager(session_factory) as tm:
|
||||
session = await tm.get_session()
|
||||
await tm.begin()
|
||||
|
||||
series = AnimeSeries(
|
||||
key="tm-auto-rollback",
|
||||
name="TM Auto Rollback",
|
||||
site="https://test.com",
|
||||
folder="/test",
|
||||
)
|
||||
session.add(series)
|
||||
await session.flush()
|
||||
|
||||
raise ValueError("Force exception")
|
||||
|
||||
# Verify data NOT persisted
|
||||
async with session_factory() as verify_session:
|
||||
from sqlalchemy import select
|
||||
result = await verify_session.execute(
|
||||
select(AnimeSeries).where(AnimeSeries.key == "tm-auto-rollback")
|
||||
)
|
||||
series = result.scalar_one_or_none()
|
||||
assert series is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_transaction_manager_state_tracking(self, session_factory):
|
||||
"""Test transaction manager tracks state correctly."""
|
||||
async with TransactionManager(session_factory) as tm:
|
||||
assert tm.is_in_transaction() is False
|
||||
|
||||
await tm.begin()
|
||||
assert tm.is_in_transaction() is True
|
||||
|
||||
await tm.commit()
|
||||
assert tm.is_in_transaction() is False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Helper Function Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestConnectionHelpers:
|
||||
"""Tests for connection module helper functions."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_session_in_transaction(self, db_session):
|
||||
"""Test is_session_in_transaction helper."""
|
||||
# Initially not in transaction
|
||||
assert is_session_in_transaction(db_session) is False
|
||||
|
||||
async with atomic(db_session):
|
||||
# Now in transaction
|
||||
assert is_session_in_transaction(db_session) is True
|
||||
|
||||
# After exit, depends on session state
|
||||
# SQLite behavior may vary
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_transaction_depth(self, db_session):
|
||||
"""Test get_session_transaction_depth helper."""
|
||||
depth = get_session_transaction_depth(db_session)
|
||||
assert depth >= 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# @transactional Decorator Integration Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestTransactionalDecoratorIntegration:
|
||||
"""Integration tests for @transactional decorator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorated_function_commits(self, db_session):
|
||||
"""Test decorated function commits on success."""
|
||||
@transactional()
|
||||
async def create_series_decorated(db: AsyncSession):
|
||||
return await AnimeSeriesService.create(
|
||||
db,
|
||||
key="decorated-test",
|
||||
name="Decorated Test",
|
||||
site="https://test.com",
|
||||
folder="/test",
|
||||
)
|
||||
|
||||
series = await create_series_decorated(db=db_session)
|
||||
|
||||
# Verify committed
|
||||
retrieved = await AnimeSeriesService.get_by_key(
|
||||
db_session, "decorated-test"
|
||||
)
|
||||
assert retrieved is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_decorated_function_rollback(self, db_session):
|
||||
"""Test decorated function rolls back on error."""
|
||||
@transactional()
|
||||
async def create_then_fail(db: AsyncSession):
|
||||
await AnimeSeriesService.create(
|
||||
db,
|
||||
key="decorated-rollback",
|
||||
name="Decorated Rollback",
|
||||
site="https://test.com",
|
||||
folder="/test",
|
||||
)
|
||||
raise ValueError("Force failure")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await create_then_fail(db=db_session)
|
||||
|
||||
# Verify NOT committed
|
||||
retrieved = await AnimeSeriesService.get_by_key(
|
||||
db_session, "decorated-rollback"
|
||||
)
|
||||
assert retrieved is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_nested_decorated_functions(self, db_session):
|
||||
"""Test nested decorated functions work correctly."""
|
||||
@transactional(propagation=TransactionPropagation.NESTED)
|
||||
async def inner_operation(db: AsyncSession, series_id: int):
|
||||
return await EpisodeService.create(
|
||||
db,
|
||||
series_id=series_id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
)
|
||||
|
||||
@transactional()
|
||||
async def outer_operation(db: AsyncSession):
|
||||
series = await AnimeSeriesService.create(
|
||||
db,
|
||||
key="nested-decorated",
|
||||
name="Nested Decorated",
|
||||
site="https://test.com",
|
||||
folder="/test",
|
||||
)
|
||||
episode = await inner_operation(db=db, series_id=series.id)
|
||||
return series, episode
|
||||
|
||||
series, episode = await outer_operation(db=db_session)
|
||||
|
||||
# Both should be committed
|
||||
assert series is not None
|
||||
assert episode is not None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Concurrent Transaction Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestConcurrentTransactions:
|
||||
"""Tests for concurrent transaction handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_writes_different_keys(self, session_factory):
|
||||
"""Test concurrent writes to different records."""
|
||||
async def create_series(key: str):
|
||||
async with session_factory() as session:
|
||||
async with atomic(session):
|
||||
await AnimeSeriesService.create(
|
||||
session,
|
||||
key=key,
|
||||
name=f"Series {key}",
|
||||
site="https://test.com",
|
||||
folder=f"/test/{key}",
|
||||
)
|
||||
|
||||
# Run concurrent creates
|
||||
await asyncio.gather(
|
||||
create_series("concurrent-1"),
|
||||
create_series("concurrent-2"),
|
||||
create_series("concurrent-3"),
|
||||
)
|
||||
|
||||
# Verify all created
|
||||
async with session_factory() as verify_session:
|
||||
for i in range(1, 4):
|
||||
series = await AnimeSeriesService.get_by_key(
|
||||
verify_session, f"concurrent-{i}"
|
||||
)
|
||||
assert series is not None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Queue Repository Transaction Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestQueueRepositoryTransactions:
|
||||
"""Integration tests for QueueRepository transaction handling."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_item_atomic(self, session_factory):
|
||||
"""Test save_item creates series, episode, and queue item atomically."""
|
||||
from src.server.models.download import (
|
||||
DownloadItem,
|
||||
DownloadStatus,
|
||||
EpisodeIdentifier,
|
||||
)
|
||||
from src.server.services.queue_repository import QueueRepository
|
||||
|
||||
repo = QueueRepository(session_factory)
|
||||
|
||||
item = DownloadItem(
|
||||
id="temp-id",
|
||||
serie_id="repo-atomic-test",
|
||||
serie_folder="/test/folder",
|
||||
serie_name="Repo Atomic Test",
|
||||
episode=EpisodeIdentifier(season=1, episode=1),
|
||||
status=DownloadStatus.PENDING,
|
||||
)
|
||||
|
||||
saved_item = await repo.save_item(item)
|
||||
|
||||
assert saved_item.id != "temp-id" # Should have DB ID
|
||||
|
||||
# Verify all entities created
|
||||
async with session_factory() as verify_session:
|
||||
series = await AnimeSeriesService.get_by_key(
|
||||
verify_session, "repo-atomic-test"
|
||||
)
|
||||
assert series is not None
|
||||
|
||||
episodes = await EpisodeService.get_by_series(
|
||||
verify_session, series.id
|
||||
)
|
||||
assert len(episodes) == 1
|
||||
|
||||
queue_items = await DownloadQueueService.get_all(verify_session)
|
||||
assert len(queue_items) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_all_atomic(self, session_factory):
|
||||
"""Test clear_all removes all items atomically."""
|
||||
from src.server.models.download import (
|
||||
DownloadItem,
|
||||
DownloadStatus,
|
||||
EpisodeIdentifier,
|
||||
)
|
||||
from src.server.services.queue_repository import QueueRepository
|
||||
|
||||
repo = QueueRepository(session_factory)
|
||||
|
||||
# Add some items
|
||||
for i in range(3):
|
||||
item = DownloadItem(
|
||||
id=f"clear-{i}",
|
||||
serie_id=f"clear-series-{i}",
|
||||
serie_folder=f"/test/folder/{i}",
|
||||
serie_name=f"Clear Series {i}",
|
||||
episode=EpisodeIdentifier(season=1, episode=1),
|
||||
status=DownloadStatus.PENDING,
|
||||
)
|
||||
await repo.save_item(item)
|
||||
|
||||
# Clear all
|
||||
count = await repo.clear_all()
|
||||
|
||||
assert count == 3
|
||||
|
||||
# Verify all cleared
|
||||
async with session_factory() as verify_session:
|
||||
queue_items = await DownloadQueueService.get_all(verify_session)
|
||||
assert len(queue_items) == 0
|
||||
546
tests/unit/test_service_transactions.py
Normal file
546
tests/unit/test_service_transactions.py
Normal file
@@ -0,0 +1,546 @@
|
||||
"""Unit tests for service layer transaction behavior.
|
||||
|
||||
Tests that service operations correctly handle transactions,
|
||||
especially compound operations that require atomicity.
|
||||
"""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from src.server.database.base import Base
|
||||
from src.server.database.models import (
|
||||
AnimeSeries,
|
||||
DownloadQueueItem,
|
||||
Episode,
|
||||
UserSession,
|
||||
)
|
||||
from src.server.database.service import (
|
||||
AnimeSeriesService,
|
||||
DownloadQueueService,
|
||||
EpisodeService,
|
||||
UserSessionService,
|
||||
)
|
||||
from src.server.database.transaction import atomic
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_engine():
|
||||
"""Create in-memory database engine for testing."""
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
echo=False,
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db_session(db_engine):
|
||||
"""Create database session for testing."""
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
async_session = async_sessionmaker(
|
||||
db_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
async with async_session() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# AnimeSeriesService Transaction Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestAnimeSeriesServiceTransactions:
|
||||
"""Tests for AnimeSeriesService transaction behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_uses_flush_not_commit(self, db_session):
|
||||
"""Test create uses flush for transaction compatibility."""
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
key="test-key",
|
||||
name="Test Series",
|
||||
site="https://test.com",
|
||||
folder="/test/folder",
|
||||
)
|
||||
|
||||
# Series should exist in session
|
||||
assert series.id is not None
|
||||
|
||||
# But not committed yet (we're in an uncommitted transaction)
|
||||
# We can verify by checking the session's uncommitted state
|
||||
assert series in db_session
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_uses_flush_not_commit(self, db_session):
|
||||
"""Test update uses flush for transaction compatibility."""
|
||||
# Create series
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
key="update-test",
|
||||
name="Original Name",
|
||||
site="https://test.com",
|
||||
folder="/test/folder",
|
||||
)
|
||||
|
||||
# Update series
|
||||
updated = await AnimeSeriesService.update(
|
||||
db_session,
|
||||
series.id,
|
||||
name="Updated Name",
|
||||
)
|
||||
|
||||
assert updated.name == "Updated Name"
|
||||
assert updated in db_session
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# EpisodeService Transaction Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestEpisodeServiceTransactions:
|
||||
"""Tests for EpisodeService transaction behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_mark_downloaded_atomicity(self, db_session):
|
||||
"""Test bulk_mark_downloaded updates all or none."""
|
||||
# Create series and episodes
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
key="bulk-test-series",
|
||||
name="Bulk Test",
|
||||
site="https://test.com",
|
||||
folder="/test/folder",
|
||||
)
|
||||
|
||||
episodes = []
|
||||
for i in range(1, 4):
|
||||
ep = await EpisodeService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=i,
|
||||
title=f"Episode {i}",
|
||||
)
|
||||
episodes.append(ep)
|
||||
|
||||
episode_ids = [ep.id for ep in episodes]
|
||||
file_paths = [f"/path/ep{i}.mp4" for i in range(1, 4)]
|
||||
|
||||
# Bulk update within atomic context
|
||||
async with atomic(db_session):
|
||||
count = await EpisodeService.bulk_mark_downloaded(
|
||||
db_session,
|
||||
episode_ids,
|
||||
file_paths,
|
||||
)
|
||||
|
||||
assert count == 3
|
||||
|
||||
# Verify all episodes were marked
|
||||
for i, ep_id in enumerate(episode_ids):
|
||||
episode = await EpisodeService.get_by_id(db_session, ep_id)
|
||||
assert episode.is_downloaded is True
|
||||
assert episode.file_path == file_paths[i]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_mark_downloaded_empty_list(self, db_session):
|
||||
"""Test bulk_mark_downloaded handles empty list."""
|
||||
count = await EpisodeService.bulk_mark_downloaded(
|
||||
db_session,
|
||||
episode_ids=[],
|
||||
)
|
||||
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_by_series_and_episode_transaction(self, db_session):
|
||||
"""Test delete_by_series_and_episode in transaction."""
|
||||
# Create series and episode
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
key="delete-test-series",
|
||||
name="Delete Test",
|
||||
site="https://test.com",
|
||||
folder="/test/folder",
|
||||
)
|
||||
|
||||
await EpisodeService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
title="Episode 1",
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
# Delete episode within transaction
|
||||
async with atomic(db_session):
|
||||
deleted = await EpisodeService.delete_by_series_and_episode(
|
||||
db_session,
|
||||
series_key="delete-test-series",
|
||||
season=1,
|
||||
episode_number=1,
|
||||
)
|
||||
|
||||
assert deleted is True
|
||||
|
||||
# Verify episode is gone
|
||||
episode = await EpisodeService.get_by_episode(
|
||||
db_session,
|
||||
series.id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
)
|
||||
assert episode is None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# DownloadQueueService Transaction Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestDownloadQueueServiceTransactions:
|
||||
"""Tests for DownloadQueueService transaction behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_delete_atomicity(self, db_session):
|
||||
"""Test bulk_delete removes all or none."""
|
||||
# Create series and episodes
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
key="queue-bulk-test",
|
||||
name="Queue Bulk Test",
|
||||
site="https://test.com",
|
||||
folder="/test/folder",
|
||||
)
|
||||
|
||||
item_ids = []
|
||||
for i in range(1, 4):
|
||||
episode = await EpisodeService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=i,
|
||||
)
|
||||
item = await DownloadQueueService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
episode_id=episode.id,
|
||||
)
|
||||
item_ids.append(item.id)
|
||||
|
||||
# Bulk delete within atomic context
|
||||
async with atomic(db_session):
|
||||
count = await DownloadQueueService.bulk_delete(
|
||||
db_session,
|
||||
item_ids,
|
||||
)
|
||||
|
||||
assert count == 3
|
||||
|
||||
# Verify all items deleted
|
||||
all_items = await DownloadQueueService.get_all(db_session)
|
||||
assert len(all_items) == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_delete_empty_list(self, db_session):
|
||||
"""Test bulk_delete handles empty list."""
|
||||
count = await DownloadQueueService.bulk_delete(
|
||||
db_session,
|
||||
item_ids=[],
|
||||
)
|
||||
|
||||
assert count == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_clear_all_atomicity(self, db_session):
|
||||
"""Test clear_all removes all items atomically."""
|
||||
# Create series and queue items
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
key="clear-all-test",
|
||||
name="Clear All Test",
|
||||
site="https://test.com",
|
||||
folder="/test/folder",
|
||||
)
|
||||
|
||||
for i in range(1, 4):
|
||||
episode = await EpisodeService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=i,
|
||||
)
|
||||
await DownloadQueueService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
episode_id=episode.id,
|
||||
)
|
||||
|
||||
# Clear all within atomic context
|
||||
async with atomic(db_session):
|
||||
count = await DownloadQueueService.clear_all(db_session)
|
||||
|
||||
assert count == 3
|
||||
|
||||
# Verify all items cleared
|
||||
all_items = await DownloadQueueService.get_all(db_session)
|
||||
assert len(all_items) == 0
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# UserSessionService Transaction Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestUserSessionServiceTransactions:
|
||||
"""Tests for UserSessionService transaction behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rotate_session_atomicity(self, db_session):
|
||||
"""Test rotate_session is atomic (revoke + create)."""
|
||||
# Create old session
|
||||
old_session = await UserSessionService.create(
|
||||
db_session,
|
||||
session_id="old-session-123",
|
||||
token_hash="old_hash",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
# Rotate session within atomic context
|
||||
async with atomic(db_session):
|
||||
new_session = await UserSessionService.rotate_session(
|
||||
db_session,
|
||||
old_session_id="old-session-123",
|
||||
new_session_id="new-session-456",
|
||||
new_token_hash="new_hash",
|
||||
new_expires_at=datetime.now(timezone.utc) + timedelta(hours=2),
|
||||
)
|
||||
|
||||
assert new_session is not None
|
||||
assert new_session.session_id == "new-session-456"
|
||||
|
||||
# Verify old session is revoked
|
||||
old = await UserSessionService.get_by_session_id(
|
||||
db_session, "old-session-123"
|
||||
)
|
||||
assert old.is_active is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rotate_session_old_not_found(self, db_session):
|
||||
"""Test rotate_session returns None if old session not found."""
|
||||
result = await UserSessionService.rotate_session(
|
||||
db_session,
|
||||
old_session_id="nonexistent-session",
|
||||
new_session_id="new-session",
|
||||
new_token_hash="hash",
|
||||
new_expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cleanup_expired_bulk_delete(self, db_session):
|
||||
"""Test cleanup_expired removes all expired sessions."""
|
||||
# Create expired sessions
|
||||
for i in range(3):
|
||||
await UserSessionService.create(
|
||||
db_session,
|
||||
session_id=f"expired-{i}",
|
||||
token_hash=f"hash-{i}",
|
||||
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
||||
)
|
||||
|
||||
# Create active session
|
||||
await UserSessionService.create(
|
||||
db_session,
|
||||
session_id="active-session",
|
||||
token_hash="active_hash",
|
||||
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
# Cleanup expired within atomic context
|
||||
async with atomic(db_session):
|
||||
count = await UserSessionService.cleanup_expired(db_session)
|
||||
|
||||
assert count == 3
|
||||
|
||||
# Verify active session still exists
|
||||
active = await UserSessionService.get_by_session_id(
|
||||
db_session, "active-session"
|
||||
)
|
||||
assert active is not None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Compound Operation Rollback Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestCompoundOperationRollback:
|
||||
"""Tests for rollback behavior in compound operations."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rollback_on_partial_failure(self, db_session):
|
||||
"""Test rollback when compound operation fails mid-way."""
|
||||
# Create initial series
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
key="rollback-test-series",
|
||||
name="Rollback Test",
|
||||
site="https://test.com",
|
||||
folder="/test/folder",
|
||||
)
|
||||
await db_session.commit()
|
||||
|
||||
# Store the id before starting the transaction to avoid expired state access
|
||||
series_id = series.id
|
||||
|
||||
try:
|
||||
async with atomic(db_session):
|
||||
# Create episode
|
||||
episode = await EpisodeService.create(
|
||||
db_session,
|
||||
series_id=series_id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
)
|
||||
|
||||
# Force flush to persist episode in transaction
|
||||
await db_session.flush()
|
||||
|
||||
# Simulate failure mid-operation
|
||||
raise ValueError("Simulated failure")
|
||||
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Verify episode was NOT persisted
|
||||
episode = await EpisodeService.get_by_episode(
|
||||
db_session,
|
||||
series_id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
)
|
||||
assert episode is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_orphan_data_on_failure(self, db_session):
|
||||
"""Test no orphaned data when multi-service operation fails."""
|
||||
try:
|
||||
async with atomic(db_session):
|
||||
# Create series
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
key="orphan-test-series",
|
||||
name="Orphan Test",
|
||||
site="https://test.com",
|
||||
folder="/test/folder",
|
||||
)
|
||||
|
||||
# Create episode
|
||||
episode = await EpisodeService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
)
|
||||
|
||||
# Create queue item
|
||||
item = await DownloadQueueService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
episode_id=episode.id,
|
||||
)
|
||||
|
||||
await db_session.flush()
|
||||
|
||||
# Fail after all creates
|
||||
raise RuntimeError("Critical failure")
|
||||
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
# Verify nothing was persisted
|
||||
all_series = await AnimeSeriesService.get_all(db_session)
|
||||
series_keys = [s.key for s in all_series]
|
||||
assert "orphan-test-series" not in series_keys
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Nested Transaction Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestNestedTransactions:
|
||||
"""Tests for nested transaction (savepoint) behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_savepoint_partial_rollback(self, db_session):
|
||||
"""Test savepoint allows partial rollback."""
|
||||
# Create series
|
||||
series = await AnimeSeriesService.create(
|
||||
db_session,
|
||||
key="savepoint-test",
|
||||
name="Savepoint Test",
|
||||
site="https://test.com",
|
||||
folder="/test/folder",
|
||||
)
|
||||
|
||||
async with atomic(db_session) as tx:
|
||||
# Create first episode (should persist)
|
||||
await EpisodeService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=1,
|
||||
)
|
||||
|
||||
# Nested transaction for second episode
|
||||
async with tx.savepoint() as sp:
|
||||
await EpisodeService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=2,
|
||||
)
|
||||
|
||||
# Rollback only the savepoint
|
||||
await sp.rollback()
|
||||
|
||||
# Create third episode (should persist)
|
||||
await EpisodeService.create(
|
||||
db_session,
|
||||
series_id=series.id,
|
||||
season=1,
|
||||
episode_number=3,
|
||||
)
|
||||
|
||||
# Verify first and third episodes exist, second doesn't
|
||||
episodes = await EpisodeService.get_by_series(db_session, series.id)
|
||||
episode_numbers = [ep.episode_number for ep in episodes]
|
||||
|
||||
assert 1 in episode_numbers
|
||||
assert 2 not in episode_numbers # Rolled back
|
||||
assert 3 in episode_numbers
|
||||
668
tests/unit/test_transactions.py
Normal file
668
tests/unit/test_transactions.py
Normal file
@@ -0,0 +1,668 @@
|
||||
"""Unit tests for database transaction utilities.
|
||||
|
||||
Tests the transaction management utilities including decorators,
|
||||
context managers, and helper functions.
|
||||
"""
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from src.server.database.base import Base
|
||||
from src.server.database.transaction import (
|
||||
AsyncTransactionContext,
|
||||
TransactionContext,
|
||||
TransactionError,
|
||||
TransactionPropagation,
|
||||
atomic,
|
||||
atomic_sync,
|
||||
is_in_transaction,
|
||||
transactional,
|
||||
)
|
||||
|
||||
# ============================================================================
|
||||
# Fixtures
|
||||
# ============================================================================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_engine():
|
||||
"""Create in-memory async database engine for testing."""
|
||||
engine = create_async_engine(
|
||||
"sqlite+aiosqlite:///:memory:",
|
||||
echo=False,
|
||||
)
|
||||
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(Base.metadata.create_all)
|
||||
|
||||
yield engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def async_session(async_engine):
|
||||
"""Create async database session for testing."""
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker
|
||||
|
||||
async_session_factory = async_sessionmaker(
|
||||
async_engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
)
|
||||
|
||||
async with async_session_factory() as session:
|
||||
yield session
|
||||
await session.rollback()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TransactionContext Tests (Sync)
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestTransactionContext:
|
||||
"""Tests for synchronous TransactionContext."""
|
||||
|
||||
def test_context_manager_protocol(self):
|
||||
"""Test context manager enters and exits properly."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session.in_transaction.return_value = False
|
||||
|
||||
with TransactionContext(mock_session) as ctx:
|
||||
assert ctx.session == mock_session
|
||||
mock_session.begin.assert_called_once()
|
||||
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_rollback_on_exception(self):
|
||||
"""Test rollback is called when exception occurs."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session.in_transaction.return_value = False
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with TransactionContext(mock_session):
|
||||
raise ValueError("Test error")
|
||||
|
||||
mock_session.rollback.assert_called_once()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
def test_no_begin_if_already_in_transaction(self):
|
||||
"""Test no new transaction started if already in one."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session.in_transaction.return_value = True
|
||||
|
||||
with TransactionContext(mock_session):
|
||||
pass
|
||||
|
||||
mock_session.begin.assert_not_called()
|
||||
|
||||
def test_explicit_commit(self):
|
||||
"""Test explicit commit within context."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session.in_transaction.return_value = False
|
||||
|
||||
with TransactionContext(mock_session) as ctx:
|
||||
ctx.commit()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
# Should not commit again on exit
|
||||
assert mock_session.commit.call_count == 1
|
||||
|
||||
def test_explicit_rollback(self):
|
||||
"""Test explicit rollback within context."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session.in_transaction.return_value = False
|
||||
|
||||
with TransactionContext(mock_session) as ctx:
|
||||
ctx.rollback()
|
||||
mock_session.rollback.assert_called_once()
|
||||
|
||||
# Should not commit after explicit rollback
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# AsyncTransactionContext Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestAsyncTransactionContext:
|
||||
"""Tests for asynchronous AsyncTransactionContext."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_context_manager_protocol(self):
|
||||
"""Test async context manager enters and exits properly."""
|
||||
mock_session = AsyncMock(spec=AsyncSession)
|
||||
mock_session.in_transaction.return_value = False
|
||||
mock_session.begin = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.rollback = AsyncMock()
|
||||
|
||||
async with AsyncTransactionContext(mock_session) as ctx:
|
||||
assert ctx.session == mock_session
|
||||
mock_session.begin.assert_called_once()
|
||||
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_rollback_on_exception(self):
|
||||
"""Test async rollback is called when exception occurs."""
|
||||
mock_session = AsyncMock(spec=AsyncSession)
|
||||
mock_session.in_transaction.return_value = False
|
||||
mock_session.begin = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.rollback = AsyncMock()
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async with AsyncTransactionContext(mock_session):
|
||||
raise ValueError("Test error")
|
||||
|
||||
mock_session.rollback.assert_called_once()
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_explicit_commit(self):
|
||||
"""Test async explicit commit within context."""
|
||||
mock_session = AsyncMock(spec=AsyncSession)
|
||||
mock_session.in_transaction.return_value = False
|
||||
mock_session.begin = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
async with AsyncTransactionContext(mock_session) as ctx:
|
||||
await ctx.commit()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
# Should not commit again on exit
|
||||
assert mock_session.commit.call_count == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_explicit_rollback(self):
|
||||
"""Test async explicit rollback within context."""
|
||||
mock_session = AsyncMock(spec=AsyncSession)
|
||||
mock_session.in_transaction.return_value = False
|
||||
mock_session.begin = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.rollback = AsyncMock()
|
||||
|
||||
async with AsyncTransactionContext(mock_session) as ctx:
|
||||
await ctx.rollback()
|
||||
mock_session.rollback.assert_called_once()
|
||||
|
||||
# Should not commit after explicit rollback
|
||||
mock_session.commit.assert_not_called()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# atomic() Context Manager Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestAtomicContextManager:
|
||||
"""Tests for atomic() async context manager."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atomic_commits_on_success(self):
|
||||
"""Test atomic commits transaction on success."""
|
||||
mock_session = AsyncMock(spec=AsyncSession)
|
||||
mock_session.in_transaction.return_value = False
|
||||
mock_session.begin = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.rollback = AsyncMock()
|
||||
|
||||
async with atomic(mock_session) as tx:
|
||||
pass
|
||||
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atomic_rollback_on_failure(self):
|
||||
"""Test atomic rolls back transaction on failure."""
|
||||
mock_session = AsyncMock(spec=AsyncSession)
|
||||
mock_session.in_transaction.return_value = False
|
||||
mock_session.begin = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.rollback = AsyncMock()
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
async with atomic(mock_session):
|
||||
raise RuntimeError("Operation failed")
|
||||
|
||||
mock_session.rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atomic_nested_propagation(self):
|
||||
"""Test atomic with NESTED propagation creates savepoint."""
|
||||
mock_session = AsyncMock(spec=AsyncSession)
|
||||
mock_session.in_transaction.return_value = True
|
||||
mock_nested = AsyncMock()
|
||||
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
|
||||
|
||||
async with atomic(
|
||||
mock_session, propagation=TransactionPropagation.NESTED
|
||||
):
|
||||
pass
|
||||
|
||||
mock_session.begin_nested.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_atomic_required_propagation_default(self):
|
||||
"""Test atomic uses REQUIRED propagation by default."""
|
||||
mock_session = AsyncMock(spec=AsyncSession)
|
||||
mock_session.in_transaction.return_value = False
|
||||
mock_session.begin = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
async with atomic(mock_session) as tx:
|
||||
# Should start new transaction
|
||||
mock_session.begin.assert_called_once()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# @transactional Decorator Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestTransactionalDecorator:
|
||||
"""Tests for @transactional decorator."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function_wrapped(self):
|
||||
"""Test async function is wrapped in transaction."""
|
||||
mock_session = AsyncMock(spec=AsyncSession)
|
||||
mock_session.in_transaction.return_value = False
|
||||
mock_session.begin = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.rollback = AsyncMock()
|
||||
|
||||
@transactional()
|
||||
async def sample_operation(db: AsyncSession):
|
||||
return "result"
|
||||
|
||||
result = await sample_operation(db=mock_session)
|
||||
|
||||
assert result == "result"
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_rollback_on_error(self):
|
||||
"""Test async function rollback on error."""
|
||||
mock_session = AsyncMock(spec=AsyncSession)
|
||||
mock_session.in_transaction.return_value = False
|
||||
mock_session.begin = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.rollback = AsyncMock()
|
||||
|
||||
@transactional()
|
||||
async def failing_operation(db: AsyncSession):
|
||||
raise ValueError("Operation failed")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await failing_operation(db=mock_session)
|
||||
|
||||
mock_session.rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_session_param_name(self):
|
||||
"""Test decorator with custom session parameter name."""
|
||||
mock_session = AsyncMock(spec=AsyncSession)
|
||||
mock_session.in_transaction.return_value = False
|
||||
mock_session.begin = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
|
||||
@transactional(session_param="session")
|
||||
async def operation_with_session(session: AsyncSession):
|
||||
return "done"
|
||||
|
||||
result = await operation_with_session(session=mock_session)
|
||||
|
||||
assert result == "done"
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_session_raises_error(self):
|
||||
"""Test error raised when session parameter not found."""
|
||||
@transactional()
|
||||
async def operation_no_session(data: dict):
|
||||
return data
|
||||
|
||||
with pytest.raises(TransactionError):
|
||||
await operation_no_session(data={"key": "value"})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_propagation_passed_to_atomic(self):
|
||||
"""Test propagation mode is passed to atomic."""
|
||||
mock_session = AsyncMock(spec=AsyncSession)
|
||||
mock_session.in_transaction.return_value = True
|
||||
mock_nested = AsyncMock()
|
||||
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
|
||||
|
||||
@transactional(propagation=TransactionPropagation.NESTED)
|
||||
async def nested_operation(db: AsyncSession):
|
||||
return "nested"
|
||||
|
||||
result = await nested_operation(db=mock_session)
|
||||
|
||||
assert result == "nested"
|
||||
mock_session.begin_nested.assert_called_once()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Sync transactional decorator Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestSyncTransactionalDecorator:
|
||||
"""Tests for @transactional decorator with sync functions."""
|
||||
|
||||
def test_sync_function_wrapped(self):
|
||||
"""Test sync function is wrapped in transaction."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session.in_transaction.return_value = False
|
||||
|
||||
@transactional()
|
||||
def sample_sync_operation(db: Session):
|
||||
return "sync_result"
|
||||
|
||||
result = sample_sync_operation(db=mock_session)
|
||||
|
||||
assert result == "sync_result"
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_sync_rollback_on_error(self):
|
||||
"""Test sync function rollback on error."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session.in_transaction.return_value = False
|
||||
|
||||
@transactional()
|
||||
def failing_sync_operation(db: Session):
|
||||
raise ValueError("Sync operation failed")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
failing_sync_operation(db=mock_session)
|
||||
|
||||
mock_session.rollback.assert_called_once()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Helper Function Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestHelperFunctions:
|
||||
"""Tests for transaction helper functions."""
|
||||
|
||||
def test_is_in_transaction_true(self):
|
||||
"""Test is_in_transaction returns True when in transaction."""
|
||||
mock_session = MagicMock()
|
||||
mock_session.in_transaction.return_value = True
|
||||
|
||||
assert is_in_transaction(mock_session) is True
|
||||
|
||||
def test_is_in_transaction_false(self):
|
||||
"""Test is_in_transaction returns False when not in transaction."""
|
||||
mock_session = MagicMock()
|
||||
mock_session.in_transaction.return_value = False
|
||||
|
||||
assert is_in_transaction(mock_session) is False
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Integration Tests with Real Database
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestTransactionIntegration:
|
||||
"""Integration tests using real in-memory database."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_transaction_commit(self, async_session):
|
||||
"""Test actual transaction commit with real session."""
|
||||
from src.server.database.models import AnimeSeries
|
||||
|
||||
async with atomic(async_session):
|
||||
series = AnimeSeries(
|
||||
key="test-series",
|
||||
name="Test Series",
|
||||
site="https://test.com",
|
||||
folder="/test/folder",
|
||||
)
|
||||
async_session.add(series)
|
||||
|
||||
# Verify data persisted
|
||||
from sqlalchemy import select
|
||||
result = await async_session.execute(
|
||||
select(AnimeSeries).where(AnimeSeries.key == "test-series")
|
||||
)
|
||||
saved_series = result.scalar_one_or_none()
|
||||
|
||||
assert saved_series is not None
|
||||
assert saved_series.name == "Test Series"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_transaction_rollback(self, async_session):
|
||||
"""Test actual transaction rollback with real session."""
|
||||
from src.server.database.models import AnimeSeries
|
||||
|
||||
try:
|
||||
async with atomic(async_session):
|
||||
series = AnimeSeries(
|
||||
key="rollback-series",
|
||||
name="Rollback Series",
|
||||
site="https://test.com",
|
||||
folder="/test/folder",
|
||||
)
|
||||
async_session.add(series)
|
||||
await async_session.flush()
|
||||
|
||||
# Force rollback
|
||||
raise ValueError("Simulated error")
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Verify data was NOT persisted
|
||||
from sqlalchemy import select
|
||||
result = await async_session.execute(
|
||||
select(AnimeSeries).where(AnimeSeries.key == "rollback-series")
|
||||
)
|
||||
saved_series = result.scalar_one_or_none()
|
||||
|
||||
assert saved_series is None
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# TransactionPropagation Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestTransactionPropagation:
|
||||
"""Tests for transaction propagation modes."""
|
||||
|
||||
def test_propagation_enum_values(self):
|
||||
"""Test propagation enum has correct values."""
|
||||
assert TransactionPropagation.REQUIRED.value == "required"
|
||||
assert TransactionPropagation.REQUIRES_NEW.value == "requires_new"
|
||||
assert TransactionPropagation.NESTED.value == "nested"
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Additional Coverage Tests
|
||||
# ============================================================================
|
||||
|
||||
|
||||
class TestSyncSavepointCoverage:
|
||||
"""Additional tests for sync savepoint coverage."""
|
||||
|
||||
def test_savepoint_exception_rolls_back(self):
|
||||
"""Test savepoint rollback when exception occurs within savepoint."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session.in_transaction.return_value = False
|
||||
mock_nested = MagicMock()
|
||||
mock_session.begin_nested.return_value = mock_nested
|
||||
|
||||
with TransactionContext(mock_session) as ctx:
|
||||
with pytest.raises(ValueError):
|
||||
with ctx.savepoint() as sp:
|
||||
raise ValueError("Error in savepoint")
|
||||
|
||||
# Nested transaction should have been rolled back
|
||||
mock_nested.rollback.assert_called_once()
|
||||
|
||||
def test_savepoint_commit_explicit(self):
|
||||
"""Test explicit commit on savepoint."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session.in_transaction.return_value = False
|
||||
mock_nested = MagicMock()
|
||||
mock_session.begin_nested.return_value = mock_nested
|
||||
|
||||
with TransactionContext(mock_session) as ctx:
|
||||
with ctx.savepoint() as sp:
|
||||
sp.commit()
|
||||
# Commit should just log, SQLAlchemy handles actual commit
|
||||
|
||||
|
||||
class TestAsyncSavepointCoverage:
|
||||
"""Additional tests for async savepoint coverage."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_savepoint_exception_rolls_back(self):
|
||||
"""Test async savepoint rollback when exception occurs."""
|
||||
mock_session = AsyncMock(spec=AsyncSession)
|
||||
mock_session.in_transaction.return_value = False
|
||||
mock_session.begin = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.rollback = AsyncMock()
|
||||
mock_nested = AsyncMock()
|
||||
mock_nested.rollback = AsyncMock()
|
||||
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
|
||||
|
||||
async with AsyncTransactionContext(mock_session) as ctx:
|
||||
with pytest.raises(ValueError):
|
||||
async with ctx.savepoint() as sp:
|
||||
raise ValueError("Error in async savepoint")
|
||||
|
||||
# Nested transaction should have been rolled back
|
||||
mock_nested.rollback.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_savepoint_commit_explicit(self):
|
||||
"""Test explicit commit on async savepoint."""
|
||||
mock_session = AsyncMock(spec=AsyncSession)
|
||||
mock_session.in_transaction.return_value = False
|
||||
mock_session.begin = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_nested = AsyncMock()
|
||||
mock_session.begin_nested = AsyncMock(return_value=mock_nested)
|
||||
|
||||
async with AsyncTransactionContext(mock_session) as ctx:
|
||||
async with ctx.savepoint() as sp:
|
||||
await sp.commit()
|
||||
# Commit should just log, SQLAlchemy handles actual commit
|
||||
|
||||
|
||||
class TestAtomicNestedPropagationNoTransaction:
|
||||
"""Tests for NESTED propagation when not in transaction."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_nested_starts_new_when_not_in_transaction(self):
|
||||
"""Test NESTED propagation starts new transaction when none exists."""
|
||||
mock_session = AsyncMock(spec=AsyncSession)
|
||||
mock_session.in_transaction.return_value = False
|
||||
mock_session.begin = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.rollback = AsyncMock()
|
||||
|
||||
async with atomic(mock_session, TransactionPropagation.NESTED) as tx:
|
||||
# Should start new transaction since none exists
|
||||
pass
|
||||
|
||||
mock_session.begin.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_sync_nested_starts_new_when_not_in_transaction(self):
|
||||
"""Test sync NESTED propagation starts new transaction when none exists."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session.in_transaction.return_value = False
|
||||
|
||||
with atomic_sync(mock_session, TransactionPropagation.NESTED) as tx:
|
||||
pass
|
||||
|
||||
mock_session.begin.assert_called_once()
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
|
||||
class TestGetTransactionDepth:
|
||||
"""Tests for get_transaction_depth helper."""
|
||||
|
||||
def test_depth_zero_when_not_in_transaction(self):
|
||||
"""Test depth is 0 when not in transaction."""
|
||||
from src.server.database.transaction import get_transaction_depth
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session.in_transaction.return_value = False
|
||||
|
||||
depth = get_transaction_depth(mock_session)
|
||||
assert depth == 0
|
||||
|
||||
def test_depth_one_in_transaction(self):
|
||||
"""Test depth is 1 in basic transaction."""
|
||||
from src.server.database.transaction import get_transaction_depth
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session.in_transaction.return_value = True
|
||||
mock_session._nested_transaction = None
|
||||
|
||||
depth = get_transaction_depth(mock_session)
|
||||
assert depth == 1
|
||||
|
||||
def test_depth_two_with_nested_transaction(self):
|
||||
"""Test depth is 2 with nested transaction."""
|
||||
from src.server.database.transaction import get_transaction_depth
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session.in_transaction.return_value = True
|
||||
mock_session._nested_transaction = MagicMock() # Has nested
|
||||
|
||||
depth = get_transaction_depth(mock_session)
|
||||
assert depth == 2
|
||||
|
||||
|
||||
class TestTransactionalDecoratorPositionalArgs:
|
||||
"""Tests for transactional decorator with positional arguments."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_from_positional_arg(self):
|
||||
"""Test decorator extracts session from positional argument."""
|
||||
mock_session = AsyncMock(spec=AsyncSession)
|
||||
mock_session.in_transaction.return_value = False
|
||||
mock_session.begin = AsyncMock()
|
||||
mock_session.commit = AsyncMock()
|
||||
mock_session.rollback = AsyncMock()
|
||||
|
||||
@transactional()
|
||||
async def operation(db: AsyncSession, data: str):
|
||||
return f"processed: {data}"
|
||||
|
||||
# Pass session as positional argument
|
||||
result = await operation(mock_session, "test")
|
||||
|
||||
assert result == "processed: test"
|
||||
mock_session.commit.assert_called_once()
|
||||
|
||||
def test_sync_session_from_positional_arg(self):
|
||||
"""Test sync decorator extracts session from positional argument."""
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session.in_transaction.return_value = False
|
||||
|
||||
@transactional()
|
||||
def operation(db: Session, data: str):
|
||||
return f"processed: {data}"
|
||||
|
||||
result = operation(mock_session, "test")
|
||||
|
||||
assert result == "processed: test"
|
||||
mock_session.commit.assert_called_once()
|
||||
Reference in New Issue
Block a user