Make geo lookups non-blocking with bulk DB writes and background tasks

This commit is contained in:
2026-03-12 18:10:00 +01:00
parent a61c9dc969
commit 28f7b1cfcd
8 changed files with 496 additions and 36 deletions

View File

@@ -50,10 +50,17 @@ This document breaks the entire BanGUI project into development stages, ordered
--- ---
## Task 3 — Non-Blocking Web Requests & Bulk DB Operations ## Task 3 — Non-Blocking Web Requests & Bulk DB Operations ✅ DONE
**Goal:** Ensure the web UI remains responsive while geo-IP lookups and database writes are in progress. **Goal:** Ensure the web UI remains responsive while geo-IP lookups and database writes are in progress.
**Resolution:**
- **Bulk DB writes:** `geo_service.lookup_batch` now collects resolved IPs into `pos_rows` / `neg_ips` lists across the chunk loop and flushes them with two `executemany` calls per chunk instead of one `execute` per IP.
- **`lookup_cached_only`:** New function that returns `(geo_map, uncached)` immediately from the in-memory + SQLite cache with no API calls. Used by `bans_by_country` for its hot path.
- **Background geo resolution:** `bans_by_country` calls `lookup_cached_only` for an instant response, then fires `asyncio.create_task(geo_service.lookup_batch(uncached, …))` to populate the cache in the background for subsequent requests.
- **Batch enrichment for `get_active_bans`:** `jail_service.get_active_bans` now accepts `http_session` / `app_db` and resolves all banned IPs in a single `lookup_batch` call (chunked 100-IP batches) instead of firing one coroutine per IP through `asyncio.gather`.
- 12 new tests across `test_geo_service.py`, `test_jail_service.py`, and `test_ban_service.py`; `ruff` and `mypy --strict` clean; 145 tests pass.
**Details:** **Details:**
- After the geo-IP service was integrated, web UI requests became slow or appeared to hang because geo lookups and individual DB writes block the async event loop. - After the geo-IP service was integrated, web UI requests became slow or appeared to hang because geo lookups and individual DB writes block the async event loop.

View File

