Add async lock protection to geo service cache and mark Task 16 done

This commit is contained in:
2026-04-17 16:51:05 +02:00
parent 04b2e2f700
commit 1e2850a34e
6 changed files with 77 additions and 56 deletions

View File

@@ -109,25 +109,30 @@ _dirty: set[str] = set()
#: Optional MaxMind GeoLite2 reader initialised by :func:`init_geoip`.
_geoip_reader: geoip2.database.Reader | None = None
#: Lock protecting mutations to the in-memory geo caches.
_cache_lock: asyncio.Lock = asyncio.Lock()
def clear_cache() -> None:
async def clear_cache() -> None:
"""Flush both the positive and negative lookup caches.
Also clears the dirty set so any pending-but-unpersisted entries are
discarded. Useful in tests and when the operator suspects stale data.
"""
_cache.clear()
_neg_cache.clear()
_dirty.clear()
async with _cache_lock:
_cache.clear()
_neg_cache.clear()
_dirty.clear()
def clear_neg_cache() -> None:
async def clear_neg_cache() -> None:
"""Flush only the negative (failed-lookups) cache.
Useful when triggering a manual re-resolve so that previously failed
IPs are immediately eligible for a new API attempt.
"""
_neg_cache.clear()
async with _cache_lock:
_neg_cache.clear()
def is_cached(ip: str) -> bool:
@@ -208,7 +213,7 @@ async def re_resolve_all(
if not unresolved:
return {"resolved": 0, "total": 0}
clear_neg_cache()
await clear_neg_cache()
geo_map = await lookup_batch(unresolved, http_session, db=db)
resolved_count = sum(
1 for info in geo_map.values() if info.country_code is not None
@@ -292,18 +297,28 @@ async def load_cache_from_db(db: aiosqlite.Connection) -> None:
database (not the fail2ban database).
"""
count = 0
cache_entries: list[tuple[str, GeoInfo]] = []
for row in await geo_cache_repo.load_all(db):
country_code: str | None = row["country_code"]
if country_code is None:
continue
ip: str = row["ip"]
_cache[ip] = GeoInfo(
country_code=country_code,
country_name=row["country_name"],
asn=row["asn"],
org=row["org"],
cache_entries.append(
(
ip,
GeoInfo(
country_code=country_code,
country_name=row["country_name"],
asn=row["asn"],
org=row["org"],
),
)
)
count += 1
async with _cache_lock:
for ip, info in cache_entries:
_cache[ip] = info
log.info("geo_cache_loaded_from_db", entries=count)
@@ -395,7 +410,7 @@ async def lookup(
if data.get("status") == "success":
api_ok = True
result = _parse_single_response(data)
_store(ip, result)
await _store(ip, result)
if result.country_code is not None and db is not None:
try:
await _persist_entry(db, ip, result)
@@ -421,7 +436,7 @@ async def lookup(
# Try local MaxMind database as fallback.
fallback = _geoip_lookup(ip)
if fallback is not None:
_store(ip, fallback)
await _store(ip, fallback)
if fallback.country_code is not None and db is not None:
try:
await _persist_entry(db, ip, fallback)
@@ -432,7 +447,8 @@ async def lookup(
return fallback
# Both resolvers failed — record in negative cache to avoid hammering.
_neg_cache[ip] = time.monotonic()
async with _cache_lock:
_neg_cache[ip] = time.monotonic()
if db is not None:
try:
await _persist_neg_entry(db, ip)
@@ -566,7 +582,7 @@ async def lookup_batch(
for ip, info in chunk_result.items():
if info.country_code is not None:
# Successful API resolution.
_store(ip, info)
await _store(ip, info)
geo_result[ip] = info
if db is not None:
pos_rows.append(
@@ -590,7 +606,8 @@ async def lookup_batch(
)
else:
# Both resolvers failed — record in negative cache.
_neg_cache[ip] = time.monotonic()
async with _cache_lock:
_neg_cache[ip] = time.monotonic()
geo_result[ip] = _empty
if db is not None:
neg_ips.append(ip)
@@ -735,7 +752,7 @@ def _str_or_none(value: object) -> str | None:
return s if s else None
def _store(ip: str, info: GeoInfo) -> None:
async def _store(ip: str, info: GeoInfo) -> None:
"""Insert *info* into the module-level cache, flushing if over capacity.
When the IP resolved successfully (``country_code is not None``) it is
@@ -746,13 +763,14 @@ def _store(ip: str, info: GeoInfo) -> None:
ip: The IP address key.
info: The :class:`GeoInfo` to store.
"""
if len(_cache) >= _MAX_CACHE_SIZE:
_cache.clear()
_dirty.clear()
log.info("geo_cache_flushed", reason="capacity")
_cache[ip] = info
if info.country_code is not None:
_dirty.add(ip)
async with _cache_lock:
if len(_cache) >= _MAX_CACHE_SIZE:
_cache.clear()
_dirty.clear()
log.info("geo_cache_flushed", reason="capacity")
_cache[ip] = info
if info.country_code is not None:
_dirty.add(ip)
async def flush_dirty(db: aiosqlite.Connection) -> int:
@@ -773,19 +791,20 @@ async def flush_dirty(db: aiosqlite.Connection) -> int:
Returns:
The number of rows successfully upserted.
"""
if not _dirty:
return 0
async with _cache_lock:
if not _dirty:
return 0
# Atomically snapshot and clear in a single-threaded async context.
# No ``await`` between copy and clear ensures no interleaving.
to_flush = _dirty.copy()
_dirty.clear()
# Atomically snapshot and clear while holding the cache lock.
to_flush = _dirty.copy()
_dirty.clear()
rows = [
(ip, _cache[ip].country_code, _cache[ip].country_name, _cache[ip].asn, _cache[ip].org)
for ip in to_flush
if ip in _cache
]
rows = [
(ip, _cache[ip].country_code, _cache[ip].country_name, _cache[ip].asn, _cache[ip].org)
for ip in to_flush
if ip in _cache
]
if not rows:
return 0

View File

@@ -67,7 +67,7 @@ async def _run_re_resolve_with_resources(settings: "Settings", http_session: "Cl
log.info("geo_re_resolve_start", unresolved=len(unresolved_ips))
# Clear the negative cache so these IPs are eligible for fresh API calls.
geo_service.clear_neg_cache()
await geo_service.clear_neg_cache()
# lookup_batch handles throttling, retries, and persistence when db is
# passed. This is a background task so DB writes are allowed.