"""Tests for the geo/IP-lookup router endpoints.""" from __future__ import annotations from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import aiosqlite import pytest from httpx import ASGITransport, AsyncClient from app.config import Settings from app.db import init_db from app.main import create_app from app.services.geo_service import GeoInfo # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- _SETUP_PAYLOAD = { "master_password": "testpassword1", "database_path": "bangui.db", "fail2ban_socket": "/var/run/fail2ban/fail2ban.sock", "timezone": "UTC", "session_duration_minutes": 60, } @pytest.fixture async def geo_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc] """Provide an authenticated ``AsyncClient`` for geo endpoint tests.""" settings = Settings( database_path=str(tmp_path / "geo_test.db"), fail2ban_socket="/tmp/fake.sock", session_secret="test-geo-secret", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path) db.row_factory = aiosqlite.Row await init_db(db) app.state.db = db app.state.http_session = MagicMock() transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: await ac.post("/api/setup", json=_SETUP_PAYLOAD) login = await ac.post( "/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}, ) assert login.status_code == 200 yield ac await db.close() # --------------------------------------------------------------------------- # GET /api/geo/lookup/{ip} # --------------------------------------------------------------------------- class TestGeoLookup: """Tests for ``GET /api/geo/lookup/{ip}``.""" async def test_200_with_geo_info(self, geo_client: AsyncClient) -> None: """GET /api/geo/lookup/{ip} returns 200 with enriched result.""" geo = GeoInfo(country_code="DE", country_name="Germany", asn="12345", org="Acme") result = { "ip": "1.2.3.4", "currently_banned_in": ["sshd"], "geo": geo, } with patch( "app.routers.geo.jail_service.lookup_ip", AsyncMock(return_value=result), ): resp = await geo_client.get("/api/geo/lookup/1.2.3.4") assert resp.status_code == 200 data = resp.json() assert data["ip"] == "1.2.3.4" assert data["currently_banned_in"] == ["sshd"] assert data["geo"]["country_code"] == "DE" assert data["geo"]["country_name"] == "Germany" assert data["geo"]["asn"] == "12345" assert data["geo"]["org"] == "Acme" async def test_200_when_not_banned(self, geo_client: AsyncClient) -> None: """GET /api/geo/lookup/{ip} returns empty list when IP is not banned anywhere.""" result = { "ip": "8.8.8.8", "currently_banned_in": [], "geo": GeoInfo(country_code="US", country_name="United States", asn=None, org=None), } with patch( "app.routers.geo.jail_service.lookup_ip", AsyncMock(return_value=result), ): resp = await geo_client.get("/api/geo/lookup/8.8.8.8") assert resp.status_code == 200 assert resp.json()["currently_banned_in"] == [] async def test_200_with_no_geo(self, geo_client: AsyncClient) -> None: """GET /api/geo/lookup/{ip} returns null geo when enricher fails.""" result = { "ip": "1.2.3.4", "currently_banned_in": [], "geo": None, } with patch( "app.routers.geo.jail_service.lookup_ip", AsyncMock(return_value=result), ): resp = await geo_client.get("/api/geo/lookup/1.2.3.4") assert resp.status_code == 200 assert resp.json()["geo"] is None async def test_400_for_invalid_ip(self, geo_client: AsyncClient) -> None: """GET /api/geo/lookup/{ip} returns 400 for an invalid IP address.""" with patch( "app.routers.geo.jail_service.lookup_ip", AsyncMock(side_effect=ValueError("Invalid IP address: 'bad_ip'")), ): resp = await geo_client.get("/api/geo/lookup/bad_ip") assert resp.status_code == 400 assert "detail" in resp.json() async def test_401_when_unauthenticated(self, geo_client: AsyncClient) -> None: """GET /api/geo/lookup/{ip} returns 401 without a session.""" app = geo_client._transport.app # type: ignore[attr-defined] resp = await AsyncClient( transport=ASGITransport(app=app), base_url="http://test", ).get("/api/geo/lookup/1.2.3.4") assert resp.status_code == 401 async def test_ipv6_address(self, geo_client: AsyncClient) -> None: """GET /api/geo/lookup/{ip} handles IPv6 addresses.""" result = { "ip": "2001:db8::1", "currently_banned_in": [], "geo": None, } with patch( "app.routers.geo.jail_service.lookup_ip", AsyncMock(return_value=result), ): resp = await geo_client.get("/api/geo/lookup/2001:db8::1") assert resp.status_code == 200 assert resp.json()["ip"] == "2001:db8::1" # --------------------------------------------------------------------------- # POST /api/geo/re-resolve # --------------------------------------------------------------------------- class TestReResolve: """Tests for ``POST /api/geo/re-resolve``.""" async def test_returns_200_with_counts(self, geo_client: AsyncClient) -> None: """POST /api/geo/re-resolve returns 200 with resolved/total counts.""" with patch( "app.routers.geo.geo_service.lookup_batch", AsyncMock(return_value={}), ): resp = await geo_client.post("/api/geo/re-resolve") assert resp.status_code == 200 data = resp.json() assert "resolved" in data assert "total" in data async def test_empty_when_no_unresolved_ips(self, geo_client: AsyncClient) -> None: """Returns resolved=0, total=0 when geo_cache has no NULL country_code rows.""" resp = await geo_client.post("/api/geo/re-resolve") assert resp.status_code == 200 assert resp.json() == {"resolved": 0, "total": 0} async def test_re_resolves_null_ips(self, geo_client: AsyncClient) -> None: """IPs with null country_code in geo_cache are re-resolved via lookup_batch.""" # Insert a NULL entry into geo_cache. app = geo_client._transport.app # type: ignore[attr-defined] db: aiosqlite.Connection = app.state.db await db.execute("INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)", ("5.5.5.5",)) await db.commit() geo_result = {"5.5.5.5": GeoInfo(country_code="FR", country_name="France", asn=None, org=None)} with patch( "app.routers.geo.geo_service.lookup_batch", AsyncMock(return_value=geo_result), ): resp = await geo_client.post("/api/geo/re-resolve") assert resp.status_code == 200 data = resp.json() assert data["total"] == 1 assert data["resolved"] == 1 async def test_401_when_unauthenticated(self, geo_client: AsyncClient) -> None: """POST /api/geo/re-resolve requires authentication.""" app = geo_client._transport.app # type: ignore[attr-defined] resp = await AsyncClient( transport=ASGITransport(app=app), base_url="http://test", ).post("/api/geo/re-resolve") assert resp.status_code == 401 # --------------------------------------------------------------------------- # GET /api/geo/stats # --------------------------------------------------------------------------- class TestGeoStats: """Tests for ``GET /api/geo/stats``.""" async def test_returns_200_with_stats(self, geo_client: AsyncClient) -> None: """GET /api/geo/stats returns 200 with the expected keys.""" stats = { "cache_size": 100, "unresolved": 5, "neg_cache_size": 2, "dirty_size": 0, } with patch( "app.routers.geo.geo_service.cache_stats", AsyncMock(return_value=stats), ): resp = await geo_client.get("/api/geo/stats") assert resp.status_code == 200 data = resp.json() assert data["cache_size"] == 100 assert data["unresolved"] == 5 assert data["neg_cache_size"] == 2 assert data["dirty_size"] == 0 async def test_stats_empty_cache(self, geo_client: AsyncClient) -> None: """GET /api/geo/stats returns all zeros on a fresh database.""" resp = await geo_client.get("/api/geo/stats") assert resp.status_code == 200 data = resp.json() assert data["cache_size"] >= 0 assert data["unresolved"] == 0 assert data["neg_cache_size"] >= 0 assert data["dirty_size"] >= 0 async def test_stats_counts_unresolved(self, geo_client: AsyncClient) -> None: """GET /api/geo/stats counts NULL-country rows correctly.""" app = geo_client._transport.app # type: ignore[attr-defined] db: aiosqlite.Connection = app.state.db await db.execute("INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)", ("7.7.7.7",)) await db.execute("INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)", ("8.8.8.8",)) await db.commit() resp = await geo_client.get("/api/geo/stats") assert resp.status_code == 200 assert resp.json()["unresolved"] >= 2 async def test_401_when_unauthenticated(self, geo_client: AsyncClient) -> None: """GET /api/geo/stats requires authentication.""" app = geo_client._transport.app # type: ignore[attr-defined] resp = await AsyncClient( transport=ASGITransport(app=app), base_url="http://test", ).get("/api/geo/stats") assert resp.status_code == 401