diff --git a/Docs/Tasks.md b/Docs/Tasks.md index 50be70c..199ded2 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -310,6 +310,8 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue. **Docs changes needed:** Update `Docs/Refactoring.md`. +**Status:** Done ✅ + **Why this is needed:** The current safety is implicit and fragile. A future change that adds an `await` inside the critical section (e.g. logging to a remote sink) would silently introduce data loss in the dirty-flush path. An explicit lock documents the intent and makes the safety guarantee unconditional. --- diff --git a/backend/app/services/geo_service.py b/backend/app/services/geo_service.py index 90f8e4a..89df8ab 100644 --- a/backend/app/services/geo_service.py +++ b/backend/app/services/geo_service.py @@ -109,25 +109,30 @@ _dirty: set[str] = set() #: Optional MaxMind GeoLite2 reader initialised by :func:`init_geoip`. _geoip_reader: geoip2.database.Reader | None = None +#: Lock protecting mutations to the in-memory geo caches. +_cache_lock: asyncio.Lock = asyncio.Lock() -def clear_cache() -> None: + +async def clear_cache() -> None: """Flush both the positive and negative lookup caches. Also clears the dirty set so any pending-but-unpersisted entries are discarded. Useful in tests and when the operator suspects stale data. """ - _cache.clear() - _neg_cache.clear() - _dirty.clear() + async with _cache_lock: + _cache.clear() + _neg_cache.clear() + _dirty.clear() -def clear_neg_cache() -> None: +async def clear_neg_cache() -> None: """Flush only the negative (failed-lookups) cache. Useful when triggering a manual re-resolve so that previously failed IPs are immediately eligible for a new API attempt. """ - _neg_cache.clear() + async with _cache_lock: + _neg_cache.clear() def is_cached(ip: str) -> bool: @@ -208,7 +213,7 @@ async def re_resolve_all( if not unresolved: return {"resolved": 0, "total": 0} - clear_neg_cache() + await clear_neg_cache() geo_map = await lookup_batch(unresolved, http_session, db=db) resolved_count = sum( 1 for info in geo_map.values() if info.country_code is not None @@ -292,18 +297,28 @@ async def load_cache_from_db(db: aiosqlite.Connection) -> None: database (not the fail2ban database). """ count = 0 + cache_entries: list[tuple[str, GeoInfo]] = [] for row in await geo_cache_repo.load_all(db): country_code: str | None = row["country_code"] if country_code is None: continue ip: str = row["ip"] - _cache[ip] = GeoInfo( - country_code=country_code, - country_name=row["country_name"], - asn=row["asn"], - org=row["org"], + cache_entries.append( + ( + ip, + GeoInfo( + country_code=country_code, + country_name=row["country_name"], + asn=row["asn"], + org=row["org"], + ), + ) ) count += 1 + + async with _cache_lock: + for ip, info in cache_entries: + _cache[ip] = info log.info("geo_cache_loaded_from_db", entries=count) @@ -395,7 +410,7 @@ async def lookup( if data.get("status") == "success": api_ok = True result = _parse_single_response(data) - _store(ip, result) + await _store(ip, result) if result.country_code is not None and db is not None: try: await _persist_entry(db, ip, result) @@ -421,7 +436,7 @@ async def lookup( # Try local MaxMind database as fallback. fallback = _geoip_lookup(ip) if fallback is not None: - _store(ip, fallback) + await _store(ip, fallback) if fallback.country_code is not None and db is not None: try: await _persist_entry(db, ip, fallback) @@ -432,7 +447,8 @@ async def lookup( return fallback # Both resolvers failed — record in negative cache to avoid hammering. - _neg_cache[ip] = time.monotonic() + async with _cache_lock: + _neg_cache[ip] = time.monotonic() if db is not None: try: await _persist_neg_entry(db, ip) @@ -566,7 +582,7 @@ async def lookup_batch( for ip, info in chunk_result.items(): if info.country_code is not None: # Successful API resolution. - _store(ip, info) + await _store(ip, info) geo_result[ip] = info if db is not None: pos_rows.append( @@ -590,7 +606,8 @@ async def lookup_batch( ) else: # Both resolvers failed — record in negative cache. - _neg_cache[ip] = time.monotonic() + async with _cache_lock: + _neg_cache[ip] = time.monotonic() geo_result[ip] = _empty if db is not None: neg_ips.append(ip) @@ -735,7 +752,7 @@ def _str_or_none(value: object) -> str | None: return s if s else None -def _store(ip: str, info: GeoInfo) -> None: +async def _store(ip: str, info: GeoInfo) -> None: """Insert *info* into the module-level cache, flushing if over capacity. When the IP resolved successfully (``country_code is not None``) it is @@ -746,13 +763,14 @@ def _store(ip: str, info: GeoInfo) -> None: ip: The IP address key. info: The :class:`GeoInfo` to store. """ - if len(_cache) >= _MAX_CACHE_SIZE: - _cache.clear() - _dirty.clear() - log.info("geo_cache_flushed", reason="capacity") - _cache[ip] = info - if info.country_code is not None: - _dirty.add(ip) + async with _cache_lock: + if len(_cache) >= _MAX_CACHE_SIZE: + _cache.clear() + _dirty.clear() + log.info("geo_cache_flushed", reason="capacity") + _cache[ip] = info + if info.country_code is not None: + _dirty.add(ip) async def flush_dirty(db: aiosqlite.Connection) -> int: @@ -773,19 +791,20 @@ async def flush_dirty(db: aiosqlite.Connection) -> int: Returns: The number of rows successfully upserted. """ - if not _dirty: - return 0 + async with _cache_lock: + if not _dirty: + return 0 - # Atomically snapshot and clear in a single-threaded async context. - # No ``await`` between copy and clear ensures no interleaving. - to_flush = _dirty.copy() - _dirty.clear() + # Atomically snapshot and clear while holding the cache lock. + to_flush = _dirty.copy() + _dirty.clear() + + rows = [ + (ip, _cache[ip].country_code, _cache[ip].country_name, _cache[ip].asn, _cache[ip].org) + for ip in to_flush + if ip in _cache + ] - rows = [ - (ip, _cache[ip].country_code, _cache[ip].country_name, _cache[ip].asn, _cache[ip].org) - for ip in to_flush - if ip in _cache - ] if not rows: return 0 diff --git a/backend/app/tasks/geo_re_resolve.py b/backend/app/tasks/geo_re_resolve.py index 770fe0d..53e1ed9 100644 --- a/backend/app/tasks/geo_re_resolve.py +++ b/backend/app/tasks/geo_re_resolve.py @@ -67,7 +67,7 @@ async def _run_re_resolve_with_resources(settings: "Settings", http_session: "Cl log.info("geo_re_resolve_start", unresolved=len(unresolved_ips)) # Clear the negative cache so these IPs are eligible for fresh API calls. - geo_service.clear_neg_cache() + await geo_service.clear_neg_cache() # lookup_batch handles throttling, retries, and persistence when db is # passed. This is a background task so DB writes are allowed. diff --git a/backend/tests/test_services/test_ban_service.py b/backend/tests/test_services/test_ban_service.py index 87be876..e28a2f3 100644 --- a/backend/tests/test_services/test_ban_service.py +++ b/backend/tests/test_services/test_ban_service.py @@ -700,7 +700,7 @@ class TestOriginFilter: assert len(result.bans) == 205 assert all(b.country_code == "DE" for b in result.bans) - geo_service.clear_cache() + await geo_service.clear_cache() async def test_bans_by_country_source_archive_reads_archive( self, app_db_with_archive: aiosqlite.Connection @@ -769,7 +769,7 @@ class TestBansbyCountryBackground: assert result.total == 3 # Country counts should reflect the cached data. assert "DE" in result.countries or "US" in result.countries or "JP" in result.countries - geo_service.clear_cache() + await geo_service.clear_cache() async def test_uncached_ips_trigger_background_task( self, mixed_origin_db_path: str @@ -778,7 +778,7 @@ class TestBansbyCountryBackground: resolution and the response returns without blocking.""" from app.services import geo_service - geo_service.clear_cache() # ensure cache is empty + await geo_service.clear_cache() # ensure cache is empty with ( patch( @@ -810,7 +810,7 @@ class TestBansbyCountryBackground: """When http_session is None, no background task is created.""" from app.services import geo_service - geo_service.clear_cache() + await geo_service.clear_cache() with ( patch( diff --git a/backend/tests/test_services/test_ban_service_perf.py b/backend/tests/test_services/test_ban_service_perf.py index 9468716..058c868 100644 --- a/backend/tests/test_services/test_ban_service_perf.py +++ b/backend/tests/test_services/test_ban_service_perf.py @@ -130,7 +130,7 @@ async def perf_db_path(tmp_path_factory: Any) -> str: ips = await _seed_f2b_db(path, _BAN_COUNT) # Pre-populate the in-memory geo cache so no network calls are made. - geo_service.clear_cache() + await geo_service.clear_cache() country_cycle = _COUNTRIES * (_BAN_COUNT // len(_COUNTRIES) + 1) for i, ip in enumerate(ips): cc, cn = country_cycle[i] diff --git a/backend/tests/test_services/test_geo_service.py b/backend/tests/test_services/test_geo_service.py index 9c99c3f..c58bf19 100644 --- a/backend/tests/test_services/test_geo_service.py +++ b/backend/tests/test_services/test_geo_service.py @@ -45,9 +45,9 @@ def _make_session(response_json: dict[str, object], status: int = 200) -> MagicM @pytest.fixture(autouse=True) -def clear_geo_cache() -> None: +async def clear_geo_cache() -> None: """Flush the module-level geo cache before every test.""" - geo_service.clear_cache() + await geo_service.clear_cache() # --------------------------------------------------------------------------- @@ -162,7 +162,7 @@ class TestLookupCaching: ) await geo_service.lookup("2.3.4.5", session) - geo_service.clear_cache() + await geo_service.clear_cache() await geo_service.lookup("2.3.4.5", session) assert session.get.call_count == 2 @@ -259,7 +259,7 @@ class TestNegativeCache: session = _make_session({"status": "fail", "message": "private range"}) await geo_service.lookup("192.0.2.3", session) - geo_service.clear_neg_cache() + await geo_service.clear_neg_cache() await geo_service.lookup("192.0.2.3", session) assert session.get.call_count == 2 @@ -474,27 +474,27 @@ class TestLookupBatchSingleCommit: class TestDirtySetTracking: """_store() marks successfully resolved IPs as dirty.""" - def test_successful_resolution_adds_to_dirty(self) -> None: + async def test_successful_resolution_adds_to_dirty(self) -> None: """Storing a GeoInfo with a country_code adds the IP to _dirty.""" info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP") - geo_service._store("1.2.3.4", info) + await geo_service._store("1.2.3.4", info) assert "1.2.3.4" in geo_service._dirty - def test_null_country_does_not_add_to_dirty(self) -> None: + async def test_null_country_does_not_add_to_dirty(self) -> None: """Storing a GeoInfo with country_code=None must not pollute _dirty.""" info = GeoInfo(country_code=None, country_name=None, asn=None, org=None) - geo_service._store("10.0.0.1", info) + await geo_service._store("10.0.0.1", info) assert "10.0.0.1" not in geo_service._dirty - def test_clear_cache_also_clears_dirty(self) -> None: + async def test_clear_cache_also_clears_dirty(self) -> None: """clear_cache() must discard any pending dirty entries.""" info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP") - geo_service._store("8.8.8.8", info) + await geo_service._store("8.8.8.8", info) assert geo_service._dirty - geo_service.clear_cache() + await geo_service.clear_cache() assert not geo_service._dirty @@ -519,7 +519,7 @@ class TestFlushDirty: async def test_flush_writes_and_clears_dirty(self) -> None: """flush_dirty() inserts all dirty IPs and clears _dirty afterwards.""" info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT") - geo_service._store("100.0.0.1", info) + await geo_service._store("100.0.0.1", info) assert "100.0.0.1" in geo_service._dirty db = _make_async_db() @@ -542,7 +542,7 @@ class TestFlushDirty: async def test_flush_re_adds_to_dirty_on_db_error(self) -> None: """When the DB write fails, entries are re-added to _dirty for retry.""" info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP") - geo_service._store("200.0.0.1", info) + await geo_service._store("200.0.0.1", info) db = _make_async_db() db.executemany = AsyncMock(side_effect=OSError("disk full")) @@ -881,7 +881,7 @@ class TestReResolveAll: AsyncMock(return_value=geo_map), ) as mock_lookup, patch( "app.services.geo_service.clear_neg_cache", - MagicMock(), + AsyncMock(), ) as mock_clear: result = await geo_service.re_resolve_all(db, session)