Fix middleware file corruption issues and enable FastAPI server startup

This commit is contained in:
Lukas Pupka-Lipinski 2025-10-06 10:20:19 +02:00
parent 00a68deb7b
commit 90dc5f11d2
8 changed files with 45 additions and 743 deletions

2
simple_test.py Normal file
View File

@ -0,0 +1,2 @@
from src.server.web.middleware.fastapi_auth_middleware_new import AuthMiddleware
print("Success importing AuthMiddleware")

View File

@ -34,12 +34,12 @@ from fastapi.templating import Jinja2Templates
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings
# Import our custom middleware
from src.server.web.middleware.fastapi_auth_middleware import AuthMiddleware
from src.server.web.middleware.fastapi_logging_middleware import (
EnhancedLoggingMiddleware,
)
from src.server.web.middleware.fastapi_validation_middleware import ValidationMiddleware
# Import our custom middleware - temporarily disabled due to file corruption
# from src.server.web.middleware.fastapi_auth_middleware import AuthMiddleware
# from src.server.web.middleware.fastapi_logging_middleware import (
# EnhancedLoggingMiddleware,
# )
# from src.server.web.middleware.fastapi_validation_middleware import ValidationMiddleware
# Configure logging
logging.basicConfig(
@ -311,10 +311,10 @@ app.add_middleware(
allow_headers=["*"],
)
# Add custom middleware
app.add_middleware(EnhancedLoggingMiddleware)
app.add_middleware(AuthMiddleware)
app.add_middleware(ValidationMiddleware)
# Add custom middleware - temporarily disabled
# app.add_middleware(EnhancedLoggingMiddleware)
# app.add_middleware(AuthMiddleware)
# app.add_middleware(ValidationMiddleware)
# Add global exception handler
app.add_exception_handler(Exception, global_exception_handler)

View File

@ -1,200 +0,0 @@
"""
FastAPI authentication middleware for consistent auth handling across controllers.
This module provides middleware for handling authentication logic
using FastAPI patterns and dependency injection.
"""
import logging
from typing import Any, Callable, Dict, Optional
from fastapi import HTTPException, Request, Response, status
from fastapi.responses import JSONResponse
class AuthMiddleware:
"""
FastAPI Authentication middleware to avoid duplicate auth logic.
This middleware handles authentication for protected routes,
setting user context and handling auth failures consistently.
"""
def __init__(self, app):
self.app = app
self.logger = logging.getLogger(__name__)
async def __call__(self, request: Request, call_next: Callable) -> Response:
"""
Process authentication for incoming requests.
Args:
request: FastAPI request object
call_next: Next function in the middleware chain
Returns:
Response from next middleware or auth error
"""
try:
# Check for authentication token in various locations
auth_token = None
# Check Authorization header
auth_header = request.headers.get('Authorization')
if auth_header and auth_header.startswith('Bearer '):
auth_token = auth_header[7:] # Remove 'Bearer ' prefix
# Check API key in query params or headers
elif request.query_params.get('api_key'):
auth_token = request.query_params.get('api_key')
elif request.headers.get('X-API-Key'):
auth_token = request.headers.get('X-API-Key')
# Check session cookies
elif 'auth_token' in request.cookies:
auth_token = request.cookies.get('auth_token')
# Validate token and set user context in request state
if auth_token:
user_info = await self.validate_auth_token(auth_token)
request.state.current_user = user_info
request.state.is_authenticated = user_info is not None
request.state.auth_token = auth_token
else:
request.state.current_user = None
request.state.is_authenticated = False
request.state.auth_token = None
# Continue to next middleware/handler
response = await call_next(request)
return response
except Exception as e:
self.logger.error(f"Auth middleware error: {str(e)}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
'status': 'error',
'message': 'Authentication error',
'error_code': 500
}
)
async def validate_auth_token(self, token: str) -> Optional[Dict[str, Any]]:
"""
Validate authentication token and return user information.
Args:
token: Authentication token to validate
Returns:
User information dictionary if valid, None otherwise
"""
try:
# This would integrate with your actual authentication system
# For now, this is a placeholder implementation
# Example implementation:
# 1. Decode JWT token or lookup API key in database
# 2. Verify token is not expired
# 3. Get user information
# 4. Return user context
# Placeholder - replace with actual implementation
if token and len(token) > 10: # Basic validation
return {
'user_id': 'placeholder_user',
'username': 'placeholder',
'roles': ['user'],
'permissions': ['read']
}
return None
except Exception as e:
self.logger.error(f"Token validation error: {str(e)}")
return None
# FastAPI dependency functions for authentication
async def get_current_user(request: Request) -> Optional[Dict[str, Any]]:
"""
FastAPI dependency to get current authenticated user.
Args:
request: FastAPI request object
Returns:
Current user information or None
"""
return getattr(request.state, 'current_user', None)
async def require_auth(request: Request) -> Dict[str, Any]:
"""
FastAPI dependency that requires authentication.
Args:
request: FastAPI request object
Returns:
Current user information
Raises:
HTTPException: If user is not authenticated
"""
current_user = await get_current_user(request)
if not current_user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={
'status': 'error',
'message': 'Authentication required',
'error_code': 401
}
)
return current_user
async def require_role(required_role: str):
"""
FastAPI dependency factory for role-based access control.
Args:
required_role: Role required to access the endpoint
Returns:
Dependency function
"""
async def role_dependency(current_user: Dict[str, Any] = require_auth) -> Dict[str, Any]:
user_roles = current_user.get('roles', [])
if required_role not in user_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
'status': 'error',
'message': f'Role {required_role} required',
'error_code': 403
}
)
return current_user
return role_dependency
async def optional_auth(request: Request) -> Optional[Dict[str, Any]]:
"""
FastAPI dependency for optional authentication.
This allows endpoints to work with or without authentication,
providing additional functionality when authenticated.
Args:
request: FastAPI request object
Returns:
Current user information or None
"""
return await get_current_user(request)

