Reduce per-request DB overhead (Task 4)

- Cache setup_completed flag in app.state._setup_complete_cached after
  first successful is_setup_complete() call; all subsequent API requests
  skip the DB query entirely (one-way transition, cleared on restart).
- Add in-memory session token TTL cache (10 s) in require_auth; the second
  request with the same token within the window skips session_repo.get_session.
- Call invalidate_session_cache() on logout so revoked tokens are evicted
  immediately rather than waiting for TTL expiry.
- Add clear_session_cache() for test isolation.
- 5 new tests covering the cached fast-path for both optimisations.
- 460 tests pass, 83% coverage, zero ruff/mypy warnings.
This commit is contained in:
2026-03-10 19:16:00 +01:00
parent 44a5a3d70e
commit d931e8c6a3
7 changed files with 428 additions and 17 deletions

View File

@@ -6,6 +6,7 @@ Routers import directly from this module — never from ``app.state``
directly — to keep coupling explicit and testable.
"""
import time
from typing import Annotated
import aiosqlite
@@ -14,11 +15,44 @@ from fastapi import Depends, HTTPException, Request, status
from app.config import Settings
from app.models.auth import Session
from app.utils.time_utils import utc_now
log: structlog.stdlib.BoundLogger = structlog.get_logger()
_COOKIE_NAME = "bangui_session"
# ---------------------------------------------------------------------------
# Session validation cache
# ---------------------------------------------------------------------------
#: How long (seconds) a validated session token is served from the in-memory
#: cache without re-querying SQLite. Eliminates repeated DB lookups for the
#: same token arriving in near-simultaneous parallel requests.
_SESSION_CACHE_TTL: float = 10.0
#: ``token → (Session, cache_expiry_monotonic_time)``
_session_cache: dict[str, tuple[Session, float]] = {}
def clear_session_cache() -> None:
"""Flush the entire in-memory session validation cache.
Useful in tests to prevent stale state from leaking between test cases.
"""
_session_cache.clear()
def invalidate_session_cache(token: str) -> None:
"""Evict *token* from the in-memory session cache.
Must be called during logout so the revoked token is no longer served
from cache without a DB round-trip.
Args:
token: The session token to remove.
"""
_session_cache.pop(token, None)
async def get_db(request: Request) -> aiosqlite.Connection:
"""Provide the shared :class:`aiosqlite.Connection` from ``app.state``.
@@ -63,6 +97,11 @@ async def require_auth(
The token is read from the ``bangui_session`` cookie or the
``Authorization: Bearer`` header.
Validated tokens are cached in memory for :data:`_SESSION_CACHE_TTL`
seconds so that concurrent requests sharing the same token avoid repeated
SQLite round-trips. The cache is bypassed on expiry and explicitly
cleared by :func:`invalidate_session_cache` on logout.
Args:
request: The incoming FastAPI request.
db: Injected aiosqlite connection.
@@ -88,8 +127,18 @@ async def require_auth(
headers={"WWW-Authenticate": "Bearer"},
)
# Fast path: serve from in-memory cache when the entry is still fresh and
# the session itself has not yet exceeded its own expiry time.
cached = _session_cache.get(token)
if cached is not None:
session, cache_expires_at = cached
if time.monotonic() < cache_expires_at and session.expires_at > utc_now().isoformat():
return session
# Stale cache entry — evict and fall through to DB.
_session_cache.pop(token, None)
try:
return await auth_service.validate_session(db, token)
session = await auth_service.validate_session(db, token)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@@ -97,6 +146,9 @@ async def require_auth(
headers={"WWW-Authenticate": "Bearer"},
) from exc
_session_cache[token] = (session, time.monotonic() + _SESSION_CACHE_TTL)
return session
# Convenience type aliases for route signatures.
DbDep = Annotated[aiosqlite.Connection, Depends(get_db)]

View File

@@ -289,12 +289,19 @@ class SetupRedirectMiddleware(BaseHTTPMiddleware):
return await call_next(request)
# If setup is not complete, block all other API requests.
if path.startswith("/api"):
# Fast path: setup completion is a one-way transition. Once it is
# True it is cached on app.state so all subsequent requests skip the
# DB query entirely. The flag is reset only when the app restarts.
if path.startswith("/api") and not getattr(
request.app.state, "_setup_complete_cached", False
):
db: aiosqlite.Connection | None = getattr(request.app.state, "db", None)
if db is not None:
from app.services import setup_service # noqa: PLC0415
if not await setup_service.is_setup_complete(db):
if await setup_service.is_setup_complete(db):
request.app.state._setup_complete_cached = True
else:
return RedirectResponse(
url="/api/setup",
status_code=status.HTTP_307_TEMPORARY_REDIRECT,

View File

@@ -12,7 +12,7 @@ from __future__ import annotations
import structlog
from fastapi import APIRouter, HTTPException, Request, Response, status
from app.dependencies import DbDep, SettingsDep
from app.dependencies import DbDep, SettingsDep, invalidate_session_cache
from app.models.auth import LoginRequest, LoginResponse, LogoutResponse
from app.services import auth_service
@@ -101,6 +101,7 @@ async def logout(
token = _extract_token(request)
if token:
await auth_service.logout(db, token)
invalidate_session_cache(token)
response.delete_cookie(key=_COOKIE_NAME)
return LogoutResponse()

View File

@@ -2,6 +2,9 @@
from __future__ import annotations
from unittest.mock import patch
import pytest
from httpx import AsyncClient
# ---------------------------------------------------------------------------
@@ -143,5 +146,107 @@ class TestRequireAuth:
self, client: AsyncClient
) -> None:
"""Health endpoint is accessible without authentication."""
# ---------------------------------------------------------------------------
# Session-token cache (Task 4)
# ---------------------------------------------------------------------------
class TestRequireAuthSessionCache:
"""In-memory session token cache inside ``require_auth``."""
@pytest.fixture(autouse=True)
def reset_cache(self) -> None: # type: ignore[misc]
"""Flush the session cache before and after every test in this class."""
from app import dependencies
dependencies.clear_session_cache()
yield # type: ignore[misc]
dependencies.clear_session_cache()
async def test_second_request_skips_db(self, client: AsyncClient) -> None:
"""Second authenticated request within TTL skips the session DB query.
The first request populates the in-memory cache via ``require_auth``.
The second request — using the same token before the TTL expires —
must return ``session_repo.get_session`` *without* calling it.
"""
from app.repositories import session_repo
await _do_setup(client)
token = await _login(client)
# Ensure cache is empty so the first request definitely hits the DB.
from app import dependencies
dependencies.clear_session_cache()
call_count = 0
original_get_session = session_repo.get_session
async def _tracking(db, tok): # type: ignore[no-untyped-def]
nonlocal call_count
call_count += 1
return await original_get_session(db, tok)
with patch.object(session_repo, "get_session", side_effect=_tracking):
resp1 = await client.get(
"/api/dashboard/status",
headers={"Authorization": f"Bearer {token}"},
)
resp2 = await client.get(
"/api/dashboard/status",
headers={"Authorization": f"Bearer {token}"},
)
assert resp1.status_code == 200
assert resp2.status_code == 200
# DB queried exactly once: the first request populates the cache,
# the second request is served entirely from memory.
assert call_count == 1
async def test_token_enters_cache_after_first_auth(
self, client: AsyncClient
) -> None:
"""A successful auth request places the token in ``_session_cache``."""
from app import dependencies
await _do_setup(client)
token = await _login(client)
dependencies.clear_session_cache()
assert token not in dependencies._session_cache
await client.get(
"/api/dashboard/status",
headers={"Authorization": f"Bearer {token}"},
)
assert token in dependencies._session_cache
async def test_logout_evicts_token_from_cache(
self, client: AsyncClient
) -> None:
"""Logout removes the session token from the in-memory cache immediately."""
from app import dependencies
await _do_setup(client)
token = await _login(client)
# Warm the cache.
await client.get(
"/api/dashboard/status",
headers={"Authorization": f"Bearer {token}"},
)
assert token in dependencies._session_cache
# Logout must evict the entry.
await client.post(
"/api/auth/logout",
headers={"Authorization": f"Bearer {token}"},
)
assert token not in dependencies._session_cache
response = await client.get("/api/health")
assert response.status_code == 200

View File

@@ -2,7 +2,65 @@
from __future__ import annotations
from httpx import AsyncClient
from pathlib import Path
from unittest.mock import patch
import aiosqlite
import pytest
from httpx import ASGITransport, AsyncClient
from app.config import Settings
from app.db import init_db
from app.main import create_app
# ---------------------------------------------------------------------------
# Shared setup payload
# ---------------------------------------------------------------------------
_SETUP_PAYLOAD: dict[str, object] = {
"master_password": "supersecret123",
"database_path": "bangui.db",
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
"timezone": "UTC",
"session_duration_minutes": 60,
}
# ---------------------------------------------------------------------------
# Fixture for tests that need direct access to app.state
# ---------------------------------------------------------------------------
@pytest.fixture
async def app_and_client(tmp_path: Path) -> tuple[object, AsyncClient]: # type: ignore[misc]
"""Yield ``(app, client)`` for tests that inspect ``app.state`` directly.
Args:
tmp_path: Pytest-provided isolated temporary directory.
Yields:
A tuple of ``(FastAPI app instance, AsyncClient)``.
"""
settings = Settings(
database_path=str(tmp_path / "setup_cache_test.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
session_secret="test-setup-cache-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
db.row_factory = aiosqlite.Row
await init_db(db)
app.state.db = db
transport: ASGITransport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield app, ac
await db.close()
class TestGetSetupStatus:
@@ -156,3 +214,75 @@ class TestGetTimezone:
# Should return 200, not a 307 redirect, because /api/setup paths
# are always allowed by the SetupRedirectMiddleware.
assert response.status_code == 200
# ---------------------------------------------------------------------------
# Setup-complete flag caching in SetupRedirectMiddleware (Task 4)
# ---------------------------------------------------------------------------
class TestSetupCompleteCaching:
"""SetupRedirectMiddleware caches the setup_complete flag in ``app.state``."""
async def test_cache_flag_set_after_first_post_setup_request(
self,
app_and_client: tuple[object, AsyncClient],
) -> None:
"""``_setup_complete_cached`` is set to True on the first request after setup.
The ``/api/setup`` path is in ``_ALWAYS_ALLOWED`` so it bypasses the
middleware check. The first request to a non-exempt endpoint triggers
the DB query and, when setup is complete, populates the cache flag.
"""
from fastapi import FastAPI
app, client = app_and_client
assert isinstance(app, FastAPI)
# Complete setup (exempt from middleware, no flag set yet).
resp = await client.post("/api/setup", json=_SETUP_PAYLOAD)
assert resp.status_code == 201
# Flag not yet cached — setup was via an exempt path.
assert not getattr(app.state, "_setup_complete_cached", False)
# First non-exempt request — middleware queries DB and sets the flag.
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload]
assert app.state._setup_complete_cached is True # type: ignore[attr-defined]
async def test_cached_path_skips_is_setup_complete(
self,
app_and_client: tuple[object, AsyncClient],
) -> None:
"""Subsequent requests do not call ``is_setup_complete`` once flag is cached.
After the flag is set, the middleware must not touch the database for
any further requests — even if ``is_setup_complete`` would raise.
"""
from fastapi import FastAPI
app, client = app_and_client
assert isinstance(app, FastAPI)
# Do setup and warm the cache.
await client.post("/api/setup", json=_SETUP_PAYLOAD)
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload]
assert app.state._setup_complete_cached is True # type: ignore[attr-defined]
call_count = 0
async def _counting(db): # type: ignore[no-untyped-def]
nonlocal call_count
call_count += 1
return True
with patch("app.services.setup_service.is_setup_complete", side_effect=_counting):
await client.post(
"/api/auth/login",
json={"password": _SETUP_PAYLOAD["master_password"]},
)
# Cache was warm — is_setup_complete must not have been called.
assert call_count == 0