diff --git a/Docs/Backend-Development.md b/Docs/Backend-Development.md index 7cf9388..e567e31 100644 --- a/Docs/Backend-Development.md +++ b/Docs/Backend-Development.md @@ -364,6 +364,91 @@ assert escape_like("10.0.0.1") == "10.0.0.1" # Unchanged --- +## 6.2 Database Migrations + +The application database schema is versioned and migrated automatically on startup via `app.db.init_db()`. + +### Migration Design Principles + +**Migrations must be atomic.** All schema changes for a single version (DDL statements) and the `schema_migrations` record insert must be wrapped in a single `BEGIN IMMEDIATE ... COMMIT` transaction. This prevents partial migrations if a process crashes mid-migration. + +If a crash occurs between migration steps, the next startup will: +1. Detect the missing `schema_migrations` record. +2. Re-apply the entire migration in a single transaction (all-or-nothing). +3. Avoid data corruption or schema inconsistency. + +### Writing a New Migration + +1. **Add the DDL statements** to `_MIGRATIONS` dict in `app/db.py`: + +```python +_MIGRATIONS: dict[int, str] = { + 1: _CREATE_INITIAL_SCHEMA, + 2: """ +-- Migration 2: Add new_column to users table. +ALTER TABLE users ADD COLUMN new_column TEXT DEFAULT 'default_value'; +CREATE INDEX idx_users_new_column ON users(new_column); +""", +} +``` + +2. **Update `_CURRENT_SCHEMA_VERSION`** to the new version number: + +```python +_CURRENT_SCHEMA_VERSION: int = 2 # was 1 +``` + +3. **Ensure idempotency where possible:** + - Use `CREATE TABLE IF NOT EXISTS` and `CREATE INDEX IF NOT EXISTS`. + - For `ALTER TABLE ADD COLUMN`, check if the column exists first using `PRAGMA table_info()` if re-applying the migration is a concern. + +4. **Verify atomicity in tests:** + +```python +async def test_migration_2_is_atomic(tmp_path: Path) -> None: + """Verify migration 2 rolls back on failure.""" + db = await open_db(str(tmp_path / "test.db")) + try: + await db.execute("CREATE TABLE schema_migrations (version INTEGER PRIMARY KEY);") + await db.commit() + + # Add a test migration that fails mid-way + original = db_module._MIGRATIONS.copy() + db_module._MIGRATIONS[99] = """ + CREATE TABLE test_table (id INTEGER PRIMARY KEY); + INSERT INTO nonexistent_table VALUES (1); + """ + + try: + with pytest.raises(Exception): + await _apply_migration(db, 99) + + # Verify rollback: migration NOT recorded + async with db.execute( + "SELECT version FROM schema_migrations WHERE version = 99;" + ) as cursor: + assert await cursor.fetchone() is None + + # Verify rollback: table NOT created + async with db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='test_table';" + ) as cursor: + assert await cursor.fetchone() is None + finally: + db_module._MIGRATIONS = original + finally: + await db.close() +``` + +### Common Pitfalls + +- **Non-idempotent statements** — `ALTER TABLE ADD COLUMN` without `IF NOT EXISTS` will fail on re-run. Use explicit checks if needed. +- **Comments containing semicolons** — the migration parser strips comments correctly, but avoid unusual comment syntax. +- **String literals with semicolons** — the parser handles these; no special escaping needed. +- **Multiple operations in one migration** — keep migrations focused. Combine related DDL but split unrelated changes. + +--- + ## 7. Logging - Use **structlog** for every log message. diff --git a/backend/app/db.py b/backend/app/db.py index 97b0ae7..f5903b1 100644 --- a/backend/app/db.py +++ b/backend/app/db.py @@ -142,7 +142,7 @@ async def _configure_connection(db: aiosqlite.Connection) -> None: async def _get_current_schema_version(db: aiosqlite.Connection) -> int: """Return the highest applied schema version for the given database.""" - await db.executescript(_CREATE_SCHEMA_MIGRATIONS) + await db.execute(_CREATE_SCHEMA_MIGRATIONS) async with db.execute("SELECT MAX(version) FROM schema_migrations;") as cursor: row = await cursor.fetchone() if row is None or row[0] is None: @@ -150,12 +150,114 @@ async def _get_current_schema_version(db: aiosqlite.Connection) -> int: return int(row[0]) +async def _parse_migration_statements(script: str) -> list[str]: + """Parse a migration script into individual SQL statements. + + Splits on semicolons but ignores semicolons inside string literals and + comments. Handles both block (-- comment) and line comments. + + Args: + script: The raw migration script. + + Returns: + List of SQL statements (stripped of whitespace and comments). + """ + statements: list[str] = [] + current_stmt: list[str] = [] + i = 0 + + while i < len(script): + char = script[i] + + # Skip block comments (-- ...) + if i < len(script) - 1 and script[i:i+2] == "--": + while i < len(script) and script[i] != "\n": + i += 1 + i += 1 + continue + + # Skip line comments (/* ... */) + if i < len(script) - 1 and script[i:i+2] == "/*": + i += 2 + while i < len(script) - 1: + if script[i:i+2] == "*/": + i += 2 + break + i += 1 + continue + + # Handle string literals (single or double quotes) + if char in ("'", '"'): + quote = char + current_stmt.append(char) + i += 1 + while i < len(script): + if script[i] == quote: + if i + 1 < len(script) and script[i + 1] == quote: + # Escaped quote + current_stmt.append(quote) + current_stmt.append(quote) + i += 2 + else: + # End of string + current_stmt.append(quote) + i += 1 + break + else: + current_stmt.append(script[i]) + i += 1 + continue + + # Statement separator + if char == ";": + stmt = "".join(current_stmt).strip() + if stmt: + statements.append(stmt) + current_stmt = [] + i += 1 + continue + + current_stmt.append(char) + i += 1 + + # Add any remaining statement + stmt = "".join(current_stmt).strip() + if stmt: + statements.append(stmt) + + return statements + + async def _apply_migration(db: aiosqlite.Connection, version: int) -> None: - """Apply a single migration step and record its completion.""" + """Apply a single migration step and record its completion atomically. + + Wraps all DDL statements and the schema_migrations insert in a single + BEGIN IMMEDIATE ... COMMIT transaction to ensure atomicity. If any + statement fails, the entire migration is rolled back. + + Args: + db: An open aiosqlite.Connection. + version: The migration version number. + + Raises: + Any exception from executing the migration statements or inserting + the schema migration record will cause a rollback. + """ migration_script = _MIGRATIONS[version] - await db.executescript(migration_script) - await db.execute("INSERT INTO schema_migrations (version) VALUES (?);", (version,)) - await db.commit() + statements = await _parse_migration_statements(migration_script) + + try: + await db.execute("BEGIN IMMEDIATE;") + + for statement in statements: + await db.execute(statement) + + await db.execute("INSERT INTO schema_migrations (version) VALUES (?);", (version,)) + + await db.commit() + except Exception: + await db.rollback() + raise async def _migrate_schema(db: aiosqlite.Connection) -> None: diff --git a/backend/tests/test_db.py b/backend/tests/test_db.py index b125cb7..9c3a694 100644 --- a/backend/tests/test_db.py +++ b/backend/tests/test_db.py @@ -1,9 +1,16 @@ import asyncio from pathlib import Path +from unittest.mock import AsyncMock, patch import aiosqlite +import pytest -from app.db import open_db +from app.db import ( + _apply_migration, + _parse_migration_statements, + init_db, + open_db, +) async def test_open_db_applies_hardening_pragmas(tmp_path: Path) -> None: @@ -56,3 +63,181 @@ async def test_open_db_respects_busy_timeout_for_concurrent_writes(tmp_path: Pat assert row is not None and row[0] == "locked" finally: await connection_a.close() + + +async def test_parse_migration_statements_single_statement() -> None: + """Test parsing a single statement without comments.""" + script = "CREATE TABLE test (id INTEGER PRIMARY KEY);" + statements = await _parse_migration_statements(script) + assert len(statements) == 1 + assert statements[0] == "CREATE TABLE test (id INTEGER PRIMARY KEY)" + + +async def test_parse_migration_statements_multiple_statements() -> None: + """Test parsing multiple statements separated by semicolons.""" + script = """ + CREATE TABLE test (id INTEGER PRIMARY KEY); + CREATE INDEX idx_test ON test(id); + """ + statements = await _parse_migration_statements(script) + assert len(statements) == 2 + assert "CREATE TABLE test" in statements[0] + assert "CREATE INDEX idx_test" in statements[1] + + +async def test_parse_migration_statements_with_line_comments() -> None: + """Test parsing statements with line comments.""" + script = """ + -- This is a comment + CREATE TABLE test (id INTEGER PRIMARY KEY); + -- Another comment + """ + statements = await _parse_migration_statements(script) + assert len(statements) == 1 + assert "CREATE TABLE test" in statements[0] + + +async def test_parse_migration_statements_with_block_comments() -> None: + """Test parsing statements with block comments.""" + script = """ + /* Block comment */ + CREATE TABLE test (id INTEGER PRIMARY KEY); + /* Another block */ + """ + statements = await _parse_migration_statements(script) + assert len(statements) == 1 + assert "CREATE TABLE test" in statements[0] + + +async def test_parse_migration_statements_with_string_literals() -> None: + """Test parsing statements with string literals containing semicolons.""" + script = """ + CREATE TABLE test (id INTEGER PRIMARY KEY, data TEXT); + INSERT INTO test (data) VALUES ('value; with; semicolons'); + """ + statements = await _parse_migration_statements(script) + assert len(statements) == 2 + assert "CREATE TABLE test" in statements[0] + assert "INSERT INTO test" in statements[1] + assert "value; with; semicolons" in statements[1] + + +async def test_parse_migration_statements_with_escaped_quotes() -> None: + """Test parsing statements with escaped quotes in string literals.""" + script = """ + INSERT INTO test (data) VALUES ('it''s a test'); + """ + statements = await _parse_migration_statements(script) + assert len(statements) == 1 + assert "it''s a test" in statements[0] + + +async def test_apply_migration_is_atomic_success(tmp_path: Path) -> None: + """Test that migration is atomic when all statements succeed.""" + database_path = str(tmp_path / "bangui_atomic.db") + db = await open_db(database_path) + try: + # Initialize schema_migrations table + await db.execute( + "CREATE TABLE IF NOT EXISTS schema_migrations (version INTEGER PRIMARY KEY, migrated_at TEXT);" + ) + await db.commit() + + # Apply a test migration + await _apply_migration(db, 1) + + # Verify the migration was recorded + async with db.execute( + "SELECT version FROM schema_migrations WHERE version = 1;" + ) as cursor: + row = await cursor.fetchone() + assert row is not None and row[0] == 1 + + # Verify the schema tables exist + async with db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='settings';" + ) as cursor: + row = await cursor.fetchone() + assert row is not None + finally: + await db.close() + + +async def test_apply_migration_is_atomic_rollback(tmp_path: Path) -> None: + """Test that migration is rolled back when a statement fails. + + This test verifies that when an error occurs mid-migration, the + transaction is rolled back and the schema_migrations table is NOT updated. + """ + database_path = str(tmp_path / "bangui_rollback.db") + db = await open_db(database_path) + try: + # Initialize schema_migrations table + await db.execute( + "CREATE TABLE IF NOT EXISTS schema_migrations (version INTEGER PRIMARY KEY, migrated_at TEXT);" + ) + await db.commit() + + # Create a custom migration that will fail + from app import db as db_module + + original_migrations = db_module._MIGRATIONS.copy() + + # Add a migration that will fail on the second statement + db_module._MIGRATIONS[99] = """ + CREATE TABLE test_rollback (id INTEGER PRIMARY KEY); + INSERT INTO nonexistent_table VALUES (1); + """ + + try: + # Attempt migration; it should fail + with pytest.raises(Exception): # sqlite3 will raise an error + await _apply_migration(db, 99) + + # Verify the migration was NOT recorded + async with db.execute( + "SELECT version FROM schema_migrations WHERE version = 99;" + ) as cursor: + row = await cursor.fetchone() + assert row is None + + # Verify the test table was NOT created (rollback occurred) + async with db.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name='test_rollback';" + ) as cursor: + row = await cursor.fetchone() + assert row is None + finally: + # Restore original migrations + db_module._MIGRATIONS = original_migrations + finally: + await db.close() + + +async def test_init_db_idempotent(tmp_path: Path) -> None: + """Test that init_db is idempotent.""" + database_path = str(tmp_path / "bangui_idempotent.db") + db = await open_db(database_path) + try: + # Initialize once + await init_db(db) + + # Get schema version + async with db.execute( + "SELECT MAX(version) FROM schema_migrations;" + ) as cursor: + row1 = await cursor.fetchone() + + # Initialize again (should be no-op) + await init_db(db) + + # Verify schema version is unchanged + async with db.execute( + "SELECT MAX(version) FROM schema_migrations;" + ) as cursor: + row2 = await cursor.fetchone() + + assert row1 == row2 + finally: + await db.close() +