Replace process-local session cache with pluggable session cache backend

This commit is contained in:
2026-04-10 19:22:02 +02:00
parent 2157502670
commit 1dfc17f4f5
6 changed files with 100 additions and 64 deletions

View File

@@ -27,6 +27,7 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
- Issue: `_session_cache` in `app/dependencies.py` is a process-local dict, so logout invalidation and session revocation only work within a single process. - Issue: `_session_cache` in `app/dependencies.py` is a process-local dict, so logout invalidation and session revocation only work within a single process.
- Propose: Define a cache interface and provide a default in-memory implementation, with the option to swap in shared cache storage (Redis, Memcached) for clustered production deployments. - Propose: Define a cache interface and provide a default in-memory implementation, with the option to swap in shared cache storage (Redis, Memcached) for clustered production deployments.
- Test: Add unit tests for the cache abstraction and verify logout/invalidation behaves correctly through the configured cache implementation. - Test: Add unit tests for the cache abstraction and verify logout/invalidation behaves correctly through the configured cache implementation.
- Status: completed
4. Harden SQLite connection configuration and lifecycle 4. Harden SQLite connection configuration and lifecycle
- Goal: Make application database access robust under concurrent requests and background task load. - Goal: Make application database access robust under concurrent requests and background task load.

View File

@@ -7,7 +7,6 @@ directly — to keep coupling explicit and testable.
""" """
import datetime import datetime
import time
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import Annotated, Protocol, cast from typing import Annotated, Protocol, cast
@@ -22,7 +21,7 @@ from app.models.auth import Session
from app.models.config import PendingRecovery from app.models.config import PendingRecovery
from app.models.server import ServerStatus from app.models.server import ServerStatus
from app.utils.runtime_state import RuntimeState from app.utils.runtime_state import RuntimeState
from app.utils.time_utils import utc_now from app.utils.session_cache import SessionCache
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -37,6 +36,7 @@ class AppState(Protocol):
pending_recovery: PendingRecovery | None pending_recovery: PendingRecovery | None
last_activation: dict[str, datetime.datetime] | None last_activation: dict[str, datetime.datetime] | None
runtime_state: RuntimeState runtime_state: RuntimeState
session_cache: SessionCache
_COOKIE_NAME = "bangui_session" _COOKIE_NAME = "bangui_session"
@@ -50,40 +50,14 @@ _COOKIE_NAME = "bangui_session"
#: same token arriving in near-simultaneous parallel requests. #: same token arriving in near-simultaneous parallel requests.
#: #:
#: NOTE: this cache is process-local and is not cluster-safe. In multi-worker #: NOTE: this cache is process-local and is not cluster-safe. In multi-worker
#: or distributed deployments, each process maintains its own cache, so logout #: or distributed deployments, the configured cache backend should provide
#: invalidation and revocation may be delayed unless a shared cache is used. #: invalidation semantics appropriate for the deployment.
#: ``token → (Session, cache_expiry_monotonic_time)``
_session_cache: dict[str, tuple[Session, float]] = {}
def clear_session_cache() -> None:
"""Flush the entire in-memory session validation cache.
Useful in tests to prevent stale state from leaking between test cases.
This only affects the current process.
"""
_session_cache.clear()
def _session_cache_enabled(settings: Settings) -> bool: def _session_cache_enabled(settings: Settings) -> bool:
"""Return whether the in-memory session cache should be used.""" """Return whether the session validation cache should be used."""
return settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0 return settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0
def invalidate_session_cache(token: str) -> None:
"""Evict *token* from the in-memory session cache.
Must be called during logout so the revoked token is no longer served
from cache without a DB round-trip. This invalidation is local to the
current process; a clustered deployment would need a shared cache for
global invalidation.
Args:
token: The session token to remove.
"""
_session_cache.pop(token, None)
async def get_db(request: Request) -> AsyncGenerator[aiosqlite.Connection, None]: async def get_db(request: Request) -> AsyncGenerator[aiosqlite.Connection, None]:
"""Provide a request-scoped :class:`aiosqlite.Connection` for the current request. """Provide a request-scoped :class:`aiosqlite.Connection` for the current request.
@@ -188,6 +162,20 @@ async def get_fail2ban_start_command(settings: Settings = Depends(get_settings))
"""Provide the configured fail2ban start command.""" """Provide the configured fail2ban start command."""
return settings.fail2ban_start_command return settings.fail2ban_start_command
async def get_session_cache(request: Request) -> SessionCache:
"""Provide the configured session cache backend from application state."""
state = cast("AppState", request.app.state)
session_cache = getattr(state, "session_cache", None)
if session_cache is None:
log.error("session_cache_unavailable")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Session cache is not available.",
)
return session_cache
async def get_app_state(request: Request) -> AppState: async def get_app_state(request: Request) -> AppState:
"""Provide the application state object for the current request.""" """Provide the application state object for the current request."""
return cast("AppState", request.app.state) return cast("AppState", request.app.state)
@@ -212,6 +200,7 @@ async def require_auth(
request: Request, request: Request,
db: Annotated[aiosqlite.Connection, Depends(get_db)], db: Annotated[aiosqlite.Connection, Depends(get_db)],
settings: Annotated[Settings, Depends(get_settings)], settings: Annotated[Settings, Depends(get_settings)],
session_cache: Annotated[SessionCache, Depends(get_session_cache)],
) -> Session: ) -> Session:
"""Validate the session token and return the active session. """Validate the session token and return the active session.
@@ -223,7 +212,7 @@ async def require_auth(
round-trips. This cache is disabled by default because process-local round-trips. This cache is disabled by default because process-local
invalidation is not safe in multi-worker or clustered deployments. invalidation is not safe in multi-worker or clustered deployments.
When enabled, entries are bypassed on expiry and explicitly cleared by When enabled, entries are bypassed on expiry and explicitly cleared by
:func:`invalidate_session_cache` on logout. the configured session cache backend on logout.
Args: Args:
request: The incoming FastAPI request. request: The incoming FastAPI request.
@@ -253,15 +242,9 @@ async def require_auth(
cache_enabled = _session_cache_enabled(settings) cache_enabled = _session_cache_enabled(settings)
if cache_enabled: if cache_enabled:
# Fast path: serve from in-memory cache when the entry is still fresh and cached = session_cache.get(token)
# the session itself has not yet exceeded its own expiry time.
cached = _session_cache.get(token)
if cached is not None: if cached is not None:
session, cache_expires_at = cached return cached
if time.monotonic() < cache_expires_at and session.expires_at > utc_now().isoformat():
return session
# Stale cache entry — evict and fall through to DB.
_session_cache.pop(token, None)
try: try:
session = await auth_service.validate_session(db, token, settings.session_secret) session = await auth_service.validate_session(db, token, settings.session_secret)
@@ -273,10 +256,7 @@ async def require_auth(
) from exc ) from exc
if cache_enabled: if cache_enabled:
_session_cache[token] = ( session_cache.set(token, session, settings.session_cache_ttl_seconds)
session,
time.monotonic() + settings.session_cache_ttl_seconds,
)
return session return session
@@ -290,6 +270,7 @@ Fail2BanConfigDirDep = Annotated[str, Depends(get_fail2ban_config_dir)]
Fail2BanStartCommandDep = Annotated[str, Depends(get_fail2ban_start_command)] Fail2BanStartCommandDep = Annotated[str, Depends(get_fail2ban_start_command)]
ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)] ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)]
PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)] PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)]
SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)]
AppStateDep = Annotated[AppState, Depends(get_app_state)] AppStateDep = Annotated[AppState, Depends(get_app_state)]
AppDep = Annotated[FastAPI, Depends(get_app)] AppDep = Annotated[FastAPI, Depends(get_app)]
AuthDep = Annotated[Session, Depends(require_auth)] AuthDep = Annotated[Session, Depends(require_auth)]

