Add unit tests for dependency exception handling
- Created test_dependency_exception_handling.py with 5 comprehensive tests - Tests verify proper handling of HTTPException in async generator dependencies - All tests pass, confirming fix for 'generator didn't stop after athrow()' error - Updated instructions with complete task documentation
This commit is contained in:
152
tests/unit/test_dependency_exception_handling.py
Normal file
152
tests/unit/test_dependency_exception_handling.py
Normal file
@@ -0,0 +1,152 @@
|
||||
"""
|
||||
Unit tests for dependency exception handling in FastAPI dependencies.
|
||||
|
||||
This module tests that async generator dependencies properly handle exceptions
|
||||
thrown back into them, preventing the "generator didn't stop after athrow()" error.
|
||||
"""
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException, Depends
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_optional_database_session_handles_http_exception():
|
||||
"""Test that get_optional_database_session properly handles HTTPException.
|
||||
|
||||
This test verifies the fix for the "generator didn't stop after athrow()" error
|
||||
that occurred when an HTTPException was raised after yielding a database session.
|
||||
"""
|
||||
from src.server.utils.dependencies import get_optional_database_session
|
||||
|
||||
# Create a test app
|
||||
app = FastAPI()
|
||||
|
||||
@app.post("/test")
|
||||
async def test_endpoint(
|
||||
db: Optional[object] = Depends(get_optional_database_session)
|
||||
):
|
||||
"""Test endpoint that raises HTTPException after dependency yields."""
|
||||
# Simulate validation error that raises HTTPException
|
||||
raise HTTPException(status_code=400, detail="Validation error")
|
||||
|
||||
# Test the endpoint
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.post("/test")
|
||||
|
||||
# Should return 400, not 500 (internal server error)
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "Validation error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_database_session_handles_http_exception():
|
||||
"""Test that get_database_session properly handles HTTPException.
|
||||
|
||||
This test verifies the fix for the "generator didn't stop after athrow()" error
|
||||
that occurred when an HTTPException was raised after yielding a database session.
|
||||
"""
|
||||
from src.server.utils.dependencies import get_database_session
|
||||
|
||||
# Create a test app
|
||||
app = FastAPI()
|
||||
|
||||
@app.post("/test")
|
||||
async def test_endpoint(db: object = Depends(get_database_session)):
|
||||
"""Test endpoint that raises HTTPException after dependency yields."""
|
||||
# Simulate validation error that raises HTTPException
|
||||
raise HTTPException(status_code=400, detail="Validation error")
|
||||
|
||||
# Test the endpoint - may get 501/503 if DB not available, or 400 if it is
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.post("/test")
|
||||
|
||||
# Should return proper HTTP error, not 500 with generator error
|
||||
assert response.status_code in (400, 501, 503)
|
||||
assert "generator didn't stop" not in str(response.json())
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_exceptions_in_optional_session():
|
||||
"""Test that multiple different exceptions are handled correctly."""
|
||||
from src.server.utils.dependencies import get_optional_database_session
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.post("/test-400")
|
||||
async def test_400(db: Optional[object] = Depends(get_optional_database_session)):
|
||||
raise HTTPException(status_code=400, detail="Bad request")
|
||||
|
||||
@app.post("/test-404")
|
||||
async def test_404(db: Optional[object] = Depends(get_optional_database_session)):
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
|
||||
@app.post("/test-422")
|
||||
async def test_422(db: Optional[object] = Depends(get_optional_database_session)):
|
||||
raise HTTPException(status_code=422, detail="Validation error")
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
# Test 400
|
||||
response = await client.post("/test-400")
|
||||
assert response.status_code == 400
|
||||
|
||||
# Test 404
|
||||
response = await client.post("/test-404")
|
||||
assert response.status_code == 404
|
||||
|
||||
# Test 422
|
||||
response = await client.post("/test-422")
|
||||
assert response.status_code == 422
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_successful_request_with_optional_session():
|
||||
"""Test that successful requests still work properly."""
|
||||
from src.server.utils.dependencies import get_optional_database_session
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.post("/test")
|
||||
async def test_endpoint(db: Optional[object] = Depends(get_optional_database_session)):
|
||||
"""Test endpoint that succeeds."""
|
||||
return {"status": "success", "db_available": db is not None}
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.post("/test")
|
||||
|
||||
# Should return 200 for successful request
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "success"
|
||||
assert isinstance(data["db_available"], bool)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exception_after_using_session():
|
||||
"""Test exception raised after actually using the database session."""
|
||||
from src.server.utils.dependencies import get_optional_database_session
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.post("/test")
|
||||
async def test_endpoint(db: Optional[object] = Depends(get_optional_database_session)):
|
||||
"""Test endpoint that uses db then raises exception."""
|
||||
# Simulate using the database
|
||||
if db is not None:
|
||||
# In real code, would do: await db.execute(...)
|
||||
pass
|
||||
|
||||
# Then raise an exception
|
||||
raise HTTPException(status_code=400, detail="After using db")
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
response = await client.post("/test")
|
||||
|
||||
# Should properly handle the exception
|
||||
assert response.status_code == 400
|
||||
assert response.json()["detail"] == "After using db"
|
||||
Reference in New Issue
Block a user