Fix geo cache write performance: batch commits, read-only GETs, dirty flush
- Remove per-IP db.commit() from _persist_entry() and _persist_neg_entry(); add a single commit after the full lookup_batch() chunk loop instead. Reduces commits from ~5,200 to 1 per bans/by-country request. - Remove db dependency from GET /api/dashboard/bans and GET /api/dashboard/bans/by-country; pass app_db=None so no SQLite writes occur during read-only requests. - Add _dirty set to geo_service; _store() marks resolved IPs dirty. New flush_dirty(db) batch-upserts all dirty entries in one transaction. New geo_cache_flush APScheduler task flushes every 60 s so geo data is persisted without blocking requests.
This commit is contained in:
@@ -34,7 +34,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from app.config import Settings, get_settings
|
||||
from app.db import init_db
|
||||
from app.routers import auth, bans, blocklist, config, dashboard, geo, health, history, jails, server, setup
|
||||
from app.tasks import blocklist_import, health_check
|
||||
from app.tasks import blocklist_import, geo_cache_flush, health_check
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError, Fail2BanProtocolError
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -151,6 +151,9 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
# --- Blocklist import scheduled task ---
|
||||
blocklist_import.register(app)
|
||||
|
||||
# --- Periodic geo cache flush to SQLite ---
|
||||
geo_cache_flush.register(app)
|
||||
|
||||
log.info("bangui_started")
|
||||
|
||||
try:
|
||||
|
||||
@@ -9,16 +9,14 @@ Also provides ``GET /api/dashboard/bans`` for the dashboard ban-list table.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
import aiosqlite
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
from fastapi import APIRouter, Query, Request
|
||||
|
||||
from app.dependencies import AuthDep, get_db
|
||||
from app.dependencies import AuthDep
|
||||
from app.models.ban import (
|
||||
BanOrigin,
|
||||
BansByCountryResponse,
|
||||
@@ -77,7 +75,6 @@ async def get_server_status(
|
||||
async def get_dashboard_bans(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
db: Annotated[aiosqlite.Connection, Depends(get_db)],
|
||||
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
|
||||
page: int = Query(default=1, ge=1, description="1-based page number."),
|
||||
page_size: int = Query(default=_DEFAULT_PAGE_SIZE, ge=1, le=500, description="Items per page."),
|
||||
@@ -90,12 +87,13 @@ async def get_dashboard_bans(
|
||||
|
||||
Reads from the fail2ban database and enriches each entry with
|
||||
geolocation data (country, ASN, organisation) from the ip-api.com
|
||||
free API. Results are sorted newest-first.
|
||||
free API. Results are sorted newest-first. Geo lookups are served
|
||||
from the in-memory cache only; no database writes occur during this
|
||||
GET request.
|
||||
|
||||
Args:
|
||||
request: The incoming request (used to access ``app.state``).
|
||||
_auth: Validated session dependency.
|
||||
db: BanGUI application database (for persistent geo cache writes).
|
||||
range: Time-range preset — ``"24h"``, ``"7d"``, ``"30d"``, or
|
||||
``"365d"``.
|
||||
page: 1-based page number.
|
||||
@@ -115,7 +113,7 @@ async def get_dashboard_bans(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
http_session=http_session,
|
||||
app_db=db,
|
||||
app_db=None,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
@@ -128,7 +126,6 @@ async def get_dashboard_bans(
|
||||
async def get_bans_by_country(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
db: Annotated[aiosqlite.Connection, Depends(get_db)],
|
||||
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
|
||||
origin: BanOrigin | None = Query(
|
||||
default=None,
|
||||
@@ -139,12 +136,13 @@ async def get_bans_by_country(
|
||||
|
||||
Uses SQL aggregation (``GROUP BY ip``) and batch geo-resolution to handle
|
||||
10 000+ banned IPs efficiently. Returns a ``{country_code: count}`` map
|
||||
and the 200 most recent raw ban rows for the companion access table.
|
||||
and the 200 most recent raw ban rows for the companion access table. Geo
|
||||
lookups are served from the in-memory cache only; no database writes occur
|
||||
during this GET request.
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
_auth: Validated session dependency.
|
||||
db: BanGUI application database (for persistent geo cache writes).
|
||||
range: Time-range preset.
|
||||
origin: Optional filter by ban origin.
|
||||
|
||||
@@ -159,7 +157,7 @@ async def get_bans_by_country(
|
||||
socket_path,
|
||||
range,
|
||||
http_session=http_session,
|
||||
app_db=db,
|
||||
app_db=None,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
|
||||
@@ -118,6 +118,10 @@ _cache: dict[str, GeoInfo] = {}
|
||||
#: Entries within :data:`_NEG_CACHE_TTL` seconds are not re-queried.
|
||||
_neg_cache: dict[str, float] = {}
|
||||
|
||||
#: IPs added to :data:`_cache` but not yet persisted to the database.
|
||||
#: Consumed and cleared atomically by :func:`flush_dirty`.
|
||||
_dirty: set[str] = set()
|
||||
|
||||
#: Optional MaxMind GeoLite2 reader initialised by :func:`init_geoip`.
|
||||
_geoip_reader: geoip2.database.Reader | None = None
|
||||
|
||||
@@ -125,10 +129,12 @@ _geoip_reader: geoip2.database.Reader | None = None
|
||||
def clear_cache() -> None:
|
||||
"""Flush both the positive and negative lookup caches.
|
||||
|
||||
Useful in tests and when the operator suspects stale data.
|
||||
Also clears the dirty set so any pending-but-unpersisted entries are
|
||||
discarded. Useful in tests and when the operator suspects stale data.
|
||||
"""
|
||||
_cache.clear()
|
||||
_neg_cache.clear()
|
||||
_dirty.clear()
|
||||
|
||||
|
||||
def clear_neg_cache() -> None:
|
||||
@@ -256,7 +262,6 @@ async def _persist_entry(
|
||||
""",
|
||||
(ip, info.country_code, info.country_name, info.asn, info.org),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
async def _persist_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
|
||||
@@ -273,7 +278,6 @@ async def _persist_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
|
||||
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
|
||||
(ip,),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -330,6 +334,7 @@ async def lookup(
|
||||
if result.country_code is not None and db is not None:
|
||||
try:
|
||||
await _persist_entry(db, ip, result)
|
||||
await db.commit()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("geo_persist_failed", ip=ip, error=str(exc))
|
||||
log.debug("geo_lookup_success", ip=ip, country=result.country_code, asn=result.asn)
|
||||
@@ -350,6 +355,7 @@ async def lookup(
|
||||
if fallback.country_code is not None and db is not None:
|
||||
try:
|
||||
await _persist_entry(db, ip, fallback)
|
||||
await db.commit()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("geo_persist_failed", ip=ip, error=str(exc))
|
||||
log.debug("geo_geoip_fallback_success", ip=ip, country=fallback.country_code)
|
||||
@@ -360,6 +366,7 @@ async def lookup(
|
||||
if db is not None:
|
||||
try:
|
||||
await _persist_neg_entry(db, ip)
|
||||
await db.commit()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("geo_persist_neg_failed", ip=ip, error=str(exc))
|
||||
|
||||
@@ -449,6 +456,12 @@ async def lookup_batch(
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("geo_persist_neg_failed", ip=ip, error=str(exc))
|
||||
|
||||
if db is not None:
|
||||
try:
|
||||
await db.commit()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("geo_batch_commit_failed", error=str(exc))
|
||||
|
||||
log.info(
|
||||
"geo_batch_lookup_complete",
|
||||
requested=len(uncached),
|
||||
@@ -561,11 +574,77 @@ def _str_or_none(value: object) -> str | None:
|
||||
def _store(ip: str, info: GeoInfo) -> None:
|
||||
"""Insert *info* into the module-level cache, flushing if over capacity.
|
||||
|
||||
When the IP resolved successfully (``country_code is not None``) it is
|
||||
also added to the :data:`_dirty` set so :func:`flush_dirty` can persist
|
||||
it to the database on the next scheduled flush.
|
||||
|
||||
Args:
|
||||
ip: The IP address key.
|
||||
info: The :class:`GeoInfo` to store.
|
||||
"""
|
||||
if len(_cache) >= _MAX_CACHE_SIZE:
|
||||
_cache.clear()
|
||||
_dirty.clear()
|
||||
log.info("geo_cache_flushed", reason="capacity")
|
||||
_cache[ip] = info
|
||||
if info.country_code is not None:
|
||||
_dirty.add(ip)
|
||||
|
||||
|
||||
async def flush_dirty(db: aiosqlite.Connection) -> int:
|
||||
"""Persist all new in-memory geo entries to the ``geo_cache`` table.
|
||||
|
||||
Takes an atomic snapshot of :data:`_dirty`, clears it, then batch-inserts
|
||||
all entries that are still present in :data:`_cache` using a single
|
||||
``executemany`` call and one ``COMMIT``. This is the only place that
|
||||
writes to the persistent cache during normal operation after startup.
|
||||
|
||||
If the database write fails the entries are re-added to :data:`_dirty`
|
||||
so they will be retried on the next flush cycle.
|
||||
|
||||
Args:
|
||||
db: Open :class:`aiosqlite.Connection` to the BanGUI application
|
||||
database.
|
||||
|
||||
Returns:
|
||||
The number of rows successfully upserted.
|
||||
"""
|
||||
if not _dirty:
|
||||
return 0
|
||||
|
||||
# Atomically snapshot and clear in a single-threaded async context.
|
||||
# No ``await`` between copy and clear ensures no interleaving.
|
||||
to_flush = _dirty.copy()
|
||||
_dirty.clear()
|
||||
|
||||
rows = [
|
||||
(ip, _cache[ip].country_code, _cache[ip].country_name, _cache[ip].asn, _cache[ip].org)
|
||||
for ip in to_flush
|
||||
if ip in _cache
|
||||
]
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
try:
|
||||
await db.executemany(
|
||||
"""
|
||||
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(ip) DO UPDATE SET
|
||||
country_code = excluded.country_code,
|
||||
country_name = excluded.country_name,
|
||||
asn = excluded.asn,
|
||||
org = excluded.org,
|
||||
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
||||
""",
|
||||
rows,
|
||||
)
|
||||
await db.commit()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("geo_flush_dirty_failed", error=str(exc))
|
||||
# Re-add to dirty so they are retried on the next flush cycle.
|
||||
_dirty.update(to_flush)
|
||||
return 0
|
||||
|
||||
log.info("geo_flush_dirty_complete", count=len(rows))
|
||||
return len(rows)
|
||||
|
||||
66
backend/app/tasks/geo_cache_flush.py
Normal file
66
backend/app/tasks/geo_cache_flush.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Geo cache flush background task.
|
||||
|
||||
Registers an APScheduler job that periodically persists newly resolved IP
|
||||
geo entries from the in-memory ``_dirty`` set to the ``geo_cache`` table.
|
||||
|
||||
After Task 2 removed geo cache writes from GET requests, newly resolved IPs
|
||||
are only held in the in-memory cache until this task flushes them. With the
|
||||
default 60-second interval, at most one minute of new resolution results is
|
||||
at risk on an unexpected process restart.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import structlog
|
||||
|
||||
from app.services import geo_service
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from fastapi import FastAPI
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
#: How often the flush job fires (seconds). Configurable tuning constant.
|
||||
GEO_FLUSH_INTERVAL: int = 60
|
||||
|
||||
#: Stable APScheduler job ID — ensures re-registration replaces, not duplicates.
|
||||
JOB_ID: str = "geo_cache_flush"
|
||||
|
||||
|
||||
async def _run_flush(app: Any) -> None:
|
||||
"""Flush the geo service dirty set to the application database.
|
||||
|
||||
Reads shared resources from ``app.state`` and delegates to
|
||||
:func:`~app.services.geo_service.flush_dirty`.
|
||||
|
||||
Args:
|
||||
app: The :class:`fastapi.FastAPI` application instance passed via
|
||||
APScheduler ``kwargs``.
|
||||
"""
|
||||
db = app.state.db
|
||||
count = await geo_service.flush_dirty(db)
|
||||
if count > 0:
|
||||
log.debug("geo_cache_flush_ran", flushed=count)
|
||||
|
||||
|
||||
def register(app: FastAPI) -> None:
|
||||
"""Add (or replace) the geo cache flush job in the application scheduler.
|
||||
|
||||
Must be called after the scheduler has been started (i.e., inside the
|
||||
lifespan handler, after ``scheduler.start()``).
|
||||
|
||||
Args:
|
||||
app: The :class:`fastapi.FastAPI` application instance whose
|
||||
``app.state.scheduler`` will receive the job.
|
||||
"""
|
||||
app.state.scheduler.add_job(
|
||||
_run_flush,
|
||||
trigger="interval",
|
||||
seconds=GEO_FLUSH_INTERVAL,
|
||||
kwargs={"app": app},
|
||||
id=JOB_ID,
|
||||
replace_existing=True,
|
||||
)
|
||||
log.info("geo_cache_flush_scheduled", interval_seconds=GEO_FLUSH_INTERVAL)
|
||||
@@ -356,3 +356,212 @@ class TestGeoipFallback:
|
||||
|
||||
assert result is not None
|
||||
assert result.country_code is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batch single-commit behaviour (Task 1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_batch_session(batch_response: list[dict[str, object]]) -> MagicMock:
|
||||
"""Build a mock aiohttp.ClientSession for batch POST calls.
|
||||
|
||||
Args:
|
||||
batch_response: The list that the mock response's ``json()`` returns.
|
||||
|
||||
Returns:
|
||||
A :class:`MagicMock` with a ``post`` method wired as an async context.
|
||||
"""
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 200
|
||||
mock_resp.json = AsyncMock(return_value=batch_response)
|
||||
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
session = MagicMock()
|
||||
session.post = MagicMock(return_value=mock_ctx)
|
||||
return session
|
||||
|
||||
|
||||
def _make_async_db() -> MagicMock:
|
||||
"""Build a minimal mock :class:`aiosqlite.Connection`.
|
||||
|
||||
Returns:
|
||||
MagicMock with ``execute``, ``executemany``, and ``commit`` wired as
|
||||
async coroutines.
|
||||
"""
|
||||
db = MagicMock()
|
||||
db.execute = AsyncMock()
|
||||
db.executemany = AsyncMock()
|
||||
db.commit = AsyncMock()
|
||||
return db
|
||||
|
||||
|
||||
class TestLookupBatchSingleCommit:
|
||||
"""lookup_batch() issues exactly one commit per call, not one per IP."""
|
||||
|
||||
async def test_single_commit_for_multiple_ips(self) -> None:
|
||||
"""A batch of N IPs produces exactly one db.commit(), not N."""
|
||||
ips = ["1.1.1.1", "2.2.2.2", "3.3.3.3"]
|
||||
batch_response = [
|
||||
{"query": ip, "status": "success", "countryCode": "DE", "country": "Germany", "as": "AS1", "org": "Org"}
|
||||
for ip in ips
|
||||
]
|
||||
session = _make_batch_session(batch_response)
|
||||
db = _make_async_db()
|
||||
|
||||
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
||||
|
||||
db.commit.assert_awaited_once()
|
||||
|
||||
async def test_commit_called_even_on_failed_lookups(self) -> None:
|
||||
"""A batch with all-failed lookups still triggers one commit."""
|
||||
ips = ["10.0.0.1", "10.0.0.2"]
|
||||
batch_response = [
|
||||
{"query": ip, "status": "fail", "message": "private range"}
|
||||
for ip in ips
|
||||
]
|
||||
session = _make_batch_session(batch_response)
|
||||
db = _make_async_db()
|
||||
|
||||
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
|
||||
|
||||
db.commit.assert_awaited_once()
|
||||
|
||||
async def test_no_commit_when_db_is_none(self) -> None:
|
||||
"""When db=None, no commit is attempted."""
|
||||
ips = ["1.1.1.1"]
|
||||
batch_response = [
|
||||
{"query": "1.1.1.1", "status": "success", "countryCode": "US", "country": "United States", "as": "AS15169", "org": "Google LLC"},
|
||||
]
|
||||
session = _make_batch_session(batch_response)
|
||||
|
||||
# Should not raise; without db there is nothing to commit.
|
||||
result = await geo_service.lookup_batch(ips, session, db=None)
|
||||
|
||||
assert result["1.1.1.1"].country_code == "US"
|
||||
|
||||
async def test_no_commit_for_all_cached_ips(self) -> None:
|
||||
"""When all IPs are already cached, no HTTP call and no commit occur."""
|
||||
geo_service._cache["5.5.5.5"] = GeoInfo( # type: ignore[attr-defined]
|
||||
country_code="FR", country_name="France", asn="AS1", org="ISP"
|
||||
)
|
||||
db = _make_async_db()
|
||||
session = _make_batch_session([])
|
||||
|
||||
result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db) # type: ignore[arg-type]
|
||||
|
||||
assert result["5.5.5.5"].country_code == "FR"
|
||||
db.commit.assert_not_awaited()
|
||||
session.post.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dirty-set tracking and flush_dirty (Task 3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDirtySetTracking:
|
||||
"""_store() marks successfully resolved IPs as dirty."""
|
||||
|
||||
def test_successful_resolution_adds_to_dirty(self) -> None:
|
||||
"""Storing a GeoInfo with a country_code adds the IP to _dirty."""
|
||||
info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")
|
||||
geo_service._store("1.2.3.4", info) # type: ignore[attr-defined]
|
||||
|
||||
assert "1.2.3.4" in geo_service._dirty # type: ignore[attr-defined]
|
||||
|
||||
def test_null_country_does_not_add_to_dirty(self) -> None:
|
||||
"""Storing a GeoInfo with country_code=None must not pollute _dirty."""
|
||||
info = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
||||
geo_service._store("10.0.0.1", info) # type: ignore[attr-defined]
|
||||
|
||||
assert "10.0.0.1" not in geo_service._dirty # type: ignore[attr-defined]
|
||||
|
||||
def test_clear_cache_also_clears_dirty(self) -> None:
|
||||
"""clear_cache() must discard any pending dirty entries."""
|
||||
info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP")
|
||||
geo_service._store("8.8.8.8", info) # type: ignore[attr-defined]
|
||||
assert geo_service._dirty # type: ignore[attr-defined]
|
||||
|
||||
geo_service.clear_cache()
|
||||
|
||||
assert not geo_service._dirty # type: ignore[attr-defined]
|
||||
|
||||
async def test_lookup_batch_populates_dirty(self) -> None:
|
||||
"""After lookup_batch() with db=None, resolved IPs appear in _dirty."""
|
||||
ips = ["1.1.1.1", "2.2.2.2"]
|
||||
batch_response = [
|
||||
{"query": ip, "status": "success", "countryCode": "JP", "country": "Japan", "as": "AS7500", "org": "IIJ"}
|
||||
for ip in ips
|
||||
]
|
||||
session = _make_batch_session(batch_response)
|
||||
|
||||
await geo_service.lookup_batch(ips, session, db=None)
|
||||
|
||||
for ip in ips:
|
||||
assert ip in geo_service._dirty # type: ignore[attr-defined]
|
||||
|
||||
|
||||
class TestFlushDirty:
|
||||
"""flush_dirty() persists dirty entries and clears the set."""
|
||||
|
||||
async def test_flush_writes_and_clears_dirty(self) -> None:
|
||||
"""flush_dirty() inserts all dirty IPs and clears _dirty afterwards."""
|
||||
info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT")
|
||||
geo_service._store("100.0.0.1", info) # type: ignore[attr-defined]
|
||||
assert "100.0.0.1" in geo_service._dirty # type: ignore[attr-defined]
|
||||
|
||||
db = _make_async_db()
|
||||
count = await geo_service.flush_dirty(db)
|
||||
|
||||
assert count == 1
|
||||
db.executemany.assert_awaited_once()
|
||||
db.commit.assert_awaited_once()
|
||||
assert "100.0.0.1" not in geo_service._dirty # type: ignore[attr-defined]
|
||||
|
||||
async def test_flush_returns_zero_when_nothing_dirty(self) -> None:
|
||||
"""flush_dirty() returns 0 and makes no DB calls when _dirty is empty."""
|
||||
db = _make_async_db()
|
||||
count = await geo_service.flush_dirty(db)
|
||||
|
||||
assert count == 0
|
||||
db.executemany.assert_not_awaited()
|
||||
db.commit.assert_not_awaited()
|
||||
|
||||
async def test_flush_re_adds_to_dirty_on_db_error(self) -> None:
|
||||
"""When the DB write fails, entries are re-added to _dirty for retry."""
|
||||
info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP")
|
||||
geo_service._store("200.0.0.1", info) # type: ignore[attr-defined]
|
||||
|
||||
db = _make_async_db()
|
||||
db.executemany = AsyncMock(side_effect=OSError("disk full"))
|
||||
|
||||
count = await geo_service.flush_dirty(db)
|
||||
|
||||
assert count == 0
|
||||
assert "200.0.0.1" in geo_service._dirty # type: ignore[attr-defined]
|
||||
|
||||
async def test_flush_batch_and_lookup_batch_integration(self) -> None:
|
||||
"""lookup_batch() populates _dirty; flush_dirty() then persists them."""
|
||||
ips = ["10.1.2.3", "10.1.2.4"]
|
||||
batch_response = [
|
||||
{"query": ip, "status": "success", "countryCode": "CA", "country": "Canada", "as": "AS812", "org": "Bell"}
|
||||
for ip in ips
|
||||
]
|
||||
session = _make_batch_session(batch_response)
|
||||
|
||||
# Resolve without DB to populate only in-memory cache and _dirty.
|
||||
await geo_service.lookup_batch(ips, session, db=None)
|
||||
assert geo_service._dirty == set(ips) # type: ignore[attr-defined]
|
||||
|
||||
# Now flush to the DB.
|
||||
db = _make_async_db()
|
||||
count = await geo_service.flush_dirty(db)
|
||||
|
||||
assert count == 2
|
||||
assert not geo_service._dirty # type: ignore[attr-defined]
|
||||
db.commit.assert_awaited_once()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user