Task 13: move ban_ip, unban_ip, and get_active_bans from jail_service to ban_service and update routers/tests

This commit is contained in:
2026-04-17 16:22:20 +02:00
parent 6e1e3c4546
commit 8c6950afc1
9 changed files with 366 additions and 247 deletions

View File

@@ -241,6 +241,8 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
### Task 13 — Move `ban_ip`, `unban_ip`, and `get_active_bans` from `jail_service` to `ban_service` ### Task 13 — Move `ban_ip`, `unban_ip`, and `get_active_bans` from `jail_service` to `ban_service`
**Status:** Completed ✅
**Found in:** `backend/app/services/jail_service.py` contains `ban_ip`, `unban_ip`, and `get_active_bans`. These operations conceptually belong in `ban_service.py`, which is the declared home for ban management. Routers `bans.py` and `blocklist.py` already import `jail_service` specifically for these functions. **Found in:** `backend/app/services/jail_service.py` contains `ban_ip`, `unban_ip`, and `get_active_bans`. These operations conceptually belong in `ban_service.py`, which is the declared home for ban management. Routers `bans.py` and `blocklist.py` already import `jail_service` specifically for these functions.
**Goal:** Move the three functions into `ban_service.py`. Update `bans.py` and `blocklist.py` to import from `ban_service` instead of `jail_service`. Remove the now-redundant `jail_service` import from those routers if it is no longer needed. **Goal:** Move the three functions into `ban_service.py`. Update `bans.py` and `blocklist.py` to import from `ban_service` instead of `jail_service`. Remove the now-redundant `jail_service` import from those routers if it is no longer needed.

View File