View File

@@ -46,6 +46,7 @@ from app.routers import (
from app.startup import startup_shared_resources from app.startup import startup_shared_resources
from app.utils.fail2ban_client import Fail2BanConnectionError, Fail2BanProtocolError from app.utils.fail2ban_client import Fail2BanConnectionError, Fail2BanProtocolError
from app.utils.runtime_state import ApplicationState, RuntimeState from app.utils.runtime_state import ApplicationState, RuntimeState
from app.utils.session_cache import InMemorySessionCache
from app.utils.setup_state import is_setup_complete_cached, set_setup_complete_cache from app.utils.setup_state import is_setup_complete_cached, set_setup_complete_cache
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -289,6 +290,7 @@ def create_app(settings: Settings | None = None) -> FastAPI:
# shared Starlette state bag itself does not hold mutable business state. # shared Starlette state bag itself does not hold mutable business state.
app.state = ApplicationState(RuntimeState()) app.state = ApplicationState(RuntimeState())
app.state.settings = resolved_settings app.state.settings = resolved_settings
app.state.session_cache = InMemorySessionCache()
set_setup_complete_cache(app, False) set_setup_complete_cache(app, False)
# --- CORS --- # --- CORS ---

View File

@@ -12,7 +12,7 @@ from __future__ import annotations
import structlog import structlog
from fastapi import APIRouter, HTTPException, Request, Response, status from fastapi import APIRouter, HTTPException, Request, Response, status
from app.dependencies import DbDep, SettingsDep, invalidate_session_cache from app.dependencies import DbDep, SessionCacheDep, SettingsDep
from app.models.auth import LoginRequest, LoginResponse, LogoutResponse from app.models.auth import LoginRequest, LoginResponse, LogoutResponse
from app.services import auth_service from app.services import auth_service
@@ -88,6 +88,7 @@ async def logout(
response: Response, response: Response,
db: DbDep, db: DbDep,
settings: SettingsDep, settings: SettingsDep,
session_cache: SessionCacheDep,
) -> LogoutResponse: ) -> LogoutResponse:
"""Invalidate the active session. """Invalidate the active session.
@@ -108,8 +109,8 @@ async def logout(
if token: if token:
raw_token = await auth_service.logout(db, token, settings.session_secret) raw_token = await auth_service.logout(db, token, settings.session_secret)
if raw_token: if raw_token:
invalidate_session_cache(raw_token) session_cache.invalidate(raw_token)
invalidate_session_cache(token) session_cache.invalidate(token)
response.delete_cookie(key=_COOKIE_NAME) response.delete_cookie(key=_COOKIE_NAME)
return LogoutResponse() return LogoutResponse()

View File

@@ -0,0 +1,57 @@
"""Pluggable session cache abstraction.
This module defines a cache interface for authenticated sessions and a default
process-local in-memory implementation. The backend can swap the cache
implementation without changing the authentication dependency logic.
"""
from __future__ import annotations
import time
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING: # pragma: no cover
from app.models.auth import Session
class SessionCache(Protocol):
"""Interface for session token validation cache backends."""
def get(self, token: str) -> Session | None:
"""Return the cached session for *token*, or ``None`` if missing."""
def set(self, token: str, session: Session, ttl_seconds: float) -> None:
"""Cache the validated *session* for *token* for *ttl_seconds*."""
def invalidate(self, token: str) -> None:
"""Remove *token* from the cache if it exists."""
def clear(self) -> None:
"""Remove all entries from the cache."""
class InMemorySessionCache:
"""A process-local session cache implementation."""
def __init__(self) -> None:
self._entries: dict[str, tuple[Session, float]] = {}
def get(self, token: str) -> Session | None:
entry = self._entries.get(token)
if entry is None:
return None
session, expires_at = entry
if time.monotonic() >= expires_at:
self._entries.pop(token, None)
return None
return session
def set(self, token: str, session: Session, ttl_seconds: float) -> None:
self._entries[token] = (session, time.monotonic() + ttl_seconds)
def invalidate(self, token: str) -> None:
self._entries.pop(token, None)
def clear(self) -> None:
self._entries.clear()

View File

@@ -209,13 +209,11 @@ class TestRequireAuthSessionCache:
"""In-memory session token cache inside ``require_auth``.""" """In-memory session token cache inside ``require_auth``."""
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reset_cache(self) -> Generator[None, None, None]: def reset_cache(self, client: AsyncClient) -> Generator[None, None, None]:
"""Flush the session cache before and after every test in this class.""" """Flush the session cache before and after every test in this class."""
from app import dependencies client._transport.app.state.session_cache.clear()
dependencies.clear_session_cache()
yield yield
dependencies.clear_session_cache() client._transport.app.state.session_cache.clear()
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def enable_session_cache(self, client: AsyncClient) -> Generator[None, None, None]: def enable_session_cache(self, client: AsyncClient) -> Generator[None, None, None]:
@@ -237,9 +235,7 @@ class TestRequireAuthSessionCache:
token = await _login(client) token = await _login(client)
# Ensure cache is empty so the first request definitely hits the DB. # Ensure cache is empty so the first request definitely hits the DB.
from app import dependencies client._transport.app.state.session_cache.clear()
dependencies.clear_session_cache()
call_count = 0 call_count = 0
original_get_session = session_repo.get_session original_get_session = session_repo.get_session
@@ -268,27 +264,25 @@ class TestRequireAuthSessionCache:
async def test_token_enters_cache_after_first_auth( async def test_token_enters_cache_after_first_auth(
self, client: AsyncClient self, client: AsyncClient
) -> None: ) -> None:
"""A successful auth request places the token in ``_session_cache``.""" """A successful auth request places the token in the session cache."""
from app import dependencies
await _do_setup(client) await _do_setup(client)
token = await _login(client) token = await _login(client)
dependencies.clear_session_cache() client._transport.app.state.session_cache.clear()
assert token not in dependencies._session_cache assert client._transport.app.state.session_cache.get(token) is None
await client.get( await client.get(
"/api/dashboard/status", "/api/dashboard/status",
headers={"Authorization": f"Bearer {token}"}, headers={"Authorization": f"Bearer {token}"},
) )
assert token in dependencies._session_cache assert client._transport.app.state.session_cache.get(token) is not None
async def test_logout_evicts_token_from_cache( async def test_logout_evicts_token_from_cache(
self, client: AsyncClient self, client: AsyncClient
) -> None: ) -> None:
"""Logout removes the session token from the in-memory cache immediately.""" """Logout removes the session token from the session cache immediately."""
from app import dependencies
await _do_setup(client) await _do_setup(client)
token = await _login(client) token = await _login(client)
@@ -298,14 +292,14 @@ class TestRequireAuthSessionCache:
"/api/dashboard/status", "/api/dashboard/status",
headers={"Authorization": f"Bearer {token}"}, headers={"Authorization": f"Bearer {token}"},
) )
assert token in dependencies._session_cache assert client._transport.app.state.session_cache.get(token) is not None
# Logout must evict the entry. # Logout must evict the entry.
await client.post( await client.post(
"/api/auth/logout", "/api/auth/logout",
headers={"Authorization": f"Bearer {token}"}, headers={"Authorization": f"Bearer {token}"},
) )
assert token not in dependencies._session_cache assert client._transport.app.state.session_cache.get(token) is None
response = await client.get("/api/health") response = await client.get("/api/health")
assert response.status_code == 200 assert response.status_code == 200