Files
BanGUI/backend/app/dependencies.py

268 lines
9.1 KiB
Python

"""FastAPI dependency providers.
All ``Depends()`` callables that inject shared resources (database
connection, settings, services, auth guard) are defined here.
Routers import directly from this module — never from ``app.state``
directly — to keep coupling explicit and testable.
"""
import datetime
import time
from collections.abc import AsyncGenerator
from typing import Annotated, Protocol, cast
import aiohttp
import aiosqlite
import structlog
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
from fastapi import Depends, HTTPException, Request, status
from app.config import Settings
from app.models.auth import Session
from app.models.config import PendingRecovery
from app.models.server import ServerStatus
from app.utils.time_utils import utc_now
log: structlog.stdlib.BoundLogger = structlog.get_logger()
class AppState(Protocol):
"""Partial view of the FastAPI application state used by dependencies."""
settings: Settings
http_session: aiohttp.ClientSession
scheduler: AsyncIOScheduler
server_status: ServerStatus
pending_recovery: PendingRecovery | None
last_activation: dict[str, datetime.datetime] | None
_COOKIE_NAME = "bangui_session"
# ---------------------------------------------------------------------------
# Session validation cache
# ---------------------------------------------------------------------------
#: How long (seconds) a validated session token is served from the in-memory
#: cache without re-querying SQLite. Eliminates repeated DB lookups for the
#: same token arriving in near-simultaneous parallel requests.
_SESSION_CACHE_TTL: float = 10.0
#: ``token → (Session, cache_expiry_monotonic_time)``
_session_cache: dict[str, tuple[Session, float]] = {}
def clear_session_cache() -> None:
"""Flush the entire in-memory session validation cache.
Useful in tests to prevent stale state from leaking between test cases.
"""
_session_cache.clear()
def invalidate_session_cache(token: str) -> None:
"""Evict *token* from the in-memory session cache.
Must be called during logout so the revoked token is no longer served
from cache without a DB round-trip.
Args:
token: The session token to remove.
"""
_session_cache.pop(token, None)
async def get_db(request: Request) -> AsyncGenerator[aiosqlite.Connection, None]:
"""Provide a request-scoped :class:`aiosqlite.Connection` for the current request.
Opens a fresh connection for every request and closes it when the request
is finished. This avoids contention and locking issues from a single shared
SQLite connection across concurrent requests.
Args:
request: The current FastAPI request (injected automatically).
Yields:
An open :class:`aiosqlite.Connection` for the request.
"""
from app.db import open_db # noqa: PLC0415
settings = cast("AppState", request.app.state).settings
try:
db = await open_db(settings.database_path)
except Exception as exc:
log.error("database_open_failed", error=str(exc))
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Database is not available.",
) from exc
try:
yield db
finally:
await db.close()
async def get_settings(request: Request) -> Settings:
"""Provide the :class:`~app.config.Settings` instance from ``app.state``.
Args:
request: The current FastAPI request (injected automatically).
Returns:
The application settings loaded at startup.
"""
state = cast("AppState", request.app.state)
return state.settings
async def get_http_session(request: Request) -> aiohttp.ClientSession:
"""Provide the shared HTTP client session from application state.
Args:
request: The current FastAPI request.
Returns:
A shared :class:`aiohttp.ClientSession` managed by the lifespan.
Raises:
HTTPException: If the session is unavailable.
"""
state = cast("AppState", request.app.state)
http_session = getattr(state, "http_session", None)
if http_session is None:
log.error("http_session_unavailable")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="HTTP session is not available.",
)
return http_session
async def get_scheduler(request: Request) -> AsyncIOScheduler:
"""Provide the shared scheduler from application state.
Args:
request: The current FastAPI request.
Returns:
The :class:`apscheduler.schedulers.asyncio.AsyncIOScheduler` instance.
Raises:
HTTPException: If the scheduler is unavailable.
"""
state = cast("AppState", request.app.state)
scheduler = getattr(state, "scheduler", None)
if scheduler is None:
log.error("scheduler_unavailable")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Scheduler is not available.",
)
return scheduler
async def get_fail2ban_socket(settings: Settings = Depends(get_settings)) -> str:
"""Provide the configured path to the fail2ban Unix domain socket."""
return settings.fail2ban_socket
async def get_fail2ban_config_dir(settings: Settings = Depends(get_settings)) -> str:
"""Provide the configured fail2ban configuration directory."""
return settings.fail2ban_config_dir
async def get_fail2ban_start_command(settings: Settings = Depends(get_settings)) -> str:
"""Provide the configured fail2ban start command."""
return settings.fail2ban_start_command
async def get_app_state(request: Request) -> AppState:
"""Provide the application state object for the current request."""
return cast("AppState", request.app.state)
async def get_server_status(request: Request) -> ServerStatus:
"""Return the cached fail2ban server status snapshot from app state."""
state = cast("AppState", request.app.state)
return getattr(state, "server_status", ServerStatus(online=False))
async def get_pending_recovery(request: Request) -> PendingRecovery | None:
"""Return the current pending recovery record from app state."""
state = cast("AppState", request.app.state)
return getattr(state, "pending_recovery", None)
async def require_auth(
request: Request,
db: Annotated[aiosqlite.Connection, Depends(get_db)],
) -> Session:
"""Validate the session token and return the active session.
The token is read from the ``bangui_session`` cookie or the
``Authorization: Bearer`` header.
Validated tokens are cached in memory for :data:`_SESSION_CACHE_TTL`
seconds so that concurrent requests sharing the same token avoid repeated
SQLite round-trips. The cache is bypassed on expiry and explicitly
cleared by :func:`invalidate_session_cache` on logout.
Args:
request: The incoming FastAPI request.
db: Injected aiosqlite connection.
Returns:
The active :class:`~app.models.auth.Session`.
Raises:
HTTPException: 401 if no valid session token is found.
"""
from app.services import auth_service # noqa: PLC0415
token: str | None = request.cookies.get(_COOKIE_NAME)
if not token:
auth_header: str = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[len("Bearer "):]
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required.",
headers={"WWW-Authenticate": "Bearer"},
)
# Fast path: serve from in-memory cache when the entry is still fresh and
# the session itself has not yet exceeded its own expiry time.
cached = _session_cache.get(token)
if cached is not None:
session, cache_expires_at = cached
if time.monotonic() < cache_expires_at and session.expires_at > utc_now().isoformat():
return session
# Stale cache entry — evict and fall through to DB.
_session_cache.pop(token, None)
try:
session = await auth_service.validate_session(db, token)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(exc),
headers={"WWW-Authenticate": "Bearer"},
) from exc
_session_cache[token] = (session, time.monotonic() + _SESSION_CACHE_TTL)
return session
# Convenience type aliases for route signatures.
DbDep = Annotated[aiosqlite.Connection, Depends(get_db)]
SettingsDep = Annotated[Settings, Depends(get_settings)]
HttpSessionDep = Annotated[aiohttp.ClientSession, Depends(get_http_session)]
SchedulerDep = Annotated[AsyncIOScheduler, Depends(get_scheduler)]
Fail2BanSocketDep = Annotated[str, Depends(get_fail2ban_socket)]
Fail2BanConfigDirDep = Annotated[str, Depends(get_fail2ban_config_dir)]
Fail2BanStartCommandDep = Annotated[str, Depends(get_fail2ban_start_command)]
ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)]
PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)]
AppStateDep = Annotated[AppState, Depends(get_app_state)]
AuthDep = Annotated[Session, Depends(require_auth)]