Migrate request/response interceptors to FastAPI middleware - Created FastAPI-compatible auth and validation middleware
This commit is contained in:
parent
e0c80c178d
commit
721326ecaf
@ -34,6 +34,10 @@ from fastapi.templating import Jinja2Templates
|
|||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
# Import our custom middleware
|
||||||
|
from web.middleware.fastapi_auth_middleware import AuthMiddleware
|
||||||
|
from web.middleware.fastapi_validation_middleware import ValidationMiddleware
|
||||||
|
|
||||||
# Configure logging
|
# Configure logging
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
@ -246,6 +250,10 @@ app.add_middleware(
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Add custom middleware
|
||||||
|
app.add_middleware(AuthMiddleware)
|
||||||
|
app.add_middleware(ValidationMiddleware)
|
||||||
|
|
||||||
# Request logging middleware
|
# Request logging middleware
|
||||||
@app.middleware("http")
|
@app.middleware("http")
|
||||||
async def log_requests(request: Request, call_next):
|
async def log_requests(request: Request, call_next):
|
||||||
|
|||||||
199
src/server/web/middleware/fastapi_auth_middleware.py
Normal file
199
src/server/web/middleware/fastapi_auth_middleware.py
Normal file
@ -0,0 +1,199 @@
|
|||||||
|
"""
|
||||||
|
FastAPI authentication middleware for consistent auth handling across controllers.
|
||||||
|
|
||||||
|
This module provides middleware for handling authentication logic
|
||||||
|
using FastAPI patterns and dependency injection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Callable, Optional, Dict, Any
|
||||||
|
from fastapi import Request, Response, HTTPException, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
|
||||||
|
class AuthMiddleware:
|
||||||
|
"""
|
||||||
|
FastAPI Authentication middleware to avoid duplicate auth logic.
|
||||||
|
|
||||||
|
This middleware handles authentication for protected routes,
|
||||||
|
setting user context and handling auth failures consistently.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, app):
|
||||||
|
self.app = app
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def __call__(self, request: Request, call_next: Callable) -> Response:
|
||||||
|
"""
|
||||||
|
Process authentication for incoming requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
call_next: Next function in the middleware chain
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response from next middleware or auth error
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Check for authentication token in various locations
|
||||||
|
auth_token = None
|
||||||
|
|
||||||
|
# Check Authorization header
|
||||||
|
auth_header = request.headers.get('Authorization')
|
||||||
|
if auth_header and auth_header.startswith('Bearer '):
|
||||||
|
auth_token = auth_header[7:] # Remove 'Bearer ' prefix
|
||||||
|
|
||||||
|
# Check API key in query params or headers
|
||||||
|
elif request.query_params.get('api_key'):
|
||||||
|
auth_token = request.query_params.get('api_key')
|
||||||
|
elif request.headers.get('X-API-Key'):
|
||||||
|
auth_token = request.headers.get('X-API-Key')
|
||||||
|
|
||||||
|
# Check session cookies
|
||||||
|
elif 'auth_token' in request.cookies:
|
||||||
|
auth_token = request.cookies.get('auth_token')
|
||||||
|
|
||||||
|
# Validate token and set user context in request state
|
||||||
|
if auth_token:
|
||||||
|
user_info = await self.validate_auth_token(auth_token)
|
||||||
|
request.state.current_user = user_info
|
||||||
|
request.state.is_authenticated = user_info is not None
|
||||||
|
request.state.auth_token = auth_token
|
||||||
|
else:
|
||||||
|
request.state.current_user = None
|
||||||
|
request.state.is_authenticated = False
|
||||||
|
request.state.auth_token = None
|
||||||
|
|
||||||
|
# Continue to next middleware/handler
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Auth middleware error: {str(e)}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
content={
|
||||||
|
'status': 'error',
|
||||||
|
'message': 'Authentication error',
|
||||||
|
'error_code': 500
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async def validate_auth_token(self, token: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Validate authentication token and return user information.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
token: Authentication token to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User information dictionary if valid, None otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# This would integrate with your actual authentication system
|
||||||
|
# For now, this is a placeholder implementation
|
||||||
|
|
||||||
|
# Example implementation:
|
||||||
|
# 1. Decode JWT token or lookup API key in database
|
||||||
|
# 2. Verify token is not expired
|
||||||
|
# 3. Get user information
|
||||||
|
# 4. Return user context
|
||||||
|
|
||||||
|
# Placeholder - replace with actual implementation
|
||||||
|
if token and len(token) > 10: # Basic validation
|
||||||
|
return {
|
||||||
|
'user_id': 'placeholder_user',
|
||||||
|
'username': 'placeholder',
|
||||||
|
'roles': ['user'],
|
||||||
|
'permissions': ['read']
|
||||||
|
}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Token validation error: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# FastAPI dependency functions for authentication
|
||||||
|
async def get_current_user(request: Request) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
FastAPI dependency to get current authenticated user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current user information or None
|
||||||
|
"""
|
||||||
|
return getattr(request.state, 'current_user', None)
|
||||||
|
|
||||||
|
|
||||||
|
async def require_auth(request: Request) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
FastAPI dependency that requires authentication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current user information
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If user is not authenticated
|
||||||
|
"""
|
||||||
|
current_user = await get_current_user(request)
|
||||||
|
if not current_user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail={
|
||||||
|
'status': 'error',
|
||||||
|
'message': 'Authentication required',
|
||||||
|
'error_code': 401
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
async def require_role(required_role: str):
|
||||||
|
"""
|
||||||
|
FastAPI dependency factory for role-based access control.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
required_role: Role required to access the endpoint
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dependency function
|
||||||
|
"""
|
||||||
|
async def role_dependency(current_user: Dict[str, Any] = require_auth) -> Dict[str, Any]:
|
||||||
|
user_roles = current_user.get('roles', [])
|
||||||
|
|
||||||
|
if required_role not in user_roles:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail={
|
||||||
|
'status': 'error',
|
||||||
|
'message': f'Role {required_role} required',
|
||||||
|
'error_code': 403
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
return role_dependency
|
||||||
|
|
||||||
|
|
||||||
|
async def optional_auth(request: Request) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
FastAPI dependency for optional authentication.
|
||||||
|
|
||||||
|
This allows endpoints to work with or without authentication,
|
||||||
|
providing additional functionality when authenticated.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Current user information or None
|
||||||
|
"""
|
||||||
|
return await get_current_user(request)
|
||||||
269
src/server/web/middleware/fastapi_validation_middleware.py
Normal file
269
src/server/web/middleware/fastapi_validation_middleware.py
Normal file
@ -0,0 +1,269 @@
|
|||||||
|
"""
|
||||||
|
FastAPI request validation middleware for consistent validation across controllers.
|
||||||
|
|
||||||
|
This module provides middleware for handling request validation logic
|
||||||
|
using FastAPI patterns and dependency injection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Callable, Dict, Any, Optional, Union
|
||||||
|
from fastapi import Request, Response, HTTPException, status
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
import html
|
||||||
|
import re
|
||||||
|
|
||||||
|
|
||||||
|
class ValidationMiddleware:
|
||||||
|
"""
|
||||||
|
FastAPI Request validation middleware.
|
||||||
|
|
||||||
|
This middleware handles common request validation tasks:
|
||||||
|
- Content-Type validation
|
||||||
|
- JSON parsing and validation
|
||||||
|
- Basic input sanitization
|
||||||
|
- Request size limits
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, app, max_request_size: int = 10 * 1024 * 1024):
|
||||||
|
self.app = app
|
||||||
|
self.max_request_size = max_request_size
|
||||||
|
self.logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
async def __call__(self, request: Request, call_next: Callable) -> Response:
|
||||||
|
"""
|
||||||
|
Process validation for incoming requests.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
call_next: Next function in the middleware chain
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Response from next middleware or validation error
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Store processed request data in request state
|
||||||
|
request.state.validated_data = None
|
||||||
|
request.state.query_params = dict(request.query_params)
|
||||||
|
request.state.request_headers = dict(request.headers)
|
||||||
|
|
||||||
|
# Validate request size
|
||||||
|
content_length = request.headers.get('content-length')
|
||||||
|
if content_length and int(content_length) > self.max_request_size:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
|
||||||
|
content={
|
||||||
|
'status': 'error',
|
||||||
|
'message': 'Request too large',
|
||||||
|
'error_code': 413
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle JSON requests
|
||||||
|
content_type = request.headers.get('content-type', '')
|
||||||
|
if 'application/json' in content_type:
|
||||||
|
try:
|
||||||
|
body = await request.body()
|
||||||
|
if body:
|
||||||
|
data = json.loads(body.decode('utf-8'))
|
||||||
|
# Basic sanitization
|
||||||
|
request.state.validated_data = self.sanitize_json_data(data)
|
||||||
|
else:
|
||||||
|
request.state.validated_data = {}
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
content={
|
||||||
|
'status': 'error',
|
||||||
|
'message': 'Invalid JSON format',
|
||||||
|
'details': str(e),
|
||||||
|
'error_code': 400
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"JSON processing error: {str(e)}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
content={
|
||||||
|
'status': 'error',
|
||||||
|
'message': 'Error processing JSON data',
|
||||||
|
'error_code': 400
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle form data
|
||||||
|
elif 'application/x-www-form-urlencoded' in content_type or 'multipart/form-data' in content_type:
|
||||||
|
try:
|
||||||
|
form_data = await request.form()
|
||||||
|
request.state.validated_data = {}
|
||||||
|
for key, value in form_data.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
request.state.validated_data[key] = self.sanitize_string(value)
|
||||||
|
else:
|
||||||
|
request.state.validated_data[key] = value
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Form data processing error: {str(e)}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
content={
|
||||||
|
'status': 'error',
|
||||||
|
'message': 'Error processing form data',
|
||||||
|
'error_code': 400
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Sanitize query parameters
|
||||||
|
sanitized_params = {}
|
||||||
|
for key, value in request.state.query_params.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
sanitized_params[key] = self.sanitize_string(value)
|
||||||
|
else:
|
||||||
|
sanitized_params[key] = value
|
||||||
|
request.state.query_params = sanitized_params
|
||||||
|
|
||||||
|
# Continue to next middleware/handler
|
||||||
|
response = await call_next(request)
|
||||||
|
return response
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Validation middleware error: {str(e)}")
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
content={
|
||||||
|
'status': 'error',
|
||||||
|
'message': 'Validation error',
|
||||||
|
'error_code': 500
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def sanitize_string(self, value: str, max_length: int = 1000) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize string input by removing/escaping dangerous characters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: String to sanitize
|
||||||
|
max_length: Maximum allowed length
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized string
|
||||||
|
"""
|
||||||
|
if not isinstance(value, str):
|
||||||
|
return str(value)
|
||||||
|
|
||||||
|
# Truncate if too long
|
||||||
|
if len(value) > max_length:
|
||||||
|
value = value[:max_length]
|
||||||
|
|
||||||
|
# HTML escape to prevent XSS
|
||||||
|
value = html.escape(value)
|
||||||
|
|
||||||
|
# Remove potentially dangerous patterns
|
||||||
|
value = re.sub(r'<script[^>]*>.*?</script>', '', value, flags=re.IGNORECASE | re.DOTALL)
|
||||||
|
value = re.sub(r'javascript:', '', value, flags=re.IGNORECASE)
|
||||||
|
value = re.sub(r'on\w+\s*=', '', value, flags=re.IGNORECASE)
|
||||||
|
|
||||||
|
return value.strip()
|
||||||
|
|
||||||
|
def sanitize_json_data(self, data: Union[Dict, list, str, int, float, bool, None]) -> Union[Dict, list, str, int, float, bool, None]:
|
||||||
|
"""
|
||||||
|
Recursively sanitize JSON data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: JSON data to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized JSON data
|
||||||
|
"""
|
||||||
|
if isinstance(data, dict):
|
||||||
|
return {key: self.sanitize_json_data(value) for key, value in data.items()}
|
||||||
|
elif isinstance(data, list):
|
||||||
|
return [self.sanitize_json_data(item) for item in data]
|
||||||
|
elif isinstance(data, str):
|
||||||
|
return self.sanitize_string(data)
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
# FastAPI dependency functions for validation
|
||||||
|
async def get_validated_data(request: Request) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
FastAPI dependency to get validated request data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated request data or None
|
||||||
|
"""
|
||||||
|
return getattr(request.state, 'validated_data', None)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_query_params(request: Request) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
FastAPI dependency to get sanitized query parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized query parameters
|
||||||
|
"""
|
||||||
|
return getattr(request.state, 'query_params', {})
|
||||||
|
|
||||||
|
|
||||||
|
async def require_json_data(request: Request) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
FastAPI dependency that requires JSON data to be present.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: FastAPI request object
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Validated JSON data
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException: If no JSON data is present
|
||||||
|
"""
|
||||||
|
data = await get_validated_data(request)
|
||||||
|
if not data:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail={
|
||||||
|
'status': 'error',
|
||||||
|
'message': 'JSON data required',
|
||||||
|
'error_code': 400
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def validate_required_fields(required_fields: list):
|
||||||
|
"""
|
||||||
|
FastAPI dependency factory for validating required fields.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
required_fields: List of required field names
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dependency function
|
||||||
|
"""
|
||||||
|
async def field_validation_dependency(data: Dict[str, Any] = require_json_data) -> Dict[str, Any]:
|
||||||
|
missing_fields = []
|
||||||
|
for field in required_fields:
|
||||||
|
if field not in data or data[field] is None or data[field] == '':
|
||||||
|
missing_fields.append(field)
|
||||||
|
|
||||||
|
if missing_fields:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail={
|
||||||
|
'status': 'error',
|
||||||
|
'message': f'Missing required fields: {", ".join(missing_fields)}',
|
||||||
|
'missing_fields': missing_fields,
|
||||||
|
'error_code': 400
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
return field_validation_dependency
|
||||||
@ -105,7 +105,7 @@ This document contains tasks for migrating the web application from Flask to Fas
|
|||||||
|
|
||||||
- [x] Convert Flask middleware to FastAPI middleware
|
- [x] Convert Flask middleware to FastAPI middleware
|
||||||
- [x] Update error handling from Flask error handlers to FastAPI exception handlers
|
- [x] Update error handling from Flask error handlers to FastAPI exception handlers
|
||||||
- [ ] Migrate request/response interceptors
|
- [x] Migrate request/response interceptors
|
||||||
- [ ] Update logging middleware if used
|
- [ ] Update logging middleware if used
|
||||||
|
|
||||||
## 🧪 Testing and Validation
|
## 🧪 Testing and Validation
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user