Remove 335-line task specification from Docs/Tasks.md. Add test confirming _cleanup_wal_files skips recently-modified WAL/SHM files. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
285 lines
9.7 KiB
Python
285 lines
9.7 KiB
Python
import asyncio
|
|
import os
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
from app.db import (
|
|
_apply_migration,
|
|
_cleanup_wal_files,
|
|
_parse_migration_statements,
|
|
init_db,
|
|
open_db,
|
|
)
|
|
|
|
|
|
async def test_open_db_applies_hardening_pragmas(tmp_path: Path) -> None:
|
|
database_path = str(tmp_path / "bangui_test.db")
|
|
db = await open_db(database_path)
|
|
try:
|
|
async with db.execute("PRAGMA journal_mode;") as cursor:
|
|
row = await cursor.fetchone()
|
|
assert row is not None and row[0].lower() == "wal"
|
|
|
|
async with db.execute("PRAGMA foreign_keys;") as cursor:
|
|
row = await cursor.fetchone()
|
|
assert row is not None and row[0] == 1
|
|
|
|
async with db.execute("PRAGMA busy_timeout;") as cursor:
|
|
row = await cursor.fetchone()
|
|
assert row is not None and row[0] == 5000
|
|
finally:
|
|
await db.close()
|
|
|
|
|
|
async def test_open_db_respects_busy_timeout_for_concurrent_writes(tmp_path: Path) -> None:
|
|
database_path = str(tmp_path / "bangui_lock.db")
|
|
connection_a = await open_db(database_path)
|
|
try:
|
|
await connection_a.execute("CREATE TABLE IF NOT EXISTS test_lock (id INTEGER PRIMARY KEY, value TEXT);")
|
|
await connection_a.commit()
|
|
|
|
await connection_a.execute("BEGIN EXCLUSIVE;")
|
|
|
|
async def write_after_lock() -> None:
|
|
connection_b = await open_db(database_path)
|
|
try:
|
|
await connection_b.execute("INSERT INTO test_lock (value) VALUES ('locked');")
|
|
await connection_b.commit()
|
|
finally:
|
|
await connection_b.close()
|
|
|
|
task = asyncio.create_task(write_after_lock())
|
|
await asyncio.sleep(0.1)
|
|
await connection_a.commit()
|
|
await task
|
|
|
|
async with connection_a.execute("SELECT value FROM test_lock;") as cursor:
|
|
row = await cursor.fetchone()
|
|
assert row is not None and row[0] == "locked"
|
|
finally:
|
|
await connection_a.close()
|
|
|
|
|
|
async def test_parse_migration_statements_single_statement() -> None:
|
|
"""Test parsing a single statement without comments."""
|
|
script = "CREATE TABLE test (id INTEGER PRIMARY KEY);"
|
|
statements = await _parse_migration_statements(script)
|
|
assert len(statements) == 1
|
|
assert statements[0] == "CREATE TABLE test (id INTEGER PRIMARY KEY)"
|
|
|
|
|
|
async def test_parse_migration_statements_multiple_statements() -> None:
|
|
"""Test parsing multiple statements separated by semicolons."""
|
|
script = """
|
|
CREATE TABLE test (id INTEGER PRIMARY KEY);
|
|
CREATE INDEX idx_test ON test(id);
|
|
"""
|
|
statements = await _parse_migration_statements(script)
|
|
assert len(statements) == 2
|
|
assert "CREATE TABLE test" in statements[0]
|
|
assert "CREATE INDEX idx_test" in statements[1]
|
|
|
|
|
|
async def test_parse_migration_statements_with_line_comments() -> None:
|
|
"""Test parsing statements with line comments."""
|
|
script = """
|
|
-- This is a comment
|
|
CREATE TABLE test (id INTEGER PRIMARY KEY);
|
|
-- Another comment
|
|
"""
|
|
statements = await _parse_migration_statements(script)
|
|
assert len(statements) == 1
|
|
assert "CREATE TABLE test" in statements[0]
|
|
|
|
|
|
async def test_parse_migration_statements_with_block_comments() -> None:
|
|
"""Test parsing statements with block comments."""
|
|
script = """
|
|
/* Block comment */
|
|
CREATE TABLE test (id INTEGER PRIMARY KEY);
|
|
/* Another block */
|
|
"""
|
|
statements = await _parse_migration_statements(script)
|
|
assert len(statements) == 1
|
|
assert "CREATE TABLE test" in statements[0]
|
|
|
|
|
|
async def test_parse_migration_statements_with_string_literals() -> None:
|
|
"""Test parsing statements with string literals containing semicolons."""
|
|
script = """
|
|
CREATE TABLE test (id INTEGER PRIMARY KEY, data TEXT);
|
|
INSERT INTO test (data) VALUES ('value; with; semicolons');
|
|
"""
|
|
statements = await _parse_migration_statements(script)
|
|
assert len(statements) == 2
|
|
assert "CREATE TABLE test" in statements[0]
|
|
assert "INSERT INTO test" in statements[1]
|
|
assert "value; with; semicolons" in statements[1]
|
|
|
|
|
|
async def test_parse_migration_statements_with_escaped_quotes() -> None:
|
|
"""Test parsing statements with escaped quotes in string literals."""
|
|
script = """
|
|
INSERT INTO test (data) VALUES ('it''s a test');
|
|
"""
|
|
statements = await _parse_migration_statements(script)
|
|
assert len(statements) == 1
|
|
assert "it''s a test" in statements[0]
|
|
|
|
|
|
async def test_apply_migration_is_atomic_success(tmp_path: Path) -> None:
|
|
"""Test that migration is atomic when all statements succeed."""
|
|
database_path = str(tmp_path / "bangui_atomic.db")
|
|
db = await open_db(database_path)
|
|
try:
|
|
# Initialize schema_migrations table
|
|
await db.execute(
|
|
"CREATE TABLE IF NOT EXISTS schema_migrations (version INTEGER PRIMARY KEY, migrated_at TEXT);"
|
|
)
|
|
await db.commit()
|
|
|
|
# Apply a test migration
|
|
await _apply_migration(db, 1)
|
|
|
|
# Verify the migration was recorded
|
|
async with db.execute("SELECT version FROM schema_migrations WHERE version = 1;") as cursor:
|
|
row = await cursor.fetchone()
|
|
assert row is not None and row[0] == 1
|
|
|
|
# Verify the schema tables exist
|
|
async with db.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='settings';") as cursor:
|
|
row = await cursor.fetchone()
|
|
assert row is not None
|
|
finally:
|
|
await db.close()
|
|
|
|
|
|
async def test_apply_migration_is_atomic_rollback(tmp_path: Path) -> None:
|
|
"""Test that migration is rolled back when a statement fails.
|
|
|
|
This test verifies that when an error occurs mid-migration, the
|
|
transaction is rolled back and the schema_migrations table is NOT updated.
|
|
"""
|
|
database_path = str(tmp_path / "bangui_rollback.db")
|
|
db = await open_db(database_path)
|
|
try:
|
|
# Initialize schema_migrations table
|
|
await db.execute(
|
|
"CREATE TABLE IF NOT EXISTS schema_migrations (version INTEGER PRIMARY KEY, migrated_at TEXT);"
|
|
)
|
|
await db.commit()
|
|
|
|
# Create a custom migration that will fail
|
|
from app import db as db_module
|
|
|
|
original_migrations = db_module._MIGRATIONS.copy()
|
|
|
|
# Add a migration that will fail on the second statement
|
|
db_module._MIGRATIONS[99] = """
|
|
CREATE TABLE test_rollback (id INTEGER PRIMARY KEY);
|
|
INSERT INTO nonexistent_table VALUES (1);
|
|
"""
|
|
|
|
try:
|
|
# Attempt migration; it should fail
|
|
with pytest.raises(Exception): # sqlite3 will raise an error
|
|
await _apply_migration(db, 99)
|
|
|
|
# Verify the migration was NOT recorded
|
|
async with db.execute("SELECT version FROM schema_migrations WHERE version = 99;") as cursor:
|
|
row = await cursor.fetchone()
|
|
assert row is None
|
|
|
|
# Verify the test table was NOT created (rollback occurred)
|
|
async with db.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='test_rollback';"
|
|
) as cursor:
|
|
row = await cursor.fetchone()
|
|
assert row is None
|
|
finally:
|
|
# Restore original migrations
|
|
db_module._MIGRATIONS = original_migrations
|
|
finally:
|
|
await db.close()
|
|
|
|
|
|
async def test_init_db_idempotent(tmp_path: Path) -> None:
|
|
"""Test that init_db is idempotent."""
|
|
database_path = str(tmp_path / "bangui_idempotent.db")
|
|
db = await open_db(database_path)
|
|
try:
|
|
# Initialize once
|
|
await init_db(db)
|
|
|
|
# Get schema version
|
|
async with db.execute("SELECT MAX(version) FROM schema_migrations;") as cursor:
|
|
row1 = await cursor.fetchone()
|
|
|
|
# Initialize again (should be no-op)
|
|
await init_db(db)
|
|
|
|
# Verify schema version is unchanged
|
|
async with db.execute("SELECT MAX(version) FROM schema_migrations;") as cursor:
|
|
row2 = await cursor.fetchone()
|
|
|
|
assert row1 == row2
|
|
finally:
|
|
await db.close()
|
|
|
|
|
|
async def test_cleanup_wal_files_removes_orphaned_files(tmp_path: Path) -> None:
|
|
"""Test that _cleanup_wal_files removes orphaned WAL and SHM files."""
|
|
db_path = str(tmp_path / "test_wal.db")
|
|
wal_path = Path(db_path + "-wal")
|
|
shm_path = Path(db_path + "-shm")
|
|
|
|
# Create the orphaned files with an old mtime so they look stale
|
|
wal_path.write_text("orphan")
|
|
shm_path.write_text("orphan")
|
|
old_mtime = time.time() - 20
|
|
os.utime(wal_path, (old_mtime, old_mtime))
|
|
os.utime(shm_path, (old_mtime, old_mtime))
|
|
|
|
assert wal_path.exists()
|
|
assert shm_path.exists()
|
|
|
|
# Run cleanup
|
|
await _cleanup_wal_files(db_path)
|
|
|
|
# Both files should be removed
|
|
assert not wal_path.exists()
|
|
assert not shm_path.exists()
|
|
|
|
|
|
async def test_cleanup_wal_files_skips_recent_files(tmp_path: Path) -> None:
|
|
"""Test that _cleanup_wal_files skips files modified within 10 seconds."""
|
|
db_path = str(tmp_path / "test_wal_recent.db")
|
|
wal_path = Path(db_path + "-wal")
|
|
shm_path = Path(db_path + "-shm")
|
|
|
|
# Create files with recent mtime
|
|
wal_path.write_text("recent")
|
|
shm_path.write_text("recent")
|
|
recent_mtime = time.time() - 5
|
|
os.utime(wal_path, (recent_mtime, recent_mtime))
|
|
os.utime(shm_path, (recent_mtime, recent_mtime))
|
|
|
|
assert wal_path.exists()
|
|
assert shm_path.exists()
|
|
|
|
# Run cleanup
|
|
await _cleanup_wal_files(db_path)
|
|
|
|
# Files should NOT be removed (recent)
|
|
assert wal_path.exists()
|
|
assert shm_path.exists()
|
|
|
|
|
|
async def test_cleanup_wal_files_handles_missing_files(tmp_path: Path) -> None:
|
|
"""Test that _cleanup_wal_files handles non-existent files gracefully."""
|
|
db_path = str(tmp_path / "nonexistent.db")
|
|
|
|
# Should not raise
|
|
await _cleanup_wal_files(db_path)
|