"""Tests for geo_service.lookup().""" from __future__ import annotations from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.services import geo_service from app.services.geo_service import GeoInfo # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _make_session(response_json: dict[str, object], status: int = 200) -> MagicMock: """Build a mock aiohttp.ClientSession that returns *response_json*. Args: response_json: The dict that the mock response's ``json()`` returns. status: HTTP status code for the mock response. Returns: A :class:`MagicMock` that behaves like an ``aiohttp.ClientSession`` in an ``async with`` context. """ mock_resp = AsyncMock() mock_resp.status = status mock_resp.json = AsyncMock(return_value=response_json) mock_ctx = AsyncMock() mock_ctx.__aenter__ = AsyncMock(return_value=mock_resp) mock_ctx.__aexit__ = AsyncMock(return_value=False) session = MagicMock() session.get = MagicMock(return_value=mock_ctx) return session # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- @pytest.fixture(autouse=True) def clear_geo_cache() -> None: # type: ignore[misc] """Flush the module-level geo cache before every test.""" geo_service.clear_cache() # --------------------------------------------------------------------------- # Happy path # --------------------------------------------------------------------------- class TestLookupSuccess: """geo_service.lookup() under normal conditions.""" async def test_returns_country_code(self) -> None: """country_code is populated from the ``countryCode`` field.""" session = _make_session( { "status": "success", "countryCode": "DE", "country": "Germany", "as": "AS3320 Deutsche Telekom AG", "org": "AS3320 Deutsche Telekom AG", } ) result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] assert result is not None assert result.country_code == "DE" async def test_returns_country_name(self) -> None: """country_name is populated from the ``country`` field.""" session = _make_session( { "status": "success", "countryCode": "US", "country": "United States", "as": "AS15169 Google LLC", "org": "Google LLC", } ) result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type] assert result is not None assert result.country_name == "United States" async def test_asn_extracted_without_org_suffix(self) -> None: """The ASN field contains only the ``AS`` prefix, not the full string.""" session = _make_session( { "status": "success", "countryCode": "DE", "country": "Germany", "as": "AS3320 Deutsche Telekom AG", "org": "Deutsche Telekom", } ) result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] assert result is not None assert result.asn == "AS3320" async def test_org_populated(self) -> None: """org field is populated from the ``org`` key.""" session = _make_session( { "status": "success", "countryCode": "US", "country": "United States", "as": "AS15169 Google LLC", "org": "Google LLC", } ) result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type] assert result is not None assert result.org == "Google LLC" # --------------------------------------------------------------------------- # Cache behaviour # --------------------------------------------------------------------------- class TestLookupCaching: """Verify that results are cached and the cache can be cleared.""" async def test_second_call_uses_cache(self) -> None: """Subsequent lookups for the same IP do not make additional HTTP requests.""" session = _make_session( { "status": "success", "countryCode": "DE", "country": "Germany", "as": "AS3320 Deutsche Telekom AG", "org": "Deutsche Telekom", } ) await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] # The session.get() should only have been called once. assert session.get.call_count == 1 async def test_clear_cache_forces_refetch(self) -> None: """After clearing the cache a new HTTP request is made.""" session = _make_session( { "status": "success", "countryCode": "DE", "country": "Germany", "as": "AS3320", "org": "Telekom", } ) await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type] geo_service.clear_cache() await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type] assert session.get.call_count == 2 async def test_negative_result_stored_in_neg_cache(self) -> None: """A failed lookup is stored in the negative cache, so the second call is blocked.""" session = _make_session( {"status": "fail", "message": "reserved range"} ) await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type] await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type] # Second call is blocked by the negative cache — only one API hit. assert session.get.call_count == 1 # --------------------------------------------------------------------------- # Failure modes # --------------------------------------------------------------------------- class TestLookupFailures: """geo_service.lookup() when things go wrong.""" async def test_non_200_response_returns_null_geo_info(self) -> None: """A 429 or 500 status returns GeoInfo with null fields (not None).""" session = _make_session({}, status=429) result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] assert result is not None assert isinstance(result, GeoInfo) assert result.country_code is None async def test_network_error_returns_null_geo_info(self) -> None: """A network exception returns GeoInfo with null fields (not None).""" session = MagicMock() mock_ctx = AsyncMock() mock_ctx.__aenter__ = AsyncMock(side_effect=OSError("connection refused")) mock_ctx.__aexit__ = AsyncMock(return_value=False) session.get = MagicMock(return_value=mock_ctx) result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type] assert result is not None assert isinstance(result, GeoInfo) assert result.country_code is None async def test_failed_status_returns_geo_info_with_nulls(self) -> None: """When ip-api returns ``status=fail`` a GeoInfo with null fields is returned (but not cached).""" session = _make_session({"status": "fail", "message": "private range"}) result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type] assert result is not None assert isinstance(result, GeoInfo) assert result.country_code is None assert result.country_name is None # --------------------------------------------------------------------------- # Negative cache # --------------------------------------------------------------------------- class TestNegativeCache: """Verify the negative cache throttles retries for failing IPs.""" async def test_neg_cache_blocks_second_lookup(self) -> None: """After a failed lookup the second call is served from the neg cache.""" session = _make_session({"status": "fail", "message": "private range"}) r1 = await geo_service.lookup("192.0.2.1", session) # type: ignore[arg-type] r2 = await geo_service.lookup("192.0.2.1", session) # type: ignore[arg-type] # Only one HTTP call should have been made; second served from neg cache. assert session.get.call_count == 1 assert r1 is not None and r1.country_code is None assert r2 is not None and r2.country_code is None async def test_neg_cache_retries_after_ttl(self) -> None: """When the neg-cache entry is older than the TTL a new API call is made.""" session = _make_session({"status": "fail", "message": "private range"}) await geo_service.lookup("192.0.2.2", session) # type: ignore[arg-type] # Manually expire the neg-cache entry. geo_service._neg_cache["192.0.2.2"] -= geo_service._NEG_CACHE_TTL + 1 # type: ignore[attr-defined] await geo_service.lookup("192.0.2.2", session) # type: ignore[arg-type] # Both calls should have hit the API. assert session.get.call_count == 2 async def test_clear_neg_cache_allows_immediate_retry(self) -> None: """After clearing the neg cache the IP is eligible for a new API call.""" session = _make_session({"status": "fail", "message": "private range"}) await geo_service.lookup("192.0.2.3", session) # type: ignore[arg-type] geo_service.clear_neg_cache() await geo_service.lookup("192.0.2.3", session) # type: ignore[arg-type] assert session.get.call_count == 2 async def test_successful_lookup_does_not_pollute_neg_cache(self) -> None: """A successful lookup must not create a neg-cache entry.""" session = _make_session( { "status": "success", "countryCode": "DE", "country": "Germany", "as": "AS3320", "org": "Telekom", } ) await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] assert "1.2.3.4" not in geo_service._neg_cache # type: ignore[attr-defined] # --------------------------------------------------------------------------- # GeoIP2 (MaxMind) fallback # --------------------------------------------------------------------------- class TestGeoipFallback: """Verify the MaxMind GeoLite2 fallback is used when ip-api fails.""" def _make_geoip_reader(self, iso_code: str, name: str) -> MagicMock: """Build a mock geoip2.database.Reader that returns *iso_code*.""" country_mock = MagicMock() country_mock.iso_code = iso_code country_mock.name = name response_mock = MagicMock() response_mock.country = country_mock reader = MagicMock() reader.country = MagicMock(return_value=response_mock) return reader async def test_geoip_fallback_called_when_api_fails(self) -> None: """When ip-api returns status=fail, the geoip2 reader is consulted.""" session = _make_session({"status": "fail", "message": "reserved range"}) mock_reader = self._make_geoip_reader("DE", "Germany") with patch.object(geo_service, "_geoip_reader", mock_reader): result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] mock_reader.country.assert_called_once_with("1.2.3.4") assert result is not None assert result.country_code == "DE" assert result.country_name == "Germany" async def test_geoip_fallback_result_stored_in_cache(self) -> None: """A successful geoip2 fallback result is stored in the positive cache.""" session = _make_session({"status": "fail", "message": "reserved range"}) mock_reader = self._make_geoip_reader("US", "United States") with patch.object(geo_service, "_geoip_reader", mock_reader): await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type] # Second call must be served from positive cache without hitting API. await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type] assert session.get.call_count == 1 assert "8.8.8.8" in geo_service._cache # type: ignore[attr-defined] async def test_geoip_fallback_not_called_on_api_success(self) -> None: """When ip-api succeeds, the geoip2 reader must not be consulted.""" session = _make_session( { "status": "success", "countryCode": "JP", "country": "Japan", "as": "AS12345", "org": "NTT", } ) mock_reader = self._make_geoip_reader("XX", "Nowhere") with patch.object(geo_service, "_geoip_reader", mock_reader): result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] mock_reader.country.assert_not_called() assert result is not None assert result.country_code == "JP" async def test_geoip_fallback_not_called_when_no_reader(self) -> None: """When no geoip2 reader is configured, the fallback silently does nothing.""" session = _make_session({"status": "fail", "message": "private range"}) with patch.object(geo_service, "_geoip_reader", None): result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type] assert result is not None assert result.country_code is None # --------------------------------------------------------------------------- # Batch single-commit behaviour (Task 1) # --------------------------------------------------------------------------- def _make_batch_session(batch_response: list[dict[str, object]]) -> MagicMock: """Build a mock aiohttp.ClientSession for batch POST calls. Args: batch_response: The list that the mock response's ``json()`` returns. Returns: A :class:`MagicMock` with a ``post`` method wired as an async context. """ mock_resp = AsyncMock() mock_resp.status = 200 mock_resp.json = AsyncMock(return_value=batch_response) mock_ctx = AsyncMock() mock_ctx.__aenter__ = AsyncMock(return_value=mock_resp) mock_ctx.__aexit__ = AsyncMock(return_value=False) session = MagicMock() session.post = MagicMock(return_value=mock_ctx) return session def _make_async_db() -> MagicMock: """Build a minimal mock :class:`aiosqlite.Connection`. Returns: MagicMock with ``execute``, ``executemany``, and ``commit`` wired as async coroutines. """ db = MagicMock() db.execute = AsyncMock() db.executemany = AsyncMock() db.commit = AsyncMock() return db class TestLookupBatchSingleCommit: """lookup_batch() issues exactly one commit per call, not one per IP.""" async def test_single_commit_for_multiple_ips(self) -> None: """A batch of N IPs produces exactly one db.commit(), not N.""" ips = ["1.1.1.1", "2.2.2.2", "3.3.3.3"] batch_response = [ {"query": ip, "status": "success", "countryCode": "DE", "country": "Germany", "as": "AS1", "org": "Org"} for ip in ips ] session = _make_batch_session(batch_response) db = _make_async_db() await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type] db.commit.assert_awaited_once() async def test_commit_called_even_on_failed_lookups(self) -> None: """A batch with all-failed lookups still triggers one commit.""" ips = ["10.0.0.1", "10.0.0.2"] batch_response = [ {"query": ip, "status": "fail", "message": "private range"} for ip in ips ] session = _make_batch_session(batch_response) db = _make_async_db() await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type] db.commit.assert_awaited_once() async def test_no_commit_when_db_is_none(self) -> None: """When db=None, no commit is attempted.""" ips = ["1.1.1.1"] batch_response = [ { "query": "1.1.1.1", "status": "success", "countryCode": "US", "country": "United States", "as": "AS15169", "org": "Google LLC", }, ] session = _make_batch_session(batch_response) # Should not raise; without db there is nothing to commit. result = await geo_service.lookup_batch(ips, session, db=None) assert result["1.1.1.1"].country_code == "US" async def test_no_commit_for_all_cached_ips(self) -> None: """When all IPs are already cached, no HTTP call and no commit occur.""" geo_service._cache["5.5.5.5"] = GeoInfo( # type: ignore[attr-defined] country_code="FR", country_name="France", asn="AS1", org="ISP" ) db = _make_async_db() session = _make_batch_session([]) result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db) # type: ignore[arg-type] assert result["5.5.5.5"].country_code == "FR" db.commit.assert_not_awaited() session.post.assert_not_called() # --------------------------------------------------------------------------- # Dirty-set tracking and flush_dirty (Task 3) # --------------------------------------------------------------------------- class TestDirtySetTracking: """_store() marks successfully resolved IPs as dirty.""" def test_successful_resolution_adds_to_dirty(self) -> None: """Storing a GeoInfo with a country_code adds the IP to _dirty.""" info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP") geo_service._store("1.2.3.4", info) # type: ignore[attr-defined] assert "1.2.3.4" in geo_service._dirty # type: ignore[attr-defined] def test_null_country_does_not_add_to_dirty(self) -> None: """Storing a GeoInfo with country_code=None must not pollute _dirty.""" info = GeoInfo(country_code=None, country_name=None, asn=None, org=None) geo_service._store("10.0.0.1", info) # type: ignore[attr-defined] assert "10.0.0.1" not in geo_service._dirty # type: ignore[attr-defined] def test_clear_cache_also_clears_dirty(self) -> None: """clear_cache() must discard any pending dirty entries.""" info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP") geo_service._store("8.8.8.8", info) # type: ignore[attr-defined] assert geo_service._dirty # type: ignore[attr-defined] geo_service.clear_cache() assert not geo_service._dirty # type: ignore[attr-defined] async def test_lookup_batch_populates_dirty(self) -> None: """After lookup_batch() with db=None, resolved IPs appear in _dirty.""" ips = ["1.1.1.1", "2.2.2.2"] batch_response = [ {"query": ip, "status": "success", "countryCode": "JP", "country": "Japan", "as": "AS7500", "org": "IIJ"} for ip in ips ] session = _make_batch_session(batch_response) await geo_service.lookup_batch(ips, session, db=None) for ip in ips: assert ip in geo_service._dirty # type: ignore[attr-defined] class TestFlushDirty: """flush_dirty() persists dirty entries and clears the set.""" async def test_flush_writes_and_clears_dirty(self) -> None: """flush_dirty() inserts all dirty IPs and clears _dirty afterwards.""" info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT") geo_service._store("100.0.0.1", info) # type: ignore[attr-defined] assert "100.0.0.1" in geo_service._dirty # type: ignore[attr-defined] db = _make_async_db() count = await geo_service.flush_dirty(db) assert count == 1 db.executemany.assert_awaited_once() db.commit.assert_awaited_once() assert "100.0.0.1" not in geo_service._dirty # type: ignore[attr-defined] async def test_flush_returns_zero_when_nothing_dirty(self) -> None: """flush_dirty() returns 0 and makes no DB calls when _dirty is empty.""" db = _make_async_db() count = await geo_service.flush_dirty(db) assert count == 0 db.executemany.assert_not_awaited() db.commit.assert_not_awaited() async def test_flush_re_adds_to_dirty_on_db_error(self) -> None: """When the DB write fails, entries are re-added to _dirty for retry.""" info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP") geo_service._store("200.0.0.1", info) # type: ignore[attr-defined] db = _make_async_db() db.executemany = AsyncMock(side_effect=OSError("disk full")) count = await geo_service.flush_dirty(db) assert count == 0 assert "200.0.0.1" in geo_service._dirty # type: ignore[attr-defined] async def test_flush_batch_and_lookup_batch_integration(self) -> None: """lookup_batch() populates _dirty; flush_dirty() then persists them.""" ips = ["10.1.2.3", "10.1.2.4"] batch_response = [ {"query": ip, "status": "success", "countryCode": "CA", "country": "Canada", "as": "AS812", "org": "Bell"} for ip in ips ] session = _make_batch_session(batch_response) # Resolve without DB to populate only in-memory cache and _dirty. await geo_service.lookup_batch(ips, session, db=None) assert geo_service._dirty == set(ips) # type: ignore[attr-defined] # Now flush to the DB. db = _make_async_db() count = await geo_service.flush_dirty(db) assert count == 2 assert not geo_service._dirty # type: ignore[attr-defined] db.commit.assert_awaited_once()