"""FastAPI dependency providers. All ``Depends()`` callables that inject shared resources (database connection, settings, services, auth guard) are defined here. Routers import directly from this module — never from ``app.state`` directly — to keep coupling explicit and testable. """ import datetime from collections.abc import AsyncGenerator from dataclasses import dataclass from typing import Annotated, Protocol, cast import aiohttp import aiosqlite import structlog from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped] from fastapi import Depends, FastAPI, HTTPException, Request, status from app.config import Settings from app.models.auth import Session from app.models.config import PendingRecovery from app.models.server import ServerStatus from app.repositories.protocols import SessionRepository from app.services.protocols import AuthService, JailService from app.utils.constants import SESSION_COOKIE_NAME from app.utils.runtime_state import RuntimeState from app.utils.session_cache import SessionCache log: structlog.stdlib.BoundLogger = structlog.get_logger() class AppState(Protocol): """Partial view of the FastAPI application state used by dependencies.""" settings: Settings http_session: aiohttp.ClientSession scheduler: AsyncIOScheduler server_status: ServerStatus pending_recovery: PendingRecovery | None last_activation: dict[str, datetime.datetime] | None runtime_settings: Settings | None runtime_state: RuntimeState session_cache: SessionCache @dataclass class ApplicationContext: """A typed wrapper around shared application lifecycle resources.""" settings: Settings http_session: aiohttp.ClientSession | None scheduler: AsyncIOScheduler | None server_status: ServerStatus pending_recovery: PendingRecovery | None last_activation: dict[str, datetime.datetime] | None runtime_settings: Settings | None runtime_state: RuntimeState session_cache: SessionCache | None # --------------------------------------------------------------------------- # Session validation cache # --------------------------------------------------------------------------- #: How long (seconds) a validated session token is served from the in-memory #: cache without re-querying SQLite. Eliminates repeated DB lookups for the #: same token arriving in near-simultaneous parallel requests. #: #: NOTE: this cache is process-local and is not cluster-safe. In multi-worker #: or distributed deployments, the configured cache backend should provide #: invalidation semantics appropriate for the deployment. def _session_cache_enabled(settings: Settings) -> bool: """Return whether the session validation cache should be used.""" return settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0 def _build_app_context(request: Request) -> ApplicationContext: state = cast("AppState", request.app.state) return ApplicationContext( settings=state.settings, http_session=getattr(state, "http_session", None), scheduler=getattr(state, "scheduler", None), server_status=getattr(state, "server_status", ServerStatus(online=False)), pending_recovery=getattr(state, "pending_recovery", None), last_activation=getattr(state, "last_activation", None), runtime_settings=getattr(state, "runtime_settings", None), runtime_state=state.runtime_state, session_cache=getattr(state, "session_cache", None), ) async def get_app_context(request: Request) -> ApplicationContext: """Provide the typed application context for the current request.""" return _build_app_context(request) async def get_settings(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> Settings: """Provide the effective application settings for the current request.""" return app_context.runtime_settings if app_context.runtime_settings is not None else app_context.settings async def get_db( settings: Annotated[Settings, Depends(get_settings)], ) -> AsyncGenerator[aiosqlite.Connection, None]: """Provide a request-scoped :class:`aiosqlite.Connection` for the current request. Opens a fresh connection for every request and closes it when the request is finished. This avoids contention and locking issues from a single shared SQLite connection across concurrent requests. The database path is taken from the effective application settings so runtime overrides stored during setup are respected. Args: settings: The effective application settings for the current request. Yields: An open :class:`aiosqlite.Connection` for the request. """ from app.db import open_db # noqa: PLC0415 try: db = await open_db(settings.database_path) except Exception as exc: log.error("database_open_failed", error=str(exc)) raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Database is not available.", ) from exc try: yield db finally: await db.close() async def get_http_session( app_context: Annotated[ApplicationContext, Depends(get_app_context)], ) -> aiohttp.ClientSession: """Provide the shared HTTP client session from application context. Args: app_context: The injected shared application context. Returns: A shared :class:`aiohttp.ClientSession` managed by the lifespan. Raises: HTTPException: If the session is unavailable. """ if app_context.http_session is None: log.error("http_session_unavailable") raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="HTTP session is not available.", ) return app_context.http_session async def get_scheduler(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> AsyncIOScheduler: """Provide the shared scheduler from application context. Args: app_context: The injected shared application context. Returns: The :class:`apscheduler.schedulers.asyncio.AsyncIOScheduler` instance. Raises: HTTPException: If the scheduler is unavailable. """ if app_context.scheduler is None: log.error("scheduler_unavailable") raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Scheduler is not available.", ) return app_context.scheduler async def get_fail2ban_socket(settings: Settings = Depends(get_settings)) -> str: """Provide the configured path to the fail2ban Unix domain socket.""" return settings.fail2ban_socket async def get_fail2ban_config_dir(settings: Settings = Depends(get_settings)) -> str: """Provide the configured fail2ban configuration directory.""" return settings.fail2ban_config_dir async def get_fail2ban_start_command(settings: Settings = Depends(get_settings)) -> str: """Provide the configured fail2ban start command.""" return settings.fail2ban_start_command async def get_session_cache(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> SessionCache: """Provide the configured session cache backend from application context.""" if app_context.session_cache is None: log.error("session_cache_unavailable") raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Session cache is not available.", ) return app_context.session_cache async def get_auth_service() -> AuthService: """Provide the concrete authentication service implementation.""" from app.services import auth_service # noqa: PLC0415 return cast("AuthService", auth_service) async def get_jail_service() -> JailService: """Provide the concrete jail service implementation.""" from app.services import jail_service # noqa: PLC0415 return cast("JailService", jail_service) async def get_session_repo() -> SessionRepository: """Provide the concrete session repository implementation.""" from app.repositories import session_repo # noqa: PLC0415 return session_repo async def get_app_state(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> ApplicationContext: """Provide the application state object for the current request.""" return app_context async def get_app(request: Request) -> FastAPI: """Provide the FastAPI application instance for the current request.""" return request.app async def get_server_status(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> ServerStatus: """Return the cached fail2ban server status snapshot from application context.""" return app_context.server_status async def get_pending_recovery( app_context: Annotated[ApplicationContext, Depends(get_app_context)], ) -> PendingRecovery | None: """Return the current pending recovery record from application context.""" return app_context.pending_recovery async def require_auth( request: Request, db: Annotated[aiosqlite.Connection, Depends(get_db)], settings: Annotated[Settings, Depends(get_settings)], session_cache: Annotated[SessionCache, Depends(get_session_cache)], auth_service: Annotated[AuthService, Depends(get_auth_service)], session_repo: Annotated[SessionRepository, Depends(get_session_repo)], ) -> Session: """Validate the session token and return the active session. The token is read from the ``bangui_session`` cookie or the ``Authorization: Bearer`` header. Validated tokens may be cached in memory for a short period so that concurrent requests sharing the same token avoid repeated SQLite round-trips. This cache is disabled by default because process-local invalidation is not safe in multi-worker or clustered deployments. When enabled, entries are bypassed on expiry and explicitly cleared by the configured session cache backend on logout. Args: request: The incoming FastAPI request. db: Injected aiosqlite connection. settings: Application settings used for signed session token validation. Returns: The active :class:`~app.models.auth.Session`. Raises: HTTPException: 401 if no valid session token is found. """ token: str | None = request.cookies.get(SESSION_COOKIE_NAME) if not token: auth_header: str = request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): token = auth_header[len("Bearer "):] if not token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required.", headers={"WWW-Authenticate": "Bearer"}, ) cache_enabled = _session_cache_enabled(settings) if cache_enabled: cached = session_cache.get(token) if cached is not None: return cached try: session = await auth_service.validate_session( db, token, settings.session_secret, session_repo=session_repo, ) except ValueError as exc: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc), headers={"WWW-Authenticate": "Bearer"}, ) from exc if cache_enabled: session_cache.set(token, session, settings.session_cache_ttl_seconds) return session # Convenience type aliases for route signatures. DbDep = Annotated[aiosqlite.Connection, Depends(get_db)] SettingsDep = Annotated[Settings, Depends(get_settings)] HttpSessionDep = Annotated[aiohttp.ClientSession, Depends(get_http_session)] SchedulerDep = Annotated[AsyncIOScheduler, Depends(get_scheduler)] Fail2BanSocketDep = Annotated[str, Depends(get_fail2ban_socket)] Fail2BanConfigDirDep = Annotated[str, Depends(get_fail2ban_config_dir)] Fail2BanStartCommandDep = Annotated[str, Depends(get_fail2ban_start_command)] ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)] PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)] SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)] AuthServiceDep = Annotated[AuthService, Depends(get_auth_service)] JailServiceDep = Annotated[JailService, Depends(get_jail_service)] SessionRepoDep = Annotated[SessionRepository, Depends(get_session_repo)] AppStateDep = Annotated[AppState, Depends(get_app_state)] AppDep = Annotated[FastAPI, Depends(get_app)] AuthDep = Annotated[Session, Depends(require_auth)]