From 66cc2fdfcbc8eac8cd712990a217e9950697993d Mon Sep 17 00:00:00 2001 From: Lukas Date: Mon, 27 Oct 2025 20:15:07 +0100 Subject: [PATCH] fix connection test --- src/server/api/diagnostics.py | 41 ++++-- tests/unit/test_diagnostics.py | 227 +++++++++++++++++++++++++++++++++ 2 files changed, 259 insertions(+), 9 deletions(-) create mode 100644 tests/unit/test_diagnostics.py diff --git a/src/server/api/diagnostics.py b/src/server/api/diagnostics.py index 1db05c8..8255167 100644 --- a/src/server/api/diagnostics.py +++ b/src/server/api/diagnostics.py @@ -35,6 +35,9 @@ class NetworkDiagnostics(BaseModel): ..., description="Overall internet connectivity status" ) dns_working: bool = Field(..., description="DNS resolution status") + aniworld_reachable: bool = Field( + ..., description="Aniworld.to connectivity status" + ) tests: List[NetworkTestResult] = Field( ..., description="Individual network tests" ) @@ -109,19 +112,20 @@ async def test_host_connectivity( ) -@router.get("/network", response_model=NetworkDiagnostics) +@router.get("/network") async def network_diagnostics( auth: Optional[dict] = Depends(require_auth), -) -> NetworkDiagnostics: +) -> Dict: """Run network connectivity diagnostics. - Tests DNS resolution and connectivity to common services. + Tests DNS resolution and connectivity to common services including + aniworld.to. Args: auth: Authentication token (optional) Returns: - NetworkDiagnostics with test results + Dict with status and diagnostics data Raises: HTTPException: If diagnostics fail @@ -132,11 +136,12 @@ async def network_diagnostics( # Check DNS dns_working = await check_dns() - # Test connectivity to various hosts + # Test connectivity to various hosts including aniworld.to test_hosts = [ ("google.com", 80), ("cloudflare.com", 80), ("github.com", 443), + ("aniworld.to", 443), ] # Run all tests concurrently @@ -148,17 +153,35 @@ async def network_diagnostics( # Determine overall internet connectivity internet_connected = any(result.reachable for result in test_results) - logger.info( - f"Network diagnostics complete: " - f"DNS={dns_working}, Internet={internet_connected}" + # Check if aniworld.to is reachable + aniworld_result = next( + (r for r in test_results if r.host == "aniworld.to"), + None + ) + aniworld_reachable = ( + aniworld_result.reachable if aniworld_result else False ) - return NetworkDiagnostics( + logger.info( + f"Network diagnostics complete: " + f"DNS={dns_working}, Internet={internet_connected}, " + f"Aniworld={aniworld_reachable}" + ) + + # Create diagnostics data + diagnostics_data = NetworkDiagnostics( internet_connected=internet_connected, dns_working=dns_working, + aniworld_reachable=aniworld_reachable, tests=test_results, ) + # Return in standard format expected by frontend + return { + "status": "success", + "data": diagnostics_data.model_dump(), + } + except Exception as e: logger.exception("Failed to run network diagnostics") raise HTTPException( diff --git a/tests/unit/test_diagnostics.py b/tests/unit/test_diagnostics.py new file mode 100644 index 0000000..eca0e08 --- /dev/null +++ b/tests/unit/test_diagnostics.py @@ -0,0 +1,227 @@ +"""Unit tests for diagnostics endpoints.""" +from unittest.mock import MagicMock, patch + +import pytest + +from src.server.api.diagnostics import ( + NetworkTestResult, + check_dns, + network_diagnostics, + test_host_connectivity, +) + + +class TestDiagnosticsEndpoint: + """Test diagnostics API endpoints.""" + + @pytest.mark.asyncio + async def test_network_diagnostics_returns_standard_format(self): + """Test that network diagnostics returns the expected format.""" + # Mock authentication + mock_auth = {"user_id": "test_user"} + + # Mock the helper functions + with patch( + "src.server.api.diagnostics.check_dns", + return_value=True + ), patch( + "src.server.api.diagnostics.test_host_connectivity", + side_effect=[ + NetworkTestResult( + host="google.com", + reachable=True, + response_time_ms=50.5 + ), + NetworkTestResult( + host="cloudflare.com", + reachable=True, + response_time_ms=30.2 + ), + NetworkTestResult( + host="github.com", + reachable=True, + response_time_ms=100.0 + ), + NetworkTestResult( + host="aniworld.to", + reachable=True, + response_time_ms=75.3 + ), + ] + ): + # Call the endpoint + result = await network_diagnostics(auth=mock_auth) + + # Verify response structure + assert isinstance(result, dict) + assert "status" in result + assert "data" in result + assert result["status"] == "success" + + # Verify data structure + data = result["data"] + assert "internet_connected" in data + assert "dns_working" in data + assert "aniworld_reachable" in data + assert "tests" in data + + # Verify values + assert data["internet_connected"] is True + assert data["dns_working"] is True + assert data["aniworld_reachable"] is True + assert len(data["tests"]) == 4 + + @pytest.mark.asyncio + async def test_network_diagnostics_aniworld_unreachable(self): + """Test diagnostics when aniworld.to is unreachable.""" + mock_auth = {"user_id": "test_user"} + + with patch( + "src.server.api.diagnostics.check_dns", + return_value=True + ), patch( + "src.server.api.diagnostics.test_host_connectivity", + side_effect=[ + NetworkTestResult( + host="google.com", + reachable=True, + response_time_ms=50.5 + ), + NetworkTestResult( + host="cloudflare.com", + reachable=True, + response_time_ms=30.2 + ), + NetworkTestResult( + host="github.com", + reachable=True, + response_time_ms=100.0 + ), + NetworkTestResult( + host="aniworld.to", + reachable=False, + error="Connection timeout" + ), + ] + ): + result = await network_diagnostics(auth=mock_auth) + + # Verify aniworld is marked as unreachable + assert result["status"] == "success" + assert result["data"]["aniworld_reachable"] is False + assert result["data"]["internet_connected"] is True + + @pytest.mark.asyncio + async def test_network_diagnostics_all_unreachable(self): + """Test diagnostics when all hosts are unreachable.""" + mock_auth = {"user_id": "test_user"} + + with patch( + "src.server.api.diagnostics.check_dns", + return_value=False + ), patch( + "src.server.api.diagnostics.test_host_connectivity", + side_effect=[ + NetworkTestResult( + host="google.com", + reachable=False, + error="Connection timeout" + ), + NetworkTestResult( + host="cloudflare.com", + reachable=False, + error="Connection timeout" + ), + NetworkTestResult( + host="github.com", + reachable=False, + error="Connection timeout" + ), + NetworkTestResult( + host="aniworld.to", + reachable=False, + error="Connection timeout" + ), + ] + ): + result = await network_diagnostics(auth=mock_auth) + + # Verify all are unreachable + assert result["status"] == "success" + assert result["data"]["internet_connected"] is False + assert result["data"]["dns_working"] is False + assert result["data"]["aniworld_reachable"] is False + + +class TestNetworkHelpers: + """Test network helper functions.""" + + @pytest.mark.asyncio + async def test_check_dns_success(self): + """Test DNS check when DNS is working.""" + with patch("socket.gethostbyname", return_value="142.250.185.78"): + result = await check_dns() + assert result is True + + @pytest.mark.asyncio + async def test_check_dns_failure(self): + """Test DNS check when DNS fails.""" + import socket + with patch( + "socket.gethostbyname", + side_effect=socket.gaierror("DNS lookup failed") + ): + result = await check_dns() + assert result is False + + @pytest.mark.asyncio + async def test_host_connectivity_success(self): + """Test host connectivity check when host is reachable.""" + with patch( + "socket.create_connection", + return_value=MagicMock() + ): + result = await test_host_connectivity("google.com", 80) + assert result.host == "google.com" + assert result.reachable is True + assert result.response_time_ms is not None + assert result.response_time_ms >= 0 + assert result.error is None + + @pytest.mark.asyncio + async def test_host_connectivity_timeout(self): + """Test host connectivity when connection times out.""" + import asyncio + with patch( + "socket.create_connection", + side_effect=asyncio.TimeoutError() + ): + result = await test_host_connectivity("example.com", 80, 1.0) + assert result.host == "example.com" + assert result.reachable is False + assert result.error == "Connection timeout" + + @pytest.mark.asyncio + async def test_host_connectivity_dns_failure(self): + """Test host connectivity when DNS resolution fails.""" + import socket + with patch( + "socket.create_connection", + side_effect=socket.gaierror("Name resolution failed") + ): + result = await test_host_connectivity("invalid.host", 80) + assert result.host == "invalid.host" + assert result.reachable is False + assert "DNS resolution failed" in result.error + + @pytest.mark.asyncio + async def test_host_connectivity_connection_refused(self): + """Test host connectivity when connection is refused.""" + with patch( + "socket.create_connection", + side_effect=ConnectionRefusedError() + ): + result = await test_host_connectivity("localhost", 12345) + assert result.host == "localhost" + assert result.reachable is False + assert result.error == "Connection refused"