From 6eab47f7ba5c7a33a262d84cfa92f7cbf19799c9 Mon Sep 17 00:00:00 2001 From: Lukas Date: Tue, 7 Apr 2026 21:41:55 +0200 Subject: [PATCH] Fix setup persistence and load persisted runtime configuration --- Docs/Tasks.md | 1 + backend/app/routers/setup.py | 23 +++++---- backend/app/services/setup_service.py | 57 ++++++++++++++++++++++ backend/app/startup.py | 24 ++++++++++ backend/tests/test_main.py | 61 +++++++++++++++++++++++- backend/tests/test_routers/test_setup.py | 37 +++++++++++--- 6 files changed, 188 insertions(+), 15 deletions(-) diff --git a/Docs/Tasks.md b/Docs/Tasks.md index 0e16aa0..5202c14 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -11,6 +11,7 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue. ### 1. Fix setup persistence - Where found: `backend/app/config.py`, `backend/app/startup.py`, `backend/app/services/setup_service.py`, `backend/app/routers/setup.py` - Goal: runtime configuration should use the values persisted during setup for `database_path`, `fail2ban_socket`, `timezone`, and `session_duration_minutes` rather than only environment defaults. +- Status: completed - Possible traps and issues: - Setup may appear successful but later use a different DB/socket on restart. - A partially persisted setup run must not leave the app in a broken or half-configured state. diff --git a/backend/app/routers/setup.py b/backend/app/routers/setup.py index cd4f9cb..5c02bd4 100644 --- a/backend/app/routers/setup.py +++ b/backend/app/routers/setup.py @@ -10,10 +10,10 @@ from __future__ import annotations import structlog from fastapi import APIRouter, HTTPException, status -from app.dependencies import AppDep, DbDep +from app.dependencies import AppDep, DbDep, SettingsDep from app.models.setup import SetupRequest, SetupResponse, SetupStatusResponse, SetupTimezoneResponse from app.services import setup_service -from app.utils.setup_state import set_setup_complete_cache +from app.utils.setup_state import is_setup_complete_cached, set_setup_complete_cache log: structlog.stdlib.BoundLogger = structlog.get_logger() @@ -25,14 +25,14 @@ router = APIRouter(prefix="/api/setup", tags=["setup"]) response_model=SetupStatusResponse, summary="Check whether setup has been completed", ) -async def get_setup_status(db: DbDep) -> SetupStatusResponse: +async def get_setup_status(app: AppDep) -> SetupStatusResponse: """Return whether the initial setup wizard has been completed. Returns: :class:`~app.models.setup.SetupStatusResponse` with ``completed`` set to ``True`` if setup is done, ``False`` otherwise. """ - done = await setup_service.is_setup_complete(db) + done = is_setup_complete_cached(app) return SetupStatusResponse(completed=done) @@ -60,7 +60,7 @@ async def post_setup( Raises: HTTPException: 409 if setup has already been completed. """ - if await setup_service.is_setup_complete(db): + if is_setup_complete_cached(app) or await setup_service.is_setup_complete(db): raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="Setup has already been completed.", @@ -75,6 +75,14 @@ async def post_setup( session_duration_minutes=body.session_duration_minutes, ) set_setup_complete_cache(app, True) + app.state.settings = app.state.settings.model_copy( + update={ + "database_path": body.database_path, + "fail2ban_socket": body.fail2ban_socket, + "timezone": body.timezone, + "session_duration_minutes": body.session_duration_minutes, + } + ) return SetupResponse() @@ -83,7 +91,7 @@ async def post_setup( response_model=SetupTimezoneResponse, summary="Return the configured IANA timezone", ) -async def get_timezone(db: DbDep) -> SetupTimezoneResponse: +async def get_timezone(settings: SettingsDep) -> SetupTimezoneResponse: """Return the IANA timezone configured during the initial setup wizard. The frontend uses this to convert UTC timestamps to the local time zone @@ -94,5 +102,4 @@ async def get_timezone(db: DbDep) -> SetupTimezoneResponse: set to the stored IANA identifier (e.g. ``"UTC"`` or ``"Europe/Berlin"``), defaulting to ``"UTC"`` if unset. """ - tz = await setup_service.get_timezone(db) - return SetupTimezoneResponse(timezone=tz) + return SetupTimezoneResponse(timezone=settings.timezone) diff --git a/backend/app/services/setup_service.py b/backend/app/services/setup_service.py index 5254fce..a1f8631 100644 --- a/backend/app/services/setup_service.py +++ b/backend/app/services/setup_service.py @@ -8,6 +8,7 @@ enforcing the rule that setup can only run once. from __future__ import annotations import asyncio +from pathlib import Path from typing import TYPE_CHECKING import bcrypt @@ -16,6 +17,7 @@ import structlog if TYPE_CHECKING: import aiosqlite +from app.db import init_db, open_db from app.repositories import settings_repo log: structlog.stdlib.BoundLogger = structlog.get_logger() @@ -95,6 +97,9 @@ async def run_setup( await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_HIGH, "100") await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, "50") await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW, "20") + + await _ensure_database_initialized(database_path) + # Mark setup as complete — must be last so a partial failure leaves # setup_completed unset and does not lock out the user. await settings_repo.set_setting(db, _KEY_SETUP_DONE, "1") @@ -114,6 +119,58 @@ async def get_password_hash(db: aiosqlite.Connection) -> str | None: return await util_get_password_hash(db) +async def get_persisted_runtime_settings(db: aiosqlite.Connection) -> dict[str, str | int]: + """Return runtime configuration values persisted during initial setup.""" + runtime_settings: dict[str, str | int] = {} + + database_path = await settings_repo.get_setting(db, _KEY_DATABASE_PATH) + if database_path: + runtime_settings["database_path"] = database_path + + fail2ban_socket = await settings_repo.get_setting(db, _KEY_FAIL2BAN_SOCKET) + if fail2ban_socket: + runtime_settings["fail2ban_socket"] = fail2ban_socket + + timezone = await settings_repo.get_setting(db, _KEY_TIMEZONE) + if timezone: + runtime_settings["timezone"] = timezone + + session_duration = await settings_repo.get_setting(db, _KEY_SESSION_DURATION) + if session_duration is not None: + try: + runtime_settings["session_duration_minutes"] = int(session_duration) + except ValueError: + log.warning( + "invalid_setup_setting", + key=_KEY_SESSION_DURATION, + value=session_duration, + ) + + return runtime_settings + + +async def _ensure_database_initialized(database_path: str) -> None: + """Create and initialise the configured runtime database if it does not exist.""" + database_path_obj = Path(database_path) + parent_dir = database_path_obj.parent + + try: + parent_dir.mkdir(parents=True, exist_ok=True) + except PermissionError: + log.warning( + "cannot_create_runtime_database_parent", + database_path=database_path, + parent=str(parent_dir), + ) + return + + db = await open_db(str(database_path_obj)) + try: + await init_db(db) + finally: + await db.close() + + async def get_timezone(db: aiosqlite.Connection) -> str: """Return the configured IANA timezone string.""" tz = await settings_repo.get_setting(db, _KEY_TIMEZONE) diff --git a/backend/app/startup.py b/backend/app/startup.py index fda34c5..1a76027 100644 --- a/backend/app/startup.py +++ b/backend/app/startup.py @@ -24,6 +24,15 @@ from app.utils.setup_state import set_setup_complete_cache log: structlog.stdlib.BoundLogger = structlog.get_logger() +async def _ensure_database_schema(database_path: str) -> None: + """Create the configured runtime database if it does not already exist.""" + db = await open_db(database_path) + try: + await init_db(db) + finally: + await db.close() + + async def startup_shared_resources( app: FastAPI, settings: Settings, @@ -44,6 +53,7 @@ async def startup_shared_resources( log.debug("database_directory_ensured", directory=str(db_path.parent)) + original_db_path = db_path.resolve() startup_db = await open_db(settings.database_path) try: await init_db(startup_db) @@ -52,6 +62,20 @@ async def startup_shared_resources( setup_complete = await setup_service.is_setup_complete(startup_db) set_setup_complete_cache(app, setup_complete) log.debug("setup_completion_cached", completed=setup_complete) + + if setup_complete: + persisted_runtime_settings = ( + await setup_service.get_persisted_runtime_settings(startup_db) + ) + if persisted_runtime_settings: + updated_settings = settings.model_copy(update=persisted_runtime_settings) + if Path(updated_settings.database_path).resolve() != original_db_path: + await _ensure_database_schema(updated_settings.database_path) + app.state.settings = updated_settings + log.info( + "runtime_settings_overridden_from_setup", + overrides=persisted_runtime_settings, + ) finally: await startup_db.close() diff --git a/backend/tests/test_main.py b/backend/tests/test_main.py index 366fe5f..9608bf0 100644 --- a/backend/tests/test_main.py +++ b/backend/tests/test_main.py @@ -4,10 +4,13 @@ import asyncio from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch +import aiosqlite from httpx import ASGITransport, AsyncClient from app.config import Settings +from app.db import init_db from app.main import CORSMiddleware, _lifespan, create_app +from app.services import setup_service def test_create_app_configures_cors_from_settings() -> None: @@ -120,6 +123,61 @@ async def test_lifespan_initialises_and_cleans_up_shared_resources(tmp_path: Pat mock_scheduler.shutdown.assert_called_once_with(wait=False) +async def test_startup_overrides_settings_from_persisted_setup(tmp_path: Path) -> None: + """Startup should replace env defaults with values persisted by setup.""" + env_settings = Settings( + database_path=str(tmp_path / "pointer.db"), + fail2ban_socket="/tmp/fake_fail2ban.sock", + fail2ban_config_dir=str(tmp_path / "fail2ban"), + session_secret="test-startup-secret", + session_duration_minutes=60, + timezone="UTC", + log_level="debug", + ) + app = create_app(settings=env_settings) + + runtime_db_path = str(tmp_path / "runtime.db") + db = await aiosqlite.connect(env_settings.database_path) + db.row_factory = aiosqlite.Row + await init_db(db) + await setup_service.run_setup( + db, + master_password="supersecret123", + database_path=runtime_db_path, + fail2ban_socket="/tmp/persisted.sock", + timezone="Europe/Berlin", + session_duration_minutes=123, + ) + await db.close() + + mock_scheduler = MagicMock() + mock_scheduler.start = MagicMock() + mock_scheduler.shutdown = MagicMock() + + mock_http_session = MagicMock() + mock_http_session.close = AsyncMock() + + with ( + patch("app.startup.ensure_jail_configs"), + patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session), + patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), + patch("app.services.geo_service.init_geoip"), + patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)), + patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)), + patch("app.tasks.health_check.register"), + patch("app.tasks.blocklist_import.register"), + patch("app.tasks.geo_cache_flush.register"), + patch("app.tasks.geo_re_resolve.register"), + patch("app.tasks.history_sync.register"), + ): + async with _lifespan(app): + assert app.state.settings.database_path == runtime_db_path + assert app.state.settings.fail2ban_socket == "/tmp/persisted.sock" + assert app.state.settings.timezone == "Europe/Berlin" + assert app.state.settings.session_duration_minutes == 123 + assert Path(runtime_db_path).exists() + + async def test_concurrent_requests_use_request_scoped_db_connections(tmp_path: Path) -> None: """Concurrent requests each open and close their own database connection.""" settings = Settings( @@ -167,7 +225,8 @@ async def test_concurrent_requests_use_request_scoped_db_connections(tmp_path: P ): transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: - responses = await asyncio.gather(*(client.get("/api/setup") for _ in range(5))) + app.state.setup_complete_cached = True + responses = await asyncio.gather(*(client.post("/api/auth/logout") for _ in range(5))) assert len(connections) == 5 assert len({id(connection) for connection in connections}) == 5 diff --git a/backend/tests/test_routers/test_setup.py b/backend/tests/test_routers/test_setup.py index 846b494..7f3ccd8 100644 --- a/backend/tests/test_routers/test_setup.py +++ b/backend/tests/test_routers/test_setup.py @@ -143,6 +143,28 @@ class TestPostSetup: assert response.status_code == 201 +class TestPostSetupRuntimeState: + async def test_updates_runtime_settings_after_setup( + self, app_and_client: tuple[object, AsyncClient] + ) -> None: + """App state should reflect setup settings immediately after setup.""" + app, client = app_and_client + payload = { + "master_password": "supersecret123", + "database_path": "bangui.db", + "fail2ban_socket": "/tmp/persisted.sock", + "timezone": "Europe/Berlin", + "session_duration_minutes": 90, + } + + response = await client.post("/api/setup", json=payload) + assert response.status_code == 201 + assert app.state.settings.database_path == payload["database_path"] + assert app.state.settings.fail2ban_socket == payload["fail2ban_socket"] + assert app.state.settings.timezone == payload["timezone"] + assert app.state.settings.session_duration_minutes == payload["session_duration_minutes"] + + class TestSetupRedirectMiddleware: """Verify that the setup-redirect middleware enforces setup-first.""" @@ -316,8 +338,9 @@ class TestLifespanDatabaseDirectoryCreation: patch("app.tasks.blocklist_import.register"), patch("app.tasks.geo_cache_flush.register"), patch("app.tasks.geo_re_resolve.register"), - patch("app.main.AsyncIOScheduler", return_value=mock_scheduler), - patch("app.main.ensure_jail_configs"), + patch("app.tasks.history_sync.register"), + patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), + patch("app.startup.ensure_jail_configs"), ): async with _lifespan(app): assert nested_db.parent.exists(), ( @@ -359,8 +382,9 @@ class TestLifespanDatabaseDirectoryCreation: patch("app.tasks.blocklist_import.register"), patch("app.tasks.geo_cache_flush.register"), patch("app.tasks.geo_re_resolve.register"), - patch("app.main.AsyncIOScheduler", return_value=mock_scheduler), - patch("app.main.ensure_jail_configs"), + patch("app.tasks.history_sync.register"), + patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), + patch("app.startup.ensure_jail_configs"), ): # Should not raise FileExistsError or similar. async with _lifespan(app): @@ -409,8 +433,9 @@ class TestLifespanSetupCache: patch("app.tasks.blocklist_import.register"), patch("app.tasks.geo_cache_flush.register"), patch("app.tasks.geo_re_resolve.register"), - patch("app.main.AsyncIOScheduler", return_value=mock_scheduler), - patch("app.main.ensure_jail_configs"), + patch("app.tasks.history_sync.register"), + patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), + patch("app.startup.ensure_jail_configs"), ): async with _lifespan(app): assert app.state.setup_complete_cached is True