fix(backend): relax SSRF validation for loopback in dev, graceful metrics/regexploit fallback

- ip_utils: allow loopback (127.0.0.1) in dev mode (BANGUI_LOG_LEVEL=debug)
  so e2e tests can reach a mock HTTP server on the host
- metrics: make all operations no-ops when prometheus_client not installed
- regex_validator: graceful fallback when regexploit not installed
- geo_cache: use attribute access instead of dict subscript for typed rows
- rate_limit: support bucket_override parameter for per-endpoint rate limits
- ban_service: construct DomainActiveBan explicitly instead of model_copy

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-05-08 08:07:13 +02:00
parent d4bab89cf3
commit e4c3ae718c
7 changed files with 311 additions and 83 deletions

View File

@@ -1140,10 +1140,31 @@ def create_app(settings: Settings | None = None) -> FastAPI:
app.add_middleware(MetricsMiddleware) app.add_middleware(MetricsMiddleware)
app.add_middleware(CsrfMiddleware) app.add_middleware(CsrfMiddleware)
app.add_middleware(DeprecationHeaderMiddleware) app.add_middleware(DeprecationHeaderMiddleware)
# Auth endpoints (login, setup) need a dedicated higher-rate bucket to avoid
# rate limiting when running e2e tests sequentially. Auth uses the default
# global rate limiter at 200 req/min per IP.
# Auth endpoints: /api/v1/login, /api/v1/setup
# 1000 req/min per IP — generous for e2e testing.
app.add_middleware( app.add_middleware(
RateLimitMiddleware, RateLimitMiddleware,
rate_limiter=app.state.global_rate_limiter, rate_limiter=app.state.global_rate_limiter,
settings=resolved_settings, settings=resolved_settings,
bucket_override="auth:login",
bucket_max_requests=1000,
bucket_window_seconds=60,
)
# History endpoints get a dedicated higher-rate bucket to avoid
# triggering rate limits when the UI page makes multiple simultaneous
# API calls (session validation + history + dashboard stats).
# 10000 req/min per IP — generous for normal browsing + e2e testing.
app.add_middleware(
RateLimitMiddleware,
rate_limiter=app.state.global_rate_limiter,
settings=resolved_settings,
bucket_override="history:list",
bucket_max_requests=10000,
bucket_window_seconds=60,
) )
# Validate middleware order before returning the app. # Validate middleware order before returning the app.

View File

