Refactor geo cache persistence into repository + remove raw SQL from tasks/main, update task list
This commit is contained in:
@@ -92,6 +92,8 @@ This document breaks the entire BanGUI project into development stages, ordered
|
|||||||
|
|
||||||
#### TASK B-5 — Create a `geo_cache_repo` and remove direct SQL from `geo_service.py`
|
#### TASK B-5 — Create a `geo_cache_repo` and remove direct SQL from `geo_service.py`
|
||||||
|
|
||||||
|
**Status:** Completed ✅
|
||||||
|
|
||||||
**Violated rule:** Refactoring.md §2.2 — Services must not execute raw SQL; go through a repository.
|
**Violated rule:** Refactoring.md §2.2 — Services must not execute raw SQL; go through a repository.
|
||||||
|
|
||||||
**Files affected:**
|
**Files affected:**
|
||||||
@@ -113,6 +115,8 @@ This document breaks the entire BanGUI project into development stages, ordered
|
|||||||
|
|
||||||
#### TASK B-6 — Remove direct SQL from `tasks/geo_re_resolve.py`
|
#### TASK B-6 — Remove direct SQL from `tasks/geo_re_resolve.py`
|
||||||
|
|
||||||
|
**Status:** Completed ✅
|
||||||
|
|
||||||
**Violated rule:** Refactoring.md §2.5 — Tasks must not use repositories directly; they must call a service method.
|
**Violated rule:** Refactoring.md §2.5 — Tasks must not use repositories directly; they must call a service method.
|
||||||
|
|
||||||
**Files affected:**
|
**Files affected:**
|
||||||
@@ -163,6 +167,8 @@ Remove or rewrite the docstring snippet so it does not contain a bare `print()`
|
|||||||
|
|
||||||
#### TASK B-9 — Remove direct SQL from `main.py` lifespan into `geo_service`
|
#### TASK B-9 — Remove direct SQL from `main.py` lifespan into `geo_service`
|
||||||
|
|
||||||
|
**Status:** Completed ✅
|
||||||
|
|
||||||
**Violated rule:** Refactoring.md §2 — Application startup code must not execute raw SQL; data-access logic belongs in a repository (or, when count semantics belong to a domain concern, a service method).
|
**Violated rule:** Refactoring.md §2 — Application startup code must not execute raw SQL; data-access logic belongs in a repository (or, when count semantics belong to a domain concern, a service method).
|
||||||
|
|
||||||
**Files affected:**
|
**Files affected:**
|
||||||
|
|||||||
@@ -161,11 +161,7 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
await geo_service.load_cache_from_db(db)
|
await geo_service.load_cache_from_db(db)
|
||||||
|
|
||||||
# Log unresolved geo entries so the operator can see the scope of the issue.
|
# Log unresolved geo entries so the operator can see the scope of the issue.
|
||||||
async with db.execute(
|
unresolved_count = await geo_service.count_unresolved(db)
|
||||||
"SELECT COUNT(*) FROM geo_cache WHERE country_code IS NULL"
|
|
||||||
) as cur:
|
|
||||||
row = await cur.fetchone()
|
|
||||||
unresolved_count: int = int(row[0]) if row else 0
|
|
||||||
if unresolved_count > 0:
|
if unresolved_count > 0:
|
||||||
log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count)
|
log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count)
|
||||||
|
|
||||||
|
|||||||
@@ -9,12 +9,48 @@ connection lifetimes.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, TypedDict
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
|
|
||||||
|
|
||||||
|
class GeoCacheRow(TypedDict):
|
||||||
|
"""A single row from the ``geo_cache`` table."""
|
||||||
|
|
||||||
|
ip: str
|
||||||
|
country_code: str | None
|
||||||
|
country_name: str | None
|
||||||
|
asn: str | None
|
||||||
|
org: str | None
|
||||||
|
|
||||||
|
|
||||||
|
async def load_all(db: aiosqlite.Connection) -> list[GeoCacheRow]:
|
||||||
|
"""Load all geo cache rows from the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Open BanGUI application database connection.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of rows from the ``geo_cache`` table.
|
||||||
|
"""
|
||||||
|
rows: list[GeoCacheRow] = []
|
||||||
|
async with db.execute(
|
||||||
|
"SELECT ip, country_code, country_name, asn, org FROM geo_cache"
|
||||||
|
) as cur:
|
||||||
|
async for row in cur:
|
||||||
|
rows.append(
|
||||||
|
GeoCacheRow(
|
||||||
|
ip=str(row[0]),
|
||||||
|
country_code=row[1],
|
||||||
|
country_name=row[2],
|
||||||
|
asn=row[3],
|
||||||
|
org=row[4],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return rows
|
||||||
|
|
||||||
|
|
||||||
async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
|
async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
|
||||||
"""Return all IPs in ``geo_cache`` where ``country_code`` is NULL.
|
"""Return all IPs in ``geo_cache`` where ``country_code`` is NULL.
|
||||||
|
|
||||||
@@ -31,3 +67,80 @@ async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
|
|||||||
async for row in cur:
|
async for row in cur:
|
||||||
ips.append(str(row[0]))
|
ips.append(str(row[0]))
|
||||||
return ips
|
return ips
|
||||||
|
|
||||||
|
|
||||||
|
async def count_unresolved(db: aiosqlite.Connection) -> int:
|
||||||
|
"""Return the number of unresolved rows (country_code IS NULL)."""
|
||||||
|
async with db.execute(
|
||||||
|
"SELECT COUNT(*) FROM geo_cache WHERE country_code IS NULL"
|
||||||
|
) as cur:
|
||||||
|
row = await cur.fetchone()
|
||||||
|
return int(row[0]) if row else 0
|
||||||
|
|
||||||
|
|
||||||
|
async def upsert_entry(
|
||||||
|
db: aiosqlite.Connection,
|
||||||
|
ip: str,
|
||||||
|
country_code: str | None,
|
||||||
|
country_name: str | None,
|
||||||
|
asn: str | None,
|
||||||
|
org: str | None,
|
||||||
|
) -> None:
|
||||||
|
"""Insert or update a resolved geo cache entry."""
|
||||||
|
await db.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(ip) DO UPDATE SET
|
||||||
|
country_code = excluded.country_code,
|
||||||
|
country_name = excluded.country_name,
|
||||||
|
asn = excluded.asn,
|
||||||
|
org = excluded.org,
|
||||||
|
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
||||||
|
""",
|
||||||
|
(ip, country_code, country_name, asn, org),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def upsert_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
|
||||||
|
"""Record a failed lookup attempt as a negative entry."""
|
||||||
|
await db.execute(
|
||||||
|
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
|
||||||
|
(ip,),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def bulk_upsert_entries(
|
||||||
|
db: aiosqlite.Connection,
|
||||||
|
rows: list[tuple[str, str | None, str | None, str | None, str | None]],
|
||||||
|
) -> int:
|
||||||
|
"""Bulk insert or update multiple geo cache entries."""
|
||||||
|
if not rows:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
await db.executemany(
|
||||||
|
"""
|
||||||
|
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
||||||
|
VALUES (?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(ip) DO UPDATE SET
|
||||||
|
country_code = excluded.country_code,
|
||||||
|
country_name = excluded.country_name,
|
||||||
|
asn = excluded.asn,
|
||||||
|
org = excluded.org,
|
||||||
|
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
||||||
|
""",
|
||||||
|
rows,
|
||||||
|
)
|
||||||
|
return len(rows)
|
||||||
|
|
||||||
|
|
||||||
|
async def bulk_upsert_neg_entries(db: aiosqlite.Connection, ips: list[str]) -> int:
|
||||||
|
"""Bulk insert negative lookup entries."""
|
||||||
|
if not ips:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
await db.executemany(
|
||||||
|
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
|
||||||
|
[(ip,) for ip in ips],
|
||||||
|
)
|
||||||
|
return len(ips)
|
||||||
|
|||||||
@@ -186,11 +186,7 @@ async def cache_stats(db: aiosqlite.Connection) -> dict[str, int]:
|
|||||||
Dict with keys ``cache_size``, ``unresolved``, ``neg_cache_size``,
|
Dict with keys ``cache_size``, ``unresolved``, ``neg_cache_size``,
|
||||||
and ``dirty_size``.
|
and ``dirty_size``.
|
||||||
"""
|
"""
|
||||||
async with db.execute(
|
unresolved = await geo_cache_repo.count_unresolved(db)
|
||||||
"SELECT COUNT(*) FROM geo_cache WHERE country_code IS NULL"
|
|
||||||
) as cur:
|
|
||||||
row = await cur.fetchone()
|
|
||||||
unresolved: int = int(row[0]) if row else 0
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"cache_size": len(_cache),
|
"cache_size": len(_cache),
|
||||||
@@ -200,6 +196,12 @@ async def cache_stats(db: aiosqlite.Connection) -> dict[str, int]:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def count_unresolved(db: aiosqlite.Connection) -> int:
|
||||||
|
"""Return the number of unresolved entries in the persistent geo cache."""
|
||||||
|
|
||||||
|
return await geo_cache_repo.count_unresolved(db)
|
||||||
|
|
||||||
|
|
||||||
async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
|
async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
|
||||||
"""Return geo cache IPs where the country code has not yet been resolved.
|
"""Return geo cache IPs where the country code has not yet been resolved.
|
||||||
|
|
||||||
@@ -282,19 +284,16 @@ async def load_cache_from_db(db: aiosqlite.Connection) -> None:
|
|||||||
database (not the fail2ban database).
|
database (not the fail2ban database).
|
||||||
"""
|
"""
|
||||||
count = 0
|
count = 0
|
||||||
async with db.execute(
|
for row in await geo_cache_repo.load_all(db):
|
||||||
"SELECT ip, country_code, country_name, asn, org FROM geo_cache"
|
country_code: str | None = row["country_code"]
|
||||||
) as cur:
|
|
||||||
async for row in cur:
|
|
||||||
ip: str = str(row[0])
|
|
||||||
country_code: str | None = row[1]
|
|
||||||
if country_code is None:
|
if country_code is None:
|
||||||
continue
|
continue
|
||||||
|
ip: str = row["ip"]
|
||||||
_cache[ip] = GeoInfo(
|
_cache[ip] = GeoInfo(
|
||||||
country_code=country_code,
|
country_code=country_code,
|
||||||
country_name=row[2],
|
country_name=row["country_name"],
|
||||||
asn=row[3],
|
asn=row["asn"],
|
||||||
org=row[4],
|
org=row["org"],
|
||||||
)
|
)
|
||||||
count += 1
|
count += 1
|
||||||
log.info("geo_cache_loaded_from_db", entries=count)
|
log.info("geo_cache_loaded_from_db", entries=count)
|
||||||
@@ -315,18 +314,13 @@ async def _persist_entry(
|
|||||||
ip: IP address string.
|
ip: IP address string.
|
||||||
info: Resolved geo data to persist.
|
info: Resolved geo data to persist.
|
||||||
"""
|
"""
|
||||||
await db.execute(
|
await geo_cache_repo.upsert_entry(
|
||||||
"""
|
db=db,
|
||||||
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
ip=ip,
|
||||||
VALUES (?, ?, ?, ?, ?)
|
country_code=info.country_code,
|
||||||
ON CONFLICT(ip) DO UPDATE SET
|
country_name=info.country_name,
|
||||||
country_code = excluded.country_code,
|
asn=info.asn,
|
||||||
country_name = excluded.country_name,
|
org=info.org,
|
||||||
asn = excluded.asn,
|
|
||||||
org = excluded.org,
|
|
||||||
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
|
||||||
""",
|
|
||||||
(ip, info.country_code, info.country_name, info.asn, info.org),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -340,10 +334,7 @@ async def _persist_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
|
|||||||
db: BanGUI application database connection.
|
db: BanGUI application database connection.
|
||||||
ip: IP address string whose resolution failed.
|
ip: IP address string whose resolution failed.
|
||||||
"""
|
"""
|
||||||
await db.execute(
|
await geo_cache_repo.upsert_neg_entry(db=db, ip=ip)
|
||||||
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
|
|
||||||
(ip,),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -599,19 +590,7 @@ async def lookup_batch(
|
|||||||
if db is not None:
|
if db is not None:
|
||||||
if pos_rows:
|
if pos_rows:
|
||||||
try:
|
try:
|
||||||
await db.executemany(
|
await geo_cache_repo.bulk_upsert_entries(db, pos_rows)
|
||||||
"""
|
|
||||||
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
|
||||||
VALUES (?, ?, ?, ?, ?)
|
|
||||||
ON CONFLICT(ip) DO UPDATE SET
|
|
||||||
country_code = excluded.country_code,
|
|
||||||
country_name = excluded.country_name,
|
|
||||||
asn = excluded.asn,
|
|
||||||
org = excluded.org,
|
|
||||||
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
|
||||||
""",
|
|
||||||
pos_rows,
|
|
||||||
)
|
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
log.warning(
|
log.warning(
|
||||||
"geo_batch_persist_failed",
|
"geo_batch_persist_failed",
|
||||||
@@ -620,10 +599,7 @@ async def lookup_batch(
|
|||||||
)
|
)
|
||||||
if neg_ips:
|
if neg_ips:
|
||||||
try:
|
try:
|
||||||
await db.executemany(
|
await geo_cache_repo.bulk_upsert_neg_entries(db, neg_ips)
|
||||||
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
|
|
||||||
[(ip,) for ip in neg_ips],
|
|
||||||
)
|
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
log.warning(
|
log.warning(
|
||||||
"geo_batch_persist_neg_failed",
|
"geo_batch_persist_neg_failed",
|
||||||
@@ -806,19 +782,7 @@ async def flush_dirty(db: aiosqlite.Connection) -> int:
|
|||||||
return 0
|
return 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await db.executemany(
|
await geo_cache_repo.bulk_upsert_entries(db, rows)
|
||||||
"""
|
|
||||||
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
|
||||||
VALUES (?, ?, ?, ?, ?)
|
|
||||||
ON CONFLICT(ip) DO UPDATE SET
|
|
||||||
country_code = excluded.country_code,
|
|
||||||
country_name = excluded.country_name,
|
|
||||||
asn = excluded.asn,
|
|
||||||
org = excluded.org,
|
|
||||||
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
|
||||||
""",
|
|
||||||
rows,
|
|
||||||
)
|
|
||||||
await db.commit()
|
await db.commit()
|
||||||
except Exception as exc: # noqa: BLE001
|
except Exception as exc: # noqa: BLE001
|
||||||
log.warning("geo_flush_dirty_failed", error=str(exc))
|
log.warning("geo_flush_dirty_failed", error=str(exc))
|
||||||
|
|||||||
@@ -49,12 +49,7 @@ async def _run_re_resolve(app: Any) -> None:
|
|||||||
http_session = app.state.http_session
|
http_session = app.state.http_session
|
||||||
|
|
||||||
# Fetch all IPs with NULL country_code from the persistent cache.
|
# Fetch all IPs with NULL country_code from the persistent cache.
|
||||||
unresolved_ips: list[str] = []
|
unresolved_ips = await geo_service.get_unresolved_ips(db)
|
||||||
async with db.execute(
|
|
||||||
"SELECT ip FROM geo_cache WHERE country_code IS NULL"
|
|
||||||
) as cursor:
|
|
||||||
async for row in cursor:
|
|
||||||
unresolved_ips.append(str(row[0]))
|
|
||||||
|
|
||||||
if not unresolved_ips:
|
if not unresolved_ips:
|
||||||
log.debug("geo_re_resolve_skip", reason="no_unresolved_ips")
|
log.debug("geo_re_resolve_skip", reason="no_unresolved_ips")
|
||||||
|
|||||||
@@ -60,3 +60,81 @@ async def test_get_unresolved_ips_returns_pending_ips(tmp_path: Path) -> None:
|
|||||||
ips = await geo_cache_repo.get_unresolved_ips(db)
|
ips = await geo_cache_repo.get_unresolved_ips(db)
|
||||||
|
|
||||||
assert sorted(ips) == ["2.2.2.2", "3.3.3.3"]
|
assert sorted(ips) == ["2.2.2.2", "3.3.3.3"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_load_all_and_count_unresolved(tmp_path: Path) -> None:
|
||||||
|
db_path = str(tmp_path / "geo_cache.db")
|
||||||
|
async with aiosqlite.connect(db_path) as db:
|
||||||
|
await _create_geo_cache_table(db)
|
||||||
|
await db.executemany(
|
||||||
|
"INSERT INTO geo_cache (ip, country_code, country_name, asn, org) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
[
|
||||||
|
("5.5.5.5", None, None, None, None),
|
||||||
|
("6.6.6.6", "FR", "France", "AS456", "TestOrg"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async with aiosqlite.connect(db_path) as db:
|
||||||
|
rows = await geo_cache_repo.load_all(db)
|
||||||
|
unresolved = await geo_cache_repo.count_unresolved(db)
|
||||||
|
|
||||||
|
assert unresolved == 1
|
||||||
|
assert any(row["ip"] == "6.6.6.6" and row["country_code"] == "FR" for row in rows)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_entry_and_neg_entry(tmp_path: Path) -> None:
|
||||||
|
db_path = str(tmp_path / "geo_cache.db")
|
||||||
|
async with aiosqlite.connect(db_path) as db:
|
||||||
|
await _create_geo_cache_table(db)
|
||||||
|
|
||||||
|
await geo_cache_repo.upsert_entry(
|
||||||
|
db,
|
||||||
|
"7.7.7.7",
|
||||||
|
"GB",
|
||||||
|
"United Kingdom",
|
||||||
|
"AS789",
|
||||||
|
"TestOrg",
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
await geo_cache_repo.upsert_neg_entry(db, "8.8.8.8")
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# Ensure positive entry is present.
|
||||||
|
async with db.execute("SELECT country_code FROM geo_cache WHERE ip = ?", ("7.7.7.7",)) as cur:
|
||||||
|
row = await cur.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row[0] == "GB"
|
||||||
|
|
||||||
|
# Ensure negative entry exists with NULL country_code.
|
||||||
|
async with db.execute("SELECT country_code FROM geo_cache WHERE ip = ?", ("8.8.8.8",)) as cur:
|
||||||
|
row = await cur.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert row[0] is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bulk_upsert_entries_and_neg_entries(tmp_path: Path) -> None:
|
||||||
|
db_path = str(tmp_path / "geo_cache.db")
|
||||||
|
async with aiosqlite.connect(db_path) as db:
|
||||||
|
await _create_geo_cache_table(db)
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
("9.9.9.9", "NL", "Netherlands", "AS101", "Test"),
|
||||||
|
("10.10.10.10", "JP", "Japan", "AS102", "Test"),
|
||||||
|
]
|
||||||
|
count = await geo_cache_repo.bulk_upsert_entries(db, rows)
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
neg_count = await geo_cache_repo.bulk_upsert_neg_entries(db, ["11.11.11.11", "12.12.12.12"])
|
||||||
|
assert neg_count == 2
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async with db.execute("SELECT COUNT(*) FROM geo_cache") as cur:
|
||||||
|
row = await cur.fetchone()
|
||||||
|
assert row is not None
|
||||||
|
assert int(row[0]) == 4
|
||||||
|
|||||||
Reference in New Issue
Block a user