- Add TYPE_CHECKING guards for runtime-expensive imports (aiohttp, aiosqlite) - Reorganize imports to follow PEP 8 conventions - Convert TypeAlias to modern PEP 695 type syntax (where appropriate) - Use Sequence/Mapping from collections.abc for type hints (covariant) - Replace string literals with cast() for improved type inference - Fix casting of Fail2BanResponse and TypedDict patterns - Add IpLookupResult TypedDict for precise return type annotation - Reformat overlong lines for readability (120 char limit) - Add asyncio_mode and filterwarnings to pytest config - Update test fixtures with improved type hints This improves mypy type checking and makes type relationships explicit.
281 lines
10 KiB
Python
281 lines
10 KiB
Python
"""Tests for the geo/IP-lookup router endpoints."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import aiosqlite
|
|
import pytest
|
|
from httpx import ASGITransport, AsyncClient
|
|
|
|
from app.config import Settings
|
|
from app.db import init_db
|
|
from app.main import create_app
|
|
from app.services.geo_service import GeoInfo
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_SETUP_PAYLOAD = {
|
|
"master_password": "testpassword1",
|
|
"database_path": "bangui.db",
|
|
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
|
"timezone": "UTC",
|
|
"session_duration_minutes": 60,
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
async def geo_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
|
"""Provide an authenticated ``AsyncClient`` for geo endpoint tests."""
|
|
settings = Settings(
|
|
database_path=str(tmp_path / "geo_test.db"),
|
|
fail2ban_socket="/tmp/fake.sock",
|
|
session_secret="test-geo-secret",
|
|
session_duration_minutes=60,
|
|
timezone="UTC",
|
|
log_level="debug",
|
|
)
|
|
app = create_app(settings=settings)
|
|
|
|
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
|
|
db.row_factory = aiosqlite.Row
|
|
await init_db(db)
|
|
app.state.db = db
|
|
app.state.http_session = MagicMock()
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
await ac.post("/api/setup", json=_SETUP_PAYLOAD)
|
|
login = await ac.post(
|
|
"/api/auth/login",
|
|
json={"password": _SETUP_PAYLOAD["master_password"]},
|
|
)
|
|
assert login.status_code == 200
|
|
yield ac
|
|
|
|
await db.close()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GET /api/geo/lookup/{ip}
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGeoLookup:
|
|
"""Tests for ``GET /api/geo/lookup/{ip}``."""
|
|
|
|
async def test_200_with_geo_info(self, geo_client: AsyncClient) -> None:
|
|
"""GET /api/geo/lookup/{ip} returns 200 with enriched result."""
|
|
geo = GeoInfo(country_code="DE", country_name="Germany", asn="12345", org="Acme")
|
|
result: dict[str, object] = {
|
|
"ip": "1.2.3.4",
|
|
"currently_banned_in": ["sshd"],
|
|
"geo": geo,
|
|
}
|
|
with patch(
|
|
"app.routers.geo.jail_service.lookup_ip",
|
|
AsyncMock(return_value=result),
|
|
):
|
|
resp = await geo_client.get("/api/geo/lookup/1.2.3.4")
|
|
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["ip"] == "1.2.3.4"
|
|
assert data["currently_banned_in"] == ["sshd"]
|
|
assert data["geo"]["country_code"] == "DE"
|
|
assert data["geo"]["country_name"] == "Germany"
|
|
assert data["geo"]["asn"] == "12345"
|
|
assert data["geo"]["org"] == "Acme"
|
|
|
|
async def test_200_when_not_banned(self, geo_client: AsyncClient) -> None:
|
|
"""GET /api/geo/lookup/{ip} returns empty list when IP is not banned anywhere."""
|
|
result: dict[str, object] = {
|
|
"ip": "8.8.8.8",
|
|
"currently_banned_in": [],
|
|
"geo": GeoInfo(country_code="US", country_name="United States", asn=None, org=None),
|
|
}
|
|
with patch(
|
|
"app.routers.geo.jail_service.lookup_ip",
|
|
AsyncMock(return_value=result),
|
|
):
|
|
resp = await geo_client.get("/api/geo/lookup/8.8.8.8")
|
|
|
|
assert resp.status_code == 200
|
|
assert resp.json()["currently_banned_in"] == []
|
|
|
|
async def test_200_with_no_geo(self, geo_client: AsyncClient) -> None:
|
|
"""GET /api/geo/lookup/{ip} returns null geo when enricher fails."""
|
|
result: dict[str, object] = {
|
|
"ip": "1.2.3.4",
|
|
"currently_banned_in": [],
|
|
"geo": None,
|
|
}
|
|
with patch(
|
|
"app.routers.geo.jail_service.lookup_ip",
|
|
AsyncMock(return_value=result),
|
|
):
|
|
resp = await geo_client.get("/api/geo/lookup/1.2.3.4")
|
|
|
|
assert resp.status_code == 200
|
|
assert resp.json()["geo"] is None
|
|
|
|
async def test_400_for_invalid_ip(self, geo_client: AsyncClient) -> None:
|
|
"""GET /api/geo/lookup/{ip} returns 400 for an invalid IP address."""
|
|
with patch(
|
|
"app.routers.geo.jail_service.lookup_ip",
|
|
AsyncMock(side_effect=ValueError("Invalid IP address: 'bad_ip'")),
|
|
):
|
|
resp = await geo_client.get("/api/geo/lookup/bad_ip")
|
|
|
|
assert resp.status_code == 400
|
|
assert "detail" in resp.json()
|
|
|
|
async def test_401_when_unauthenticated(self, geo_client: AsyncClient) -> None:
|
|
"""GET /api/geo/lookup/{ip} returns 401 without a session."""
|
|
app = geo_client._transport.app # type: ignore[attr-defined]
|
|
resp = await AsyncClient(
|
|
transport=ASGITransport(app=app),
|
|
base_url="http://test",
|
|
).get("/api/geo/lookup/1.2.3.4")
|
|
assert resp.status_code == 401
|
|
|
|
async def test_ipv6_address(self, geo_client: AsyncClient) -> None:
|
|
"""GET /api/geo/lookup/{ip} handles IPv6 addresses."""
|
|
result: dict[str, object] = {
|
|
"ip": "2001:db8::1",
|
|
"currently_banned_in": [],
|
|
"geo": None,
|
|
}
|
|
with patch(
|
|
"app.routers.geo.jail_service.lookup_ip",
|
|
AsyncMock(return_value=result),
|
|
):
|
|
resp = await geo_client.get("/api/geo/lookup/2001:db8::1")
|
|
|
|
assert resp.status_code == 200
|
|
assert resp.json()["ip"] == "2001:db8::1"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# POST /api/geo/re-resolve
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestReResolve:
|
|
"""Tests for ``POST /api/geo/re-resolve``."""
|
|
|
|
async def test_returns_200_with_counts(self, geo_client: AsyncClient) -> None:
|
|
"""POST /api/geo/re-resolve returns 200 with resolved/total counts."""
|
|
with patch(
|
|
"app.routers.geo.geo_service.lookup_batch",
|
|
AsyncMock(return_value={}),
|
|
):
|
|
resp = await geo_client.post("/api/geo/re-resolve")
|
|
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert "resolved" in data
|
|
assert "total" in data
|
|
|
|
async def test_empty_when_no_unresolved_ips(self, geo_client: AsyncClient) -> None:
|
|
"""Returns resolved=0, total=0 when geo_cache has no NULL country_code rows."""
|
|
resp = await geo_client.post("/api/geo/re-resolve")
|
|
|
|
assert resp.status_code == 200
|
|
assert resp.json() == {"resolved": 0, "total": 0}
|
|
|
|
async def test_re_resolves_null_ips(self, geo_client: AsyncClient) -> None:
|
|
"""IPs with null country_code in geo_cache are re-resolved via lookup_batch."""
|
|
# Insert a NULL entry into geo_cache.
|
|
app = geo_client._transport.app # type: ignore[attr-defined]
|
|
db: aiosqlite.Connection = app.state.db
|
|
await db.execute("INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)", ("5.5.5.5",))
|
|
await db.commit()
|
|
|
|
geo_result = {"5.5.5.5": GeoInfo(country_code="FR", country_name="France", asn=None, org=None)}
|
|
with patch(
|
|
"app.routers.geo.geo_service.lookup_batch",
|
|
AsyncMock(return_value=geo_result),
|
|
):
|
|
resp = await geo_client.post("/api/geo/re-resolve")
|
|
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["total"] == 1
|
|
assert data["resolved"] == 1
|
|
|
|
async def test_401_when_unauthenticated(self, geo_client: AsyncClient) -> None:
|
|
"""POST /api/geo/re-resolve requires authentication."""
|
|
app = geo_client._transport.app # type: ignore[attr-defined]
|
|
resp = await AsyncClient(
|
|
transport=ASGITransport(app=app),
|
|
base_url="http://test",
|
|
).post("/api/geo/re-resolve")
|
|
assert resp.status_code == 401
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# GET /api/geo/stats
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGeoStats:
|
|
"""Tests for ``GET /api/geo/stats``."""
|
|
|
|
async def test_returns_200_with_stats(self, geo_client: AsyncClient) -> None:
|
|
"""GET /api/geo/stats returns 200 with the expected keys."""
|
|
stats = {
|
|
"cache_size": 100,
|
|
"unresolved": 5,
|
|
"neg_cache_size": 2,
|
|
"dirty_size": 0,
|
|
}
|
|
with patch(
|
|
"app.routers.geo.geo_service.cache_stats",
|
|
AsyncMock(return_value=stats),
|
|
):
|
|
resp = await geo_client.get("/api/geo/stats")
|
|
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["cache_size"] == 100
|
|
assert data["unresolved"] == 5
|
|
assert data["neg_cache_size"] == 2
|
|
assert data["dirty_size"] == 0
|
|
|
|
async def test_stats_empty_cache(self, geo_client: AsyncClient) -> None:
|
|
"""GET /api/geo/stats returns all zeros on a fresh database."""
|
|
resp = await geo_client.get("/api/geo/stats")
|
|
|
|
assert resp.status_code == 200
|
|
data = resp.json()
|
|
assert data["cache_size"] >= 0
|
|
assert data["unresolved"] == 0
|
|
assert data["neg_cache_size"] >= 0
|
|
assert data["dirty_size"] >= 0
|
|
|
|
async def test_stats_counts_unresolved(self, geo_client: AsyncClient) -> None:
|
|
"""GET /api/geo/stats counts NULL-country rows correctly."""
|
|
app = geo_client._transport.app # type: ignore[attr-defined]
|
|
db: aiosqlite.Connection = app.state.db
|
|
await db.execute("INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)", ("7.7.7.7",))
|
|
await db.execute("INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)", ("8.8.8.8",))
|
|
await db.commit()
|
|
|
|
resp = await geo_client.get("/api/geo/stats")
|
|
|
|
assert resp.status_code == 200
|
|
assert resp.json()["unresolved"] >= 2
|
|
|
|
async def test_401_when_unauthenticated(self, geo_client: AsyncClient) -> None:
|
|
"""GET /api/geo/stats requires authentication."""
|
|
app = geo_client._transport.app # type: ignore[attr-defined]
|
|
resp = await AsyncClient(
|
|
transport=ASGITransport(app=app),
|
|
base_url="http://test",
|
|
).get("/api/geo/stats")
|
|
assert resp.status_code == 401
|