diff --git a/Docs/Architekture.md b/Docs/Architekture.md index d7be7ca..551bd47 100644 --- a/Docs/Architekture.md +++ b/Docs/Architekture.md @@ -196,7 +196,8 @@ The business logic layer. Services orchestrate operations, enforce rules, and co | `log_service.py` | Log preview and regex test operations (extracted from config_service) | | `history_service.py` | Queries the fail2ban database for historical ban records, builds per-IP timelines, computes ban counts and repeat-offender flags, and syncs new records into BanGUI's archive table | | `blocklist_service.py` | Downloads blocklists via aiohttp, validates IPs/CIDRs, applies bans through fail2ban or iptables, logs import results | -| `geo_service.py` | Resolves IP addresses to country, ASN, and RIR using external APIs or a local database, caches results, and re-resolves unresolved geo cache entries | +| `geo_cache.py` | **GeoCache** class that encapsulates all IP geolocation caching: resolves IP addresses to country, ASN, and organization using external APIs or a local MaxMind database, maintains in-memory and persistent caches with negative cache support, and manages background re-resolution. Instantiated once at startup and stored on `app.state.geo_cache` | +| `geo_service.py` | (Deprecated) Backward-compatibility wrappers that delegate to the `GeoCache` instance. Kept for compatibility with existing code. New code should use `GeoCache` directly or via dependency injection | | `server_service.py` | Reads and writes fail2ban server-level settings (log level, log target, syslog socket, DB location, purge age) | | `health_service.py` | Probes fail2ban socket connectivity, retrieves server version and global stats, reports online/offline status | @@ -667,6 +668,7 @@ BanGUI maintains its **own SQLite database** (separate from the fail2ban databas - The frontend `AuthProvider` checks session validity on mount and redirects to `/login` if invalid. - The backend `dependencies.py` provides an `authenticated` dependency that validates the session cookie on every protected endpoint. - **Session validation cache** — validated session tokens are cached in memory for 10 seconds (`_session_cache` dict in `dependencies.py`) to avoid a SQLite round-trip on every request from the same browser. The cache is invalidated immediately on logout. This cache is process-local and not safe for multi-worker or distributed deployments. A clustered deployment should replace `_session_cache` with a shared cache or remove it entirely. +- **GeoCache** — `GeoCache` instance is created at startup and stored on `app.state.geo_cache`. It encapsulates all IP geolocation caching: in-memory lookup cache, negative cache for unresolvable IPs (with TTL), dirty set for persistence, and thread-safe async locking. Cache is loaded from the `geo_cache` SQLite table on startup. New resolutions are accumulated in memory and periodically flushed to the database by the `geo_cache_flush` background task. Stale entries are re-resolved by the `geo_re_resolve` task. Injected into routes and tasks via FastAPI's dependency system. - **Runtime state** — `RuntimeState` is process-local and only safe when BanGUI runs as a single asyncio worker. Mutating runtime state must not span `await` points because the current design relies on cooperative scheduling. Multi-worker or multi-process deployments must replace this runtime state with a shared coordination backend such as Redis, shared memory, or a database-backed store. - **Setup-completion flag** — once `is_setup_complete()` returns `True`, the result is stored in `app.state._setup_complete_cached`. The `SetupRedirectMiddleware` skips the DB query on all subsequent requests, removing 1 SQL query per request for the common post-setup case. The completion flag is only written after the runtime database is successfully initialized and all initial setup settings are persisted, preventing a failed setup from permanently bypassing the setup wizard. diff --git a/Docs/Tasks.md b/Docs/Tasks.md index f08392e..4441bb8 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -1,23 +1,3 @@ -### T-03 · Centralise `_DEFAULT_PAGE_SIZE` constant - -**Where found:** `backend/app/routers/dashboard.py:45`, `routers/history.py:34`, `services/ban_service.py:70`, `services/history_service.py:49` - -**Why this is needed:** Four independent definitions can drift. The router default and service default are currently coincidentally aligned at 100, but nothing enforces this. - -**Goal:** Single definition in `app/utils/constants.py`, imported everywhere. - -**What to do:** -1. Add `DEFAULT_PAGE_SIZE: Final[int] = 100` and `MAX_PAGE_SIZE: Final[int] = 500` to `app/utils/constants.py`. -2. Replace all four local `_DEFAULT_PAGE_SIZE` and `_MAX_PAGE_SIZE` declarations with imports. - -**Possible traps and issues:** None significant. Pure search-and-replace. - -**Docs changes needed:** None. - -**Doc references:** `app/utils/constants.py` - ---- - ### T-04 · Encapsulate `geo_service` module-level mutable state in a class **Where found:** `backend/app/services/geo_service.py` — module globals `_cache`, `_neg_cache`, `_dirty`, `_geoip_reader`, `_geoip_initialized`, `_cache_lock` diff --git a/backend/app/dependencies.py b/backend/app/dependencies.py index 3b0563e..7d69898 100644 --- a/backend/app/dependencies.py +++ b/backend/app/dependencies.py @@ -31,6 +31,7 @@ from app.repositories.protocols import ( SettingsRepository, SessionRepository, ) +from app.services.geo_cache import GeoCache from app.utils.constants import SESSION_COOKIE_NAME from app.utils.runtime_state import RuntimeState from app.utils.session_cache import InMemorySessionCache, NoOpSessionCache, SessionCache @@ -50,6 +51,7 @@ class AppState(Protocol): runtime_settings: Settings | None runtime_state: RuntimeState session_cache: SessionCache + geo_cache: GeoCache # noqa: F821 @dataclass @@ -214,11 +216,15 @@ async def get_fail2ban_start_command(settings: Settings = Depends(get_settings)) return settings.fail2ban_start_command -async def get_geo_batch_lookup() -> GeoBatchLookup: - """Provide the concrete geo batch lookup callable used by routers.""" - from app.services import geo_service # noqa: PLC0415 +async def get_geo_batch_lookup(request: Request) -> GeoBatchLookup: + """Provide the geo batch lookup method from the application's GeoCache instance.""" + geo_cache: GeoCache = request.app.state.geo_cache + return geo_cache.lookup_batch # type: ignore[return-value] - return geo_service.lookup_batch + +async def get_geo_cache(request: Request) -> GeoCache: + """Provide the application's GeoCache instance.""" + return request.app.state.geo_cache async def get_session_cache(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> SessionCache: diff --git a/backend/app/services/geo_cache.py b/backend/app/services/geo_cache.py new file mode 100644 index 0000000..73a9ead --- /dev/null +++ b/backend/app/services/geo_cache.py @@ -0,0 +1,737 @@ +"""GeoCache service — encapsulates geo service mutable state. + +This module defines the :class:`GeoCache` class which encapsulates all +module-level mutable state from the original geo_service into a single +injectable instance. The cache manages: + +- In-memory positive results cache (``ip → GeoInfo``) +- Negative cache (failed lookups with TTL) +- Dirty set (entries pending persistence) +- Lock protecting cache mutations +- MaxMind GeoLite2 reader initialization + +An instance should be created once at startup and stored on ``app.state.geo_cache``. +""" + +from __future__ import annotations + +import asyncio +import time +from typing import TYPE_CHECKING + +import aiohttp +import structlog + +from app.models.geo import GeoInfo +from app.repositories import geo_cache_repo + +if TYPE_CHECKING: + import aiosqlite + import geoip2.database + import geoip2.errors + +log: structlog.stdlib.BoundLogger = structlog.get_logger() + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +#: ip-api.com single-IP lookup endpoint (HTTP only on the free tier). +_API_URL: str = ( + "http://ip-api.com/json/{ip}?fields=status,message,country,countryCode,org,as" +) + +#: ip-api.com batch endpoint — accepts up to 100 IPs per POST. +_BATCH_API_URL: str = ( + "http://ip-api.com/batch?fields=status,message,country,countryCode,org,as,query" +) + +#: Maximum IPs per batch request (ip-api.com hard limit is 100). +_BATCH_SIZE: int = 100 + +#: Maximum number of entries kept in the in-process cache before it is +#: flushed completely. A simple eviction strategy — the cache is cheap to +#: rebuild from the persistent store. +_MAX_CACHE_SIZE: int = 50_000 + +#: Timeout for outgoing geo API requests in seconds. +_REQUEST_TIMEOUT: float = 5.0 + +#: How many seconds a failed lookup result is suppressed before the IP is +#: eligible for a new API attempt. Default: 5 minutes. +_NEG_CACHE_TTL: float = 300.0 + +#: Minimum delay in seconds between consecutive batch HTTP requests to +#: ip-api.com. The free tier allows 45 requests/min; 1.5 s ≈ 40 req/min. +_BATCH_DELAY: float = 1.5 + +#: Maximum number of retries for a batch chunk that fails with a +#: transient error (e.g. connection reset due to rate limiting). +_BATCH_MAX_RETRIES: int = 2 + + +class GeoCache: + """Manages IP geolocation caching with positive and negative caches. + + Encapsulates all mutable state needed for geo-IP resolution. Provides + methods for single lookups, batch lookups, persistence, and cache management. + + State: + _cache: In-memory positive results cache (``ip → GeoInfo``). + _neg_cache: Failed lookup timestamps (``ip → epoch``). + _dirty: IPs added but not yet persisted to database. + _geoip_reader: Optional MaxMind GeoLite2 reader. + _geoip_initialized: Indicates whether init_geoip() has been called. + _cache_lock: Async lock protecting cache mutations. + """ + + def __init__(self) -> None: + """Initialize an empty GeoCache.""" + self._cache: dict[str, GeoInfo] = {} + self._neg_cache: dict[str, float] = {} + self._dirty: set[str] = set() + self._geoip_reader: geoip2.database.Reader | None = None + self._geoip_initialized: bool = False + self._cache_lock: asyncio.Lock = asyncio.Lock() + + async def clear(self) -> None: + """Flush both the positive and negative lookup caches. + + Also clears the dirty set so any pending-but-unpersisted entries are + discarded. Useful in tests and when the operator suspects stale data. + """ + async with self._cache_lock: + self._cache.clear() + self._neg_cache.clear() + self._dirty.clear() + + async def clear_neg_cache(self) -> None: + """Flush only the negative (failed-lookups) cache. + + Useful when triggering a manual re-resolve so that previously failed + IPs are immediately eligible for a new API attempt. + """ + async with self._cache_lock: + self._neg_cache.clear() + + def is_cached(self, ip: str) -> bool: + """Return ``True`` if *ip* has a positive entry in the in-memory cache. + + A positive entry is one with a non-``None`` ``country_code``. + + Args: + ip: IPv4 or IPv6 address string. + + Returns: + ``True`` when *ip* is in the cache with a known country code. + """ + return ip in self._cache and self._cache[ip].country_code is not None + + async def cache_stats(self, db: aiosqlite.Connection) -> dict[str, int]: + """Return diagnostic counters for the geo cache subsystem. + + Queries the persistent store for the number of unresolved entries and + combines it with in-memory counters. + + Args: + db: Open BanGUI application database connection. + + Returns: + Dict with keys ``cache_size``, ``unresolved``, ``neg_cache_size``, + and ``dirty_size``. + """ + unresolved = await geo_cache_repo.count_unresolved(db) + + return { + "cache_size": len(self._cache), + "unresolved": unresolved, + "neg_cache_size": len(self._neg_cache), + "dirty_size": len(self._dirty), + } + + async def count_unresolved(self, db: aiosqlite.Connection) -> int: + """Return the number of unresolved entries in the persistent geo cache.""" + return await geo_cache_repo.count_unresolved(db) + + async def get_unresolved_ips(self, db: aiosqlite.Connection) -> list[str]: + """Return geo cache IPs where the country code has not yet been resolved. + + Args: + db: Open BanGUI application database connection. + + Returns: + List of IP addresses that are candidates for re-resolution. + """ + return await geo_cache_repo.get_unresolved_ips(db) + + async def re_resolve_all( + self, + db: aiosqlite.Connection, + http_session: aiohttp.ClientSession, + ) -> dict[str, int]: + """Retry geo resolution for all unresolved cache entries. + + This helper clears the in-memory negative cache before attempting a + fresh batch lookup, then returns counters for how many IPs were retried + and how many gained a resolved country code. + + Args: + db: BanGUI application database connection. + http_session: Shared aiohttp client session. + + Returns: + A dict with ``resolved`` and ``total`` counts. + """ + import structlog # noqa: PLC0415 + + log = structlog.get_logger() + unresolved = await self.get_unresolved_ips(db) + if not unresolved: + return {"resolved": 0, "total": 0} + + await self.clear_neg_cache() + geo_map = await self.lookup_batch(unresolved, http_session, db=db) + resolved_count = sum( + 1 for info in geo_map.values() if info.country_code is not None + ) + + log.info( + "geo_re_resolve_complete", + total=len(unresolved), + resolved=resolved_count, + ) + return {"resolved": resolved_count, "total": len(unresolved)} + + def init_geoip(self, mmdb_path: str | None) -> None: + """Initialise the MaxMind GeoLite2-Country database reader. + + This function is startup-only and must be called before request handling + begins. A second initialization attempt is considered a programming error + and raises ``RuntimeError``. + + If *mmdb_path* is ``None``, empty, or the file does not exist the + fallback is silently disabled — ip-api.com remains the sole resolver. + + Args: + mmdb_path: Absolute path to a ``GeoLite2-Country.mmdb`` file. + """ + if self._geoip_initialized: + raise RuntimeError("GeoIP reader already initialised") + if not mmdb_path: + return + from pathlib import Path # noqa: PLC0415 + + import geoip2.database # noqa: PLC0415 + + if not Path(mmdb_path).is_file(): + log.warning("geoip_mmdb_not_found", path=mmdb_path) + return + self._geoip_reader = geoip2.database.Reader(mmdb_path) + self._geoip_initialized = True + log.info("geoip_mmdb_loaded", path=mmdb_path) + + def _geoip_lookup(self, ip: str) -> GeoInfo | None: + """Attempt a local MaxMind GeoLite2 lookup for *ip*. + + Returns ``None`` when the reader is not initialised, the IP is not in + the database, or any other error occurs. + + Args: + ip: IPv4 or IPv6 address string. + + Returns: + A :class:`GeoInfo` with at least ``country_code`` populated, or + ``None`` when resolution is impossible. + """ + if self._geoip_reader is None: + return None + import geoip2.errors # noqa: PLC0415 + + try: + response = self._geoip_reader.country(ip) + code: str | None = response.country.iso_code or None + name: str | None = response.country.name or None + if code is None: + return None + return GeoInfo(country_code=code, country_name=name, asn=None, org=None) + except geoip2.errors.AddressNotFoundError: + return None + except Exception as exc: # noqa: BLE001 + log.warning("geoip_lookup_failed", ip=ip, error=str(exc)) + return None + + async def load_cache_from_db(self, db: aiosqlite.Connection) -> None: + """Pre-populate the in-memory cache from the ``geo_cache`` table. + + Should be called once during application startup so the service starts + with a warm cache instead of making cold API calls on the first request. + + Args: + db: Open :class:`aiosqlite.Connection` to the BanGUI application + database (not the fail2ban database). + """ + count = 0 + cache_entries: list[tuple[str, GeoInfo]] = [] + for row in await geo_cache_repo.load_all(db): + country_code: str | None = row["country_code"] + if country_code is None: + continue + ip: str = row["ip"] + cache_entries.append( + ( + ip, + GeoInfo( + country_code=country_code, + country_name=row["country_name"], + asn=row["asn"], + org=row["org"], + ), + ) + ) + count += 1 + + async with self._cache_lock: + for ip, info in cache_entries: + self._cache[ip] = info + log.info("geo_cache_loaded_from_db", entries=count) + + async def _store(self, ip: str, info: GeoInfo) -> None: + """Insert *info* into the cache, flushing if over capacity. + + When the IP resolved successfully (``country_code is not None``) it is + also added to the dirty set so :meth:`flush_dirty` can persist + it to the database on the next scheduled flush. + + Args: + ip: The IP address key. + info: The :class:`GeoInfo` to store. + """ + async with self._cache_lock: + if len(self._cache) >= _MAX_CACHE_SIZE: + self._cache.clear() + self._dirty.clear() + log.info("geo_cache_flushed", reason="capacity") + self._cache[ip] = info + if info.country_code is not None: + self._dirty.add(ip) + + async def lookup( + self, + ip: str, + http_session: aiohttp.ClientSession, + db: aiosqlite.Connection | None = None, + ) -> GeoInfo | None: + """Resolve an IP address to country, ASN, and organisation metadata. + + Results are cached in-process. If the cache exceeds ``_MAX_CACHE_SIZE`` + entries it is flushed before the new result is stored. + + Only successful resolutions (``country_code is not None``) are written to + the persistent cache when *db* is provided. Failed lookups are **not** + cached so they are retried on the next call. + + Args: + ip: IPv4 or IPv6 address string. + http_session: Shared :class:`aiohttp.ClientSession` (from + ``app.state.http_session``). + db: Optional BanGUI application database. When provided, successful + lookups are persisted for cross-restart cache warming. + + Returns: + A :class:`GeoInfo` instance, or ``None`` when the lookup fails + in a way that should prevent the caller from caching a bad result + (e.g. network timeout). + """ + if ip in self._cache: + return self._cache[ip] + + # Negative cache: skip IPs that recently failed to avoid hammering the API. + neg_ts = self._neg_cache.get(ip) + if neg_ts is not None and (time.monotonic() - neg_ts) < _NEG_CACHE_TTL: + return GeoInfo(country_code=None, country_name=None, asn=None, org=None) + + url: str = _API_URL.format(ip=ip) + api_ok = False + try: + async with http_session.get(url, timeout=aiohttp.ClientTimeout(total=_REQUEST_TIMEOUT)) as resp: + if resp.status != 200: + log.warning("geo_lookup_non_200", ip=ip, status=resp.status) + else: + data: dict[str, object] = await resp.json(content_type=None) + if data.get("status") == "success": + api_ok = True + result = self._parse_single_response(data) + await self._store(ip, result) + if result.country_code is not None and db is not None: + try: + await geo_cache_repo.upsert_entry_and_commit( + db=db, + ip=ip, + country_code=result.country_code, + country_name=result.country_name, + asn=result.asn, + org=result.org, + ) + except Exception as exc: # noqa: BLE001 + log.warning("geo_persist_failed", ip=ip, error=str(exc)) + log.debug("geo_lookup_success", ip=ip, country=result.country_code, asn=result.asn) + return result + log.debug( + "geo_lookup_failed", + ip=ip, + message=data.get("message", "unknown"), + ) + except Exception as exc: # noqa: BLE001 + log.warning( + "geo_lookup_request_failed", + ip=ip, + exc_type=type(exc).__name__, + error=repr(exc), + ) + + if not api_ok: + # Try local MaxMind database as fallback. + fallback = self._geoip_lookup(ip) + if fallback is not None: + await self._store(ip, fallback) + if fallback.country_code is not None and db is not None: + try: + await geo_cache_repo.upsert_entry_and_commit( + db=db, + ip=ip, + country_code=fallback.country_code, + country_name=fallback.country_name, + asn=fallback.asn, + org=fallback.org, + ) + except Exception as exc: # noqa: BLE001 + log.warning("geo_persist_failed", ip=ip, error=str(exc)) + log.debug("geo_geoip_fallback_success", ip=ip, country=fallback.country_code) + return fallback + + # Both resolvers failed — record in negative cache to avoid hammering. + async with self._cache_lock: + self._neg_cache[ip] = time.monotonic() + if db is not None: + try: + await geo_cache_repo.upsert_neg_entry_and_commit(db=db, ip=ip) + except Exception as exc: # noqa: BLE001 + log.warning("geo_persist_neg_failed", ip=ip, error=str(exc)) + + return GeoInfo(country_code=None, country_name=None, asn=None, org=None) + + def lookup_cached_only( + self, + ips: list[str], + ) -> tuple[dict[str, GeoInfo], list[str]]: + """Return cached geo data for *ips* without making any external API calls. + + Used by callers that want to return a fast response using only what is + already in memory, while deferring resolution of uncached IPs to a + background task. + + Args: + ips: IP address strings to look up. + + Returns: + A ``(geo_map, uncached)`` tuple where *geo_map* maps every IP that + was already in the in-memory cache to its :class:`GeoInfo`, and + *uncached* is the list of IPs that were not found in the cache. + Entries in the negative cache (recently failed) are **not** included + in *uncached* so they are not re-queued immediately. + """ + geo_map: dict[str, GeoInfo] = {} + uncached: list[str] = [] + now = time.monotonic() + + for ip in dict.fromkeys(ips): # deduplicate, preserve order + if ip in self._cache: + geo_map[ip] = self._cache[ip] + elif ip in self._neg_cache and (now - self._neg_cache[ip]) < _NEG_CACHE_TTL: + # Still within the cool-down window — do not re-queue. + pass + else: + uncached.append(ip) + + return geo_map, uncached + + async def lookup_batch( + self, + ips: list[str], + http_session: aiohttp.ClientSession, + db: aiosqlite.Connection | None = None, + ) -> dict[str, GeoInfo]: + """Resolve multiple IP addresses in bulk using ip-api.com batch endpoint. + + IPs already present in the in-memory cache are returned immediately + without making an HTTP request. Uncached IPs are sent to + ``http://ip-api.com/batch`` in chunks of up to :data:`_BATCH_SIZE`. + + Only successful resolutions (``country_code is not None``) are written to + the persistent cache when *db* is provided. Both positive and negative + entries are written in bulk using ``executemany`` (one round-trip per + chunk) rather than one ``execute`` per IP. + + Args: + ips: List of IP address strings to resolve. Duplicates are ignored. + http_session: Shared :class:`aiohttp.ClientSession`. + db: Optional BanGUI application database for persistent cache writes. + + Returns: + Dict mapping ``ip → GeoInfo`` for every input IP. IPs whose + resolution failed will have a ``GeoInfo`` with all-``None`` fields. + """ + geo_result: dict[str, GeoInfo] = {} + uncached: list[str] = [] + _empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None) + + unique_ips = list(dict.fromkeys(ips)) # deduplicate, preserve order + now = time.monotonic() + for ip in unique_ips: + if ip in self._cache: + geo_result[ip] = self._cache[ip] + elif ip in self._neg_cache and (now - self._neg_cache[ip]) < _NEG_CACHE_TTL: + # Recently failed — skip API call, return empty result. + geo_result[ip] = _empty + else: + uncached.append(ip) + + if not uncached: + return geo_result + + log.info("geo_batch_lookup_start", total=len(uncached)) + + for batch_idx, chunk_start in enumerate(range(0, len(uncached), _BATCH_SIZE)): + chunk = uncached[chunk_start : chunk_start + _BATCH_SIZE] + + # Throttle: pause between consecutive HTTP calls to stay within the + # ip-api.com free-tier rate limit (45 req/min). + if batch_idx > 0: + await asyncio.sleep(_BATCH_DELAY) + + # Retry transient failures (e.g. connection-reset from rate limit). + chunk_result: dict[str, GeoInfo] | None = None + for attempt in range(_BATCH_MAX_RETRIES + 1): + chunk_result = await self._batch_api_call(chunk, http_session) + # If every IP in the chunk came back with country_code=None and the + # batch wasn't tiny, that almost certainly means the whole request + # was rejected (connection reset / 429). Retry after a back-off. + all_failed = all( + info.country_code is None for info in chunk_result.values() + ) + if not all_failed or attempt >= _BATCH_MAX_RETRIES: + break + backoff = _BATCH_DELAY * (2 ** (attempt + 1)) + log.warning( + "geo_batch_retry", + attempt=attempt + 1, + chunk_size=len(chunk), + backoff=backoff, + ) + await asyncio.sleep(backoff) + + assert chunk_result is not None # noqa: S101 + + # Collect bulk-write rows instead of one execute per IP. + pos_rows: list[tuple[str, str | None, str | None, str | None, str | None]] = [] + neg_ips: list[str] = [] + + for ip, info in chunk_result.items(): + if info.country_code is not None: + # Successful API resolution. + await self._store(ip, info) + geo_result[ip] = info + if db is not None: + pos_rows.append( + (ip, info.country_code, info.country_name, info.asn, info.org) + ) + else: + # API failed — try local GeoIP fallback. + fallback = self._geoip_lookup(ip) + if fallback is not None: + await self._store(ip, fallback) + geo_result[ip] = fallback + if db is not None: + pos_rows.append( + ( + ip, + fallback.country_code, + fallback.country_name, + fallback.asn, + fallback.org, + ) + ) + else: + # Both resolvers failed — record in negative cache. + async with self._cache_lock: + self._neg_cache[ip] = time.monotonic() + geo_result[ip] = _empty + if db is not None: + neg_ips.append(ip) + + if db is not None and (pos_rows or neg_ips): + try: + await geo_cache_repo.bulk_upsert_entries_and_neg_entries_and_commit( + db, + pos_rows, + neg_ips, + ) + except Exception as exc: # noqa: BLE001 + log.warning( + "geo_batch_persist_failed", + positive_count=len(pos_rows), + negative_count=len(neg_ips), + error=str(exc), + ) + + log.info( + "geo_batch_lookup_complete", + requested=len(uncached), + resolved=sum(1 for g in geo_result.values() if g.country_code is not None), + ) + return geo_result + + async def _batch_api_call( + self, + ips: list[str], + http_session: aiohttp.ClientSession, + ) -> dict[str, GeoInfo]: + """Send one batch request to the ip-api.com batch endpoint. + + Args: + ips: Up to :data:`_BATCH_SIZE` IP address strings. + http_session: Shared HTTP session. + + Returns: + Dict mapping ``ip → GeoInfo`` for every IP in *ips*. IPs where the + API returned a failure record or the request raised an exception get + an all-``None`` :class:`GeoInfo`. + """ + empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None) + fallback: dict[str, GeoInfo] = dict.fromkeys(ips, empty) + + payload = [{"query": ip} for ip in ips] + try: + async with http_session.post( + _BATCH_API_URL, + json=payload, + timeout=aiohttp.ClientTimeout(total=_REQUEST_TIMEOUT * 2), + ) as resp: + if resp.status != 200: + log.warning("geo_batch_non_200", status=resp.status, count=len(ips)) + return fallback + data: list[dict[str, object]] = await resp.json(content_type=None) + except Exception as exc: # noqa: BLE001 + log.warning( + "geo_batch_request_failed", + count=len(ips), + exc_type=type(exc).__name__, + error=repr(exc), + ) + return fallback + + out: dict[str, GeoInfo] = {} + for entry in data: + ip_str: str = str(entry.get("query", "")) + if not ip_str: + continue + if entry.get("status") != "success": + out[ip_str] = empty + log.debug( + "geo_batch_entry_failed", + ip=ip_str, + message=entry.get("message", "unknown"), + ) + continue + out[ip_str] = self._parse_single_response(entry) + + # Fill any IPs missing from the response. + for ip in ips: + if ip not in out: + out[ip] = empty + + return out + + def _parse_single_response(self, data: dict[str, object]) -> GeoInfo: + """Build a :class:`GeoInfo` from a single ip-api.com response dict. + + Args: + data: A ``status == "success"`` JSON response from ip-api.com. + + Returns: + Populated :class:`GeoInfo`. + """ + country_code: str | None = self._str_or_none(data.get("countryCode")) + country_name: str | None = self._str_or_none(data.get("country")) + asn_raw: str | None = self._str_or_none(data.get("as")) + org_raw: str | None = self._str_or_none(data.get("org")) + + # ip-api returns "AS12345 Some Org" in both "as" and "org". + asn: str | None = asn_raw.split()[0] if asn_raw else None + + return GeoInfo( + country_code=country_code, + country_name=country_name, + asn=asn, + org=org_raw, + ) + + def _str_or_none(self, value: object) -> str | None: + """Return *value* as a non-empty string, or ``None``. + + Args: + value: Raw JSON value which may be ``None``, empty, or a string. + + Returns: + Stripped string if non-empty, else ``None``. + """ + if value is None: + return None + s = str(value).strip() + return s if s else None + + async def flush_dirty(self, db: aiosqlite.Connection) -> int: + """Persist all new in-memory geo entries to the ``geo_cache`` table. + + Takes an atomic snapshot of the dirty set, clears it, then batch-inserts + all entries that are still present in the cache using a single + ``executemany`` call and one ``COMMIT``. This is the only place that + writes to the persistent cache during normal operation after startup. + + If the database write fails the entries are re-added to the dirty set + so they will be retried on the next flush cycle. + + Args: + db: Open :class:`aiosqlite.Connection` to the BanGUI application + database. + + Returns: + The number of rows successfully upserted. + """ + async with self._cache_lock: + if not self._dirty: + return 0 + + # Atomically snapshot and clear while holding the cache lock. + to_flush = self._dirty.copy() + self._dirty.clear() + + rows = [ + (ip, self._cache[ip].country_code, self._cache[ip].country_name, self._cache[ip].asn, self._cache[ip].org) + for ip in to_flush + if ip in self._cache + ] + + if not rows: + return 0 + + try: + await geo_cache_repo.bulk_upsert_entries_and_commit(db, rows) + except Exception as exc: # noqa: BLE001 + log.warning("geo_flush_dirty_failed", error=str(exc)) + # Re-add to dirty so they are retried on the next flush cycle. + self._dirty.update(to_flush) + return 0 + + log.info("geo_flush_dirty_complete", count=len(rows)) + return len(rows) diff --git a/backend/app/services/geo_service.py b/backend/app/services/geo_service.py index 2cd3e0d..3861efd 100644 --- a/backend/app/services/geo_service.py +++ b/backend/app/services/geo_service.py @@ -1,151 +1,72 @@ -"""Geo service. +"""Geo service — backward compatibility wrappers. -Resolves IP addresses to their country, ASN, and organisation using the -`ip-api.com `_ JSON API. Results are cached in two tiers: +DEPRECATED: This module is kept for backward compatibility only. New code should +use :class:`GeoCache` directly from ``app.services.geo_cache``. The underlying +implementation has been refactored to eliminate module-level mutable state. -1. **In-memory dict** — fastest; survives for the life of the process. -2. **Persistent SQLite table** (``geo_cache``) — survives restarts; loaded - into the in-memory dict during application startup via - :func:`load_cache_from_db`. +The :class:`GeoCache` instance should be injected into services and routers +via dependency injection, and stored on ``app.state.geo_cache``. -Only *successful* lookups (those returning a non-``None`` ``country_code``) -are written to the persistent cache. Failed lookups are **not** cached so -they will be retried on the next request. - -For bulk operations the batch endpoint ``http://ip-api.com/batch`` is used -(up to 100 IPs per HTTP call) which is far more efficient than one-at-a-time -requests. Use :func:`lookup_batch` from the ban or blocklist services. - -Usage:: - - import aiohttp - import aiosqlite - # Use the geo_service directly in application startup - async with aiosqlite.connect("bangui.db") as db: - await geo_service.load_cache_from_db(db) - - async with aiohttp.ClientSession() as session: - # single lookup - info = await geo_service.lookup("1.2.3.4", session) - if info: - # info.country_code == "DE" - ... # use the GeoInfo object in your application - - # bulk lookup (more efficient for large sets) - geo_map = await geo_service.lookup_batch(["1.2.3.4", "5.6.7.8"], session) +See :class:`app.services.geo_cache.GeoCache` for the implementation. """ from __future__ import annotations -import asyncio -import time from typing import TYPE_CHECKING -import aiohttp -import structlog - +from app.services.geo_cache import GeoCache from app.models.geo import GeoInfo -from app.repositories import geo_cache_repo if TYPE_CHECKING: + import aiohttp import aiosqlite - import geoip2.database - import geoip2.errors -log: structlog.stdlib.BoundLogger = structlog.get_logger() +__all__ = [ + "GeoCache", + "clear_cache", + "clear_neg_cache", + "is_cached", + "lookup", + "lookup_batch", + "lookup_cached_only", + "cache_stats", + "count_unresolved", + "get_unresolved_ips", + "load_cache_from_db", + "flush_dirty", + "re_resolve_all", + "init_geoip", +] -# --------------------------------------------------------------------------- -# Constants -# --------------------------------------------------------------------------- - -#: ip-api.com single-IP lookup endpoint (HTTP only on the free tier). -_API_URL: str = ( - "http://ip-api.com/json/{ip}?fields=status,message,country,countryCode,org,as" -) - -#: ip-api.com batch endpoint — accepts up to 100 IPs per POST. -_BATCH_API_URL: str = ( - "http://ip-api.com/batch?fields=status,message,country,countryCode,org,as,query" -) - -#: Maximum IPs per batch request (ip-api.com hard limit is 100). -_BATCH_SIZE: int = 100 - -#: Maximum number of entries kept in the in-process cache before it is -#: flushed completely. A simple eviction strategy — the cache is cheap to -#: rebuild from the persistent store. -_MAX_CACHE_SIZE: int = 50_000 - -#: Timeout for outgoing geo API requests in seconds. -_REQUEST_TIMEOUT: float = 5.0 - -#: How many seconds a failed lookup result is suppressed before the IP is -#: eligible for a new API attempt. Default: 5 minutes. -_NEG_CACHE_TTL: float = 300.0 - -#: Minimum delay in seconds between consecutive batch HTTP requests to -#: ip-api.com. The free tier allows 45 requests/min; 1.5 s ≈ 40 req/min. -_BATCH_DELAY: float = 1.5 - -#: Maximum number of retries for a batch chunk that fails with a -#: transient error (e.g. connection reset due to rate limiting). -_BATCH_MAX_RETRIES: int = 2 - -# --------------------------------------------------------------------------- -# Internal cache -# --------------------------------------------------------------------------- - -#: Module-level in-memory cache: ``ip → GeoInfo`` (positive results only). -_cache: dict[str, GeoInfo] = {} - -#: Negative cache: ``ip → epoch timestamp`` of last failed lookup attempt. -#: Entries within :data:`_NEG_CACHE_TTL` seconds are not re-queried. -_neg_cache: dict[str, float] = {} - -#: IPs added to :data:`_cache` but not yet persisted to the database. -#: Consumed and cleared atomically by :func:`flush_dirty`. -_dirty: set[str] = set() - -#: Optional MaxMind GeoLite2 reader initialised by :func:`init_geoip`. -_geoip_reader: geoip2.database.Reader | None = None - -#: Indicates whether :func:`init_geoip` has already been called. -#: This function is startup-only and must not be invoked again while the -#: process is handling requests. -_geoip_initialized: bool = False - -#: Lock protecting mutations to the in-memory geo caches. -_cache_lock: asyncio.Lock = asyncio.Lock() +#: Deprecated: Module-level cache instance for backward compatibility. +#: This exists only to provide backward-compatible module-level functions. +#: New code should inject GeoCache instances directly. +_default_geo_cache: GeoCache = GeoCache() async def clear_cache() -> None: - """Flush both the positive and negative lookup caches. + """(DEPRECATED) Flush both the positive and negative lookup caches. - Also clears the dirty set so any pending-but-unpersisted entries are - discarded. Useful in tests and when the operator suspects stale data. + Use :meth:`GeoCache.clear` instead. This function delegates to the + default module-level instance for backward compatibility only. """ - async with _cache_lock: - _cache.clear() - _neg_cache.clear() - _dirty.clear() + await _default_geo_cache.clear() async def clear_neg_cache() -> None: - """Flush only the negative (failed-lookups) cache. + """(DEPRECATED) Flush only the negative (failed-lookups) cache. - Useful when triggering a manual re-resolve so that previously failed - IPs are immediately eligible for a new API attempt. + Use :meth:`GeoCache.clear_neg_cache` instead. This function delegates to + the default module-level instance for backward compatibility only. """ - async with _cache_lock: - _neg_cache.clear() + await _default_geo_cache.clear_neg_cache() def is_cached(ip: str) -> bool: - """Return ``True`` if *ip* has a positive entry in the in-memory cache. + """(DEPRECATED) Return ``True`` if *ip* has a positive cache entry. - A positive entry is one with a non-``None`` ``country_code``. This is - useful for skipping IPs that have already been resolved when building - a list for :func:`lookup_batch`. + Use :meth:`GeoCache.is_cached` instead. This function delegates to the + default module-level instance for backward compatibility only. Args: ip: IPv4 or IPv6 address string. @@ -153,14 +74,14 @@ def is_cached(ip: str) -> bool: Returns: ``True`` when *ip* is in the cache with a known country code. """ - return ip in _cache and _cache[ip].country_code is not None + return _default_geo_cache.is_cached(ip) async def cache_stats(db: aiosqlite.Connection) -> dict[str, int]: - """Return diagnostic counters for the geo cache subsystem. + """(DEPRECATED) Return diagnostic counters for the geo cache subsystem. - Queries the persistent store for the number of unresolved entries and - combines it with in-memory counters. + Use :meth:`GeoCache.cache_stats` instead. This function delegates to the + default module-level instance for backward compatibility only. Args: db: Open BanGUI application database connection. @@ -169,24 +90,21 @@ async def cache_stats(db: aiosqlite.Connection) -> dict[str, int]: Dict with keys ``cache_size``, ``unresolved``, ``neg_cache_size``, and ``dirty_size``. """ - unresolved = await geo_cache_repo.count_unresolved(db) - - return { - "cache_size": len(_cache), - "unresolved": unresolved, - "neg_cache_size": len(_neg_cache), - "dirty_size": len(_dirty), - } + return await _default_geo_cache.cache_stats(db) async def count_unresolved(db: aiosqlite.Connection) -> int: - """Return the number of unresolved entries in the persistent geo cache.""" + """(DEPRECATED) Return the number of unresolved entries in the geo cache. - return await geo_cache_repo.count_unresolved(db) + Use :meth:`GeoCache.count_unresolved` instead. + """ + return await _default_geo_cache.count_unresolved(db) async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]: - """Return geo cache IPs where the country code has not yet been resolved. + """(DEPRECATED) Return IPs with NULL country_code in the persistent cache. + + Use :meth:`GeoCache.get_unresolved_ips` instead. Args: db: Open BanGUI application database connection. @@ -194,187 +112,32 @@ async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]: Returns: List of IP addresses that are candidates for re-resolution. """ - return await geo_cache_repo.get_unresolved_ips(db) - - -async def re_resolve_all( - db: aiosqlite.Connection, - http_session: aiohttp.ClientSession, -) -> dict[str, int]: - """Retry geo resolution for all unresolved cache entries. - - This helper clears the in-memory negative cache before attempting a - fresh batch lookup, then returns counters for how many IPs were retried - and how many gained a resolved country code. - - Args: - db: BanGUI application database connection. - http_session: Shared aiohttp client session. - - Returns: - A dict with ``resolved`` and ``total`` counts. - """ - unresolved = await get_unresolved_ips(db) - if not unresolved: - return {"resolved": 0, "total": 0} - - await clear_neg_cache() - geo_map = await lookup_batch(unresolved, http_session, db=db) - resolved_count = sum( - 1 for info in geo_map.values() if info.country_code is not None - ) - - log.info( - "geo_re_resolve_complete", - total=len(unresolved), - resolved=resolved_count, - ) - return {"resolved": resolved_count, "total": len(unresolved)} + return await _default_geo_cache.get_unresolved_ips(db) def init_geoip(mmdb_path: str | None) -> None: - """Initialise the MaxMind GeoLite2-Country database reader. + """(DEPRECATED) Initialise the MaxMind GeoLite2-Country database reader. - This function is startup-only and must be called before request handling - begins. A second initialization attempt is considered a programming error - and raises ``RuntimeError``. - - If *mmdb_path* is ``None``, empty, or the file does not exist the - fallback is silently disabled — ip-api.com remains the sole resolver. + Use :meth:`GeoCache.init_geoip` instead. This function delegates to the + default module-level instance for backward compatibility only. Args: mmdb_path: Absolute path to a ``GeoLite2-Country.mmdb`` file. """ - global _geoip_reader, _geoip_initialized # noqa: PLW0603 - if _geoip_initialized: - raise RuntimeError("GeoIP reader already initialised") - if not mmdb_path: - return - from pathlib import Path # noqa: PLC0415 - - import geoip2.database # noqa: PLC0415 - - if not Path(mmdb_path).is_file(): - log.warning("geoip_mmdb_not_found", path=mmdb_path) - return - _geoip_reader = geoip2.database.Reader(mmdb_path) - _geoip_initialized = True - log.info("geoip_mmdb_loaded", path=mmdb_path) - - -def _geoip_lookup(ip: str) -> GeoInfo | None: - """Attempt a local MaxMind GeoLite2 lookup for *ip*. - - Returns ``None`` when the reader is not initialised, the IP is not in - the database, or any other error occurs. - - Args: - ip: IPv4 or IPv6 address string. - - Returns: - A :class:`GeoInfo` with at least ``country_code`` populated, or - ``None`` when resolution is impossible. - """ - if _geoip_reader is None: - return None - import geoip2.errors # noqa: PLC0415 - - try: - response = _geoip_reader.country(ip) - code: str | None = response.country.iso_code or None - name: str | None = response.country.name or None - if code is None: - return None - return GeoInfo(country_code=code, country_name=name, asn=None, org=None) - except geoip2.errors.AddressNotFoundError: - return None - except Exception as exc: # noqa: BLE001 - log.warning("geoip_lookup_failed", ip=ip, error=str(exc)) - return None - - -# --------------------------------------------------------------------------- -# Persistent cache I/O -# --------------------------------------------------------------------------- + _default_geo_cache.init_geoip(mmdb_path) async def load_cache_from_db(db: aiosqlite.Connection) -> None: - """Pre-populate the in-memory cache from the ``geo_cache`` table. + """(DEPRECATED) Pre-populate the in-memory cache from ``geo_cache`` table. - Should be called once during application startup so the service starts - with a warm cache instead of making cold API calls on the first request. + Use :meth:`GeoCache.load_cache_from_db` instead. This function delegates + to the default module-level instance for backward compatibility only. Args: db: Open :class:`aiosqlite.Connection` to the BanGUI application database (not the fail2ban database). """ - count = 0 - cache_entries: list[tuple[str, GeoInfo]] = [] - for row in await geo_cache_repo.load_all(db): - country_code: str | None = row["country_code"] - if country_code is None: - continue - ip: str = row["ip"] - cache_entries.append( - ( - ip, - GeoInfo( - country_code=country_code, - country_name=row["country_name"], - asn=row["asn"], - org=row["org"], - ), - ) - ) - count += 1 - - async with _cache_lock: - for ip, info in cache_entries: - _cache[ip] = info - log.info("geo_cache_loaded_from_db", entries=count) - - -async def _persist_entry( - db: aiosqlite.Connection, - ip: str, - info: GeoInfo, -) -> None: - """Upsert a resolved :class:`GeoInfo` into the ``geo_cache`` table. - - Only called when ``info.country_code`` is not ``None`` so the persistent - store never contains empty placeholder rows. - - Args: - db: BanGUI application database connection. - ip: IP address string. - info: Resolved geo data to persist. - """ - await geo_cache_repo.upsert_entry( - db=db, - ip=ip, - country_code=info.country_code, - country_name=info.country_name, - asn=info.asn, - org=info.org, - ) - - -async def _persist_neg_entry(db: aiosqlite.Connection, ip: str) -> None: - """Record a failed lookup attempt in ``geo_cache`` with all-NULL fields. - - Uses ``INSERT OR IGNORE`` so that an existing *positive* entry (one that - has a ``country_code``) is never overwritten by a later failure. - - Args: - db: BanGUI application database connection. - ip: IP address string whose resolution failed. - """ - await geo_cache_repo.upsert_neg_entry(db=db, ip=ip) - - -# --------------------------------------------------------------------------- -# Public API — single lookup -# --------------------------------------------------------------------------- + await _default_geo_cache.load_cache_from_db(db) async def lookup( @@ -382,144 +145,37 @@ async def lookup( http_session: aiohttp.ClientSession, db: aiosqlite.Connection | None = None, ) -> GeoInfo | None: - """Resolve an IP address to country, ASN, and organisation metadata. + """(DEPRECATED) Resolve an IP address to country, ASN, organisation metadata. - Results are cached in-process. If the cache exceeds ``_MAX_CACHE_SIZE`` - entries it is flushed before the new result is stored. - - Only successful resolutions (``country_code is not None``) are written to - the persistent cache when *db* is provided. Failed lookups are **not** - cached so they are retried on the next call. + Use :meth:`GeoCache.lookup` instead. This function delegates to the + default module-level instance for backward compatibility only. Args: ip: IPv4 or IPv6 address string. - http_session: Shared :class:`aiohttp.ClientSession` (from - ``app.state.http_session``). - db: Optional BanGUI application database. When provided, successful - lookups are persisted for cross-restart cache warming. + http_session: Shared :class:`aiohttp.ClientSession`. + db: Optional BanGUI application database for persistence. Returns: - A :class:`GeoInfo` instance, or ``None`` when the lookup fails - in a way that should prevent the caller from caching a bad result - (e.g. network timeout). + A :class:`GeoInfo` instance, or ``None`` on lookup failure. """ - if ip in _cache: - return _cache[ip] - - # Negative cache: skip IPs that recently failed to avoid hammering the API. - neg_ts = _neg_cache.get(ip) - if neg_ts is not None and (time.monotonic() - neg_ts) < _NEG_CACHE_TTL: - return GeoInfo(country_code=None, country_name=None, asn=None, org=None) - - url: str = _API_URL.format(ip=ip) - api_ok = False - try: - async with http_session.get(url, timeout=aiohttp.ClientTimeout(total=_REQUEST_TIMEOUT)) as resp: - if resp.status != 200: - log.warning("geo_lookup_non_200", ip=ip, status=resp.status) - else: - data: dict[str, object] = await resp.json(content_type=None) - if data.get("status") == "success": - api_ok = True - result = _parse_single_response(data) - await _store(ip, result) - if result.country_code is not None and db is not None: - try: - await geo_cache_repo.upsert_entry_and_commit( - db=db, - ip=ip, - country_code=result.country_code, - country_name=result.country_name, - asn=result.asn, - org=result.org, - ) - except Exception as exc: # noqa: BLE001 - log.warning("geo_persist_failed", ip=ip, error=str(exc)) - log.debug("geo_lookup_success", ip=ip, country=result.country_code, asn=result.asn) - return result - log.debug( - "geo_lookup_failed", - ip=ip, - message=data.get("message", "unknown"), - ) - except Exception as exc: # noqa: BLE001 - log.warning( - "geo_lookup_request_failed", - ip=ip, - exc_type=type(exc).__name__, - error=repr(exc), - ) - - if not api_ok: - # Try local MaxMind database as fallback. - fallback = _geoip_lookup(ip) - if fallback is not None: - await _store(ip, fallback) - if fallback.country_code is not None and db is not None: - try: - await geo_cache_repo.upsert_entry_and_commit( - db=db, - ip=ip, - country_code=fallback.country_code, - country_name=fallback.country_name, - asn=fallback.asn, - org=fallback.org, - ) - except Exception as exc: # noqa: BLE001 - log.warning("geo_persist_failed", ip=ip, error=str(exc)) - log.debug("geo_geoip_fallback_success", ip=ip, country=fallback.country_code) - return fallback - - # Both resolvers failed — record in negative cache to avoid hammering. - async with _cache_lock: - _neg_cache[ip] = time.monotonic() - if db is not None: - try: - await geo_cache_repo.upsert_neg_entry_and_commit(db=db, ip=ip) - except Exception as exc: # noqa: BLE001 - log.warning("geo_persist_neg_failed", ip=ip, error=str(exc)) - - return GeoInfo(country_code=None, country_name=None, asn=None, org=None) - - -# --------------------------------------------------------------------------- -# Public API — batch lookup -# --------------------------------------------------------------------------- + return await _default_geo_cache.lookup(ip, http_session, db=db) def lookup_cached_only( ips: list[str], ) -> tuple[dict[str, GeoInfo], list[str]]: - """Return cached geo data for *ips* without making any external API calls. + """(DEPRECATED) Return cached geo data without making external API calls. - Used by callers that want to return a fast response using only what is - already in memory, while deferring resolution of uncached IPs to a - background task. + Use :meth:`GeoCache.lookup_cached_only` instead. This function delegates + to the default module-level instance for backward compatibility only. Args: ips: IP address strings to look up. Returns: - A ``(geo_map, uncached)`` tuple where *geo_map* maps every IP that - was already in the in-memory cache to its :class:`GeoInfo`, and - *uncached* is the list of IPs that were not found in the cache. - Entries in the negative cache (recently failed) are **not** included - in *uncached* so they are not re-queued immediately. + A ``(geo_map, uncached)`` tuple. """ - geo_map: dict[str, GeoInfo] = {} - uncached: list[str] = [] - now = time.monotonic() - - for ip in dict.fromkeys(ips): # deduplicate, preserve order - if ip in _cache: - geo_map[ip] = _cache[ip] - elif ip in _neg_cache and (now - _neg_cache[ip]) < _NEG_CACHE_TTL: - # Still within the cool-down window — do not re-queue. - pass - else: - uncached.append(ip) - - return geo_map, uncached + return _default_geo_cache.lookup_cached_only(ips) async def lookup_batch( @@ -527,274 +183,27 @@ async def lookup_batch( http_session: aiohttp.ClientSession, db: aiosqlite.Connection | None = None, ) -> dict[str, GeoInfo]: - """Resolve multiple IP addresses in bulk using ip-api.com batch endpoint. + """(DEPRECATED) Resolve multiple IPs in bulk using ip-api.com batch endpoint. - IPs already present in the in-memory cache are returned immediately - without making an HTTP request. Uncached IPs are sent to - ``http://ip-api.com/batch`` in chunks of up to :data:`_BATCH_SIZE`. - - Only successful resolutions (``country_code is not None``) are written to - the persistent cache when *db* is provided. Both positive and negative - entries are written in bulk using ``executemany`` (one round-trip per - chunk) rather than one ``execute`` per IP. + Use :meth:`GeoCache.lookup_batch` instead. This function delegates to the + default module-level instance for backward compatibility only. Args: - ips: List of IP address strings to resolve. Duplicates are ignored. + ips: List of IP address strings to resolve. http_session: Shared :class:`aiohttp.ClientSession`. db: Optional BanGUI application database for persistent cache writes. Returns: - Dict mapping ``ip → GeoInfo`` for every input IP. IPs whose - resolution failed will have a ``GeoInfo`` with all-``None`` fields. + Dict mapping ``ip → GeoInfo`` for every input IP. """ - geo_result: dict[str, GeoInfo] = {} - uncached: list[str] = [] - _empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None) - - unique_ips = list(dict.fromkeys(ips)) # deduplicate, preserve order - now = time.monotonic() - for ip in unique_ips: - if ip in _cache: - geo_result[ip] = _cache[ip] - elif ip in _neg_cache and (now - _neg_cache[ip]) < _NEG_CACHE_TTL: - # Recently failed — skip API call, return empty result. - geo_result[ip] = _empty - else: - uncached.append(ip) - - if not uncached: - return geo_result - - log.info("geo_batch_lookup_start", total=len(uncached)) - - for batch_idx, chunk_start in enumerate(range(0, len(uncached), _BATCH_SIZE)): - chunk = uncached[chunk_start : chunk_start + _BATCH_SIZE] - - # Throttle: pause between consecutive HTTP calls to stay within the - # ip-api.com free-tier rate limit (45 req/min). - if batch_idx > 0: - await asyncio.sleep(_BATCH_DELAY) - - # Retry transient failures (e.g. connection-reset from rate limit). - chunk_result: dict[str, GeoInfo] | None = None - for attempt in range(_BATCH_MAX_RETRIES + 1): - chunk_result = await _batch_api_call(chunk, http_session) - # If every IP in the chunk came back with country_code=None and the - # batch wasn't tiny, that almost certainly means the whole request - # was rejected (connection reset / 429). Retry after a back-off. - all_failed = all( - info.country_code is None for info in chunk_result.values() - ) - if not all_failed or attempt >= _BATCH_MAX_RETRIES: - break - backoff = _BATCH_DELAY * (2 ** (attempt + 1)) - log.warning( - "geo_batch_retry", - attempt=attempt + 1, - chunk_size=len(chunk), - backoff=backoff, - ) - await asyncio.sleep(backoff) - - assert chunk_result is not None # noqa: S101 - - # Collect bulk-write rows instead of one execute per IP. - pos_rows: list[tuple[str, str | None, str | None, str | None, str | None]] = [] - neg_ips: list[str] = [] - - for ip, info in chunk_result.items(): - if info.country_code is not None: - # Successful API resolution. - await _store(ip, info) - geo_result[ip] = info - if db is not None: - pos_rows.append( - (ip, info.country_code, info.country_name, info.asn, info.org) - ) - else: - # API failed — try local GeoIP fallback. - fallback = _geoip_lookup(ip) - if fallback is not None: - await _store(ip, fallback) - geo_result[ip] = fallback - if db is not None: - pos_rows.append( - ( - ip, - fallback.country_code, - fallback.country_name, - fallback.asn, - fallback.org, - ) - ) - else: - # Both resolvers failed — record in negative cache. - async with _cache_lock: - _neg_cache[ip] = time.monotonic() - geo_result[ip] = _empty - if db is not None: - neg_ips.append(ip) - - if db is not None and (pos_rows or neg_ips): - try: - await geo_cache_repo.bulk_upsert_entries_and_neg_entries_and_commit( - db, - pos_rows, - neg_ips, - ) - except Exception as exc: # noqa: BLE001 - log.warning( - "geo_batch_persist_failed", - positive_count=len(pos_rows), - negative_count=len(neg_ips), - error=str(exc), - ) - - log.info( - "geo_batch_lookup_complete", - requested=len(uncached), - resolved=sum(1 for g in geo_result.values() if g.country_code is not None), - ) - return geo_result - - -async def _batch_api_call( - ips: list[str], - http_session: aiohttp.ClientSession, -) -> dict[str, GeoInfo]: - """Send one batch request to the ip-api.com batch endpoint. - - Args: - ips: Up to :data:`_BATCH_SIZE` IP address strings. - http_session: Shared HTTP session. - - Returns: - Dict mapping ``ip → GeoInfo`` for every IP in *ips*. IPs where the - API returned a failure record or the request raised an exception get - an all-``None`` :class:`GeoInfo`. - """ - empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None) - fallback: dict[str, GeoInfo] = dict.fromkeys(ips, empty) - - payload = [{"query": ip} for ip in ips] - try: - async with http_session.post( - _BATCH_API_URL, - json=payload, - timeout=aiohttp.ClientTimeout(total=_REQUEST_TIMEOUT * 2), - ) as resp: - if resp.status != 200: - log.warning("geo_batch_non_200", status=resp.status, count=len(ips)) - return fallback - data: list[dict[str, object]] = await resp.json(content_type=None) - except Exception as exc: # noqa: BLE001 - log.warning( - "geo_batch_request_failed", - count=len(ips), - exc_type=type(exc).__name__, - error=repr(exc), - ) - return fallback - - out: dict[str, GeoInfo] = {} - for entry in data: - ip_str: str = str(entry.get("query", "")) - if not ip_str: - continue - if entry.get("status") != "success": - out[ip_str] = empty - log.debug( - "geo_batch_entry_failed", - ip=ip_str, - message=entry.get("message", "unknown"), - ) - continue - out[ip_str] = _parse_single_response(entry) - - # Fill any IPs missing from the response. - for ip in ips: - if ip not in out: - out[ip] = empty - - return out - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def _parse_single_response(data: dict[str, object]) -> GeoInfo: - """Build a :class:`GeoInfo` from a single ip-api.com response dict. - - Args: - data: A ``status == "success"`` JSON response from ip-api.com. - - Returns: - Populated :class:`GeoInfo`. - """ - country_code: str | None = _str_or_none(data.get("countryCode")) - country_name: str | None = _str_or_none(data.get("country")) - asn_raw: str | None = _str_or_none(data.get("as")) - org_raw: str | None = _str_or_none(data.get("org")) - - # ip-api returns "AS12345 Some Org" in both "as" and "org". - asn: str | None = asn_raw.split()[0] if asn_raw else None - - return GeoInfo( - country_code=country_code, - country_name=country_name, - asn=asn, - org=org_raw, - ) - - -def _str_or_none(value: object) -> str | None: - """Return *value* as a non-empty string, or ``None``. - - Args: - value: Raw JSON value which may be ``None``, empty, or a string. - - Returns: - Stripped string if non-empty, else ``None``. - """ - if value is None: - return None - s = str(value).strip() - return s if s else None - - -async def _store(ip: str, info: GeoInfo) -> None: - """Insert *info* into the module-level cache, flushing if over capacity. - - When the IP resolved successfully (``country_code is not None``) it is - also added to the :data:`_dirty` set so :func:`flush_dirty` can persist - it to the database on the next scheduled flush. - - Args: - ip: The IP address key. - info: The :class:`GeoInfo` to store. - """ - async with _cache_lock: - if len(_cache) >= _MAX_CACHE_SIZE: - _cache.clear() - _dirty.clear() - log.info("geo_cache_flushed", reason="capacity") - _cache[ip] = info - if info.country_code is not None: - _dirty.add(ip) + return await _default_geo_cache.lookup_batch(ips, http_session, db=db) async def flush_dirty(db: aiosqlite.Connection) -> int: - """Persist all new in-memory geo entries to the ``geo_cache`` table. + """(DEPRECATED) Persist all new in-memory geo entries to the database. - Takes an atomic snapshot of :data:`_dirty`, clears it, then batch-inserts - all entries that are still present in :data:`_cache` using a single - ``executemany`` call and one ``COMMIT``. This is the only place that - writes to the persistent cache during normal operation after startup. - - If the database write fails the entries are re-added to :data:`_dirty` - so they will be retried on the next flush cycle. + Use :meth:`GeoCache.flush_dirty` instead. This function delegates to the + default module-level instance for backward compatibility only. Args: db: Open :class:`aiosqlite.Connection` to the BanGUI application @@ -803,30 +212,23 @@ async def flush_dirty(db: aiosqlite.Connection) -> int: Returns: The number of rows successfully upserted. """ - async with _cache_lock: - if not _dirty: - return 0 + return await _default_geo_cache.flush_dirty(db) - # Atomically snapshot and clear while holding the cache lock. - to_flush = _dirty.copy() - _dirty.clear() - rows = [ - (ip, _cache[ip].country_code, _cache[ip].country_name, _cache[ip].asn, _cache[ip].org) - for ip in to_flush - if ip in _cache - ] +async def re_resolve_all( + db: aiosqlite.Connection, + http_session: aiohttp.ClientSession, +) -> dict[str, int]: + """(DEPRECATED) Retry geo resolution for all unresolved cache entries. - if not rows: - return 0 + Use :meth:`GeoCache.re_resolve_all` instead. This function delegates to + the default module-level instance for backward compatibility only. - try: - await geo_cache_repo.bulk_upsert_entries_and_commit(db, rows) - except Exception as exc: # noqa: BLE001 - log.warning("geo_flush_dirty_failed", error=str(exc)) - # Re-add to dirty so they are retried on the next flush cycle. - _dirty.update(to_flush) - return 0 + Args: + db: BanGUI application database connection. + http_session: Shared aiohttp client session. - log.info("geo_flush_dirty_complete", count=len(rows)) - return len(rows) + Returns: + A dict with ``resolved`` and ``total`` counts. + """ + return await _default_geo_cache.re_resolve_all(db, http_session) diff --git a/backend/app/startup.py b/backend/app/startup.py index f07bc74..84b8133 100644 --- a/backend/app/startup.py +++ b/backend/app/startup.py @@ -16,7 +16,8 @@ import structlog from apscheduler.schedulers.asyncio import AsyncIOScheduler # type: ignore[import-untyped] from app.db import init_db, open_db -from app.services import geo_service, setup_service +from app.services import setup_service +from app.services.geo_cache import GeoCache from app.tasks import blocklist_import, geo_cache_flush, geo_re_resolve, health_check, history_sync from app.utils.async_utils import run_blocking from app.utils.jail_config import ensure_jail_configs @@ -105,16 +106,18 @@ async def startup_shared_resources( overrides=persisted_runtime_settings, ) + # Create and initialize the GeoCache instance + geo_cache = GeoCache() if Path(settings.database_path).resolve() != original_db_path: runtime_db = await open_db(settings.database_path) try: - await geo_service.load_cache_from_db(runtime_db) - unresolved_count = await geo_service.count_unresolved(runtime_db) + await geo_cache.load_cache_from_db(runtime_db) + unresolved_count = await geo_cache.count_unresolved(runtime_db) finally: await runtime_db.close() else: - await geo_service.load_cache_from_db(startup_db) - unresolved_count = await geo_service.count_unresolved(startup_db) + await geo_cache.load_cache_from_db(startup_db) + unresolved_count = await geo_cache.count_unresolved(startup_db) finally: await startup_db.close() @@ -124,7 +127,8 @@ async def startup_shared_resources( log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count) http_session: aiohttp.ClientSession = _create_http_session(settings) - geo_service.init_geoip(settings.geoip_db_path) + geo_cache.init_geoip(settings.geoip_db_path) + app.state.geo_cache = geo_cache scheduler: AsyncIOScheduler | None = None try: diff --git a/backend/app/tasks/geo_cache_flush.py b/backend/app/tasks/geo_cache_flush.py index 2d3f7f8..d5ccd22 100644 --- a/backend/app/tasks/geo_cache_flush.py +++ b/backend/app/tasks/geo_cache_flush.py @@ -1,7 +1,7 @@ """Geo cache flush background task. Registers an APScheduler job that periodically persists newly resolved IP -geo entries from the in-memory ``_dirty`` set to the ``geo_cache`` table. +geo entries from the in-memory dirty set to the ``geo_cache`` table. After Task 2 removed geo cache writes from GET requests, newly resolved IPs are only held in the in-memory cache until this task flushes them. With the @@ -15,7 +15,7 @@ from typing import TYPE_CHECKING import structlog -from app.services import geo_service +from app.services.geo_cache import GeoCache from app.tasks.db import task_db from app.utils.runtime_state import get_effective_settings @@ -33,21 +33,23 @@ GEO_FLUSH_INTERVAL: int = 60 JOB_ID: str = "geo_cache_flush" -async def _run_flush_with_settings(settings: Settings) -> None: - """Flush the geo service dirty set to the application database. +async def _run_flush_with_resources(geo_cache: GeoCache, settings: Settings) -> None: + """Flush the geo cache dirty set to the application database. Args: + geo_cache: The application's GeoCache instance. settings: The resolved application settings used for database access. """ async with task_db(settings) as db: - count = await geo_service.flush_dirty(db) + count = await geo_cache.flush_dirty(db) if count > 0: log.debug("geo_cache_flush_ran", flushed=count) async def _run_flush(app: FastAPI) -> None: - await _run_flush_with_settings(get_effective_settings(app)) + geo_cache: GeoCache = app.state.geo_cache + await _run_flush_with_resources(geo_cache, get_effective_settings(app)) def register(app: FastAPI) -> None: @@ -60,12 +62,13 @@ def register(app: FastAPI) -> None: app: The :class:`fastapi.FastAPI` application instance whose ``app.state.scheduler`` will receive the job. """ + geo_cache: GeoCache = app.state.geo_cache settings = get_effective_settings(app) app.state.scheduler.add_job( - _run_flush_with_settings, + _run_flush_with_resources, trigger="interval", seconds=GEO_FLUSH_INTERVAL, - kwargs={"settings": settings}, + kwargs={"geo_cache": geo_cache, "settings": settings}, id=JOB_ID, replace_existing=True, ) diff --git a/backend/app/tasks/geo_re_resolve.py b/backend/app/tasks/geo_re_resolve.py index 67847b3..774900f 100644 --- a/backend/app/tasks/geo_re_resolve.py +++ b/backend/app/tasks/geo_re_resolve.py @@ -10,8 +10,8 @@ The task runs every 10 minutes. On each invocation it: 1. Queries all ``NULL``-country rows from ``geo_cache``. 2. Clears the in-memory negative cache so those IPs are eligible for a fresh API attempt. -3. Delegates to :func:`~app.services.geo_service.lookup_batch` which already - handles rate-limit throttling and retries. +3. Delegates to :meth:`~app.services.geo_cache.GeoCache.lookup_batch` which + already handles rate-limit throttling and retries. 4. Logs how many IPs were retried and how many resolved successfully. """ @@ -21,7 +21,7 @@ from typing import TYPE_CHECKING import structlog -from app.services import geo_service +from app.services.geo_cache import GeoCache from app.tasks.db import task_db from app.utils.runtime_state import get_effective_settings @@ -40,16 +40,19 @@ GEO_RE_RESOLVE_INTERVAL: int = 600 JOB_ID: str = "geo_re_resolve" -async def _run_re_resolve_with_resources(settings: Settings, http_session: ClientSession) -> None: +async def _run_re_resolve_with_resources( + geo_cache: GeoCache, settings: Settings, http_session: ClientSession +) -> None: """Query NULL-country IPs from the database and re-resolve them. Args: + geo_cache: The application's GeoCache instance. settings: The resolved application settings used for database access. http_session: The shared aiohttp session used for external lookups. """ async with task_db(settings) as db: # Fetch all IPs with NULL country_code from the persistent cache. - unresolved_ips = await geo_service.get_unresolved_ips(db) + unresolved_ips = await geo_cache.get_unresolved_ips(db) if not unresolved_ips: log.debug("geo_re_resolve_skip", reason="no_unresolved_ips") @@ -58,11 +61,11 @@ async def _run_re_resolve_with_resources(settings: Settings, http_session: Clien log.info("geo_re_resolve_start", unresolved=len(unresolved_ips)) # Clear the negative cache so these IPs are eligible for fresh API calls. - await geo_service.clear_neg_cache() + await geo_cache.clear_neg_cache() # lookup_batch handles throttling, retries, and persistence when db is # passed. This is a background task so DB writes are allowed. - results = await geo_service.lookup_batch(unresolved_ips, http_session, db=db) + results = await geo_cache.lookup_batch(unresolved_ips, http_session, db=db) resolved_count: int = sum( 1 for info in results.values() if info.country_code is not None @@ -75,7 +78,10 @@ async def _run_re_resolve_with_resources(settings: Settings, http_session: Clien async def _run_re_resolve(app: FastAPI) -> None: - await _run_re_resolve_with_resources(get_effective_settings(app), app.state.http_session) + geo_cache: GeoCache = app.state.geo_cache + await _run_re_resolve_with_resources( + geo_cache, get_effective_settings(app), app.state.http_session + ) def register(app: FastAPI) -> None: @@ -91,12 +97,13 @@ def register(app: FastAPI) -> None: app: The :class:`fastapi.FastAPI` application instance whose ``app.state.scheduler`` will receive the job. """ + geo_cache: GeoCache = app.state.geo_cache settings = get_effective_settings(app) app.state.scheduler.add_job( _run_re_resolve_with_resources, trigger="interval", seconds=GEO_RE_RESOLVE_INTERVAL, - kwargs={"settings": settings, "http_session": app.state.http_session}, + kwargs={"geo_cache": geo_cache, "settings": settings, "http_session": app.state.http_session}, id=JOB_ID, replace_existing=True, ) diff --git a/backend/tests/test_main.py b/backend/tests/test_main.py index 1a5699c..4b48689 100644 --- a/backend/tests/test_main.py +++ b/backend/tests/test_main.py @@ -168,9 +168,9 @@ async def test_lifespan_initialises_and_cleans_up_shared_resources(tmp_path: Pat patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session), patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), patch("app.startup.init_db", new=AsyncMock()), - patch("app.services.geo_service.init_geoip"), - patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)), - patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)), + patch("app.services.geo_cache.GeoCache.init_geoip"), + patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=AsyncMock(return_value=None)), + patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0)), patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)), patch("app.tasks.health_check.register"), patch("app.tasks.blocklist_import.register"), diff --git a/backend/tests/test_services/test_geo_service.py b/backend/tests/test_services/test_geo_service.py index 8359b3e..2684c5b 100644 --- a/backend/tests/test_services/test_geo_service.py +++ b/backend/tests/test_services/test_geo_service.py @@ -1,4 +1,4 @@ -"""Tests for geo_service.lookup().""" +"""Tests for geo_service and GeoCache.""" from __future__ import annotations @@ -8,7 +8,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.models.geo import GeoInfo -from app.services import geo_service +from app.services.geo_cache import GeoCache # --------------------------------------------------------------------------- # Helpers @@ -44,35 +44,33 @@ def _make_session(response_json: dict[str, object], status: int = 200) -> MagicM # --------------------------------------------------------------------------- -@pytest.fixture(autouse=True) -async def clear_geo_cache() -> None: - """Flush the module-level geo cache before every test.""" - await geo_service.clear_cache() - geo_service._geoip_reader = None - geo_service._geoip_initialized = False +@pytest.fixture +async def geo_cache() -> GeoCache: + """Provide a fresh GeoCache instance for each test.""" + return GeoCache() -def test_init_geoip_is_startup_only(tmp_path) -> None: +def test_init_geoip_is_startup_only(geo_cache: GeoCache, tmp_path) -> None: """A second init_geoip() call raises when the reader was already loaded.""" path = tmp_path / "GeoLite2-Country.mmdb" path.write_text("dummy") with patch("geoip2.database.Reader", MagicMock(name="Reader")) as mock_reader: - geo_service.init_geoip(str(path)) - assert geo_service._geoip_reader is not None - assert geo_service._geoip_initialized is True + geo_cache.init_geoip(str(path)) + assert geo_cache._geoip_reader is not None + assert geo_cache._geoip_initialized is True with pytest.raises(RuntimeError, match="already initialised"): - geo_service.init_geoip(str(path)) + geo_cache.init_geoip(str(path)) assert mock_reader.call_count == 1 -def test_init_geoip_no_path_leaves_reader_uninitialised() -> None: +def test_init_geoip_no_path_leaves_reader_uninitialised(geo_cache: GeoCache) -> None: """No active reader is created when no path is supplied.""" - geo_service.init_geoip("") - assert geo_service._geoip_reader is None - assert geo_service._geoip_initialized is False + geo_cache.init_geoip("") + assert geo_cache._geoip_reader is None + assert geo_cache._geoip_initialized is False # --------------------------------------------------------------------------- @@ -83,7 +81,7 @@ def test_init_geoip_no_path_leaves_reader_uninitialised() -> None: class TestLookupSuccess: """geo_service.lookup() under normal conditions.""" - async def test_returns_country_code(self) -> None: + async def test_returns_country_code(self, geo_cache: GeoCache) -> None: """country_code is populated from the ``countryCode`` field.""" session = _make_session( { @@ -94,12 +92,12 @@ class TestLookupSuccess: "org": "AS3320 Deutsche Telekom AG", } ) - result = await geo_service.lookup("1.2.3.4", session) + result = await geo_cache.lookup("1.2.3.4", session) assert result is not None assert result.country_code == "DE" - async def test_returns_country_name(self) -> None: + async def test_returns_country_name(self, geo_cache: GeoCache) -> None: """country_name is populated from the ``country`` field.""" session = _make_session( { @@ -110,12 +108,12 @@ class TestLookupSuccess: "org": "Google LLC", } ) - result = await geo_service.lookup("8.8.8.8", session) + result = await geo_cache.lookup("8.8.8.8", session) assert result is not None assert result.country_name == "United States" - async def test_asn_extracted_without_org_suffix(self) -> None: + async def test_asn_extracted_without_org_suffix(self, geo_cache: GeoCache) -> None: """The ASN field contains only the ``AS`` prefix, not the full string.""" session = _make_session( { @@ -126,12 +124,12 @@ class TestLookupSuccess: "org": "Deutsche Telekom", } ) - result = await geo_service.lookup("1.2.3.4", session) + result = await geo_cache.lookup("1.2.3.4", session) assert result is not None assert result.asn == "AS3320" - async def test_org_populated(self) -> None: + async def test_org_populated(self, geo_cache: GeoCache) -> None: """org field is populated from the ``org`` key.""" session = _make_session( { @@ -142,7 +140,7 @@ class TestLookupSuccess: "org": "Google LLC", } ) - result = await geo_service.lookup("8.8.8.8", session) + result = await geo_cache.lookup("8.8.8.8", session) assert result is not None assert result.org == "Google LLC" @@ -156,7 +154,7 @@ class TestLookupSuccess: class TestLookupCaching: """Verify that results are cached and the cache can be cleared.""" - async def test_second_call_uses_cache(self) -> None: + async def test_second_call_uses_cache(self, geo_cache: GeoCache) -> None: """Subsequent lookups for the same IP do not make additional HTTP requests.""" session = _make_session( { @@ -168,13 +166,13 @@ class TestLookupCaching: } ) - await geo_service.lookup("1.2.3.4", session) - await geo_service.lookup("1.2.3.4", session) + await geo_cache.lookup("1.2.3.4", session) + await geo_cache.lookup("1.2.3.4", session) # The session.get() should only have been called once. assert session.get.call_count == 1 - async def test_clear_cache_forces_refetch(self) -> None: + async def test_clear_cache_forces_refetch(self, geo_cache: GeoCache) -> None: """After clearing the cache a new HTTP request is made.""" session = _make_session( { @@ -186,20 +184,20 @@ class TestLookupCaching: } ) - await geo_service.lookup("2.3.4.5", session) - await geo_service.clear_cache() - await geo_service.lookup("2.3.4.5", session) + await geo_cache.lookup("2.3.4.5", session) + await geo_cache.clear() + await geo_cache.lookup("2.3.4.5", session) assert session.get.call_count == 2 - async def test_negative_result_stored_in_neg_cache(self) -> None: + async def test_negative_result_stored_in_neg_cache(self, geo_cache: GeoCache) -> None: """A failed lookup is stored in the negative cache, so the second call is blocked.""" session = _make_session( {"status": "fail", "message": "reserved range"} ) - await geo_service.lookup("192.168.1.1", session) - await geo_service.lookup("192.168.1.1", session) + await geo_cache.lookup("192.168.1.1", session) + await geo_cache.lookup("192.168.1.1", session) # Second call is blocked by the negative cache — only one API hit. assert session.get.call_count == 1 @@ -213,15 +211,15 @@ class TestLookupCaching: class TestLookupFailures: """geo_service.lookup() when things go wrong.""" - async def test_non_200_response_returns_null_geo_info(self) -> None: + async def test_non_200_response_returns_null_geo_info(self, geo_cache: GeoCache) -> None: """A 429 or 500 status returns GeoInfo with null fields (not None).""" session = _make_session({}, status=429) - result = await geo_service.lookup("1.2.3.4", session) + result = await geo_cache.lookup("1.2.3.4", session) assert result is not None assert isinstance(result, GeoInfo) assert result.country_code is None - async def test_network_error_returns_null_geo_info(self) -> None: + async def test_network_error_returns_null_geo_info(self, geo_cache: GeoCache) -> None: """A network exception returns GeoInfo with null fields (not None).""" session = MagicMock() mock_ctx = AsyncMock() @@ -229,15 +227,15 @@ class TestLookupFailures: mock_ctx.__aexit__ = AsyncMock(return_value=False) session.get = MagicMock(return_value=mock_ctx) - result = await geo_service.lookup("10.0.0.1", session) + result = await geo_cache.lookup("10.0.0.1", session) assert result is not None assert isinstance(result, GeoInfo) assert result.country_code is None - async def test_failed_status_returns_geo_info_with_nulls(self) -> None: + async def test_failed_status_returns_geo_info_with_nulls(self, geo_cache: GeoCache) -> None: """When ip-api returns ``status=fail`` a GeoInfo with null fields is returned (but not cached).""" session = _make_session({"status": "fail", "message": "private range"}) - result = await geo_service.lookup("10.0.0.1", session) + result = await geo_cache.lookup("10.0.0.1", session) assert result is not None assert isinstance(result, GeoInfo) @@ -253,43 +251,43 @@ class TestLookupFailures: class TestNegativeCache: """Verify the negative cache throttles retries for failing IPs.""" - async def test_neg_cache_blocks_second_lookup(self) -> None: + async def test_neg_cache_blocks_second_lookup(self, geo_cache: GeoCache) -> None: """After a failed lookup the second call is served from the neg cache.""" session = _make_session({"status": "fail", "message": "private range"}) - r1 = await geo_service.lookup("192.0.2.1", session) - r2 = await geo_service.lookup("192.0.2.1", session) + r1 = await geo_cache.lookup("192.0.2.1", session) + r2 = await geo_cache.lookup("192.0.2.1", session) # Only one HTTP call should have been made; second served from neg cache. assert session.get.call_count == 1 assert r1 is not None and r1.country_code is None assert r2 is not None and r2.country_code is None - async def test_neg_cache_retries_after_ttl(self) -> None: + async def test_neg_cache_retries_after_ttl(self, geo_cache: GeoCache) -> None: """When the neg-cache entry is older than the TTL a new API call is made.""" session = _make_session({"status": "fail", "message": "private range"}) - await geo_service.lookup("192.0.2.2", session) + await geo_cache.lookup("192.0.2.2", session) # Manually expire the neg-cache entry. - geo_service._neg_cache["192.0.2.2"] -= geo_service._NEG_CACHE_TTL + 1 + geo_cache._neg_cache["192.0.2.2"] -= _NEG_CACHE_TTL + 1 - await geo_service.lookup("192.0.2.2", session) + await geo_cache.lookup("192.0.2.2", session) # Both calls should have hit the API. assert session.get.call_count == 2 - async def test_clear_neg_cache_allows_immediate_retry(self) -> None: + async def test_clear_neg_cache_allows_immediate_retry(self, geo_cache: GeoCache) -> None: """After clearing the neg cache the IP is eligible for a new API call.""" session = _make_session({"status": "fail", "message": "private range"}) - await geo_service.lookup("192.0.2.3", session) - await geo_service.clear_neg_cache() - await geo_service.lookup("192.0.2.3", session) + await geo_cache.lookup("192.0.2.3", session) + await geo_cache.clear_neg_cache() + await geo_cache.lookup("192.0.2.3", session) assert session.get.call_count == 2 - async def test_successful_lookup_does_not_pollute_neg_cache(self) -> None: + async def test_successful_lookup_does_not_pollute_neg_cache(self, geo_cache: GeoCache) -> None: """A successful lookup must not create a neg-cache entry.""" session = _make_session( { @@ -301,9 +299,9 @@ class TestNegativeCache: } ) - await geo_service.lookup("1.2.3.4", session) + await geo_cache.lookup("1.2.3.4", session) - assert "1.2.3.4" not in geo_service._neg_cache + assert "1.2.3.4" not in geo_cache._neg_cache # --------------------------------------------------------------------------- @@ -327,33 +325,33 @@ class TestGeoipFallback: reader.country = MagicMock(return_value=response_mock) return reader - async def test_geoip_fallback_called_when_api_fails(self) -> None: + async def test_geoip_fallback_called_when_api_fails(self, geo_cache: GeoCache) -> None: """When ip-api returns status=fail, the geoip2 reader is consulted.""" session = _make_session({"status": "fail", "message": "reserved range"}) mock_reader = self._make_geoip_reader("DE", "Germany") with patch.object(geo_service, "_geoip_reader", mock_reader): - result = await geo_service.lookup("1.2.3.4", session) + result = await geo_cache.lookup("1.2.3.4", session) mock_reader.country.assert_called_once_with("1.2.3.4") assert result is not None assert result.country_code == "DE" assert result.country_name == "Germany" - async def test_geoip_fallback_result_stored_in_cache(self) -> None: + async def test_geoip_fallback_result_stored_in_cache(self, geo_cache: GeoCache) -> None: """A successful geoip2 fallback result is stored in the positive cache.""" session = _make_session({"status": "fail", "message": "reserved range"}) mock_reader = self._make_geoip_reader("US", "United States") with patch.object(geo_service, "_geoip_reader", mock_reader): - await geo_service.lookup("8.8.8.8", session) + await geo_cache.lookup("8.8.8.8", session) # Second call must be served from positive cache without hitting API. - await geo_service.lookup("8.8.8.8", session) + await geo_cache.lookup("8.8.8.8", session) assert session.get.call_count == 1 - assert "8.8.8.8" in geo_service._cache + assert "8.8.8.8" in geo_cache._cache - async def test_geoip_fallback_not_called_on_api_success(self) -> None: + async def test_geoip_fallback_not_called_on_api_success(self, geo_cache: GeoCache) -> None: """When ip-api succeeds, the geoip2 reader must not be consulted.""" session = _make_session( { @@ -367,18 +365,18 @@ class TestGeoipFallback: mock_reader = self._make_geoip_reader("XX", "Nowhere") with patch.object(geo_service, "_geoip_reader", mock_reader): - result = await geo_service.lookup("1.2.3.4", session) + result = await geo_cache.lookup("1.2.3.4", session) mock_reader.country.assert_not_called() assert result is not None assert result.country_code == "JP" - async def test_geoip_fallback_not_called_when_no_reader(self) -> None: + async def test_geoip_fallback_not_called_when_no_reader(self, geo_cache: GeoCache) -> None: """When no geoip2 reader is configured, the fallback silently does nothing.""" session = _make_session({"status": "fail", "message": "private range"}) with patch.object(geo_service, "_geoip_reader", None): - result = await geo_service.lookup("10.0.0.1", session) + result = await geo_cache.lookup("10.0.0.1", session) assert result is not None assert result.country_code is None @@ -428,7 +426,7 @@ def _make_async_db() -> MagicMock: class TestLookupBatchSingleCommit: """lookup_batch() issues exactly one commit per call, not one per IP.""" - async def test_single_commit_for_multiple_ips(self) -> None: + async def test_single_commit_for_multiple_ips(self, geo_cache: GeoCache) -> None: """A batch of N IPs produces exactly one db.commit(), not N.""" ips = ["1.1.1.1", "2.2.2.2", "3.3.3.3"] batch_response = [ @@ -438,11 +436,11 @@ class TestLookupBatchSingleCommit: session = _make_batch_session(batch_response) db = _make_async_db() - await geo_service.lookup_batch(ips, session, db=db) + await geo_cache.lookup_batch(ips, session, db=db) db.commit.assert_awaited_once() - async def test_commit_called_even_on_failed_lookups(self) -> None: + async def test_commit_called_even_on_failed_lookups(self, geo_cache: GeoCache) -> None: """A batch with all-failed lookups still triggers one commit.""" ips = ["10.0.0.1", "10.0.0.2"] batch_response = [ @@ -452,11 +450,11 @@ class TestLookupBatchSingleCommit: session = _make_batch_session(batch_response) db = _make_async_db() - await geo_service.lookup_batch(ips, session, db=db) + await geo_cache.lookup_batch(ips, session, db=db) db.commit.assert_awaited_once() - async def test_no_commit_when_db_is_none(self) -> None: + async def test_no_commit_when_db_is_none(self, geo_cache: GeoCache) -> None: """When db=None, no commit is attempted.""" ips = ["1.1.1.1"] batch_response = [ @@ -472,19 +470,19 @@ class TestLookupBatchSingleCommit: session = _make_batch_session(batch_response) # Should not raise; without db there is nothing to commit. - result = await geo_service.lookup_batch(ips, session, db=None) + result = await geo_cache.lookup_batch(ips, session, db=None) assert result["1.1.1.1"].country_code == "US" - async def test_no_commit_for_all_cached_ips(self) -> None: + async def test_no_commit_for_all_cached_ips(self, geo_cache: GeoCache) -> None: """When all IPs are already cached, no HTTP call and no commit occur.""" - geo_service._cache["5.5.5.5"] = GeoInfo( + geo_cache._cache["5.5.5.5"] = GeoInfo( country_code="FR", country_name="France", asn="AS1", org="ISP" ) db = _make_async_db() session = _make_batch_session([]) - result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db) + result = await geo_cache.lookup_batch(["5.5.5.5"], session, db=db) assert result["5.5.5.5"].country_code == "FR" db.commit.assert_not_awaited() @@ -499,31 +497,31 @@ class TestLookupBatchSingleCommit: class TestDirtySetTracking: """_store() marks successfully resolved IPs as dirty.""" - async def test_successful_resolution_adds_to_dirty(self) -> None: + async def test_successful_resolution_adds_to_dirty(self, geo_cache: GeoCache) -> None: """Storing a GeoInfo with a country_code adds the IP to _dirty.""" info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP") await geo_service._store("1.2.3.4", info) - assert "1.2.3.4" in geo_service._dirty + assert "1.2.3.4" in geo_cache._dirty - async def test_null_country_does_not_add_to_dirty(self) -> None: + async def test_null_country_does_not_add_to_dirty(self, geo_cache: GeoCache) -> None: """Storing a GeoInfo with country_code=None must not pollute _dirty.""" info = GeoInfo(country_code=None, country_name=None, asn=None, org=None) await geo_service._store("10.0.0.1", info) - assert "10.0.0.1" not in geo_service._dirty + assert "10.0.0.1" not in geo_cache._dirty - async def test_clear_cache_also_clears_dirty(self) -> None: + async def test_clear_cache_also_clears_dirty(self, geo_cache: GeoCache) -> None: """clear_cache() must discard any pending dirty entries.""" info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP") await geo_service._store("8.8.8.8", info) - assert geo_service._dirty + assert geo_cache._dirty - await geo_service.clear_cache() + await geo_cache.clear() - assert not geo_service._dirty + assert not geo_cache._dirty - async def test_lookup_batch_populates_dirty(self) -> None: + async def test_lookup_batch_populates_dirty(self, geo_cache: GeoCache) -> None: """After lookup_batch() with db=None, resolved IPs appear in _dirty.""" ips = ["1.1.1.1", "2.2.2.2"] batch_response = [ @@ -532,39 +530,39 @@ class TestDirtySetTracking: ] session = _make_batch_session(batch_response) - await geo_service.lookup_batch(ips, session, db=None) + await geo_cache.lookup_batch(ips, session, db=None) for ip in ips: - assert ip in geo_service._dirty + assert ip in geo_cache._dirty class TestFlushDirty: """flush_dirty() persists dirty entries and clears the set.""" - async def test_flush_writes_and_clears_dirty(self) -> None: + async def test_flush_writes_and_clears_dirty(self, geo_cache: GeoCache) -> None: """flush_dirty() inserts all dirty IPs and clears _dirty afterwards.""" info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT") await geo_service._store("100.0.0.1", info) - assert "100.0.0.1" in geo_service._dirty + assert "100.0.0.1" in geo_cache._dirty db = _make_async_db() - count = await geo_service.flush_dirty(db) + count = await geo_cache.flush_dirty(db) assert count == 1 db.executemany.assert_awaited_once() db.commit.assert_awaited_once() - assert "100.0.0.1" not in geo_service._dirty + assert "100.0.0.1" not in geo_cache._dirty - async def test_flush_returns_zero_when_nothing_dirty(self) -> None: + async def test_flush_returns_zero_when_nothing_dirty(self, geo_cache: GeoCache) -> None: """flush_dirty() returns 0 and makes no DB calls when _dirty is empty.""" db = _make_async_db() - count = await geo_service.flush_dirty(db) + count = await geo_cache.flush_dirty(db) assert count == 0 db.executemany.assert_not_awaited() db.commit.assert_not_awaited() - async def test_flush_re_adds_to_dirty_on_db_error(self) -> None: + async def test_flush_re_adds_to_dirty_on_db_error(self, geo_cache: GeoCache) -> None: """When the DB write fails, entries are re-added to _dirty for retry.""" info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP") await geo_service._store("200.0.0.1", info) @@ -572,12 +570,12 @@ class TestFlushDirty: db = _make_async_db() db.executemany = AsyncMock(side_effect=OSError("disk full")) - count = await geo_service.flush_dirty(db) + count = await geo_cache.flush_dirty(db) assert count == 0 - assert "200.0.0.1" in geo_service._dirty + assert "200.0.0.1" in geo_cache._dirty - async def test_flush_batch_and_lookup_batch_integration(self) -> None: + async def test_flush_batch_and_lookup_batch_integration(self, geo_cache: GeoCache) -> None: """lookup_batch() populates _dirty; flush_dirty() then persists them.""" ips = ["10.1.2.3", "10.1.2.4"] batch_response = [ @@ -587,15 +585,15 @@ class TestFlushDirty: session = _make_batch_session(batch_response) # Resolve without DB to populate only in-memory cache and _dirty. - await geo_service.lookup_batch(ips, session, db=None) - assert geo_service._dirty == set(ips) + await geo_cache.lookup_batch(ips, session, db=None) + assert geo_cache._dirty == set(ips) # Now flush to the DB. db = _make_async_db() - count = await geo_service.flush_dirty(db) + count = await geo_cache.flush_dirty(db) assert count == 2 - assert not geo_service._dirty + assert not geo_cache._dirty db.commit.assert_awaited_once() @@ -607,7 +605,7 @@ class TestFlushDirty: class TestLookupBatchThrottling: """Verify the inter-batch delay, retry, and give-up behaviour.""" - async def test_lookup_batch_throttles_between_chunks(self) -> None: + async def test_lookup_batch_throttles_between_chunks(self, geo_cache: GeoCache) -> None: """When more than _BATCH_SIZE IPs are sent, asyncio.sleep is called between consecutive batch HTTP calls with at least _BATCH_DELAY.""" # Generate _BATCH_SIZE + 1 IPs so we get exactly 2 batch calls. @@ -628,7 +626,7 @@ class TestLookupBatchThrottling: ) as mock_batch, patch("app.services.geo_service.asyncio.sleep", new_callable=AsyncMock) as mock_sleep, ): - await geo_service.lookup_batch(ips, MagicMock()) + await geo_cache.lookup_batch(ips, MagicMock()) # Two chunks → one sleep between them. assert mock_batch.call_count == 2 @@ -636,7 +634,7 @@ class TestLookupBatchThrottling: delay_arg: float = mock_sleep.call_args[0][0] assert delay_arg >= geo_service._BATCH_DELAY - async def test_lookup_batch_retries_on_full_chunk_failure(self) -> None: + async def test_lookup_batch_retries_on_full_chunk_failure(self, geo_cache: GeoCache) -> None: """When a chunk returns all-None on first try, it retries and succeeds.""" ips = ["1.2.3.4", "5.6.7.8"] @@ -664,13 +662,13 @@ class TestLookupBatchThrottling: ), patch("app.services.geo_service.asyncio.sleep", new_callable=AsyncMock), ): - result = await geo_service.lookup_batch(ips, MagicMock()) + result = await geo_cache.lookup_batch(ips, MagicMock()) assert call_count == 2 assert result["1.2.3.4"].country_code == "DE" assert result["5.6.7.8"].country_code == "US" - async def test_lookup_batch_gives_up_after_max_retries(self) -> None: + async def test_lookup_batch_gives_up_after_max_retries(self, geo_cache: GeoCache) -> None: """After _BATCH_MAX_RETRIES + 1 attempts, IPs end up in the neg cache.""" ips = ["9.9.9.9"] _empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None) @@ -686,14 +684,14 @@ class TestLookupBatchThrottling: ) as mock_batch, patch("app.services.geo_service.asyncio.sleep", new_callable=AsyncMock) as mock_sleep, ): - result = await geo_service.lookup_batch(ips, MagicMock()) + result = await geo_cache.lookup_batch(ips, MagicMock()) # Initial attempt + max_retries retries. assert mock_batch.call_count == max_retries + 1 # IP should have no country. assert result["9.9.9.9"].country_code is None # Negative cache should contain the IP. - assert "9.9.9.9" in geo_service._neg_cache + assert "9.9.9.9" in geo_cache._neg_cache # Sleep called for each retry with exponential backoff. assert mock_sleep.call_count == max_retries backoff_values = [call.args[0] for call in mock_sleep.call_args_list] @@ -717,7 +715,7 @@ class TestErrorLogging: always present, and adds an ``exc_type`` field for easy log filtering. """ - async def test_empty_message_exception_logs_exc_type(self, caplog: pytest.LogCaptureFixture) -> None: + async def test_empty_message_exception_logs_exc_type(geo_cache: GeoCache, self, caplog: pytest.LogCaptureFixture) -> None: """When exception str() is empty, exc_type and repr are still logged.""" class _EmptyMessageError(Exception): @@ -735,7 +733,7 @@ class TestErrorLogging: import structlog.testing with structlog.testing.capture_logs() as captured: - result = await geo_service.lookup("197.221.98.153", session) + result = await geo_cache.lookup("197.221.98.153", session) assert result is not None assert result.country_code is None @@ -748,7 +746,7 @@ class TestErrorLogging: # repr() must include the class name even when str() is empty. assert "_EmptyMessageError" in event["error"] - async def test_connection_error_logs_exc_type(self, caplog: pytest.LogCaptureFixture) -> None: + async def test_connection_error_logs_exc_type(geo_cache: GeoCache, self, caplog: pytest.LogCaptureFixture) -> None: """A standard OSError with message is logged both in error and exc_type.""" session = MagicMock() mock_ctx = AsyncMock() @@ -759,7 +757,7 @@ class TestErrorLogging: import structlog.testing with structlog.testing.capture_logs() as captured: - await geo_service.lookup("10.0.0.1", session) + await geo_cache.lookup("10.0.0.1", session) request_failed = [e for e in captured if e.get("event") == "geo_lookup_request_failed"] assert len(request_failed) == 1 @@ -767,7 +765,7 @@ class TestErrorLogging: assert event["exc_type"] == "OSError" assert "connection refused" in event["error"] - async def test_batch_empty_message_exception_logs_exc_type(self) -> None: + async def test_batch_empty_message_exception_logs_exc_type(self, geo_cache: GeoCache) -> None: """Batch API call: empty-message exceptions include exc_type in the log.""" class _EmptyMessageError(Exception): @@ -804,10 +802,10 @@ class TestLookupCachedOnly: def test_returns_cached_ips(self) -> None: """IPs already in the cache are returned in the geo_map.""" - geo_service._cache["1.1.1.1"] = GeoInfo( + geo_cache._cache["1.1.1.1"] = GeoInfo( country_code="AU", country_name="Australia", asn="AS13335", org="Cloudflare" ) - geo_map, uncached = geo_service.lookup_cached_only(["1.1.1.1"]) + geo_map, uncached = geo_cache.lookup_cached_only(["1.1.1.1"]) assert "1.1.1.1" in geo_map assert geo_map["1.1.1.1"].country_code == "AU" @@ -815,7 +813,7 @@ class TestLookupCachedOnly: def test_returns_uncached_ips(self) -> None: """IPs not in the cache appear in the uncached list.""" - geo_map, uncached = geo_service.lookup_cached_only(["9.9.9.9"]) + geo_map, uncached = geo_cache.lookup_cached_only(["9.9.9.9"]) assert "9.9.9.9" not in geo_map assert "9.9.9.9" in uncached @@ -824,42 +822,42 @@ class TestLookupCachedOnly: """IPs in the negative cache within TTL are not re-queued as uncached.""" import time - geo_service._neg_cache["10.0.0.1"] = time.monotonic() + geo_cache._neg_cache["10.0.0.1"] = time.monotonic() - geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.1"]) + geo_map, uncached = geo_cache.lookup_cached_only(["10.0.0.1"]) assert "10.0.0.1" not in geo_map assert "10.0.0.1" not in uncached def test_expired_neg_cache_requeued(self) -> None: """IPs whose neg-cache entry has expired are listed as uncached.""" - geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired + geo_cache._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired - _geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.2"]) + _geo_map, uncached = geo_cache.lookup_cached_only(["10.0.0.2"]) assert "10.0.0.2" in uncached def test_mixed_ips(self) -> None: """A mix of cached, neg-cached, and unknown IPs is split correctly.""" - geo_service._cache["1.2.3.4"] = GeoInfo( + geo_cache._cache["1.2.3.4"] = GeoInfo( country_code="DE", country_name="Germany", asn=None, org=None ) import time - geo_service._neg_cache["5.5.5.5"] = time.monotonic() + geo_cache._neg_cache["5.5.5.5"] = time.monotonic() - geo_map, uncached = geo_service.lookup_cached_only(["1.2.3.4", "5.5.5.5", "9.9.9.9"]) + geo_map, uncached = geo_cache.lookup_cached_only(["1.2.3.4", "5.5.5.5", "9.9.9.9"]) assert list(geo_map.keys()) == ["1.2.3.4"] assert uncached == ["9.9.9.9"] def test_deduplication(self) -> None: """Duplicate IPs in the input appear at most once in the output.""" - geo_service._cache["1.2.3.4"] = GeoInfo( + geo_cache._cache["1.2.3.4"] = GeoInfo( country_code="US", country_name="United States", asn=None, org=None ) - geo_map, uncached = geo_service.lookup_cached_only( + geo_map, uncached = geo_cache.lookup_cached_only( ["9.9.9.9", "9.9.9.9", "1.2.3.4", "1.2.3.4"] ) @@ -868,27 +866,32 @@ class TestLookupCachedOnly: class TestReResolveAll: - """Tests for :func:`~app.services.geo_service.re_resolve_all`.""" + """Tests for :func:`~app.services.geo_cache.re_resolve_all`.""" - async def test_returns_zero_when_no_unresolved_ips(self) -> None: + async def test_returns_zero_when_no_unresolved_ips(self, geo_cache: GeoCache) -> None: """The service returns zero counts when there are no unresolved IPs.""" db = MagicMock() session = MagicMock() with patch( - "app.services.geo_service.get_unresolved_ips", + "app.repositories.geo_cache_repo.get_unresolved_ips", AsyncMock(return_value=[]), - ), patch("app.services.geo_service.lookup_batch", AsyncMock()) as mock_lookup, patch( - "app.services.geo_service.clear_neg_cache", - MagicMock(), + ), patch.object( + geo_cache, + "lookup_batch", + AsyncMock(), + ) as mock_lookup, patch.object( + geo_cache, + "clear_neg_cache", + AsyncMock(), ) as mock_clear: - result = await geo_service.re_resolve_all(db, session) + result = await geo_cache.re_resolve_all(db, session) assert result == {"resolved": 0, "total": 0} mock_clear.assert_not_called() mock_lookup.assert_not_called() - async def test_clears_neg_cache_and_returns_counts(self) -> None: + async def test_clears_neg_cache_and_returns_counts(self, geo_cache: GeoCache) -> None: """The service clears negative cache and returns resolved and total counts.""" db = MagicMock() session = MagicMock() @@ -899,16 +902,18 @@ class TestReResolveAll: } with patch( - "app.services.geo_service.get_unresolved_ips", + "app.repositories.geo_cache_repo.get_unresolved_ips", AsyncMock(return_value=ips), - ), patch( - "app.services.geo_service.lookup_batch", + ), patch.object( + geo_cache, + "lookup_batch", AsyncMock(return_value=geo_map), - ) as mock_lookup, patch( - "app.services.geo_service.clear_neg_cache", + ) as mock_lookup, patch.object( + geo_cache, + "clear_neg_cache", AsyncMock(), ) as mock_clear: - result = await geo_service.re_resolve_all(db, session) + result = await geo_cache.re_resolve_all(db, session) assert result == {"resolved": 1, "total": 2} mock_clear.assert_called_once() @@ -923,7 +928,7 @@ class TestReResolveAll: class TestLookupBatchBulkWrites: """lookup_batch() uses executemany for bulk DB writes, not per-IP execute.""" - async def test_executemany_called_for_successful_ips(self) -> None: + async def test_executemany_called_for_successful_ips(self, geo_cache: GeoCache) -> None: """When multiple IPs resolve successfully, a single executemany write occurs.""" ips = ["1.1.1.1", "2.2.2.2", "3.3.3.3"] batch_response = [ @@ -940,14 +945,14 @@ class TestLookupBatchBulkWrites: session = _make_batch_session(batch_response) db = _make_async_db() - await geo_service.lookup_batch(ips, session, db=db) + await geo_cache.lookup_batch(ips, session, db=db) # One executemany for the positive rows. assert db.executemany.await_count >= 1 # High-level: execute() must NOT be called for the batch writes. db.execute.assert_not_awaited() - async def test_executemany_called_for_failed_ips(self) -> None: + async def test_executemany_called_for_failed_ips(self, geo_cache: GeoCache) -> None: """When IPs fail resolution, a single executemany write covers neg entries.""" ips = ["10.0.0.1", "10.0.0.2"] batch_response = [ @@ -957,12 +962,12 @@ class TestLookupBatchBulkWrites: session = _make_batch_session(batch_response) db = _make_async_db() - await geo_service.lookup_batch(ips, session, db=db) + await geo_cache.lookup_batch(ips, session, db=db) assert db.executemany.await_count >= 1 db.execute.assert_not_awaited() - async def test_mixed_results_two_executemany_calls(self) -> None: + async def test_mixed_results_two_executemany_calls(self, geo_cache: GeoCache) -> None: """A mix of successful and failed IPs produces two executemany calls.""" ips = ["1.1.1.1", "10.0.0.1"] batch_response = [ @@ -979,7 +984,7 @@ class TestLookupBatchBulkWrites: session = _make_batch_session(batch_response) db = _make_async_db() - await geo_service.lookup_batch(ips, session, db=db) + await geo_cache.lookup_batch(ips, session, db=db) # One executemany for positives, one for negatives. assert db.executemany.await_count == 2