diff --git a/Docs/Tasks.md b/Docs/Tasks.md index d536825..145dfc5 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -27,7 +27,7 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue. - Replace it with a deterministic packaging or configuration model so the backend does not depend on repository layout. - Refactor `Fail2BanClient` so concurrency control is instance-based and not backed by hidden module globals. -- **Improve startup / setup guard behavior.** +- **Improve startup / setup guard behavior.** ✅ - Convert `SetupRedirectMiddleware` from an on-demand DB check into a startup/initialisation guard where possible. - Cache setup completion in a safe way and provide an explicit invalidation path if the application state changes. - Reduce middleware responsibility and avoid DB access during normal request dispatch. @@ -75,6 +75,6 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue. 1. ✅ Fix the global SQLite connection pattern and tests. 2. ✅ Refactor dependency injection / explicit shared resources. 3. ✅ Harden fail2ban client concurrency and packaging. -4. Convert setup guard to a safer startup-driven model. +4. ✅ Convert setup guard to a safer startup-driven model. 5. Add deployment-safe configuration and production-ready CORS. 6. Add lifecycle and concurrency regression tests. diff --git a/backend/app/main.py b/backend/app/main.py index e11a73e..a59be51 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -50,6 +50,10 @@ from app.routers import ( from app.tasks import blocklist_import, geo_cache_flush, geo_re_resolve, health_check, history_sync from app.utils.fail2ban_client import Fail2BanConnectionError, Fail2BanProtocolError from app.utils.jail_config import ensure_jail_configs +from app.utils.setup_state import ( + is_setup_complete_cached, + set_setup_complete_cache, +) log: structlog.stdlib.BoundLogger = structlog.get_logger() @@ -122,6 +126,11 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]: await init_db(db) await geo_service.load_cache_from_db(db) unresolved_count = await geo_service.count_unresolved(db) + from app.services import setup_service # noqa: PLC0415 + + setup_complete = await setup_service.is_setup_complete(db) + set_setup_complete_cache(app, setup_complete) + log.debug("setup_completion_cached", completed=setup_complete) finally: await db.close() @@ -133,8 +142,6 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]: app.state.http_session = http_session # --- Pre-warm geo cache from the persistent store --- - from app.services import geo_service # noqa: PLC0415 - geo_service.init_geoip(settings.geoip_db_path) # --- Background task scheduler --- @@ -292,39 +299,14 @@ class SetupRedirectMiddleware(BaseHTTPMiddleware): return await call_next(request) # If setup is not complete, block all other API requests. - # Fast path: setup completion is a one-way transition. Once it is - # True it is cached on app.state so all subsequent requests skip the - # DB query entirely. The flag is reset only when the app restarts. - if path.startswith("/api") and not getattr( - request.app.state, "_setup_complete_cached", False - ): - from app.db import open_db # noqa: PLC0415 - from app.services import setup_service # noqa: PLC0415 - - db = getattr(request.app.state, "db", None) - if db is None: - settings = request.app.state.settings - db = await open_db(settings.database_path) - try: - is_complete = await setup_service.is_setup_complete(db) - except Exception: - log.debug("setup_check_failed", reason="db_uninitialised_or_inaccessible") - is_complete = False - finally: - await db.close() - else: - try: - is_complete = await setup_service.is_setup_complete(db) - except Exception: - log.debug("setup_check_failed", reason="db_uninitialised_or_inaccessible") - is_complete = False - - if not is_complete: - return RedirectResponse( - url="/api/setup", - status_code=status.HTTP_307_TEMPORARY_REDIRECT, - ) - request.app.state._setup_complete_cached = True + # The setup completion state is resolved at startup and stored in + # ``app.state.setup_complete_cached`` so this middleware does not + # perform any database queries during normal request handling. + if path.startswith("/api") and not is_setup_complete_cached(request.app): + return RedirectResponse( + url="/api/setup", + status_code=status.HTTP_307_TEMPORARY_REDIRECT, + ) return await call_next(request) @@ -360,6 +342,7 @@ def create_app(settings: Settings | None = None) -> FastAPI: # Store settings on app.state so the lifespan handler can access them. app.state.settings = resolved_settings + set_setup_complete_cache(app, False) # --- CORS --- # In production the frontend is served by the same origin. diff --git a/backend/app/routers/setup.py b/backend/app/routers/setup.py index c42c6f6..36aa6f9 100644 --- a/backend/app/routers/setup.py +++ b/backend/app/routers/setup.py @@ -8,11 +8,12 @@ return ``409 Conflict``. from __future__ import annotations import structlog -from fastapi import APIRouter, HTTPException, status +from fastapi import APIRouter, HTTPException, Request, status from app.dependencies import DbDep from app.models.setup import SetupRequest, SetupResponse, SetupStatusResponse, SetupTimezoneResponse from app.services import setup_service +from app.utils.setup_state import set_setup_complete_cache log: structlog.stdlib.BoundLogger = structlog.get_logger() @@ -41,10 +42,15 @@ async def get_setup_status(db: DbDep) -> SetupStatusResponse: status_code=status.HTTP_201_CREATED, summary="Run the initial setup wizard", ) -async def post_setup(body: SetupRequest, db: DbDep) -> SetupResponse: +async def post_setup( + request: Request, + body: SetupRequest, + db: DbDep, +) -> SetupResponse: """Persist the initial BanGUI configuration. Args: + request: The incoming HTTP request. body: Setup request payload validated by Pydantic. db: Injected aiosqlite connection. @@ -68,6 +74,7 @@ async def post_setup(body: SetupRequest, db: DbDep) -> SetupResponse: timezone=body.timezone, session_duration_minutes=body.session_duration_minutes, ) + set_setup_complete_cache(request.app, True) return SetupResponse() diff --git a/backend/app/utils/setup_state.py b/backend/app/utils/setup_state.py new file mode 100644 index 0000000..5e30940 --- /dev/null +++ b/backend/app/utils/setup_state.py @@ -0,0 +1,24 @@ +"""Manage the cached setup completion flag stored on application state.""" + +from __future__ import annotations + +from fastapi import FastAPI + + +def is_setup_complete_cached(app: FastAPI) -> bool: + """Return the cached setup completion state from application state.""" + return getattr(app.state, "setup_complete_cached", False) + + +def set_setup_complete_cache(app: FastAPI, completed: bool) -> None: + """Set the cached setup completion state on application state.""" + app.state.setup_complete_cached = completed + + +def invalidate_setup_complete_cache(app: FastAPI) -> None: + """Reset the cached setup completion state. + + This helper exists so the cache can be invalidated explicitly if the + application state changes during runtime. + """ + app.state.setup_complete_cached = False diff --git a/backend/tests/test_routers/test_setup.py b/backend/tests/test_routers/test_setup.py index da9e623..846b494 100644 --- a/backend/tests/test_routers/test_setup.py +++ b/backend/tests/test_routers/test_setup.py @@ -12,6 +12,7 @@ from httpx import ASGITransport, AsyncClient from app.config import Settings from app.db import init_db from app.main import _lifespan, create_app +from app.services import setup_service # --------------------------------------------------------------------------- # Shared setup payload @@ -224,32 +225,18 @@ class TestGetTimezone: class TestSetupCompleteCaching: """SetupRedirectMiddleware caches the setup_complete flag in ``app.state``.""" - async def test_cache_flag_set_after_first_post_setup_request( - self, - app_and_client: tuple[object, AsyncClient], + async def test_cache_flag_set_after_post_setup( + self, app_and_client: tuple[object, AsyncClient] ) -> None: - """``_setup_complete_cached`` is set to True on the first request after setup. - - The ``/api/setup`` path is in ``_ALWAYS_ALLOWED`` so it bypasses the - middleware check. The first request to a non-exempt endpoint triggers - the DB query and, when setup is complete, populates the cache flag. - """ + """``setup_complete_cached`` is set to True immediately after setup.""" from fastapi import FastAPI app, client = app_and_client assert isinstance(app, FastAPI) - # Complete setup (exempt from middleware, no flag set yet). resp = await client.post("/api/setup", json=_SETUP_PAYLOAD) assert resp.status_code == 201 - - # Flag not yet cached — setup was via an exempt path. - assert not getattr(app.state, "_setup_complete_cached", False) - - # First non-exempt request — middleware queries DB and sets the flag. - await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) - - assert app.state._setup_complete_cached is True + assert app.state.setup_complete_cached is True async def test_cached_path_skips_is_setup_complete( self, @@ -268,11 +255,11 @@ class TestSetupCompleteCaching: # Do setup and warm the cache. await client.post("/api/setup", json=_SETUP_PAYLOAD) await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) - assert app.state._setup_complete_cached is True + assert app.state.setup_complete_cached is True call_count = 0 - async def _counting(db: aiosqlite.Connection) -> bool: + async def _counting(_db: aiosqlite.Connection) -> bool: nonlocal call_count call_count += 1 return True @@ -380,6 +367,55 @@ class TestLifespanDatabaseDirectoryCreation: assert tmp_path.exists() +class TestLifespanSetupCache: + """Verify that app startup resolves setup completion into app.state.""" + + async def test_startup_caches_setup_completion(self, tmp_path: Path) -> None: + """Lifespan should populate ``setup_complete_cached`` based on the DB.""" + settings = Settings( + database_path=str(tmp_path / "bangui.db"), + fail2ban_socket="/tmp/fake.sock", + session_secret="test-lifespan-setup-cache-secret", + session_duration_minutes=60, + timezone="UTC", + log_level="debug", + ) + app = create_app(settings=settings) + + db = await aiosqlite.connect(settings.database_path) + db.row_factory = aiosqlite.Row + await init_db(db) + await setup_service.run_setup( + db, + master_password="supersecret123", + database_path=settings.database_path, + fail2ban_socket=settings.fail2ban_socket, + timezone=settings.timezone, + session_duration_minutes=settings.session_duration_minutes, + ) + await db.close() + + mock_scheduler = MagicMock() + mock_scheduler.start = MagicMock() + mock_scheduler.shutdown = MagicMock() + + with ( + patch("app.services.geo_service.init_geoip"), + patch( + "app.services.geo_service.load_cache_from_db", + new=AsyncMock(return_value=None), + ), + patch("app.tasks.health_check.register"), + patch("app.tasks.blocklist_import.register"), + patch("app.tasks.geo_cache_flush.register"), + patch("app.tasks.geo_re_resolve.register"), + patch("app.main.AsyncIOScheduler", return_value=mock_scheduler), + patch("app.main.ensure_jail_configs"), + ): + async with _lifespan(app): + assert app.state.setup_complete_cached is True + + # --------------------------------------------------------------------------- # Task 0.2 — Middleware redirects when app.state.db is None # ---------------------------------------------------------------------------