refactoring-backend #4
@@ -475,14 +475,75 @@ async def init_db(db: aiosqlite.Connection) -> None:
|
|||||||
async def open_db(database_path: str) -> aiosqlite.Connection:
|
async def open_db(database_path: str) -> aiosqlite.Connection:
|
||||||
"""Open a new application SQLite connection with the standard settings.
|
"""Open a new application SQLite connection with the standard settings.
|
||||||
|
|
||||||
|
Creates the parent directory if it does not exist.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
database_path: Path to the BanGUI SQLite database.
|
database_path: Path to the BanGUI SQLite database.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A configured :class:`aiosqlite.Connection` instance.
|
A configured :class:`aiosqlite.Connection` instance.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DatabasePathInvalidError: If the directory cannot be created or is inaccessible.
|
||||||
|
DatabasePermissionDeniedError: If aiosqlite.connect raises PermissionError.
|
||||||
|
DatabaseCorruptedError: If the database file is corrupted.
|
||||||
|
DatabaseUnavailableError: For any other unexpected error.
|
||||||
"""
|
"""
|
||||||
await _cleanup_wal_files(database_path)
|
from app.exceptions import (
|
||||||
db = await aiosqlite.connect(database_path)
|
DatabaseCorruptedError,
|
||||||
|
DatabasePathInvalidError,
|
||||||
|
DatabasePermissionDeniedError,
|
||||||
|
DatabaseUnavailableError,
|
||||||
|
)
|
||||||
|
|
||||||
|
db_dir = Path(database_path).parent
|
||||||
|
if not db_dir.exists():
|
||||||
|
try:
|
||||||
|
db_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
except PermissionError as exc:
|
||||||
|
log.error("database_open_failed", error=str(exc), database_path=database_path)
|
||||||
|
raise DatabasePathInvalidError(database_path) from exc
|
||||||
|
except OSError as exc:
|
||||||
|
log.error("database_open_failed", error=str(exc), database_path=database_path)
|
||||||
|
raise DatabaseUnavailableError(database_path, str(exc)) from exc
|
||||||
|
|
||||||
|
try:
|
||||||
|
db = await aiosqlite.connect(database_path)
|
||||||
|
except PermissionError as exc:
|
||||||
|
log.error("database_open_failed", error=str(exc), database_path=database_path)
|
||||||
|
raise DatabasePermissionDeniedError(database_path) from exc
|
||||||
|
except aiosqlite.OperationalError as exc:
|
||||||
|
error_msg = str(exc).lower()
|
||||||
|
sqlite_code = getattr(exc, "sqlite_errorcode", None)
|
||||||
|
log.error(
|
||||||
|
"database_open_failed",
|
||||||
|
error=str(exc),
|
||||||
|
sqlite_errorcode=sqlite_code,
|
||||||
|
database_path=database_path,
|
||||||
|
)
|
||||||
|
if "database is locked" in error_msg or "busy" in error_msg:
|
||||||
|
raise DatabaseUnavailableError(database_path, str(exc)) from exc
|
||||||
|
if "unable to open database file" in error_msg:
|
||||||
|
raise DatabasePathInvalidError(database_path) from exc
|
||||||
|
raise DatabaseUnavailableError(database_path, str(exc)) from exc
|
||||||
|
except aiosqlite.DatabaseError as exc:
|
||||||
|
log.error(
|
||||||
|
"database_open_failed",
|
||||||
|
error=str(exc),
|
||||||
|
database_path=database_path,
|
||||||
|
)
|
||||||
|
raise DatabaseCorruptedError(database_path) from exc
|
||||||
|
except OSError as exc:
|
||||||
|
log.error("database_open_failed", error=str(exc), database_path=database_path)
|
||||||
|
raise DatabaseUnavailableError(database_path, str(exc)) from exc
|
||||||
|
except Exception as exc:
|
||||||
|
log.error("database_open_failed", error=str(exc), database_path=database_path)
|
||||||
|
raise DatabaseUnavailableError(database_path, str(exc)) from exc
|
||||||
|
|
||||||
db.row_factory = aiosqlite.Row
|
db.row_factory = aiosqlite.Row
|
||||||
await _configure_connection(db)
|
try:
|
||||||
|
await _configure_connection(db)
|
||||||
|
except Exception:
|
||||||
|
await db.close()
|
||||||
|
raise
|
||||||
return db
|
return db
|
||||||
|
|||||||
@@ -165,22 +165,61 @@ async def get_db(
|
|||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
An open :class:`aiosqlite.Connection` for the request.
|
An open :class:`aiosqlite.Connection` for the request.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
DatabaseBusyError: After 3 retries when database is locked by concurrent writers.
|
||||||
|
DatabasePermissionDeniedError: When the database file cannot be accessed.
|
||||||
|
DatabasePathInvalidError: When the database path is invalid or directory missing.
|
||||||
|
DatabaseCorruptedError: When the database file is corrupted.
|
||||||
|
DatabaseUnavailableError: For any other unexpected database error.
|
||||||
"""
|
"""
|
||||||
from app.db import open_db # noqa: PLC0415
|
from app.db import open_db # noqa: PLC0415
|
||||||
|
from app.exceptions import (
|
||||||
|
DatabaseBusyError,
|
||||||
|
DatabaseCorruptedError,
|
||||||
|
DatabasePathInvalidError,
|
||||||
|
DatabasePermissionDeniedError,
|
||||||
|
DatabaseUnavailableError,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
db = None
|
||||||
db = await open_db(settings.database_path)
|
retries = 3
|
||||||
except Exception as exc:
|
retry_delay = 0.1
|
||||||
log.error("database_open_failed", error=str(exc))
|
last_exc = None
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
for attempt in range(1, retries + 1):
|
||||||
detail="Database is not available.",
|
try:
|
||||||
) from exc
|
db = await open_db(settings.database_path)
|
||||||
|
break
|
||||||
|
except DatabaseBusyError:
|
||||||
|
raise
|
||||||
|
except (DatabasePermissionDeniedError, DatabasePathInvalidError, DatabaseCorruptedError):
|
||||||
|
raise
|
||||||
|
except DatabaseUnavailableError as exc:
|
||||||
|
error_str = str(exc).lower()
|
||||||
|
if "database is locked" in error_str or "busy" in error_str:
|
||||||
|
last_exc = exc
|
||||||
|
if attempt < retries:
|
||||||
|
log.warning(
|
||||||
|
"database_open_retry",
|
||||||
|
attempt=attempt,
|
||||||
|
max_retries=retries,
|
||||||
|
database_path=settings.database_path,
|
||||||
|
)
|
||||||
|
import asyncio
|
||||||
|
await asyncio.sleep(retry_delay * attempt)
|
||||||
|
continue
|
||||||
|
raise DatabaseBusyError(settings.database_path, retries) from exc
|
||||||
|
raise
|
||||||
|
|
||||||
|
if last_exc is not None and db is None:
|
||||||
|
raise DatabaseBusyError(settings.database_path, retries)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
yield db
|
yield db
|
||||||
finally:
|
finally:
|
||||||
await db.close()
|
if db is not None:
|
||||||
|
await db.close()
|
||||||
|
|
||||||
|
|
||||||
async def get_http_session(
|
async def get_http_session(
|
||||||
|
|||||||
@@ -473,6 +473,75 @@ class SetupAlreadyCompleteError(ConflictError):
|
|||||||
super().__init__("Setup has already been completed.")
|
super().__init__("Setup has already been completed.")
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseBusyError(ServiceUnavailableError):
|
||||||
|
"""Raised when the SQLite database is locked or busy after all retries."""
|
||||||
|
|
||||||
|
error_code: str = "database_busy"
|
||||||
|
|
||||||
|
def __init__(self, database_path: str, retries: int) -> None:
|
||||||
|
self.database_path = database_path
|
||||||
|
self.retries = retries
|
||||||
|
super().__init__(
|
||||||
|
f"Database is temporarily busy after {retries} retries."
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_error_metadata(self) -> ErrorMetadata:
|
||||||
|
return {"database_path": self.database_path, "retries": self.retries}
|
||||||
|
|
||||||
|
|
||||||
|
class DatabasePermissionDeniedError(ServiceUnavailableError):
|
||||||
|
"""Raised when the database file cannot be accessed due to insufficient permissions."""
|
||||||
|
|
||||||
|
error_code: str = "database_permission_denied"
|
||||||
|
|
||||||
|
def __init__(self, database_path: str) -> None:
|
||||||
|
self.database_path = database_path
|
||||||
|
super().__init__("Insufficient permissions to access the database file.")
|
||||||
|
|
||||||
|
def get_error_metadata(self) -> ErrorMetadata:
|
||||||
|
return {"database_path": self.database_path}
|
||||||
|
|
||||||
|
|
||||||
|
class DatabasePathInvalidError(ServiceUnavailableError):
|
||||||
|
"""Raised when the database directory does not exist or the path is invalid."""
|
||||||
|
|
||||||
|
error_code: str = "database_path_invalid"
|
||||||
|
|
||||||
|
def __init__(self, database_path: str) -> None:
|
||||||
|
self.database_path = database_path
|
||||||
|
super().__init__("Database directory does not exist or path is invalid.")
|
||||||
|
|
||||||
|
def get_error_metadata(self) -> ErrorMetadata:
|
||||||
|
return {"database_path": self.database_path}
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseCorruptedError(ServiceUnavailableError):
|
||||||
|
"""Raised when the database file is corrupted."""
|
||||||
|
|
||||||
|
error_code: str = "database_corrupted"
|
||||||
|
|
||||||
|
def __init__(self, database_path: str) -> None:
|
||||||
|
self.database_path = database_path
|
||||||
|
super().__init__("Database file is corrupted.")
|
||||||
|
|
||||||
|
def get_error_metadata(self) -> ErrorMetadata:
|
||||||
|
return {"database_path": self.database_path}
|
||||||
|
|
||||||
|
|
||||||
|
class DatabaseUnavailableError(ServiceUnavailableError):
|
||||||
|
"""Raised for any other unexpected database error."""
|
||||||
|
|
||||||
|
error_code: str = "database_unavailable"
|
||||||
|
|
||||||
|
def __init__(self, database_path: str, error: str) -> None:
|
||||||
|
self.database_path = database_path
|
||||||
|
self.error = error
|
||||||
|
super().__init__(f"Database is not available: {error}")
|
||||||
|
|
||||||
|
def get_error_metadata(self) -> ErrorMetadata:
|
||||||
|
return {"database_path": self.database_path, "error": self.error}
|
||||||
|
|
||||||
|
|
||||||
class BlocklistSourceNotFoundError(NotFoundError):
|
class BlocklistSourceNotFoundError(NotFoundError):
|
||||||
"""Raised when a blocklist source is not found."""
|
"""Raised when a blocklist source is not found."""
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
import aiosqlite
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from starlette.requests import Request
|
from starlette.requests import Request
|
||||||
@@ -19,6 +20,13 @@ from app.dependencies import (
|
|||||||
get_settings,
|
get_settings,
|
||||||
get_settings_repo,
|
get_settings_repo,
|
||||||
)
|
)
|
||||||
|
from app.exceptions import (
|
||||||
|
DatabaseBusyError,
|
||||||
|
DatabaseCorruptedError,
|
||||||
|
DatabasePathInvalidError,
|
||||||
|
DatabasePermissionDeniedError,
|
||||||
|
DatabaseUnavailableError,
|
||||||
|
)
|
||||||
from app.main import create_app
|
from app.main import create_app
|
||||||
from app.models.server import ServerStatus
|
from app.models.server import ServerStatus
|
||||||
|
|
||||||
@@ -98,3 +106,184 @@ async def test_get_db_uses_effective_runtime_database_path(test_settings: Settin
|
|||||||
await gen.aclose()
|
await gen.aclose()
|
||||||
|
|
||||||
mock_open_db.assert_awaited_once_with("/tmp/runtime.db")
|
mock_open_db.assert_awaited_once_with("/tmp/runtime.db")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Database error handling tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_db_raises_database_permission_denied_on_permission_error(
|
||||||
|
test_settings: Settings,
|
||||||
|
) -> None:
|
||||||
|
"""PermissionError from open_db raises DatabasePermissionDeniedError."""
|
||||||
|
with patch(
|
||||||
|
"app.db.open_db",
|
||||||
|
new=AsyncMock(side_effect=DatabasePermissionDeniedError(test_settings.database_path)),
|
||||||
|
):
|
||||||
|
gen = get_db(settings=test_settings)
|
||||||
|
with pytest.raises(DatabasePermissionDeniedError) as exc_info:
|
||||||
|
await gen.__anext__()
|
||||||
|
await gen.aclose()
|
||||||
|
|
||||||
|
assert exc_info.value.error_code == "database_permission_denied"
|
||||||
|
assert exc_info.value.database_path == test_settings.database_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_db_raises_database_path_invalid_on_missing_directory(
|
||||||
|
test_settings: Settings,
|
||||||
|
) -> None:
|
||||||
|
"""sqlite3.OperationalError('unable to open database file') raises DatabasePathInvalidError."""
|
||||||
|
with patch(
|
||||||
|
"app.db.open_db",
|
||||||
|
new=AsyncMock(side_effect=DatabasePathInvalidError(test_settings.database_path)),
|
||||||
|
):
|
||||||
|
gen = get_db(settings=test_settings)
|
||||||
|
with pytest.raises(DatabasePathInvalidError) as exc_info:
|
||||||
|
await gen.__anext__()
|
||||||
|
await gen.aclose()
|
||||||
|
|
||||||
|
assert exc_info.value.error_code == "database_path_invalid"
|
||||||
|
assert exc_info.value.database_path == test_settings.database_path
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_db_retries_on_database_locked(test_settings: Settings) -> None:
|
||||||
|
"""get_db retries up to 3 times when database is locked."""
|
||||||
|
mock_connection = MagicMock()
|
||||||
|
mock_connection.close = AsyncMock()
|
||||||
|
|
||||||
|
locked_err = DatabaseUnavailableError(
|
||||||
|
test_settings.database_path, "database is locked"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.db.open_db",
|
||||||
|
new=AsyncMock(side_effect=[locked_err, locked_err, mock_connection]),
|
||||||
|
) as mock_open:
|
||||||
|
gen = get_db(settings=test_settings)
|
||||||
|
with patch("asyncio.sleep", new=AsyncMock()) as mock_sleep:
|
||||||
|
connection = await gen.__anext__()
|
||||||
|
await gen.aclose()
|
||||||
|
|
||||||
|
assert mock_open.call_count == 3
|
||||||
|
assert connection is mock_connection
|
||||||
|
assert mock_sleep.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_db_fails_after_max_retries_on_database_locked(
|
||||||
|
test_settings: Settings,
|
||||||
|
) -> None:
|
||||||
|
"""After 3 retries on database locked, raises DatabaseBusyError."""
|
||||||
|
locked_err = DatabaseUnavailableError(
|
||||||
|
test_settings.database_path, "database is locked"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.db.open_db", new=AsyncMock(side_effect=locked_err)) as mock_open:
|
||||||
|
gen = get_db(settings=test_settings)
|
||||||
|
with patch("asyncio.sleep", new=AsyncMock()):
|
||||||
|
with pytest.raises(DatabaseBusyError) as exc_info:
|
||||||
|
await gen.__anext__()
|
||||||
|
await gen.aclose()
|
||||||
|
|
||||||
|
assert mock_open.call_count == 3
|
||||||
|
assert exc_info.value.error_code == "database_busy"
|
||||||
|
assert exc_info.value.retries == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_db_raises_database_corrupted_on_malformed_db(
|
||||||
|
test_settings: Settings,
|
||||||
|
) -> None:
|
||||||
|
"""sqlite3.DatabaseError('database disk image is malformed') raises DatabaseCorruptedError."""
|
||||||
|
with patch(
|
||||||
|
"app.db.open_db",
|
||||||
|
new=AsyncMock(side_effect=DatabaseCorruptedError(test_settings.database_path)),
|
||||||
|
):
|
||||||
|
gen = get_db(settings=test_settings)
|
||||||
|
with pytest.raises(DatabaseCorruptedError) as exc_info:
|
||||||
|
await gen.__anext__()
|
||||||
|
await gen.aclose()
|
||||||
|
|
||||||
|
assert exc_info.value.error_code == "database_corrupted"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_open_db_creates_parent_directory_if_missing(tmp_path: pytest.Path) -> None:
|
||||||
|
"""open_db creates the parent directory when it does not exist."""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from app.db import open_db
|
||||||
|
|
||||||
|
db_path = str(Path(str(tmp_path)) / "subdir" / "deeper" / "bangui.db")
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
mock_conn.close = AsyncMock()
|
||||||
|
mock_conn.execute = AsyncMock()
|
||||||
|
mock_conn.commit = AsyncMock()
|
||||||
|
|
||||||
|
with patch("aiosqlite.connect", new=AsyncMock(return_value=mock_conn)), \
|
||||||
|
patch("app.db._configure_connection", new=AsyncMock()):
|
||||||
|
connection = await open_db(db_path)
|
||||||
|
|
||||||
|
assert connection is mock_conn
|
||||||
|
assert Path(db_path).parent.exists()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_open_db_logs_specific_sqlite_error_code() -> None:
|
||||||
|
"""open_db logs the SQLite error code when available."""
|
||||||
|
from app.db import open_db
|
||||||
|
|
||||||
|
exc = aiosqlite.OperationalError("database is locked")
|
||||||
|
exc.sqlite_errorcode = 5 # SQLITE_BUSY
|
||||||
|
|
||||||
|
with patch("aiosqlite.connect", new=AsyncMock(side_effect=exc)), \
|
||||||
|
pytest.raises(DatabaseUnavailableError):
|
||||||
|
await open_db("/tmp/test.db")
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Error metadata tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_database_busy_error_metadata() -> None:
|
||||||
|
"""DatabaseBusyError returns correct metadata."""
|
||||||
|
err = DatabaseBusyError("/data/bangui.db", retries=3)
|
||||||
|
assert err.error_code == "database_busy"
|
||||||
|
metadata = err.get_error_metadata()
|
||||||
|
assert metadata["database_path"] == "/data/bangui.db"
|
||||||
|
assert metadata["retries"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_database_permission_denied_error_metadata() -> None:
|
||||||
|
"""DatabasePermissionDeniedError returns correct metadata."""
|
||||||
|
err = DatabasePermissionDeniedError("/data/bangui.db")
|
||||||
|
assert err.error_code == "database_permission_denied"
|
||||||
|
assert err.get_error_metadata()["database_path"] == "/data/bangui.db"
|
||||||
|
|
||||||
|
|
||||||
|
def test_database_path_invalid_error_metadata() -> None:
|
||||||
|
"""DatabasePathInvalidError returns correct metadata."""
|
||||||
|
err = DatabasePathInvalidError("/data/bangui.db")
|
||||||
|
assert err.error_code == "database_path_invalid"
|
||||||
|
assert err.get_error_metadata()["database_path"] == "/data/bangui.db"
|
||||||
|
|
||||||
|
|
||||||
|
def test_database_corrupted_error_metadata() -> None:
|
||||||
|
"""DatabaseCorruptedError returns correct metadata."""
|
||||||
|
err = DatabaseCorruptedError("/data/bangui.db")
|
||||||
|
assert err.error_code == "database_corrupted"
|
||||||
|
assert err.get_error_metadata()["database_path"] == "/data/bangui.db"
|
||||||
|
|
||||||
|
|
||||||
|
def test_database_unavailable_error_metadata() -> None:
|
||||||
|
"""DatabaseUnavailableError returns correct metadata."""
|
||||||
|
err = DatabaseUnavailableError("/data/bangui.db", "some error")
|
||||||
|
assert err.error_code == "database_unavailable"
|
||||||
|
metadata = err.get_error_metadata()
|
||||||
|
assert metadata["database_path"] == "/data/bangui.db"
|
||||||
|
assert metadata["error"] == "some error"
|
||||||
|
|||||||
Reference in New Issue
Block a user