diff --git a/tests/unit/test_error_tracking.py b/tests/unit/test_error_tracking.py new file mode 100644 index 0000000..6659e4e --- /dev/null +++ b/tests/unit/test_error_tracking.py @@ -0,0 +1,619 @@ +"""Unit tests for Error Tracking utilities. + +Tests cover: +- Error tracking and history management +- Error statistics calculation +- Request context management +- Context stack operations +- Global singleton instances +- Error deduplication and cleanup +""" +from datetime import datetime, timezone +from typing import Any, Dict + +import pytest + +from src.server.utils.error_tracking import ( + ErrorTracker, + RequestContextManager, + get_context_manager, + get_error_tracker, + reset_error_tracker, +) + + +@pytest.fixture +def error_tracker(): + """Create error tracker instance.""" + return ErrorTracker() + + +@pytest.fixture +def context_manager(): + """Create request context manager instance.""" + return RequestContextManager() + + +class TestErrorTrackerInitialization: + """Tests for ErrorTracker initialization.""" + + def test_initialization(self, error_tracker): + """Test error tracker initialization.""" + assert error_tracker.error_history == [] + assert error_tracker.max_history_size == 1000 + + def test_default_max_history_size(self, error_tracker): + """Test default max history size.""" + assert error_tracker.max_history_size == 1000 + + +class TestTrackError: + """Tests for error tracking functionality.""" + + def test_track_error_basic(self, error_tracker): + """Test basic error tracking.""" + error_id = error_tracker.track_error( + error_type="ValueError", + message="Test error", + request_path="/api/test", + request_method="GET", + ) + + assert error_id is not None + assert len(error_tracker.error_history) == 1 + + error_entry = error_tracker.error_history[0] + assert error_entry["id"] == error_id + assert error_entry["type"] == "ValueError" + assert error_entry["message"] == "Test error" + assert error_entry["request_path"] == "/api/test" + assert error_entry["request_method"] == "GET" + assert error_entry["status_code"] == 500 # default + + def test_track_error_with_user_id(self, error_tracker): + """Test error tracking with user ID.""" + error_id = error_tracker.track_error( + error_type="AuthError", + message="Unauthorized access", + request_path="/api/protected", + request_method="POST", + user_id="user123", + ) + + error_entry = error_tracker.error_history[0] + assert error_entry["user_id"] == "user123" + + def test_track_error_with_custom_status_code(self, error_tracker): + """Test error tracking with custom status code.""" + error_id = error_tracker.track_error( + error_type="NotFoundError", + message="Resource not found", + request_path="/api/resource/123", + request_method="GET", + status_code=404, + ) + + error_entry = error_tracker.error_history[0] + assert error_entry["status_code"] == 404 + + def test_track_error_with_details(self, error_tracker): + """Test error tracking with additional details.""" + details = { + "stack_trace": "line 1\nline 2", + "user_agent": "Mozilla/5.0", + } + + error_id = error_tracker.track_error( + error_type="RuntimeError", + message="Runtime error occurred", + request_path="/api/action", + request_method="PUT", + details=details, + ) + + error_entry = error_tracker.error_history[0] + assert error_entry["details"] == details + + def test_track_error_with_request_id(self, error_tracker): + """Test error tracking with request ID for correlation.""" + error_id = error_tracker.track_error( + error_type="DatabaseError", + message="Connection failed", + request_path="/api/data", + request_method="GET", + request_id="req-12345", + ) + + error_entry = error_tracker.error_history[0] + assert error_entry["request_id"] == "req-12345" + + def test_track_multiple_errors(self, error_tracker): + """Test tracking multiple errors.""" + error_id1 = error_tracker.track_error( + error_type="Error1", + message="First error", + request_path="/api/1", + request_method="GET", + ) + + error_id2 = error_tracker.track_error( + error_type="Error2", + message="Second error", + request_path="/api/2", + request_method="POST", + ) + + assert len(error_tracker.error_history) == 2 + assert error_id1 != error_id2 + assert error_tracker.error_history[0]["id"] == error_id1 + assert error_tracker.error_history[1]["id"] == error_id2 + + def test_error_has_timestamp(self, error_tracker): + """Test that errors have ISO formatted timestamps.""" + error_id = error_tracker.track_error( + error_type="TestError", + message="Test", + request_path="/test", + request_method="GET", + ) + + error_entry = error_tracker.error_history[0] + timestamp = error_entry["timestamp"] + + # Should be valid ISO format with timezone + parsed = datetime.fromisoformat(timestamp) + assert parsed.tzinfo is not None + + +class TestErrorHistoryManagement: + """Tests for error history management.""" + + def test_history_size_limit(self, error_tracker): + """Test that history size is limited to max_history_size.""" + error_tracker.max_history_size = 5 + + # Track 10 errors + for i in range(10): + error_tracker.track_error( + error_type=f"Error{i}", + message=f"Error {i}", + request_path=f"/api/{i}", + request_method="GET", + ) + + # Only last 5 should remain + assert len(error_tracker.error_history) == 5 + + # Should be errors 5-9 + assert error_tracker.error_history[0]["type"] == "Error5" + assert error_tracker.error_history[-1]["type"] == "Error9" + + def test_clear_history(self, error_tracker): + """Test clearing error history.""" + # Track some errors + for i in range(3): + error_tracker.track_error( + error_type=f"Error{i}", + message=f"Error {i}", + request_path=f"/api/{i}", + request_method="GET", + ) + + assert len(error_tracker.error_history) == 3 + + error_tracker.clear_history() + + assert len(error_tracker.error_history) == 0 + + def test_get_recent_errors(self, error_tracker): + """Test getting recent errors.""" + # Track 5 errors + for i in range(5): + error_tracker.track_error( + error_type=f"Error{i}", + message=f"Error {i}", + request_path=f"/api/{i}", + request_method="GET", + ) + + # Get last 3 + recent = error_tracker.get_recent_errors(limit=3) + + assert len(recent) == 3 + assert recent[0]["type"] == "Error2" + assert recent[1]["type"] == "Error3" + assert recent[2]["type"] == "Error4" + + def test_get_recent_errors_with_empty_history(self, error_tracker): + """Test get_recent_errors with empty history.""" + recent = error_tracker.get_recent_errors() + assert recent == [] + + def test_get_recent_errors_default_limit(self, error_tracker): + """Test get_recent_errors default limit is 10.""" + # Track 15 errors + for i in range(15): + error_tracker.track_error( + error_type=f"Error{i}", + message=f"Error {i}", + request_path=f"/api/{i}", + request_method="GET", + ) + + # Default limit is 10 + recent = error_tracker.get_recent_errors() + assert len(recent) == 10 + + # Should be errors 5-14 + assert recent[0]["type"] == "Error5" + assert recent[-1]["type"] == "Error14" + + def test_get_recent_errors_limit_exceeds_history(self, error_tracker): + """Test get_recent_errors when limit exceeds history size.""" + # Track 3 errors + for i in range(3): + error_tracker.track_error( + error_type=f"Error{i}", + message=f"Error {i}", + request_path=f"/api/{i}", + request_method="GET", + ) + + # Request more than available + recent = error_tracker.get_recent_errors(limit=10) + assert len(recent) == 3 + + +class TestErrorStatistics: + """Tests for error statistics calculation.""" + + def test_get_error_stats_empty_history(self, error_tracker): + """Test error stats with empty history.""" + stats = error_tracker.get_error_stats() + + assert stats["total_errors"] == 0 + assert stats["error_types"] == {} + + def test_get_error_stats_single_error(self, error_tracker): + """Test error stats with single error.""" + error_tracker.track_error( + error_type="ValueError", + message="Test error", + request_path="/api/test", + request_method="GET", + status_code=400, + ) + + stats = error_tracker.get_error_stats() + + assert stats["total_errors"] == 1 + assert stats["error_types"] == {"ValueError": 1} + assert stats["status_codes"] == {400: 1} + assert stats["last_error"]["type"] == "ValueError" + + def test_get_error_stats_multiple_error_types(self, error_tracker): + """Test error stats with multiple error types.""" + error_tracker.track_error( + error_type="ValueError", + message="Error 1", + request_path="/api/1", + request_method="GET", + status_code=400, + ) + + error_tracker.track_error( + error_type="ValueError", + message="Error 2", + request_path="/api/2", + request_method="GET", + status_code=400, + ) + + error_tracker.track_error( + error_type="RuntimeError", + message="Error 3", + request_path="/api/3", + request_method="POST", + status_code=500, + ) + + stats = error_tracker.get_error_stats() + + assert stats["total_errors"] == 3 + assert stats["error_types"] == {"ValueError": 2, "RuntimeError": 1} + assert stats["status_codes"] == {400: 2, 500: 1} + + def test_get_error_stats_multiple_status_codes(self, error_tracker): + """Test error stats with multiple status codes.""" + status_codes = [400, 404, 500, 400, 404] + + for i, code in enumerate(status_codes): + error_tracker.track_error( + error_type=f"Error{i}", + message=f"Error {i}", + request_path=f"/api/{i}", + request_method="GET", + status_code=code, + ) + + stats = error_tracker.get_error_stats() + + assert stats["status_codes"] == {400: 2, 404: 2, 500: 1} + + def test_get_error_stats_last_error(self, error_tracker): + """Test that last_error contains most recent error.""" + error_tracker.track_error( + error_type="FirstError", + message="First", + request_path="/api/1", + request_method="GET", + ) + + error_tracker.track_error( + error_type="LastError", + message="Last", + request_path="/api/2", + request_method="GET", + ) + + stats = error_tracker.get_error_stats() + + assert stats["last_error"]["type"] == "LastError" + assert stats["last_error"]["message"] == "Last" + + +class TestRequestContextManager: + """Tests for RequestContextManager.""" + + def test_initialization(self, context_manager): + """Test context manager initialization.""" + assert context_manager.context_stack == [] + + def test_push_context(self, context_manager): + """Test pushing context onto stack.""" + context_manager.push_context( + request_id="req-123", + request_path="/api/test", + request_method="GET", + ) + + assert len(context_manager.context_stack) == 1 + context = context_manager.context_stack[0] + + assert context["request_id"] == "req-123" + assert context["request_path"] == "/api/test" + assert context["request_method"] == "GET" + assert context["user_id"] is None + assert "timestamp" in context + + def test_push_context_with_user_id(self, context_manager): + """Test pushing context with user ID.""" + context_manager.push_context( + request_id="req-123", + request_path="/api/protected", + request_method="POST", + user_id="user456", + ) + + context = context_manager.context_stack[0] + assert context["user_id"] == "user456" + + def test_push_multiple_contexts(self, context_manager): + """Test pushing multiple contexts.""" + context_manager.push_context( + request_id="req-1", + request_path="/api/1", + request_method="GET", + ) + + context_manager.push_context( + request_id="req-2", + request_path="/api/2", + request_method="POST", + ) + + assert len(context_manager.context_stack) == 2 + assert context_manager.context_stack[0]["request_id"] == "req-1" + assert context_manager.context_stack[1]["request_id"] == "req-2" + + def test_pop_context(self, context_manager): + """Test popping context from stack.""" + context_manager.push_context( + request_id="req-123", + request_path="/api/test", + request_method="GET", + ) + + popped = context_manager.pop_context() + + assert popped is not None + assert popped["request_id"] == "req-123" + assert len(context_manager.context_stack) == 0 + + def test_pop_context_empty_stack(self, context_manager): + """Test popping from empty stack returns None.""" + popped = context_manager.pop_context() + assert popped is None + + def test_pop_context_order(self, context_manager): + """Test that pop_context follows LIFO order.""" + context_manager.push_context( + request_id="req-1", + request_path="/api/1", + request_method="GET", + ) + + context_manager.push_context( + request_id="req-2", + request_path="/api/2", + request_method="POST", + ) + + # Pop should return last pushed + popped1 = context_manager.pop_context() + assert popped1["request_id"] == "req-2" + + popped2 = context_manager.pop_context() + assert popped2["request_id"] == "req-1" + + # Stack should be empty + assert len(context_manager.context_stack) == 0 + + def test_get_current_context(self, context_manager): + """Test getting current context without popping.""" + context_manager.push_context( + request_id="req-123", + request_path="/api/test", + request_method="GET", + ) + + current = context_manager.get_current_context() + + assert current is not None + assert current["request_id"] == "req-123" + # Stack should still have the context + assert len(context_manager.context_stack) == 1 + + def test_get_current_context_empty_stack(self, context_manager): + """Test getting current context from empty stack.""" + current = context_manager.get_current_context() + assert current is None + + def test_get_current_context_returns_last(self, context_manager): + """Test that get_current_context returns most recent.""" + context_manager.push_context( + request_id="req-1", + request_path="/api/1", + request_method="GET", + ) + + context_manager.push_context( + request_id="req-2", + request_path="/api/2", + request_method="POST", + ) + + current = context_manager.get_current_context() + assert current["request_id"] == "req-2" + + def test_context_has_timestamp(self, context_manager): + """Test that contexts have timestamps.""" + context_manager.push_context( + request_id="req-123", + request_path="/api/test", + request_method="GET", + ) + + context = context_manager.get_current_context() + timestamp = context["timestamp"] + + # Should be valid ISO format with timezone + parsed = datetime.fromisoformat(timestamp) + assert parsed.tzinfo is not None + + +class TestGlobalInstances: + """Tests for global singleton instances.""" + + def test_get_error_tracker_singleton(self): + """Test that get_error_tracker returns singleton.""" + reset_error_tracker() + + tracker1 = get_error_tracker() + tracker2 = get_error_tracker() + + assert tracker1 is tracker2 + + def test_reset_error_tracker(self): + """Test reset_error_tracker creates new instance.""" + tracker1 = get_error_tracker() + reset_error_tracker() + tracker2 = get_error_tracker() + + assert tracker1 is not tracker2 + + def test_get_context_manager_singleton(self): + """Test that get_context_manager returns singleton.""" + manager1 = get_context_manager() + manager2 = get_context_manager() + + assert manager1 is manager2 + + def test_error_tracker_state_persists(self): + """Test that error tracker state persists across calls.""" + reset_error_tracker() + + tracker1 = get_error_tracker() + tracker1.track_error( + error_type="TestError", + message="Test", + request_path="/test", + request_method="GET", + ) + + tracker2 = get_error_tracker() + assert len(tracker2.error_history) == 1 + assert tracker2.error_history[0]["type"] == "TestError" + + +class TestErrorTrackerEdgeCases: + """Tests for edge cases and error conditions.""" + + def test_track_error_without_details(self, error_tracker): + """Test that details default to empty dict.""" + error_id = error_tracker.track_error( + error_type="Error", + message="Test", + request_path="/test", + request_method="GET", + ) + + error_entry = error_tracker.error_history[0] + assert error_entry["details"] == {} + + def test_track_error_with_none_user_id(self, error_tracker): + """Test that user_id can be None.""" + error_id = error_tracker.track_error( + error_type="Error", + message="Test", + request_path="/test", + request_method="GET", + user_id=None, + ) + + error_entry = error_tracker.error_history[0] + assert error_entry["user_id"] is None + + def test_unique_error_ids(self, error_tracker): + """Test that each error gets unique ID.""" + ids = set() + + for i in range(100): + error_id = error_tracker.track_error( + error_type="Error", + message="Test", + request_path="/test", + request_method="GET", + ) + ids.add(error_id) + + # All IDs should be unique + assert len(ids) == 100 + + def test_history_trimming_preserves_recent(self, error_tracker): + """Test that trimming preserves most recent errors.""" + error_tracker.max_history_size = 3 + + # Track errors with unique types + for i in range(5): + error_tracker.track_error( + error_type=f"Error{i}", + message=f"Error {i}", + request_path=f"/api/{i}", + request_method="GET", + ) + + # Should keep last 3 (errors 2, 3, 4) + assert len(error_tracker.error_history) == 3 + types = [e["type"] for e in error_tracker.error_history] + assert types == ["Error2", "Error3", "Error4"]