602 lines
19 KiB
Python
602 lines
19 KiB
Python
"""Database initialization and setup module.
|
|
|
|
This module provides comprehensive database initialization functionality:
|
|
- Schema creation and validation
|
|
- Database health checks
|
|
- Schema versioning support
|
|
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from sqlalchemy import inspect, text
|
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
|
|
|
from src.config.settings import settings
|
|
from src.server.database.base import Base
|
|
from src.server.database.connection import get_engine
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# =============================================================================
|
|
# Schema Version Constants
|
|
# =============================================================================
|
|
|
|
CURRENT_SCHEMA_VERSION = "1.0.0"
|
|
SCHEMA_VERSION_TABLE = "schema_version"
|
|
|
|
# Expected tables in the current schema
|
|
EXPECTED_TABLES = {
|
|
"anime_series",
|
|
"episodes",
|
|
"download_queue",
|
|
"user_sessions",
|
|
}
|
|
|
|
# Expected indexes for performance
|
|
EXPECTED_INDEXES = {
|
|
"anime_series": ["ix_anime_series_key", "ix_anime_series_name"],
|
|
"episodes": ["ix_episodes_series_id"],
|
|
"download_queue": [
|
|
"ix_download_queue_series_id",
|
|
"ix_download_queue_episode_id",
|
|
],
|
|
"user_sessions": [
|
|
"ix_user_sessions_session_id",
|
|
"ix_user_sessions_user_id",
|
|
"ix_user_sessions_is_active",
|
|
],
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Database Initialization
|
|
# =============================================================================
|
|
|
|
|
|
async def initialize_database(
|
|
engine: Optional[AsyncEngine] = None,
|
|
create_schema: bool = True,
|
|
validate_schema: bool = True,
|
|
seed_data: bool = False,
|
|
) -> Dict[str, Any]:
|
|
"""Initialize database with schema creation and validation.
|
|
|
|
This is the main entry point for database initialization. It performs:
|
|
1. Schema creation (if requested)
|
|
2. Schema validation (if requested)
|
|
3. Initial data seeding (if requested)
|
|
4. Health check
|
|
|
|
Args:
|
|
engine: Optional database engine (uses default if not provided)
|
|
create_schema: Whether to create database schema
|
|
validate_schema: Whether to validate schema after creation
|
|
seed_data: Whether to seed initial data
|
|
|
|
Returns:
|
|
Dictionary with initialization results containing:
|
|
- success: Whether initialization succeeded
|
|
- schema_version: Current schema version
|
|
- tables_created: List of tables created
|
|
- validation_result: Schema validation result
|
|
- health_check: Database health status
|
|
|
|
Raises:
|
|
RuntimeError: If database initialization fails
|
|
|
|
Example:
|
|
result = await initialize_database(
|
|
create_schema=True,
|
|
validate_schema=True,
|
|
seed_data=True
|
|
)
|
|
if result["success"]:
|
|
logger.info(f"Database initialized: {result['schema_version']}")
|
|
"""
|
|
if engine is None:
|
|
engine = get_engine()
|
|
|
|
logger.info("Starting database initialization...")
|
|
result = {
|
|
"success": False,
|
|
"schema_version": None,
|
|
"tables_created": [],
|
|
"validation_result": None,
|
|
"health_check": None,
|
|
}
|
|
|
|
try:
|
|
# Create schema if requested
|
|
if create_schema:
|
|
tables = await create_database_schema(engine)
|
|
result["tables_created"] = tables
|
|
logger.info(f"Created {len(tables)} tables")
|
|
|
|
# Validate schema if requested
|
|
if validate_schema:
|
|
validation = await validate_database_schema(engine)
|
|
result["validation_result"] = validation
|
|
|
|
if not validation["valid"]:
|
|
logger.warning(
|
|
f"Schema validation issues: {validation['issues']}"
|
|
)
|
|
|
|
# Seed initial data if requested
|
|
if seed_data:
|
|
await seed_initial_data(engine)
|
|
logger.info("Initial data seeding complete")
|
|
|
|
# Get schema version
|
|
version = await get_schema_version(engine)
|
|
result["schema_version"] = version
|
|
|
|
# Health check
|
|
health = await check_database_health(engine)
|
|
result["health_check"] = health
|
|
|
|
result["success"] = True
|
|
logger.info("Database initialization complete")
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Database initialization failed: {e}", exc_info=True)
|
|
raise RuntimeError(f"Failed to initialize database: {e}") from e
|
|
|
|
|
|
async def create_database_schema(
|
|
engine: Optional[AsyncEngine] = None
|
|
) -> List[str]:
|
|
"""Create database schema with all tables and indexes.
|
|
|
|
Creates all tables defined in Base.metadata if they don't exist.
|
|
This is idempotent - safe to call multiple times.
|
|
|
|
Args:
|
|
engine: Optional database engine (uses default if not provided)
|
|
|
|
Returns:
|
|
List of table names created
|
|
|
|
Raises:
|
|
RuntimeError: If schema creation fails
|
|
"""
|
|
if engine is None:
|
|
engine = get_engine()
|
|
|
|
logger.info("Creating database schema...")
|
|
|
|
try:
|
|
# Create all tables
|
|
async with engine.begin() as conn:
|
|
# Get existing tables before creation
|
|
existing_tables = await conn.run_sync(
|
|
lambda sync_conn: inspect(sync_conn).get_table_names()
|
|
)
|
|
|
|
# Create all tables defined in Base
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
# Get tables after creation
|
|
new_tables = await conn.run_sync(
|
|
lambda sync_conn: inspect(sync_conn).get_table_names()
|
|
)
|
|
|
|
# Determine which tables were created
|
|
created_tables = [t for t in new_tables if t not in existing_tables]
|
|
|
|
if created_tables:
|
|
logger.info(f"Created tables: {', '.join(created_tables)}")
|
|
else:
|
|
logger.info("All tables already exist")
|
|
|
|
return new_tables
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create schema: {e}", exc_info=True)
|
|
raise RuntimeError(f"Schema creation failed: {e}") from e
|
|
|
|
|
|
async def validate_database_schema(
|
|
engine: Optional[AsyncEngine] = None
|
|
) -> Dict[str, Any]:
|
|
"""Validate database schema integrity.
|
|
|
|
Checks that all expected tables, columns, and indexes exist.
|
|
Reports any missing or unexpected schema elements.
|
|
|
|
Args:
|
|
engine: Optional database engine (uses default if not provided)
|
|
|
|
Returns:
|
|
Dictionary with validation results containing:
|
|
- valid: Whether schema is valid
|
|
- missing_tables: List of missing tables
|
|
- extra_tables: List of unexpected tables
|
|
- missing_indexes: Dict of missing indexes by table
|
|
- issues: List of validation issues
|
|
"""
|
|
if engine is None:
|
|
engine = get_engine()
|
|
|
|
logger.info("Validating database schema...")
|
|
|
|
result = {
|
|
"valid": True,
|
|
"missing_tables": [],
|
|
"extra_tables": [],
|
|
"missing_indexes": {},
|
|
"issues": [],
|
|
}
|
|
|
|
try:
|
|
async with engine.connect() as conn:
|
|
# Get existing tables
|
|
existing_tables = await conn.run_sync(
|
|
lambda sync_conn: set(inspect(sync_conn).get_table_names())
|
|
)
|
|
|
|
# Check for missing tables
|
|
missing = EXPECTED_TABLES - existing_tables
|
|
if missing:
|
|
result["missing_tables"] = list(missing)
|
|
result["valid"] = False
|
|
result["issues"].append(
|
|
f"Missing tables: {', '.join(missing)}"
|
|
)
|
|
|
|
# Check for extra tables (excluding SQLite internal tables)
|
|
extra = existing_tables - EXPECTED_TABLES
|
|
extra = {t for t in extra if not t.startswith("sqlite_")}
|
|
if extra:
|
|
result["extra_tables"] = list(extra)
|
|
result["issues"].append(
|
|
f"Unexpected tables: {', '.join(extra)}"
|
|
)
|
|
|
|
# Check indexes for each table
|
|
for table_name in EXPECTED_TABLES & existing_tables:
|
|
existing_indexes = await conn.run_sync(
|
|
lambda sync_conn: [
|
|
idx["name"]
|
|
for idx in inspect(sync_conn).get_indexes(table_name)
|
|
]
|
|
)
|
|
|
|
expected_indexes = EXPECTED_INDEXES.get(table_name, [])
|
|
missing_indexes = [
|
|
idx for idx in expected_indexes
|
|
if idx not in existing_indexes
|
|
]
|
|
|
|
if missing_indexes:
|
|
result["missing_indexes"][table_name] = missing_indexes
|
|
result["valid"] = False
|
|
result["issues"].append(
|
|
f"Missing indexes on {table_name}: "
|
|
f"{', '.join(missing_indexes)}"
|
|
)
|
|
|
|
if result["valid"]:
|
|
logger.info("Schema validation passed")
|
|
else:
|
|
logger.warning(
|
|
f"Schema validation issues found: {len(result['issues'])}"
|
|
)
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Schema validation failed: {e}", exc_info=True)
|
|
return {
|
|
"valid": False,
|
|
"missing_tables": [],
|
|
"extra_tables": [],
|
|
"missing_indexes": {},
|
|
"issues": [f"Validation error: {str(e)}"],
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Schema Version Management
|
|
# =============================================================================
|
|
|
|
|
|
async def get_schema_version(engine: Optional[AsyncEngine] = None) -> str:
|
|
"""Get current database schema version.
|
|
|
|
Returns version string based on existing tables and structure.
|
|
|
|
Args:
|
|
engine: Optional database engine (uses default if not provided)
|
|
|
|
Returns:
|
|
Schema version string (e.g., "1.0.0", "empty", "unknown")
|
|
"""
|
|
if engine is None:
|
|
engine = get_engine()
|
|
|
|
try:
|
|
async with engine.connect() as conn:
|
|
# Get existing tables
|
|
tables = await conn.run_sync(
|
|
lambda sync_conn: set(inspect(sync_conn).get_table_names())
|
|
)
|
|
|
|
# Filter out SQLite internal tables
|
|
tables = {t for t in tables if not t.startswith("sqlite_")}
|
|
|
|
if not tables:
|
|
return "empty"
|
|
elif tables == EXPECTED_TABLES:
|
|
return CURRENT_SCHEMA_VERSION
|
|
else:
|
|
return "unknown"
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get schema version: {e}")
|
|
return "error"
|
|
|
|
|
|
async def create_schema_version_table(
|
|
engine: Optional[AsyncEngine] = None
|
|
) -> None:
|
|
"""Create schema version tracking table.
|
|
|
|
Args:
|
|
engine: Optional database engine (uses default if not provided)
|
|
"""
|
|
if engine is None:
|
|
engine = get_engine()
|
|
|
|
async with engine.begin() as conn:
|
|
await conn.execute(
|
|
text(
|
|
f"""
|
|
CREATE TABLE IF NOT EXISTS {SCHEMA_VERSION_TABLE} (
|
|
version VARCHAR(20) PRIMARY KEY,
|
|
applied_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
|
description TEXT
|
|
)
|
|
"""
|
|
)
|
|
)
|
|
|
|
|
|
# =============================================================================
|
|
# Initial Data Seeding
|
|
# =============================================================================
|
|
|
|
|
|
async def seed_initial_data(engine: Optional[AsyncEngine] = None) -> None:
|
|
"""Seed database with initial data.
|
|
|
|
Creates default configuration and sample data if database is empty.
|
|
Safe to call multiple times - only seeds if tables are empty.
|
|
|
|
Args:
|
|
engine: Optional database engine (uses default if not provided)
|
|
"""
|
|
if engine is None:
|
|
engine = get_engine()
|
|
|
|
logger.info("Seeding initial data...")
|
|
|
|
try:
|
|
# Use engine directly for seeding to avoid dependency on session factory
|
|
async with engine.connect() as conn:
|
|
# Check if data already exists
|
|
result = await conn.execute(
|
|
text("SELECT COUNT(*) FROM anime_series")
|
|
)
|
|
count = result.scalar()
|
|
|
|
if count > 0:
|
|
logger.info("Database already contains data, skipping seed")
|
|
return
|
|
|
|
# Seed sample data if needed
|
|
# Note: In production, you may want to skip this
|
|
logger.info("Database is empty, but no sample data to seed")
|
|
logger.info("Data will be populated via normal application usage")
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to seed initial data: {e}", exc_info=True)
|
|
raise
|
|
|
|
|
|
# =============================================================================
|
|
# Database Health Check
|
|
# =============================================================================
|
|
|
|
|
|
async def check_database_health(
|
|
engine: Optional[AsyncEngine] = None
|
|
) -> Dict[str, Any]:
|
|
"""Check database health and connectivity.
|
|
|
|
Performs basic health checks including:
|
|
- Database connectivity
|
|
- Table accessibility
|
|
- Basic query execution
|
|
|
|
Args:
|
|
engine: Optional database engine (uses default if not provided)
|
|
|
|
Returns:
|
|
Dictionary with health check results containing:
|
|
- healthy: Overall health status
|
|
- accessible: Whether database is accessible
|
|
- tables: Number of tables
|
|
- connectivity_ms: Connection time in milliseconds
|
|
- issues: List of any health issues
|
|
"""
|
|
if engine is None:
|
|
engine = get_engine()
|
|
|
|
result = {
|
|
"healthy": True,
|
|
"accessible": False,
|
|
"tables": 0,
|
|
"connectivity_ms": 0,
|
|
"issues": [],
|
|
}
|
|
|
|
try:
|
|
# Measure connectivity time
|
|
import time
|
|
start_time = time.time()
|
|
|
|
async with engine.connect() as conn:
|
|
# Test basic query
|
|
await conn.execute(text("SELECT 1"))
|
|
|
|
# Get table count
|
|
tables = await conn.run_sync(
|
|
lambda sync_conn: inspect(sync_conn).get_table_names()
|
|
)
|
|
result["tables"] = len(tables)
|
|
|
|
end_time = time.time()
|
|
# Ensure at least 1ms for timing (avoid 0 for very fast operations)
|
|
result["connectivity_ms"] = max(1, int((end_time - start_time) * 1000))
|
|
result["accessible"] = True
|
|
|
|
# Check for expected tables
|
|
if result["tables"] < len(EXPECTED_TABLES):
|
|
result["healthy"] = False
|
|
result["issues"].append(
|
|
f"Expected {len(EXPECTED_TABLES)} tables, "
|
|
f"found {result['tables']}"
|
|
)
|
|
|
|
if result["healthy"]:
|
|
logger.info(
|
|
f"Database health check passed "
|
|
f"(connectivity: {result['connectivity_ms']}ms)"
|
|
)
|
|
else:
|
|
logger.warning(f"Database health issues: {result['issues']}")
|
|
|
|
return result
|
|
|
|
except Exception as e:
|
|
logger.error(f"Database health check failed: {e}")
|
|
return {
|
|
"healthy": False,
|
|
"accessible": False,
|
|
"tables": 0,
|
|
"connectivity_ms": 0,
|
|
"issues": [str(e)],
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Database Backup and Restore
|
|
# =============================================================================
|
|
|
|
|
|
async def create_database_backup(
|
|
backup_path: Optional[Path] = None
|
|
) -> Path:
|
|
"""Create database backup.
|
|
|
|
For SQLite databases, creates a copy of the database file.
|
|
For other databases, this should be extended to use appropriate tools.
|
|
|
|
Args:
|
|
backup_path: Optional path for backup file
|
|
(defaults to data/backups/aniworld_YYYYMMDD_HHMMSS.db)
|
|
|
|
Returns:
|
|
Path to created backup file
|
|
|
|
Raises:
|
|
RuntimeError: If backup creation fails
|
|
"""
|
|
import shutil
|
|
|
|
# Get database path from settings
|
|
db_url = settings.database_url
|
|
|
|
if not db_url.startswith("sqlite"):
|
|
raise NotImplementedError(
|
|
"Backup currently only supported for SQLite databases"
|
|
)
|
|
|
|
# Extract database file path
|
|
db_path = Path(db_url.replace("sqlite:///", ""))
|
|
|
|
if not db_path.exists():
|
|
raise RuntimeError(f"Database file not found: {db_path}")
|
|
|
|
# Create backup path
|
|
if backup_path is None:
|
|
backup_dir = Path("data/backups")
|
|
backup_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
backup_path = backup_dir / f"aniworld_{timestamp}.db"
|
|
|
|
try:
|
|
logger.info(f"Creating database backup: {backup_path}")
|
|
shutil.copy2(db_path, backup_path)
|
|
logger.info(f"Backup created successfully: {backup_path}")
|
|
return backup_path
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to create backup: {e}", exc_info=True)
|
|
raise RuntimeError(f"Backup creation failed: {e}") from e
|
|
|
|
|
|
# =============================================================================
|
|
# Utility Functions
|
|
# =============================================================================
|
|
|
|
|
|
def get_database_info() -> Dict[str, Any]:
|
|
"""Get database configuration information.
|
|
|
|
Returns:
|
|
Dictionary with database configuration details
|
|
"""
|
|
return {
|
|
"database_url": settings.database_url,
|
|
"database_type": (
|
|
"sqlite" if "sqlite" in settings.database_url
|
|
else "postgresql" if "postgresql" in settings.database_url
|
|
else "mysql" if "mysql" in settings.database_url
|
|
else "unknown"
|
|
),
|
|
"schema_version": CURRENT_SCHEMA_VERSION,
|
|
"expected_tables": list(EXPECTED_TABLES),
|
|
"log_level": settings.log_level,
|
|
}
|
|
|
|
|
|
# =============================================================================
|
|
# Public API
|
|
# =============================================================================
|
|
|
|
|
|
__all__ = [
|
|
"initialize_database",
|
|
"create_database_schema",
|
|
"validate_database_schema",
|
|
"get_schema_version",
|
|
"create_schema_version_table",
|
|
"seed_initial_data",
|
|
"check_database_health",
|
|
"create_database_backup",
|
|
"get_database_info",
|
|
"CURRENT_SCHEMA_VERSION",
|
|
"EXPECTED_TABLES",
|
|
]
|