484 lines
17 KiB
Python
484 lines
17 KiB
Python
"""
|
|
Unit tests for dependency injection system.
|
|
|
|
This module tests the FastAPI dependency injection utilities including
|
|
SeriesApp dependency, database session dependency, and authentication
|
|
dependencies.
|
|
"""
|
|
from unittest.mock import MagicMock, Mock, patch
|
|
|
|
import pytest
|
|
from fastapi import HTTPException, status
|
|
from fastapi.security import HTTPAuthorizationCredentials
|
|
|
|
from src.server.utils.dependencies import (
|
|
CommonQueryParams,
|
|
common_parameters,
|
|
get_current_user,
|
|
get_database_session,
|
|
get_optional_database_session,
|
|
get_series_app,
|
|
log_request_dependency,
|
|
optional_auth,
|
|
rate_limit_dependency,
|
|
require_auth,
|
|
reset_series_app,
|
|
)
|
|
|
|
|
|
class TestSeriesAppDependency:
|
|
"""Test cases for SeriesApp dependency injection."""
|
|
|
|
def setup_method(self):
|
|
"""Setup for each test method."""
|
|
# Reset the global SeriesApp instance before each test
|
|
reset_series_app()
|
|
|
|
@patch('os.path.isdir', return_value=True)
|
|
@patch('src.server.utils.dependencies.settings')
|
|
@patch('src.server.utils.dependencies.SeriesApp')
|
|
def test_get_series_app_success(self, mock_series_app_class,
|
|
mock_settings, mock_isdir):
|
|
"""Test successful SeriesApp dependency injection."""
|
|
# Arrange
|
|
mock_settings.anime_directory = "/path/to/anime"
|
|
mock_series_app_instance = Mock()
|
|
mock_series_app_class.return_value = mock_series_app_instance
|
|
|
|
# Act
|
|
result = get_series_app()
|
|
|
|
# Assert
|
|
assert result == mock_series_app_instance
|
|
mock_series_app_class.assert_called_once_with("/path/to/anime")
|
|
|
|
@patch('src.server.services.config_service.get_config_service')
|
|
@patch('src.server.utils.dependencies.settings')
|
|
def test_get_series_app_no_directory_configured(
|
|
self, mock_settings, mock_get_config_service
|
|
):
|
|
"""Test SeriesApp dependency when directory is not configured.
|
|
|
|
In test mode (pytest running), get_series_app() falls back to
|
|
tempdir instead of raising 503. This test verifies the fallback
|
|
produces a valid SeriesApp (using tempdir).
|
|
"""
|
|
# Arrange
|
|
mock_settings.anime_directory = ""
|
|
|
|
# Mock config service to return empty config
|
|
mock_config_service = Mock()
|
|
mock_config = Mock()
|
|
mock_config.other = {}
|
|
mock_config_service.load_config.return_value = mock_config
|
|
mock_get_config_service.return_value = mock_config_service
|
|
|
|
# Act - in test mode, fallback to tempdir instead of 503
|
|
import tempfile
|
|
result = get_series_app()
|
|
assert result is not None
|
|
# settings.anime_directory should have been set to tempdir
|
|
assert mock_settings.anime_directory == tempfile.gettempdir()
|
|
|
|
@patch('src.server.utils.dependencies.settings')
|
|
@patch('src.server.utils.dependencies.SeriesApp')
|
|
def test_get_series_app_initialization_error(self, mock_series_app_class,
|
|
mock_settings):
|
|
"""Test SeriesApp dependency when initialization fails."""
|
|
# Arrange
|
|
mock_settings.anime_directory = "/path/to/anime"
|
|
mock_series_app_class.side_effect = Exception("Initialization failed")
|
|
|
|
# Act & Assert
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
get_series_app()
|
|
|
|
assert (exc_info.value.status_code ==
|
|
status.HTTP_500_INTERNAL_SERVER_ERROR)
|
|
assert "Failed to initialize SeriesApp" in str(exc_info.value.detail)
|
|
|
|
@patch('os.path.isdir', return_value=True)
|
|
@patch('src.server.utils.dependencies.settings')
|
|
@patch('src.server.utils.dependencies.SeriesApp')
|
|
def test_get_series_app_singleton_behavior(self, mock_series_app_class,
|
|
mock_settings, mock_isdir):
|
|
"""Test SeriesApp dependency returns same instance on calls."""
|
|
# Arrange
|
|
mock_settings.anime_directory = "/path/to/anime"
|
|
mock_series_app_instance = Mock()
|
|
mock_series_app_class.return_value = mock_series_app_instance
|
|
|
|
# Act
|
|
result1 = get_series_app()
|
|
result2 = get_series_app()
|
|
|
|
# Assert
|
|
assert result1 == result2
|
|
assert result1 == mock_series_app_instance
|
|
# SeriesApp should only be instantiated once
|
|
mock_series_app_class.assert_called_once_with("/path/to/anime")
|
|
|
|
def test_reset_series_app(self):
|
|
"""Test resetting the global SeriesApp instance."""
|
|
# Act
|
|
reset_series_app()
|
|
|
|
# Assert - this should complete without error
|
|
|
|
|
|
class TestDatabaseDependency:
|
|
"""Test cases for database session dependency injection."""
|
|
|
|
def test_get_database_session_not_implemented(self):
|
|
"""Test that database session dependency is async generator."""
|
|
import inspect
|
|
|
|
# Test that function exists and is an async generator function
|
|
assert inspect.isfunction(get_database_session)
|
|
assert inspect.isasyncgenfunction(get_database_session)
|
|
|
|
|
|
class TestAuthenticationDependencies:
|
|
"""Test cases for authentication dependency injection."""
|
|
|
|
def test_get_current_user_not_implemented(self):
|
|
"""Test that current user dependency rejects invalid tokens."""
|
|
# Arrange
|
|
credentials = HTTPAuthorizationCredentials(
|
|
scheme="Bearer",
|
|
credentials="invalid-token"
|
|
)
|
|
|
|
# Act & Assert
|
|
with pytest.raises(HTTPException) as exc_info:
|
|
get_current_user(credentials)
|
|
|
|
# Should raise 401 for invalid token
|
|
assert (exc_info.value.status_code ==
|
|
status.HTTP_401_UNAUTHORIZED)
|
|
|
|
def test_require_auth_with_user(self):
|
|
"""Test require_auth dependency with authenticated user."""
|
|
# Arrange
|
|
mock_user = {"user_id": 123, "username": "testuser"}
|
|
|
|
# Act
|
|
result = require_auth(mock_user)
|
|
|
|
# Assert
|
|
assert result == mock_user
|
|
|
|
def test_optional_auth_without_credentials(self):
|
|
"""Test optional authentication without credentials."""
|
|
# Act
|
|
result = optional_auth(None)
|
|
|
|
# Assert
|
|
assert result is None
|
|
|
|
@patch('src.server.utils.dependencies.get_current_user')
|
|
def test_optional_auth_with_valid_credentials(self, mock_get_current_user):
|
|
"""Test optional authentication with valid credentials."""
|
|
# Arrange
|
|
credentials = HTTPAuthorizationCredentials(
|
|
scheme="Bearer",
|
|
credentials="valid-token"
|
|
)
|
|
mock_user = {"user_id": 123, "username": "testuser"}
|
|
mock_get_current_user.return_value = mock_user
|
|
|
|
# Act
|
|
result = optional_auth(credentials)
|
|
|
|
# Assert
|
|
assert result == mock_user
|
|
mock_get_current_user.assert_called_once_with(credentials)
|
|
|
|
@patch('src.server.utils.dependencies.get_current_user')
|
|
def test_optional_auth_with_invalid_credentials(
|
|
self,
|
|
mock_get_current_user
|
|
):
|
|
"""Test optional authentication with invalid credentials."""
|
|
# Arrange
|
|
credentials = HTTPAuthorizationCredentials(
|
|
scheme="Bearer",
|
|
credentials="invalid-token"
|
|
)
|
|
mock_get_current_user.side_effect = HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Invalid token"
|
|
)
|
|
|
|
# Act
|
|
result = optional_auth(credentials)
|
|
|
|
# Assert
|
|
assert result is None
|
|
mock_get_current_user.assert_called_once_with(credentials)
|
|
|
|
|
|
class TestCommonQueryParams:
|
|
"""Test cases for common query parameters."""
|
|
|
|
def test_common_query_params_initialization(self):
|
|
"""Test CommonQueryParams initialization."""
|
|
# Act
|
|
params = CommonQueryParams(skip=10, limit=50)
|
|
|
|
# Assert
|
|
assert params.skip == 10
|
|
assert params.limit == 50
|
|
|
|
def test_common_query_params_defaults(self):
|
|
"""Test CommonQueryParams with default values."""
|
|
# Act
|
|
params = CommonQueryParams()
|
|
|
|
# Assert
|
|
assert params.skip == 0
|
|
assert params.limit == 100
|
|
|
|
def test_common_parameters_dependency(self):
|
|
"""Test common parameters dependency function."""
|
|
# Act
|
|
params = common_parameters(skip=20, limit=30)
|
|
|
|
# Assert
|
|
assert isinstance(params, CommonQueryParams)
|
|
assert params.skip == 20
|
|
assert params.limit == 30
|
|
|
|
def test_common_parameters_dependency_defaults(self):
|
|
"""Test common parameters dependency with defaults."""
|
|
# Act
|
|
params = common_parameters()
|
|
|
|
# Assert
|
|
assert isinstance(params, CommonQueryParams)
|
|
assert params.skip == 0
|
|
assert params.limit == 100
|
|
|
|
|
|
class TestUtilityDependencies:
|
|
"""Test cases for utility dependencies."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rate_limit_dependency(self):
|
|
"""Test rate limit dependency (placeholder)."""
|
|
from unittest.mock import Mock
|
|
|
|
# Create a mock request
|
|
mock_request = Mock()
|
|
mock_request.client = Mock()
|
|
mock_request.client.host = "127.0.0.1"
|
|
|
|
# Act - should complete without error
|
|
await rate_limit_dependency(mock_request)
|
|
|
|
# Assert - no exception should be raised
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_log_request_dependency(self):
|
|
"""Test log request dependency (placeholder)."""
|
|
from unittest.mock import Mock
|
|
|
|
# Create a mock request
|
|
mock_request = Mock()
|
|
mock_request.method = "GET"
|
|
mock_request.url = Mock()
|
|
mock_request.url.path = "/test"
|
|
mock_request.client = Mock()
|
|
mock_request.client.host = "127.0.0.1"
|
|
mock_request.query_params = {}
|
|
|
|
# Act - should complete without error
|
|
await log_request_dependency(mock_request)
|
|
|
|
# Assert - no exception should be raised
|
|
|
|
|
|
class TestOptionalDatabaseSession:
|
|
"""Test cases for optional database session dependency."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_optional_database_session_success(self):
|
|
"""Test successful database session creation."""
|
|
from unittest.mock import AsyncMock
|
|
|
|
# Mock the database session
|
|
mock_session = AsyncMock()
|
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_session.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
mock_get_db = Mock(return_value=mock_session)
|
|
|
|
with patch('src.server.database.get_db_session', mock_get_db):
|
|
# Act
|
|
gen = get_optional_database_session()
|
|
session = await gen.__anext__()
|
|
|
|
# Assert
|
|
assert session is mock_session
|
|
mock_get_db.assert_called_once()
|
|
|
|
# Cleanup
|
|
try:
|
|
await gen.aclose()
|
|
except StopAsyncIteration:
|
|
pass
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_optional_database_session_not_available(self):
|
|
"""Test when database is not available (ImportError)."""
|
|
# Mock ImportError when trying to import get_db_session
|
|
with patch(
|
|
'src.server.database.get_db_session',
|
|
side_effect=ImportError("No module named 'database'")
|
|
):
|
|
# Act
|
|
gen = get_optional_database_session()
|
|
session = await gen.__anext__()
|
|
|
|
# Assert - should return None when database not available
|
|
assert session is None
|
|
|
|
# Cleanup
|
|
try:
|
|
await gen.aclose()
|
|
except StopAsyncIteration:
|
|
pass
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_optional_database_session_runtime_error(self):
|
|
"""Test when database raises RuntimeError."""
|
|
# Mock RuntimeError when trying to get database session
|
|
with patch(
|
|
'src.server.database.get_db_session',
|
|
side_effect=RuntimeError("Database connection failed")
|
|
):
|
|
# Act
|
|
gen = get_optional_database_session()
|
|
session = await gen.__anext__()
|
|
|
|
# Assert - should return None on RuntimeError
|
|
assert session is None
|
|
|
|
# Cleanup
|
|
try:
|
|
await gen.aclose()
|
|
except StopAsyncIteration:
|
|
pass
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_optional_database_session_exception_during_use(self):
|
|
"""
|
|
Test that exceptions during database operations are properly propagated.
|
|
|
|
This test specifically addresses the "generator didn't stop after athrow()"
|
|
error that occurred when exceptions were caught and re-raised within the
|
|
yield block of an async context manager.
|
|
"""
|
|
from unittest.mock import AsyncMock
|
|
|
|
# Create a mock session that will raise an exception when used
|
|
mock_session = AsyncMock()
|
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_session.__aexit__ = AsyncMock(return_value=None)
|
|
|
|
# Mock the get_db_session to return our mock session
|
|
mock_get_db = Mock(return_value=mock_session)
|
|
|
|
with patch('src.server.database.get_db_session', mock_get_db):
|
|
# Act & Assert
|
|
gen = get_optional_database_session()
|
|
session = await gen.__anext__()
|
|
|
|
assert session is mock_session
|
|
|
|
# Simulate an exception being thrown into the generator
|
|
# This mimics what happens when an endpoint raises an exception
|
|
# after the dependency has yielded
|
|
test_exception = ValueError("Database operation failed")
|
|
|
|
try:
|
|
# This should not cause "generator didn't stop after athrow()" error
|
|
await gen.athrow(test_exception)
|
|
# If we get here, the exception was swallowed (shouldn't happen)
|
|
pytest.fail("Exception should have been propagated")
|
|
except ValueError as e:
|
|
# Exception should be properly propagated
|
|
assert str(e) == "Database operation failed"
|
|
except StopAsyncIteration:
|
|
# Generator stopped normally after exception
|
|
pass
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_optional_database_session_cleanup_on_exception(self):
|
|
"""Test that database session is properly cleaned up when exception occurs."""
|
|
from unittest.mock import AsyncMock
|
|
|
|
# Track cleanup
|
|
cleanup_called = []
|
|
|
|
async def mock_exit(*args):
|
|
cleanup_called.append(True)
|
|
return None
|
|
|
|
# Create a mock session with tracked cleanup
|
|
mock_session = AsyncMock()
|
|
mock_session.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_session.__aexit__ = mock_exit
|
|
|
|
mock_get_db = Mock(return_value=mock_session)
|
|
|
|
with patch('src.server.database.get_db_session', mock_get_db):
|
|
gen = get_optional_database_session()
|
|
session = await gen.__anext__()
|
|
|
|
assert session is mock_session
|
|
|
|
# Throw an exception to simulate endpoint failure
|
|
try:
|
|
await gen.athrow(RuntimeError("Simulated endpoint error"))
|
|
except (RuntimeError, StopAsyncIteration):
|
|
pass
|
|
|
|
# Assert cleanup was called
|
|
assert len(cleanup_called) > 0, "Session cleanup should have been called"
|
|
|
|
|
|
class TestIntegrationScenarios:
|
|
"""Integration test scenarios for dependency injection."""
|
|
|
|
def test_series_app_lifecycle(self):
|
|
"""Test the complete SeriesApp dependency lifecycle."""
|
|
# Use separate mock instances for each call
|
|
with patch('src.server.utils.dependencies.settings') as mock_settings:
|
|
with patch(
|
|
'src.server.utils.dependencies.SeriesApp'
|
|
) as mock_series_app_class:
|
|
# Arrange
|
|
mock_settings.anime_directory = "/path/to/anime"
|
|
|
|
# Create separate mock instances for each instantiation
|
|
mock_instance1 = MagicMock()
|
|
mock_instance2 = MagicMock()
|
|
mock_series_app_class.side_effect = [
|
|
mock_instance1, mock_instance2
|
|
]
|
|
|
|
# Act - Get SeriesApp instance
|
|
app1 = get_series_app()
|
|
app2 = get_series_app() # Should return same instance
|
|
|
|
# Reset and get again
|
|
reset_series_app()
|
|
app3 = get_series_app()
|
|
|
|
# Assert
|
|
assert app1 == app2 # Same instance due to singleton behavior
|
|
assert app1 != app3 # Different instance after reset
|
|
# Called twice due to reset
|
|
assert mock_series_app_class.call_count == 2
|