"""Unit tests for correlation ID middleware and distributed tracing.""" from typing import Any import pytest from httpx import AsyncClient from starlette.testclient import TestClient from app.config import Settings from app.main import create_app from app.middleware.correlation import CORRELATION_ID_CONTEXT_KEY from app.models.server import ServerStatus def test_correlation_middleware_generates_uuid_when_header_absent( test_settings: Settings, ) -> None: """Correlation middleware generates a UUID4 when X-Correlation-ID header is missing.""" app = create_app(settings=test_settings) # Test with TestClient (synchronous) client = TestClient(app) response = client.get("/api/v1/health") # Should have correlation ID header in response assert "X-Correlation-ID" in response.headers correlation_id = response.headers["X-Correlation-ID"] # UUID4 format: 8-4-4-4-12 hex digits assert len(correlation_id) == 36 assert correlation_id.count("-") == 4 def test_correlation_middleware_preserves_header_from_request( test_settings: Settings, ) -> None: """Correlation middleware preserves X-Correlation-ID header from client request.""" app = create_app(settings=test_settings) client = TestClient(app) test_correlation_id = "550e8400-e29b-41d4-a716-446655440000" response = client.get("/api/v1/health", headers={"X-Correlation-ID": test_correlation_id}) # Should return the same correlation ID in response assert response.headers["X-Correlation-ID"] == test_correlation_id def test_correlation_middleware_stores_in_request_state( test_settings: Settings, ) -> None: """Correlation middleware stores correlation ID in request.state for handlers.""" from unittest.mock import MagicMock app = create_app(settings=test_settings) app.state.server_status = ServerStatus(online=True) mock_scheduler = MagicMock() mock_scheduler.running = True app.state.scheduler = mock_scheduler client = TestClient(app) # Make a request and verify correlation ID is available to handlers test_correlation_id = "550e8400-e29b-41d4-a716-446655440000" response = client.get("/api/v1/health", headers={"X-Correlation-ID": test_correlation_id}) # The health endpoint should return 200, proving the correlation ID was processed assert response.status_code == 200 # Response should have correlation ID header (proves it was stored and added) assert response.headers["X-Correlation-ID"] == test_correlation_id def test_correlation_id_in_response_headers( test_settings: Settings, ) -> None: """Correlation ID is included in all response headers.""" app = create_app(settings=test_settings) client = TestClient(app) # Test without providing header (should generate one) response = client.get("/api/v1/health") assert "X-Correlation-ID" in response.headers # Test with providing header (should preserve it) test_id = "test-correlation-id-12345" response = client.get("/api/v1/health", headers={"X-Correlation-ID": test_id}) assert response.headers["X-Correlation-ID"] == test_id