Refactor geo re-resolve to use geo_cache repo and move data-access out of router

This commit is contained in:
2026-03-16 21:12:07 +01:00
parent 8f515893ea
commit dcd8059b27
6 changed files with 157 additions and 193 deletions

View File

@@ -16,7 +16,9 @@ This document breaks the entire BanGUI project into development stages, ordered
--- ---
#### TASK B-1 — Create a `fail2ban_db` repository for direct fail2ban database queries #### TASK B-1 — Create a `fail2ban_db` repository for direct fail2ban database queries
**Status:** Completed
**Violated rule:** Refactoring.md §2.2 — Services must not perform direct `aiosqlite` calls; go through a repository. **Violated rule:** Refactoring.md §2.2 — Services must not perform direct `aiosqlite` calls; go through a repository.
@@ -41,6 +43,8 @@ This document breaks the entire BanGUI project into development stages, ordered
#### TASK B-2 — Remove direct SQL query from `routers/geo.py` #### TASK B-2 — Remove direct SQL query from `routers/geo.py`
**Status:** Completed ✅
**Violated rule:** Refactoring.md §2.1 — Routers must contain zero business logic; no SQL or repository imports. **Violated rule:** Refactoring.md §2.1 — Routers must contain zero business logic; no SQL or repository imports.
**Files affected:** **Files affected:**

View File

