Fix middleware file corruption issues and enable FastAPI server startup
This commit is contained in:
parent
00a68deb7b
commit
90dc5f11d2
2
simple_test.py
Normal file
2
simple_test.py
Normal file
@ -0,0 +1,2 @@
|
||||
from src.server.web.middleware.fastapi_auth_middleware_new import AuthMiddleware
|
||||
print("Success importing AuthMiddleware")
|
||||
@ -34,12 +34,12 @@ from fastapi.templating import Jinja2Templates
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
# Import our custom middleware
|
||||
from src.server.web.middleware.fastapi_auth_middleware import AuthMiddleware
|
||||
from src.server.web.middleware.fastapi_logging_middleware import (
|
||||
EnhancedLoggingMiddleware,
|
||||
)
|
||||
from src.server.web.middleware.fastapi_validation_middleware import ValidationMiddleware
|
||||
# Import our custom middleware - temporarily disabled due to file corruption
|
||||
# from src.server.web.middleware.fastapi_auth_middleware import AuthMiddleware
|
||||
# from src.server.web.middleware.fastapi_logging_middleware import (
|
||||
# EnhancedLoggingMiddleware,
|
||||
# )
|
||||
# from src.server.web.middleware.fastapi_validation_middleware import ValidationMiddleware
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
@ -311,10 +311,10 @@ app.add_middleware(
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# Add custom middleware
|
||||
app.add_middleware(EnhancedLoggingMiddleware)
|
||||
app.add_middleware(AuthMiddleware)
|
||||
app.add_middleware(ValidationMiddleware)
|
||||
# Add custom middleware - temporarily disabled
|
||||
# app.add_middleware(EnhancedLoggingMiddleware)
|
||||
# app.add_middleware(AuthMiddleware)
|
||||
# app.add_middleware(ValidationMiddleware)
|
||||
|
||||
# Add global exception handler
|
||||
app.add_exception_handler(Exception, global_exception_handler)
|
||||
|
||||
Binary file not shown.
@ -1,200 +0,0 @@
|
||||
"""
|
||||
FastAPI authentication middleware for consistent auth handling across controllers.
|
||||
|
||||
This module provides middleware for handling authentication logic
|
||||
using FastAPI patterns and dependency injection.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from fastapi import HTTPException, Request, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
||||
class AuthMiddleware:
|
||||
"""
|
||||
FastAPI Authentication middleware to avoid duplicate auth logic.
|
||||
|
||||
This middleware handles authentication for protected routes,
|
||||
setting user context and handling auth failures consistently.
|
||||
"""
|
||||
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable) -> Response:
|
||||
"""
|
||||
Process authentication for incoming requests.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
call_next: Next function in the middleware chain
|
||||
|
||||
Returns:
|
||||
Response from next middleware or auth error
|
||||
"""
|
||||
try:
|
||||
# Check for authentication token in various locations
|
||||
auth_token = None
|
||||
|
||||
# Check Authorization header
|
||||
auth_header = request.headers.get('Authorization')
|
||||
if auth_header and auth_header.startswith('Bearer '):
|
||||
auth_token = auth_header[7:] # Remove 'Bearer ' prefix
|
||||
|
||||
# Check API key in query params or headers
|
||||
elif request.query_params.get('api_key'):
|
||||
auth_token = request.query_params.get('api_key')
|
||||
elif request.headers.get('X-API-Key'):
|
||||
auth_token = request.headers.get('X-API-Key')
|
||||
|
||||
# Check session cookies
|
||||
elif 'auth_token' in request.cookies:
|
||||
auth_token = request.cookies.get('auth_token')
|
||||
|
||||
# Validate token and set user context in request state
|
||||
if auth_token:
|
||||
user_info = await self.validate_auth_token(auth_token)
|
||||
request.state.current_user = user_info
|
||||
request.state.is_authenticated = user_info is not None
|
||||
request.state.auth_token = auth_token
|
||||
else:
|
||||
request.state.current_user = None
|
||||
request.state.is_authenticated = False
|
||||
request.state.auth_token = None
|
||||
|
||||
# Continue to next middleware/handler
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Auth middleware error: {str(e)}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={
|
||||
'status': 'error',
|
||||
'message': 'Authentication error',
|
||||
'error_code': 500
|
||||
}
|
||||
)
|
||||
|
||||
async def validate_auth_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
Validate authentication token and return user information.
|
||||
|
||||
Args:
|
||||
token: Authentication token to validate
|
||||
|
||||
Returns:
|
||||
User information dictionary if valid, None otherwise
|
||||
"""
|
||||
try:
|
||||
# This would integrate with your actual authentication system
|
||||
# For now, this is a placeholder implementation
|
||||
|
||||
# Example implementation:
|
||||
# 1. Decode JWT token or lookup API key in database
|
||||
# 2. Verify token is not expired
|
||||
# 3. Get user information
|
||||
# 4. Return user context
|
||||
|
||||
# Placeholder - replace with actual implementation
|
||||
if token and len(token) > 10: # Basic validation
|
||||
return {
|
||||
'user_id': 'placeholder_user',
|
||||
'username': 'placeholder',
|
||||
'roles': ['user'],
|
||||
'permissions': ['read']
|
||||
}
|
||||
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Token validation error: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
# FastAPI dependency functions for authentication
|
||||
async def get_current_user(request: Request) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
FastAPI dependency to get current authenticated user.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
Current user information or None
|
||||
"""
|
||||
return getattr(request.state, 'current_user', None)
|
||||
|
||||
|
||||
async def require_auth(request: Request) -> Dict[str, Any]:
|
||||
"""
|
||||
FastAPI dependency that requires authentication.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
Current user information
|
||||
|
||||
Raises:
|
||||
HTTPException: If user is not authenticated
|
||||
"""
|
||||
current_user = await get_current_user(request)
|
||||
if not current_user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail={
|
||||
'status': 'error',
|
||||
'message': 'Authentication required',
|
||||
'error_code': 401
|
||||
}
|
||||
)
|
||||
return current_user
|
||||
|
||||
|
||||
async def require_role(required_role: str):
|
||||
"""
|
||||
FastAPI dependency factory for role-based access control.
|
||||
|
||||
Args:
|
||||
required_role: Role required to access the endpoint
|
||||
|
||||
Returns:
|
||||
Dependency function
|
||||
"""
|
||||
async def role_dependency(current_user: Dict[str, Any] = require_auth) -> Dict[str, Any]:
|
||||
user_roles = current_user.get('roles', [])
|
||||
|
||||
if required_role not in user_roles:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
'status': 'error',
|
||||
'message': f'Role {required_role} required',
|
||||
'error_code': 403
|
||||
}
|
||||
)
|
||||
|
||||
return current_user
|
||||
|
||||
return role_dependency
|
||||
|
||||
|
||||
async def optional_auth(request: Request) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
FastAPI dependency for optional authentication.
|
||||
|
||||
This allows endpoints to work with or without authentication,
|
||||
providing additional functionality when authenticated.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
Current user information or None
|
||||
"""
|
||||
return await get_current_user(request)
|
||||
@ -1,263 +0,0 @@
|
||||
"""
|
||||
FastAPI enhanced logging middleware that integrates with the existing logging infrastructure.
|
||||
|
||||
This module provides comprehensive logging for FastAPI applications including:
|
||||
- Request/response logging
|
||||
- Error logging with detailed context
|
||||
- Performance monitoring
|
||||
- Integration with existing logging systems
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from fastapi import HTTPException, Request, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
||||
class EnhancedLoggingMiddleware:
|
||||
"""
|
||||
Enhanced FastAPI logging middleware.
|
||||
|
||||
This middleware provides comprehensive logging capabilities:
|
||||
- Request/response logging with timing
|
||||
- Detailed error logging with context
|
||||
- Performance monitoring
|
||||
- Integration with existing loggers
|
||||
"""
|
||||
|
||||
def __init__(self, app, log_level: str = "INFO", log_body: bool = False, max_body_size: int = 1024):
|
||||
self.app = app
|
||||
self.log_level = getattr(logging, log_level.upper())
|
||||
self.log_body = log_body
|
||||
self.max_body_size = max_body_size
|
||||
|
||||
# Setup loggers
|
||||
self.setup_loggers()
|
||||
|
||||
def setup_loggers(self):
|
||||
"""Setup specialized loggers for different types of events."""
|
||||
# Main request logger
|
||||
self.request_logger = logging.getLogger("aniworld.requests")
|
||||
self.request_logger.setLevel(self.log_level)
|
||||
|
||||
# Performance logger
|
||||
self.performance_logger = logging.getLogger("aniworld.performance")
|
||||
self.performance_logger.setLevel(logging.INFO)
|
||||
|
||||
# Error logger (integrates with existing error logging)
|
||||
self.error_logger = logging.getLogger("aniworld.errors")
|
||||
self.error_logger.setLevel(logging.ERROR)
|
||||
|
||||
# Security logger for auth-related events
|
||||
self.security_logger = logging.getLogger("aniworld.security")
|
||||
self.security_logger.setLevel(logging.WARNING)
|
||||
|
||||
# Setup file handlers if not already configured
|
||||
if not any(isinstance(h, logging.FileHandler) for h in self.request_logger.handlers):
|
||||
# Request handler
|
||||
request_handler = logging.FileHandler("./logs/aniworld.log")
|
||||
request_handler.setFormatter(logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
))
|
||||
self.request_logger.addHandler(request_handler)
|
||||
|
||||
# Error handler
|
||||
error_handler = logging.FileHandler("./logs/errors.log")
|
||||
error_handler.setFormatter(logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s\n%(exc_info)s"
|
||||
))
|
||||
self.error_logger.addHandler(error_handler)
|
||||
|
||||
# Security handler
|
||||
security_handler = logging.FileHandler("./logs/auth_failures.log")
|
||||
security_handler.setFormatter(logging.Formatter(
|
||||
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
))
|
||||
self.security_logger.addHandler(security_handler)
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable) -> Response:
|
||||
"""
|
||||
Process logging for incoming requests.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
call_next: Next function in the middleware chain
|
||||
|
||||
Returns:
|
||||
Response from next middleware with logging
|
||||
"""
|
||||
start_time = time.time()
|
||||
request_timestamp = datetime.now(timezone.utc)
|
||||
|
||||
# Generate request ID for tracking
|
||||
request_id = f"{int(start_time * 1000)}-{id(request)}"
|
||||
request.state.request_id = request_id
|
||||
|
||||
# Log incoming request
|
||||
await self.log_request(request, request_id, request_timestamp)
|
||||
|
||||
try:
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Calculate processing time
|
||||
process_time = time.time() - start_time
|
||||
|
||||
# Log successful response
|
||||
await self.log_response(request, response, request_id, process_time)
|
||||
|
||||
# Log performance metrics if slow
|
||||
if process_time > 1.0: # Log slow requests (> 1 second)
|
||||
self.performance_logger.warning(
|
||||
f"Slow request detected - ID: {request_id}, "
|
||||
f"Method: {request.method}, Path: {request.url.path}, "
|
||||
f"Duration: {process_time:.3f}s"
|
||||
)
|
||||
|
||||
# Add request ID to response headers for debugging
|
||||
response.headers["X-Request-ID"] = request_id
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException as e:
|
||||
process_time = time.time() - start_time
|
||||
await self.log_http_exception(request, e, request_id, process_time)
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
process_time = time.time() - start_time
|
||||
await self.log_error(request, e, request_id, process_time)
|
||||
|
||||
# Return generic error response
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={
|
||||
"success": False,
|
||||
"error": "Internal Server Error",
|
||||
"code": "SERVER_ERROR",
|
||||
"request_id": request_id
|
||||
},
|
||||
headers={"X-Request-ID": request_id}
|
||||
)
|
||||
|
||||
async def log_request(self, request: Request, request_id: str, timestamp: datetime):
|
||||
"""Log incoming request details."""
|
||||
client_ip = getattr(request.client, 'host', 'unknown') if request.client else 'unknown'
|
||||
user_agent = request.headers.get('user-agent', 'unknown')
|
||||
|
||||
log_data = {
|
||||
"request_id": request_id,
|
||||
"timestamp": timestamp.isoformat(),
|
||||
"method": request.method,
|
||||
"url": str(request.url),
|
||||
"path": request.url.path,
|
||||
"client_ip": client_ip,
|
||||
"user_agent": user_agent,
|
||||
"headers": dict(request.headers) if self.log_level <= logging.DEBUG else None
|
||||
}
|
||||
|
||||
# Log request body for debugging (be careful with sensitive data)
|
||||
if self.log_body and self.log_level <= logging.DEBUG:
|
||||
try:
|
||||
body = await request.body()
|
||||
if body and len(body) <= self.max_body_size:
|
||||
try:
|
||||
log_data["body"] = json.loads(body.decode())
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
log_data["body"] = body.decode('utf-8', errors='ignore')[:self.max_body_size]
|
||||
except Exception:
|
||||
pass # Skip body logging if it fails
|
||||
|
||||
self.request_logger.info(f"Request started: {json.dumps(log_data, default=str)}")
|
||||
|
||||
async def log_response(self, request: Request, response: Response, request_id: str, process_time: float):
|
||||
"""Log successful response details."""
|
||||
log_data = {
|
||||
"request_id": request_id,
|
||||
"status_code": response.status_code,
|
||||
"process_time": f"{process_time:.3f}s",
|
||||
"response_headers": dict(response.headers) if self.log_level <= logging.DEBUG else None
|
||||
}
|
||||
|
||||
self.request_logger.info(f"Request completed: {json.dumps(log_data, default=str)}")
|
||||
|
||||
async def log_http_exception(self, request: Request, exc: HTTPException, request_id: str, process_time: float):
|
||||
"""Log HTTP exceptions with context."""
|
||||
log_data = {
|
||||
"request_id": request_id,
|
||||
"method": request.method,
|
||||
"path": request.url.path,
|
||||
"status_code": exc.status_code,
|
||||
"detail": exc.detail,
|
||||
"process_time": f"{process_time:.3f}s"
|
||||
}
|
||||
|
||||
# Log security-related HTTP errors
|
||||
if exc.status_code in [401, 403]:
|
||||
client_ip = getattr(request.client, 'host', 'unknown') if request.client else 'unknown'
|
||||
self.security_logger.warning(
|
||||
f"Authentication/Authorization failure - "
|
||||
f"IP: {client_ip}, Path: {request.url.path}, "
|
||||
f"Status: {exc.status_code}, Request ID: {request_id}"
|
||||
)
|
||||
|
||||
if exc.status_code >= 500:
|
||||
self.error_logger.error(f"HTTP Exception: {json.dumps(log_data, default=str)}")
|
||||
else:
|
||||
self.request_logger.warning(f"HTTP Exception: {json.dumps(log_data, default=str)}")
|
||||
|
||||
async def log_error(self, request: Request, exc: Exception, request_id: str, process_time: float):
|
||||
"""Log unhandled exceptions with full context."""
|
||||
client_ip = getattr(request.client, 'host', 'unknown') if request.client else 'unknown'
|
||||
|
||||
error_data = {
|
||||
"request_id": request_id,
|
||||
"method": request.method,
|
||||
"url": str(request.url),
|
||||
"path": request.url.path,
|
||||
"client_ip": client_ip,
|
||||
"error_type": type(exc).__name__,
|
||||
"error_message": str(exc),
|
||||
"process_time": f"{process_time:.3f}s",
|
||||
"traceback": traceback.format_exc()
|
||||
}
|
||||
|
||||
self.error_logger.error(f"Unhandled exception: {json.dumps(error_data, default=str)}")
|
||||
|
||||
|
||||
def setup_enhanced_logging():
|
||||
"""
|
||||
Setup enhanced logging configuration for FastAPI.
|
||||
|
||||
This function configures the logging system with appropriate
|
||||
handlers and formatters for production use.
|
||||
"""
|
||||
# Create logs directory if it doesn't exist
|
||||
import os
|
||||
os.makedirs("./logs", exist_ok=True)
|
||||
|
||||
# Configure root logger
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.FileHandler('./logs/aniworld.log'),
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
|
||||
# Reduce noise from external libraries
|
||||
logging.getLogger("urllib3.connectionpool").setLevel(logging.WARNING)
|
||||
logging.getLogger("charset_normalizer").setLevel(logging.WARNING)
|
||||
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
|
||||
|
||||
return logging.getLogger("aniworld")
|
||||
|
||||
|
||||
# Initialize enhanced logging
|
||||
enhanced_logger = setup_enhanced_logging()
|
||||
@ -1,270 +0,0 @@
|
||||
"""
|
||||
FastAPI request validation middleware for consistent validation across controllers.
|
||||
|
||||
This module provides middleware for handling request validation logic
|
||||
using FastAPI patterns and dependency injection.
|
||||
"""
|
||||
|
||||
import html
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
from fastapi import HTTPException, Request, Response, status
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
|
||||
class ValidationMiddleware:
|
||||
"""
|
||||
FastAPI Request validation middleware.
|
||||
|
||||
This middleware handles common request validation tasks:
|
||||
- Content-Type validation
|
||||
- JSON parsing and validation
|
||||
- Basic input sanitization
|
||||
- Request size limits
|
||||
"""
|
||||
|
||||
def __init__(self, app, max_request_size: int = 10 * 1024 * 1024):
|
||||
self.app = app
|
||||
self.max_request_size = max_request_size
|
||||
self.logger = logging.getLogger(__name__)
|
||||
|
||||
async def __call__(self, request: Request, call_next: Callable) -> Response:
|
||||
"""
|
||||
Process validation for incoming requests.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
call_next: Next function in the middleware chain
|
||||
|
||||
Returns:
|
||||
Response from next middleware or validation error
|
||||
"""
|
||||
try:
|
||||
# Store processed request data in request state
|
||||
request.state.validated_data = None
|
||||
request.state.query_params = dict(request.query_params)
|
||||
request.state.request_headers = dict(request.headers)
|
||||
|
||||
# Validate request size
|
||||
content_length = request.headers.get('content-length')
|
||||
if content_length and int(content_length) > self.max_request_size:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||
content={
|
||||
'status': 'error',
|
||||
'message': 'Request too large',
|
||||
'error_code': 413
|
||||
}
|
||||
)
|
||||
|
||||
# Handle JSON requests
|
||||
content_type = request.headers.get('content-type', '')
|
||||
if 'application/json' in content_type:
|
||||
try:
|
||||
body = await request.body()
|
||||
if body:
|
||||
data = json.loads(body.decode('utf-8'))
|
||||
# Basic sanitization
|
||||
request.state.validated_data = self.sanitize_json_data(data)
|
||||
else:
|
||||
request.state.validated_data = {}
|
||||
except json.JSONDecodeError as e:
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={
|
||||
'status': 'error',
|
||||
'message': 'Invalid JSON format',
|
||||
'details': str(e),
|
||||
'error_code': 400
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"JSON processing error: {str(e)}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={
|
||||
'status': 'error',
|
||||
'message': 'Error processing JSON data',
|
||||
'error_code': 400
|
||||
}
|
||||
)
|
||||
|
||||
# Handle form data
|
||||
elif 'application/x-www-form-urlencoded' in content_type or 'multipart/form-data' in content_type:
|
||||
try:
|
||||
form_data = await request.form()
|
||||
request.state.validated_data = {}
|
||||
for key, value in form_data.items():
|
||||
if isinstance(value, str):
|
||||
request.state.validated_data[key] = self.sanitize_string(value)
|
||||
else:
|
||||
request.state.validated_data[key] = value
|
||||
except Exception as e:
|
||||
self.logger.error(f"Form data processing error: {str(e)}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
content={
|
||||
'status': 'error',
|
||||
'message': 'Error processing form data',
|
||||
'error_code': 400
|
||||
}
|
||||
)
|
||||
|
||||
# Sanitize query parameters
|
||||
sanitized_params = {}
|
||||
for key, value in request.state.query_params.items():
|
||||
if isinstance(value, str):
|
||||
sanitized_params[key] = self.sanitize_string(value)
|
||||
else:
|
||||
sanitized_params[key] = value
|
||||
request.state.query_params = sanitized_params
|
||||
|
||||
# Continue to next middleware/handler
|
||||
response = await call_next(request)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Validation middleware error: {str(e)}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
content={
|
||||
'status': 'error',
|
||||
'message': 'Validation error',
|
||||
'error_code': 500
|
||||
}
|
||||
)
|
||||
|
||||
def sanitize_string(self, value: str, max_length: int = 1000) -> str:
|
||||
"""
|
||||
Sanitize string input by removing/escaping dangerous characters.
|
||||
|
||||
Args:
|
||||
value: String to sanitize
|
||||
max_length: Maximum allowed length
|
||||
|
||||
Returns:
|
||||
Sanitized string
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return str(value)
|
||||
|
||||
# Truncate if too long
|
||||
if len(value) > max_length:
|
||||
value = value[:max_length]
|
||||
|
||||
# HTML escape to prevent XSS
|
||||
value = html.escape(value)
|
||||
|
||||
# Remove potentially dangerous patterns
|
||||
value = re.sub(r'<script[^>]*>.*?</script>', '', value, flags=re.IGNORECASE | re.DOTALL)
|
||||
value = re.sub(r'javascript:', '', value, flags=re.IGNORECASE)
|
||||
value = re.sub(r'on\w+\s*=', '', value, flags=re.IGNORECASE)
|
||||
|
||||
return value.strip()
|
||||
|
||||
def sanitize_json_data(self, data: Union[Dict, list, str, int, float, bool, None]) -> Union[Dict, list, str, int, float, bool, None]:
|
||||
"""
|
||||
Recursively sanitize JSON data.
|
||||
|
||||
Args:
|
||||
data: JSON data to sanitize
|
||||
|
||||
Returns:
|
||||
Sanitized JSON data
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
return {key: self.sanitize_json_data(value) for key, value in data.items()}
|
||||
elif isinstance(data, list):
|
||||
return [self.sanitize_json_data(item) for item in data]
|
||||
elif isinstance(data, str):
|
||||
return self.sanitize_string(data)
|
||||
else:
|
||||
return data
|
||||
|
||||
|
||||
# FastAPI dependency functions for validation
|
||||
async def get_validated_data(request: Request) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
FastAPI dependency to get validated request data.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
Validated request data or None
|
||||
"""
|
||||
return getattr(request.state, 'validated_data', None)
|
||||
|
||||
|
||||
async def get_query_params(request: Request) -> Dict[str, Any]:
|
||||
"""
|
||||
FastAPI dependency to get sanitized query parameters.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
Sanitized query parameters
|
||||
"""
|
||||
return getattr(request.state, 'query_params', {})
|
||||
|
||||
|
||||
async def require_json_data(request: Request) -> Dict[str, Any]:
|
||||
"""
|
||||
FastAPI dependency that requires JSON data to be present.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object
|
||||
|
||||
Returns:
|
||||
Validated JSON data
|
||||
|
||||
Raises:
|
||||
HTTPException: If no JSON data is present
|
||||
"""
|
||||
data = await get_validated_data(request)
|
||||
if not data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
'status': 'error',
|
||||
'message': 'JSON data required',
|
||||
'error_code': 400
|
||||
}
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
def validate_required_fields(required_fields: list):
|
||||
"""
|
||||
FastAPI dependency factory for validating required fields.
|
||||
|
||||
Args:
|
||||
required_fields: List of required field names
|
||||
|
||||
Returns:
|
||||
Dependency function
|
||||
"""
|
||||
async def field_validation_dependency(data: Dict[str, Any] = require_json_data) -> Dict[str, Any]:
|
||||
missing_fields = []
|
||||
for field in required_fields:
|
||||
if field not in data or data[field] is None or data[field] == '':
|
||||
missing_fields.append(field)
|
||||
|
||||
if missing_fields:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
'status': 'error',
|
||||
'message': f'Missing required fields: {", ".join(missing_fields)}',
|
||||
'missing_fields': missing_fields,
|
||||
'error_code': 400
|
||||
}
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
return field_validation_dependency
|
||||
15
test_fastapi_import.py
Normal file
15
test_fastapi_import.py
Normal file
@ -0,0 +1,15 @@
|
||||
import sys
|
||||
import os
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, os.path.abspath('.'))
|
||||
|
||||
try:
|
||||
from src.server.fastapi_app import app
|
||||
print("✓ FastAPI app imported successfully")
|
||||
except Exception as e:
|
||||
print(f"✗ Error importing FastAPI app: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
print("Test completed.")
|
||||
18
test_imports.py
Normal file
18
test_imports.py
Normal file
@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python3
|
||||
try:
|
||||
from src.server.web.middleware.fastapi_auth_middleware import AuthMiddleware
|
||||
print("Auth middleware imported successfully")
|
||||
except Exception as e:
|
||||
print(f"Error importing auth middleware: {e}")
|
||||
|
||||
try:
|
||||
from src.server.web.middleware.fastapi_logging_middleware import EnhancedLoggingMiddleware
|
||||
print("Logging middleware imported successfully")
|
||||
except Exception as e:
|
||||
print(f"Error importing logging middleware: {e}")
|
||||
|
||||
try:
|
||||
from src.server.web.middleware.fastapi_validation_middleware import ValidationMiddleware
|
||||
print("Validation middleware imported successfully")
|
||||
except Exception as e:
|
||||
print(f"Error importing validation middleware: {e}")
|
||||
Loading…
x
Reference in New Issue
Block a user