"""Unit tests for database connection module. Tests cover engine/session lifecycle, utility functions, TransactionManager, SavepointHandle, and various error paths. """ from __future__ import annotations from unittest.mock import AsyncMock, MagicMock, patch import pytest import src.server.database.connection as conn_mod from src.server.database.connection import ( SavepointHandle, TransactionManager, _get_database_url, get_session_transaction_depth, is_session_in_transaction, ) # ── Helpers ─────────────────────────────────────────────────────────────────── @pytest.fixture(autouse=True) def _reset_globals(): """Reset the module-level globals before/after every test.""" old_engine = conn_mod._engine old_sync = conn_mod._sync_engine old_sf = conn_mod._session_factory old_ssf = conn_mod._sync_session_factory conn_mod._engine = None conn_mod._sync_engine = None conn_mod._session_factory = None conn_mod._sync_session_factory = None yield conn_mod._engine = old_engine conn_mod._sync_engine = old_sync conn_mod._session_factory = old_sf conn_mod._sync_session_factory = old_ssf # ══════════════════════════════════════════════════════════════════════════════ # _get_database_url # ══════════════════════════════════════════════════════════════════════════════ class TestGetDatabaseURL: """Test _get_database_url helper.""" def test_sqlite_url_converted(self): """sqlite:/// should be converted to sqlite+aiosqlite:///.""" with patch.object( conn_mod.settings, "database_url", "sqlite:///./data/anime.db", ): result = _get_database_url() assert "aiosqlite" in result def test_non_sqlite_url_unchanged(self): """Non-SQLite URL should remain unchanged.""" with patch.object( conn_mod.settings, "database_url", "postgresql://user:pass@localhost/db", ): result = _get_database_url() assert result == "postgresql://user:pass@localhost/db" # ══════════════════════════════════════════════════════════════════════════════ # get_engine / get_sync_engine # ══════════════════════════════════════════════════════════════════════════════ class TestGetEngine: """Test get_engine and get_sync_engine.""" def test_raises_when_not_initialized(self): """get_engine should raise RuntimeError before init_db.""" with pytest.raises(RuntimeError, match="not initialized"): conn_mod.get_engine() def test_returns_engine_when_set(self): """Should return the engine when initialised.""" fake_engine = MagicMock() conn_mod._engine = fake_engine assert conn_mod.get_engine() is fake_engine def test_get_sync_engine_raises(self): """get_sync_engine should raise RuntimeError before init_db.""" with pytest.raises(RuntimeError, match="not initialized"): conn_mod.get_sync_engine() def test_get_sync_engine_returns(self): """Should return sync engine when set.""" fake = MagicMock() conn_mod._sync_engine = fake assert conn_mod.get_sync_engine() is fake # ══════════════════════════════════════════════════════════════════════════════ # get_db_session # ══════════════════════════════════════════════════════════════════════════════ class TestGetDBSession: """Test get_db_session async context manager.""" @pytest.mark.asyncio async def test_raises_when_not_initialized(self): """Should raise RuntimeError if session factory is None.""" with pytest.raises(RuntimeError, match="not initialized"): async with conn_mod.get_db_session(): pass @pytest.mark.asyncio async def test_commits_on_success(self): """Session should be committed on normal exit.""" mock_session = AsyncMock() factory = MagicMock(return_value=mock_session) conn_mod._session_factory = factory async with conn_mod.get_db_session() as session: assert session is mock_session mock_session.commit.assert_called_once() mock_session.close.assert_called_once() @pytest.mark.asyncio async def test_rollback_on_exception(self): """Session should be rolled back on exception.""" mock_session = AsyncMock() factory = MagicMock(return_value=mock_session) conn_mod._session_factory = factory with pytest.raises(ValueError): async with conn_mod.get_db_session(): raise ValueError("boom") mock_session.rollback.assert_called_once() mock_session.commit.assert_not_called() mock_session.close.assert_called_once() # ══════════════════════════════════════════════════════════════════════════════ # get_sync_session / get_async_session_factory # ══════════════════════════════════════════════════════════════════════════════ class TestGetSyncSession: """Test get_sync_session.""" def test_raises_when_not_initialized(self): """Should raise RuntimeError.""" with pytest.raises(RuntimeError, match="not initialized"): conn_mod.get_sync_session() def test_returns_session(self): """Should return a session from the factory.""" mock_session = MagicMock() conn_mod._sync_session_factory = MagicMock(return_value=mock_session) assert conn_mod.get_sync_session() is mock_session class TestGetAsyncSessionFactory: """Test get_async_session_factory.""" def test_raises_when_not_initialized(self): """Should raise RuntimeError.""" with pytest.raises(RuntimeError, match="not initialized"): conn_mod.get_async_session_factory() def test_returns_session(self): """Should return a new async session.""" mock_session = AsyncMock() conn_mod._session_factory = MagicMock(return_value=mock_session) assert conn_mod.get_async_session_factory() is mock_session # ══════════════════════════════════════════════════════════════════════════════ # get_transactional_session # ══════════════════════════════════════════════════════════════════════════════ class TestGetTransactionalSession: """Test get_transactional_session.""" @pytest.mark.asyncio async def test_raises_when_not_initialized(self): """Should raise RuntimeError.""" with pytest.raises(RuntimeError, match="not initialized"): async with conn_mod.get_transactional_session(): pass @pytest.mark.asyncio async def test_does_not_auto_commit(self): """Session should NOT be committed on normal exit.""" mock_session = AsyncMock() conn_mod._session_factory = MagicMock(return_value=mock_session) async with conn_mod.get_transactional_session() as session: pass mock_session.commit.assert_not_called() mock_session.close.assert_called_once() @pytest.mark.asyncio async def test_rollback_on_exception(self): """Should rollback on exception.""" mock_session = AsyncMock() conn_mod._session_factory = MagicMock(return_value=mock_session) with pytest.raises(ValueError): async with conn_mod.get_transactional_session(): raise ValueError("boom") mock_session.rollback.assert_called_once() # ══════════════════════════════════════════════════════════════════════════════ # close_db # ══════════════════════════════════════════════════════════════════════════════ class TestCloseDB: """Test close_db function.""" @pytest.mark.asyncio async def test_disposes_engines(self): """Should dispose both engines.""" mock_engine = AsyncMock() mock_sync = MagicMock() mock_sync.url = "sqlite:///test.db" mock_sync.connect.return_value.__enter__ = MagicMock() mock_sync.connect.return_value.__exit__ = MagicMock() conn_ctx = MagicMock() conn_ctx.__enter__ = MagicMock(return_value=MagicMock()) conn_ctx.__exit__ = MagicMock(return_value=False) mock_sync.connect.return_value = conn_ctx conn_mod._engine = mock_engine conn_mod._sync_engine = mock_sync conn_mod._session_factory = MagicMock() conn_mod._sync_session_factory = MagicMock() await conn_mod.close_db() mock_engine.dispose.assert_called_once() mock_sync.dispose.assert_called_once() assert conn_mod._engine is None assert conn_mod._sync_engine is None @pytest.mark.asyncio async def test_noop_when_not_initialized(self): """Should not raise if engines are None.""" await conn_mod.close_db() # should not raise # ══════════════════════════════════════════════════════════════════════════════ # TransactionManager # ══════════════════════════════════════════════════════════════════════════════ class TestTransactionManager: """Test TransactionManager class.""" def test_init_raises_without_factory(self): """Should raise RuntimeError when no session factory.""" with pytest.raises(RuntimeError, match="not initialized"): TransactionManager() @pytest.mark.asyncio async def test_context_manager_creates_and_closes_session(self): """Should create session on enter and close on exit.""" mock_session = AsyncMock() factory = MagicMock(return_value=mock_session) async with TransactionManager(session_factory=factory) as tm: session = await tm.get_session() assert session is mock_session mock_session.close.assert_called_once() @pytest.mark.asyncio async def test_begin_commit(self): """begin then commit should work.""" mock_session = AsyncMock() factory = MagicMock(return_value=mock_session) async with TransactionManager(session_factory=factory) as tm: await tm.begin() assert tm.is_in_transaction() is True await tm.commit() assert tm.is_in_transaction() is False mock_session.begin.assert_called_once() mock_session.commit.assert_called_once() @pytest.mark.asyncio async def test_begin_rollback(self): """begin then rollback should work.""" mock_session = AsyncMock() factory = MagicMock(return_value=mock_session) async with TransactionManager(session_factory=factory) as tm: await tm.begin() await tm.rollback() assert tm.is_in_transaction() is False mock_session.rollback.assert_called_once() @pytest.mark.asyncio async def test_exception_auto_rollback(self): """Exception inside context manager should auto rollback.""" mock_session = AsyncMock() factory = MagicMock(return_value=mock_session) with pytest.raises(ValueError): async with TransactionManager(session_factory=factory) as tm: await tm.begin() raise ValueError("boom") mock_session.rollback.assert_called_once() @pytest.mark.asyncio async def test_double_begin_raises(self): """begin called twice should raise.""" mock_session = AsyncMock() factory = MagicMock(return_value=mock_session) async with TransactionManager(session_factory=factory) as tm: await tm.begin() with pytest.raises(RuntimeError, match="Already in"): await tm.begin() @pytest.mark.asyncio async def test_commit_without_begin_raises(self): """commit without begin should raise.""" mock_session = AsyncMock() factory = MagicMock(return_value=mock_session) async with TransactionManager(session_factory=factory) as tm: with pytest.raises(RuntimeError, match="Not in"): await tm.commit() @pytest.mark.asyncio async def test_get_session_outside_context_raises(self): """get_session outside context manager should raise.""" factory = MagicMock() tm = TransactionManager(session_factory=factory) with pytest.raises(RuntimeError, match="context manager"): await tm.get_session() @pytest.mark.asyncio async def test_transaction_depth(self): """get_transaction_depth should reflect state.""" mock_session = AsyncMock() factory = MagicMock(return_value=mock_session) async with TransactionManager(session_factory=factory) as tm: assert tm.get_transaction_depth() == 0 await tm.begin() assert tm.get_transaction_depth() == 1 await tm.commit() assert tm.get_transaction_depth() == 0 @pytest.mark.asyncio async def test_savepoint_creation(self): """savepoint should return SavepointHandle.""" mock_session = AsyncMock() mock_nested = AsyncMock() mock_session.begin_nested = AsyncMock(return_value=mock_nested) factory = MagicMock(return_value=mock_session) async with TransactionManager(session_factory=factory) as tm: await tm.begin() sp = await tm.savepoint("sp1") assert isinstance(sp, SavepointHandle) @pytest.mark.asyncio async def test_savepoint_without_transaction_raises(self): """savepoint outside transaction should raise.""" mock_session = AsyncMock() factory = MagicMock(return_value=mock_session) async with TransactionManager(session_factory=factory) as tm: with pytest.raises(RuntimeError, match="Must be in"): await tm.savepoint() @pytest.mark.asyncio async def test_rollback_without_session_raises(self): """rollback without active session should raise.""" factory = MagicMock() tm = TransactionManager(session_factory=factory) with pytest.raises(RuntimeError, match="No active session"): await tm.rollback() # ══════════════════════════════════════════════════════════════════════════════ # SavepointHandle # ══════════════════════════════════════════════════════════════════════════════ class TestSavepointHandle: """Test SavepointHandle class.""" @pytest.mark.asyncio async def test_rollback(self): """Should call nested.rollback().""" mock_nested = AsyncMock() sp = SavepointHandle(mock_nested, "sp1") await sp.rollback() mock_nested.rollback.assert_called_once() assert sp._released is True @pytest.mark.asyncio async def test_rollback_idempotent(self): """Second rollback should be a noop.""" mock_nested = AsyncMock() sp = SavepointHandle(mock_nested, "sp1") await sp.rollback() await sp.rollback() mock_nested.rollback.assert_called_once() @pytest.mark.asyncio async def test_release(self): """Should mark as released.""" mock_nested = AsyncMock() sp = SavepointHandle(mock_nested, "sp1") await sp.release() assert sp._released is True @pytest.mark.asyncio async def test_release_idempotent(self): """Second release should be a noop.""" mock_nested = AsyncMock() sp = SavepointHandle(mock_nested, "sp1") await sp.release() await sp.release() assert sp._released is True # ══════════════════════════════════════════════════════════════════════════════ # Utility Functions # ══════════════════════════════════════════════════════════════════════════════ class TestUtilityFunctions: """Test is_session_in_transaction and get_session_transaction_depth.""" def test_in_transaction_true(self): """Should return True when session is in transaction.""" session = MagicMock() session.in_transaction.return_value = True assert is_session_in_transaction(session) is True def test_in_transaction_false(self): """Should return False when session is not in transaction.""" session = MagicMock() session.in_transaction.return_value = False assert is_session_in_transaction(session) is False def test_transaction_depth_zero(self): """Should return 0 when not in transaction.""" session = MagicMock() session.in_transaction.return_value = False assert get_session_transaction_depth(session) == 0 def test_transaction_depth_one(self): """Should return 1 when in transaction.""" session = MagicMock() session.in_transaction.return_value = True assert get_session_transaction_depth(session) == 1