Convert setup guard to startup-driven cache and update tests
This commit is contained in:
@@ -27,7 +27,7 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
|
|||||||
- Replace it with a deterministic packaging or configuration model so the backend does not depend on repository layout.
|
- Replace it with a deterministic packaging or configuration model so the backend does not depend on repository layout.
|
||||||
- Refactor `Fail2BanClient` so concurrency control is instance-based and not backed by hidden module globals.
|
- Refactor `Fail2BanClient` so concurrency control is instance-based and not backed by hidden module globals.
|
||||||
|
|
||||||
- **Improve startup / setup guard behavior.**
|
- **Improve startup / setup guard behavior.** ✅
|
||||||
- Convert `SetupRedirectMiddleware` from an on-demand DB check into a startup/initialisation guard where possible.
|
- Convert `SetupRedirectMiddleware` from an on-demand DB check into a startup/initialisation guard where possible.
|
||||||
- Cache setup completion in a safe way and provide an explicit invalidation path if the application state changes.
|
- Cache setup completion in a safe way and provide an explicit invalidation path if the application state changes.
|
||||||
- Reduce middleware responsibility and avoid DB access during normal request dispatch.
|
- Reduce middleware responsibility and avoid DB access during normal request dispatch.
|
||||||
@@ -75,6 +75,6 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
|
|||||||
1. ✅ Fix the global SQLite connection pattern and tests.
|
1. ✅ Fix the global SQLite connection pattern and tests.
|
||||||
2. ✅ Refactor dependency injection / explicit shared resources.
|
2. ✅ Refactor dependency injection / explicit shared resources.
|
||||||
3. ✅ Harden fail2ban client concurrency and packaging.
|
3. ✅ Harden fail2ban client concurrency and packaging.
|
||||||
4. Convert setup guard to a safer startup-driven model.
|
4. ✅ Convert setup guard to a safer startup-driven model.
|
||||||
5. Add deployment-safe configuration and production-ready CORS.
|
5. Add deployment-safe configuration and production-ready CORS.
|
||||||
6. Add lifecycle and concurrency regression tests.
|
6. Add lifecycle and concurrency regression tests.
|
||||||
|
|||||||
@@ -50,6 +50,10 @@ from app.routers import (
|
|||||||
from app.tasks import blocklist_import, geo_cache_flush, geo_re_resolve, health_check, history_sync
|
from app.tasks import blocklist_import, geo_cache_flush, geo_re_resolve, health_check, history_sync
|
||||||
from app.utils.fail2ban_client import Fail2BanConnectionError, Fail2BanProtocolError
|
from app.utils.fail2ban_client import Fail2BanConnectionError, Fail2BanProtocolError
|
||||||
from app.utils.jail_config import ensure_jail_configs
|
from app.utils.jail_config import ensure_jail_configs
|
||||||
|
from app.utils.setup_state import (
|
||||||
|
is_setup_complete_cached,
|
||||||
|
set_setup_complete_cache,
|
||||||
|
)
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
@@ -122,6 +126,11 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
await init_db(db)
|
await init_db(db)
|
||||||
await geo_service.load_cache_from_db(db)
|
await geo_service.load_cache_from_db(db)
|
||||||
unresolved_count = await geo_service.count_unresolved(db)
|
unresolved_count = await geo_service.count_unresolved(db)
|
||||||
|
from app.services import setup_service # noqa: PLC0415
|
||||||
|
|
||||||
|
setup_complete = await setup_service.is_setup_complete(db)
|
||||||
|
set_setup_complete_cache(app, setup_complete)
|
||||||
|
log.debug("setup_completion_cached", completed=setup_complete)
|
||||||
finally:
|
finally:
|
||||||
await db.close()
|
await db.close()
|
||||||
|
|
||||||
@@ -133,8 +142,6 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
app.state.http_session = http_session
|
app.state.http_session = http_session
|
||||||
|
|
||||||
# --- Pre-warm geo cache from the persistent store ---
|
# --- Pre-warm geo cache from the persistent store ---
|
||||||
from app.services import geo_service # noqa: PLC0415
|
|
||||||
|
|
||||||
geo_service.init_geoip(settings.geoip_db_path)
|
geo_service.init_geoip(settings.geoip_db_path)
|
||||||
|
|
||||||
# --- Background task scheduler ---
|
# --- Background task scheduler ---
|
||||||
@@ -292,39 +299,14 @@ class SetupRedirectMiddleware(BaseHTTPMiddleware):
|
|||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
# If setup is not complete, block all other API requests.
|
# If setup is not complete, block all other API requests.
|
||||||
# Fast path: setup completion is a one-way transition. Once it is
|
# The setup completion state is resolved at startup and stored in
|
||||||
# True it is cached on app.state so all subsequent requests skip the
|
# ``app.state.setup_complete_cached`` so this middleware does not
|
||||||
# DB query entirely. The flag is reset only when the app restarts.
|
# perform any database queries during normal request handling.
|
||||||
if path.startswith("/api") and not getattr(
|
if path.startswith("/api") and not is_setup_complete_cached(request.app):
|
||||||
request.app.state, "_setup_complete_cached", False
|
return RedirectResponse(
|
||||||
):
|
url="/api/setup",
|
||||||
from app.db import open_db # noqa: PLC0415
|
status_code=status.HTTP_307_TEMPORARY_REDIRECT,
|
||||||
from app.services import setup_service # noqa: PLC0415
|
)
|
||||||
|
|
||||||
db = getattr(request.app.state, "db", None)
|
|
||||||
if db is None:
|
|
||||||
settings = request.app.state.settings
|
|
||||||
db = await open_db(settings.database_path)
|
|
||||||
try:
|
|
||||||
is_complete = await setup_service.is_setup_complete(db)
|
|
||||||
except Exception:
|
|
||||||
log.debug("setup_check_failed", reason="db_uninitialised_or_inaccessible")
|
|
||||||
is_complete = False
|
|
||||||
finally:
|
|
||||||
await db.close()
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
is_complete = await setup_service.is_setup_complete(db)
|
|
||||||
except Exception:
|
|
||||||
log.debug("setup_check_failed", reason="db_uninitialised_or_inaccessible")
|
|
||||||
is_complete = False
|
|
||||||
|
|
||||||
if not is_complete:
|
|
||||||
return RedirectResponse(
|
|
||||||
url="/api/setup",
|
|
||||||
status_code=status.HTTP_307_TEMPORARY_REDIRECT,
|
|
||||||
)
|
|
||||||
request.app.state._setup_complete_cached = True
|
|
||||||
|
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
@@ -360,6 +342,7 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
|||||||
|
|
||||||
# Store settings on app.state so the lifespan handler can access them.
|
# Store settings on app.state so the lifespan handler can access them.
|
||||||
app.state.settings = resolved_settings
|
app.state.settings = resolved_settings
|
||||||
|
set_setup_complete_cache(app, False)
|
||||||
|
|
||||||
# --- CORS ---
|
# --- CORS ---
|
||||||
# In production the frontend is served by the same origin.
|
# In production the frontend is served by the same origin.
|
||||||
|
|||||||
@@ -8,11 +8,12 @@ return ``409 Conflict``.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
from fastapi import APIRouter, HTTPException, status
|
from fastapi import APIRouter, HTTPException, Request, status
|
||||||
|
|
||||||
from app.dependencies import DbDep
|
from app.dependencies import DbDep
|
||||||
from app.models.setup import SetupRequest, SetupResponse, SetupStatusResponse, SetupTimezoneResponse
|
from app.models.setup import SetupRequest, SetupResponse, SetupStatusResponse, SetupTimezoneResponse
|
||||||
from app.services import setup_service
|
from app.services import setup_service
|
||||||
|
from app.utils.setup_state import set_setup_complete_cache
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
@@ -41,10 +42,15 @@ async def get_setup_status(db: DbDep) -> SetupStatusResponse:
|
|||||||
status_code=status.HTTP_201_CREATED,
|
status_code=status.HTTP_201_CREATED,
|
||||||
summary="Run the initial setup wizard",
|
summary="Run the initial setup wizard",
|
||||||
)
|
)
|
||||||
async def post_setup(body: SetupRequest, db: DbDep) -> SetupResponse:
|
async def post_setup(
|
||||||
|
request: Request,
|
||||||
|
body: SetupRequest,
|
||||||
|
db: DbDep,
|
||||||
|
) -> SetupResponse:
|
||||||
"""Persist the initial BanGUI configuration.
|
"""Persist the initial BanGUI configuration.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
request: The incoming HTTP request.
|
||||||
body: Setup request payload validated by Pydantic.
|
body: Setup request payload validated by Pydantic.
|
||||||
db: Injected aiosqlite connection.
|
db: Injected aiosqlite connection.
|
||||||
|
|
||||||
@@ -68,6 +74,7 @@ async def post_setup(body: SetupRequest, db: DbDep) -> SetupResponse:
|
|||||||
timezone=body.timezone,
|
timezone=body.timezone,
|
||||||
session_duration_minutes=body.session_duration_minutes,
|
session_duration_minutes=body.session_duration_minutes,
|
||||||
)
|
)
|
||||||
|
set_setup_complete_cache(request.app, True)
|
||||||
return SetupResponse()
|
return SetupResponse()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
24
backend/app/utils/setup_state.py
Normal file
24
backend/app/utils/setup_state.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""Manage the cached setup completion flag stored on application state."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
|
||||||
|
def is_setup_complete_cached(app: FastAPI) -> bool:
|
||||||
|
"""Return the cached setup completion state from application state."""
|
||||||
|
return getattr(app.state, "setup_complete_cached", False)
|
||||||
|
|
||||||
|
|
||||||
|
def set_setup_complete_cache(app: FastAPI, completed: bool) -> None:
|
||||||
|
"""Set the cached setup completion state on application state."""
|
||||||
|
app.state.setup_complete_cached = completed
|
||||||
|
|
||||||
|
|
||||||
|
def invalidate_setup_complete_cache(app: FastAPI) -> None:
|
||||||
|
"""Reset the cached setup completion state.
|
||||||
|
|
||||||
|
This helper exists so the cache can be invalidated explicitly if the
|
||||||
|
application state changes during runtime.
|
||||||
|
"""
|
||||||
|
app.state.setup_complete_cached = False
|
||||||
@@ -12,6 +12,7 @@ from httpx import ASGITransport, AsyncClient
|
|||||||
from app.config import Settings
|
from app.config import Settings
|
||||||
from app.db import init_db
|
from app.db import init_db
|
||||||
from app.main import _lifespan, create_app
|
from app.main import _lifespan, create_app
|
||||||
|
from app.services import setup_service
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Shared setup payload
|
# Shared setup payload
|
||||||
@@ -224,32 +225,18 @@ class TestGetTimezone:
|
|||||||
class TestSetupCompleteCaching:
|
class TestSetupCompleteCaching:
|
||||||
"""SetupRedirectMiddleware caches the setup_complete flag in ``app.state``."""
|
"""SetupRedirectMiddleware caches the setup_complete flag in ``app.state``."""
|
||||||
|
|
||||||
async def test_cache_flag_set_after_first_post_setup_request(
|
async def test_cache_flag_set_after_post_setup(
|
||||||
self,
|
self, app_and_client: tuple[object, AsyncClient]
|
||||||
app_and_client: tuple[object, AsyncClient],
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""``_setup_complete_cached`` is set to True on the first request after setup.
|
"""``setup_complete_cached`` is set to True immediately after setup."""
|
||||||
|
|
||||||
The ``/api/setup`` path is in ``_ALWAYS_ALLOWED`` so it bypasses the
|
|
||||||
middleware check. The first request to a non-exempt endpoint triggers
|
|
||||||
the DB query and, when setup is complete, populates the cache flag.
|
|
||||||
"""
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
|
||||||
app, client = app_and_client
|
app, client = app_and_client
|
||||||
assert isinstance(app, FastAPI)
|
assert isinstance(app, FastAPI)
|
||||||
|
|
||||||
# Complete setup (exempt from middleware, no flag set yet).
|
|
||||||
resp = await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
resp = await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
||||||
assert resp.status_code == 201
|
assert resp.status_code == 201
|
||||||
|
assert app.state.setup_complete_cached is True
|
||||||
# Flag not yet cached — setup was via an exempt path.
|
|
||||||
assert not getattr(app.state, "_setup_complete_cached", False)
|
|
||||||
|
|
||||||
# First non-exempt request — middleware queries DB and sets the flag.
|
|
||||||
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]})
|
|
||||||
|
|
||||||
assert app.state._setup_complete_cached is True
|
|
||||||
|
|
||||||
async def test_cached_path_skips_is_setup_complete(
|
async def test_cached_path_skips_is_setup_complete(
|
||||||
self,
|
self,
|
||||||
@@ -268,11 +255,11 @@ class TestSetupCompleteCaching:
|
|||||||
# Do setup and warm the cache.
|
# Do setup and warm the cache.
|
||||||
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
||||||
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]})
|
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]})
|
||||||
assert app.state._setup_complete_cached is True
|
assert app.state.setup_complete_cached is True
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
async def _counting(db: aiosqlite.Connection) -> bool:
|
async def _counting(_db: aiosqlite.Connection) -> bool:
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
return True
|
return True
|
||||||
@@ -380,6 +367,55 @@ class TestLifespanDatabaseDirectoryCreation:
|
|||||||
assert tmp_path.exists()
|
assert tmp_path.exists()
|
||||||
|
|
||||||
|
|
||||||
|
class TestLifespanSetupCache:
|
||||||
|
"""Verify that app startup resolves setup completion into app.state."""
|
||||||
|
|
||||||
|
async def test_startup_caches_setup_completion(self, tmp_path: Path) -> None:
|
||||||
|
"""Lifespan should populate ``setup_complete_cached`` based on the DB."""
|
||||||
|
settings = Settings(
|
||||||
|
database_path=str(tmp_path / "bangui.db"),
|
||||||
|
fail2ban_socket="/tmp/fake.sock",
|
||||||
|
session_secret="test-lifespan-setup-cache-secret",
|
||||||
|
session_duration_minutes=60,
|
||||||
|
timezone="UTC",
|
||||||
|
log_level="debug",
|
||||||
|
)
|
||||||
|
app = create_app(settings=settings)
|
||||||
|
|
||||||
|
db = await aiosqlite.connect(settings.database_path)
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
await init_db(db)
|
||||||
|
await setup_service.run_setup(
|
||||||
|
db,
|
||||||
|
master_password="supersecret123",
|
||||||
|
database_path=settings.database_path,
|
||||||
|
fail2ban_socket=settings.fail2ban_socket,
|
||||||
|
timezone=settings.timezone,
|
||||||
|
session_duration_minutes=settings.session_duration_minutes,
|
||||||
|
)
|
||||||
|
await db.close()
|
||||||
|
|
||||||
|
mock_scheduler = MagicMock()
|
||||||
|
mock_scheduler.start = MagicMock()
|
||||||
|
mock_scheduler.shutdown = MagicMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.services.geo_service.init_geoip"),
|
||||||
|
patch(
|
||||||
|
"app.services.geo_service.load_cache_from_db",
|
||||||
|
new=AsyncMock(return_value=None),
|
||||||
|
),
|
||||||
|
patch("app.tasks.health_check.register"),
|
||||||
|
patch("app.tasks.blocklist_import.register"),
|
||||||
|
patch("app.tasks.geo_cache_flush.register"),
|
||||||
|
patch("app.tasks.geo_re_resolve.register"),
|
||||||
|
patch("app.main.AsyncIOScheduler", return_value=mock_scheduler),
|
||||||
|
patch("app.main.ensure_jail_configs"),
|
||||||
|
):
|
||||||
|
async with _lifespan(app):
|
||||||
|
assert app.state.setup_complete_cached is True
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Task 0.2 — Middleware redirects when app.state.db is None
|
# Task 0.2 — Middleware redirects when app.state.db is None
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
Reference in New Issue
Block a user