Make geo lookups non-blocking with bulk DB writes and background tasks

This commit is contained in:
2026-03-12 18:10:00 +01:00
parent a61c9dc969
commit 28f7b1cfcd
8 changed files with 496 additions and 36 deletions

View File

@@ -614,6 +614,108 @@ class TestOriginFilter:
assert result.total == 3
# ---------------------------------------------------------------------------
# bans_by_country — background geo resolution (Task 3)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestBansbyCountryBackground:
"""bans_by_country() with http_session uses cache-only geo and fires a
background task for uncached IPs instead of blocking on API calls."""
async def test_cached_geo_returned_without_api_call(
self, mixed_origin_db_path: str
) -> None:
"""When all IPs are in the cache, lookup_cached_only returns them and
no background task is created."""
from app.services import geo_service
# Pre-populate the cache for all three IPs in the fixture.
geo_service._cache["10.0.0.1"] = geo_service.GeoInfo( # type: ignore[attr-defined]
country_code="DE", country_name="Germany", asn=None, org=None
)
geo_service._cache["10.0.0.2"] = geo_service.GeoInfo( # type: ignore[attr-defined]
country_code="US", country_name="United States", asn=None, org=None
)
geo_service._cache["10.0.0.3"] = geo_service.GeoInfo( # type: ignore[attr-defined]
country_code="JP", country_name="Japan", asn=None, org=None
)
with (
patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
),
patch(
"app.services.ban_service.asyncio.create_task"
) as mock_create_task,
):
mock_session = AsyncMock()
result = await ban_service.bans_by_country(
"/fake/sock", "24h", http_session=mock_session
)
# All countries resolved from cache — no background task needed.
mock_create_task.assert_not_called()
assert result.total == 3
# Country counts should reflect the cached data.
assert "DE" in result.countries or "US" in result.countries or "JP" in result.countries
geo_service.clear_cache()
async def test_uncached_ips_trigger_background_task(
self, mixed_origin_db_path: str
) -> None:
"""When IPs are NOT in the cache, create_task is called for background
resolution and the response returns without blocking."""
from app.services import geo_service
geo_service.clear_cache() # ensure cache is empty
with (
patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
),
patch(
"app.services.ban_service.asyncio.create_task"
) as mock_create_task,
):
mock_session = AsyncMock()
result = await ban_service.bans_by_country(
"/fake/sock", "24h", http_session=mock_session
)
# Background task must have been scheduled for uncached IPs.
mock_create_task.assert_called_once()
# Response is still valid with empty country map (IPs not cached yet).
assert result.total == 3
async def test_no_background_task_without_http_session(
self, mixed_origin_db_path: str
) -> None:
"""When http_session is None, no background task is created."""
from app.services import geo_service
geo_service.clear_cache()
with (
patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
),
patch(
"app.services.ban_service.asyncio.create_task"
) as mock_create_task,
):
result = await ban_service.bans_by_country(
"/fake/sock", "24h", http_session=None
)
mock_create_task.assert_not_called()
assert result.total == 3
# ---------------------------------------------------------------------------
# ban_trend
# ---------------------------------------------------------------------------

View File

