Files
BanGUI/backend/app/dependencies.py

270 lines
9.7 KiB
Python

"""FastAPI dependency providers.
All ``Depends()`` callables that inject shared resources (database
connection, settings, services, auth guard) are defined here.
Routers import directly from this module — never from ``app.state``
directly — to keep coupling explicit and testable.
"""
import datetime
from collections.abc import AsyncGenerator
from typing import Annotated, Protocol, cast
import aiohttp
import aiosqlite
import structlog
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
from fastapi import Depends, FastAPI, HTTPException, Request, status
from app.config import Settings
from app.models.auth import Session
from app.models.config import PendingRecovery
from app.models.server import ServerStatus
from app.utils.runtime_state import RuntimeState, get_effective_settings
from app.utils.session_cache import SessionCache
log: structlog.stdlib.BoundLogger = structlog.get_logger()
class AppState(Protocol):
"""Partial view of the FastAPI application state used by dependencies."""
settings: Settings
http_session: aiohttp.ClientSession
scheduler: AsyncIOScheduler
server_status: ServerStatus
pending_recovery: PendingRecovery | None
last_activation: dict[str, datetime.datetime] | None
runtime_settings: Settings | None
runtime_state: RuntimeState
session_cache: SessionCache
_COOKIE_NAME = "bangui_session"
# ---------------------------------------------------------------------------
# Session validation cache
# ---------------------------------------------------------------------------
#: How long (seconds) a validated session token is served from the in-memory
#: cache without re-querying SQLite. Eliminates repeated DB lookups for the
#: same token arriving in near-simultaneous parallel requests.
#:
#: NOTE: this cache is process-local and is not cluster-safe. In multi-worker
#: or distributed deployments, the configured cache backend should provide
#: invalidation semantics appropriate for the deployment.
def _session_cache_enabled(settings: Settings) -> bool:
"""Return whether the session validation cache should be used."""
return settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0
async def get_db(request: Request) -> AsyncGenerator[aiosqlite.Connection, None]:
"""Provide a request-scoped :class:`aiosqlite.Connection` for the current request.
Opens a fresh connection for every request and closes it when the request
is finished. This avoids contention and locking issues from a single shared
SQLite connection across concurrent requests.
Args:
request: The current FastAPI request (injected automatically).
Yields:
An open :class:`aiosqlite.Connection` for the request.
"""
from app.db import open_db # noqa: PLC0415
settings = cast("AppState", request.app.state).settings
try:
db = await open_db(settings.database_path)
except Exception as exc:
log.error("database_open_failed", error=str(exc))
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Database is not available.",
) from exc
try:
yield db
finally:
await db.close()
async def get_settings(request: Request) -> Settings:
"""Provide the effective application settings for the current request."""
return get_effective_settings(request.app)
async def get_http_session(request: Request) -> aiohttp.ClientSession:
"""Provide the shared HTTP client session from application state.
Args:
request: The current FastAPI request.
Returns:
A shared :class:`aiohttp.ClientSession` managed by the lifespan.
Raises:
HTTPException: If the session is unavailable.
"""
state = cast("AppState", request.app.state)
http_session = getattr(state, "http_session", None)
if http_session is None:
log.error("http_session_unavailable")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="HTTP session is not available.",
)
return http_session
async def get_scheduler(request: Request) -> AsyncIOScheduler:
"""Provide the shared scheduler from application state.
Args:
request: The current FastAPI request.
Returns:
The :class:`apscheduler.schedulers.asyncio.AsyncIOScheduler` instance.
Raises:
HTTPException: If the scheduler is unavailable.
"""
state = cast("AppState", request.app.state)
scheduler = getattr(state, "scheduler", None)
if scheduler is None:
log.error("scheduler_unavailable")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Scheduler is not available.",
)
return scheduler
async def get_fail2ban_socket(settings: Settings = Depends(get_settings)) -> str:
"""Provide the configured path to the fail2ban Unix domain socket."""
return settings.fail2ban_socket
async def get_fail2ban_config_dir(settings: Settings = Depends(get_settings)) -> str:
"""Provide the configured fail2ban configuration directory."""
return settings.fail2ban_config_dir
async def get_fail2ban_start_command(settings: Settings = Depends(get_settings)) -> str:
"""Provide the configured fail2ban start command."""
return settings.fail2ban_start_command
async def get_session_cache(request: Request) -> SessionCache:
"""Provide the configured session cache backend from application state."""
state = cast("AppState", request.app.state)
session_cache = getattr(state, "session_cache", None)
if session_cache is None:
log.error("session_cache_unavailable")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Session cache is not available.",
)
return session_cache
async def get_app_state(request: Request) -> AppState:
"""Provide the application state object for the current request."""
return cast("AppState", request.app.state)
async def get_app(request: Request) -> FastAPI:
"""Provide the FastAPI application instance for the current request."""
return request.app
async def get_server_status(request: Request) -> ServerStatus:
"""Return the cached fail2ban server status snapshot from app state."""
state = cast("AppState", request.app.state)
return getattr(state, "server_status", ServerStatus(online=False))
async def get_pending_recovery(request: Request) -> PendingRecovery | None:
"""Return the current pending recovery record from app state."""
state = cast("AppState", request.app.state)
return getattr(state, "pending_recovery", None)
async def require_auth(
request: Request,
db: Annotated[aiosqlite.Connection, Depends(get_db)],
settings: Annotated[Settings, Depends(get_settings)],
session_cache: Annotated[SessionCache, Depends(get_session_cache)],
) -> Session:
"""Validate the session token and return the active session.
The token is read from the ``bangui_session`` cookie or the
``Authorization: Bearer`` header.
Validated tokens may be cached in memory for a short period so that
concurrent requests sharing the same token avoid repeated SQLite
round-trips. This cache is disabled by default because process-local
invalidation is not safe in multi-worker or clustered deployments.
When enabled, entries are bypassed on expiry and explicitly cleared by
the configured session cache backend on logout.
Args:
request: The incoming FastAPI request.
db: Injected aiosqlite connection.
settings: Application settings used for signed session token validation.
Returns:
The active :class:`~app.models.auth.Session`.
Raises:
HTTPException: 401 if no valid session token is found.
"""
from app.services import auth_service # noqa: PLC0415
token: str | None = request.cookies.get(_COOKIE_NAME)
if not token:
auth_header: str = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[len("Bearer "):]
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required.",
headers={"WWW-Authenticate": "Bearer"},
)
cache_enabled = _session_cache_enabled(settings)
if cache_enabled:
cached = session_cache.get(token)
if cached is not None:
return cached
try:
session = await auth_service.validate_session(db, token, settings.session_secret)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(exc),
headers={"WWW-Authenticate": "Bearer"},
) from exc
if cache_enabled:
session_cache.set(token, session, settings.session_cache_ttl_seconds)
return session
# Convenience type aliases for route signatures.
DbDep = Annotated[aiosqlite.Connection, Depends(get_db)]
SettingsDep = Annotated[Settings, Depends(get_settings)]
HttpSessionDep = Annotated[aiohttp.ClientSession, Depends(get_http_session)]
SchedulerDep = Annotated[AsyncIOScheduler, Depends(get_scheduler)]
Fail2BanSocketDep = Annotated[str, Depends(get_fail2ban_socket)]
Fail2BanConfigDirDep = Annotated[str, Depends(get_fail2ban_config_dir)]
Fail2BanStartCommandDep = Annotated[str, Depends(get_fail2ban_start_command)]
ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)]
PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)]
SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)]
AppStateDep = Annotated[AppState, Depends(get_app_state)]
AppDep = Annotated[FastAPI, Depends(get_app)]
AuthDep = Annotated[Session, Depends(require_auth)]