Move geo cache commit handling into repository layer

This commit is contained in:
2026-04-18 20:10:05 +02:00
parent be1d66988f
commit c1f188643c
5 changed files with 168 additions and 34 deletions

View File

@@ -104,6 +104,19 @@ async def upsert_entry(
)
async def upsert_entry_and_commit(
db: aiosqlite.Connection,
ip: str,
country_code: str | None,
country_name: str | None,
asn: str | None,
org: str | None,
) -> None:
"""Insert or update a resolved geo cache entry and commit."""
await upsert_entry(db, ip, country_code, country_name, asn, org)
await db.commit()
async def upsert_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
"""Record a failed lookup attempt as a negative entry."""
await db.execute(
@@ -112,6 +125,12 @@ async def upsert_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
)
async def upsert_neg_entry_and_commit(db: aiosqlite.Connection, ip: str) -> None:
"""Record a failed lookup attempt and commit the transaction."""
await upsert_neg_entry(db, ip)
await db.commit()
async def bulk_upsert_entries(
db: aiosqlite.Connection,
rows: Sequence[tuple[str, str | None, str | None, str | None, str | None]],
@@ -146,3 +165,40 @@ async def bulk_upsert_neg_entries(db: aiosqlite.Connection, ips: list[str]) -> i
[(ip,) for ip in ips],
)
return len(ips)
async def bulk_upsert_entries_and_commit(
db: aiosqlite.Connection,
rows: Sequence[tuple[str, str | None, str | None, str | None, str | None]],
) -> int:
"""Bulk insert or update multiple geo cache entries and commit."""
count = await bulk_upsert_entries(db, rows)
await db.commit()
return count
async def bulk_upsert_neg_entries_and_commit(db: aiosqlite.Connection, ips: list[str]) -> int:
"""Bulk insert negative lookup entries and commit."""
count = await bulk_upsert_neg_entries(db, ips)
await db.commit()
return count
async def bulk_upsert_entries_and_neg_entries_and_commit(
db: aiosqlite.Connection,
rows: Sequence[tuple[str, str | None, str | None, str | None, str | None]],
ips: list[str],
) -> tuple[int, int]:
"""Persist positive and negative geo cache rows together, then commit."""
positive_count = 0
negative_count = 0
if rows:
positive_count = await bulk_upsert_entries(db, rows)
if ips:
negative_count = await bulk_upsert_neg_entries(db, ips)
if rows or ips:
await db.commit()
return positive_count, negative_count

View File

@@ -156,9 +156,23 @@ class GeoCacheRepository(Protocol):
) -> None:
...
async def upsert_entry_and_commit(
self,
db: aiosqlite.Connection,
ip: str,
country_code: str | None,
country_name: str | None,
asn: str | None,
org: str | None,
) -> None:
...
async def upsert_neg_entry(self, db: aiosqlite.Connection, ip: str) -> None:
...
async def upsert_neg_entry_and_commit(self, db: aiosqlite.Connection, ip: str) -> None:
...
async def bulk_upsert_entries(
self,
db: aiosqlite.Connection,
@@ -166,9 +180,27 @@ class GeoCacheRepository(Protocol):
) -> int:
...
async def bulk_upsert_entries_and_commit(
self,
db: aiosqlite.Connection,
rows: Iterable[tuple[str, str | None, str | None, str | None, str | None]],
) -> int:
...
async def bulk_upsert_neg_entries(self, db: aiosqlite.Connection, ips: list[str]) -> int:
...
async def bulk_upsert_neg_entries_and_commit(self, db: aiosqlite.Connection, ips: list[str]) -> int:
...
async def bulk_upsert_entries_and_neg_entries_and_commit(
self,
db: aiosqlite.Connection,
rows: Iterable[tuple[str, str | None, str | None, str | None, str | None]],
ips: list[str],
) -> tuple[int, int]:
...
class HistoryArchiveRepository(Protocol):
"""Protocol for archived ban history persistence operations."""

View File

