Add async lock protection to geo service cache and mark Task 16 done
This commit is contained in:
@@ -310,6 +310,8 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
|
|||||||
|
|
||||||
**Docs changes needed:** Update `Docs/Refactoring.md`.
|
**Docs changes needed:** Update `Docs/Refactoring.md`.
|
||||||
|
|
||||||
|
**Status:** Done ✅
|
||||||
|
|
||||||
**Why this is needed:** The current safety is implicit and fragile. A future change that adds an `await` inside the critical section (e.g. logging to a remote sink) would silently introduce data loss in the dirty-flush path. An explicit lock documents the intent and makes the safety guarantee unconditional.
|
**Why this is needed:** The current safety is implicit and fragile. A future change that adds an `await` inside the critical section (e.g. logging to a remote sink) would silently introduce data loss in the dirty-flush path. An explicit lock documents the intent and makes the safety guarantee unconditional.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|||||||
@@ -109,25 +109,30 @@ _dirty: set[str] = set()
|
|||||||
#: Optional MaxMind GeoLite2 reader initialised by :func:`init_geoip`.
|
#: Optional MaxMind GeoLite2 reader initialised by :func:`init_geoip`.
|
||||||
_geoip_reader: geoip2.database.Reader | None = None
|
_geoip_reader: geoip2.database.Reader | None = None
|
||||||
|
|
||||||
|
#: Lock protecting mutations to the in-memory geo caches.
|
||||||
|
_cache_lock: asyncio.Lock = asyncio.Lock()
|
||||||
|
|
||||||
def clear_cache() -> None:
|
|
||||||
|
async def clear_cache() -> None:
|
||||||
"""Flush both the positive and negative lookup caches.
|
"""Flush both the positive and negative lookup caches.
|
||||||
|
|
||||||
Also clears the dirty set so any pending-but-unpersisted entries are
|
Also clears the dirty set so any pending-but-unpersisted entries are
|
||||||
discarded. Useful in tests and when the operator suspects stale data.
|
discarded. Useful in tests and when the operator suspects stale data.
|
||||||
"""
|
"""
|
||||||
_cache.clear()
|
async with _cache_lock:
|
||||||
_neg_cache.clear()
|
_cache.clear()
|
||||||
_dirty.clear()
|
_neg_cache.clear()
|
||||||
|
_dirty.clear()
|
||||||
|
|
||||||
|
|
||||||
def clear_neg_cache() -> None:
|
async def clear_neg_cache() -> None:
|
||||||
"""Flush only the negative (failed-lookups) cache.
|
"""Flush only the negative (failed-lookups) cache.
|
||||||
|
|
||||||
Useful when triggering a manual re-resolve so that previously failed
|
Useful when triggering a manual re-resolve so that previously failed
|
||||||
IPs are immediately eligible for a new API attempt.
|
IPs are immediately eligible for a new API attempt.
|
||||||
"""
|
"""
|
||||||
_neg_cache.clear()
|
async with _cache_lock:
|
||||||
|
_neg_cache.clear()
|
||||||
|
|
||||||
|
|
||||||
def is_cached(ip: str) -> bool:
|
def is_cached(ip: str) -> bool:
|
||||||
@@ -208,7 +213,7 @@ async def re_resolve_all(
|
|||||||
if not unresolved:
|
if not unresolved:
|
||||||
return {"resolved": 0, "total": 0}
|
return {"resolved": 0, "total": 0}
|
||||||
|
|
||||||
clear_neg_cache()
|
await clear_neg_cache()
|
||||||
geo_map = await lookup_batch(unresolved, http_session, db=db)
|
geo_map = await lookup_batch(unresolved, http_session, db=db)
|
||||||
resolved_count = sum(
|
resolved_count = sum(
|
||||||
1 for info in geo_map.values() if info.country_code is not None
|
1 for info in geo_map.values() if info.country_code is not None
|
||||||
@@ -292,18 +297,28 @@ async def load_cache_from_db(db: aiosqlite.Connection) -> None:
|
|||||||
database (not the fail2ban database).
|
database (not the fail2ban database).
|
||||||
"""
|
"""
|
||||||
count = 0
|
count = 0
|
||||||
|
cache_entries: list[tuple[str, GeoInfo]] = []
|
||||||
for row in await geo_cache_repo.load_all(db):
|
for row in await geo_cache_repo.load_all(db):
|
||||||
country_code: str | None = row["country_code"]
|
country_code: str | None = row["country_code"]
|
||||||
if country_code is None:
|
if country_code is None:
|
||||||
continue
|
continue
|
||||||
ip: str = row["ip"]
|
ip: str = row["ip"]
|
||||||
_cache[ip] = GeoInfo(
|
cache_entries.append(
|
||||||
country_code=country_code,
|
(
|
||||||
country_name=row["country_name"],
|
ip,
|
||||||
asn=row["asn"],
|
GeoInfo(
|
||||||
org=row["org"],
|
country_code=country_code,
|
||||||
|
country_name=row["country_name"],
|
||||||
|
asn=row["asn"],
|
||||||
|
org=row["org"],
|
||||||
|
),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
count += 1
|
count += 1
|
||||||
|
|
||||||
|
async with _cache_lock:
|
||||||
|
for ip, info in cache_entries:
|
||||||
|
_cache[ip] = info
|
||||||
log.info("geo_cache_loaded_from_db", entries=count)
|
log.info("geo_cache_loaded_from_db", entries=count)
|
||||||
|
|
||||||
|
|
||||||
@@ -395,7 +410,7 @@ async def lookup(
|
|||||||
if data.get("status") == "success":
|
if data.get("status") == "success":
|
||||||
api_ok = True
|
api_ok = True
|
||||||
result = _parse_single_response(data)
|
result = _parse_single_response(data)
|
||||||
_store(ip, result)
|
await _store(ip, result)
|
||||||
if result.country_code is not None and db is not None:
|
if result.country_code is not None and db is not None:
|
||||||
try:
|
try:
|
||||||
await _persist_entry(db, ip, result)
|
await _persist_entry(db, ip, result)
|
||||||
@@ -421,7 +436,7 @@ async def lookup(
|
|||||||
# Try local MaxMind database as fallback.
|
# Try local MaxMind database as fallback.
|
||||||
fallback = _geoip_lookup(ip)
|
fallback = _geoip_lookup(ip)
|
||||||
if fallback is not None:
|
if fallback is not None:
|
||||||
_store(ip, fallback)
|
await _store(ip, fallback)
|
||||||
if fallback.country_code is not None and db is not None:
|
if fallback.country_code is not None and db is not None:
|
||||||
try:
|
try:
|
||||||
await _persist_entry(db, ip, fallback)
|
await _persist_entry(db, ip, fallback)
|
||||||
@@ -432,7 +447,8 @@ async def lookup(
|
|||||||
return fallback
|
return fallback
|
||||||
|
|
||||||
# Both resolvers failed — record in negative cache to avoid hammering.
|
# Both resolvers failed — record in negative cache to avoid hammering.
|
||||||
_neg_cache[ip] = time.monotonic()
|
async with _cache_lock:
|
||||||
|
_neg_cache[ip] = time.monotonic()
|
||||||
if db is not None:
|
if db is not None:
|
||||||
try:
|
try:
|
||||||
await _persist_neg_entry(db, ip)
|
await _persist_neg_entry(db, ip)
|
||||||
@@ -566,7 +582,7 @@ async def lookup_batch(
|
|||||||
for ip, info in chunk_result.items():
|
for ip, info in chunk_result.items():
|
||||||
if info.country_code is not None:
|
if info.country_code is not None:
|
||||||
# Successful API resolution.
|
# Successful API resolution.
|
||||||
_store(ip, info)
|
await _store(ip, info)
|
||||||
geo_result[ip] = info
|
geo_result[ip] = info
|
||||||
if db is not None:
|
if db is not None:
|
||||||
pos_rows.append(
|
pos_rows.append(
|
||||||
@@ -590,7 +606,8 @@ async def lookup_batch(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Both resolvers failed — record in negative cache.
|
# Both resolvers failed — record in negative cache.
|
||||||
_neg_cache[ip] = time.monotonic()
|
async with _cache_lock:
|
||||||
|
_neg_cache[ip] = time.monotonic()
|
||||||
geo_result[ip] = _empty
|
geo_result[ip] = _empty
|
||||||
if db is not None:
|
if db is not None:
|
||||||
neg_ips.append(ip)
|
neg_ips.append(ip)
|
||||||
@@ -735,7 +752,7 @@ def _str_or_none(value: object) -> str | None:
|
|||||||
return s if s else None
|
return s if s else None
|
||||||
|
|
||||||
|
|
||||||
def _store(ip: str, info: GeoInfo) -> None:
|
async def _store(ip: str, info: GeoInfo) -> None:
|
||||||
"""Insert *info* into the module-level cache, flushing if over capacity.
|
"""Insert *info* into the module-level cache, flushing if over capacity.
|
||||||
|
|
||||||
When the IP resolved successfully (``country_code is not None``) it is
|
When the IP resolved successfully (``country_code is not None``) it is
|
||||||
@@ -746,13 +763,14 @@ def _store(ip: str, info: GeoInfo) -> None:
|
|||||||
ip: The IP address key.
|
ip: The IP address key.
|
||||||
info: The :class:`GeoInfo` to store.
|
info: The :class:`GeoInfo` to store.
|
||||||
"""
|
"""
|
||||||
if len(_cache) >= _MAX_CACHE_SIZE:
|
async with _cache_lock:
|
||||||
_cache.clear()
|
if len(_cache) >= _MAX_CACHE_SIZE:
|
||||||
_dirty.clear()
|
_cache.clear()
|
||||||
log.info("geo_cache_flushed", reason="capacity")
|
_dirty.clear()
|
||||||
_cache[ip] = info
|
log.info("geo_cache_flushed", reason="capacity")
|
||||||
if info.country_code is not None:
|
_cache[ip] = info
|
||||||
_dirty.add(ip)
|
if info.country_code is not None:
|
||||||
|
_dirty.add(ip)
|
||||||
|
|
||||||
|
|
||||||
async def flush_dirty(db: aiosqlite.Connection) -> int:
|
async def flush_dirty(db: aiosqlite.Connection) -> int:
|
||||||
@@ -773,19 +791,20 @@ async def flush_dirty(db: aiosqlite.Connection) -> int:
|
|||||||
Returns:
|
Returns:
|
||||||
The number of rows successfully upserted.
|
The number of rows successfully upserted.
|
||||||
"""
|
"""
|
||||||
if not _dirty:
|
async with _cache_lock:
|
||||||
return 0
|
if not _dirty:
|
||||||
|
return 0
|
||||||
|
|
||||||
# Atomically snapshot and clear in a single-threaded async context.
|
# Atomically snapshot and clear while holding the cache lock.
|
||||||
# No ``await`` between copy and clear ensures no interleaving.
|
to_flush = _dirty.copy()
|
||||||
to_flush = _dirty.copy()
|
_dirty.clear()
|
||||||
_dirty.clear()
|
|
||||||
|
rows = [
|
||||||
|
(ip, _cache[ip].country_code, _cache[ip].country_name, _cache[ip].asn, _cache[ip].org)
|
||||||
|
for ip in to_flush
|
||||||
|
if ip in _cache
|
||||||
|
]
|
||||||
|
|
||||||
rows = [
|
|
||||||
(ip, _cache[ip].country_code, _cache[ip].country_name, _cache[ip].asn, _cache[ip].org)
|
|
||||||
for ip in to_flush
|
|
||||||
if ip in _cache
|
|
||||||
]
|
|
||||||
if not rows:
|
if not rows:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|||||||
@@ -67,7 +67,7 @@ async def _run_re_resolve_with_resources(settings: "Settings", http_session: "Cl
|
|||||||
log.info("geo_re_resolve_start", unresolved=len(unresolved_ips))
|
log.info("geo_re_resolve_start", unresolved=len(unresolved_ips))
|
||||||
|
|
||||||
# Clear the negative cache so these IPs are eligible for fresh API calls.
|
# Clear the negative cache so these IPs are eligible for fresh API calls.
|
||||||
geo_service.clear_neg_cache()
|
await geo_service.clear_neg_cache()
|
||||||
|
|
||||||
# lookup_batch handles throttling, retries, and persistence when db is
|
# lookup_batch handles throttling, retries, and persistence when db is
|
||||||
# passed. This is a background task so DB writes are allowed.
|
# passed. This is a background task so DB writes are allowed.
|
||||||
|
|||||||
@@ -700,7 +700,7 @@ class TestOriginFilter:
|
|||||||
assert len(result.bans) == 205
|
assert len(result.bans) == 205
|
||||||
assert all(b.country_code == "DE" for b in result.bans)
|
assert all(b.country_code == "DE" for b in result.bans)
|
||||||
|
|
||||||
geo_service.clear_cache()
|
await geo_service.clear_cache()
|
||||||
|
|
||||||
async def test_bans_by_country_source_archive_reads_archive(
|
async def test_bans_by_country_source_archive_reads_archive(
|
||||||
self, app_db_with_archive: aiosqlite.Connection
|
self, app_db_with_archive: aiosqlite.Connection
|
||||||
@@ -769,7 +769,7 @@ class TestBansbyCountryBackground:
|
|||||||
assert result.total == 3
|
assert result.total == 3
|
||||||
# Country counts should reflect the cached data.
|
# Country counts should reflect the cached data.
|
||||||
assert "DE" in result.countries or "US" in result.countries or "JP" in result.countries
|
assert "DE" in result.countries or "US" in result.countries or "JP" in result.countries
|
||||||
geo_service.clear_cache()
|
await geo_service.clear_cache()
|
||||||
|
|
||||||
async def test_uncached_ips_trigger_background_task(
|
async def test_uncached_ips_trigger_background_task(
|
||||||
self, mixed_origin_db_path: str
|
self, mixed_origin_db_path: str
|
||||||
@@ -778,7 +778,7 @@ class TestBansbyCountryBackground:
|
|||||||
resolution and the response returns without blocking."""
|
resolution and the response returns without blocking."""
|
||||||
from app.services import geo_service
|
from app.services import geo_service
|
||||||
|
|
||||||
geo_service.clear_cache() # ensure cache is empty
|
await geo_service.clear_cache() # ensure cache is empty
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
@@ -810,7 +810,7 @@ class TestBansbyCountryBackground:
|
|||||||
"""When http_session is None, no background task is created."""
|
"""When http_session is None, no background task is created."""
|
||||||
from app.services import geo_service
|
from app.services import geo_service
|
||||||
|
|
||||||
geo_service.clear_cache()
|
await geo_service.clear_cache()
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
|
|||||||
@@ -130,7 +130,7 @@ async def perf_db_path(tmp_path_factory: Any) -> str:
|
|||||||
ips = await _seed_f2b_db(path, _BAN_COUNT)
|
ips = await _seed_f2b_db(path, _BAN_COUNT)
|
||||||
|
|
||||||
# Pre-populate the in-memory geo cache so no network calls are made.
|
# Pre-populate the in-memory geo cache so no network calls are made.
|
||||||
geo_service.clear_cache()
|
await geo_service.clear_cache()
|
||||||
country_cycle = _COUNTRIES * (_BAN_COUNT // len(_COUNTRIES) + 1)
|
country_cycle = _COUNTRIES * (_BAN_COUNT // len(_COUNTRIES) + 1)
|
||||||
for i, ip in enumerate(ips):
|
for i, ip in enumerate(ips):
|
||||||
cc, cn = country_cycle[i]
|
cc, cn = country_cycle[i]
|
||||||
|
|||||||
@@ -45,9 +45,9 @@ def _make_session(response_json: dict[str, object], status: int = 200) -> MagicM
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def clear_geo_cache() -> None:
|
async def clear_geo_cache() -> None:
|
||||||
"""Flush the module-level geo cache before every test."""
|
"""Flush the module-level geo cache before every test."""
|
||||||
geo_service.clear_cache()
|
await geo_service.clear_cache()
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -162,7 +162,7 @@ class TestLookupCaching:
|
|||||||
)
|
)
|
||||||
|
|
||||||
await geo_service.lookup("2.3.4.5", session)
|
await geo_service.lookup("2.3.4.5", session)
|
||||||
geo_service.clear_cache()
|
await geo_service.clear_cache()
|
||||||
await geo_service.lookup("2.3.4.5", session)
|
await geo_service.lookup("2.3.4.5", session)
|
||||||
|
|
||||||
assert session.get.call_count == 2
|
assert session.get.call_count == 2
|
||||||
@@ -259,7 +259,7 @@ class TestNegativeCache:
|
|||||||
session = _make_session({"status": "fail", "message": "private range"})
|
session = _make_session({"status": "fail", "message": "private range"})
|
||||||
|
|
||||||
await geo_service.lookup("192.0.2.3", session)
|
await geo_service.lookup("192.0.2.3", session)
|
||||||
geo_service.clear_neg_cache()
|
await geo_service.clear_neg_cache()
|
||||||
await geo_service.lookup("192.0.2.3", session)
|
await geo_service.lookup("192.0.2.3", session)
|
||||||
|
|
||||||
assert session.get.call_count == 2
|
assert session.get.call_count == 2
|
||||||
@@ -474,27 +474,27 @@ class TestLookupBatchSingleCommit:
|
|||||||
class TestDirtySetTracking:
|
class TestDirtySetTracking:
|
||||||
"""_store() marks successfully resolved IPs as dirty."""
|
"""_store() marks successfully resolved IPs as dirty."""
|
||||||
|
|
||||||
def test_successful_resolution_adds_to_dirty(self) -> None:
|
async def test_successful_resolution_adds_to_dirty(self) -> None:
|
||||||
"""Storing a GeoInfo with a country_code adds the IP to _dirty."""
|
"""Storing a GeoInfo with a country_code adds the IP to _dirty."""
|
||||||
info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")
|
info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")
|
||||||
geo_service._store("1.2.3.4", info)
|
await geo_service._store("1.2.3.4", info)
|
||||||
|
|
||||||
assert "1.2.3.4" in geo_service._dirty
|
assert "1.2.3.4" in geo_service._dirty
|
||||||
|
|
||||||
def test_null_country_does_not_add_to_dirty(self) -> None:
|
async def test_null_country_does_not_add_to_dirty(self) -> None:
|
||||||
"""Storing a GeoInfo with country_code=None must not pollute _dirty."""
|
"""Storing a GeoInfo with country_code=None must not pollute _dirty."""
|
||||||
info = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
info = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
||||||
geo_service._store("10.0.0.1", info)
|
await geo_service._store("10.0.0.1", info)
|
||||||
|
|
||||||
assert "10.0.0.1" not in geo_service._dirty
|
assert "10.0.0.1" not in geo_service._dirty
|
||||||
|
|
||||||
def test_clear_cache_also_clears_dirty(self) -> None:
|
async def test_clear_cache_also_clears_dirty(self) -> None:
|
||||||
"""clear_cache() must discard any pending dirty entries."""
|
"""clear_cache() must discard any pending dirty entries."""
|
||||||
info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP")
|
info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP")
|
||||||
geo_service._store("8.8.8.8", info)
|
await geo_service._store("8.8.8.8", info)
|
||||||
assert geo_service._dirty
|
assert geo_service._dirty
|
||||||
|
|
||||||
geo_service.clear_cache()
|
await geo_service.clear_cache()
|
||||||
|
|
||||||
assert not geo_service._dirty
|
assert not geo_service._dirty
|
||||||
|
|
||||||
@@ -519,7 +519,7 @@ class TestFlushDirty:
|
|||||||
async def test_flush_writes_and_clears_dirty(self) -> None:
|
async def test_flush_writes_and_clears_dirty(self) -> None:
|
||||||
"""flush_dirty() inserts all dirty IPs and clears _dirty afterwards."""
|
"""flush_dirty() inserts all dirty IPs and clears _dirty afterwards."""
|
||||||
info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT")
|
info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT")
|
||||||
geo_service._store("100.0.0.1", info)
|
await geo_service._store("100.0.0.1", info)
|
||||||
assert "100.0.0.1" in geo_service._dirty
|
assert "100.0.0.1" in geo_service._dirty
|
||||||
|
|
||||||
db = _make_async_db()
|
db = _make_async_db()
|
||||||
@@ -542,7 +542,7 @@ class TestFlushDirty:
|
|||||||
async def test_flush_re_adds_to_dirty_on_db_error(self) -> None:
|
async def test_flush_re_adds_to_dirty_on_db_error(self) -> None:
|
||||||
"""When the DB write fails, entries are re-added to _dirty for retry."""
|
"""When the DB write fails, entries are re-added to _dirty for retry."""
|
||||||
info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP")
|
info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP")
|
||||||
geo_service._store("200.0.0.1", info)
|
await geo_service._store("200.0.0.1", info)
|
||||||
|
|
||||||
db = _make_async_db()
|
db = _make_async_db()
|
||||||
db.executemany = AsyncMock(side_effect=OSError("disk full"))
|
db.executemany = AsyncMock(side_effect=OSError("disk full"))
|
||||||
@@ -881,7 +881,7 @@ class TestReResolveAll:
|
|||||||
AsyncMock(return_value=geo_map),
|
AsyncMock(return_value=geo_map),
|
||||||
) as mock_lookup, patch(
|
) as mock_lookup, patch(
|
||||||
"app.services.geo_service.clear_neg_cache",
|
"app.services.geo_service.clear_neg_cache",
|
||||||
MagicMock(),
|
AsyncMock(),
|
||||||
) as mock_clear:
|
) as mock_clear:
|
||||||
result = await geo_service.re_resolve_all(db, session)
|
result = await geo_service.re_resolve_all(db, session)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user