refactoring-backend #4
@@ -475,14 +475,75 @@ async def init_db(db: aiosqlite.Connection) -> None:
|
||||
async def open_db(database_path: str) -> aiosqlite.Connection:
|
||||
"""Open a new application SQLite connection with the standard settings.
|
||||
|
||||
Creates the parent directory if it does not exist.
|
||||
|
||||
Args:
|
||||
database_path: Path to the BanGUI SQLite database.
|
||||
|
||||
Returns:
|
||||
A configured :class:`aiosqlite.Connection` instance.
|
||||
|
||||
Raises:
|
||||
DatabasePathInvalidError: If the directory cannot be created or is inaccessible.
|
||||
DatabasePermissionDeniedError: If aiosqlite.connect raises PermissionError.
|
||||
DatabaseCorruptedError: If the database file is corrupted.
|
||||
DatabaseUnavailableError: For any other unexpected error.
|
||||
"""
|
||||
await _cleanup_wal_files(database_path)
|
||||
db = await aiosqlite.connect(database_path)
|
||||
from app.exceptions import (
|
||||
DatabaseCorruptedError,
|
||||
DatabasePathInvalidError,
|
||||
DatabasePermissionDeniedError,
|
||||
DatabaseUnavailableError,
|
||||
)
|
||||
|
||||
db_dir = Path(database_path).parent
|
||||
if not db_dir.exists():
|
||||
try:
|
||||
db_dir.mkdir(parents=True, exist_ok=True)
|
||||
except PermissionError as exc:
|
||||
log.error("database_open_failed", error=str(exc), database_path=database_path)
|
||||
raise DatabasePathInvalidError(database_path) from exc
|
||||
except OSError as exc:
|
||||
log.error("database_open_failed", error=str(exc), database_path=database_path)
|
||||
raise DatabaseUnavailableError(database_path, str(exc)) from exc
|
||||
|
||||
try:
|
||||
db = await aiosqlite.connect(database_path)
|
||||
except PermissionError as exc:
|
||||
log.error("database_open_failed", error=str(exc), database_path=database_path)
|
||||
raise DatabasePermissionDeniedError(database_path) from exc
|
||||
except aiosqlite.OperationalError as exc:
|
||||
error_msg = str(exc).lower()
|
||||
sqlite_code = getattr(exc, "sqlite_errorcode", None)
|
||||
log.error(
|
||||
"database_open_failed",
|
||||
error=str(exc),
|
||||
sqlite_errorcode=sqlite_code,
|
||||
database_path=database_path,
|
||||
)
|
||||
if "database is locked" in error_msg or "busy" in error_msg:
|
||||
raise DatabaseUnavailableError(database_path, str(exc)) from exc
|
||||
if "unable to open database file" in error_msg:
|
||||
raise DatabasePathInvalidError(database_path) from exc
|
||||
raise DatabaseUnavailableError(database_path, str(exc)) from exc
|
||||
except aiosqlite.DatabaseError as exc:
|
||||
log.error(
|
||||
"database_open_failed",
|
||||
error=str(exc),
|
||||
database_path=database_path,
|
||||
)
|
||||
raise DatabaseCorruptedError(database_path) from exc
|
||||
except OSError as exc:
|
||||
log.error("database_open_failed", error=str(exc), database_path=database_path)
|
||||
raise DatabaseUnavailableError(database_path, str(exc)) from exc
|
||||
except Exception as exc:
|
||||
log.error("database_open_failed", error=str(exc), database_path=database_path)
|
||||
raise DatabaseUnavailableError(database_path, str(exc)) from exc
|
||||
|
||||
db.row_factory = aiosqlite.Row
|
||||
await _configure_connection(db)
|
||||
try:
|
||||
await _configure_connection(db)
|
||||
except Exception:
|
||||
await db.close()
|
||||
raise
|
||||
return db
|
||||
|
||||
@@ -165,22 +165,61 @@ async def get_db(
|
||||
|
||||
Yields:
|
||||
An open :class:`aiosqlite.Connection` for the request.
|
||||
|
||||
Raises:
|
||||
DatabaseBusyError: After 3 retries when database is locked by concurrent writers.
|
||||
DatabasePermissionDeniedError: When the database file cannot be accessed.
|
||||
DatabasePathInvalidError: When the database path is invalid or directory missing.
|
||||
DatabaseCorruptedError: When the database file is corrupted.
|
||||
DatabaseUnavailableError: For any other unexpected database error.
|
||||
"""
|
||||
from app.db import open_db # noqa: PLC0415
|
||||
from app.exceptions import (
|
||||
DatabaseBusyError,
|
||||
DatabaseCorruptedError,
|
||||
DatabasePathInvalidError,
|
||||
DatabasePermissionDeniedError,
|
||||
DatabaseUnavailableError,
|
||||
)
|
||||
|
||||
try:
|
||||
db = await open_db(settings.database_path)
|
||||
except Exception as exc:
|
||||
log.error("database_open_failed", error=str(exc))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Database is not available.",
|
||||
) from exc
|
||||
db = None
|
||||
retries = 3
|
||||
retry_delay = 0.1
|
||||
last_exc = None
|
||||
|
||||
for attempt in range(1, retries + 1):
|
||||
try:
|
||||
db = await open_db(settings.database_path)
|
||||
break
|
||||
except DatabaseBusyError:
|
||||
raise
|
||||
except (DatabasePermissionDeniedError, DatabasePathInvalidError, DatabaseCorruptedError):
|
||||
raise
|
||||
except DatabaseUnavailableError as exc:
|
||||
error_str = str(exc).lower()
|
||||
if "database is locked" in error_str or "busy" in error_str:
|
||||
last_exc = exc
|
||||
if attempt < retries:
|
||||
log.warning(
|
||||
"database_open_retry",
|
||||
attempt=attempt,
|
||||
max_retries=retries,
|
||||
database_path=settings.database_path,
|
||||
)
|
||||
import asyncio
|
||||
await asyncio.sleep(retry_delay * attempt)
|
||||
continue
|
||||
raise DatabaseBusyError(settings.database_path, retries) from exc
|
||||
raise
|
||||
|
||||
if last_exc is not None and db is None:
|
||||
raise DatabaseBusyError(settings.database_path, retries)
|
||||
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
await db.close()
|
||||
if db is not None:
|
||||
await db.close()
|
||||
|
||||
|
||||
async def get_http_session(
|
||||
|
||||
@@ -473,6 +473,75 @@ class SetupAlreadyCompleteError(ConflictError):
|
||||
super().__init__("Setup has already been completed.")
|
||||
|
||||
|
||||
class DatabaseBusyError(ServiceUnavailableError):
|
||||
"""Raised when the SQLite database is locked or busy after all retries."""
|
||||
|
||||
error_code: str = "database_busy"
|
||||
|
||||
def __init__(self, database_path: str, retries: int) -> None:
|
||||
self.database_path = database_path
|
||||
self.retries = retries
|
||||
super().__init__(
|
||||
f"Database is temporarily busy after {retries} retries."
|
||||
)
|
||||
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"database_path": self.database_path, "retries": self.retries}
|
||||
|
||||
|
||||
class DatabasePermissionDeniedError(ServiceUnavailableError):
|
||||
"""Raised when the database file cannot be accessed due to insufficient permissions."""
|
||||
|
||||
error_code: str = "database_permission_denied"
|
||||
|
||||
def __init__(self, database_path: str) -> None:
|
||||
self.database_path = database_path
|
||||
super().__init__("Insufficient permissions to access the database file.")
|
||||
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"database_path": self.database_path}
|
||||
|
||||
|
||||
class DatabasePathInvalidError(ServiceUnavailableError):
|
||||
"""Raised when the database directory does not exist or the path is invalid."""
|
||||
|
||||
error_code: str = "database_path_invalid"
|
||||
|
||||
def __init__(self, database_path: str) -> None:
|
||||
self.database_path = database_path
|
||||
super().__init__("Database directory does not exist or path is invalid.")
|
||||
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"database_path": self.database_path}
|
||||
|
||||
|
||||
class DatabaseCorruptedError(ServiceUnavailableError):
|
||||
"""Raised when the database file is corrupted."""
|
||||
|
||||
error_code: str = "database_corrupted"
|
||||
|
||||
def __init__(self, database_path: str) -> None:
|
||||
self.database_path = database_path
|
||||
super().__init__("Database file is corrupted.")
|
||||
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"database_path": self.database_path}
|
||||
|
||||
|
||||
class DatabaseUnavailableError(ServiceUnavailableError):
|
||||
"""Raised for any other unexpected database error."""
|
||||
|
||||
error_code: str = "database_unavailable"
|
||||
|
||||
def __init__(self, database_path: str, error: str) -> None:
|
||||
self.database_path = database_path
|
||||
self.error = error
|
||||
super().__init__(f"Database is not available: {error}")
|
||||
|
||||
def get_error_metadata(self) -> ErrorMetadata:
|
||||
return {"database_path": self.database_path, "error": self.error}
|
||||
|
||||
|
||||
class BlocklistSourceNotFoundError(NotFoundError):
|
||||
"""Raised when a blocklist source is not found."""
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from starlette.requests import Request
|
||||
@@ -19,6 +20,13 @@ from app.dependencies import (
|
||||
get_settings,
|
||||
get_settings_repo,
|
||||
)
|
||||
from app.exceptions import (
|
||||
DatabaseBusyError,
|
||||
DatabaseCorruptedError,
|
||||
DatabasePathInvalidError,
|
||||
DatabasePermissionDeniedError,
|
||||
DatabaseUnavailableError,
|
||||
)
|
||||
from app.main import create_app
|
||||
from app.models.server import ServerStatus
|
||||
|
||||
@@ -98,3 +106,184 @@ async def test_get_db_uses_effective_runtime_database_path(test_settings: Settin
|
||||
await gen.aclose()
|
||||
|
||||
mock_open_db.assert_awaited_once_with("/tmp/runtime.db")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Database error handling tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_raises_database_permission_denied_on_permission_error(
|
||||
test_settings: Settings,
|
||||
) -> None:
|
||||
"""PermissionError from open_db raises DatabasePermissionDeniedError."""
|
||||
with patch(
|
||||
"app.db.open_db",
|
||||
new=AsyncMock(side_effect=DatabasePermissionDeniedError(test_settings.database_path)),
|
||||
):
|
||||
gen = get_db(settings=test_settings)
|
||||
with pytest.raises(DatabasePermissionDeniedError) as exc_info:
|
||||
await gen.__anext__()
|
||||
await gen.aclose()
|
||||
|
||||
assert exc_info.value.error_code == "database_permission_denied"
|
||||
assert exc_info.value.database_path == test_settings.database_path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_raises_database_path_invalid_on_missing_directory(
|
||||
test_settings: Settings,
|
||||
) -> None:
|
||||
"""sqlite3.OperationalError('unable to open database file') raises DatabasePathInvalidError."""
|
||||
with patch(
|
||||
"app.db.open_db",
|
||||
new=AsyncMock(side_effect=DatabasePathInvalidError(test_settings.database_path)),
|
||||
):
|
||||
gen = get_db(settings=test_settings)
|
||||
with pytest.raises(DatabasePathInvalidError) as exc_info:
|
||||
await gen.__anext__()
|
||||
await gen.aclose()
|
||||
|
||||
assert exc_info.value.error_code == "database_path_invalid"
|
||||
assert exc_info.value.database_path == test_settings.database_path
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_retries_on_database_locked(test_settings: Settings) -> None:
|
||||
"""get_db retries up to 3 times when database is locked."""
|
||||
mock_connection = MagicMock()
|
||||
mock_connection.close = AsyncMock()
|
||||
|
||||
locked_err = DatabaseUnavailableError(
|
||||
test_settings.database_path, "database is locked"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.db.open_db",
|
||||
new=AsyncMock(side_effect=[locked_err, locked_err, mock_connection]),
|
||||
) as mock_open:
|
||||
gen = get_db(settings=test_settings)
|
||||
with patch("asyncio.sleep", new=AsyncMock()) as mock_sleep:
|
||||
connection = await gen.__anext__()
|
||||
await gen.aclose()
|
||||
|
||||
assert mock_open.call_count == 3
|
||||
assert connection is mock_connection
|
||||
assert mock_sleep.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_fails_after_max_retries_on_database_locked(
|
||||
test_settings: Settings,
|
||||
) -> None:
|
||||
"""After 3 retries on database locked, raises DatabaseBusyError."""
|
||||
locked_err = DatabaseUnavailableError(
|
||||
test_settings.database_path, "database is locked"
|
||||
)
|
||||
|
||||
with patch("app.db.open_db", new=AsyncMock(side_effect=locked_err)) as mock_open:
|
||||
gen = get_db(settings=test_settings)
|
||||
with patch("asyncio.sleep", new=AsyncMock()):
|
||||
with pytest.raises(DatabaseBusyError) as exc_info:
|
||||
await gen.__anext__()
|
||||
await gen.aclose()
|
||||
|
||||
assert mock_open.call_count == 3
|
||||
assert exc_info.value.error_code == "database_busy"
|
||||
assert exc_info.value.retries == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_raises_database_corrupted_on_malformed_db(
|
||||
test_settings: Settings,
|
||||
) -> None:
|
||||
"""sqlite3.DatabaseError('database disk image is malformed') raises DatabaseCorruptedError."""
|
||||
with patch(
|
||||
"app.db.open_db",
|
||||
new=AsyncMock(side_effect=DatabaseCorruptedError(test_settings.database_path)),
|
||||
):
|
||||
gen = get_db(settings=test_settings)
|
||||
with pytest.raises(DatabaseCorruptedError) as exc_info:
|
||||
await gen.__anext__()
|
||||
await gen.aclose()
|
||||
|
||||
assert exc_info.value.error_code == "database_corrupted"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_db_creates_parent_directory_if_missing(tmp_path: pytest.Path) -> None:
|
||||
"""open_db creates the parent directory when it does not exist."""
|
||||
from pathlib import Path
|
||||
|
||||
from app.db import open_db
|
||||
|
||||
db_path = str(Path(str(tmp_path)) / "subdir" / "deeper" / "bangui.db")
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.close = AsyncMock()
|
||||
mock_conn.execute = AsyncMock()
|
||||
mock_conn.commit = AsyncMock()
|
||||
|
||||
with patch("aiosqlite.connect", new=AsyncMock(return_value=mock_conn)), \
|
||||
patch("app.db._configure_connection", new=AsyncMock()):
|
||||
connection = await open_db(db_path)
|
||||
|
||||
assert connection is mock_conn
|
||||
assert Path(db_path).parent.exists()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_open_db_logs_specific_sqlite_error_code() -> None:
|
||||
"""open_db logs the SQLite error code when available."""
|
||||
from app.db import open_db
|
||||
|
||||
exc = aiosqlite.OperationalError("database is locked")
|
||||
exc.sqlite_errorcode = 5 # SQLITE_BUSY
|
||||
|
||||
with patch("aiosqlite.connect", new=AsyncMock(side_effect=exc)), \
|
||||
pytest.raises(DatabaseUnavailableError):
|
||||
await open_db("/tmp/test.db")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error metadata tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_database_busy_error_metadata() -> None:
|
||||
"""DatabaseBusyError returns correct metadata."""
|
||||
err = DatabaseBusyError("/data/bangui.db", retries=3)
|
||||
assert err.error_code == "database_busy"
|
||||
metadata = err.get_error_metadata()
|
||||
assert metadata["database_path"] == "/data/bangui.db"
|
||||
assert metadata["retries"] == 3
|
||||
|
||||
|
||||
def test_database_permission_denied_error_metadata() -> None:
|
||||
"""DatabasePermissionDeniedError returns correct metadata."""
|
||||
err = DatabasePermissionDeniedError("/data/bangui.db")
|
||||
assert err.error_code == "database_permission_denied"
|
||||
assert err.get_error_metadata()["database_path"] == "/data/bangui.db"
|
||||
|
||||
|
||||
def test_database_path_invalid_error_metadata() -> None:
|
||||
"""DatabasePathInvalidError returns correct metadata."""
|
||||
err = DatabasePathInvalidError("/data/bangui.db")
|
||||
assert err.error_code == "database_path_invalid"
|
||||
assert err.get_error_metadata()["database_path"] == "/data/bangui.db"
|
||||
|
||||
|
||||
def test_database_corrupted_error_metadata() -> None:
|
||||
"""DatabaseCorruptedError returns correct metadata."""
|
||||
err = DatabaseCorruptedError("/data/bangui.db")
|
||||
assert err.error_code == "database_corrupted"
|
||||
assert err.get_error_metadata()["database_path"] == "/data/bangui.db"
|
||||
|
||||
|
||||
def test_database_unavailable_error_metadata() -> None:
|
||||
"""DatabaseUnavailableError returns correct metadata."""
|
||||
err = DatabaseUnavailableError("/data/bangui.db", "some error")
|
||||
assert err.error_code == "database_unavailable"
|
||||
metadata = err.get_error_metadata()
|
||||
assert metadata["database_path"] == "/data/bangui.db"
|
||||
assert metadata["error"] == "some error"
|
||||
|
||||
Reference in New Issue
Block a user