- Add in-memory rate limiter with per-IP deque tracking of attempt timestamps - Limit login attempts to 5 per 60 seconds per IP, return 429 on excess - Add Retry-After header to rate limit responses - Implement IP extraction utility with proxy trust validation (prevent X-Forwarded-For spoofing) - Integrate rate limiter into auth router and dependencies - Add 10-second asyncio.sleep on failed login attempts to further slow brute-force - Add comprehensive tests for rate limiting (9 new tests, all passing) - Update Features.md to document login rate limiting - Update Backend-Development.md with rate limiting conventions and design patterns - Fix test infrastructure issues: update password to meet complexity requirements - Fix TestValidateSession tests to use Bearer token authentication - All tests passing: 23 auth tests + full test suite coverage Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
428 lines
16 KiB
Python
428 lines
16 KiB
Python
"""FastAPI dependency providers.
|
|
|
|
All ``Depends()`` callables that inject shared resources (database
|
|
connection, settings, services, auth guard) are defined here.
|
|
Routers import directly from this module — never from ``app.state``
|
|
directly — to keep coupling explicit and testable.
|
|
"""
|
|
|
|
import datetime
|
|
from collections.abc import AsyncGenerator
|
|
from dataclasses import dataclass
|
|
from typing import Annotated, cast
|
|
|
|
import aiohttp
|
|
import aiosqlite
|
|
import structlog
|
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
|
|
from fastapi import Depends, FastAPI, HTTPException, Request, status
|
|
|
|
from app.config import Settings
|
|
from app.models.auth import Session
|
|
from app.models.config import PendingRecovery
|
|
from app.models.server import ServerStatus
|
|
from app.repositories.protocols import (
|
|
BlocklistRepository,
|
|
Fail2BanDbRepository,
|
|
GeoCacheRepository,
|
|
HistoryArchiveRepository,
|
|
ImportLogRepository,
|
|
SessionRepository,
|
|
SettingsRepository,
|
|
)
|
|
from app.services.geo_cache import GeoCache
|
|
from app.utils.constants import SESSION_COOKIE_NAME
|
|
from app.utils.rate_limiter import RateLimiter
|
|
from app.utils.runtime_state import ApplicationState, RuntimeState
|
|
from app.utils.session_cache import NoOpSessionCache, SessionCache
|
|
|
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
|
|
|
|
|
@dataclass
|
|
class ApplicationContext:
|
|
"""A typed wrapper around shared application lifecycle resources."""
|
|
|
|
settings: Settings
|
|
http_session: aiohttp.ClientSession | None
|
|
scheduler: AsyncIOScheduler | None
|
|
server_status: ServerStatus
|
|
pending_recovery: PendingRecovery | None
|
|
last_activation: dict[str, datetime.datetime] | None
|
|
runtime_settings: Settings | None
|
|
runtime_state: RuntimeState
|
|
session_cache: SessionCache | None
|
|
login_rate_limiter: RateLimiter
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Session validation cache
|
|
# ---------------------------------------------------------------------------
|
|
|
|
#: How long (seconds) a validated session token is served from the in-memory
|
|
#: cache without re-querying SQLite. Eliminates repeated DB lookups for the
|
|
#: same token arriving in near-simultaneous parallel requests.
|
|
#:
|
|
#: NOTE: this cache is process-local and is not cluster-safe. In multi-worker
|
|
#: or distributed deployments, the configured cache backend should provide
|
|
#: invalidation semantics appropriate for the deployment.
|
|
|
|
def _session_cache_enabled(settings: Settings) -> bool:
|
|
"""Return whether the session validation cache should be used."""
|
|
return settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0
|
|
|
|
|
|
def _build_app_context(request: Request) -> ApplicationContext:
|
|
state = cast("ApplicationState", request.app.state)
|
|
session_cache = getattr(state, "session_cache", None)
|
|
if session_cache is None:
|
|
session_cache = NoOpSessionCache()
|
|
|
|
login_rate_limiter: RateLimiter = getattr(state, "login_rate_limiter", None)
|
|
if login_rate_limiter is None:
|
|
login_rate_limiter = RateLimiter()
|
|
|
|
return ApplicationContext(
|
|
settings=state.settings,
|
|
http_session=getattr(state, "http_session", None),
|
|
scheduler=getattr(state, "scheduler", None),
|
|
server_status=getattr(state, "server_status", ServerStatus(online=False)),
|
|
pending_recovery=getattr(state, "pending_recovery", None),
|
|
last_activation=getattr(state, "last_activation", None),
|
|
runtime_settings=getattr(state, "runtime_settings", None),
|
|
runtime_state=state.runtime_state,
|
|
session_cache=session_cache,
|
|
login_rate_limiter=login_rate_limiter,
|
|
)
|
|
|
|
|
|
async def get_app_context(request: Request) -> ApplicationContext:
|
|
"""Provide the typed application context for the current request."""
|
|
return _build_app_context(request)
|
|
|
|
|
|
async def get_settings(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> Settings:
|
|
"""Provide the effective application settings for the current request."""
|
|
return app_context.runtime_settings if app_context.runtime_settings is not None else app_context.settings
|
|
|
|
|
|
async def get_db(
|
|
settings: Annotated[Settings, Depends(get_settings)],
|
|
) -> AsyncGenerator[aiosqlite.Connection, None]:
|
|
"""Provide a request-scoped :class:`aiosqlite.Connection` for the current request.
|
|
|
|
Opens a fresh connection for every request and closes it when the request
|
|
is finished. This avoids contention and locking issues from a single shared
|
|
SQLite connection across concurrent requests.
|
|
|
|
The database path is taken from the effective application settings so
|
|
runtime overrides stored during setup are respected.
|
|
|
|
Args:
|
|
settings: The effective application settings for the current request.
|
|
|
|
Yields:
|
|
An open :class:`aiosqlite.Connection` for the request.
|
|
"""
|
|
from app.db import open_db # noqa: PLC0415
|
|
|
|
try:
|
|
db = await open_db(settings.database_path)
|
|
except Exception as exc:
|
|
log.error("database_open_failed", error=str(exc))
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Database is not available.",
|
|
) from exc
|
|
|
|
try:
|
|
yield db
|
|
finally:
|
|
await db.close()
|
|
|
|
|
|
async def get_http_session(
|
|
app_context: Annotated[ApplicationContext, Depends(get_app_context)],
|
|
) -> aiohttp.ClientSession:
|
|
"""Provide the shared HTTP client session from application context.
|
|
|
|
Args:
|
|
app_context: The injected shared application context.
|
|
|
|
Returns:
|
|
A shared :class:`aiohttp.ClientSession` managed by the lifespan.
|
|
|
|
Raises:
|
|
HTTPException: If the session is unavailable.
|
|
"""
|
|
if app_context.http_session is None:
|
|
log.error("http_session_unavailable")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="HTTP session is not available.",
|
|
)
|
|
return app_context.http_session
|
|
|
|
|
|
async def get_scheduler(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> AsyncIOScheduler:
|
|
"""Provide the shared scheduler from application context.
|
|
|
|
Args:
|
|
app_context: The injected shared application context.
|
|
|
|
Returns:
|
|
The :class:`apscheduler.schedulers.asyncio.AsyncIOScheduler` instance.
|
|
|
|
Raises:
|
|
HTTPException: If the scheduler is unavailable.
|
|
"""
|
|
if app_context.scheduler is None:
|
|
log.error("scheduler_unavailable")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Scheduler is not available.",
|
|
)
|
|
return app_context.scheduler
|
|
|
|
|
|
async def get_fail2ban_socket(settings: Settings = Depends(get_settings)) -> str:
|
|
"""Provide the configured path to the fail2ban Unix domain socket."""
|
|
return settings.fail2ban_socket
|
|
|
|
|
|
async def get_fail2ban_config_dir(settings: Settings = Depends(get_settings)) -> str:
|
|
"""Provide the configured fail2ban configuration directory."""
|
|
return settings.fail2ban_config_dir
|
|
|
|
|
|
async def get_fail2ban_start_command(settings: Settings = Depends(get_settings)) -> str:
|
|
"""Provide the configured fail2ban start command."""
|
|
return settings.fail2ban_start_command
|
|
|
|
|
|
async def get_geo_cache(request: Request) -> GeoCache:
|
|
"""Provide the application's GeoCache instance."""
|
|
geo_cache: GeoCache = cast("GeoCache", request.app.state.geo_cache)
|
|
return geo_cache
|
|
|
|
|
|
async def get_session_cache(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> SessionCache:
|
|
"""Provide the configured session cache backend from application context."""
|
|
if app_context.session_cache is None:
|
|
log.error("session_cache_unavailable")
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Session cache is not available.",
|
|
)
|
|
return app_context.session_cache
|
|
|
|
|
|
async def get_login_rate_limiter(
|
|
app_context: Annotated[ApplicationContext, Depends(get_app_context)],
|
|
) -> RateLimiter:
|
|
"""Provide the login endpoint rate limiter from application context."""
|
|
return app_context.login_rate_limiter
|
|
|
|
|
|
async def get_session_repo() -> SessionRepository:
|
|
"""Provide the concrete session repository implementation.
|
|
|
|
The session_repo module uses structural typing to satisfy the SessionRepository
|
|
Protocol interface — its top-level async functions must match the Protocol
|
|
signatures exactly. This is documented in Backend-Development.md § 13.7.1.
|
|
"""
|
|
from app.repositories import session_repo # noqa: PLC0415
|
|
|
|
return session_repo
|
|
|
|
|
|
async def get_blocklist_repo() -> BlocklistRepository:
|
|
"""Provide the concrete blocklist repository implementation.
|
|
|
|
The blocklist_repo module uses structural typing to satisfy the BlocklistRepository
|
|
Protocol interface — its top-level async functions must match the Protocol
|
|
signatures exactly. This is documented in Backend-Development.md § 13.7.1.
|
|
"""
|
|
from app.repositories import blocklist_repo # noqa: PLC0415
|
|
|
|
return cast("BlocklistRepository", blocklist_repo)
|
|
|
|
|
|
async def get_import_log_repo() -> ImportLogRepository:
|
|
"""Provide the concrete import log repository implementation.
|
|
|
|
The import_log_repo module uses structural typing to satisfy the ImportLogRepository
|
|
Protocol interface — its top-level async functions must match the Protocol
|
|
signatures exactly. This is documented in Backend-Development.md § 13.7.1.
|
|
"""
|
|
from app.repositories import import_log_repo # noqa: PLC0415
|
|
|
|
return cast("ImportLogRepository", import_log_repo)
|
|
|
|
|
|
async def get_settings_repo() -> SettingsRepository:
|
|
"""Provide the concrete settings repository implementation.
|
|
|
|
The settings_repo module uses structural typing to satisfy the SettingsRepository
|
|
Protocol interface — its top-level async functions must match the Protocol
|
|
signatures exactly. This is documented in Backend-Development.md § 13.7.1.
|
|
"""
|
|
from app.repositories import settings_repo # noqa: PLC0415
|
|
|
|
return cast("SettingsRepository", settings_repo)
|
|
|
|
|
|
async def get_history_archive_repo() -> HistoryArchiveRepository:
|
|
"""Provide the concrete history archive repository implementation.
|
|
|
|
The history_archive_repo module uses structural typing to satisfy the
|
|
HistoryArchiveRepository Protocol interface — its top-level async functions
|
|
must match the Protocol signatures exactly. This is documented in
|
|
Backend-Development.md § 13.7.1.
|
|
"""
|
|
from app.repositories import history_archive_repo # noqa: PLC0415
|
|
|
|
return cast("HistoryArchiveRepository", history_archive_repo)
|
|
|
|
|
|
async def get_geo_cache_repo() -> GeoCacheRepository:
|
|
"""Provide the concrete geo cache repository implementation.
|
|
|
|
The geo_cache_repo module uses structural typing to satisfy the GeoCacheRepository
|
|
Protocol interface — its top-level async functions must match the Protocol
|
|
signatures exactly. This is documented in Backend-Development.md § 13.7.1.
|
|
"""
|
|
from app.repositories import geo_cache_repo # noqa: PLC0415
|
|
|
|
return cast("GeoCacheRepository", geo_cache_repo)
|
|
|
|
|
|
async def get_fail2ban_db_repo() -> Fail2BanDbRepository:
|
|
"""Provide the concrete fail2ban DB repository implementation.
|
|
|
|
The fail2ban_db_repo module uses structural typing to satisfy the
|
|
Fail2BanDbRepository Protocol interface — its top-level async functions must
|
|
match the Protocol signatures exactly. This is documented in
|
|
Backend-Development.md § 13.7.1.
|
|
"""
|
|
from app.repositories import fail2ban_db_repo # noqa: PLC0415
|
|
|
|
return cast("Fail2BanDbRepository", fail2ban_db_repo)
|
|
|
|
|
|
async def get_app_state(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> ApplicationContext:
|
|
"""Provide the application state object for the current request."""
|
|
return app_context
|
|
|
|
|
|
async def get_app(request: Request) -> FastAPI:
|
|
"""Provide the FastAPI application instance for the current request."""
|
|
return request.app
|
|
|
|
|
|
async def get_server_status(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> ServerStatus:
|
|
"""Return the cached fail2ban server status snapshot from application context."""
|
|
return app_context.server_status
|
|
|
|
|
|
async def get_pending_recovery(
|
|
app_context: Annotated[ApplicationContext, Depends(get_app_context)],
|
|
) -> PendingRecovery | None:
|
|
"""Return the current pending recovery record from application context."""
|
|
return app_context.pending_recovery
|
|
|
|
async def require_auth(
|
|
request: Request,
|
|
db: Annotated[aiosqlite.Connection, Depends(get_db)],
|
|
settings: Annotated[Settings, Depends(get_settings)],
|
|
session_cache: Annotated[SessionCache, Depends(get_session_cache)],
|
|
session_repo: Annotated[SessionRepository, Depends(get_session_repo)],
|
|
) -> Session:
|
|
"""Validate the session token and return the active session.
|
|
|
|
The token is read from the ``bangui_session`` cookie or the
|
|
``Authorization: Bearer`` header.
|
|
|
|
Validated tokens may be cached in memory for a short period so that
|
|
concurrent requests sharing the same token avoid repeated SQLite
|
|
round-trips. This cache is disabled by default because process-local
|
|
invalidation is not safe in multi-worker or clustered deployments.
|
|
When enabled, entries are bypassed on expiry and explicitly cleared by
|
|
the configured session cache backend on logout.
|
|
|
|
Args:
|
|
request: The incoming FastAPI request.
|
|
db: Injected aiosqlite connection.
|
|
settings: Application settings used for signed session token validation.
|
|
|
|
Returns:
|
|
The active :class:`~app.models.auth.Session`.
|
|
|
|
Raises:
|
|
HTTPException: 401 if no valid session token is found.
|
|
"""
|
|
|
|
token: str | None = request.cookies.get(SESSION_COOKIE_NAME)
|
|
if not token:
|
|
auth_header: str = request.headers.get("Authorization", "")
|
|
if auth_header.startswith("Bearer "):
|
|
token = auth_header[len("Bearer "):]
|
|
|
|
if not token:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Authentication required.",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
cache_enabled = _session_cache_enabled(settings)
|
|
if cache_enabled:
|
|
cached = session_cache.get(token)
|
|
if cached is not None:
|
|
return cached
|
|
|
|
from app.services import auth_service # noqa: PLC0415
|
|
|
|
try:
|
|
session = await auth_service.validate_session(
|
|
db,
|
|
token,
|
|
settings.session_secret,
|
|
session_repo=session_repo,
|
|
)
|
|
except ValueError as exc:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=str(exc),
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
) from exc
|
|
|
|
if cache_enabled:
|
|
session_cache.set(token, session, settings.session_cache_ttl_seconds)
|
|
return session
|
|
|
|
|
|
# Convenience type aliases for route signatures.
|
|
DbDep = Annotated[aiosqlite.Connection, Depends(get_db)]
|
|
SettingsDep = Annotated[Settings, Depends(get_settings)]
|
|
HttpSessionDep = Annotated[aiohttp.ClientSession, Depends(get_http_session)]
|
|
SchedulerDep = Annotated[AsyncIOScheduler, Depends(get_scheduler)]
|
|
Fail2BanSocketDep = Annotated[str, Depends(get_fail2ban_socket)]
|
|
Fail2BanConfigDirDep = Annotated[str, Depends(get_fail2ban_config_dir)]
|
|
Fail2BanStartCommandDep = Annotated[str, Depends(get_fail2ban_start_command)]
|
|
GeoCacheDep = Annotated[GeoCache, Depends(get_geo_cache)]
|
|
ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)]
|
|
PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)]
|
|
SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)]
|
|
SessionRepoDep = Annotated[SessionRepository, Depends(get_session_repo)]
|
|
SettingsRepoDep = Annotated[SettingsRepository, Depends(get_settings_repo)]
|
|
HistoryArchiveRepositoryDep = Annotated[HistoryArchiveRepository, Depends(get_history_archive_repo)]
|
|
BlocklistRepositoryDep = Annotated[BlocklistRepository, Depends(get_blocklist_repo)]
|
|
ImportLogRepositoryDep = Annotated[ImportLogRepository, Depends(get_import_log_repo)]
|
|
GeoCacheRepositoryDep = Annotated[GeoCacheRepository, Depends(get_geo_cache_repo)]
|
|
Fail2BanDbRepositoryDep = Annotated[Fail2BanDbRepository, Depends(get_fail2ban_db_repo)]
|
|
AppStateDep = Annotated[ApplicationContext, Depends(get_app_state)]
|
|
AppDep = Annotated[FastAPI, Depends(get_app)]
|
|
AuthDep = Annotated[Session, Depends(require_auth)]
|
|
LoginRateLimiterDep = Annotated[RateLimiter, Depends(get_login_rate_limiter)]
|