refactor: improve backend type safety and import organization

- Add TYPE_CHECKING guards for runtime-expensive imports (aiohttp, aiosqlite)
- Reorganize imports to follow PEP 8 conventions
- Convert TypeAlias to modern PEP 695 type syntax (where appropriate)
- Use Sequence/Mapping from collections.abc for type hints (covariant)
- Replace string literals with cast() for improved type inference
- Fix casting of Fail2BanResponse and TypedDict patterns
- Add IpLookupResult TypedDict for precise return type annotation
- Reformat overlong lines for readability (120 char limit)
- Add asyncio_mode and filterwarnings to pytest config
- Update test fixtures with improved type hints

This improves mypy type checking and makes type relationships explicit.
This commit is contained in:
2026-03-20 13:44:14 +01:00
parent 6515164d53
commit 250bb1a2e5
30 changed files with 431 additions and 644 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -44,7 +45,7 @@ def _make_session(response_json: dict[str, object], status: int = 200) -> MagicM
@pytest.fixture(autouse=True)
def clear_geo_cache() -> None: # type: ignore[misc]
def clear_geo_cache() -> None:
"""Flush the module-level geo cache before every test."""
geo_service.clear_cache()
@@ -68,7 +69,7 @@ class TestLookupSuccess:
"org": "AS3320 Deutsche Telekom AG",
}
)
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
result = await geo_service.lookup("1.2.3.4", session)
assert result is not None
assert result.country_code == "DE"
@@ -84,7 +85,7 @@ class TestLookupSuccess:
"org": "Google LLC",
}
)
result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
result = await geo_service.lookup("8.8.8.8", session)
assert result is not None
assert result.country_name == "United States"
@@ -100,7 +101,7 @@ class TestLookupSuccess:
"org": "Deutsche Telekom",
}
)
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
result = await geo_service.lookup("1.2.3.4", session)
assert result is not None
assert result.asn == "AS3320"
@@ -116,7 +117,7 @@ class TestLookupSuccess:
"org": "Google LLC",
}
)
result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
result = await geo_service.lookup("8.8.8.8", session)
assert result is not None
assert result.org == "Google LLC"
@@ -142,8 +143,8 @@ class TestLookupCaching:
}
)
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
await geo_service.lookup("1.2.3.4", session)
await geo_service.lookup("1.2.3.4", session)
# The session.get() should only have been called once.
assert session.get.call_count == 1
@@ -160,9 +161,9 @@ class TestLookupCaching:
}
)
await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type]
await geo_service.lookup("2.3.4.5", session)
geo_service.clear_cache()
await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type]
await geo_service.lookup("2.3.4.5", session)
assert session.get.call_count == 2
@@ -172,8 +173,8 @@ class TestLookupCaching:
{"status": "fail", "message": "reserved range"}
)
await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type]
await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type]
await geo_service.lookup("192.168.1.1", session)
await geo_service.lookup("192.168.1.1", session)
# Second call is blocked by the negative cache — only one API hit.
assert session.get.call_count == 1
@@ -190,7 +191,7 @@ class TestLookupFailures:
async def test_non_200_response_returns_null_geo_info(self) -> None:
"""A 429 or 500 status returns GeoInfo with null fields (not None)."""
session = _make_session({}, status=429)
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
result = await geo_service.lookup("1.2.3.4", session)
assert result is not None
assert isinstance(result, GeoInfo)
assert result.country_code is None
@@ -203,7 +204,7 @@ class TestLookupFailures:
mock_ctx.__aexit__ = AsyncMock(return_value=False)
session.get = MagicMock(return_value=mock_ctx)
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
result = await geo_service.lookup("10.0.0.1", session)
assert result is not None
assert isinstance(result, GeoInfo)
assert result.country_code is None
@@ -211,7 +212,7 @@ class TestLookupFailures:
async def test_failed_status_returns_geo_info_with_nulls(self) -> None:
"""When ip-api returns ``status=fail`` a GeoInfo with null fields is returned (but not cached)."""
session = _make_session({"status": "fail", "message": "private range"})
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
result = await geo_service.lookup("10.0.0.1", session)
assert result is not None
assert isinstance(result, GeoInfo)
@@ -231,8 +232,8 @@ class TestNegativeCache:
"""After a failed lookup the second call is served from the neg cache."""
session = _make_session({"status": "fail", "message": "private range"})
r1 = await geo_service.lookup("192.0.2.1", session) # type: ignore[arg-type]
r2 = await geo_service.lookup("192.0.2.1", session) # type: ignore[arg-type]
r1 = await geo_service.lookup("192.0.2.1", session)
r2 = await geo_service.lookup("192.0.2.1", session)
# Only one HTTP call should have been made; second served from neg cache.
assert session.get.call_count == 1
@@ -243,12 +244,12 @@ class TestNegativeCache:
"""When the neg-cache entry is older than the TTL a new API call is made."""
session = _make_session({"status": "fail", "message": "private range"})
await geo_service.lookup("192.0.2.2", session) # type: ignore[arg-type]
await geo_service.lookup("192.0.2.2", session)
# Manually expire the neg-cache entry.
geo_service._neg_cache["192.0.2.2"] -= geo_service._NEG_CACHE_TTL + 1 # type: ignore[attr-defined]
geo_service._neg_cache["192.0.2.2"] -= geo_service._NEG_CACHE_TTL + 1
await geo_service.lookup("192.0.2.2", session) # type: ignore[arg-type]
await geo_service.lookup("192.0.2.2", session)
# Both calls should have hit the API.
assert session.get.call_count == 2
@@ -257,9 +258,9 @@ class TestNegativeCache:
"""After clearing the neg cache the IP is eligible for a new API call."""
session = _make_session({"status": "fail", "message": "private range"})
await geo_service.lookup("192.0.2.3", session) # type: ignore[arg-type]
await geo_service.lookup("192.0.2.3", session)
geo_service.clear_neg_cache()
await geo_service.lookup("192.0.2.3", session) # type: ignore[arg-type]
await geo_service.lookup("192.0.2.3", session)
assert session.get.call_count == 2
@@ -275,9 +276,9 @@ class TestNegativeCache:
}
)
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
await geo_service.lookup("1.2.3.4", session)
assert "1.2.3.4" not in geo_service._neg_cache # type: ignore[attr-defined]
assert "1.2.3.4" not in geo_service._neg_cache
# ---------------------------------------------------------------------------
@@ -307,7 +308,7 @@ class TestGeoipFallback:
mock_reader = self._make_geoip_reader("DE", "Germany")
with patch.object(geo_service, "_geoip_reader", mock_reader):
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
result = await geo_service.lookup("1.2.3.4", session)
mock_reader.country.assert_called_once_with("1.2.3.4")
assert result is not None
@@ -320,12 +321,12 @@ class TestGeoipFallback:
mock_reader = self._make_geoip_reader("US", "United States")
with patch.object(geo_service, "_geoip_reader", mock_reader):
await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
await geo_service.lookup("8.8.8.8", session)
# Second call must be served from positive cache without hitting API.
await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
await geo_service.lookup("8.8.8.8", session)
assert session.get.call_count == 1
assert "8.8.8.8" in geo_service._cache # type: ignore[attr-defined]
assert "8.8.8.8" in geo_service._cache
async def test_geoip_fallback_not_called_on_api_success(self) -> None:
"""When ip-api succeeds, the geoip2 reader must not be consulted."""
@@ -341,7 +342,7 @@ class TestGeoipFallback:
mock_reader = self._make_geoip_reader("XX", "Nowhere")
with patch.object(geo_service, "_geoip_reader", mock_reader):
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
result = await geo_service.lookup("1.2.3.4", session)
mock_reader.country.assert_not_called()
assert result is not None
@@ -352,7 +353,7 @@ class TestGeoipFallback:
session = _make_session({"status": "fail", "message": "private range"})
with patch.object(geo_service, "_geoip_reader", None):
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
result = await geo_service.lookup("10.0.0.1", session)
assert result is not None
assert result.country_code is None
@@ -363,7 +364,7 @@ class TestGeoipFallback:
# ---------------------------------------------------------------------------
def _make_batch_session(batch_response: list[dict[str, object]]) -> MagicMock:
def _make_batch_session(batch_response: Sequence[Mapping[str, object]]) -> MagicMock:
"""Build a mock aiohttp.ClientSession for batch POST calls.
Args:
@@ -412,7 +413,7 @@ class TestLookupBatchSingleCommit:
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
await geo_service.lookup_batch(ips, session, db=db)
db.commit.assert_awaited_once()
@@ -426,7 +427,7 @@ class TestLookupBatchSingleCommit:
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
await geo_service.lookup_batch(ips, session, db=db)
db.commit.assert_awaited_once()
@@ -452,13 +453,13 @@ class TestLookupBatchSingleCommit:
async def test_no_commit_for_all_cached_ips(self) -> None:
"""When all IPs are already cached, no HTTP call and no commit occur."""
geo_service._cache["5.5.5.5"] = GeoInfo( # type: ignore[attr-defined]
geo_service._cache["5.5.5.5"] = GeoInfo(
country_code="FR", country_name="France", asn="AS1", org="ISP"
)
db = _make_async_db()
session = _make_batch_session([])
result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db) # type: ignore[arg-type]
result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db)
assert result["5.5.5.5"].country_code == "FR"
db.commit.assert_not_awaited()
@@ -476,26 +477,26 @@ class TestDirtySetTracking:
def test_successful_resolution_adds_to_dirty(self) -> None:
"""Storing a GeoInfo with a country_code adds the IP to _dirty."""
info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")
geo_service._store("1.2.3.4", info) # type: ignore[attr-defined]
geo_service._store("1.2.3.4", info)
assert "1.2.3.4" in geo_service._dirty # type: ignore[attr-defined]
assert "1.2.3.4" in geo_service._dirty
def test_null_country_does_not_add_to_dirty(self) -> None:
"""Storing a GeoInfo with country_code=None must not pollute _dirty."""
info = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
geo_service._store("10.0.0.1", info) # type: ignore[attr-defined]
geo_service._store("10.0.0.1", info)
assert "10.0.0.1" not in geo_service._dirty # type: ignore[attr-defined]
assert "10.0.0.1" not in geo_service._dirty
def test_clear_cache_also_clears_dirty(self) -> None:
"""clear_cache() must discard any pending dirty entries."""
info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP")
geo_service._store("8.8.8.8", info) # type: ignore[attr-defined]
assert geo_service._dirty # type: ignore[attr-defined]
geo_service._store("8.8.8.8", info)
assert geo_service._dirty
geo_service.clear_cache()
assert not geo_service._dirty # type: ignore[attr-defined]
assert not geo_service._dirty
async def test_lookup_batch_populates_dirty(self) -> None:
"""After lookup_batch() with db=None, resolved IPs appear in _dirty."""
@@ -509,7 +510,7 @@ class TestDirtySetTracking:
await geo_service.lookup_batch(ips, session, db=None)
for ip in ips:
assert ip in geo_service._dirty # type: ignore[attr-defined]
assert ip in geo_service._dirty
class TestFlushDirty:
@@ -518,8 +519,8 @@ class TestFlushDirty:
async def test_flush_writes_and_clears_dirty(self) -> None:
"""flush_dirty() inserts all dirty IPs and clears _dirty afterwards."""
info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT")
geo_service._store("100.0.0.1", info) # type: ignore[attr-defined]
assert "100.0.0.1" in geo_service._dirty # type: ignore[attr-defined]
geo_service._store("100.0.0.1", info)
assert "100.0.0.1" in geo_service._dirty
db = _make_async_db()
count = await geo_service.flush_dirty(db)
@@ -527,7 +528,7 @@ class TestFlushDirty:
assert count == 1
db.executemany.assert_awaited_once()
db.commit.assert_awaited_once()
assert "100.0.0.1" not in geo_service._dirty # type: ignore[attr-defined]
assert "100.0.0.1" not in geo_service._dirty
async def test_flush_returns_zero_when_nothing_dirty(self) -> None:
"""flush_dirty() returns 0 and makes no DB calls when _dirty is empty."""
@@ -541,7 +542,7 @@ class TestFlushDirty:
async def test_flush_re_adds_to_dirty_on_db_error(self) -> None:
"""When the DB write fails, entries are re-added to _dirty for retry."""
info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP")
geo_service._store("200.0.0.1", info) # type: ignore[attr-defined]
geo_service._store("200.0.0.1", info)
db = _make_async_db()
db.executemany = AsyncMock(side_effect=OSError("disk full"))
@@ -549,7 +550,7 @@ class TestFlushDirty:
count = await geo_service.flush_dirty(db)
assert count == 0
assert "200.0.0.1" in geo_service._dirty # type: ignore[attr-defined]
assert "200.0.0.1" in geo_service._dirty
async def test_flush_batch_and_lookup_batch_integration(self) -> None:
"""lookup_batch() populates _dirty; flush_dirty() then persists them."""
@@ -562,14 +563,14 @@ class TestFlushDirty:
# Resolve without DB to populate only in-memory cache and _dirty.
await geo_service.lookup_batch(ips, session, db=None)
assert geo_service._dirty == set(ips) # type: ignore[attr-defined]
assert geo_service._dirty == set(ips)
# Now flush to the DB.
db = _make_async_db()
count = await geo_service.flush_dirty(db)
assert count == 2
assert not geo_service._dirty # type: ignore[attr-defined]
assert not geo_service._dirty
db.commit.assert_awaited_once()
@@ -585,7 +586,7 @@ class TestLookupBatchThrottling:
"""When more than _BATCH_SIZE IPs are sent, asyncio.sleep is called
between consecutive batch HTTP calls with at least _BATCH_DELAY."""
# Generate _BATCH_SIZE + 1 IPs so we get exactly 2 batch calls.
batch_size: int = geo_service._BATCH_SIZE # type: ignore[attr-defined]
batch_size: int = geo_service._BATCH_SIZE
ips = [f"10.0.{i // 256}.{i % 256}" for i in range(batch_size + 1)]
def _make_result(chunk: list[str], _session: object) -> dict[str, GeoInfo]:
@@ -608,7 +609,7 @@ class TestLookupBatchThrottling:
assert mock_batch.call_count == 2
mock_sleep.assert_awaited_once()
delay_arg: float = mock_sleep.call_args[0][0]
assert delay_arg >= geo_service._BATCH_DELAY # type: ignore[attr-defined]
assert delay_arg >= geo_service._BATCH_DELAY
async def test_lookup_batch_retries_on_full_chunk_failure(self) -> None:
"""When a chunk returns all-None on first try, it retries and succeeds."""
@@ -650,7 +651,7 @@ class TestLookupBatchThrottling:
_empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
_failure: dict[str, GeoInfo] = dict.fromkeys(ips, _empty)
max_retries: int = geo_service._BATCH_MAX_RETRIES # type: ignore[attr-defined]
max_retries: int = geo_service._BATCH_MAX_RETRIES
with (
patch(
@@ -667,11 +668,11 @@ class TestLookupBatchThrottling:
# IP should have no country.
assert result["9.9.9.9"].country_code is None
# Negative cache should contain the IP.
assert "9.9.9.9" in geo_service._neg_cache # type: ignore[attr-defined]
assert "9.9.9.9" in geo_service._neg_cache
# Sleep called for each retry with exponential backoff.
assert mock_sleep.call_count == max_retries
backoff_values = [call.args[0] for call in mock_sleep.call_args_list]
batch_delay: float = geo_service._BATCH_DELAY # type: ignore[attr-defined]
batch_delay: float = geo_service._BATCH_DELAY
for i, val in enumerate(backoff_values):
expected = batch_delay * (2 ** (i + 1))
assert val == pytest.approx(expected)
@@ -709,7 +710,7 @@ class TestErrorLogging:
import structlog.testing
with structlog.testing.capture_logs() as captured:
result = await geo_service.lookup("197.221.98.153", session) # type: ignore[arg-type]
result = await geo_service.lookup("197.221.98.153", session)
assert result is not None
assert result.country_code is None
@@ -733,7 +734,7 @@ class TestErrorLogging:
import structlog.testing
with structlog.testing.capture_logs() as captured:
await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
await geo_service.lookup("10.0.0.1", session)
request_failed = [e for e in captured if e.get("event") == "geo_lookup_request_failed"]
assert len(request_failed) == 1
@@ -757,7 +758,7 @@ class TestErrorLogging:
import structlog.testing
with structlog.testing.capture_logs() as captured:
result = await geo_service._batch_api_call(["1.2.3.4"], session) # type: ignore[attr-defined]
result = await geo_service._batch_api_call(["1.2.3.4"], session)
assert result["1.2.3.4"].country_code is None
@@ -778,7 +779,7 @@ class TestLookupCachedOnly:
def test_returns_cached_ips(self) -> None:
"""IPs already in the cache are returned in the geo_map."""
geo_service._cache["1.1.1.1"] = GeoInfo( # type: ignore[attr-defined]
geo_service._cache["1.1.1.1"] = GeoInfo(
country_code="AU", country_name="Australia", asn="AS13335", org="Cloudflare"
)
geo_map, uncached = geo_service.lookup_cached_only(["1.1.1.1"])
@@ -798,7 +799,7 @@ class TestLookupCachedOnly:
"""IPs in the negative cache within TTL are not re-queued as uncached."""
import time
geo_service._neg_cache["10.0.0.1"] = time.monotonic() # type: ignore[attr-defined]
geo_service._neg_cache["10.0.0.1"] = time.monotonic()
geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.1"])
@@ -807,7 +808,7 @@ class TestLookupCachedOnly:
def test_expired_neg_cache_requeued(self) -> None:
"""IPs whose neg-cache entry has expired are listed as uncached."""
geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired # type: ignore[attr-defined]
geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired
_geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.2"])
@@ -815,12 +816,12 @@ class TestLookupCachedOnly:
def test_mixed_ips(self) -> None:
"""A mix of cached, neg-cached, and unknown IPs is split correctly."""
geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined]
geo_service._cache["1.2.3.4"] = GeoInfo(
country_code="DE", country_name="Germany", asn=None, org=None
)
import time
geo_service._neg_cache["5.5.5.5"] = time.monotonic() # type: ignore[attr-defined]
geo_service._neg_cache["5.5.5.5"] = time.monotonic()
geo_map, uncached = geo_service.lookup_cached_only(["1.2.3.4", "5.5.5.5", "9.9.9.9"])
@@ -829,7 +830,7 @@ class TestLookupCachedOnly:
def test_deduplication(self) -> None:
"""Duplicate IPs in the input appear at most once in the output."""
geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined]
geo_service._cache["1.2.3.4"] = GeoInfo(
country_code="US", country_name="United States", asn=None, org=None
)
@@ -866,7 +867,7 @@ class TestLookupBatchBulkWrites:
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
await geo_service.lookup_batch(ips, session, db=db)
# One executemany for the positive rows.
assert db.executemany.await_count >= 1
@@ -883,7 +884,7 @@ class TestLookupBatchBulkWrites:
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
await geo_service.lookup_batch(ips, session, db=db)
assert db.executemany.await_count >= 1
db.execute.assert_not_awaited()
@@ -905,7 +906,7 @@ class TestLookupBatchBulkWrites:
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type]
await geo_service.lookup_batch(ips, session, db=db)
# One executemany for positives, one for negatives.
assert db.executemany.await_count == 2