refactor: complete Task 2/3 geo decouple + exceptions centralization; mark as done
This commit is contained in:
@@ -11,11 +11,8 @@ so BanGUI never modifies or locks the fail2ban database.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
|
||||
@@ -39,18 +36,16 @@ from app.models.ban import (
|
||||
JailBanCount as JailBanCountModel,
|
||||
)
|
||||
from app.repositories import fail2ban_db_repo
|
||||
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanResponse
|
||||
from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoBatchLookup, GeoCacheLookup, GeoEnricher, GeoInfo
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
type GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -102,98 +97,6 @@ def _since_unix(range_: TimeRange) -> int:
|
||||
return int(time.time()) - seconds
|
||||
|
||||
|
||||
def _ts_to_iso(unix_ts: int) -> str:
|
||||
"""Convert a Unix timestamp to an ISO 8601 UTC string.
|
||||
|
||||
Args:
|
||||
unix_ts: Seconds since the Unix epoch.
|
||||
|
||||
Returns:
|
||||
ISO 8601 UTC timestamp, e.g. ``"2026-03-01T12:00:00+00:00"``.
|
||||
"""
|
||||
return datetime.fromtimestamp(unix_ts, tz=UTC).isoformat()
|
||||
|
||||
|
||||
async def _get_fail2ban_db_path(socket_path: str) -> str:
|
||||
"""Query fail2ban for the path to its SQLite database.
|
||||
|
||||
Sends the ``get dbfile`` command via the fail2ban socket and returns
|
||||
the value of the ``dbfile`` setting.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
|
||||
Returns:
|
||||
Absolute path to the fail2ban SQLite database file.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If fail2ban reports that no database is configured
|
||||
or if the socket response is unexpected.
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
||||
cannot be reached.
|
||||
"""
|
||||
async with Fail2BanClient(socket_path, timeout=_SOCKET_TIMEOUT) as client:
|
||||
response = await client.send(["get", "dbfile"])
|
||||
|
||||
try:
|
||||
code, data = cast("Fail2BanResponse", response)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise RuntimeError(f"Unexpected response from fail2ban: {response!r}") from exc
|
||||
|
||||
if code != 0:
|
||||
raise RuntimeError(f"fail2ban error code {code}: {data!r}")
|
||||
|
||||
if data is None:
|
||||
raise RuntimeError("fail2ban has no database configured (dbfile is None)")
|
||||
|
||||
return str(data)
|
||||
|
||||
|
||||
def _parse_data_json(raw: object) -> tuple[list[str], int]:
|
||||
"""Extract matches and failure count from the ``bans.data`` column.
|
||||
|
||||
The ``data`` column stores a JSON blob with optional keys:
|
||||
|
||||
* ``matches`` — list of raw matched log lines.
|
||||
* ``failures`` — total failure count that triggered the ban.
|
||||
|
||||
Args:
|
||||
raw: The raw ``data`` column value (string, dict, or ``None``).
|
||||
|
||||
Returns:
|
||||
A ``(matches, failures)`` tuple. Both default to empty/zero when
|
||||
parsing fails or the column is absent.
|
||||
"""
|
||||
if raw is None:
|
||||
return [], 0
|
||||
|
||||
obj: dict[str, object] = {}
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
parsed: object = json.loads(raw)
|
||||
if isinstance(parsed, dict):
|
||||
obj = parsed
|
||||
# json.loads("null") → None, or other non-dict — treat as empty
|
||||
except json.JSONDecodeError:
|
||||
return [], 0
|
||||
elif isinstance(raw, dict):
|
||||
obj = raw
|
||||
|
||||
raw_matches = obj.get("matches")
|
||||
if isinstance(raw_matches, list):
|
||||
matches: list[str] = [str(m) for m in raw_matches]
|
||||
else:
|
||||
matches = []
|
||||
|
||||
raw_failures = obj.get("failures")
|
||||
failures: int = 0
|
||||
if isinstance(raw_failures, (int, float, str)):
|
||||
try:
|
||||
failures = int(raw_failures)
|
||||
except (ValueError, TypeError):
|
||||
failures = 0
|
||||
|
||||
return matches, failures
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -209,6 +112,7 @@ async def list_bans(
|
||||
page_size: int = _DEFAULT_PAGE_SIZE,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> DashboardBanListResponse:
|
||||
@@ -248,14 +152,13 @@ async def list_bans(
|
||||
:class:`~app.models.ban.DashboardBanListResponse` containing the
|
||||
paginated items and total count.
|
||||
"""
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
since: int = _since_unix(range_)
|
||||
effective_page_size: int = min(page_size, _MAX_PAGE_SIZE)
|
||||
offset: int = (page - 1) * effective_page_size
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
"ban_service_list_bans",
|
||||
db_path=db_path,
|
||||
@@ -276,10 +179,10 @@ async def list_bans(
|
||||
# This avoids hitting the 45 req/min single-IP rate limit when the
|
||||
# page contains many bans (e.g. after a large blocklist import).
|
||||
geo_map: dict[str, GeoInfo] = {}
|
||||
if http_session is not None and rows:
|
||||
if http_session is not None and rows and geo_batch_lookup is not None:
|
||||
page_ips: list[str] = [r.ip for r in rows]
|
||||
try:
|
||||
geo_map = await geo_service.lookup_batch(page_ips, http_session, db=app_db)
|
||||
geo_map = await geo_batch_lookup(page_ips, http_session, db=app_db)
|
||||
except Exception: # noqa: BLE001
|
||||
log.warning("ban_service_batch_geo_failed_list_bans")
|
||||
|
||||
@@ -287,9 +190,9 @@ async def list_bans(
|
||||
for row in rows:
|
||||
jail: str = row.jail
|
||||
ip: str = row.ip
|
||||
banned_at: str = _ts_to_iso(row.timeofban)
|
||||
banned_at: str = ts_to_iso(row.timeofban)
|
||||
ban_count: int = row.bancount
|
||||
matches, _ = _parse_data_json(row.data)
|
||||
matches, _ = parse_data_json(row.data)
|
||||
service: str | None = matches[0] if matches else None
|
||||
|
||||
country_code: str | None = None
|
||||
@@ -350,6 +253,8 @@ async def bans_by_country(
|
||||
socket_path: str,
|
||||
range_: TimeRange,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
geo_cache_lookup: GeoCacheLookup | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
@@ -389,11 +294,10 @@ async def bans_by_country(
|
||||
:class:`~app.models.ban.BansByCountryResponse` with per-country
|
||||
aggregation and the companion ban list.
|
||||
"""
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
since: int = _since_unix(range_)
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
"ban_service_bans_by_country",
|
||||
db_path=db_path,
|
||||
@@ -429,23 +333,24 @@ async def bans_by_country(
|
||||
unique_ips: list[str] = [r.ip for r in agg_rows]
|
||||
geo_map: dict[str, GeoInfo] = {}
|
||||
|
||||
if http_session is not None and unique_ips:
|
||||
if http_session is not None and unique_ips and geo_cache_lookup is not None:
|
||||
# Serve only what is already in the in-memory cache — no API calls on
|
||||
# the hot path. Uncached IPs are resolved asynchronously in the
|
||||
# background so subsequent requests benefit from a warmer cache.
|
||||
geo_map, uncached = geo_service.lookup_cached_only(unique_ips)
|
||||
geo_map, uncached = geo_cache_lookup(unique_ips)
|
||||
if uncached:
|
||||
log.info(
|
||||
"ban_service_geo_background_scheduled",
|
||||
uncached=len(uncached),
|
||||
cached=len(geo_map),
|
||||
)
|
||||
# Fire-and-forget: lookup_batch handles rate-limiting / retries.
|
||||
# The dirty-set flush task persists results to the DB.
|
||||
asyncio.create_task( # noqa: RUF006
|
||||
geo_service.lookup_batch(uncached, http_session, db=app_db),
|
||||
name="geo_bans_by_country",
|
||||
)
|
||||
if geo_batch_lookup is not None:
|
||||
# Fire-and-forget: lookup_batch handles rate-limiting / retries.
|
||||
# The dirty-set flush task persists results to the DB.
|
||||
asyncio.create_task( # noqa: RUF006
|
||||
geo_batch_lookup(uncached, http_session, db=app_db),
|
||||
name="geo_bans_by_country",
|
||||
)
|
||||
elif geo_enricher is not None and unique_ips:
|
||||
# Fallback: legacy per-IP enricher (used in tests / older callers).
|
||||
async def _safe_lookup(ip: str) -> tuple[str, GeoInfo | None]:
|
||||
@@ -483,13 +388,13 @@ async def bans_by_country(
|
||||
cn = geo.country_name if geo else None
|
||||
asn: str | None = geo.asn if geo else None
|
||||
org: str | None = geo.org if geo else None
|
||||
matches, _ = _parse_data_json(companion_row.data)
|
||||
matches, _ = parse_data_json(companion_row.data)
|
||||
|
||||
bans.append(
|
||||
DashboardBanItem(
|
||||
ip=ip,
|
||||
jail=companion_row.jail,
|
||||
banned_at=_ts_to_iso(companion_row.timeofban),
|
||||
banned_at=ts_to_iso(companion_row.timeofban),
|
||||
service=matches[0] if matches else None,
|
||||
country_code=cc,
|
||||
country_name=cn,
|
||||
@@ -550,7 +455,7 @@ async def ban_trend(
|
||||
num_buckets: int = bucket_count(range_)
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
"ban_service_ban_trend",
|
||||
db_path=db_path,
|
||||
@@ -571,7 +476,7 @@ async def ban_trend(
|
||||
|
||||
buckets: list[BanTrendBucket] = [
|
||||
BanTrendBucket(
|
||||
timestamp=_ts_to_iso(since + i * bucket_secs),
|
||||
timestamp=ts_to_iso(since + i * bucket_secs),
|
||||
count=counts[i],
|
||||
)
|
||||
for i in range(num_buckets)
|
||||
@@ -615,12 +520,12 @@ async def bans_by_jail(
|
||||
since: int = _since_unix(range_)
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.debug(
|
||||
"ban_service_bans_by_jail",
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
since_iso=_ts_to_iso(since),
|
||||
since_iso=ts_to_iso(since),
|
||||
range=range_,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user