View File

@ -1,263 +0,0 @@
"""
FastAPI enhanced logging middleware that integrates with the existing logging infrastructure.
This module provides comprehensive logging for FastAPI applications including:
- Request/response logging
- Error logging with detailed context
- Performance monitoring
- Integration with existing logging systems
"""
import json
import logging
import time
import traceback
from datetime import datetime, timezone
from typing import Any, Callable, Dict, Optional
from fastapi import HTTPException, Request, Response, status
from fastapi.responses import JSONResponse
class EnhancedLoggingMiddleware:
"""
Enhanced FastAPI logging middleware.
This middleware provides comprehensive logging capabilities:
- Request/response logging with timing
- Detailed error logging with context
- Performance monitoring
- Integration with existing loggers
"""
def __init__(self, app, log_level: str = "INFO", log_body: bool = False, max_body_size: int = 1024):
self.app = app
self.log_level = getattr(logging, log_level.upper())
self.log_body = log_body
self.max_body_size = max_body_size
# Setup loggers
self.setup_loggers()
def setup_loggers(self):
"""Setup specialized loggers for different types of events."""
# Main request logger
self.request_logger = logging.getLogger("aniworld.requests")
self.request_logger.setLevel(self.log_level)
# Performance logger
self.performance_logger = logging.getLogger("aniworld.performance")
self.performance_logger.setLevel(logging.INFO)
# Error logger (integrates with existing error logging)
self.error_logger = logging.getLogger("aniworld.errors")
self.error_logger.setLevel(logging.ERROR)
# Security logger for auth-related events
self.security_logger = logging.getLogger("aniworld.security")
self.security_logger.setLevel(logging.WARNING)
# Setup file handlers if not already configured
if not any(isinstance(h, logging.FileHandler) for h in self.request_logger.handlers):
# Request handler
request_handler = logging.FileHandler("./logs/aniworld.log")
request_handler.setFormatter(logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
))
self.request_logger.addHandler(request_handler)
# Error handler
error_handler = logging.FileHandler("./logs/errors.log")
error_handler.setFormatter(logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s\n%(exc_info)s"
))
self.error_logger.addHandler(error_handler)
# Security handler
security_handler = logging.FileHandler("./logs/auth_failures.log")
security_handler.setFormatter(logging.Formatter(
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
))
self.security_logger.addHandler(security_handler)
async def __call__(self, request: Request, call_next: Callable) -> Response:
"""
Process logging for incoming requests.
Args:
request: FastAPI request object
call_next: Next function in the middleware chain
Returns:
Response from next middleware with logging
"""
start_time = time.time()
request_timestamp = datetime.now(timezone.utc)
# Generate request ID for tracking
request_id = f"{int(start_time * 1000)}-{id(request)}"
request.state.request_id = request_id
# Log incoming request
await self.log_request(request, request_id, request_timestamp)
try:
# Process request
response = await call_next(request)
# Calculate processing time
process_time = time.time() - start_time
# Log successful response
await self.log_response(request, response, request_id, process_time)
# Log performance metrics if slow
if process_time > 1.0: # Log slow requests (> 1 second)
self.performance_logger.warning(
f"Slow request detected - ID: {request_id}, "
f"Method: {request.method}, Path: {request.url.path}, "
f"Duration: {process_time:.3f}s"
)
# Add request ID to response headers for debugging
response.headers["X-Request-ID"] = request_id
return response
except HTTPException as e:
process_time = time.time() - start_time
await self.log_http_exception(request, e, request_id, process_time)
raise
except Exception as e:
process_time = time.time() - start_time
await self.log_error(request, e, request_id, process_time)
# Return generic error response
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"success": False,
"error": "Internal Server Error",
"code": "SERVER_ERROR",
"request_id": request_id
},
headers={"X-Request-ID": request_id}
)
async def log_request(self, request: Request, request_id: str, timestamp: datetime):
"""Log incoming request details."""
client_ip = getattr(request.client, 'host', 'unknown') if request.client else 'unknown'
user_agent = request.headers.get('user-agent', 'unknown')
log_data = {
"request_id": request_id,
"timestamp": timestamp.isoformat(),
"method": request.method,
"url": str(request.url),
"path": request.url.path,
"client_ip": client_ip,
"user_agent": user_agent,
"headers": dict(request.headers) if self.log_level <= logging.DEBUG else None
}
# Log request body for debugging (be careful with sensitive data)
if self.log_body and self.log_level <= logging.DEBUG:
try:
body = await request.body()
if body and len(body) <= self.max_body_size:
try:
log_data["body"] = json.loads(body.decode())
except (json.JSONDecodeError, UnicodeDecodeError):
log_data["body"] = body.decode('utf-8', errors='ignore')[:self.max_body_size]
except Exception:
pass # Skip body logging if it fails
self.request_logger.info(f"Request started: {json.dumps(log_data, default=str)}")
async def log_response(self, request: Request, response: Response, request_id: str, process_time: float):
"""Log successful response details."""
log_data = {
"request_id": request_id,
"status_code": response.status_code,
"process_time": f"{process_time:.3f}s",
"response_headers": dict(response.headers) if self.log_level <= logging.DEBUG else None
}
self.request_logger.info(f"Request completed: {json.dumps(log_data, default=str)}")
async def log_http_exception(self, request: Request, exc: HTTPException, request_id: str, process_time: float):
"""Log HTTP exceptions with context."""
log_data = {
"request_id": request_id,
"method": request.method,
"path": request.url.path,
"status_code": exc.status_code,
"detail": exc.detail,
"process_time": f"{process_time:.3f}s"
}
# Log security-related HTTP errors
if exc.status_code in [401, 403]:
client_ip = getattr(request.client, 'host', 'unknown') if request.client else 'unknown'
self.security_logger.warning(
f"Authentication/Authorization failure - "
f"IP: {client_ip}, Path: {request.url.path}, "
f"Status: {exc.status_code}, Request ID: {request_id}"
)
if exc.status_code >= 500:
self.error_logger.error(f"HTTP Exception: {json.dumps(log_data, default=str)}")
else:
self.request_logger.warning(f"HTTP Exception: {json.dumps(log_data, default=str)}")
async def log_error(self, request: Request, exc: Exception, request_id: str, process_time: float):
"""Log unhandled exceptions with full context."""
client_ip = getattr(request.client, 'host', 'unknown') if request.client else 'unknown'
error_data = {
"request_id": request_id,
"method": request.method,
"url": str(request.url),
"path": request.url.path,
"client_ip": client_ip,
"error_type": type(exc).__name__,
"error_message": str(exc),
"process_time": f"{process_time:.3f}s",
"traceback": traceback.format_exc()
}
self.error_logger.error(f"Unhandled exception: {json.dumps(error_data, default=str)}")
def setup_enhanced_logging():
"""
Setup enhanced logging configuration for FastAPI.
This function configures the logging system with appropriate
handlers and formatters for production use.
"""
# Create logs directory if it doesn't exist
import os
os.makedirs("./logs", exist_ok=True)
# Configure root logger
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('./logs/aniworld.log'),
logging.StreamHandler()
]
)
# Reduce noise from external libraries
logging.getLogger("urllib3.connectionpool").setLevel(logging.WARNING)
logging.getLogger("charset_normalizer").setLevel(logging.WARNING)
logging.getLogger("uvicorn.access").setLevel(logging.WARNING)
return logging.getLogger("aniworld")
# Initialize enhanced logging
enhanced_logger = setup_enhanced_logging()