@@ -20,7 +20,7 @@ from fastapi import APIRouter, HTTPException, Request, status
from app.dependencies import AuthDep from app.dependencies import AuthDep
from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest
from app.models.jail import JailCommandResponse from app.models.jail import JailCommandResponse
from app.services import geo_service, jail_service from app.services import jail_service
from app.services.jail_service import JailNotFoundError, JailOperationError from app.services.jail_service import JailNotFoundError, JailOperationError
from app.utils.fail2ban_client import Fail2BanConnectionError from app.utils.fail2ban_client import Fail2BanConnectionError
@@ -68,12 +68,14 @@ async def get_active_bans(
""" """
socket_path: str = request.app.state.settings.fail2ban_socket socket_path: str = request.app.state.settings.fail2ban_socket
http_session: aiohttp.ClientSession = request.app.state.http_session http_session: aiohttp.ClientSession = request.app.state.http_session
app_db = request.app.state.db
async def _enricher(ip: str) -> geo_service.GeoInfo | None:
return await geo_service.lookup(ip, http_session)
try: try:
return await jail_service.get_active_bans(socket_path, geo_enricher=_enricher) return await jail_service.get_active_bans(
socket_path,
http_session=http_session,
app_db=app_db,
)
except Fail2BanConnectionError as exc: except Fail2BanConnectionError as exc:
raise _bad_gateway(exc) from exc raise _bad_gateway(exc) from exc

View File

@@ -10,6 +10,7 @@ so BanGUI never modifies or locks the fail2ban database.
from __future__ import annotations from __future__ import annotations
import asyncio
import json import json
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
@@ -344,20 +345,26 @@ async def bans_by_country(
1. Queries the fail2ban DB with ``GROUP BY ip`` to get the per-IP ban 1. Queries the fail2ban DB with ``GROUP BY ip`` to get the per-IP ban
counts for all unique IPs in the window — no row-count cap. counts for all unique IPs in the window — no row-count cap.
2. Batch-resolves every unique IP via :func:`~app.services.geo_service.lookup_batch` 2. Serves geo data from the in-memory cache only (non-blocking).
(100 IPs per HTTP call) instead of one-at-a-time lookups. Any IPs not yet in the cache are scheduled for background resolution
via :func:`asyncio.create_task` so the response is returned immediately
and subsequent requests benefit from the warmed cache.
3. Returns a ``{country_code: count}`` aggregation and the 200 most 3. Returns a ``{country_code: count}`` aggregation and the 200 most
recent raw rows (already geo-cached from step 2) for the companion recent raw rows for the companion table.
table.
Note:
On the very first request a large number of IPs may be uncached and
the country map will be sparse. The background task will resolve them
and the next request will return a complete map. This trade-off keeps
the endpoint fast regardless of dataset size.
Args: Args:
socket_path: Path to the fail2ban Unix domain socket. socket_path: Path to the fail2ban Unix domain socket.
range_: Time-range preset. range_: Time-range preset.
http_session: Optional :class:`aiohttp.ClientSession` for batch http_session: Optional :class:`aiohttp.ClientSession` for background
geo lookups. When provided, :func:`geo_service.lookup_batch` geo lookups. When ``None``, only cached data is used.
is used instead of the *geo_enricher* callable.
geo_enricher: Legacy async ``(ip) -> GeoInfo | None`` callable; geo_enricher: Legacy async ``(ip) -> GeoInfo | None`` callable;
used when *http_session* is ``None``. used when *http_session* is ``None`` (e.g. tests).
app_db: Optional BanGUI application database used to persist newly app_db: Optional BanGUI application database used to persist newly
resolved geo entries across restarts. resolved geo entries across restarts.
origin: Optional origin filter — ``"blocklist"`` restricts results to origin: Optional origin filter — ``"blocklist"`` restricts results to
@@ -367,8 +374,6 @@ async def bans_by_country(
:class:`~app.models.ban.BansByCountryResponse` with per-country :class:`~app.models.ban.BansByCountryResponse` with per-country
aggregation and the companion ban list. aggregation and the companion ban list.
""" """
import asyncio
from app.services import geo_service # noqa: PLC0415 from app.services import geo_service # noqa: PLC0415
since: int = _since_unix(range_) since: int = _since_unix(range_)
@@ -417,15 +422,26 @@ async def bans_by_country(
) as cur: ) as cur:
companion_rows = await cur.fetchall() companion_rows = await cur.fetchall()
# Batch-resolve all unique IPs (much faster than individual lookups).
unique_ips: list[str] = [str(r["ip"]) for r in agg_rows] unique_ips: list[str] = [str(r["ip"]) for r in agg_rows]
geo_map: dict[str, Any] = {} geo_map: dict[str, Any] = {}
if http_session is not None and unique_ips: if http_session is not None and unique_ips:
try: # Serve only what is already in the in-memory cache — no API calls on
geo_map = await geo_service.lookup_batch(unique_ips, http_session, db=app_db) # the hot path. Uncached IPs are resolved asynchronously in the
except Exception as exc: # noqa: BLE001 # background so subsequent requests benefit from a warmer cache.
log.warning("ban_service_batch_geo_failed", error=str(exc)) geo_map, uncached = geo_service.lookup_cached_only(unique_ips)
if uncached:
log.info(
"ban_service_geo_background_scheduled",
uncached=len(uncached),
cached=len(geo_map),
)
# Fire-and-forget: lookup_batch handles rate-limiting / retries.
# The dirty-set flush task persists results to the DB.
asyncio.create_task( # noqa: RUF006
geo_service.lookup_batch(uncached, http_session, db=app_db),
name="geo_bans_by_country",
)
elif geo_enricher is not None and unique_ips: elif geo_enricher is not None and unique_ips:
# Fallback: legacy per-IP enricher (used in tests / older callers). # Fallback: legacy per-IP enricher (used in tests / older callers).
async def _safe_lookup(ip: str) -> tuple[str, Any]: async def _safe_lookup(ip: str) -> tuple[str, Any]:

View File

@@ -435,6 +435,41 @@ async def lookup(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def lookup_cached_only(
ips: list[str],
) -> tuple[dict[str, GeoInfo], list[str]]:
"""Return cached geo data for *ips* without making any external API calls.
Used by callers that want to return a fast response using only what is
already in memory, while deferring resolution of uncached IPs to a
background task.
Args:
ips: IP address strings to look up.
Returns:
A ``(geo_map, uncached)`` tuple where *geo_map* maps every IP that
was already in the in-memory cache to its :class:`GeoInfo`, and
*uncached* is the list of IPs that were not found in the cache.
Entries in the negative cache (recently failed) are **not** included
in *uncached* so they are not re-queued immediately.
"""
geo_map: dict[str, GeoInfo] = {}
uncached: list[str] = []
now = time.monotonic()
for ip in dict.fromkeys(ips): # deduplicate, preserve order
if ip in _cache:
geo_map[ip] = _cache[ip]
elif ip in _neg_cache and (now - _neg_cache[ip]) < _NEG_CACHE_TTL:
# Still within the cool-down window — do not re-queue.
pass
else:
uncached.append(ip)
return geo_map, uncached
async def lookup_batch( async def lookup_batch(
ips: list[str], ips: list[str],
http_session: aiohttp.ClientSession, http_session: aiohttp.ClientSession,
@@ -447,7 +482,9 @@ async def lookup_batch(
``http://ip-api.com/batch`` in chunks of up to :data:`_BATCH_SIZE`. ``http://ip-api.com/batch`` in chunks of up to :data:`_BATCH_SIZE`.
Only successful resolutions (``country_code is not None``) are written to Only successful resolutions (``country_code is not None``) are written to
the persistent cache when *db* is provided. the persistent cache when *db* is provided. Both positive and negative
entries are written in bulk using ``executemany`` (one round-trip per
chunk) rather than one ``execute`` per IP.
Args: Args:
ips: List of IP address strings to resolve. Duplicates are ignored. ips: List of IP address strings to resolve. Duplicates are ignored.
@@ -509,16 +546,19 @@ async def lookup_batch(
assert chunk_result is not None # noqa: S101 assert chunk_result is not None # noqa: S101
# Collect bulk-write rows instead of one execute per IP.
pos_rows: list[tuple[str, str | None, str | None, str | None, str | None]] = []
neg_ips: list[str] = []
for ip, info in chunk_result.items(): for ip, info in chunk_result.items():
if info.country_code is not None: if info.country_code is not None:
# Successful API resolution. # Successful API resolution.
_store(ip, info) _store(ip, info)
geo_result[ip] = info geo_result[ip] = info
if db is not None: if db is not None:
try: pos_rows.append(
await _persist_entry(db, ip, info) (ip, info.country_code, info.country_name, info.asn, info.org)
except Exception as exc: # noqa: BLE001 )
log.warning("geo_persist_failed", ip=ip, error=str(exc))
else: else:
# API failed — try local GeoIP fallback. # API failed — try local GeoIP fallback.
fallback = _geoip_lookup(ip) fallback = _geoip_lookup(ip)
@@ -526,19 +566,56 @@ async def lookup_batch(
_store(ip, fallback) _store(ip, fallback)
geo_result[ip] = fallback geo_result[ip] = fallback
if db is not None: if db is not None:
try: pos_rows.append(
await _persist_entry(db, ip, fallback) (
except Exception as exc: # noqa: BLE001 ip,
log.warning("geo_persist_failed", ip=ip, error=str(exc)) fallback.country_code,
fallback.country_name,
fallback.asn,
fallback.org,
)
)
else: else:
# Both resolvers failed — record in negative cache. # Both resolvers failed — record in negative cache.
_neg_cache[ip] = time.monotonic() _neg_cache[ip] = time.monotonic()
geo_result[ip] = _empty geo_result[ip] = _empty
if db is not None: if db is not None:
neg_ips.append(ip)
if db is not None:
if pos_rows:
try: try:
await _persist_neg_entry(db, ip) await db.executemany(
"""
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(ip) DO UPDATE SET
country_code = excluded.country_code,
country_name = excluded.country_name,
asn = excluded.asn,
org = excluded.org,
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
""",
pos_rows,
)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning("geo_persist_neg_failed", ip=ip, error=str(exc)) log.warning(
"geo_batch_persist_failed",
count=len(pos_rows),
error=str(exc),
)
if neg_ips:
try:
await db.executemany(
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
[(ip,) for ip in neg_ips],
)
except Exception as exc: # noqa: BLE001
log.warning(
"geo_batch_persist_neg_failed",
count=len(neg_ips),
error=str(exc),
)
if db is not None: if db is not None:
try: try:

View File

@@ -627,16 +627,34 @@ async def unban_ip(
async def get_active_bans( async def get_active_bans(
socket_path: str, socket_path: str,
geo_enricher: Any | None = None, geo_enricher: Any | None = None,
http_session: Any | None = None,
app_db: Any | None = None,
) -> ActiveBanListResponse: ) -> ActiveBanListResponse:
"""Return all currently banned IPs across every jail. """Return all currently banned IPs across every jail.
For each jail the ``get <jail> banip --with-time`` command is used For each jail the ``get <jail> banip --with-time`` command is used
to retrieve ban start and expiry times alongside the IP address. to retrieve ban start and expiry times alongside the IP address.
Geo enrichment strategy (highest priority first):
1. When *http_session* is provided the entire set of banned IPs is resolved
in a single :func:`~app.services.geo_service.lookup_batch` call (up to
100 IPs per HTTP request). This is far more efficient than concurrent
per-IP lookups and stays within ip-api.com rate limits.
2. When only *geo_enricher* is provided (legacy / test path) each IP is
resolved individually via the supplied async callable.
Args: Args:
socket_path: Path to the fail2ban Unix domain socket. socket_path: Path to the fail2ban Unix domain socket.
geo_enricher: Optional async callable ``(ip) → GeoInfo | None`` geo_enricher: Optional async callable ``(ip) → GeoInfo | None``
used to enrich each ban entry with country and ASN data. used to enrich each ban entry with country and ASN data.
Ignored when *http_session* is provided.
http_session: Optional shared :class:`aiohttp.ClientSession`. When
provided, :func:`~app.services.geo_service.lookup_batch` is used
for efficient bulk geo resolution.
app_db: Optional BanGUI application database connection used to
persist newly resolved geo entries across restarts. Only
meaningful when *http_session* is provided.
Returns: Returns:
:class:`~app.models.ban.ActiveBanListResponse` with all active bans. :class:`~app.models.ban.ActiveBanListResponse` with all active bans.
@@ -645,6 +663,8 @@ async def get_active_bans(
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket ~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
cannot be reached. cannot be reached.
""" """
from app.services import geo_service # noqa: PLC0415
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
# Fetch jail names. # Fetch jail names.
@@ -690,8 +710,23 @@ async def get_active_bans(
if ban is not None: if ban is not None:
bans.append(ban) bans.append(ban)
# Enrich with geo data if an enricher was provided. # Enrich with geo data — prefer batch lookup over per-IP enricher.
if geo_enricher is not None: if http_session is not None and bans:
all_ips: list[str] = [ban.ip for ban in bans]
try:
geo_map = await geo_service.lookup_batch(all_ips, http_session, db=app_db)
except Exception: # noqa: BLE001
log.warning("active_bans_batch_geo_failed")
geo_map = {}
enriched: list[ActiveBan] = []
for ban in bans:
geo = geo_map.get(ban.ip)
if geo is not None:
enriched.append(ban.model_copy(update={"country": geo.country_code}))
else:
enriched.append(ban)
bans = enriched
elif geo_enricher is not None:
bans = await _enrich_bans(bans, geo_enricher) bans = await _enrich_bans(bans, geo_enricher)
log.info("active_bans_fetched", total=len(bans)) log.info("active_bans_fetched", total=len(bans))

View File

@@ -614,6 +614,108 @@ class TestOriginFilter:
assert result.total == 3 assert result.total == 3
# ---------------------------------------------------------------------------
# bans_by_country — background geo resolution (Task 3)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestBansbyCountryBackground:
"""bans_by_country() with http_session uses cache-only geo and fires a
background task for uncached IPs instead of blocking on API calls."""
async def test_cached_geo_returned_without_api_call(
self, mixed_origin_db_path: str
) -> None:
"""When all IPs are in the cache, lookup_cached_only returns them and
no background task is created."""
from app.services import geo_service
# Pre-populate the cache for all three IPs in the fixture.
geo_service._cache["10.0.0.1"] = geo_service.GeoInfo( # type: ignore[attr-defined]
country_code="DE", country_name="Germany", asn=None, org=None
)
geo_service._cache["10.0.0.2"] = geo_service.GeoInfo( # type: ignore[attr-defined]
country_code="US", country_name="United States", asn=None, org=None
)
geo_service._cache["10.0.0.3"] = geo_service.GeoInfo( # type: ignore[attr-defined]
country_code="JP", country_name="Japan", asn=None, org=None
)
with (
patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
),
patch(
"app.services.ban_service.asyncio.create_task"
) as mock_create_task,
):
mock_session = AsyncMock()
result = await ban_service.bans_by_country(
"/fake/sock", "24h", http_session=mock_session
)
# All countries resolved from cache — no background task needed.
mock_create_task.assert_not_called()
assert result.total == 3
# Country counts should reflect the cached data.
assert "DE" in result.countries or "US" in result.countries or "JP" in result.countries
geo_service.clear_cache()
async def test_uncached_ips_trigger_background_task(
self, mixed_origin_db_path: str
) -> None:
"""When IPs are NOT in the cache, create_task is called for background
resolution and the response returns without blocking."""
from app.services import geo_service
geo_service.clear_cache() # ensure cache is empty
with (
patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
),
patch(
"app.services.ban_service.asyncio.create_task"
) as mock_create_task,
):
mock_session = AsyncMock()
result = await ban_service.bans_by_country(
"/fake/sock", "24h", http_session=mock_session
)
# Background task must have been scheduled for uncached IPs.
mock_create_task.assert_called_once()
# Response is still valid with empty country map (IPs not cached yet).
assert result.total == 3
async def test_no_background_task_without_http_session(
self, mixed_origin_db_path: str
) -> None:
"""When http_session is None, no background task is created."""
from app.services import geo_service
geo_service.clear_cache()
with (
patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
),
patch(
"app.services.ban_service.asyncio.create_task"
) as mock_create_task,
):
result = await ban_service.bans_by_country(
"/fake/sock", "24h", http_session=None
)
mock_create_task.assert_not_called()
assert result.total == 3
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# ban_trend # ban_trend
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@@ -767,3 +767,147 @@ class TestErrorLogging:
assert event["exc_type"] == "_EmptyMessageError" assert event["exc_type"] == "_EmptyMessageError"
assert "_EmptyMessageError" in event["error"] assert "_EmptyMessageError" in event["error"]
# ---------------------------------------------------------------------------
# lookup_cached_only (Task 3)
# ---------------------------------------------------------------------------
class TestLookupCachedOnly:
"""lookup_cached_only() returns cache hits without making API calls."""
def test_returns_cached_ips(self) -> None:
"""IPs already in the cache are returned in the geo_map."""
geo_service._cache["1.1.1.1"] = GeoInfo( # type: ignore[attr-defined]
country_code="AU", country_name="Australia", asn="AS13335", org="Cloudflare"
)
geo_map, uncached = geo_service.lookup_cached_only(["1.1.1.1"])
assert "1.1.1.1" in geo_map
assert geo_map["1.1.1.1"].country_code == "AU"
assert uncached == []
def test_returns_uncached_ips(self) -> None:
"""IPs not in the cache appear in the uncached list."""
geo_map, uncached = geo_service.lookup_cached_only(["9.9.9.9"])
assert "9.9.9.9" not in geo_map
assert "9.9.9.9" in uncached
def test_neg_cached_ips_excluded_from_uncached(self) -> None:
"""IPs in the negative cache within TTL are not re-queued as uncached."""
import time
geo_service._neg_cache["10.0.0.1"] = time.monotonic() # type: ignore[attr-defined]
geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.1"])
assert "10.0.0.1" not in geo_map
assert "10.0.0.1" not in uncached
def test_expired_neg_cache_requeued(self) -> None:
"""IPs whose neg-cache entry has expired are listed as uncached."""
geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired # type: ignore[attr-defined]
_geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.2"])
assert "10.0.0.2" in uncached
def test_mixed_ips(self) -> None:
"""A mix of cached, neg-cached, and unknown IPs is split correctly."""
geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined]
country_code="DE", country_name="Germany", asn=None, org=None
)
import time
geo_service._neg_cache["5.5.5.5"] = time.monotonic() # type: ignore[attr-defined]
geo_map, uncached = geo_service.lookup_cached_only(["1.2.3.4", "5.5.5.5", "9.9.9.9"])
assert list(geo_map.keys()) == ["1.2.3.4"]
assert uncached == ["9.9.9.9"]
def test_deduplication(self) -> None:
"""Duplicate IPs in the input appear at most once in the output."""
geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined]
country_code="US", country_name="United States", asn=None, org=None
)
geo_map, uncached = geo_service.lookup_cached_only(
["9.9.9.9", "9.9.9.9", "1.2.3.4", "1.2.3.4"]
)
assert len([ip for ip in geo_map if ip == "1.2.3.4"]) == 1
assert uncached.count("9.9.9.9") == 1
# ---------------------------------------------------------------------------
# Bulk DB writes via executemany (Task 3)
# ---------------------------------------------------------------------------
class TestLookupBatchBulkWrites:
"""lookup_batch() uses executemany for bulk DB writes, not per-IP execute."""
async def test_executemany_called_for_successful_ips(self) -> None:
"""When multiple IPs resolve successfully, a single executemany write occurs."""
ips = ["1.1.1.1", "2.2.2.2", "3.3.3.3"]
batch_response = [
{
"query": ip,
"status": "success",
"countryCode": "DE",
"country": "Germany",
"as": "AS3320",
"org": "Telekom",
}
for ip in ips
]
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
# One executemany for the positive rows.
assert db.executemany.await_count >= 1
# High-level: execute() must NOT be called for the batch writes.
db.execute.assert_not_awaited()
async def test_executemany_called_for_failed_ips(self) -> None:
"""When IPs fail resolution, a single executemany write covers neg entries."""
ips = ["10.0.0.1", "10.0.0.2"]
batch_response = [
{"query": ip, "status": "fail", "message": "private range"}
for ip in ips
]
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
assert db.executemany.await_count >= 1
db.execute.assert_not_awaited()
async def test_mixed_results_two_executemany_calls(self) -> None:
"""A mix of successful and failed IPs produces two executemany calls."""
ips = ["1.1.1.1", "10.0.0.1"]
batch_response = [
{
"query": "1.1.1.1",
"status": "success",
"countryCode": "AU",
"country": "Australia",
"as": "AS13335",
"org": "Cloudflare",
},
{"query": "10.0.0.1", "status": "fail", "message": "private range"},
]
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
# One executemany for positives, one for negatives.
assert db.executemany.await_count == 2
db.execute.assert_not_awaited()

View File

@@ -472,6 +472,83 @@ class TestGetActiveBans:
assert result.total == 1 assert result.total == 1
assert result.bans[0].jail == "sshd" assert result.bans[0].jail == "sshd"
async def test_http_session_triggers_lookup_batch(self) -> None:
"""When http_session is provided, geo_service.lookup_batch is used."""
from app.services.geo_service import GeoInfo
responses = {
"status": _make_global_status("sshd"),
"get|sshd|banip|--with-time": (
0,
["1.2.3.4 \t2025-01-01 12:00:00 + 3600 = 2025-01-01 13:00:00"],
),
}
mock_geo = {"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")}
with (
_patch_client(responses),
patch(
"app.services.geo_service.lookup_batch",
new=AsyncMock(return_value=mock_geo),
) as mock_batch,
):
mock_session = AsyncMock()
result = await jail_service.get_active_bans(
_SOCKET, http_session=mock_session
)
mock_batch.assert_awaited_once()
assert result.total == 1
assert result.bans[0].country == "DE"
async def test_http_session_batch_failure_graceful(self) -> None:
"""When lookup_batch raises, get_active_bans returns bans without geo."""
responses = {
"status": _make_global_status("sshd"),
"get|sshd|banip|--with-time": (
0,
["1.2.3.4 \t2025-01-01 12:00:00 + 3600 = 2025-01-01 13:00:00"],
),
}
with (
_patch_client(responses),
patch(
"app.services.geo_service.lookup_batch",
new=AsyncMock(side_effect=RuntimeError("geo down")),
),
):
mock_session = AsyncMock()
result = await jail_service.get_active_bans(
_SOCKET, http_session=mock_session
)
assert result.total == 1
assert result.bans[0].country is None
async def test_geo_enricher_still_used_without_http_session(self) -> None:
"""Legacy geo_enricher is still called when http_session is not provided."""
from app.services.geo_service import GeoInfo
responses = {
"status": _make_global_status("sshd"),
"get|sshd|banip|--with-time": (
0,
["1.2.3.4 \t2025-01-01 12:00:00 + 3600 = 2025-01-01 13:00:00"],
),
}
async def _enricher(ip: str) -> GeoInfo | None:
return GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None)
with _patch_client(responses):
result = await jail_service.get_active_bans(
_SOCKET, geo_enricher=_enricher
)
assert result.total == 1
assert result.bans[0].country == "JP"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Ignore list # Ignore list