"""Unit tests for service layer transaction behavior. Tests that service operations correctly handle transactions, especially compound operations that require atomicity. """ from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, MagicMock, patch import pytest from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker from src.server.database.base import Base from src.server.database.models import ( AnimeSeries, DownloadQueueItem, Episode, UserSession, ) from src.server.database.service import ( AnimeSeriesService, DownloadQueueService, EpisodeService, UserSessionService, ) from src.server.database.transaction import atomic # ============================================================================ # Fixtures # ============================================================================ @pytest.fixture async def db_engine(): """Create in-memory database engine for testing.""" engine = create_async_engine( "sqlite+aiosqlite:///:memory:", echo=False, ) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield engine await engine.dispose() @pytest.fixture async def db_session(db_engine): """Create database session for testing.""" from sqlalchemy.ext.asyncio import async_sessionmaker async_session = async_sessionmaker( db_engine, class_=AsyncSession, expire_on_commit=False, ) async with async_session() as session: yield session await session.rollback() # ============================================================================ # AnimeSeriesService Transaction Tests # ============================================================================ class TestAnimeSeriesServiceTransactions: """Tests for AnimeSeriesService transaction behavior.""" @pytest.mark.asyncio async def test_create_uses_flush_not_commit(self, db_session): """Test create uses flush for transaction compatibility.""" series = await AnimeSeriesService.create( db_session, key="test-key", name="Test Series", site="https://test.com", folder="/test/folder", ) # Series should exist in session assert series.id is not None # But not committed yet (we're in an uncommitted transaction) # We can verify by checking the session's uncommitted state assert series in db_session @pytest.mark.asyncio async def test_update_uses_flush_not_commit(self, db_session): """Test update uses flush for transaction compatibility.""" # Create series series = await AnimeSeriesService.create( db_session, key="update-test", name="Original Name", site="https://test.com", folder="/test/folder", ) # Update series updated = await AnimeSeriesService.update( db_session, series.id, name="Updated Name", ) assert updated.name == "Updated Name" assert updated in db_session # ============================================================================ # EpisodeService Transaction Tests # ============================================================================ class TestEpisodeServiceTransactions: """Tests for EpisodeService transaction behavior.""" @pytest.mark.asyncio async def test_bulk_mark_downloaded_atomicity(self, db_session): """Test bulk_mark_downloaded updates all or none.""" # Create series and episodes series = await AnimeSeriesService.create( db_session, key="bulk-test-series", name="Bulk Test", site="https://test.com", folder="/test/folder", ) episodes = [] for i in range(1, 4): ep = await EpisodeService.create( db_session, series_id=series.id, season=1, episode_number=i, title=f"Episode {i}", ) episodes.append(ep) episode_ids = [ep.id for ep in episodes] file_paths = [f"/path/ep{i}.mp4" for i in range(1, 4)] # Bulk update within atomic context async with atomic(db_session): count = await EpisodeService.bulk_mark_downloaded( db_session, episode_ids, file_paths, ) assert count == 3 # Verify all episodes were marked for i, ep_id in enumerate(episode_ids): episode = await EpisodeService.get_by_id(db_session, ep_id) assert episode.is_downloaded is True assert episode.file_path == file_paths[i] @pytest.mark.asyncio async def test_bulk_mark_downloaded_empty_list(self, db_session): """Test bulk_mark_downloaded handles empty list.""" count = await EpisodeService.bulk_mark_downloaded( db_session, episode_ids=[], ) assert count == 0 @pytest.mark.asyncio async def test_delete_by_series_and_episode_transaction(self, db_session): """Test delete_by_series_and_episode in transaction.""" # Create series and episode series = await AnimeSeriesService.create( db_session, key="delete-test-series", name="Delete Test", site="https://test.com", folder="/test/folder", ) await EpisodeService.create( db_session, series_id=series.id, season=1, episode_number=1, title="Episode 1", ) await db_session.commit() # Delete episode within transaction async with atomic(db_session): deleted = await EpisodeService.delete_by_series_and_episode( db_session, series_key="delete-test-series", season=1, episode_number=1, ) assert deleted is True # Verify episode is gone episode = await EpisodeService.get_by_episode( db_session, series.id, season=1, episode_number=1, ) assert episode is None # ============================================================================ # DownloadQueueService Transaction Tests # ============================================================================ class TestDownloadQueueServiceTransactions: """Tests for DownloadQueueService transaction behavior.""" @pytest.mark.asyncio async def test_bulk_delete_atomicity(self, db_session): """Test bulk_delete removes all or none.""" # Create series and episodes series = await AnimeSeriesService.create( db_session, key="queue-bulk-test", name="Queue Bulk Test", site="https://test.com", folder="/test/folder", ) item_ids = [] for i in range(1, 4): episode = await EpisodeService.create( db_session, series_id=series.id, season=1, episode_number=i, ) item = await DownloadQueueService.create( db_session, series_id=series.id, episode_id=episode.id, ) item_ids.append(item.id) # Bulk delete within atomic context async with atomic(db_session): count = await DownloadQueueService.bulk_delete( db_session, item_ids, ) assert count == 3 # Verify all items deleted all_items = await DownloadQueueService.get_all(db_session) assert len(all_items) == 0 @pytest.mark.asyncio async def test_bulk_delete_empty_list(self, db_session): """Test bulk_delete handles empty list.""" count = await DownloadQueueService.bulk_delete( db_session, item_ids=[], ) assert count == 0 @pytest.mark.asyncio async def test_clear_all_atomicity(self, db_session): """Test clear_all removes all items atomically.""" # Create series and queue items series = await AnimeSeriesService.create( db_session, key="clear-all-test", name="Clear All Test", site="https://test.com", folder="/test/folder", ) for i in range(1, 4): episode = await EpisodeService.create( db_session, series_id=series.id, season=1, episode_number=i, ) await DownloadQueueService.create( db_session, series_id=series.id, episode_id=episode.id, ) # Clear all within atomic context async with atomic(db_session): count = await DownloadQueueService.clear_all(db_session) assert count == 3 # Verify all items cleared all_items = await DownloadQueueService.get_all(db_session) assert len(all_items) == 0 # ============================================================================ # UserSessionService Transaction Tests # ============================================================================ class TestUserSessionServiceTransactions: """Tests for UserSessionService transaction behavior.""" @pytest.mark.asyncio async def test_rotate_session_atomicity(self, db_session): """Test rotate_session is atomic (revoke + create).""" # Create old session old_session = await UserSessionService.create( db_session, session_id="old-session-123", token_hash="old_hash", expires_at=datetime.now(timezone.utc) + timedelta(hours=1), ) await db_session.commit() # Rotate session within atomic context async with atomic(db_session): new_session = await UserSessionService.rotate_session( db_session, old_session_id="old-session-123", new_session_id="new-session-456", new_token_hash="new_hash", new_expires_at=datetime.now(timezone.utc) + timedelta(hours=2), ) assert new_session is not None assert new_session.session_id == "new-session-456" # Verify old session is revoked old = await UserSessionService.get_by_session_id( db_session, "old-session-123" ) assert old.is_active is False @pytest.mark.asyncio async def test_rotate_session_old_not_found(self, db_session): """Test rotate_session returns None if old session not found.""" result = await UserSessionService.rotate_session( db_session, old_session_id="nonexistent-session", new_session_id="new-session", new_token_hash="hash", new_expires_at=datetime.now(timezone.utc) + timedelta(hours=1), ) assert result is None @pytest.mark.asyncio async def test_cleanup_expired_bulk_delete(self, db_session): """Test cleanup_expired removes all expired sessions.""" # Create expired sessions for i in range(3): await UserSessionService.create( db_session, session_id=f"expired-{i}", token_hash=f"hash-{i}", expires_at=datetime.now(timezone.utc) - timedelta(hours=1), ) # Create active session await UserSessionService.create( db_session, session_id="active-session", token_hash="active_hash", expires_at=datetime.now(timezone.utc) + timedelta(hours=1), ) await db_session.commit() # Cleanup expired within atomic context async with atomic(db_session): count = await UserSessionService.cleanup_expired(db_session) assert count == 3 # Verify active session still exists active = await UserSessionService.get_by_session_id( db_session, "active-session" ) assert active is not None # ============================================================================ # Compound Operation Rollback Tests # ============================================================================ class TestCompoundOperationRollback: """Tests for rollback behavior in compound operations.""" @pytest.mark.asyncio async def test_rollback_on_partial_failure(self, db_session): """Test rollback when compound operation fails mid-way.""" # Create initial series series = await AnimeSeriesService.create( db_session, key="rollback-test-series", name="Rollback Test", site="https://test.com", folder="/test/folder", ) await db_session.commit() # Store the id before starting the transaction to avoid expired state access series_id = series.id try: async with atomic(db_session): # Create episode episode = await EpisodeService.create( db_session, series_id=series_id, season=1, episode_number=1, ) # Force flush to persist episode in transaction await db_session.flush() # Simulate failure mid-operation raise ValueError("Simulated failure") except ValueError: pass # Verify episode was NOT persisted episode = await EpisodeService.get_by_episode( db_session, series_id, season=1, episode_number=1, ) assert episode is None @pytest.mark.asyncio async def test_no_orphan_data_on_failure(self, db_session): """Test no orphaned data when multi-service operation fails.""" try: async with atomic(db_session): # Create series series = await AnimeSeriesService.create( db_session, key="orphan-test-series", name="Orphan Test", site="https://test.com", folder="/test/folder", ) # Create episode episode = await EpisodeService.create( db_session, series_id=series.id, season=1, episode_number=1, ) # Create queue item item = await DownloadQueueService.create( db_session, series_id=series.id, episode_id=episode.id, ) await db_session.flush() # Fail after all creates raise RuntimeError("Critical failure") except RuntimeError: pass # Verify nothing was persisted all_series = await AnimeSeriesService.get_all(db_session) series_keys = [s.key for s in all_series] assert "orphan-test-series" not in series_keys # ============================================================================ # Nested Transaction Tests # ============================================================================ class TestNestedTransactions: """Tests for nested transaction (savepoint) behavior.""" @pytest.mark.asyncio async def test_savepoint_partial_rollback(self, db_session): """Test savepoint allows partial rollback.""" # Create series series = await AnimeSeriesService.create( db_session, key="savepoint-test", name="Savepoint Test", site="https://test.com", folder="/test/folder", ) async with atomic(db_session) as tx: # Create first episode (should persist) await EpisodeService.create( db_session, series_id=series.id, season=1, episode_number=1, ) # Nested transaction for second episode async with tx.savepoint() as sp: await EpisodeService.create( db_session, series_id=series.id, season=1, episode_number=2, ) # Rollback only the savepoint await sp.rollback() # Create third episode (should persist) await EpisodeService.create( db_session, series_id=series.id, season=1, episode_number=3, ) # Verify first and third episodes exist, second doesn't episodes = await EpisodeService.get_by_series(db_session, series.id) episode_numbers = [ep.episode_number for ep in episodes] assert 1 in episode_numbers assert 2 not in episode_numbers # Rolled back assert 3 in episode_numbers