Replace process-local session cache with pluggable session cache backend

This commit is contained in:
2026-04-10 19:22:02 +02:00
parent 2157502670
commit 1dfc17f4f5
6 changed files with 100 additions and 64 deletions

View File

@@ -27,6 +27,7 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
- Issue: `_session_cache` in `app/dependencies.py` is a process-local dict, so logout invalidation and session revocation only work within a single process.
- Propose: Define a cache interface and provide a default in-memory implementation, with the option to swap in shared cache storage (Redis, Memcached) for clustered production deployments.
- Test: Add unit tests for the cache abstraction and verify logout/invalidation behaves correctly through the configured cache implementation.
- Status: completed
4. Harden SQLite connection configuration and lifecycle
- Goal: Make application database access robust under concurrent requests and background task load.

View File

@@ -7,7 +7,6 @@ directly — to keep coupling explicit and testable.
"""
import datetime
import time
from collections.abc import AsyncGenerator
from typing import Annotated, Protocol, cast
@@ -22,7 +21,7 @@ from app.models.auth import Session
from app.models.config import PendingRecovery
from app.models.server import ServerStatus
from app.utils.runtime_state import RuntimeState
from app.utils.time_utils import utc_now
from app.utils.session_cache import SessionCache
log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -37,6 +36,7 @@ class AppState(Protocol):
pending_recovery: PendingRecovery | None
last_activation: dict[str, datetime.datetime] | None
runtime_state: RuntimeState
session_cache: SessionCache
_COOKIE_NAME = "bangui_session"
@@ -50,40 +50,14 @@ _COOKIE_NAME = "bangui_session"
#: same token arriving in near-simultaneous parallel requests.
#:
#: NOTE: this cache is process-local and is not cluster-safe. In multi-worker
#: or distributed deployments, each process maintains its own cache, so logout
#: invalidation and revocation may be delayed unless a shared cache is used.
#: ``token → (Session, cache_expiry_monotonic_time)``
_session_cache: dict[str, tuple[Session, float]] = {}
def clear_session_cache() -> None:
"""Flush the entire in-memory session validation cache.
Useful in tests to prevent stale state from leaking between test cases.
This only affects the current process.
"""
_session_cache.clear()
#: or distributed deployments, the configured cache backend should provide
#: invalidation semantics appropriate for the deployment.
def _session_cache_enabled(settings: Settings) -> bool:
"""Return whether the in-memory session cache should be used."""
"""Return whether the session validation cache should be used."""
return settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0
def invalidate_session_cache(token: str) -> None:
"""Evict *token* from the in-memory session cache.
Must be called during logout so the revoked token is no longer served
from cache without a DB round-trip. This invalidation is local to the
current process; a clustered deployment would need a shared cache for
global invalidation.
Args:
token: The session token to remove.
"""
_session_cache.pop(token, None)
async def get_db(request: Request) -> AsyncGenerator[aiosqlite.Connection, None]:
"""Provide a request-scoped :class:`aiosqlite.Connection` for the current request.
@@ -188,6 +162,20 @@ async def get_fail2ban_start_command(settings: Settings = Depends(get_settings))
"""Provide the configured fail2ban start command."""
return settings.fail2ban_start_command
async def get_session_cache(request: Request) -> SessionCache:
"""Provide the configured session cache backend from application state."""
state = cast("AppState", request.app.state)
session_cache = getattr(state, "session_cache", None)
if session_cache is None:
log.error("session_cache_unavailable")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Session cache is not available.",
)
return session_cache
async def get_app_state(request: Request) -> AppState:
"""Provide the application state object for the current request."""
return cast("AppState", request.app.state)
@@ -212,6 +200,7 @@ async def require_auth(
request: Request,
db: Annotated[aiosqlite.Connection, Depends(get_db)],
settings: Annotated[Settings, Depends(get_settings)],
session_cache: Annotated[SessionCache, Depends(get_session_cache)],
) -> Session:
"""Validate the session token and return the active session.
@@ -223,7 +212,7 @@ async def require_auth(
round-trips. This cache is disabled by default because process-local
invalidation is not safe in multi-worker or clustered deployments.
When enabled, entries are bypassed on expiry and explicitly cleared by
:func:`invalidate_session_cache` on logout.
the configured session cache backend on logout.
Args:
request: The incoming FastAPI request.
@@ -253,15 +242,9 @@ async def require_auth(
cache_enabled = _session_cache_enabled(settings)
if cache_enabled:
# Fast path: serve from in-memory cache when the entry is still fresh and
# the session itself has not yet exceeded its own expiry time.
cached = _session_cache.get(token)
cached = session_cache.get(token)
if cached is not None:
session, cache_expires_at = cached
if time.monotonic() < cache_expires_at and session.expires_at > utc_now().isoformat():
return session
# Stale cache entry — evict and fall through to DB.
_session_cache.pop(token, None)
return cached
try:
session = await auth_service.validate_session(db, token, settings.session_secret)
@@ -273,10 +256,7 @@ async def require_auth(
) from exc
if cache_enabled:
_session_cache[token] = (
session,
time.monotonic() + settings.session_cache_ttl_seconds,
)
session_cache.set(token, session, settings.session_cache_ttl_seconds)
return session
@@ -290,6 +270,7 @@ Fail2BanConfigDirDep = Annotated[str, Depends(get_fail2ban_config_dir)]
Fail2BanStartCommandDep = Annotated[str, Depends(get_fail2ban_start_command)]
ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)]
PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)]
SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)]
AppStateDep = Annotated[AppState, Depends(get_app_state)]
AppDep = Annotated[FastAPI, Depends(get_app)]
AuthDep = Annotated[Session, Depends(require_auth)]

View File

@@ -46,6 +46,7 @@ from app.routers import (
from app.startup import startup_shared_resources
from app.utils.fail2ban_client import Fail2BanConnectionError, Fail2BanProtocolError
from app.utils.runtime_state import ApplicationState, RuntimeState
from app.utils.session_cache import InMemorySessionCache
from app.utils.setup_state import is_setup_complete_cached, set_setup_complete_cache
log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -289,6 +290,7 @@ def create_app(settings: Settings | None = None) -> FastAPI:
# shared Starlette state bag itself does not hold mutable business state.
app.state = ApplicationState(RuntimeState())
app.state.settings = resolved_settings
app.state.session_cache = InMemorySessionCache()
set_setup_complete_cache(app, False)
# --- CORS ---

View File

@@ -12,7 +12,7 @@ from __future__ import annotations
import structlog
from fastapi import APIRouter, HTTPException, Request, Response, status
from app.dependencies import DbDep, SettingsDep, invalidate_session_cache
from app.dependencies import DbDep, SessionCacheDep, SettingsDep
from app.models.auth import LoginRequest, LoginResponse, LogoutResponse
from app.services import auth_service
@@ -88,6 +88,7 @@ async def logout(
response: Response,
db: DbDep,
settings: SettingsDep,
session_cache: SessionCacheDep,
) -> LogoutResponse:
"""Invalidate the active session.
@@ -108,8 +109,8 @@ async def logout(
if token:
raw_token = await auth_service.logout(db, token, settings.session_secret)
if raw_token:
invalidate_session_cache(raw_token)
invalidate_session_cache(token)
session_cache.invalidate(raw_token)
session_cache.invalidate(token)
response.delete_cookie(key=_COOKIE_NAME)
return LogoutResponse()

View File

@@ -0,0 +1,57 @@
"""Pluggable session cache abstraction.
This module defines a cache interface for authenticated sessions and a default
process-local in-memory implementation. The backend can swap the cache
implementation without changing the authentication dependency logic.
"""
from __future__ import annotations
import time
from typing import TYPE_CHECKING, Protocol
if TYPE_CHECKING: # pragma: no cover
from app.models.auth import Session
class SessionCache(Protocol):
"""Interface for session token validation cache backends."""
def get(self, token: str) -> Session | None:
"""Return the cached session for *token*, or ``None`` if missing."""
def set(self, token: str, session: Session, ttl_seconds: float) -> None:
"""Cache the validated *session* for *token* for *ttl_seconds*."""
def invalidate(self, token: str) -> None:
"""Remove *token* from the cache if it exists."""
def clear(self) -> None:
"""Remove all entries from the cache."""
class InMemorySessionCache:
"""A process-local session cache implementation."""
def __init__(self) -> None:
self._entries: dict[str, tuple[Session, float]] = {}
def get(self, token: str) -> Session | None:
entry = self._entries.get(token)
if entry is None:
return None
session, expires_at = entry
if time.monotonic() >= expires_at:
self._entries.pop(token, None)
return None
return session
def set(self, token: str, session: Session, ttl_seconds: float) -> None:
self._entries[token] = (session, time.monotonic() + ttl_seconds)
def invalidate(self, token: str) -> None:
self._entries.pop(token, None)
def clear(self) -> None:
self._entries.clear()

View File

@@ -209,13 +209,11 @@ class TestRequireAuthSessionCache:
"""In-memory session token cache inside ``require_auth``."""
@pytest.fixture(autouse=True)
def reset_cache(self) -> Generator[None, None, None]:
def reset_cache(self, client: AsyncClient) -> Generator[None, None, None]:
"""Flush the session cache before and after every test in this class."""
from app import dependencies
dependencies.clear_session_cache()
client._transport.app.state.session_cache.clear()
yield
dependencies.clear_session_cache()
client._transport.app.state.session_cache.clear()
@pytest.fixture(autouse=True)
def enable_session_cache(self, client: AsyncClient) -> Generator[None, None, None]:
@@ -237,9 +235,7 @@ class TestRequireAuthSessionCache:
token = await _login(client)
# Ensure cache is empty so the first request definitely hits the DB.
from app import dependencies
dependencies.clear_session_cache()
client._transport.app.state.session_cache.clear()
call_count = 0
original_get_session = session_repo.get_session
@@ -268,27 +264,25 @@ class TestRequireAuthSessionCache:
async def test_token_enters_cache_after_first_auth(
self, client: AsyncClient
) -> None:
"""A successful auth request places the token in ``_session_cache``."""
from app import dependencies
"""A successful auth request places the token in the session cache."""
await _do_setup(client)
token = await _login(client)
dependencies.clear_session_cache()
assert token not in dependencies._session_cache
client._transport.app.state.session_cache.clear()
assert client._transport.app.state.session_cache.get(token) is None
await client.get(
"/api/dashboard/status",
headers={"Authorization": f"Bearer {token}"},
)
assert token in dependencies._session_cache
assert client._transport.app.state.session_cache.get(token) is not None
async def test_logout_evicts_token_from_cache(
self, client: AsyncClient
) -> None:
"""Logout removes the session token from the in-memory cache immediately."""
from app import dependencies
"""Logout removes the session token from the session cache immediately."""
await _do_setup(client)
token = await _login(client)
@@ -298,14 +292,14 @@ class TestRequireAuthSessionCache:
"/api/dashboard/status",
headers={"Authorization": f"Bearer {token}"},
)
assert token in dependencies._session_cache
assert client._transport.app.state.session_cache.get(token) is not None
# Logout must evict the entry.
await client.post(
"/api/auth/logout",
headers={"Authorization": f"Bearer {token}"},
)
assert token not in dependencies._session_cache
assert client._transport.app.state.session_cache.get(token) is None
response = await client.get("/api/health")
assert response.status_code == 200