import asyncio import os import time from pathlib import Path import pytest from app.db import ( _apply_migration, _cleanup_wal_files, _parse_migration_statements, init_db, open_db, ) async def test_open_db_applies_hardening_pragmas(tmp_path: Path) -> None: database_path = str(tmp_path / "bangui_test.db") db = await open_db(database_path) try: async with db.execute("PRAGMA journal_mode;") as cursor: row = await cursor.fetchone() assert row is not None and row[0].lower() == "wal" async with db.execute("PRAGMA foreign_keys;") as cursor: row = await cursor.fetchone() assert row is not None and row[0] == 1 async with db.execute("PRAGMA busy_timeout;") as cursor: row = await cursor.fetchone() assert row is not None and row[0] == 5000 finally: await db.close() async def test_open_db_respects_busy_timeout_for_concurrent_writes(tmp_path: Path) -> None: database_path = str(tmp_path / "bangui_lock.db") connection_a = await open_db(database_path) try: await connection_a.execute("CREATE TABLE IF NOT EXISTS test_lock (id INTEGER PRIMARY KEY, value TEXT);") await connection_a.commit() await connection_a.execute("BEGIN EXCLUSIVE;") async def write_after_lock() -> None: connection_b = await open_db(database_path) try: await connection_b.execute("INSERT INTO test_lock (value) VALUES ('locked');") await connection_b.commit() finally: await connection_b.close() task = asyncio.create_task(write_after_lock()) await asyncio.sleep(0.1) await connection_a.commit() await task async with connection_a.execute("SELECT value FROM test_lock;") as cursor: row = await cursor.fetchone() assert row is not None and row[0] == "locked" finally: await connection_a.close() async def test_parse_migration_statements_single_statement() -> None: """Test parsing a single statement without comments.""" script = "CREATE TABLE test (id INTEGER PRIMARY KEY);" statements = await _parse_migration_statements(script) assert len(statements) == 1 assert statements[0] == "CREATE TABLE test (id INTEGER PRIMARY KEY)" async def test_parse_migration_statements_multiple_statements() -> None: """Test parsing multiple statements separated by semicolons.""" script = """ CREATE TABLE test (id INTEGER PRIMARY KEY); CREATE INDEX idx_test ON test(id); """ statements = await _parse_migration_statements(script) assert len(statements) == 2 assert "CREATE TABLE test" in statements[0] assert "CREATE INDEX idx_test" in statements[1] async def test_parse_migration_statements_with_line_comments() -> None: """Test parsing statements with line comments.""" script = """ -- This is a comment CREATE TABLE test (id INTEGER PRIMARY KEY); -- Another comment """ statements = await _parse_migration_statements(script) assert len(statements) == 1 assert "CREATE TABLE test" in statements[0] async def test_parse_migration_statements_with_block_comments() -> None: """Test parsing statements with block comments.""" script = """ /* Block comment */ CREATE TABLE test (id INTEGER PRIMARY KEY); /* Another block */ """ statements = await _parse_migration_statements(script) assert len(statements) == 1 assert "CREATE TABLE test" in statements[0] async def test_parse_migration_statements_with_string_literals() -> None: """Test parsing statements with string literals containing semicolons.""" script = """ CREATE TABLE test (id INTEGER PRIMARY KEY, data TEXT); INSERT INTO test (data) VALUES ('value; with; semicolons'); """ statements = await _parse_migration_statements(script) assert len(statements) == 2 assert "CREATE TABLE test" in statements[0] assert "INSERT INTO test" in statements[1] assert "value; with; semicolons" in statements[1] async def test_parse_migration_statements_with_escaped_quotes() -> None: """Test parsing statements with escaped quotes in string literals.""" script = """ INSERT INTO test (data) VALUES ('it''s a test'); """ statements = await _parse_migration_statements(script) assert len(statements) == 1 assert "it''s a test" in statements[0] async def test_apply_migration_is_atomic_success(tmp_path: Path) -> None: """Test that migration is atomic when all statements succeed.""" database_path = str(tmp_path / "bangui_atomic.db") db = await open_db(database_path) try: # Initialize schema_migrations table await db.execute( "CREATE TABLE IF NOT EXISTS schema_migrations (version INTEGER PRIMARY KEY, migrated_at TEXT);" ) await db.commit() # Apply a test migration await _apply_migration(db, 1) # Verify the migration was recorded async with db.execute("SELECT version FROM schema_migrations WHERE version = 1;") as cursor: row = await cursor.fetchone() assert row is not None and row[0] == 1 # Verify the schema tables exist async with db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='settings';") as cursor: row = await cursor.fetchone() assert row is not None finally: await db.close() async def test_apply_migration_is_atomic_rollback(tmp_path: Path) -> None: """Test that migration is rolled back when a statement fails. This test verifies that when an error occurs mid-migration, the transaction is rolled back and the schema_migrations table is NOT updated. """ database_path = str(tmp_path / "bangui_rollback.db") db = await open_db(database_path) try: # Initialize schema_migrations table await db.execute( "CREATE TABLE IF NOT EXISTS schema_migrations (version INTEGER PRIMARY KEY, migrated_at TEXT);" ) await db.commit() # Create a custom migration that will fail from app import db as db_module original_migrations = db_module._MIGRATIONS.copy() # Add a migration that will fail on the second statement db_module._MIGRATIONS[99] = """ CREATE TABLE test_rollback (id INTEGER PRIMARY KEY); INSERT INTO nonexistent_table VALUES (1); """ try: # Attempt migration; it should fail with pytest.raises(Exception): # sqlite3 will raise an error await _apply_migration(db, 99) # Verify the migration was NOT recorded async with db.execute("SELECT version FROM schema_migrations WHERE version = 99;") as cursor: row = await cursor.fetchone() assert row is None # Verify the test table was NOT created (rollback occurred) async with db.execute( "SELECT name FROM sqlite_master WHERE type='table' AND name='test_rollback';" ) as cursor: row = await cursor.fetchone() assert row is None finally: # Restore original migrations db_module._MIGRATIONS = original_migrations finally: await db.close() async def test_init_db_idempotent(tmp_path: Path) -> None: """Test that init_db is idempotent.""" database_path = str(tmp_path / "bangui_idempotent.db") db = await open_db(database_path) try: # Initialize once await init_db(db) # Get schema version async with db.execute("SELECT MAX(version) FROM schema_migrations;") as cursor: row1 = await cursor.fetchone() # Initialize again (should be no-op) await init_db(db) # Verify schema version is unchanged async with db.execute("SELECT MAX(version) FROM schema_migrations;") as cursor: row2 = await cursor.fetchone() assert row1 == row2 finally: await db.close() async def test_cleanup_wal_files_removes_orphaned_files(tmp_path: Path) -> None: """Test that _cleanup_wal_files removes orphaned WAL and SHM files.""" db_path = str(tmp_path / "test_wal.db") wal_path = Path(db_path + "-wal") shm_path = Path(db_path + "-shm") # Create the orphaned files with an old mtime so they look stale wal_path.write_text("orphan") shm_path.write_text("orphan") old_mtime = time.time() - 20 os.utime(wal_path, (old_mtime, old_mtime)) os.utime(shm_path, (old_mtime, old_mtime)) assert wal_path.exists() assert shm_path.exists() # Run cleanup await _cleanup_wal_files(db_path) # Both files should be removed assert not wal_path.exists() assert not shm_path.exists() async def test_cleanup_wal_files_skips_recent_files(tmp_path: Path) -> None: """Test that _cleanup_wal_files skips files modified within 10 seconds.""" db_path = str(tmp_path / "test_wal_recent.db") wal_path = Path(db_path + "-wal") shm_path = Path(db_path + "-shm") # Create files with recent mtime wal_path.write_text("recent") shm_path.write_text("recent") recent_mtime = time.time() - 5 os.utime(wal_path, (recent_mtime, recent_mtime)) os.utime(shm_path, (recent_mtime, recent_mtime)) assert wal_path.exists() assert shm_path.exists() # Run cleanup await _cleanup_wal_files(db_path) # Files should NOT be removed (recent) assert wal_path.exists() assert shm_path.exists() async def test_cleanup_wal_files_handles_missing_files(tmp_path: Path) -> None: """Test that _cleanup_wal_files handles non-existent files gracefully.""" db_path = str(tmp_path / "nonexistent.db") # Should not raise await _cleanup_wal_files(db_path)