"""Tests for the geo cache repository.""" from pathlib import Path import aiosqlite import pytest from app.repositories import geo_cache_repo async def _create_geo_cache_table(db: aiosqlite.Connection) -> None: await db.execute( """ CREATE TABLE IF NOT EXISTS geo_cache ( ip TEXT PRIMARY KEY, country_code TEXT, country_name TEXT, asn TEXT, org TEXT, cached_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')) ) """ ) await db.commit() @pytest.mark.asyncio async def test_get_unresolved_ips_returns_empty_when_none_exist(tmp_path: Path) -> None: db_path = str(tmp_path / "geo_cache.db") async with aiosqlite.connect(db_path) as db: await _create_geo_cache_table(db) await db.execute( "INSERT INTO geo_cache (ip, country_code, country_name, asn, org) VALUES (?, ?, ?, ?, ?)", ("1.1.1.1", "DE", "Germany", "AS123", "Test"), ) await db.commit() async with aiosqlite.connect(db_path) as db: ips = await geo_cache_repo.get_unresolved_ips(db) assert ips == [] @pytest.mark.asyncio async def test_get_unresolved_ips_returns_pending_ips(tmp_path: Path) -> None: db_path = str(tmp_path / "geo_cache.db") async with aiosqlite.connect(db_path) as db: await _create_geo_cache_table(db) await db.executemany( "INSERT INTO geo_cache (ip, country_code) VALUES (?, ?)", [ ("2.2.2.2", None), ("3.3.3.3", None), ("4.4.4.4", "US"), ], ) await db.commit() async with aiosqlite.connect(db_path) as db: ips = await geo_cache_repo.get_unresolved_ips(db) assert sorted(ips) == ["2.2.2.2", "3.3.3.3"] @pytest.mark.asyncio async def test_load_all_and_count_unresolved(tmp_path: Path) -> None: db_path = str(tmp_path / "geo_cache.db") async with aiosqlite.connect(db_path) as db: await _create_geo_cache_table(db) await db.executemany( "INSERT INTO geo_cache (ip, country_code, country_name, asn, org) VALUES (?, ?, ?, ?, ?)", [ ("5.5.5.5", None, None, None, None), ("6.6.6.6", "FR", "France", "AS456", "TestOrg"), ], ) await db.commit() async with aiosqlite.connect(db_path) as db: rows = await geo_cache_repo.load_all(db) unresolved = await geo_cache_repo.count_unresolved(db) assert unresolved == 1 assert any(row["ip"] == "6.6.6.6" and row["country_code"] == "FR" for row in rows) @pytest.mark.asyncio async def test_upsert_entry_and_neg_entry(tmp_path: Path) -> None: db_path = str(tmp_path / "geo_cache.db") async with aiosqlite.connect(db_path) as db: await _create_geo_cache_table(db) await geo_cache_repo.upsert_entry( db, "7.7.7.7", "GB", "United Kingdom", "AS789", "TestOrg", ) await db.commit() await geo_cache_repo.upsert_neg_entry(db, "8.8.8.8") await db.commit() # Ensure positive entry is present. async with db.execute("SELECT country_code FROM geo_cache WHERE ip = ?", ("7.7.7.7",)) as cur: row = await cur.fetchone() assert row is not None assert row[0] == "GB" # Ensure negative entry exists with NULL country_code. async with db.execute("SELECT country_code FROM geo_cache WHERE ip = ?", ("8.8.8.8",)) as cur: row = await cur.fetchone() assert row is not None assert row[0] is None @pytest.mark.asyncio async def test_bulk_upsert_entries_and_neg_entries(tmp_path: Path) -> None: db_path = str(tmp_path / "geo_cache.db") async with aiosqlite.connect(db_path) as db: await _create_geo_cache_table(db) rows = [ ("9.9.9.9", "NL", "Netherlands", "AS101", "Test"), ("10.10.10.10", "JP", "Japan", "AS102", "Test"), ] count = await geo_cache_repo.bulk_upsert_entries(db, rows) assert count == 2 neg_count = await geo_cache_repo.bulk_upsert_neg_entries(db, ["11.11.11.11", "12.12.12.12"]) assert neg_count == 2 await db.commit() async with db.execute("SELECT COUNT(*) FROM geo_cache") as cur: row = await cur.fetchone() assert row is not None assert int(row[0]) == 4