@@ -22,7 +22,7 @@ from app.dependencies import (
from app.exceptions import JailNotFoundError, JailOperationError from app.exceptions import JailNotFoundError, JailOperationError
from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest
from app.models.jail import JailCommandResponse from app.models.jail import JailCommandResponse
from app.services import jail_service from app.services import ban_service, jail_service
from app.exceptions import Fail2BanConnectionError from app.exceptions import Fail2BanConnectionError
router: APIRouter = APIRouter(prefix="/api/bans", tags=["Bans"]) router: APIRouter = APIRouter(prefix="/api/bans", tags=["Bans"])
@@ -72,7 +72,7 @@ async def get_active_bans(
HTTPException: 502 when fail2ban is unreachable. HTTPException: 502 when fail2ban is unreachable.
""" """
try: try:
return await jail_service.get_active_bans( return await ban_service.get_active_bans(
socket_path, socket_path,
geo_batch_lookup=geo_batch_lookup, geo_batch_lookup=geo_batch_lookup,
http_session=http_session, http_session=http_session,
@@ -114,7 +114,7 @@ async def ban_ip(
HTTPException: 502 when fail2ban is unreachable. HTTPException: 502 when fail2ban is unreachable.
""" """
try: try:
await jail_service.ban_ip(socket_path, body.jail, body.ip) await ban_service.ban_ip(socket_path, body.jail, body.ip)
return JailCommandResponse( return JailCommandResponse(
message=f"IP {body.ip!r} banned in jail {body.jail!r}.", message=f"IP {body.ip!r} banned in jail {body.jail!r}.",
jail=body.jail, jail=body.jail,
@@ -174,7 +174,7 @@ async def unban_ip(
target_jail: str | None = None if (body.unban_all or body.jail is None) else body.jail target_jail: str | None = None if (body.unban_all or body.jail is None) else body.jail
try: try:
await jail_service.unban_ip(socket_path, body.ip, jail=target_jail) await ban_service.unban_ip(socket_path, body.ip, jail=target_jail)
scope = f"jail {target_jail!r}" if target_jail else "all jails" scope = f"jail {target_jail!r}" if target_jail else "all jails"
return JailCommandResponse( return JailCommandResponse(
message=f"IP {body.ip!r} unbanned from {scope}.", message=f"IP {body.ip!r} unbanned from {scope}.",

View File

@@ -44,7 +44,7 @@ from app.models.blocklist import (
ScheduleConfig, ScheduleConfig,
ScheduleInfo, ScheduleInfo,
) )
from app.services import blocklist_service, geo_service, jail_service from app.services import ban_service, blocklist_service, geo_service
from app.tasks.blocklist_import import run_import_with_resources from app.tasks.blocklist_import import run_import_with_resources
router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"]) router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"])
@@ -138,7 +138,7 @@ async def run_import_now(
socket_path, socket_path,
geo_is_cached=geo_service.is_cached, geo_is_cached=geo_service.is_cached,
geo_batch_lookup=geo_batch_lookup, geo_batch_lookup=geo_batch_lookup,
ban_ip=jail_service.ban_ip, ban_ip=ban_service.ban_ip,
) )

View File

@@ -11,11 +11,14 @@ so BanGUI never modifies or locks the fail2ban database.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextlib
import ipaddress
import time import time
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, cast
import structlog import structlog
from app.exceptions import JailNotFoundError, JailOperationError
from app.models.ban import ( from app.models.ban import (
BLOCKLIST_JAIL, BLOCKLIST_JAIL,
BUCKET_SECONDS, BUCKET_SECONDS,
@@ -26,6 +29,8 @@ from app.models.ban import (
BansByJailResponse, BansByJailResponse,
BanTrendBucket, BanTrendBucket,
BanTrendResponse, BanTrendResponse,
ActiveBan,
ActiveBanListResponse,
DashboardBanItem, DashboardBanItem,
DashboardBanListResponse, DashboardBanListResponse,
TimeRange, TimeRange,
@@ -42,6 +47,10 @@ from app.repositories.history_archive_repo import (
) )
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso
from app.utils.fail2ban_client import (
Fail2BanClient,
Fail2BanResponse,
)
if TYPE_CHECKING: if TYPE_CHECKING:
import aiohttp import aiohttp
@@ -70,6 +79,114 @@ _SOCKET_TIMEOUT: float = 5.0
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _ok(response: object) -> object:
"""Extract the payload from a fail2ban ``(return_code, data)`` response.
Args:
response: Raw value returned by :meth:`~Fail2BanClient.send`.
Returns:
The payload ``data`` portion of the response.
Raises:
ValueError: If the response indicates an error (return code ≠ 0).
"""
try:
code, data = response # type: ignore[assignment]
except (TypeError, ValueError) as exc:
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
if code != 0:
raise ValueError(f"fail2ban returned error code {code}: {data!r}")
return data
def _to_dict(pairs: object) -> dict[str, object]:
"""Convert a list of ``(key, value)`` pairs to a plain dict.
Args:
pairs: A list of ``(key, value)`` pairs (or any iterable thereof).
Returns:
A :class:`dict` with the keys and values from *pairs*.
"""
if not isinstance(pairs, (list, tuple)):
return {}
result: dict[str, object] = {}
for item in pairs:
try:
k, v = item
result[str(k)] = v
except (TypeError, ValueError):
pass
return result
def _ensure_list(value: object | None) -> list[str]:
"""Coerce a fail2ban response value to a list of strings."""
if value is None:
return []
if isinstance(value, str):
return [value] if value.strip() else []
if isinstance(value, (list, tuple)):
return [str(v) for v in value if v is not None]
return [str(value)]
def _is_not_found_error(exc: Exception) -> bool:
"""Return ``True`` if *exc* indicates a jail does not exist."""
msg = str(exc).lower()
return any(
phrase in msg
for phrase in (
"unknown jail",
"unknownjail",
"no jail",
"does not exist",
"not found",
)
)
async def ban_ip(socket_path: str, jail: str, ip: str) -> None:
"""Ban an IP address in the specified jail."""
try:
ipaddress.ip_address(ip)
except ValueError as exc:
raise ValueError(f"Invalid IP address: {ip!r}") from exc
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
try:
_ok(await client.send(["set", jail, "banip", ip]))
except ValueError as exc:
if _is_not_found_error(exc):
raise JailNotFoundError(jail) from exc
raise JailOperationError(str(exc)) from exc
async def unban_ip(socket_path: str, ip: str, jail: str | None = None) -> None:
"""Unban an IP address from a specific jail or all jails."""
try:
ipaddress.ip_address(ip)
except ValueError as exc:
raise ValueError(f"Invalid IP address: {ip!r}") from exc
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
if jail is None:
_ok(await client.send(["unban", ip]))
return
try:
_ok(await client.send(["set", jail, "unbanip", ip]))
except ValueError as exc:
if _is_not_found_error(exc):
raise JailNotFoundError(jail) from exc
raise JailOperationError(str(exc)) from exc
def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]: def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
"""Return a SQL fragment and its parameters for the origin filter. """Return a SQL fragment and its parameters for the origin filter.
@@ -88,6 +205,191 @@ def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
return "", () return "", ()
def _parse_ban_entry(entry: str, jail: str) -> ActiveBan | None:
"""Parse a ban entry from ``get <jail> banip --with-time`` output."""
from datetime import UTC, datetime
try:
parts = entry.split("\t", 1)
ip = parts[0].strip()
ipaddress.ip_address(ip)
if len(parts) < 2:
return ActiveBan(
ip=ip,
jail=jail,
banned_at=None,
expires_at=None,
ban_count=1,
country=None,
)
time_part = parts[1].strip()
plus_idx = time_part.find(" + ")
if plus_idx == -1:
banned_at_str = time_part.strip()
expires_at_str: str | None = None
else:
banned_at_str = time_part[:plus_idx].strip()
remainder = time_part[plus_idx + 3 :]
eq_idx = remainder.find(" = ")
expires_at_str = remainder[eq_idx + 3 :].strip() if eq_idx != -1 else None
_date_fmt = "%Y-%m-%d %H:%M:%S"
def _to_iso(ts: str) -> str:
dt = datetime.strptime(ts, _date_fmt).replace(tzinfo=UTC)
return dt.isoformat()
banned_at_iso: str | None = None
expires_at_iso: str | None = None
with contextlib.suppress(ValueError):
banned_at_iso = _to_iso(banned_at_str)
with contextlib.suppress(ValueError):
if expires_at_str:
expires_at_iso = _to_iso(expires_at_str)
return ActiveBan(
ip=ip,
jail=jail,
banned_at=banned_at_iso,
expires_at=expires_at_iso,
ban_count=1,
country=None,
)
except (ValueError, IndexError, AttributeError) as exc:
log.debug("ban_entry_parse_error", entry=entry, jail=jail, error=str(exc))
return None
async def _enrich_bans(
bans: list[ActiveBan],
geo_enricher: GeoEnricher,
) -> list[ActiveBan]:
"""Enrich ban records with geo data asynchronously."""
geo_results: list[object | Exception] = await asyncio.gather(
*[cast("Awaitable[object]", geo_enricher(ban.ip)) for ban in bans],
return_exceptions=True,
)
enriched: list[ActiveBan] = []
for ban, geo in zip(bans, geo_results, strict=False):
if geo is not None and not isinstance(geo, Exception):
geo_info = cast("GeoInfo", geo)
enriched.append(ban.model_copy(update={"country": geo_info.country_code}))
else:
enriched.append(ban)
return enriched
async def get_active_bans(
socket_path: str,
geo_batch_lookup: GeoBatchLookup | None = None,
geo_enricher: GeoEnricher | None = None,
http_session: aiohttp.ClientSession | None = None,
app_db: aiosqlite.Connection | None = None,
) -> ActiveBanListResponse:
"""Return all currently banned IPs across every jail.
For each jail the ``get <jail> banip --with-time`` command is used
to retrieve ban start and expiry times alongside the IP address.
Geo enrichment strategy (highest priority first):
1. When *http_session* is provided the entire set of banned IPs is resolved
in a single :func:`~app.services.geo_service.lookup_batch` call (up to
100 IPs per HTTP request). This is far more efficient than concurrent
per-IP lookups and stays within ip-api.com rate limits.
2. When only *geo_enricher* is provided (legacy / test path) each IP is
resolved individually via the supplied async callable.
Args:
socket_path: Path to the fail2ban Unix domain socket.
geo_enricher: Optional async callable ``(ip) -> GeoInfo | None``.
Used to enrich each ban entry with country and ASN data.
Ignored when *http_session* is provided.
http_session: Optional shared :class:`aiohttp.ClientSession`. When
provided, :func:`~app.services.geo_service.lookup_batch` is used
for efficient bulk geo resolution.
app_db: Optional BanGUI application database connection used to
persist newly resolved geo entries across restarts. Only
meaningful when *http_session* is provided.
Returns:
:class:`~app.models.ban.ActiveBanListResponse` with all active bans.
Raises:
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
cannot be reached.
"""
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
global_status = _to_dict(_ok(await client.send(["status"])))
jail_list_raw: str = str(global_status.get("Jail list", "") or "").strip()
jail_names: list[str] = (
[j.strip() for j in jail_list_raw.split(",") if j.strip()]
if jail_list_raw
else []
)
if not jail_names:
return ActiveBanListResponse(bans=[], total=0)
results: list[object | Exception] = await asyncio.gather(
*[client.send(["get", jn, "banip", "--with-time"]) for jn in jail_names],
return_exceptions=True,
)
bans: list[ActiveBan] = []
for jail_name, raw_result in zip(jail_names, results, strict=False):
if isinstance(raw_result, Exception):
log.warning(
"active_bans_fetch_error",
jail=jail_name,
error=str(raw_result),
)
continue
try:
ban_list: list[str] = cast("list[str]", _ok(raw_result)) or []
except (TypeError, ValueError) as exc:
log.warning(
"active_bans_parse_error",
jail=jail_name,
error=str(exc),
)
continue
for entry in ban_list:
ban = _parse_ban_entry(str(entry), jail_name)
if ban is not None:
bans.append(ban)
if http_session is not None and bans and geo_batch_lookup is not None:
all_ips: list[str] = [ban.ip for ban in bans]
try:
geo_map = await geo_batch_lookup(all_ips, http_session, db=app_db)
except Exception: # noqa: BLE001
log.warning("active_bans_batch_geo_failed")
geo_map = {}
enriched: list[ActiveBan] = []
for ban in bans:
geo = geo_map.get(ban.ip)
if geo is not None:
enriched.append(ban.model_copy(update={"country": geo.country_code}))
else:
enriched.append(ban)
bans = enriched
elif geo_enricher is not None:
bans = await _enrich_bans(bans, geo_enricher)
log.info("active_bans_fetched", total=len(bans))
return ActiveBanListResponse(bans=bans, total=len(bans))
_TIME_RANGE_SLACK_SECONDS: int = 60 _TIME_RANGE_SLACK_SECONDS: int = 60
@@ -502,7 +804,7 @@ async def bans_by_country(
country_names[cc] = cn country_names[cc] = cn
# Build companion table from recent rows (geo already cached from batch step). # Build companion table from recent rows (geo already cached from batch step).
bans: list[DashboardBanItem] = [] bans: list[ActiveBan] = []
for companion_row in companion_rows: for companion_row in companion_rows:
if source == "archive": if source == "archive":
ip = companion_row["ip"] ip = companion_row["ip"]

View File

@@ -802,194 +802,6 @@ async def restart_daemon(
) )
# ---------------------------------------------------------------------------
# Public API — Ban / Unban
# ---------------------------------------------------------------------------
async def ban_ip(socket_path: str, jail: str, ip: str) -> None:
"""Ban an IP address in a specific fail2ban jail.
The IP address is validated with :mod:`ipaddress` before the command
is sent to fail2ban.
Args:
socket_path: Path to the fail2ban Unix domain socket.
jail: Jail in which to apply the ban.
ip: IP address to ban (IPv4 or IPv6).
Raises:
ValueError: If *ip* is not a valid IP address.
JailNotFoundError: If *jail* is not a known jail.
JailOperationError: If fail2ban reports the operation failed.
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
cannot be reached.
"""
# Validate the IP address before sending to avoid injection.
try:
ipaddress.ip_address(ip)
except ValueError as exc:
raise ValueError(f"Invalid IP address: {ip!r}") from exc
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
try:
_ok(await client.send(["set", jail, "banip", ip]))
log.info("ip_banned", ip=ip, jail=jail)
except ValueError as exc:
if _is_not_found_error(exc):
raise JailNotFoundError(jail) from exc
raise JailOperationError(str(exc)) from exc
async def unban_ip(
socket_path: str,
ip: str,
jail: str | None = None,
) -> None:
"""Unban an IP address from one or all fail2ban jails.
If *jail* is ``None``, the IP is unbanned from every jail using the
global ``unban`` command. Otherwise only the specified jail is
targeted.
Args:
socket_path: Path to the fail2ban Unix domain socket.
ip: IP address to unban.
jail: Jail to unban from. ``None`` means all jails.
Raises:
ValueError: If *ip* is not a valid IP address.
JailNotFoundError: If *jail* is specified but does not exist.
JailOperationError: If fail2ban reports the operation failed.
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
cannot be reached.
"""
try:
ipaddress.ip_address(ip)
except ValueError as exc:
raise ValueError(f"Invalid IP address: {ip!r}") from exc
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
try:
if jail is None:
_ok(await client.send(["unban", ip]))
log.info("ip_unbanned_all_jails", ip=ip)
else:
_ok(await client.send(["set", jail, "unbanip", ip]))
log.info("ip_unbanned", ip=ip, jail=jail)
except ValueError as exc:
if _is_not_found_error(exc):
raise JailNotFoundError(jail or "") from exc
raise JailOperationError(str(exc)) from exc
async def get_active_bans(
socket_path: str,
geo_batch_lookup: GeoBatchLookup | None = None,
geo_enricher: GeoEnricher | None = None,
http_session: aiohttp.ClientSession | None = None,
app_db: aiosqlite.Connection | None = None,
) -> ActiveBanListResponse:
"""Return all currently banned IPs across every jail.
For each jail the ``get <jail> banip --with-time`` command is used
to retrieve ban start and expiry times alongside the IP address.
Geo enrichment strategy (highest priority first):
1. When *http_session* is provided the entire set of banned IPs is resolved
in a single :func:`~app.services.geo_service.lookup_batch` call (up to
100 IPs per HTTP request). This is far more efficient than concurrent
per-IP lookups and stays within ip-api.com rate limits.
2. When only *geo_enricher* is provided (legacy / test path) each IP is
resolved individually via the supplied async callable.
Args:
socket_path: Path to the fail2ban Unix domain socket.
geo_enricher: Optional async callable ``(ip) → GeoInfo | None``
used to enrich each ban entry with country and ASN data.
Ignored when *http_session* is provided.
http_session: Optional shared :class:`aiohttp.ClientSession`. When
provided, :func:`~app.services.geo_service.lookup_batch` is used
for efficient bulk geo resolution.
app_db: Optional BanGUI application database connection used to
persist newly resolved geo entries across restarts. Only
meaningful when *http_session* is provided.
Returns:
:class:`~app.models.ban.ActiveBanListResponse` with all active bans.
Raises:
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
cannot be reached.
"""
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
# Fetch jail names.
global_status = _to_dict(_ok(await client.send(["status"])))
jail_list_raw: str = str(global_status.get("Jail list", "") or "").strip()
jail_names: list[str] = (
[j.strip() for j in jail_list_raw.split(",") if j.strip()]
if jail_list_raw
else []
)
if not jail_names:
return ActiveBanListResponse(bans=[], total=0)
# For each jail, fetch the ban list with time info in parallel.
results: list[object | Exception] = await asyncio.gather(
*[client.send(["get", jn, "banip", "--with-time"]) for jn in jail_names],
return_exceptions=True,
)
bans: list[ActiveBan] = []
for jail_name, raw_result in zip(jail_names, results, strict=False):
if isinstance(raw_result, Exception):
log.warning(
"active_bans_fetch_error",
jail=jail_name,
error=str(raw_result),
)
continue
try:
ban_list: list[str] = cast("list[str]", _ok(raw_result)) or []
except (TypeError, ValueError) as exc:
log.warning(
"active_bans_parse_error",
jail=jail_name,
error=str(exc),
)
continue
for entry in ban_list:
ban = _parse_ban_entry(str(entry), jail_name)
if ban is not None:
bans.append(ban)
# Enrich with geo data — prefer batch lookup over per-IP enricher.
if http_session is not None and bans and geo_batch_lookup is not None:
all_ips: list[str] = [ban.ip for ban in bans]
try:
geo_map = await geo_batch_lookup(all_ips, http_session, db=app_db)
except Exception: # noqa: BLE001
log.warning("active_bans_batch_geo_failed")
geo_map = {}
enriched: list[ActiveBan] = []
for ban in bans:
geo = geo_map.get(ban.ip)
if geo is not None:
enriched.append(ban.model_copy(update={"country": geo.country_code}))
else:
enriched.append(ban)
bans = enriched
elif geo_enricher is not None:
bans = await _enrich_bans(bans, geo_enricher)
log.info("active_bans_fetched", total=len(bans))
return ActiveBanListResponse(bans=bans, total=len(bans))
def _parse_ban_entry(entry: str, jail: str) -> ActiveBan | None: def _parse_ban_entry(entry: str, jail: str) -> ActiveBan | None:

View File

@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any
import structlog import structlog
from app.db import open_db from app.db import open_db
from app.services import blocklist_service, jail_service from app.services import ban_service, blocklist_service
from app.utils.runtime_state import get_effective_settings from app.utils.runtime_state import get_effective_settings
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -55,7 +55,7 @@ async def _run_import_with_resources(settings: Settings, http_session: ClientSes
db, db,
http_session, http_session,
socket_path, socket_path,
ban_ip=jail_service.ban_ip, ban_ip=ban_service.ban_ip,
) )
log.info( log.info(
"blocklist_import_finished", "blocklist_import_finished",

View File

@@ -84,7 +84,7 @@ class TestGetActiveBans:
total=1, total=1,
) )
with patch( with patch(
"app.routers.bans.jail_service.get_active_bans", "app.routers.bans.ban_service.get_active_bans",
AsyncMock(return_value=mock_response), AsyncMock(return_value=mock_response),
): ):
resp = await bans_client.get("/api/bans/active") resp = await bans_client.get("/api/bans/active")
@@ -107,7 +107,7 @@ class TestGetActiveBans:
"""GET /api/bans/active returns empty list when no bans are active.""" """GET /api/bans/active returns empty list when no bans are active."""
mock_response = ActiveBanListResponse(bans=[], total=0) mock_response = ActiveBanListResponse(bans=[], total=0)
with patch( with patch(
"app.routers.bans.jail_service.get_active_bans", "app.routers.bans.ban_service.get_active_bans",
AsyncMock(return_value=mock_response), AsyncMock(return_value=mock_response),
): ):
resp = await bans_client.get("/api/bans/active") resp = await bans_client.get("/api/bans/active")
@@ -132,7 +132,7 @@ class TestGetActiveBans:
total=1, total=1,
) )
with patch( with patch(
"app.routers.bans.jail_service.get_active_bans", "app.routers.bans.ban_service.get_active_bans",
AsyncMock(return_value=mock_response), AsyncMock(return_value=mock_response),
): ):
resp = await bans_client.get("/api/bans/active") resp = await bans_client.get("/api/bans/active")
@@ -156,7 +156,7 @@ class TestBanIp:
async def test_201_on_success(self, bans_client: AsyncClient) -> None: async def test_201_on_success(self, bans_client: AsyncClient) -> None:
"""POST /api/bans returns 201 when the IP is banned.""" """POST /api/bans returns 201 when the IP is banned."""
with patch( with patch(
"app.routers.bans.jail_service.ban_ip", "app.routers.bans.ban_service.ban_ip",
AsyncMock(return_value=None), AsyncMock(return_value=None),
): ):
resp = await bans_client.post( resp = await bans_client.post(
@@ -170,7 +170,7 @@ class TestBanIp:
async def test_400_for_invalid_ip(self, bans_client: AsyncClient) -> None: async def test_400_for_invalid_ip(self, bans_client: AsyncClient) -> None:
"""POST /api/bans returns 400 for an invalid IP address.""" """POST /api/bans returns 400 for an invalid IP address."""
with patch( with patch(
"app.routers.bans.jail_service.ban_ip", "app.routers.bans.ban_service.ban_ip",
AsyncMock(side_effect=ValueError("Invalid IP address: 'bad'")), AsyncMock(side_effect=ValueError("Invalid IP address: 'bad'")),
): ):
resp = await bans_client.post( resp = await bans_client.post(
@@ -185,7 +185,7 @@ class TestBanIp:
from app.services.jail_service import JailNotFoundError from app.services.jail_service import JailNotFoundError
with patch( with patch(
"app.routers.bans.jail_service.ban_ip", "app.routers.bans.ban_service.ban_ip",
AsyncMock(side_effect=JailNotFoundError("ghost")), AsyncMock(side_effect=JailNotFoundError("ghost")),
): ):
resp = await bans_client.post( resp = await bans_client.post(
@@ -215,7 +215,7 @@ class TestUnbanIp:
async def test_200_unban_from_all(self, bans_client: AsyncClient) -> None: async def test_200_unban_from_all(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans with unban_all=true unbans from all jails.""" """DELETE /api/bans with unban_all=true unbans from all jails."""
with patch( with patch(
"app.routers.bans.jail_service.unban_ip", "app.routers.bans.ban_service.unban_ip",
AsyncMock(return_value=None), AsyncMock(return_value=None),
): ):
resp = await bans_client.request( resp = await bans_client.request(
@@ -230,7 +230,7 @@ class TestUnbanIp:
async def test_200_unban_from_specific_jail(self, bans_client: AsyncClient) -> None: async def test_200_unban_from_specific_jail(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans with a jail unbans from that jail only.""" """DELETE /api/bans with a jail unbans from that jail only."""
with patch( with patch(
"app.routers.bans.jail_service.unban_ip", "app.routers.bans.ban_service.unban_ip",
AsyncMock(return_value=None), AsyncMock(return_value=None),
): ):
resp = await bans_client.request( resp = await bans_client.request(
@@ -245,7 +245,7 @@ class TestUnbanIp:
async def test_400_for_invalid_ip(self, bans_client: AsyncClient) -> None: async def test_400_for_invalid_ip(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans returns 400 for an invalid IP.""" """DELETE /api/bans returns 400 for an invalid IP."""
with patch( with patch(
"app.routers.bans.jail_service.unban_ip", "app.routers.bans.ban_service.unban_ip",
AsyncMock(side_effect=ValueError("Invalid IP address: 'bad'")), AsyncMock(side_effect=ValueError("Invalid IP address: 'bad'")),
): ):
resp = await bans_client.request( resp = await bans_client.request(
@@ -261,7 +261,7 @@ class TestUnbanIp:
from app.services.jail_service import JailNotFoundError from app.services.jail_service import JailNotFoundError
with patch( with patch(
"app.routers.bans.jail_service.unban_ip", "app.routers.bans.ban_service.unban_ip",
AsyncMock(side_effect=JailNotFoundError("ghost")), AsyncMock(side_effect=JailNotFoundError("ghost")),
): ):
resp = await bans_client.request( resp = await bans_client.request(

View File

@@ -174,17 +174,17 @@ class TestImport:
source = await blocklist_service.create_source(db, "Import Test", "https://t.test/") source = await blocklist_service.create_source(db, "Import Test", "https://t.test/")
from app.services import jail_service from app.services import ban_service
with patch( with patch(
"app.services.jail_service.ban_ip", new_callable=AsyncMock "app.services.ban_service.ban_ip", new_callable=AsyncMock
) as mock_ban: ) as mock_ban:
result = await blocklist_service.import_source( result = await blocklist_service.import_source(
source, source,
session, session,
"/tmp/fake.sock", "/tmp/fake.sock",
db, db,
ban_ip=jail_service.ban_ip, ban_ip=ban_service.ban_ip,
) )
assert result.ips_imported == 2 assert result.ips_imported == 2
@@ -198,15 +198,15 @@ class TestImport:
session = _make_session(content) session = _make_session(content)
source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/") source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/")
from app.services import jail_service from app.services import ban_service
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock): with patch("app.services.ban_service.ban_ip", new_callable=AsyncMock):
result = await blocklist_service.import_source( result = await blocklist_service.import_source(
source, source,
session, session,
"/tmp/fake.sock", "/tmp/fake.sock",
db, db,
ban_ip=jail_service.ban_ip, ban_ip=ban_service.ban_ip,
) )
assert result.ips_imported == 1 assert result.ips_imported == 1
@@ -217,14 +217,14 @@ class TestImport:
session = _make_session("", status=503) session = _make_session("", status=503)
source = await blocklist_service.create_source(db, "Err Source", "https://err.test/") source = await blocklist_service.create_source(db, "Err Source", "https://err.test/")
from app.services import jail_service from app.services import ban_service
result = await blocklist_service.import_source( result = await blocklist_service.import_source(
source, source,
session, session,
"/tmp/fake.sock", "/tmp/fake.sock",
db, db,
ban_ip=jail_service.ban_ip, ban_ip=ban_service.ban_ip,
) )
assert result.ips_imported == 0 assert result.ips_imported == 0
@@ -234,6 +234,7 @@ class TestImport:
"""import_source aborts immediately and records an error when the target jail """import_source aborts immediately and records an error when the target jail
does not exist in fail2ban instead of silently skipping every IP.""" does not exist in fail2ban instead of silently skipping every IP."""
from app.services.jail_service import JailNotFoundError from app.services.jail_service import JailNotFoundError
from app.services import ban_service
content = "\n".join(f"1.2.3.{i}" for i in range(100)) content = "\n".join(f"1.2.3.{i}" for i in range(100))
session = _make_session(content) session = _make_session(content)
@@ -246,15 +247,13 @@ class TestImport:
call_count += 1 call_count += 1
raise JailNotFoundError(jail) raise JailNotFoundError(jail)
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found): with patch("app.services.ban_service.ban_ip", side_effect=_raise_jail_not_found):
from app.services import jail_service
result = await blocklist_service.import_source( result = await blocklist_service.import_source(
source, source,
session, session,
"/tmp/fake.sock", "/tmp/fake.sock",
db, db,
ban_ip=jail_service.ban_ip, ban_ip=ban_service.ban_ip,
) )
# Must abort after the first JailNotFoundError — only one ban attempt. # Must abort after the first JailNotFoundError — only one ban attempt.
@@ -273,15 +272,15 @@ class TestImport:
session = _make_session(content) session = _make_session(content)
with patch( with patch(
"app.services.jail_service.ban_ip", new_callable=AsyncMock "app.services.ban_service.ban_ip", new_callable=AsyncMock
): ):
from app.services import jail_service from app.services import ban_service
result = await blocklist_service.import_all( result = await blocklist_service.import_all(
db, db,
session, session,
"/tmp/fake.sock", "/tmp/fake.sock",
ban_ip=jail_service.ban_ip, ban_ip=ban_service.ban_ip,
) )
# Only S1 is enabled, S2 is disabled. # Only S1 is enabled, S2 is disabled.
@@ -415,16 +414,16 @@ class TestGeoPrewarmCacheFilter:
def _mock_is_cached(ip: str) -> bool: def _mock_is_cached(ip: str) -> bool:
return ip == "1.2.3.4" return ip == "1.2.3.4"
from app.services import jail_service from app.services import ban_service
mock_batch = AsyncMock(return_value={}) mock_batch = AsyncMock(return_value={})
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock): with patch("app.services.ban_service.ban_ip", new_callable=AsyncMock):
result = await blocklist_service.import_source( result = await blocklist_service.import_source(
source, source,
session, session,
"/tmp/fake.sock", "/tmp/fake.sock",
db, db,
ban_ip=jail_service.ban_ip, ban_ip=ban_service.ban_ip,
geo_is_cached=_mock_is_cached, geo_is_cached=_mock_is_cached,
geo_batch_lookup=mock_batch, geo_batch_lookup=mock_batch,
) )

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import contextlib
from typing import Any from typing import Any
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
@@ -12,7 +13,7 @@ from app.exceptions import Fail2BanConnectionError
from app.models.ban import ActiveBanListResponse, JailBannedIpsResponse from app.models.ban import ActiveBanListResponse, JailBannedIpsResponse
from app.models.geo import GeoDetail, GeoInfo from app.models.geo import GeoDetail, GeoInfo
from app.models.jail import JailDetailResponse, JailListResponse from app.models.jail import JailDetailResponse, JailListResponse
from app.services import jail_service from app.services import ban_service, jail_service
from app.services.jail_service import JailNotFoundError, JailOperationError from app.services.jail_service import JailNotFoundError, JailOperationError
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -71,7 +72,10 @@ def _patch_client(responses: dict[str, Any]) -> Any:
def __init__(self, **_kw: Any) -> None: def __init__(self, **_kw: Any) -> None:
self.send = mock_send self.send = mock_send
return patch("app.services.jail_service.Fail2BanClient", _FakeClient) stack = contextlib.ExitStack()
stack.enter_context(patch("app.services.jail_service.Fail2BanClient", _FakeClient))
stack.enter_context(patch("app.services.ban_service.Fail2BanClient", _FakeClient))
return stack
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -555,19 +559,19 @@ class TestJailControls:
class TestBanUnban: class TestBanUnban:
"""Unit tests for :func:`~app.services.jail_service.ban_ip` and """Unit tests for :func:`~app.services.ban_service.ban_ip` and
:func:`~app.services.jail_service.unban_ip`. :func:`~app.services.ban_service.unban_ip`.
""" """
async def test_ban_ip_success(self) -> None: async def test_ban_ip_success(self) -> None:
"""ban_ip sends the banip command for a valid IP.""" """ban_ip sends the banip command for a valid IP."""
with _patch_client({"set|sshd|banip|1.2.3.4": (0, 1)}): with _patch_client({"set|sshd|banip|1.2.3.4": (0, 1)}):
await jail_service.ban_ip(_SOCKET, "sshd", "1.2.3.4") # should not raise await ban_service.ban_ip(_SOCKET, "sshd", "1.2.3.4") # should not raise
async def test_ban_ip_invalid_raises(self) -> None: async def test_ban_ip_invalid_raises(self) -> None:
"""ban_ip raises ValueError for a non-IP value.""" """ban_ip raises ValueError for a non-IP value."""
with pytest.raises(ValueError, match="Invalid IP"): with pytest.raises(ValueError, match="Invalid IP"):
await jail_service.ban_ip(_SOCKET, "sshd", "not-an-ip") await ban_service.ban_ip(_SOCKET, "sshd", "not-an-ip")
async def test_ban_ip_unknown_jail_exception_raises_jail_not_found(self) -> None: async def test_ban_ip_unknown_jail_exception_raises_jail_not_found(self) -> None:
"""ban_ip raises JailNotFoundError when fail2ban returns UnknownJailException. """ban_ip raises JailNotFoundError when fail2ban returns UnknownJailException.
@@ -581,27 +585,27 @@ class TestBanUnban:
_patch_client({"set|missing-jail|banip|1.2.3.4": response}), _patch_client({"set|missing-jail|banip|1.2.3.4": response}),
pytest.raises(JailNotFoundError, match="missing-jail"), pytest.raises(JailNotFoundError, match="missing-jail"),
): ):
await jail_service.ban_ip(_SOCKET, "missing-jail", "1.2.3.4") await ban_service.ban_ip(_SOCKET, "missing-jail", "1.2.3.4")
async def test_ban_ipv6_success(self) -> None: async def test_ban_ipv6_success(self) -> None:
"""ban_ip accepts an IPv6 address.""" """ban_ip accepts an IPv6 address."""
with _patch_client({"set|sshd|banip|::1": (0, 1)}): with _patch_client({"set|sshd|banip|::1": (0, 1)}):
await jail_service.ban_ip(_SOCKET, "sshd", "::1") # should not raise await ban_service.ban_ip(_SOCKET, "sshd", "::1") # should not raise
async def test_unban_ip_all_jails(self) -> None: async def test_unban_ip_all_jails(self) -> None:
"""unban_ip with jail=None uses the global unban command.""" """unban_ip with jail=None uses the global unban command."""
with _patch_client({"unban|1.2.3.4": (0, 1)}): with _patch_client({"unban|1.2.3.4": (0, 1)}):
await jail_service.unban_ip(_SOCKET, "1.2.3.4") # should not raise await ban_service.unban_ip(_SOCKET, "1.2.3.4") # should not raise
async def test_unban_ip_specific_jail(self) -> None: async def test_unban_ip_specific_jail(self) -> None:
"""unban_ip with a jail sends the set unbanip command.""" """unban_ip with a jail sends the set unbanip command."""
with _patch_client({"set|sshd|unbanip|1.2.3.4": (0, 1)}): with _patch_client({"set|sshd|unbanip|1.2.3.4": (0, 1)}):
await jail_service.unban_ip(_SOCKET, "1.2.3.4", jail="sshd") # should not raise await ban_service.unban_ip(_SOCKET, "1.2.3.4", jail="sshd") # should not raise
async def test_unban_invalid_ip_raises(self) -> None: async def test_unban_invalid_ip_raises(self) -> None:
"""unban_ip raises ValueError for an invalid IP.""" """unban_ip raises ValueError for an invalid IP."""
with pytest.raises(ValueError, match="Invalid IP"): with pytest.raises(ValueError, match="Invalid IP"):
await jail_service.unban_ip(_SOCKET, "bad-ip") await ban_service.unban_ip(_SOCKET, "bad-ip")
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -610,7 +614,7 @@ class TestBanUnban:
class TestGetActiveBans: class TestGetActiveBans:
"""Unit tests for :func:`~app.services.jail_service.get_active_bans`.""" """Unit tests for :func:`~app.services.ban_service.get_active_bans`."""
async def test_returns_active_ban_list_response(self) -> None: async def test_returns_active_ban_list_response(self) -> None:
"""get_active_bans returns an ActiveBanListResponse.""" """get_active_bans returns an ActiveBanListResponse."""
@@ -622,7 +626,7 @@ class TestGetActiveBans:
), ),
} }
with _patch_client(responses): with _patch_client(responses):
result = await jail_service.get_active_bans(_SOCKET) result = await ban_service.get_active_bans(_SOCKET)
assert isinstance(result, ActiveBanListResponse) assert isinstance(result, ActiveBanListResponse)
assert result.total == 1 assert result.total == 1
@@ -633,7 +637,7 @@ class TestGetActiveBans:
"""get_active_bans returns empty list when no jails are active.""" """get_active_bans returns empty list when no jails are active."""
responses = {"status": (0, [("Number of jail", 0), ("Jail list", "")])} responses = {"status": (0, [("Number of jail", 0), ("Jail list", "")])}
with _patch_client(responses): with _patch_client(responses):
result = await jail_service.get_active_bans(_SOCKET) result = await ban_service.get_active_bans(_SOCKET)
assert result.total == 0 assert result.total == 0
assert result.bans == [] assert result.bans == []
@@ -645,7 +649,7 @@ class TestGetActiveBans:
"get|sshd|banip|--with-time": (0, []), "get|sshd|banip|--with-time": (0, []),
} }
with _patch_client(responses): with _patch_client(responses):
result = await jail_service.get_active_bans(_SOCKET) result = await ban_service.get_active_bans(_SOCKET)
assert result.total == 0 assert result.total == 0
@@ -659,7 +663,7 @@ class TestGetActiveBans:
), ),
} }
with _patch_client(responses): with _patch_client(responses):
result = await jail_service.get_active_bans(_SOCKET) result = await ban_service.get_active_bans(_SOCKET)
ban = result.bans[0] ban = result.bans[0]
assert ban.banned_at is not None assert ban.banned_at is not None
@@ -691,8 +695,8 @@ class TestGetActiveBans:
def __init__(self, **_kw: Any) -> None: def __init__(self, **_kw: Any) -> None:
self.send = AsyncMock(side_effect=_side) self.send = AsyncMock(side_effect=_side)
with patch("app.services.jail_service.Fail2BanClient", _FakeClientPartial): with patch("app.services.ban_service.Fail2BanClient", _FakeClientPartial):
result = await jail_service.get_active_bans(_SOCKET) result = await ban_service.get_active_bans(_SOCKET)
# Only sshd ban returned (nginx silently skipped) # Only sshd ban returned (nginx silently skipped)
assert result.total == 1 assert result.total == 1
@@ -714,7 +718,7 @@ class TestGetActiveBans:
with _patch_client(responses): with _patch_client(responses):
mock_session = AsyncMock() mock_session = AsyncMock()
result = await jail_service.get_active_bans( result = await ban_service.get_active_bans(
_SOCKET, _SOCKET,
http_session=mock_session, http_session=mock_session,
geo_batch_lookup=mock_batch, geo_batch_lookup=mock_batch,
@@ -738,7 +742,7 @@ class TestGetActiveBans:
with _patch_client(responses): with _patch_client(responses):
mock_session = AsyncMock() mock_session = AsyncMock()
result = await jail_service.get_active_bans( result = await ban_service.get_active_bans(
_SOCKET, _SOCKET,
http_session=mock_session, http_session=mock_session,
geo_batch_lookup=failing_batch, geo_batch_lookup=failing_batch,
@@ -763,7 +767,7 @@ class TestGetActiveBans:
return GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None) return GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None)
with _patch_client(responses): with _patch_client(responses):
result = await jail_service.get_active_bans( result = await ban_service.get_active_bans(
_SOCKET, geo_enricher=_enricher _SOCKET, geo_enricher=_enricher
) )