196 lines
5.0 KiB
Python
196 lines
5.0 KiB
Python
"""
|
|
Dependency injection utilities for FastAPI.
|
|
|
|
This module provides dependency injection functions for the FastAPI
|
|
application, including SeriesApp instances, database sessions, and
|
|
authentication dependencies.
|
|
"""
|
|
from typing import AsyncGenerator, Optional
|
|
|
|
from fastapi import Depends, HTTPException, status
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
|
|
try:
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
except Exception: # pragma: no cover - optional dependency
|
|
AsyncSession = object
|
|
|
|
from src.config.settings import settings
|
|
from src.core.SeriesApp import SeriesApp
|
|
from src.server.services.auth_service import AuthError, auth_service
|
|
|
|
# Security scheme for JWT authentication
|
|
security = HTTPBearer()
|
|
|
|
|
|
# Global SeriesApp instance
|
|
_series_app: Optional[SeriesApp] = None
|
|
|
|
|
|
def get_series_app() -> SeriesApp:
|
|
"""
|
|
Dependency to get SeriesApp instance.
|
|
|
|
Returns:
|
|
SeriesApp: The main application instance for anime management
|
|
|
|
Raises:
|
|
HTTPException: If SeriesApp is not initialized or anime directory
|
|
is not configured
|
|
"""
|
|
global _series_app
|
|
|
|
if not settings.anime_directory:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Anime directory not configured. Please complete setup."
|
|
)
|
|
|
|
if _series_app is None:
|
|
try:
|
|
_series_app = SeriesApp(settings.anime_directory)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to initialize SeriesApp: {str(e)}"
|
|
)
|
|
|
|
return _series_app
|
|
|
|
|
|
def reset_series_app() -> None:
|
|
"""Reset the global SeriesApp instance (for testing or config changes)."""
|
|
global _series_app
|
|
_series_app = None
|
|
|
|
|
|
async def get_database_session() -> AsyncGenerator[Optional[object], None]:
|
|
"""
|
|
Dependency to get database session.
|
|
|
|
Yields:
|
|
AsyncSession: Database session for async operations
|
|
"""
|
|
# TODO: Implement database session management
|
|
# This is a placeholder for future database implementation
|
|
raise HTTPException(
|
|
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
|
detail="Database functionality not yet implemented"
|
|
)
|
|
|
|
|
|
def get_current_user(
|
|
credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
) -> dict:
|
|
"""
|
|
Dependency to get current authenticated user.
|
|
|
|
Args:
|
|
credentials: JWT token from Authorization header
|
|
|
|
Returns:
|
|
dict: User information
|
|
|
|
Raises:
|
|
HTTPException: If token is invalid or user is not authenticated
|
|
"""
|
|
if not credentials:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Missing authorization credentials",
|
|
)
|
|
|
|
token = credentials.credentials
|
|
try:
|
|
# Validate and decode token using the auth service
|
|
session = auth_service.create_session_model(token)
|
|
return session.dict()
|
|
except AuthError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=str(e),
|
|
)
|
|
|
|
|
|
def require_auth(
|
|
current_user: dict = Depends(get_current_user)
|
|
) -> dict:
|
|
"""
|
|
Dependency that requires authentication.
|
|
|
|
Args:
|
|
current_user: Current authenticated user from get_current_user
|
|
|
|
Returns:
|
|
dict: User information
|
|
"""
|
|
return current_user
|
|
|
|
|
|
def optional_auth(
|
|
credentials: Optional[HTTPAuthorizationCredentials] = Depends(
|
|
HTTPBearer(auto_error=False)
|
|
)
|
|
) -> Optional[dict]:
|
|
"""
|
|
Dependency for optional authentication.
|
|
|
|
Args:
|
|
credentials: Optional JWT token from Authorization header
|
|
|
|
Returns:
|
|
Optional[dict]: User information if authenticated, None otherwise
|
|
"""
|
|
if credentials is None:
|
|
return None
|
|
|
|
try:
|
|
return get_current_user(credentials)
|
|
except HTTPException:
|
|
return None
|
|
|
|
|
|
class CommonQueryParams:
|
|
"""Common query parameters for API endpoints."""
|
|
|
|
def __init__(self, skip: int = 0, limit: int = 100):
|
|
self.skip = skip
|
|
self.limit = limit
|
|
|
|
|
|
def common_parameters(
|
|
skip: int = 0,
|
|
limit: int = 100
|
|
) -> CommonQueryParams:
|
|
"""
|
|
Dependency for common query parameters.
|
|
|
|
Args:
|
|
skip: Number of items to skip (for pagination)
|
|
limit: Maximum number of items to return
|
|
|
|
Returns:
|
|
CommonQueryParams: Common query parameters
|
|
"""
|
|
return CommonQueryParams(skip=skip, limit=limit)
|
|
|
|
|
|
# Dependency for rate limiting (placeholder)
|
|
async def rate_limit_dependency():
|
|
"""
|
|
Dependency for rate limiting API requests.
|
|
|
|
TODO: Implement rate limiting logic
|
|
"""
|
|
pass
|
|
|
|
|
|
# Dependency for request logging (placeholder)
|
|
async def log_request_dependency():
|
|
"""
|
|
Dependency for logging API requests.
|
|
|
|
TODO: Implement request logging logic
|
|
"""
|
|
pass
|