""" Security Middleware for AniWorld. This module provides security-related middleware including CORS, CSP, security headers, and request sanitization. """ import logging import re from typing import Callable, List, Optional from fastapi import FastAPI, Request, Response from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware from starlette.types import ASGIApp logger = logging.getLogger(__name__) class SecurityHeadersMiddleware(BaseHTTPMiddleware): """Middleware to add security headers to all responses.""" def __init__( self, app: ASGIApp, hsts_max_age: int = 31536000, # 1 year hsts_include_subdomains: bool = True, hsts_preload: bool = False, frame_options: str = "DENY", content_type_options: bool = True, xss_protection: bool = True, referrer_policy: str = "strict-origin-when-cross-origin", permissions_policy: Optional[str] = None, ): """ Initialize security headers middleware. Args: app: ASGI application hsts_max_age: HSTS max-age in seconds hsts_include_subdomains: Include subdomains in HSTS hsts_preload: Enable HSTS preload frame_options: X-Frame-Options value (DENY, SAMEORIGIN, or ALLOW-FROM) content_type_options: Enable X-Content-Type-Options: nosniff xss_protection: Enable X-XSS-Protection referrer_policy: Referrer-Policy value permissions_policy: Permissions-Policy value """ super().__init__(app) self.hsts_max_age = hsts_max_age self.hsts_include_subdomains = hsts_include_subdomains self.hsts_preload = hsts_preload self.frame_options = frame_options self.content_type_options = content_type_options self.xss_protection = xss_protection self.referrer_policy = referrer_policy self.permissions_policy = permissions_policy async def dispatch(self, request: Request, call_next: Callable) -> Response: """ Process request and add security headers to response. Args: request: Incoming request call_next: Next middleware in chain Returns: Response with security headers """ response = await call_next(request) # HSTS Header hsts_value = f"max-age={self.hsts_max_age}" if self.hsts_include_subdomains: hsts_value += "; includeSubDomains" if self.hsts_preload: hsts_value += "; preload" response.headers["Strict-Transport-Security"] = hsts_value # X-Frame-Options response.headers["X-Frame-Options"] = self.frame_options # X-Content-Type-Options if self.content_type_options: response.headers["X-Content-Type-Options"] = "nosniff" # X-XSS-Protection (deprecated but still useful for older browsers) if self.xss_protection: response.headers["X-XSS-Protection"] = "1; mode=block" # Referrer-Policy response.headers["Referrer-Policy"] = self.referrer_policy # Permissions-Policy if self.permissions_policy: response.headers["Permissions-Policy"] = self.permissions_policy # Remove potentially revealing headers response.headers.pop("Server", None) response.headers.pop("X-Powered-By", None) return response class ContentSecurityPolicyMiddleware(BaseHTTPMiddleware): """Middleware to add Content Security Policy headers.""" def __init__( self, app: ASGIApp, default_src: List[str] = None, script_src: List[str] = None, style_src: List[str] = None, img_src: List[str] = None, font_src: List[str] = None, connect_src: List[str] = None, frame_src: List[str] = None, object_src: List[str] = None, media_src: List[str] = None, worker_src: List[str] = None, form_action: List[str] = None, frame_ancestors: List[str] = None, base_uri: List[str] = None, upgrade_insecure_requests: bool = True, block_all_mixed_content: bool = True, report_only: bool = False, ): """ Initialize CSP middleware. Args: app: ASGI application default_src: default-src directive values script_src: script-src directive values style_src: style-src directive values img_src: img-src directive values font_src: font-src directive values connect_src: connect-src directive values frame_src: frame-src directive values object_src: object-src directive values media_src: media-src directive values worker_src: worker-src directive values form_action: form-action directive values frame_ancestors: frame-ancestors directive values base_uri: base-uri directive values upgrade_insecure_requests: Enable upgrade-insecure-requests block_all_mixed_content: Enable block-all-mixed-content report_only: Use Content-Security-Policy-Report-Only header """ super().__init__(app) # Default secure CSP self.directives = { "default-src": default_src or ["'self'"], "script-src": script_src or ["'self'", "'unsafe-inline'"], "style-src": style_src or ["'self'", "'unsafe-inline'"], "img-src": img_src or ["'self'", "data:", "https:"], "font-src": font_src or ["'self'", "data:"], "connect-src": connect_src or ["'self'", "ws:", "wss:"], "frame-src": frame_src or ["'none'"], "object-src": object_src or ["'none'"], "media-src": media_src or ["'self'"], "worker-src": worker_src or ["'self'"], "form-action": form_action or ["'self'"], "frame-ancestors": frame_ancestors or ["'none'"], "base-uri": base_uri or ["'self'"], } self.upgrade_insecure_requests = upgrade_insecure_requests self.block_all_mixed_content = block_all_mixed_content self.report_only = report_only def _build_csp_header(self) -> str: """ Build the CSP header value. Returns: CSP header string """ parts = [] for directive, values in self.directives.items(): if values: parts.append(f"{directive} {' '.join(values)}") if self.upgrade_insecure_requests: parts.append("upgrade-insecure-requests") if self.block_all_mixed_content: parts.append("block-all-mixed-content") return "; ".join(parts) async def dispatch(self, request: Request, call_next: Callable) -> Response: """ Process request and add CSP header to response. Args: request: Incoming request call_next: Next middleware in chain Returns: Response with CSP header """ response = await call_next(request) header_name = ( "Content-Security-Policy-Report-Only" if self.report_only else "Content-Security-Policy" ) response.headers[header_name] = self._build_csp_header() return response class RequestSanitizationMiddleware(BaseHTTPMiddleware): """Middleware to sanitize and validate incoming requests.""" # Common SQL injection patterns SQL_INJECTION_PATTERNS = [ re.compile(r"(\bunion\b.*\bselect\b)", re.IGNORECASE), re.compile(r"(\bselect\b.*\bfrom\b)", re.IGNORECASE), re.compile(r"(\binsert\b.*\binto\b)", re.IGNORECASE), re.compile(r"(\bupdate\b.*\bset\b)", re.IGNORECASE), re.compile(r"(\bdelete\b.*\bfrom\b)", re.IGNORECASE), re.compile(r"(\bdrop\b.*\btable\b)", re.IGNORECASE), re.compile(r"(\bexec\b|\bexecute\b)", re.IGNORECASE), re.compile(r"(--|\#|\/\*|\*\/)", re.IGNORECASE), ] # Common XSS patterns XSS_PATTERNS = [ re.compile(r"]*>.*?", re.IGNORECASE | re.DOTALL), re.compile(r"javascript:", re.IGNORECASE), re.compile(r"on\w+\s*=", re.IGNORECASE), # Event handlers like onclick= re.compile(r"]*>", re.IGNORECASE), ] def __init__( self, app: ASGIApp, check_sql_injection: bool = True, check_xss: bool = True, max_request_size: int = 10 * 1024 * 1024, # 10 MB allowed_content_types: Optional[List[str]] = None, ): """ Initialize request sanitization middleware. Args: app: ASGI application check_sql_injection: Enable SQL injection checks check_xss: Enable XSS checks max_request_size: Maximum request body size in bytes allowed_content_types: List of allowed content types """ super().__init__(app) self.check_sql_injection = check_sql_injection self.check_xss = check_xss self.max_request_size = max_request_size self.allowed_content_types = allowed_content_types or [ "application/json", "application/x-www-form-urlencoded", "multipart/form-data", "text/plain", ] def _check_sql_injection(self, value: str) -> bool: """ Check if string contains SQL injection patterns. Args: value: String to check Returns: True if potential SQL injection detected """ for pattern in self.SQL_INJECTION_PATTERNS: if pattern.search(value): return True return False def _check_xss(self, value: str) -> bool: """ Check if string contains XSS patterns. Args: value: String to check Returns: True if potential XSS detected """ for pattern in self.XSS_PATTERNS: if pattern.search(value): return True return False def _sanitize_value(self, value: str) -> Optional[str]: """ Sanitize a string value. Args: value: Value to sanitize Returns: None if malicious content detected, sanitized value otherwise """ if self.check_sql_injection and self._check_sql_injection(value): logger.warning(f"Potential SQL injection detected: {value[:100]}") return None if self.check_xss and self._check_xss(value): logger.warning(f"Potential XSS detected: {value[:100]}") return None return value async def dispatch(self, request: Request, call_next: Callable) -> Response: """ Process and sanitize request. Args: request: Incoming request call_next: Next middleware in chain Returns: Response or error response if request is malicious """ # Check content type content_type = request.headers.get("content-type", "").split(";")[0].strip() if ( content_type and not any(ct in content_type for ct in self.allowed_content_types) ): logger.warning(f"Unsupported content type: {content_type}") return JSONResponse( status_code=415, content={"detail": "Unsupported Media Type"}, ) # Check request size content_length = request.headers.get("content-length") if content_length and int(content_length) > self.max_request_size: logger.warning(f"Request too large: {content_length} bytes") return JSONResponse( status_code=413, content={"detail": "Request Entity Too Large"}, ) # Check query parameters for key, value in request.query_params.items(): if isinstance(value, str): sanitized = self._sanitize_value(value) if sanitized is None: logger.warning(f"Malicious query parameter detected: {key}") return JSONResponse( status_code=400, content={"detail": "Malicious request detected"}, ) # Check path parameters for key, value in request.path_params.items(): if isinstance(value, str): sanitized = self._sanitize_value(value) if sanitized is None: logger.warning(f"Malicious path parameter detected: {key}") return JSONResponse( status_code=400, content={"detail": "Malicious request detected"}, ) return await call_next(request) def configure_security_middleware( app: FastAPI, cors_origins: List[str] = None, cors_allow_credentials: bool = True, enable_hsts: bool = True, enable_csp: bool = True, enable_sanitization: bool = True, csp_report_only: bool = False, ) -> None: """ Configure all security middleware for the FastAPI application. Args: app: FastAPI application instance cors_origins: List of allowed CORS origins cors_allow_credentials: Allow credentials in CORS requests enable_hsts: Enable HSTS and other security headers enable_csp: Enable Content Security Policy enable_sanitization: Enable request sanitization csp_report_only: Use CSP in report-only mode """ # CORS Middleware if cors_origins is None: cors_origins = ["http://localhost:3000", "http://localhost:8000"] app.add_middleware( CORSMiddleware, allow_origins=cors_origins, allow_credentials=cors_allow_credentials, allow_methods=["*"], allow_headers=["*"], expose_headers=["*"], ) # Security Headers Middleware if enable_hsts: app.add_middleware( SecurityHeadersMiddleware, hsts_max_age=31536000, hsts_include_subdomains=True, frame_options="DENY", content_type_options=True, xss_protection=True, referrer_policy="strict-origin-when-cross-origin", ) # Content Security Policy Middleware if enable_csp: app.add_middleware( ContentSecurityPolicyMiddleware, report_only=csp_report_only, # Allow inline scripts and styles for development # In production, use nonces or hashes script_src=["'self'", "'unsafe-inline'", "'unsafe-eval'"], style_src=["'self'", "'unsafe-inline'", "https://cdnjs.cloudflare.com"], font_src=["'self'", "data:", "https://cdnjs.cloudflare.com"], img_src=["'self'", "data:", "https:"], connect_src=["'self'", "ws://localhost:*", "wss://localhost:*"], ) # Request Sanitization Middleware if enable_sanitization: app.add_middleware( RequestSanitizationMiddleware, check_sql_injection=True, check_xss=True, max_request_size=10 * 1024 * 1024, # 10 MB ) logger.info("Security middleware configured successfully")