"""CSRF protection middleware for cookie-authenticated state-mutating requests. This middleware enforces explicit CSRF protection on POST, PUT, DELETE, and PATCH requests that use cookie-based authentication. Requests must include the custom header `X-BanGUI-Request: 1` to proceed. Bearer token authentication (via Authorization header) bypasses this check as it is not CSRF-vulnerable. GET, HEAD, and OPTIONS requests are also exempt. Cross-site requests cannot set custom headers without CORS preflight, which the backend rejects for non-allowed origins, providing defense-in-depth. """ from __future__ import annotations from typing import TYPE_CHECKING import structlog from fastapi import status from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware if TYPE_CHECKING: from collections.abc import Awaitable, Callable from starlette.requests import Request from starlette.responses import Response as StarletteResponse log: structlog.stdlib.BoundLogger = structlog.get_logger() # Header name and value that clients must provide for state-mutating requests. _CSRF_HEADER_NAME: str = "X-BanGUI-Request" _CSRF_HEADER_VALUE: str = "1" # HTTP methods that require CSRF protection. _CSRF_PROTECTED_METHODS: frozenset[str] = frozenset({"POST", "PUT", "DELETE", "PATCH"}) # Session cookie name for detecting cookie-based authentication. _SESSION_COOKIE_NAME: str = "bangui_session" class CsrfMiddleware(BaseHTTPMiddleware): """Protect cookie-authenticated state-mutating requests with custom header check. For requests using POST, PUT, DELETE, or PATCH methods that are authenticated via the session cookie (not Bearer token), this middleware requires the presence of a custom header to prevent CSRF attacks. Bearer token requests and safe HTTP methods are exempt. """ async def dispatch( self, request: Request, call_next: Callable[[Request], Awaitable[StarletteResponse]], ) -> StarletteResponse: """Intercept requests to enforce CSRF protection. Args: request: The incoming HTTP request. call_next: The next middleware / router handler. Returns: Either a 403 Forbidden response if CSRF validation fails, or the normal router response. """ # Skip check for safe methods. if request.method not in _CSRF_PROTECTED_METHODS: return await call_next(request) # Skip check if using Bearer token authentication (not CSRF-vulnerable). auth_header: str = request.headers.get("Authorization", "") if auth_header.startswith("Bearer "): return await call_next(request) # Skip check if not using cookie-based authentication. if _SESSION_COOKIE_NAME not in request.cookies: return await call_next(request) # Enforce CSRF header for cookie-authenticated state-mutating requests. csrf_header: str | None = request.headers.get(_CSRF_HEADER_NAME) if csrf_header != _CSRF_HEADER_VALUE: log.warning( "csrf_validation_failed", method=request.method, path=request.url.path, has_cookie=True, csrf_header_present=csrf_header is not None, ) return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content={"detail": "CSRF validation failed. Request rejected."}, ) return await call_next(request)