Files
BanGUI/backend/app/main.py
Lukas 7ec80fdeec refactor(logging): replace structlog with stdlib logging compat layer
- Remove structlog dependency from backend/pyproject.toml
- Add app.utils.logging_compat shim for keyword-arg logging API
- Add app.utils.json_formatter for JSON log output with extra fields
- Update all backend modules to use logging_compat.get_logger()
- Update docstrings in log_sanitizer.py and json_formatter.py
- Update test comment in test_async_utils.py
- Record 406 failing tests in Docs/Tasks.md for tracking
2026-05-10 13:37:54 +02:00

1221 lines
44 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""BanGUI FastAPI application factory.
Call :func:`create_app` to obtain a configured :class:`fastapi.FastAPI`
instance suitable for direct use with an ASGI server (e.g. ``uvicorn``) or
in tests via ``httpx.AsyncClient``.
The lifespan handler manages all shared resources — database connection, HTTP
session, and scheduler — so every component can rely on them being available
on ``app.state`` throughout the request lifecycle.
"""
from __future__ import annotations
import logging
import os
import re
import sys
from contextlib import asynccontextmanager
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from collections.abc import AsyncGenerator, Awaitable, Callable
from starlette.responses import Response as StarletteResponse
from app.models.response import ErrorMetadata
from fastapi import FastAPI, HTTPException, Request, status
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, RedirectResponse
from starlette.middleware.base import BaseHTTPMiddleware
from app import __version__
from app.config import Settings, get_settings
from app.exceptions import (
AuthenticationError,
BadRequestError,
ConflictError,
DomainError,
Fail2BanConnectionError,
Fail2BanProtocolError,
NotFoundError,
OperationError,
RateLimitError,
ServiceUnavailableError,
)
from app.middleware.correlation import CorrelationIdMiddleware
from app.middleware.csrf import CsrfMiddleware
from app.middleware.deprecation import DeprecationHeaderMiddleware
from app.middleware.metrics import MetricsMiddleware
from app.middleware.rate_limit import RateLimitMiddleware
from app.models.response import ErrorResponse
from app.routers import (
auth,
bans,
blocklist,
config,
dashboard,
file_config,
geo,
health,
history,
jails,
jails_v2,
metrics,
server,
setup,
)
from app.startup import startup_shared_resources
from app.utils.external_logging import (
ExternalLogHandler,
create_external_log_handler,
)
from app.utils.rate_limiter import GlobalRateLimiter
from app.utils.runtime_state import ApplicationState, RuntimeState
from app.utils.scheduler_lock import release_scheduler_lock
from app.utils.session_cache import InMemorySessionCache, NoOpSessionCache
from app.utils.setup_state import is_setup_complete_cached, set_setup_complete_cache
from app.utils.json_formatter import JSONFormatter
log = logging.getLogger("bangui")
# ---------------------------------------------------------------------------
# Logging configuration
# ---------------------------------------------------------------------------
_external_log_handler: ExternalLogHandler | None = None
def _external_logging_processor(record: logging.LogRecord) -> None:
"""Queue log record to external logging handler.
Args:
record: The log record to queue.
"""
if _external_log_handler is not None:
_external_log_handler.queue_log(
{
"event": record.getMessage(),
"level": record.levelname.lower(),
"logger": record.name,
"timestamp": record.created,
}
)
class _ExternalLoggingHandler(logging.Handler):
"""Handler that forwards log records to the external log handler."""
def emit(self, record: logging.LogRecord) -> None:
_external_logging_processor(record)
def _configure_logging(log_level: str, log_file: str | None, settings: Settings | None = None) -> None:
"""Configure stdlib logging for production JSON output.
Args:
log_level: One of ``debug``, ``info``, ``warning``, ``error``, ``critical``.
log_file: Optional file path to write logs to (in addition to stdout).
settings: Optional Settings object to configure external logging.
"""
level: int = logging.getLevelName(log_level.upper())
handlers: list[logging.Handler] = [logging.StreamHandler(sys.stdout)]
if log_file:
os.makedirs(os.path.dirname(log_file), exist_ok=True)
handlers.append(logging.FileHandler(log_file))
# Suppress verbose third-party library logs that emit plain text
# through the standard library logging module.
if settings is None or settings.suppress_third_party_logs:
logging.getLogger("apscheduler").setLevel(logging.WARNING)
logging.getLogger("aiosqlite").setLevel(logging.WARNING)
formatter = JSONFormatter()
for handler in handlers:
handler.setFormatter(formatter)
logging.basicConfig(level=level, handlers=handlers)
if settings and settings.external_logging_enabled and settings.external_logging_provider:
external_handler = _ExternalLoggingHandler()
external_handler.setLevel(logging.DEBUG)
logging.getLogger().addHandler(external_handler)
# ---------------------------------------------------------------------------
# Lifespan
# ---------------------------------------------------------------------------
def _update_session_cache(app: FastAPI, settings: Settings) -> None:
"""Update the session cache backend based on settings.
Replaces the current cache with InMemorySessionCache or NoOpSessionCache
depending on whether session caching is enabled and configured with a
positive TTL.
Args:
app: The :class:`fastapi.FastAPI` instance.
settings: The effective application settings.
"""
cache_enabled = settings.session_cache_enabled and settings.session_cache_ttl_seconds > 0.0
app.state.session_cache = (
InMemorySessionCache() if cache_enabled else NoOpSessionCache()
)
@asynccontextmanager
async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
"""Manage the lifetime of all shared application resources.
Resources are initialised in order on startup and released in reverse
order on shutdown. They are stored on ``app.state`` so they are
accessible to dependency providers and tests.
The scheduler lock is released on shutdown to allow other instances to
acquire it during rolling deployments or after a crash.
Args:
app: The :class:`fastapi.FastAPI` instance being started.
"""
global _external_log_handler # noqa: PLW0603
settings: Settings = app.state.settings
runtime_state = app.state.runtime_state
http_session, scheduler, startup_db = await startup_shared_resources(app, settings)
app.state.http_session = http_session
app.state.scheduler = scheduler
app.state.startup_db = startup_db
# Initialize external logging handler before configuring logging
_external_log_handler = None
if settings.external_logging_enabled and settings.external_logging_provider:
try:
_external_log_handler = create_external_log_handler(
provider=settings.external_logging_provider,
api_key=settings.datadog_api_key,
datadog_site=settings.datadog_site,
datadog_batch_size=settings.datadog_batch_size,
papertrail_host=settings.papertrail_host,
papertrail_port=settings.papertrail_port,
papertrail_program_name=settings.papertrail_program_name,
elasticsearch_hosts=settings.elasticsearch_hosts,
elasticsearch_index_prefix=settings.elasticsearch_index_prefix,
elasticsearch_batch_size=settings.elasticsearch_batch_size,
flush_interval_seconds=settings.external_logging_flush_interval_seconds,
buffer_size=settings.external_logging_buffer_size,
http_session=http_session,
)
if _external_log_handler:
_external_log_handler.start_periodic_flush()
except ValueError as exc:
from app.utils import metrics as _metrics_mod
_metrics_mod.external_logging_init_failures.inc()
runtime_state.external_log_init_failed = True
log.error(
"external_logging_initialization_failed",
error=str(exc),
)
if settings.external_log_required:
msg = f"External logging is required but handler creation failed: {exc}"
log.critical("external_logging_required_but_unavailable", error=str(exc))
raise RuntimeError(msg) from exc
# Now configure logging with the handler in place
_configure_logging(settings.log_level, settings.log_file, settings)
log.info("bangui_starting_up", database_path=settings.database_path)
# Ensure session cache is initialized based on effective settings.
# This cache is process-local and not cluster-safe. In multi-worker
# deployments, it should be replaced with a shared backend.
_update_session_cache(app, settings)
# Initialize the global rate limiter (200 requests per 60 seconds per IP).
# Applied to all endpoints via middleware. Process-local implementation.
app.state.global_rate_limiter = GlobalRateLimiter(max_requests=200, window_seconds=60)
log.info("bangui_started")
try:
yield
finally:
# Grace period for pending tasks to complete before hard shutdown.
# Docker stop sends SIGTERM; uvicorn catches it and calls lifespan shutdown.
# We use a shorter timeout here (25s) to leave a safety margin before
# Docker's 30s kill timeout kicks in.
graceful_timeout: float = 25.0
log.info("bangui_shutting_down", timeout_seconds=graceful_timeout)
# 1. Signal scheduler to stop accepting new jobs.
# APScheduler's shutdown(wait=False) prevents new jobs from being submitted
# while allowing currently-running jobs to complete.
scheduler.shutdown(wait=False)
log.debug("scheduler_stopped_accepting_jobs")
# 2. Drain in-flight tasks: wait for running background jobs to complete.
# This gives blocklist imports, geo resolutions, and history syncs time to finish.
# Tasks that exceed the timeout are cancelled — the finally block in each
# task's coroutine handles cleanup.
import asyncio # noqa: TC003
current_task = asyncio.current_task()
pending_tasks: list[asyncio.Task[Any]] = [
t for t in asyncio.all_tasks() if not t.done() and t is not current_task
]
if pending_tasks:
log.info(
"waiting_for_pending_tasks",
count=len(pending_tasks),
timeout_seconds=graceful_timeout,
)
try:
await asyncio.wait_for(
asyncio.gather(*pending_tasks, return_exceptions=True),
timeout=graceful_timeout,
)
log.debug("pending_tasks_completed")
except TimeoutError:
log.warning(
"pending_tasks_timeout",
cancelled_count=len(pending_tasks),
)
# 3. Close HTTP session to release connections.
try:
await http_session.close()
log.debug("http_session_closed")
except asyncio.CancelledError:
log.debug("http_session_close_cancelled")
# 4. Shutdown external logging handler.
if _external_log_handler:
try:
await _external_log_handler.shutdown()
log.debug("external_logging_shutdown_complete")
except Exception as exc:
log.error("external_logging_shutdown_failed", error=str(exc))
# 5. Release the scheduler lock so other instances can take over immediately.
# During rolling deployments or restarts, this allows the new instance to
# acquire the lock without waiting for TTL expiry.
try:
await release_scheduler_lock(startup_db)
log.debug("scheduler_lock_released")
except Exception as e:
log.error("scheduler_lock_release_failed", error=str(e))
# 6. Close the database connection.
await startup_db.close()
log.info("bangui_shut_down")
# ---------------------------------------------------------------------------
# Exception handlers
# ---------------------------------------------------------------------------
def _get_error_code(exc: Exception) -> str:
"""Get the machine-readable error code from an exception.
First checks if the exception has an error_code class attribute.
Falls back to converting the exception class name to snake_case.
Args:
exc: The exception instance.
Returns:
A snake_case error code string.
"""
if hasattr(exc, "error_code"):
return exc.error_code
exc_name = exc.__class__.__name__
snake_case = re.sub(r"(?<!^)(?=[A-Z])", "_", exc_name).lower()
return snake_case
def _get_error_metadata(exc: Exception) -> ErrorMetadata:
"""Get structured metadata from an exception.
Calls the exception's get_error_metadata() method if available.
Args:
exc: The exception instance.
Returns:
A dictionary of metadata safe for API responses.
"""
if hasattr(exc, "get_error_metadata") and callable(exc.get_error_metadata):
return exc.get_error_metadata()
return {}
def _get_correlation_id(request: Request) -> str | None:
"""Extract correlation ID from request state if available.
The correlation ID is set by CorrelationIdMiddleware.
Args:
request: The incoming FastAPI request.
Returns:
The correlation ID string, or None if not present.
"""
return getattr(request.state, "correlation_id", None)
async def _unhandled_exception_handler(
request: Request,
exc: Exception,
) -> JSONResponse:
"""Return a sanitised 500 JSON response for any unhandled exception.
The exception is logged with full context before the response is sent.
No stack trace is leaked to the client.
Args:
request: The incoming FastAPI request.
exc: The unhandled exception.
Returns:
A :class:`fastapi.responses.JSONResponse` with status 500.
"""
log.error(
"unhandled_exception",
path=request.url.path,
method=request.method,
exc_info=exc,
)
error_response = ErrorResponse(
code="internal_error",
detail="An unexpected error occurred. Please try again later.",
metadata={},
correlation_id=_get_correlation_id(request),
)
return JSONResponse(
status_code=500,
content=error_response.model_dump(),
)
async def _fail2ban_connection_handler(
request: Request,
exc: Fail2BanConnectionError,
) -> JSONResponse:
"""Return a ``502 Bad Gateway`` response when fail2ban is unreachable.
Args:
request: The incoming FastAPI request.
exc: The :class:`~app.exceptions.Fail2BanConnectionError`.
Returns:
A :class:`fastapi.responses.JSONResponse` with status 502.
"""
log.warning(
"fail2ban_connection_error",
path=request.url.path,
method=request.method,
error=str(exc),
)
error_response = ErrorResponse(
code="fail2ban_unreachable",
detail="Cannot reach the fail2ban service. Check the server status page.",
metadata={"socket_path": exc.socket_path},
correlation_id=_get_correlation_id(request),
)
return JSONResponse(
status_code=502,
content=error_response.model_dump(),
)
async def _fail2ban_protocol_handler(
request: Request,
exc: Fail2BanProtocolError,
) -> JSONResponse:
"""Return a ``502 Bad Gateway`` response for fail2ban protocol errors.
Args:
request: The incoming FastAPI request.
exc: The :class:`~app.exceptions.Fail2BanProtocolError`.
Returns:
A :class:`fastapi.responses.JSONResponse` with status 502.
"""
log.warning(
"fail2ban_protocol_error",
path=request.url.path,
method=request.method,
error=str(exc),
)
error_response = ErrorResponse(
code="fail2ban_protocol_error",
detail="Cannot reach the fail2ban service. Check the server status page.",
metadata={},
correlation_id=_get_correlation_id(request),
)
return JSONResponse(
status_code=502,
content=error_response.model_dump(),
)
async def _not_found_handler(
request: Request,
exc: NotFoundError,
) -> JSONResponse:
"""Return a ``404 Not Found`` response for missing domain entities.
Args:
request: The incoming FastAPI request.
exc: The not-found exception.
Returns:
A :class:`fastapi.responses.JSONResponse` with status 404.
"""
log.warning(
"domain_not_found",
path=request.url.path,
method=request.method,
error=str(exc),
)
error_response = ErrorResponse(
code=_get_error_code(exc),
detail=str(exc),
metadata=_get_error_metadata(exc),
correlation_id=_get_correlation_id(request),
)
return JSONResponse(
status_code=status.HTTP_404_NOT_FOUND,
content=error_response.model_dump(),
)
async def _bad_request_handler(
request: Request,
exc: BadRequestError,
) -> JSONResponse:
"""Return a ``400 Bad Request`` response for validation and domain contract errors.
Args:
request: The incoming FastAPI request.
exc: The validation exception.
Returns:
A :class:`fastapi.responses.JSONResponse` with status 400.
"""
log.warning(
"domain_bad_request",
path=request.url.path,
method=request.method,
error=str(exc),
)
error_response = ErrorResponse(
code=_get_error_code(exc),
detail=str(exc),
metadata=_get_error_metadata(exc),
correlation_id=_get_correlation_id(request),
)
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=error_response.model_dump(),
)
async def _conflict_handler(
request: Request,
exc: ConflictError,
) -> JSONResponse:
"""Return a ``409 Conflict`` response for domain state conflicts."""
log.warning(
"domain_conflict",
path=request.url.path,
method=request.method,
error=str(exc),
)
error_response = ErrorResponse(
code=_get_error_code(exc),
detail=str(exc),
metadata=_get_error_metadata(exc),
correlation_id=_get_correlation_id(request),
)
return JSONResponse(
status_code=status.HTTP_409_CONFLICT,
content=error_response.model_dump(),
)
async def _domain_error_handler(
request: Request,
exc: DomainError,
) -> JSONResponse:
"""Return a ``500 Internal Server Error`` response for domain write failures."""
log.error(
"domain_internal_error",
path=request.url.path,
method=request.method,
error=str(exc),
exc_info=exc,
)
error_response = ErrorResponse(
code=_get_error_code(exc),
detail=str(exc),
metadata=_get_error_metadata(exc),
correlation_id=_get_correlation_id(request),
)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content=error_response.model_dump(),
)
async def _value_error_handler(
request: Request,
exc: ValueError,
) -> JSONResponse:
"""Return a ``400 Bad Request`` response for validation and value errors.
Args:
request: The incoming FastAPI request.
exc: The :class:`ValueError`.
Returns:
A :class:`fastapi.responses.JSONResponse` with status 400.
"""
log.warning(
"value_error",
path=request.url.path,
method=request.method,
error=str(exc),
)
error_response = ErrorResponse(
code="invalid_input",
detail=str(exc),
metadata={},
correlation_id=_get_correlation_id(request),
)
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=error_response.model_dump(),
)
async def _service_unavailable_handler(
request: Request,
exc: ServiceUnavailableError,
) -> JSONResponse:
"""Return a ``503 Service Unavailable`` response for infrastructure errors.
Args:
request: The incoming FastAPI request.
exc: The infrastructure exception (e.g., ConfigDirError).
Returns:
A :class:`fastapi.responses.JSONResponse` with status 503.
"""
log.warning(
"service_unavailable",
path=request.url.path,
method=request.method,
error=str(exc),
)
error_response = ErrorResponse(
code=_get_error_code(exc),
detail=str(exc),
metadata=_get_error_metadata(exc),
correlation_id=_get_correlation_id(request),
)
return JSONResponse(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
content=error_response.model_dump(),
)
async def _authentication_error_handler(
request: Request,
exc: AuthenticationError,
) -> JSONResponse:
"""Return a ``401 Unauthorized`` response for authentication failures.
Args:
request: The incoming FastAPI request.
exc: The :class:`~app.exceptions.AuthenticationError`.
Returns:
A :class:`fastapi.responses.JSONResponse` with status 401.
"""
log.warning(
"authentication_error",
path=request.url.path,
method=request.method,
error=str(exc),
)
error_response = ErrorResponse(
code=_get_error_code(exc),
detail=str(exc),
metadata=_get_error_metadata(exc),
correlation_id=_get_correlation_id(request),
)
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=error_response.model_dump(),
)
async def _rate_limit_error_handler(
request: Request,
exc: RateLimitError,
) -> JSONResponse:
"""Return a ``429 Too Many Requests`` response for rate limit exceeded errors.
Uses dynamic Retry-After header based on the actual rate limit configuration.
Args:
request: The incoming FastAPI request.
exc: The :class:`~app.exceptions.RateLimitError`.
Returns:
A :class:`fastapi.responses.JSONResponse` with status 429 and Retry-After header.
"""
log.warning(
"rate_limit_exceeded",
path=request.url.path,
method=request.method,
error=str(exc),
retry_after_seconds=exc.retry_after_seconds,
)
error_response = ErrorResponse(
code=_get_error_code(exc),
detail=str(exc),
metadata=_get_error_metadata(exc),
correlation_id=_get_correlation_id(request),
)
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content=error_response.model_dump(),
headers={"Retry-After": str(int(exc.retry_after_seconds))},
)
async def _http_exception_handler(
request: Request,
exc: HTTPException,
) -> JSONResponse:
"""Return a standardized error response for FastAPI HTTPException.
This handler standardizes responses from FastAPI validation errors,
path parameter mismatches, and other built-in validation failures
to use the ErrorResponse envelope with a machine-readable error code.
Args:
request: The incoming FastAPI request.
exc: The :class:`fastapi.HTTPException`.
Returns:
A :class:`fastapi.responses.JSONResponse` with the original status code.
"""
log.warning(
"http_exception",
path=request.url.path,
method=request.method,
status_code=exc.status_code,
error=exc.detail,
)
error_code_map = {
status.HTTP_400_BAD_REQUEST: "invalid_input",
status.HTTP_401_UNAUTHORIZED: "authentication_required",
status.HTTP_403_FORBIDDEN: "forbidden",
status.HTTP_404_NOT_FOUND: "not_found",
status.HTTP_409_CONFLICT: "conflict",
status.HTTP_422_UNPROCESSABLE_ENTITY: "invalid_input",
status.HTTP_429_TOO_MANY_REQUESTS: "rate_limit_exceeded",
status.HTTP_500_INTERNAL_SERVER_ERROR: "internal_error",
status.HTTP_503_SERVICE_UNAVAILABLE: "service_unavailable",
}
error_code = error_code_map.get(exc.status_code, "internal_error")
error_response = ErrorResponse(
code=error_code,
detail=exc.detail,
metadata={},
correlation_id=_get_correlation_id(request),
)
return JSONResponse(
status_code=exc.status_code,
content=error_response.model_dump(),
headers=exc.headers or {},
)
async def _request_validation_error_handler(
request: Request,
exc: RequestValidationError,
) -> JSONResponse:
"""Return a standardized error response for Pydantic validation errors.
Converts FastAPI's RequestValidationError to our unified ErrorResponse format.
Aggregates validation errors into metadata for the client to handle.
Args:
request: The incoming FastAPI request.
exc: The :class:`fastapi.exceptions.RequestValidationError`.
Returns:
A :class:`fastapi.responses.JSONResponse` with status 400.
"""
log.warning(
"request_validation_error",
path=request.url.path,
method=request.method,
error_count=len(exc.errors()),
)
validation_errors = exc.errors()
error_details: dict[str, str | int | float | bool | None] = {}
if validation_errors:
error_details["field_errors"] = len(validation_errors)
first_error = validation_errors[0]
error_details["first_field"] = ".".join(str(x) for x in first_error["loc"])
error_response = ErrorResponse(
code="invalid_input",
detail="Request validation failed.",
metadata=error_details,
correlation_id=_get_correlation_id(request),
)
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=error_response.model_dump(),
)
# ---------------------------------------------------------------------------
# Setup-redirect middleware
# ---------------------------------------------------------------------------
# Exact paths that are always reachable, even before setup is complete.
# Using exact matching prevents fragile prefix-based allowlists. For example,
# if we used startswith(), a future route like /api/setup-debug would bypass
# the guard without being explicitly allowed.
_EXACT_ALLOWED: frozenset[str] = frozenset(
{
"/api/v1/setup", # GET/POST /api/v1/setup
"/api/v1/health", # Health check endpoint (combined)
"/api/v1/health/live", # Kubernetes liveness probe
"/api/v1/health/ready", # Kubernetes readiness probe
"/api/docs", # Swagger UI
"/api/redoc", # ReDoc
"/api/openapi.json", # OpenAPI schema
},
)
# Prefix paths that are always reachable. These MUST end with "/" to prevent
# matching paths like "/api/v1/setup-debug" while still matching nested routes
# like "/api/v1/setup/timezone".
_PREFIX_ALLOWED: frozenset[str] = frozenset(
{
"/api/v1/setup/", # Nested setup routes (e.g., /api/v1/setup/timezone)
},
)
# Security headers constants
_CSP_POLICY: str = "default-src 'self'"
_X_FRAME_OPTIONS: str = "DENY"
_X_CONTENT_TYPE_OPTIONS: str = "nosniff"
_X_XSS_PROTECTION: str = "1; mode=block"
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Add security-related HTTP response headers to prevent common attacks.
This middleware adds the following headers to every HTTP response:
- Content-Security-Policy: Prevents XSS by restricting resource origins
- X-Frame-Options: Prevents clickjacking by controlling iframe embedding
- X-Content-Type-Options: Prevents MIME-sniffing attacks
- X-XSS-Protection: Enables browser XSS protection (legacy header)
These headers implement defense-in-depth against client-side attacks
by relying on browser security policies rather than server-side logic alone.
"""
async def dispatch(
self,
request: Request,
call_next: Callable[[Request], Awaitable[StarletteResponse]],
) -> StarletteResponse:
"""Intercept responses to inject security headers.
Args:
request: The incoming HTTP request.
call_next: The next middleware / router handler.
Returns:
The response from the next middleware / router with security headers added.
"""
response: StarletteResponse = await call_next(request)
response.headers["Content-Security-Policy"] = _CSP_POLICY
response.headers["X-Frame-Options"] = _X_FRAME_OPTIONS
response.headers["X-Content-Type-Options"] = _X_CONTENT_TYPE_OPTIONS
response.headers["X-XSS-Protection"] = _X_XSS_PROTECTION
return response
class SetupRedirectMiddleware(BaseHTTPMiddleware):
"""Redirect all API requests to ``/api/setup`` until setup is done.
Once setup is complete this middleware is a no-op. Paths listed in
:data:`_EXACT_ALLOWED` and :data:`_PREFIX_ALLOWED` are exempt so the
setup endpoint and dependencies (health, docs, openapi schema) are always
reachable.
This middleware uses explicit path matching rather than prefix-based rules
to prevent fragile allowlists. For example, using startswith() could
accidentally allow paths like /api/setup-debug that shouldn't bypass
the setup guard.
"""
async def dispatch(
self,
request: Request,
call_next: Callable[[Request], Awaitable[StarletteResponse]],
) -> StarletteResponse:
"""Intercept requests before they reach the router.
Args:
request: The incoming HTTP request.
call_next: The next middleware / router handler.
Returns:
Either a ``307 Temporary Redirect`` to ``/api/setup`` or the
normal router response.
"""
# Remove trailing slash for consistent path comparison.
# Note: request.url.path does not include query parameters, so those
# don't need special handling.
path: str = request.url.path.rstrip("/") or "/"
# Check if path is in the explicit allowlist (exact match).
if path in _EXACT_ALLOWED:
return await call_next(request)
# Check if path matches any allowed prefix. Prefixes in _PREFIX_ALLOWED
# end with "/" to prevent accidental matches. For example:
# - "/api/setup/" matches "/api/setup/timezone" (prefix match)
# - "/api/setup/" does NOT match "/api/setup-debug" (exact prefix without /)
for prefix in _PREFIX_ALLOWED:
if path == prefix.rstrip("/") or path.startswith(prefix):
return await call_next(request)
# Health endpoint is always reachable (needed for Docker/health checks
# and load balancer probes before setup is complete).
if path == "/api/v1/health":
return await call_next(request)
# If setup is not complete, block all other API requests.
# The setup completion state is resolved at startup and stored in
# ``app.state.setup_complete_cached`` so this middleware does not
# perform any database queries during normal request handling.
if path.startswith("/api/v1") and not is_setup_complete_cached(request.app):
return RedirectResponse(
url="/api/v1/setup",
status_code=status.HTTP_307_TEMPORARY_REDIRECT,
)
return await call_next(request)
def _enforce_single_worker() -> None:
"""Fail loudly if multi-worker deployment is detected.
Checks both ``WEB_CONCURRENCY`` (set by gunicorn / many-Rack frameworks)
and the explicit ``BANGUI_WORKERS`` env var. Uvicorn --workers flag also
sets WEB_CONCURRENCY in newer versions.
Skipping is intentional for test mode (TESTING env var set).
Raises:
RuntimeError: If worker count > 1 is detected.
"""
# Check WEB_CONCURRENCY (gunicorn, uvicorn --workers in recent versions)
web_concurrency = os.environ.get("WEB_CONCURRENCY")
if web_concurrency is not None:
try:
workers = int(web_concurrency)
if workers > 1:
raise RuntimeError(
"BanGUI cannot run with multiple workers.\n"
f"WEB_CONCURRENCY is set to {workers}. Expected 1.\n"
"\n"
"Why: in-memory session cache, rate-limit windows, and runtime "
"state are process-local. Multiple workers cause stale rate "
"limits, ghost sessions, and inconsistent server status.\n"
"\n"
"Fix: run with a single worker process. Use container "
"orchestration for horizontal scaling.\n"
"\n"
"See Docs/Deployment.md § Single-Worker Requirement."
)
except ValueError as e:
raise RuntimeError(
f"WEB_CONCURRENCY must be an integer, got: {web_concurrency}"
) from e
# Check explicit BANGUI_WORKERS override (discouraged, still enforced)
bangui_workers = os.environ.get("BANGUI_WORKERS")
if bangui_workers is not None:
try:
workers = int(bangui_workers)
if workers > 1:
raise RuntimeError(
"BanGUI cannot run with multiple workers.\n"
f"BANGUI_WORKERS is set to {workers}. Expected 1.\n"
"\n"
"Fix: set BANGUI_WORKERS=1 or remove from environment.\n"
"\n"
"See Docs/Deployment.md § Single-Worker Requirement."
)
except ValueError as e:
raise RuntimeError(
f"BANGUI_WORKERS must be an integer, got: {bangui_workers}"
) from e
# ---------------------------------------------------------------------------
# Application factory
# ---------------------------------------------------------------------------
def _assert_middleware_order(app: FastAPI) -> None:
"""Assert required middleware order at startup.
Raises:
AssertionError: If middleware are not in the required order.
"""
registered = [m.cls.__name__ for m in app.user_middleware]
# Find positions; skip middleware not in the security-critical chain
order: tuple[str, ...] = (
"RateLimitMiddleware",
"CsrfMiddleware",
"CorrelationIdMiddleware",
)
positions = {name: registered.index(name) for name in order if name in registered}
# RateLimitMiddleware must be before CsrfMiddleware
if (
"RateLimitMiddleware" in positions
and "CsrfMiddleware" in positions
and positions["RateLimitMiddleware"] > positions["CsrfMiddleware"]
):
raise AssertionError(
f"Middleware order violation: RateLimitMiddleware (position {positions['RateLimitMiddleware']}) "
f"must be registered before CsrfMiddleware (position {positions['CsrfMiddleware']}). "
f"Current order: {registered}"
)
# CsrfMiddleware must be before CorrelationIdMiddleware
if (
"CsrfMiddleware" in positions
and "CorrelationIdMiddleware" in positions
and positions["CsrfMiddleware"] > positions["CorrelationIdMiddleware"]
):
raise AssertionError(
f"Middleware order violation: CsrfMiddleware (position {positions['CsrfMiddleware']}) "
f"must be registered before CorrelationIdMiddleware (position {positions['CorrelationIdMiddleware']}). "
f"Current order: {registered}"
)
def create_app(settings: Settings | None = None) -> FastAPI:
"""Create and configure the BanGUI FastAPI application.
This factory is the single entry point for creating the application.
Tests can pass a custom ``settings`` object to override defaults
without touching environment variables.
Args:
settings: Optional pre-built :class:`~app.config.Settings` instance.
If ``None``, settings are loaded from the environment via
:func:`~app.config.get_settings`.
Returns:
A fully configured :class:`fastapi.FastAPI` application ready for use.
Raises:
RuntimeError: If multi-worker configuration is detected (WEB_CONCURRENCY
or --workers > 1), unless TESTING environment variable is set.
"""
# Enforce single-worker constraint before anything else.
# Skip in test mode (TESTING env var set by test framework or explicitly).
if not os.environ.get("TESTING"):
_enforce_single_worker()
resolved_settings: Settings = settings if settings is not None else get_settings()
# Configure API docs based on enable_docs setting.
# In production, docs are disabled (None). In development, docs are served at /api/*.
docs_url = "/api/docs" if resolved_settings.enable_docs else None
redoc_url = "/api/redoc" if resolved_settings.enable_docs else None
openapi_url = "/api/openapi.json" if resolved_settings.enable_docs else None
app: FastAPI = FastAPI(
title="BanGUI",
description="Web interface for monitoring, managing, and configuring fail2ban.",
version=__version__,
lifespan=_lifespan,
docs_url=docs_url,
redoc_url=redoc_url,
openapi_url=openapi_url,
)
# Store immutable configuration and the dedicated runtime state manager on
# app.state. Runtime state values are proxied through the wrapper so the
# shared Starlette state bag itself does not hold mutable business state.
app.state = ApplicationState(RuntimeState())
app.state.settings = resolved_settings
app.state.session_cache = (
InMemorySessionCache()
if resolved_settings.session_cache_enabled and resolved_settings.session_cache_ttl_seconds > 0.0
else NoOpSessionCache()
)
# Initialize the global rate limiter (200 requests per 60 seconds per IP).
# This is also re-initialized in the lifespan, but must be present here
# for tests that bypass the lifespan via ASGITransport.
app.state.global_rate_limiter = GlobalRateLimiter(max_requests=200, window_seconds=60)
set_setup_complete_cache(app, False)
# --- CORS ---
# Allow origins configured by the runtime environment. In production,
# this should be explicitly set to the frontend origin(s) or left empty
# when the UI is served from the same origin as the API.
if resolved_settings.cors_allowed_origins:
app.add_middleware(
CORSMiddleware,
allow_origins=resolved_settings.cors_allowed_origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# --- Middleware ---
# Note: Starlette applies middleware in reverse order of registration
# (last registered = outermost; first to see request, last to see response).
#
# Required processing order (outermost → innermost):
# 1. CorrelationIdMiddleware generates/extracts correlation ID first
# 2. CsrfMiddleware CSRF validation after correlation ID is available
# 3. RateLimitMiddleware rate limiting last (needs correlation ID for logging)
#
# This requires registration order (reverse of processing):
# 1. RateLimitMiddleware (registered first → innermost for responses)
# 2. CsrfMiddleware
# 3. CorrelationIdMiddleware (registered last → outermost for requests)
app.add_middleware(CorrelationIdMiddleware)
app.add_middleware(SecurityHeadersMiddleware)
app.add_middleware(SetupRedirectMiddleware)
app.add_middleware(MetricsMiddleware)
app.add_middleware(CsrfMiddleware)
app.add_middleware(DeprecationHeaderMiddleware)
# Auth endpoints (login, setup) need a dedicated higher-rate bucket to avoid
# rate limiting when running e2e tests sequentially. Auth uses the default
# global rate limiter at 200 req/min per IP.
# Auth endpoints: /api/v1/login, /api/v1/setup
# 1000 req/min per IP — generous for e2e testing.
app.add_middleware(
RateLimitMiddleware,
rate_limiter=app.state.global_rate_limiter,
settings=resolved_settings,
bucket_override="auth:login",
bucket_max_requests=1000,
bucket_window_seconds=60,
)
# History endpoints get a dedicated higher-rate bucket to avoid
# triggering rate limits when the UI page makes multiple simultaneous
# API calls (session validation + history + dashboard stats).
# 10000 req/min per IP — generous for normal browsing + e2e testing.
app.add_middleware(
RateLimitMiddleware,
rate_limiter=app.state.global_rate_limiter,
settings=resolved_settings,
bucket_override="history:list",
bucket_max_requests=10000,
bucket_window_seconds=60,
)
# Validate middleware order before returning the app.
# Raising loud errors at startup is intentional — a misconfigured middleware
# stack is a security-critical defect that must not slip through silently.
_assert_middleware_order(app)
# --- Exception handlers ---
#
# Exception handlers are registered from most specific to least specific. FastAPI evaluates
# them in registration order, allowing specific handlers to match before fallback handlers.
#
# The hierarchy (in order) is:
# 1. Network-specific errors (Fail2BanConnectionError, Fail2BanProtocolError) → HTTP 502
# 2. Auth/rate-limit errors (AuthenticationError, RateLimitError) → HTTP 401/429
# 3. Category handlers (NotFoundError, BadRequestError, ConflictError) → HTTP 404/400/409
# 4. OperationError handler → HTTP 500
# 5. ServiceUnavailableError handler → HTTP 503
# 6. Generic DomainError handler (catch-all for any unregistered DomainError subclass) → HTTP 500
# 7. RequestValidationError handler (Pydantic validation errors) → HTTP 400
# 8. HTTPException (FastAPI built-ins) → HTTP varies
# 9. ValueError (Pydantic validation) → HTTP 400
# 10. Exception (absolute catch-all for unexpected errors) → HTTP 500
#
# This ensures that any new DomainError subclass that inherits from a registered category
# is automatically handled with the correct error_code and metadata. If a developer adds
# a DomainError subclass without putting it in a category, it falls through to the
# generic DomainError handler rather than the unhandled_exception_handler.
app.add_exception_handler(Fail2BanConnectionError, _fail2ban_connection_handler) # type: ignore[arg-type]
app.add_exception_handler(Fail2BanProtocolError, _fail2ban_protocol_handler) # type: ignore[arg-type]
app.add_exception_handler(AuthenticationError, _authentication_error_handler) # type: ignore[arg-type]
app.add_exception_handler(RateLimitError, _rate_limit_error_handler) # type: ignore[arg-type]
app.add_exception_handler(NotFoundError, _not_found_handler) # type: ignore[arg-type]
app.add_exception_handler(BadRequestError, _bad_request_handler) # type: ignore[arg-type]
app.add_exception_handler(ConflictError, _conflict_handler) # type: ignore[arg-type]
app.add_exception_handler(OperationError, _domain_error_handler) # type: ignore[arg-type]
app.add_exception_handler(ServiceUnavailableError, _service_unavailable_handler) # type: ignore[arg-type]
app.add_exception_handler(DomainError, _domain_error_handler) # type: ignore[arg-type]
app.add_exception_handler(RequestValidationError, _request_validation_error_handler) # type: ignore[arg-type]
app.add_exception_handler(HTTPException, _http_exception_handler) # type: ignore[arg-type]
app.add_exception_handler(ValueError, _value_error_handler) # type: ignore[arg-type]
app.add_exception_handler(Exception, _unhandled_exception_handler)
# --- Routers ---
app.include_router(metrics.router, prefix="/api/v1")
app.include_router(health.router)
app.include_router(setup.router)
app.include_router(auth.router)
app.include_router(dashboard.router)
app.include_router(jails.router)
app.include_router(bans.router)
app.include_router(geo.router)
app.include_router(config.router)
app.include_router(file_config.router)
app.include_router(server.router)
app.include_router(history.router)
app.include_router(blocklist.router)
app.include_router(jails_v2.router)
return app