refactoring-backend #3

Merged
lukas.pupkalipinski merged 403 commits from refactoring-backend into main 2026-05-20 20:23:46 +02:00
10 changed files with 1035 additions and 889 deletions
Showing only changes of commit 654dbdb000 - Show all commits

View File

@@ -196,7 +196,8 @@ The business logic layer. Services orchestrate operations, enforce rules, and co
| `log_service.py` | Log preview and regex test operations (extracted from config_service) |
| `history_service.py` | Queries the fail2ban database for historical ban records, builds per-IP timelines, computes ban counts and repeat-offender flags, and syncs new records into BanGUI's archive table |
| `blocklist_service.py` | Downloads blocklists via aiohttp, validates IPs/CIDRs, applies bans through fail2ban or iptables, logs import results |
| `geo_service.py` | Resolves IP addresses to country, ASN, and RIR using external APIs or a local database, caches results, and re-resolves unresolved geo cache entries |
| `geo_cache.py` | **GeoCache** class that encapsulates all IP geolocation caching: resolves IP addresses to country, ASN, and organization using external APIs or a local MaxMind database, maintains in-memory and persistent caches with negative cache support, and manages background re-resolution. Instantiated once at startup and stored on `app.state.geo_cache` |
| `geo_service.py` | (Deprecated) Backward-compatibility wrappers that delegate to the `GeoCache` instance. Kept for compatibility with existing code. New code should use `GeoCache` directly or via dependency injection |
| `server_service.py` | Reads and writes fail2ban server-level settings (log level, log target, syslog socket, DB location, purge age) |
| `health_service.py` | Probes fail2ban socket connectivity, retrieves server version and global stats, reports online/offline status |
@@ -667,6 +668,7 @@ BanGUI maintains its **own SQLite database** (separate from the fail2ban databas
- The frontend `AuthProvider` checks session validity on mount and redirects to `/login` if invalid.
- The backend `dependencies.py` provides an `authenticated` dependency that validates the session cookie on every protected endpoint.
- **Session validation cache** — validated session tokens are cached in memory for 10 seconds (`_session_cache` dict in `dependencies.py`) to avoid a SQLite round-trip on every request from the same browser. The cache is invalidated immediately on logout. This cache is process-local and not safe for multi-worker or distributed deployments. A clustered deployment should replace `_session_cache` with a shared cache or remove it entirely.
- **GeoCache** — `GeoCache` instance is created at startup and stored on `app.state.geo_cache`. It encapsulates all IP geolocation caching: in-memory lookup cache, negative cache for unresolvable IPs (with TTL), dirty set for persistence, and thread-safe async locking. Cache is loaded from the `geo_cache` SQLite table on startup. New resolutions are accumulated in memory and periodically flushed to the database by the `geo_cache_flush` background task. Stale entries are re-resolved by the `geo_re_resolve` task. Injected into routes and tasks via FastAPI's dependency system.
- **Runtime state** — `RuntimeState` is process-local and only safe when BanGUI runs as a single asyncio worker. Mutating runtime state must not span `await` points because the current design relies on cooperative scheduling. Multi-worker or multi-process deployments must replace this runtime state with a shared coordination backend such as Redis, shared memory, or a database-backed store.
- **Setup-completion flag** — once `is_setup_complete()` returns `True`, the result is stored in `app.state._setup_complete_cached`. The `SetupRedirectMiddleware` skips the DB query on all subsequent requests, removing 1 SQL query per request for the common post-setup case. The completion flag is only written after the runtime database is successfully initialized and all initial setup settings are persisted, preventing a failed setup from permanently bypassing the setup wizard.

View File

@@ -1,23 +1,3 @@
### T-03 · Centralise `_DEFAULT_PAGE_SIZE` constant
**Where found:** `backend/app/routers/dashboard.py:45`, `routers/history.py:34`, `services/ban_service.py:70`, `services/history_service.py:49`
**Why this is needed:** Four independent definitions can drift. The router default and service default are currently coincidentally aligned at 100, but nothing enforces this.
**Goal:** Single definition in `app/utils/constants.py`, imported everywhere.
**What to do:**
1. Add `DEFAULT_PAGE_SIZE: Final[int] = 100` and `MAX_PAGE_SIZE: Final[int] = 500` to `app/utils/constants.py`.
2. Replace all four local `_DEFAULT_PAGE_SIZE` and `_MAX_PAGE_SIZE` declarations with imports.
**Possible traps and issues:** None significant. Pure search-and-replace.
**Docs changes needed:** None.
**Doc references:** `app/utils/constants.py`
---
### T-04 · Encapsulate `geo_service` module-level mutable state in a class
**Where found:** `backend/app/services/geo_service.py` — module globals `_cache`, `_neg_cache`, `_dirty`, `_geoip_reader`, `_geoip_initialized`, `_cache_lock`

View File

@@ -31,6 +31,7 @@ from app.repositories.protocols import (
SettingsRepository,
SessionRepository,
)
from app.services.geo_cache import GeoCache
from app.utils.constants import SESSION_COOKIE_NAME
from app.utils.runtime_state import RuntimeState
from app.utils.session_cache import InMemorySessionCache, NoOpSessionCache, SessionCache
@@ -50,6 +51,7 @@ class AppState(Protocol):
runtime_settings: Settings | None
runtime_state: RuntimeState
session_cache: SessionCache
geo_cache: GeoCache # noqa: F821
@dataclass
@@ -214,11 +216,15 @@ async def get_fail2ban_start_command(settings: Settings = Depends(get_settings))
return settings.fail2ban_start_command
async def get_geo_batch_lookup() -> GeoBatchLookup:
"""Provide the concrete geo batch lookup callable used by routers."""
from app.services import geo_service # noqa: PLC0415
async def get_geo_batch_lookup(request: Request) -> GeoBatchLookup:
"""Provide the geo batch lookup method from the application's GeoCache instance."""
geo_cache: GeoCache = request.app.state.geo_cache
return geo_cache.lookup_batch # type: ignore[return-value]
return geo_service.lookup_batch
async def get_geo_cache(request: Request) -> GeoCache:
"""Provide the application's GeoCache instance."""
return request.app.state.geo_cache
async def get_session_cache(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> SessionCache:

View File

@@ -0,0 +1,737 @@
"""GeoCache service — encapsulates geo service mutable state.
This module defines the :class:`GeoCache` class which encapsulates all
module-level mutable state from the original geo_service into a single
injectable instance. The cache manages:
- In-memory positive results cache (``ip → GeoInfo``)
- Negative cache (failed lookups with TTL)
- Dirty set (entries pending persistence)
- Lock protecting cache mutations
- MaxMind GeoLite2 reader initialization
An instance should be created once at startup and stored on ``app.state.geo_cache``.
"""
from __future__ import annotations
import asyncio
import time
from typing import TYPE_CHECKING
import aiohttp
import structlog
from app.models.geo import GeoInfo
from app.repositories import geo_cache_repo
if TYPE_CHECKING:
import aiosqlite
import geoip2.database
import geoip2.errors
log: structlog.stdlib.BoundLogger = structlog.get_logger()
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
#: ip-api.com single-IP lookup endpoint (HTTP only on the free tier).
_API_URL: str = (
"http://ip-api.com/json/{ip}?fields=status,message,country,countryCode,org,as"
)
#: ip-api.com batch endpoint — accepts up to 100 IPs per POST.
_BATCH_API_URL: str = (
"http://ip-api.com/batch?fields=status,message,country,countryCode,org,as,query"
)
#: Maximum IPs per batch request (ip-api.com hard limit is 100).
_BATCH_SIZE: int = 100
#: Maximum number of entries kept in the in-process cache before it is
#: flushed completely. A simple eviction strategy — the cache is cheap to
#: rebuild from the persistent store.
_MAX_CACHE_SIZE: int = 50_000
#: Timeout for outgoing geo API requests in seconds.
_REQUEST_TIMEOUT: float = 5.0
#: How many seconds a failed lookup result is suppressed before the IP is
#: eligible for a new API attempt. Default: 5 minutes.
_NEG_CACHE_TTL: float = 300.0
#: Minimum delay in seconds between consecutive batch HTTP requests to
#: ip-api.com. The free tier allows 45 requests/min; 1.5 s ≈ 40 req/min.
_BATCH_DELAY: float = 1.5
#: Maximum number of retries for a batch chunk that fails with a
#: transient error (e.g. connection reset due to rate limiting).
_BATCH_MAX_RETRIES: int = 2
class GeoCache:
"""Manages IP geolocation caching with positive and negative caches.
Encapsulates all mutable state needed for geo-IP resolution. Provides
methods for single lookups, batch lookups, persistence, and cache management.
State:
_cache: In-memory positive results cache (``ip → GeoInfo``).
_neg_cache: Failed lookup timestamps (``ip → epoch``).
_dirty: IPs added but not yet persisted to database.
_geoip_reader: Optional MaxMind GeoLite2 reader.
_geoip_initialized: Indicates whether init_geoip() has been called.
_cache_lock: Async lock protecting cache mutations.
"""
def __init__(self) -> None:
"""Initialize an empty GeoCache."""
self._cache: dict[str, GeoInfo] = {}
self._neg_cache: dict[str, float] = {}
self._dirty: set[str] = set()
self._geoip_reader: geoip2.database.Reader | None = None
self._geoip_initialized: bool = False
self._cache_lock: asyncio.Lock = asyncio.Lock()
async def clear(self) -> None:
"""Flush both the positive and negative lookup caches.
Also clears the dirty set so any pending-but-unpersisted entries are
discarded. Useful in tests and when the operator suspects stale data.
"""
async with self._cache_lock:
self._cache.clear()
self._neg_cache.clear()
self._dirty.clear()
async def clear_neg_cache(self) -> None:
"""Flush only the negative (failed-lookups) cache.
Useful when triggering a manual re-resolve so that previously failed
IPs are immediately eligible for a new API attempt.
"""
async with self._cache_lock:
self._neg_cache.clear()
def is_cached(self, ip: str) -> bool:
"""Return ``True`` if *ip* has a positive entry in the in-memory cache.
A positive entry is one with a non-``None`` ``country_code``.
Args:
ip: IPv4 or IPv6 address string.
Returns:
``True`` when *ip* is in the cache with a known country code.
"""
return ip in self._cache and self._cache[ip].country_code is not None
async def cache_stats(self, db: aiosqlite.Connection) -> dict[str, int]:
"""Return diagnostic counters for the geo cache subsystem.
Queries the persistent store for the number of unresolved entries and
combines it with in-memory counters.
Args:
db: Open BanGUI application database connection.
Returns:
Dict with keys ``cache_size``, ``unresolved``, ``neg_cache_size``,
and ``dirty_size``.
"""
unresolved = await geo_cache_repo.count_unresolved(db)
return {
"cache_size": len(self._cache),
"unresolved": unresolved,
"neg_cache_size": len(self._neg_cache),
"dirty_size": len(self._dirty),
}
async def count_unresolved(self, db: aiosqlite.Connection) -> int:
"""Return the number of unresolved entries in the persistent geo cache."""
return await geo_cache_repo.count_unresolved(db)
async def get_unresolved_ips(self, db: aiosqlite.Connection) -> list[str]:
"""Return geo cache IPs where the country code has not yet been resolved.
Args:
db: Open BanGUI application database connection.
Returns:
List of IP addresses that are candidates for re-resolution.
"""
return await geo_cache_repo.get_unresolved_ips(db)
async def re_resolve_all(
self,
db: aiosqlite.Connection,
http_session: aiohttp.ClientSession,
) -> dict[str, int]:
"""Retry geo resolution for all unresolved cache entries.
This helper clears the in-memory negative cache before attempting a
fresh batch lookup, then returns counters for how many IPs were retried
and how many gained a resolved country code.
Args:
db: BanGUI application database connection.
http_session: Shared aiohttp client session.
Returns:
A dict with ``resolved`` and ``total`` counts.
"""
import structlog # noqa: PLC0415
log = structlog.get_logger()
unresolved = await self.get_unresolved_ips(db)
if not unresolved:
return {"resolved": 0, "total": 0}
await self.clear_neg_cache()
geo_map = await self.lookup_batch(unresolved, http_session, db=db)
resolved_count = sum(
1 for info in geo_map.values() if info.country_code is not None
)
log.info(
"geo_re_resolve_complete",
total=len(unresolved),
resolved=resolved_count,
)
return {"resolved": resolved_count, "total": len(unresolved)}
def init_geoip(self, mmdb_path: str | None) -> None:
"""Initialise the MaxMind GeoLite2-Country database reader.
This function is startup-only and must be called before request handling
begins. A second initialization attempt is considered a programming error
and raises ``RuntimeError``.
If *mmdb_path* is ``None``, empty, or the file does not exist the
fallback is silently disabled — ip-api.com remains the sole resolver.
Args:
mmdb_path: Absolute path to a ``GeoLite2-Country.mmdb`` file.
"""
if self._geoip_initialized:
raise RuntimeError("GeoIP reader already initialised")
if not mmdb_path:
return
from pathlib import Path # noqa: PLC0415
import geoip2.database # noqa: PLC0415
if not Path(mmdb_path).is_file():
log.warning("geoip_mmdb_not_found", path=mmdb_path)
return
self._geoip_reader = geoip2.database.Reader(mmdb_path)
self._geoip_initialized = True
log.info("geoip_mmdb_loaded", path=mmdb_path)
def _geoip_lookup(self, ip: str) -> GeoInfo | None:
"""Attempt a local MaxMind GeoLite2 lookup for *ip*.
Returns ``None`` when the reader is not initialised, the IP is not in
the database, or any other error occurs.
Args:
ip: IPv4 or IPv6 address string.
Returns:
A :class:`GeoInfo` with at least ``country_code`` populated, or
``None`` when resolution is impossible.
"""
if self._geoip_reader is None:
return None
import geoip2.errors # noqa: PLC0415
try:
response = self._geoip_reader.country(ip)
code: str | None = response.country.iso_code or None
name: str | None = response.country.name or None
if code is None:
return None
return GeoInfo(country_code=code, country_name=name, asn=None, org=None)
except geoip2.errors.AddressNotFoundError:
return None
except Exception as exc: # noqa: BLE001
log.warning("geoip_lookup_failed", ip=ip, error=str(exc))
return None
async def load_cache_from_db(self, db: aiosqlite.Connection) -> None:
"""Pre-populate the in-memory cache from the ``geo_cache`` table.
Should be called once during application startup so the service starts
with a warm cache instead of making cold API calls on the first request.
Args:
db: Open :class:`aiosqlite.Connection` to the BanGUI application
database (not the fail2ban database).
"""
count = 0
cache_entries: list[tuple[str, GeoInfo]] = []
for row in await geo_cache_repo.load_all(db):
country_code: str | None = row["country_code"]
if country_code is None:
continue
ip: str = row["ip"]
cache_entries.append(
(
ip,
GeoInfo(
country_code=country_code,
country_name=row["country_name"],
asn=row["asn"],
org=row["org"],
),
)
)
count += 1
async with self._cache_lock:
for ip, info in cache_entries:
self._cache[ip] = info
log.info("geo_cache_loaded_from_db", entries=count)
async def _store(self, ip: str, info: GeoInfo) -> None:
"""Insert *info* into the cache, flushing if over capacity.
When the IP resolved successfully (``country_code is not None``) it is
also added to the dirty set so :meth:`flush_dirty` can persist
it to the database on the next scheduled flush.
Args:
ip: The IP address key.
info: The :class:`GeoInfo` to store.
"""
async with self._cache_lock:
if len(self._cache) >= _MAX_CACHE_SIZE:
self._cache.clear()
self._dirty.clear()
log.info("geo_cache_flushed", reason="capacity")
self._cache[ip] = info
if info.country_code is not None:
self._dirty.add(ip)
async def lookup(
self,
ip: str,
http_session: aiohttp.ClientSession,
db: aiosqlite.Connection | None = None,
) -> GeoInfo | None:
"""Resolve an IP address to country, ASN, and organisation metadata.
Results are cached in-process. If the cache exceeds ``_MAX_CACHE_SIZE``
entries it is flushed before the new result is stored.
Only successful resolutions (``country_code is not None``) are written to
the persistent cache when *db* is provided. Failed lookups are **not**
cached so they are retried on the next call.
Args:
ip: IPv4 or IPv6 address string.
http_session: Shared :class:`aiohttp.ClientSession` (from
``app.state.http_session``).
db: Optional BanGUI application database. When provided, successful
lookups are persisted for cross-restart cache warming.
Returns:
A :class:`GeoInfo` instance, or ``None`` when the lookup fails
in a way that should prevent the caller from caching a bad result
(e.g. network timeout).
"""
if ip in self._cache:
return self._cache[ip]
# Negative cache: skip IPs that recently failed to avoid hammering the API.
neg_ts = self._neg_cache.get(ip)
if neg_ts is not None and (time.monotonic() - neg_ts) < _NEG_CACHE_TTL:
return GeoInfo(country_code=None, country_name=None, asn=None, org=None)
url: str = _API_URL.format(ip=ip)
api_ok = False
try:
async with http_session.get(url, timeout=aiohttp.ClientTimeout(total=_REQUEST_TIMEOUT)) as resp:
if resp.status != 200:
log.warning("geo_lookup_non_200", ip=ip, status=resp.status)
else:
data: dict[str, object] = await resp.json(content_type=None)
if data.get("status") == "success":
api_ok = True
result = self._parse_single_response(data)
await self._store(ip, result)
if result.country_code is not None and db is not None:
try:
await geo_cache_repo.upsert_entry_and_commit(
db=db,
ip=ip,
country_code=result.country_code,
country_name=result.country_name,
asn=result.asn,
org=result.org,
)
except Exception as exc: # noqa: BLE001
log.warning("geo_persist_failed", ip=ip, error=str(exc))
log.debug("geo_lookup_success", ip=ip, country=result.country_code, asn=result.asn)
return result
log.debug(
"geo_lookup_failed",
ip=ip,
message=data.get("message", "unknown"),
)
except Exception as exc: # noqa: BLE001
log.warning(
"geo_lookup_request_failed",
ip=ip,
exc_type=type(exc).__name__,
error=repr(exc),
)
if not api_ok:
# Try local MaxMind database as fallback.
fallback = self._geoip_lookup(ip)
if fallback is not None:
await self._store(ip, fallback)
if fallback.country_code is not None and db is not None:
try:
await geo_cache_repo.upsert_entry_and_commit(
db=db,
ip=ip,
country_code=fallback.country_code,
country_name=fallback.country_name,
asn=fallback.asn,
org=fallback.org,
)
except Exception as exc: # noqa: BLE001
log.warning("geo_persist_failed", ip=ip, error=str(exc))
log.debug("geo_geoip_fallback_success", ip=ip, country=fallback.country_code)
return fallback
# Both resolvers failed — record in negative cache to avoid hammering.
async with self._cache_lock:
self._neg_cache[ip] = time.monotonic()
if db is not None:
try:
await geo_cache_repo.upsert_neg_entry_and_commit(db=db, ip=ip)
except Exception as exc: # noqa: BLE001
log.warning("geo_persist_neg_failed", ip=ip, error=str(exc))
return GeoInfo(country_code=None, country_name=None, asn=None, org=None)
def lookup_cached_only(
self,
ips: list[str],
) -> tuple[dict[str, GeoInfo], list[str]]:
"""Return cached geo data for *ips* without making any external API calls.
Used by callers that want to return a fast response using only what is
already in memory, while deferring resolution of uncached IPs to a
background task.
Args:
ips: IP address strings to look up.
Returns:
A ``(geo_map, uncached)`` tuple where *geo_map* maps every IP that
was already in the in-memory cache to its :class:`GeoInfo`, and
*uncached* is the list of IPs that were not found in the cache.
Entries in the negative cache (recently failed) are **not** included
in *uncached* so they are not re-queued immediately.
"""
geo_map: dict[str, GeoInfo] = {}
uncached: list[str] = []
now = time.monotonic()
for ip in dict.fromkeys(ips): # deduplicate, preserve order
if ip in self._cache:
geo_map[ip] = self._cache[ip]
elif ip in self._neg_cache and (now - self._neg_cache[ip]) < _NEG_CACHE_TTL:
# Still within the cool-down window — do not re-queue.
pass
else:
uncached.append(ip)
return geo_map, uncached
async def lookup_batch(
self,
ips: list[str],
http_session: aiohttp.ClientSession,
db: aiosqlite.Connection | None = None,
) -> dict[str, GeoInfo]:
"""Resolve multiple IP addresses in bulk using ip-api.com batch endpoint.
IPs already present in the in-memory cache are returned immediately
without making an HTTP request. Uncached IPs are sent to
``http://ip-api.com/batch`` in chunks of up to :data:`_BATCH_SIZE`.
Only successful resolutions (``country_code is not None``) are written to
the persistent cache when *db* is provided. Both positive and negative
entries are written in bulk using ``executemany`` (one round-trip per
chunk) rather than one ``execute`` per IP.
Args:
ips: List of IP address strings to resolve. Duplicates are ignored.
http_session: Shared :class:`aiohttp.ClientSession`.
db: Optional BanGUI application database for persistent cache writes.
Returns:
Dict mapping ``ip → GeoInfo`` for every input IP. IPs whose
resolution failed will have a ``GeoInfo`` with all-``None`` fields.
"""
geo_result: dict[str, GeoInfo] = {}
uncached: list[str] = []
_empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
unique_ips = list(dict.fromkeys(ips)) # deduplicate, preserve order
now = time.monotonic()
for ip in unique_ips:
if ip in self._cache:
geo_result[ip] = self._cache[ip]
elif ip in self._neg_cache and (now - self._neg_cache[ip]) < _NEG_CACHE_TTL:
# Recently failed — skip API call, return empty result.
geo_result[ip] = _empty
else:
uncached.append(ip)
if not uncached:
return geo_result
log.info("geo_batch_lookup_start", total=len(uncached))
for batch_idx, chunk_start in enumerate(range(0, len(uncached), _BATCH_SIZE)):
chunk = uncached[chunk_start : chunk_start + _BATCH_SIZE]
# Throttle: pause between consecutive HTTP calls to stay within the
# ip-api.com free-tier rate limit (45 req/min).
if batch_idx > 0:
await asyncio.sleep(_BATCH_DELAY)
# Retry transient failures (e.g. connection-reset from rate limit).
chunk_result: dict[str, GeoInfo] | None = None
for attempt in range(_BATCH_MAX_RETRIES + 1):
chunk_result = await self._batch_api_call(chunk, http_session)
# If every IP in the chunk came back with country_code=None and the
# batch wasn't tiny, that almost certainly means the whole request
# was rejected (connection reset / 429). Retry after a back-off.
all_failed = all(
info.country_code is None for info in chunk_result.values()
)
if not all_failed or attempt >= _BATCH_MAX_RETRIES:
break
backoff = _BATCH_DELAY * (2 ** (attempt + 1))
log.warning(
"geo_batch_retry",
attempt=attempt + 1,
chunk_size=len(chunk),
backoff=backoff,
)
await asyncio.sleep(backoff)
assert chunk_result is not None # noqa: S101
# Collect bulk-write rows instead of one execute per IP.
pos_rows: list[tuple[str, str | None, str | None, str | None, str | None]] = []
neg_ips: list[str] = []
for ip, info in chunk_result.items():
if info.country_code is not None:
# Successful API resolution.
await self._store(ip, info)
geo_result[ip] = info
if db is not None:
pos_rows.append(
(ip, info.country_code, info.country_name, info.asn, info.org)
)
else:
# API failed — try local GeoIP fallback.
fallback = self._geoip_lookup(ip)
if fallback is not None:
await self._store(ip, fallback)
geo_result[ip] = fallback
if db is not None:
pos_rows.append(
(
ip,
fallback.country_code,
fallback.country_name,
fallback.asn,
fallback.org,
)
)
else:
# Both resolvers failed — record in negative cache.
async with self._cache_lock:
self._neg_cache[ip] = time.monotonic()
geo_result[ip] = _empty
if db is not None:
neg_ips.append(ip)
if db is not None and (pos_rows or neg_ips):
try:
await geo_cache_repo.bulk_upsert_entries_and_neg_entries_and_commit(
db,
pos_rows,
neg_ips,
)
except Exception as exc: # noqa: BLE001
log.warning(
"geo_batch_persist_failed",
positive_count=len(pos_rows),
negative_count=len(neg_ips),
error=str(exc),
)
log.info(
"geo_batch_lookup_complete",
requested=len(uncached),
resolved=sum(1 for g in geo_result.values() if g.country_code is not None),
)
return geo_result
async def _batch_api_call(
self,
ips: list[str],
http_session: aiohttp.ClientSession,
) -> dict[str, GeoInfo]:
"""Send one batch request to the ip-api.com batch endpoint.
Args:
ips: Up to :data:`_BATCH_SIZE` IP address strings.
http_session: Shared HTTP session.
Returns:
Dict mapping ``ip → GeoInfo`` for every IP in *ips*. IPs where the
API returned a failure record or the request raised an exception get
an all-``None`` :class:`GeoInfo`.
"""
empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
fallback: dict[str, GeoInfo] = dict.fromkeys(ips, empty)
payload = [{"query": ip} for ip in ips]
try:
async with http_session.post(
_BATCH_API_URL,
json=payload,
timeout=aiohttp.ClientTimeout(total=_REQUEST_TIMEOUT * 2),
) as resp:
if resp.status != 200:
log.warning("geo_batch_non_200", status=resp.status, count=len(ips))
return fallback
data: list[dict[str, object]] = await resp.json(content_type=None)
except Exception as exc: # noqa: BLE001
log.warning(
"geo_batch_request_failed",
count=len(ips),
exc_type=type(exc).__name__,
error=repr(exc),
)
return fallback
out: dict[str, GeoInfo] = {}
for entry in data:
ip_str: str = str(entry.get("query", ""))
if not ip_str:
continue
if entry.get("status") != "success":
out[ip_str] = empty
log.debug(
"geo_batch_entry_failed",
ip=ip_str,
message=entry.get("message", "unknown"),
)
continue
out[ip_str] = self._parse_single_response(entry)
# Fill any IPs missing from the response.
for ip in ips:
if ip not in out:
out[ip] = empty
return out
def _parse_single_response(self, data: dict[str, object]) -> GeoInfo:
"""Build a :class:`GeoInfo` from a single ip-api.com response dict.
Args:
data: A ``status == "success"`` JSON response from ip-api.com.
Returns:
Populated :class:`GeoInfo`.
"""
country_code: str | None = self._str_or_none(data.get("countryCode"))
country_name: str | None = self._str_or_none(data.get("country"))
asn_raw: str | None = self._str_or_none(data.get("as"))
org_raw: str | None = self._str_or_none(data.get("org"))
# ip-api returns "AS12345 Some Org" in both "as" and "org".
asn: str | None = asn_raw.split()[0] if asn_raw else None
return GeoInfo(
country_code=country_code,
country_name=country_name,
asn=asn,
org=org_raw,
)
def _str_or_none(self, value: object) -> str | None:
"""Return *value* as a non-empty string, or ``None``.
Args:
value: Raw JSON value which may be ``None``, empty, or a string.
Returns:
Stripped string if non-empty, else ``None``.
"""
if value is None:
return None
s = str(value).strip()
return s if s else None
async def flush_dirty(self, db: aiosqlite.Connection) -> int:
"""Persist all new in-memory geo entries to the ``geo_cache`` table.
Takes an atomic snapshot of the dirty set, clears it, then batch-inserts
all entries that are still present in the cache using a single
``executemany`` call and one ``COMMIT``. This is the only place that
writes to the persistent cache during normal operation after startup.
If the database write fails the entries are re-added to the dirty set
so they will be retried on the next flush cycle.
Args:
db: Open :class:`aiosqlite.Connection` to the BanGUI application
database.
Returns:
The number of rows successfully upserted.
"""
async with self._cache_lock:
if not self._dirty:
return 0
# Atomically snapshot and clear while holding the cache lock.
to_flush = self._dirty.copy()
self._dirty.clear()
rows = [
(ip, self._cache[ip].country_code, self._cache[ip].country_name, self._cache[ip].asn, self._cache[ip].org)
for ip in to_flush
if ip in self._cache
]
if not rows:
return 0
try:
await geo_cache_repo.bulk_upsert_entries_and_commit(db, rows)
except Exception as exc: # noqa: BLE001
log.warning("geo_flush_dirty_failed", error=str(exc))
# Re-add to dirty so they are retried on the next flush cycle.
self._dirty.update(to_flush)
return 0
log.info("geo_flush_dirty_complete", count=len(rows))
return len(rows)

View File

@@ -1,151 +1,72 @@
"""Geo service.
"""Geo service — backward compatibility wrappers.
Resolves IP addresses to their country, ASN, and organisation using the
`ip-api.com <http://ip-api.com>`_ JSON API. Results are cached in two tiers:
DEPRECATED: This module is kept for backward compatibility only. New code should
use :class:`GeoCache` directly from ``app.services.geo_cache``. The underlying
implementation has been refactored to eliminate module-level mutable state.
1. **In-memory dict** — fastest; survives for the life of the process.
2. **Persistent SQLite table** (``geo_cache``) — survives restarts; loaded
into the in-memory dict during application startup via
:func:`load_cache_from_db`.
The :class:`GeoCache` instance should be injected into services and routers
via dependency injection, and stored on ``app.state.geo_cache``.
Only *successful* lookups (those returning a non-``None`` ``country_code``)
are written to the persistent cache. Failed lookups are **not** cached so
they will be retried on the next request.
For bulk operations the batch endpoint ``http://ip-api.com/batch`` is used
(up to 100 IPs per HTTP call) which is far more efficient than one-at-a-time
requests. Use :func:`lookup_batch` from the ban or blocklist services.
Usage::
import aiohttp
import aiosqlite
# Use the geo_service directly in application startup
async with aiosqlite.connect("bangui.db") as db:
await geo_service.load_cache_from_db(db)
async with aiohttp.ClientSession() as session:
# single lookup
info = await geo_service.lookup("1.2.3.4", session)
if info:
# info.country_code == "DE"
... # use the GeoInfo object in your application
# bulk lookup (more efficient for large sets)
geo_map = await geo_service.lookup_batch(["1.2.3.4", "5.6.7.8"], session)
See :class:`app.services.geo_cache.GeoCache` for the implementation.
"""
from __future__ import annotations
import asyncio
import time
from typing import TYPE_CHECKING
import aiohttp
import structlog
from app.services.geo_cache import GeoCache
from app.models.geo import GeoInfo
from app.repositories import geo_cache_repo
if TYPE_CHECKING:
import aiohttp
import aiosqlite
import geoip2.database
import geoip2.errors
log: structlog.stdlib.BoundLogger = structlog.get_logger()
__all__ = [
"GeoCache",
"clear_cache",
"clear_neg_cache",
"is_cached",
"lookup",
"lookup_batch",
"lookup_cached_only",
"cache_stats",
"count_unresolved",
"get_unresolved_ips",
"load_cache_from_db",
"flush_dirty",
"re_resolve_all",
"init_geoip",
]
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
#: ip-api.com single-IP lookup endpoint (HTTP only on the free tier).
_API_URL: str = (
"http://ip-api.com/json/{ip}?fields=status,message,country,countryCode,org,as"
)
#: ip-api.com batch endpoint — accepts up to 100 IPs per POST.
_BATCH_API_URL: str = (
"http://ip-api.com/batch?fields=status,message,country,countryCode,org,as,query"
)
#: Maximum IPs per batch request (ip-api.com hard limit is 100).
_BATCH_SIZE: int = 100
#: Maximum number of entries kept in the in-process cache before it is
#: flushed completely. A simple eviction strategy — the cache is cheap to
#: rebuild from the persistent store.
_MAX_CACHE_SIZE: int = 50_000
#: Timeout for outgoing geo API requests in seconds.
_REQUEST_TIMEOUT: float = 5.0
#: How many seconds a failed lookup result is suppressed before the IP is
#: eligible for a new API attempt. Default: 5 minutes.
_NEG_CACHE_TTL: float = 300.0
#: Minimum delay in seconds between consecutive batch HTTP requests to
#: ip-api.com. The free tier allows 45 requests/min; 1.5 s ≈ 40 req/min.
_BATCH_DELAY: float = 1.5
#: Maximum number of retries for a batch chunk that fails with a
#: transient error (e.g. connection reset due to rate limiting).
_BATCH_MAX_RETRIES: int = 2
# ---------------------------------------------------------------------------
# Internal cache
# ---------------------------------------------------------------------------
#: Module-level in-memory cache: ``ip → GeoInfo`` (positive results only).
_cache: dict[str, GeoInfo] = {}
#: Negative cache: ``ip → epoch timestamp`` of last failed lookup attempt.
#: Entries within :data:`_NEG_CACHE_TTL` seconds are not re-queried.
_neg_cache: dict[str, float] = {}
#: IPs added to :data:`_cache` but not yet persisted to the database.
#: Consumed and cleared atomically by :func:`flush_dirty`.
_dirty: set[str] = set()
#: Optional MaxMind GeoLite2 reader initialised by :func:`init_geoip`.
_geoip_reader: geoip2.database.Reader | None = None
#: Indicates whether :func:`init_geoip` has already been called.
#: This function is startup-only and must not be invoked again while the
#: process is handling requests.
_geoip_initialized: bool = False
#: Lock protecting mutations to the in-memory geo caches.
_cache_lock: asyncio.Lock = asyncio.Lock()
#: Deprecated: Module-level cache instance for backward compatibility.
#: This exists only to provide backward-compatible module-level functions.
#: New code should inject GeoCache instances directly.
_default_geo_cache: GeoCache = GeoCache()
async def clear_cache() -> None:
"""Flush both the positive and negative lookup caches.
"""(DEPRECATED) Flush both the positive and negative lookup caches.
Also clears the dirty set so any pending-but-unpersisted entries are
discarded. Useful in tests and when the operator suspects stale data.
Use :meth:`GeoCache.clear` instead. This function delegates to the
default module-level instance for backward compatibility only.
"""
async with _cache_lock:
_cache.clear()
_neg_cache.clear()
_dirty.clear()
await _default_geo_cache.clear()
async def clear_neg_cache() -> None:
"""Flush only the negative (failed-lookups) cache.
"""(DEPRECATED) Flush only the negative (failed-lookups) cache.
Useful when triggering a manual re-resolve so that previously failed
IPs are immediately eligible for a new API attempt.
Use :meth:`GeoCache.clear_neg_cache` instead. This function delegates to
the default module-level instance for backward compatibility only.
"""
async with _cache_lock:
_neg_cache.clear()
await _default_geo_cache.clear_neg_cache()
def is_cached(ip: str) -> bool:
"""Return ``True`` if *ip* has a positive entry in the in-memory cache.
"""(DEPRECATED) Return ``True`` if *ip* has a positive cache entry.
A positive entry is one with a non-``None`` ``country_code``. This is
useful for skipping IPs that have already been resolved when building
a list for :func:`lookup_batch`.
Use :meth:`GeoCache.is_cached` instead. This function delegates to the
default module-level instance for backward compatibility only.
Args:
ip: IPv4 or IPv6 address string.
@@ -153,14 +74,14 @@ def is_cached(ip: str) -> bool:
Returns:
``True`` when *ip* is in the cache with a known country code.
"""
return ip in _cache and _cache[ip].country_code is not None
return _default_geo_cache.is_cached(ip)
async def cache_stats(db: aiosqlite.Connection) -> dict[str, int]:
"""Return diagnostic counters for the geo cache subsystem.
"""(DEPRECATED) Return diagnostic counters for the geo cache subsystem.
Queries the persistent store for the number of unresolved entries and
combines it with in-memory counters.
Use :meth:`GeoCache.cache_stats` instead. This function delegates to the
default module-level instance for backward compatibility only.
Args:
db: Open BanGUI application database connection.
@@ -169,24 +90,21 @@ async def cache_stats(db: aiosqlite.Connection) -> dict[str, int]:
Dict with keys ``cache_size``, ``unresolved``, ``neg_cache_size``,
and ``dirty_size``.
"""
unresolved = await geo_cache_repo.count_unresolved(db)
return {
"cache_size": len(_cache),
"unresolved": unresolved,
"neg_cache_size": len(_neg_cache),
"dirty_size": len(_dirty),
}
return await _default_geo_cache.cache_stats(db)
async def count_unresolved(db: aiosqlite.Connection) -> int:
"""Return the number of unresolved entries in the persistent geo cache."""
"""(DEPRECATED) Return the number of unresolved entries in the geo cache.
return await geo_cache_repo.count_unresolved(db)
Use :meth:`GeoCache.count_unresolved` instead.
"""
return await _default_geo_cache.count_unresolved(db)
async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
"""Return geo cache IPs where the country code has not yet been resolved.
"""(DEPRECATED) Return IPs with NULL country_code in the persistent cache.
Use :meth:`GeoCache.get_unresolved_ips` instead.
Args:
db: Open BanGUI application database connection.
@@ -194,187 +112,32 @@ async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
Returns:
List of IP addresses that are candidates for re-resolution.
"""
return await geo_cache_repo.get_unresolved_ips(db)
async def re_resolve_all(
db: aiosqlite.Connection,
http_session: aiohttp.ClientSession,
) -> dict[str, int]:
"""Retry geo resolution for all unresolved cache entries.
This helper clears the in-memory negative cache before attempting a
fresh batch lookup, then returns counters for how many IPs were retried
and how many gained a resolved country code.
Args:
db: BanGUI application database connection.
http_session: Shared aiohttp client session.
Returns:
A dict with ``resolved`` and ``total`` counts.
"""
unresolved = await get_unresolved_ips(db)
if not unresolved:
return {"resolved": 0, "total": 0}
await clear_neg_cache()
geo_map = await lookup_batch(unresolved, http_session, db=db)
resolved_count = sum(
1 for info in geo_map.values() if info.country_code is not None
)
log.info(
"geo_re_resolve_complete",
total=len(unresolved),
resolved=resolved_count,
)
return {"resolved": resolved_count, "total": len(unresolved)}
return await _default_geo_cache.get_unresolved_ips(db)
def init_geoip(mmdb_path: str | None) -> None:
"""Initialise the MaxMind GeoLite2-Country database reader.
"""(DEPRECATED) Initialise the MaxMind GeoLite2-Country database reader.
This function is startup-only and must be called before request handling
begins. A second initialization attempt is considered a programming error
and raises ``RuntimeError``.
If *mmdb_path* is ``None``, empty, or the file does not exist the
fallback is silently disabled — ip-api.com remains the sole resolver.
Use :meth:`GeoCache.init_geoip` instead. This function delegates to the
default module-level instance for backward compatibility only.
Args:
mmdb_path: Absolute path to a ``GeoLite2-Country.mmdb`` file.
"""
global _geoip_reader, _geoip_initialized # noqa: PLW0603
if _geoip_initialized:
raise RuntimeError("GeoIP reader already initialised")
if not mmdb_path:
return
from pathlib import Path # noqa: PLC0415
import geoip2.database # noqa: PLC0415
if not Path(mmdb_path).is_file():
log.warning("geoip_mmdb_not_found", path=mmdb_path)
return
_geoip_reader = geoip2.database.Reader(mmdb_path)
_geoip_initialized = True
log.info("geoip_mmdb_loaded", path=mmdb_path)
def _geoip_lookup(ip: str) -> GeoInfo | None:
"""Attempt a local MaxMind GeoLite2 lookup for *ip*.
Returns ``None`` when the reader is not initialised, the IP is not in
the database, or any other error occurs.
Args:
ip: IPv4 or IPv6 address string.
Returns:
A :class:`GeoInfo` with at least ``country_code`` populated, or
``None`` when resolution is impossible.
"""
if _geoip_reader is None:
return None
import geoip2.errors # noqa: PLC0415
try:
response = _geoip_reader.country(ip)
code: str | None = response.country.iso_code or None
name: str | None = response.country.name or None
if code is None:
return None
return GeoInfo(country_code=code, country_name=name, asn=None, org=None)
except geoip2.errors.AddressNotFoundError:
return None
except Exception as exc: # noqa: BLE001
log.warning("geoip_lookup_failed", ip=ip, error=str(exc))
return None
# ---------------------------------------------------------------------------
# Persistent cache I/O
# ---------------------------------------------------------------------------
_default_geo_cache.init_geoip(mmdb_path)
async def load_cache_from_db(db: aiosqlite.Connection) -> None:
"""Pre-populate the in-memory cache from the ``geo_cache`` table.
"""(DEPRECATED) Pre-populate the in-memory cache from ``geo_cache`` table.
Should be called once during application startup so the service starts
with a warm cache instead of making cold API calls on the first request.
Use :meth:`GeoCache.load_cache_from_db` instead. This function delegates
to the default module-level instance for backward compatibility only.
Args:
db: Open :class:`aiosqlite.Connection` to the BanGUI application
database (not the fail2ban database).
"""
count = 0
cache_entries: list[tuple[str, GeoInfo]] = []
for row in await geo_cache_repo.load_all(db):
country_code: str | None = row["country_code"]
if country_code is None:
continue
ip: str = row["ip"]
cache_entries.append(
(
ip,
GeoInfo(
country_code=country_code,
country_name=row["country_name"],
asn=row["asn"],
org=row["org"],
),
)
)
count += 1
async with _cache_lock:
for ip, info in cache_entries:
_cache[ip] = info
log.info("geo_cache_loaded_from_db", entries=count)
async def _persist_entry(
db: aiosqlite.Connection,
ip: str,
info: GeoInfo,
) -> None:
"""Upsert a resolved :class:`GeoInfo` into the ``geo_cache`` table.
Only called when ``info.country_code`` is not ``None`` so the persistent
store never contains empty placeholder rows.
Args:
db: BanGUI application database connection.
ip: IP address string.
info: Resolved geo data to persist.
"""
await geo_cache_repo.upsert_entry(
db=db,
ip=ip,
country_code=info.country_code,
country_name=info.country_name,
asn=info.asn,
org=info.org,
)
async def _persist_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
"""Record a failed lookup attempt in ``geo_cache`` with all-NULL fields.
Uses ``INSERT OR IGNORE`` so that an existing *positive* entry (one that
has a ``country_code``) is never overwritten by a later failure.
Args:
db: BanGUI application database connection.
ip: IP address string whose resolution failed.
"""
await geo_cache_repo.upsert_neg_entry(db=db, ip=ip)
# ---------------------------------------------------------------------------
# Public API — single lookup
# ---------------------------------------------------------------------------
await _default_geo_cache.load_cache_from_db(db)
async def lookup(
@@ -382,144 +145,37 @@ async def lookup(
http_session: aiohttp.ClientSession,
db: aiosqlite.Connection | None = None,
) -> GeoInfo | None:
"""Resolve an IP address to country, ASN, and organisation metadata.
"""(DEPRECATED) Resolve an IP address to country, ASN, organisation metadata.
Results are cached in-process. If the cache exceeds ``_MAX_CACHE_SIZE``
entries it is flushed before the new result is stored.
Only successful resolutions (``country_code is not None``) are written to
the persistent cache when *db* is provided. Failed lookups are **not**
cached so they are retried on the next call.
Use :meth:`GeoCache.lookup` instead. This function delegates to the
default module-level instance for backward compatibility only.
Args:
ip: IPv4 or IPv6 address string.
http_session: Shared :class:`aiohttp.ClientSession` (from
``app.state.http_session``).
db: Optional BanGUI application database. When provided, successful
lookups are persisted for cross-restart cache warming.
http_session: Shared :class:`aiohttp.ClientSession`.
db: Optional BanGUI application database for persistence.
Returns:
A :class:`GeoInfo` instance, or ``None`` when the lookup fails
in a way that should prevent the caller from caching a bad result
(e.g. network timeout).
A :class:`GeoInfo` instance, or ``None`` on lookup failure.
"""
if ip in _cache:
return _cache[ip]
# Negative cache: skip IPs that recently failed to avoid hammering the API.
neg_ts = _neg_cache.get(ip)
if neg_ts is not None and (time.monotonic() - neg_ts) < _NEG_CACHE_TTL:
return GeoInfo(country_code=None, country_name=None, asn=None, org=None)
url: str = _API_URL.format(ip=ip)
api_ok = False
try:
async with http_session.get(url, timeout=aiohttp.ClientTimeout(total=_REQUEST_TIMEOUT)) as resp:
if resp.status != 200:
log.warning("geo_lookup_non_200", ip=ip, status=resp.status)
else:
data: dict[str, object] = await resp.json(content_type=None)
if data.get("status") == "success":
api_ok = True
result = _parse_single_response(data)
await _store(ip, result)
if result.country_code is not None and db is not None:
try:
await geo_cache_repo.upsert_entry_and_commit(
db=db,
ip=ip,
country_code=result.country_code,
country_name=result.country_name,
asn=result.asn,
org=result.org,
)
except Exception as exc: # noqa: BLE001
log.warning("geo_persist_failed", ip=ip, error=str(exc))
log.debug("geo_lookup_success", ip=ip, country=result.country_code, asn=result.asn)
return result
log.debug(
"geo_lookup_failed",
ip=ip,
message=data.get("message", "unknown"),
)
except Exception as exc: # noqa: BLE001
log.warning(
"geo_lookup_request_failed",
ip=ip,
exc_type=type(exc).__name__,
error=repr(exc),
)
if not api_ok:
# Try local MaxMind database as fallback.
fallback = _geoip_lookup(ip)
if fallback is not None:
await _store(ip, fallback)
if fallback.country_code is not None and db is not None:
try:
await geo_cache_repo.upsert_entry_and_commit(
db=db,
ip=ip,
country_code=fallback.country_code,
country_name=fallback.country_name,
asn=fallback.asn,
org=fallback.org,
)
except Exception as exc: # noqa: BLE001
log.warning("geo_persist_failed", ip=ip, error=str(exc))
log.debug("geo_geoip_fallback_success", ip=ip, country=fallback.country_code)
return fallback
# Both resolvers failed — record in negative cache to avoid hammering.
async with _cache_lock:
_neg_cache[ip] = time.monotonic()
if db is not None:
try:
await geo_cache_repo.upsert_neg_entry_and_commit(db=db, ip=ip)
except Exception as exc: # noqa: BLE001
log.warning("geo_persist_neg_failed", ip=ip, error=str(exc))
return GeoInfo(country_code=None, country_name=None, asn=None, org=None)
# ---------------------------------------------------------------------------
# Public API — batch lookup
# ---------------------------------------------------------------------------
return await _default_geo_cache.lookup(ip, http_session, db=db)
def lookup_cached_only(
ips: list[str],
) -> tuple[dict[str, GeoInfo], list[str]]:
"""Return cached geo data for *ips* without making any external API calls.
"""(DEPRECATED) Return cached geo data without making external API calls.
Used by callers that want to return a fast response using only what is
already in memory, while deferring resolution of uncached IPs to a
background task.
Use :meth:`GeoCache.lookup_cached_only` instead. This function delegates
to the default module-level instance for backward compatibility only.
Args:
ips: IP address strings to look up.
Returns:
A ``(geo_map, uncached)`` tuple where *geo_map* maps every IP that
was already in the in-memory cache to its :class:`GeoInfo`, and
*uncached* is the list of IPs that were not found in the cache.
Entries in the negative cache (recently failed) are **not** included
in *uncached* so they are not re-queued immediately.
A ``(geo_map, uncached)`` tuple.
"""
geo_map: dict[str, GeoInfo] = {}
uncached: list[str] = []
now = time.monotonic()
for ip in dict.fromkeys(ips): # deduplicate, preserve order
if ip in _cache:
geo_map[ip] = _cache[ip]
elif ip in _neg_cache and (now - _neg_cache[ip]) < _NEG_CACHE_TTL:
# Still within the cool-down window — do not re-queue.
pass
else:
uncached.append(ip)
return geo_map, uncached
return _default_geo_cache.lookup_cached_only(ips)
async def lookup_batch(
@@ -527,274 +183,27 @@ async def lookup_batch(
http_session: aiohttp.ClientSession,
db: aiosqlite.Connection | None = None,
) -> dict[str, GeoInfo]:
"""Resolve multiple IP addresses in bulk using ip-api.com batch endpoint.
"""(DEPRECATED) Resolve multiple IPs in bulk using ip-api.com batch endpoint.
IPs already present in the in-memory cache are returned immediately
without making an HTTP request. Uncached IPs are sent to
``http://ip-api.com/batch`` in chunks of up to :data:`_BATCH_SIZE`.
Only successful resolutions (``country_code is not None``) are written to
the persistent cache when *db* is provided. Both positive and negative
entries are written in bulk using ``executemany`` (one round-trip per
chunk) rather than one ``execute`` per IP.
Use :meth:`GeoCache.lookup_batch` instead. This function delegates to the
default module-level instance for backward compatibility only.
Args:
ips: List of IP address strings to resolve. Duplicates are ignored.
ips: List of IP address strings to resolve.
http_session: Shared :class:`aiohttp.ClientSession`.
db: Optional BanGUI application database for persistent cache writes.
Returns:
Dict mapping ``ip → GeoInfo`` for every input IP. IPs whose
resolution failed will have a ``GeoInfo`` with all-``None`` fields.
Dict mapping ``ip → GeoInfo`` for every input IP.
"""
geo_result: dict[str, GeoInfo] = {}
uncached: list[str] = []
_empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
unique_ips = list(dict.fromkeys(ips)) # deduplicate, preserve order
now = time.monotonic()
for ip in unique_ips:
if ip in _cache:
geo_result[ip] = _cache[ip]
elif ip in _neg_cache and (now - _neg_cache[ip]) < _NEG_CACHE_TTL:
# Recently failed — skip API call, return empty result.
geo_result[ip] = _empty
else:
uncached.append(ip)
if not uncached:
return geo_result
log.info("geo_batch_lookup_start", total=len(uncached))
for batch_idx, chunk_start in enumerate(range(0, len(uncached), _BATCH_SIZE)):
chunk = uncached[chunk_start : chunk_start + _BATCH_SIZE]
# Throttle: pause between consecutive HTTP calls to stay within the
# ip-api.com free-tier rate limit (45 req/min).
if batch_idx > 0:
await asyncio.sleep(_BATCH_DELAY)
# Retry transient failures (e.g. connection-reset from rate limit).
chunk_result: dict[str, GeoInfo] | None = None
for attempt in range(_BATCH_MAX_RETRIES + 1):
chunk_result = await _batch_api_call(chunk, http_session)
# If every IP in the chunk came back with country_code=None and the
# batch wasn't tiny, that almost certainly means the whole request
# was rejected (connection reset / 429). Retry after a back-off.
all_failed = all(
info.country_code is None for info in chunk_result.values()
)
if not all_failed or attempt >= _BATCH_MAX_RETRIES:
break
backoff = _BATCH_DELAY * (2 ** (attempt + 1))
log.warning(
"geo_batch_retry",
attempt=attempt + 1,
chunk_size=len(chunk),
backoff=backoff,
)
await asyncio.sleep(backoff)
assert chunk_result is not None # noqa: S101
# Collect bulk-write rows instead of one execute per IP.
pos_rows: list[tuple[str, str | None, str | None, str | None, str | None]] = []
neg_ips: list[str] = []
for ip, info in chunk_result.items():
if info.country_code is not None:
# Successful API resolution.
await _store(ip, info)
geo_result[ip] = info
if db is not None:
pos_rows.append(
(ip, info.country_code, info.country_name, info.asn, info.org)
)
else:
# API failed — try local GeoIP fallback.
fallback = _geoip_lookup(ip)
if fallback is not None:
await _store(ip, fallback)
geo_result[ip] = fallback
if db is not None:
pos_rows.append(
(
ip,
fallback.country_code,
fallback.country_name,
fallback.asn,
fallback.org,
)
)
else:
# Both resolvers failed — record in negative cache.
async with _cache_lock:
_neg_cache[ip] = time.monotonic()
geo_result[ip] = _empty
if db is not None:
neg_ips.append(ip)
if db is not None and (pos_rows or neg_ips):
try:
await geo_cache_repo.bulk_upsert_entries_and_neg_entries_and_commit(
db,
pos_rows,
neg_ips,
)
except Exception as exc: # noqa: BLE001
log.warning(
"geo_batch_persist_failed",
positive_count=len(pos_rows),
negative_count=len(neg_ips),
error=str(exc),
)
log.info(
"geo_batch_lookup_complete",
requested=len(uncached),
resolved=sum(1 for g in geo_result.values() if g.country_code is not None),
)
return geo_result
async def _batch_api_call(
ips: list[str],
http_session: aiohttp.ClientSession,
) -> dict[str, GeoInfo]:
"""Send one batch request to the ip-api.com batch endpoint.
Args:
ips: Up to :data:`_BATCH_SIZE` IP address strings.
http_session: Shared HTTP session.
Returns:
Dict mapping ``ip → GeoInfo`` for every IP in *ips*. IPs where the
API returned a failure record or the request raised an exception get
an all-``None`` :class:`GeoInfo`.
"""
empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
fallback: dict[str, GeoInfo] = dict.fromkeys(ips, empty)
payload = [{"query": ip} for ip in ips]
try:
async with http_session.post(
_BATCH_API_URL,
json=payload,
timeout=aiohttp.ClientTimeout(total=_REQUEST_TIMEOUT * 2),
) as resp:
if resp.status != 200:
log.warning("geo_batch_non_200", status=resp.status, count=len(ips))
return fallback
data: list[dict[str, object]] = await resp.json(content_type=None)
except Exception as exc: # noqa: BLE001
log.warning(
"geo_batch_request_failed",
count=len(ips),
exc_type=type(exc).__name__,
error=repr(exc),
)
return fallback
out: dict[str, GeoInfo] = {}
for entry in data:
ip_str: str = str(entry.get("query", ""))
if not ip_str:
continue
if entry.get("status") != "success":
out[ip_str] = empty
log.debug(
"geo_batch_entry_failed",
ip=ip_str,
message=entry.get("message", "unknown"),
)
continue
out[ip_str] = _parse_single_response(entry)
# Fill any IPs missing from the response.
for ip in ips:
if ip not in out:
out[ip] = empty
return out
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _parse_single_response(data: dict[str, object]) -> GeoInfo:
"""Build a :class:`GeoInfo` from a single ip-api.com response dict.
Args:
data: A ``status == "success"`` JSON response from ip-api.com.
Returns:
Populated :class:`GeoInfo`.
"""
country_code: str | None = _str_or_none(data.get("countryCode"))
country_name: str | None = _str_or_none(data.get("country"))
asn_raw: str | None = _str_or_none(data.get("as"))
org_raw: str | None = _str_or_none(data.get("org"))
# ip-api returns "AS12345 Some Org" in both "as" and "org".
asn: str | None = asn_raw.split()[0] if asn_raw else None
return GeoInfo(
country_code=country_code,
country_name=country_name,
asn=asn,
org=org_raw,
)
def _str_or_none(value: object) -> str | None:
"""Return *value* as a non-empty string, or ``None``.
Args:
value: Raw JSON value which may be ``None``, empty, or a string.
Returns:
Stripped string if non-empty, else ``None``.
"""
if value is None:
return None
s = str(value).strip()
return s if s else None
async def _store(ip: str, info: GeoInfo) -> None:
"""Insert *info* into the module-level cache, flushing if over capacity.
When the IP resolved successfully (``country_code is not None``) it is
also added to the :data:`_dirty` set so :func:`flush_dirty` can persist
it to the database on the next scheduled flush.
Args:
ip: The IP address key.
info: The :class:`GeoInfo` to store.
"""
async with _cache_lock:
if len(_cache) >= _MAX_CACHE_SIZE:
_cache.clear()
_dirty.clear()
log.info("geo_cache_flushed", reason="capacity")
_cache[ip] = info
if info.country_code is not None:
_dirty.add(ip)
return await _default_geo_cache.lookup_batch(ips, http_session, db=db)
async def flush_dirty(db: aiosqlite.Connection) -> int:
"""Persist all new in-memory geo entries to the ``geo_cache`` table.
"""(DEPRECATED) Persist all new in-memory geo entries to the database.
Takes an atomic snapshot of :data:`_dirty`, clears it, then batch-inserts
all entries that are still present in :data:`_cache` using a single
``executemany`` call and one ``COMMIT``. This is the only place that
writes to the persistent cache during normal operation after startup.
If the database write fails the entries are re-added to :data:`_dirty`
so they will be retried on the next flush cycle.
Use :meth:`GeoCache.flush_dirty` instead. This function delegates to the
default module-level instance for backward compatibility only.
Args:
db: Open :class:`aiosqlite.Connection` to the BanGUI application
@@ -803,30 +212,23 @@ async def flush_dirty(db: aiosqlite.Connection) -> int:
Returns:
The number of rows successfully upserted.
"""
async with _cache_lock:
if not _dirty:
return 0
return await _default_geo_cache.flush_dirty(db)
# Atomically snapshot and clear while holding the cache lock.
to_flush = _dirty.copy()
_dirty.clear()
rows = [
(ip, _cache[ip].country_code, _cache[ip].country_name, _cache[ip].asn, _cache[ip].org)
for ip in to_flush
if ip in _cache
]
async def re_resolve_all(
db: aiosqlite.Connection,
http_session: aiohttp.ClientSession,
) -> dict[str, int]:
"""(DEPRECATED) Retry geo resolution for all unresolved cache entries.
if not rows:
return 0
Use :meth:`GeoCache.re_resolve_all` instead. This function delegates to
the default module-level instance for backward compatibility only.
try:
await geo_cache_repo.bulk_upsert_entries_and_commit(db, rows)
except Exception as exc: # noqa: BLE001
log.warning("geo_flush_dirty_failed", error=str(exc))
# Re-add to dirty so they are retried on the next flush cycle.
_dirty.update(to_flush)
return 0
Args:
db: BanGUI application database connection.
http_session: Shared aiohttp client session.
log.info("geo_flush_dirty_complete", count=len(rows))
return len(rows)
Returns:
A dict with ``resolved`` and ``total`` counts.
"""
return await _default_geo_cache.re_resolve_all(db, http_session)

View File

@@ -16,7 +16,8 @@ import structlog
from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped]
from app.db import init_db, open_db
from app.services import geo_service, setup_service
from app.services import setup_service
from app.services.geo_cache import GeoCache
from app.tasks import blocklist_import, geo_cache_flush, geo_re_resolve, health_check, history_sync
from app.utils.async_utils import run_blocking
from app.utils.jail_config import ensure_jail_configs
@@ -105,16 +106,18 @@ async def startup_shared_resources(
overrides=persisted_runtime_settings,
)
# Create and initialize the GeoCache instance
geo_cache = GeoCache()
if Path(settings.database_path).resolve() != original_db_path:
runtime_db = await open_db(settings.database_path)
try:
await geo_service.load_cache_from_db(runtime_db)
unresolved_count = await geo_service.count_unresolved(runtime_db)
await geo_cache.load_cache_from_db(runtime_db)
unresolved_count = await geo_cache.count_unresolved(runtime_db)
finally:
await runtime_db.close()
else:
await geo_service.load_cache_from_db(startup_db)
unresolved_count = await geo_service.count_unresolved(startup_db)
await geo_cache.load_cache_from_db(startup_db)
unresolved_count = await geo_cache.count_unresolved(startup_db)
finally:
await startup_db.close()
@@ -124,7 +127,8 @@ async def startup_shared_resources(
log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count)
http_session: aiohttp.ClientSession = _create_http_session(settings)
geo_service.init_geoip(settings.geoip_db_path)
geo_cache.init_geoip(settings.geoip_db_path)
app.state.geo_cache = geo_cache
scheduler: AsyncIOScheduler | None = None
try:

View File

@@ -1,7 +1,7 @@
"""Geo cache flush background task.
Registers an APScheduler job that periodically persists newly resolved IP
geo entries from the in-memory ``_dirty`` set to the ``geo_cache`` table.
geo entries from the in-memory dirty set to the ``geo_cache`` table.
After Task 2 removed geo cache writes from GET requests, newly resolved IPs
are only held in the in-memory cache until this task flushes them. With the
@@ -15,7 +15,7 @@ from typing import TYPE_CHECKING
import structlog
from app.services import geo_service
from app.services.geo_cache import GeoCache
from app.tasks.db import task_db
from app.utils.runtime_state import get_effective_settings
@@ -33,21 +33,23 @@ GEO_FLUSH_INTERVAL: int = 60
JOB_ID: str = "geo_cache_flush"
async def _run_flush_with_settings(settings: Settings) -> None:
"""Flush the geo service dirty set to the application database.
async def _run_flush_with_resources(geo_cache: GeoCache, settings: Settings) -> None:
"""Flush the geo cache dirty set to the application database.
Args:
geo_cache: The application's GeoCache instance.
settings: The resolved application settings used for database access.
"""
async with task_db(settings) as db:
count = await geo_service.flush_dirty(db)
count = await geo_cache.flush_dirty(db)
if count > 0:
log.debug("geo_cache_flush_ran", flushed=count)
async def _run_flush(app: FastAPI) -> None:
await _run_flush_with_settings(get_effective_settings(app))
geo_cache: GeoCache = app.state.geo_cache
await _run_flush_with_resources(geo_cache, get_effective_settings(app))
def register(app: FastAPI) -> None:
@@ -60,12 +62,13 @@ def register(app: FastAPI) -> None:
app: The :class:`fastapi.FastAPI` application instance whose
``app.state.scheduler`` will receive the job.
"""
geo_cache: GeoCache = app.state.geo_cache
settings = get_effective_settings(app)
app.state.scheduler.add_job(
_run_flush_with_settings,
_run_flush_with_resources,
trigger="interval",
seconds=GEO_FLUSH_INTERVAL,
kwargs={"settings": settings},
kwargs={"geo_cache": geo_cache, "settings": settings},
id=JOB_ID,
replace_existing=True,
)

View File

@@ -10,8 +10,8 @@ The task runs every 10 minutes. On each invocation it:
1. Queries all ``NULL``-country rows from ``geo_cache``.
2. Clears the in-memory negative cache so those IPs are eligible for a fresh
API attempt.
3. Delegates to :func:`~app.services.geo_service.lookup_batch` which already
handles rate-limit throttling and retries.
3. Delegates to :meth:`~app.services.geo_cache.GeoCache.lookup_batch` which
already handles rate-limit throttling and retries.
4. Logs how many IPs were retried and how many resolved successfully.
"""
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING
import structlog
from app.services import geo_service
from app.services.geo_cache import GeoCache
from app.tasks.db import task_db
from app.utils.runtime_state import get_effective_settings
@@ -40,16 +40,19 @@ GEO_RE_RESOLVE_INTERVAL: int = 600
JOB_ID: str = "geo_re_resolve"
async def _run_re_resolve_with_resources(settings: Settings, http_session: ClientSession) -> None:
async def _run_re_resolve_with_resources(
geo_cache: GeoCache, settings: Settings, http_session: ClientSession
) -> None:
"""Query NULL-country IPs from the database and re-resolve them.
Args:
geo_cache: The application's GeoCache instance.
settings: The resolved application settings used for database access.
http_session: The shared aiohttp session used for external lookups.
"""
async with task_db(settings) as db:
# Fetch all IPs with NULL country_code from the persistent cache.
unresolved_ips = await geo_service.get_unresolved_ips(db)
unresolved_ips = await geo_cache.get_unresolved_ips(db)
if not unresolved_ips:
log.debug("geo_re_resolve_skip", reason="no_unresolved_ips")
@@ -58,11 +61,11 @@ async def _run_re_resolve_with_resources(settings: Settings, http_session: Clien
log.info("geo_re_resolve_start", unresolved=len(unresolved_ips))
# Clear the negative cache so these IPs are eligible for fresh API calls.
await geo_service.clear_neg_cache()
await geo_cache.clear_neg_cache()
# lookup_batch handles throttling, retries, and persistence when db is
# passed. This is a background task so DB writes are allowed.
results = await geo_service.lookup_batch(unresolved_ips, http_session, db=db)
results = await geo_cache.lookup_batch(unresolved_ips, http_session, db=db)
resolved_count: int = sum(
1 for info in results.values() if info.country_code is not None
@@ -75,7 +78,10 @@ async def _run_re_resolve_with_resources(settings: Settings, http_session: Clien
async def _run_re_resolve(app: FastAPI) -> None:
await _run_re_resolve_with_resources(get_effective_settings(app), app.state.http_session)
geo_cache: GeoCache = app.state.geo_cache
await _run_re_resolve_with_resources(
geo_cache, get_effective_settings(app), app.state.http_session
)
def register(app: FastAPI) -> None:
@@ -91,12 +97,13 @@ def register(app: FastAPI) -> None:
app: The :class:`fastapi.FastAPI` application instance whose
``app.state.scheduler`` will receive the job.
"""
geo_cache: GeoCache = app.state.geo_cache
settings = get_effective_settings(app)
app.state.scheduler.add_job(
_run_re_resolve_with_resources,
trigger="interval",
seconds=GEO_RE_RESOLVE_INTERVAL,
kwargs={"settings": settings, "http_session": app.state.http_session},
kwargs={"geo_cache": geo_cache, "settings": settings, "http_session": app.state.http_session},
id=JOB_ID,
replace_existing=True,
)

View File

@@ -168,9 +168,9 @@ async def test_lifespan_initialises_and_cleans_up_shared_resources(tmp_path: Pat
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session),
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.startup.init_db", new=AsyncMock()),
patch("app.services.geo_service.init_geoip"),
patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)),
patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.services.geo_cache.GeoCache.init_geoip"),
patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=AsyncMock(return_value=None)),
patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)),
patch("app.tasks.health_check.register"),
patch("app.tasks.blocklist_import.register"),

View File

@@ -1,4 +1,4 @@
"""Tests for geo_service.lookup()."""
"""Tests for geo_service and GeoCache."""
from __future__ import annotations
@@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.models.geo import GeoInfo
from app.services import geo_service
from app.services.geo_cache import GeoCache
# ---------------------------------------------------------------------------
# Helpers
@@ -44,35 +44,33 @@ def _make_session(response_json: dict[str, object], status: int = 200) -> MagicM
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
async def clear_geo_cache() -> None:
"""Flush the module-level geo cache before every test."""
await geo_service.clear_cache()
geo_service._geoip_reader = None
geo_service._geoip_initialized = False
@pytest.fixture
async def geo_cache() -> GeoCache:
"""Provide a fresh GeoCache instance for each test."""
return GeoCache()
def test_init_geoip_is_startup_only(tmp_path) -> None:
def test_init_geoip_is_startup_only(geo_cache: GeoCache, tmp_path) -> None:
"""A second init_geoip() call raises when the reader was already loaded."""
path = tmp_path / "GeoLite2-Country.mmdb"
path.write_text("dummy")
with patch("geoip2.database.Reader", MagicMock(name="Reader")) as mock_reader:
geo_service.init_geoip(str(path))
assert geo_service._geoip_reader is not None
assert geo_service._geoip_initialized is True
geo_cache.init_geoip(str(path))
assert geo_cache._geoip_reader is not None
assert geo_cache._geoip_initialized is True
with pytest.raises(RuntimeError, match="already initialised"):
geo_service.init_geoip(str(path))
geo_cache.init_geoip(str(path))
assert mock_reader.call_count == 1
def test_init_geoip_no_path_leaves_reader_uninitialised() -> None:
def test_init_geoip_no_path_leaves_reader_uninitialised(geo_cache: GeoCache) -> None:
"""No active reader is created when no path is supplied."""
geo_service.init_geoip("")
assert geo_service._geoip_reader is None
assert geo_service._geoip_initialized is False
geo_cache.init_geoip("")
assert geo_cache._geoip_reader is None
assert geo_cache._geoip_initialized is False
# ---------------------------------------------------------------------------
@@ -83,7 +81,7 @@ def test_init_geoip_no_path_leaves_reader_uninitialised() -> None:
class TestLookupSuccess:
"""geo_service.lookup() under normal conditions."""
async def test_returns_country_code(self) -> None:
async def test_returns_country_code(self, geo_cache: GeoCache) -> None:
"""country_code is populated from the ``countryCode`` field."""
session = _make_session(
{
@@ -94,12 +92,12 @@ class TestLookupSuccess:
"org": "AS3320 Deutsche Telekom AG",
}
)
result = await geo_service.lookup("1.2.3.4", session)
result = await geo_cache.lookup("1.2.3.4", session)
assert result is not None
assert result.country_code == "DE"
async def test_returns_country_name(self) -> None:
async def test_returns_country_name(self, geo_cache: GeoCache) -> None:
"""country_name is populated from the ``country`` field."""
session = _make_session(
{
@@ -110,12 +108,12 @@ class TestLookupSuccess:
"org": "Google LLC",
}
)
result = await geo_service.lookup("8.8.8.8", session)
result = await geo_cache.lookup("8.8.8.8", session)
assert result is not None
assert result.country_name == "United States"
async def test_asn_extracted_without_org_suffix(self) -> None:
async def test_asn_extracted_without_org_suffix(self, geo_cache: GeoCache) -> None:
"""The ASN field contains only the ``AS<N>`` prefix, not the full string."""
session = _make_session(
{
@@ -126,12 +124,12 @@ class TestLookupSuccess:
"org": "Deutsche Telekom",
}
)
result = await geo_service.lookup("1.2.3.4", session)
result = await geo_cache.lookup("1.2.3.4", session)
assert result is not None
assert result.asn == "AS3320"
async def test_org_populated(self) -> None:
async def test_org_populated(self, geo_cache: GeoCache) -> None:
"""org field is populated from the ``org`` key."""
session = _make_session(
{
@@ -142,7 +140,7 @@ class TestLookupSuccess:
"org": "Google LLC",
}
)
result = await geo_service.lookup("8.8.8.8", session)
result = await geo_cache.lookup("8.8.8.8", session)
assert result is not None
assert result.org == "Google LLC"
@@ -156,7 +154,7 @@ class TestLookupSuccess:
class TestLookupCaching:
"""Verify that results are cached and the cache can be cleared."""
async def test_second_call_uses_cache(self) -> None:
async def test_second_call_uses_cache(self, geo_cache: GeoCache) -> None:
"""Subsequent lookups for the same IP do not make additional HTTP requests."""
session = _make_session(
{
@@ -168,13 +166,13 @@ class TestLookupCaching:
}
)
await geo_service.lookup("1.2.3.4", session)
await geo_service.lookup("1.2.3.4", session)
await geo_cache.lookup("1.2.3.4", session)
await geo_cache.lookup("1.2.3.4", session)
# The session.get() should only have been called once.
assert session.get.call_count == 1
async def test_clear_cache_forces_refetch(self) -> None:
async def test_clear_cache_forces_refetch(self, geo_cache: GeoCache) -> None:
"""After clearing the cache a new HTTP request is made."""
session = _make_session(
{
@@ -186,20 +184,20 @@ class TestLookupCaching:
}
)
await geo_service.lookup("2.3.4.5", session)
await geo_service.clear_cache()
await geo_service.lookup("2.3.4.5", session)
await geo_cache.lookup("2.3.4.5", session)
await geo_cache.clear()
await geo_cache.lookup("2.3.4.5", session)
assert session.get.call_count == 2
async def test_negative_result_stored_in_neg_cache(self) -> None:
async def test_negative_result_stored_in_neg_cache(self, geo_cache: GeoCache) -> None:
"""A failed lookup is stored in the negative cache, so the second call is blocked."""
session = _make_session(
{"status": "fail", "message": "reserved range"}
)
await geo_service.lookup("192.168.1.1", session)
await geo_service.lookup("192.168.1.1", session)
await geo_cache.lookup("192.168.1.1", session)
await geo_cache.lookup("192.168.1.1", session)
# Second call is blocked by the negative cache — only one API hit.
assert session.get.call_count == 1
@@ -213,15 +211,15 @@ class TestLookupCaching:
class TestLookupFailures:
"""geo_service.lookup() when things go wrong."""
async def test_non_200_response_returns_null_geo_info(self) -> None:
async def test_non_200_response_returns_null_geo_info(self, geo_cache: GeoCache) -> None:
"""A 429 or 500 status returns GeoInfo with null fields (not None)."""
session = _make_session({}, status=429)
result = await geo_service.lookup("1.2.3.4", session)
result = await geo_cache.lookup("1.2.3.4", session)
assert result is not None
assert isinstance(result, GeoInfo)
assert result.country_code is None
async def test_network_error_returns_null_geo_info(self) -> None:
async def test_network_error_returns_null_geo_info(self, geo_cache: GeoCache) -> None:
"""A network exception returns GeoInfo with null fields (not None)."""
session = MagicMock()
mock_ctx = AsyncMock()
@@ -229,15 +227,15 @@ class TestLookupFailures:
mock_ctx.__aexit__ = AsyncMock(return_value=False)
session.get = MagicMock(return_value=mock_ctx)
result = await geo_service.lookup("10.0.0.1", session)
result = await geo_cache.lookup("10.0.0.1", session)
assert result is not None
assert isinstance(result, GeoInfo)
assert result.country_code is None
async def test_failed_status_returns_geo_info_with_nulls(self) -> None:
async def test_failed_status_returns_geo_info_with_nulls(self, geo_cache: GeoCache) -> None:
"""When ip-api returns ``status=fail`` a GeoInfo with null fields is returned (but not cached)."""
session = _make_session({"status": "fail", "message": "private range"})
result = await geo_service.lookup("10.0.0.1", session)
result = await geo_cache.lookup("10.0.0.1", session)
assert result is not None
assert isinstance(result, GeoInfo)
@@ -253,43 +251,43 @@ class TestLookupFailures:
class TestNegativeCache:
"""Verify the negative cache throttles retries for failing IPs."""
async def test_neg_cache_blocks_second_lookup(self) -> None:
async def test_neg_cache_blocks_second_lookup(self, geo_cache: GeoCache) -> None:
"""After a failed lookup the second call is served from the neg cache."""
session = _make_session({"status": "fail", "message": "private range"})
r1 = await geo_service.lookup("192.0.2.1", session)
r2 = await geo_service.lookup("192.0.2.1", session)
r1 = await geo_cache.lookup("192.0.2.1", session)
r2 = await geo_cache.lookup("192.0.2.1", session)
# Only one HTTP call should have been made; second served from neg cache.
assert session.get.call_count == 1
assert r1 is not None and r1.country_code is None
assert r2 is not None and r2.country_code is None
async def test_neg_cache_retries_after_ttl(self) -> None:
async def test_neg_cache_retries_after_ttl(self, geo_cache: GeoCache) -> None:
"""When the neg-cache entry is older than the TTL a new API call is made."""
session = _make_session({"status": "fail", "message": "private range"})
await geo_service.lookup("192.0.2.2", session)
await geo_cache.lookup("192.0.2.2", session)
# Manually expire the neg-cache entry.
geo_service._neg_cache["192.0.2.2"] -= geo_service._NEG_CACHE_TTL + 1
geo_cache._neg_cache["192.0.2.2"] -= _NEG_CACHE_TTL + 1
await geo_service.lookup("192.0.2.2", session)
await geo_cache.lookup("192.0.2.2", session)
# Both calls should have hit the API.
assert session.get.call_count == 2
async def test_clear_neg_cache_allows_immediate_retry(self) -> None:
async def test_clear_neg_cache_allows_immediate_retry(self, geo_cache: GeoCache) -> None:
"""After clearing the neg cache the IP is eligible for a new API call."""
session = _make_session({"status": "fail", "message": "private range"})
await geo_service.lookup("192.0.2.3", session)
await geo_service.clear_neg_cache()
await geo_service.lookup("192.0.2.3", session)
await geo_cache.lookup("192.0.2.3", session)
await geo_cache.clear_neg_cache()
await geo_cache.lookup("192.0.2.3", session)
assert session.get.call_count == 2
async def test_successful_lookup_does_not_pollute_neg_cache(self) -> None:
async def test_successful_lookup_does_not_pollute_neg_cache(self, geo_cache: GeoCache) -> None:
"""A successful lookup must not create a neg-cache entry."""
session = _make_session(
{
@@ -301,9 +299,9 @@ class TestNegativeCache:
}
)
await geo_service.lookup("1.2.3.4", session)
await geo_cache.lookup("1.2.3.4", session)
assert "1.2.3.4" not in geo_service._neg_cache
assert "1.2.3.4" not in geo_cache._neg_cache
# ---------------------------------------------------------------------------
@@ -327,33 +325,33 @@ class TestGeoipFallback:
reader.country = MagicMock(return_value=response_mock)
return reader
async def test_geoip_fallback_called_when_api_fails(self) -> None:
async def test_geoip_fallback_called_when_api_fails(self, geo_cache: GeoCache) -> None:
"""When ip-api returns status=fail, the geoip2 reader is consulted."""
session = _make_session({"status": "fail", "message": "reserved range"})
mock_reader = self._make_geoip_reader("DE", "Germany")
with patch.object(geo_service, "_geoip_reader", mock_reader):
result = await geo_service.lookup("1.2.3.4", session)
result = await geo_cache.lookup("1.2.3.4", session)
mock_reader.country.assert_called_once_with("1.2.3.4")
assert result is not None
assert result.country_code == "DE"
assert result.country_name == "Germany"
async def test_geoip_fallback_result_stored_in_cache(self) -> None:
async def test_geoip_fallback_result_stored_in_cache(self, geo_cache: GeoCache) -> None:
"""A successful geoip2 fallback result is stored in the positive cache."""
session = _make_session({"status": "fail", "message": "reserved range"})
mock_reader = self._make_geoip_reader("US", "United States")
with patch.object(geo_service, "_geoip_reader", mock_reader):
await geo_service.lookup("8.8.8.8", session)
await geo_cache.lookup("8.8.8.8", session)
# Second call must be served from positive cache without hitting API.
await geo_service.lookup("8.8.8.8", session)
await geo_cache.lookup("8.8.8.8", session)
assert session.get.call_count == 1
assert "8.8.8.8" in geo_service._cache
assert "8.8.8.8" in geo_cache._cache
async def test_geoip_fallback_not_called_on_api_success(self) -> None:
async def test_geoip_fallback_not_called_on_api_success(self, geo_cache: GeoCache) -> None:
"""When ip-api succeeds, the geoip2 reader must not be consulted."""
session = _make_session(
{
@@ -367,18 +365,18 @@ class TestGeoipFallback:
mock_reader = self._make_geoip_reader("XX", "Nowhere")
with patch.object(geo_service, "_geoip_reader", mock_reader):
result = await geo_service.lookup("1.2.3.4", session)
result = await geo_cache.lookup("1.2.3.4", session)
mock_reader.country.assert_not_called()
assert result is not None
assert result.country_code == "JP"
async def test_geoip_fallback_not_called_when_no_reader(self) -> None:
async def test_geoip_fallback_not_called_when_no_reader(self, geo_cache: GeoCache) -> None:
"""When no geoip2 reader is configured, the fallback silently does nothing."""
session = _make_session({"status": "fail", "message": "private range"})
with patch.object(geo_service, "_geoip_reader", None):
result = await geo_service.lookup("10.0.0.1", session)
result = await geo_cache.lookup("10.0.0.1", session)
assert result is not None
assert result.country_code is None
@@ -428,7 +426,7 @@ def _make_async_db() -> MagicMock:
class TestLookupBatchSingleCommit:
"""lookup_batch() issues exactly one commit per call, not one per IP."""
async def test_single_commit_for_multiple_ips(self) -> None:
async def test_single_commit_for_multiple_ips(self, geo_cache: GeoCache) -> None:
"""A batch of N IPs produces exactly one db.commit(), not N."""
ips = ["1.1.1.1", "2.2.2.2", "3.3.3.3"]
batch_response = [
@@ -438,11 +436,11 @@ class TestLookupBatchSingleCommit:
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db)
await geo_cache.lookup_batch(ips, session, db=db)
db.commit.assert_awaited_once()
async def test_commit_called_even_on_failed_lookups(self) -> None:
async def test_commit_called_even_on_failed_lookups(self, geo_cache: GeoCache) -> None:
"""A batch with all-failed lookups still triggers one commit."""
ips = ["10.0.0.1", "10.0.0.2"]
batch_response = [
@@ -452,11 +450,11 @@ class TestLookupBatchSingleCommit:
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db)
await geo_cache.lookup_batch(ips, session, db=db)
db.commit.assert_awaited_once()
async def test_no_commit_when_db_is_none(self) -> None:
async def test_no_commit_when_db_is_none(self, geo_cache: GeoCache) -> None:
"""When db=None, no commit is attempted."""
ips = ["1.1.1.1"]
batch_response = [
@@ -472,19 +470,19 @@ class TestLookupBatchSingleCommit:
session = _make_batch_session(batch_response)
# Should not raise; without db there is nothing to commit.
result = await geo_service.lookup_batch(ips, session, db=None)
result = await geo_cache.lookup_batch(ips, session, db=None)
assert result["1.1.1.1"].country_code == "US"
async def test_no_commit_for_all_cached_ips(self) -> None:
async def test_no_commit_for_all_cached_ips(self, geo_cache: GeoCache) -> None:
"""When all IPs are already cached, no HTTP call and no commit occur."""
geo_service._cache["5.5.5.5"] = GeoInfo(
geo_cache._cache["5.5.5.5"] = GeoInfo(
country_code="FR", country_name="France", asn="AS1", org="ISP"
)
db = _make_async_db()
session = _make_batch_session([])
result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db)
result = await geo_cache.lookup_batch(["5.5.5.5"], session, db=db)
assert result["5.5.5.5"].country_code == "FR"
db.commit.assert_not_awaited()
@@ -499,31 +497,31 @@ class TestLookupBatchSingleCommit:
class TestDirtySetTracking:
"""_store() marks successfully resolved IPs as dirty."""
async def test_successful_resolution_adds_to_dirty(self) -> None:
async def test_successful_resolution_adds_to_dirty(self, geo_cache: GeoCache) -> None:
"""Storing a GeoInfo with a country_code adds the IP to _dirty."""
info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")
await geo_service._store("1.2.3.4", info)
assert "1.2.3.4" in geo_service._dirty
assert "1.2.3.4" in geo_cache._dirty
async def test_null_country_does_not_add_to_dirty(self) -> None:
async def test_null_country_does_not_add_to_dirty(self, geo_cache: GeoCache) -> None:
"""Storing a GeoInfo with country_code=None must not pollute _dirty."""
info = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
await geo_service._store("10.0.0.1", info)
assert "10.0.0.1" not in geo_service._dirty
assert "10.0.0.1" not in geo_cache._dirty
async def test_clear_cache_also_clears_dirty(self) -> None:
async def test_clear_cache_also_clears_dirty(self, geo_cache: GeoCache) -> None:
"""clear_cache() must discard any pending dirty entries."""
info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP")
await geo_service._store("8.8.8.8", info)
assert geo_service._dirty
assert geo_cache._dirty
await geo_service.clear_cache()
await geo_cache.clear()
assert not geo_service._dirty
assert not geo_cache._dirty
async def test_lookup_batch_populates_dirty(self) -> None:
async def test_lookup_batch_populates_dirty(self, geo_cache: GeoCache) -> None:
"""After lookup_batch() with db=None, resolved IPs appear in _dirty."""
ips = ["1.1.1.1", "2.2.2.2"]
batch_response = [
@@ -532,39 +530,39 @@ class TestDirtySetTracking:
]
session = _make_batch_session(batch_response)
await geo_service.lookup_batch(ips, session, db=None)
await geo_cache.lookup_batch(ips, session, db=None)
for ip in ips:
assert ip in geo_service._dirty
assert ip in geo_cache._dirty
class TestFlushDirty:
"""flush_dirty() persists dirty entries and clears the set."""
async def test_flush_writes_and_clears_dirty(self) -> None:
async def test_flush_writes_and_clears_dirty(self, geo_cache: GeoCache) -> None:
"""flush_dirty() inserts all dirty IPs and clears _dirty afterwards."""
info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT")
await geo_service._store("100.0.0.1", info)
assert "100.0.0.1" in geo_service._dirty
assert "100.0.0.1" in geo_cache._dirty
db = _make_async_db()
count = await geo_service.flush_dirty(db)
count = await geo_cache.flush_dirty(db)
assert count == 1
db.executemany.assert_awaited_once()
db.commit.assert_awaited_once()
assert "100.0.0.1" not in geo_service._dirty
assert "100.0.0.1" not in geo_cache._dirty
async def test_flush_returns_zero_when_nothing_dirty(self) -> None:
async def test_flush_returns_zero_when_nothing_dirty(self, geo_cache: GeoCache) -> None:
"""flush_dirty() returns 0 and makes no DB calls when _dirty is empty."""
db = _make_async_db()
count = await geo_service.flush_dirty(db)
count = await geo_cache.flush_dirty(db)
assert count == 0
db.executemany.assert_not_awaited()
db.commit.assert_not_awaited()
async def test_flush_re_adds_to_dirty_on_db_error(self) -> None:
async def test_flush_re_adds_to_dirty_on_db_error(self, geo_cache: GeoCache) -> None:
"""When the DB write fails, entries are re-added to _dirty for retry."""
info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP")
await geo_service._store("200.0.0.1", info)
@@ -572,12 +570,12 @@ class TestFlushDirty:
db = _make_async_db()
db.executemany = AsyncMock(side_effect=OSError("disk full"))
count = await geo_service.flush_dirty(db)
count = await geo_cache.flush_dirty(db)
assert count == 0
assert "200.0.0.1" in geo_service._dirty
assert "200.0.0.1" in geo_cache._dirty
async def test_flush_batch_and_lookup_batch_integration(self) -> None:
async def test_flush_batch_and_lookup_batch_integration(self, geo_cache: GeoCache) -> None:
"""lookup_batch() populates _dirty; flush_dirty() then persists them."""
ips = ["10.1.2.3", "10.1.2.4"]
batch_response = [
@@ -587,15 +585,15 @@ class TestFlushDirty:
session = _make_batch_session(batch_response)
# Resolve without DB to populate only in-memory cache and _dirty.
await geo_service.lookup_batch(ips, session, db=None)
assert geo_service._dirty == set(ips)
await geo_cache.lookup_batch(ips, session, db=None)
assert geo_cache._dirty == set(ips)
# Now flush to the DB.
db = _make_async_db()
count = await geo_service.flush_dirty(db)
count = await geo_cache.flush_dirty(db)
assert count == 2
assert not geo_service._dirty
assert not geo_cache._dirty
db.commit.assert_awaited_once()
@@ -607,7 +605,7 @@ class TestFlushDirty:
class TestLookupBatchThrottling:
"""Verify the inter-batch delay, retry, and give-up behaviour."""
async def test_lookup_batch_throttles_between_chunks(self) -> None:
async def test_lookup_batch_throttles_between_chunks(self, geo_cache: GeoCache) -> None:
"""When more than _BATCH_SIZE IPs are sent, asyncio.sleep is called
between consecutive batch HTTP calls with at least _BATCH_DELAY."""
# Generate _BATCH_SIZE + 1 IPs so we get exactly 2 batch calls.
@@ -628,7 +626,7 @@ class TestLookupBatchThrottling:
) as mock_batch,
patch("app.services.geo_service.asyncio.sleep", new_callable=AsyncMock) as mock_sleep,
):
await geo_service.lookup_batch(ips, MagicMock())
await geo_cache.lookup_batch(ips, MagicMock())
# Two chunks → one sleep between them.
assert mock_batch.call_count == 2
@@ -636,7 +634,7 @@ class TestLookupBatchThrottling:
delay_arg: float = mock_sleep.call_args[0][0]
assert delay_arg >= geo_service._BATCH_DELAY
async def test_lookup_batch_retries_on_full_chunk_failure(self) -> None:
async def test_lookup_batch_retries_on_full_chunk_failure(self, geo_cache: GeoCache) -> None:
"""When a chunk returns all-None on first try, it retries and succeeds."""
ips = ["1.2.3.4", "5.6.7.8"]
@@ -664,13 +662,13 @@ class TestLookupBatchThrottling:
),
patch("app.services.geo_service.asyncio.sleep", new_callable=AsyncMock),
):
result = await geo_service.lookup_batch(ips, MagicMock())
result = await geo_cache.lookup_batch(ips, MagicMock())
assert call_count == 2
assert result["1.2.3.4"].country_code == "DE"
assert result["5.6.7.8"].country_code == "US"
async def test_lookup_batch_gives_up_after_max_retries(self) -> None:
async def test_lookup_batch_gives_up_after_max_retries(self, geo_cache: GeoCache) -> None:
"""After _BATCH_MAX_RETRIES + 1 attempts, IPs end up in the neg cache."""
ips = ["9.9.9.9"]
_empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
@@ -686,14 +684,14 @@ class TestLookupBatchThrottling:
) as mock_batch,
patch("app.services.geo_service.asyncio.sleep", new_callable=AsyncMock) as mock_sleep,
):
result = await geo_service.lookup_batch(ips, MagicMock())
result = await geo_cache.lookup_batch(ips, MagicMock())
# Initial attempt + max_retries retries.
assert mock_batch.call_count == max_retries + 1
# IP should have no country.
assert result["9.9.9.9"].country_code is None
# Negative cache should contain the IP.
assert "9.9.9.9" in geo_service._neg_cache
assert "9.9.9.9" in geo_cache._neg_cache
# Sleep called for each retry with exponential backoff.
assert mock_sleep.call_count == max_retries
backoff_values = [call.args[0] for call in mock_sleep.call_args_list]
@@ -717,7 +715,7 @@ class TestErrorLogging:
always present, and adds an ``exc_type`` field for easy log filtering.
"""
async def test_empty_message_exception_logs_exc_type(self, caplog: pytest.LogCaptureFixture) -> None:
async def test_empty_message_exception_logs_exc_type(geo_cache: GeoCache, self, caplog: pytest.LogCaptureFixture) -> None:
"""When exception str() is empty, exc_type and repr are still logged."""
class _EmptyMessageError(Exception):
@@ -735,7 +733,7 @@ class TestErrorLogging:
import structlog.testing
with structlog.testing.capture_logs() as captured:
result = await geo_service.lookup("197.221.98.153", session)
result = await geo_cache.lookup("197.221.98.153", session)
assert result is not None
assert result.country_code is None
@@ -748,7 +746,7 @@ class TestErrorLogging:
# repr() must include the class name even when str() is empty.
assert "_EmptyMessageError" in event["error"]
async def test_connection_error_logs_exc_type(self, caplog: pytest.LogCaptureFixture) -> None:
async def test_connection_error_logs_exc_type(geo_cache: GeoCache, self, caplog: pytest.LogCaptureFixture) -> None:
"""A standard OSError with message is logged both in error and exc_type."""
session = MagicMock()
mock_ctx = AsyncMock()
@@ -759,7 +757,7 @@ class TestErrorLogging:
import structlog.testing
with structlog.testing.capture_logs() as captured:
await geo_service.lookup("10.0.0.1", session)
await geo_cache.lookup("10.0.0.1", session)
request_failed = [e for e in captured if e.get("event") == "geo_lookup_request_failed"]
assert len(request_failed) == 1
@@ -767,7 +765,7 @@ class TestErrorLogging:
assert event["exc_type"] == "OSError"
assert "connection refused" in event["error"]
async def test_batch_empty_message_exception_logs_exc_type(self) -> None:
async def test_batch_empty_message_exception_logs_exc_type(self, geo_cache: GeoCache) -> None:
"""Batch API call: empty-message exceptions include exc_type in the log."""
class _EmptyMessageError(Exception):
@@ -804,10 +802,10 @@ class TestLookupCachedOnly:
def test_returns_cached_ips(self) -> None:
"""IPs already in the cache are returned in the geo_map."""
geo_service._cache["1.1.1.1"] = GeoInfo(
geo_cache._cache["1.1.1.1"] = GeoInfo(
country_code="AU", country_name="Australia", asn="AS13335", org="Cloudflare"
)
geo_map, uncached = geo_service.lookup_cached_only(["1.1.1.1"])
geo_map, uncached = geo_cache.lookup_cached_only(["1.1.1.1"])
assert "1.1.1.1" in geo_map
assert geo_map["1.1.1.1"].country_code == "AU"
@@ -815,7 +813,7 @@ class TestLookupCachedOnly:
def test_returns_uncached_ips(self) -> None:
"""IPs not in the cache appear in the uncached list."""
geo_map, uncached = geo_service.lookup_cached_only(["9.9.9.9"])
geo_map, uncached = geo_cache.lookup_cached_only(["9.9.9.9"])
assert "9.9.9.9" not in geo_map
assert "9.9.9.9" in uncached
@@ -824,42 +822,42 @@ class TestLookupCachedOnly:
"""IPs in the negative cache within TTL are not re-queued as uncached."""
import time
geo_service._neg_cache["10.0.0.1"] = time.monotonic()
geo_cache._neg_cache["10.0.0.1"] = time.monotonic()
geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.1"])
geo_map, uncached = geo_cache.lookup_cached_only(["10.0.0.1"])
assert "10.0.0.1" not in geo_map
assert "10.0.0.1" not in uncached
def test_expired_neg_cache_requeued(self) -> None:
"""IPs whose neg-cache entry has expired are listed as uncached."""
geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired
geo_cache._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired
_geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.2"])
_geo_map, uncached = geo_cache.lookup_cached_only(["10.0.0.2"])
assert "10.0.0.2" in uncached
def test_mixed_ips(self) -> None:
"""A mix of cached, neg-cached, and unknown IPs is split correctly."""
geo_service._cache["1.2.3.4"] = GeoInfo(
geo_cache._cache["1.2.3.4"] = GeoInfo(
country_code="DE", country_name="Germany", asn=None, org=None
)
import time
geo_service._neg_cache["5.5.5.5"] = time.monotonic()
geo_cache._neg_cache["5.5.5.5"] = time.monotonic()
geo_map, uncached = geo_service.lookup_cached_only(["1.2.3.4", "5.5.5.5", "9.9.9.9"])
geo_map, uncached = geo_cache.lookup_cached_only(["1.2.3.4", "5.5.5.5", "9.9.9.9"])
assert list(geo_map.keys()) == ["1.2.3.4"]
assert uncached == ["9.9.9.9"]
def test_deduplication(self) -> None:
"""Duplicate IPs in the input appear at most once in the output."""
geo_service._cache["1.2.3.4"] = GeoInfo(
geo_cache._cache["1.2.3.4"] = GeoInfo(
country_code="US", country_name="United States", asn=None, org=None
)
geo_map, uncached = geo_service.lookup_cached_only(
geo_map, uncached = geo_cache.lookup_cached_only(
["9.9.9.9", "9.9.9.9", "1.2.3.4", "1.2.3.4"]
)
@@ -868,27 +866,32 @@ class TestLookupCachedOnly:
class TestReResolveAll:
"""Tests for :func:`~app.services.geo_service.re_resolve_all`."""
"""Tests for :func:`~app.services.geo_cache.re_resolve_all`."""
async def test_returns_zero_when_no_unresolved_ips(self) -> None:
async def test_returns_zero_when_no_unresolved_ips(self, geo_cache: GeoCache) -> None:
"""The service returns zero counts when there are no unresolved IPs."""
db = MagicMock()
session = MagicMock()
with patch(
"app.services.geo_service.get_unresolved_ips",
"app.repositories.geo_cache_repo.get_unresolved_ips",
AsyncMock(return_value=[]),
), patch("app.services.geo_service.lookup_batch", AsyncMock()) as mock_lookup, patch(
"app.services.geo_service.clear_neg_cache",
MagicMock(),
), patch.object(
geo_cache,
"lookup_batch",
AsyncMock(),
) as mock_lookup, patch.object(
geo_cache,
"clear_neg_cache",
AsyncMock(),
) as mock_clear:
result = await geo_service.re_resolve_all(db, session)
result = await geo_cache.re_resolve_all(db, session)
assert result == {"resolved": 0, "total": 0}
mock_clear.assert_not_called()
mock_lookup.assert_not_called()
async def test_clears_neg_cache_and_returns_counts(self) -> None:
async def test_clears_neg_cache_and_returns_counts(self, geo_cache: GeoCache) -> None:
"""The service clears negative cache and returns resolved and total counts."""
db = MagicMock()
session = MagicMock()
@@ -899,16 +902,18 @@ class TestReResolveAll:
}
with patch(
"app.services.geo_service.get_unresolved_ips",
"app.repositories.geo_cache_repo.get_unresolved_ips",
AsyncMock(return_value=ips),
), patch(
"app.services.geo_service.lookup_batch",
), patch.object(
geo_cache,
"lookup_batch",
AsyncMock(return_value=geo_map),
) as mock_lookup, patch(
"app.services.geo_service.clear_neg_cache",
) as mock_lookup, patch.object(
geo_cache,
"clear_neg_cache",
AsyncMock(),
) as mock_clear:
result = await geo_service.re_resolve_all(db, session)
result = await geo_cache.re_resolve_all(db, session)
assert result == {"resolved": 1, "total": 2}
mock_clear.assert_called_once()
@@ -923,7 +928,7 @@ class TestReResolveAll:
class TestLookupBatchBulkWrites:
"""lookup_batch() uses executemany for bulk DB writes, not per-IP execute."""
async def test_executemany_called_for_successful_ips(self) -> None:
async def test_executemany_called_for_successful_ips(self, geo_cache: GeoCache) -> None:
"""When multiple IPs resolve successfully, a single executemany write occurs."""
ips = ["1.1.1.1", "2.2.2.2", "3.3.3.3"]
batch_response = [
@@ -940,14 +945,14 @@ class TestLookupBatchBulkWrites:
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db)
await geo_cache.lookup_batch(ips, session, db=db)
# One executemany for the positive rows.
assert db.executemany.await_count >= 1
# High-level: execute() must NOT be called for the batch writes.
db.execute.assert_not_awaited()
async def test_executemany_called_for_failed_ips(self) -> None:
async def test_executemany_called_for_failed_ips(self, geo_cache: GeoCache) -> None:
"""When IPs fail resolution, a single executemany write covers neg entries."""
ips = ["10.0.0.1", "10.0.0.2"]
batch_response = [
@@ -957,12 +962,12 @@ class TestLookupBatchBulkWrites:
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db)
await geo_cache.lookup_batch(ips, session, db=db)
assert db.executemany.await_count >= 1
db.execute.assert_not_awaited()
async def test_mixed_results_two_executemany_calls(self) -> None:
async def test_mixed_results_two_executemany_calls(self, geo_cache: GeoCache) -> None:
"""A mix of successful and failed IPs produces two executemany calls."""
ips = ["1.1.1.1", "10.0.0.1"]
batch_response = [
@@ -979,7 +984,7 @@ class TestLookupBatchBulkWrites:
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db)
await geo_cache.lookup_batch(ips, session, db=db)
# One executemany for positives, one for negatives.
assert db.executemany.await_count == 2