"""Database connection and session management for SQLAlchemy. This module provides database engine creation, session factory configuration, and dependency injection helpers for FastAPI endpoints. Functions: - init_db: Initialize database engine and create tables - close_db: Close database connections and cleanup - get_db_session: FastAPI dependency for database sessions - get_engine: Get database engine instance """ from __future__ import annotations import logging from contextlib import asynccontextmanager from typing import AsyncGenerator, Optional from sqlalchemy import create_engine, event, pool from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine, ) from sqlalchemy.orm import Session, sessionmaker from src.config.settings import settings from src.server.database.base import Base logger = logging.getLogger(__name__) # Global engine and session factory instances _engine: Optional[AsyncEngine] = None _sync_engine: Optional[create_engine] = None _session_factory: Optional[async_sessionmaker[AsyncSession]] = None _sync_session_factory: Optional[sessionmaker[Session]] = None def _get_database_url() -> str: """Get database URL from settings. Converts SQLite URLs to async format if needed. Returns: Database URL string suitable for async engine """ url = settings.database_url # Convert sqlite:/// to sqlite+aiosqlite:/// for async support if url.startswith("sqlite:///"): url = url.replace("sqlite:///", "sqlite+aiosqlite:///") return url def _configure_sqlite_engine(engine: AsyncEngine) -> None: """Configure SQLite-specific engine settings. Enables foreign key support and optimizes connection pooling. Args: engine: SQLAlchemy async engine instance """ @event.listens_for(engine.sync_engine, "connect") def set_sqlite_pragma(dbapi_conn, connection_record): """Enable foreign keys and set pragmas for SQLite.""" cursor = dbapi_conn.cursor() cursor.execute("PRAGMA foreign_keys=ON") cursor.execute("PRAGMA journal_mode=WAL") cursor.close() async def init_db() -> None: """Initialize database engine and create tables. Creates async and sync engines, session factories, and database tables. Should be called during application startup. Raises: Exception: If database initialization fails """ global _engine, _sync_engine, _session_factory, _sync_session_factory try: # Get database URL db_url = _get_database_url() logger.info(f"Initializing database: {db_url}") # Build engine kwargs based on database type is_sqlite = "sqlite" in db_url engine_kwargs = { "echo": settings.log_level == "DEBUG", "poolclass": pool.StaticPool if is_sqlite else pool.QueuePool, "pool_pre_ping": True, } # Only add pool_size and max_overflow for non-SQLite databases if not is_sqlite: engine_kwargs["pool_size"] = 5 engine_kwargs["max_overflow"] = 10 # Create async engine _engine = create_async_engine(db_url, **engine_kwargs) # Configure SQLite if needed if is_sqlite: _configure_sqlite_engine(_engine) # Create async session factory _session_factory = async_sessionmaker( bind=_engine, class_=AsyncSession, expire_on_commit=False, autoflush=False, autocommit=False, ) # Create sync engine for initial setup sync_url = settings.database_url is_sqlite_sync = "sqlite" in sync_url sync_engine_kwargs = { "echo": settings.log_level == "DEBUG", "poolclass": pool.StaticPool if is_sqlite_sync else pool.QueuePool, "pool_pre_ping": True, } _sync_engine = create_engine(sync_url, **sync_engine_kwargs) # Create sync session factory _sync_session_factory = sessionmaker( bind=_sync_engine, expire_on_commit=False, autoflush=False, autocommit=False, ) # Create all tables logger.info("Creating database tables...") Base.metadata.create_all(bind=_sync_engine) logger.info("Database initialization complete") except Exception as e: logger.error(f"Failed to initialize database: {e}") raise async def close_db() -> None: """Close database connections and cleanup resources. Should be called during application shutdown. """ global _engine, _sync_engine, _session_factory, _sync_session_factory try: if _engine: logger.info("Closing async database engine...") await _engine.dispose() _engine = None _session_factory = None if _sync_engine: logger.info("Closing sync database engine...") _sync_engine.dispose() _sync_engine = None _sync_session_factory = None logger.info("Database connections closed") except Exception as e: logger.error(f"Error closing database: {e}") def get_engine() -> AsyncEngine: """Get the database engine instance. Returns: AsyncEngine instance Raises: RuntimeError: If database is not initialized """ if _engine is None: raise RuntimeError( "Database not initialized. Call init_db() first." ) return _engine def get_sync_engine(): """Get the sync database engine instance. Returns: Engine instance Raises: RuntimeError: If database is not initialized """ if _sync_engine is None: raise RuntimeError( "Database not initialized. Call init_db() first." ) return _sync_engine @asynccontextmanager async def get_db_session() -> AsyncGenerator[AsyncSession, None]: """FastAPI dependency to get database session. Provides an async database session with automatic commit/rollback. Use this as a dependency in FastAPI endpoints. Yields: AsyncSession: Database session for async operations Raises: RuntimeError: If database is not initialized Example: @app.get("/anime") async def get_anime( db: AsyncSession = Depends(get_db_session) ): result = await db.execute(select(AnimeSeries)) return result.scalars().all() """ if _session_factory is None: raise RuntimeError( "Database not initialized. Call init_db() first." ) session = _session_factory() try: yield session await session.commit() except Exception: await session.rollback() raise finally: await session.close() def get_sync_session() -> Session: """Get a sync database session. Use this for synchronous operations outside FastAPI endpoints. Remember to close the session when done. Returns: Session: Database session for sync operations Raises: RuntimeError: If database is not initialized Example: session = get_sync_session() try: result = session.execute(select(AnimeSeries)) return result.scalars().all() finally: session.close() """ if _sync_session_factory is None: raise RuntimeError( "Database not initialized. Call init_db() first." ) return _sync_session_factory() def get_async_session_factory() -> AsyncSession: """Get a new async database session (factory function). Creates a new session instance for use in repository patterns. The caller is responsible for committing/rolling back and closing. Returns: AsyncSession: New database session for async operations Raises: RuntimeError: If database is not initialized Example: session = get_async_session_factory() try: result = await session.execute(select(AnimeSeries)) await session.commit() return result.scalars().all() except Exception: await session.rollback() raise finally: await session.close() """ if _session_factory is None: raise RuntimeError( "Database not initialized. Call init_db() first." ) return _session_factory()