Replace process-local session cache with pluggable session cache backend
This commit is contained in:
@@ -27,6 +27,7 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
|
|||||||
- Issue: `_session_cache` in `app/dependencies.py` is a process-local dict, so logout invalidation and session revocation only work within a single process.
|
- Issue: `_session_cache` in `app/dependencies.py` is a process-local dict, so logout invalidation and session revocation only work within a single process.
|
||||||
- Propose: Define a cache interface and provide a default in-memory implementation, with the option to swap in shared cache storage (Redis, Memcached) for clustered production deployments.
|
- Propose: Define a cache interface and provide a default in-memory implementation, with the option to swap in shared cache storage (Redis, Memcached) for clustered production deployments.
|
||||||
- Test: Add unit tests for the cache abstraction and verify logout/invalidation behaves correctly through the configured cache implementation.
|
- Test: Add unit tests for the cache abstraction and verify logout/invalidation behaves correctly through the configured cache implementation.
|
||||||
|
- Status: completed
|
||||||
|
|
||||||
4. Harden SQLite connection configuration and lifecycle
|
4. Harden SQLite connection configuration and lifecycle
|
||||||
- Goal: Make application database access robust under concurrent requests and background task load.
|
- Goal: Make application database access robust under concurrent requests and background task load.
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ directly — to keep coupling explicit and testable.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import time
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Annotated, Protocol, cast
|
from typing import Annotated, Protocol, cast
|
||||||
|
|
||||||
@@ -22,7 +21,7 @@ from app.models.auth import Session
|
|||||||
from app.models.config import PendingRecovery
|
from app.models.config import PendingRecovery
|
||||||
from app.models.server import ServerStatus
|
from app.models.server import ServerStatus
|
||||||
from app.utils.runtime_state import RuntimeState
|
from app.utils.runtime_state import RuntimeState
|
||||||
from app.utils.time_utils import utc_now
|
from app.utils.session_cache import SessionCache
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
@@ -37,6 +36,7 @@ class AppState(Protocol):
|
|||||||
pending_recovery: PendingRecovery | None
|
pending_recovery: PendingRecovery | None
|
||||||
last_activation: dict[str, datetime.datetime] | None
|
last_activation: dict[str, datetime.datetime] | None
|
||||||
runtime_state: RuntimeState
|
runtime_state: RuntimeState
|
||||||
|
session_cache: SessionCache
|
||||||
|
|
||||||
|
|
||||||
_COOKIE_NAME = "bangui_session"
|
_COOKIE_NAME = "bangui_session"
|
||||||
@@ -50,40 +50,14 @@ _COOKIE_NAME = "bangui_session"
|
|||||||
#: same token arriving in near-simultaneous parallel requests.
|
#: same token arriving in near-simultaneous parallel requests.
|
||||||
#:
|
#:
|
||||||
#: NOTE: this cache is process-local and is not cluster-safe. In multi-worker
|
#: NOTE: this cache is process-local and is not cluster-safe. In multi-worker
|
||||||
#: or distributed deployments, each process maintains its own cache, so logout
|
#: or distributed deployments, the configured cache backend should provide
|
||||||
#: invalidation and revocation may be delayed unless a shared cache is used.
|
#: invalidation semantics appropriate for the deployment.
|
||||||
#: ``token → (Session, cache_expiry_monotonic_time)``
|
|
||||||
_session_cache: dict[str, tuple[Session, float]] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def clear_session_cache() -> None:
|
|
||||||
"""Flush the entire in-memory session validation cache.
|
|
||||||
|
|
||||||
Useful in tests to prevent stale state from leaking between test cases.
|
|
||||||
This only affects the current process.
|
|
||||||
"""
|
|
||||||
_session_cache.clear()
|
|
||||||
|
|
||||||
|
|
||||||
def _session_cache_enabled(settings: Settings) -> bool:
|
def _session_cache_enabled(settings: Settings) -> bool:
|
||||||
"""Return whether the in-memory session cache should be used."""
|
"""Return whether the session validation cache should be used."""
|
||||||
return settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0
|
return settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0
|
||||||
|
|
||||||
|
|
||||||
def invalidate_session_cache(token: str) -> None:
|
|
||||||
"""Evict *token* from the in-memory session cache.
|
|
||||||
|
|
||||||
Must be called during logout so the revoked token is no longer served
|
|
||||||
from cache without a DB round-trip. This invalidation is local to the
|
|
||||||
current process; a clustered deployment would need a shared cache for
|
|
||||||
global invalidation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token: The session token to remove.
|
|
||||||
"""
|
|
||||||
_session_cache.pop(token, None)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_db(request: Request) -> AsyncGenerator[aiosqlite.Connection, None]:
|
async def get_db(request: Request) -> AsyncGenerator[aiosqlite.Connection, None]:
|
||||||
"""Provide a request-scoped :class:`aiosqlite.Connection` for the current request.
|
"""Provide a request-scoped :class:`aiosqlite.Connection` for the current request.
|
||||||
|
|
||||||
@@ -188,6 +162,20 @@ async def get_fail2ban_start_command(settings: Settings = Depends(get_settings))
|
|||||||
"""Provide the configured fail2ban start command."""
|
"""Provide the configured fail2ban start command."""
|
||||||
return settings.fail2ban_start_command
|
return settings.fail2ban_start_command
|
||||||
|
|
||||||
|
|
||||||
|
async def get_session_cache(request: Request) -> SessionCache:
|
||||||
|
"""Provide the configured session cache backend from application state."""
|
||||||
|
state = cast("AppState", request.app.state)
|
||||||
|
session_cache = getattr(state, "session_cache", None)
|
||||||
|
if session_cache is None:
|
||||||
|
log.error("session_cache_unavailable")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
detail="Session cache is not available.",
|
||||||
|
)
|
||||||
|
return session_cache
|
||||||
|
|
||||||
|
|
||||||
async def get_app_state(request: Request) -> AppState:
|
async def get_app_state(request: Request) -> AppState:
|
||||||
"""Provide the application state object for the current request."""
|
"""Provide the application state object for the current request."""
|
||||||
return cast("AppState", request.app.state)
|
return cast("AppState", request.app.state)
|
||||||
@@ -212,6 +200,7 @@ async def require_auth(
|
|||||||
request: Request,
|
request: Request,
|
||||||
db: Annotated[aiosqlite.Connection, Depends(get_db)],
|
db: Annotated[aiosqlite.Connection, Depends(get_db)],
|
||||||
settings: Annotated[Settings, Depends(get_settings)],
|
settings: Annotated[Settings, Depends(get_settings)],
|
||||||
|
session_cache: Annotated[SessionCache, Depends(get_session_cache)],
|
||||||
) -> Session:
|
) -> Session:
|
||||||
"""Validate the session token and return the active session.
|
"""Validate the session token and return the active session.
|
||||||
|
|
||||||
@@ -223,7 +212,7 @@ async def require_auth(
|
|||||||
round-trips. This cache is disabled by default because process-local
|
round-trips. This cache is disabled by default because process-local
|
||||||
invalidation is not safe in multi-worker or clustered deployments.
|
invalidation is not safe in multi-worker or clustered deployments.
|
||||||
When enabled, entries are bypassed on expiry and explicitly cleared by
|
When enabled, entries are bypassed on expiry and explicitly cleared by
|
||||||
:func:`invalidate_session_cache` on logout.
|
the configured session cache backend on logout.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: The incoming FastAPI request.
|
request: The incoming FastAPI request.
|
||||||
@@ -253,15 +242,9 @@ async def require_auth(
|
|||||||
|
|
||||||
cache_enabled = _session_cache_enabled(settings)
|
cache_enabled = _session_cache_enabled(settings)
|
||||||
if cache_enabled:
|
if cache_enabled:
|
||||||
# Fast path: serve from in-memory cache when the entry is still fresh and
|
cached = session_cache.get(token)
|
||||||
# the session itself has not yet exceeded its own expiry time.
|
|
||||||
cached = _session_cache.get(token)
|
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
session, cache_expires_at = cached
|
return cached
|
||||||
if time.monotonic() < cache_expires_at and session.expires_at > utc_now().isoformat():
|
|
||||||
return session
|
|
||||||
# Stale cache entry — evict and fall through to DB.
|
|
||||||
_session_cache.pop(token, None)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
session = await auth_service.validate_session(db, token, settings.session_secret)
|
session = await auth_service.validate_session(db, token, settings.session_secret)
|
||||||
@@ -273,10 +256,7 @@ async def require_auth(
|
|||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
if cache_enabled:
|
if cache_enabled:
|
||||||
_session_cache[token] = (
|
session_cache.set(token, session, settings.session_cache_ttl_seconds)
|
||||||
session,
|
|
||||||
time.monotonic() + settings.session_cache_ttl_seconds,
|
|
||||||
)
|
|
||||||
return session
|
return session
|
||||||
|
|
||||||
|
|
||||||
@@ -290,6 +270,7 @@ Fail2BanConfigDirDep = Annotated[str, Depends(get_fail2ban_config_dir)]
|
|||||||
Fail2BanStartCommandDep = Annotated[str, Depends(get_fail2ban_start_command)]
|
Fail2BanStartCommandDep = Annotated[str, Depends(get_fail2ban_start_command)]
|
||||||
ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)]
|
ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)]
|
||||||
PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)]
|
PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)]
|
||||||
|
SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)]
|
||||||
AppStateDep = Annotated[AppState, Depends(get_app_state)]
|
AppStateDep = Annotated[AppState, Depends(get_app_state)]
|
||||||
AppDep = Annotated[FastAPI, Depends(get_app)]
|
AppDep = Annotated[FastAPI, Depends(get_app)]
|
||||||
AuthDep = Annotated[Session, Depends(require_auth)]
|
AuthDep = Annotated[Session, Depends(require_auth)]
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ from app.routers import (
|
|||||||
from app.startup import startup_shared_resources
|
from app.startup import startup_shared_resources
|
||||||
from app.utils.fail2ban_client import Fail2BanConnectionError, Fail2BanProtocolError
|
from app.utils.fail2ban_client import Fail2BanConnectionError, Fail2BanProtocolError
|
||||||
from app.utils.runtime_state import ApplicationState, RuntimeState
|
from app.utils.runtime_state import ApplicationState, RuntimeState
|
||||||
|
from app.utils.session_cache import InMemorySessionCache
|
||||||
from app.utils.setup_state import is_setup_complete_cached, set_setup_complete_cache
|
from app.utils.setup_state import is_setup_complete_cached, set_setup_complete_cache
|
||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
@@ -289,6 +290,7 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
|||||||
# shared Starlette state bag itself does not hold mutable business state.
|
# shared Starlette state bag itself does not hold mutable business state.
|
||||||
app.state = ApplicationState(RuntimeState())
|
app.state = ApplicationState(RuntimeState())
|
||||||
app.state.settings = resolved_settings
|
app.state.settings = resolved_settings
|
||||||
|
app.state.session_cache = InMemorySessionCache()
|
||||||
set_setup_complete_cache(app, False)
|
set_setup_complete_cache(app, False)
|
||||||
|
|
||||||
# --- CORS ---
|
# --- CORS ---
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from __future__ import annotations
|
|||||||
import structlog
|
import structlog
|
||||||
from fastapi import APIRouter, HTTPException, Request, Response, status
|
from fastapi import APIRouter, HTTPException, Request, Response, status
|
||||||
|
|
||||||
from app.dependencies import DbDep, SettingsDep, invalidate_session_cache
|
from app.dependencies import DbDep, SessionCacheDep, SettingsDep
|
||||||
from app.models.auth import LoginRequest, LoginResponse, LogoutResponse
|
from app.models.auth import LoginRequest, LoginResponse, LogoutResponse
|
||||||
from app.services import auth_service
|
from app.services import auth_service
|
||||||
|
|
||||||
@@ -88,6 +88,7 @@ async def logout(
|
|||||||
response: Response,
|
response: Response,
|
||||||
db: DbDep,
|
db: DbDep,
|
||||||
settings: SettingsDep,
|
settings: SettingsDep,
|
||||||
|
session_cache: SessionCacheDep,
|
||||||
) -> LogoutResponse:
|
) -> LogoutResponse:
|
||||||
"""Invalidate the active session.
|
"""Invalidate the active session.
|
||||||
|
|
||||||
@@ -108,8 +109,8 @@ async def logout(
|
|||||||
if token:
|
if token:
|
||||||
raw_token = await auth_service.logout(db, token, settings.session_secret)
|
raw_token = await auth_service.logout(db, token, settings.session_secret)
|
||||||
if raw_token:
|
if raw_token:
|
||||||
invalidate_session_cache(raw_token)
|
session_cache.invalidate(raw_token)
|
||||||
invalidate_session_cache(token)
|
session_cache.invalidate(token)
|
||||||
response.delete_cookie(key=_COOKIE_NAME)
|
response.delete_cookie(key=_COOKIE_NAME)
|
||||||
return LogoutResponse()
|
return LogoutResponse()
|
||||||
|
|
||||||
|
|||||||
57
backend/app/utils/session_cache.py
Normal file
57
backend/app/utils/session_cache.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
"""Pluggable session cache abstraction.
|
||||||
|
|
||||||
|
This module defines a cache interface for authenticated sessions and a default
|
||||||
|
process-local in-memory implementation. The backend can swap the cache
|
||||||
|
implementation without changing the authentication dependency logic.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
|
if TYPE_CHECKING: # pragma: no cover
|
||||||
|
from app.models.auth import Session
|
||||||
|
|
||||||
|
|
||||||
|
class SessionCache(Protocol):
|
||||||
|
"""Interface for session token validation cache backends."""
|
||||||
|
|
||||||
|
def get(self, token: str) -> Session | None:
|
||||||
|
"""Return the cached session for *token*, or ``None`` if missing."""
|
||||||
|
|
||||||
|
def set(self, token: str, session: Session, ttl_seconds: float) -> None:
|
||||||
|
"""Cache the validated *session* for *token* for *ttl_seconds*."""
|
||||||
|
|
||||||
|
def invalidate(self, token: str) -> None:
|
||||||
|
"""Remove *token* from the cache if it exists."""
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Remove all entries from the cache."""
|
||||||
|
|
||||||
|
|
||||||
|
class InMemorySessionCache:
|
||||||
|
"""A process-local session cache implementation."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._entries: dict[str, tuple[Session, float]] = {}
|
||||||
|
|
||||||
|
def get(self, token: str) -> Session | None:
|
||||||
|
entry = self._entries.get(token)
|
||||||
|
if entry is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
session, expires_at = entry
|
||||||
|
if time.monotonic() >= expires_at:
|
||||||
|
self._entries.pop(token, None)
|
||||||
|
return None
|
||||||
|
return session
|
||||||
|
|
||||||
|
def set(self, token: str, session: Session, ttl_seconds: float) -> None:
|
||||||
|
self._entries[token] = (session, time.monotonic() + ttl_seconds)
|
||||||
|
|
||||||
|
def invalidate(self, token: str) -> None:
|
||||||
|
self._entries.pop(token, None)
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
self._entries.clear()
|
||||||
@@ -209,13 +209,11 @@ class TestRequireAuthSessionCache:
|
|||||||
"""In-memory session token cache inside ``require_auth``."""
|
"""In-memory session token cache inside ``require_auth``."""
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def reset_cache(self) -> Generator[None, None, None]:
|
def reset_cache(self, client: AsyncClient) -> Generator[None, None, None]:
|
||||||
"""Flush the session cache before and after every test in this class."""
|
"""Flush the session cache before and after every test in this class."""
|
||||||
from app import dependencies
|
client._transport.app.state.session_cache.clear()
|
||||||
|
|
||||||
dependencies.clear_session_cache()
|
|
||||||
yield
|
yield
|
||||||
dependencies.clear_session_cache()
|
client._transport.app.state.session_cache.clear()
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def enable_session_cache(self, client: AsyncClient) -> Generator[None, None, None]:
|
def enable_session_cache(self, client: AsyncClient) -> Generator[None, None, None]:
|
||||||
@@ -237,9 +235,7 @@ class TestRequireAuthSessionCache:
|
|||||||
token = await _login(client)
|
token = await _login(client)
|
||||||
|
|
||||||
# Ensure cache is empty so the first request definitely hits the DB.
|
# Ensure cache is empty so the first request definitely hits the DB.
|
||||||
from app import dependencies
|
client._transport.app.state.session_cache.clear()
|
||||||
|
|
||||||
dependencies.clear_session_cache()
|
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
original_get_session = session_repo.get_session
|
original_get_session = session_repo.get_session
|
||||||
@@ -268,27 +264,25 @@ class TestRequireAuthSessionCache:
|
|||||||
async def test_token_enters_cache_after_first_auth(
|
async def test_token_enters_cache_after_first_auth(
|
||||||
self, client: AsyncClient
|
self, client: AsyncClient
|
||||||
) -> None:
|
) -> None:
|
||||||
"""A successful auth request places the token in ``_session_cache``."""
|
"""A successful auth request places the token in the session cache."""
|
||||||
from app import dependencies
|
|
||||||
|
|
||||||
await _do_setup(client)
|
await _do_setup(client)
|
||||||
token = await _login(client)
|
token = await _login(client)
|
||||||
|
|
||||||
dependencies.clear_session_cache()
|
client._transport.app.state.session_cache.clear()
|
||||||
assert token not in dependencies._session_cache
|
assert client._transport.app.state.session_cache.get(token) is None
|
||||||
|
|
||||||
await client.get(
|
await client.get(
|
||||||
"/api/dashboard/status",
|
"/api/dashboard/status",
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
)
|
)
|
||||||
|
|
||||||
assert token in dependencies._session_cache
|
assert client._transport.app.state.session_cache.get(token) is not None
|
||||||
|
|
||||||
async def test_logout_evicts_token_from_cache(
|
async def test_logout_evicts_token_from_cache(
|
||||||
self, client: AsyncClient
|
self, client: AsyncClient
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Logout removes the session token from the in-memory cache immediately."""
|
"""Logout removes the session token from the session cache immediately."""
|
||||||
from app import dependencies
|
|
||||||
|
|
||||||
await _do_setup(client)
|
await _do_setup(client)
|
||||||
token = await _login(client)
|
token = await _login(client)
|
||||||
@@ -298,14 +292,14 @@ class TestRequireAuthSessionCache:
|
|||||||
"/api/dashboard/status",
|
"/api/dashboard/status",
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
)
|
)
|
||||||
assert token in dependencies._session_cache
|
assert client._transport.app.state.session_cache.get(token) is not None
|
||||||
|
|
||||||
# Logout must evict the entry.
|
# Logout must evict the entry.
|
||||||
await client.post(
|
await client.post(
|
||||||
"/api/auth/logout",
|
"/api/auth/logout",
|
||||||
headers={"Authorization": f"Bearer {token}"},
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
)
|
)
|
||||||
assert token not in dependencies._session_cache
|
assert client._transport.app.state.session_cache.get(token) is None
|
||||||
|
|
||||||
response = await client.get("/api/health")
|
response = await client.get("/api/health")
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|||||||
Reference in New Issue
Block a user