@@ -0,0 +1,33 @@
"""Repository for the geo cache persistent store.
This module provides typed, async helpers for querying and mutating the
``geo_cache`` table in the BanGUI application database.
All functions accept an open :class:`aiosqlite.Connection` and do not manage
connection lifetimes.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import aiosqlite
async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
"""Return all IPs in ``geo_cache`` where ``country_code`` is NULL.
Args:
db: Open BanGUI application database connection.
Returns:
List of IPv4/IPv6 strings that need geo resolution.
"""
ips: list[str] = []
async with db.execute(
"SELECT ip FROM geo_cache WHERE country_code IS NULL"
) as cur:
async for row in cur:
ips.append(str(row[0]))
return ips

View File

@@ -153,12 +153,7 @@ async def re_resolve_geo(
that were retried. that were retried.
""" """
# Collect all IPs in geo_cache that still lack a country code. # Collect all IPs in geo_cache that still lack a country code.
unresolved: list[str] = [] unresolved = await geo_service.get_unresolved_ips(db)
async with db.execute(
"SELECT ip FROM geo_cache WHERE country_code IS NULL"
) as cur:
async for row in cur:
unresolved.append(str(row[0]))
if not unresolved: if not unresolved:
return {"resolved": 0, "total": 0} return {"resolved": 0, "total": 0}

View File

@@ -13,12 +13,15 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import time import time
from dataclasses import asdict
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import aiosqlite
import structlog import structlog
if TYPE_CHECKING:
import aiosqlite
from app.models.ban import ( from app.models.ban import (
BLOCKLIST_JAIL, BLOCKLIST_JAIL,
BUCKET_SECONDS, BUCKET_SECONDS,
@@ -31,11 +34,11 @@ from app.models.ban import (
BanTrendResponse, BanTrendResponse,
DashboardBanItem, DashboardBanItem,
DashboardBanListResponse, DashboardBanListResponse,
JailBanCount,
TimeRange, TimeRange,
_derive_origin, _derive_origin,
bucket_count, bucket_count,
) )
from app.repositories import fail2ban_db_repo
from app.utils.fail2ban_client import Fail2BanClient from app.utils.fail2ban_client import Fail2BanClient
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -244,33 +247,20 @@ async def list_bans(
origin=origin, origin=origin,
) )
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db: rows, total = await fail2ban_db_repo.get_currently_banned(
f2b_db.row_factory = aiosqlite.Row db_path=db_path,
since=since,
async with f2b_db.execute( origin=origin,
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause, limit=effective_page_size,
(since, *origin_params), offset=offset,
) as cur: )
count_row = await cur.fetchone()
total: int = int(count_row[0]) if count_row else 0
async with f2b_db.execute(
"SELECT jail, ip, timeofban, bancount, data "
"FROM bans "
"WHERE timeofban >= ?"
+ origin_clause
+ " ORDER BY timeofban DESC "
"LIMIT ? OFFSET ?",
(since, *origin_params, effective_page_size, offset),
) as cur:
rows = await cur.fetchall()
# Batch-resolve geo data for all IPs on this page in a single API call. # Batch-resolve geo data for all IPs on this page in a single API call.
# This avoids hitting the 45 req/min single-IP rate limit when the # This avoids hitting the 45 req/min single-IP rate limit when the
# page contains many bans (e.g. after a large blocklist import). # page contains many bans (e.g. after a large blocklist import).
geo_map: dict[str, Any] = {} geo_map: dict[str, Any] = {}
if http_session is not None and rows: if http_session is not None and rows:
page_ips: list[str] = [str(r["ip"]) for r in rows] page_ips: list[str] = [r.ip for r in rows]
try: try:
geo_map = await geo_service.lookup_batch(page_ips, http_session, db=app_db) geo_map = await geo_service.lookup_batch(page_ips, http_session, db=app_db)
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
@@ -278,11 +268,11 @@ async def list_bans(
items: list[DashboardBanItem] = [] items: list[DashboardBanItem] = []
for row in rows: for row in rows:
jail: str = str(row["jail"]) jail: str = row.jail
ip: str = str(row["ip"]) ip: str = row.ip
banned_at: str = _ts_to_iso(int(row["timeofban"])) banned_at: str = _ts_to_iso(row.timeofban)
ban_count: int = int(row["bancount"]) ban_count: int = row.bancount
matches, _ = _parse_data_json(row["data"]) matches, _ = _parse_data_json(row.data)
service: str | None = matches[0] if matches else None service: str | None = matches[0] if matches else None
country_code: str | None = None country_code: str | None = None
@@ -395,42 +385,31 @@ async def bans_by_country(
origin=origin, origin=origin,
) )
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db: # Total count and companion rows reuse the same SQL query logic.
f2b_db.row_factory = aiosqlite.Row # Passing limit=0 returns only the total from the count query.
_, total = await fail2ban_db_repo.get_currently_banned(
db_path=db_path,
since=since,
origin=origin,
limit=0,
offset=0,
)
# Total count for the window. agg_rows = await fail2ban_db_repo.get_ban_event_counts(
async with f2b_db.execute( db_path=db_path,
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause, since=since,
(since, *origin_params), origin=origin,
) as cur: )
count_row = await cur.fetchone()
total: int = int(count_row[0]) if count_row else 0
# Aggregation: unique IPs + their total event count. companion_rows, _ = await fail2ban_db_repo.get_currently_banned(
# No LIMIT here — we need all unique source IPs for accurate country counts. db_path=db_path,
async with f2b_db.execute( since=since,
"SELECT ip, COUNT(*) AS event_count " origin=origin,
"FROM bans " limit=_MAX_COMPANION_BANS,
"WHERE timeofban >= ?" offset=0,
+ origin_clause )
+ " GROUP BY ip",
(since, *origin_params),
) as cur:
agg_rows = await cur.fetchall()
# Companion table: most recent raw rows for display alongside the map. unique_ips: list[str] = [r.ip for r in agg_rows]
async with f2b_db.execute(
"SELECT jail, ip, timeofban, bancount, data "
"FROM bans "
"WHERE timeofban >= ?"
+ origin_clause
+ " ORDER BY timeofban DESC "
"LIMIT ?",
(since, *origin_params, _MAX_COMPANION_BANS),
) as cur:
companion_rows = await cur.fetchall()
unique_ips: list[str] = [str(r["ip"]) for r in agg_rows]
geo_map: dict[str, Any] = {} geo_map: dict[str, Any] = {}
if http_session is not None and unique_ips: if http_session is not None and unique_ips:
@@ -467,11 +446,11 @@ async def bans_by_country(
country_names: dict[str, str] = {} country_names: dict[str, str] = {}
for row in agg_rows: for row in agg_rows:
ip: str = str(row["ip"]) ip: str = row.ip
geo = geo_map.get(ip) geo = geo_map.get(ip)
cc: str | None = geo.country_code if geo else None cc: str | None = geo.country_code if geo else None
cn: str | None = geo.country_name if geo else None cn: str | None = geo.country_name if geo else None
event_count: int = int(row["event_count"]) event_count: int = row.event_count
if cc: if cc:
countries[cc] = countries.get(cc, 0) + event_count countries[cc] = countries.get(cc, 0) + event_count
@@ -481,26 +460,26 @@ async def bans_by_country(
# Build companion table from recent rows (geo already cached from batch step). # Build companion table from recent rows (geo already cached from batch step).
bans: list[DashboardBanItem] = [] bans: list[DashboardBanItem] = []
for row in companion_rows: for row in companion_rows:
ip = str(row["ip"]) ip = row.ip
geo = geo_map.get(ip) geo = geo_map.get(ip)
cc = geo.country_code if geo else None cc = geo.country_code if geo else None
cn = geo.country_name if geo else None cn = geo.country_name if geo else None
asn: str | None = geo.asn if geo else None asn: str | None = geo.asn if geo else None
org: str | None = geo.org if geo else None org: str | None = geo.org if geo else None
matches, _ = _parse_data_json(row["data"]) matches, _ = _parse_data_json(row.data)
bans.append( bans.append(
DashboardBanItem( DashboardBanItem(
ip=ip, ip=ip,
jail=str(row["jail"]), jail=row.jail,
banned_at=_ts_to_iso(int(row["timeofban"])), banned_at=_ts_to_iso(row.timeofban),
service=matches[0] if matches else None, service=matches[0] if matches else None,
country_code=cc, country_code=cc,
country_name=cn, country_name=cn,
asn=asn, asn=asn,
org=org, org=org,
ban_count=int(row["bancount"]), ban_count=row.bancount,
origin=_derive_origin(str(row["jail"])), origin=_derive_origin(row.jail),
) )
) )
@@ -565,32 +544,18 @@ async def ban_trend(
num_buckets=num_buckets, num_buckets=num_buckets,
) )
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db: counts = await fail2ban_db_repo.get_ban_counts_by_bucket(
f2b_db.row_factory = aiosqlite.Row db_path=db_path,
since=since,
async with f2b_db.execute( bucket_secs=bucket_secs,
"SELECT CAST((timeofban - ?) / ? AS INTEGER) AS bucket_idx, " num_buckets=num_buckets,
"COUNT(*) AS cnt " origin=origin,
"FROM bans " )
"WHERE timeofban >= ?"
+ origin_clause
+ " GROUP BY bucket_idx "
"ORDER BY bucket_idx",
(since, bucket_secs, since, *origin_params),
) as cur:
rows = await cur.fetchall()
# Map bucket_idx → count; ignore any out-of-range indices.
counts: dict[int, int] = {}
for row in rows:
idx: int = int(row["bucket_idx"])
if 0 <= idx < num_buckets:
counts[idx] = int(row["cnt"])
buckets: list[BanTrendBucket] = [ buckets: list[BanTrendBucket] = [
BanTrendBucket( BanTrendBucket(
timestamp=_ts_to_iso(since + i * bucket_secs), timestamp=_ts_to_iso(since + i * bucket_secs),
count=counts.get(i, 0), count=counts[i],
) )
for i in range(num_buckets) for i in range(num_buckets)
] ]
@@ -643,50 +608,37 @@ async def bans_by_jail(
origin=origin, origin=origin,
) )
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db: total, jails = await fail2ban_db_repo.get_bans_by_jail(
f2b_db.row_factory = aiosqlite.Row db_path=db_path,
since=since,
origin=origin,
)
async with f2b_db.execute( # Diagnostic guard: if zero results were returned, check whether the table
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause, # has *any* rows and log a warning with min/max timeofban so operators can
(since, *origin_params), # diagnose timezone or filter mismatches from logs.
) as cur:
count_row = await cur.fetchone()
total: int = int(count_row[0]) if count_row else 0
# Diagnostic guard: if zero results were returned, check whether the
# table has *any* rows and log a warning with min/max timeofban so
# operators can diagnose timezone or filter mismatches from logs.
if total == 0: if total == 0:
async with f2b_db.execute( table_row_count, min_timeofban, max_timeofban = (
"SELECT COUNT(*), MIN(timeofban), MAX(timeofban) FROM bans" await fail2ban_db_repo.get_bans_table_summary(db_path)
) as cur: )
diag_row = await cur.fetchone() if table_row_count > 0:
if diag_row and diag_row[0] > 0:
log.warning( log.warning(
"ban_service_bans_by_jail_empty_despite_data", "ban_service_bans_by_jail_empty_despite_data",
table_row_count=diag_row[0], table_row_count=table_row_count,
min_timeofban=diag_row[1], min_timeofban=min_timeofban,
max_timeofban=diag_row[2], max_timeofban=max_timeofban,
since=since, since=since,
range=range_, range=range_,
) )
async with f2b_db.execute(
"SELECT jail, COUNT(*) AS cnt "
"FROM bans "
"WHERE timeofban >= ?"
+ origin_clause
+ " GROUP BY jail ORDER BY cnt DESC",
(since, *origin_params),
) as cur:
rows = await cur.fetchall()
jails: list[JailBanCount] = [
JailBanCount(jail=str(row["jail"]), count=int(row["cnt"])) for row in rows
]
log.debug( log.debug(
"ban_service_bans_by_jail_result", "ban_service_bans_by_jail_result",
total=total, total=total,
jail_count=len(jails), jail_count=len(jails),
) )
return BansByJailResponse(jails=jails, total=total)
# Pydantic strict validation requires either dicts or model instances.
# Our repository returns dataclasses for simplicity, so convert them here.
jail_dicts: list[dict[str, object]] = [asdict(j) for j in jails]
return BansByJailResponse(jails=jail_dicts, total=total)

View File

@@ -46,6 +46,8 @@ from typing import TYPE_CHECKING
import aiohttp import aiohttp
import structlog import structlog
from app.repositories import geo_cache_repo
if TYPE_CHECKING: if TYPE_CHECKING:
import aiosqlite import aiosqlite
import geoip2.database import geoip2.database
@@ -198,6 +200,18 @@ async def cache_stats(db: aiosqlite.Connection) -> dict[str, int]:
} }
async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]:
"""Return geo cache IPs where the country code has not yet been resolved.
Args:
db: Open BanGUI application database connection.
Returns:
List of IP addresses that are candidates for re-resolution.
"""
return await geo_cache_repo.get_unresolved_ips(db)
def init_geoip(mmdb_path: str | None) -> None: def init_geoip(mmdb_path: str | None) -> None:
"""Initialise the MaxMind GeoLite2-Country database reader. """Initialise the MaxMind GeoLite2-Country database reader.

View File

@@ -13,7 +13,6 @@ from __future__ import annotations
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any from typing import Any
import aiosqlite
import structlog import structlog
from app.models.ban import TIME_RANGE_SECONDS, TimeRange from app.models.ban import TIME_RANGE_SECONDS, TimeRange
@@ -23,6 +22,7 @@ from app.models.history import (
IpDetailResponse, IpDetailResponse,
IpTimelineEvent, IpTimelineEvent,
) )
from app.repositories import fail2ban_db_repo
from app.services.ban_service import _get_fail2ban_db_path, _parse_data_json, _ts_to_iso from app.services.ban_service import _get_fail2ban_db_path, _parse_data_json, _ts_to_iso
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -84,26 +84,11 @@ async def list_history(
and the total matching count. and the total matching count.
""" """
effective_page_size: int = min(page_size, _MAX_PAGE_SIZE) effective_page_size: int = min(page_size, _MAX_PAGE_SIZE)
offset: int = (page - 1) * effective_page_size
# Build WHERE clauses dynamically. # Build WHERE clauses dynamically.
wheres: list[str] = [] since: int | None = None
params: list[Any] = []
if range_ is not None: if range_ is not None:
since: int = _since_unix(range_) since = _since_unix(range_)
wheres.append("timeofban >= ?")
params.append(since)
if jail is not None:
wheres.append("jail = ?")
params.append(jail)
if ip_filter is not None:
wheres.append("ip LIKE ?")
params.append(f"{ip_filter}%")
where_sql: str = ("WHERE " + " AND ".join(wheres)) if wheres else ""
db_path: str = await _get_fail2ban_db_path(socket_path) db_path: str = await _get_fail2ban_db_path(socket_path)
log.info( log.info(
@@ -115,32 +100,22 @@ async def list_history(
page=page, page=page,
) )
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db: rows, total = await fail2ban_db_repo.get_history_page(
f2b_db.row_factory = aiosqlite.Row db_path=db_path,
since=since,
async with f2b_db.execute( jail=jail,
f"SELECT COUNT(*) FROM bans {where_sql}", # noqa: S608 ip_filter=ip_filter,
params, page=page,
) as cur: page_size=effective_page_size,
count_row = await cur.fetchone() )
total: int = int(count_row[0]) if count_row else 0
async with f2b_db.execute(
f"SELECT jail, ip, timeofban, bancount, data " # noqa: S608
f"FROM bans {where_sql} "
"ORDER BY timeofban DESC "
"LIMIT ? OFFSET ?",
[*params, effective_page_size, offset],
) as cur:
rows = await cur.fetchall()
items: list[HistoryBanItem] = [] items: list[HistoryBanItem] = []
for row in rows: for row in rows:
jail_name: str = str(row["jail"]) jail_name: str = row.jail
ip: str = str(row["ip"]) ip: str = row.ip
banned_at: str = _ts_to_iso(int(row["timeofban"])) banned_at: str = _ts_to_iso(row.timeofban)
ban_count: int = int(row["bancount"]) ban_count: int = row.bancount
matches, failures = _parse_data_json(row["data"]) matches, failures = _parse_data_json(row.data)
country_code: str | None = None country_code: str | None = None
country_name: str | None = None country_name: str | None = None
@@ -205,16 +180,7 @@ async def get_ip_detail(
db_path: str = await _get_fail2ban_db_path(socket_path) db_path: str = await _get_fail2ban_db_path(socket_path)
log.info("history_service_ip_detail", db_path=db_path, ip=ip) log.info("history_service_ip_detail", db_path=db_path, ip=ip)
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db: rows = await fail2ban_db_repo.get_history_for_ip(db_path=db_path, ip=ip)
f2b_db.row_factory = aiosqlite.Row
async with f2b_db.execute(
"SELECT jail, ip, timeofban, bancount, data "
"FROM bans "
"WHERE ip = ? "
"ORDER BY timeofban DESC",
(ip,),
) as cur:
rows = await cur.fetchall()
if not rows: if not rows:
return None return None
@@ -223,10 +189,10 @@ async def get_ip_detail(
total_failures: int = 0 total_failures: int = 0
for row in rows: for row in rows:
jail_name: str = str(row["jail"]) jail_name: str = row.jail
banned_at: str = _ts_to_iso(int(row["timeofban"])) banned_at: str = _ts_to_iso(row.timeofban)
ban_count: int = int(row["bancount"]) ban_count: int = row.bancount
matches, failures = _parse_data_json(row["data"]) matches, failures = _parse_data_json(row.data)
total_failures += failures total_failures += failures
timeline.append( timeline.append(
IpTimelineEvent( IpTimelineEvent(