"""Metrics collection middleware for BanGUI. Tracks HTTP request count, latency, and active requests. Excludes the /metrics endpoint to prevent recursive metrics collection. """ from __future__ import annotations import re import time from typing import TYPE_CHECKING from app.utils.logging_compat import get_logger from starlette.middleware.base import BaseHTTPMiddleware from app.utils.metrics import http_active_requests, http_request_count, http_request_latency if TYPE_CHECKING: from collections.abc import Awaitable, Callable from starlette.requests import Request from starlette.responses import Response log = get_logger(__name__) # Paths excluded from detailed metrics (to avoid cardinality explosion) EXCLUDED_PATHS = {"/metrics", "/health", "/api/health"} # Pattern to normalize endpoint paths (convert IDs to placeholders) PATH_PATTERN = re.compile(r"/api/[^/]+/[a-f0-9\-]{36}|/api/[^/]+/\d+") def _normalize_path(path: str) -> str: """Normalize path by replacing IDs with placeholders. Converts paths like /api/resource/123 to /api/resource/{id} to prevent cardinality explosion from dynamic IDs. Args: path: The request path. Returns: Normalized path with IDs replaced by {id}. """ return PATH_PATTERN.sub(r"/api/{id}", path) class MetricsMiddleware(BaseHTTPMiddleware): """Middleware to collect Prometheus metrics for HTTP requests.""" async def dispatch( self, request: Request, call_next: Callable[[Request], Awaitable[Response]], ) -> Response: """Collect metrics for the request and response. Args: request: The incoming request. call_next: The next middleware/route handler. Returns: The response. """ # Skip metrics for excluded paths if request.url.path in EXCLUDED_PATHS: return await call_next(request) method: str = request.method endpoint: str = _normalize_path(request.url.path) # Track active requests http_active_requests.labels(method=method, endpoint=endpoint).inc() start_time = time.perf_counter() status_code = 500 try: response: Response = await call_next(request) status_code = response.status_code return response finally: # Record metrics duration: float = time.perf_counter() - start_time http_request_latency.labels(method=method, endpoint=endpoint).observe(duration) http_request_count.labels(method=method, endpoint=endpoint, status_code=status_code).inc() http_active_requests.labels(method=method, endpoint=endpoint).dec() log.debug( "http_request_recorded", method=method, endpoint=endpoint, status_code=status_code, duration_ms=duration * 1000, )