Add async lock protection to geo service cache and mark Task 16 done
This commit is contained in:
@@ -109,25 +109,30 @@ _dirty: set[str] = set()
|
||||
#: Optional MaxMind GeoLite2 reader initialised by :func:`init_geoip`.
|
||||
_geoip_reader: geoip2.database.Reader | None = None
|
||||
|
||||
#: Lock protecting mutations to the in-memory geo caches.
|
||||
_cache_lock: asyncio.Lock = asyncio.Lock()
|
||||
|
||||
def clear_cache() -> None:
|
||||
|
||||
async def clear_cache() -> None:
|
||||
"""Flush both the positive and negative lookup caches.
|
||||
|
||||
Also clears the dirty set so any pending-but-unpersisted entries are
|
||||
discarded. Useful in tests and when the operator suspects stale data.
|
||||
"""
|
||||
_cache.clear()
|
||||
_neg_cache.clear()
|
||||
_dirty.clear()
|
||||
async with _cache_lock:
|
||||
_cache.clear()
|
||||
_neg_cache.clear()
|
||||
_dirty.clear()
|
||||
|
||||
|
||||
def clear_neg_cache() -> None:
|
||||
async def clear_neg_cache() -> None:
|
||||
"""Flush only the negative (failed-lookups) cache.
|
||||
|
||||
Useful when triggering a manual re-resolve so that previously failed
|
||||
IPs are immediately eligible for a new API attempt.
|
||||
"""
|
||||
_neg_cache.clear()
|
||||
async with _cache_lock:
|
||||
_neg_cache.clear()
|
||||
|
||||
|
||||
def is_cached(ip: str) -> bool:
|
||||
@@ -208,7 +213,7 @@ async def re_resolve_all(
|
||||
if not unresolved:
|
||||
return {"resolved": 0, "total": 0}
|
||||
|
||||
clear_neg_cache()
|
||||
await clear_neg_cache()
|
||||
geo_map = await lookup_batch(unresolved, http_session, db=db)
|
||||
resolved_count = sum(
|
||||
1 for info in geo_map.values() if info.country_code is not None
|
||||
@@ -292,18 +297,28 @@ async def load_cache_from_db(db: aiosqlite.Connection) -> None:
|
||||
database (not the fail2ban database).
|
||||
"""
|
||||
count = 0
|
||||
cache_entries: list[tuple[str, GeoInfo]] = []
|
||||
for row in await geo_cache_repo.load_all(db):
|
||||
country_code: str | None = row["country_code"]
|
||||
if country_code is None:
|
||||
continue
|
||||
ip: str = row["ip"]
|
||||
_cache[ip] = GeoInfo(
|
||||
country_code=country_code,
|
||||
country_name=row["country_name"],
|
||||
asn=row["asn"],
|
||||
org=row["org"],
|
||||
cache_entries.append(
|
||||
(
|
||||
ip,
|
||||
GeoInfo(
|
||||
country_code=country_code,
|
||||
country_name=row["country_name"],
|
||||
asn=row["asn"],
|
||||
org=row["org"],
|
||||
),
|
||||
)
|
||||
)
|
||||
count += 1
|
||||
|
||||
async with _cache_lock:
|
||||
for ip, info in cache_entries:
|
||||
_cache[ip] = info
|
||||
log.info("geo_cache_loaded_from_db", entries=count)
|
||||
|
||||
|
||||
@@ -395,7 +410,7 @@ async def lookup(
|
||||
if data.get("status") == "success":
|
||||
api_ok = True
|
||||
result = _parse_single_response(data)
|
||||
_store(ip, result)
|
||||
await _store(ip, result)
|
||||
if result.country_code is not None and db is not None:
|
||||
try:
|
||||
await _persist_entry(db, ip, result)
|
||||
@@ -421,7 +436,7 @@ async def lookup(
|
||||
# Try local MaxMind database as fallback.
|
||||
fallback = _geoip_lookup(ip)
|
||||
if fallback is not None:
|
||||
_store(ip, fallback)
|
||||
await _store(ip, fallback)
|
||||
if fallback.country_code is not None and db is not None:
|
||||
try:
|
||||
await _persist_entry(db, ip, fallback)
|
||||
@@ -432,7 +447,8 @@ async def lookup(
|
||||
return fallback
|
||||
|
||||
# Both resolvers failed — record in negative cache to avoid hammering.
|
||||
_neg_cache[ip] = time.monotonic()
|
||||
async with _cache_lock:
|
||||
_neg_cache[ip] = time.monotonic()
|
||||
if db is not None:
|
||||
try:
|
||||
await _persist_neg_entry(db, ip)
|
||||
@@ -566,7 +582,7 @@ async def lookup_batch(
|
||||
for ip, info in chunk_result.items():
|
||||
if info.country_code is not None:
|
||||
# Successful API resolution.
|
||||
_store(ip, info)
|
||||
await _store(ip, info)
|
||||
geo_result[ip] = info
|
||||
if db is not None:
|
||||
pos_rows.append(
|
||||
@@ -590,7 +606,8 @@ async def lookup_batch(
|
||||
)
|
||||
else:
|
||||
# Both resolvers failed — record in negative cache.
|
||||
_neg_cache[ip] = time.monotonic()
|
||||
async with _cache_lock:
|
||||
_neg_cache[ip] = time.monotonic()
|
||||
geo_result[ip] = _empty
|
||||
if db is not None:
|
||||
neg_ips.append(ip)
|
||||
@@ -735,7 +752,7 @@ def _str_or_none(value: object) -> str | None:
|
||||
return s if s else None
|
||||
|
||||
|
||||
def _store(ip: str, info: GeoInfo) -> None:
|
||||
async def _store(ip: str, info: GeoInfo) -> None:
|
||||
"""Insert *info* into the module-level cache, flushing if over capacity.
|
||||
|
||||
When the IP resolved successfully (``country_code is not None``) it is
|
||||
@@ -746,13 +763,14 @@ def _store(ip: str, info: GeoInfo) -> None:
|
||||
ip: The IP address key.
|
||||
info: The :class:`GeoInfo` to store.
|
||||
"""
|
||||
if len(_cache) >= _MAX_CACHE_SIZE:
|
||||
_cache.clear()
|
||||
_dirty.clear()
|
||||
log.info("geo_cache_flushed", reason="capacity")
|
||||
_cache[ip] = info
|
||||
if info.country_code is not None:
|
||||
_dirty.add(ip)
|
||||
async with _cache_lock:
|
||||
if len(_cache) >= _MAX_CACHE_SIZE:
|
||||
_cache.clear()
|
||||
_dirty.clear()
|
||||
log.info("geo_cache_flushed", reason="capacity")
|
||||
_cache[ip] = info
|
||||
if info.country_code is not None:
|
||||
_dirty.add(ip)
|
||||
|
||||
|
||||
async def flush_dirty(db: aiosqlite.Connection) -> int:
|
||||
@@ -773,19 +791,20 @@ async def flush_dirty(db: aiosqlite.Connection) -> int:
|
||||
Returns:
|
||||
The number of rows successfully upserted.
|
||||
"""
|
||||
if not _dirty:
|
||||
return 0
|
||||
async with _cache_lock:
|
||||
if not _dirty:
|
||||
return 0
|
||||
|
||||
# Atomically snapshot and clear in a single-threaded async context.
|
||||
# No ``await`` between copy and clear ensures no interleaving.
|
||||
to_flush = _dirty.copy()
|
||||
_dirty.clear()
|
||||
# Atomically snapshot and clear while holding the cache lock.
|
||||
to_flush = _dirty.copy()
|
||||
_dirty.clear()
|
||||
|
||||
rows = [
|
||||
(ip, _cache[ip].country_code, _cache[ip].country_name, _cache[ip].asn, _cache[ip].org)
|
||||
for ip in to_flush
|
||||
if ip in _cache
|
||||
]
|
||||
|
||||
rows = [
|
||||
(ip, _cache[ip].country_code, _cache[ip].country_name, _cache[ip].asn, _cache[ip].org)
|
||||
for ip in to_flush
|
||||
if ip in _cache
|
||||
]
|
||||
if not rows:
|
||||
return 0
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ async def _run_re_resolve_with_resources(settings: "Settings", http_session: "Cl
|
||||
log.info("geo_re_resolve_start", unresolved=len(unresolved_ips))
|
||||
|
||||
# Clear the negative cache so these IPs are eligible for fresh API calls.
|
||||
geo_service.clear_neg_cache()
|
||||
await geo_service.clear_neg_cache()
|
||||
|
||||
# lookup_batch handles throttling, retries, and persistence when db is
|
||||
# passed. This is a background task so DB writes are allowed.
|
||||
|
||||
Reference in New Issue
Block a user