Move session cache initialization from per-request _build_app_context to startup lifespan handler. The session cache type is now decided once at app startup based on settings, making _build_app_context pure (read-only). Changes: - Move cache initialization logic to new _update_session_cache() in main.py - Call _update_session_cache() during lifespan startup to initialize cache - Remove three if/elif/elif branches mutating state.session_cache from _build_app_context - Add cache swap logic to set_runtime_settings() in runtime_state.py to handle runtime settings changes (e.g., setup wizard updates) - Keep app.state.session_cache initialization in create_app() for test compatibility This ensures: - _build_app_context is pure and doesn't mutate app state on each request - Session cache configuration decisions are centralized at startup - Settings changes during runtime (via setup wizard) also trigger cache swap - Cache initialization logic is isolated in one place Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
397 lines
15 KiB
Python
397 lines
15 KiB
Python
"""FastAPI dependency providers.
|
|
|
|
All ``Depends()`` callables that inject shared resources (database
|
|
connection, settings, services, auth guard) are defined here.
|
|
Routers import directly from this module — never from ``app.state``
|
|
directly — to keep coupling explicit and testable.
|
|
"""
|
|
|
|
import datetime
|
|
from collections.abc import AsyncGenerator
|
|
from dataclasses import dataclass
|
|
from typing import Annotated, Protocol, cast
|
|
|
|
import aiohttp
|
|
import aiosqlite
|
|
import structlog
|
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
|
|
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
|
|
|
from app.config import Settings
|
|
from app.models.auth import Session
|
|
from app.models.config import PendingRecovery
|
|
from app.models.geo import GeoBatchLookup
|
|
from app.models.server import ServerStatus
|
|
from app.repositories.protocols import (
|
|
BlocklistRepository,
|
|
Fail2BanDbRepository,
|
|
GeoCacheRepository,
|
|
HistoryArchiveRepository,
|
|
ImportLogRepository,
|
|
SessionRepository,
|
|
SettingsRepository,
|
|
)
|
|
from app.services.geo_cache import GeoCache
|
|
from app.utils.constants import SESSION_COOKIE_NAME
|
|
from app.utils.runtime_state import RuntimeState
|
|
from app.utils.session_cache import NoOpSessionCache, SessionCache
|
|
|
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
|
|
|
|
|
class AppState(Protocol):
|
|
"""Partial view of the FastAPI application state used by dependencies."""
|
|
|
|
settings: Settings
|
|
http_session: aiohttp.ClientSession
|
|
scheduler: AsyncIOScheduler
|
|
server_status: ServerStatus
|
|
pending_recovery: PendingRecovery | None
|
|
last_activation: dict[str, datetime.datetime] | None
|
|
runtime_settings: Settings | None
|
|
runtime_state: RuntimeState
|
|
session_cache: SessionCache
|
|
geo_cache: GeoCache # noqa: F821
|
|
|
|
|
|
@dataclass
|
|
class ApplicationContext:
|
|
"""A typed wrapper around shared application lifecycle resources."""
|
|
|
|
settings: Settings
|
|
http_session: aiohttp.ClientSession | None
|
|
scheduler: AsyncIOScheduler | None
|
|
server_status: ServerStatus
|
|
pending_recovery: PendingRecovery | None
|
|
last_activation: dict[str, datetime.datetime] | None
|
|
runtime_settings: Settings | None
|
|
runtime_state: RuntimeState
|
|
session_cache: SessionCache | None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Session validation cache
|
|
# ---------------------------------------------------------------------------
|
|
|
|
#: How long (seconds) a validated session token is served from the in-memory
|
|
#: cache without re-querying SQLite. Eliminates repeated DB lookups for the
|
|
#: same token arriving in near-simultaneous parallel requests.
|
|
#:
|
|
#: NOTE: this cache is process-local and is not cluster-safe. In multi-worker
|
|
#: or distributed deployments, the configured cache backend should provide
|
|
#: invalidation semantics appropriate for the deployment.
|
|
|
|
def _session_cache_enabled(settings: Settings) -> bool:
|
|
"""Return whether the session validation cache should be used."""
|
|
return settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0
|
|
|
|
|
|
def _build_app_context(request: Request) -> ApplicationContext:
|
|
state = cast("AppState", request.app.state)
|
|
session_cache = getattr(state, "session_cache", None)
|
|
if session_cache is None:
|
|
session_cache = NoOpSessionCache()
|
|
|
|
return ApplicationContext(
|
|
settings=state.settings,
|
|
http_session=getattr(state, "http_session", None),
|
|
scheduler=getattr(state, "scheduler", None),
|
|
server_status=getattr(state, "server_status", ServerStatus(online=False)),
|
|
pending_recovery=getattr(state, "pending_recovery", None),
|
|
last_activation=getattr(state, "last_activation", None),
|
|
runtime_settings=getattr(state, "runtime_settings", None),
|
|
runtime_state=state.runtime_state,
|
|
session_cache=session_cache,
|
|
)
|
|
|
|
|
|
async def get_app_context(request: Request) -> ApplicationContext:
|
|
"""Provide the typed application context for the current request."""
|
|
return _build_app_context(request)
|
|
|
|
|
|
async def get_settings(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> Settings:
|
|
"""Provide the effective application settings for the current request."""
|
|
return app_context.runtime_settings if app_context.runtime_settings is not None else app_context.settings
|
|
|
|
|
|
async def get_db(
|
|
settings: Annotated[Settings, Depends(get_settings)],
|
|
) -> AsyncGenerator[aiosqlite.Connection, None]:
|
|
"""Provide a request-scoped :class:`aiosqlite.Connection` for the current request.
|
|
|
|
Opens a fresh connection for every request and closes it when the request
|
|
is finished. This avoids contention and locking issues from a single shared
|
|
SQLite connection across concurrent requests.
|
|
|
|
The database path is taken from the effective application settings so
|
|
runtime overrides stored during setup are respected.
|
|
|
|
Args:
|
|
settings: The effective application settings for the current request.
|
|
|
|
Yields:
|
|
An open :class:`aiosqlite.Connection` for the request.
|
|
"""
|
|
from app.db import open_db # noqa: PLC0415
|
|
|
|
try:
|
|
db = await open_db(settings.database_path)
|
|
except Exception as exc:
|
|
log.error("database_open_failed", error=str(exc))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Database is not available.",
|
|
) from exc
|
|
|
|
try:
|
|
yield db
|
|
finally:
|
|
await db.close()
|
|
|
|
|
|
async def get_http_session(
|
|
app_context: Annotated[ApplicationContext, Depends(get_app_context)],
|
|
) -> aiohttp.ClientSession:
|
|
"""Provide the shared HTTP client session from application context.
|
|
|
|
Args:
|
|
app_context: The injected shared application context.
|
|
|
|
Returns:
|
|
A shared :class:`aiohttp.ClientSession` managed by the lifespan.
|
|
|
|
Raises:
|
|
HTTPException: If the session is unavailable.
|
|
"""
|
|
if app_context.http_session is None:
|
|
log.error("http_session_unavailable")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="HTTP session is not available.",
|
|
)
|
|
return app_context.http_session
|
|
|
|
|
|
async def get_scheduler(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> AsyncIOScheduler:
|
|
"""Provide the shared scheduler from application context.
|
|
|
|
Args:
|
|
app_context: The injected shared application context.
|
|
|
|
Returns:
|
|
The :class:`apscheduler.schedulers.asyncio.AsyncIOScheduler` instance.
|
|
|
|
Raises:
|
|
HTTPException: If the scheduler is unavailable.
|
|
"""
|
|
if app_context.scheduler is None:
|
|
log.error("scheduler_unavailable")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Scheduler is not available.",
|
|
)
|
|
return app_context.scheduler
|
|
|
|
|
|
async def get_fail2ban_socket(settings: Settings = Depends(get_settings)) -> str:
|
|
"""Provide the configured path to the fail2ban Unix domain socket."""
|
|
return settings.fail2ban_socket
|
|
|
|
|
|
async def get_fail2ban_config_dir(settings: Settings = Depends(get_settings)) -> str:
|
|
"""Provide the configured fail2ban configuration directory."""
|
|
return settings.fail2ban_config_dir
|
|
|
|
|
|
async def get_fail2ban_start_command(settings: Settings = Depends(get_settings)) -> str:
|
|
"""Provide the configured fail2ban start command."""
|
|
return settings.fail2ban_start_command
|
|
|
|
|
|
async def get_geo_batch_lookup(request: Request) -> GeoBatchLookup:
|
|
"""Provide the geo batch lookup method from the application's GeoCache instance."""
|
|
geo_cache: GeoCache = request.app.state.geo_cache
|
|
return geo_cache.lookup_batch # type: ignore[return-value]
|
|
|
|
|
|
async def get_geo_cache(request: Request) -> GeoCache:
|
|
"""Provide the application's GeoCache instance."""
|
|
return request.app.state.geo_cache
|
|
|
|
|
|
async def get_session_cache(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> SessionCache:
|
|
"""Provide the configured session cache backend from application context."""
|
|
if app_context.session_cache is None:
|
|
log.error("session_cache_unavailable")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Session cache is not available.",
|
|
)
|
|
return app_context.session_cache
|
|
|
|
|
|
async def get_session_repo() -> SessionRepository:
|
|
"""Provide the concrete session repository implementation."""
|
|
from app.repositories import session_repo # noqa: PLC0415
|
|
|
|
return session_repo
|
|
|
|
|
|
async def get_blocklist_repo() -> BlocklistRepository:
|
|
"""Provide the concrete blocklist repository implementation."""
|
|
from app.repositories import blocklist_repo # noqa: PLC0415
|
|
|
|
return cast("BlocklistRepository", blocklist_repo)
|
|
|
|
|
|
async def get_import_log_repo() -> ImportLogRepository:
|
|
"""Provide the concrete import log repository implementation."""
|
|
from app.repositories import import_log_repo # noqa: PLC0415
|
|
|
|
return cast("ImportLogRepository", import_log_repo)
|
|
|
|
|
|
async def get_settings_repo() -> SettingsRepository:
|
|
"""Provide the concrete settings repository implementation."""
|
|
from app.repositories import settings_repo # noqa: PLC0415
|
|
|
|
return cast("SettingsRepository", settings_repo)
|
|
|
|
|
|
async def get_history_archive_repo() -> HistoryArchiveRepository:
|
|
"""Provide the concrete history archive repository implementation."""
|
|
from app.repositories import history_archive_repo # noqa: PLC0415
|
|
|
|
return cast("HistoryArchiveRepository", history_archive_repo)
|
|
|
|
|
|
async def get_geo_cache_repo() -> GeoCacheRepository:
|
|
"""Provide the concrete geo cache repository implementation."""
|
|
from app.repositories import geo_cache_repo # noqa: PLC0415
|
|
|
|
return cast("GeoCacheRepository", geo_cache_repo)
|
|
|
|
|
|
async def get_fail2ban_db_repo() -> Fail2BanDbRepository:
|
|
"""Provide the concrete fail2ban DB repository implementation."""
|
|
from app.repositories import fail2ban_db_repo # noqa: PLC0415
|
|
|
|
return cast("Fail2BanDbRepository", fail2ban_db_repo)
|
|
|
|
|
|
async def get_app_state(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> ApplicationContext:
|
|
"""Provide the application state object for the current request."""
|
|
return app_context
|
|
|
|
|
|
async def get_app(request: Request) -> FastAPI:
|
|
"""Provide the FastAPI application instance for the current request."""
|
|
return request.app
|
|
|
|
|
|
async def get_server_status(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> ServerStatus:
|
|
"""Return the cached fail2ban server status snapshot from application context."""
|
|
return app_context.server_status
|
|
|
|
|
|
async def get_pending_recovery(
|
|
app_context: Annotated[ApplicationContext, Depends(get_app_context)],
|
|
) -> PendingRecovery | None:
|
|
"""Return the current pending recovery record from application context."""
|
|
return app_context.pending_recovery
|
|
|
|
async def require_auth(
|
|
request: Request,
|
|
db: Annotated[aiosqlite.Connection, Depends(get_db)],
|
|
settings: Annotated[Settings, Depends(get_settings)],
|
|
session_cache: Annotated[SessionCache, Depends(get_session_cache)],
|
|
session_repo: Annotated[SessionRepository, Depends(get_session_repo)],
|
|
) -> Session:
|
|
"""Validate the session token and return the active session.
|
|
|
|
The token is read from the ``bangui_session`` cookie or the
|
|
``Authorization: Bearer`` header.
|
|
|
|
Validated tokens may be cached in memory for a short period so that
|
|
concurrent requests sharing the same token avoid repeated SQLite
|
|
round-trips. This cache is disabled by default because process-local
|
|
invalidation is not safe in multi-worker or clustered deployments.
|
|
When enabled, entries are bypassed on expiry and explicitly cleared by
|
|
the configured session cache backend on logout.
|
|
|
|
Args:
|
|
request: The incoming FastAPI request.
|
|
db: Injected aiosqlite connection.
|
|
settings: Application settings used for signed session token validation.
|
|
|
|
Returns:
|
|
The active :class:`~app.models.auth.Session`.
|
|
|
|
Raises:
|
|
HTTPException: 401 if no valid session token is found.
|
|
"""
|
|
|
|
token: str | None = request.cookies.get(SESSION_COOKIE_NAME)
|
|
if not token:
|
|
auth_header: str = request.headers.get("Authorization", "")
|
|
if auth_header.startswith("Bearer "):
|
|
token = auth_header[len("Bearer "):]
|
|
|
|
if not token:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Authentication required.",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
cache_enabled = _session_cache_enabled(settings)
|
|
if cache_enabled:
|
|
cached = session_cache.get(token)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
from app.services import auth_service # noqa: PLC0415
|
|
|
|
try:
|
|
session = await auth_service.validate_session(
|
|
db,
|
|
token,
|
|
settings.session_secret,
|
|
session_repo=session_repo,
|
|
)
|
|
except ValueError as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=str(exc),
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
) from exc
|
|
|
|
if cache_enabled:
|
|
session_cache.set(token, session, settings.session_cache_ttl_seconds)
|
|
return session
|
|
|
|
|
|
# Convenience type aliases for route signatures.
|
|
DbDep = Annotated[aiosqlite.Connection, Depends(get_db)]
|
|
SettingsDep = Annotated[Settings, Depends(get_settings)]
|
|
HttpSessionDep = Annotated[aiohttp.ClientSession, Depends(get_http_session)]
|
|
SchedulerDep = Annotated[AsyncIOScheduler, Depends(get_scheduler)]
|
|
Fail2BanSocketDep = Annotated[str, Depends(get_fail2ban_socket)]
|
|
Fail2BanConfigDirDep = Annotated[str, Depends(get_fail2ban_config_dir)]
|
|
Fail2BanStartCommandDep = Annotated[str, Depends(get_fail2ban_start_command)]
|
|
GeoBatchLookupDep = Annotated[GeoBatchLookup, Depends(get_geo_batch_lookup)]
|
|
ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)]
|
|
PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)]
|
|
SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)]
|
|
SessionRepoDep = Annotated[SessionRepository, Depends(get_session_repo)]
|
|
SettingsRepoDep = Annotated[SettingsRepository, Depends(get_settings_repo)]
|
|
HistoryArchiveRepositoryDep = Annotated[HistoryArchiveRepository, Depends(get_history_archive_repo)]
|
|
BlocklistRepositoryDep = Annotated[BlocklistRepository, Depends(get_blocklist_repo)]
|
|
ImportLogRepositoryDep = Annotated[ImportLogRepository, Depends(get_import_log_repo)]
|
|
GeoCacheRepositoryDep = Annotated[GeoCacheRepository, Depends(get_geo_cache_repo)]
|
|
Fail2BanDbRepositoryDep = Annotated[Fail2BanDbRepository, Depends(get_fail2ban_db_repo)]
|
|
AppStateDep = Annotated[ApplicationContext, Depends(get_app_state)]
|
|
AppDep = Annotated[FastAPI, Depends(get_app)]
|
|
AuthDep = Annotated[Session, Depends(require_auth)]
|