- Create transaction.py with @transactional decorator, atomic() context manager - Add TransactionPropagation modes: REQUIRED, REQUIRES_NEW, NESTED - Add savepoint support for nested transactions with partial rollback - Update connection.py with TransactionManager, get_transactional_session - Update service.py with bulk operations (bulk_mark_downloaded, bulk_delete) - Wrap QueueRepository.save_item() and clear_all() in atomic transactions - Add comprehensive tests (66 transaction tests, 90% coverage) - All 1090 tests passing
547 lines
18 KiB
Python
547 lines
18 KiB
Python
"""Unit tests for service layer transaction behavior.
|
|
|
|
Tests that service operations correctly handle transactions,
|
|
especially compound operations that require atomicity.
|
|
"""
|
|
from datetime import datetime, timedelta, timezone
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|
from sqlalchemy.orm import sessionmaker
|
|
|
|
from src.server.database.base import Base
|
|
from src.server.database.models import (
|
|
AnimeSeries,
|
|
DownloadQueueItem,
|
|
Episode,
|
|
UserSession,
|
|
)
|
|
from src.server.database.service import (
|
|
AnimeSeriesService,
|
|
DownloadQueueService,
|
|
EpisodeService,
|
|
UserSessionService,
|
|
)
|
|
from src.server.database.transaction import atomic
|
|
|
|
# ============================================================================
|
|
# Fixtures
|
|
# ============================================================================
|
|
|
|
|
|
@pytest.fixture
|
|
async def db_engine():
|
|
"""Create in-memory database engine for testing."""
|
|
engine = create_async_engine(
|
|
"sqlite+aiosqlite:///:memory:",
|
|
echo=False,
|
|
)
|
|
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
yield engine
|
|
|
|
await engine.dispose()
|
|
|
|
|
|
@pytest.fixture
|
|
async def db_session(db_engine):
|
|
"""Create database session for testing."""
|
|
from sqlalchemy.ext.asyncio import async_sessionmaker
|
|
|
|
async_session = async_sessionmaker(
|
|
db_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
)
|
|
|
|
async with async_session() as session:
|
|
yield session
|
|
await session.rollback()
|
|
|
|
|
|
# ============================================================================
|
|
# AnimeSeriesService Transaction Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestAnimeSeriesServiceTransactions:
|
|
"""Tests for AnimeSeriesService transaction behavior."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_uses_flush_not_commit(self, db_session):
|
|
"""Test create uses flush for transaction compatibility."""
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="test-key",
|
|
name="Test Series",
|
|
site="https://test.com",
|
|
folder="/test/folder",
|
|
)
|
|
|
|
# Series should exist in session
|
|
assert series.id is not None
|
|
|
|
# But not committed yet (we're in an uncommitted transaction)
|
|
# We can verify by checking the session's uncommitted state
|
|
assert series in db_session
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_update_uses_flush_not_commit(self, db_session):
|
|
"""Test update uses flush for transaction compatibility."""
|
|
# Create series
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="update-test",
|
|
name="Original Name",
|
|
site="https://test.com",
|
|
folder="/test/folder",
|
|
)
|
|
|
|
# Update series
|
|
updated = await AnimeSeriesService.update(
|
|
db_session,
|
|
series.id,
|
|
name="Updated Name",
|
|
)
|
|
|
|
assert updated.name == "Updated Name"
|
|
assert updated in db_session
|
|
|
|
|
|
# ============================================================================
|
|
# EpisodeService Transaction Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestEpisodeServiceTransactions:
|
|
"""Tests for EpisodeService transaction behavior."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_bulk_mark_downloaded_atomicity(self, db_session):
|
|
"""Test bulk_mark_downloaded updates all or none."""
|
|
# Create series and episodes
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="bulk-test-series",
|
|
name="Bulk Test",
|
|
site="https://test.com",
|
|
folder="/test/folder",
|
|
)
|
|
|
|
episodes = []
|
|
for i in range(1, 4):
|
|
ep = await EpisodeService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
season=1,
|
|
episode_number=i,
|
|
title=f"Episode {i}",
|
|
)
|
|
episodes.append(ep)
|
|
|
|
episode_ids = [ep.id for ep in episodes]
|
|
file_paths = [f"/path/ep{i}.mp4" for i in range(1, 4)]
|
|
|
|
# Bulk update within atomic context
|
|
async with atomic(db_session):
|
|
count = await EpisodeService.bulk_mark_downloaded(
|
|
db_session,
|
|
episode_ids,
|
|
file_paths,
|
|
)
|
|
|
|
assert count == 3
|
|
|
|
# Verify all episodes were marked
|
|
for i, ep_id in enumerate(episode_ids):
|
|
episode = await EpisodeService.get_by_id(db_session, ep_id)
|
|
assert episode.is_downloaded is True
|
|
assert episode.file_path == file_paths[i]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_bulk_mark_downloaded_empty_list(self, db_session):
|
|
"""Test bulk_mark_downloaded handles empty list."""
|
|
count = await EpisodeService.bulk_mark_downloaded(
|
|
db_session,
|
|
episode_ids=[],
|
|
)
|
|
|
|
assert count == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_delete_by_series_and_episode_transaction(self, db_session):
|
|
"""Test delete_by_series_and_episode in transaction."""
|
|
# Create series and episode
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="delete-test-series",
|
|
name="Delete Test",
|
|
site="https://test.com",
|
|
folder="/test/folder",
|
|
)
|
|
|
|
await EpisodeService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
season=1,
|
|
episode_number=1,
|
|
title="Episode 1",
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Delete episode within transaction
|
|
async with atomic(db_session):
|
|
deleted = await EpisodeService.delete_by_series_and_episode(
|
|
db_session,
|
|
series_key="delete-test-series",
|
|
season=1,
|
|
episode_number=1,
|
|
)
|
|
|
|
assert deleted is True
|
|
|
|
# Verify episode is gone
|
|
episode = await EpisodeService.get_by_episode(
|
|
db_session,
|
|
series.id,
|
|
season=1,
|
|
episode_number=1,
|
|
)
|
|
assert episode is None
|
|
|
|
|
|
# ============================================================================
|
|
# DownloadQueueService Transaction Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestDownloadQueueServiceTransactions:
|
|
"""Tests for DownloadQueueService transaction behavior."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_bulk_delete_atomicity(self, db_session):
|
|
"""Test bulk_delete removes all or none."""
|
|
# Create series and episodes
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="queue-bulk-test",
|
|
name="Queue Bulk Test",
|
|
site="https://test.com",
|
|
folder="/test/folder",
|
|
)
|
|
|
|
item_ids = []
|
|
for i in range(1, 4):
|
|
episode = await EpisodeService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
season=1,
|
|
episode_number=i,
|
|
)
|
|
item = await DownloadQueueService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
episode_id=episode.id,
|
|
)
|
|
item_ids.append(item.id)
|
|
|
|
# Bulk delete within atomic context
|
|
async with atomic(db_session):
|
|
count = await DownloadQueueService.bulk_delete(
|
|
db_session,
|
|
item_ids,
|
|
)
|
|
|
|
assert count == 3
|
|
|
|
# Verify all items deleted
|
|
all_items = await DownloadQueueService.get_all(db_session)
|
|
assert len(all_items) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_bulk_delete_empty_list(self, db_session):
|
|
"""Test bulk_delete handles empty list."""
|
|
count = await DownloadQueueService.bulk_delete(
|
|
db_session,
|
|
item_ids=[],
|
|
)
|
|
|
|
assert count == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_clear_all_atomicity(self, db_session):
|
|
"""Test clear_all removes all items atomically."""
|
|
# Create series and queue items
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="clear-all-test",
|
|
name="Clear All Test",
|
|
site="https://test.com",
|
|
folder="/test/folder",
|
|
)
|
|
|
|
for i in range(1, 4):
|
|
episode = await EpisodeService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
season=1,
|
|
episode_number=i,
|
|
)
|
|
await DownloadQueueService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
episode_id=episode.id,
|
|
)
|
|
|
|
# Clear all within atomic context
|
|
async with atomic(db_session):
|
|
count = await DownloadQueueService.clear_all(db_session)
|
|
|
|
assert count == 3
|
|
|
|
# Verify all items cleared
|
|
all_items = await DownloadQueueService.get_all(db_session)
|
|
assert len(all_items) == 0
|
|
|
|
|
|
# ============================================================================
|
|
# UserSessionService Transaction Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestUserSessionServiceTransactions:
|
|
"""Tests for UserSessionService transaction behavior."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rotate_session_atomicity(self, db_session):
|
|
"""Test rotate_session is atomic (revoke + create)."""
|
|
# Create old session
|
|
old_session = await UserSessionService.create(
|
|
db_session,
|
|
session_id="old-session-123",
|
|
token_hash="old_hash",
|
|
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Rotate session within atomic context
|
|
async with atomic(db_session):
|
|
new_session = await UserSessionService.rotate_session(
|
|
db_session,
|
|
old_session_id="old-session-123",
|
|
new_session_id="new-session-456",
|
|
new_token_hash="new_hash",
|
|
new_expires_at=datetime.now(timezone.utc) + timedelta(hours=2),
|
|
)
|
|
|
|
assert new_session is not None
|
|
assert new_session.session_id == "new-session-456"
|
|
|
|
# Verify old session is revoked
|
|
old = await UserSessionService.get_by_session_id(
|
|
db_session, "old-session-123"
|
|
)
|
|
assert old.is_active is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rotate_session_old_not_found(self, db_session):
|
|
"""Test rotate_session returns None if old session not found."""
|
|
result = await UserSessionService.rotate_session(
|
|
db_session,
|
|
old_session_id="nonexistent-session",
|
|
new_session_id="new-session",
|
|
new_token_hash="hash",
|
|
new_expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
|
)
|
|
|
|
assert result is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_cleanup_expired_bulk_delete(self, db_session):
|
|
"""Test cleanup_expired removes all expired sessions."""
|
|
# Create expired sessions
|
|
for i in range(3):
|
|
await UserSessionService.create(
|
|
db_session,
|
|
session_id=f"expired-{i}",
|
|
token_hash=f"hash-{i}",
|
|
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
|
|
)
|
|
|
|
# Create active session
|
|
await UserSessionService.create(
|
|
db_session,
|
|
session_id="active-session",
|
|
token_hash="active_hash",
|
|
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Cleanup expired within atomic context
|
|
async with atomic(db_session):
|
|
count = await UserSessionService.cleanup_expired(db_session)
|
|
|
|
assert count == 3
|
|
|
|
# Verify active session still exists
|
|
active = await UserSessionService.get_by_session_id(
|
|
db_session, "active-session"
|
|
)
|
|
assert active is not None
|
|
|
|
|
|
# ============================================================================
|
|
# Compound Operation Rollback Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestCompoundOperationRollback:
|
|
"""Tests for rollback behavior in compound operations."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rollback_on_partial_failure(self, db_session):
|
|
"""Test rollback when compound operation fails mid-way."""
|
|
# Create initial series
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="rollback-test-series",
|
|
name="Rollback Test",
|
|
site="https://test.com",
|
|
folder="/test/folder",
|
|
)
|
|
await db_session.commit()
|
|
|
|
# Store the id before starting the transaction to avoid expired state access
|
|
series_id = series.id
|
|
|
|
try:
|
|
async with atomic(db_session):
|
|
# Create episode
|
|
episode = await EpisodeService.create(
|
|
db_session,
|
|
series_id=series_id,
|
|
season=1,
|
|
episode_number=1,
|
|
)
|
|
|
|
# Force flush to persist episode in transaction
|
|
await db_session.flush()
|
|
|
|
# Simulate failure mid-operation
|
|
raise ValueError("Simulated failure")
|
|
|
|
except ValueError:
|
|
pass
|
|
|
|
# Verify episode was NOT persisted
|
|
episode = await EpisodeService.get_by_episode(
|
|
db_session,
|
|
series_id,
|
|
season=1,
|
|
episode_number=1,
|
|
)
|
|
assert episode is None
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_no_orphan_data_on_failure(self, db_session):
|
|
"""Test no orphaned data when multi-service operation fails."""
|
|
try:
|
|
async with atomic(db_session):
|
|
# Create series
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="orphan-test-series",
|
|
name="Orphan Test",
|
|
site="https://test.com",
|
|
folder="/test/folder",
|
|
)
|
|
|
|
# Create episode
|
|
episode = await EpisodeService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
season=1,
|
|
episode_number=1,
|
|
)
|
|
|
|
# Create queue item
|
|
item = await DownloadQueueService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
episode_id=episode.id,
|
|
)
|
|
|
|
await db_session.flush()
|
|
|
|
# Fail after all creates
|
|
raise RuntimeError("Critical failure")
|
|
|
|
except RuntimeError:
|
|
pass
|
|
|
|
# Verify nothing was persisted
|
|
all_series = await AnimeSeriesService.get_all(db_session)
|
|
series_keys = [s.key for s in all_series]
|
|
assert "orphan-test-series" not in series_keys
|
|
|
|
|
|
# ============================================================================
|
|
# Nested Transaction Tests
|
|
# ============================================================================
|
|
|
|
|
|
class TestNestedTransactions:
|
|
"""Tests for nested transaction (savepoint) behavior."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_savepoint_partial_rollback(self, db_session):
|
|
"""Test savepoint allows partial rollback."""
|
|
# Create series
|
|
series = await AnimeSeriesService.create(
|
|
db_session,
|
|
key="savepoint-test",
|
|
name="Savepoint Test",
|
|
site="https://test.com",
|
|
folder="/test/folder",
|
|
)
|
|
|
|
async with atomic(db_session) as tx:
|
|
# Create first episode (should persist)
|
|
await EpisodeService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
season=1,
|
|
episode_number=1,
|
|
)
|
|
|
|
# Nested transaction for second episode
|
|
async with tx.savepoint() as sp:
|
|
await EpisodeService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
season=1,
|
|
episode_number=2,
|
|
)
|
|
|
|
# Rollback only the savepoint
|
|
await sp.rollback()
|
|
|
|
# Create third episode (should persist)
|
|
await EpisodeService.create(
|
|
db_session,
|
|
series_id=series.id,
|
|
season=1,
|
|
episode_number=3,
|
|
)
|
|
|
|
# Verify first and third episodes exist, second doesn't
|
|
episodes = await EpisodeService.get_by_series(db_session, series.id)
|
|
episode_numbers = [ep.episode_number for ep in episodes]
|
|
|
|
assert 1 in episode_numbers
|
|
assert 2 not in episode_numbers # Rolled back
|
|
assert 3 in episode_numbers
|