Files
BanGUI/backend/app/services/ban_service.py
Lukas 2f2e5a7419 fix: retry, semaphore, reload lock, activation verify, bans_by_jail diagnostics
Stage 1.1-1.3: reload_all include/exclude_jails params already implemented;
  added keyword-arg assertions in router and service tests.

Stage 2.1/6.1: _send_command_sync retry loop (3 attempts, 150ms exp backoff)
  retrying on EAGAIN/ECONNREFUSED/ENOBUFS; immediate raise on all other errors.

Stage 2.2: asyncio.Lock at module level in jail_service.reload_all to
  serialize concurrent reload--all commands.

Stage 3.1: activate_jail re-queries _get_active_jail_names after reload;
  returns active=False with descriptive message if jail did not start.

Stage 4.1/6.2: asyncio.Semaphore (max 10) in Fail2BanClient.send, lazy-
  initialized; logs fail2ban_command_waiting_semaphore at debug when waiting.

Stage 5.1/5.2: unit tests asserting reload_all is called with include_jails
  and exclude_jails; activation verification happy/sad path tests.

Stage 6.3: TestSendCommandSyncRetry (5 cases) + TestFail2BanClientSemaphore
  concurrency test.

Stage 7.1-7.3: _since_unix uses time.time(); bans_by_jail debug logging with
  since_iso; diagnostic warning when total==0 despite table rows; unit test
  verifying the warning fires for stale data.
2026-03-14 11:09:55 +01:00

