Files
BanGUI/backend/tests/test_dependencies.py
Lukas 9e765c6cb7 Add granular DB error types with retry logic
New exceptions: DatabaseBusyError, DatabasePermissionDeniedError,
DatabasePathInvalidError, DatabaseCorruptedError, DatabaseUnavailableError.

open_db creates parent directory if missing. Catches all aiosqlite errors
and maps to specific exception types.

get_db retries up to 3x on locked database with backoff.
Propagates specific exceptions instead of generic HTTPException.

Tests for all new error types and retry behavior.
2026-05-24 22:05:34 +02:00

290 lines
10 KiB
Python

from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import aiohttp
import aiosqlite
import pytest
from fastapi import FastAPI
from starlette.requests import Request
from app.config import Settings
from app.dependencies import (
ApplicationContext,
get_app_context,
get_db,
get_history_archive_repo,
get_http_session,
get_scheduler,
get_session_cache,
get_settings,
get_settings_repo,
)
from app.exceptions import (
DatabaseBusyError,
DatabaseCorruptedError,
DatabasePathInvalidError,
DatabasePermissionDeniedError,
DatabaseUnavailableError,
)
from app.main import create_app
from app.models.server import ServerStatus
def _make_test_request(app: FastAPI) -> Request:
scope = {
"type": "http",
"method": "GET",
"path": "/",
"headers": [],
"query_string": b"",
"client": ("test", 0),
"server": ("test", 0),
"scheme": "http",
"app": app,
}
return Request(scope)
@pytest.mark.asyncio
async def test_app_context_dependency_exposes_shared_resources(test_settings: Settings) -> None:
app = create_app(settings=test_settings)
session = aiohttp.ClientSession()
scheduler = MagicMock()
app.state.http_session = session
app.state.scheduler = scheduler
app.state.server_status = ServerStatus(online=False)
app.state.pending_recovery = None
app.state.last_activation = None
request = _make_test_request(app)
app_context = await get_app_context(request)
assert isinstance(app_context, ApplicationContext)
assert app_context.settings is test_settings
assert app_context.http_session is session
assert app_context.scheduler is scheduler
assert app_context.session_cache is app.state.session_cache
assert app_context.runtime_state is app.state.runtime_state
assert await get_settings(app_context) is test_settings
assert await get_http_session(app_context) is session
assert await get_scheduler(app_context) is scheduler
assert await get_session_cache(app_context) is app.state.session_cache
await session.close()
@pytest.mark.asyncio
async def test_settings_and_history_archive_repo_dependencies_return_modules() -> None:
settings_repo = await get_settings_repo()
history_archive_repo = await get_history_archive_repo()
assert hasattr(settings_repo, "get_setting")
assert hasattr(settings_repo, "set_setting")
assert hasattr(settings_repo, "delete_setting")
assert hasattr(settings_repo, "get_all_settings")
assert hasattr(history_archive_repo, "archive_ban_event")
assert hasattr(history_archive_repo, "get_max_timeofban")
assert hasattr(history_archive_repo, "get_archived_history")
@pytest.mark.asyncio
async def test_get_db_uses_effective_runtime_database_path(test_settings: Settings) -> None:
"""Database connections should use effective runtime settings when overridden."""
runtime_settings = test_settings.model_copy(update={"database_path": "/tmp/runtime.db"})
mock_connection = MagicMock()
mock_connection.close = AsyncMock()
with patch("app.db.open_db", new=AsyncMock(return_value=mock_connection)) as mock_open_db:
gen = get_db(settings=runtime_settings)
try:
connection = await gen.__anext__()
assert connection is mock_connection
finally:
await gen.aclose()
mock_open_db.assert_awaited_once_with("/tmp/runtime.db")
# ---------------------------------------------------------------------------
# Database error handling tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_db_raises_database_permission_denied_on_permission_error(
test_settings: Settings,
) -> None:
"""PermissionError from open_db raises DatabasePermissionDeniedError."""
with patch(
"app.db.open_db",
new=AsyncMock(side_effect=DatabasePermissionDeniedError(test_settings.database_path)),
):
gen = get_db(settings=test_settings)
with pytest.raises(DatabasePermissionDeniedError) as exc_info:
await gen.__anext__()
await gen.aclose()
assert exc_info.value.error_code == "database_permission_denied"
assert exc_info.value.database_path == test_settings.database_path
@pytest.mark.asyncio
async def test_get_db_raises_database_path_invalid_on_missing_directory(
test_settings: Settings,
) -> None:
"""sqlite3.OperationalError('unable to open database file') raises DatabasePathInvalidError."""
with patch(
"app.db.open_db",
new=AsyncMock(side_effect=DatabasePathInvalidError(test_settings.database_path)),
):
gen = get_db(settings=test_settings)
with pytest.raises(DatabasePathInvalidError) as exc_info:
await gen.__anext__()
await gen.aclose()
assert exc_info.value.error_code == "database_path_invalid"
assert exc_info.value.database_path == test_settings.database_path
@pytest.mark.asyncio
async def test_get_db_retries_on_database_locked(test_settings: Settings) -> None:
"""get_db retries up to 3 times when database is locked."""
mock_connection = MagicMock()
mock_connection.close = AsyncMock()
locked_err = DatabaseUnavailableError(
test_settings.database_path, "database is locked"
)
with patch(
"app.db.open_db",
new=AsyncMock(side_effect=[locked_err, locked_err, mock_connection]),
) as mock_open:
gen = get_db(settings=test_settings)
with patch("asyncio.sleep", new=AsyncMock()) as mock_sleep:
connection = await gen.__anext__()
await gen.aclose()
assert mock_open.call_count == 3
assert connection is mock_connection
assert mock_sleep.call_count == 2
@pytest.mark.asyncio
async def test_get_db_fails_after_max_retries_on_database_locked(
test_settings: Settings,
) -> None:
"""After 3 retries on database locked, raises DatabaseBusyError."""
locked_err = DatabaseUnavailableError(
test_settings.database_path, "database is locked"
)
with patch("app.db.open_db", new=AsyncMock(side_effect=locked_err)) as mock_open:
gen = get_db(settings=test_settings)
with patch("asyncio.sleep", new=AsyncMock()):
with pytest.raises(DatabaseBusyError) as exc_info:
await gen.__anext__()
await gen.aclose()
assert mock_open.call_count == 3
assert exc_info.value.error_code == "database_busy"
assert exc_info.value.retries == 3
@pytest.mark.asyncio
async def test_get_db_raises_database_corrupted_on_malformed_db(
test_settings: Settings,
) -> None:
"""sqlite3.DatabaseError('database disk image is malformed') raises DatabaseCorruptedError."""
with patch(
"app.db.open_db",
new=AsyncMock(side_effect=DatabaseCorruptedError(test_settings.database_path)),
):
gen = get_db(settings=test_settings)
with pytest.raises(DatabaseCorruptedError) as exc_info:
await gen.__anext__()
await gen.aclose()
assert exc_info.value.error_code == "database_corrupted"
@pytest.mark.asyncio
async def test_open_db_creates_parent_directory_if_missing(tmp_path: pytest.Path) -> None:
"""open_db creates the parent directory when it does not exist."""
from pathlib import Path
from app.db import open_db
db_path = str(Path(str(tmp_path)) / "subdir" / "deeper" / "bangui.db")
mock_conn = MagicMock()
mock_conn.close = AsyncMock()
mock_conn.execute = AsyncMock()
mock_conn.commit = AsyncMock()
with patch("aiosqlite.connect", new=AsyncMock(return_value=mock_conn)), \
patch("app.db._configure_connection", new=AsyncMock()):
connection = await open_db(db_path)
assert connection is mock_conn
assert Path(db_path).parent.exists()
@pytest.mark.asyncio
async def test_open_db_logs_specific_sqlite_error_code() -> None:
"""open_db logs the SQLite error code when available."""
from app.db import open_db
exc = aiosqlite.OperationalError("database is locked")
exc.sqlite_errorcode = 5 # SQLITE_BUSY
with patch("aiosqlite.connect", new=AsyncMock(side_effect=exc)), \
pytest.raises(DatabaseUnavailableError):
await open_db("/tmp/test.db")
# ---------------------------------------------------------------------------
# Error metadata tests
# ---------------------------------------------------------------------------
def test_database_busy_error_metadata() -> None:
"""DatabaseBusyError returns correct metadata."""
err = DatabaseBusyError("/data/bangui.db", retries=3)
assert err.error_code == "database_busy"
metadata = err.get_error_metadata()
assert metadata["database_path"] == "/data/bangui.db"
assert metadata["retries"] == 3
def test_database_permission_denied_error_metadata() -> None:
"""DatabasePermissionDeniedError returns correct metadata."""
err = DatabasePermissionDeniedError("/data/bangui.db")
assert err.error_code == "database_permission_denied"
assert err.get_error_metadata()["database_path"] == "/data/bangui.db"
def test_database_path_invalid_error_metadata() -> None:
"""DatabasePathInvalidError returns correct metadata."""
err = DatabasePathInvalidError("/data/bangui.db")
assert err.error_code == "database_path_invalid"
assert err.get_error_metadata()["database_path"] == "/data/bangui.db"
def test_database_corrupted_error_metadata() -> None:
"""DatabaseCorruptedError returns correct metadata."""
err = DatabaseCorruptedError("/data/bangui.db")
assert err.error_code == "database_corrupted"
assert err.get_error_metadata()["database_path"] == "/data/bangui.db"
def test_database_unavailable_error_metadata() -> None:
"""DatabaseUnavailableError returns correct metadata."""
err = DatabaseUnavailableError("/data/bangui.db", "some error")
assert err.error_code == "database_unavailable"
metadata = err.get_error_metadata()
assert metadata["database_path"] == "/data/bangui.db"
assert metadata["error"] == "some error"