267 lines
7.6 KiB
Python
267 lines
7.6 KiB
Python
"""Database connection and session management for SQLAlchemy.
|
|
|
|
This module provides database engine creation, session factory configuration,
|
|
and dependency injection helpers for FastAPI endpoints.
|
|
|
|
Functions:
|
|
- init_db: Initialize database engine and create tables
|
|
- close_db: Close database connections and cleanup
|
|
- get_db_session: FastAPI dependency for database sessions
|
|
- get_engine: Get database engine instance
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from contextlib import asynccontextmanager
|
|
from typing import AsyncGenerator, Optional
|
|
|
|
from sqlalchemy import create_engine, event, pool
|
|
from sqlalchemy.ext.asyncio import (
|
|
AsyncEngine,
|
|
AsyncSession,
|
|
async_sessionmaker,
|
|
create_async_engine,
|
|
)
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
from src.config.settings import settings
|
|
from src.server.database.base import Base
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global engine and session factory instances
|
|
_engine: Optional[AsyncEngine] = None
|
|
_sync_engine: Optional[create_engine] = None
|
|
_session_factory: Optional[async_sessionmaker[AsyncSession]] = None
|
|
_sync_session_factory: Optional[sessionmaker[Session]] = None
|
|
|
|
|
|
def _get_database_url() -> str:
|
|
"""Get database URL from settings.
|
|
|
|
Converts SQLite URLs to async format if needed.
|
|
|
|
Returns:
|
|
Database URL string suitable for async engine
|
|
"""
|
|
url = settings.database_url
|
|
|
|
# Convert sqlite:/// to sqlite+aiosqlite:/// for async support
|
|
if url.startswith("sqlite:///"):
|
|
url = url.replace("sqlite:///", "sqlite+aiosqlite:///")
|
|
|
|
return url
|
|
|
|
|
|
def _configure_sqlite_engine(engine: AsyncEngine) -> None:
|
|
"""Configure SQLite-specific engine settings.
|
|
|
|
Enables foreign key support and optimizes connection pooling.
|
|
|
|
Args:
|
|
engine: SQLAlchemy async engine instance
|
|
"""
|
|
@event.listens_for(engine.sync_engine, "connect")
|
|
def set_sqlite_pragma(dbapi_conn, connection_record):
|
|
"""Enable foreign keys and set pragmas for SQLite."""
|
|
cursor = dbapi_conn.cursor()
|
|
cursor.execute("PRAGMA foreign_keys=ON")
|
|
cursor.execute("PRAGMA journal_mode=WAL")
|
|
cursor.close()
|
|
|
|
|
|
async def init_db() -> None:
|
|
"""Initialize database engine and create tables.
|
|
|
|
Creates async and sync engines, session factories, and database tables.
|
|
Should be called during application startup.
|
|
|
|
Raises:
|
|
Exception: If database initialization fails
|
|
"""
|
|
global _engine, _sync_engine, _session_factory, _sync_session_factory
|
|
|
|
try:
|
|
# Get database URL
|
|
db_url = _get_database_url()
|
|
logger.info(f"Initializing database: {db_url}")
|
|
|
|
# Build engine kwargs based on database type
|
|
is_sqlite = "sqlite" in db_url
|
|
engine_kwargs = {
|
|
"echo": settings.log_level == "DEBUG",
|
|
"poolclass": pool.StaticPool if is_sqlite else pool.QueuePool,
|
|
"pool_pre_ping": True,
|
|
}
|
|
|
|
# Only add pool_size and max_overflow for non-SQLite databases
|
|
if not is_sqlite:
|
|
engine_kwargs["pool_size"] = 5
|
|
engine_kwargs["max_overflow"] = 10
|
|
|
|
# Create async engine
|
|
_engine = create_async_engine(db_url, **engine_kwargs)
|
|
|
|
# Configure SQLite if needed
|
|
if is_sqlite:
|
|
_configure_sqlite_engine(_engine)
|
|
|
|
# Create async session factory
|
|
_session_factory = async_sessionmaker(
|
|
bind=_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
|
|
# Create sync engine for initial setup
|
|
sync_url = settings.database_url
|
|
is_sqlite_sync = "sqlite" in sync_url
|
|
sync_engine_kwargs = {
|
|
"echo": settings.log_level == "DEBUG",
|
|
"poolclass": pool.StaticPool if is_sqlite_sync else pool.QueuePool,
|
|
"pool_pre_ping": True,
|
|
}
|
|
_sync_engine = create_engine(sync_url, **sync_engine_kwargs)
|
|
|
|
# Create sync session factory
|
|
_sync_session_factory = sessionmaker(
|
|
bind=_sync_engine,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
|
|
# Create all tables
|
|
logger.info("Creating database tables...")
|
|
Base.metadata.create_all(bind=_sync_engine)
|
|
logger.info("Database initialization complete")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize database: {e}")
|
|
raise
|
|
|
|
|
|
async def close_db() -> None:
|
|
"""Close database connections and cleanup resources.
|
|
|
|
Should be called during application shutdown.
|
|
"""
|
|
global _engine, _sync_engine, _session_factory, _sync_session_factory
|
|
|
|
try:
|
|
if _engine:
|
|
logger.info("Closing async database engine...")
|
|
await _engine.dispose()
|
|
_engine = None
|
|
_session_factory = None
|
|
|
|
if _sync_engine:
|
|
logger.info("Closing sync database engine...")
|
|
_sync_engine.dispose()
|
|
_sync_engine = None
|
|
_sync_session_factory = None
|
|
|
|
logger.info("Database connections closed")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error closing database: {e}")
|
|
|
|
|
|
def get_engine() -> AsyncEngine:
|
|
"""Get the database engine instance.
|
|
|
|
Returns:
|
|
AsyncEngine instance
|
|
|
|
Raises:
|
|
RuntimeError: If database is not initialized
|
|
"""
|
|
if _engine is None:
|
|
raise RuntimeError(
|
|
"Database not initialized. Call init_db() first."
|
|
)
|
|
return _engine
|
|
|
|
|
|
def get_sync_engine():
|
|
"""Get the sync database engine instance.
|
|
|
|
Returns:
|
|
Engine instance
|
|
|
|
Raises:
|
|
RuntimeError: If database is not initialized
|
|
"""
|
|
if _sync_engine is None:
|
|
raise RuntimeError(
|
|
"Database not initialized. Call init_db() first."
|
|
)
|
|
return _sync_engine
|
|
|
|
|
|
@asynccontextmanager
|
|
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
|
"""FastAPI dependency to get database session.
|
|
|
|
Provides an async database session with automatic commit/rollback.
|
|
Use this as a dependency in FastAPI endpoints.
|
|
|
|
Yields:
|
|
AsyncSession: Database session for async operations
|
|
|
|
Raises:
|
|
RuntimeError: If database is not initialized
|
|
|
|
Example:
|
|
@app.get("/anime")
|
|
async def get_anime(
|
|
db: AsyncSession = Depends(get_db_session)
|
|
):
|
|
result = await db.execute(select(AnimeSeries))
|
|
return result.scalars().all()
|
|
"""
|
|
if _session_factory is None:
|
|
raise RuntimeError(
|
|
"Database not initialized. Call init_db() first."
|
|
)
|
|
|
|
session = _session_factory()
|
|
try:
|
|
yield session
|
|
await session.commit()
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
finally:
|
|
await session.close()
|
|
|
|
|
|
def get_sync_session() -> Session:
|
|
"""Get a sync database session.
|
|
|
|
Use this for synchronous operations outside FastAPI endpoints.
|
|
Remember to close the session when done.
|
|
|
|
Returns:
|
|
Session: Database session for sync operations
|
|
|
|
Raises:
|
|
RuntimeError: If database is not initialized
|
|
|
|
Example:
|
|
session = get_sync_session()
|
|
try:
|
|
result = session.execute(select(AnimeSeries))
|
|
return result.scalars().all()
|
|
finally:
|
|
session.close()
|
|
"""
|
|
if _sync_session_factory is None:
|
|
raise RuntimeError(
|
|
"Database not initialized. Call init_db() first."
|
|
)
|
|
|
|
return _sync_session_factory()
|