""" Unit tests for dependency exception handling in FastAPI dependencies. This module tests that async generator dependencies properly handle exceptions thrown back into them, preventing the "generator didn't stop after athrow()" error. """ import pytest from fastapi import FastAPI, HTTPException, Depends from httpx import AsyncClient, ASGITransport from typing import AsyncGenerator, Optional @pytest.mark.asyncio async def test_get_optional_database_session_handles_http_exception(): """Test that get_optional_database_session properly handles HTTPException. This test verifies the fix for the "generator didn't stop after athrow()" error that occurred when an HTTPException was raised after yielding a database session. """ from src.server.utils.dependencies import get_optional_database_session # Create a test app app = FastAPI() @app.post("/test") async def test_endpoint( db: Optional[object] = Depends(get_optional_database_session) ): """Test endpoint that raises HTTPException after dependency yields.""" # Simulate validation error that raises HTTPException raise HTTPException(status_code=400, detail="Validation error") # Test the endpoint transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post("/test") # Should return 400, not 500 (internal server error) assert response.status_code == 400 assert response.json()["detail"] == "Validation error" @pytest.mark.asyncio async def test_get_database_session_handles_http_exception(): """Test that get_database_session properly handles HTTPException. This test verifies the fix for the "generator didn't stop after athrow()" error that occurred when an HTTPException was raised after yielding a database session. """ from src.server.utils.dependencies import get_database_session # Create a test app app = FastAPI() @app.post("/test") async def test_endpoint(db: object = Depends(get_database_session)): """Test endpoint that raises HTTPException after dependency yields.""" # Simulate validation error that raises HTTPException raise HTTPException(status_code=400, detail="Validation error") # Test the endpoint - may get 501/503 if DB not available, or 400 if it is transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post("/test") # Should return proper HTTP error, not 500 with generator error assert response.status_code in (400, 501, 503) assert "generator didn't stop" not in str(response.json()) @pytest.mark.asyncio async def test_multiple_exceptions_in_optional_session(): """Test that multiple different exceptions are handled correctly.""" from src.server.utils.dependencies import get_optional_database_session app = FastAPI() @app.post("/test-400") async def test_400(db: Optional[object] = Depends(get_optional_database_session)): raise HTTPException(status_code=400, detail="Bad request") @app.post("/test-404") async def test_404(db: Optional[object] = Depends(get_optional_database_session)): raise HTTPException(status_code=404, detail="Not found") @app.post("/test-422") async def test_422(db: Optional[object] = Depends(get_optional_database_session)): raise HTTPException(status_code=422, detail="Validation error") transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: # Test 400 response = await client.post("/test-400") assert response.status_code == 400 # Test 404 response = await client.post("/test-404") assert response.status_code == 404 # Test 422 response = await client.post("/test-422") assert response.status_code == 422 @pytest.mark.asyncio async def test_successful_request_with_optional_session(): """Test that successful requests still work properly.""" from src.server.utils.dependencies import get_optional_database_session app = FastAPI() @app.post("/test") async def test_endpoint(db: Optional[object] = Depends(get_optional_database_session)): """Test endpoint that succeeds.""" return {"status": "success", "db_available": db is not None} transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post("/test") # Should return 200 for successful request assert response.status_code == 200 data = response.json() assert data["status"] == "success" assert isinstance(data["db_available"], bool) @pytest.mark.asyncio async def test_exception_after_using_session(): """Test exception raised after actually using the database session.""" from src.server.utils.dependencies import get_optional_database_session app = FastAPI() @app.post("/test") async def test_endpoint(db: Optional[object] = Depends(get_optional_database_session)): """Test endpoint that uses db then raises exception.""" # Simulate using the database if db is not None: # In real code, would do: await db.execute(...) pass # Then raise an exception raise HTTPException(status_code=400, detail="After using db") transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as client: response = await client.post("/test") # Should properly handle the exception assert response.status_code == 400 assert response.json()["detail"] == "After using db"