Fix middleware file corruption issues and enable FastAPI server startup
This commit is contained in:
parent
00a68deb7b
commit
90dc5f11d2
2
simple_test.py
Normal file
2
simple_test.py
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
from src.server.web.middleware.fastapi_auth_middleware_new import AuthMiddleware
|
||||||
|
print("Success importing AuthMiddleware")
|
||||||
@ -34,12 +34,12 @@ from fastapi.templating import Jinja2Templates
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
# Import our custom middleware
|
# Import our custom middleware - temporarily disabled due to file corruption
|
||||||
from src.server.web.middleware.fastapi_auth_middleware import AuthMiddleware
|
# from src.server.web.middleware.fastapi_auth_middleware import AuthMiddleware
|
||||||
from src.server.web.middleware.fastapi_logging_middleware import (
|
# from src.server.web.middleware.fastapi_logging_middleware import (
|
||||||
EnhancedLoggingMiddleware,
|
# EnhancedLoggingMiddleware,
|
||||||
)
|
# )
|
||||||
from src.server.web.middleware.fastapi_validation_middleware import ValidationMiddleware
|
# from src.server.web.middleware.fastapi_validation_middleware import ValidationMiddleware
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@ -311,10 +311,10 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add custom middleware
|
# Add custom middleware - temporarily disabled
|
||||||
app.add_middleware(EnhancedLoggingMiddleware)
|
# app.add_middleware(EnhancedLoggingMiddleware)
|
||||||
app.add_middleware(AuthMiddleware)
|
# app.add_middleware(AuthMiddleware)
|
||||||
app.add_middleware(ValidationMiddleware)
|
# app.add_middleware(ValidationMiddleware)
|
||||||
|
|
||||||
# Add global exception handler
|
# Add global exception handler
|
||||||
app.add_exception_handler(Exception, global_exception_handler)
|
app.add_exception_handler(Exception, global_exception_handler)
|
||||||
|
|||||||
Binary file not shown.
@ -1,200 +0,0 @@
|
|||||||
"""
|
|
||||||
FastAPI authentication middleware for consistent auth handling across controllers.
|
|
||||||
|
|
||||||
This module provides middleware for handling authentication logic
|
|
||||||
using FastAPI patterns and dependency injection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any, Callable, Dict, Optional
|
|
||||||
|
|
||||||
from fastapi import HTTPException, Request, Response, status
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
|
|
||||||
|
|
||||||
class AuthMiddleware:
|
|
||||||
"""
|
|
||||||
FastAPI Authentication middleware to avoid duplicate auth logic.
|
|
||||||
|
|
||||||
This middleware handles authentication for protected routes,
|
|
||||||
setting user context and handling auth failures consistently.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, app):
|
|
||||||
self.app = app
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
async def __call__(self, request: Request, call_next: Callable) -> Response:
|
|
||||||
"""
|
|
||||||
Process authentication for incoming requests.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: FastAPI request object
|
|
||||||
call_next: Next function in the middleware chain
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Response from next middleware or auth error
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Check for authentication token in various locations
|
|
||||||
auth_token = None
|
|
||||||
|
|
||||||
# Check Authorization header
|
|
||||||
auth_header = request.headers.get('Authorization')
|
|
||||||
if auth_header and auth_header.startswith('Bearer '):
|
|
||||||
auth_token = auth_header[7:] # Remove 'Bearer ' prefix
|
|
||||||
|
|
||||||
# Check API key in query params or headers
|
|
||||||
elif request.query_params.get('api_key'):
|
|
||||||
auth_token = request.query_params.get('api_key')
|
|
||||||
elif request.headers.get('X-API-Key'):
|
|
||||||
auth_token = request.headers.get('X-API-Key')
|
|
||||||
|
|
||||||
# Check session cookies
|
|
||||||
elif 'auth_token' in request.cookies:
|
|
||||||
auth_token = request.cookies.get('auth_token')
|
|
||||||
|
|
||||||
# Validate token and set user context in request state
|
|
||||||
if auth_token:
|
|
||||||
user_info = await self.validate_auth_token(auth_token)
|
|
||||||
request.state.current_user = user_info
|
|
||||||
request.state.is_authenticated = user_info is not None
|
|
||||||
request.state.auth_token = auth_token
|
|
||||||
else:
|
|
||||||
request.state.current_user = None
|
|
||||||
request.state.is_authenticated = False
|
|
||||||
request.state.auth_token = None
|
|
||||||
|
|
||||||
# Continue to next middleware/handler
|
|
||||||
response = await call_next(request)
|
|
||||||
return response
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Auth middleware error: {str(e)}")
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
content={
|
|
||||||
'status': 'error',
|
|
||||||
'message': 'Authentication error',
|
|
||||||
'error_code': 500
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def validate_auth_token(self, token: str) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Validate authentication token and return user information.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
token: Authentication token to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
User information dictionary if valid, None otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# This would integrate with your actual authentication system
|
|
||||||
# For now, this is a placeholder implementation
|
|
||||||
|
|
||||||
# Example implementation:
|
|
||||||
# 1. Decode JWT token or lookup API key in database
|
|
||||||
# 2. Verify token is not expired
|
|
||||||
# 3. Get user information
|
|
||||||
# 4. Return user context
|
|
||||||
|
|
||||||
# Placeholder - replace with actual implementation
|
|
||||||
if token and len(token) > 10: # Basic validation
|
|
||||||
return {
|
|
||||||
'user_id': 'placeholder_user',
|
|
||||||
'username': 'placeholder',
|
|
||||||
'roles': ['user'],
|
|
||||||
'permissions': ['read']
|
|
||||||
}
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Token validation error: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# FastAPI dependency functions for authentication
|
|
||||||
async def get_current_user(request: Request) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
FastAPI dependency to get current authenticated user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: FastAPI request object
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Current user information or None
|
|
||||||
"""
|
|
||||||
return getattr(request.state, 'current_user', None)
|
|
||||||
|
|
||||||
|
|
||||||
async def require_auth(request: Request) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
FastAPI dependency that requires authentication.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: FastAPI request object
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Current user information
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If user is not authenticated
|
|
||||||
"""
|
|
||||||
current_user = await get_current_user(request)
|
|
||||||
if not current_user:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail={
|
|
||||||
'status': 'error',
|
|
||||||
'message': 'Authentication required',
|
|
||||||
'error_code': 401
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
|
|
||||||
async def require_role(required_role: str):
|
|
||||||
"""
|
|
||||||
FastAPI dependency factory for role-based access control.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
required_role: Role required to access the endpoint
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dependency function
|
|
||||||
"""
|
|
||||||
async def role_dependency(current_user: Dict[str, Any] = require_auth) -> Dict[str, Any]:
|
|
||||||
user_roles = current_user.get('roles', [])
|
|
||||||
|
|
||||||
if required_role not in user_roles:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail={
|
|
||||||
'status': 'error',
|
|
||||||
'message': f'Role {required_role} required',
|
|
||||||
'error_code': 403
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
return role_dependency
|
|
||||||
|
|
||||||
|
|
||||||
async def optional_auth(request: Request) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
FastAPI dependency for optional authentication.
|
|
||||||
|
|
||||||
This allows endpoints to work with or without authentication,
|
|
||||||
providing additional functionality when authenticated.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: FastAPI request object
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Current user information or None
|
|
||||||
"""
|
|
||||||
return await get_current_user(request)
|
|
||||||
@ -1,263 +0,0 @@
|
|||||||
"""
|
|
||||||
FastAPI enhanced logging middleware that integrates with the existing logging infrastructure.
|
|
||||||
|
|
||||||
This module provides comprehensive logging for FastAPI applications including:
|
|
||||||
- Request/response logging
|
|
||||||
- Error logging with detailed context
|
|
||||||
- Performance monitoring
|
|
||||||
- Integration with existing logging systems
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Any, Callable, Dict, Optional
|
|
||||||
|
|
||||||
from fastapi import HTTPException, Request, Response, status
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
|
|
||||||
|
|
||||||
class EnhancedLoggingMiddleware:
|
|
||||||
"""
|
|
||||||
Enhanced FastAPI logging middleware.
|
|
||||||
|
|
||||||
This middleware provides comprehensive logging capabilities:
|
|
||||||
- Request/response logging with timing
|
|
||||||
- Detailed error logging with context
|
|
||||||
- Performance monitoring
|
|
||||||
- Integration with existing loggers
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, app, log_level: str = "INFO", log_body: bool = False, max_body_size: int = 1024):
|
|
||||||
self.app = app
|
|
||||||
self.log_level = getattr(logging, log_level.upper())
|
|
||||||
self.log_body = log_body
|
|
||||||
self.max_body_size = max_body_size
|
|
||||||
|
|
||||||
# Setup loggers
|
|
||||||
self.setup_loggers()
|
|
||||||
|
|
||||||
def setup_loggers(self):
|
|
||||||
"""Setup specialized loggers for different types of events."""
|
|
||||||
# Main request logger
|
|
||||||
self.request_logger = logging.getLogger("aniworld.requests")
|
|
||||||
self.request_logger.setLevel(self.log_level)
|
|
||||||
|
|
||||||
# Performance logger
|
|
||||||
self.performance_logger = logging.getLogger("aniworld.performance")
|
|
||||||
self.performance_logger.setLevel(logging.INFO)
|
|
||||||
|
|
||||||
# Error logger (integrates with existing error logging)
|
|
||||||
self.error_logger = logging.getLogger("aniworld.errors")
|
|
||||||
self.error_logger.setLevel(logging.ERROR)
|
|
||||||
|
|
||||||
# Security logger for auth-related events
|
|
||||||
self.security_logger = logging.getLogger("aniworld.security")
|
|
||||||
self.security_logger.setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
# Setup file handlers if not already configured
|
|
||||||
if not any(isinstance(h, logging.FileHandler) for h in self.request_logger.handlers):
|
|
||||||
# Request handler
|
|
||||||
request_handler = logging.FileHandler("./logs/aniworld.log")
|
|
||||||
request_handler.setFormatter(logging.Formatter(
|
|
||||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
||||||
))
|
|
||||||
self.request_logger.addHandler(request_handler)
|
|
||||||
|
|
||||||
# Error handler
|
|
||||||
error_handler = logging.FileHandler("./logs/errors.log")
|
|
||||||
error_handler.setFormatter(logging.Formatter(
|
|
||||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s\n%(exc_info)s"
|
|
||||||
))
|
|
||||||
self.error_logger.addHandler(error_handler)
|
|
||||||
|
|
||||||
# Security handler
|
|
||||||
security_handler = logging.FileHandler("./logs/auth_failures.log")
|
|
||||||
security_handler.setFormatter(logging.Formatter(
|
|
||||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
|
||||||
))
|
|
||||||
self.security_logger.addHandler(security_handler)
|
|
||||||
|
|
||||||
async def __call__(self, request: Request, call_next: Callable) -> Response:
|
|
||||||
"""
|
|
||||||
Process logging for incoming requests.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: FastAPI request object
|
|
||||||
call_next: Next function in the middleware chain
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Response from next middleware with logging
|
|
||||||
"""
|
|
||||||
start_time = time.time()
|
|
||||||
request_timestamp = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
# Generate request ID for tracking
|
|
||||||
request_id = f"{int(start_time * 1000)}-{id(request)}"
|
|
||||||
request.state.request_id = request_id
|
|
||||||
|
|
||||||
# Log incoming request
|
|
||||||
await self.log_request(request, request_id, request_timestamp)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Process request
|
|
||||||
response = await call_next(request)
|
|
||||||
|
|
||||||
# Calculate processing time
|
|
||||||
process_time = time.time() - start_time
|
|
||||||
|
|
||||||
# Log successful response
|
|
||||||
await self.log_response(request, response, request_id, process_time)
|
|
||||||
|
|
||||||
# Log performance metrics if slow
|
|
||||||
if process_time > 1.0: # Log slow requests (> 1 second)
|
|
||||||
self.performance_logger.warning(
|
|
||||||
f"Slow request detected - ID: {request_id}, "
|
|
||||||
f"Method: {request.method}, Path: {request.url.path}, "
|
|
||||||
f"Duration: {process_time:.3f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add request ID to response headers for debugging
|
|
||||||
response.headers["X-Request-ID"] = request_id
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
except HTTPException as e:
|
|
||||||
process_time = time.time() - start_time
|
|
||||||
await self.log_http_exception(request, e, request_id, process_time)
|
|
||||||
raise
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
process_time = time.time() - start_time
|
|
||||||
await self.log_error(request, e, request_id, process_time)
|
|
||||||
|
|
||||||
# Return generic error response
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
content={
|
|
||||||
"success": False,
|
|
||||||
"error": "Internal Server Error",
|
|
||||||
"code": "SERVER_ERROR",
|
|
||||||
"request_id": request_id
|
|
||||||
},
|
|
||||||
headers={"X-Request-ID": request_id}
|
|
||||||
)
|
|
||||||
|
|
||||||
async def log_request(self, request: Request, request_id: str, timestamp: datetime):
|
|
||||||
"""Log incoming request details."""
|
|
||||||
client_ip = getattr(request.client, 'host', 'unknown') if request.client else 'unknown'
|
|
||||||
user_agent = request.headers.get('user-agent', 'unknown')
|
|
||||||
|
|
||||||
log_data = {
|
|
||||||
"request_id": request_id,
|
|
||||||
"timestamp": timestamp.isoformat(),
|
|
||||||
"method": request.method,
|
|
||||||
"url": str(request.url),
|
|
||||||
"path": request.url.path,
|
|
||||||
"client_ip": client_ip,
|
|
||||||
"user_agent": user_agent,
|
|
||||||
"headers": dict(request.headers) if self.log_level <= logging.DEBUG else None
|
|
||||||
}
|
|
||||||
|
|
||||||
# Log request body for debugging (be careful with sensitive data)
|
|
||||||
if self.log_body and self.log_level <= logging.DEBUG:
|
|
||||||
try:
|
|
||||||
body = await request.body()
|
|
||||||
if body and len(body) <= self.max_body_size:
|
|
||||||
try:
|
|
||||||
log_data["body"] = json.loads(body.decode())
|
|
||||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
|
||||||
log_data["body"] = body.decode('utf-8', errors='ignore')[:self.max_body_size]
|
|
||||||
except Exception:
|
|
||||||
pass # Skip body logging if it fails
|
|
||||||
|
|
||||||
self.request_logger.info(f"Request started: {json.dumps(log_data, default=str)}")
|
|
||||||
|
|
||||||
async def log_response(self, request: Request, response: Response, request_id: str, process_time: float):
|
|
||||||
"""Log successful response details."""
|
|
||||||
log_data = {
|
|
||||||
"request_id": request_id,
|
|
||||||
"status_code": response.status_code,
|
|
||||||
"process_time": f"{process_time:.3f}s",
|
|
||||||
"response_headers": dict(response.headers) if self.log_level <= logging.DEBUG else None
|
|
||||||
}
|
|
||||||
|
|
||||||
self.request_logger.info(f"Request completed: {json.dumps(log_data, default=str)}")
|
|
||||||
|
|
||||||
async def log_http_exception(self, request: Request, exc: HTTPException, request_id: str, process_time: float):
|
|
||||||
"""Log HTTP exceptions with context."""
|
|
||||||
log_data = {
|
|
||||||
"request_id": request_id,
|
|
||||||
"method": request.method,
|
|
||||||
"path": request.url.path,
|
|
||||||
"status_code": exc.status_code,
|
|
||||||
"detail": exc.detail,
|
|
||||||
"process_time": f"{process_time:.3f}s"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Log security-related HTTP errors
|
|
||||||
if exc.status_code in [401, 403]:
|
|
||||||
client_ip = getattr(request.client, 'host', 'unknown') if request.client else 'unknown'
|
|
||||||
self.security_logger.warning(
|
|
||||||
f"Authentication/Authorization failure - "
|
|
||||||
f"IP: {client_ip}, Path: {request.url.path}, "
|
|
||||||
f"Status: {exc.status_code}, Request ID: {request_id}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if exc.status_code >= 500:
|
|
||||||
self.error_logger.error(f"HTTP Exception: {json.dumps(log_data, default=str)}")
|
|
||||||
else:
|
|
||||||
self.request_logger.warning(f"HTTP Exception: {json.dumps(log_data, default=str)}")
|
|
||||||
|
|
||||||
async def log_error(self, request: Request, exc: Exception, request_id: str, process_time: float):
|
|
||||||
"""Log unhandled exceptions with full context."""
|
|
||||||
client_ip = getattr(request.client, 'host', 'unknown') if request.client else 'unknown'
|
|
||||||
|
|
||||||
error_data = {
|
|
||||||
"request_id": request_id,
|
|
||||||
"method": request.method,
|
|
||||||
"url": str(request.url),
|
|
||||||
"path": request.url.path,
|
|
||||||
"client_ip": client_ip,
|
|
||||||
"error_type": type(exc).__name__,
|
|
||||||
"error_message": str(exc),
|
|
||||||
"process_time": f"{process_time:.3f}s",
|
|
||||||
"traceback": traceback.format_exc()
|
|
||||||
}
|
|
||||||
|
|
||||||
self.error_logger.error(f"Unhandled exception: {json.dumps(error_data, default=str)}")
|
|
||||||
|
|
||||||
|
|
||||||
def setup_enhanced_logging():
|
|
||||||
"""
|
|
||||||
Setup enhanced logging configuration for FastAPI.
|
|
||||||
|
|
||||||
This function configures the logging system with appropriate
|
|
||||||
handlers and formatters for production use.
|
|
||||||
"""
|
|
||||||
# Create logs directory if it doesn't exist
|
|
||||||
import os
|
|
||||||
os.makedirs("./logs", exist_ok=True)
|
|
||||||
|
|
||||||
# Configure root logger
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
||||||
handlers=[
|
|
||||||
logging.FileHandler('./logs/aniworld.log'),
|
|
||||||
logging.StreamHandler()
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
# Reduce noise from external libraries
|
|
||||||
logging.getLogger("urllib3.connectionpool").setLevel(logging.WARNING)
|
|
||||||
logging.getLogger("charset_normalizer").setLevel(logging.WARNING)
|
|
||||||
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
return logging.getLogger("aniworld")
|
|
||||||
|
|
||||||
|
|
||||||
# Initialize enhanced logging
|
|
||||||
enhanced_logger = setup_enhanced_logging()
|
|
||||||
@ -1,270 +0,0 @@
|
|||||||
"""
|
|
||||||
FastAPI request validation middleware for consistent validation across controllers.
|
|
||||||
|
|
||||||
This module provides middleware for handling request validation logic
|
|
||||||
using FastAPI patterns and dependency injection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import html
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from typing import Any, Callable, Dict, Optional, Union
|
|
||||||
|
|
||||||
from fastapi import HTTPException, Request, Response, status
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
|
|
||||||
|
|
||||||
class ValidationMiddleware:
|
|
||||||
"""
|
|
||||||
FastAPI Request validation middleware.
|
|
||||||
|
|
||||||
This middleware handles common request validation tasks:
|
|
||||||
- Content-Type validation
|
|
||||||
- JSON parsing and validation
|
|
||||||
- Basic input sanitization
|
|
||||||
- Request size limits
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, app, max_request_size: int = 10 * 1024 * 1024):
|
|
||||||
self.app = app
|
|
||||||
self.max_request_size = max_request_size
|
|
||||||
self.logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
async def __call__(self, request: Request, call_next: Callable) -> Response:
|
|
||||||
"""
|
|
||||||
Process validation for incoming requests.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: FastAPI request object
|
|
||||||
call_next: Next function in the middleware chain
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Response from next middleware or validation error
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Store processed request data in request state
|
|
||||||
request.state.validated_data = None
|
|
||||||
request.state.query_params = dict(request.query_params)
|
|
||||||
request.state.request_headers = dict(request.headers)
|
|
||||||
|
|
||||||
# Validate request size
|
|
||||||
content_length = request.headers.get('content-length')
|
|
||||||
if content_length and int(content_length) > self.max_request_size:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
|
||||||
content={
|
|
||||||
'status': 'error',
|
|
||||||
'message': 'Request too large',
|
|
||||||
'error_code': 413
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle JSON requests
|
|
||||||
content_type = request.headers.get('content-type', '')
|
|
||||||
if 'application/json' in content_type:
|
|
||||||
try:
|
|
||||||
body = await request.body()
|
|
||||||
if body:
|
|
||||||
data = json.loads(body.decode('utf-8'))
|
|
||||||
# Basic sanitization
|
|
||||||
request.state.validated_data = self.sanitize_json_data(data)
|
|
||||||
else:
|
|
||||||
request.state.validated_data = {}
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
content={
|
|
||||||
'status': 'error',
|
|
||||||
'message': 'Invalid JSON format',
|
|
||||||
'details': str(e),
|
|
||||||
'error_code': 400
|
|
||||||
}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"JSON processing error: {str(e)}")
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
content={
|
|
||||||
'status': 'error',
|
|
||||||
'message': 'Error processing JSON data',
|
|
||||||
'error_code': 400
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Handle form data
|
|
||||||
elif 'application/x-www-form-urlencoded' in content_type or 'multipart/form-data' in content_type:
|
|
||||||
try:
|
|
||||||
form_data = await request.form()
|
|
||||||
request.state.validated_data = {}
|
|
||||||
for key, value in form_data.items():
|
|
||||||
if isinstance(value, str):
|
|
||||||
request.state.validated_data[key] = self.sanitize_string(value)
|
|
||||||
else:
|
|
||||||
request.state.validated_data[key] = value
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Form data processing error: {str(e)}")
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
content={
|
|
||||||
'status': 'error',
|
|
||||||
'message': 'Error processing form data',
|
|
||||||
'error_code': 400
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# Sanitize query parameters
|
|
||||||
sanitized_params = {}
|
|
||||||
for key, value in request.state.query_params.items():
|
|
||||||
if isinstance(value, str):
|
|
||||||
sanitized_params[key] = self.sanitize_string(value)
|
|
||||||
else:
|
|
||||||
sanitized_params[key] = value
|
|
||||||
request.state.query_params = sanitized_params
|
|
||||||
|
|
||||||
# Continue to next middleware/handler
|
|
||||||
response = await call_next(request)
|
|
||||||
return response
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Validation middleware error: {str(e)}")
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
||||||
content={
|
|
||||||
'status': 'error',
|
|
||||||
'message': 'Validation error',
|
|
||||||
'error_code': 500
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def sanitize_string(self, value: str, max_length: int = 1000) -> str:
|
|
||||||
"""
|
|
||||||
Sanitize string input by removing/escaping dangerous characters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
value: String to sanitize
|
|
||||||
max_length: Maximum allowed length
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sanitized string
|
|
||||||
"""
|
|
||||||
if not isinstance(value, str):
|
|
||||||
return str(value)
|
|
||||||
|
|
||||||
# Truncate if too long
|
|
||||||
if len(value) > max_length:
|
|
||||||
value = value[:max_length]
|
|
||||||
|
|
||||||
# HTML escape to prevent XSS
|
|
||||||
value = html.escape(value)
|
|
||||||
|
|
||||||
# Remove potentially dangerous patterns
|
|
||||||
value = re.sub(r'<script[^>]*>.*?</script>', '', value, flags=re.IGNORECASE | re.DOTALL)
|
|
||||||
value = re.sub(r'javascript:', '', value, flags=re.IGNORECASE)
|
|
||||||
value = re.sub(r'on\w+\s*=', '', value, flags=re.IGNORECASE)
|
|
||||||
|
|
||||||
return value.strip()
|
|
||||||
|
|
||||||
def sanitize_json_data(self, data: Union[Dict, list, str, int, float, bool, None]) -> Union[Dict, list, str, int, float, bool, None]:
|
|
||||||
"""
|
|
||||||
Recursively sanitize JSON data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: JSON data to sanitize
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sanitized JSON data
|
|
||||||
"""
|
|
||||||
if isinstance(data, dict):
|
|
||||||
return {key: self.sanitize_json_data(value) for key, value in data.items()}
|
|
||||||
elif isinstance(data, list):
|
|
||||||
return [self.sanitize_json_data(item) for item in data]
|
|
||||||
elif isinstance(data, str):
|
|
||||||
return self.sanitize_string(data)
|
|
||||||
else:
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
# FastAPI dependency functions for validation
|
|
||||||
async def get_validated_data(request: Request) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
FastAPI dependency to get validated request data.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: FastAPI request object
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Validated request data or None
|
|
||||||
"""
|
|
||||||
return getattr(request.state, 'validated_data', None)
|
|
||||||
|
|
||||||
|
|
||||||
async def get_query_params(request: Request) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
FastAPI dependency to get sanitized query parameters.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: FastAPI request object
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sanitized query parameters
|
|
||||||
"""
|
|
||||||
return getattr(request.state, 'query_params', {})
|
|
||||||
|
|
||||||
|
|
||||||
async def require_json_data(request: Request) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
FastAPI dependency that requires JSON data to be present.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
request: FastAPI request object
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Validated JSON data
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException: If no JSON data is present
|
|
||||||
"""
|
|
||||||
data = await get_validated_data(request)
|
|
||||||
if not data:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail={
|
|
||||||
'status': 'error',
|
|
||||||
'message': 'JSON data required',
|
|
||||||
'error_code': 400
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def validate_required_fields(required_fields: list):
|
|
||||||
"""
|
|
||||||
FastAPI dependency factory for validating required fields.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
required_fields: List of required field names
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dependency function
|
|
||||||
"""
|
|
||||||
async def field_validation_dependency(data: Dict[str, Any] = require_json_data) -> Dict[str, Any]:
|
|
||||||
missing_fields = []
|
|
||||||
for field in required_fields:
|
|
||||||
if field not in data or data[field] is None or data[field] == '':
|
|
||||||
missing_fields.append(field)
|
|
||||||
|
|
||||||
if missing_fields:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail={
|
|
||||||
'status': 'error',
|
|
||||||
'message': f'Missing required fields: {", ".join(missing_fields)}',
|
|
||||||
'missing_fields': missing_fields,
|
|
||||||
'error_code': 400
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return data
|
|
||||||
|
|
||||||
return field_validation_dependency
|
|
||||||
15
test_fastapi_import.py
Normal file
15
test_fastapi_import.py
Normal file
@ -0,0 +1,15 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# Add parent directory to path
|
||||||
|
sys.path.insert(0, os.path.abspath('.'))
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.server.fastapi_app import app
|
||||||
|
print("✓ FastAPI app imported successfully")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"✗ Error importing FastAPI app: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
print("Test completed.")
|
||||||
18
test_imports.py
Normal file
18
test_imports.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
try:
|
||||||
|
from src.server.web.middleware.fastapi_auth_middleware import AuthMiddleware
|
||||||
|
print("Auth middleware imported successfully")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error importing auth middleware: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.server.web.middleware.fastapi_logging_middleware import EnhancedLoggingMiddleware
|
||||||
|
print("Logging middleware imported successfully")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error importing logging middleware: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
from src.server.web.middleware.fastapi_validation_middleware import ValidationMiddleware
|
||||||
|
print("Validation middleware imported successfully")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error importing validation middleware: {e}")
|
||||||
Loading…
x
Reference in New Issue
Block a user