"""
Security Middleware for AniWorld.
This module provides security-related middleware including CORS, CSP,
security headers, and request sanitization.
"""
import logging
import re
from typing import Callable, List, Optional
from fastapi import FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
logger = logging.getLogger(__name__)
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Middleware to add security headers to all responses."""
def __init__(
self,
app: ASGIApp,
hsts_max_age: int = 31536000, # 1 year
hsts_include_subdomains: bool = True,
hsts_preload: bool = False,
frame_options: str = "DENY",
content_type_options: bool = True,
xss_protection: bool = True,
referrer_policy: str = "strict-origin-when-cross-origin",
permissions_policy: Optional[str] = None,
):
"""
Initialize security headers middleware.
Args:
app: ASGI application
hsts_max_age: HSTS max-age in seconds
hsts_include_subdomains: Include subdomains in HSTS
hsts_preload: Enable HSTS preload
frame_options: X-Frame-Options value (DENY, SAMEORIGIN, or ALLOW-FROM)
content_type_options: Enable X-Content-Type-Options: nosniff
xss_protection: Enable X-XSS-Protection
referrer_policy: Referrer-Policy value
permissions_policy: Permissions-Policy value
"""
super().__init__(app)
self.hsts_max_age = hsts_max_age
self.hsts_include_subdomains = hsts_include_subdomains
self.hsts_preload = hsts_preload
self.frame_options = frame_options
self.content_type_options = content_type_options
self.xss_protection = xss_protection
self.referrer_policy = referrer_policy
self.permissions_policy = permissions_policy
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""
Process request and add security headers to response.
Args:
request: Incoming request
call_next: Next middleware in chain
Returns:
Response with security headers
"""
response = await call_next(request)
# HSTS Header
hsts_value = f"max-age={self.hsts_max_age}"
if self.hsts_include_subdomains:
hsts_value += "; includeSubDomains"
if self.hsts_preload:
hsts_value += "; preload"
response.headers["Strict-Transport-Security"] = hsts_value
# X-Frame-Options
response.headers["X-Frame-Options"] = self.frame_options
# X-Content-Type-Options
if self.content_type_options:
response.headers["X-Content-Type-Options"] = "nosniff"
# X-XSS-Protection (deprecated but still useful for older browsers)
if self.xss_protection:
response.headers["X-XSS-Protection"] = "1; mode=block"
# Referrer-Policy
response.headers["Referrer-Policy"] = self.referrer_policy
# Permissions-Policy
if self.permissions_policy:
response.headers["Permissions-Policy"] = self.permissions_policy
# Remove potentially revealing headers
response.headers.pop("Server", None)
response.headers.pop("X-Powered-By", None)
return response
class ContentSecurityPolicyMiddleware(BaseHTTPMiddleware):
"""Middleware to add Content Security Policy headers."""
def __init__(
self,
app: ASGIApp,
default_src: List[str] = None,
script_src: List[str] = None,
style_src: List[str] = None,
img_src: List[str] = None,
font_src: List[str] = None,
connect_src: List[str] = None,
frame_src: List[str] = None,
object_src: List[str] = None,
media_src: List[str] = None,
worker_src: List[str] = None,
form_action: List[str] = None,
frame_ancestors: List[str] = None,
base_uri: List[str] = None,
upgrade_insecure_requests: bool = True,
block_all_mixed_content: bool = True,
report_only: bool = False,
):
"""
Initialize CSP middleware.
Args:
app: ASGI application
default_src: default-src directive values
script_src: script-src directive values
style_src: style-src directive values
img_src: img-src directive values
font_src: font-src directive values
connect_src: connect-src directive values
frame_src: frame-src directive values
object_src: object-src directive values
media_src: media-src directive values
worker_src: worker-src directive values
form_action: form-action directive values
frame_ancestors: frame-ancestors directive values
base_uri: base-uri directive values
upgrade_insecure_requests: Enable upgrade-insecure-requests
block_all_mixed_content: Enable block-all-mixed-content
report_only: Use Content-Security-Policy-Report-Only header
"""
super().__init__(app)
# Default secure CSP
self.directives = {
"default-src": default_src or ["'self'"],
"script-src": script_src or ["'self'", "'unsafe-inline'"],
"style-src": style_src or ["'self'", "'unsafe-inline'"],
"img-src": img_src or ["'self'", "data:", "https:"],
"font-src": font_src or ["'self'", "data:"],
"connect-src": connect_src or ["'self'", "ws:", "wss:"],
"frame-src": frame_src or ["'none'"],
"object-src": object_src or ["'none'"],
"media-src": media_src or ["'self'"],
"worker-src": worker_src or ["'self'"],
"form-action": form_action or ["'self'"],
"frame-ancestors": frame_ancestors or ["'none'"],
"base-uri": base_uri or ["'self'"],
}
self.upgrade_insecure_requests = upgrade_insecure_requests
self.block_all_mixed_content = block_all_mixed_content
self.report_only = report_only
def _build_csp_header(self) -> str:
"""
Build the CSP header value.
Returns:
CSP header string
"""
parts = []
for directive, values in self.directives.items():
if values:
parts.append(f"{directive} {' '.join(values)}")
if self.upgrade_insecure_requests:
parts.append("upgrade-insecure-requests")
if self.block_all_mixed_content:
parts.append("block-all-mixed-content")
return "; ".join(parts)
async def dispatch(self, request: Request, call_next: Callable) -> Response:
"""
Process request and add CSP header to response.
Args:
request: Incoming request
call_next: Next middleware in chain
Returns:
Response with CSP header
"""
response = await call_next(request)
header_name = (
"Content-Security-Policy-Report-Only"
if self.report_only
else "Content-Security-Policy"
)
response.headers[header_name] = self._build_csp_header()
return response
class RequestSanitizationMiddleware(BaseHTTPMiddleware):
"""Middleware to sanitize and validate incoming requests."""
# Common SQL injection patterns
SQL_INJECTION_PATTERNS = [
re.compile(r"(\bunion\b.*\bselect\b)", re.IGNORECASE),
re.compile(r"(\bselect\b.*\bfrom\b)", re.IGNORECASE),
re.compile(r"(\binsert\b.*\binto\b)", re.IGNORECASE),
re.compile(r"(\bupdate\b.*\bset\b)", re.IGNORECASE),
re.compile(r"(\bdelete\b.*\bfrom\b)", re.IGNORECASE),
re.compile(r"(\bdrop\b.*\btable\b)", re.IGNORECASE),
re.compile(r"(\bexec\b|\bexecute\b)", re.IGNORECASE),
re.compile(r"(--|\#|\/\*|\*\/)", re.IGNORECASE),
]
# Common XSS patterns
XSS_PATTERNS = [
re.compile(r"", re.IGNORECASE | re.DOTALL),
re.compile(r"javascript:", re.IGNORECASE),
re.compile(r"on\w+\s*=", re.IGNORECASE), # Event handlers like onclick=
re.compile(r"