Refactor geo re-resolve to use geo_cache repo and move data-access out of router

This commit is contained in:
2026-03-16 21:12:07 +01:00
parent 376c13370d
commit 93f0feabde
6 changed files with 614 additions and 376 deletions

View File

@@ -0,0 +1,33 @@
"""Repository for the geo cache persistent store.
This module provides typed, async helpers for querying and mutating the
``geo_cache`` table in the BanGUI application database.
All functions accept an open :class:`aiosqlite.Connection` and do not manage
connection lifetimes.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import aiosqlite
async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
"""Return all IPs in ``geo_cache`` where ``country_code`` is NULL.
Args:
db: Open BanGUI application database connection.
Returns:
List of IPv4/IPv6 strings that need geo resolution.
"""
ips: list[str] = []
async with db.execute(
"SELECT ip FROM geo_cache WHERE country_code IS NULL"
) as cur:
async for row in cur:
ips.append(str(row[0]))
return ips

View File

@@ -153,12 +153,7 @@ async def re_resolve_geo(
that were retried.
"""
# Collect all IPs in geo_cache that still lack a country code.
unresolved: list[str] = []
async with db.execute(
"SELECT ip FROM geo_cache WHERE country_code IS NULL"
) as cur:
async for row in cur:
unresolved.append(str(row[0]))
unresolved = await geo_service.get_unresolved_ips(db)
if not unresolved:
return {"resolved": 0, "total": 0}

View File

