Aniworld/src/server/database/connection.py
2025-12-02 14:04:37 +01:00

267 lines
7.6 KiB
Python

"""Database connection and session management for SQLAlchemy.
This module provides database engine creation, session factory configuration,
and dependency injection helpers for FastAPI endpoints.
Functions:
- init_db: Initialize database engine and create tables
- close_db: Close database connections and cleanup
- get_db_session: FastAPI dependency for database sessions
- get_engine: Get database engine instance
"""
from __future__ import annotations
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator, Optional
from sqlalchemy import create_engine, event, pool
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import Session, sessionmaker
from src.config.settings import settings
from src.server.database.base import Base
logger = logging.getLogger(__name__)
# Global engine and session factory instances
_engine: Optional[AsyncEngine] = None
_sync_engine: Optional[create_engine] = None
_session_factory: Optional[async_sessionmaker[AsyncSession]] = None
_sync_session_factory: Optional[sessionmaker[Session]] = None
def _get_database_url() -> str:
"""Get database URL from settings.
Converts SQLite URLs to async format if needed.
Returns:
Database URL string suitable for async engine
"""
url = settings.database_url
# Convert sqlite:/// to sqlite+aiosqlite:/// for async support
if url.startswith("sqlite:///"):
url = url.replace("sqlite:///", "sqlite+aiosqlite:///")
return url
def _configure_sqlite_engine(engine: AsyncEngine) -> None:
"""Configure SQLite-specific engine settings.
Enables foreign key support and optimizes connection pooling.
Args:
engine: SQLAlchemy async engine instance
"""
@event.listens_for(engine.sync_engine, "connect")
def set_sqlite_pragma(dbapi_conn, connection_record):
"""Enable foreign keys and set pragmas for SQLite."""
cursor = dbapi_conn.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.execute("PRAGMA journal_mode=WAL")
cursor.close()
async def init_db() -> None:
"""Initialize database engine and create tables.
Creates async and sync engines, session factories, and database tables.
Should be called during application startup.
Raises:
Exception: If database initialization fails
"""
global _engine, _sync_engine, _session_factory, _sync_session_factory
try:
# Get database URL
db_url = _get_database_url()
logger.info(f"Initializing database: {db_url}")
# Build engine kwargs based on database type
is_sqlite = "sqlite" in db_url
engine_kwargs = {
"echo": settings.log_level == "DEBUG",
"poolclass": pool.StaticPool if is_sqlite else pool.QueuePool,
"pool_pre_ping": True,
}
# Only add pool_size and max_overflow for non-SQLite databases
if not is_sqlite:
engine_kwargs["pool_size"] = 5
engine_kwargs["max_overflow"] = 10
# Create async engine
_engine = create_async_engine(db_url, **engine_kwargs)
# Configure SQLite if needed
if is_sqlite:
_configure_sqlite_engine(_engine)
# Create async session factory
_session_factory = async_sessionmaker(
bind=_engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
# Create sync engine for initial setup
sync_url = settings.database_url
is_sqlite_sync = "sqlite" in sync_url
sync_engine_kwargs = {
"echo": settings.log_level == "DEBUG",
"poolclass": pool.StaticPool if is_sqlite_sync else pool.QueuePool,
"pool_pre_ping": True,
}
_sync_engine = create_engine(sync_url, **sync_engine_kwargs)
# Create sync session factory
_sync_session_factory = sessionmaker(
bind=_sync_engine,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
# Create all tables
logger.info("Creating database tables...")
Base.metadata.create_all(bind=_sync_engine)
logger.info("Database initialization complete")
except Exception as e:
logger.error(f"Failed to initialize database: {e}")
raise
async def close_db() -> None:
"""Close database connections and cleanup resources.
Should be called during application shutdown.
"""
global _engine, _sync_engine, _session_factory, _sync_session_factory
try:
if _engine:
logger.info("Closing async database engine...")
await _engine.dispose()
_engine = None
_session_factory = None
if _sync_engine:
logger.info("Closing sync database engine...")
_sync_engine.dispose()
_sync_engine = None
_sync_session_factory = None
logger.info("Database connections closed")
except Exception as e:
logger.error(f"Error closing database: {e}")
def get_engine() -> AsyncEngine:
"""Get the database engine instance.
Returns:
AsyncEngine instance
Raises:
RuntimeError: If database is not initialized
"""
if _engine is None:
raise RuntimeError(
"Database not initialized. Call init_db() first."
)
return _engine
def get_sync_engine():
"""Get the sync database engine instance.
Returns:
Engine instance
Raises:
RuntimeError: If database is not initialized
"""
if _sync_engine is None:
raise RuntimeError(
"Database not initialized. Call init_db() first."
)
return _sync_engine
@asynccontextmanager
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""FastAPI dependency to get database session.
Provides an async database session with automatic commit/rollback.
Use this as a dependency in FastAPI endpoints.
Yields:
AsyncSession: Database session for async operations
Raises:
RuntimeError: If database is not initialized
Example:
@app.get("/anime")
async def get_anime(
db: AsyncSession = Depends(get_db_session)
):
result = await db.execute(select(AnimeSeries))
return result.scalars().all()
"""
if _session_factory is None:
raise RuntimeError(
"Database not initialized. Call init_db() first."
)
session = _session_factory()
try:
yield session
await session.commit()
except Exception:
await session.rollback()
raise
finally:
await session.close()
def get_sync_session() -> Session:
"""Get a sync database session.
Use this for synchronous operations outside FastAPI endpoints.
Remember to close the session when done.
Returns:
Session: Database session for sync operations
Raises:
RuntimeError: If database is not initialized
Example:
session = get_sync_session()
try:
result = session.execute(select(AnimeSeries))
return result.scalars().all()
finally:
session.close()
"""
if _sync_session_factory is None:
raise RuntimeError(
"Database not initialized. Call init_db() first."
)
return _sync_session_factory()