"""Integration tests for database transaction behavior. Tests real database transaction handling including: - Transaction isolation - Concurrent transaction handling - Real commit/rollback behavior """ import asyncio from datetime import datetime, timedelta, timezone from typing import List import pytest from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from src.server.database.base import Base from src.server.database.connection import ( TransactionManager, get_session_transaction_depth, is_session_in_transaction, ) from src.server.database.models import AnimeSeries, DownloadQueueItem, Episode from src.server.database.service import ( AnimeSeriesService, DownloadQueueService, EpisodeService, ) from src.server.database.transaction import ( TransactionPropagation, atomic, transactional, ) # ============================================================================ # Fixtures # ============================================================================ @pytest.fixture async def db_engine(): """Create in-memory database engine for testing.""" engine = create_async_engine( "sqlite+aiosqlite:///:memory:", echo=False, ) async with engine.begin() as conn: await conn.run_sync(Base.metadata.create_all) yield engine await engine.dispose() @pytest.fixture async def session_factory(db_engine): """Create session factory for testing.""" from sqlalchemy.ext.asyncio import async_sessionmaker return async_sessionmaker( db_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) @pytest.fixture async def db_session(session_factory): """Create database session for testing.""" async with session_factory() as session: yield session await session.rollback() # ============================================================================ # Real Database Transaction Tests # ============================================================================ class TestRealDatabaseTransactions: """Tests using real in-memory database.""" @pytest.mark.asyncio async def test_commit_persists_data(self, db_session): """Test that committed data is actually persisted.""" async with atomic(db_session): series = await AnimeSeriesService.create( db_session, key="commit-test", name="Commit Test Series", site="https://test.com", folder="/test/folder", ) # Data should be retrievable after commit retrieved = await AnimeSeriesService.get_by_key( db_session, "commit-test" ) assert retrieved is not None assert retrieved.name == "Commit Test Series" @pytest.mark.asyncio async def test_rollback_discards_data(self, db_session): """Test that rolled back data is discarded.""" try: async with atomic(db_session): series = await AnimeSeriesService.create( db_session, key="rollback-test", name="Rollback Test Series", site="https://test.com", folder="/test/folder", ) await db_session.flush() raise ValueError("Force rollback") except ValueError: pass # Data should NOT be retrievable after rollback retrieved = await AnimeSeriesService.get_by_key( db_session, "rollback-test" ) assert retrieved is None @pytest.mark.asyncio async def test_multiple_operations_atomic(self, db_session): """Test multiple operations are committed together.""" async with atomic(db_session): # Create series series = await AnimeSeriesService.create( db_session, key="atomic-multi-test", name="Atomic Multi Test", site="https://test.com", folder="/test/folder", ) # Create episode episode = await EpisodeService.create( db_session, series_id=series.id, season=1, episode_number=1, title="Episode 1", ) # Create queue item item = await DownloadQueueService.create( db_session, series_id=series.id, episode_id=episode.id, ) # All should be persisted retrieved_series = await AnimeSeriesService.get_by_key( db_session, "atomic-multi-test" ) assert retrieved_series is not None episodes = await EpisodeService.get_by_series( db_session, retrieved_series.id ) assert len(episodes) == 1 queue_items = await DownloadQueueService.get_all(db_session) assert len(queue_items) >= 1 @pytest.mark.asyncio async def test_multiple_operations_rollback_all(self, db_session): """Test multiple operations are all rolled back on failure.""" try: async with atomic(db_session): # Create series series = await AnimeSeriesService.create( db_session, key="rollback-multi-test", name="Rollback Multi Test", site="https://test.com", folder="/test/folder", ) # Create episode episode = await EpisodeService.create( db_session, series_id=series.id, season=1, episode_number=1, ) # Create queue item item = await DownloadQueueService.create( db_session, series_id=series.id, episode_id=episode.id, ) await db_session.flush() raise RuntimeError("Force complete rollback") except RuntimeError: pass # None should be persisted retrieved_series = await AnimeSeriesService.get_by_key( db_session, "rollback-multi-test" ) assert retrieved_series is None # ============================================================================ # Transaction Manager Tests # ============================================================================ class TestTransactionManager: """Tests for TransactionManager class.""" @pytest.mark.asyncio async def test_transaction_manager_basic_flow(self, session_factory): """Test basic transaction manager usage.""" async with TransactionManager(session_factory) as tm: session = await tm.get_session() await tm.begin() series = AnimeSeries( key="tm-test", name="TM Test", site="https://test.com", folder="/test", ) session.add(series) await tm.commit() # Verify data persisted async with session_factory() as verify_session: from sqlalchemy import select result = await verify_session.execute( select(AnimeSeries).where(AnimeSeries.key == "tm-test") ) series = result.scalar_one_or_none() assert series is not None @pytest.mark.asyncio async def test_transaction_manager_rollback(self, session_factory): """Test transaction manager rollback.""" async with TransactionManager(session_factory) as tm: session = await tm.get_session() await tm.begin() series = AnimeSeries( key="tm-rollback-test", name="TM Rollback Test", site="https://test.com", folder="/test", ) session.add(series) await session.flush() await tm.rollback() # Verify data NOT persisted async with session_factory() as verify_session: from sqlalchemy import select result = await verify_session.execute( select(AnimeSeries).where(AnimeSeries.key == "tm-rollback-test") ) series = result.scalar_one_or_none() assert series is None @pytest.mark.asyncio async def test_transaction_manager_auto_rollback_on_exception( self, session_factory ): """Test transaction manager auto-rolls back on exception.""" with pytest.raises(ValueError): async with TransactionManager(session_factory) as tm: session = await tm.get_session() await tm.begin() series = AnimeSeries( key="tm-auto-rollback", name="TM Auto Rollback", site="https://test.com", folder="/test", ) session.add(series) await session.flush() raise ValueError("Force exception") # Verify data NOT persisted async with session_factory() as verify_session: from sqlalchemy import select result = await verify_session.execute( select(AnimeSeries).where(AnimeSeries.key == "tm-auto-rollback") ) series = result.scalar_one_or_none() assert series is None @pytest.mark.asyncio async def test_transaction_manager_state_tracking(self, session_factory): """Test transaction manager tracks state correctly.""" async with TransactionManager(session_factory) as tm: assert tm.is_in_transaction() is False await tm.begin() assert tm.is_in_transaction() is True await tm.commit() assert tm.is_in_transaction() is False # ============================================================================ # Helper Function Tests # ============================================================================ class TestConnectionHelpers: """Tests for connection module helper functions.""" @pytest.mark.asyncio async def test_is_session_in_transaction(self, db_session): """Test is_session_in_transaction helper.""" # Initially not in transaction assert is_session_in_transaction(db_session) is False async with atomic(db_session): # Now in transaction assert is_session_in_transaction(db_session) is True # After exit, depends on session state # SQLite behavior may vary @pytest.mark.asyncio async def test_get_session_transaction_depth(self, db_session): """Test get_session_transaction_depth helper.""" depth = get_session_transaction_depth(db_session) assert depth >= 0 # ============================================================================ # @transactional Decorator Integration Tests # ============================================================================ class TestTransactionalDecoratorIntegration: """Integration tests for @transactional decorator.""" @pytest.mark.asyncio async def test_decorated_function_commits(self, db_session): """Test decorated function commits on success.""" @transactional() async def create_series_decorated(db: AsyncSession): return await AnimeSeriesService.create( db, key="decorated-test", name="Decorated Test", site="https://test.com", folder="/test", ) series = await create_series_decorated(db=db_session) # Verify committed retrieved = await AnimeSeriesService.get_by_key( db_session, "decorated-test" ) assert retrieved is not None @pytest.mark.asyncio async def test_decorated_function_rollback(self, db_session): """Test decorated function rolls back on error.""" @transactional() async def create_then_fail(db: AsyncSession): await AnimeSeriesService.create( db, key="decorated-rollback", name="Decorated Rollback", site="https://test.com", folder="/test", ) raise ValueError("Force failure") with pytest.raises(ValueError): await create_then_fail(db=db_session) # Verify NOT committed retrieved = await AnimeSeriesService.get_by_key( db_session, "decorated-rollback" ) assert retrieved is None @pytest.mark.asyncio async def test_nested_decorated_functions(self, db_session): """Test nested decorated functions work correctly.""" @transactional(propagation=TransactionPropagation.NESTED) async def inner_operation(db: AsyncSession, series_id: int): return await EpisodeService.create( db, series_id=series_id, season=1, episode_number=1, ) @transactional() async def outer_operation(db: AsyncSession): series = await AnimeSeriesService.create( db, key="nested-decorated", name="Nested Decorated", site="https://test.com", folder="/test", ) episode = await inner_operation(db=db, series_id=series.id) return series, episode series, episode = await outer_operation(db=db_session) # Both should be committed assert series is not None assert episode is not None # ============================================================================ # Concurrent Transaction Tests # ============================================================================ class TestConcurrentTransactions: """Tests for concurrent transaction handling.""" @pytest.mark.asyncio async def test_concurrent_writes_different_keys(self, session_factory): """Test concurrent writes to different records.""" async def create_series(key: str): async with session_factory() as session: async with atomic(session): await AnimeSeriesService.create( session, key=key, name=f"Series {key}", site="https://test.com", folder=f"/test/{key}", ) # Run concurrent creates await asyncio.gather( create_series("concurrent-1"), create_series("concurrent-2"), create_series("concurrent-3"), ) # Verify all created async with session_factory() as verify_session: for i in range(1, 4): series = await AnimeSeriesService.get_by_key( verify_session, f"concurrent-{i}" ) assert series is not None # ============================================================================ # Queue Repository Transaction Tests # ============================================================================ class TestQueueRepositoryTransactions: """Integration tests for QueueRepository transaction handling.""" @pytest.mark.asyncio async def test_save_item_atomic(self, session_factory): """Test save_item creates series, episode, and queue item atomically.""" from src.server.models.download import ( DownloadItem, DownloadStatus, EpisodeIdentifier, ) from src.server.services.queue_repository import QueueRepository repo = QueueRepository(session_factory) item = DownloadItem( id="temp-id", serie_id="repo-atomic-test", serie_folder="/test/folder", serie_name="Repo Atomic Test", episode=EpisodeIdentifier(season=1, episode=1), status=DownloadStatus.PENDING, ) saved_item = await repo.save_item(item) assert saved_item.id != "temp-id" # Should have DB ID # Verify all entities created async with session_factory() as verify_session: series = await AnimeSeriesService.get_by_key( verify_session, "repo-atomic-test" ) assert series is not None episodes = await EpisodeService.get_by_series( verify_session, series.id ) assert len(episodes) == 1 queue_items = await DownloadQueueService.get_all(verify_session) assert len(queue_items) >= 1 @pytest.mark.asyncio async def test_clear_all_atomic(self, session_factory): """Test clear_all removes all items atomically.""" from src.server.models.download import ( DownloadItem, DownloadStatus, EpisodeIdentifier, ) from src.server.services.queue_repository import QueueRepository repo = QueueRepository(session_factory) # Add some items for i in range(3): item = DownloadItem( id=f"clear-{i}", serie_id=f"clear-series-{i}", serie_folder=f"/test/folder/{i}", serie_name=f"Clear Series {i}", episode=EpisodeIdentifier(season=1, episode=1), status=DownloadStatus.PENDING, ) await repo.save_item(item) # Clear all count = await repo.clear_all() assert count == 3 # Verify all cleared async with session_factory() as verify_session: queue_items = await DownloadQueueService.get_all(verify_session) assert len(queue_items) == 0