diff --git a/backend/app/db.py b/backend/app/db.py index 6f4b74e..85d7567 100644 --- a/backend/app/db.py +++ b/backend/app/db.py @@ -475,14 +475,75 @@ async def init_db(db: aiosqlite.Connection) -> None: async def open_db(database_path: str) -> aiosqlite.Connection: """Open a new application SQLite connection with the standard settings. + Creates the parent directory if it does not exist. + Args: database_path: Path to the BanGUI SQLite database. Returns: A configured :class:`aiosqlite.Connection` instance. + + Raises: + DatabasePathInvalidError: If the directory cannot be created or is inaccessible. + DatabasePermissionDeniedError: If aiosqlite.connect raises PermissionError. + DatabaseCorruptedError: If the database file is corrupted. + DatabaseUnavailableError: For any other unexpected error. """ - await _cleanup_wal_files(database_path) - db = await aiosqlite.connect(database_path) + from app.exceptions import ( + DatabaseCorruptedError, + DatabasePathInvalidError, + DatabasePermissionDeniedError, + DatabaseUnavailableError, + ) + + db_dir = Path(database_path).parent + if not db_dir.exists(): + try: + db_dir.mkdir(parents=True, exist_ok=True) + except PermissionError as exc: + log.error("database_open_failed", error=str(exc), database_path=database_path) + raise DatabasePathInvalidError(database_path) from exc + except OSError as exc: + log.error("database_open_failed", error=str(exc), database_path=database_path) + raise DatabaseUnavailableError(database_path, str(exc)) from exc + + try: + db = await aiosqlite.connect(database_path) + except PermissionError as exc: + log.error("database_open_failed", error=str(exc), database_path=database_path) + raise DatabasePermissionDeniedError(database_path) from exc + except aiosqlite.OperationalError as exc: + error_msg = str(exc).lower() + sqlite_code = getattr(exc, "sqlite_errorcode", None) + log.error( + "database_open_failed", + error=str(exc), + sqlite_errorcode=sqlite_code, + database_path=database_path, + ) + if "database is locked" in error_msg or "busy" in error_msg: + raise DatabaseUnavailableError(database_path, str(exc)) from exc + if "unable to open database file" in error_msg: + raise DatabasePathInvalidError(database_path) from exc + raise DatabaseUnavailableError(database_path, str(exc)) from exc + except aiosqlite.DatabaseError as exc: + log.error( + "database_open_failed", + error=str(exc), + database_path=database_path, + ) + raise DatabaseCorruptedError(database_path) from exc + except OSError as exc: + log.error("database_open_failed", error=str(exc), database_path=database_path) + raise DatabaseUnavailableError(database_path, str(exc)) from exc + except Exception as exc: + log.error("database_open_failed", error=str(exc), database_path=database_path) + raise DatabaseUnavailableError(database_path, str(exc)) from exc + db.row_factory = aiosqlite.Row - await _configure_connection(db) + try: + await _configure_connection(db) + except Exception: + await db.close() + raise return db diff --git a/backend/app/dependencies.py b/backend/app/dependencies.py index 1244012..21c0471 100644 --- a/backend/app/dependencies.py +++ b/backend/app/dependencies.py @@ -165,22 +165,61 @@ async def get_db( Yields: An open :class:`aiosqlite.Connection` for the request. + + Raises: + DatabaseBusyError: After 3 retries when database is locked by concurrent writers. + DatabasePermissionDeniedError: When the database file cannot be accessed. + DatabasePathInvalidError: When the database path is invalid or directory missing. + DatabaseCorruptedError: When the database file is corrupted. + DatabaseUnavailableError: For any other unexpected database error. """ from app.db import open_db # noqa: PLC0415 + from app.exceptions import ( + DatabaseBusyError, + DatabaseCorruptedError, + DatabasePathInvalidError, + DatabasePermissionDeniedError, + DatabaseUnavailableError, + ) - try: - db = await open_db(settings.database_path) - except Exception as exc: - log.error("database_open_failed", error=str(exc)) - raise HTTPException( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - detail="Database is not available.", - ) from exc + db = None + retries = 3 + retry_delay = 0.1 + last_exc = None + + for attempt in range(1, retries + 1): + try: + db = await open_db(settings.database_path) + break + except DatabaseBusyError: + raise + except (DatabasePermissionDeniedError, DatabasePathInvalidError, DatabaseCorruptedError): + raise + except DatabaseUnavailableError as exc: + error_str = str(exc).lower() + if "database is locked" in error_str or "busy" in error_str: + last_exc = exc + if attempt < retries: + log.warning( + "database_open_retry", + attempt=attempt, + max_retries=retries, + database_path=settings.database_path, + ) + import asyncio + await asyncio.sleep(retry_delay * attempt) + continue + raise DatabaseBusyError(settings.database_path, retries) from exc + raise + + if last_exc is not None and db is None: + raise DatabaseBusyError(settings.database_path, retries) try: yield db finally: - await db.close() + if db is not None: + await db.close() async def get_http_session( diff --git a/backend/app/exceptions.py b/backend/app/exceptions.py index f8a8682..caabab4 100644 --- a/backend/app/exceptions.py +++ b/backend/app/exceptions.py @@ -473,6 +473,75 @@ class SetupAlreadyCompleteError(ConflictError): super().__init__("Setup has already been completed.") +class DatabaseBusyError(ServiceUnavailableError): + """Raised when the SQLite database is locked or busy after all retries.""" + + error_code: str = "database_busy" + + def __init__(self, database_path: str, retries: int) -> None: + self.database_path = database_path + self.retries = retries + super().__init__( + f"Database is temporarily busy after {retries} retries." + ) + + def get_error_metadata(self) -> ErrorMetadata: + return {"database_path": self.database_path, "retries": self.retries} + + +class DatabasePermissionDeniedError(ServiceUnavailableError): + """Raised when the database file cannot be accessed due to insufficient permissions.""" + + error_code: str = "database_permission_denied" + + def __init__(self, database_path: str) -> None: + self.database_path = database_path + super().__init__("Insufficient permissions to access the database file.") + + def get_error_metadata(self) -> ErrorMetadata: + return {"database_path": self.database_path} + + +class DatabasePathInvalidError(ServiceUnavailableError): + """Raised when the database directory does not exist or the path is invalid.""" + + error_code: str = "database_path_invalid" + + def __init__(self, database_path: str) -> None: + self.database_path = database_path + super().__init__("Database directory does not exist or path is invalid.") + + def get_error_metadata(self) -> ErrorMetadata: + return {"database_path": self.database_path} + + +class DatabaseCorruptedError(ServiceUnavailableError): + """Raised when the database file is corrupted.""" + + error_code: str = "database_corrupted" + + def __init__(self, database_path: str) -> None: + self.database_path = database_path + super().__init__("Database file is corrupted.") + + def get_error_metadata(self) -> ErrorMetadata: + return {"database_path": self.database_path} + + +class DatabaseUnavailableError(ServiceUnavailableError): + """Raised for any other unexpected database error.""" + + error_code: str = "database_unavailable" + + def __init__(self, database_path: str, error: str) -> None: + self.database_path = database_path + self.error = error + super().__init__(f"Database is not available: {error}") + + def get_error_metadata(self) -> ErrorMetadata: + return {"database_path": self.database_path, "error": self.error} + + class BlocklistSourceNotFoundError(NotFoundError): """Raised when a blocklist source is not found.""" diff --git a/backend/tests/test_dependencies.py b/backend/tests/test_dependencies.py index df3ef56..154e75f 100644 --- a/backend/tests/test_dependencies.py +++ b/backend/tests/test_dependencies.py @@ -3,6 +3,7 @@ from __future__ import annotations from unittest.mock import AsyncMock, MagicMock, patch import aiohttp +import aiosqlite import pytest from fastapi import FastAPI from starlette.requests import Request @@ -19,6 +20,13 @@ from app.dependencies import ( get_settings, get_settings_repo, ) +from app.exceptions import ( + DatabaseBusyError, + DatabaseCorruptedError, + DatabasePathInvalidError, + DatabasePermissionDeniedError, + DatabaseUnavailableError, +) from app.main import create_app from app.models.server import ServerStatus @@ -98,3 +106,184 @@ async def test_get_db_uses_effective_runtime_database_path(test_settings: Settin await gen.aclose() mock_open_db.assert_awaited_once_with("/tmp/runtime.db") + + +# --------------------------------------------------------------------------- +# Database error handling tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_db_raises_database_permission_denied_on_permission_error( + test_settings: Settings, +) -> None: + """PermissionError from open_db raises DatabasePermissionDeniedError.""" + with patch( + "app.db.open_db", + new=AsyncMock(side_effect=DatabasePermissionDeniedError(test_settings.database_path)), + ): + gen = get_db(settings=test_settings) + with pytest.raises(DatabasePermissionDeniedError) as exc_info: + await gen.__anext__() + await gen.aclose() + + assert exc_info.value.error_code == "database_permission_denied" + assert exc_info.value.database_path == test_settings.database_path + + +@pytest.mark.asyncio +async def test_get_db_raises_database_path_invalid_on_missing_directory( + test_settings: Settings, +) -> None: + """sqlite3.OperationalError('unable to open database file') raises DatabasePathInvalidError.""" + with patch( + "app.db.open_db", + new=AsyncMock(side_effect=DatabasePathInvalidError(test_settings.database_path)), + ): + gen = get_db(settings=test_settings) + with pytest.raises(DatabasePathInvalidError) as exc_info: + await gen.__anext__() + await gen.aclose() + + assert exc_info.value.error_code == "database_path_invalid" + assert exc_info.value.database_path == test_settings.database_path + + +@pytest.mark.asyncio +async def test_get_db_retries_on_database_locked(test_settings: Settings) -> None: + """get_db retries up to 3 times when database is locked.""" + mock_connection = MagicMock() + mock_connection.close = AsyncMock() + + locked_err = DatabaseUnavailableError( + test_settings.database_path, "database is locked" + ) + + with patch( + "app.db.open_db", + new=AsyncMock(side_effect=[locked_err, locked_err, mock_connection]), + ) as mock_open: + gen = get_db(settings=test_settings) + with patch("asyncio.sleep", new=AsyncMock()) as mock_sleep: + connection = await gen.__anext__() + await gen.aclose() + + assert mock_open.call_count == 3 + assert connection is mock_connection + assert mock_sleep.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_db_fails_after_max_retries_on_database_locked( + test_settings: Settings, +) -> None: + """After 3 retries on database locked, raises DatabaseBusyError.""" + locked_err = DatabaseUnavailableError( + test_settings.database_path, "database is locked" + ) + + with patch("app.db.open_db", new=AsyncMock(side_effect=locked_err)) as mock_open: + gen = get_db(settings=test_settings) + with patch("asyncio.sleep", new=AsyncMock()): + with pytest.raises(DatabaseBusyError) as exc_info: + await gen.__anext__() + await gen.aclose() + + assert mock_open.call_count == 3 + assert exc_info.value.error_code == "database_busy" + assert exc_info.value.retries == 3 + + +@pytest.mark.asyncio +async def test_get_db_raises_database_corrupted_on_malformed_db( + test_settings: Settings, +) -> None: + """sqlite3.DatabaseError('database disk image is malformed') raises DatabaseCorruptedError.""" + with patch( + "app.db.open_db", + new=AsyncMock(side_effect=DatabaseCorruptedError(test_settings.database_path)), + ): + gen = get_db(settings=test_settings) + with pytest.raises(DatabaseCorruptedError) as exc_info: + await gen.__anext__() + await gen.aclose() + + assert exc_info.value.error_code == "database_corrupted" + + +@pytest.mark.asyncio +async def test_open_db_creates_parent_directory_if_missing(tmp_path: pytest.Path) -> None: + """open_db creates the parent directory when it does not exist.""" + from pathlib import Path + + from app.db import open_db + + db_path = str(Path(str(tmp_path)) / "subdir" / "deeper" / "bangui.db") + mock_conn = MagicMock() + mock_conn.close = AsyncMock() + mock_conn.execute = AsyncMock() + mock_conn.commit = AsyncMock() + + with patch("aiosqlite.connect", new=AsyncMock(return_value=mock_conn)), \ + patch("app.db._configure_connection", new=AsyncMock()): + connection = await open_db(db_path) + + assert connection is mock_conn + assert Path(db_path).parent.exists() + + +@pytest.mark.asyncio +async def test_open_db_logs_specific_sqlite_error_code() -> None: + """open_db logs the SQLite error code when available.""" + from app.db import open_db + + exc = aiosqlite.OperationalError("database is locked") + exc.sqlite_errorcode = 5 # SQLITE_BUSY + + with patch("aiosqlite.connect", new=AsyncMock(side_effect=exc)), \ + pytest.raises(DatabaseUnavailableError): + await open_db("/tmp/test.db") + + +# --------------------------------------------------------------------------- +# Error metadata tests +# --------------------------------------------------------------------------- + + +def test_database_busy_error_metadata() -> None: + """DatabaseBusyError returns correct metadata.""" + err = DatabaseBusyError("/data/bangui.db", retries=3) + assert err.error_code == "database_busy" + metadata = err.get_error_metadata() + assert metadata["database_path"] == "/data/bangui.db" + assert metadata["retries"] == 3 + + +def test_database_permission_denied_error_metadata() -> None: + """DatabasePermissionDeniedError returns correct metadata.""" + err = DatabasePermissionDeniedError("/data/bangui.db") + assert err.error_code == "database_permission_denied" + assert err.get_error_metadata()["database_path"] == "/data/bangui.db" + + +def test_database_path_invalid_error_metadata() -> None: + """DatabasePathInvalidError returns correct metadata.""" + err = DatabasePathInvalidError("/data/bangui.db") + assert err.error_code == "database_path_invalid" + assert err.get_error_metadata()["database_path"] == "/data/bangui.db" + + +def test_database_corrupted_error_metadata() -> None: + """DatabaseCorruptedError returns correct metadata.""" + err = DatabaseCorruptedError("/data/bangui.db") + assert err.error_code == "database_corrupted" + assert err.get_error_metadata()["database_path"] == "/data/bangui.db" + + +def test_database_unavailable_error_metadata() -> None: + """DatabaseUnavailableError returns correct metadata.""" + err = DatabaseUnavailableError("/data/bangui.db", "some error") + assert err.error_code == "database_unavailable" + metadata = err.get_error_metadata() + assert metadata["database_path"] == "/data/bangui.db" + assert metadata["error"] == "some error"