refactor: complete Task 2/3 geo decouple + exceptions centralization; mark as done

This commit is contained in:
2026-03-21 17:15:02 +01:00
parent 3aba2b6446
commit a442836c5c
28 changed files with 803 additions and 571 deletions

23
backend/app/exceptions.py Normal file
View File

@@ -0,0 +1,23 @@
"""Shared domain exception classes used across routers and services."""
from __future__ import annotations
class JailNotFoundError(Exception):
"""Raised when a requested jail name does not exist."""
class JailOperationError(Exception):
"""Raised when a fail2ban jail operation fails."""
class ConfigValidationError(Exception):
"""Raised when config values fail validation before applying."""
class ConfigOperationError(Exception):
"""Raised when a config payload update or command fails."""
class ServerOperationError(Exception):
"""Raised when a server control command (e.g. refresh) fails."""

View File

@@ -3,8 +3,18 @@
Response models for the ``GET /api/geo/lookup/{ip}`` endpoint.
"""
from __future__ import annotations
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING
from pydantic import BaseModel, ConfigDict, Field
if TYPE_CHECKING:
import aiohttp
import aiosqlite
class GeoDetail(BaseModel):
"""Enriched geolocation data for an IP address.
@@ -64,3 +74,26 @@ class IpLookupResponse(BaseModel):
default=None,
description="Enriched geographical and network information.",
)
# ---------------------------------------------------------------------------
# shared service types
# ---------------------------------------------------------------------------
@dataclass
class GeoInfo:
"""Geo resolution result used throughout backend services."""
country_code: str | None
country_name: str | None
asn: str | None
org: str | None
GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
GeoBatchLookup = Callable[
[list[str], "aiohttp.ClientSession", "aiosqlite.Connection | None"],
Awaitable[dict[str, GeoInfo]],
]
GeoCacheLookup = Callable[[list[str]], tuple[dict[str, GeoInfo], list[str]]]

View File

