Refactor geo cache persistence into repository + remove raw SQL from tasks/main, update task list

This commit is contained in:
2026-03-17 09:18:05 +01:00
parent 7866f9cbb2
commit 68114924bb
6 changed files with 230 additions and 78 deletions

View File

@@ -161,11 +161,7 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
await geo_service.load_cache_from_db(db)
# Log unresolved geo entries so the operator can see the scope of the issue.
async with db.execute(
"SELECT COUNT(*) FROM geo_cache WHERE country_code IS NULL"
) as cur:
row = await cur.fetchone()
unresolved_count: int = int(row[0]) if row else 0
unresolved_count = await geo_service.count_unresolved(db)
if unresolved_count > 0:
log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count)

View File

@@ -9,12 +9,48 @@ connection lifetimes.
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, TypedDict
if TYPE_CHECKING:
import aiosqlite
class GeoCacheRow(TypedDict):
"""A single row from the ``geo_cache`` table."""
ip: str
country_code: str | None
country_name: str | None
asn: str | None
org: str | None
async def load_all(db: aiosqlite.Connection) -> list[GeoCacheRow]:
"""Load all geo cache rows from the database.
Args:
db: Open BanGUI application database connection.
Returns:
List of rows from the ``geo_cache`` table.
"""
rows: list[GeoCacheRow] = []
async with db.execute(
"SELECT ip, country_code, country_name, asn, org FROM geo_cache"
) as cur:
async for row in cur:
rows.append(
GeoCacheRow(
ip=str(row[0]),
country_code=row[1],
country_name=row[2],
asn=row[3],
org=row[4],
)
)
return rows
async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
"""Return all IPs in ``geo_cache`` where ``country_code`` is NULL.
@@ -31,3 +67,80 @@ async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
async for row in cur:
ips.append(str(row[0]))
return ips
async def count_unresolved(db: aiosqlite.Connection) -> int:
"""Return the number of unresolved rows (country_code IS NULL)."""
async with db.execute(
"SELECT COUNT(*) FROM geo_cache WHERE country_code IS NULL"
) as cur:
row = await cur.fetchone()
return int(row[0]) if row else 0
async def upsert_entry(
db: aiosqlite.Connection,
ip: str,
country_code: str | None,
country_name: str | None,
asn: str | None,
org: str | None,
) -> None:
"""Insert or update a resolved geo cache entry."""
await db.execute(
"""
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(ip) DO UPDATE SET
country_code = excluded.country_code,
country_name = excluded.country_name,
asn = excluded.asn,
org = excluded.org,
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
""",
(ip, country_code, country_name, asn, org),
)
async def upsert_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
"""Record a failed lookup attempt as a negative entry."""
await db.execute(
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
(ip,),
)
async def bulk_upsert_entries(
db: aiosqlite.Connection,
rows: list[tuple[str, str | None, str | None, str | None, str | None]],
) -> int:
"""Bulk insert or update multiple geo cache entries."""
if not rows:
return 0
await db.executemany(
"""
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(ip) DO UPDATE SET
country_code = excluded.country_code,
country_name = excluded.country_name,
asn = excluded.asn,
org = excluded.org,
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
""",
rows,
)
return len(rows)
async def bulk_upsert_neg_entries(db: aiosqlite.Connection, ips: list[str]) -> int:
"""Bulk insert negative lookup entries."""
if not ips:
return 0
await db.executemany(
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
[(ip,) for ip in ips],
)
return len(ips)

View File

@@ -186,11 +186,7 @@ async def cache_stats(db: aiosqlite.Connection) -> dict[str, int]:
Dict with keys ``cache_size``, ``unresolved``, ``neg_cache_size``,
and ``dirty_size``.
"""
async with db.execute(
"SELECT COUNT(*) FROM geo_cache WHERE country_code IS NULL"
) as cur:
row = await cur.fetchone()
unresolved: int = int(row[0]) if row else 0
unresolved = await geo_cache_repo.count_unresolved(db)
return {
"cache_size": len(_cache),
@@ -200,6 +196,12 @@ async def cache_stats(db: aiosqlite.Connection) -> dict[str, int]:
}
async def count_unresolved(db: aiosqlite.Connection) -> int:
"""Return the number of unresolved entries in the persistent geo cache."""
return await geo_cache_repo.count_unresolved(db)
async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
"""Return geo cache IPs where the country code has not yet been resolved.
@@ -282,21 +284,18 @@ async def load_cache_from_db(db: aiosqlite.Connection) -> None:
database (not the fail2ban database).
"""
count = 0
async with db.execute(
"SELECT ip, country_code, country_name, asn, org FROM geo_cache"
) as cur:
async for row in cur:
ip: str = str(row[0])
country_code: str | None = row[1]
if country_code is None:
continue
_cache[ip] = GeoInfo(
country_code=country_code,
country_name=row[2],
asn=row[3],
org=row[4],
)
count += 1
for row in await geo_cache_repo.load_all(db):
country_code: str | None = row["country_code"]
if country_code is None:
continue
ip: str = row["ip"]
_cache[ip] = GeoInfo(
country_code=country_code,
country_name=row["country_name"],
asn=row["asn"],
org=row["org"],
)
count += 1
log.info("geo_cache_loaded_from_db", entries=count)
@@ -315,18 +314,13 @@ async def _persist_entry(
ip: IP address string.
info: Resolved geo data to persist.
"""
await db.execute(
"""
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(ip) DO UPDATE SET
country_code = excluded.country_code,
country_name = excluded.country_name,
asn = excluded.asn,
org = excluded.org,
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
""",
(ip, info.country_code, info.country_name, info.asn, info.org),
await geo_cache_repo.upsert_entry(
db=db,
ip=ip,
country_code=info.country_code,
country_name=info.country_name,
asn=info.asn,
org=info.org,
)
@@ -340,10 +334,7 @@ async def _persist_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
db: BanGUI application database connection.
ip: IP address string whose resolution failed.
"""
await db.execute(
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
(ip,),
)
await geo_cache_repo.upsert_neg_entry(db=db, ip=ip)
# ---------------------------------------------------------------------------
@@ -599,19 +590,7 @@ async def lookup_batch(
if db is not None:
if pos_rows:
try:
await db.executemany(
"""
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(ip) DO UPDATE SET
country_code = excluded.country_code,
country_name = excluded.country_name,
asn = excluded.asn,
org = excluded.org,
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
""",
pos_rows,
)
await geo_cache_repo.bulk_upsert_entries(db, pos_rows)
except Exception as exc: # noqa: BLE001
log.warning(
"geo_batch_persist_failed",
@@ -620,10 +599,7 @@ async def lookup_batch(
)
if neg_ips:
try:
await db.executemany(
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
[(ip,) for ip in neg_ips],
)
await geo_cache_repo.bulk_upsert_neg_entries(db, neg_ips)
except Exception as exc: # noqa: BLE001
log.warning(
"geo_batch_persist_neg_failed",
@@ -806,19 +782,7 @@ async def flush_dirty(db: aiosqlite.Connection) -> int:
return 0
try:
await db.executemany(
"""
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(ip) DO UPDATE SET
country_code = excluded.country_code,
country_name = excluded.country_name,
asn = excluded.asn,
org = excluded.org,
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
""",
rows,
)
await geo_cache_repo.bulk_upsert_entries(db, rows)
await db.commit()
except Exception as exc: # noqa: BLE001
log.warning("geo_flush_dirty_failed", error=str(exc))

View File

@@ -49,12 +49,7 @@ async def _run_re_resolve(app: Any) -> None:
http_session = app.state.http_session
# Fetch all IPs with NULL country_code from the persistent cache.
unresolved_ips: list[str] = []
async with db.execute(
"SELECT ip FROM geo_cache WHERE country_code IS NULL"
) as cursor:
async for row in cursor:
unresolved_ips.append(str(row[0]))
unresolved_ips = await geo_service.get_unresolved_ips(db)
if not unresolved_ips:
log.debug("geo_re_resolve_skip", reason="no_unresolved_ips")