- Add TYPE_CHECKING guards for runtime-expensive imports (aiohttp, aiosqlite) - Reorganize imports to follow PEP 8 conventions - Convert TypeAlias to modern PEP 695 type syntax (where appropriate) - Use Sequence/Mapping from collections.abc for type hints (covariant) - Replace string literals with cast() for improved type inference - Fix casting of Fail2BanResponse and TypedDict patterns - Add IpLookupResult TypedDict for precise return type annotation - Reformat overlong lines for readability (120 char limit) - Add asyncio_mode and filterwarnings to pytest config - Update test fixtures with improved type hints This improves mypy type checking and makes type relationships explicit.
149 lines
4.1 KiB
Python
149 lines
4.1 KiB
Python
"""Repository for the geo cache persistent store.
|
|
|
|
This module provides typed, async helpers for querying and mutating the
|
|
``geo_cache`` table in the BanGUI application database.
|
|
|
|
All functions accept an open :class:`aiosqlite.Connection` and do not manage
|
|
connection lifetimes.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING, TypedDict
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Sequence
|
|
|
|
import aiosqlite
|
|
|
|
|
|
class GeoCacheRow(TypedDict):
|
|
"""A single row from the ``geo_cache`` table."""
|
|
|
|
ip: str
|
|
country_code: str | None
|
|
country_name: str | None
|
|
asn: str | None
|
|
org: str | None
|
|
|
|
|
|
async def load_all(db: aiosqlite.Connection) -> list[GeoCacheRow]:
|
|
"""Load all geo cache rows from the database.
|
|
|
|
Args:
|
|
db: Open BanGUI application database connection.
|
|
|
|
Returns:
|
|
List of rows from the ``geo_cache`` table.
|
|
"""
|
|
rows: list[GeoCacheRow] = []
|
|
async with db.execute(
|
|
"SELECT ip, country_code, country_name, asn, org FROM geo_cache"
|
|
) as cur:
|
|
async for row in cur:
|
|
rows.append(
|
|
GeoCacheRow(
|
|
ip=str(row[0]),
|
|
country_code=row[1],
|
|
country_name=row[2],
|
|
asn=row[3],
|
|
org=row[4],
|
|
)
|
|
)
|
|
return rows
|
|
|
|
|
|
async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
|
|
"""Return all IPs in ``geo_cache`` where ``country_code`` is NULL.
|
|
|
|
Args:
|
|
db: Open BanGUI application database connection.
|
|
|
|
Returns:
|
|
List of IPv4/IPv6 strings that need geo resolution.
|
|
"""
|
|
ips: list[str] = []
|
|
async with db.execute(
|
|
"SELECT ip FROM geo_cache WHERE country_code IS NULL"
|
|
) as cur:
|
|
async for row in cur:
|
|
ips.append(str(row[0]))
|
|
return ips
|
|
|
|
|
|
async def count_unresolved(db: aiosqlite.Connection) -> int:
|
|
"""Return the number of unresolved rows (country_code IS NULL)."""
|
|
async with db.execute(
|
|
"SELECT COUNT(*) FROM geo_cache WHERE country_code IS NULL"
|
|
) as cur:
|
|
row = await cur.fetchone()
|
|
return int(row[0]) if row else 0
|
|
|
|
|
|
async def upsert_entry(
|
|
db: aiosqlite.Connection,
|
|
ip: str,
|
|
country_code: str | None,
|
|
country_name: str | None,
|
|
asn: str | None,
|
|
org: str | None,
|
|
) -> None:
|
|
"""Insert or update a resolved geo cache entry."""
|
|
await db.execute(
|
|
"""
|
|
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
ON CONFLICT(ip) DO UPDATE SET
|
|
country_code = excluded.country_code,
|
|
country_name = excluded.country_name,
|
|
asn = excluded.asn,
|
|
org = excluded.org,
|
|
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
|
""",
|
|
(ip, country_code, country_name, asn, org),
|
|
)
|
|
|
|
|
|
async def upsert_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
|
|
"""Record a failed lookup attempt as a negative entry."""
|
|
await db.execute(
|
|
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
|
|
(ip,),
|
|
)
|
|
|
|
|
|
async def bulk_upsert_entries(
|
|
db: aiosqlite.Connection,
|
|
rows: Sequence[tuple[str, str | None, str | None, str | None, str | None]],
|
|
) -> int:
|
|
"""Bulk insert or update multiple geo cache entries."""
|
|
if not rows:
|
|
return 0
|
|
|
|
await db.executemany(
|
|
"""
|
|
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
|
VALUES (?, ?, ?, ?, ?)
|
|
ON CONFLICT(ip) DO UPDATE SET
|
|
country_code = excluded.country_code,
|
|
country_name = excluded.country_name,
|
|
asn = excluded.asn,
|
|
org = excluded.org,
|
|
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
|
""",
|
|
rows,
|
|
)
|
|
return len(rows)
|
|
|
|
|
|
async def bulk_upsert_neg_entries(db: aiosqlite.Connection, ips: list[str]) -> int:
|
|
"""Bulk insert negative lookup entries."""
|
|
if not ips:
|
|
return 0
|
|
|
|
await db.executemany(
|
|
"INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)",
|
|
[(ip,) for ip in ips],
|
|
)
|
|
return len(ips)
|