@@ -425,8 +425,14 @@ async def lookup(
await _store(ip, result)
if result.country_code is not None and db is not None:
try:
await _persist_entry(db, ip, result)
await db.commit()
await geo_cache_repo.upsert_entry_and_commit(
db=db,
ip=ip,
country_code=result.country_code,
country_name=result.country_name,
asn=result.asn,
org=result.org,
)
except Exception as exc: # noqa: BLE001
log.warning("geo_persist_failed", ip=ip, error=str(exc))
log.debug("geo_lookup_success", ip=ip, country=result.country_code, asn=result.asn)
@@ -451,8 +457,14 @@ async def lookup(
await _store(ip, fallback)
if fallback.country_code is not None and db is not None:
try:
await _persist_entry(db, ip, fallback)
await db.commit()
await geo_cache_repo.upsert_entry_and_commit(
db=db,
ip=ip,
country_code=fallback.country_code,
country_name=fallback.country_name,
asn=fallback.asn,
org=fallback.org,
)
except Exception as exc: # noqa: BLE001
log.warning("geo_persist_failed", ip=ip, error=str(exc))
log.debug("geo_geoip_fallback_success", ip=ip, country=fallback.country_code)
@@ -463,8 +475,7 @@ async def lookup(
_neg_cache[ip] = time.monotonic()
if db is not None:
try:
await _persist_neg_entry(db, ip)
await db.commit()
await geo_cache_repo.upsert_neg_entry_and_commit(db=db, ip=ip)
except Exception as exc: # noqa: BLE001
log.warning("geo_persist_neg_failed", ip=ip, error=str(exc))
@@ -604,7 +615,7 @@ async def lookup_batch(
# API failed — try local GeoIP fallback.
fallback = _geoip_lookup(ip)
if fallback is not None:
_store(ip, fallback)
await _store(ip, fallback)
geo_result[ip] = fallback
if db is not None:
pos_rows.append(
@@ -624,31 +635,20 @@ async def lookup_batch(
if db is not None:
neg_ips.append(ip)
if db is not None:
if pos_rows:
try:
await geo_cache_repo.bulk_upsert_entries(db, pos_rows)
except Exception as exc: # noqa: BLE001
log.warning(
"geo_batch_persist_failed",
count=len(pos_rows),
error=str(exc),
)
if neg_ips:
try:
await geo_cache_repo.bulk_upsert_neg_entries(db, neg_ips)
except Exception as exc: # noqa: BLE001
log.warning(
"geo_batch_persist_neg_failed",
count=len(neg_ips),
error=str(exc),
)
if db is not None:
try:
await db.commit()
except Exception as exc: # noqa: BLE001
log.warning("geo_batch_commit_failed", error=str(exc))
if db is not None and (pos_rows or neg_ips):
try:
await geo_cache_repo.bulk_upsert_entries_and_neg_entries_and_commit(
db,
pos_rows,
neg_ips,
)
except Exception as exc: # noqa: BLE001
log.warning(
"geo_batch_persist_failed",
positive_count=len(pos_rows),
negative_count=len(neg_ips),
error=str(exc),
)
log.info(
"geo_batch_lookup_complete",
@@ -821,8 +821,7 @@ async def flush_dirty(db: aiosqlite.Connection) -> int:
return 0
try:
await geo_cache_repo.bulk_upsert_entries(db, rows)
await db.commit()
await geo_cache_repo.bulk_upsert_entries_and_commit(db, rows)
except Exception as exc: # noqa: BLE001
log.warning("geo_flush_dirty_failed", error=str(exc))
# Re-add to dirty so they are retried on the next flush cycle.

View File

@@ -116,6 +116,51 @@ async def test_upsert_entry_and_neg_entry(tmp_path: Path) -> None:
assert row[0] is None
@pytest.mark.asyncio
async def test_upsert_entry_and_commit_commits_transaction(tmp_path: Path) -> None:
db_path = str(tmp_path / "geo_cache.db")
async with aiosqlite.connect(db_path) as db:
await _create_geo_cache_table(db)
await geo_cache_repo.upsert_entry_and_commit(
db,
"13.13.13.13",
"NL",
"Netherlands",
"AS1313",
"TestOrg",
)
async with db.execute("SELECT country_code FROM geo_cache WHERE ip = ?", ("13.13.13.13",)) as cur:
row = await cur.fetchone()
assert row is not None
assert row[0] == "NL"
@pytest.mark.asyncio
async def test_bulk_upsert_entries_and_neg_entries_and_commit_commits_once(tmp_path: Path) -> None:
db_path = str(tmp_path / "geo_cache.db")
async with aiosqlite.connect(db_path) as db:
await _create_geo_cache_table(db)
rows = [
("14.14.14.14", "BE", "Belgium", "AS1414", "Test"),
]
count, neg_count = await geo_cache_repo.bulk_upsert_entries_and_neg_entries_and_commit(
db,
rows,
["15.15.15.15"],
)
assert count == 1
assert neg_count == 1
async with db.execute("SELECT COUNT(*) FROM geo_cache") as cur:
row = await cur.fetchone()
assert row is not None
assert int(row[0]) == 2
@pytest.mark.asyncio
async def test_bulk_upsert_entries_and_neg_entries(tmp_path: Path) -> None:
db_path = str(tmp_path / "geo_cache.db")