"""Database connection and session management for SQLAlchemy. This module provides database engine creation, session factory configuration, and dependency injection helpers for FastAPI endpoints. Functions: - init_db: Initialize database engine and create tables - close_db: Close database connections and cleanup - get_db_session: FastAPI dependency for database sessions - get_transactional_session: Session without auto-commit for transactions - get_engine: Get database engine instance Classes: - TransactionManager: Helper class for manual transaction control """ from __future__ import annotations import logging from contextlib import asynccontextmanager from typing import AsyncGenerator, Optional from sqlalchemy import create_engine, event, pool from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine, ) from sqlalchemy.orm import Session, sessionmaker from src.config.settings import settings from src.server.database.base import Base logger = logging.getLogger(__name__) # Global engine and session factory instances _engine: Optional[AsyncEngine] = None _sync_engine: Optional[create_engine] = None _session_factory: Optional[async_sessionmaker[AsyncSession]] = None _sync_session_factory: Optional[sessionmaker[Session]] = None def _get_database_url() -> str: """Get database URL from settings. Converts SQLite URLs to async format if needed. Returns: Database URL string suitable for async engine """ url = settings.database_url # Convert sqlite:/// to sqlite+aiosqlite:/// for async support if url.startswith("sqlite:///"): url = url.replace("sqlite:///", "sqlite+aiosqlite:///") return url def _configure_sqlite_engine(engine: AsyncEngine) -> None: """Configure SQLite-specific engine settings. Enables foreign key support and optimizes connection pooling. Args: engine: SQLAlchemy async engine instance """ @event.listens_for(engine.sync_engine, "connect") def set_sqlite_pragma(dbapi_conn, connection_record): """Enable foreign keys and set pragmas for SQLite.""" cursor = dbapi_conn.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.execute("PRAGMA journal_mode=WAL") cursor.close() async def init_db() -> None: """Initialize database engine and create tables. Creates async and sync engines, session factories, and database tables. Should be called during application startup. Raises: Exception: If database initialization fails """ global _engine, _sync_engine, _session_factory, _sync_session_factory try: # Get database URL db_url = _get_database_url() logger.info(f"Initializing database: {db_url}") # Build engine kwargs based on database type is_sqlite = "sqlite" in db_url engine_kwargs = { "echo": settings.log_level == "DEBUG", "poolclass": pool.StaticPool if is_sqlite else pool.QueuePool, "pool_pre_ping": True, } # Only add pool_size and max_overflow for non-SQLite databases if not is_sqlite: engine_kwargs["pool_size"] = 5 engine_kwargs["max_overflow"] = 10 # Create async engine _engine = create_async_engine(db_url, **engine_kwargs) # Configure SQLite if needed if is_sqlite: _configure_sqlite_engine(_engine) # Create async session factory _session_factory = async_sessionmaker( bind=_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) # Create sync engine for initial setup sync_url = settings.database_url is_sqlite_sync = "sqlite" in sync_url sync_engine_kwargs = { "echo": settings.log_level == "DEBUG", "poolclass": pool.StaticPool if is_sqlite_sync else pool.QueuePool, "pool_pre_ping": True, } _sync_engine = create_engine(sync_url, **sync_engine_kwargs) # Create sync session factory _sync_session_factory = sessionmaker( bind=_sync_engine, expire_on_commit=False, autoflush=False, autocommit=False, ) # Create all tables logger.info("Creating database tables...") Base.metadata.create_all(bind=_sync_engine) logger.info("Database initialization complete") except Exception as e: logger.error(f"Failed to initialize database: {e}") raise async def close_db() -> None: """Close database connections and cleanup resources. Performs a WAL checkpoint for SQLite databases to ensure all pending writes are flushed to the main database file before closing connections. This prevents database corruption during shutdown. Should be called during application shutdown. """ global _engine, _sync_engine, _session_factory, _sync_session_factory try: # For SQLite: checkpoint WAL to ensure all writes are flushed if _sync_engine and "sqlite" in str(_sync_engine.url): logger.info("Running SQLite WAL checkpoint before shutdown...") try: from sqlalchemy import text with _sync_engine.connect() as conn: # TRUNCATE mode: checkpoint and truncate WAL file conn.execute(text("PRAGMA wal_checkpoint(TRUNCATE)")) conn.commit() logger.info("SQLite WAL checkpoint completed") except Exception as e: logger.warning(f"WAL checkpoint failed (non-critical): {e}") if _engine: logger.info("Closing async database engine...") await _engine.dispose() _engine = None _session_factory = None if _sync_engine: logger.info("Closing sync database engine...") _sync_engine.dispose() _sync_engine = None _sync_session_factory = None logger.info("Database connections closed") except Exception as e: logger.error(f"Error closing database: {e}") def get_engine() -> AsyncEngine: """Get the database engine instance. Returns: AsyncEngine instance Raises: RuntimeError: If database is not initialized """ if _engine is None: raise RuntimeError( "Database not initialized. Call init_db() first." ) return _engine def get_sync_engine(): """Get the sync database engine instance. Returns: Engine instance Raises: RuntimeError: If database is not initialized """ if _sync_engine is None: raise RuntimeError( "Database not initialized. Call init_db() first." ) return _sync_engine @asynccontextmanager async def get_db_session() -> AsyncGenerator[AsyncSession, None]: """FastAPI dependency to get database session. Provides an async database session with automatic commit/rollback. Use this as a dependency in FastAPI endpoints. Yields: AsyncSession: Database session for async operations Raises: RuntimeError: If database is not initialized Example: @app.get("/anime") async def get_anime( db: AsyncSession = Depends(get_db_session) ): result = await db.execute(select(AnimeSeries)) return result.scalars().all() """ if _session_factory is None: raise RuntimeError( "Database not initialized. Call init_db() first." ) session = _session_factory() try: yield session await session.commit() except Exception: await session.rollback() raise finally: await session.close() def get_sync_session() -> Session: """Get a sync database session. Use this for synchronous operations outside FastAPI endpoints. Remember to close the session when done. Returns: Session: Database session for sync operations Raises: RuntimeError: If database is not initialized Example: session = get_sync_session() try: result = session.execute(select(AnimeSeries)) return result.scalars().all() finally: session.close() """ if _sync_session_factory is None: raise RuntimeError( "Database not initialized. Call init_db() first." ) return _sync_session_factory() def get_async_session_factory() -> AsyncSession: """Get a new async database session (factory function). Creates a new session instance for use in repository patterns. The caller is responsible for committing/rolling back and closing. Returns: AsyncSession: New database session for async operations Raises: RuntimeError: If database is not initialized Example: session = get_async_session_factory() try: result = await session.execute(select(AnimeSeries)) await session.commit() return result.scalars().all() except Exception: await session.rollback() raise finally: await session.close() """ if _session_factory is None: raise RuntimeError( "Database not initialized. Call init_db() first." ) return _session_factory() @asynccontextmanager async def get_transactional_session() -> AsyncGenerator[AsyncSession, None]: """Get a database session without auto-commit for explicit transaction control. Unlike get_db_session(), this does NOT auto-commit on success. Use this when you need explicit transaction control with the @transactional decorator or atomic() context manager. Yields: AsyncSession: Database session for async operations Raises: RuntimeError: If database is not initialized Example: async with get_transactional_session() as session: async with atomic(session) as tx: # Multiple operations in transaction await operation1(session) await operation2(session) # Committed when exiting atomic() context """ if _session_factory is None: raise RuntimeError( "Database not initialized. Call init_db() first." ) session = _session_factory() try: yield session except Exception: await session.rollback() raise finally: await session.close() class TransactionManager: """Helper class for manual transaction control. Provides a cleaner interface for managing transactions across multiple service calls within a single request. Attributes: _session_factory: Factory for creating new sessions _session: Current active session _in_transaction: Whether currently in a transaction Example: async with TransactionManager() as tm: session = await tm.get_session() await tm.begin() try: await service1.operation(session) await service2.operation(session) await tm.commit() except Exception: await tm.rollback() raise """ def __init__( self, session_factory: Optional[async_sessionmaker] = None ) -> None: """Initialize transaction manager. Args: session_factory: Optional custom session factory. Uses global factory if not provided. """ self._session_factory = session_factory or _session_factory self._session: Optional[AsyncSession] = None self._in_transaction = False if self._session_factory is None: raise RuntimeError( "Database not initialized. Call init_db() first." ) async def __aenter__(self) -> "TransactionManager": """Enter context manager and create session.""" self._session = self._session_factory() logger.debug("TransactionManager: Created new session") return self async def __aexit__( self, exc_type: Optional[type], exc_val: Optional[BaseException], exc_tb: Optional[object], ) -> bool: """Exit context manager and cleanup session. Automatically rolls back if an exception occurred and transaction wasn't explicitly committed. """ if self._session: if exc_type is not None and self._in_transaction: logger.warning( "TransactionManager: Rolling back due to exception: %s", exc_val, ) await self._session.rollback() await self._session.close() self._session = None self._in_transaction = False logger.debug("TransactionManager: Session closed") return False async def get_session(self) -> AsyncSession: """Get the current session. Returns: Current AsyncSession instance Raises: RuntimeError: If not within context manager """ if self._session is None: raise RuntimeError( "TransactionManager must be used as async context manager" ) return self._session async def begin(self) -> None: """Begin a new transaction. Raises: RuntimeError: If already in a transaction or no session """ if self._session is None: raise RuntimeError("No active session") if self._in_transaction: raise RuntimeError("Already in a transaction") await self._session.begin() self._in_transaction = True logger.debug("TransactionManager: Transaction started") async def commit(self) -> None: """Commit the current transaction. Raises: RuntimeError: If not in a transaction """ if not self._in_transaction or self._session is None: raise RuntimeError("Not in a transaction") await self._session.commit() self._in_transaction = False logger.debug("TransactionManager: Transaction committed") async def rollback(self) -> None: """Rollback the current transaction. Raises: RuntimeError: If not in a transaction """ if self._session is None: raise RuntimeError("No active session") await self._session.rollback() self._in_transaction = False logger.debug("TransactionManager: Transaction rolled back") async def savepoint(self, name: Optional[str] = None) -> "SavepointHandle": """Create a savepoint within the current transaction. Args: name: Optional savepoint name Returns: SavepointHandle for controlling the savepoint Raises: RuntimeError: If not in a transaction """ if not self._in_transaction or self._session is None: raise RuntimeError("Must be in a transaction to create savepoint") nested = await self._session.begin_nested() return SavepointHandle(nested, name or "unnamed") def is_in_transaction(self) -> bool: """Check if currently in a transaction. Returns: True if in an active transaction """ return self._in_transaction def get_transaction_depth(self) -> int: """Get current transaction nesting depth. Returns: 0 if not in transaction, 1+ for nested transactions """ if not self._in_transaction: return 0 return 1 # Basic implementation - could be extended class SavepointHandle: """Handle for controlling a database savepoint. Attributes: _nested: SQLAlchemy nested transaction _name: Savepoint name for logging _released: Whether savepoint has been released """ def __init__(self, nested: object, name: str) -> None: """Initialize savepoint handle. Args: nested: SQLAlchemy nested transaction object name: Savepoint name """ self._nested = nested self._name = name self._released = False logger.debug("Created savepoint: %s", name) async def rollback(self) -> None: """Rollback to this savepoint.""" if not self._released: await self._nested.rollback() self._released = True logger.debug("Rolled back savepoint: %s", self._name) async def release(self) -> None: """Release (commit) this savepoint.""" if not self._released: # Nested transactions commit automatically in SQLAlchemy self._released = True logger.debug("Released savepoint: %s", self._name) def is_session_in_transaction(session: AsyncSession | Session) -> bool: """Check if a session is currently in a transaction. Args: session: SQLAlchemy session (sync or async) Returns: True if session is in an active transaction """ return session.in_transaction() def get_session_transaction_depth(session: AsyncSession | Session) -> int: """Get the transaction nesting depth of a session. Args: session: SQLAlchemy session (sync or async) Returns: Number of nested transactions (0 if not in transaction) """ if not session.in_transaction(): return 0 # Check for nested transaction state # Note: SQLAlchemy doesn't directly expose nesting depth return 1