"""FastAPI dependency providers and composition root. This module is BanGUI's dependency injection composition root. All injectable resources — database connections, settings, services, repositories, and authentication guards — are defined here as provider functions. **Key Principles:** 1. **Composition Root Pattern**: No heavyweight DI container is used. Instead, FastAPI's `Depends()` framework manages all dependencies, keeping the pattern lightweight and explicit. 2. **Explicit Over Implicit**: Every dependency is declared in function signatures. There is no hidden coupling or magic. This makes the dependency graph visible to type checkers, debuggers, and developers. 3. **Service Context Dependencies**: Related resources (e.g., db + repository) are bundled into context objects (SessionServiceContext, BlocklistServiceContext) to prevent routers from accessing raw database connections. 4. **Repository Boundary Enforcement**: Routers must NOT import DbDep. They depend on service context dependencies instead, which contain both the database connection and the necessary repositories. This ensures repositories are the only modules executing SQL. See Architekture.md § 2.3 (Dependency Wiring and Service Composition) for a complete guide to the DI pattern, including examples of adding new services. See Backend-Development.md § 6 for dependency layering rules. """ import datetime from collections.abc import AsyncGenerator, Awaitable, Callable from dataclasses import dataclass from typing import Annotated, cast import aiohttp import aiosqlite from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped] from fastapi import Depends, FastAPI, HTTPException, Request, status from app.config import Settings from app.exceptions import RateLimitError from app.models.auth import Session from app.models.config import PendingRecovery from app.models.server import ServerStatus # Module-level imports for repositories and services # These are safe at module level since no circular dependencies exist from app.repositories import ( blocklist_repo, fail2ban_db_repo, geo_cache_repo, history_archive_repo, import_log_repo, import_run_repo, session_repo, settings_repo, ) from app.repositories.protocols import ( BlocklistRepository, Fail2BanDbRepository, GeoCacheRepository, HistoryArchiveRepository, ImportLogRepository, ImportRunRepository, SessionRepository, SettingsRepository, ) from app.services import auth_service, health_service from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service from app.services.geo_cache import GeoCache from app.services.protocols import Fail2BanMetadataService from app.utils.constants import SESSION_COOKIE_NAME from app.utils.logging_compat import get_logger from app.utils.rate_limiter import GlobalRateLimiter from app.utils.runtime_state import ApplicationState, JailServiceState, RuntimeState from app.utils.session_cache import NoOpSessionCache, SessionCache log = get_logger(__name__) @dataclass class ApplicationContext: """A typed wrapper around shared application lifecycle resources.""" settings: Settings http_session: aiohttp.ClientSession | None scheduler: AsyncIOScheduler | None server_status: ServerStatus pending_recovery: PendingRecovery | None last_activation: dict[str, datetime.datetime] | None runtime_settings: Settings | None runtime_state: RuntimeState session_cache: SessionCache | None global_rate_limiter: GlobalRateLimiter # --------------------------------------------------------------------------- # Session validation cache # --------------------------------------------------------------------------- #: How long (seconds) a validated session token is served from the in-memory #: cache without re-querying SQLite. Eliminates repeated DB lookups for the #: same token arriving in near-simultaneous parallel requests. #: #: NOTE: this cache is process-local and is not cluster-safe. In multi-worker #: or distributed deployments, the configured cache backend should provide #: invalidation semantics appropriate for the deployment. def _session_cache_enabled(settings: Settings) -> bool: """Return whether the session validation cache should be used.""" return settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0 def _build_app_context(request: Request) -> ApplicationContext: state = cast("ApplicationState", request.app.state) session_cache = getattr(state, "session_cache", None) if session_cache is None: session_cache = NoOpSessionCache() global_rate_limiter: GlobalRateLimiter = getattr(state, "global_rate_limiter", None) if global_rate_limiter is None: global_rate_limiter = GlobalRateLimiter() return ApplicationContext( settings=state.settings, http_session=getattr(state, "http_session", None), scheduler=getattr(state, "scheduler", None), server_status=getattr(state, "server_status", ServerStatus(online=False)), pending_recovery=getattr(state, "pending_recovery", None), last_activation=getattr(state, "last_activation", None), runtime_settings=getattr(state, "runtime_settings", None), runtime_state=state.runtime_state, session_cache=session_cache, global_rate_limiter=global_rate_limiter, ) async def get_app_context(request: Request) -> ApplicationContext: """Provide the typed application context for the current request.""" return _build_app_context(request) async def get_settings(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> Settings: """Provide the effective application settings for the current request.""" return app_context.runtime_settings if app_context.runtime_settings is not None else app_context.settings async def get_db( settings: Annotated[Settings, Depends(get_settings)], ) -> AsyncGenerator[aiosqlite.Connection, None]: """Provide a request-scoped :class:`aiosqlite.Connection` for the current request. Opens a fresh connection for every request and closes it when the request is finished. This avoids contention and locking issues from a single shared SQLite connection across concurrent requests. The database path is taken from the effective application settings so runtime overrides stored during setup are respected. Args: settings: The effective application settings for the current request. Yields: An open :class:`aiosqlite.Connection` for the request. Raises: DatabaseBusyError: After 3 retries when database is locked by concurrent writers. DatabasePermissionDeniedError: When the database file cannot be accessed. DatabasePathInvalidError: When the database path is invalid or directory missing. DatabaseCorruptedError: When the database file is corrupted. DatabaseUnavailableError: For any other unexpected database error. """ from app.db import open_db # noqa: PLC0415 from app.exceptions import ( DatabaseBusyError, DatabaseCorruptedError, DatabasePathInvalidError, DatabasePermissionDeniedError, DatabaseUnavailableError, ) db = None retries = 3 retry_delay = 0.1 last_exc = None for attempt in range(1, retries + 1): try: db = await open_db(settings.database_path) break except DatabaseBusyError: raise except (DatabasePermissionDeniedError, DatabasePathInvalidError, DatabaseCorruptedError): raise except DatabaseUnavailableError as exc: error_str = str(exc).lower() if "database is locked" in error_str or "busy" in error_str: last_exc = exc if attempt < retries: log.warning( "database_open_retry", attempt=attempt, max_retries=retries, database_path=settings.database_path, ) import asyncio await asyncio.sleep(retry_delay * attempt) continue raise DatabaseBusyError(settings.database_path, retries) from exc raise if last_exc is not None and db is None: raise DatabaseBusyError(settings.database_path, retries) try: yield db finally: if db is not None: await db.close() async def get_http_session( app_context: Annotated[ApplicationContext, Depends(get_app_context)], ) -> aiohttp.ClientSession: """Provide the shared HTTP client session from application context. Args: app_context: The injected shared application context. Returns: A shared :class:`aiohttp.ClientSession` managed by the lifespan. Raises: HTTPException: If the session is unavailable. """ if app_context.http_session is None: log.error("http_session_unavailable") raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="HTTP session is not available.", ) return app_context.http_session async def get_scheduler(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> AsyncIOScheduler: """Provide the shared scheduler from application context. Args: app_context: The injected shared application context. Returns: The :class:`apscheduler.schedulers.asyncio.AsyncIOScheduler` instance. Raises: HTTPException: If the scheduler is unavailable. """ if app_context.scheduler is None: log.error("scheduler_unavailable") raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Scheduler is not available.", ) return app_context.scheduler async def get_fail2ban_socket(settings: Settings = Depends(get_settings)) -> str: """Provide the configured path to the fail2ban Unix domain socket.""" return settings.fail2ban_socket async def get_fail2ban_config_dir(settings: Settings = Depends(get_settings)) -> str: """Provide the configured fail2ban configuration directory.""" return settings.fail2ban_config_dir async def get_fail2ban_start_command(settings: Settings = Depends(get_settings)) -> str: """Provide the configured fail2ban start command.""" return settings.fail2ban_start_command async def get_geo_cache(request: Request) -> GeoCache: """Provide the application's GeoCache instance.""" geo_cache: GeoCache = cast("GeoCache", request.app.state.geo_cache) return geo_cache async def get_session_cache(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> SessionCache: """Provide the configured session cache backend from application context.""" if app_context.session_cache is None: log.error("session_cache_unavailable") raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Session cache is not available.", ) return app_context.session_cache async def get_global_rate_limiter( app_context: Annotated[ApplicationContext, Depends(get_app_context)], ) -> GlobalRateLimiter: """Provide the global rate limiter from application context.""" return app_context.global_rate_limiter def rate_limit_dependency( bucket: str, max_requests: int, window_seconds: int, ) -> Callable[[Request, "GlobalRateLimiter"], None]: """Create a rate limit dependency for a specific bucket and limit. Use this factory to create per-endpoint rate limit dependencies. Each call returns a configured dependency that enforces the specified rate limit before the endpoint handler runs. Args: bucket: Bucket name (e.g., "bans:ban", "blocklist:import"). max_requests: Maximum requests allowed within the window. window_seconds: Time window for this bucket. Returns: A callable that can be used as a FastAPI Depends() dependency. """ async def check_rate_limit( request: Request, rate_limiter: GlobalRateLimiterDep, ) -> None: from app.utils.client_ip import get_client_ip settings: Settings = request.app.state.settings client_ip = get_client_ip(request, trusted_proxies=settings.trusted_proxies) is_allowed, retry_after = rate_limiter.check_allowed_for_bucket(bucket, client_ip, max_requests, window_seconds) if not is_allowed: log.warning( "operation_rate_limit_exceeded", client_ip=client_ip, bucket=bucket, path=request.url.path, method=request.method, retry_after=retry_after, ) raise RateLimitError( f"Rate limit exceeded for {bucket}. Please try again later.", retry_after_seconds=retry_after, ) return check_rate_limit async def get_session_repo() -> SessionRepository: """Provide the concrete session repository implementation. The session_repo module uses structural typing to satisfy the SessionRepository Protocol interface — its top-level async functions must match the Protocol signatures exactly. This is documented in Backend-Development.md § 13.7.1. """ return session_repo async def get_blocklist_repo() -> BlocklistRepository: """Provide the concrete blocklist repository implementation. The blocklist_repo module uses structural typing to satisfy the BlocklistRepository Protocol interface — its top-level async functions must match the Protocol signatures exactly. This is documented in Backend-Development.md § 13.7.1. """ return cast("BlocklistRepository", blocklist_repo) async def get_import_log_repo() -> ImportLogRepository: """Provide the concrete import log repository implementation. The import_log_repo module uses structural typing to satisfy the ImportLogRepository Protocol interface — its top-level async functions must match the Protocol signatures exactly. This is documented in Backend-Development.md § 13.7.1. """ return cast("ImportLogRepository", import_log_repo) async def get_import_run_repo() -> ImportRunRepository: """Provide the concrete import run repository implementation. The import_run_repo module uses structural typing to satisfy the ImportRunRepository Protocol interface for tracking blocklist imports for idempotency detection. """ return cast("ImportRunRepository", import_run_repo) async def get_settings_repo() -> SettingsRepository: """Provide the concrete settings repository implementation. The settings_repo module uses structural typing to satisfy the SettingsRepository Protocol interface — its top-level async functions must match the Protocol signatures exactly. This is documented in Backend-Development.md § 13.7.1. """ return cast("SettingsRepository", settings_repo) async def get_history_archive_repo() -> HistoryArchiveRepository: """Provide the concrete history archive repository implementation. The history_archive_repo module uses structural typing to satisfy the HistoryArchiveRepository Protocol interface — its top-level async functions must match the Protocol signatures exactly. This is documented in Backend-Development.md § 13.7.1. """ return cast("HistoryArchiveRepository", history_archive_repo) async def get_geo_cache_repo() -> GeoCacheRepository: """Provide the concrete geo cache repository implementation. The geo_cache_repo module uses structural typing to satisfy the GeoCacheRepository Protocol interface — its top-level async functions must match the Protocol signatures exactly. This is documented in Backend-Development.md § 13.7.1. """ return cast("GeoCacheRepository", geo_cache_repo) async def get_fail2ban_db_repo() -> Fail2BanDbRepository: """Provide the concrete fail2ban DB repository implementation. The fail2ban_db_repo module uses structural typing to satisfy the Fail2BanDbRepository Protocol interface — its top-level async functions must match the Protocol signatures exactly. This is documented in Backend-Development.md § 13.7.1. """ return cast("Fail2BanDbRepository", fail2ban_db_repo) async def get_app_state(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> ApplicationContext: """Provide the application state object for the current request.""" return app_context async def get_app(request: Request) -> FastAPI: """Provide the FastAPI application instance for the current request.""" return request.app async def get_server_status(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> ServerStatus: """Return the cached fail2ban server status snapshot from application context.""" if app_context.server_status is None: return ServerStatus(online=False) return app_context.server_status async def get_pending_recovery( app_context: Annotated[ApplicationContext, Depends(get_app_context)], ) -> PendingRecovery | None: """Return the current pending recovery record from application context.""" return app_context.pending_recovery async def get_jail_service_state( app_context: Annotated[ApplicationContext, Depends(get_app_context)], ) -> JailServiceState: """Return the jail service state holder from runtime state. Returns: The JailServiceState containing capability detection cache and synchronization primitives for jail operations. """ return app_context.runtime_state.jail_service_state async def get_health_probe() -> Callable[[str], Awaitable[ServerStatus]]: """Provide the health probe function for checking fail2ban connectivity. Returns: A callable that probes the fail2ban socket and returns ServerStatus. This allows explicit dependency injection to avoid hidden service coupling. """ return health_service.probe async def get_fail2ban_metadata_service() -> object: """Provide the Fail2BanMetadataService instance. Returns: The singleton Fail2BanMetadataService for resolving fail2ban metadata (such as the database path) and caching results. """ return default_fail2ban_metadata_service # ----------------------------------------------------------------------- # Service facade dependencies (db + repositories combined) # These are for routers that need database access through services. # Routers should depend on these instead of raw database connections. # ----------------------------------------------------------------------- @dataclass class SessionServiceContext: """Context for session-related database operations. Combines the database connection and session repository so that routers don't need to import DbDep directly. """ db: aiosqlite.Connection session_repo: SessionRepository async def get_session_service_context( db: Annotated[aiosqlite.Connection, Depends(get_db)], session_repo: Annotated[SessionRepository, Depends(get_session_repo)], ) -> SessionServiceContext: """Provide combined session database context for routers. Args: db: Request-scoped database connection. session_repo: Session repository implementation. Returns: SessionServiceContext with both db and repository. """ return SessionServiceContext(db=db, session_repo=session_repo) @dataclass class BlocklistServiceContext: """Context for blocklist-related database operations. Combines the database connection and blocklist-related repositories so that routers don't need to import DbDep directly. """ db: aiosqlite.Connection blocklist_repo: BlocklistRepository import_log_repo: ImportLogRepository settings_repo: SettingsRepository async def get_blocklist_service_context( db: Annotated[aiosqlite.Connection, Depends(get_db)], blocklist_repo: Annotated[BlocklistRepository, Depends(get_blocklist_repo)], import_log_repo: Annotated[ImportLogRepository, Depends(get_import_log_repo)], settings_repo: Annotated[SettingsRepository, Depends(get_settings_repo)], ) -> BlocklistServiceContext: """Provide combined blocklist database context for routers. Args: db: Request-scoped database connection. blocklist_repo: Blocklist repository implementation. import_log_repo: Import log repository implementation. settings_repo: Settings repository implementation. Returns: BlocklistServiceContext with db and all blocklist repositories. """ return BlocklistServiceContext( db=db, blocklist_repo=blocklist_repo, import_log_repo=import_log_repo, settings_repo=settings_repo, ) @dataclass class SettingsServiceContext: """Context for settings-related database operations. Combines the database connection and settings repository so that routers don't need to import DbDep directly. """ db: aiosqlite.Connection settings_repo: SettingsRepository async def get_settings_service_context( db: Annotated[aiosqlite.Connection, Depends(get_db)], settings_repo: Annotated[SettingsRepository, Depends(get_settings_repo)], ) -> SettingsServiceContext: """Provide combined settings database context for routers. Args: db: Request-scoped database connection. settings_repo: Settings repository implementation. Returns: SettingsServiceContext with both db and repository. """ return SettingsServiceContext(db=db, settings_repo=settings_repo) @dataclass class BanServiceContext: """Context for ban-related database operations. Combines the database connection and fail2ban DB repository. """ db: aiosqlite.Connection fail2ban_db_repo: Fail2BanDbRepository async def get_ban_service_context( db: Annotated[aiosqlite.Connection, Depends(get_db)], fail2ban_db_repo: Annotated[Fail2BanDbRepository, Depends(get_fail2ban_db_repo)], ) -> BanServiceContext: """Provide combined ban database context for routers. Args: db: Request-scoped database connection. fail2ban_db_repo: Fail2Ban DB repository implementation. Returns: BanServiceContext with both db and repository. """ return BanServiceContext(db=db, fail2ban_db_repo=fail2ban_db_repo) @dataclass class HistoryServiceContext: """Context for history-related database operations. Combines database connection and history-related repositories. """ db: aiosqlite.Connection fail2ban_db_repo: Fail2BanDbRepository history_archive_repo: HistoryArchiveRepository async def get_history_service_context( db: Annotated[aiosqlite.Connection, Depends(get_db)], fail2ban_db_repo: Annotated[Fail2BanDbRepository, Depends(get_fail2ban_db_repo)], history_archive_repo: Annotated[HistoryArchiveRepository, Depends(get_history_archive_repo)], ) -> HistoryServiceContext: """Provide combined history database context for routers. Args: db: Request-scoped database connection. fail2ban_db_repo: Fail2Ban DB repository implementation. history_archive_repo: History archive repository implementation. Returns: HistoryServiceContext with db and all history repositories. """ return HistoryServiceContext( db=db, fail2ban_db_repo=fail2ban_db_repo, history_archive_repo=history_archive_repo, ) # Internal database dependency for use by other dependencies only # Routers should NOT import this - they should use repository dependencies instead _DbDep = Annotated[aiosqlite.Connection, Depends(get_db)] async def require_auth( request: Request, db: _DbDep, settings: Annotated[Settings, Depends(get_settings)], session_cache: Annotated[SessionCache, Depends(get_session_cache)], session_repo: Annotated[SessionRepository, Depends(get_session_repo)], ) -> Session: """Validate the session token and return the active session. The token is read from the ``bangui_session`` cookie or the ``Authorization: Bearer`` header. Validated tokens may be cached in memory for a short period so that concurrent requests sharing the same token avoid repeated SQLite round-trips. This cache is disabled by default because process-local invalidation is not safe in multi-worker or clustered deployments. When enabled, entries are bypassed on expiry and explicitly cleared by the configured session cache backend on logout. Args: request: The incoming FastAPI request. db: Injected aiosqlite connection (for repository operations). settings: Application settings used for signed session token validation. session_cache: Session validation cache backend. session_repo: Session repository for persistence operations. Returns: The active :class:`~app.models.auth.Session`. Raises: HTTPException: 401 if no valid session token is found. """ token: str | None = request.cookies.get(SESSION_COOKIE_NAME) if not token: auth_header: str = request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): token = auth_header[len("Bearer ") :] if not token: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Authentication required.", headers={"WWW-Authenticate": "Bearer"}, ) cache_enabled = _session_cache_enabled(settings) if cache_enabled: cached = session_cache.get(token) if cached is not None: return cached try: session = await auth_service.validate_session( db, token, settings.session_secret, settings.session_secret_previous, session_repo=session_repo, ) except ValueError as exc: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail=str(exc), headers={"WWW-Authenticate": "Bearer"}, ) from exc if cache_enabled: session_cache.set(token, session, settings.session_cache_ttl_seconds) return session # Convenience type aliases for route signatures. # NOTE: Database connections are NOT exported to routers. Routers should depend on # repository dependencies (SessionRepoDep, BlocklistRepositoryDep, etc.) instead. # See Backend-Development.md for the dependency layering rules. SettingsDep = Annotated[Settings, Depends(get_settings)] HttpSessionDep = Annotated[aiohttp.ClientSession, Depends(get_http_session)] SchedulerDep = Annotated[AsyncIOScheduler, Depends(get_scheduler)] Fail2BanSocketDep = Annotated[str, Depends(get_fail2ban_socket)] Fail2BanConfigDirDep = Annotated[str, Depends(get_fail2ban_config_dir)] Fail2BanStartCommandDep = Annotated[str, Depends(get_fail2ban_start_command)] GeoCacheDep = Annotated[GeoCache, Depends(get_geo_cache)] ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)] PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)] JailServiceStateDep = Annotated[JailServiceState, Depends(get_jail_service_state)] HealthProbeDep = Annotated[Callable[[str], Awaitable[ServerStatus]], Depends(get_health_probe)] SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)] SessionRepoDep = Annotated[SessionRepository, Depends(get_session_repo)] SettingsRepoDep = Annotated[SettingsRepository, Depends(get_settings_repo)] HistoryArchiveRepositoryDep = Annotated[HistoryArchiveRepository, Depends(get_history_archive_repo)] BlocklistRepositoryDep = Annotated[BlocklistRepository, Depends(get_blocklist_repo)] ImportLogRepositoryDep = Annotated[ImportLogRepository, Depends(get_import_log_repo)] ImportRunRepositoryDep = Annotated[ImportRunRepository, Depends(get_import_run_repo)] GeoCacheRepositoryDep = Annotated[GeoCacheRepository, Depends(get_geo_cache_repo)] Fail2BanDbRepositoryDep = Annotated[Fail2BanDbRepository, Depends(get_fail2ban_db_repo)] AppStateDep = Annotated[ApplicationContext, Depends(get_app_state)] AppDep = Annotated[FastAPI, Depends(get_app)] AuthDep = Annotated[Session, Depends(require_auth)] GlobalRateLimiterDep = Annotated[GlobalRateLimiter, Depends(get_global_rate_limiter)] Fail2BanMetadataServiceDep = Annotated[Fail2BanMetadataService, Depends(get_fail2ban_metadata_service)] # Service context dependencies (db + repositories combined for routers) # Routers should use these instead of importing DbDep directly. SessionServiceContextDep = Annotated[SessionServiceContext, Depends(get_session_service_context)] BlocklistServiceContextDep = Annotated[BlocklistServiceContext, Depends(get_blocklist_service_context)] SettingsServiceContextDep = Annotated[SettingsServiceContext, Depends(get_settings_service_context)] BanServiceContextDep = Annotated[BanServiceContext, Depends(get_ban_service_context)] HistoryServiceContextDep = Annotated[HistoryServiceContext, Depends(get_history_service_context)] # DEPRECATED: DbDep is provided for backward compatibility only. # DO NOT use in new code. Use repository dependencies instead (SessionRepoDep, BlocklistRepositoryDep, etc.) # See Backend-Development.md § 6 for dependency layering rules. DbDep = _DbDep