New exceptions: DatabaseBusyError, DatabasePermissionDeniedError, DatabasePathInvalidError, DatabaseCorruptedError, DatabaseUnavailableError. open_db creates parent directory if missing. Catches all aiosqlite errors and maps to specific exception types. get_db retries up to 3x on locked database with backoff. Propagates specific exceptions instead of generic HTTPException. Tests for all new error types and retry behavior.
290 lines
10 KiB
Python
290 lines
10 KiB
Python
from __future__ import annotations
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import aiohttp
|
|
import aiosqlite
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
from starlette.requests import Request
|
|
|
|
from app.config import Settings
|
|
from app.dependencies import (
|
|
ApplicationContext,
|
|
get_app_context,
|
|
get_db,
|
|
get_history_archive_repo,
|
|
get_http_session,
|
|
get_scheduler,
|
|
get_session_cache,
|
|
get_settings,
|
|
get_settings_repo,
|
|
)
|
|
from app.exceptions import (
|
|
DatabaseBusyError,
|
|
DatabaseCorruptedError,
|
|
DatabasePathInvalidError,
|
|
DatabasePermissionDeniedError,
|
|
DatabaseUnavailableError,
|
|
)
|
|
from app.main import create_app
|
|
from app.models.server import ServerStatus
|
|
|
|
|
|
def _make_test_request(app: FastAPI) -> Request:
|
|
scope = {
|
|
"type": "http",
|
|
"method": "GET",
|
|
"path": "/",
|
|
"headers": [],
|
|
"query_string": b"",
|
|
"client": ("test", 0),
|
|
"server": ("test", 0),
|
|
"scheme": "http",
|
|
"app": app,
|
|
}
|
|
return Request(scope)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_app_context_dependency_exposes_shared_resources(test_settings: Settings) -> None:
|
|
app = create_app(settings=test_settings)
|
|
session = aiohttp.ClientSession()
|
|
scheduler = MagicMock()
|
|
app.state.http_session = session
|
|
app.state.scheduler = scheduler
|
|
app.state.server_status = ServerStatus(online=False)
|
|
app.state.pending_recovery = None
|
|
app.state.last_activation = None
|
|
|
|
request = _make_test_request(app)
|
|
app_context = await get_app_context(request)
|
|
|
|
assert isinstance(app_context, ApplicationContext)
|
|
assert app_context.settings is test_settings
|
|
assert app_context.http_session is session
|
|
assert app_context.scheduler is scheduler
|
|
assert app_context.session_cache is app.state.session_cache
|
|
assert app_context.runtime_state is app.state.runtime_state
|
|
assert await get_settings(app_context) is test_settings
|
|
assert await get_http_session(app_context) is session
|
|
assert await get_scheduler(app_context) is scheduler
|
|
assert await get_session_cache(app_context) is app.state.session_cache
|
|
|
|
await session.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_settings_and_history_archive_repo_dependencies_return_modules() -> None:
|
|
settings_repo = await get_settings_repo()
|
|
history_archive_repo = await get_history_archive_repo()
|
|
|
|
assert hasattr(settings_repo, "get_setting")
|
|
assert hasattr(settings_repo, "set_setting")
|
|
assert hasattr(settings_repo, "delete_setting")
|
|
assert hasattr(settings_repo, "get_all_settings")
|
|
|
|
assert hasattr(history_archive_repo, "archive_ban_event")
|
|
assert hasattr(history_archive_repo, "get_max_timeofban")
|
|
assert hasattr(history_archive_repo, "get_archived_history")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_db_uses_effective_runtime_database_path(test_settings: Settings) -> None:
|
|
"""Database connections should use effective runtime settings when overridden."""
|
|
runtime_settings = test_settings.model_copy(update={"database_path": "/tmp/runtime.db"})
|
|
|
|
mock_connection = MagicMock()
|
|
mock_connection.close = AsyncMock()
|
|
|
|
with patch("app.db.open_db", new=AsyncMock(return_value=mock_connection)) as mock_open_db:
|
|
gen = get_db(settings=runtime_settings)
|
|
try:
|
|
connection = await gen.__anext__()
|
|
assert connection is mock_connection
|
|
finally:
|
|
await gen.aclose()
|
|
|
|
mock_open_db.assert_awaited_once_with("/tmp/runtime.db")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Database error handling tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_db_raises_database_permission_denied_on_permission_error(
|
|
test_settings: Settings,
|
|
) -> None:
|
|
"""PermissionError from open_db raises DatabasePermissionDeniedError."""
|
|
with patch(
|
|
"app.db.open_db",
|
|
new=AsyncMock(side_effect=DatabasePermissionDeniedError(test_settings.database_path)),
|
|
):
|
|
gen = get_db(settings=test_settings)
|
|
with pytest.raises(DatabasePermissionDeniedError) as exc_info:
|
|
await gen.__anext__()
|
|
await gen.aclose()
|
|
|
|
assert exc_info.value.error_code == "database_permission_denied"
|
|
assert exc_info.value.database_path == test_settings.database_path
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_db_raises_database_path_invalid_on_missing_directory(
|
|
test_settings: Settings,
|
|
) -> None:
|
|
"""sqlite3.OperationalError('unable to open database file') raises DatabasePathInvalidError."""
|
|
with patch(
|
|
"app.db.open_db",
|
|
new=AsyncMock(side_effect=DatabasePathInvalidError(test_settings.database_path)),
|
|
):
|
|
gen = get_db(settings=test_settings)
|
|
with pytest.raises(DatabasePathInvalidError) as exc_info:
|
|
await gen.__anext__()
|
|
await gen.aclose()
|
|
|
|
assert exc_info.value.error_code == "database_path_invalid"
|
|
assert exc_info.value.database_path == test_settings.database_path
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_db_retries_on_database_locked(test_settings: Settings) -> None:
|
|
"""get_db retries up to 3 times when database is locked."""
|
|
mock_connection = MagicMock()
|
|
mock_connection.close = AsyncMock()
|
|
|
|
locked_err = DatabaseUnavailableError(
|
|
test_settings.database_path, "database is locked"
|
|
)
|
|
|
|
with patch(
|
|
"app.db.open_db",
|
|
new=AsyncMock(side_effect=[locked_err, locked_err, mock_connection]),
|
|
) as mock_open:
|
|
gen = get_db(settings=test_settings)
|
|
with patch("asyncio.sleep", new=AsyncMock()) as mock_sleep:
|
|
connection = await gen.__anext__()
|
|
await gen.aclose()
|
|
|
|
assert mock_open.call_count == 3
|
|
assert connection is mock_connection
|
|
assert mock_sleep.call_count == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_db_fails_after_max_retries_on_database_locked(
|
|
test_settings: Settings,
|
|
) -> None:
|
|
"""After 3 retries on database locked, raises DatabaseBusyError."""
|
|
locked_err = DatabaseUnavailableError(
|
|
test_settings.database_path, "database is locked"
|
|
)
|
|
|
|
with patch("app.db.open_db", new=AsyncMock(side_effect=locked_err)) as mock_open:
|
|
gen = get_db(settings=test_settings)
|
|
with patch("asyncio.sleep", new=AsyncMock()):
|
|
with pytest.raises(DatabaseBusyError) as exc_info:
|
|
await gen.__anext__()
|
|
await gen.aclose()
|
|
|
|
assert mock_open.call_count == 3
|
|
assert exc_info.value.error_code == "database_busy"
|
|
assert exc_info.value.retries == 3
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_db_raises_database_corrupted_on_malformed_db(
|
|
test_settings: Settings,
|
|
) -> None:
|
|
"""sqlite3.DatabaseError('database disk image is malformed') raises DatabaseCorruptedError."""
|
|
with patch(
|
|
"app.db.open_db",
|
|
new=AsyncMock(side_effect=DatabaseCorruptedError(test_settings.database_path)),
|
|
):
|
|
gen = get_db(settings=test_settings)
|
|
with pytest.raises(DatabaseCorruptedError) as exc_info:
|
|
await gen.__anext__()
|
|
await gen.aclose()
|
|
|
|
assert exc_info.value.error_code == "database_corrupted"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_open_db_creates_parent_directory_if_missing(tmp_path: pytest.Path) -> None:
|
|
"""open_db creates the parent directory when it does not exist."""
|
|
from pathlib import Path
|
|
|
|
from app.db import open_db
|
|
|
|
db_path = str(Path(str(tmp_path)) / "subdir" / "deeper" / "bangui.db")
|
|
mock_conn = MagicMock()
|
|
mock_conn.close = AsyncMock()
|
|
mock_conn.execute = AsyncMock()
|
|
mock_conn.commit = AsyncMock()
|
|
|
|
with patch("aiosqlite.connect", new=AsyncMock(return_value=mock_conn)), \
|
|
patch("app.db._configure_connection", new=AsyncMock()):
|
|
connection = await open_db(db_path)
|
|
|
|
assert connection is mock_conn
|
|
assert Path(db_path).parent.exists()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_open_db_logs_specific_sqlite_error_code() -> None:
|
|
"""open_db logs the SQLite error code when available."""
|
|
from app.db import open_db
|
|
|
|
exc = aiosqlite.OperationalError("database is locked")
|
|
exc.sqlite_errorcode = 5 # SQLITE_BUSY
|
|
|
|
with patch("aiosqlite.connect", new=AsyncMock(side_effect=exc)), \
|
|
pytest.raises(DatabaseUnavailableError):
|
|
await open_db("/tmp/test.db")
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Error metadata tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_database_busy_error_metadata() -> None:
|
|
"""DatabaseBusyError returns correct metadata."""
|
|
err = DatabaseBusyError("/data/bangui.db", retries=3)
|
|
assert err.error_code == "database_busy"
|
|
metadata = err.get_error_metadata()
|
|
assert metadata["database_path"] == "/data/bangui.db"
|
|
assert metadata["retries"] == 3
|
|
|
|
|
|
def test_database_permission_denied_error_metadata() -> None:
|
|
"""DatabasePermissionDeniedError returns correct metadata."""
|
|
err = DatabasePermissionDeniedError("/data/bangui.db")
|
|
assert err.error_code == "database_permission_denied"
|
|
assert err.get_error_metadata()["database_path"] == "/data/bangui.db"
|
|
|
|
|
|
def test_database_path_invalid_error_metadata() -> None:
|
|
"""DatabasePathInvalidError returns correct metadata."""
|
|
err = DatabasePathInvalidError("/data/bangui.db")
|
|
assert err.error_code == "database_path_invalid"
|
|
assert err.get_error_metadata()["database_path"] == "/data/bangui.db"
|
|
|
|
|
|
def test_database_corrupted_error_metadata() -> None:
|
|
"""DatabaseCorruptedError returns correct metadata."""
|
|
err = DatabaseCorruptedError("/data/bangui.db")
|
|
assert err.error_code == "database_corrupted"
|
|
assert err.get_error_metadata()["database_path"] == "/data/bangui.db"
|
|
|
|
|
|
def test_database_unavailable_error_metadata() -> None:
|
|
"""DatabaseUnavailableError returns correct metadata."""
|
|
err = DatabaseUnavailableError("/data/bangui.db", "some error")
|
|
assert err.error_code == "database_unavailable"
|
|
metadata = err.get_error_metadata()
|
|
assert metadata["database_path"] == "/data/bangui.db"
|
|
assert metadata["error"] == "some error"
|