346 lines
12 KiB
Python
346 lines
12 KiB
Python
"""Unit tests for FastAPI middleware error handler.
|
|
|
|
Tests error response formatting, exception handler registration,
|
|
custom exception handling, and the general exception handler.
|
|
"""
|
|
|
|
from typing import Any, Dict, Optional
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
from httpx import ASGITransport, AsyncClient
|
|
|
|
from src.server.exceptions import (
|
|
AniWorldAPIException,
|
|
AuthenticationError,
|
|
AuthorizationError,
|
|
BadRequestError,
|
|
ConflictError,
|
|
NotFoundError,
|
|
RateLimitError,
|
|
ValidationError,
|
|
)
|
|
from src.server.middleware.error_handler import (
|
|
create_error_response,
|
|
register_exception_handlers,
|
|
)
|
|
|
|
|
|
class TestCreateErrorResponse:
|
|
"""Tests for the create_error_response utility function."""
|
|
|
|
def test_basic_error_response_structure(self):
|
|
"""Error response has success, error, and message keys."""
|
|
resp = create_error_response(
|
|
status_code=400, error="BAD_REQUEST", message="Invalid input"
|
|
)
|
|
assert resp["success"] is False
|
|
assert resp["error"] == "BAD_REQUEST"
|
|
assert resp["message"] == "Invalid input"
|
|
|
|
def test_response_includes_details_when_provided(self):
|
|
"""Details dict is included when specified."""
|
|
details = {"field": "name", "reason": "too long"}
|
|
resp = create_error_response(
|
|
status_code=422, error="VALIDATION", message="Bad",
|
|
details=details,
|
|
)
|
|
assert resp["details"] == details
|
|
|
|
def test_response_excludes_details_when_none(self):
|
|
"""Details key absent when not specified."""
|
|
resp = create_error_response(
|
|
status_code=400, error="ERR", message="msg"
|
|
)
|
|
assert "details" not in resp
|
|
|
|
def test_response_includes_request_id(self):
|
|
"""Request ID is included when provided."""
|
|
resp = create_error_response(
|
|
status_code=500, error="ERR", message="msg",
|
|
request_id="req-123",
|
|
)
|
|
assert resp["request_id"] == "req-123"
|
|
|
|
def test_response_excludes_request_id_when_none(self):
|
|
"""Request ID key absent when not specified."""
|
|
resp = create_error_response(
|
|
status_code=500, error="ERR", message="msg"
|
|
)
|
|
assert "request_id" not in resp
|
|
|
|
|
|
class TestExceptionHandlerRegistration:
|
|
"""Tests that exception handlers are correctly registered on a FastAPI app."""
|
|
|
|
@pytest.fixture
|
|
def app_with_handlers(self) -> FastAPI:
|
|
"""Create a FastAPI app with registered exception handlers."""
|
|
app = FastAPI()
|
|
register_exception_handlers(app)
|
|
return app
|
|
|
|
def _add_route_raising(self, app: FastAPI, exc: Exception):
|
|
"""Add a GET /test route that raises the given exception."""
|
|
@app.get("/test")
|
|
async def route():
|
|
raise exc
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_authentication_error_returns_401(
|
|
self, app_with_handlers
|
|
):
|
|
"""AuthenticationError maps to HTTP 401."""
|
|
self._add_route_raising(
|
|
app_with_handlers, AuthenticationError("bad creds")
|
|
)
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=app_with_handlers),
|
|
base_url="http://test",
|
|
) as client:
|
|
resp = await client.get("/test")
|
|
assert resp.status_code == 401
|
|
body = resp.json()
|
|
assert body["success"] is False
|
|
assert body["error"] == "AUTHENTICATION_ERROR"
|
|
assert body["message"] == "bad creds"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_authorization_error_returns_403(
|
|
self, app_with_handlers
|
|
):
|
|
"""AuthorizationError maps to HTTP 403."""
|
|
self._add_route_raising(
|
|
app_with_handlers, AuthorizationError("forbidden")
|
|
)
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=app_with_handlers),
|
|
base_url="http://test",
|
|
) as client:
|
|
resp = await client.get("/test")
|
|
assert resp.status_code == 403
|
|
assert resp.json()["error"] == "AUTHORIZATION_ERROR"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_bad_request_error_returns_400(
|
|
self, app_with_handlers
|
|
):
|
|
"""BadRequestError maps to HTTP 400."""
|
|
self._add_route_raising(
|
|
app_with_handlers, BadRequestError("invalid")
|
|
)
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=app_with_handlers),
|
|
base_url="http://test",
|
|
) as client:
|
|
resp = await client.get("/test")
|
|
assert resp.status_code == 400
|
|
assert resp.json()["error"] == "BAD_REQUEST"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_not_found_error_returns_404(
|
|
self, app_with_handlers
|
|
):
|
|
"""NotFoundError maps to HTTP 404."""
|
|
self._add_route_raising(
|
|
app_with_handlers,
|
|
NotFoundError("anime not found", resource_type="anime", resource_id=42),
|
|
)
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=app_with_handlers),
|
|
base_url="http://test",
|
|
) as client:
|
|
resp = await client.get("/test")
|
|
assert resp.status_code == 404
|
|
body = resp.json()
|
|
assert body["error"] == "NOT_FOUND"
|
|
assert body["details"]["resource_type"] == "anime"
|
|
assert body["details"]["resource_id"] == 42
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_validation_error_returns_422(
|
|
self, app_with_handlers
|
|
):
|
|
"""ValidationError maps to HTTP 422."""
|
|
self._add_route_raising(
|
|
app_with_handlers, ValidationError("bad data")
|
|
)
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=app_with_handlers),
|
|
base_url="http://test",
|
|
) as client:
|
|
resp = await client.get("/test")
|
|
assert resp.status_code == 422
|
|
assert resp.json()["error"] == "VALIDATION_ERROR"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_conflict_error_returns_409(
|
|
self, app_with_handlers
|
|
):
|
|
"""ConflictError maps to HTTP 409."""
|
|
self._add_route_raising(
|
|
app_with_handlers, ConflictError("duplicate")
|
|
)
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=app_with_handlers),
|
|
base_url="http://test",
|
|
) as client:
|
|
resp = await client.get("/test")
|
|
assert resp.status_code == 409
|
|
assert resp.json()["error"] == "CONFLICT"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_rate_limit_error_returns_429(
|
|
self, app_with_handlers
|
|
):
|
|
"""RateLimitError maps to HTTP 429."""
|
|
self._add_route_raising(
|
|
app_with_handlers, RateLimitError("too many", retry_after=60)
|
|
)
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=app_with_handlers),
|
|
base_url="http://test",
|
|
) as client:
|
|
resp = await client.get("/test")
|
|
assert resp.status_code == 429
|
|
body = resp.json()
|
|
assert body["error"] == "RATE_LIMIT_EXCEEDED"
|
|
assert body["details"]["retry_after"] == 60
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_generic_api_exception_returns_status(
|
|
self, app_with_handlers
|
|
):
|
|
"""AniWorldAPIException uses its status_code."""
|
|
self._add_route_raising(
|
|
app_with_handlers,
|
|
AniWorldAPIException("custom error", status_code=418),
|
|
)
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=app_with_handlers),
|
|
base_url="http://test",
|
|
) as client:
|
|
resp = await client.get("/test")
|
|
assert resp.status_code == 418
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unexpected_exception_returns_500(
|
|
self, app_with_handlers
|
|
):
|
|
"""Unhandled exceptions map to HTTP 500 with generic message."""
|
|
self._add_route_raising(
|
|
app_with_handlers, RuntimeError("unexpected crash")
|
|
)
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=app_with_handlers, raise_app_exceptions=False),
|
|
base_url="http://test",
|
|
) as client:
|
|
resp = await client.get("/test")
|
|
assert resp.status_code == 500
|
|
body = resp.json()
|
|
assert body["error"] == "INTERNAL_SERVER_ERROR"
|
|
assert body["message"] == "An unexpected error occurred"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unexpected_exception_hides_stack_trace(
|
|
self, app_with_handlers
|
|
):
|
|
"""Stack traces are not leaked in 500 error responses."""
|
|
self._add_route_raising(
|
|
app_with_handlers, RuntimeError("internal secret")
|
|
)
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=app_with_handlers, raise_app_exceptions=False),
|
|
base_url="http://test",
|
|
) as client:
|
|
resp = await client.get("/test")
|
|
body = resp.json()
|
|
assert "internal secret" not in body["message"]
|
|
assert "Traceback" not in str(body)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_error_response_is_json(self, app_with_handlers):
|
|
"""All error responses are JSON formatted."""
|
|
self._add_route_raising(
|
|
app_with_handlers, NotFoundError("missing")
|
|
)
|
|
async with AsyncClient(
|
|
transport=ASGITransport(app=app_with_handlers),
|
|
base_url="http://test",
|
|
) as client:
|
|
resp = await client.get("/test")
|
|
assert resp.headers["content-type"] == "application/json"
|
|
|
|
|
|
class TestExceptionClasses:
|
|
"""Tests for custom exception class properties."""
|
|
|
|
def test_aniworld_exception_defaults(self):
|
|
"""AniWorldAPIException has sensible defaults."""
|
|
exc = AniWorldAPIException("test")
|
|
assert exc.message == "test"
|
|
assert exc.status_code == 500
|
|
assert exc.error_code == "AniWorldAPIException"
|
|
assert exc.details == {}
|
|
|
|
def test_to_dict_format(self):
|
|
"""to_dict returns proper structure."""
|
|
exc = AniWorldAPIException(
|
|
"fail", status_code=400, error_code="FAIL",
|
|
details={"reason": "bad"}
|
|
)
|
|
d = exc.to_dict()
|
|
assert d["error"] == "FAIL"
|
|
assert d["message"] == "fail"
|
|
assert d["details"]["reason"] == "bad"
|
|
|
|
def test_not_found_with_resource_info(self):
|
|
"""NotFoundError includes resource_type and resource_id in details."""
|
|
exc = NotFoundError(
|
|
"not found", resource_type="anime", resource_id="abc-123"
|
|
)
|
|
assert exc.details["resource_type"] == "anime"
|
|
assert exc.details["resource_id"] == "abc-123"
|
|
|
|
def test_rate_limit_with_retry_after(self):
|
|
"""RateLimitError includes retry_after in details."""
|
|
exc = RateLimitError("slow down", retry_after=30)
|
|
assert exc.details["retry_after"] == 30
|
|
|
|
def test_authentication_error_defaults(self):
|
|
"""AuthenticationError defaults to 401 status."""
|
|
exc = AuthenticationError()
|
|
assert exc.status_code == 401
|
|
assert exc.error_code == "AUTHENTICATION_ERROR"
|
|
|
|
def test_authorization_error_defaults(self):
|
|
"""AuthorizationError defaults to 403 status."""
|
|
exc = AuthorizationError()
|
|
assert exc.status_code == 403
|
|
|
|
def test_validation_error_defaults(self):
|
|
"""ValidationError defaults to 422 status."""
|
|
exc = ValidationError()
|
|
assert exc.status_code == 422
|
|
|
|
def test_bad_request_error_defaults(self):
|
|
"""BadRequestError defaults to 400 status."""
|
|
exc = BadRequestError()
|
|
assert exc.status_code == 400
|
|
|
|
def test_conflict_error_defaults(self):
|
|
"""ConflictError defaults to 409 status."""
|
|
exc = ConflictError()
|
|
assert exc.status_code == 409
|
|
|
|
def test_exception_inheritance_chain(self):
|
|
"""All custom exceptions inherit from AniWorldAPIException."""
|
|
assert issubclass(AuthenticationError, AniWorldAPIException)
|
|
assert issubclass(AuthorizationError, AniWorldAPIException)
|
|
assert issubclass(NotFoundError, AniWorldAPIException)
|
|
assert issubclass(ValidationError, AniWorldAPIException)
|
|
assert issubclass(BadRequestError, AniWorldAPIException)
|
|
assert issubclass(ConflictError, AniWorldAPIException)
|
|
assert issubclass(RateLimitError, AniWorldAPIException)
|