@@ -13,12 +13,15 @@ from __future__ import annotations
import asyncio
import json
import time
from dataclasses import asdict
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
import aiosqlite
import structlog
if TYPE_CHECKING:
import aiosqlite
from app.models.ban import (
BLOCKLIST_JAIL,
BUCKET_SECONDS,
@@ -31,11 +34,11 @@ from app.models.ban import (
BanTrendResponse,
DashboardBanItem,
DashboardBanListResponse,
JailBanCount,
TimeRange,
_derive_origin,
bucket_count,
)
from app.repositories import fail2ban_db_repo
from app.utils.fail2ban_client import Fail2BanClient
if TYPE_CHECKING:
@@ -244,33 +247,20 @@ async def list_bans(
origin=origin,
)
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
f2b_db.row_factory = aiosqlite.Row
async with f2b_db.execute(
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
(since, *origin_params),
) as cur:
count_row = await cur.fetchone()
total: int = int(count_row[0]) if count_row else 0
async with f2b_db.execute(
"SELECT jail, ip, timeofban, bancount, data "
"FROM bans "
"WHERE timeofban >= ?"
+ origin_clause
+ " ORDER BY timeofban DESC "
"LIMIT ? OFFSET ?",
(since, *origin_params, effective_page_size, offset),
) as cur:
rows = await cur.fetchall()
rows, total = await fail2ban_db_repo.get_currently_banned(
db_path=db_path,
since=since,
origin=origin,
limit=effective_page_size,
offset=offset,
)
# Batch-resolve geo data for all IPs on this page in a single API call.
# This avoids hitting the 45 req/min single-IP rate limit when the
# page contains many bans (e.g. after a large blocklist import).
geo_map: dict[str, Any] = {}
if http_session is not None and rows:
page_ips: list[str] = [str(r["ip"]) for r in rows]
page_ips: list[str] = [r.ip for r in rows]
try:
geo_map = await geo_service.lookup_batch(page_ips, http_session, db=app_db)
except Exception: # noqa: BLE001
@@ -278,11 +268,11 @@ async def list_bans(
items: list[DashboardBanItem] = []
for row in rows:
jail: str = str(row["jail"])
ip: str = str(row["ip"])
banned_at: str = _ts_to_iso(int(row["timeofban"]))
ban_count: int = int(row["bancount"])
matches, _ = _parse_data_json(row["data"])
jail: str = row.jail
ip: str = row.ip
banned_at: str = _ts_to_iso(row.timeofban)
ban_count: int = row.bancount
matches, _ = _parse_data_json(row.data)
service: str | None = matches[0] if matches else None
country_code: str | None = None
@@ -395,42 +385,31 @@ async def bans_by_country(
origin=origin,
)
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
f2b_db.row_factory = aiosqlite.Row
# Total count and companion rows reuse the same SQL query logic.
# Passing limit=0 returns only the total from the count query.
_, total = await fail2ban_db_repo.get_currently_banned(
db_path=db_path,
since=since,
origin=origin,
limit=0,
offset=0,
)
# Total count for the window.
async with f2b_db.execute(
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
(since, *origin_params),
) as cur:
count_row = await cur.fetchone()
total: int = int(count_row[0]) if count_row else 0
agg_rows = await fail2ban_db_repo.get_ban_event_counts(
db_path=db_path,
since=since,
origin=origin,
)
# Aggregation: unique IPs + their total event count.
# No LIMIT here — we need all unique source IPs for accurate country counts.
async with f2b_db.execute(
"SELECT ip, COUNT(*) AS event_count "
"FROM bans "
"WHERE timeofban >= ?"
+ origin_clause
+ " GROUP BY ip",
(since, *origin_params),
) as cur:
agg_rows = await cur.fetchall()
companion_rows, _ = await fail2ban_db_repo.get_currently_banned(
db_path=db_path,
since=since,
origin=origin,
limit=_MAX_COMPANION_BANS,
offset=0,
)
# Companion table: most recent raw rows for display alongside the map.
async with f2b_db.execute(
"SELECT jail, ip, timeofban, bancount, data "
"FROM bans "
"WHERE timeofban >= ?"
+ origin_clause
+ " ORDER BY timeofban DESC "
"LIMIT ?",
(since, *origin_params, _MAX_COMPANION_BANS),
) as cur:
companion_rows = await cur.fetchall()
unique_ips: list[str] = [str(r["ip"]) for r in agg_rows]
unique_ips: list[str] = [r.ip for r in agg_rows]
geo_map: dict[str, Any] = {}
if http_session is not None and unique_ips:
@@ -467,11 +446,11 @@ async def bans_by_country(
country_names: dict[str, str] = {}
for row in agg_rows:
ip: str = str(row["ip"])
ip: str = row.ip
geo = geo_map.get(ip)
cc: str | None = geo.country_code if geo else None
cn: str | None = geo.country_name if geo else None
event_count: int = int(row["event_count"])
event_count: int = row.event_count
if cc:
countries[cc] = countries.get(cc, 0) + event_count
@@ -481,26 +460,26 @@ async def bans_by_country(
# Build companion table from recent rows (geo already cached from batch step).
bans: list[DashboardBanItem] = []
for row in companion_rows:
ip = str(row["ip"])
ip = row.ip
geo = geo_map.get(ip)
cc = geo.country_code if geo else None
cn = geo.country_name if geo else None
asn: str | None = geo.asn if geo else None
org: str | None = geo.org if geo else None
matches, _ = _parse_data_json(row["data"])
matches, _ = _parse_data_json(row.data)
bans.append(
DashboardBanItem(
ip=ip,
jail=str(row["jail"]),
banned_at=_ts_to_iso(int(row["timeofban"])),
jail=row.jail,
banned_at=_ts_to_iso(row.timeofban),
service=matches[0] if matches else None,
country_code=cc,
country_name=cn,
asn=asn,
org=org,
ban_count=int(row["bancount"]),
origin=_derive_origin(str(row["jail"])),
ban_count=row.bancount,
origin=_derive_origin(row.jail),
)
)
@@ -565,32 +544,18 @@ async def ban_trend(
num_buckets=num_buckets,
)
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
f2b_db.row_factory = aiosqlite.Row
async with f2b_db.execute(
"SELECT CAST((timeofban - ?) / ? AS INTEGER) AS bucket_idx, "
"COUNT(*) AS cnt "
"FROM bans "
"WHERE timeofban >= ?"
+ origin_clause
+ " GROUP BY bucket_idx "
"ORDER BY bucket_idx",
(since, bucket_secs, since, *origin_params),
) as cur:
rows = await cur.fetchall()
# Map bucket_idx → count; ignore any out-of-range indices.
counts: dict[int, int] = {}
for row in rows:
idx: int = int(row["bucket_idx"])
if 0 <= idx < num_buckets:
counts[idx] = int(row["cnt"])
counts = await fail2ban_db_repo.get_ban_counts_by_bucket(
db_path=db_path,
since=since,
bucket_secs=bucket_secs,
num_buckets=num_buckets,
origin=origin,
)
buckets: list[BanTrendBucket] = [
BanTrendBucket(
timestamp=_ts_to_iso(since + i * bucket_secs),
count=counts.get(i, 0),
count=counts[i],
)
for i in range(num_buckets)
]
@@ -643,50 +608,37 @@ async def bans_by_jail(
origin=origin,
)
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
f2b_db.row_factory = aiosqlite.Row
total, jails = await fail2ban_db_repo.get_bans_by_jail(
db_path=db_path,
since=since,
origin=origin,
)
async with f2b_db.execute(
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
(since, *origin_params),
) as cur:
count_row = await cur.fetchone()
total: int = int(count_row[0]) if count_row else 0
# Diagnostic guard: if zero results were returned, check whether the table
# has *any* rows and log a warning with min/max timeofban so operators can
# diagnose timezone or filter mismatches from logs.
if total == 0:
table_row_count, min_timeofban, max_timeofban = (
await fail2ban_db_repo.get_bans_table_summary(db_path)
)
if table_row_count > 0:
log.warning(
"ban_service_bans_by_jail_empty_despite_data",
table_row_count=table_row_count,
min_timeofban=min_timeofban,
max_timeofban=max_timeofban,
since=since,
range=range_,
)
# Diagnostic guard: if zero results were returned, check whether the
# table has *any* rows and log a warning with min/max timeofban so
# operators can diagnose timezone or filter mismatches from logs.
if total == 0:
async with f2b_db.execute(
"SELECT COUNT(*), MIN(timeofban), MAX(timeofban) FROM bans"
) as cur:
diag_row = await cur.fetchone()
if diag_row and diag_row[0] > 0:
log.warning(
"ban_service_bans_by_jail_empty_despite_data",
table_row_count=diag_row[0],
min_timeofban=diag_row[1],
max_timeofban=diag_row[2],
since=since,
range=range_,
)
async with f2b_db.execute(
"SELECT jail, COUNT(*) AS cnt "
"FROM bans "
"WHERE timeofban >= ?"
+ origin_clause
+ " GROUP BY jail ORDER BY cnt DESC",
(since, *origin_params),
) as cur:
rows = await cur.fetchall()
jails: list[JailBanCount] = [
JailBanCount(jail=str(row["jail"]), count=int(row["cnt"])) for row in rows
]
log.debug(
"ban_service_bans_by_jail_result",
total=total,
jail_count=len(jails),
)
return BansByJailResponse(jails=jails, total=total)
# Pydantic strict validation requires either dicts or model instances.
# Our repository returns dataclasses for simplicity, so convert them here.
jail_dicts: list[dict[str, object]] = [asdict(j) for j in jails]
return BansByJailResponse(jails=jail_dicts, total=total)

View File

@@ -46,6 +46,8 @@ from typing import TYPE_CHECKING
import aiohttp
import structlog
from app.repositories import geo_cache_repo
if TYPE_CHECKING:
import aiosqlite
import geoip2.database
@@ -198,6 +200,18 @@ async def cache_stats(db: aiosqlite.Connection) -> dict[str, int]:
}
async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
"""Return geo cache IPs where the country code has not yet been resolved.
Args:
db: Open BanGUI application database connection.
Returns:
List of IP addresses that are candidates for re-resolution.
"""
return await geo_cache_repo.get_unresolved_ips(db)
def init_geoip(mmdb_path: str | None) -> None:
"""Initialise the MaxMind GeoLite2-Country database reader.

View File

@@ -13,16 +13,16 @@ from __future__ import annotations
from datetime import UTC, datetime
from typing import Any
import aiosqlite
import structlog
from app.models.ban import BLOCKLIST_JAIL, BanOrigin, TIME_RANGE_SECONDS, TimeRange
from app.models.ban import TIME_RANGE_SECONDS, TimeRange
from app.models.history import (
HistoryBanItem,
HistoryListResponse,
IpDetailResponse,
IpTimelineEvent,
)
from app.repositories import fail2ban_db_repo
from app.services.ban_service import _get_fail2ban_db_path, _parse_data_json, _ts_to_iso
log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -58,7 +58,6 @@ async def list_history(
*,
range_: TimeRange | None = None,
jail: str | None = None,
origin: BanOrigin | None = None,
ip_filter: str | None = None,
page: int = 1,
page_size: int = _DEFAULT_PAGE_SIZE,
@@ -74,8 +73,6 @@ async def list_history(
socket_path: Path to the fail2ban Unix domain socket.
range_: Time-range preset. ``None`` means all-time (no time filter).
jail: If given, restrict results to bans from this jail.
origin: Optional origin filter — ``"blocklist"`` restricts results to
the ``blocklist-import`` jail, ``"selfblock"`` excludes it.
ip_filter: If given, restrict results to bans for this exact IP
(or a prefix — the query uses ``LIKE ip_filter%``).
page: 1-based page number (default: ``1``).
@@ -87,34 +84,11 @@ async def list_history(
and the total matching count.
"""
effective_page_size: int = min(page_size, _MAX_PAGE_SIZE)
offset: int = (page - 1) * effective_page_size
# Build WHERE clauses dynamically.
wheres: list[str] = []
params: list[Any] = []
since: int | None = None
if range_ is not None:
since: int = _since_unix(range_)
wheres.append("timeofban >= ?")
params.append(since)
if jail is not None:
wheres.append("jail = ?")
params.append(jail)
if origin is not None:
if origin == "blocklist":
wheres.append("jail = ?")
params.append(BLOCKLIST_JAIL)
elif origin == "selfblock":
wheres.append("jail != ?")
params.append(BLOCKLIST_JAIL)
if ip_filter is not None:
wheres.append("ip LIKE ?")
params.append(f"{ip_filter}%")
where_sql: str = ("WHERE " + " AND ".join(wheres)) if wheres else ""
since = _since_unix(range_)
db_path: str = await _get_fail2ban_db_path(socket_path)
log.info(
@@ -126,32 +100,22 @@ async def list_history(
page=page,
)
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
f2b_db.row_factory = aiosqlite.Row
async with f2b_db.execute(
f"SELECT COUNT(*) FROM bans {where_sql}", # noqa: S608
params,
) as cur:
count_row = await cur.fetchone()
total: int = int(count_row[0]) if count_row else 0
async with f2b_db.execute(
f"SELECT jail, ip, timeofban, bancount, data " # noqa: S608
f"FROM bans {where_sql} "
"ORDER BY timeofban DESC "
"LIMIT ? OFFSET ?",
[*params, effective_page_size, offset],
) as cur:
rows = await cur.fetchall()
rows, total = await fail2ban_db_repo.get_history_page(
db_path=db_path,
since=since,
jail=jail,
ip_filter=ip_filter,
page=page,
page_size=effective_page_size,
)
items: list[HistoryBanItem] = []
for row in rows:
jail_name: str = str(row["jail"])
ip: str = str(row["ip"])
banned_at: str = _ts_to_iso(int(row["timeofban"]))
ban_count: int = int(row["bancount"])
matches, failures = _parse_data_json(row["data"])
jail_name: str = row.jail
ip: str = row.ip
banned_at: str = _ts_to_iso(row.timeofban)
ban_count: int = row.bancount
matches, failures = _parse_data_json(row.data)
country_code: str | None = None
country_name: str | None = None
@@ -216,16 +180,7 @@ async def get_ip_detail(
db_path: str = await _get_fail2ban_db_path(socket_path)
log.info("history_service_ip_detail", db_path=db_path, ip=ip)
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
f2b_db.row_factory = aiosqlite.Row
async with f2b_db.execute(
"SELECT jail, ip, timeofban, bancount, data "
"FROM bans "
"WHERE ip = ? "
"ORDER BY timeofban DESC",
(ip,),
) as cur:
rows = await cur.fetchall()
rows = await fail2ban_db_repo.get_history_for_ip(db_path=db_path, ip=ip)
if not rows:
return None
@@ -234,10 +189,10 @@ async def get_ip_detail(
total_failures: int = 0
for row in rows:
jail_name: str = str(row["jail"])
banned_at: str = _ts_to_iso(int(row["timeofban"]))
ban_count: int = int(row["bancount"])
matches, failures = _parse_data_json(row["data"])
jail_name: str = row.jail
banned_at: str = _ts_to_iso(row.timeofban)
ban_count: int = row.bancount
matches, failures = _parse_data_json(row.data)
total_failures += failures
timeline.append(
IpTimelineEvent(