602 lines
19 KiB
Python

"""Database initialization and setup module.
This module provides comprehensive database initialization functionality:
- Schema creation and validation
- Database health checks
- Schema versioning support
"""
from __future__ import annotations
import logging
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from sqlalchemy import inspect, text
from sqlalchemy.ext.asyncio import AsyncEngine
from src.config.settings import settings
from src.server.database.base import Base
from src.server.database.connection import get_engine
logger = logging.getLogger(__name__)
# =============================================================================
# Schema Version Constants
# =============================================================================
CURRENT_SCHEMA_VERSION = "1.0.0"
SCHEMA_VERSION_TABLE = "schema_version"
# Expected tables in the current schema
EXPECTED_TABLES = {
"anime_series",
"episodes",
"download_queue",
"user_sessions",
}
# Expected indexes for performance
EXPECTED_INDEXES = {
"anime_series": ["ix_anime_series_key", "ix_anime_series_name"],
"episodes": ["ix_episodes_series_id"],
"download_queue": [
"ix_download_queue_series_id",
"ix_download_queue_episode_id",
],
"user_sessions": [
"ix_user_sessions_session_id",
"ix_user_sessions_user_id",
"ix_user_sessions_is_active",
],
}
# =============================================================================
# Database Initialization
# =============================================================================
async def initialize_database(
engine: Optional[AsyncEngine] = None,
create_schema: bool = True,
validate_schema: bool = True,
seed_data: bool = False,
) -> Dict[str, Any]:
"""Initialize database with schema creation and validation.
This is the main entry point for database initialization. It performs:
1. Schema creation (if requested)
2. Schema validation (if requested)
3. Initial data seeding (if requested)
4. Health check
Args:
engine: Optional database engine (uses default if not provided)
create_schema: Whether to create database schema
validate_schema: Whether to validate schema after creation
seed_data: Whether to seed initial data
Returns:
Dictionary with initialization results containing:
- success: Whether initialization succeeded
- schema_version: Current schema version
- tables_created: List of tables created
- validation_result: Schema validation result
- health_check: Database health status
Raises:
RuntimeError: If database initialization fails
Example:
result = await initialize_database(
create_schema=True,
validate_schema=True,
seed_data=True
)
if result["success"]:
logger.info(f"Database initialized: {result['schema_version']}")
"""
if engine is None:
engine = get_engine()
logger.info("Starting database initialization...")
result = {
"success": False,
"schema_version": None,
"tables_created": [],
"validation_result": None,
"health_check": None,
}
try:
# Create schema if requested
if create_schema:
tables = await create_database_schema(engine)
result["tables_created"] = tables
logger.info(f"Created {len(tables)} tables")
# Validate schema if requested
if validate_schema:
validation = await validate_database_schema(engine)
result["validation_result"] = validation
if not validation["valid"]:
logger.warning(
f"Schema validation issues: {validation['issues']}"
)
# Seed initial data if requested
if seed_data:
await seed_initial_data(engine)
logger.info("Initial data seeding complete")
# Get schema version
version = await get_schema_version(engine)
result["schema_version"] = version
# Health check
health = await check_database_health(engine)
result["health_check"] = health
result["success"] = True
logger.info("Database initialization complete")
return result
except Exception as e:
logger.error(f"Database initialization failed: {e}", exc_info=True)
raise RuntimeError(f"Failed to initialize database: {e}") from e
async def create_database_schema(
engine: Optional[AsyncEngine] = None
) -> List[str]:
"""Create database schema with all tables and indexes.
Creates all tables defined in Base.metadata if they don't exist.
This is idempotent - safe to call multiple times.
Args:
engine: Optional database engine (uses default if not provided)
Returns:
List of table names created
Raises:
RuntimeError: If schema creation fails
"""
if engine is None:
engine = get_engine()
logger.info("Creating database schema...")
try:
# Create all tables
async with engine.begin() as conn:
# Get existing tables before creation
existing_tables = await conn.run_sync(
lambda sync_conn: inspect(sync_conn).get_table_names()
)
# Create all tables defined in Base
await conn.run_sync(Base.metadata.create_all)
# Get tables after creation
new_tables = await conn.run_sync(
lambda sync_conn: inspect(sync_conn).get_table_names()
)
# Determine which tables were created
created_tables = [t for t in new_tables if t not in existing_tables]
if created_tables:
logger.info(f"Created tables: {', '.join(created_tables)}")
else:
logger.info("All tables already exist")
return new_tables
except Exception as e:
logger.error(f"Failed to create schema: {e}", exc_info=True)
raise RuntimeError(f"Schema creation failed: {e}") from e
async def validate_database_schema(
engine: Optional[AsyncEngine] = None
) -> Dict[str, Any]:
"""Validate database schema integrity.
Checks that all expected tables, columns, and indexes exist.
Reports any missing or unexpected schema elements.
Args:
engine: Optional database engine (uses default if not provided)
Returns:
Dictionary with validation results containing:
- valid: Whether schema is valid
- missing_tables: List of missing tables
- extra_tables: List of unexpected tables
- missing_indexes: Dict of missing indexes by table
- issues: List of validation issues
"""
if engine is None:
engine = get_engine()
logger.info("Validating database schema...")
result = {
"valid": True,
"missing_tables": [],
"extra_tables": [],
"missing_indexes": {},
"issues": [],
}
try:
async with engine.connect() as conn:
# Get existing tables
existing_tables = await conn.run_sync(
lambda sync_conn: set(inspect(sync_conn).get_table_names())
)
# Check for missing tables
missing = EXPECTED_TABLES - existing_tables
if missing:
result["missing_tables"] = list(missing)
result["valid"] = False
result["issues"].append(
f"Missing tables: {', '.join(missing)}"
)
# Check for extra tables (excluding SQLite internal tables)
extra = existing_tables - EXPECTED_TABLES
extra = {t for t in extra if not t.startswith("sqlite_")}
if extra:
result["extra_tables"] = list(extra)
result["issues"].append(
f"Unexpected tables: {', '.join(extra)}"
)
# Check indexes for each table
for table_name in EXPECTED_TABLES & existing_tables:
existing_indexes = await conn.run_sync(
lambda sync_conn: [
idx["name"]
for idx in inspect(sync_conn).get_indexes(table_name)
]
)
expected_indexes = EXPECTED_INDEXES.get(table_name, [])
missing_indexes = [
idx for idx in expected_indexes
if idx not in existing_indexes
]
if missing_indexes:
result["missing_indexes"][table_name] = missing_indexes
result["valid"] = False
result["issues"].append(
f"Missing indexes on {table_name}: "
f"{', '.join(missing_indexes)}"
)
if result["valid"]:
logger.info("Schema validation passed")
else:
logger.warning(
f"Schema validation issues found: {len(result['issues'])}"
)
return result
except Exception as e:
logger.error(f"Schema validation failed: {e}", exc_info=True)
return {
"valid": False,
"missing_tables": [],
"extra_tables": [],
"missing_indexes": {},
"issues": [f"Validation error: {str(e)}"],
}
# =============================================================================
# Schema Version Management
# =============================================================================
async def get_schema_version(engine: Optional[AsyncEngine] = None) -> str:
"""Get current database schema version.
Returns version string based on existing tables and structure.
Args:
engine: Optional database engine (uses default if not provided)
Returns:
Schema version string (e.g., "1.0.0", "empty", "unknown")
"""
if engine is None:
engine = get_engine()
try:
async with engine.connect() as conn:
# Get existing tables
tables = await conn.run_sync(
lambda sync_conn: set(inspect(sync_conn).get_table_names())
)
# Filter out SQLite internal tables
tables = {t for t in tables if not t.startswith("sqlite_")}
if not tables:
return "empty"
elif tables == EXPECTED_TABLES:
return CURRENT_SCHEMA_VERSION
else:
return "unknown"
except Exception as e:
logger.error(f"Failed to get schema version: {e}")
return "error"
async def create_schema_version_table(
engine: Optional[AsyncEngine] = None
) -> None:
"""Create schema version tracking table.
Args:
engine: Optional database engine (uses default if not provided)
"""
if engine is None:
engine = get_engine()
async with engine.begin() as conn:
await conn.execute(
text(
f"""
CREATE TABLE IF NOT EXISTS {SCHEMA_VERSION_TABLE} (
version VARCHAR(20) PRIMARY KEY,
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
description TEXT
)
"""
)
)
# =============================================================================
# Initial Data Seeding
# =============================================================================
async def seed_initial_data(engine: Optional[AsyncEngine] = None) -> None:
"""Seed database with initial data.
Creates default configuration and sample data if database is empty.
Safe to call multiple times - only seeds if tables are empty.
Args:
engine: Optional database engine (uses default if not provided)
"""
if engine is None:
engine = get_engine()
logger.info("Seeding initial data...")
try:
# Use engine directly for seeding to avoid dependency on session factory
async with engine.connect() as conn:
# Check if data already exists
result = await conn.execute(
text("SELECT COUNT(*) FROM anime_series")
)
count = result.scalar()
if count > 0:
logger.info("Database already contains data, skipping seed")
return
# Seed sample data if needed
# Note: In production, you may want to skip this
logger.info("Database is empty, but no sample data to seed")
logger.info("Data will be populated via normal application usage")
except Exception as e:
logger.error(f"Failed to seed initial data: {e}", exc_info=True)
raise
# =============================================================================
# Database Health Check
# =============================================================================
async def check_database_health(
engine: Optional[AsyncEngine] = None
) -> Dict[str, Any]:
"""Check database health and connectivity.
Performs basic health checks including:
- Database connectivity
- Table accessibility
- Basic query execution
Args:
engine: Optional database engine (uses default if not provided)
Returns:
Dictionary with health check results containing:
- healthy: Overall health status
- accessible: Whether database is accessible
- tables: Number of tables
- connectivity_ms: Connection time in milliseconds
- issues: List of any health issues
"""
if engine is None:
engine = get_engine()
result = {
"healthy": True,
"accessible": False,
"tables": 0,
"connectivity_ms": 0,
"issues": [],
}
try:
# Measure connectivity time
import time
start_time = time.time()
async with engine.connect() as conn:
# Test basic query
await conn.execute(text("SELECT 1"))
# Get table count
tables = await conn.run_sync(
lambda sync_conn: inspect(sync_conn).get_table_names()
)
result["tables"] = len(tables)
end_time = time.time()
# Ensure at least 1ms for timing (avoid 0 for very fast operations)
result["connectivity_ms"] = max(1, int((end_time - start_time) * 1000))
result["accessible"] = True
# Check for expected tables
if result["tables"] < len(EXPECTED_TABLES):
result["healthy"] = False
result["issues"].append(
f"Expected {len(EXPECTED_TABLES)} tables, "
f"found {result['tables']}"
)
if result["healthy"]:
logger.info(
f"Database health check passed "
f"(connectivity: {result['connectivity_ms']}ms)"
)
else:
logger.warning(f"Database health issues: {result['issues']}")
return result
except Exception as e:
logger.error(f"Database health check failed: {e}")
return {
"healthy": False,
"accessible": False,
"tables": 0,
"connectivity_ms": 0,
"issues": [str(e)],
}
# =============================================================================
# Database Backup and Restore
# =============================================================================
async def create_database_backup(
backup_path: Optional[Path] = None
) -> Path:
"""Create database backup.
For SQLite databases, creates a copy of the database file.
For other databases, this should be extended to use appropriate tools.
Args:
backup_path: Optional path for backup file
(defaults to data/backups/aniworld_YYYYMMDD_HHMMSS.db)
Returns:
Path to created backup file
Raises:
RuntimeError: If backup creation fails
"""
import shutil
# Get database path from settings
db_url = settings.database_url
if not db_url.startswith("sqlite"):
raise NotImplementedError(
"Backup currently only supported for SQLite databases"
)
# Extract database file path
db_path = Path(db_url.replace("sqlite:///", ""))
if not db_path.exists():
raise RuntimeError(f"Database file not found: {db_path}")
# Create backup path
if backup_path is None:
backup_dir = Path("data/backups")
backup_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
backup_path = backup_dir / f"aniworld_{timestamp}.db"
try:
logger.info(f"Creating database backup: {backup_path}")
shutil.copy2(db_path, backup_path)
logger.info(f"Backup created successfully: {backup_path}")
return backup_path
except Exception as e:
logger.error(f"Failed to create backup: {e}", exc_info=True)
raise RuntimeError(f"Backup creation failed: {e}") from e
# =============================================================================
# Utility Functions
# =============================================================================
def get_database_info() -> Dict[str, Any]:
"""Get database configuration information.
Returns:
Dictionary with database configuration details
"""
return {
"database_url": settings.database_url,
"database_type": (
"sqlite" if "sqlite" in settings.database_url
else "postgresql" if "postgresql" in settings.database_url
else "mysql" if "mysql" in settings.database_url
else "unknown"
),
"schema_version": CURRENT_SCHEMA_VERSION,
"expected_tables": list(EXPECTED_TABLES),
"log_level": settings.log_level,
}
# =============================================================================
# Public API
# =============================================================================
__all__ = [
"initialize_database",
"create_database_schema",
"validate_database_schema",
"get_schema_version",
"create_schema_version_table",
"seed_initial_data",
"check_database_health",
"create_database_backup",
"get_database_info",
"CURRENT_SCHEMA_VERSION",
"EXPECTED_TABLES",
]