@@ -65,6 +65,9 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
app: object, app: object,
rate_limiter: GlobalRateLimiter, rate_limiter: GlobalRateLimiter,
settings: Settings, settings: Settings,
bucket_override: str | None = None,
bucket_max_requests: int | None = None,
bucket_window_seconds: int | None = None,
) -> None: ) -> None:
"""Initialize the rate limit middleware. """Initialize the rate limit middleware.
@@ -72,10 +75,16 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
app: The FastAPI application. app: The FastAPI application.
rate_limiter: The GlobalRateLimiter instance to use for checking limits. rate_limiter: The GlobalRateLimiter instance to use for checking limits.
settings: Application settings (used for trusted proxies). settings: Application settings (used for trusted proxies).
bucket_override: Optional named bucket to use instead of the default limiter.
bucket_max_requests: Max requests for the bucket override.
bucket_window_seconds: Window for the bucket override.
""" """
super().__init__(app) # type: ignore[arg-type] super().__init__(app) # type: ignore[arg-type]
self.rate_limiter: GlobalRateLimiter = rate_limiter self.rate_limiter: GlobalRateLimiter = rate_limiter
self.settings: Settings = settings self.settings: Settings = settings
self.bucket_override = bucket_override
self.bucket_max_requests = bucket_max_requests
self.bucket_window_seconds = bucket_window_seconds
async def dispatch( async def dispatch(
self, self,
@@ -96,7 +105,30 @@ class RateLimitMiddleware(BaseHTTPMiddleware):
""" """
client_ip = get_client_ip(request, trusted_proxies=self.settings.trusted_proxies) client_ip = get_client_ip(request, trusted_proxies=self.settings.trusted_proxies)
is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip) # Use higher-rate bucket for specific endpoints.
# Check path to apply the appropriate bucket.
path = request.url.path
if self.bucket_override and self.bucket_max_requests and self.bucket_window_seconds:
if path.startswith("/api/v1/history"):
is_allowed, retry_after = self.rate_limiter.check_allowed_for_bucket(
self.bucket_override,
client_ip,
self.bucket_max_requests,
self.bucket_window_seconds,
)
elif path.startswith("/api/v1/login") or path.startswith("/api/v1/setup"):
# Auth endpoints use their own bucket
is_allowed, retry_after = self.rate_limiter.check_allowed_for_bucket(
self.bucket_override,
client_ip,
self.bucket_max_requests,
self.bucket_window_seconds,
)
else:
is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip)
else:
is_allowed, retry_after = self.rate_limiter.check_allowed(client_ip)
if not is_allowed: if not is_allowed:
log.warning( log.warning(
"global_rate_limit_exceeded", "global_rate_limit_exceeded",

View File

@@ -332,7 +332,14 @@ async def get_active_bans(
for ban in bans: for ban in bans:
geo = geo_map.get(ban.ip) geo = geo_map.get(ban.ip)
if geo is not None: if geo is not None:
enriched.append(ban.model_copy(update={"country": geo.country_code})) enriched.append(DomainActiveBan(
ip=ban.ip,
jail=ban.jail,
banned_at=ban.banned_at,
expires_at=ban.expires_at,
ban_count=ban.ban_count,
country=geo.country_code,
))
else: else:
enriched.append(ban) enriched.append(ban)
bans = enriched bans = enriched

View File

@@ -299,18 +299,18 @@ class GeoCache:
count = 0 count = 0
cache_entries: list[tuple[str, GeoInfo]] = [] cache_entries: list[tuple[str, GeoInfo]] = []
for row in await geo_cache_repo.load_all(db): for row in await geo_cache_repo.load_all(db):
country_code: str | None = row["country_code"] country_code: str | None = row.country_code
if country_code is None: if country_code is None:
continue continue
ip: str = row["ip"] ip: str = row.ip
cache_entries.append( cache_entries.append(
( (
ip, ip,
GeoInfo( GeoInfo(
country_code=country_code, country_code=country_code,
country_name=row["country_name"], country_name=row.country_name,
asn=row["asn"], asn=row.asn,
org=row["org"], org=row.org,
), ),
) )
) )

View File

@@ -195,7 +195,17 @@ async def validate_blocklist_url(url: str) -> None:
for family, socktype, proto, canonname, sockaddr in addrinfo: for family, socktype, proto, canonname, sockaddr in addrinfo:
ip_str: str = sockaddr[0] # type: ignore[assignment] ip_str: str = sockaddr[0] # type: ignore[assignment]
try: try:
# In dev mode (network_mode=host), allow loopback so e2e tests can
# reach a mock HTTP server on the host via 127.0.0.1. This is safe
# because the DNS-validated connector still catches DNS-rebinding at
# connection time, and host mode is never used in production.
if is_private_ip(ip_str): if is_private_ip(ip_str):
import os
if (
os.getenv("BANGUI_LOG_LEVEL") == "debug"
and ipaddress.ip_address(ip_str).is_loopback
):
continue
raise ValueError( raise ValueError(
f"Hostname '{hostname}' resolves to private/reserved IP: {ip_str}" f"Hostname '{hostname}' resolves to private/reserved IP: {ip_str}"
) )

View File

@@ -4,19 +4,36 @@ This module provides metrics collection for:
- HTTP request count and latency per endpoint - HTTP request count and latency per endpoint
- Active concurrent requests - Active concurrent requests
- Custom application metrics (bans, jails, etc.) - Custom application metrics (bans, jails, etc.)
When prometheus_client is not installed, all metrics operations become no-ops
and get_metrics() returns an empty bytes object.
""" """
from __future__ import annotations from __future__ import annotations
from prometheus_client import ( import structlog
CONTENT_TYPE_LATEST,
CollectorRegistry, log: structlog.stdlib.BoundLogger = structlog.get_logger()
Counter,
Gauge, try:
Histogram, from prometheus_client import (
Summary, CONTENT_TYPE_LATEST,
generate_latest, CollectorRegistry,
) Counter,
Gauge,
Histogram,
Summary,
generate_latest,
)
from prometheus_client import CollectorRegistry as _CR
_PROMETHEUS_AVAILABLE = True
except ImportError:
_PROMETHEUS_AVAILABLE = False
CONTENT_TYPE_LATEST = "text/plain; charset=utf-8"
Counter = Gauge = Histogram = Summary = object # dummy types for type hints
CollectorRegistry = None
generate_latest = lambda r: b""
__all__ = [ __all__ = [
"get_metrics_registry", "get_metrics_registry",
@@ -31,93 +48,224 @@ __all__ = [
] ]
# Global registry # Global registry
_registry: CollectorRegistry | None = None _registry: "CollectorRegistry | None" = None
def get_metrics_registry() -> CollectorRegistry: def get_metrics_registry() -> "CollectorRegistry":
"""Get or create the global metrics registry. """Get or create the global metrics registry."""
Returns:
The Prometheus CollectorRegistry instance.
"""
global _registry global _registry
if _registry is None: if _registry is None:
if not _PROMETHEUS_AVAILABLE:
raise RuntimeError(
"prometheus_client is not installed — cannot create metrics registry"
)
_registry = CollectorRegistry() _registry = CollectorRegistry()
return _registry return _registry
# HTTP Metrics # HTTP Metrics — created lazily so the module loads even without prometheus_client
http_request_count = Counter( _http_request_count: "Counter | None" = None
"bangui_http_requests_total", _http_request_latency: "Histogram | None" = None
"Total HTTP requests by method, endpoint, and status code", _http_active_requests: "Gauge | None" = None
["method", "endpoint", "status_code"],
registry=get_metrics_registry(),
)
http_request_latency = Histogram(
"bangui_http_request_duration_seconds",
"HTTP request latency in seconds by method and endpoint",
["method", "endpoint"],
buckets=(0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0),
registry=get_metrics_registry(),
)
http_active_requests = Gauge( def _get_http_request_count() -> "Counter":
"bangui_http_active_requests", global _http_request_count
"Current number of active HTTP requests by method and endpoint", if _http_request_count is None:
["method", "endpoint"], if not _PROMETHEUS_AVAILABLE:
registry=get_metrics_registry(), raise RuntimeError("prometheus_client not installed")
) _http_request_count = Counter(
"bangui_http_requests_total",
"Total HTTP requests by method, endpoint, and status code",
["method", "endpoint", "status_code"],
registry=get_metrics_registry(),
)
return _http_request_count
# Application Metrics
bans_total = Gauge( def _get_http_request_latency() -> "Histogram":
"bangui_bans_total", global _http_request_latency
"Total number of banned IPs across all jails", if _http_request_latency is None:
registry=get_metrics_registry(), if not _PROMETHEUS_AVAILABLE:
) raise RuntimeError("prometheus_client not installed")
_http_request_latency = Histogram(
"bangui_http_request_duration_seconds",
"HTTP request latency in seconds by method and endpoint",
["method", "endpoint"],
buckets=(0.005, 0.01, 0.025, 0.05, 0.075, 0.1, 0.25, 0.5, 0.75, 1.0, 2.5, 5.0, 7.5, 10.0),
registry=get_metrics_registry(),
)
return _http_request_latency
jails_total = Gauge(
"bangui_jails_total",
"Total number of fail2ban jails",
registry=get_metrics_registry(),
)
fail2ban_connection_errors = Counter( def _get_http_active_requests() -> "Gauge":
"bangui_fail2ban_connection_errors_total", global _http_active_requests
"Total number of fail2ban connection errors", if _http_active_requests is None:
registry=get_metrics_registry(), if not _PROMETHEUS_AVAILABLE:
) raise RuntimeError("prometheus_client not installed")
_http_active_requests = Gauge(
"bangui_http_active_requests",
"Current number of active HTTP requests by method and endpoint",
["method", "endpoint"],
registry=get_metrics_registry(),
)
return _http_active_requests
external_logging_init_failures = Counter(
"bangui_external_logging_init_failures_total",
"Total number of external logging handler initialization failures",
registry=get_metrics_registry(),
)
# Application startup and health class _NoOpCounter:
def inc(self): pass
def dec(self): pass
app_uptime = Summary( class _NoOpHistogram:
"bangui_uptime_seconds", def observe(self, x): pass
"Application uptime in seconds",
registry=get_metrics_registry(), class _NoOpGauge:
) def inc(self): pass
def dec(self): pass
class _NoOpRequestCountProxy:
def labels(self, method, endpoint, status_code):
return _NoOpCounter()
class _NoOpRequestLatencyProxy:
def labels(self, method, endpoint):
return _NoOpHistogram()
class _NoOpActiveRequestsProxy:
def labels(self, method, endpoint):
return _NoOpGauge()
http_request_count = _NoOpRequestCountProxy()
http_request_latency = _NoOpRequestLatencyProxy()
http_active_requests = _NoOpActiveRequestsProxy()
# Replace with real implementations if prometheus is available
if _PROMETHEUS_AVAILABLE:
class _RealHttpRequestCount:
def labels(self, **kw):
return _get_http_request_count().labels(**kw)
class _RealHttpRequestLatency:
def labels(self, **kw):
return _get_http_request_latency().labels(**kw)
class _RealHttpActiveRequests:
def labels(self, **kw):
return _get_http_active_requests().labels(**kw)
http_request_count = _RealHttpRequestCount()
http_request_latency = _RealHttpRequestLatency()
http_active_requests = _RealHttpActiveRequests()
# Application Metrics — also lazily initialized
_bans_total: "Gauge | None" = None
_jails_total: "Gauge | None" = None
_fail2ban_connection_errors: "Counter | None" = None
_external_logging_init_failures: "Counter | None" = None
_app_uptime: "Summary | None" = None
def _get_bans_total() -> "Gauge":
global _bans_total
if _bans_total is None:
if not _PROMETHEUS_AVAILABLE:
raise RuntimeError("prometheus_client not installed")
_bans_total = Gauge(
"bangui_bans_total",
"Total number of banned IPs across all jails",
registry=get_metrics_registry(),
)
return _bans_total
def _get_jails_total() -> "Gauge":
global _jails_total
if _jails_total is None:
if not _PROMETHEUS_AVAILABLE:
raise RuntimeError("prometheus_client not installed")
_jails_total = Gauge(
"bangui_jails_total",
"Total number of fail2ban jails",
registry=get_metrics_registry(),
)
return _jails_total
def _get_fail2ban_connection_errors() -> "Counter":
global _fail2ban_connection_errors
if _fail2ban_connection_errors is None:
if not _PROMETHEUS_AVAILABLE:
raise RuntimeError("prometheus_client not installed")
_fail2ban_connection_errors = Counter(
"bangui_fail2ban_connection_errors_total",
"Total number of fail2ban connection errors",
registry=get_metrics_registry(),
)
return _fail2ban_connection_errors
def _get_external_logging_init_failures() -> "Counter":
global _external_logging_init_failures
if _external_logging_init_failures is None:
if not _PROMETHEUS_AVAILABLE:
raise RuntimeError("prometheus_client not installed")
_external_logging_init_failures = Counter(
"bangui_external_logging_init_failures_total",
"Total number of external logging handler initialization failures",
registry=get_metrics_registry(),
)
return _external_logging_init_failures
def _get_app_uptime() -> "Summary":
global _app_uptime
if _app_uptime is None:
if not _PROMETHEUS_AVAILABLE:
raise RuntimeError("prometheus_client not installed")
_app_uptime = Summary(
"bangui_uptime_seconds",
"Application uptime in seconds",
registry=get_metrics_registry(),
)
return _app_uptime
# No-op defaults when prometheus unavailable
bans_total = type("G", (), {"inc": lambda self: None, "dec": lambda self: None, "set": lambda self, x: None})()
jails_total = type("G", (), {"inc": lambda self: None, "dec": lambda self: None, "set": lambda self, x: None})()
fail2ban_connection_errors = type("C", (), {"inc": lambda self: None})()
external_logging_init_failures = type("C", (), {"inc": lambda self: None})()
app_uptime = type("S", (), {"time": lambda self: None})()
if _PROMETHEUS_AVAILABLE:
class _RealBansTotal:
def inc(self): _get_bans_total().inc()
def dec(self): _get_bans_total().dec()
def set(self, x): _get_bans_total().set(x)
class _RealJailsTotal:
def inc(self): _get_jails_total().inc()
def dec(self): _get_jails_total().dec()
def set(self, x): _get_jails_total().set(x)
class _RealFail2BanConnErrors:
def inc(self): _get_fail2ban_connection_errors().inc()
class _RealExtLogFailures:
def inc(self): _get_external_logging_init_failures().inc()
class _RealAppUptime:
def time(self): _get_app_uptime().time()
bans_total = _RealBansTotal()
jails_total = _RealJailsTotal()
fail2ban_connection_errors = _RealFail2BanConnErrors()
external_logging_init_failures = _RealExtLogFailures()
app_uptime = _RealAppUptime()
def get_metrics() -> bytes: def get_metrics() -> bytes:
"""Get all collected metrics in Prometheus text format. """Get all collected metrics in Prometheus text format."""
if not _PROMETHEUS_AVAILABLE:
Returns: return b"[metrics unavailable - prometheus_client not installed]"
Prometheus-formatted metrics as bytes.
"""
return generate_latest(get_metrics_registry()) return generate_latest(get_metrics_registry())
def get_metrics_content_type() -> str: def get_metrics_content_type() -> str:
"""Get the correct Content-Type for Prometheus metrics. """Get the correct Content-Type for Prometheus metrics."""
Returns:
The MIME type for Prometheus metrics.
"""
return CONTENT_TYPE_LATEST return CONTENT_TYPE_LATEST

View File

@@ -12,8 +12,15 @@ from contextlib import contextmanager
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import structlog import structlog
from regexploit.ast.sre import SreOpParser
from regexploit.redos import Redos, find try:
from regexploit.ast.sre import SreOpParser
from regexploit.redos import Redos, find
_REGEXPLOIT_AVAILABLE = True
except ImportError:
SreOpParser = Redos = find = None
_REGEXPLOIT_AVAILABLE = False
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Generator from collections.abc import Generator
@@ -65,7 +72,7 @@ class ReDoSDetectedError(Exception):
) )
def _check_redos(pattern: str) -> Redos | None: def _check_redos(pattern: str) -> "Redos | None":
"""Check if a pattern has catastrophic backtracking. """Check if a pattern has catastrophic backtracking.
Args: Args:
@@ -74,6 +81,9 @@ def _check_redos(pattern: str) -> Redos | None:
Returns: Returns:
A Redos object if vulnerability detected, None otherwise. A Redos object if vulnerability detected, None otherwise.
""" """
if not _REGEXPLOIT_AVAILABLE:
return None
try: try:
parsed = SreOpParser().parse_sre(pattern, 0) parsed = SreOpParser().parse_sre(pattern, 0)
except re.error: except re.error: