- Add WebSocket shutdown() with client notification and graceful close - Enhance download service stop() with pending state persistence - Expand FastAPI lifespan shutdown with proper cleanup sequence - Add SQLite WAL checkpoint before database close - Update stop_server.sh to use SIGTERM with timeout fallback - Configure uvicorn timeout_graceful_shutdown=30s - Update ARCHITECTURE.md with shutdown documentation
593 lines
18 KiB
Python
593 lines
18 KiB
Python
"""Database connection and session management for SQLAlchemy.
|
|
|
|
This module provides database engine creation, session factory configuration,
|
|
and dependency injection helpers for FastAPI endpoints.
|
|
|
|
Functions:
|
|
- init_db: Initialize database engine and create tables
|
|
- close_db: Close database connections and cleanup
|
|
- get_db_session: FastAPI dependency for database sessions
|
|
- get_transactional_session: Session without auto-commit for transactions
|
|
- get_engine: Get database engine instance
|
|
|
|
Classes:
|
|
- TransactionManager: Helper class for manual transaction control
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from contextlib import asynccontextmanager
|
|
from typing import AsyncGenerator, Optional
|
|
|
|
from sqlalchemy import create_engine, event, pool
|
|
from sqlalchemy.ext.asyncio import (
|
|
AsyncEngine,
|
|
AsyncSession,
|
|
async_sessionmaker,
|
|
create_async_engine,
|
|
)
|
|
from sqlalchemy.orm import Session, sessionmaker
|
|
|
|
from src.config.settings import settings
|
|
from src.server.database.base import Base
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Global engine and session factory instances
|
|
_engine: Optional[AsyncEngine] = None
|
|
_sync_engine: Optional[create_engine] = None
|
|
_session_factory: Optional[async_sessionmaker[AsyncSession]] = None
|
|
_sync_session_factory: Optional[sessionmaker[Session]] = None
|
|
|
|
|
|
def _get_database_url() -> str:
|
|
"""Get database URL from settings.
|
|
|
|
Converts SQLite URLs to async format if needed.
|
|
|
|
Returns:
|
|
Database URL string suitable for async engine
|
|
"""
|
|
url = settings.database_url
|
|
|
|
# Convert sqlite:/// to sqlite+aiosqlite:/// for async support
|
|
if url.startswith("sqlite:///"):
|
|
url = url.replace("sqlite:///", "sqlite+aiosqlite:///")
|
|
|
|
return url
|
|
|
|
|
|
def _configure_sqlite_engine(engine: AsyncEngine) -> None:
|
|
"""Configure SQLite-specific engine settings.
|
|
|
|
Enables foreign key support and optimizes connection pooling.
|
|
|
|
Args:
|
|
engine: SQLAlchemy async engine instance
|
|
"""
|
|
@event.listens_for(engine.sync_engine, "connect")
|
|
def set_sqlite_pragma(dbapi_conn, connection_record):
|
|
"""Enable foreign keys and set pragmas for SQLite."""
|
|
cursor = dbapi_conn.cursor()
|
|
cursor.execute("PRAGMA foreign_keys=ON")
|
|
cursor.execute("PRAGMA journal_mode=WAL")
|
|
cursor.close()
|
|
|
|
|
|
async def init_db() -> None:
|
|
"""Initialize database engine and create tables.
|
|
|
|
Creates async and sync engines, session factories, and database tables.
|
|
Should be called during application startup.
|
|
|
|
Raises:
|
|
Exception: If database initialization fails
|
|
"""
|
|
global _engine, _sync_engine, _session_factory, _sync_session_factory
|
|
|
|
try:
|
|
# Get database URL
|
|
db_url = _get_database_url()
|
|
logger.info(f"Initializing database: {db_url}")
|
|
|
|
# Build engine kwargs based on database type
|
|
is_sqlite = "sqlite" in db_url
|
|
engine_kwargs = {
|
|
"echo": settings.log_level == "DEBUG",
|
|
"poolclass": pool.StaticPool if is_sqlite else pool.QueuePool,
|
|
"pool_pre_ping": True,
|
|
}
|
|
|
|
# Only add pool_size and max_overflow for non-SQLite databases
|
|
if not is_sqlite:
|
|
engine_kwargs["pool_size"] = 5
|
|
engine_kwargs["max_overflow"] = 10
|
|
|
|
# Create async engine
|
|
_engine = create_async_engine(db_url, **engine_kwargs)
|
|
|
|
# Configure SQLite if needed
|
|
if is_sqlite:
|
|
_configure_sqlite_engine(_engine)
|
|
|
|
# Create async session factory
|
|
_session_factory = async_sessionmaker(
|
|
bind=_engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
|
|
# Create sync engine for initial setup
|
|
sync_url = settings.database_url
|
|
is_sqlite_sync = "sqlite" in sync_url
|
|
sync_engine_kwargs = {
|
|
"echo": settings.log_level == "DEBUG",
|
|
"poolclass": pool.StaticPool if is_sqlite_sync else pool.QueuePool,
|
|
"pool_pre_ping": True,
|
|
}
|
|
_sync_engine = create_engine(sync_url, **sync_engine_kwargs)
|
|
|
|
# Create sync session factory
|
|
_sync_session_factory = sessionmaker(
|
|
bind=_sync_engine,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
autocommit=False,
|
|
)
|
|
|
|
# Create all tables
|
|
logger.info("Creating database tables...")
|
|
Base.metadata.create_all(bind=_sync_engine)
|
|
logger.info("Database initialization complete")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to initialize database: {e}")
|
|
raise
|
|
|
|
|
|
async def close_db() -> None:
|
|
"""Close database connections and cleanup resources.
|
|
|
|
Performs a WAL checkpoint for SQLite databases to ensure all
|
|
pending writes are flushed to the main database file before
|
|
closing connections. This prevents database corruption during
|
|
shutdown.
|
|
|
|
Should be called during application shutdown.
|
|
"""
|
|
global _engine, _sync_engine, _session_factory, _sync_session_factory
|
|
|
|
try:
|
|
# For SQLite: checkpoint WAL to ensure all writes are flushed
|
|
if _sync_engine and "sqlite" in str(_sync_engine.url):
|
|
logger.info("Running SQLite WAL checkpoint before shutdown...")
|
|
try:
|
|
from sqlalchemy import text
|
|
with _sync_engine.connect() as conn:
|
|
# TRUNCATE mode: checkpoint and truncate WAL file
|
|
conn.execute(text("PRAGMA wal_checkpoint(TRUNCATE)"))
|
|
conn.commit()
|
|
logger.info("SQLite WAL checkpoint completed")
|
|
except Exception as e:
|
|
logger.warning(f"WAL checkpoint failed (non-critical): {e}")
|
|
|
|
if _engine:
|
|
logger.info("Closing async database engine...")
|
|
await _engine.dispose()
|
|
_engine = None
|
|
_session_factory = None
|
|
|
|
if _sync_engine:
|
|
logger.info("Closing sync database engine...")
|
|
_sync_engine.dispose()
|
|
_sync_engine = None
|
|
_sync_session_factory = None
|
|
|
|
logger.info("Database connections closed")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error closing database: {e}")
|
|
|
|
|
|
def get_engine() -> AsyncEngine:
|
|
"""Get the database engine instance.
|
|
|
|
Returns:
|
|
AsyncEngine instance
|
|
|
|
Raises:
|
|
RuntimeError: If database is not initialized
|
|
"""
|
|
if _engine is None:
|
|
raise RuntimeError(
|
|
"Database not initialized. Call init_db() first."
|
|
)
|
|
return _engine
|
|
|
|
|
|
def get_sync_engine():
|
|
"""Get the sync database engine instance.
|
|
|
|
Returns:
|
|
Engine instance
|
|
|
|
Raises:
|
|
RuntimeError: If database is not initialized
|
|
"""
|
|
if _sync_engine is None:
|
|
raise RuntimeError(
|
|
"Database not initialized. Call init_db() first."
|
|
)
|
|
return _sync_engine
|
|
|
|
|
|
@asynccontextmanager
|
|
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
|
"""FastAPI dependency to get database session.
|
|
|
|
Provides an async database session with automatic commit/rollback.
|
|
Use this as a dependency in FastAPI endpoints.
|
|
|
|
Yields:
|
|
AsyncSession: Database session for async operations
|
|
|
|
Raises:
|
|
RuntimeError: If database is not initialized
|
|
|
|
Example:
|
|
@app.get("/anime")
|
|
async def get_anime(
|
|
db: AsyncSession = Depends(get_db_session)
|
|
):
|
|
result = await db.execute(select(AnimeSeries))
|
|
return result.scalars().all()
|
|
"""
|
|
if _session_factory is None:
|
|
raise RuntimeError(
|
|
"Database not initialized. Call init_db() first."
|
|
)
|
|
|
|
session = _session_factory()
|
|
try:
|
|
yield session
|
|
await session.commit()
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
finally:
|
|
await session.close()
|
|
|
|
|
|
def get_sync_session() -> Session:
|
|
"""Get a sync database session.
|
|
|
|
Use this for synchronous operations outside FastAPI endpoints.
|
|
Remember to close the session when done.
|
|
|
|
Returns:
|
|
Session: Database session for sync operations
|
|
|
|
Raises:
|
|
RuntimeError: If database is not initialized
|
|
|
|
Example:
|
|
session = get_sync_session()
|
|
try:
|
|
result = session.execute(select(AnimeSeries))
|
|
return result.scalars().all()
|
|
finally:
|
|
session.close()
|
|
"""
|
|
if _sync_session_factory is None:
|
|
raise RuntimeError(
|
|
"Database not initialized. Call init_db() first."
|
|
)
|
|
|
|
return _sync_session_factory()
|
|
|
|
|
|
def get_async_session_factory() -> AsyncSession:
|
|
"""Get a new async database session (factory function).
|
|
|
|
Creates a new session instance for use in repository patterns.
|
|
The caller is responsible for committing/rolling back and closing.
|
|
|
|
Returns:
|
|
AsyncSession: New database session for async operations
|
|
|
|
Raises:
|
|
RuntimeError: If database is not initialized
|
|
|
|
Example:
|
|
session = get_async_session_factory()
|
|
try:
|
|
result = await session.execute(select(AnimeSeries))
|
|
await session.commit()
|
|
return result.scalars().all()
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
finally:
|
|
await session.close()
|
|
"""
|
|
if _session_factory is None:
|
|
raise RuntimeError(
|
|
"Database not initialized. Call init_db() first."
|
|
)
|
|
|
|
return _session_factory()
|
|
|
|
|
|
@asynccontextmanager
|
|
async def get_transactional_session() -> AsyncGenerator[AsyncSession, None]:
|
|
"""Get a database session without auto-commit for explicit transaction control.
|
|
|
|
Unlike get_db_session(), this does NOT auto-commit on success.
|
|
Use this when you need explicit transaction control with the
|
|
@transactional decorator or atomic() context manager.
|
|
|
|
Yields:
|
|
AsyncSession: Database session for async operations
|
|
|
|
Raises:
|
|
RuntimeError: If database is not initialized
|
|
|
|
Example:
|
|
async with get_transactional_session() as session:
|
|
async with atomic(session) as tx:
|
|
# Multiple operations in transaction
|
|
await operation1(session)
|
|
await operation2(session)
|
|
# Committed when exiting atomic() context
|
|
"""
|
|
if _session_factory is None:
|
|
raise RuntimeError(
|
|
"Database not initialized. Call init_db() first."
|
|
)
|
|
|
|
session = _session_factory()
|
|
try:
|
|
yield session
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
finally:
|
|
await session.close()
|
|
|
|
|
|
class TransactionManager:
|
|
"""Helper class for manual transaction control.
|
|
|
|
Provides a cleaner interface for managing transactions across
|
|
multiple service calls within a single request.
|
|
|
|
Attributes:
|
|
_session_factory: Factory for creating new sessions
|
|
_session: Current active session
|
|
_in_transaction: Whether currently in a transaction
|
|
|
|
Example:
|
|
async with TransactionManager() as tm:
|
|
session = await tm.get_session()
|
|
await tm.begin()
|
|
try:
|
|
await service1.operation(session)
|
|
await service2.operation(session)
|
|
await tm.commit()
|
|
except Exception:
|
|
await tm.rollback()
|
|
raise
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
session_factory: Optional[async_sessionmaker] = None
|
|
) -> None:
|
|
"""Initialize transaction manager.
|
|
|
|
Args:
|
|
session_factory: Optional custom session factory.
|
|
Uses global factory if not provided.
|
|
"""
|
|
self._session_factory = session_factory or _session_factory
|
|
self._session: Optional[AsyncSession] = None
|
|
self._in_transaction = False
|
|
|
|
if self._session_factory is None:
|
|
raise RuntimeError(
|
|
"Database not initialized. Call init_db() first."
|
|
)
|
|
|
|
async def __aenter__(self) -> "TransactionManager":
|
|
"""Enter context manager and create session."""
|
|
self._session = self._session_factory()
|
|
logger.debug("TransactionManager: Created new session")
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: Optional[type],
|
|
exc_val: Optional[BaseException],
|
|
exc_tb: Optional[object],
|
|
) -> bool:
|
|
"""Exit context manager and cleanup session.
|
|
|
|
Automatically rolls back if an exception occurred and
|
|
transaction wasn't explicitly committed.
|
|
"""
|
|
if self._session:
|
|
if exc_type is not None and self._in_transaction:
|
|
logger.warning(
|
|
"TransactionManager: Rolling back due to exception: %s",
|
|
exc_val,
|
|
)
|
|
await self._session.rollback()
|
|
|
|
await self._session.close()
|
|
self._session = None
|
|
self._in_transaction = False
|
|
logger.debug("TransactionManager: Session closed")
|
|
|
|
return False
|
|
|
|
async def get_session(self) -> AsyncSession:
|
|
"""Get the current session.
|
|
|
|
Returns:
|
|
Current AsyncSession instance
|
|
|
|
Raises:
|
|
RuntimeError: If not within context manager
|
|
"""
|
|
if self._session is None:
|
|
raise RuntimeError(
|
|
"TransactionManager must be used as async context manager"
|
|
)
|
|
return self._session
|
|
|
|
async def begin(self) -> None:
|
|
"""Begin a new transaction.
|
|
|
|
Raises:
|
|
RuntimeError: If already in a transaction or no session
|
|
"""
|
|
if self._session is None:
|
|
raise RuntimeError("No active session")
|
|
|
|
if self._in_transaction:
|
|
raise RuntimeError("Already in a transaction")
|
|
|
|
await self._session.begin()
|
|
self._in_transaction = True
|
|
logger.debug("TransactionManager: Transaction started")
|
|
|
|
async def commit(self) -> None:
|
|
"""Commit the current transaction.
|
|
|
|
Raises:
|
|
RuntimeError: If not in a transaction
|
|
"""
|
|
if not self._in_transaction or self._session is None:
|
|
raise RuntimeError("Not in a transaction")
|
|
|
|
await self._session.commit()
|
|
self._in_transaction = False
|
|
logger.debug("TransactionManager: Transaction committed")
|
|
|
|
async def rollback(self) -> None:
|
|
"""Rollback the current transaction.
|
|
|
|
Raises:
|
|
RuntimeError: If not in a transaction
|
|
"""
|
|
if self._session is None:
|
|
raise RuntimeError("No active session")
|
|
|
|
await self._session.rollback()
|
|
self._in_transaction = False
|
|
logger.debug("TransactionManager: Transaction rolled back")
|
|
|
|
async def savepoint(self, name: Optional[str] = None) -> "SavepointHandle":
|
|
"""Create a savepoint within the current transaction.
|
|
|
|
Args:
|
|
name: Optional savepoint name
|
|
|
|
Returns:
|
|
SavepointHandle for controlling the savepoint
|
|
|
|
Raises:
|
|
RuntimeError: If not in a transaction
|
|
"""
|
|
if not self._in_transaction or self._session is None:
|
|
raise RuntimeError("Must be in a transaction to create savepoint")
|
|
|
|
nested = await self._session.begin_nested()
|
|
return SavepointHandle(nested, name or "unnamed")
|
|
|
|
def is_in_transaction(self) -> bool:
|
|
"""Check if currently in a transaction.
|
|
|
|
Returns:
|
|
True if in an active transaction
|
|
"""
|
|
return self._in_transaction
|
|
|
|
def get_transaction_depth(self) -> int:
|
|
"""Get current transaction nesting depth.
|
|
|
|
Returns:
|
|
0 if not in transaction, 1+ for nested transactions
|
|
"""
|
|
if not self._in_transaction:
|
|
return 0
|
|
return 1 # Basic implementation - could be extended
|
|
|
|
|
|
class SavepointHandle:
|
|
"""Handle for controlling a database savepoint.
|
|
|
|
Attributes:
|
|
_nested: SQLAlchemy nested transaction
|
|
_name: Savepoint name for logging
|
|
_released: Whether savepoint has been released
|
|
"""
|
|
|
|
def __init__(self, nested: object, name: str) -> None:
|
|
"""Initialize savepoint handle.
|
|
|
|
Args:
|
|
nested: SQLAlchemy nested transaction object
|
|
name: Savepoint name
|
|
"""
|
|
self._nested = nested
|
|
self._name = name
|
|
self._released = False
|
|
logger.debug("Created savepoint: %s", name)
|
|
|
|
async def rollback(self) -> None:
|
|
"""Rollback to this savepoint."""
|
|
if not self._released:
|
|
await self._nested.rollback()
|
|
self._released = True
|
|
logger.debug("Rolled back savepoint: %s", self._name)
|
|
|
|
async def release(self) -> None:
|
|
"""Release (commit) this savepoint."""
|
|
if not self._released:
|
|
# Nested transactions commit automatically in SQLAlchemy
|
|
self._released = True
|
|
logger.debug("Released savepoint: %s", self._name)
|
|
|
|
|
|
def is_session_in_transaction(session: AsyncSession | Session) -> bool:
|
|
"""Check if a session is currently in a transaction.
|
|
|
|
Args:
|
|
session: SQLAlchemy session (sync or async)
|
|
|
|
Returns:
|
|
True if session is in an active transaction
|
|
"""
|
|
return session.in_transaction()
|
|
|
|
|
|
def get_session_transaction_depth(session: AsyncSession | Session) -> int:
|
|
"""Get the transaction nesting depth of a session.
|
|
|
|
Args:
|
|
session: SQLAlchemy session (sync or async)
|
|
|
|
Returns:
|
|
Number of nested transactions (0 if not in transaction)
|
|
"""
|
|
if not session.in_transaction():
|
|
return 0
|
|
|
|
# Check for nested transaction state
|
|
# Note: SQLAlchemy doesn't directly expose nesting depth
|
|
return 1
|
|
|