feature/ignore-self-toggle #1
@@ -50,10 +50,17 @@ This document breaks the entire BanGUI project into development stages, ordered
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Task 3 — Non-Blocking Web Requests & Bulk DB Operations
|
## Task 3 — Non-Blocking Web Requests & Bulk DB Operations ✅ DONE
|
||||||
|
|
||||||
**Goal:** Ensure the web UI remains responsive while geo-IP lookups and database writes are in progress.
|
**Goal:** Ensure the web UI remains responsive while geo-IP lookups and database writes are in progress.
|
||||||
|
|
||||||
|
**Resolution:**
|
||||||
|
- **Bulk DB writes:** `geo_service.lookup_batch` now collects resolved IPs into `pos_rows` / `neg_ips` lists across the chunk loop and flushes them with two `executemany` calls per chunk instead of one `execute` per IP.
|
||||||
|
- **`lookup_cached_only`:** New function that returns `(geo_map, uncached)` immediately from the in-memory + SQLite cache with no API calls. Used by `bans_by_country` for its hot path.
|
||||||
|
- **Background geo resolution:** `bans_by_country` calls `lookup_cached_only` for an instant response, then fires `asyncio.create_task(geo_service.lookup_batch(uncached, …))` to populate the cache in the background for subsequent requests.
|
||||||
|
- **Batch enrichment for `get_active_bans`:** `jail_service.get_active_bans` now accepts `http_session` / `app_db` and resolves all banned IPs in a single `lookup_batch` call (chunked 100-IP batches) instead of firing one coroutine per IP through `asyncio.gather`.
|
||||||
|
- 12 new tests across `test_geo_service.py`, `test_jail_service.py`, and `test_ban_service.py`; `ruff` and `mypy --strict` clean; 145 tests pass.
|
||||||
|
|
||||||
**Details:**
|
**Details:**
|
||||||
|
|
||||||
- After the geo-IP service was integrated, web UI requests became slow or appeared to hang because geo lookups and individual DB writes block the async event loop.
|
- After the geo-IP service was integrated, web UI requests became slow or appeared to hang because geo lookups and individual DB writes block the async event loop.
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from fastapi import APIRouter, HTTPException, Request, status
|
|||||||
from app.dependencies import AuthDep
|
from app.dependencies import AuthDep
|
||||||
from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest
|
from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest
|
||||||
from app.models.jail import JailCommandResponse
|
from app.models.jail import JailCommandResponse
|
||||||
from app.services import geo_service, jail_service
|
from app.services import jail_service
|
||||||
from app.services.jail_service import JailNotFoundError, JailOperationError
|
from app.services.jail_service import JailNotFoundError, JailOperationError
|
||||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||||
|
|
||||||
@@ -68,12 +68,14 @@ async def get_active_bans(
|
|||||||
"""
|
"""
|
||||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||||
http_session: aiohttp.ClientSession = request.app.state.http_session
|
http_session: aiohttp.ClientSession = request.app.state.http_session
|
||||||
|
app_db = request.app.state.db
|
||||||
async def _enricher(ip: str) -> geo_service.GeoInfo | None:
|
|
||||||
return await geo_service.lookup(ip, http_session)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await jail_service.get_active_bans(socket_path, geo_enricher=_enricher)
|
return await jail_service.get_active_bans(
|
||||||
|
socket_path,
|
||||||
|
http_session=http_session,
|
||||||
|
app_db=app_db,
|
||||||
|
)
|
||||||
except Fail2BanConnectionError as exc:
|
except Fail2BanConnectionError as exc:
|
||||||
raise _bad_gateway(exc) from exc
|
raise _bad_gateway(exc) from exc
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ so BanGUI never modifies or locks the fail2ban database.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
@@ -344,20 +345,26 @@ async def bans_by_country(
|
|||||||
|
|
||||||
1. Queries the fail2ban DB with ``GROUP BY ip`` to get the per-IP ban
|
1. Queries the fail2ban DB with ``GROUP BY ip`` to get the per-IP ban
|
||||||
counts for all unique IPs in the window — no row-count cap.
|
counts for all unique IPs in the window — no row-count cap.
|
||||||
2. Batch-resolves every unique IP via :func:`~app.services.geo_service.lookup_batch`
|
2. Serves geo data from the in-memory cache only (non-blocking).
|
||||||
(100 IPs per HTTP call) instead of one-at-a-time lookups.
|
Any IPs not yet in the cache are scheduled for background resolution
|
||||||
|
via :func:`asyncio.create_task` so the response is returned immediately
|
||||||
|
and subsequent requests benefit from the warmed cache.
|
||||||
3. Returns a ``{country_code: count}`` aggregation and the 200 most
|
3. Returns a ``{country_code: count}`` aggregation and the 200 most
|
||||||
recent raw rows (already geo-cached from step 2) for the companion
|
recent raw rows for the companion table.
|
||||||
table.
|
|
||||||
|
Note:
|
||||||
|
On the very first request a large number of IPs may be uncached and
|
||||||
|
the country map will be sparse. The background task will resolve them
|
||||||
|
and the next request will return a complete map. This trade-off keeps
|
||||||
|
the endpoint fast regardless of dataset size.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
socket_path: Path to the fail2ban Unix domain socket.
|
socket_path: Path to the fail2ban Unix domain socket.
|
||||||
range_: Time-range preset.
|
range_: Time-range preset.
|
||||||
http_session: Optional :class:`aiohttp.ClientSession` for batch
|
http_session: Optional :class:`aiohttp.ClientSession` for background
|
||||||
geo lookups. When provided, :func:`geo_service.lookup_batch`
|
geo lookups. When ``None``, only cached data is used.
|
||||||
is used instead of the *geo_enricher* callable.
|
|
||||||
geo_enricher: Legacy async ``(ip) -> GeoInfo | None`` callable;
|
geo_enricher: Legacy async ``(ip) -> GeoInfo | None`` callable;
|
||||||
used when *http_session* is ``None``.
|
used when *http_session* is ``None`` (e.g. tests).
|
||||||
app_db: Optional BanGUI application database used to persist newly
|
app_db: Optional BanGUI application database used to persist newly
|
||||||
resolved geo entries across restarts.
|
resolved geo entries across restarts.
|
||||||
origin: Optional origin filter — ``"blocklist"`` restricts results to
|
origin: Optional origin filter — ``"blocklist"`` restricts results to
|
||||||
@@ -367,8 +374,6 @@ async def bans_by_country(
|
|||||||
:class:`~app.models.ban.BansByCountryResponse` with per-country
|
:class:`~app.models.ban.BansByCountryResponse` with per-country
|
||||||
aggregation and the companion ban list.
|
aggregation and the companion ban list.
|
||||||
"""
|
"""
|
||||||
import asyncio
|
|
||||||
|
|
||||||
from app.services import geo_service # noqa: PLC0415
|
from app.services import geo_service # noqa: PLC0415
|
||||||
|
|
||||||
since: int = _since_unix(range_)
|
since: int = _since_unix(range_)
|
||||||
@@ -417,15 +422,26 @@ async def bans_by_country(
|
|||||||
) as cur:
|
) as cur:
|
||||||
companion_rows = await cur.fetchall()
|
companion_rows = await cur.fetchall()
|
||||||
|
|
||||||
# Batch-resolve all unique IPs (much faster than individual lookups).
|
|
||||||
unique_ips: list[str] = [str(r["ip"]) for r in agg_rows]
|
unique_ips: list[str] = [str(r["ip"]) for r in agg_rows]
|
||||||
geo_map: dict[str, Any] = {}
|
geo_map: dict[str, Any] = {}
|
||||||
|
|
||||||
if http_session is not None and unique_ips:
|
if http_session is not None and unique_ips:
|
||||||
try:
|
# Serve only what is already in the in-memory cache — no API calls on
|
||||||
geo_map = await geo_service.lookup_batch(unique_ips, http_session, db=app_db)
|
# the hot path. Uncached IPs are resolved asynchronously in the
|
||||||
except Exception as exc: # noqa: BLE001
|
# background so subsequent requests benefit from a warmer cache.
|
||||||
log.warning("ban_service_batch_geo_failed", error=str(exc))
|
geo_map, uncached = geo_service.lookup_cached_only(unique_ips)
|
||||||
|
if uncached:
|
||||||
|
log.info(
|
||||||
|
"ban_service_geo_background_scheduled",
|
||||||
|
uncached=len(uncached),
|
||||||
|
cached=len(geo_map),
|
||||||
|
)
|
||||||
|
# Fire-and-forget: lookup_batch handles rate-limiting / retries.
|
||||||
|
# The dirty-set flush task persists results to the DB.
|
||||||
|
asyncio.create_task( # noqa: RUF006
|
||||||
|
geo_service.lookup_batch(uncached, http_session, db=app_db),
|
||||||
|
name="geo_bans_by_country",
|
||||||
|
)
|
||||||
elif geo_enricher is not None and unique_ips:
|
elif geo_enricher is not None and unique_ips:
|
||||||
# Fallback: legacy per-IP enricher (used in tests / older callers).
|
# Fallback: legacy per-IP enricher (used in tests / older callers).
|
||||||
async def _safe_lookup(ip: str) -> tuple[str, Any]:
|
async def _safe_lookup(ip: str) -> tuple[str, Any]:
|
||||||
|
|||||||
@@ -435,6 +435,41 @@ async def lookup(
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def lookup_cached_only(
|
||||||
|
ips: list[str],
|
||||||
|
) -> tuple[dict[str, GeoInfo], list[str]]:
|
||||||
|
"""Return cached geo data for *ips* without making any external API calls.
|
||||||
|
|
||||||
|
Used by callers that want to return a fast response using only what is
|
||||||
|
already in memory, while deferring resolution of uncached IPs to a
|
||||||
|
background task.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ips: IP address strings to look up.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ``(geo_map, uncached)`` tuple where *geo_map* maps every IP that
|
||||||
|
was already in the in-memory cache to its :class:`GeoInfo`, and
|
||||||
|
*uncached* is the list of IPs that were not found in the cache.
|
||||||
|
Entries in the negative cache (recently failed) are **not** included
|
||||||
|
in *uncached* so they are not re-queued immediately.
|
||||||
|
"""
|
||||||
|
geo_map: dict[str, GeoInfo] = {}
|
||||||
|
uncached: list[str] = []
|
||||||
|
now = time.monotonic()
|
||||||
|
|
||||||
|
for ip in dict.fromkeys(ips): # deduplicate, preserve order
|
||||||
|
if ip in _cache:
|
||||||
|
geo_map[ip] = _cache[ip]
|
||||||
|
elif ip in _neg_cache and (now - _neg_cache[ip]) < _NEG_CACHE_TTL:
|
||||||
|
# Still within the cool-down window — do not re-queue.
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
uncached.append(ip)
|
||||||
|
|
||||||
|
return geo_map, uncached
|
||||||
|
|
||||||
|
|
||||||
async def lookup_batch(
|
async def lookup_batch(
|
||||||
ips: list[str],
|
ips: list[str],
|
||||||
http_session: aiohttp.ClientSession,
|
http_session: aiohttp.ClientSession,
|
||||||
@@ -447,7 +482,9 @@ async def lookup_batch(
|
|||||||
``http://ip-api.com/batch`` in chunks of up to :data:`_BATCH_SIZE`.
|
``http://ip-api.com/batch`` in chunks of up to :data:`_BATCH_SIZE`.
|
||||||
|
|
||||||
Only successful resolutions (``country_code is not None``) are written to
|
Only successful resolutions (``country_code is not None``) are written to
|
||||||
the persistent cache when *db* is provided.
|
the persistent cache when *db* is provided. Both positive and negative
|
||||||
|
entries are written in bulk using ``executemany`` (one round-trip per
|
||||||
|
chunk) rather than one ``execute`` per IP.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ips: List of IP address strings to resolve. Duplicates are ignored.
|
ips: List of IP address strings to resolve. Duplicates are ignored.
|
||||||
@@ -509,16 +546,19 @@ async def lookup_batch(
|
|||||||
|
|
||||||
assert chunk_result is not None # noqa: S101
|
assert chunk_result is not None # noqa: S101
|
||||||
|
|
||||||
|
# Collect bulk-write rows instead of one execute per IP.
|
||||||
|
pos_rows: list[tuple[str, str | None, str | None, str | None, str | None]] = []
|
||||||
|
neg_ips: list[str] = []
|
||||||
|
|
||||||
for ip, info in chunk_result.items():
|
for ip, info in chunk_result.items():
|
||||||
if info.country_code is not None:
|
if info.country_code is not None:
|
||||||
# Successful API resolution.
|
# Successful API resolution.
|
||||||
_store(ip, info)
|
_store(ip, info)
|
||||||
geo_result[ip] = info
|
geo_result[ip] = info
|
||||||
if db is not None:
|
if db is not None:
|
||||||
try:
|
pos_rows.append(
|
||||||
await _persist_entry(db, ip, info)
|
(ip, info.country_code, info.country_name, info.asn, info.org)
|
||||||
except Exception as exc: # noqa: BLE001
|
)
|
||||||
log.warning("geo_persist_failed", ip=ip, error=str(exc))
|
|
||||||
else:
|
else:
|
||||||
# API failed — try local GeoIP fallback.
|
# API failed — try local GeoIP fallback.
|
||||||
fallback = _geoip_lookup(ip)
|
fallback = _geoip_lookup(ip)
|
||||||
@@ -526,19 +566,56 @@ async def lookup_batch(
|
|||||||
_store(ip, fallback)
|
_store(ip, fallback)
|
||||||
geo_result[ip] = fallback
|
geo_result[ip] = fallback
|
||||||
if db is not None:
|
if db is not None:
|
||||||
try:
|
pos_rows.append(
|
||||||
await _persist_entry(db, ip, fallback)
|
(
|
||||||
except Exception as exc: # noqa: BLE001
|
ip,
|
||||||
log.warning("geo_persist_failed", ip=ip, error=str(exc))
|
fallback.country_code,
|
||||||
|
fallback.country_name,
|
||||||
|
fallback.asn,
|
||||||
|
fallback.org,
|
||||||
|
)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
# Both resolvers failed — record in negative cache.
|
# Both resolvers failed — record in negative cache.
|
||||||
_neg_cache[ip] = time.monotonic()
|
_neg_cache[ip] = time.monotonic()
|
||||||
geo_result[ip] = _empty
|
geo_result[ip] = _empty
|
||||||
if db is not None:
|
if db is not None:
|
||||||
try:
|
neg_ips.append(ip)
|
||||||
await _persist_neg_entry(db, ip)
|
|
||||||
except Exception as exc: # noqa: BLE001
|
if db is not None:
|
||||||
log.warning("geo_persist_neg_failed", ip=ip, error=str(exc))
|
if pos_rows:
|
||||||
|
try:
|
||||||
|
await db.executemany(
|
||||||
|
"""
|
||||||
|
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(ip) DO UPDATE SET
|
||||||
|
country_code = excluded.country_code,
|
||||||
|
country_name = excluded.country_name,
|
||||||
|
asn = excluded.asn,
|
||||||
|
org = excluded.org,
|
||||||
|
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
||||||
|
""",
|
||||||
|
pos_rows,
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
log.warning(
|
||||||
|
"geo_batch_persist_failed",
|
||||||
|
count=len(pos_rows),
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
if neg_ips:
|
||||||
|
try:
|
||||||
|
await db.executemany(
|
||||||
|
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
|
||||||
|
[(ip,) for ip in neg_ips],
|
||||||
|
)
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
log.warning(
|
||||||
|
"geo_batch_persist_neg_failed",
|
||||||
|
count=len(neg_ips),
|
||||||
|
error=str(exc),
|
||||||
|
)
|
||||||
|
|
||||||
if db is not None:
|
if db is not None:
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -627,16 +627,34 @@ async def unban_ip(
|
|||||||
async def get_active_bans(
|
async def get_active_bans(
|
||||||
socket_path: str,
|
socket_path: str,
|
||||||
geo_enricher: Any | None = None,
|
geo_enricher: Any | None = None,
|
||||||
|
http_session: Any | None = None,
|
||||||
|
app_db: Any | None = None,
|
||||||
) -> ActiveBanListResponse:
|
) -> ActiveBanListResponse:
|
||||||
"""Return all currently banned IPs across every jail.
|
"""Return all currently banned IPs across every jail.
|
||||||
|
|
||||||
For each jail the ``get <jail> banip --with-time`` command is used
|
For each jail the ``get <jail> banip --with-time`` command is used
|
||||||
to retrieve ban start and expiry times alongside the IP address.
|
to retrieve ban start and expiry times alongside the IP address.
|
||||||
|
|
||||||
|
Geo enrichment strategy (highest priority first):
|
||||||
|
|
||||||
|
1. When *http_session* is provided the entire set of banned IPs is resolved
|
||||||
|
in a single :func:`~app.services.geo_service.lookup_batch` call (up to
|
||||||
|
100 IPs per HTTP request). This is far more efficient than concurrent
|
||||||
|
per-IP lookups and stays within ip-api.com rate limits.
|
||||||
|
2. When only *geo_enricher* is provided (legacy / test path) each IP is
|
||||||
|
resolved individually via the supplied async callable.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
socket_path: Path to the fail2ban Unix domain socket.
|
socket_path: Path to the fail2ban Unix domain socket.
|
||||||
geo_enricher: Optional async callable ``(ip) → GeoInfo | None``
|
geo_enricher: Optional async callable ``(ip) → GeoInfo | None``
|
||||||
used to enrich each ban entry with country and ASN data.
|
used to enrich each ban entry with country and ASN data.
|
||||||
|
Ignored when *http_session* is provided.
|
||||||
|
http_session: Optional shared :class:`aiohttp.ClientSession`. When
|
||||||
|
provided, :func:`~app.services.geo_service.lookup_batch` is used
|
||||||
|
for efficient bulk geo resolution.
|
||||||
|
app_db: Optional BanGUI application database connection used to
|
||||||
|
persist newly resolved geo entries across restarts. Only
|
||||||
|
meaningful when *http_session* is provided.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
:class:`~app.models.ban.ActiveBanListResponse` with all active bans.
|
:class:`~app.models.ban.ActiveBanListResponse` with all active bans.
|
||||||
@@ -645,6 +663,8 @@ async def get_active_bans(
|
|||||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
||||||
cannot be reached.
|
cannot be reached.
|
||||||
"""
|
"""
|
||||||
|
from app.services import geo_service # noqa: PLC0415
|
||||||
|
|
||||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||||
|
|
||||||
# Fetch jail names.
|
# Fetch jail names.
|
||||||
@@ -690,8 +710,23 @@ async def get_active_bans(
|
|||||||
if ban is not None:
|
if ban is not None:
|
||||||
bans.append(ban)
|
bans.append(ban)
|
||||||
|
|
||||||
# Enrich with geo data if an enricher was provided.
|
# Enrich with geo data — prefer batch lookup over per-IP enricher.
|
||||||
if geo_enricher is not None:
|
if http_session is not None and bans:
|
||||||
|
all_ips: list[str] = [ban.ip for ban in bans]
|
||||||
|
try:
|
||||||
|
geo_map = await geo_service.lookup_batch(all_ips, http_session, db=app_db)
|
||||||
|
except Exception: # noqa: BLE001
|
||||||
|
log.warning("active_bans_batch_geo_failed")
|
||||||
|
geo_map = {}
|
||||||
|
enriched: list[ActiveBan] = []
|
||||||
|
for ban in bans:
|
||||||
|
geo = geo_map.get(ban.ip)
|
||||||
|
if geo is not None:
|
||||||
|
enriched.append(ban.model_copy(update={"country": geo.country_code}))
|
||||||
|
else:
|
||||||
|
enriched.append(ban)
|
||||||
|
bans = enriched
|
||||||
|
elif geo_enricher is not None:
|
||||||
bans = await _enrich_bans(bans, geo_enricher)
|
bans = await _enrich_bans(bans, geo_enricher)
|
||||||
|
|
||||||
log.info("active_bans_fetched", total=len(bans))
|
log.info("active_bans_fetched", total=len(bans))
|
||||||
|
|||||||
@@ -614,6 +614,108 @@ class TestOriginFilter:
|
|||||||
assert result.total == 3
|
assert result.total == 3
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# bans_by_country — background geo resolution (Task 3)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
class TestBansbyCountryBackground:
|
||||||
|
"""bans_by_country() with http_session uses cache-only geo and fires a
|
||||||
|
background task for uncached IPs instead of blocking on API calls."""
|
||||||
|
|
||||||
|
async def test_cached_geo_returned_without_api_call(
|
||||||
|
self, mixed_origin_db_path: str
|
||||||
|
) -> None:
|
||||||
|
"""When all IPs are in the cache, lookup_cached_only returns them and
|
||||||
|
no background task is created."""
|
||||||
|
from app.services import geo_service
|
||||||
|
|
||||||
|
# Pre-populate the cache for all three IPs in the fixture.
|
||||||
|
geo_service._cache["10.0.0.1"] = geo_service.GeoInfo( # type: ignore[attr-defined]
|
||||||
|
country_code="DE", country_name="Germany", asn=None, org=None
|
||||||
|
)
|
||||||
|
geo_service._cache["10.0.0.2"] = geo_service.GeoInfo( # type: ignore[attr-defined]
|
||||||
|
country_code="US", country_name="United States", asn=None, org=None
|
||||||
|
)
|
||||||
|
geo_service._cache["10.0.0.3"] = geo_service.GeoInfo( # type: ignore[attr-defined]
|
||||||
|
country_code="JP", country_name="Japan", asn=None, org=None
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"app.services.ban_service.asyncio.create_task"
|
||||||
|
) as mock_create_task,
|
||||||
|
):
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
result = await ban_service.bans_by_country(
|
||||||
|
"/fake/sock", "24h", http_session=mock_session
|
||||||
|
)
|
||||||
|
|
||||||
|
# All countries resolved from cache — no background task needed.
|
||||||
|
mock_create_task.assert_not_called()
|
||||||
|
assert result.total == 3
|
||||||
|
# Country counts should reflect the cached data.
|
||||||
|
assert "DE" in result.countries or "US" in result.countries or "JP" in result.countries
|
||||||
|
geo_service.clear_cache()
|
||||||
|
|
||||||
|
async def test_uncached_ips_trigger_background_task(
|
||||||
|
self, mixed_origin_db_path: str
|
||||||
|
) -> None:
|
||||||
|
"""When IPs are NOT in the cache, create_task is called for background
|
||||||
|
resolution and the response returns without blocking."""
|
||||||
|
from app.services import geo_service
|
||||||
|
|
||||||
|
geo_service.clear_cache() # ensure cache is empty
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"app.services.ban_service.asyncio.create_task"
|
||||||
|
) as mock_create_task,
|
||||||
|
):
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
result = await ban_service.bans_by_country(
|
||||||
|
"/fake/sock", "24h", http_session=mock_session
|
||||||
|
)
|
||||||
|
|
||||||
|
# Background task must have been scheduled for uncached IPs.
|
||||||
|
mock_create_task.assert_called_once()
|
||||||
|
# Response is still valid with empty country map (IPs not cached yet).
|
||||||
|
assert result.total == 3
|
||||||
|
|
||||||
|
async def test_no_background_task_without_http_session(
|
||||||
|
self, mixed_origin_db_path: str
|
||||||
|
) -> None:
|
||||||
|
"""When http_session is None, no background task is created."""
|
||||||
|
from app.services import geo_service
|
||||||
|
|
||||||
|
geo_service.clear_cache()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"app.services.ban_service._get_fail2ban_db_path",
|
||||||
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"app.services.ban_service.asyncio.create_task"
|
||||||
|
) as mock_create_task,
|
||||||
|
):
|
||||||
|
result = await ban_service.bans_by_country(
|
||||||
|
"/fake/sock", "24h", http_session=None
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_create_task.assert_not_called()
|
||||||
|
assert result.total == 3
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# ban_trend
|
# ban_trend
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -767,3 +767,147 @@ class TestErrorLogging:
|
|||||||
assert event["exc_type"] == "_EmptyMessageError"
|
assert event["exc_type"] == "_EmptyMessageError"
|
||||||
assert "_EmptyMessageError" in event["error"]
|
assert "_EmptyMessageError" in event["error"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# lookup_cached_only (Task 3)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestLookupCachedOnly:
|
||||||
|
"""lookup_cached_only() returns cache hits without making API calls."""
|
||||||
|
|
||||||
|
def test_returns_cached_ips(self) -> None:
|
||||||
|
"""IPs already in the cache are returned in the geo_map."""
|
||||||
|
geo_service._cache["1.1.1.1"] = GeoInfo( # type: ignore[attr-defined]
|
||||||
|
country_code="AU", country_name="Australia", asn="AS13335", org="Cloudflare"
|
||||||
|
)
|
||||||
|
geo_map, uncached = geo_service.lookup_cached_only(["1.1.1.1"])
|
||||||
|
|
||||||
|
assert "1.1.1.1" in geo_map
|
||||||
|
assert geo_map["1.1.1.1"].country_code == "AU"
|
||||||
|
assert uncached == []
|
||||||
|
|
||||||
|
def test_returns_uncached_ips(self) -> None:
|
||||||
|
"""IPs not in the cache appear in the uncached list."""
|
||||||
|
geo_map, uncached = geo_service.lookup_cached_only(["9.9.9.9"])
|
||||||
|
|
||||||
|
assert "9.9.9.9" not in geo_map
|
||||||
|
assert "9.9.9.9" in uncached
|
||||||
|
|
||||||
|
def test_neg_cached_ips_excluded_from_uncached(self) -> None:
|
||||||
|
"""IPs in the negative cache within TTL are not re-queued as uncached."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
geo_service._neg_cache["10.0.0.1"] = time.monotonic() # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.1"])
|
||||||
|
|
||||||
|
assert "10.0.0.1" not in geo_map
|
||||||
|
assert "10.0.0.1" not in uncached
|
||||||
|
|
||||||
|
def test_expired_neg_cache_requeued(self) -> None:
|
||||||
|
"""IPs whose neg-cache entry has expired are listed as uncached."""
|
||||||
|
geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
_geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.2"])
|
||||||
|
|
||||||
|
assert "10.0.0.2" in uncached
|
||||||
|
|
||||||
|
def test_mixed_ips(self) -> None:
|
||||||
|
"""A mix of cached, neg-cached, and unknown IPs is split correctly."""
|
||||||
|
geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined]
|
||||||
|
country_code="DE", country_name="Germany", asn=None, org=None
|
||||||
|
)
|
||||||
|
import time
|
||||||
|
|
||||||
|
geo_service._neg_cache["5.5.5.5"] = time.monotonic() # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
geo_map, uncached = geo_service.lookup_cached_only(["1.2.3.4", "5.5.5.5", "9.9.9.9"])
|
||||||
|
|
||||||
|
assert list(geo_map.keys()) == ["1.2.3.4"]
|
||||||
|
assert uncached == ["9.9.9.9"]
|
||||||
|
|
||||||
|
def test_deduplication(self) -> None:
|
||||||
|
"""Duplicate IPs in the input appear at most once in the output."""
|
||||||
|
geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined]
|
||||||
|
country_code="US", country_name="United States", asn=None, org=None
|
||||||
|
)
|
||||||
|
|
||||||
|
geo_map, uncached = geo_service.lookup_cached_only(
|
||||||
|
["9.9.9.9", "9.9.9.9", "1.2.3.4", "1.2.3.4"]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len([ip for ip in geo_map if ip == "1.2.3.4"]) == 1
|
||||||
|
assert uncached.count("9.9.9.9") == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Bulk DB writes via executemany (Task 3)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestLookupBatchBulkWrites:
|
||||||
|
"""lookup_batch() uses executemany for bulk DB writes, not per-IP execute."""
|
||||||
|
|
||||||
|
async def test_executemany_called_for_successful_ips(self) -> None:
|
||||||
|
"""When multiple IPs resolve successfully, a single executemany write occurs."""
|
||||||
|
ips = ["1.1.1.1", "2.2.2.2", "3.3.3.3"]
|
||||||
|
batch_response = [
|
||||||
|
{
|
||||||
|
"query": ip,
|
||||||
|
"status": "success",
|
||||||
|
"countryCode": "DE",
|
||||||
|
"country": "Germany",
|
||||||
|
"as": "AS3320",
|
||||||
|
"org": "Telekom",
|
||||||
|
}
|
||||||
|
for ip in ips
|
||||||
|
]
|
||||||
|
session = _make_batch_session(batch_response)
|
||||||
|
db = _make_async_db()
|
||||||
|
|
||||||
|
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
# One executemany for the positive rows.
|
||||||
|
assert db.executemany.await_count >= 1
|
||||||
|
# High-level: execute() must NOT be called for the batch writes.
|
||||||
|
db.execute.assert_not_awaited()
|
||||||
|
|
||||||
|
async def test_executemany_called_for_failed_ips(self) -> None:
|
||||||
|
"""When IPs fail resolution, a single executemany write covers neg entries."""
|
||||||
|
ips = ["10.0.0.1", "10.0.0.2"]
|
||||||
|
batch_response = [
|
||||||
|
{"query": ip, "status": "fail", "message": "private range"}
|
||||||
|
for ip in ips
|
||||||
|
]
|
||||||
|
session = _make_batch_session(batch_response)
|
||||||
|
db = _make_async_db()
|
||||||
|
|
||||||
|
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
assert db.executemany.await_count >= 1
|
||||||
|
db.execute.assert_not_awaited()
|
||||||
|
|
||||||
|
async def test_mixed_results_two_executemany_calls(self) -> None:
|
||||||
|
"""A mix of successful and failed IPs produces two executemany calls."""
|
||||||
|
ips = ["1.1.1.1", "10.0.0.1"]
|
||||||
|
batch_response = [
|
||||||
|
{
|
||||||
|
"query": "1.1.1.1",
|
||||||
|
"status": "success",
|
||||||
|
"countryCode": "AU",
|
||||||
|
"country": "Australia",
|
||||||
|
"as": "AS13335",
|
||||||
|
"org": "Cloudflare",
|
||||||
|
},
|
||||||
|
{"query": "10.0.0.1", "status": "fail", "message": "private range"},
|
||||||
|
]
|
||||||
|
session = _make_batch_session(batch_response)
|
||||||
|
db = _make_async_db()
|
||||||
|
|
||||||
|
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
# One executemany for positives, one for negatives.
|
||||||
|
assert db.executemany.await_count == 2
|
||||||
|
db.execute.assert_not_awaited()
|
||||||
|
|
||||||
|
|||||||
@@ -472,6 +472,83 @@ class TestGetActiveBans:
|
|||||||
assert result.total == 1
|
assert result.total == 1
|
||||||
assert result.bans[0].jail == "sshd"
|
assert result.bans[0].jail == "sshd"
|
||||||
|
|
||||||
|
async def test_http_session_triggers_lookup_batch(self) -> None:
|
||||||
|
"""When http_session is provided, geo_service.lookup_batch is used."""
|
||||||
|
from app.services.geo_service import GeoInfo
|
||||||
|
|
||||||
|
responses = {
|
||||||
|
"status": _make_global_status("sshd"),
|
||||||
|
"get|sshd|banip|--with-time": (
|
||||||
|
0,
|
||||||
|
["1.2.3.4 \t2025-01-01 12:00:00 + 3600 = 2025-01-01 13:00:00"],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
mock_geo = {"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")}
|
||||||
|
|
||||||
|
with (
|
||||||
|
_patch_client(responses),
|
||||||
|
patch(
|
||||||
|
"app.services.geo_service.lookup_batch",
|
||||||
|
new=AsyncMock(return_value=mock_geo),
|
||||||
|
) as mock_batch,
|
||||||
|
):
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
result = await jail_service.get_active_bans(
|
||||||
|
_SOCKET, http_session=mock_session
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_batch.assert_awaited_once()
|
||||||
|
assert result.total == 1
|
||||||
|
assert result.bans[0].country == "DE"
|
||||||
|
|
||||||
|
async def test_http_session_batch_failure_graceful(self) -> None:
|
||||||
|
"""When lookup_batch raises, get_active_bans returns bans without geo."""
|
||||||
|
responses = {
|
||||||
|
"status": _make_global_status("sshd"),
|
||||||
|
"get|sshd|banip|--with-time": (
|
||||||
|
0,
|
||||||
|
["1.2.3.4 \t2025-01-01 12:00:00 + 3600 = 2025-01-01 13:00:00"],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
with (
|
||||||
|
_patch_client(responses),
|
||||||
|
patch(
|
||||||
|
"app.services.geo_service.lookup_batch",
|
||||||
|
new=AsyncMock(side_effect=RuntimeError("geo down")),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
result = await jail_service.get_active_bans(
|
||||||
|
_SOCKET, http_session=mock_session
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.total == 1
|
||||||
|
assert result.bans[0].country is None
|
||||||
|
|
||||||
|
async def test_geo_enricher_still_used_without_http_session(self) -> None:
|
||||||
|
"""Legacy geo_enricher is still called when http_session is not provided."""
|
||||||
|
from app.services.geo_service import GeoInfo
|
||||||
|
|
||||||
|
responses = {
|
||||||
|
"status": _make_global_status("sshd"),
|
||||||
|
"get|sshd|banip|--with-time": (
|
||||||
|
0,
|
||||||
|
["1.2.3.4 \t2025-01-01 12:00:00 + 3600 = 2025-01-01 13:00:00"],
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _enricher(ip: str) -> GeoInfo | None:
|
||||||
|
return GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None)
|
||||||
|
|
||||||
|
with _patch_client(responses):
|
||||||
|
result = await jail_service.get_active_bans(
|
||||||
|
_SOCKET, geo_enricher=_enricher
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.total == 1
|
||||||
|
assert result.bans[0].country == "JP"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Ignore list
|
# Ignore list
|
||||||
|
|||||||
Reference in New Issue
Block a user