222 lines
7.8 KiB
Python
222 lines
7.8 KiB
Python
"""Tests for the geo re-resolve background task.
|
|
|
|
Validates that :func:`~app.tasks.geo_re_resolve._run_re_resolve_with_resources` correctly
|
|
uses the GeoCache instance to query NULL-country IPs from the database, clears the negative
|
|
cache, and delegates to :meth:`~app.services.geo_cache.GeoCache.lookup_batch` for a fresh
|
|
resolution attempt.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.models.geo import GeoInfo
|
|
from app.services.geo_cache import GeoCache
|
|
from app.tasks.geo_re_resolve import _run_re_resolve_with_resources
|
|
|
|
|
|
class _AsyncRowIterator:
|
|
"""Minimal async iterator over a list of row tuples."""
|
|
|
|
def __init__(self, rows: list[tuple[str]]) -> None:
|
|
self._iter = iter(rows)
|
|
|
|
def __aiter__(self) -> _AsyncRowIterator:
|
|
return self
|
|
|
|
async def __anext__(self) -> tuple[str]:
|
|
try:
|
|
return next(self._iter)
|
|
except StopIteration:
|
|
raise StopAsyncIteration # noqa: B904
|
|
|
|
|
|
def _make_app(
|
|
unresolved_ips: list[str],
|
|
lookup_result: dict[str, GeoInfo] | None = None,
|
|
) -> MagicMock:
|
|
"""Build a minimal mock ``app`` with ``state.db`` and ``state.http_session``.
|
|
|
|
The mock database returns *unresolved_ips* when the re-resolve task
|
|
queries ``SELECT ip FROM geo_cache WHERE country_code IS NULL``.
|
|
|
|
Args:
|
|
unresolved_ips: IPs to return from the mocked DB query.
|
|
lookup_result: Value returned by the mocked ``lookup_batch``.
|
|
Defaults to an empty dict.
|
|
|
|
Returns:
|
|
A :class:`unittest.mock.MagicMock` that mimics ``fastapi.FastAPI``.
|
|
"""
|
|
if lookup_result is None:
|
|
lookup_result = {}
|
|
|
|
rows = [(ip,) for ip in unresolved_ips]
|
|
cursor = _AsyncRowIterator(rows)
|
|
|
|
# db.execute() returns an async context manager yielding the cursor.
|
|
ctx = AsyncMock()
|
|
ctx.__aenter__ = AsyncMock(return_value=cursor)
|
|
ctx.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
db = AsyncMock()
|
|
db.execute = MagicMock(return_value=ctx)
|
|
|
|
http_session = MagicMock()
|
|
|
|
app = MagicMock()
|
|
app.state.db = db
|
|
app.state.http_session = http_session
|
|
app.state.settings = MagicMock(database_path="/tmp/fake.db")
|
|
app.state.runtime_settings = None
|
|
|
|
return app
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_re_resolve_no_unresolved_ips_skips() -> None:
|
|
"""The task should return immediately when no NULL-country IPs exist."""
|
|
geo_cache = GeoCache()
|
|
settings = MagicMock(database_path="/tmp/fake.db")
|
|
http_session = MagicMock()
|
|
|
|
with patch.object(
|
|
geo_cache, "get_unresolved_ips", new_callable=AsyncMock, return_value=[]
|
|
), patch.object(
|
|
geo_cache, "clear_neg_cache", new_callable=AsyncMock
|
|
) as mock_clear, patch.object(
|
|
geo_cache, "lookup_batch", new_callable=AsyncMock
|
|
) as mock_lookup:
|
|
await _run_re_resolve_with_resources(geo_cache, settings, http_session)
|
|
|
|
mock_clear.assert_not_called()
|
|
mock_lookup.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_re_resolve_clears_neg_cache() -> None:
|
|
"""The task must clear the negative cache before calling lookup_batch."""
|
|
ips = ["1.2.3.4", "5.6.7.8"]
|
|
result: dict[str, GeoInfo] = {
|
|
"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS3320", org="DTAG"),
|
|
"5.6.7.8": GeoInfo(country_code="US", country_name="United States", asn="AS15169", org="Google"),
|
|
}
|
|
geo_cache = GeoCache()
|
|
settings = MagicMock(database_path="/tmp/fake.db")
|
|
http_session = MagicMock()
|
|
|
|
with patch.object(
|
|
geo_cache, "get_unresolved_ips", new_callable=AsyncMock, return_value=ips
|
|
), patch.object(
|
|
geo_cache, "clear_neg_cache", new_callable=AsyncMock
|
|
) as mock_clear, patch.object(
|
|
geo_cache, "lookup_batch", new_callable=AsyncMock, return_value=result
|
|
):
|
|
await _run_re_resolve_with_resources(geo_cache, settings, http_session)
|
|
|
|
mock_clear.assert_called_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_re_resolve_calls_lookup_batch_with_db() -> None:
|
|
"""The task must pass the db to lookup_batch for persistence."""
|
|
ips = ["10.0.0.1", "10.0.0.2"]
|
|
result: dict[str, GeoInfo] = {
|
|
"10.0.0.1": GeoInfo(country_code="FR", country_name="France", asn=None, org=None),
|
|
"10.0.0.2": GeoInfo(country_code=None, country_name=None, asn=None, org=None),
|
|
}
|
|
geo_cache = GeoCache()
|
|
settings = MagicMock(database_path="/tmp/fake.db")
|
|
http_session = MagicMock()
|
|
|
|
with patch.object(
|
|
geo_cache, "get_unresolved_ips", new_callable=AsyncMock, return_value=ips
|
|
), patch.object(
|
|
geo_cache, "clear_neg_cache", new_callable=AsyncMock
|
|
), patch.object(
|
|
geo_cache, "lookup_batch", new_callable=AsyncMock, return_value=result
|
|
) as mock_lookup:
|
|
await _run_re_resolve_with_resources(geo_cache, settings, http_session)
|
|
|
|
# Verify lookup_batch was called with the ips and http_session
|
|
# (can't verify the exact db object as it's created by task_db)
|
|
assert mock_lookup.call_count >= 1
|
|
call_args = mock_lookup.call_args
|
|
assert call_args[0][0] == ips # First positional arg is IPs
|
|
assert call_args[0][1] == http_session # Second positional arg is session
|
|
assert "db" in call_args[1] # db passed as kwarg
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_re_resolve_logs_correct_counts(caplog: Any) -> None:
|
|
"""The task should verify the function completes when given multiple IPs."""
|
|
ips = ["1.1.1.1", "2.2.2.2", "3.3.3.3"]
|
|
result: dict[str, GeoInfo] = {
|
|
"1.1.1.1": GeoInfo(country_code="AU", country_name="Australia", asn=None, org=None),
|
|
"2.2.2.2": GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None),
|
|
"3.3.3.3": GeoInfo(country_code=None, country_name=None, asn=None, org=None),
|
|
}
|
|
geo_cache = GeoCache()
|
|
settings = MagicMock(database_path="/tmp/fake.db")
|
|
http_session = MagicMock()
|
|
|
|
db = AsyncMock()
|
|
|
|
with patch(
|
|
"app.tasks.db.task_db",
|
|
MagicMock(
|
|
return_value=AsyncMock(
|
|
__aenter__=AsyncMock(return_value=db),
|
|
__aexit__=AsyncMock(return_value=False),
|
|
)
|
|
),
|
|
), patch.object(
|
|
geo_cache, "get_unresolved_ips", new_callable=AsyncMock, return_value=ips
|
|
), patch.object(
|
|
geo_cache, "clear_neg_cache", new_callable=AsyncMock
|
|
), patch.object(
|
|
geo_cache, "lookup_batch", new_callable=AsyncMock, return_value=result
|
|
) as mock_lookup:
|
|
await _run_re_resolve_with_resources(geo_cache, settings, http_session)
|
|
|
|
# Verify lookup_batch was called with the right number of IPs
|
|
call_args = mock_lookup.call_args
|
|
assert len(call_args[0][0]) == 3
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_re_resolve_handles_all_resolved() -> None:
|
|
"""When every IP resolves successfully the task should complete normally."""
|
|
ips = ["4.4.4.4"]
|
|
result: dict[str, GeoInfo] = {
|
|
"4.4.4.4": GeoInfo(country_code="GB", country_name="United Kingdom", asn=None, org=None),
|
|
}
|
|
geo_cache = GeoCache()
|
|
settings = MagicMock(database_path="/tmp/fake.db")
|
|
http_session = MagicMock()
|
|
|
|
db = AsyncMock()
|
|
|
|
with patch(
|
|
"app.tasks.db.task_db",
|
|
MagicMock(
|
|
return_value=AsyncMock(
|
|
__aenter__=AsyncMock(return_value=db),
|
|
__aexit__=AsyncMock(return_value=False),
|
|
)
|
|
),
|
|
), patch.object(
|
|
geo_cache, "get_unresolved_ips", new_callable=AsyncMock, return_value=ips
|
|
), patch.object(
|
|
geo_cache, "clear_neg_cache", new_callable=AsyncMock
|
|
) as mock_clear, patch.object(
|
|
geo_cache, "lookup_batch", new_callable=AsyncMock, return_value=result
|
|
) as mock_lookup:
|
|
await _run_re_resolve_with_resources(geo_cache, settings, http_session)
|
|
|
|
mock_clear.assert_called_once()
|
|
mock_lookup.assert_called_once()
|