- Added db_session parameter to SeriesApp.__init__() - Added db_session property and set_db_session() method - Added init_from_db_async() for async database initialization - Pass db_session to SerieList and SerieScanner during construction - Added get_series_app_with_db() dependency for FastAPI endpoints - All 815 unit tests and 55 API tests pass
492 lines
14 KiB
Python
492 lines
14 KiB
Python
"""
|
|
Dependency injection utilities for FastAPI.
|
|
|
|
This module provides dependency injection functions for the FastAPI
|
|
application, including SeriesApp instances, AnimeService, DownloadService,
|
|
database sessions, and authentication dependencies.
|
|
"""
|
|
import logging
|
|
import time
|
|
from asyncio import Lock
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Optional
|
|
|
|
from fastapi import Depends, HTTPException, Request, status
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
|
|
try:
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
except Exception: # pragma: no cover - optional dependency
|
|
AsyncSession = object
|
|
|
|
from src.config.settings import settings
|
|
from src.core.SeriesApp import SeriesApp
|
|
from src.server.services.auth_service import AuthError, auth_service
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if TYPE_CHECKING:
|
|
from src.server.services.anime_service import AnimeService
|
|
from src.server.services.download_service import DownloadService
|
|
|
|
# Security scheme for JWT authentication
|
|
# Use auto_error=False to handle errors manually and return 401 instead of 403
|
|
http_bearer_security = HTTPBearer(auto_error=False)
|
|
|
|
|
|
# Global SeriesApp instance
|
|
_series_app: Optional[SeriesApp] = None
|
|
|
|
# Global service instances
|
|
_anime_service: Optional["AnimeService"] = None
|
|
_download_service: Optional["DownloadService"] = None
|
|
|
|
|
|
@dataclass
|
|
class RateLimitRecord:
|
|
"""Track request counts within a fixed time window."""
|
|
|
|
count: int
|
|
window_start: float
|
|
|
|
|
|
_RATE_LIMIT_BUCKETS: Dict[str, RateLimitRecord] = {}
|
|
_rate_limit_lock = Lock()
|
|
_RATE_LIMIT_WINDOW_SECONDS = 60.0
|
|
|
|
|
|
def get_series_app() -> SeriesApp:
|
|
"""
|
|
Dependency to get SeriesApp instance.
|
|
|
|
Returns:
|
|
SeriesApp: The main application instance for anime management
|
|
|
|
Raises:
|
|
HTTPException: If SeriesApp is not initialized or anime directory
|
|
is not configured
|
|
|
|
Note:
|
|
This creates a SeriesApp without database support. For database-
|
|
backed storage, use get_series_app_with_db() instead.
|
|
"""
|
|
global _series_app
|
|
|
|
# Try to load anime_directory from config.json if not in settings
|
|
if not settings.anime_directory:
|
|
try:
|
|
from src.server.services.config_service import get_config_service
|
|
config_service = get_config_service()
|
|
config = config_service.load_config()
|
|
if config.other and config.other.get("anime_directory"):
|
|
settings.anime_directory = str(config.other["anime_directory"])
|
|
except Exception:
|
|
pass # Will raise 503 below if still not configured
|
|
|
|
if not settings.anime_directory:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Anime directory not configured. Please complete setup."
|
|
)
|
|
|
|
if _series_app is None:
|
|
try:
|
|
_series_app = SeriesApp(settings.anime_directory)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to initialize SeriesApp: {str(e)}"
|
|
)
|
|
|
|
return _series_app
|
|
|
|
|
|
def reset_series_app() -> None:
|
|
"""Reset the global SeriesApp instance (for testing or config changes)."""
|
|
global _series_app
|
|
_series_app = None
|
|
|
|
|
|
async def get_database_session() -> AsyncGenerator:
|
|
"""
|
|
Dependency to get database session.
|
|
|
|
Yields:
|
|
AsyncSession: Database session for async operations
|
|
|
|
Example:
|
|
@app.get("/anime")
|
|
async def get_anime(db: AsyncSession = Depends(get_database_session)):
|
|
result = await db.execute(select(AnimeSeries))
|
|
return result.scalars().all()
|
|
"""
|
|
try:
|
|
from src.server.database import get_db_session
|
|
|
|
async with get_db_session() as session:
|
|
yield session
|
|
except ImportError:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
|
detail="Database functionality not installed"
|
|
)
|
|
except RuntimeError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail=f"Database not available: {str(e)}"
|
|
)
|
|
|
|
|
|
async def get_optional_database_session() -> AsyncGenerator:
|
|
"""
|
|
Dependency to get optional database session.
|
|
|
|
Unlike get_database_session(), this returns None if the database
|
|
is not available, allowing endpoints to fall back to other storage.
|
|
|
|
Yields:
|
|
AsyncSession or None: Database session if available, None otherwise
|
|
|
|
Example:
|
|
@app.post("/anime/add")
|
|
async def add_anime(
|
|
db: Optional[AsyncSession] = Depends(get_optional_database_session)
|
|
):
|
|
if db:
|
|
# Use database
|
|
await AnimeSeriesService.create(db, ...)
|
|
else:
|
|
# Fall back to file-based storage
|
|
series_app.list.add(serie)
|
|
"""
|
|
try:
|
|
from src.server.database import get_db_session
|
|
|
|
async with get_db_session() as session:
|
|
yield session
|
|
except (ImportError, RuntimeError):
|
|
# Database not available - yield None
|
|
yield None
|
|
|
|
|
|
async def get_series_app_with_db(
|
|
db: AsyncSession = Depends(get_optional_database_session),
|
|
) -> SeriesApp:
|
|
"""
|
|
Dependency to get SeriesApp instance with database support.
|
|
|
|
This creates or returns a SeriesApp instance and injects the
|
|
database session for database-backed storage.
|
|
|
|
Args:
|
|
db: Optional database session from dependency injection
|
|
|
|
Returns:
|
|
SeriesApp: The main application instance with database support
|
|
|
|
Raises:
|
|
HTTPException: If SeriesApp is not initialized or anime directory
|
|
is not configured
|
|
|
|
Example:
|
|
@app.post("/api/anime/scan")
|
|
async def scan_anime(
|
|
series_app: SeriesApp = Depends(get_series_app_with_db)
|
|
):
|
|
# series_app has db_session configured
|
|
await series_app.serie_scanner.scan_async()
|
|
"""
|
|
# Get the base SeriesApp
|
|
app = get_series_app()
|
|
|
|
# Inject database session if available
|
|
if db:
|
|
app.set_db_session(db)
|
|
|
|
return app
|
|
|
|
|
|
def get_current_user(
|
|
credentials: Optional[HTTPAuthorizationCredentials] = Depends(
|
|
http_bearer_security
|
|
),
|
|
) -> dict:
|
|
"""
|
|
Dependency to get current authenticated user.
|
|
|
|
Args:
|
|
credentials: JWT token from Authorization header
|
|
|
|
Returns:
|
|
dict: User information
|
|
|
|
Raises:
|
|
HTTPException: If token is invalid or user is not authenticated
|
|
"""
|
|
if not credentials:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Missing authorization credentials",
|
|
)
|
|
|
|
token = credentials.credentials
|
|
try:
|
|
# Validate and decode token using the auth service
|
|
session = auth_service.create_session_model(token)
|
|
return session.model_dump()
|
|
except AuthError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=str(e),
|
|
) from e
|
|
|
|
|
|
def require_auth(
|
|
current_user: dict = Depends(get_current_user)
|
|
) -> dict:
|
|
"""
|
|
Dependency that requires authentication.
|
|
|
|
Args:
|
|
current_user: Current authenticated user from get_current_user
|
|
|
|
Returns:
|
|
dict: User information
|
|
"""
|
|
return current_user
|
|
|
|
|
|
def optional_auth(
|
|
credentials: Optional[HTTPAuthorizationCredentials] = Depends(
|
|
HTTPBearer(auto_error=False)
|
|
)
|
|
) -> Optional[dict]:
|
|
"""
|
|
Dependency for optional authentication.
|
|
|
|
Args:
|
|
credentials: Optional JWT token from Authorization header
|
|
|
|
Returns:
|
|
Optional[dict]: User information if authenticated, None otherwise
|
|
"""
|
|
if credentials is None:
|
|
return None
|
|
|
|
try:
|
|
return get_current_user(credentials)
|
|
except HTTPException:
|
|
return None
|
|
|
|
|
|
def get_current_user_optional(
|
|
credentials: Optional[HTTPAuthorizationCredentials] = Depends(
|
|
HTTPBearer(auto_error=False)
|
|
)
|
|
) -> Optional[str]:
|
|
"""
|
|
Dependency to get optional current user ID.
|
|
|
|
Args:
|
|
credentials: Optional JWT token from Authorization header
|
|
|
|
Returns:
|
|
Optional[str]: User ID if authenticated, None otherwise
|
|
"""
|
|
user_dict = optional_auth(credentials)
|
|
if user_dict:
|
|
return user_dict.get("user_id")
|
|
return None
|
|
|
|
|
|
class CommonQueryParams:
|
|
"""Reusable pagination parameters shared across API endpoints."""
|
|
|
|
def __init__(self, skip: int = 0, limit: int = 100) -> None:
|
|
"""Create a reusable pagination parameter container.
|
|
|
|
Args:
|
|
skip: Number of records to offset when querying collections.
|
|
limit: Maximum number of records to return in a single call.
|
|
"""
|
|
self.skip = skip
|
|
self.limit = limit
|
|
|
|
|
|
def common_parameters(
|
|
skip: int = 0,
|
|
limit: int = 100
|
|
) -> CommonQueryParams:
|
|
"""
|
|
Dependency for common query parameters.
|
|
|
|
Args:
|
|
skip: Number of items to skip (for pagination)
|
|
limit: Maximum number of items to return
|
|
|
|
Returns:
|
|
CommonQueryParams: Common query parameters
|
|
"""
|
|
return CommonQueryParams(skip=skip, limit=limit)
|
|
|
|
|
|
# Dependency for rate limiting (placeholder)
|
|
async def rate_limit_dependency(request: Request) -> None:
|
|
"""Apply a simple fixed-window rate limit to incoming requests."""
|
|
|
|
client_id = "unknown"
|
|
if request.client and request.client.host:
|
|
client_id = request.client.host
|
|
|
|
max_requests = max(1, settings.api_rate_limit)
|
|
now = time.time()
|
|
|
|
async with _rate_limit_lock:
|
|
record = _RATE_LIMIT_BUCKETS.get(client_id)
|
|
window_expired = (
|
|
not record
|
|
or now - record.window_start >= _RATE_LIMIT_WINDOW_SECONDS
|
|
)
|
|
if window_expired:
|
|
_RATE_LIMIT_BUCKETS[client_id] = RateLimitRecord(
|
|
count=1,
|
|
window_start=now,
|
|
)
|
|
return
|
|
|
|
if record: # Type guard to satisfy mypy
|
|
record.count += 1
|
|
if record.count > max_requests:
|
|
logger.warning(
|
|
"Rate limit exceeded", extra={"client": client_id}
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
detail="Too many requests. Please slow down.",
|
|
)
|
|
|
|
|
|
# Dependency for request logging (placeholder)
|
|
async def log_request_dependency(request: Request) -> None:
|
|
"""Log request metadata for auditing and debugging purposes."""
|
|
|
|
logger.info(
|
|
"API request",
|
|
extra={
|
|
"method": request.method,
|
|
"path": request.url.path,
|
|
"client": request.client.host if request.client else "unknown",
|
|
"query": dict(request.query_params),
|
|
},
|
|
)
|
|
|
|
|
|
def get_anime_service() -> "AnimeService":
|
|
"""
|
|
Dependency to get AnimeService instance.
|
|
|
|
Returns:
|
|
AnimeService: The anime service for async operations
|
|
|
|
Raises:
|
|
HTTPException: If anime directory is not configured or
|
|
AnimeService initialization fails
|
|
"""
|
|
global _anime_service
|
|
|
|
if not settings.anime_directory:
|
|
# During test runs we allow a fallback to the system temp dir so
|
|
# fixtures that patch SeriesApp/AnimeService can still initialize
|
|
# the service even when no anime directory is configured. In
|
|
# production we still treat this as a configuration error.
|
|
import os
|
|
import sys
|
|
import tempfile
|
|
|
|
# Prefer explicit test mode opt-in via ANIWORLD_TESTING=1; fall back
|
|
# to legacy heuristics for backwards compatibility with CI.
|
|
running_tests = os.getenv("ANIWORLD_TESTING") == "1"
|
|
if not running_tests:
|
|
running_tests = (
|
|
"PYTEST_CURRENT_TEST" in os.environ
|
|
or "pytest" in sys.modules
|
|
)
|
|
|
|
if running_tests:
|
|
settings.anime_directory = tempfile.gettempdir()
|
|
else:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail=(
|
|
"Anime directory not configured. "
|
|
"Please complete setup."
|
|
),
|
|
)
|
|
|
|
if _anime_service is None:
|
|
try:
|
|
from src.server.services.anime_service import AnimeService
|
|
|
|
# Get the singleton SeriesApp instance
|
|
series_app = get_series_app()
|
|
_anime_service = AnimeService(series_app)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=(
|
|
"Failed to initialize AnimeService: "
|
|
f"{str(e)}"
|
|
),
|
|
) from e
|
|
|
|
return _anime_service
|
|
|
|
|
|
def get_download_service() -> "DownloadService":
|
|
"""
|
|
Dependency to get DownloadService instance.
|
|
|
|
Returns:
|
|
DownloadService: The download queue service
|
|
|
|
Raises:
|
|
HTTPException: If DownloadService initialization fails
|
|
"""
|
|
global _download_service
|
|
|
|
if _download_service is None:
|
|
try:
|
|
from src.server.services.download_service import DownloadService
|
|
|
|
anime_service = get_anime_service()
|
|
_download_service = DownloadService(anime_service)
|
|
|
|
# Note: DownloadService no longer needs broadcast callbacks.
|
|
# Progress updates flow through:
|
|
# SeriesApp → AnimeService → ProgressService → WebSocketService
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=(
|
|
"Failed to initialize DownloadService: "
|
|
f"{str(e)}"
|
|
),
|
|
) from e
|
|
|
|
return _download_service
|
|
|
|
|
|
def reset_anime_service() -> None:
|
|
"""Reset global AnimeService instance (for testing/config changes)."""
|
|
global _anime_service
|
|
_anime_service = None
|
|
|
|
|
|
def reset_download_service() -> None:
|
|
"""Reset global DownloadService instance (for testing/config changes)."""
|
|
global _download_service
|
|
_download_service = None
|