Aniworld/src/server/utils/dependencies.py

196 lines
5.0 KiB
Python

"""
Dependency injection utilities for FastAPI.
This module provides dependency injection functions for the FastAPI
application, including SeriesApp instances, database sessions, and
authentication dependencies.
"""
from typing import AsyncGenerator, Optional
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
try:
from sqlalchemy.ext.asyncio import AsyncSession
except Exception: # pragma: no cover - optional dependency
AsyncSession = object
from src.config.settings import settings
from src.core.SeriesApp import SeriesApp
from src.server.services.auth_service import AuthError, auth_service
# Security scheme for JWT authentication
security = HTTPBearer()
# Global SeriesApp instance
_series_app: Optional[SeriesApp] = None
def get_series_app() -> SeriesApp:
"""
Dependency to get SeriesApp instance.
Returns:
SeriesApp: The main application instance for anime management
Raises:
HTTPException: If SeriesApp is not initialized or anime directory
is not configured
"""
global _series_app
if not settings.anime_directory:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Anime directory not configured. Please complete setup."
)
if _series_app is None:
try:
_series_app = SeriesApp(settings.anime_directory)
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to initialize SeriesApp: {str(e)}"
)
return _series_app
def reset_series_app() -> None:
"""Reset the global SeriesApp instance (for testing or config changes)."""
global _series_app
_series_app = None
async def get_database_session() -> AsyncGenerator[Optional[object], None]:
"""
Dependency to get database session.
Yields:
AsyncSession: Database session for async operations
"""
# TODO: Implement database session management
# This is a placeholder for future database implementation
raise HTTPException(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
detail="Database functionality not yet implemented"
)
def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security)
) -> dict:
"""
Dependency to get current authenticated user.
Args:
credentials: JWT token from Authorization header
Returns:
dict: User information
Raises:
HTTPException: If token is invalid or user is not authenticated
"""
if not credentials:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing authorization credentials",
)
token = credentials.credentials
try:
# Validate and decode token using the auth service
session = auth_service.create_session_model(token)
return session.dict()
except AuthError as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=str(e),
)
def require_auth(
current_user: dict = Depends(get_current_user)
) -> dict:
"""
Dependency that requires authentication.
Args:
current_user: Current authenticated user from get_current_user
Returns:
dict: User information
"""
return current_user
def optional_auth(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(
HTTPBearer(auto_error=False)
)
) -> Optional[dict]:
"""
Dependency for optional authentication.
Args:
credentials: Optional JWT token from Authorization header
Returns:
Optional[dict]: User information if authenticated, None otherwise
"""
if credentials is None:
return None
try:
return get_current_user(credentials)
except HTTPException:
return None
class CommonQueryParams:
"""Common query parameters for API endpoints."""
def __init__(self, skip: int = 0, limit: int = 100):
self.skip = skip
self.limit = limit
def common_parameters(
skip: int = 0,
limit: int = 100
) -> CommonQueryParams:
"""
Dependency for common query parameters.
Args:
skip: Number of items to skip (for pagination)
limit: Maximum number of items to return
Returns:
CommonQueryParams: Common query parameters
"""
return CommonQueryParams(skip=skip, limit=limit)
# Dependency for rate limiting (placeholder)
async def rate_limit_dependency():
"""
Dependency for rate limiting API requests.
TODO: Implement rate limiting logic
"""
pass
# Dependency for request logging (placeholder)
async def log_request_dependency():
"""
Dependency for logging API requests.
TODO: Implement request logging logic
"""
pass