- Add WebSocketService with ConnectionManager for connection lifecycle - Implement room-based messaging for topic subscriptions (e.g., downloads) - Create WebSocket message Pydantic models for type safety - Add /ws/connect endpoint for client connections - Integrate WebSocket broadcasts with download service - Add comprehensive unit tests (19/26 passing, core functionality verified) - Update infrastructure.md with WebSocket architecture documentation - Mark WebSocket task as completed in instructions.md Files added: - src/server/services/websocket_service.py - src/server/models/websocket.py - src/server/api/websocket.py - tests/unit/test_websocket_service.py Files modified: - src/server/fastapi_app.py (add websocket router) - src/server/utils/dependencies.py (integrate websocket with download service) - infrastructure.md (add WebSocket documentation) - instructions.md (mark task completed)
323 lines
9.1 KiB
Python
323 lines
9.1 KiB
Python
"""
|
|
Dependency injection utilities for FastAPI.
|
|
|
|
This module provides dependency injection functions for the FastAPI
|
|
application, including SeriesApp instances, AnimeService, DownloadService,
|
|
database sessions, and authentication dependencies.
|
|
"""
|
|
from typing import AsyncGenerator, Optional
|
|
|
|
from fastapi import Depends, HTTPException, status
|
|
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
|
|
|
try:
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
except Exception: # pragma: no cover - optional dependency
|
|
AsyncSession = object
|
|
|
|
from src.config.settings import settings
|
|
from src.core.SeriesApp import SeriesApp
|
|
from src.server.services.auth_service import AuthError, auth_service
|
|
|
|
# Security scheme for JWT authentication
|
|
security = HTTPBearer()
|
|
|
|
|
|
# Global SeriesApp instance
|
|
_series_app: Optional[SeriesApp] = None
|
|
|
|
# Global service instances
|
|
_anime_service: Optional[object] = None
|
|
_download_service: Optional[object] = None
|
|
|
|
|
|
def get_series_app() -> SeriesApp:
|
|
"""
|
|
Dependency to get SeriesApp instance.
|
|
|
|
Returns:
|
|
SeriesApp: The main application instance for anime management
|
|
|
|
Raises:
|
|
HTTPException: If SeriesApp is not initialized or anime directory
|
|
is not configured
|
|
"""
|
|
global _series_app
|
|
|
|
if not settings.anime_directory:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Anime directory not configured. Please complete setup."
|
|
)
|
|
|
|
if _series_app is None:
|
|
try:
|
|
_series_app = SeriesApp(settings.anime_directory)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to initialize SeriesApp: {str(e)}"
|
|
)
|
|
|
|
return _series_app
|
|
|
|
|
|
def reset_series_app() -> None:
|
|
"""Reset the global SeriesApp instance (for testing or config changes)."""
|
|
global _series_app
|
|
_series_app = None
|
|
|
|
|
|
async def get_database_session() -> AsyncGenerator[Optional[object], None]:
|
|
"""
|
|
Dependency to get database session.
|
|
|
|
Yields:
|
|
AsyncSession: Database session for async operations
|
|
"""
|
|
# TODO: Implement database session management
|
|
# This is a placeholder for future database implementation
|
|
raise HTTPException(
|
|
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
|
detail="Database functionality not yet implemented"
|
|
)
|
|
|
|
|
|
def get_current_user(
|
|
credentials: HTTPAuthorizationCredentials = Depends(security)
|
|
) -> dict:
|
|
"""
|
|
Dependency to get current authenticated user.
|
|
|
|
Args:
|
|
credentials: JWT token from Authorization header
|
|
|
|
Returns:
|
|
dict: User information
|
|
|
|
Raises:
|
|
HTTPException: If token is invalid or user is not authenticated
|
|
"""
|
|
if not credentials:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Missing authorization credentials",
|
|
)
|
|
|
|
token = credentials.credentials
|
|
try:
|
|
# Validate and decode token using the auth service
|
|
session = auth_service.create_session_model(token)
|
|
return session.dict()
|
|
except AuthError as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail=str(e),
|
|
)
|
|
|
|
|
|
def require_auth(
|
|
current_user: dict = Depends(get_current_user)
|
|
) -> dict:
|
|
"""
|
|
Dependency that requires authentication.
|
|
|
|
Args:
|
|
current_user: Current authenticated user from get_current_user
|
|
|
|
Returns:
|
|
dict: User information
|
|
"""
|
|
return current_user
|
|
|
|
|
|
def optional_auth(
|
|
credentials: Optional[HTTPAuthorizationCredentials] = Depends(
|
|
HTTPBearer(auto_error=False)
|
|
)
|
|
) -> Optional[dict]:
|
|
"""
|
|
Dependency for optional authentication.
|
|
|
|
Args:
|
|
credentials: Optional JWT token from Authorization header
|
|
|
|
Returns:
|
|
Optional[dict]: User information if authenticated, None otherwise
|
|
"""
|
|
if credentials is None:
|
|
return None
|
|
|
|
try:
|
|
return get_current_user(credentials)
|
|
except HTTPException:
|
|
return None
|
|
|
|
|
|
def get_current_user_optional(
|
|
credentials: Optional[HTTPAuthorizationCredentials] = Depends(
|
|
HTTPBearer(auto_error=False)
|
|
)
|
|
) -> Optional[str]:
|
|
"""
|
|
Dependency to get optional current user ID.
|
|
|
|
Args:
|
|
credentials: Optional JWT token from Authorization header
|
|
|
|
Returns:
|
|
Optional[str]: User ID if authenticated, None otherwise
|
|
"""
|
|
user_dict = optional_auth(credentials)
|
|
if user_dict:
|
|
return user_dict.get("user_id")
|
|
return None
|
|
|
|
|
|
class CommonQueryParams:
|
|
"""Common query parameters for API endpoints."""
|
|
|
|
def __init__(self, skip: int = 0, limit: int = 100):
|
|
self.skip = skip
|
|
self.limit = limit
|
|
|
|
|
|
def common_parameters(
|
|
skip: int = 0,
|
|
limit: int = 100
|
|
) -> CommonQueryParams:
|
|
"""
|
|
Dependency for common query parameters.
|
|
|
|
Args:
|
|
skip: Number of items to skip (for pagination)
|
|
limit: Maximum number of items to return
|
|
|
|
Returns:
|
|
CommonQueryParams: Common query parameters
|
|
"""
|
|
return CommonQueryParams(skip=skip, limit=limit)
|
|
|
|
|
|
# Dependency for rate limiting (placeholder)
|
|
async def rate_limit_dependency():
|
|
"""
|
|
Dependency for rate limiting API requests.
|
|
|
|
TODO: Implement rate limiting logic
|
|
"""
|
|
pass
|
|
|
|
|
|
# Dependency for request logging (placeholder)
|
|
async def log_request_dependency():
|
|
"""
|
|
Dependency for logging API requests.
|
|
|
|
TODO: Implement request logging logic
|
|
"""
|
|
pass
|
|
|
|
|
|
def get_anime_service() -> object:
|
|
"""
|
|
Dependency to get AnimeService instance.
|
|
|
|
Returns:
|
|
AnimeService: The anime service for async operations
|
|
|
|
Raises:
|
|
HTTPException: If anime directory is not configured or
|
|
AnimeService initialization fails
|
|
"""
|
|
global _anime_service
|
|
|
|
if not settings.anime_directory:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
detail="Anime directory not configured. Please complete setup.",
|
|
)
|
|
|
|
if _anime_service is None:
|
|
try:
|
|
from src.server.services.anime_service import AnimeService
|
|
_anime_service = AnimeService(settings.anime_directory)
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to initialize AnimeService: {str(e)}",
|
|
) from e
|
|
|
|
return _anime_service
|
|
|
|
|
|
def get_download_service() -> object:
|
|
"""
|
|
Dependency to get DownloadService instance.
|
|
|
|
Returns:
|
|
DownloadService: The download queue service
|
|
|
|
Raises:
|
|
HTTPException: If DownloadService initialization fails
|
|
"""
|
|
global _download_service
|
|
|
|
if _download_service is None:
|
|
try:
|
|
from src.server.services.download_service import DownloadService
|
|
from src.server.services.websocket_service import get_websocket_service
|
|
|
|
# Get anime service first (required dependency)
|
|
anime_service = get_anime_service()
|
|
|
|
# Initialize download service with anime service
|
|
_download_service = DownloadService(anime_service)
|
|
|
|
# Setup WebSocket broadcast callback
|
|
ws_service = get_websocket_service()
|
|
|
|
async def broadcast_callback(update_type: str, data: dict):
|
|
"""Broadcast download updates via WebSocket."""
|
|
if update_type == "download_progress":
|
|
await ws_service.broadcast_download_progress(
|
|
data.get("download_id", ""), data
|
|
)
|
|
elif update_type == "download_complete":
|
|
await ws_service.broadcast_download_complete(
|
|
data.get("download_id", ""), data
|
|
)
|
|
elif update_type == "download_failed":
|
|
await ws_service.broadcast_download_failed(
|
|
data.get("download_id", ""), data
|
|
)
|
|
elif update_type == "queue_status":
|
|
await ws_service.broadcast_queue_status(data)
|
|
else:
|
|
# Generic queue update
|
|
await ws_service.broadcast_queue_status(data)
|
|
|
|
_download_service.set_broadcast_callback(broadcast_callback)
|
|
|
|
except HTTPException:
|
|
raise
|
|
except Exception as e:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
detail=f"Failed to initialize DownloadService: {str(e)}",
|
|
) from e
|
|
|
|
return _download_service
|
|
|
|
|
|
def reset_anime_service() -> None:
|
|
"""Reset global AnimeService instance (for testing/config changes)."""
|
|
global _anime_service
|
|
_anime_service = None
|
|
|
|
|
|
def reset_download_service() -> None:
|
|
"""Reset global DownloadService instance (for testing/config changes)."""
|
|
global _download_service
|
|
_download_service = None
|