Add async lock protection to geo service cache and mark Task 16 done

This commit is contained in:
2026-04-17 16:51:05 +02:00
parent 04b2e2f700
commit 1e2850a34e
6 changed files with 77 additions and 56 deletions

View File

@@ -310,6 +310,8 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
**Docs changes needed:** Update `Docs/Refactoring.md`. **Docs changes needed:** Update `Docs/Refactoring.md`.
**Status:** Done ✅
**Why this is needed:** The current safety is implicit and fragile. A future change that adds an `await` inside the critical section (e.g. logging to a remote sink) would silently introduce data loss in the dirty-flush path. An explicit lock documents the intent and makes the safety guarantee unconditional. **Why this is needed:** The current safety is implicit and fragile. A future change that adds an `await` inside the critical section (e.g. logging to a remote sink) would silently introduce data loss in the dirty-flush path. An explicit lock documents the intent and makes the safety guarantee unconditional.
--- ---

View File

@@ -109,25 +109,30 @@ _dirty: set[str] = set()
#: Optional MaxMind GeoLite2 reader initialised by :func:`init_geoip`. #: Optional MaxMind GeoLite2 reader initialised by :func:`init_geoip`.
_geoip_reader: geoip2.database.Reader | None = None _geoip_reader: geoip2.database.Reader | None = None
#: Lock protecting mutations to the in-memory geo caches.
_cache_lock: asyncio.Lock = asyncio.Lock()
def clear_cache() -> None:
async def clear_cache() -> None:
"""Flush both the positive and negative lookup caches. """Flush both the positive and negative lookup caches.
Also clears the dirty set so any pending-but-unpersisted entries are Also clears the dirty set so any pending-but-unpersisted entries are
discarded. Useful in tests and when the operator suspects stale data. discarded. Useful in tests and when the operator suspects stale data.
""" """
_cache.clear() async with _cache_lock:
_neg_cache.clear() _cache.clear()
_dirty.clear() _neg_cache.clear()
_dirty.clear()
def clear_neg_cache() -> None: async def clear_neg_cache() -> None:
"""Flush only the negative (failed-lookups) cache. """Flush only the negative (failed-lookups) cache.
Useful when triggering a manual re-resolve so that previously failed Useful when triggering a manual re-resolve so that previously failed
IPs are immediately eligible for a new API attempt. IPs are immediately eligible for a new API attempt.
""" """
_neg_cache.clear() async with _cache_lock:
_neg_cache.clear()
def is_cached(ip: str) -> bool: def is_cached(ip: str) -> bool:
@@ -208,7 +213,7 @@ async def re_resolve_all(
if not unresolved: if not unresolved:
return {"resolved": 0, "total": 0} return {"resolved": 0, "total": 0}
clear_neg_cache() await clear_neg_cache()
geo_map = await lookup_batch(unresolved, http_session, db=db) geo_map = await lookup_batch(unresolved, http_session, db=db)
resolved_count = sum( resolved_count = sum(
1 for info in geo_map.values() if info.country_code is not None 1 for info in geo_map.values() if info.country_code is not None
@@ -292,18 +297,28 @@ async def load_cache_from_db(db: aiosqlite.Connection) -> None:
database (not the fail2ban database). database (not the fail2ban database).
""" """
count = 0 count = 0
cache_entries: list[tuple[str, GeoInfo]] = []
for row in await geo_cache_repo.load_all(db): for row in await geo_cache_repo.load_all(db):
country_code: str | None = row["country_code"] country_code: str | None = row["country_code"]
if country_code is None: if country_code is None:
continue continue
ip: str = row["ip"] ip: str = row["ip"]
_cache[ip] = GeoInfo( cache_entries.append(
country_code=country_code, (
country_name=row["country_name"], ip,
asn=row["asn"], GeoInfo(
org=row["org"], country_code=country_code,
country_name=row["country_name"],
asn=row["asn"],
org=row["org"],
),
)
) )
count += 1 count += 1
async with _cache_lock:
for ip, info in cache_entries:
_cache[ip] = info
log.info("geo_cache_loaded_from_db", entries=count) log.info("geo_cache_loaded_from_db", entries=count)
@@ -395,7 +410,7 @@ async def lookup(
if data.get("status") == "success": if data.get("status") == "success":
api_ok = True api_ok = True
result = _parse_single_response(data) result = _parse_single_response(data)
_store(ip, result) await _store(ip, result)
if result.country_code is not None and db is not None: if result.country_code is not None and db is not None:
try: try:
await _persist_entry(db, ip, result) await _persist_entry(db, ip, result)
@@ -421,7 +436,7 @@ async def lookup(
# Try local MaxMind database as fallback. # Try local MaxMind database as fallback.
fallback = _geoip_lookup(ip) fallback = _geoip_lookup(ip)
if fallback is not None: if fallback is not None:
_store(ip, fallback) await _store(ip, fallback)
if fallback.country_code is not None and db is not None: if fallback.country_code is not None and db is not None:
try: try:
await _persist_entry(db, ip, fallback) await _persist_entry(db, ip, fallback)
@@ -432,7 +447,8 @@ async def lookup(
return fallback return fallback
# Both resolvers failed — record in negative cache to avoid hammering. # Both resolvers failed — record in negative cache to avoid hammering.
_neg_cache[ip] = time.monotonic() async with _cache_lock:
_neg_cache[ip] = time.monotonic()
if db is not None: if db is not None:
try: try:
await _persist_neg_entry(db, ip) await _persist_neg_entry(db, ip)
@@ -566,7 +582,7 @@ async def lookup_batch(
for ip, info in chunk_result.items(): for ip, info in chunk_result.items():
if info.country_code is not None: if info.country_code is not None:
# Successful API resolution. # Successful API resolution.
_store(ip, info) await _store(ip, info)
geo_result[ip] = info geo_result[ip] = info
if db is not None: if db is not None:
pos_rows.append( pos_rows.append(
@@ -590,7 +606,8 @@ async def lookup_batch(
) )
else: else:
# Both resolvers failed — record in negative cache. # Both resolvers failed — record in negative cache.
_neg_cache[ip] = time.monotonic() async with _cache_lock:
_neg_cache[ip] = time.monotonic()
geo_result[ip] = _empty geo_result[ip] = _empty
if db is not None: if db is not None:
neg_ips.append(ip) neg_ips.append(ip)
@@ -735,7 +752,7 @@ def _str_or_none(value: object) -> str | None:
return s if s else None return s if s else None
def _store(ip: str, info: GeoInfo) -> None: async def _store(ip: str, info: GeoInfo) -> None:
"""Insert *info* into the module-level cache, flushing if over capacity. """Insert *info* into the module-level cache, flushing if over capacity.
When the IP resolved successfully (``country_code is not None``) it is When the IP resolved successfully (``country_code is not None``) it is
@@ -746,13 +763,14 @@ def _store(ip: str, info: GeoInfo) -> None:
ip: The IP address key. ip: The IP address key.
info: The :class:`GeoInfo` to store. info: The :class:`GeoInfo` to store.
""" """
if len(_cache) >= _MAX_CACHE_SIZE: async with _cache_lock:
_cache.clear() if len(_cache) >= _MAX_CACHE_SIZE:
_dirty.clear() _cache.clear()
log.info("geo_cache_flushed", reason="capacity") _dirty.clear()
_cache[ip] = info log.info("geo_cache_flushed", reason="capacity")
if info.country_code is not None: _cache[ip] = info
_dirty.add(ip) if info.country_code is not None:
_dirty.add(ip)
async def flush_dirty(db: aiosqlite.Connection) -> int: async def flush_dirty(db: aiosqlite.Connection) -> int:
@@ -773,19 +791,20 @@ async def flush_dirty(db: aiosqlite.Connection) -> int:
Returns: Returns:
The number of rows successfully upserted. The number of rows successfully upserted.
""" """
if not _dirty: async with _cache_lock:
return 0 if not _dirty:
return 0
# Atomically snapshot and clear in a single-threaded async context. # Atomically snapshot and clear while holding the cache lock.
# No ``await`` between copy and clear ensures no interleaving. to_flush = _dirty.copy()
to_flush = _dirty.copy() _dirty.clear()
_dirty.clear()
rows = [
(ip, _cache[ip].country_code, _cache[ip].country_name, _cache[ip].asn, _cache[ip].org)
for ip in to_flush
if ip in _cache
]
rows = [
(ip, _cache[ip].country_code, _cache[ip].country_name, _cache[ip].asn, _cache[ip].org)
for ip in to_flush
if ip in _cache
]
if not rows: if not rows:
return 0 return 0

View File

@@ -67,7 +67,7 @@ async def _run_re_resolve_with_resources(settings: "Settings", http_session: "Cl
log.info("geo_re_resolve_start", unresolved=len(unresolved_ips)) log.info("geo_re_resolve_start", unresolved=len(unresolved_ips))
# Clear the negative cache so these IPs are eligible for fresh API calls. # Clear the negative cache so these IPs are eligible for fresh API calls.
geo_service.clear_neg_cache() await geo_service.clear_neg_cache()
# lookup_batch handles throttling, retries, and persistence when db is # lookup_batch handles throttling, retries, and persistence when db is
# passed. This is a background task so DB writes are allowed. # passed. This is a background task so DB writes are allowed.

View File

@@ -700,7 +700,7 @@ class TestOriginFilter:
assert len(result.bans) == 205 assert len(result.bans) == 205
assert all(b.country_code == "DE" for b in result.bans) assert all(b.country_code == "DE" for b in result.bans)
geo_service.clear_cache() await geo_service.clear_cache()
async def test_bans_by_country_source_archive_reads_archive( async def test_bans_by_country_source_archive_reads_archive(
self, app_db_with_archive: aiosqlite.Connection self, app_db_with_archive: aiosqlite.Connection
@@ -769,7 +769,7 @@ class TestBansbyCountryBackground:
assert result.total == 3 assert result.total == 3
# Country counts should reflect the cached data. # Country counts should reflect the cached data.
assert "DE" in result.countries or "US" in result.countries or "JP" in result.countries assert "DE" in result.countries or "US" in result.countries or "JP" in result.countries
geo_service.clear_cache() await geo_service.clear_cache()
async def test_uncached_ips_trigger_background_task( async def test_uncached_ips_trigger_background_task(
self, mixed_origin_db_path: str self, mixed_origin_db_path: str
@@ -778,7 +778,7 @@ class TestBansbyCountryBackground:
resolution and the response returns without blocking.""" resolution and the response returns without blocking."""
from app.services import geo_service from app.services import geo_service
geo_service.clear_cache() # ensure cache is empty await geo_service.clear_cache() # ensure cache is empty
with ( with (
patch( patch(
@@ -810,7 +810,7 @@ class TestBansbyCountryBackground:
"""When http_session is None, no background task is created.""" """When http_session is None, no background task is created."""
from app.services import geo_service from app.services import geo_service
geo_service.clear_cache() await geo_service.clear_cache()
with ( with (
patch( patch(

View File

@@ -130,7 +130,7 @@ async def perf_db_path(tmp_path_factory: Any) -> str:
ips = await _seed_f2b_db(path, _BAN_COUNT) ips = await _seed_f2b_db(path, _BAN_COUNT)
# Pre-populate the in-memory geo cache so no network calls are made. # Pre-populate the in-memory geo cache so no network calls are made.
geo_service.clear_cache() await geo_service.clear_cache()
country_cycle = _COUNTRIES * (_BAN_COUNT // len(_COUNTRIES) + 1) country_cycle = _COUNTRIES * (_BAN_COUNT // len(_COUNTRIES) + 1)
for i, ip in enumerate(ips): for i, ip in enumerate(ips):
cc, cn = country_cycle[i] cc, cn = country_cycle[i]

View File

@@ -45,9 +45,9 @@ def _make_session(response_json: dict[str, object], status: int = 200) -> MagicM
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def clear_geo_cache() -> None: async def clear_geo_cache() -> None:
"""Flush the module-level geo cache before every test.""" """Flush the module-level geo cache before every test."""
geo_service.clear_cache() await geo_service.clear_cache()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -162,7 +162,7 @@ class TestLookupCaching:
) )
await geo_service.lookup("2.3.4.5", session) await geo_service.lookup("2.3.4.5", session)
geo_service.clear_cache() await geo_service.clear_cache()
await geo_service.lookup("2.3.4.5", session) await geo_service.lookup("2.3.4.5", session)
assert session.get.call_count == 2 assert session.get.call_count == 2
@@ -259,7 +259,7 @@ class TestNegativeCache:
session = _make_session({"status": "fail", "message": "private range"}) session = _make_session({"status": "fail", "message": "private range"})
await geo_service.lookup("192.0.2.3", session) await geo_service.lookup("192.0.2.3", session)
geo_service.clear_neg_cache() await geo_service.clear_neg_cache()
await geo_service.lookup("192.0.2.3", session) await geo_service.lookup("192.0.2.3", session)
assert session.get.call_count == 2 assert session.get.call_count == 2
@@ -474,27 +474,27 @@ class TestLookupBatchSingleCommit:
class TestDirtySetTracking: class TestDirtySetTracking:
"""_store() marks successfully resolved IPs as dirty.""" """_store() marks successfully resolved IPs as dirty."""
def test_successful_resolution_adds_to_dirty(self) -> None: async def test_successful_resolution_adds_to_dirty(self) -> None:
"""Storing a GeoInfo with a country_code adds the IP to _dirty.""" """Storing a GeoInfo with a country_code adds the IP to _dirty."""
info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP") info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")
geo_service._store("1.2.3.4", info) await geo_service._store("1.2.3.4", info)
assert "1.2.3.4" in geo_service._dirty assert "1.2.3.4" in geo_service._dirty
def test_null_country_does_not_add_to_dirty(self) -> None: async def test_null_country_does_not_add_to_dirty(self) -> None:
"""Storing a GeoInfo with country_code=None must not pollute _dirty.""" """Storing a GeoInfo with country_code=None must not pollute _dirty."""
info = GeoInfo(country_code=None, country_name=None, asn=None, org=None) info = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
geo_service._store("10.0.0.1", info) await geo_service._store("10.0.0.1", info)
assert "10.0.0.1" not in geo_service._dirty assert "10.0.0.1" not in geo_service._dirty
def test_clear_cache_also_clears_dirty(self) -> None: async def test_clear_cache_also_clears_dirty(self) -> None:
"""clear_cache() must discard any pending dirty entries.""" """clear_cache() must discard any pending dirty entries."""
info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP") info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP")
geo_service._store("8.8.8.8", info) await geo_service._store("8.8.8.8", info)
assert geo_service._dirty assert geo_service._dirty
geo_service.clear_cache() await geo_service.clear_cache()
assert not geo_service._dirty assert not geo_service._dirty
@@ -519,7 +519,7 @@ class TestFlushDirty:
async def test_flush_writes_and_clears_dirty(self) -> None: async def test_flush_writes_and_clears_dirty(self) -> None:
"""flush_dirty() inserts all dirty IPs and clears _dirty afterwards.""" """flush_dirty() inserts all dirty IPs and clears _dirty afterwards."""
info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT") info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT")
geo_service._store("100.0.0.1", info) await geo_service._store("100.0.0.1", info)
assert "100.0.0.1" in geo_service._dirty assert "100.0.0.1" in geo_service._dirty
db = _make_async_db() db = _make_async_db()
@@ -542,7 +542,7 @@ class TestFlushDirty:
async def test_flush_re_adds_to_dirty_on_db_error(self) -> None: async def test_flush_re_adds_to_dirty_on_db_error(self) -> None:
"""When the DB write fails, entries are re-added to _dirty for retry.""" """When the DB write fails, entries are re-added to _dirty for retry."""
info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP") info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP")
geo_service._store("200.0.0.1", info) await geo_service._store("200.0.0.1", info)
db = _make_async_db() db = _make_async_db()
db.executemany = AsyncMock(side_effect=OSError("disk full")) db.executemany = AsyncMock(side_effect=OSError("disk full"))
@@ -881,7 +881,7 @@ class TestReResolveAll:
AsyncMock(return_value=geo_map), AsyncMock(return_value=geo_map),
) as mock_lookup, patch( ) as mock_lookup, patch(
"app.services.geo_service.clear_neg_cache", "app.services.geo_service.clear_neg_cache",
MagicMock(), AsyncMock(),
) as mock_clear: ) as mock_clear:
result = await geo_service.re_resolve_all(db, session) result = await geo_service.re_resolve_all(db, session)