Aniworld/tests/unit/test_service_transactions.py
Lukas 1ba67357dc Add database transaction support with atomic operations
- Create transaction.py with @transactional decorator, atomic() context manager
- Add TransactionPropagation modes: REQUIRED, REQUIRES_NEW, NESTED
- Add savepoint support for nested transactions with partial rollback
- Update connection.py with TransactionManager, get_transactional_session
- Update service.py with bulk operations (bulk_mark_downloaded, bulk_delete)
- Wrap QueueRepository.save_item() and clear_all() in atomic transactions
- Add comprehensive tests (66 transaction tests, 90% coverage)
- All 1090 tests passing
2025-12-25 18:05:33 +01:00

547 lines
18 KiB
Python

"""Unit tests for service layer transaction behavior.
Tests that service operations correctly handle transactions,
especially compound operations that require atomicity.
"""
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from src.server.database.base import Base
from src.server.database.models import (
AnimeSeries,
DownloadQueueItem,
Episode,
UserSession,
)
from src.server.database.service import (
AnimeSeriesService,
DownloadQueueService,
EpisodeService,
UserSessionService,
)
from src.server.database.transaction import atomic
# ============================================================================
# Fixtures
# ============================================================================
@pytest.fixture
async def db_engine():
"""Create in-memory database engine for testing."""
engine = create_async_engine(
"sqlite+aiosqlite:///:memory:",
echo=False,
)
async with engine.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
yield engine
await engine.dispose()
@pytest.fixture
async def db_session(db_engine):
"""Create database session for testing."""
from sqlalchemy.ext.asyncio import async_sessionmaker
async_session = async_sessionmaker(
db_engine,
class_=AsyncSession,
expire_on_commit=False,
)
async with async_session() as session:
yield session
await session.rollback()
# ============================================================================
# AnimeSeriesService Transaction Tests
# ============================================================================
class TestAnimeSeriesServiceTransactions:
"""Tests for AnimeSeriesService transaction behavior."""
@pytest.mark.asyncio
async def test_create_uses_flush_not_commit(self, db_session):
"""Test create uses flush for transaction compatibility."""
series = await AnimeSeriesService.create(
db_session,
key="test-key",
name="Test Series",
site="https://test.com",
folder="/test/folder",
)
# Series should exist in session
assert series.id is not None
# But not committed yet (we're in an uncommitted transaction)
# We can verify by checking the session's uncommitted state
assert series in db_session
@pytest.mark.asyncio
async def test_update_uses_flush_not_commit(self, db_session):
"""Test update uses flush for transaction compatibility."""
# Create series
series = await AnimeSeriesService.create(
db_session,
key="update-test",
name="Original Name",
site="https://test.com",
folder="/test/folder",
)
# Update series
updated = await AnimeSeriesService.update(
db_session,
series.id,
name="Updated Name",
)
assert updated.name == "Updated Name"
assert updated in db_session
# ============================================================================
# EpisodeService Transaction Tests
# ============================================================================
class TestEpisodeServiceTransactions:
"""Tests for EpisodeService transaction behavior."""
@pytest.mark.asyncio
async def test_bulk_mark_downloaded_atomicity(self, db_session):
"""Test bulk_mark_downloaded updates all or none."""
# Create series and episodes
series = await AnimeSeriesService.create(
db_session,
key="bulk-test-series",
name="Bulk Test",
site="https://test.com",
folder="/test/folder",
)
episodes = []
for i in range(1, 4):
ep = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=i,
title=f"Episode {i}",
)
episodes.append(ep)
episode_ids = [ep.id for ep in episodes]
file_paths = [f"/path/ep{i}.mp4" for i in range(1, 4)]
# Bulk update within atomic context
async with atomic(db_session):
count = await EpisodeService.bulk_mark_downloaded(
db_session,
episode_ids,
file_paths,
)
assert count == 3
# Verify all episodes were marked
for i, ep_id in enumerate(episode_ids):
episode = await EpisodeService.get_by_id(db_session, ep_id)
assert episode.is_downloaded is True
assert episode.file_path == file_paths[i]
@pytest.mark.asyncio
async def test_bulk_mark_downloaded_empty_list(self, db_session):
"""Test bulk_mark_downloaded handles empty list."""
count = await EpisodeService.bulk_mark_downloaded(
db_session,
episode_ids=[],
)
assert count == 0
@pytest.mark.asyncio
async def test_delete_by_series_and_episode_transaction(self, db_session):
"""Test delete_by_series_and_episode in transaction."""
# Create series and episode
series = await AnimeSeriesService.create(
db_session,
key="delete-test-series",
name="Delete Test",
site="https://test.com",
folder="/test/folder",
)
await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
title="Episode 1",
)
await db_session.commit()
# Delete episode within transaction
async with atomic(db_session):
deleted = await EpisodeService.delete_by_series_and_episode(
db_session,
series_key="delete-test-series",
season=1,
episode_number=1,
)
assert deleted is True
# Verify episode is gone
episode = await EpisodeService.get_by_episode(
db_session,
series.id,
season=1,
episode_number=1,
)
assert episode is None
# ============================================================================
# DownloadQueueService Transaction Tests
# ============================================================================
class TestDownloadQueueServiceTransactions:
"""Tests for DownloadQueueService transaction behavior."""
@pytest.mark.asyncio
async def test_bulk_delete_atomicity(self, db_session):
"""Test bulk_delete removes all or none."""
# Create series and episodes
series = await AnimeSeriesService.create(
db_session,
key="queue-bulk-test",
name="Queue Bulk Test",
site="https://test.com",
folder="/test/folder",
)
item_ids = []
for i in range(1, 4):
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=i,
)
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
item_ids.append(item.id)
# Bulk delete within atomic context
async with atomic(db_session):
count = await DownloadQueueService.bulk_delete(
db_session,
item_ids,
)
assert count == 3
# Verify all items deleted
all_items = await DownloadQueueService.get_all(db_session)
assert len(all_items) == 0
@pytest.mark.asyncio
async def test_bulk_delete_empty_list(self, db_session):
"""Test bulk_delete handles empty list."""
count = await DownloadQueueService.bulk_delete(
db_session,
item_ids=[],
)
assert count == 0
@pytest.mark.asyncio
async def test_clear_all_atomicity(self, db_session):
"""Test clear_all removes all items atomically."""
# Create series and queue items
series = await AnimeSeriesService.create(
db_session,
key="clear-all-test",
name="Clear All Test",
site="https://test.com",
folder="/test/folder",
)
for i in range(1, 4):
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=i,
)
await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
# Clear all within atomic context
async with atomic(db_session):
count = await DownloadQueueService.clear_all(db_session)
assert count == 3
# Verify all items cleared
all_items = await DownloadQueueService.get_all(db_session)
assert len(all_items) == 0
# ============================================================================
# UserSessionService Transaction Tests
# ============================================================================
class TestUserSessionServiceTransactions:
"""Tests for UserSessionService transaction behavior."""
@pytest.mark.asyncio
async def test_rotate_session_atomicity(self, db_session):
"""Test rotate_session is atomic (revoke + create)."""
# Create old session
old_session = await UserSessionService.create(
db_session,
session_id="old-session-123",
token_hash="old_hash",
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
)
await db_session.commit()
# Rotate session within atomic context
async with atomic(db_session):
new_session = await UserSessionService.rotate_session(
db_session,
old_session_id="old-session-123",
new_session_id="new-session-456",
new_token_hash="new_hash",
new_expires_at=datetime.now(timezone.utc) + timedelta(hours=2),
)
assert new_session is not None
assert new_session.session_id == "new-session-456"
# Verify old session is revoked
old = await UserSessionService.get_by_session_id(
db_session, "old-session-123"
)
assert old.is_active is False
@pytest.mark.asyncio
async def test_rotate_session_old_not_found(self, db_session):
"""Test rotate_session returns None if old session not found."""
result = await UserSessionService.rotate_session(
db_session,
old_session_id="nonexistent-session",
new_session_id="new-session",
new_token_hash="hash",
new_expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
)
assert result is None
@pytest.mark.asyncio
async def test_cleanup_expired_bulk_delete(self, db_session):
"""Test cleanup_expired removes all expired sessions."""
# Create expired sessions
for i in range(3):
await UserSessionService.create(
db_session,
session_id=f"expired-{i}",
token_hash=f"hash-{i}",
expires_at=datetime.now(timezone.utc) - timedelta(hours=1),
)
# Create active session
await UserSessionService.create(
db_session,
session_id="active-session",
token_hash="active_hash",
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
)
await db_session.commit()
# Cleanup expired within atomic context
async with atomic(db_session):
count = await UserSessionService.cleanup_expired(db_session)
assert count == 3
# Verify active session still exists
active = await UserSessionService.get_by_session_id(
db_session, "active-session"
)
assert active is not None
# ============================================================================
# Compound Operation Rollback Tests
# ============================================================================
class TestCompoundOperationRollback:
"""Tests for rollback behavior in compound operations."""
@pytest.mark.asyncio
async def test_rollback_on_partial_failure(self, db_session):
"""Test rollback when compound operation fails mid-way."""
# Create initial series
series = await AnimeSeriesService.create(
db_session,
key="rollback-test-series",
name="Rollback Test",
site="https://test.com",
folder="/test/folder",
)
await db_session.commit()
# Store the id before starting the transaction to avoid expired state access
series_id = series.id
try:
async with atomic(db_session):
# Create episode
episode = await EpisodeService.create(
db_session,
series_id=series_id,
season=1,
episode_number=1,
)
# Force flush to persist episode in transaction
await db_session.flush()
# Simulate failure mid-operation
raise ValueError("Simulated failure")
except ValueError:
pass
# Verify episode was NOT persisted
episode = await EpisodeService.get_by_episode(
db_session,
series_id,
season=1,
episode_number=1,
)
assert episode is None
@pytest.mark.asyncio
async def test_no_orphan_data_on_failure(self, db_session):
"""Test no orphaned data when multi-service operation fails."""
try:
async with atomic(db_session):
# Create series
series = await AnimeSeriesService.create(
db_session,
key="orphan-test-series",
name="Orphan Test",
site="https://test.com",
folder="/test/folder",
)
# Create episode
episode = await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
# Create queue item
item = await DownloadQueueService.create(
db_session,
series_id=series.id,
episode_id=episode.id,
)
await db_session.flush()
# Fail after all creates
raise RuntimeError("Critical failure")
except RuntimeError:
pass
# Verify nothing was persisted
all_series = await AnimeSeriesService.get_all(db_session)
series_keys = [s.key for s in all_series]
assert "orphan-test-series" not in series_keys
# ============================================================================
# Nested Transaction Tests
# ============================================================================
class TestNestedTransactions:
"""Tests for nested transaction (savepoint) behavior."""
@pytest.mark.asyncio
async def test_savepoint_partial_rollback(self, db_session):
"""Test savepoint allows partial rollback."""
# Create series
series = await AnimeSeriesService.create(
db_session,
key="savepoint-test",
name="Savepoint Test",
site="https://test.com",
folder="/test/folder",
)
async with atomic(db_session) as tx:
# Create first episode (should persist)
await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=1,
)
# Nested transaction for second episode
async with tx.savepoint() as sp:
await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=2,
)
# Rollback only the savepoint
await sp.rollback()
# Create third episode (should persist)
await EpisodeService.create(
db_session,
series_id=series.id,
season=1,
episode_number=3,
)
# Verify first and third episodes exist, second doesn't
episodes = await EpisodeService.get_by_series(db_session, series.id)
episode_numbers = [ep.episode_number for ep in episodes]
assert 1 in episode_numbers
assert 2 not in episode_numbers # Rolled back
assert 3 in episode_numbers