693 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""Ban service.
Queries the fail2ban SQLite database for ban history. The fail2ban database
path is obtained at runtime by sending ``get dbfile`` to the fail2ban daemon
via the Unix domain socket.
All database I/O is performed through aiosqlite opened in **read-only** mode
so BanGUI never modifies or locks the fail2ban database.
"""
from __future__ import annotations
import asyncio
import json
import time
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
import aiosqlite
import structlog
from app.models.ban import (
BLOCKLIST_JAIL,
BUCKET_SECONDS,
BUCKET_SIZE_LABEL,
TIME_RANGE_SECONDS,
BanOrigin,
BansByCountryResponse,
BansByJailResponse,
BanTrendBucket,
BanTrendResponse,
DashboardBanItem,
DashboardBanListResponse,
JailBanCount,
TimeRange,
_derive_origin,
bucket_count,
)
from app.utils.fail2ban_client import Fail2BanClient
if TYPE_CHECKING:
import aiohttp
log: structlog.stdlib.BoundLogger = structlog.get_logger()
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
_DEFAULT_PAGE_SIZE: int = 100
_MAX_PAGE_SIZE: int = 500
_SOCKET_TIMEOUT: float = 5.0
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
"""Return a SQL fragment and its parameters for the origin filter.
Args:
origin: ``"blocklist"`` to restrict to the blocklist-import jail,
``"selfblock"`` to exclude it, or ``None`` for no restriction.
Returns:
A ``(sql_fragment, params)`` pair — the fragment starts with ``" AND"``
so it can be appended directly to an existing WHERE clause.
"""
if origin == "blocklist":
return " AND jail = ?", (BLOCKLIST_JAIL,)
if origin == "selfblock":
return " AND jail != ?", (BLOCKLIST_JAIL,)
return "", ()
def _since_unix(range_: TimeRange) -> int:
"""Return the Unix timestamp representing the start of the time window.
Uses :func:`time.time` (always UTC epoch seconds on all platforms) to be
consistent with how fail2ban stores ``timeofban`` values in its SQLite
database. fail2ban records ``time.time()`` values directly, so
comparing against a timezone-aware ``datetime.now(UTC).timestamp()`` would
theoretically produce the same number but using :func:`time.time` avoids
any tz-aware datetime pitfalls on misconfigured systems.
Args:
range_: One of the supported time-range presets.
Returns:
Unix timestamp (seconds since epoch) equal to *now range_*.
"""
seconds: int = TIME_RANGE_SECONDS[range_]
return int(time.time()) - seconds
def _ts_to_iso(unix_ts: int) -> str:
"""Convert a Unix timestamp to an ISO 8601 UTC string.
Args:
unix_ts: Seconds since the Unix epoch.
Returns:
ISO 8601 UTC timestamp, e.g. ``"2026-03-01T12:00:00+00:00"``.
"""
return datetime.fromtimestamp(unix_ts, tz=UTC).isoformat()
async def _get_fail2ban_db_path(socket_path: str) -> str:
"""Query fail2ban for the path to its SQLite database.
Sends the ``get dbfile`` command via the fail2ban socket and returns
the value of the ``dbfile`` setting.
Args:
socket_path: Path to the fail2ban Unix domain socket.
Returns:
Absolute path to the fail2ban SQLite database file.
Raises:
RuntimeError: If fail2ban reports that no database is configured
or if the socket response is unexpected.
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
cannot be reached.
"""
async with Fail2BanClient(socket_path, timeout=_SOCKET_TIMEOUT) as client:
response = await client.send(["get", "dbfile"])
try:
code, data = response
except (TypeError, ValueError) as exc:
raise RuntimeError(f"Unexpected response from fail2ban: {response!r}") from exc
if code != 0:
raise RuntimeError(f"fail2ban error code {code}: {data!r}")
if data is None:
raise RuntimeError("fail2ban has no database configured (dbfile is None)")
return str(data)
def _parse_data_json(raw: Any) -> tuple[list[str], int]:
"""Extract matches and failure count from the ``bans.data`` column.
The ``data`` column stores a JSON blob with optional keys:
* ``matches`` — list of raw matched log lines.
* ``failures`` — total failure count that triggered the ban.
Args:
raw: The raw ``data`` column value (string, dict, or ``None``).
Returns:
A ``(matches, failures)`` tuple. Both default to empty/zero when
parsing fails or the column is absent.
"""
if raw is None:
return [], 0
obj: dict[str, Any] = {}
if isinstance(raw, str):
try:
parsed: Any = json.loads(raw)
if isinstance(parsed, dict):
obj = parsed
# json.loads("null") → None, or other non-dict — treat as empty
except json.JSONDecodeError:
return [], 0
elif isinstance(raw, dict):
obj = raw
matches: list[str] = [str(m) for m in (obj.get("matches") or [])]
failures: int = int(obj.get("failures", 0))
return matches, failures
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
async def list_bans(
socket_path: str,
range_: TimeRange,
*,
page: int = 1,
page_size: int = _DEFAULT_PAGE_SIZE,
http_session: aiohttp.ClientSession | None = None,
app_db: aiosqlite.Connection | None = None,
geo_enricher: Any | None = None,
origin: BanOrigin | None = None,
) -> DashboardBanListResponse:
"""Return a paginated list of bans within the selected time window.
Queries the fail2ban database ``bans`` table for records whose
``timeofban`` falls within the specified *range_*. Results are ordered
newest-first.
Geo enrichment strategy (highest priority first):
1. When *http_session* is provided the entire page of IPs is resolved in
one :func:`~app.services.geo_service.lookup_batch` call (up to 100 IPs
per HTTP request). This avoids the 45 req/min rate limit of the
single-IP endpoint and is the preferred production path.
2. When only *geo_enricher* is provided (legacy / test path) each IP is
resolved individually via the supplied async callable.
Args:
socket_path: Path to the fail2ban Unix domain socket.
range_: Time-range preset (``"24h"``, ``"7d"``, ``"30d"``, or
``"365d"``).
page: 1-based page number (default: ``1``).
page_size: Maximum items per page, capped at ``_MAX_PAGE_SIZE``
(default: ``100``).
http_session: Optional shared :class:`aiohttp.ClientSession`. When
provided, :func:`~app.services.geo_service.lookup_batch` is used
for efficient bulk geo resolution.
app_db: Optional BanGUI application database used to persist newly
resolved geo entries and to read back cached results.
geo_enricher: Optional async callable ``(ip: str) -> GeoInfo | None``.
Used as a fallback when *http_session* is ``None`` (e.g. tests).
origin: Optional origin filter — ``"blocklist"`` restricts results to
the ``blocklist-import`` jail, ``"selfblock"`` excludes it.
Returns:
:class:`~app.models.ban.DashboardBanListResponse` containing the
paginated items and total count.
"""
from app.services import geo_service # noqa: PLC0415
since: int = _since_unix(range_)
effective_page_size: int = min(page_size, _MAX_PAGE_SIZE)
offset: int = (page - 1) * effective_page_size
origin_clause, origin_params = _origin_sql_filter(origin)
db_path: str = await _get_fail2ban_db_path(socket_path)
log.info(
"ban_service_list_bans",
db_path=db_path,
since=since,
range=range_,
origin=origin,
)
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
f2b_db.row_factory = aiosqlite.Row
async with f2b_db.execute(
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
(since, *origin_params),
) as cur:
count_row = await cur.fetchone()
total: int = int(count_row[0]) if count_row else 0
async with f2b_db.execute(
"SELECT jail, ip, timeofban, bancount, data "
"FROM bans "
"WHERE timeofban >= ?"
+ origin_clause
+ " ORDER BY timeofban DESC "
"LIMIT ? OFFSET ?",
(since, *origin_params, effective_page_size, offset),
) as cur:
rows = await cur.fetchall()
# Batch-resolve geo data for all IPs on this page in a single API call.
# This avoids hitting the 45 req/min single-IP rate limit when the
# page contains many bans (e.g. after a large blocklist import).
geo_map: dict[str, Any] = {}
if http_session is not None and rows:
page_ips: list[str] = [str(r["ip"]) for r in rows]
try:
geo_map = await geo_service.lookup_batch(page_ips, http_session, db=app_db)
except Exception: # noqa: BLE001
log.warning("ban_service_batch_geo_failed_list_bans")
items: list[DashboardBanItem] = []
for row in rows:
jail: str = str(row["jail"])
ip: str = str(row["ip"])
banned_at: str = _ts_to_iso(int(row["timeofban"]))
ban_count: int = int(row["bancount"])
matches, _ = _parse_data_json(row["data"])
service: str | None = matches[0] if matches else None
country_code: str | None = None
country_name: str | None = None
asn: str | None = None
org: str | None = None
if geo_map:
geo = geo_map.get(ip)
if geo is not None:
country_code = geo.country_code
country_name = geo.country_name
asn = geo.asn
org = geo.org
elif geo_enricher is not None:
try:
geo = await geo_enricher(ip)
if geo is not None:
country_code = geo.country_code
country_name = geo.country_name
asn = geo.asn
org = geo.org
except Exception: # noqa: BLE001
log.warning("ban_service_geo_lookup_failed", ip=ip)
items.append(
DashboardBanItem(
ip=ip,
jail=jail,
banned_at=banned_at,
service=service,
country_code=country_code,
country_name=country_name,
asn=asn,
org=org,
ban_count=ban_count,
origin=_derive_origin(jail),
)
)
return DashboardBanListResponse(
items=items,
total=total,
page=page,
page_size=effective_page_size,
)
# ---------------------------------------------------------------------------
# bans_by_country
# ---------------------------------------------------------------------------
#: Maximum rows returned in the companion table alongside the map.
_MAX_COMPANION_BANS: int = 200
async def bans_by_country(
socket_path: str,
range_: TimeRange,
http_session: aiohttp.ClientSession | None = None,
geo_enricher: Any | None = None,
app_db: aiosqlite.Connection | None = None,
origin: BanOrigin | None = None,
) -> BansByCountryResponse:
"""Aggregate ban counts per country for the selected time window.
Uses a two-step strategy optimised for large datasets:
1. Queries the fail2ban DB with ``GROUP BY ip`` to get the per-IP ban
counts for all unique IPs in the window — no row-count cap.
2. Serves geo data from the in-memory cache only (non-blocking).
Any IPs not yet in the cache are scheduled for background resolution
via :func:`asyncio.create_task` so the response is returned immediately
and subsequent requests benefit from the warmed cache.
3. Returns a ``{country_code: count}`` aggregation and the 200 most
recent raw rows for the companion table.
Note:
On the very first request a large number of IPs may be uncached and
the country map will be sparse. The background task will resolve them
and the next request will return a complete map. This trade-off keeps
the endpoint fast regardless of dataset size.
Args:
socket_path: Path to the fail2ban Unix domain socket.
range_: Time-range preset.
http_session: Optional :class:`aiohttp.ClientSession` for background
geo lookups. When ``None``, only cached data is used.
geo_enricher: Legacy async ``(ip) -> GeoInfo | None`` callable;
used when *http_session* is ``None`` (e.g. tests).
app_db: Optional BanGUI application database used to persist newly
resolved geo entries across restarts.
origin: Optional origin filter — ``"blocklist"`` restricts results to
the ``blocklist-import`` jail, ``"selfblock"`` excludes it.
Returns:
:class:`~app.models.ban.BansByCountryResponse` with per-country
aggregation and the companion ban list.
"""
from app.services import geo_service # noqa: PLC0415
since: int = _since_unix(range_)
origin_clause, origin_params = _origin_sql_filter(origin)
db_path: str = await _get_fail2ban_db_path(socket_path)
log.info(
"ban_service_bans_by_country",
db_path=db_path,
since=since,
range=range_,
origin=origin,
)
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
f2b_db.row_factory = aiosqlite.Row
# Total count for the window.
async with f2b_db.execute(
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
(since, *origin_params),
) as cur:
count_row = await cur.fetchone()
total: int = int(count_row[0]) if count_row else 0
# Aggregation: unique IPs + their total event count.
# No LIMIT here — we need all unique source IPs for accurate country counts.
async with f2b_db.execute(
"SELECT ip, COUNT(*) AS event_count "
"FROM bans "
"WHERE timeofban >= ?"
+ origin_clause
+ " GROUP BY ip",
(since, *origin_params),
) as cur:
agg_rows = await cur.fetchall()
# Companion table: most recent raw rows for display alongside the map.
async with f2b_db.execute(
"SELECT jail, ip, timeofban, bancount, data "
"FROM bans "
"WHERE timeofban >= ?"
+ origin_clause
+ " ORDER BY timeofban DESC "
"LIMIT ?",
(since, *origin_params, _MAX_COMPANION_BANS),
) as cur:
companion_rows = await cur.fetchall()
unique_ips: list[str] = [str(r["ip"]) for r in agg_rows]
geo_map: dict[str, Any] = {}
if http_session is not None and unique_ips:
# Serve only what is already in the in-memory cache — no API calls on
# the hot path. Uncached IPs are resolved asynchronously in the
# background so subsequent requests benefit from a warmer cache.
geo_map, uncached = geo_service.lookup_cached_only(unique_ips)
if uncached:
log.info(
"ban_service_geo_background_scheduled",
uncached=len(uncached),
cached=len(geo_map),
)
# Fire-and-forget: lookup_batch handles rate-limiting / retries.
# The dirty-set flush task persists results to the DB.
asyncio.create_task( # noqa: RUF006
geo_service.lookup_batch(uncached, http_session, db=app_db),
name="geo_bans_by_country",
)
elif geo_enricher is not None and unique_ips:
# Fallback: legacy per-IP enricher (used in tests / older callers).
async def _safe_lookup(ip: str) -> tuple[str, Any]:
try:
return ip, await geo_enricher(ip)
except Exception: # noqa: BLE001
log.warning("ban_service_geo_lookup_failed", ip=ip)
return ip, None
results = await asyncio.gather(*(_safe_lookup(ip) for ip in unique_ips))
geo_map = dict(results)
# Build country aggregation from the SQL-grouped rows.
countries: dict[str, int] = {}
country_names: dict[str, str] = {}
for row in agg_rows:
ip: str = str(row["ip"])
geo = geo_map.get(ip)
cc: str | None = geo.country_code if geo else None
cn: str | None = geo.country_name if geo else None
event_count: int = int(row["event_count"])
if cc:
countries[cc] = countries.get(cc, 0) + event_count
if cn and cc not in country_names:
country_names[cc] = cn
# Build companion table from recent rows (geo already cached from batch step).
bans: list[DashboardBanItem] = []
for row in companion_rows:
ip = str(row["ip"])
geo = geo_map.get(ip)
cc = geo.country_code if geo else None
cn = geo.country_name if geo else None
asn: str | None = geo.asn if geo else None
org: str | None = geo.org if geo else None
matches, _ = _parse_data_json(row["data"])
bans.append(
DashboardBanItem(
ip=ip,
jail=str(row["jail"]),
banned_at=_ts_to_iso(int(row["timeofban"])),
service=matches[0] if matches else None,
country_code=cc,
country_name=cn,
asn=asn,
org=org,
ban_count=int(row["bancount"]),
origin=_derive_origin(str(row["jail"])),
)
)
return BansByCountryResponse(
countries=countries,
country_names=country_names,
bans=bans,
total=total,
)
# ---------------------------------------------------------------------------
# ban_trend
# ---------------------------------------------------------------------------
async def ban_trend(
socket_path: str,
range_: TimeRange,
*,
origin: BanOrigin | None = None,
) -> BanTrendResponse:
"""Return ban counts aggregated into equal-width time buckets.
Queries the fail2ban database ``bans`` table and groups records by a
computed bucket index so the frontend can render a continuous time-series
chart. All buckets within the requested window are returned — buckets
that contain zero bans are included as zero-count entries so the
frontend always receives a complete, gap-free series.
Bucket sizes per time-range preset:
* ``24h`` → 1-hour buckets (24 total)
* ``7d`` → 6-hour buckets (28 total)
* ``30d`` → 1-day buckets (30 total)
* ``365d`` → 7-day buckets (~53 total)
Args:
socket_path: Path to the fail2ban Unix domain socket.
range_: Time-range preset (``"24h"``, ``"7d"``, ``"30d"``, or
``"365d"``).
origin: Optional origin filter — ``"blocklist"`` restricts to the
``blocklist-import`` jail, ``"selfblock"`` excludes it.
Returns:
:class:`~app.models.ban.BanTrendResponse` with a full bucket list
and the human-readable bucket-size label.
"""
since: int = _since_unix(range_)
bucket_secs: int = BUCKET_SECONDS[range_]
num_buckets: int = bucket_count(range_)
origin_clause, origin_params = _origin_sql_filter(origin)
db_path: str = await _get_fail2ban_db_path(socket_path)
log.info(
"ban_service_ban_trend",
db_path=db_path,
since=since,
range=range_,
origin=origin,
bucket_secs=bucket_secs,
num_buckets=num_buckets,
)
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
f2b_db.row_factory = aiosqlite.Row
async with f2b_db.execute(
"SELECT CAST((timeofban - ?) / ? AS INTEGER) AS bucket_idx, "
"COUNT(*) AS cnt "
"FROM bans "
"WHERE timeofban >= ?"
+ origin_clause
+ " GROUP BY bucket_idx "
"ORDER BY bucket_idx",
(since, bucket_secs, since, *origin_params),
) as cur:
rows = await cur.fetchall()
# Map bucket_idx → count; ignore any out-of-range indices.
counts: dict[int, int] = {}
for row in rows:
idx: int = int(row["bucket_idx"])
if 0 <= idx < num_buckets:
counts[idx] = int(row["cnt"])
buckets: list[BanTrendBucket] = [
BanTrendBucket(
timestamp=_ts_to_iso(since + i * bucket_secs),
count=counts.get(i, 0),
)
for i in range(num_buckets)
]
return BanTrendResponse(
buckets=buckets,
bucket_size=BUCKET_SIZE_LABEL[range_],
)
# ---------------------------------------------------------------------------
# bans_by_jail
# ---------------------------------------------------------------------------
async def bans_by_jail(
socket_path: str,
range_: TimeRange,
*,
origin: BanOrigin | None = None,
) -> BansByJailResponse:
"""Return ban counts aggregated per jail for the selected time window.
Queries the fail2ban database ``bans`` table, groups records by jail
name, and returns them ordered by count descending. The origin filter
is applied when provided so callers can restrict results to blocklist-
imported bans or organic fail2ban bans.
Args:
socket_path: Path to the fail2ban Unix domain socket.
range_: Time-range preset (``"24h"``, ``"7d"``, ``"30d"``, or
``"365d"``).
origin: Optional origin filter — ``"blocklist"`` restricts to the
``blocklist-import`` jail, ``"selfblock"`` excludes it.
Returns:
:class:`~app.models.ban.BansByJailResponse` with per-jail counts
sorted descending and the total ban count.
"""
since: int = _since_unix(range_)
origin_clause, origin_params = _origin_sql_filter(origin)
db_path: str = await _get_fail2ban_db_path(socket_path)
log.debug(
"ban_service_bans_by_jail",
db_path=db_path,
since=since,
since_iso=_ts_to_iso(since),
range=range_,
origin=origin,
)
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
f2b_db.row_factory = aiosqlite.Row
async with f2b_db.execute(
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
(since, *origin_params),
) as cur:
count_row = await cur.fetchone()
total: int = int(count_row[0]) if count_row else 0
# Diagnostic guard: if zero results were returned, check whether the
# table has *any* rows and log a warning with min/max timeofban so
# operators can diagnose timezone or filter mismatches from logs.
if total == 0:
async with f2b_db.execute(
"SELECT COUNT(*), MIN(timeofban), MAX(timeofban) FROM bans"
) as cur:
diag_row = await cur.fetchone()
if diag_row and diag_row[0] > 0:
log.warning(
"ban_service_bans_by_jail_empty_despite_data",
table_row_count=diag_row[0],
min_timeofban=diag_row[1],
max_timeofban=diag_row[2],
since=since,
range=range_,
)
async with f2b_db.execute(
"SELECT jail, COUNT(*) AS cnt "
"FROM bans "
"WHERE timeofban >= ?"
+ origin_clause
+ " GROUP BY jail ORDER BY cnt DESC",
(since, *origin_params),
) as cur:
rows = await cur.fetchall()
jails: list[JailBanCount] = [
JailBanCount(jail=str(row["jail"]), count=int(row["cnt"])) for row in rows
]
log.debug(
"ban_service_bans_by_jail_result",
total=total,
jail_count=len(jails),
)
return BansByJailResponse(jails=jails, total=total)