""" Migration runner for executing database migrations. This module handles the execution of migrations in the correct order, tracks migration history, and provides rollback capabilities. """ import importlib.util import logging import time from datetime import datetime from pathlib import Path from typing import List, Optional from sqlalchemy import text from sqlalchemy.ext.asyncio import AsyncSession from .base import Migration, MigrationError, MigrationHistory logger = logging.getLogger(__name__) class MigrationRunner: """ Manages database migration execution and tracking. This class handles loading migrations, executing them in order, tracking their status, and rolling back when needed. """ def __init__(self, migrations_dir: Path, session: AsyncSession): """ Initialize migration runner. Args: migrations_dir: Directory containing migration files session: Database session for executing migrations """ self.migrations_dir = migrations_dir self.session = session self._migrations: List[Migration] = [] async def initialize(self) -> None: """ Initialize migration system by creating tracking table if needed. Raises: MigrationError: If initialization fails """ try: # Create migration_history table if it doesn't exist create_table_sql = """ CREATE TABLE IF NOT EXISTS migration_history ( id INTEGER PRIMARY KEY AUTOINCREMENT, version TEXT NOT NULL UNIQUE, description TEXT NOT NULL, applied_at TIMESTAMP NOT NULL, execution_time_ms INTEGER NOT NULL, success BOOLEAN NOT NULL DEFAULT 1, error_message TEXT ) """ await self.session.execute(text(create_table_sql)) await self.session.commit() logger.info("Migration system initialized") except Exception as e: logger.error(f"Failed to initialize migration system: {e}") raise MigrationError(f"Initialization failed: {e}") from e def load_migrations(self) -> None: """ Load all migration files from the migrations directory. Migration files should be named in format: {version}_{description}.py and contain a Migration class that inherits from base.Migration. Raises: MigrationError: If loading migrations fails """ try: self._migrations.clear() if not self.migrations_dir.exists(): logger.warning(f"Migrations directory does not exist: {self.migrations_dir}") return # Find all Python files in migrations directory migration_files = sorted(self.migrations_dir.glob("*.py")) migration_files = [f for f in migration_files if f.name != "__init__.py"] for file_path in migration_files: try: # Import the migration module dynamically spec = importlib.util.spec_from_file_location( f"migration.{file_path.stem}", file_path ) if spec and spec.loader: module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) # Find Migration subclass in module for attr_name in dir(module): attr = getattr(module, attr_name) if ( isinstance(attr, type) and issubclass(attr, Migration) and attr != Migration ): migration_instance = attr() self._migrations.append(migration_instance) logger.debug(f"Loaded migration: {migration_instance.version}") break except Exception as e: logger.error(f"Failed to load migration {file_path.name}: {e}") raise MigrationError(f"Failed to load {file_path.name}: {e}") from e # Sort migrations by version self._migrations.sort(key=lambda m: m.version) logger.info(f"Loaded {len(self._migrations)} migrations") except Exception as e: logger.error(f"Failed to load migrations: {e}") raise MigrationError(f"Loading migrations failed: {e}") from e async def get_applied_migrations(self) -> List[str]: """ Get list of already applied migration versions. Returns: List of migration versions that have been applied Raises: MigrationError: If query fails """ try: result = await self.session.execute( text("SELECT version FROM migration_history WHERE success = 1 ORDER BY version") ) versions = [row[0] for row in result.fetchall()] return versions except Exception as e: logger.error(f"Failed to get applied migrations: {e}") raise MigrationError(f"Query failed: {e}") from e async def get_pending_migrations(self) -> List[Migration]: """ Get list of migrations that haven't been applied yet. Returns: List of pending Migration objects Raises: MigrationError: If check fails """ applied = await self.get_applied_migrations() pending = [m for m in self._migrations if m.version not in applied] return pending async def apply_migration(self, migration: Migration) -> None: """ Apply a single migration. Args: migration: Migration to apply Raises: MigrationError: If migration fails """ start_time = time.time() success = False error_message = None try: logger.info(f"Applying migration: {migration.version} - {migration.description}") # Execute the migration await migration.upgrade(self.session) await self.session.commit() success = True execution_time_ms = int((time.time() - start_time) * 1000) logger.info( f"Migration {migration.version} applied successfully in {execution_time_ms}ms" ) except Exception as e: error_message = str(e) execution_time_ms = int((time.time() - start_time) * 1000) logger.error(f"Migration {migration.version} failed: {e}") await self.session.rollback() raise MigrationError(f"Migration {migration.version} failed: {e}") from e finally: # Record migration in history try: history_record = MigrationHistory( version=migration.version, description=migration.description, applied_at=datetime.now(), execution_time_ms=execution_time_ms, success=success, error_message=error_message, ) insert_sql = """ INSERT INTO migration_history (version, description, applied_at, execution_time_ms, success, error_message) VALUES (:version, :description, :applied_at, :execution_time_ms, :success, :error_message) """ await self.session.execute( text(insert_sql), { "version": history_record.version, "description": history_record.description, "applied_at": history_record.applied_at, "execution_time_ms": history_record.execution_time_ms, "success": history_record.success, "error_message": history_record.error_message, }, ) await self.session.commit() except Exception as e: logger.error(f"Failed to record migration history: {e}") async def run_migrations(self, target_version: Optional[str] = None) -> int: """ Run all pending migrations up to target version. Args: target_version: Stop at this version (None = run all) Returns: Number of migrations applied Raises: MigrationError: If migrations fail """ pending = await self.get_pending_migrations() if target_version: pending = [m for m in pending if m.version <= target_version] if not pending: logger.info("No pending migrations to apply") return 0 logger.info(f"Applying {len(pending)} pending migrations") for migration in pending: await self.apply_migration(migration) return len(pending) async def rollback_migration(self, migration: Migration) -> None: """ Rollback a single migration. Args: migration: Migration to rollback Raises: MigrationError: If rollback fails """ start_time = time.time() try: logger.info(f"Rolling back migration: {migration.version}") # Execute the downgrade await migration.downgrade(self.session) await self.session.commit() execution_time_ms = int((time.time() - start_time) * 1000) # Remove from history delete_sql = "DELETE FROM migration_history WHERE version = :version" await self.session.execute(text(delete_sql), {"version": migration.version}) await self.session.commit() logger.info( f"Migration {migration.version} rolled back successfully in {execution_time_ms}ms" ) except Exception as e: logger.error(f"Rollback of {migration.version} failed: {e}") await self.session.rollback() raise MigrationError(f"Rollback of {migration.version} failed: {e}") from e async def rollback(self, steps: int = 1) -> int: """ Rollback the last N migrations. Args: steps: Number of migrations to rollback Returns: Number of migrations rolled back Raises: MigrationError: If rollback fails """ applied = await self.get_applied_migrations() if not applied: logger.info("No migrations to rollback") return 0 # Get migrations to rollback (in reverse order) to_rollback = applied[-steps:] to_rollback.reverse() migrations_to_rollback = [m for m in self._migrations if m.version in to_rollback] logger.info(f"Rolling back {len(migrations_to_rollback)} migrations") for migration in migrations_to_rollback: await self.rollback_migration(migration) return len(migrations_to_rollback)