Convert setup guard to startup-driven cache and update tests

This commit is contained in:
2026-04-06 20:38:15 +02:00
parent 3ccfc20c64
commit 89ab41cc9e
5 changed files with 109 additions and 59 deletions

View File

@@ -50,6 +50,10 @@ from app.routers import (
from app.tasks import blocklist_import, geo_cache_flush, geo_re_resolve, health_check, history_sync
from app.utils.fail2ban_client import Fail2BanConnectionError, Fail2BanProtocolError
from app.utils.jail_config import ensure_jail_configs
from app.utils.setup_state import (
is_setup_complete_cached,
set_setup_complete_cache,
)
log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -122,6 +126,11 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await init_db(db)
await geo_service.load_cache_from_db(db)
unresolved_count = await geo_service.count_unresolved(db)
from app.services import setup_service # noqa: PLC0415
setup_complete = await setup_service.is_setup_complete(db)
set_setup_complete_cache(app, setup_complete)
log.debug("setup_completion_cached", completed=setup_complete)
finally:
await db.close()
@@ -133,8 +142,6 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
app.state.http_session = http_session
# --- Pre-warm geo cache from the persistent store ---
from app.services import geo_service # noqa: PLC0415
geo_service.init_geoip(settings.geoip_db_path)
# --- Background task scheduler ---
@@ -292,39 +299,14 @@ class SetupRedirectMiddleware(BaseHTTPMiddleware):
return await call_next(request)
# If setup is not complete, block all other API requests.
# Fast path: setup completion is a one-way transition. Once it is
# True it is cached on app.state so all subsequent requests skip the
# DB query entirely. The flag is reset only when the app restarts.
if path.startswith("/api") and not getattr(
request.app.state, "_setup_complete_cached", False
):
from app.db import open_db # noqa: PLC0415
from app.services import setup_service # noqa: PLC0415
db = getattr(request.app.state, "db", None)
if db is None:
settings = request.app.state.settings
db = await open_db(settings.database_path)
try:
is_complete = await setup_service.is_setup_complete(db)
except Exception:
log.debug("setup_check_failed", reason="db_uninitialised_or_inaccessible")
is_complete = False
finally:
await db.close()
else:
try:
is_complete = await setup_service.is_setup_complete(db)
except Exception:
log.debug("setup_check_failed", reason="db_uninitialised_or_inaccessible")
is_complete = False
if not is_complete:
return RedirectResponse(
url="/api/setup",
status_code=status.HTTP_307_TEMPORARY_REDIRECT,
)
request.app.state._setup_complete_cached = True
# The setup completion state is resolved at startup and stored in
# ``app.state.setup_complete_cached`` so this middleware does not
# perform any database queries during normal request handling.
if path.startswith("/api") and not is_setup_complete_cached(request.app):
return RedirectResponse(
url="/api/setup",
status_code=status.HTTP_307_TEMPORARY_REDIRECT,
)
return await call_next(request)
@@ -360,6 +342,7 @@ def create_app(settings: Settings | None = None) -> FastAPI:
# Store settings on app.state so the lifespan handler can access them.
app.state.settings = resolved_settings
set_setup_complete_cache(app, False)
# --- CORS ---
# In production the frontend is served by the same origin.

View File

@@ -8,11 +8,12 @@ return ``409 Conflict``.
from __future__ import annotations
import structlog
from fastapi import APIRouter, HTTPException, status
from fastapi import APIRouter, HTTPException, Request, status
from app.dependencies import DbDep
from app.models.setup import SetupRequest, SetupResponse, SetupStatusResponse, SetupTimezoneResponse
from app.services import setup_service
from app.utils.setup_state import set_setup_complete_cache
log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -41,10 +42,15 @@ async def get_setup_status(db: DbDep) -> SetupStatusResponse:
status_code=status.HTTP_201_CREATED,
summary="Run the initial setup wizard",
)
async def post_setup(body: SetupRequest, db: DbDep) -> SetupResponse:
async def post_setup(
request: Request,
body: SetupRequest,
db: DbDep,
) -> SetupResponse:
"""Persist the initial BanGUI configuration.
Args:
request: The incoming HTTP request.
body: Setup request payload validated by Pydantic.
db: Injected aiosqlite connection.
@@ -68,6 +74,7 @@ async def post_setup(body: SetupRequest, db: DbDep) -> SetupResponse:
timezone=body.timezone,
session_duration_minutes=body.session_duration_minutes,
)
set_setup_complete_cache(request.app, True)
return SetupResponse()

View File

@@ -0,0 +1,24 @@
"""Manage the cached setup completion flag stored on application state."""
from __future__ import annotations
from fastapi import FastAPI
def is_setup_complete_cached(app: FastAPI) -> bool:
"""Return the cached setup completion state from application state."""
return getattr(app.state, "setup_complete_cached", False)
def set_setup_complete_cache(app: FastAPI, completed: bool) -> None:
"""Set the cached setup completion state on application state."""
app.state.setup_complete_cached = completed
def invalidate_setup_complete_cache(app: FastAPI) -> None:
"""Reset the cached setup completion state.
This helper exists so the cache can be invalidated explicitly if the
application state changes during runtime.
"""
app.state.setup_complete_cached = False

View File

@@ -12,6 +12,7 @@ from httpx import ASGITransport, AsyncClient
from app.config import Settings
from app.db import init_db
from app.main import _lifespan, create_app
from app.services import setup_service
# ---------------------------------------------------------------------------
# Shared setup payload
@@ -224,32 +225,18 @@ class TestGetTimezone:
class TestSetupCompleteCaching:
"""SetupRedirectMiddleware caches the setup_complete flag in ``app.state``."""
async def test_cache_flag_set_after_first_post_setup_request(
self,
app_and_client: tuple[object, AsyncClient],
async def test_cache_flag_set_after_post_setup(
self, app_and_client: tuple[object, AsyncClient]
) -> None:
"""``_setup_complete_cached`` is set to True on the first request after setup.
The ``/api/setup`` path is in ``_ALWAYS_ALLOWED`` so it bypasses the
middleware check. The first request to a non-exempt endpoint triggers
the DB query and, when setup is complete, populates the cache flag.
"""
"""``setup_complete_cached`` is set to True immediately after setup."""
from fastapi import FastAPI
app, client = app_and_client
assert isinstance(app, FastAPI)
# Complete setup (exempt from middleware, no flag set yet).
resp = await client.post("/api/setup", json=_SETUP_PAYLOAD)
assert resp.status_code == 201
# Flag not yet cached — setup was via an exempt path.
assert not getattr(app.state, "_setup_complete_cached", False)
# First non-exempt request — middleware queries DB and sets the flag.
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]})
assert app.state._setup_complete_cached is True
assert app.state.setup_complete_cached is True
async def test_cached_path_skips_is_setup_complete(
self,
@@ -268,11 +255,11 @@ class TestSetupCompleteCaching:
# Do setup and warm the cache.
await client.post("/api/setup", json=_SETUP_PAYLOAD)
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]})
assert app.state._setup_complete_cached is True
assert app.state.setup_complete_cached is True
call_count = 0
async def _counting(db: aiosqlite.Connection) -> bool:
async def _counting(_db: aiosqlite.Connection) -> bool:
nonlocal call_count
call_count += 1
return True
@@ -380,6 +367,55 @@ class TestLifespanDatabaseDirectoryCreation:
assert tmp_path.exists()
class TestLifespanSetupCache:
"""Verify that app startup resolves setup completion into app.state."""
async def test_startup_caches_setup_completion(self, tmp_path: Path) -> None:
"""Lifespan should populate ``setup_complete_cached`` based on the DB."""
settings = Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake.sock",
session_secret="test-lifespan-setup-cache-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
db = await aiosqlite.connect(settings.database_path)
db.row_factory = aiosqlite.Row
await init_db(db)
await setup_service.run_setup(
db,
master_password="supersecret123",
database_path=settings.database_path,
fail2ban_socket=settings.fail2ban_socket,
timezone=settings.timezone,
session_duration_minutes=settings.session_duration_minutes,
)
await db.close()
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
with (
patch("app.services.geo_service.init_geoip"),
patch(
"app.services.geo_service.load_cache_from_db",
new=AsyncMock(return_value=None),
),
patch("app.tasks.health_check.register"),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.main.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.main.ensure_jail_configs"),
):
async with _lifespan(app):
assert app.state.setup_complete_cached is True
# ---------------------------------------------------------------------------
# Task 0.2 — Middleware redirects when app.state.db is None
# ---------------------------------------------------------------------------