@@ -767,3 +767,147 @@ class TestErrorLogging:
assert event["exc_type"] == "_EmptyMessageError"
assert "_EmptyMessageError" in event["error"]
# ---------------------------------------------------------------------------
# lookup_cached_only (Task 3)
# ---------------------------------------------------------------------------
class TestLookupCachedOnly:
"""lookup_cached_only() returns cache hits without making API calls."""
def test_returns_cached_ips(self) -> None:
"""IPs already in the cache are returned in the geo_map."""
geo_service._cache["1.1.1.1"] = GeoInfo( # type: ignore[attr-defined]
country_code="AU", country_name="Australia", asn="AS13335", org="Cloudflare"
)
geo_map, uncached = geo_service.lookup_cached_only(["1.1.1.1"])
assert "1.1.1.1" in geo_map
assert geo_map["1.1.1.1"].country_code == "AU"
assert uncached == []
def test_returns_uncached_ips(self) -> None:
"""IPs not in the cache appear in the uncached list."""
geo_map, uncached = geo_service.lookup_cached_only(["9.9.9.9"])
assert "9.9.9.9" not in geo_map
assert "9.9.9.9" in uncached
def test_neg_cached_ips_excluded_from_uncached(self) -> None:
"""IPs in the negative cache within TTL are not re-queued as uncached."""
import time
geo_service._neg_cache["10.0.0.1"] = time.monotonic() # type: ignore[attr-defined]
geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.1"])
assert "10.0.0.1" not in geo_map
assert "10.0.0.1" not in uncached
def test_expired_neg_cache_requeued(self) -> None:
"""IPs whose neg-cache entry has expired are listed as uncached."""
geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired # type: ignore[attr-defined]
_geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.2"])
assert "10.0.0.2" in uncached
def test_mixed_ips(self) -> None:
"""A mix of cached, neg-cached, and unknown IPs is split correctly."""
geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined]
country_code="DE", country_name="Germany", asn=None, org=None
)
import time
geo_service._neg_cache["5.5.5.5"] = time.monotonic() # type: ignore[attr-defined]
geo_map, uncached = geo_service.lookup_cached_only(["1.2.3.4", "5.5.5.5", "9.9.9.9"])
assert list(geo_map.keys()) == ["1.2.3.4"]
assert uncached == ["9.9.9.9"]
def test_deduplication(self) -> None:
"""Duplicate IPs in the input appear at most once in the output."""
geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined]
country_code="US", country_name="United States", asn=None, org=None
)
geo_map, uncached = geo_service.lookup_cached_only(
["9.9.9.9", "9.9.9.9", "1.2.3.4", "1.2.3.4"]
)
assert len([ip for ip in geo_map if ip == "1.2.3.4"]) == 1
assert uncached.count("9.9.9.9") == 1
# ---------------------------------------------------------------------------
# Bulk DB writes via executemany (Task 3)
# ---------------------------------------------------------------------------
class TestLookupBatchBulkWrites:
"""lookup_batch() uses executemany for bulk DB writes, not per-IP execute."""
async def test_executemany_called_for_successful_ips(self) -> None:
"""When multiple IPs resolve successfully, a single executemany write occurs."""
ips = ["1.1.1.1", "2.2.2.2", "3.3.3.3"]
batch_response = [
{
"query": ip,
"status": "success",
"countryCode": "DE",
"country": "Germany",
"as": "AS3320",
"org": "Telekom",
}
for ip in ips
]
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
# One executemany for the positive rows.
assert db.executemany.await_count >= 1
# High-level: execute() must NOT be called for the batch writes.
db.execute.assert_not_awaited()
async def test_executemany_called_for_failed_ips(self) -> None:
"""When IPs fail resolution, a single executemany write covers neg entries."""
ips = ["10.0.0.1", "10.0.0.2"]
batch_response = [
{"query": ip, "status": "fail", "message": "private range"}
for ip in ips
]
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
assert db.executemany.await_count >= 1
db.execute.assert_not_awaited()
async def test_mixed_results_two_executemany_calls(self) -> None:
"""A mix of successful and failed IPs produces two executemany calls."""
ips = ["1.1.1.1", "10.0.0.1"]
batch_response = [
{
"query": "1.1.1.1",
"status": "success",
"countryCode": "AU",
"country": "Australia",
"as": "AS13335",
"org": "Cloudflare",
},
{"query": "10.0.0.1", "status": "fail", "message": "private range"},
]
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
# One executemany for positives, one for negatives.
assert db.executemany.await_count == 2
db.execute.assert_not_awaited()

View File

@@ -472,6 +472,83 @@ class TestGetActiveBans:
assert result.total == 1
assert result.bans[0].jail == "sshd"
async def test_http_session_triggers_lookup_batch(self) -> None:
"""When http_session is provided, geo_service.lookup_batch is used."""
from app.services.geo_service import GeoInfo
responses = {
"status": _make_global_status("sshd"),
"get|sshd|banip|--with-time": (
0,
["1.2.3.4 \t2025-01-01 12:00:00 + 3600 = 2025-01-01 13:00:00"],
),
}
mock_geo = {"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")}
with (
_patch_client(responses),
patch(
"app.services.geo_service.lookup_batch",
new=AsyncMock(return_value=mock_geo),
) as mock_batch,
):
mock_session = AsyncMock()
result = await jail_service.get_active_bans(
_SOCKET, http_session=mock_session
)
mock_batch.assert_awaited_once()
assert result.total == 1
assert result.bans[0].country == "DE"
async def test_http_session_batch_failure_graceful(self) -> None:
"""When lookup_batch raises, get_active_bans returns bans without geo."""
responses = {
"status": _make_global_status("sshd"),
"get|sshd|banip|--with-time": (
0,
["1.2.3.4 \t2025-01-01 12:00:00 + 3600 = 2025-01-01 13:00:00"],
),
}
with (
_patch_client(responses),
patch(
"app.services.geo_service.lookup_batch",
new=AsyncMock(side_effect=RuntimeError("geo down")),
),
):
mock_session = AsyncMock()
result = await jail_service.get_active_bans(
_SOCKET, http_session=mock_session
)
assert result.total == 1
assert result.bans[0].country is None
async def test_geo_enricher_still_used_without_http_session(self) -> None:
"""Legacy geo_enricher is still called when http_session is not provided."""
from app.services.geo_service import GeoInfo
responses = {
"status": _make_global_status("sshd"),
"get|sshd|banip|--with-time": (
0,
["1.2.3.4 \t2025-01-01 12:00:00 + 3600 = 2025-01-01 13:00:00"],
),
}
async def _enricher(ip: str) -> GeoInfo | None:
return GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None)
with _patch_client(responses):
result = await jail_service.get_active_bans(
_SOCKET, geo_enricher=_enricher
)
assert result.total == 1
assert result.bans[0].country == "JP"
# ---------------------------------------------------------------------------
# Ignore list