from __future__ import annotations from unittest.mock import AsyncMock, MagicMock, patch import aiohttp import aiosqlite import pytest from fastapi import FastAPI from starlette.requests import Request from app.config import Settings from app.dependencies import ( ApplicationContext, get_app_context, get_db, get_history_archive_repo, get_http_session, get_scheduler, get_session_cache, get_settings, get_settings_repo, ) from app.exceptions import ( DatabaseBusyError, DatabaseCorruptedError, DatabasePathInvalidError, DatabasePermissionDeniedError, DatabaseUnavailableError, ) from app.main import create_app from app.models.server import ServerStatus def _make_test_request(app: FastAPI) -> Request: scope = { "type": "http", "method": "GET", "path": "/", "headers": [], "query_string": b"", "client": ("test", 0), "server": ("test", 0), "scheme": "http", "app": app, } return Request(scope) @pytest.mark.asyncio async def test_app_context_dependency_exposes_shared_resources(test_settings: Settings) -> None: app = create_app(settings=test_settings) session = aiohttp.ClientSession() scheduler = MagicMock() app.state.http_session = session app.state.scheduler = scheduler app.state.server_status = ServerStatus(online=False) app.state.pending_recovery = None app.state.last_activation = None request = _make_test_request(app) app_context = await get_app_context(request) assert isinstance(app_context, ApplicationContext) assert app_context.settings is test_settings assert app_context.http_session is session assert app_context.scheduler is scheduler assert app_context.session_cache is app.state.session_cache assert app_context.runtime_state is app.state.runtime_state assert await get_settings(app_context) is test_settings assert await get_http_session(app_context) is session assert await get_scheduler(app_context) is scheduler assert await get_session_cache(app_context) is app.state.session_cache await session.close() @pytest.mark.asyncio async def test_settings_and_history_archive_repo_dependencies_return_modules() -> None: settings_repo = await get_settings_repo() history_archive_repo = await get_history_archive_repo() assert hasattr(settings_repo, "get_setting") assert hasattr(settings_repo, "set_setting") assert hasattr(settings_repo, "delete_setting") assert hasattr(settings_repo, "get_all_settings") assert hasattr(history_archive_repo, "archive_ban_event") assert hasattr(history_archive_repo, "get_max_timeofban") assert hasattr(history_archive_repo, "get_archived_history") @pytest.mark.asyncio async def test_get_db_uses_effective_runtime_database_path(test_settings: Settings) -> None: """Database connections should use effective runtime settings when overridden.""" runtime_settings = test_settings.model_copy(update={"database_path": "/tmp/runtime.db"}) mock_connection = MagicMock() mock_connection.close = AsyncMock() with patch("app.db.open_db", new=AsyncMock(return_value=mock_connection)) as mock_open_db: gen = get_db(settings=runtime_settings) try: connection = await gen.__anext__() assert connection is mock_connection finally: await gen.aclose() mock_open_db.assert_awaited_once_with("/tmp/runtime.db") # --------------------------------------------------------------------------- # Database error handling tests # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_get_db_raises_database_permission_denied_on_permission_error( test_settings: Settings, ) -> None: """PermissionError from open_db raises DatabasePermissionDeniedError.""" with patch( "app.db.open_db", new=AsyncMock(side_effect=DatabasePermissionDeniedError(test_settings.database_path)), ): gen = get_db(settings=test_settings) with pytest.raises(DatabasePermissionDeniedError) as exc_info: await gen.__anext__() await gen.aclose() assert exc_info.value.error_code == "database_permission_denied" assert exc_info.value.database_path == test_settings.database_path @pytest.mark.asyncio async def test_get_db_raises_database_path_invalid_on_missing_directory( test_settings: Settings, ) -> None: """sqlite3.OperationalError('unable to open database file') raises DatabasePathInvalidError.""" with patch( "app.db.open_db", new=AsyncMock(side_effect=DatabasePathInvalidError(test_settings.database_path)), ): gen = get_db(settings=test_settings) with pytest.raises(DatabasePathInvalidError) as exc_info: await gen.__anext__() await gen.aclose() assert exc_info.value.error_code == "database_path_invalid" assert exc_info.value.database_path == test_settings.database_path @pytest.mark.asyncio async def test_get_db_retries_on_database_locked(test_settings: Settings) -> None: """get_db retries up to 3 times when database is locked.""" mock_connection = MagicMock() mock_connection.close = AsyncMock() locked_err = DatabaseUnavailableError( test_settings.database_path, "database is locked" ) with patch( "app.db.open_db", new=AsyncMock(side_effect=[locked_err, locked_err, mock_connection]), ) as mock_open: gen = get_db(settings=test_settings) with patch("asyncio.sleep", new=AsyncMock()) as mock_sleep: connection = await gen.__anext__() await gen.aclose() assert mock_open.call_count == 3 assert connection is mock_connection assert mock_sleep.call_count == 2 @pytest.mark.asyncio async def test_get_db_fails_after_max_retries_on_database_locked( test_settings: Settings, ) -> None: """After 3 retries on database locked, raises DatabaseBusyError.""" locked_err = DatabaseUnavailableError( test_settings.database_path, "database is locked" ) with patch("app.db.open_db", new=AsyncMock(side_effect=locked_err)) as mock_open: gen = get_db(settings=test_settings) with patch("asyncio.sleep", new=AsyncMock()): with pytest.raises(DatabaseBusyError) as exc_info: await gen.__anext__() await gen.aclose() assert mock_open.call_count == 3 assert exc_info.value.error_code == "database_busy" assert exc_info.value.retries == 3 @pytest.mark.asyncio async def test_get_db_raises_database_corrupted_on_malformed_db( test_settings: Settings, ) -> None: """sqlite3.DatabaseError('database disk image is malformed') raises DatabaseCorruptedError.""" with patch( "app.db.open_db", new=AsyncMock(side_effect=DatabaseCorruptedError(test_settings.database_path)), ): gen = get_db(settings=test_settings) with pytest.raises(DatabaseCorruptedError) as exc_info: await gen.__anext__() await gen.aclose() assert exc_info.value.error_code == "database_corrupted" @pytest.mark.asyncio async def test_open_db_creates_parent_directory_if_missing(tmp_path: pytest.Path) -> None: """open_db creates the parent directory when it does not exist.""" from pathlib import Path from app.db import open_db db_path = str(Path(str(tmp_path)) / "subdir" / "deeper" / "bangui.db") mock_conn = MagicMock() mock_conn.close = AsyncMock() mock_conn.execute = AsyncMock() mock_conn.commit = AsyncMock() with patch("aiosqlite.connect", new=AsyncMock(return_value=mock_conn)), \ patch("app.db._configure_connection", new=AsyncMock()): connection = await open_db(db_path) assert connection is mock_conn assert Path(db_path).parent.exists() @pytest.mark.asyncio async def test_open_db_logs_specific_sqlite_error_code() -> None: """open_db logs the SQLite error code when available.""" from app.db import open_db exc = aiosqlite.OperationalError("database is locked") exc.sqlite_errorcode = 5 # SQLITE_BUSY with patch("aiosqlite.connect", new=AsyncMock(side_effect=exc)), \ pytest.raises(DatabaseUnavailableError): await open_db("/tmp/test.db") # --------------------------------------------------------------------------- # Error metadata tests # --------------------------------------------------------------------------- def test_database_busy_error_metadata() -> None: """DatabaseBusyError returns correct metadata.""" err = DatabaseBusyError("/data/bangui.db", retries=3) assert err.error_code == "database_busy" metadata = err.get_error_metadata() assert metadata["database_path"] == "/data/bangui.db" assert metadata["retries"] == 3 def test_database_permission_denied_error_metadata() -> None: """DatabasePermissionDeniedError returns correct metadata.""" err = DatabasePermissionDeniedError("/data/bangui.db") assert err.error_code == "database_permission_denied" assert err.get_error_metadata()["database_path"] == "/data/bangui.db" def test_database_path_invalid_error_metadata() -> None: """DatabasePathInvalidError returns correct metadata.""" err = DatabasePathInvalidError("/data/bangui.db") assert err.error_code == "database_path_invalid" assert err.get_error_metadata()["database_path"] == "/data/bangui.db" def test_database_corrupted_error_metadata() -> None: """DatabaseCorruptedError returns correct metadata.""" err = DatabaseCorruptedError("/data/bangui.db") assert err.error_code == "database_corrupted" assert err.get_error_metadata()["database_path"] == "/data/bangui.db" def test_database_unavailable_error_metadata() -> None: """DatabaseUnavailableError returns correct metadata.""" err = DatabaseUnavailableError("/data/bangui.db", "some error") assert err.error_code == "database_unavailable" metadata = err.get_error_metadata() assert metadata["database_path"] == "/data/bangui.db" assert metadata["error"] == "some error"