Replace process-local session cache with pluggable session cache backend
This commit is contained in:
@@ -27,6 +27,7 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
|
||||
- Issue: `_session_cache` in `app/dependencies.py` is a process-local dict, so logout invalidation and session revocation only work within a single process.
|
||||
- Propose: Define a cache interface and provide a default in-memory implementation, with the option to swap in shared cache storage (Redis, Memcached) for clustered production deployments.
|
||||
- Test: Add unit tests for the cache abstraction and verify logout/invalidation behaves correctly through the configured cache implementation.
|
||||
- Status: completed
|
||||
|
||||
4. Harden SQLite connection configuration and lifecycle
|
||||
- Goal: Make application database access robust under concurrent requests and background task load.
|
||||
|
||||
@@ -7,7 +7,6 @@ directly — to keep coupling explicit and testable.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Annotated, Protocol, cast
|
||||
|
||||
@@ -22,7 +21,7 @@ from app.models.auth import Session
|
||||
from app.models.config import PendingRecovery
|
||||
from app.models.server import ServerStatus
|
||||
from app.utils.runtime_state import RuntimeState
|
||||
from app.utils.time_utils import utc_now
|
||||
from app.utils.session_cache import SessionCache
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -37,6 +36,7 @@ class AppState(Protocol):
|
||||
pending_recovery: PendingRecovery | None
|
||||
last_activation: dict[str, datetime.datetime] | None
|
||||
runtime_state: RuntimeState
|
||||
session_cache: SessionCache
|
||||
|
||||
|
||||
_COOKIE_NAME = "bangui_session"
|
||||
@@ -50,40 +50,14 @@ _COOKIE_NAME = "bangui_session"
|
||||
#: same token arriving in near-simultaneous parallel requests.
|
||||
#:
|
||||
#: NOTE: this cache is process-local and is not cluster-safe. In multi-worker
|
||||
#: or distributed deployments, each process maintains its own cache, so logout
|
||||
#: invalidation and revocation may be delayed unless a shared cache is used.
|
||||
#: ``token → (Session, cache_expiry_monotonic_time)``
|
||||
_session_cache: dict[str, tuple[Session, float]] = {}
|
||||
|
||||
|
||||
def clear_session_cache() -> None:
|
||||
"""Flush the entire in-memory session validation cache.
|
||||
|
||||
Useful in tests to prevent stale state from leaking between test cases.
|
||||
This only affects the current process.
|
||||
"""
|
||||
_session_cache.clear()
|
||||
|
||||
#: or distributed deployments, the configured cache backend should provide
|
||||
#: invalidation semantics appropriate for the deployment.
|
||||
|
||||
def _session_cache_enabled(settings: Settings) -> bool:
|
||||
"""Return whether the in-memory session cache should be used."""
|
||||
"""Return whether the session validation cache should be used."""
|
||||
return settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0
|
||||
|
||||
|
||||
def invalidate_session_cache(token: str) -> None:
|
||||
"""Evict *token* from the in-memory session cache.
|
||||
|
||||
Must be called during logout so the revoked token is no longer served
|
||||
from cache without a DB round-trip. This invalidation is local to the
|
||||
current process; a clustered deployment would need a shared cache for
|
||||
global invalidation.
|
||||
|
||||
Args:
|
||||
token: The session token to remove.
|
||||
"""
|
||||
_session_cache.pop(token, None)
|
||||
|
||||
|
||||
async def get_db(request: Request) -> AsyncGenerator[aiosqlite.Connection, None]:
|
||||
"""Provide a request-scoped :class:`aiosqlite.Connection` for the current request.
|
||||
|
||||
@@ -188,6 +162,20 @@ async def get_fail2ban_start_command(settings: Settings = Depends(get_settings))
|
||||
"""Provide the configured fail2ban start command."""
|
||||
return settings.fail2ban_start_command
|
||||
|
||||
|
||||
async def get_session_cache(request: Request) -> SessionCache:
|
||||
"""Provide the configured session cache backend from application state."""
|
||||
state = cast("AppState", request.app.state)
|
||||
session_cache = getattr(state, "session_cache", None)
|
||||
if session_cache is None:
|
||||
log.error("session_cache_unavailable")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
detail="Session cache is not available.",
|
||||
)
|
||||
return session_cache
|
||||
|
||||
|
||||
async def get_app_state(request: Request) -> AppState:
|
||||
"""Provide the application state object for the current request."""
|
||||
return cast("AppState", request.app.state)
|
||||
@@ -212,6 +200,7 @@ async def require_auth(
|
||||
request: Request,
|
||||
db: Annotated[aiosqlite.Connection, Depends(get_db)],
|
||||
settings: Annotated[Settings, Depends(get_settings)],
|
||||
session_cache: Annotated[SessionCache, Depends(get_session_cache)],
|
||||
) -> Session:
|
||||
"""Validate the session token and return the active session.
|
||||
|
||||
@@ -223,7 +212,7 @@ async def require_auth(
|
||||
round-trips. This cache is disabled by default because process-local
|
||||
invalidation is not safe in multi-worker or clustered deployments.
|
||||
When enabled, entries are bypassed on expiry and explicitly cleared by
|
||||
:func:`invalidate_session_cache` on logout.
|
||||
the configured session cache backend on logout.
|
||||
|
||||
Args:
|
||||
request: The incoming FastAPI request.
|
||||
@@ -253,15 +242,9 @@ async def require_auth(
|
||||
|
||||
cache_enabled = _session_cache_enabled(settings)
|
||||
if cache_enabled:
|
||||
# Fast path: serve from in-memory cache when the entry is still fresh and
|
||||
# the session itself has not yet exceeded its own expiry time.
|
||||
cached = _session_cache.get(token)
|
||||
cached = session_cache.get(token)
|
||||
if cached is not None:
|
||||
session, cache_expires_at = cached
|
||||
if time.monotonic() < cache_expires_at and session.expires_at > utc_now().isoformat():
|
||||
return session
|
||||
# Stale cache entry — evict and fall through to DB.
|
||||
_session_cache.pop(token, None)
|
||||
return cached
|
||||
|
||||
try:
|
||||
session = await auth_service.validate_session(db, token, settings.session_secret)
|
||||
@@ -273,10 +256,7 @@ async def require_auth(
|
||||
) from exc
|
||||
|
||||
if cache_enabled:
|
||||
_session_cache[token] = (
|
||||
session,
|
||||
time.monotonic() + settings.session_cache_ttl_seconds,
|
||||
)
|
||||
session_cache.set(token, session, settings.session_cache_ttl_seconds)
|
||||
return session
|
||||
|
||||
|
||||
@@ -290,6 +270,7 @@ Fail2BanConfigDirDep = Annotated[str, Depends(get_fail2ban_config_dir)]
|
||||
Fail2BanStartCommandDep = Annotated[str, Depends(get_fail2ban_start_command)]
|
||||
ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)]
|
||||
PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)]
|
||||
SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)]
|
||||
AppStateDep = Annotated[AppState, Depends(get_app_state)]
|
||||
AppDep = Annotated[FastAPI, Depends(get_app)]
|
||||
AuthDep = Annotated[Session, Depends(require_auth)]
|
||||
|
||||
@@ -46,6 +46,7 @@ from app.routers import (
|
||||
from app.startup import startup_shared_resources
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError, Fail2BanProtocolError
|
||||
from app.utils.runtime_state import ApplicationState, RuntimeState
|
||||
from app.utils.session_cache import InMemorySessionCache
|
||||
from app.utils.setup_state import is_setup_complete_cached, set_setup_complete_cache
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
@@ -289,6 +290,7 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
||||
# shared Starlette state bag itself does not hold mutable business state.
|
||||
app.state = ApplicationState(RuntimeState())
|
||||
app.state.settings = resolved_settings
|
||||
app.state.session_cache = InMemorySessionCache()
|
||||
set_setup_complete_cache(app, False)
|
||||
|
||||
# --- CORS ---
|
||||
|
||||
@@ -12,7 +12,7 @@ from __future__ import annotations
|
||||
import structlog
|
||||
from fastapi import APIRouter, HTTPException, Request, Response, status
|
||||
|
||||
from app.dependencies import DbDep, SettingsDep, invalidate_session_cache
|
||||
from app.dependencies import DbDep, SessionCacheDep, SettingsDep
|
||||
from app.models.auth import LoginRequest, LoginResponse, LogoutResponse
|
||||
from app.services import auth_service
|
||||
|
||||
@@ -88,6 +88,7 @@ async def logout(
|
||||
response: Response,
|
||||
db: DbDep,
|
||||
settings: SettingsDep,
|
||||
session_cache: SessionCacheDep,
|
||||
) -> LogoutResponse:
|
||||
"""Invalidate the active session.
|
||||
|
||||
@@ -108,8 +109,8 @@ async def logout(
|
||||
if token:
|
||||
raw_token = await auth_service.logout(db, token, settings.session_secret)
|
||||
if raw_token:
|
||||
invalidate_session_cache(raw_token)
|
||||
invalidate_session_cache(token)
|
||||
session_cache.invalidate(raw_token)
|
||||
session_cache.invalidate(token)
|
||||
response.delete_cookie(key=_COOKIE_NAME)
|
||||
return LogoutResponse()
|
||||
|
||||
|
||||
57
backend/app/utils/session_cache.py
Normal file
57
backend/app/utils/session_cache.py
Normal file
@@ -0,0 +1,57 @@
|
||||
"""Pluggable session cache abstraction.
|
||||
|
||||
This module defines a cache interface for authenticated sessions and a default
|
||||
process-local in-memory implementation. The backend can swap the cache
|
||||
implementation without changing the authentication dependency logic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from app.models.auth import Session
|
||||
|
||||
|
||||
class SessionCache(Protocol):
|
||||
"""Interface for session token validation cache backends."""
|
||||
|
||||
def get(self, token: str) -> Session | None:
|
||||
"""Return the cached session for *token*, or ``None`` if missing."""
|
||||
|
||||
def set(self, token: str, session: Session, ttl_seconds: float) -> None:
|
||||
"""Cache the validated *session* for *token* for *ttl_seconds*."""
|
||||
|
||||
def invalidate(self, token: str) -> None:
|
||||
"""Remove *token* from the cache if it exists."""
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all entries from the cache."""
|
||||
|
||||
|
||||
class InMemorySessionCache:
|
||||
"""A process-local session cache implementation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._entries: dict[str, tuple[Session, float]] = {}
|
||||
|
||||
def get(self, token: str) -> Session | None:
|
||||
entry = self._entries.get(token)
|
||||
if entry is None:
|
||||
return None
|
||||
|
||||
session, expires_at = entry
|
||||
if time.monotonic() >= expires_at:
|
||||
self._entries.pop(token, None)
|
||||
return None
|
||||
return session
|
||||
|
||||
def set(self, token: str, session: Session, ttl_seconds: float) -> None:
|
||||
self._entries[token] = (session, time.monotonic() + ttl_seconds)
|
||||
|
||||
def invalidate(self, token: str) -> None:
|
||||
self._entries.pop(token, None)
|
||||
|
||||
def clear(self) -> None:
|
||||
self._entries.clear()
|
||||
@@ -209,13 +209,11 @@ class TestRequireAuthSessionCache:
|
||||
"""In-memory session token cache inside ``require_auth``."""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_cache(self) -> Generator[None, None, None]:
|
||||
def reset_cache(self, client: AsyncClient) -> Generator[None, None, None]:
|
||||
"""Flush the session cache before and after every test in this class."""
|
||||
from app import dependencies
|
||||
|
||||
dependencies.clear_session_cache()
|
||||
client._transport.app.state.session_cache.clear()
|
||||
yield
|
||||
dependencies.clear_session_cache()
|
||||
client._transport.app.state.session_cache.clear()
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def enable_session_cache(self, client: AsyncClient) -> Generator[None, None, None]:
|
||||
@@ -237,9 +235,7 @@ class TestRequireAuthSessionCache:
|
||||
token = await _login(client)
|
||||
|
||||
# Ensure cache is empty so the first request definitely hits the DB.
|
||||
from app import dependencies
|
||||
|
||||
dependencies.clear_session_cache()
|
||||
client._transport.app.state.session_cache.clear()
|
||||
|
||||
call_count = 0
|
||||
original_get_session = session_repo.get_session
|
||||
@@ -268,27 +264,25 @@ class TestRequireAuthSessionCache:
|
||||
async def test_token_enters_cache_after_first_auth(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
"""A successful auth request places the token in ``_session_cache``."""
|
||||
from app import dependencies
|
||||
"""A successful auth request places the token in the session cache."""
|
||||
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
|
||||
dependencies.clear_session_cache()
|
||||
assert token not in dependencies._session_cache
|
||||
client._transport.app.state.session_cache.clear()
|
||||
assert client._transport.app.state.session_cache.get(token) is None
|
||||
|
||||
await client.get(
|
||||
"/api/dashboard/status",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
|
||||
assert token in dependencies._session_cache
|
||||
assert client._transport.app.state.session_cache.get(token) is not None
|
||||
|
||||
async def test_logout_evicts_token_from_cache(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
"""Logout removes the session token from the in-memory cache immediately."""
|
||||
from app import dependencies
|
||||
"""Logout removes the session token from the session cache immediately."""
|
||||
|
||||
await _do_setup(client)
|
||||
token = await _login(client)
|
||||
@@ -298,14 +292,14 @@ class TestRequireAuthSessionCache:
|
||||
"/api/dashboard/status",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert token in dependencies._session_cache
|
||||
assert client._transport.app.state.session_cache.get(token) is not None
|
||||
|
||||
# Logout must evict the entry.
|
||||
await client.post(
|
||||
"/api/auth/logout",
|
||||
headers={"Authorization": f"Bearer {token}"},
|
||||
)
|
||||
assert token not in dependencies._session_cache
|
||||
assert client._transport.app.state.session_cache.get(token) is None
|
||||
|
||||
response = await client.get("/api/health")
|
||||
assert response.status_code == 200
|
||||
|
||||
Reference in New Issue
Block a user