diff --git a/Docs/Tasks.md b/Docs/Tasks.md index 824fb35..5dd3676 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -27,6 +27,7 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue. - Issue: `_session_cache` in `app/dependencies.py` is a process-local dict, so logout invalidation and session revocation only work within a single process. - Propose: Define a cache interface and provide a default in-memory implementation, with the option to swap in shared cache storage (Redis, Memcached) for clustered production deployments. - Test: Add unit tests for the cache abstraction and verify logout/invalidation behaves correctly through the configured cache implementation. + - Status: completed 4. Harden SQLite connection configuration and lifecycle - Goal: Make application database access robust under concurrent requests and background task load. diff --git a/backend/app/dependencies.py b/backend/app/dependencies.py index 810ee93..85805d6 100644 --- a/backend/app/dependencies.py +++ b/backend/app/dependencies.py @@ -7,7 +7,6 @@ directly — to keep coupling explicit and testable. """ import datetime -import time from collections.abc import AsyncGenerator from typing import Annotated, Protocol, cast @@ -22,7 +21,7 @@ from app.models.auth import Session from app.models.config import PendingRecovery from app.models.server import ServerStatus from app.utils.runtime_state import RuntimeState -from app.utils.time_utils import utc_now +from app.utils.session_cache import SessionCache log: structlog.stdlib.BoundLogger = structlog.get_logger() @@ -37,6 +36,7 @@ class AppState(Protocol): pending_recovery: PendingRecovery | None last_activation: dict[str, datetime.datetime] | None runtime_state: RuntimeState + session_cache: SessionCache _COOKIE_NAME = "bangui_session" @@ -50,40 +50,14 @@ _COOKIE_NAME = "bangui_session" #: same token arriving in near-simultaneous parallel requests. #: #: NOTE: this cache is process-local and is not cluster-safe. In multi-worker -#: or distributed deployments, each process maintains its own cache, so logout -#: invalidation and revocation may be delayed unless a shared cache is used. -#: ``token → (Session, cache_expiry_monotonic_time)`` -_session_cache: dict[str, tuple[Session, float]] = {} - - -def clear_session_cache() -> None: - """Flush the entire in-memory session validation cache. - - Useful in tests to prevent stale state from leaking between test cases. - This only affects the current process. - """ - _session_cache.clear() - +#: or distributed deployments, the configured cache backend should provide +#: invalidation semantics appropriate for the deployment. def _session_cache_enabled(settings: Settings) -> bool: - """Return whether the in-memory session cache should be used.""" + """Return whether the session validation cache should be used.""" return settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0 -def invalidate_session_cache(token: str) -> None: - """Evict *token* from the in-memory session cache. - - Must be called during logout so the revoked token is no longer served - from cache without a DB round-trip. This invalidation is local to the - current process; a clustered deployment would need a shared cache for - global invalidation. - - Args: - token: The session token to remove. - """ - _session_cache.pop(token, None) - - async def get_db(request: Request) -> AsyncGenerator[aiosqlite.Connection, None]: """Provide a request-scoped :class:`aiosqlite.Connection` for the current request. @@ -188,6 +162,20 @@ async def get_fail2ban_start_command(settings: Settings = Depends(get_settings)) """Provide the configured fail2ban start command.""" return settings.fail2ban_start_command + +async def get_session_cache(request: Request) -> SessionCache: + """Provide the configured session cache backend from application state.""" + state = cast("AppState", request.app.state) + session_cache = getattr(state, "session_cache", None) + if session_cache is None: + log.error("session_cache_unavailable") + raise HTTPException( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + detail="Session cache is not available.", + ) + return session_cache + + async def get_app_state(request: Request) -> AppState: """Provide the application state object for the current request.""" return cast("AppState", request.app.state) @@ -212,6 +200,7 @@ async def require_auth( request: Request, db: Annotated[aiosqlite.Connection, Depends(get_db)], settings: Annotated[Settings, Depends(get_settings)], + session_cache: Annotated[SessionCache, Depends(get_session_cache)], ) -> Session: """Validate the session token and return the active session. @@ -223,7 +212,7 @@ async def require_auth( round-trips. This cache is disabled by default because process-local invalidation is not safe in multi-worker or clustered deployments. When enabled, entries are bypassed on expiry and explicitly cleared by - :func:`invalidate_session_cache` on logout. + the configured session cache backend on logout. Args: request: The incoming FastAPI request. @@ -253,15 +242,9 @@ async def require_auth( cache_enabled = _session_cache_enabled(settings) if cache_enabled: - # Fast path: serve from in-memory cache when the entry is still fresh and - # the session itself has not yet exceeded its own expiry time. - cached = _session_cache.get(token) + cached = session_cache.get(token) if cached is not None: - session, cache_expires_at = cached - if time.monotonic() < cache_expires_at and session.expires_at > utc_now().isoformat(): - return session - # Stale cache entry — evict and fall through to DB. - _session_cache.pop(token, None) + return cached try: session = await auth_service.validate_session(db, token, settings.session_secret) @@ -273,10 +256,7 @@ async def require_auth( ) from exc if cache_enabled: - _session_cache[token] = ( - session, - time.monotonic() + settings.session_cache_ttl_seconds, - ) + session_cache.set(token, session, settings.session_cache_ttl_seconds) return session @@ -290,6 +270,7 @@ Fail2BanConfigDirDep = Annotated[str, Depends(get_fail2ban_config_dir)] Fail2BanStartCommandDep = Annotated[str, Depends(get_fail2ban_start_command)] ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)] PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)] +SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)] AppStateDep = Annotated[AppState, Depends(get_app_state)] AppDep = Annotated[FastAPI, Depends(get_app)] AuthDep = Annotated[Session, Depends(require_auth)] diff --git a/backend/app/main.py b/backend/app/main.py index ea0c1b0..f5d6700 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -46,6 +46,7 @@ from app.routers import ( from app.startup import startup_shared_resources from app.utils.fail2ban_client import Fail2BanConnectionError, Fail2BanProtocolError from app.utils.runtime_state import ApplicationState, RuntimeState +from app.utils.session_cache import InMemorySessionCache from app.utils.setup_state import is_setup_complete_cached, set_setup_complete_cache log: structlog.stdlib.BoundLogger = structlog.get_logger() @@ -289,6 +290,7 @@ def create_app(settings: Settings | None = None) -> FastAPI: # shared Starlette state bag itself does not hold mutable business state. app.state = ApplicationState(RuntimeState()) app.state.settings = resolved_settings + app.state.session_cache = InMemorySessionCache() set_setup_complete_cache(app, False) # --- CORS --- diff --git a/backend/app/routers/auth.py b/backend/app/routers/auth.py index 32fae53..33237e1 100644 --- a/backend/app/routers/auth.py +++ b/backend/app/routers/auth.py @@ -12,7 +12,7 @@ from __future__ import annotations import structlog from fastapi import APIRouter, HTTPException, Request, Response, status -from app.dependencies import DbDep, SettingsDep, invalidate_session_cache +from app.dependencies import DbDep, SessionCacheDep, SettingsDep from app.models.auth import LoginRequest, LoginResponse, LogoutResponse from app.services import auth_service @@ -88,6 +88,7 @@ async def logout( response: Response, db: DbDep, settings: SettingsDep, + session_cache: SessionCacheDep, ) -> LogoutResponse: """Invalidate the active session. @@ -108,8 +109,8 @@ async def logout( if token: raw_token = await auth_service.logout(db, token, settings.session_secret) if raw_token: - invalidate_session_cache(raw_token) - invalidate_session_cache(token) + session_cache.invalidate(raw_token) + session_cache.invalidate(token) response.delete_cookie(key=_COOKIE_NAME) return LogoutResponse() diff --git a/backend/app/utils/session_cache.py b/backend/app/utils/session_cache.py new file mode 100644 index 0000000..6c06675 --- /dev/null +++ b/backend/app/utils/session_cache.py @@ -0,0 +1,57 @@ +"""Pluggable session cache abstraction. + +This module defines a cache interface for authenticated sessions and a default +process-local in-memory implementation. The backend can swap the cache +implementation without changing the authentication dependency logic. +""" + +from __future__ import annotations + +import time +from typing import TYPE_CHECKING, Protocol + +if TYPE_CHECKING: # pragma: no cover + from app.models.auth import Session + + +class SessionCache(Protocol): + """Interface for session token validation cache backends.""" + + def get(self, token: str) -> Session | None: + """Return the cached session for *token*, or ``None`` if missing.""" + + def set(self, token: str, session: Session, ttl_seconds: float) -> None: + """Cache the validated *session* for *token* for *ttl_seconds*.""" + + def invalidate(self, token: str) -> None: + """Remove *token* from the cache if it exists.""" + + def clear(self) -> None: + """Remove all entries from the cache.""" + + +class InMemorySessionCache: + """A process-local session cache implementation.""" + + def __init__(self) -> None: + self._entries: dict[str, tuple[Session, float]] = {} + + def get(self, token: str) -> Session | None: + entry = self._entries.get(token) + if entry is None: + return None + + session, expires_at = entry + if time.monotonic() >= expires_at: + self._entries.pop(token, None) + return None + return session + + def set(self, token: str, session: Session, ttl_seconds: float) -> None: + self._entries[token] = (session, time.monotonic() + ttl_seconds) + + def invalidate(self, token: str) -> None: + self._entries.pop(token, None) + + def clear(self) -> None: + self._entries.clear() diff --git a/backend/tests/test_routers/test_auth.py b/backend/tests/test_routers/test_auth.py index 25e33eb..1c0d583 100644 --- a/backend/tests/test_routers/test_auth.py +++ b/backend/tests/test_routers/test_auth.py @@ -209,13 +209,11 @@ class TestRequireAuthSessionCache: """In-memory session token cache inside ``require_auth``.""" @pytest.fixture(autouse=True) - def reset_cache(self) -> Generator[None, None, None]: + def reset_cache(self, client: AsyncClient) -> Generator[None, None, None]: """Flush the session cache before and after every test in this class.""" - from app import dependencies - - dependencies.clear_session_cache() + client._transport.app.state.session_cache.clear() yield - dependencies.clear_session_cache() + client._transport.app.state.session_cache.clear() @pytest.fixture(autouse=True) def enable_session_cache(self, client: AsyncClient) -> Generator[None, None, None]: @@ -237,9 +235,7 @@ class TestRequireAuthSessionCache: token = await _login(client) # Ensure cache is empty so the first request definitely hits the DB. - from app import dependencies - - dependencies.clear_session_cache() + client._transport.app.state.session_cache.clear() call_count = 0 original_get_session = session_repo.get_session @@ -268,27 +264,25 @@ class TestRequireAuthSessionCache: async def test_token_enters_cache_after_first_auth( self, client: AsyncClient ) -> None: - """A successful auth request places the token in ``_session_cache``.""" - from app import dependencies + """A successful auth request places the token in the session cache.""" await _do_setup(client) token = await _login(client) - dependencies.clear_session_cache() - assert token not in dependencies._session_cache + client._transport.app.state.session_cache.clear() + assert client._transport.app.state.session_cache.get(token) is None await client.get( "/api/dashboard/status", headers={"Authorization": f"Bearer {token}"}, ) - assert token in dependencies._session_cache + assert client._transport.app.state.session_cache.get(token) is not None async def test_logout_evicts_token_from_cache( self, client: AsyncClient ) -> None: - """Logout removes the session token from the in-memory cache immediately.""" - from app import dependencies + """Logout removes the session token from the session cache immediately.""" await _do_setup(client) token = await _login(client) @@ -298,14 +292,14 @@ class TestRequireAuthSessionCache: "/api/dashboard/status", headers={"Authorization": f"Bearer {token}"}, ) - assert token in dependencies._session_cache + assert client._transport.app.state.session_cache.get(token) is not None # Logout must evict the entry. await client.post( "/api/auth/logout", headers={"Authorization": f"Bearer {token}"}, ) - assert token not in dependencies._session_cache + assert client._transport.app.state.session_cache.get(token) is None response = await client.get("/api/health") assert response.status_code == 200