@@ -20,8 +20,8 @@ from fastapi import APIRouter, HTTPException, Request, status
from app.dependencies import AuthDep
from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest
from app.models.jail import JailCommandResponse
from app.services import jail_service
from app.services.jail_service import JailNotFoundError, JailOperationError
from app.services import geo_service, jail_service
from app.exceptions import JailNotFoundError, JailOperationError
from app.utils.fail2ban_client import Fail2BanConnectionError
router: APIRouter = APIRouter(prefix="/api/bans", tags=["Bans"])
@@ -73,6 +73,7 @@ async def get_active_bans(
try:
return await jail_service.get_active_bans(
socket_path,
geo_batch_lookup=geo_service.lookup_batch,
http_session=http_session,
app_db=app_db,
)

View File

@@ -42,7 +42,7 @@ from app.models.blocklist import (
ScheduleConfig,
ScheduleInfo,
)
from app.services import blocklist_service
from app.services import blocklist_service, geo_service
from app.tasks import blocklist_import as blocklist_import_task
router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"])
@@ -131,7 +131,13 @@ async def run_import_now(
"""
http_session: aiohttp.ClientSession = request.app.state.http_session
socket_path: str = request.app.state.settings.fail2ban_socket
return await blocklist_service.import_all(db, http_session, socket_path)
return await blocklist_service.import_all(
db,
http_session,
socket_path,
geo_is_cached=geo_service.is_cached,
geo_batch_lookup=geo_service.lookup_batch,
)
@router.get(

View File

@@ -93,12 +93,7 @@ from app.services.config_file_service import (
JailNameError,
JailNotFoundInConfigError,
)
from app.services.config_service import (
ConfigOperationError,
ConfigValidationError,
JailNotFoundError,
)
from app.services.jail_service import JailOperationError
from app.exceptions import ConfigOperationError, ConfigValidationError, JailNotFoundError, JailOperationError
from app.tasks.health_check import _run_probe
from app.utils.fail2ban_client import Fail2BanConnectionError

View File

@@ -30,7 +30,7 @@ from app.models.ban import (
TimeRange,
)
from app.models.server import ServerStatus, ServerStatusResponse
from app.services import ban_service
from app.services import ban_service, geo_service
router: APIRouter = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
@@ -120,6 +120,7 @@ async def get_dashboard_bans(
page_size=page_size,
http_session=http_session,
app_db=None,
geo_batch_lookup=geo_service.lookup_batch,
origin=origin,
)
@@ -163,6 +164,8 @@ async def get_bans_by_country(
socket_path,
range,
http_session=http_session,
geo_cache_lookup=geo_service.lookup_cached_only,
geo_batch_lookup=geo_service.lookup_batch,
app_db=None,
origin=origin,
)

View File

@@ -19,9 +19,8 @@ import aiosqlite
from fastapi import APIRouter, Depends, HTTPException, Path, Request, status
from app.dependencies import AuthDep, get_db
from app.models.geo import GeoCacheStatsResponse, GeoDetail, IpLookupResponse
from app.models.geo import GeoCacheStatsResponse, GeoDetail, GeoInfo, IpLookupResponse
from app.services import geo_service, jail_service
from app.services.geo_service import GeoInfo
from app.utils.fail2ban_client import Fail2BanConnectionError
router: APIRouter = APIRouter(prefix="/api/geo", tags=["Geo"])

View File

@@ -31,8 +31,8 @@ from app.models.jail import (
JailDetailResponse,
JailListResponse,
)
from app.services import jail_service
from app.services.jail_service import JailNotFoundError, JailOperationError
from app.services import geo_service, jail_service
from app.exceptions import JailNotFoundError, JailOperationError
from app.utils.fail2ban_client import Fail2BanConnectionError
router: APIRouter = APIRouter(prefix="/api/jails", tags=["Jails"])
@@ -606,6 +606,7 @@ async def get_jail_banned_ips(
page=page,
page_size=page_size,
search=search,
geo_batch_lookup=geo_service.lookup_batch,
http_session=http_session,
app_db=app_db,
)

View File

@@ -15,7 +15,7 @@ from fastapi import APIRouter, HTTPException, Request, status
from app.dependencies import AuthDep
from app.models.server import ServerSettingsResponse, ServerSettingsUpdate
from app.services import server_service
from app.services.server_service import ServerOperationError
from app.exceptions import ServerOperationError
from app.utils.fail2ban_client import Fail2BanConnectionError
router: APIRouter = APIRouter(prefix="/api/server", tags=["Server"])

View File

@@ -11,11 +11,8 @@ so BanGUI never modifies or locks the fail2ban database.
from __future__ import annotations
import asyncio
import json
import time
from collections.abc import Awaitable, Callable
from datetime import UTC, datetime
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING
import structlog
@@ -39,18 +36,16 @@ from app.models.ban import (
JailBanCount as JailBanCountModel,
)
from app.repositories import fail2ban_db_repo
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanResponse
from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso
if TYPE_CHECKING:
import aiohttp
import aiosqlite
from app.services.geo_service import GeoInfo
from app.models.geo import GeoBatchLookup, GeoCacheLookup, GeoEnricher, GeoInfo
log: structlog.stdlib.BoundLogger = structlog.get_logger()
type GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
@@ -102,98 +97,6 @@ def _since_unix(range_: TimeRange) -> int:
return int(time.time()) - seconds
def _ts_to_iso(unix_ts: int) -> str:
"""Convert a Unix timestamp to an ISO 8601 UTC string.
Args:
unix_ts: Seconds since the Unix epoch.
Returns:
ISO 8601 UTC timestamp, e.g. ``"2026-03-01T12:00:00+00:00"``.
"""
return datetime.fromtimestamp(unix_ts, tz=UTC).isoformat()
async def _get_fail2ban_db_path(socket_path: str) -> str:
"""Query fail2ban for the path to its SQLite database.
Sends the ``get dbfile`` command via the fail2ban socket and returns
the value of the ``dbfile`` setting.
Args:
socket_path: Path to the fail2ban Unix domain socket.
Returns:
Absolute path to the fail2ban SQLite database file.
Raises:
RuntimeError: If fail2ban reports that no database is configured
or if the socket response is unexpected.
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
cannot be reached.
"""
async with Fail2BanClient(socket_path, timeout=_SOCKET_TIMEOUT) as client:
response = await client.send(["get", "dbfile"])
try:
code, data = cast("Fail2BanResponse", response)
except (TypeError, ValueError) as exc:
raise RuntimeError(f"Unexpected response from fail2ban: {response!r}") from exc
if code != 0:
raise RuntimeError(f"fail2ban error code {code}: {data!r}")
if data is None:
raise RuntimeError("fail2ban has no database configured (dbfile is None)")
return str(data)
def _parse_data_json(raw: object) -> tuple[list[str], int]:
"""Extract matches and failure count from the ``bans.data`` column.
The ``data`` column stores a JSON blob with optional keys:
* ``matches`` — list of raw matched log lines.
* ``failures`` — total failure count that triggered the ban.
Args:
raw: The raw ``data`` column value (string, dict, or ``None``).
Returns:
A ``(matches, failures)`` tuple. Both default to empty/zero when
parsing fails or the column is absent.
"""
if raw is None:
return [], 0
obj: dict[str, object] = {}
if isinstance(raw, str):
try:
parsed: object = json.loads(raw)
if isinstance(parsed, dict):
obj = parsed
# json.loads("null") → None, or other non-dict — treat as empty
except json.JSONDecodeError:
return [], 0
elif isinstance(raw, dict):
obj = raw
raw_matches = obj.get("matches")
if isinstance(raw_matches, list):
matches: list[str] = [str(m) for m in raw_matches]
else:
matches = []
raw_failures = obj.get("failures")
failures: int = 0
if isinstance(raw_failures, (int, float, str)):
try:
failures = int(raw_failures)
except (ValueError, TypeError):
failures = 0
return matches, failures
# ---------------------------------------------------------------------------
@@ -209,6 +112,7 @@ async def list_bans(
page_size: int = _DEFAULT_PAGE_SIZE,
http_session: aiohttp.ClientSession | None = None,
app_db: aiosqlite.Connection | None = None,
geo_batch_lookup: GeoBatchLookup | None = None,
geo_enricher: GeoEnricher | None = None,
origin: BanOrigin | None = None,
) -> DashboardBanListResponse:
@@ -248,14 +152,13 @@ async def list_bans(
:class:`~app.models.ban.DashboardBanListResponse` containing the
paginated items and total count.
"""
from app.services import geo_service # noqa: PLC0415
since: int = _since_unix(range_)
effective_page_size: int = min(page_size, _MAX_PAGE_SIZE)
offset: int = (page - 1) * effective_page_size
origin_clause, origin_params = _origin_sql_filter(origin)
db_path: str = await _get_fail2ban_db_path(socket_path)
db_path: str = await get_fail2ban_db_path(socket_path)
log.info(
"ban_service_list_bans",
db_path=db_path,
@@ -276,10 +179,10 @@ async def list_bans(
# This avoids hitting the 45 req/min single-IP rate limit when the
# page contains many bans (e.g. after a large blocklist import).
geo_map: dict[str, GeoInfo] = {}
if http_session is not None and rows:
if http_session is not None and rows and geo_batch_lookup is not None:
page_ips: list[str] = [r.ip for r in rows]
try:
geo_map = await geo_service.lookup_batch(page_ips, http_session, db=app_db)
geo_map = await geo_batch_lookup(page_ips, http_session, db=app_db)
except Exception: # noqa: BLE001
log.warning("ban_service_batch_geo_failed_list_bans")
@@ -287,9 +190,9 @@ async def list_bans(
for row in rows:
jail: str = row.jail
ip: str = row.ip
banned_at: str = _ts_to_iso(row.timeofban)
banned_at: str = ts_to_iso(row.timeofban)
ban_count: int = row.bancount
matches, _ = _parse_data_json(row.data)
matches, _ = parse_data_json(row.data)
service: str | None = matches[0] if matches else None
country_code: str | None = None
@@ -350,6 +253,8 @@ async def bans_by_country(
socket_path: str,
range_: TimeRange,
http_session: aiohttp.ClientSession | None = None,
geo_cache_lookup: GeoCacheLookup | None = None,
geo_batch_lookup: GeoBatchLookup | None = None,
geo_enricher: GeoEnricher | None = None,
app_db: aiosqlite.Connection | None = None,
origin: BanOrigin | None = None,
@@ -389,11 +294,10 @@ async def bans_by_country(
:class:`~app.models.ban.BansByCountryResponse` with per-country
aggregation and the companion ban list.
"""
from app.services import geo_service # noqa: PLC0415
since: int = _since_unix(range_)
origin_clause, origin_params = _origin_sql_filter(origin)
db_path: str = await _get_fail2ban_db_path(socket_path)
db_path: str = await get_fail2ban_db_path(socket_path)
log.info(
"ban_service_bans_by_country",
db_path=db_path,
@@ -429,23 +333,24 @@ async def bans_by_country(
unique_ips: list[str] = [r.ip for r in agg_rows]
geo_map: dict[str, GeoInfo] = {}
if http_session is not None and unique_ips:
if http_session is not None and unique_ips and geo_cache_lookup is not None:
# Serve only what is already in the in-memory cache — no API calls on
# the hot path. Uncached IPs are resolved asynchronously in the
# background so subsequent requests benefit from a warmer cache.
geo_map, uncached = geo_service.lookup_cached_only(unique_ips)
geo_map, uncached = geo_cache_lookup(unique_ips)
if uncached:
log.info(
"ban_service_geo_background_scheduled",
uncached=len(uncached),
cached=len(geo_map),
)
# Fire-and-forget: lookup_batch handles rate-limiting / retries.
# The dirty-set flush task persists results to the DB.
asyncio.create_task( # noqa: RUF006
geo_service.lookup_batch(uncached, http_session, db=app_db),
name="geo_bans_by_country",
)
if geo_batch_lookup is not None:
# Fire-and-forget: lookup_batch handles rate-limiting / retries.
# The dirty-set flush task persists results to the DB.
asyncio.create_task( # noqa: RUF006
geo_batch_lookup(uncached, http_session, db=app_db),
name="geo_bans_by_country",
)
elif geo_enricher is not None and unique_ips:
# Fallback: legacy per-IP enricher (used in tests / older callers).
async def _safe_lookup(ip: str) -> tuple[str, GeoInfo | None]:
@@ -483,13 +388,13 @@ async def bans_by_country(
cn = geo.country_name if geo else None
asn: str | None = geo.asn if geo else None
org: str | None = geo.org if geo else None
matches, _ = _parse_data_json(companion_row.data)
matches, _ = parse_data_json(companion_row.data)
bans.append(
DashboardBanItem(
ip=ip,
jail=companion_row.jail,
banned_at=_ts_to_iso(companion_row.timeofban),
banned_at=ts_to_iso(companion_row.timeofban),
service=matches[0] if matches else None,
country_code=cc,
country_name=cn,
@@ -550,7 +455,7 @@ async def ban_trend(
num_buckets: int = bucket_count(range_)
origin_clause, origin_params = _origin_sql_filter(origin)
db_path: str = await _get_fail2ban_db_path(socket_path)
db_path: str = await get_fail2ban_db_path(socket_path)
log.info(
"ban_service_ban_trend",
db_path=db_path,
@@ -571,7 +476,7 @@ async def ban_trend(
buckets: list[BanTrendBucket] = [
BanTrendBucket(
timestamp=_ts_to_iso(since + i * bucket_secs),
timestamp=ts_to_iso(since + i * bucket_secs),
count=counts[i],
)
for i in range(num_buckets)
@@ -615,12 +520,12 @@ async def bans_by_jail(
since: int = _since_unix(range_)
origin_clause, origin_params = _origin_sql_filter(origin)
db_path: str = await _get_fail2ban_db_path(socket_path)
db_path: str = await get_fail2ban_db_path(socket_path)
log.debug(
"ban_service_bans_by_jail",
db_path=db_path,
since=since,
since_iso=_ts_to_iso(since),
since_iso=ts_to_iso(since),
range=range_,
origin=origin,
)

View File

@@ -33,9 +33,13 @@ from app.repositories import blocklist_repo, import_log_repo, settings_repo
from app.utils.ip_utils import is_valid_ip, is_valid_network
if TYPE_CHECKING:
from collections.abc import Callable
import aiohttp
import aiosqlite
from app.models.geo import GeoBatchLookup
log: structlog.stdlib.BoundLogger = structlog.get_logger()
#: Settings key used to persist the schedule config.
@@ -238,6 +242,8 @@ async def import_source(
http_session: aiohttp.ClientSession,
socket_path: str,
db: aiosqlite.Connection,
geo_is_cached: Callable[[str], bool] | None = None,
geo_batch_lookup: GeoBatchLookup | None = None,
) -> ImportSourceResult:
"""Download and apply bans from a single blocklist source.
@@ -339,12 +345,8 @@ async def import_source(
)
# --- Pre-warm geo cache for newly imported IPs ---
if imported_ips:
from app.services import geo_service # noqa: PLC0415
uncached_ips: list[str] = [
ip for ip in imported_ips if not geo_service.is_cached(ip)
]
if imported_ips and geo_is_cached is not None:
uncached_ips: list[str] = [ip for ip in imported_ips if not geo_is_cached(ip)]
skipped_geo: int = len(imported_ips) - len(uncached_ips)
if skipped_geo > 0:
@@ -355,9 +357,9 @@ async def import_source(
to_lookup=len(uncached_ips),
)
if uncached_ips:
if uncached_ips and geo_batch_lookup is not None:
try:
await geo_service.lookup_batch(uncached_ips, http_session, db=db)
await geo_batch_lookup(uncached_ips, http_session, db=db)
log.info(
"blocklist_geo_prewarm_complete",
source_id=source.id,
@@ -383,6 +385,8 @@ async def import_all(
db: aiosqlite.Connection,
http_session: aiohttp.ClientSession,
socket_path: str,
geo_is_cached: Callable[[str], bool] | None = None,
geo_batch_lookup: GeoBatchLookup | None = None,
) -> ImportRunResult:
"""Import all enabled blocklist sources.
@@ -406,7 +410,14 @@ async def import_all(
for row in sources:
source = _row_to_source(row)
result = await import_source(source, http_session, socket_path, db)
result = await import_source(
source,
http_session,
socket_path,
db,
geo_is_cached=geo_is_cached,
geo_batch_lookup=geo_batch_lookup,
)
results.append(result)
total_imported += result.ips_imported
total_skipped += result.ips_skipped

View File

@@ -54,8 +54,8 @@ from app.models.config import (
JailValidationResult,
RollbackResponse,
)
from app.exceptions import JailNotFoundError
from app.services import jail_service
from app.services.jail_service import JailNotFoundError as JailNotFoundError
from app.utils import conffile_parser
from app.utils.fail2ban_client import (
Fail2BanClient,

View File

@@ -44,6 +44,7 @@ from app.models.config import (
RegexTestResponse,
ServiceStatusResponse,
)
from app.exceptions import ConfigOperationError, ConfigValidationError, JailNotFoundError
from app.services import setup_service
from app.utils.fail2ban_client import Fail2BanClient
@@ -55,26 +56,7 @@ _SOCKET_TIMEOUT: float = 10.0
# Custom exceptions
# ---------------------------------------------------------------------------
class JailNotFoundError(Exception):
"""Raised when a requested jail name does not exist in fail2ban."""
def __init__(self, name: str) -> None:
"""Initialise with the jail name that was not found.
Args:
name: The jail name that could not be located.
"""
self.name: str = name
super().__init__(f"Jail not found: {name!r}")
class ConfigValidationError(Exception):
"""Raised when a configuration value fails validation before writing."""
class ConfigOperationError(Exception):
"""Raised when a configuration write command fails."""
# (exceptions are now defined in app.exceptions and imported above)
# ---------------------------------------------------------------------------

View File

@@ -41,13 +41,12 @@ from __future__ import annotations
import asyncio
import time
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING
import aiohttp
import structlog
from app.models.geo import GeoInfo
from app.repositories import geo_cache_repo
if TYPE_CHECKING:
@@ -94,40 +93,6 @@ _BATCH_DELAY: float = 1.5
#: transient error (e.g. connection reset due to rate limiting).
_BATCH_MAX_RETRIES: int = 2
# ---------------------------------------------------------------------------
# Domain model
# ---------------------------------------------------------------------------
@dataclass
class GeoInfo:
"""Geographical and network metadata for a single IP address.
All fields default to ``None`` when the information is unavailable or
the lookup fails gracefully.
"""
country_code: str | None
"""ISO 3166-1 alpha-2 country code, e.g. ``"DE"``."""
country_name: str | None
"""Human-readable country name, e.g. ``"Germany"``."""
asn: str | None
"""Autonomous System Number string, e.g. ``"AS3320"``."""
org: str | None
"""Organisation name associated with the IP, e.g. ``"Deutsche Telekom"``."""
type GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
"""Async callable used to enrich IPs with :class:`~app.services.geo_service.GeoInfo`.
This is a shared type alias used by services that optionally accept a geo
lookup callable (for example, :mod:`app.services.history_service`).
"""
# ---------------------------------------------------------------------------
# Internal cache
# ---------------------------------------------------------------------------

View File

@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING
import structlog
if TYPE_CHECKING:
from app.services.geo_service import GeoEnricher
from app.models.geo import GeoEnricher
from app.models.ban import TIME_RANGE_SECONDS, TimeRange
from app.models.history import (
@@ -26,7 +26,7 @@ from app.models.history import (
IpTimelineEvent,
)
from app.repositories import fail2ban_db_repo
from app.services.ban_service import _get_fail2ban_db_path, _parse_data_json, _ts_to_iso
from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso
log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -93,7 +93,7 @@ async def list_history(
if range_ is not None:
since = _since_unix(range_)
db_path: str = await _get_fail2ban_db_path(socket_path)
db_path: str = await get_fail2ban_db_path(socket_path)
log.info(
"history_service_list",
db_path=db_path,
@@ -116,9 +116,9 @@ async def list_history(
for row in rows:
jail_name: str = row.jail
ip: str = row.ip
banned_at: str = _ts_to_iso(row.timeofban)
banned_at: str = ts_to_iso(row.timeofban)
ban_count: int = row.bancount
matches, failures = _parse_data_json(row.data)
matches, failures = parse_data_json(row.data)
country_code: str | None = None
country_name: str | None = None
@@ -180,7 +180,7 @@ async def get_ip_detail(
:class:`~app.models.history.IpDetailResponse` if any records exist
for *ip*, or ``None`` if the IP has no history in the database.
"""
db_path: str = await _get_fail2ban_db_path(socket_path)
db_path: str = await get_fail2ban_db_path(socket_path)
log.info("history_service_ip_detail", db_path=db_path, ip=ip)
rows = await fail2ban_db_repo.get_history_for_ip(db_path=db_path, ip=ip)
@@ -193,9 +193,9 @@ async def get_ip_detail(
for row in rows:
jail_name: str = row.jail
banned_at: str = _ts_to_iso(row.timeofban)
banned_at: str = ts_to_iso(row.timeofban)
ban_count: int = row.bancount
matches, failures = _parse_data_json(row.data)
matches, failures = parse_data_json(row.data)
total_failures += failures
timeline.append(
IpTimelineEvent(

View File

@@ -14,11 +14,11 @@ from __future__ import annotations
import asyncio
import contextlib
import ipaddress
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, TypedDict, cast
import structlog
from app.exceptions import JailNotFoundError, JailOperationError
from app.models.ban import ActiveBan, ActiveBanListResponse, JailBannedIpsResponse
from app.models.config import BantimeEscalation
from app.models.jail import (
@@ -28,7 +28,6 @@ from app.models.jail import (
JailStatus,
JailSummary,
)
from app.services.geo_service import GeoInfo
from app.utils.fail2ban_client import (
Fail2BanClient,
Fail2BanCommand,
@@ -38,9 +37,13 @@ from app.utils.fail2ban_client import (
)
if TYPE_CHECKING:
from collections.abc import Awaitable
import aiohttp
import aiosqlite
from app.models.geo import GeoBatchLookup, GeoEnricher, GeoInfo
log: structlog.stdlib.BoundLogger = structlog.get_logger()
class IpLookupResult(TypedDict):
@@ -55,8 +58,6 @@ class IpLookupResult(TypedDict):
geo: GeoInfo | None
GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
@@ -81,23 +82,6 @@ _backend_cmd_lock: asyncio.Lock = asyncio.Lock()
# ---------------------------------------------------------------------------
class JailNotFoundError(Exception):
"""Raised when a requested jail name does not exist in fail2ban."""
def __init__(self, name: str) -> None:
"""Initialise with the jail name that was not found.
Args:
name: The jail name that could not be located.
"""
self.name: str = name
super().__init__(f"Jail not found: {name!r}")
class JailOperationError(Exception):
"""Raised when a jail control command fails for a non-auth reason."""
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
@@ -820,6 +804,7 @@ async def unban_ip(
async def get_active_bans(
socket_path: str,
geo_batch_lookup: GeoBatchLookup | None = None,
geo_enricher: GeoEnricher | None = None,
http_session: aiohttp.ClientSession | None = None,
app_db: aiosqlite.Connection | None = None,
@@ -857,7 +842,6 @@ async def get_active_bans(
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
cannot be reached.
"""
from app.services import geo_service # noqa: PLC0415
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
@@ -905,10 +889,10 @@ async def get_active_bans(
bans.append(ban)
# Enrich with geo data — prefer batch lookup over per-IP enricher.
if http_session is not None and bans:
if http_session is not None and bans and geo_batch_lookup is not None:
all_ips: list[str] = [ban.ip for ban in bans]
try:
geo_map = await geo_service.lookup_batch(all_ips, http_session, db=app_db)
geo_map = await geo_batch_lookup(all_ips, http_session, db=app_db)
except Exception: # noqa: BLE001
log.warning("active_bans_batch_geo_failed")
geo_map = {}
@@ -1017,6 +1001,7 @@ async def get_jail_banned_ips(
page: int = 1,
page_size: int = 25,
search: str | None = None,
geo_batch_lookup: GeoBatchLookup | None = None,
http_session: aiohttp.ClientSession | None = None,
app_db: aiosqlite.Connection | None = None,
) -> JailBannedIpsResponse:
@@ -1044,8 +1029,6 @@ async def get_jail_banned_ips(
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket is
unreachable.
"""
from app.services import geo_service # noqa: PLC0415
# Clamp page_size to the allowed maximum.
page_size = min(page_size, _MAX_PAGE_SIZE)
@@ -1086,10 +1069,10 @@ async def get_jail_banned_ips(
page_bans = all_bans[start : start + page_size]
# Geo-enrich only the page slice.
if http_session is not None and page_bans:
if http_session is not None and page_bans and geo_batch_lookup is not None:
page_ips = [b.ip for b in page_bans]
try:
geo_map = await geo_service.lookup_batch(page_ips, http_session, db=app_db)
geo_map = await geo_batch_lookup(page_ips, http_session, db=app_db)
except Exception: # noqa: BLE001
log.warning("jail_banned_ips_geo_failed", jail=jail_name)
geo_map = {}

View File

@@ -14,6 +14,8 @@ from typing import cast
import structlog
from app.exceptions import ServerOperationError
from app.exceptions import ServerOperationError
from app.models.server import ServerSettings, ServerSettingsResponse, ServerSettingsUpdate
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanCommand, Fail2BanResponse
@@ -54,15 +56,6 @@ def _to_str(value: object | None, default: str) -> str:
return str(value)
# ---------------------------------------------------------------------------
# Custom exceptions
# ---------------------------------------------------------------------------
class ServerOperationError(Exception):
"""Raised when a server-level set command fails."""
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------

View File

@@ -0,0 +1,63 @@
"""Utilities shared by fail2ban-related services."""
from __future__ import annotations
import json
from datetime import UTC, datetime
def ts_to_iso(unix_ts: int) -> str:
"""Convert a Unix timestamp to an ISO 8601 UTC string."""
return datetime.fromtimestamp(unix_ts, tz=UTC).isoformat()
async def get_fail2ban_db_path(socket_path: str) -> str:
"""Query fail2ban for the path to its SQLite database file."""
from app.utils.fail2ban_client import Fail2BanClient # pragma: no cover
socket_timeout: float = 5.0
async with Fail2BanClient(socket_path, timeout=socket_timeout) as client:
response = await client.send(["get", "dbfile"])
if not isinstance(response, tuple) or len(response) != 2:
raise RuntimeError(f"Unexpected response from fail2ban: {response!r}")
code, data = response
if code != 0:
raise RuntimeError(f"fail2ban error code {code}: {data!r}")
if data is None:
raise RuntimeError("fail2ban has no database configured (dbfile is None)")
return str(data)
def parse_data_json(raw: object) -> tuple[list[str], int]:
"""Extract matches and failure count from the fail2ban bans.data value."""
if raw is None:
return [], 0
obj: dict[str, object] = {}
if isinstance(raw, str):
try:
parsed = json.loads(raw)
if isinstance(parsed, dict):
obj = parsed
except json.JSONDecodeError:
return [], 0
elif isinstance(raw, dict):
obj = raw
raw_matches = obj.get("matches")
matches = [str(m) for m in raw_matches] if isinstance(raw_matches, list) else []
raw_failures = obj.get("failures")
failures = 0
if isinstance(raw_failures, (int, float, str)):
try:
failures = int(raw_failures)
except (ValueError, TypeError):
failures = 0
return matches, failures