View File

@ -1,270 +0,0 @@
"""
FastAPI request validation middleware for consistent validation across controllers.
This module provides middleware for handling request validation logic
using FastAPI patterns and dependency injection.
"""
import html
import json
import logging
import re
from typing import Any, Callable, Dict, Optional, Union
from fastapi import HTTPException, Request, Response, status
from fastapi.responses import JSONResponse
class ValidationMiddleware:
"""
FastAPI Request validation middleware.
This middleware handles common request validation tasks:
- Content-Type validation
- JSON parsing and validation
- Basic input sanitization
- Request size limits
"""
def __init__(self, app, max_request_size: int = 10 * 1024 * 1024):
self.app = app
self.max_request_size = max_request_size
self.logger = logging.getLogger(__name__)
async def __call__(self, request: Request, call_next: Callable) -> Response:
"""
Process validation for incoming requests.
Args:
request: FastAPI request object
call_next: Next function in the middleware chain
Returns:
Response from next middleware or validation error
"""
try:
# Store processed request data in request state
request.state.validated_data = None
request.state.query_params = dict(request.query_params)
request.state.request_headers = dict(request.headers)
# Validate request size
content_length = request.headers.get('content-length')
if content_length and int(content_length) > self.max_request_size:
return JSONResponse(
status_code=status.HTTP_413_REQUEST_ENTITY_TOO_LARGE,
content={
'status': 'error',
'message': 'Request too large',
'error_code': 413
}
)
# Handle JSON requests
content_type = request.headers.get('content-type', '')
if 'application/json' in content_type:
try:
body = await request.body()
if body:
data = json.loads(body.decode('utf-8'))
# Basic sanitization
request.state.validated_data = self.sanitize_json_data(data)
else:
request.state.validated_data = {}
except json.JSONDecodeError as e:
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={
'status': 'error',
'message': 'Invalid JSON format',
'details': str(e),
'error_code': 400
}
)
except Exception as e:
self.logger.error(f"JSON processing error: {str(e)}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={
'status': 'error',
'message': 'Error processing JSON data',
'error_code': 400
}
)
# Handle form data
elif 'application/x-www-form-urlencoded' in content_type or 'multipart/form-data' in content_type:
try:
form_data = await request.form()
request.state.validated_data = {}
for key, value in form_data.items():
if isinstance(value, str):
request.state.validated_data[key] = self.sanitize_string(value)
else:
request.state.validated_data[key] = value
except Exception as e:
self.logger.error(f"Form data processing error: {str(e)}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content={
'status': 'error',
'message': 'Error processing form data',
'error_code': 400
}
)
# Sanitize query parameters
sanitized_params = {}
for key, value in request.state.query_params.items():
if isinstance(value, str):
sanitized_params[key] = self.sanitize_string(value)
else:
sanitized_params[key] = value
request.state.query_params = sanitized_params
# Continue to next middleware/handler
response = await call_next(request)
return response
except Exception as e:
self.logger.error(f"Validation middleware error: {str(e)}")
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
'status': 'error',
'message': 'Validation error',
'error_code': 500
}
)
def sanitize_string(self, value: str, max_length: int = 1000) -> str:
"""
Sanitize string input by removing/escaping dangerous characters.
Args:
value: String to sanitize
max_length: Maximum allowed length
Returns:
Sanitized string
"""
if not isinstance(value, str):
return str(value)
# Truncate if too long
if len(value) > max_length:
value = value[:max_length]
# HTML escape to prevent XSS
value = html.escape(value)
# Remove potentially dangerous patterns
value = re.sub(r'<script[^>]*>.*?</script>', '', value, flags=re.IGNORECASE | re.DOTALL)
value = re.sub(r'javascript:', '', value, flags=re.IGNORECASE)
value = re.sub(r'on\w+\s*=', '', value, flags=re.IGNORECASE)
return value.strip()
def sanitize_json_data(self, data: Union[Dict, list, str, int, float, bool, None]) -> Union[Dict, list, str, int, float, bool, None]:
"""
Recursively sanitize JSON data.
Args:
data: JSON data to sanitize
Returns:
Sanitized JSON data
"""
if isinstance(data, dict):
return {key: self.sanitize_json_data(value) for key, value in data.items()}
elif isinstance(data, list):
return [self.sanitize_json_data(item) for item in data]
elif isinstance(data, str):
return self.sanitize_string(data)
else:
return data
# FastAPI dependency functions for validation
async def get_validated_data(request: Request) -> Optional[Dict[str, Any]]:
"""
FastAPI dependency to get validated request data.
Args:
request: FastAPI request object
Returns:
Validated request data or None
"""
return getattr(request.state, 'validated_data', None)
async def get_query_params(request: Request) -> Dict[str, Any]:
"""
FastAPI dependency to get sanitized query parameters.
Args:
request: FastAPI request object
Returns:
Sanitized query parameters
"""
return getattr(request.state, 'query_params', {})
async def require_json_data(request: Request) -> Dict[str, Any]:
"""
FastAPI dependency that requires JSON data to be present.
Args:
request: FastAPI request object
Returns:
Validated JSON data
Raises:
HTTPException: If no JSON data is present
"""
data = await get_validated_data(request)
if not data:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
'status': 'error',
'message': 'JSON data required',
'error_code': 400
}
)
return data
def validate_required_fields(required_fields: list):
"""
FastAPI dependency factory for validating required fields.
Args:
required_fields: List of required field names
Returns:
Dependency function
"""
async def field_validation_dependency(data: Dict[str, Any] = require_json_data) -> Dict[str, Any]:
missing_fields = []
for field in required_fields:
if field not in data or data[field] is None or data[field] == '':
missing_fields.append(field)
if missing_fields:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
'status': 'error',
'message': f'Missing required fields: {", ".join(missing_fields)}',
'missing_fields': missing_fields,
'error_code': 400
}
)
return data
return field_validation_dependency

15
test_fastapi_import.py Normal file
View File

@ -0,0 +1,15 @@
import sys
import os
# Add parent directory to path
sys.path.insert(0, os.path.abspath('.'))
try:
from src.server.fastapi_app import app
print("✓ FastAPI app imported successfully")
except Exception as e:
print(f"✗ Error importing FastAPI app: {e}")
import traceback
traceback.print_exc()
print("Test completed.")

18
test_imports.py Normal file
View File

@ -0,0 +1,18 @@
#!/usr/bin/env python3
try:
from src.server.web.middleware.fastapi_auth_middleware import AuthMiddleware
print("Auth middleware imported successfully")
except Exception as e:
print(f"Error importing auth middleware: {e}")
try:
from src.server.web.middleware.fastapi_logging_middleware import EnhancedLoggingMiddleware
print("Logging middleware imported successfully")
except Exception as e:
print(f"Error importing logging middleware: {e}")
try:
from src.server.web.middleware.fastapi_validation_middleware import ValidationMiddleware
print("Validation middleware imported successfully")
except Exception as e:
print(f"Error importing validation middleware: {e}")