Files
BanGUI/backend/app/dependencies.py

259 lines
8.8 KiB
Python

"""FastAPI dependency providers.
All ``Depends()`` callables that inject shared resources (database
connection, settings, services, auth guard) are defined here.
Routers import directly from this module — never from ``app.state``
directly — to keep coupling explicit and testable.
"""
import time
from collections.abc import AsyncGenerator
from typing import Annotated, Protocol, cast
import aiosqlite
import structlog
from fastapi import Depends, HTTPException, Request, status
from app.config import Settings
from app.models.auth import Session
from app.models.config import PendingRecovery
from app.models.server import ServerStatus
from app.utils.time_utils import utc_now
import aiohttp
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
log: structlog.stdlib.BoundLogger = structlog.get_logger()
class AppState(Protocol):
"""Partial view of the FastAPI application state used by dependencies."""
settings: Settings
http_session: aiohttp.ClientSession
scheduler: AsyncIOScheduler
_COOKIE_NAME = "bangui_session"
# ---------------------------------------------------------------------------
# Session validation cache
# ---------------------------------------------------------------------------
#: How long (seconds) a validated session token is served from the in-memory
#: cache without re-querying SQLite. Eliminates repeated DB lookups for the
#: same token arriving in near-simultaneous parallel requests.
_SESSION_CACHE_TTL: float = 10.0
#: ``token → (Session, cache_expiry_monotonic_time)``
_session_cache: dict[str, tuple[Session, float]] = {}
def clear_session_cache() -> None:
"""Flush the entire in-memory session validation cache.
Useful in tests to prevent stale state from leaking between test cases.
"""
_session_cache.clear()
def invalidate_session_cache(token: str) -> None:
"""Evict *token* from the in-memory session cache.
Must be called during logout so the revoked token is no longer served
from cache without a DB round-trip.
Args:
token: The session token to remove.
"""
_session_cache.pop(token, None)
async def get_db(request: Request) -> AsyncGenerator[aiosqlite.Connection, None]:
"""Provide a request-scoped :class:`aiosqlite.Connection` for the current request.
Opens a fresh connection for every request and closes it when the request
is finished. This avoids contention and locking issues from a single shared
SQLite connection across concurrent requests.
Args:
request: The current FastAPI request (injected automatically).
Yields:
An open :class:`aiosqlite.Connection` for the request.
"""
from app.db import open_db # noqa: PLC0415
settings = cast("AppState", request.app.state).settings
try:
db = await open_db(settings.database_path)
except Exception as exc:
log.error("database_open_failed", error=str(exc))
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Database is not available.",
) from exc
try:
yield db
finally:
await db.close()
async def get_settings(request: Request) -> Settings:
"""Provide the :class:`~app.config.Settings` instance from ``app.state``.
Args:
request: The current FastAPI request (injected automatically).
Returns:
The application settings loaded at startup.
"""
state = cast("AppState", request.app.state)
return state.settings
async def get_http_session(request: Request) -> aiohttp.ClientSession:
"""Provide the shared HTTP client session from application state.
Args:
request: The current FastAPI request.
Returns:
A shared :class:`aiohttp.ClientSession` managed by the lifespan.
Raises:
HTTPException: If the session is unavailable.
"""
state = cast("AppState", request.app.state)
http_session = getattr(state, "http_session", None)
if http_session is None:
log.error("http_session_unavailable")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="HTTP session is not available.",
)
return http_session
async def get_scheduler(request: Request) -> AsyncIOScheduler:
"""Provide the shared scheduler from application state.
Args:
request: The current FastAPI request.
Returns:
The :class:`apscheduler.schedulers.asyncio.AsyncIOScheduler` instance.
Raises:
HTTPException: If the scheduler is unavailable.
"""
state = cast("AppState", request.app.state)
scheduler = getattr(state, "scheduler", None)
if scheduler is None:
log.error("scheduler_unavailable")
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Scheduler is not available.",
)
return scheduler
async def get_fail2ban_socket(settings: Settings = Depends(get_settings)) -> str:
"""Provide the configured path to the fail2ban Unix domain socket."""
return settings.fail2ban_socket
async def get_fail2ban_config_dir(settings: Settings = Depends(get_settings)) -> str:
"""Provide the configured fail2ban configuration directory."""
return settings.fail2ban_config_dir
async def get_fail2ban_start_command(settings: Settings = Depends(get_settings)) -> str:
"""Provide the configured fail2ban start command."""
return settings.fail2ban_start_command
async def get_server_status(request: Request) -> ServerStatus:
"""Return the cached fail2ban server status snapshot from app state."""
state = cast("AppState", request.app.state)
return getattr(state, "server_status", ServerStatus(online=False))
async def get_pending_recovery(request: Request) -> PendingRecovery | None:
"""Return the current pending recovery record from app state."""
state = cast("AppState", request.app.state)
return getattr(state, "pending_recovery", None)
async def require_auth(
request: Request,
db: Annotated[aiosqlite.Connection, Depends(get_db)],
) -> Session:
"""Validate the session token and return the active session.
The token is read from the ``bangui_session`` cookie or the
``Authorization: Bearer`` header.
Validated tokens are cached in memory for :data:`_SESSION_CACHE_TTL`
seconds so that concurrent requests sharing the same token avoid repeated
SQLite round-trips. The cache is bypassed on expiry and explicitly
cleared by :func:`invalidate_session_cache` on logout.
Args:
request: The incoming FastAPI request.
db: Injected aiosqlite connection.
Returns:
The active :class:`~app.models.auth.Session`.
Raises:
HTTPException: 401 if no valid session token is found.
"""
from app.services import auth_service # noqa: PLC0415
token: str | None = request.cookies.get(_COOKIE_NAME)
if not token:
auth_header: str = request.headers.get("Authorization", "")
if auth_header.startswith("Bearer "):
token = auth_header[len("Bearer "):]
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication required.",
headers={"WWW-Authenticate": "Bearer"},
)
# Fast path: serve from in-memory cache when the entry is still fresh and
# the session itself has not yet exceeded its own expiry time.
cached = _session_cache.get(token)
if cached is not None:
session, cache_expires_at = cached
if time.monotonic() < cache_expires_at and session.expires_at > utc_now().isoformat():
return session
# Stale cache entry — evict and fall through to DB.
_session_cache.pop(token, None)
try:
session = await auth_service.validate_session(db, token)
except ValueError as exc:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(exc),
headers={"WWW-Authenticate": "Bearer"},
) from exc
_session_cache[token] = (session, time.monotonic() + _SESSION_CACHE_TTL)
return session
# Convenience type aliases for route signatures.
DbDep = Annotated[aiosqlite.Connection, Depends(get_db)]
SettingsDep = Annotated[Settings, Depends(get_settings)]
HttpSessionDep = Annotated[aiohttp.ClientSession, Depends(get_http_session)]
SchedulerDep = Annotated[AsyncIOScheduler, Depends(get_scheduler)]
Fail2BanSocketDep = Annotated[str, Depends(get_fail2ban_socket)]
Fail2BanConfigDirDep = Annotated[str, Depends(get_fail2ban_config_dir)]
Fail2BanStartCommandDep = Annotated[str, Depends(get_fail2ban_start_command)]
ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)]
PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)]
AuthDep = Annotated[Session, Depends(require_auth)]