diff --git a/Docs/Tasks.md b/Docs/Tasks.md index ac2bb82..f773def 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -92,6 +92,8 @@ This document breaks the entire BanGUI project into development stages, ordered #### TASK B-5 — Create a `geo_cache_repo` and remove direct SQL from `geo_service.py` +**Status:** Completed ✅ + **Violated rule:** Refactoring.md §2.2 — Services must not execute raw SQL; go through a repository. **Files affected:** @@ -113,6 +115,8 @@ This document breaks the entire BanGUI project into development stages, ordered #### TASK B-6 — Remove direct SQL from `tasks/geo_re_resolve.py` +**Status:** Completed ✅ + **Violated rule:** Refactoring.md §2.5 — Tasks must not use repositories directly; they must call a service method. **Files affected:** @@ -163,6 +167,8 @@ Remove or rewrite the docstring snippet so it does not contain a bare `print()` #### TASK B-9 — Remove direct SQL from `main.py` lifespan into `geo_service` +**Status:** Completed ✅ + **Violated rule:** Refactoring.md §2 — Application startup code must not execute raw SQL; data-access logic belongs in a repository (or, when count semantics belong to a domain concern, a service method). **Files affected:** diff --git a/backend/app/main.py b/backend/app/main.py index d486cde..db5531f 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -162,11 +162,7 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]: await geo_service.load_cache_from_db(db) # Log unresolved geo entries so the operator can see the scope of the issue. - async with db.execute( - "SELECT COUNT(*) FROM geo_cache WHERE country_code IS NULL" - ) as cur: - row = await cur.fetchone() - unresolved_count: int = int(row[0]) if row else 0 + unresolved_count = await geo_service.count_unresolved(db) if unresolved_count > 0: log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count) diff --git a/backend/app/repositories/geo_cache_repo.py b/backend/app/repositories/geo_cache_repo.py index 8e7ed8d..51de260 100644 --- a/backend/app/repositories/geo_cache_repo.py +++ b/backend/app/repositories/geo_cache_repo.py @@ -9,12 +9,48 @@ connection lifetimes. from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, TypedDict if TYPE_CHECKING: import aiosqlite +class GeoCacheRow(TypedDict): + """A single row from the ``geo_cache`` table.""" + + ip: str + country_code: str | None + country_name: str | None + asn: str | None + org: str | None + + +async def load_all(db: aiosqlite.Connection) -> list[GeoCacheRow]: + """Load all geo cache rows from the database. + + Args: + db: Open BanGUI application database connection. + + Returns: + List of rows from the ``geo_cache`` table. + """ + rows: list[GeoCacheRow] = [] + async with db.execute( + "SELECT ip, country_code, country_name, asn, org FROM geo_cache" + ) as cur: + async for row in cur: + rows.append( + GeoCacheRow( + ip=str(row[0]), + country_code=row[1], + country_name=row[2], + asn=row[3], + org=row[4], + ) + ) + return rows + + async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]: """Return all IPs in ``geo_cache`` where ``country_code`` is NULL. @@ -31,3 +67,80 @@ async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]: async for row in cur: ips.append(str(row[0])) return ips + + +async def count_unresolved(db: aiosqlite.Connection) -> int: + """Return the number of unresolved rows (country_code IS NULL).""" + async with db.execute( + "SELECT COUNT(*) FROM geo_cache WHERE country_code IS NULL" + ) as cur: + row = await cur.fetchone() + return int(row[0]) if row else 0 + + +async def upsert_entry( + db: aiosqlite.Connection, + ip: str, + country_code: str | None, + country_name: str | None, + asn: str | None, + org: str | None, +) -> None: + """Insert or update a resolved geo cache entry.""" + await db.execute( + """ + INSERT INTO geo_cache (ip, country_code, country_name, asn, org) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(ip) DO UPDATE SET + country_code = excluded.country_code, + country_name = excluded.country_name, + asn = excluded.asn, + org = excluded.org, + cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now') + """, + (ip, country_code, country_name, asn, org), + ) + + +async def upsert_neg_entry(db: aiosqlite.Connection, ip: str) -> None: + """Record a failed lookup attempt as a negative entry.""" + await db.execute( + "INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)", + (ip,), + ) + + +async def bulk_upsert_entries( + db: aiosqlite.Connection, + rows: list[tuple[str, str | None, str | None, str | None, str | None]], +) -> int: + """Bulk insert or update multiple geo cache entries.""" + if not rows: + return 0 + + await db.executemany( + """ + INSERT INTO geo_cache (ip, country_code, country_name, asn, org) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(ip) DO UPDATE SET + country_code = excluded.country_code, + country_name = excluded.country_name, + asn = excluded.asn, + org = excluded.org, + cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now') + """, + rows, + ) + return len(rows) + + +async def bulk_upsert_neg_entries(db: aiosqlite.Connection, ips: list[str]) -> int: + """Bulk insert negative lookup entries.""" + if not ips: + return 0 + + await db.executemany( + "INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)", + [(ip,) for ip in ips], + ) + return len(ips) diff --git a/backend/app/services/geo_service.py b/backend/app/services/geo_service.py index 95f5927..fa66b92 100644 --- a/backend/app/services/geo_service.py +++ b/backend/app/services/geo_service.py @@ -186,11 +186,7 @@ async def cache_stats(db: aiosqlite.Connection) -> dict[str, int]: Dict with keys ``cache_size``, ``unresolved``, ``neg_cache_size``, and ``dirty_size``. """ - async with db.execute( - "SELECT COUNT(*) FROM geo_cache WHERE country_code IS NULL" - ) as cur: - row = await cur.fetchone() - unresolved: int = int(row[0]) if row else 0 + unresolved = await geo_cache_repo.count_unresolved(db) return { "cache_size": len(_cache), @@ -200,6 +196,12 @@ async def cache_stats(db: aiosqlite.Connection) -> dict[str, int]: } +async def count_unresolved(db: aiosqlite.Connection) -> int: + """Return the number of unresolved entries in the persistent geo cache.""" + + return await geo_cache_repo.count_unresolved(db) + + async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]: """Return geo cache IPs where the country code has not yet been resolved. @@ -282,21 +284,18 @@ async def load_cache_from_db(db: aiosqlite.Connection) -> None: database (not the fail2ban database). """ count = 0 - async with db.execute( - "SELECT ip, country_code, country_name, asn, org FROM geo_cache" - ) as cur: - async for row in cur: - ip: str = str(row[0]) - country_code: str | None = row[1] - if country_code is None: - continue - _cache[ip] = GeoInfo( - country_code=country_code, - country_name=row[2], - asn=row[3], - org=row[4], - ) - count += 1 + for row in await geo_cache_repo.load_all(db): + country_code: str | None = row["country_code"] + if country_code is None: + continue + ip: str = row["ip"] + _cache[ip] = GeoInfo( + country_code=country_code, + country_name=row["country_name"], + asn=row["asn"], + org=row["org"], + ) + count += 1 log.info("geo_cache_loaded_from_db", entries=count) @@ -315,18 +314,13 @@ async def _persist_entry( ip: IP address string. info: Resolved geo data to persist. """ - await db.execute( - """ - INSERT INTO geo_cache (ip, country_code, country_name, asn, org) - VALUES (?, ?, ?, ?, ?) - ON CONFLICT(ip) DO UPDATE SET - country_code = excluded.country_code, - country_name = excluded.country_name, - asn = excluded.asn, - org = excluded.org, - cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now') - """, - (ip, info.country_code, info.country_name, info.asn, info.org), + await geo_cache_repo.upsert_entry( + db=db, + ip=ip, + country_code=info.country_code, + country_name=info.country_name, + asn=info.asn, + org=info.org, ) @@ -340,10 +334,7 @@ async def _persist_neg_entry(db: aiosqlite.Connection, ip: str) -> None: db: BanGUI application database connection. ip: IP address string whose resolution failed. """ - await db.execute( - "INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)", - (ip,), - ) + await geo_cache_repo.upsert_neg_entry(db=db, ip=ip) # --------------------------------------------------------------------------- @@ -599,19 +590,7 @@ async def lookup_batch( if db is not None: if pos_rows: try: - await db.executemany( - """ - INSERT INTO geo_cache (ip, country_code, country_name, asn, org) - VALUES (?, ?, ?, ?, ?) - ON CONFLICT(ip) DO UPDATE SET - country_code = excluded.country_code, - country_name = excluded.country_name, - asn = excluded.asn, - org = excluded.org, - cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now') - """, - pos_rows, - ) + await geo_cache_repo.bulk_upsert_entries(db, pos_rows) except Exception as exc: # noqa: BLE001 log.warning( "geo_batch_persist_failed", @@ -620,10 +599,7 @@ async def lookup_batch( ) if neg_ips: try: - await db.executemany( - "INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)", - [(ip,) for ip in neg_ips], - ) + await geo_cache_repo.bulk_upsert_neg_entries(db, neg_ips) except Exception as exc: # noqa: BLE001 log.warning( "geo_batch_persist_neg_failed", @@ -806,19 +782,7 @@ async def flush_dirty(db: aiosqlite.Connection) -> int: return 0 try: - await db.executemany( - """ - INSERT INTO geo_cache (ip, country_code, country_name, asn, org) - VALUES (?, ?, ?, ?, ?) - ON CONFLICT(ip) DO UPDATE SET - country_code = excluded.country_code, - country_name = excluded.country_name, - asn = excluded.asn, - org = excluded.org, - cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now') - """, - rows, - ) + await geo_cache_repo.bulk_upsert_entries(db, rows) await db.commit() except Exception as exc: # noqa: BLE001 log.warning("geo_flush_dirty_failed", error=str(exc)) diff --git a/backend/app/tasks/geo_re_resolve.py b/backend/app/tasks/geo_re_resolve.py index b0880e6..c01f6fc 100644 --- a/backend/app/tasks/geo_re_resolve.py +++ b/backend/app/tasks/geo_re_resolve.py @@ -49,12 +49,7 @@ async def _run_re_resolve(app: Any) -> None: http_session = app.state.http_session # Fetch all IPs with NULL country_code from the persistent cache. - unresolved_ips: list[str] = [] - async with db.execute( - "SELECT ip FROM geo_cache WHERE country_code IS NULL" - ) as cursor: - async for row in cursor: - unresolved_ips.append(str(row[0])) + unresolved_ips = await geo_service.get_unresolved_ips(db) if not unresolved_ips: log.debug("geo_re_resolve_skip", reason="no_unresolved_ips") diff --git a/backend/tests/test_repositories/test_geo_cache_repo.py b/backend/tests/test_repositories/test_geo_cache_repo.py index 2e070b9..fac8277 100644 --- a/backend/tests/test_repositories/test_geo_cache_repo.py +++ b/backend/tests/test_repositories/test_geo_cache_repo.py @@ -60,3 +60,81 @@ async def test_get_unresolved_ips_returns_pending_ips(tmp_path: Path) -> None: ips = await geo_cache_repo.get_unresolved_ips(db) assert sorted(ips) == ["2.2.2.2", "3.3.3.3"] + + +@pytest.mark.asyncio +async def test_load_all_and_count_unresolved(tmp_path: Path) -> None: + db_path = str(tmp_path / "geo_cache.db") + async with aiosqlite.connect(db_path) as db: + await _create_geo_cache_table(db) + await db.executemany( + "INSERT INTO geo_cache (ip, country_code, country_name, asn, org) VALUES (?, ?, ?, ?, ?)", + [ + ("5.5.5.5", None, None, None, None), + ("6.6.6.6", "FR", "France", "AS456", "TestOrg"), + ], + ) + await db.commit() + + async with aiosqlite.connect(db_path) as db: + rows = await geo_cache_repo.load_all(db) + unresolved = await geo_cache_repo.count_unresolved(db) + + assert unresolved == 1 + assert any(row["ip"] == "6.6.6.6" and row["country_code"] == "FR" for row in rows) + + +@pytest.mark.asyncio +async def test_upsert_entry_and_neg_entry(tmp_path: Path) -> None: + db_path = str(tmp_path / "geo_cache.db") + async with aiosqlite.connect(db_path) as db: + await _create_geo_cache_table(db) + + await geo_cache_repo.upsert_entry( + db, + "7.7.7.7", + "GB", + "United Kingdom", + "AS789", + "TestOrg", + ) + await db.commit() + + await geo_cache_repo.upsert_neg_entry(db, "8.8.8.8") + await db.commit() + + # Ensure positive entry is present. + async with db.execute("SELECT country_code FROM geo_cache WHERE ip = ?", ("7.7.7.7",)) as cur: + row = await cur.fetchone() + assert row is not None + assert row[0] == "GB" + + # Ensure negative entry exists with NULL country_code. + async with db.execute("SELECT country_code FROM geo_cache WHERE ip = ?", ("8.8.8.8",)) as cur: + row = await cur.fetchone() + assert row is not None + assert row[0] is None + + +@pytest.mark.asyncio +async def test_bulk_upsert_entries_and_neg_entries(tmp_path: Path) -> None: + db_path = str(tmp_path / "geo_cache.db") + async with aiosqlite.connect(db_path) as db: + await _create_geo_cache_table(db) + + rows = [ + ("9.9.9.9", "NL", "Netherlands", "AS101", "Test"), + ("10.10.10.10", "JP", "Japan", "AS102", "Test"), + ] + count = await geo_cache_repo.bulk_upsert_entries(db, rows) + assert count == 2 + + neg_count = await geo_cache_repo.bulk_upsert_neg_entries(db, ["11.11.11.11", "12.12.12.12"]) + assert neg_count == 2 + + await db.commit() + + async with db.execute("SELECT COUNT(*) FROM geo_cache") as cur: + row = await cur.fetchone() + assert row is not None + assert int(row[0]) == 4