Task 13: move ban_ip, unban_ip, and get_active_bans from jail_service to ban_service and update routers/tests
This commit is contained in:
@@ -241,6 +241,8 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
|
||||
|
||||
### Task 13 — Move `ban_ip`, `unban_ip`, and `get_active_bans` from `jail_service` to `ban_service`
|
||||
|
||||
**Status:** Completed ✅
|
||||
|
||||
**Found in:** `backend/app/services/jail_service.py` contains `ban_ip`, `unban_ip`, and `get_active_bans`. These operations conceptually belong in `ban_service.py`, which is the declared home for ban management. Routers `bans.py` and `blocklist.py` already import `jail_service` specifically for these functions.
|
||||
|
||||
**Goal:** Move the three functions into `ban_service.py`. Update `bans.py` and `blocklist.py` to import from `ban_service` instead of `jail_service`. Remove the now-redundant `jail_service` import from those routers if it is no longer needed.
|
||||
|
||||
@@ -22,7 +22,7 @@ from app.dependencies import (
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest
|
||||
from app.models.jail import JailCommandResponse
|
||||
from app.services import jail_service
|
||||
from app.services import ban_service, jail_service
|
||||
from app.exceptions import Fail2BanConnectionError
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/bans", tags=["Bans"])
|
||||
@@ -72,7 +72,7 @@ async def get_active_bans(
|
||||
HTTPException: 502 when fail2ban is unreachable.
|
||||
"""
|
||||
try:
|
||||
return await jail_service.get_active_bans(
|
||||
return await ban_service.get_active_bans(
|
||||
socket_path,
|
||||
geo_batch_lookup=geo_batch_lookup,
|
||||
http_session=http_session,
|
||||
@@ -114,7 +114,7 @@ async def ban_ip(
|
||||
HTTPException: 502 when fail2ban is unreachable.
|
||||
"""
|
||||
try:
|
||||
await jail_service.ban_ip(socket_path, body.jail, body.ip)
|
||||
await ban_service.ban_ip(socket_path, body.jail, body.ip)
|
||||
return JailCommandResponse(
|
||||
message=f"IP {body.ip!r} banned in jail {body.jail!r}.",
|
||||
jail=body.jail,
|
||||
@@ -174,7 +174,7 @@ async def unban_ip(
|
||||
target_jail: str | None = None if (body.unban_all or body.jail is None) else body.jail
|
||||
|
||||
try:
|
||||
await jail_service.unban_ip(socket_path, body.ip, jail=target_jail)
|
||||
await ban_service.unban_ip(socket_path, body.ip, jail=target_jail)
|
||||
scope = f"jail {target_jail!r}" if target_jail else "all jails"
|
||||
return JailCommandResponse(
|
||||
message=f"IP {body.ip!r} unbanned from {scope}.",
|
||||
|
||||
@@ -44,7 +44,7 @@ from app.models.blocklist import (
|
||||
ScheduleConfig,
|
||||
ScheduleInfo,
|
||||
)
|
||||
from app.services import blocklist_service, geo_service, jail_service
|
||||
from app.services import ban_service, blocklist_service, geo_service
|
||||
from app.tasks.blocklist_import import run_import_with_resources
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"])
|
||||
@@ -138,7 +138,7 @@ async def run_import_now(
|
||||
socket_path,
|
||||
geo_is_cached=geo_service.is_cached,
|
||||
geo_batch_lookup=geo_batch_lookup,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
ban_ip=ban_service.ban_ip,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -11,11 +11,14 @@ so BanGUI never modifies or locks the fail2ban database.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import ipaddress
|
||||
import time
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import structlog
|
||||
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
from app.models.ban import (
|
||||
BLOCKLIST_JAIL,
|
||||
BUCKET_SECONDS,
|
||||
@@ -26,6 +29,8 @@ from app.models.ban import (
|
||||
BansByJailResponse,
|
||||
BanTrendBucket,
|
||||
BanTrendResponse,
|
||||
ActiveBan,
|
||||
ActiveBanListResponse,
|
||||
DashboardBanItem,
|
||||
DashboardBanListResponse,
|
||||
TimeRange,
|
||||
@@ -42,6 +47,10 @@ from app.repositories.history_archive_repo import (
|
||||
)
|
||||
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
|
||||
from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso
|
||||
from app.utils.fail2ban_client import (
|
||||
Fail2BanClient,
|
||||
Fail2BanResponse,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
@@ -70,6 +79,114 @@ _SOCKET_TIMEOUT: float = 5.0
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ok(response: object) -> object:
|
||||
"""Extract the payload from a fail2ban ``(return_code, data)`` response.
|
||||
|
||||
Args:
|
||||
response: Raw value returned by :meth:`~Fail2BanClient.send`.
|
||||
|
||||
Returns:
|
||||
The payload ``data`` portion of the response.
|
||||
|
||||
Raises:
|
||||
ValueError: If the response indicates an error (return code ≠ 0).
|
||||
"""
|
||||
try:
|
||||
code, data = response # type: ignore[assignment]
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||
|
||||
if code != 0:
|
||||
raise ValueError(f"fail2ban returned error code {code}: {data!r}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _to_dict(pairs: object) -> dict[str, object]:
|
||||
"""Convert a list of ``(key, value)`` pairs to a plain dict.
|
||||
|
||||
Args:
|
||||
pairs: A list of ``(key, value)`` pairs (or any iterable thereof).
|
||||
|
||||
Returns:
|
||||
A :class:`dict` with the keys and values from *pairs*.
|
||||
"""
|
||||
if not isinstance(pairs, (list, tuple)):
|
||||
return {}
|
||||
result: dict[str, object] = {}
|
||||
for item in pairs:
|
||||
try:
|
||||
k, v = item
|
||||
result[str(k)] = v
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
def _ensure_list(value: object | None) -> list[str]:
|
||||
"""Coerce a fail2ban response value to a list of strings."""
|
||||
if value is None:
|
||||
return []
|
||||
if isinstance(value, str):
|
||||
return [value] if value.strip() else []
|
||||
if isinstance(value, (list, tuple)):
|
||||
return [str(v) for v in value if v is not None]
|
||||
return [str(value)]
|
||||
|
||||
|
||||
def _is_not_found_error(exc: Exception) -> bool:
|
||||
"""Return ``True`` if *exc* indicates a jail does not exist."""
|
||||
msg = str(exc).lower()
|
||||
return any(
|
||||
phrase in msg
|
||||
for phrase in (
|
||||
"unknown jail",
|
||||
"unknownjail",
|
||||
"no jail",
|
||||
"does not exist",
|
||||
"not found",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def ban_ip(socket_path: str, jail: str, ip: str) -> None:
|
||||
"""Ban an IP address in the specified jail."""
|
||||
try:
|
||||
ipaddress.ip_address(ip)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Invalid IP address: {ip!r}") from exc
|
||||
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
|
||||
try:
|
||||
_ok(await client.send(["set", jail, "banip", ip]))
|
||||
except ValueError as exc:
|
||||
if _is_not_found_error(exc):
|
||||
raise JailNotFoundError(jail) from exc
|
||||
raise JailOperationError(str(exc)) from exc
|
||||
|
||||
|
||||
async def unban_ip(socket_path: str, ip: str, jail: str | None = None) -> None:
|
||||
"""Unban an IP address from a specific jail or all jails."""
|
||||
try:
|
||||
ipaddress.ip_address(ip)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Invalid IP address: {ip!r}") from exc
|
||||
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
|
||||
if jail is None:
|
||||
_ok(await client.send(["unban", ip]))
|
||||
return
|
||||
|
||||
try:
|
||||
_ok(await client.send(["set", jail, "unbanip", ip]))
|
||||
except ValueError as exc:
|
||||
if _is_not_found_error(exc):
|
||||
raise JailNotFoundError(jail) from exc
|
||||
raise JailOperationError(str(exc)) from exc
|
||||
|
||||
|
||||
def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
|
||||
"""Return a SQL fragment and its parameters for the origin filter.
|
||||
|
||||
@@ -88,6 +205,191 @@ def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
|
||||
return "", ()
|
||||
|
||||
|
||||
def _parse_ban_entry(entry: str, jail: str) -> ActiveBan | None:
|
||||
"""Parse a ban entry from ``get <jail> banip --with-time`` output."""
|
||||
from datetime import UTC, datetime
|
||||
|
||||
try:
|
||||
parts = entry.split("\t", 1)
|
||||
ip = parts[0].strip()
|
||||
|
||||
ipaddress.ip_address(ip)
|
||||
|
||||
if len(parts) < 2:
|
||||
return ActiveBan(
|
||||
ip=ip,
|
||||
jail=jail,
|
||||
banned_at=None,
|
||||
expires_at=None,
|
||||
ban_count=1,
|
||||
country=None,
|
||||
)
|
||||
|
||||
time_part = parts[1].strip()
|
||||
plus_idx = time_part.find(" + ")
|
||||
if plus_idx == -1:
|
||||
banned_at_str = time_part.strip()
|
||||
expires_at_str: str | None = None
|
||||
else:
|
||||
banned_at_str = time_part[:plus_idx].strip()
|
||||
remainder = time_part[plus_idx + 3 :]
|
||||
eq_idx = remainder.find(" = ")
|
||||
expires_at_str = remainder[eq_idx + 3 :].strip() if eq_idx != -1 else None
|
||||
|
||||
_date_fmt = "%Y-%m-%d %H:%M:%S"
|
||||
|
||||
def _to_iso(ts: str) -> str:
|
||||
dt = datetime.strptime(ts, _date_fmt).replace(tzinfo=UTC)
|
||||
return dt.isoformat()
|
||||
|
||||
banned_at_iso: str | None = None
|
||||
expires_at_iso: str | None = None
|
||||
|
||||
with contextlib.suppress(ValueError):
|
||||
banned_at_iso = _to_iso(banned_at_str)
|
||||
|
||||
with contextlib.suppress(ValueError):
|
||||
if expires_at_str:
|
||||
expires_at_iso = _to_iso(expires_at_str)
|
||||
|
||||
return ActiveBan(
|
||||
ip=ip,
|
||||
jail=jail,
|
||||
banned_at=banned_at_iso,
|
||||
expires_at=expires_at_iso,
|
||||
ban_count=1,
|
||||
country=None,
|
||||
)
|
||||
except (ValueError, IndexError, AttributeError) as exc:
|
||||
log.debug("ban_entry_parse_error", entry=entry, jail=jail, error=str(exc))
|
||||
return None
|
||||
|
||||
|
||||
async def _enrich_bans(
|
||||
bans: list[ActiveBan],
|
||||
geo_enricher: GeoEnricher,
|
||||
) -> list[ActiveBan]:
|
||||
"""Enrich ban records with geo data asynchronously."""
|
||||
geo_results: list[object | Exception] = await asyncio.gather(
|
||||
*[cast("Awaitable[object]", geo_enricher(ban.ip)) for ban in bans],
|
||||
return_exceptions=True,
|
||||
)
|
||||
enriched: list[ActiveBan] = []
|
||||
for ban, geo in zip(bans, geo_results, strict=False):
|
||||
if geo is not None and not isinstance(geo, Exception):
|
||||
geo_info = cast("GeoInfo", geo)
|
||||
enriched.append(ban.model_copy(update={"country": geo_info.country_code}))
|
||||
else:
|
||||
enriched.append(ban)
|
||||
return enriched
|
||||
|
||||
|
||||
async def get_active_bans(
|
||||
socket_path: str,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
) -> ActiveBanListResponse:
|
||||
"""Return all currently banned IPs across every jail.
|
||||
|
||||
For each jail the ``get <jail> banip --with-time`` command is used
|
||||
to retrieve ban start and expiry times alongside the IP address.
|
||||
|
||||
Geo enrichment strategy (highest priority first):
|
||||
|
||||
1. When *http_session* is provided the entire set of banned IPs is resolved
|
||||
in a single :func:`~app.services.geo_service.lookup_batch` call (up to
|
||||
100 IPs per HTTP request). This is far more efficient than concurrent
|
||||
per-IP lookups and stays within ip-api.com rate limits.
|
||||
2. When only *geo_enricher* is provided (legacy / test path) each IP is
|
||||
resolved individually via the supplied async callable.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
geo_enricher: Optional async callable ``(ip) -> GeoInfo | None``.
|
||||
Used to enrich each ban entry with country and ASN data.
|
||||
Ignored when *http_session* is provided.
|
||||
http_session: Optional shared :class:`aiohttp.ClientSession`. When
|
||||
provided, :func:`~app.services.geo_service.lookup_batch` is used
|
||||
for efficient bulk geo resolution.
|
||||
app_db: Optional BanGUI application database connection used to
|
||||
persist newly resolved geo entries across restarts. Only
|
||||
meaningful when *http_session* is provided.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.ban.ActiveBanListResponse` with all active bans.
|
||||
|
||||
Raises:
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
||||
cannot be reached.
|
||||
"""
|
||||
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
|
||||
global_status = _to_dict(_ok(await client.send(["status"])))
|
||||
jail_list_raw: str = str(global_status.get("Jail list", "") or "").strip()
|
||||
jail_names: list[str] = (
|
||||
[j.strip() for j in jail_list_raw.split(",") if j.strip()]
|
||||
if jail_list_raw
|
||||
else []
|
||||
)
|
||||
|
||||
if not jail_names:
|
||||
return ActiveBanListResponse(bans=[], total=0)
|
||||
|
||||
results: list[object | Exception] = await asyncio.gather(
|
||||
*[client.send(["get", jn, "banip", "--with-time"]) for jn in jail_names],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
bans: list[ActiveBan] = []
|
||||
for jail_name, raw_result in zip(jail_names, results, strict=False):
|
||||
if isinstance(raw_result, Exception):
|
||||
log.warning(
|
||||
"active_bans_fetch_error",
|
||||
jail=jail_name,
|
||||
error=str(raw_result),
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
ban_list: list[str] = cast("list[str]", _ok(raw_result)) or []
|
||||
except (TypeError, ValueError) as exc:
|
||||
log.warning(
|
||||
"active_bans_parse_error",
|
||||
jail=jail_name,
|
||||
error=str(exc),
|
||||
)
|
||||
continue
|
||||
|
||||
for entry in ban_list:
|
||||
ban = _parse_ban_entry(str(entry), jail_name)
|
||||
if ban is not None:
|
||||
bans.append(ban)
|
||||
|
||||
if http_session is not None and bans and geo_batch_lookup is not None:
|
||||
all_ips: list[str] = [ban.ip for ban in bans]
|
||||
try:
|
||||
geo_map = await geo_batch_lookup(all_ips, http_session, db=app_db)
|
||||
except Exception: # noqa: BLE001
|
||||
log.warning("active_bans_batch_geo_failed")
|
||||
geo_map = {}
|
||||
enriched: list[ActiveBan] = []
|
||||
for ban in bans:
|
||||
geo = geo_map.get(ban.ip)
|
||||
if geo is not None:
|
||||
enriched.append(ban.model_copy(update={"country": geo.country_code}))
|
||||
else:
|
||||
enriched.append(ban)
|
||||
bans = enriched
|
||||
elif geo_enricher is not None:
|
||||
bans = await _enrich_bans(bans, geo_enricher)
|
||||
|
||||
log.info("active_bans_fetched", total=len(bans))
|
||||
return ActiveBanListResponse(bans=bans, total=len(bans))
|
||||
|
||||
|
||||
_TIME_RANGE_SLACK_SECONDS: int = 60
|
||||
|
||||
|
||||
@@ -502,7 +804,7 @@ async def bans_by_country(
|
||||
country_names[cc] = cn
|
||||
|
||||
# Build companion table from recent rows (geo already cached from batch step).
|
||||
bans: list[DashboardBanItem] = []
|
||||
bans: list[ActiveBan] = []
|
||||
for companion_row in companion_rows:
|
||||
if source == "archive":
|
||||
ip = companion_row["ip"]
|
||||
|
||||
@@ -802,194 +802,6 @@ async def restart_daemon(
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — Ban / Unban
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def ban_ip(socket_path: str, jail: str, ip: str) -> None:
|
||||
"""Ban an IP address in a specific fail2ban jail.
|
||||
|
||||
The IP address is validated with :mod:`ipaddress` before the command
|
||||
is sent to fail2ban.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
jail: Jail in which to apply the ban.
|
||||
ip: IP address to ban (IPv4 or IPv6).
|
||||
|
||||
Raises:
|
||||
ValueError: If *ip* is not a valid IP address.
|
||||
JailNotFoundError: If *jail* is not a known jail.
|
||||
JailOperationError: If fail2ban reports the operation failed.
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
||||
cannot be reached.
|
||||
"""
|
||||
# Validate the IP address before sending to avoid injection.
|
||||
try:
|
||||
ipaddress.ip_address(ip)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Invalid IP address: {ip!r}") from exc
|
||||
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
try:
|
||||
_ok(await client.send(["set", jail, "banip", ip]))
|
||||
log.info("ip_banned", ip=ip, jail=jail)
|
||||
except ValueError as exc:
|
||||
if _is_not_found_error(exc):
|
||||
raise JailNotFoundError(jail) from exc
|
||||
raise JailOperationError(str(exc)) from exc
|
||||
|
||||
|
||||
async def unban_ip(
|
||||
socket_path: str,
|
||||
ip: str,
|
||||
jail: str | None = None,
|
||||
) -> None:
|
||||
"""Unban an IP address from one or all fail2ban jails.
|
||||
|
||||
If *jail* is ``None``, the IP is unbanned from every jail using the
|
||||
global ``unban`` command. Otherwise only the specified jail is
|
||||
targeted.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
ip: IP address to unban.
|
||||
jail: Jail to unban from. ``None`` means all jails.
|
||||
|
||||
Raises:
|
||||
ValueError: If *ip* is not a valid IP address.
|
||||
JailNotFoundError: If *jail* is specified but does not exist.
|
||||
JailOperationError: If fail2ban reports the operation failed.
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
||||
cannot be reached.
|
||||
"""
|
||||
try:
|
||||
ipaddress.ip_address(ip)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Invalid IP address: {ip!r}") from exc
|
||||
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
try:
|
||||
if jail is None:
|
||||
_ok(await client.send(["unban", ip]))
|
||||
log.info("ip_unbanned_all_jails", ip=ip)
|
||||
else:
|
||||
_ok(await client.send(["set", jail, "unbanip", ip]))
|
||||
log.info("ip_unbanned", ip=ip, jail=jail)
|
||||
except ValueError as exc:
|
||||
if _is_not_found_error(exc):
|
||||
raise JailNotFoundError(jail or "") from exc
|
||||
raise JailOperationError(str(exc)) from exc
|
||||
|
||||
|
||||
async def get_active_bans(
|
||||
socket_path: str,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
) -> ActiveBanListResponse:
|
||||
"""Return all currently banned IPs across every jail.
|
||||
|
||||
For each jail the ``get <jail> banip --with-time`` command is used
|
||||
to retrieve ban start and expiry times alongside the IP address.
|
||||
|
||||
Geo enrichment strategy (highest priority first):
|
||||
|
||||
1. When *http_session* is provided the entire set of banned IPs is resolved
|
||||
in a single :func:`~app.services.geo_service.lookup_batch` call (up to
|
||||
100 IPs per HTTP request). This is far more efficient than concurrent
|
||||
per-IP lookups and stays within ip-api.com rate limits.
|
||||
2. When only *geo_enricher* is provided (legacy / test path) each IP is
|
||||
resolved individually via the supplied async callable.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
geo_enricher: Optional async callable ``(ip) → GeoInfo | None``
|
||||
used to enrich each ban entry with country and ASN data.
|
||||
Ignored when *http_session* is provided.
|
||||
http_session: Optional shared :class:`aiohttp.ClientSession`. When
|
||||
provided, :func:`~app.services.geo_service.lookup_batch` is used
|
||||
for efficient bulk geo resolution.
|
||||
app_db: Optional BanGUI application database connection used to
|
||||
persist newly resolved geo entries across restarts. Only
|
||||
meaningful when *http_session* is provided.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.ban.ActiveBanListResponse` with all active bans.
|
||||
|
||||
Raises:
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
||||
cannot be reached.
|
||||
"""
|
||||
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
|
||||
# Fetch jail names.
|
||||
global_status = _to_dict(_ok(await client.send(["status"])))
|
||||
jail_list_raw: str = str(global_status.get("Jail list", "") or "").strip()
|
||||
jail_names: list[str] = (
|
||||
[j.strip() for j in jail_list_raw.split(",") if j.strip()]
|
||||
if jail_list_raw
|
||||
else []
|
||||
)
|
||||
|
||||
if not jail_names:
|
||||
return ActiveBanListResponse(bans=[], total=0)
|
||||
|
||||
# For each jail, fetch the ban list with time info in parallel.
|
||||
results: list[object | Exception] = await asyncio.gather(
|
||||
*[client.send(["get", jn, "banip", "--with-time"]) for jn in jail_names],
|
||||
return_exceptions=True,
|
||||
)
|
||||
|
||||
bans: list[ActiveBan] = []
|
||||
for jail_name, raw_result in zip(jail_names, results, strict=False):
|
||||
if isinstance(raw_result, Exception):
|
||||
log.warning(
|
||||
"active_bans_fetch_error",
|
||||
jail=jail_name,
|
||||
error=str(raw_result),
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
ban_list: list[str] = cast("list[str]", _ok(raw_result)) or []
|
||||
except (TypeError, ValueError) as exc:
|
||||
log.warning(
|
||||
"active_bans_parse_error",
|
||||
jail=jail_name,
|
||||
error=str(exc),
|
||||
)
|
||||
continue
|
||||
|
||||
for entry in ban_list:
|
||||
ban = _parse_ban_entry(str(entry), jail_name)
|
||||
if ban is not None:
|
||||
bans.append(ban)
|
||||
|
||||
# Enrich with geo data — prefer batch lookup over per-IP enricher.
|
||||
if http_session is not None and bans and geo_batch_lookup is not None:
|
||||
all_ips: list[str] = [ban.ip for ban in bans]
|
||||
try:
|
||||
geo_map = await geo_batch_lookup(all_ips, http_session, db=app_db)
|
||||
except Exception: # noqa: BLE001
|
||||
log.warning("active_bans_batch_geo_failed")
|
||||
geo_map = {}
|
||||
enriched: list[ActiveBan] = []
|
||||
for ban in bans:
|
||||
geo = geo_map.get(ban.ip)
|
||||
if geo is not None:
|
||||
enriched.append(ban.model_copy(update={"country": geo.country_code}))
|
||||
else:
|
||||
enriched.append(ban)
|
||||
bans = enriched
|
||||
elif geo_enricher is not None:
|
||||
bans = await _enrich_bans(bans, geo_enricher)
|
||||
|
||||
log.info("active_bans_fetched", total=len(bans))
|
||||
return ActiveBanListResponse(bans=bans, total=len(bans))
|
||||
|
||||
|
||||
def _parse_ban_entry(entry: str, jail: str) -> ActiveBan | None:
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any
|
||||
import structlog
|
||||
|
||||
from app.db import open_db
|
||||
from app.services import blocklist_service, jail_service
|
||||
from app.services import ban_service, blocklist_service
|
||||
from app.utils.runtime_state import get_effective_settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -55,7 +55,7 @@ async def _run_import_with_resources(settings: Settings, http_session: ClientSes
|
||||
db,
|
||||
http_session,
|
||||
socket_path,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
ban_ip=ban_service.ban_ip,
|
||||
)
|
||||
log.info(
|
||||
"blocklist_import_finished",
|
||||
|
||||
@@ -84,7 +84,7 @@ class TestGetActiveBans:
|
||||
total=1,
|
||||
)
|
||||
with patch(
|
||||
"app.routers.bans.jail_service.get_active_bans",
|
||||
"app.routers.bans.ban_service.get_active_bans",
|
||||
AsyncMock(return_value=mock_response),
|
||||
):
|
||||
resp = await bans_client.get("/api/bans/active")
|
||||
@@ -107,7 +107,7 @@ class TestGetActiveBans:
|
||||
"""GET /api/bans/active returns empty list when no bans are active."""
|
||||
mock_response = ActiveBanListResponse(bans=[], total=0)
|
||||
with patch(
|
||||
"app.routers.bans.jail_service.get_active_bans",
|
||||
"app.routers.bans.ban_service.get_active_bans",
|
||||
AsyncMock(return_value=mock_response),
|
||||
):
|
||||
resp = await bans_client.get("/api/bans/active")
|
||||
@@ -132,7 +132,7 @@ class TestGetActiveBans:
|
||||
total=1,
|
||||
)
|
||||
with patch(
|
||||
"app.routers.bans.jail_service.get_active_bans",
|
||||
"app.routers.bans.ban_service.get_active_bans",
|
||||
AsyncMock(return_value=mock_response),
|
||||
):
|
||||
resp = await bans_client.get("/api/bans/active")
|
||||
@@ -156,7 +156,7 @@ class TestBanIp:
|
||||
async def test_201_on_success(self, bans_client: AsyncClient) -> None:
|
||||
"""POST /api/bans returns 201 when the IP is banned."""
|
||||
with patch(
|
||||
"app.routers.bans.jail_service.ban_ip",
|
||||
"app.routers.bans.ban_service.ban_ip",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await bans_client.post(
|
||||
@@ -170,7 +170,7 @@ class TestBanIp:
|
||||
async def test_400_for_invalid_ip(self, bans_client: AsyncClient) -> None:
|
||||
"""POST /api/bans returns 400 for an invalid IP address."""
|
||||
with patch(
|
||||
"app.routers.bans.jail_service.ban_ip",
|
||||
"app.routers.bans.ban_service.ban_ip",
|
||||
AsyncMock(side_effect=ValueError("Invalid IP address: 'bad'")),
|
||||
):
|
||||
resp = await bans_client.post(
|
||||
@@ -185,7 +185,7 @@ class TestBanIp:
|
||||
from app.services.jail_service import JailNotFoundError
|
||||
|
||||
with patch(
|
||||
"app.routers.bans.jail_service.ban_ip",
|
||||
"app.routers.bans.ban_service.ban_ip",
|
||||
AsyncMock(side_effect=JailNotFoundError("ghost")),
|
||||
):
|
||||
resp = await bans_client.post(
|
||||
@@ -215,7 +215,7 @@ class TestUnbanIp:
|
||||
async def test_200_unban_from_all(self, bans_client: AsyncClient) -> None:
|
||||
"""DELETE /api/bans with unban_all=true unbans from all jails."""
|
||||
with patch(
|
||||
"app.routers.bans.jail_service.unban_ip",
|
||||
"app.routers.bans.ban_service.unban_ip",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await bans_client.request(
|
||||
@@ -230,7 +230,7 @@ class TestUnbanIp:
|
||||
async def test_200_unban_from_specific_jail(self, bans_client: AsyncClient) -> None:
|
||||
"""DELETE /api/bans with a jail unbans from that jail only."""
|
||||
with patch(
|
||||
"app.routers.bans.jail_service.unban_ip",
|
||||
"app.routers.bans.ban_service.unban_ip",
|
||||
AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await bans_client.request(
|
||||
@@ -245,7 +245,7 @@ class TestUnbanIp:
|
||||
async def test_400_for_invalid_ip(self, bans_client: AsyncClient) -> None:
|
||||
"""DELETE /api/bans returns 400 for an invalid IP."""
|
||||
with patch(
|
||||
"app.routers.bans.jail_service.unban_ip",
|
||||
"app.routers.bans.ban_service.unban_ip",
|
||||
AsyncMock(side_effect=ValueError("Invalid IP address: 'bad'")),
|
||||
):
|
||||
resp = await bans_client.request(
|
||||
@@ -261,7 +261,7 @@ class TestUnbanIp:
|
||||
from app.services.jail_service import JailNotFoundError
|
||||
|
||||
with patch(
|
||||
"app.routers.bans.jail_service.unban_ip",
|
||||
"app.routers.bans.ban_service.unban_ip",
|
||||
AsyncMock(side_effect=JailNotFoundError("ghost")),
|
||||
):
|
||||
resp = await bans_client.request(
|
||||
|
||||
@@ -174,17 +174,17 @@ class TestImport:
|
||||
|
||||
source = await blocklist_service.create_source(db, "Import Test", "https://t.test/")
|
||||
|
||||
from app.services import jail_service
|
||||
from app.services import ban_service
|
||||
|
||||
with patch(
|
||||
"app.services.jail_service.ban_ip", new_callable=AsyncMock
|
||||
"app.services.ban_service.ban_ip", new_callable=AsyncMock
|
||||
) as mock_ban:
|
||||
result = await blocklist_service.import_source(
|
||||
source,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
db,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
ban_ip=ban_service.ban_ip,
|
||||
)
|
||||
|
||||
assert result.ips_imported == 2
|
||||
@@ -198,15 +198,15 @@ class TestImport:
|
||||
session = _make_session(content)
|
||||
source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/")
|
||||
|
||||
from app.services import jail_service
|
||||
from app.services import ban_service
|
||||
|
||||
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
|
||||
with patch("app.services.ban_service.ban_ip", new_callable=AsyncMock):
|
||||
result = await blocklist_service.import_source(
|
||||
source,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
db,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
ban_ip=ban_service.ban_ip,
|
||||
)
|
||||
|
||||
assert result.ips_imported == 1
|
||||
@@ -217,14 +217,14 @@ class TestImport:
|
||||
session = _make_session("", status=503)
|
||||
source = await blocklist_service.create_source(db, "Err Source", "https://err.test/")
|
||||
|
||||
from app.services import jail_service
|
||||
from app.services import ban_service
|
||||
|
||||
result = await blocklist_service.import_source(
|
||||
source,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
db,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
ban_ip=ban_service.ban_ip,
|
||||
)
|
||||
|
||||
assert result.ips_imported == 0
|
||||
@@ -234,6 +234,7 @@ class TestImport:
|
||||
"""import_source aborts immediately and records an error when the target jail
|
||||
does not exist in fail2ban instead of silently skipping every IP."""
|
||||
from app.services.jail_service import JailNotFoundError
|
||||
from app.services import ban_service
|
||||
|
||||
content = "\n".join(f"1.2.3.{i}" for i in range(100))
|
||||
session = _make_session(content)
|
||||
@@ -246,15 +247,13 @@ class TestImport:
|
||||
call_count += 1
|
||||
raise JailNotFoundError(jail)
|
||||
|
||||
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found):
|
||||
from app.services import jail_service
|
||||
|
||||
with patch("app.services.ban_service.ban_ip", side_effect=_raise_jail_not_found):
|
||||
result = await blocklist_service.import_source(
|
||||
source,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
db,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
ban_ip=ban_service.ban_ip,
|
||||
)
|
||||
|
||||
# Must abort after the first JailNotFoundError — only one ban attempt.
|
||||
@@ -273,15 +272,15 @@ class TestImport:
|
||||
session = _make_session(content)
|
||||
|
||||
with patch(
|
||||
"app.services.jail_service.ban_ip", new_callable=AsyncMock
|
||||
"app.services.ban_service.ban_ip", new_callable=AsyncMock
|
||||
):
|
||||
from app.services import jail_service
|
||||
from app.services import ban_service
|
||||
|
||||
result = await blocklist_service.import_all(
|
||||
db,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
ban_ip=jail_service.ban_ip,
|
||||
ban_ip=ban_service.ban_ip,
|
||||
)
|
||||
|
||||
# Only S1 is enabled, S2 is disabled.
|
||||
@@ -415,16 +414,16 @@ class TestGeoPrewarmCacheFilter:
|
||||
def _mock_is_cached(ip: str) -> bool:
|
||||
return ip == "1.2.3.4"
|
||||
|
||||
from app.services import jail_service
|
||||
from app.services import ban_service
|
||||
|
||||
mock_batch = AsyncMock(return_value={})
|
||||
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
|
||||
with patch("app.services.ban_service.ban_ip", new_callable=AsyncMock):
|
||||
result = await blocklist_service.import_source(
|
||||
source,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
db,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
ban_ip=ban_service.ban_ip,
|
||||
geo_is_cached=_mock_is_cached,
|
||||
geo_batch_lookup=mock_batch,
|
||||
)
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
@@ -12,7 +13,7 @@ from app.exceptions import Fail2BanConnectionError
|
||||
from app.models.ban import ActiveBanListResponse, JailBannedIpsResponse
|
||||
from app.models.geo import GeoDetail, GeoInfo
|
||||
from app.models.jail import JailDetailResponse, JailListResponse
|
||||
from app.services import jail_service
|
||||
from app.services import ban_service, jail_service
|
||||
from app.services.jail_service import JailNotFoundError, JailOperationError
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -71,7 +72,10 @@ def _patch_client(responses: dict[str, Any]) -> Any:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
self.send = mock_send
|
||||
|
||||
return patch("app.services.jail_service.Fail2BanClient", _FakeClient)
|
||||
stack = contextlib.ExitStack()
|
||||
stack.enter_context(patch("app.services.jail_service.Fail2BanClient", _FakeClient))
|
||||
stack.enter_context(patch("app.services.ban_service.Fail2BanClient", _FakeClient))
|
||||
return stack
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -555,19 +559,19 @@ class TestJailControls:
|
||||
|
||||
|
||||
class TestBanUnban:
|
||||
"""Unit tests for :func:`~app.services.jail_service.ban_ip` and
|
||||
:func:`~app.services.jail_service.unban_ip`.
|
||||
"""Unit tests for :func:`~app.services.ban_service.ban_ip` and
|
||||
:func:`~app.services.ban_service.unban_ip`.
|
||||
"""
|
||||
|
||||
async def test_ban_ip_success(self) -> None:
|
||||
"""ban_ip sends the banip command for a valid IP."""
|
||||
with _patch_client({"set|sshd|banip|1.2.3.4": (0, 1)}):
|
||||
await jail_service.ban_ip(_SOCKET, "sshd", "1.2.3.4") # should not raise
|
||||
await ban_service.ban_ip(_SOCKET, "sshd", "1.2.3.4") # should not raise
|
||||
|
||||
async def test_ban_ip_invalid_raises(self) -> None:
|
||||
"""ban_ip raises ValueError for a non-IP value."""
|
||||
with pytest.raises(ValueError, match="Invalid IP"):
|
||||
await jail_service.ban_ip(_SOCKET, "sshd", "not-an-ip")
|
||||
await ban_service.ban_ip(_SOCKET, "sshd", "not-an-ip")
|
||||
|
||||
async def test_ban_ip_unknown_jail_exception_raises_jail_not_found(self) -> None:
|
||||
"""ban_ip raises JailNotFoundError when fail2ban returns UnknownJailException.
|
||||
@@ -581,27 +585,27 @@ class TestBanUnban:
|
||||
_patch_client({"set|missing-jail|banip|1.2.3.4": response}),
|
||||
pytest.raises(JailNotFoundError, match="missing-jail"),
|
||||
):
|
||||
await jail_service.ban_ip(_SOCKET, "missing-jail", "1.2.3.4")
|
||||
await ban_service.ban_ip(_SOCKET, "missing-jail", "1.2.3.4")
|
||||
|
||||
async def test_ban_ipv6_success(self) -> None:
|
||||
"""ban_ip accepts an IPv6 address."""
|
||||
with _patch_client({"set|sshd|banip|::1": (0, 1)}):
|
||||
await jail_service.ban_ip(_SOCKET, "sshd", "::1") # should not raise
|
||||
await ban_service.ban_ip(_SOCKET, "sshd", "::1") # should not raise
|
||||
|
||||
async def test_unban_ip_all_jails(self) -> None:
|
||||
"""unban_ip with jail=None uses the global unban command."""
|
||||
with _patch_client({"unban|1.2.3.4": (0, 1)}):
|
||||
await jail_service.unban_ip(_SOCKET, "1.2.3.4") # should not raise
|
||||
await ban_service.unban_ip(_SOCKET, "1.2.3.4") # should not raise
|
||||
|
||||
async def test_unban_ip_specific_jail(self) -> None:
|
||||
"""unban_ip with a jail sends the set unbanip command."""
|
||||
with _patch_client({"set|sshd|unbanip|1.2.3.4": (0, 1)}):
|
||||
await jail_service.unban_ip(_SOCKET, "1.2.3.4", jail="sshd") # should not raise
|
||||
await ban_service.unban_ip(_SOCKET, "1.2.3.4", jail="sshd") # should not raise
|
||||
|
||||
async def test_unban_invalid_ip_raises(self) -> None:
|
||||
"""unban_ip raises ValueError for an invalid IP."""
|
||||
with pytest.raises(ValueError, match="Invalid IP"):
|
||||
await jail_service.unban_ip(_SOCKET, "bad-ip")
|
||||
await ban_service.unban_ip(_SOCKET, "bad-ip")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -610,7 +614,7 @@ class TestBanUnban:
|
||||
|
||||
|
||||
class TestGetActiveBans:
|
||||
"""Unit tests for :func:`~app.services.jail_service.get_active_bans`."""
|
||||
"""Unit tests for :func:`~app.services.ban_service.get_active_bans`."""
|
||||
|
||||
async def test_returns_active_ban_list_response(self) -> None:
|
||||
"""get_active_bans returns an ActiveBanListResponse."""
|
||||
@@ -622,7 +626,7 @@ class TestGetActiveBans:
|
||||
),
|
||||
}
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.get_active_bans(_SOCKET)
|
||||
result = await ban_service.get_active_bans(_SOCKET)
|
||||
|
||||
assert isinstance(result, ActiveBanListResponse)
|
||||
assert result.total == 1
|
||||
@@ -633,7 +637,7 @@ class TestGetActiveBans:
|
||||
"""get_active_bans returns empty list when no jails are active."""
|
||||
responses = {"status": (0, [("Number of jail", 0), ("Jail list", "")])}
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.get_active_bans(_SOCKET)
|
||||
result = await ban_service.get_active_bans(_SOCKET)
|
||||
|
||||
assert result.total == 0
|
||||
assert result.bans == []
|
||||
@@ -645,7 +649,7 @@ class TestGetActiveBans:
|
||||
"get|sshd|banip|--with-time": (0, []),
|
||||
}
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.get_active_bans(_SOCKET)
|
||||
result = await ban_service.get_active_bans(_SOCKET)
|
||||
|
||||
assert result.total == 0
|
||||
|
||||
@@ -659,7 +663,7 @@ class TestGetActiveBans:
|
||||
),
|
||||
}
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.get_active_bans(_SOCKET)
|
||||
result = await ban_service.get_active_bans(_SOCKET)
|
||||
|
||||
ban = result.bans[0]
|
||||
assert ban.banned_at is not None
|
||||
@@ -691,8 +695,8 @@ class TestGetActiveBans:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
self.send = AsyncMock(side_effect=_side)
|
||||
|
||||
with patch("app.services.jail_service.Fail2BanClient", _FakeClientPartial):
|
||||
result = await jail_service.get_active_bans(_SOCKET)
|
||||
with patch("app.services.ban_service.Fail2BanClient", _FakeClientPartial):
|
||||
result = await ban_service.get_active_bans(_SOCKET)
|
||||
|
||||
# Only sshd ban returned (nginx silently skipped)
|
||||
assert result.total == 1
|
||||
@@ -714,7 +718,7 @@ class TestGetActiveBans:
|
||||
|
||||
with _patch_client(responses):
|
||||
mock_session = AsyncMock()
|
||||
result = await jail_service.get_active_bans(
|
||||
result = await ban_service.get_active_bans(
|
||||
_SOCKET,
|
||||
http_session=mock_session,
|
||||
geo_batch_lookup=mock_batch,
|
||||
@@ -738,7 +742,7 @@ class TestGetActiveBans:
|
||||
|
||||
with _patch_client(responses):
|
||||
mock_session = AsyncMock()
|
||||
result = await jail_service.get_active_bans(
|
||||
result = await ban_service.get_active_bans(
|
||||
_SOCKET,
|
||||
http_session=mock_session,
|
||||
geo_batch_lookup=failing_batch,
|
||||
@@ -763,7 +767,7 @@ class TestGetActiveBans:
|
||||
return GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None)
|
||||
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.get_active_bans(
|
||||
result = await ban_service.get_active_bans(
|
||||
_SOCKET, geo_enricher=_enricher
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user