"""Unit tests for FastAPI middleware error handler. Tests error response formatting, exception handler registration, custom exception handling, and the general exception handler. """ from typing import Any, Dict, Optional from unittest.mock import MagicMock, patch import pytest from fastapi import FastAPI from httpx import ASGITransport, AsyncClient from src.server.exceptions import ( AniWorldAPIException, AuthenticationError, AuthorizationError, BadRequestError, ConflictError, NotFoundError, RateLimitError, ValidationError, ) from src.server.middleware.error_handler import ( create_error_response, register_exception_handlers, ) class TestCreateErrorResponse: """Tests for the create_error_response utility function.""" def test_basic_error_response_structure(self): """Error response has success, error, and message keys.""" resp = create_error_response( status_code=400, error="BAD_REQUEST", message="Invalid input" ) assert resp["success"] is False assert resp["error"] == "BAD_REQUEST" assert resp["message"] == "Invalid input" def test_response_includes_details_when_provided(self): """Details dict is included when specified.""" details = {"field": "name", "reason": "too long"} resp = create_error_response( status_code=422, error="VALIDATION", message="Bad", details=details, ) assert resp["details"] == details def test_response_excludes_details_when_none(self): """Details key absent when not specified.""" resp = create_error_response( status_code=400, error="ERR", message="msg" ) assert "details" not in resp def test_response_includes_request_id(self): """Request ID is included when provided.""" resp = create_error_response( status_code=500, error="ERR", message="msg", request_id="req-123", ) assert resp["request_id"] == "req-123" def test_response_excludes_request_id_when_none(self): """Request ID key absent when not specified.""" resp = create_error_response( status_code=500, error="ERR", message="msg" ) assert "request_id" not in resp class TestExceptionHandlerRegistration: """Tests that exception handlers are correctly registered on a FastAPI app.""" @pytest.fixture def app_with_handlers(self) -> FastAPI: """Create a FastAPI app with registered exception handlers.""" app = FastAPI() register_exception_handlers(app) return app def _add_route_raising(self, app: FastAPI, exc: Exception): """Add a GET /test route that raises the given exception.""" @app.get("/test") async def route(): raise exc @pytest.mark.asyncio async def test_authentication_error_returns_401( self, app_with_handlers ): """AuthenticationError maps to HTTP 401.""" self._add_route_raising( app_with_handlers, AuthenticationError("bad creds") ) async with AsyncClient( transport=ASGITransport(app=app_with_handlers), base_url="http://test", ) as client: resp = await client.get("/test") assert resp.status_code == 401 body = resp.json() assert body["success"] is False assert body["error"] == "AUTHENTICATION_ERROR" assert body["message"] == "bad creds" @pytest.mark.asyncio async def test_authorization_error_returns_403( self, app_with_handlers ): """AuthorizationError maps to HTTP 403.""" self._add_route_raising( app_with_handlers, AuthorizationError("forbidden") ) async with AsyncClient( transport=ASGITransport(app=app_with_handlers), base_url="http://test", ) as client: resp = await client.get("/test") assert resp.status_code == 403 assert resp.json()["error"] == "AUTHORIZATION_ERROR" @pytest.mark.asyncio async def test_bad_request_error_returns_400( self, app_with_handlers ): """BadRequestError maps to HTTP 400.""" self._add_route_raising( app_with_handlers, BadRequestError("invalid") ) async with AsyncClient( transport=ASGITransport(app=app_with_handlers), base_url="http://test", ) as client: resp = await client.get("/test") assert resp.status_code == 400 assert resp.json()["error"] == "BAD_REQUEST" @pytest.mark.asyncio async def test_not_found_error_returns_404( self, app_with_handlers ): """NotFoundError maps to HTTP 404.""" self._add_route_raising( app_with_handlers, NotFoundError("anime not found", resource_type="anime", resource_id=42), ) async with AsyncClient( transport=ASGITransport(app=app_with_handlers), base_url="http://test", ) as client: resp = await client.get("/test") assert resp.status_code == 404 body = resp.json() assert body["error"] == "NOT_FOUND" assert body["details"]["resource_type"] == "anime" assert body["details"]["resource_id"] == 42 @pytest.mark.asyncio async def test_validation_error_returns_422( self, app_with_handlers ): """ValidationError maps to HTTP 422.""" self._add_route_raising( app_with_handlers, ValidationError("bad data") ) async with AsyncClient( transport=ASGITransport(app=app_with_handlers), base_url="http://test", ) as client: resp = await client.get("/test") assert resp.status_code == 422 assert resp.json()["error"] == "VALIDATION_ERROR" @pytest.mark.asyncio async def test_conflict_error_returns_409( self, app_with_handlers ): """ConflictError maps to HTTP 409.""" self._add_route_raising( app_with_handlers, ConflictError("duplicate") ) async with AsyncClient( transport=ASGITransport(app=app_with_handlers), base_url="http://test", ) as client: resp = await client.get("/test") assert resp.status_code == 409 assert resp.json()["error"] == "CONFLICT" @pytest.mark.asyncio async def test_rate_limit_error_returns_429( self, app_with_handlers ): """RateLimitError maps to HTTP 429.""" self._add_route_raising( app_with_handlers, RateLimitError("too many", retry_after=60) ) async with AsyncClient( transport=ASGITransport(app=app_with_handlers), base_url="http://test", ) as client: resp = await client.get("/test") assert resp.status_code == 429 body = resp.json() assert body["error"] == "RATE_LIMIT_EXCEEDED" assert body["details"]["retry_after"] == 60 @pytest.mark.asyncio async def test_generic_api_exception_returns_status( self, app_with_handlers ): """AniWorldAPIException uses its status_code.""" self._add_route_raising( app_with_handlers, AniWorldAPIException("custom error", status_code=418), ) async with AsyncClient( transport=ASGITransport(app=app_with_handlers), base_url="http://test", ) as client: resp = await client.get("/test") assert resp.status_code == 418 @pytest.mark.asyncio async def test_unexpected_exception_returns_500( self, app_with_handlers ): """Unhandled exceptions map to HTTP 500 with generic message.""" self._add_route_raising( app_with_handlers, RuntimeError("unexpected crash") ) async with AsyncClient( transport=ASGITransport(app=app_with_handlers, raise_app_exceptions=False), base_url="http://test", ) as client: resp = await client.get("/test") assert resp.status_code == 500 body = resp.json() assert body["error"] == "INTERNAL_SERVER_ERROR" assert body["message"] == "An unexpected error occurred" @pytest.mark.asyncio async def test_unexpected_exception_hides_stack_trace( self, app_with_handlers ): """Stack traces are not leaked in 500 error responses.""" self._add_route_raising( app_with_handlers, RuntimeError("internal secret") ) async with AsyncClient( transport=ASGITransport(app=app_with_handlers, raise_app_exceptions=False), base_url="http://test", ) as client: resp = await client.get("/test") body = resp.json() assert "internal secret" not in body["message"] assert "Traceback" not in str(body) @pytest.mark.asyncio async def test_error_response_is_json(self, app_with_handlers): """All error responses are JSON formatted.""" self._add_route_raising( app_with_handlers, NotFoundError("missing") ) async with AsyncClient( transport=ASGITransport(app=app_with_handlers), base_url="http://test", ) as client: resp = await client.get("/test") assert resp.headers["content-type"] == "application/json" class TestExceptionClasses: """Tests for custom exception class properties.""" def test_aniworld_exception_defaults(self): """AniWorldAPIException has sensible defaults.""" exc = AniWorldAPIException("test") assert exc.message == "test" assert exc.status_code == 500 assert exc.error_code == "AniWorldAPIException" assert exc.details == {} def test_to_dict_format(self): """to_dict returns proper structure.""" exc = AniWorldAPIException( "fail", status_code=400, error_code="FAIL", details={"reason": "bad"} ) d = exc.to_dict() assert d["error"] == "FAIL" assert d["message"] == "fail" assert d["details"]["reason"] == "bad" def test_not_found_with_resource_info(self): """NotFoundError includes resource_type and resource_id in details.""" exc = NotFoundError( "not found", resource_type="anime", resource_id="abc-123" ) assert exc.details["resource_type"] == "anime" assert exc.details["resource_id"] == "abc-123" def test_rate_limit_with_retry_after(self): """RateLimitError includes retry_after in details.""" exc = RateLimitError("slow down", retry_after=30) assert exc.details["retry_after"] == 30 def test_authentication_error_defaults(self): """AuthenticationError defaults to 401 status.""" exc = AuthenticationError() assert exc.status_code == 401 assert exc.error_code == "AUTHENTICATION_ERROR" def test_authorization_error_defaults(self): """AuthorizationError defaults to 403 status.""" exc = AuthorizationError() assert exc.status_code == 403 def test_validation_error_defaults(self): """ValidationError defaults to 422 status.""" exc = ValidationError() assert exc.status_code == 422 def test_bad_request_error_defaults(self): """BadRequestError defaults to 400 status.""" exc = BadRequestError() assert exc.status_code == 400 def test_conflict_error_defaults(self): """ConflictError defaults to 409 status.""" exc = ConflictError() assert exc.status_code == 409 def test_exception_inheritance_chain(self): """All custom exceptions inherit from AniWorldAPIException.""" assert issubclass(AuthenticationError, AniWorldAPIException) assert issubclass(AuthorizationError, AniWorldAPIException) assert issubclass(NotFoundError, AniWorldAPIException) assert issubclass(ValidationError, AniWorldAPIException) assert issubclass(BadRequestError, AniWorldAPIException) assert issubclass(ConflictError, AniWorldAPIException) assert issubclass(RateLimitError, AniWorldAPIException)