"""Tests for the geo cache repository.""" from pathlib import Path import aiosqlite import pytest from app.repositories import geo_cache_repo async def _create_geo_cache_table(db: aiosqlite.Connection) -> None: await db.execute( """ CREATE TABLE IF NOT EXISTS geo_cache ( ip TEXT PRIMARY KEY, country_code TEXT, country_name TEXT, asn TEXT, org TEXT, cached_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')), last_seen TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')) ) """ ) await db.commit() @pytest.mark.asyncio async def test_get_unresolved_ips_returns_empty_when_none_exist(tmp_path: Path) -> None: db_path = str(tmp_path / "geo_cache.db") async with aiosqlite.connect(db_path) as db: await _create_geo_cache_table(db) await db.execute( "INSERT INTO geo_cache (ip, country_code, country_name, asn, org) VALUES (?, ?, ?, ?, ?)", ("1.1.1.1", "DE", "Germany", "AS123", "Test"), ) await db.commit() async with aiosqlite.connect(db_path) as db: ips = await geo_cache_repo.get_unresolved_ips(db) assert ips == [] @pytest.mark.asyncio async def test_get_unresolved_ips_returns_pending_ips(tmp_path: Path) -> None: db_path = str(tmp_path / "geo_cache.db") async with aiosqlite.connect(db_path) as db: await _create_geo_cache_table(db) await db.executemany( "INSERT INTO geo_cache (ip, country_code) VALUES (?, ?)", [ ("2.2.2.2", None), ("3.3.3.3", None), ("4.4.4.4", "US"), ], ) await db.commit() async with aiosqlite.connect(db_path) as db: ips = await geo_cache_repo.get_unresolved_ips(db) assert sorted(ips) == ["2.2.2.2", "3.3.3.3"] @pytest.mark.asyncio async def test_load_all_and_count_unresolved(tmp_path: Path) -> None: db_path = str(tmp_path / "geo_cache.db") async with aiosqlite.connect(db_path) as db: await _create_geo_cache_table(db) await db.executemany( "INSERT INTO geo_cache (ip, country_code, country_name, asn, org) VALUES (?, ?, ?, ?, ?)", [ ("5.5.5.5", None, None, None, None), ("6.6.6.6", "FR", "France", "AS456", "TestOrg"), ], ) await db.commit() async with aiosqlite.connect(db_path) as db: rows = await geo_cache_repo.load_all(db) unresolved = await geo_cache_repo.count_unresolved(db) assert unresolved == 1 assert any(row["ip"] == "6.6.6.6" and row["country_code"] == "FR" for row in rows) @pytest.mark.asyncio async def test_upsert_entry_and_neg_entry(tmp_path: Path) -> None: db_path = str(tmp_path / "geo_cache.db") async with aiosqlite.connect(db_path) as db: await _create_geo_cache_table(db) await geo_cache_repo.upsert_entry( db, "7.7.7.7", "GB", "United Kingdom", "AS789", "TestOrg", ) await db.commit() await geo_cache_repo.upsert_neg_entry(db, "8.8.8.8") await db.commit() # Ensure positive entry is present. async with db.execute("SELECT country_code FROM geo_cache WHERE ip = ?", ("7.7.7.7",)) as cur: row = await cur.fetchone() assert row is not None assert row[0] == "GB" # Ensure negative entry exists with NULL country_code. async with db.execute("SELECT country_code FROM geo_cache WHERE ip = ?", ("8.8.8.8",)) as cur: row = await cur.fetchone() assert row is not None assert row[0] is None @pytest.mark.asyncio async def test_upsert_entry_and_commit_commits_transaction(tmp_path: Path) -> None: db_path = str(tmp_path / "geo_cache.db") async with aiosqlite.connect(db_path) as db: await _create_geo_cache_table(db) await geo_cache_repo.upsert_entry_and_commit( db, "13.13.13.13", "NL", "Netherlands", "AS1313", "TestOrg", ) async with db.execute("SELECT country_code FROM geo_cache WHERE ip = ?", ("13.13.13.13",)) as cur: row = await cur.fetchone() assert row is not None assert row[0] == "NL" @pytest.mark.asyncio async def test_bulk_upsert_entries_and_neg_entries_and_commit_commits_once(tmp_path: Path) -> None: db_path = str(tmp_path / "geo_cache.db") async with aiosqlite.connect(db_path) as db: await _create_geo_cache_table(db) rows = [ ("14.14.14.14", "BE", "Belgium", "AS1414", "Test"), ] count, neg_count = await geo_cache_repo.bulk_upsert_entries_and_neg_entries_and_commit( db, rows, ["15.15.15.15"], ) assert count == 1 assert neg_count == 1 async with db.execute("SELECT COUNT(*) FROM geo_cache") as cur: row = await cur.fetchone() assert row is not None assert int(row[0]) == 2 @pytest.mark.asyncio async def test_bulk_upsert_entries_and_neg_entries(tmp_path: Path) -> None: db_path = str(tmp_path / "geo_cache.db") async with aiosqlite.connect(db_path) as db: await _create_geo_cache_table(db) rows = [ ("9.9.9.9", "NL", "Netherlands", "AS101", "Test"), ("10.10.10.10", "JP", "Japan", "AS102", "Test"), ] count = await geo_cache_repo.bulk_upsert_entries(db, rows) assert count == 2 neg_count = await geo_cache_repo.bulk_upsert_neg_entries(db, ["11.11.11.11", "12.12.12.12"]) assert neg_count == 2 await db.commit() async with db.execute("SELECT COUNT(*) FROM geo_cache") as cur: row = await cur.fetchone() assert row is not None assert int(row[0]) == 4 @pytest.mark.asyncio async def test_delete_stale_entries_removes_old_entries(tmp_path: Path) -> None: db_path = str(tmp_path / "geo_cache.db") async with aiosqlite.connect(db_path) as db: await _create_geo_cache_table(db) # Insert entries with various last_seen times await db.execute( "INSERT INTO geo_cache (ip, country_code, last_seen) VALUES (?, ?, ?)", ("1.1.1.1", "US", "2020-01-01T00:00:00Z"), ) await db.execute( "INSERT INTO geo_cache (ip, country_code, last_seen) VALUES (?, ?, ?)", ("2.2.2.2", "DE", "2024-12-01T00:00:00Z"), ) await db.execute( "INSERT INTO geo_cache (ip, country_code, last_seen) VALUES (?, ?, ?)", ("3.3.3.3", "FR", "2025-01-01T00:00:00Z"), ) await db.commit() async with aiosqlite.connect(db_path) as db: # Delete entries older than 2024-06-01 deleted = await geo_cache_repo.delete_stale_entries(db, "2024-06-01T00:00:00Z") await db.commit() assert deleted == 1 # Verify the correct entry was deleted async with aiosqlite.connect(db_path) as db, db.execute("SELECT ip FROM geo_cache ORDER BY ip") as cur: rows = await cur.fetchall() ips = [row[0] for row in rows] assert sorted(ips) == ["2.2.2.2", "3.3.3.3"] @pytest.mark.asyncio async def test_delete_stale_entries_returns_zero_when_none_stale(tmp_path: Path) -> None: db_path = str(tmp_path / "geo_cache.db") async with aiosqlite.connect(db_path) as db: await _create_geo_cache_table(db) # Insert entries with recent last_seen times await db.execute( "INSERT INTO geo_cache (ip, country_code, last_seen) VALUES (?, ?, ?)", ("1.1.1.1", "US", "2025-01-01T00:00:00Z"), ) await db.execute( "INSERT INTO geo_cache (ip, country_code, last_seen) VALUES (?, ?, ?)", ("2.2.2.2", "DE", "2025-01-02T00:00:00Z"), ) await db.commit() async with aiosqlite.connect(db_path) as db: # Try to delete entries older than 2020-01-01 (all are newer) deleted = await geo_cache_repo.delete_stale_entries(db, "2020-01-01T00:00:00Z") await db.commit() assert deleted == 0 # Verify no entries were deleted async with aiosqlite.connect(db_path) as db, db.execute("SELECT COUNT(*) FROM geo_cache") as cur: row = await cur.fetchone() assert row is not None assert int(row[0]) == 2 @pytest.mark.asyncio async def test_delete_stale_entries_with_empty_table(tmp_path: Path) -> None: db_path = str(tmp_path / "geo_cache.db") async with aiosqlite.connect(db_path) as db: await _create_geo_cache_table(db) async with aiosqlite.connect(db_path) as db: deleted = await geo_cache_repo.delete_stale_entries(db, "2024-01-01T00:00:00Z") await db.commit() assert deleted == 0