Task 13: move ban_ip, unban_ip, and get_active_bans from jail_service to ban_service and update routers/tests

This commit is contained in:
2026-04-17 16:22:20 +02:00
parent 6e1e3c4546
commit 8c6950afc1
9 changed files with 366 additions and 247 deletions

View File

@@ -241,6 +241,8 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
### Task 13 — Move `ban_ip`, `unban_ip`, and `get_active_bans` from `jail_service` to `ban_service`
**Status:** Completed ✅
**Found in:** `backend/app/services/jail_service.py` contains `ban_ip`, `unban_ip`, and `get_active_bans`. These operations conceptually belong in `ban_service.py`, which is the declared home for ban management. Routers `bans.py` and `blocklist.py` already import `jail_service` specifically for these functions.
**Goal:** Move the three functions into `ban_service.py`. Update `bans.py` and `blocklist.py` to import from `ban_service` instead of `jail_service`. Remove the now-redundant `jail_service` import from those routers if it is no longer needed.

View File

@@ -22,7 +22,7 @@ from app.dependencies import (
from app.exceptions import JailNotFoundError, JailOperationError
from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest
from app.models.jail import JailCommandResponse
from app.services import jail_service
from app.services import ban_service, jail_service
from app.exceptions import Fail2BanConnectionError
router: APIRouter = APIRouter(prefix="/api/bans", tags=["Bans"])
@@ -72,7 +72,7 @@ async def get_active_bans(
HTTPException: 502 when fail2ban is unreachable.
"""
try:
return await jail_service.get_active_bans(
return await ban_service.get_active_bans(
socket_path,
geo_batch_lookup=geo_batch_lookup,
http_session=http_session,
@@ -114,7 +114,7 @@ async def ban_ip(
HTTPException: 502 when fail2ban is unreachable.
"""
try:
await jail_service.ban_ip(socket_path, body.jail, body.ip)
await ban_service.ban_ip(socket_path, body.jail, body.ip)
return JailCommandResponse(
message=f"IP {body.ip!r} banned in jail {body.jail!r}.",
jail=body.jail,
@@ -174,7 +174,7 @@ async def unban_ip(
target_jail: str | None = None if (body.unban_all or body.jail is None) else body.jail
try:
await jail_service.unban_ip(socket_path, body.ip, jail=target_jail)
await ban_service.unban_ip(socket_path, body.ip, jail=target_jail)
scope = f"jail {target_jail!r}" if target_jail else "all jails"
return JailCommandResponse(
message=f"IP {body.ip!r} unbanned from {scope}.",

View File

@@ -44,7 +44,7 @@ from app.models.blocklist import (
ScheduleConfig,
ScheduleInfo,
)
from app.services import blocklist_service, geo_service, jail_service
from app.services import ban_service, blocklist_service, geo_service
from app.tasks.blocklist_import import run_import_with_resources
router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"])
@@ -138,7 +138,7 @@ async def run_import_now(
socket_path,
geo_is_cached=geo_service.is_cached,
geo_batch_lookup=geo_batch_lookup,
ban_ip=jail_service.ban_ip,
ban_ip=ban_service.ban_ip,
)

View File

@@ -11,11 +11,14 @@ so BanGUI never modifies or locks the fail2ban database.
from __future__ import annotations
import asyncio
import contextlib
import ipaddress
import time
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, cast
import structlog
from app.exceptions import JailNotFoundError, JailOperationError
from app.models.ban import (
BLOCKLIST_JAIL,
BUCKET_SECONDS,
@@ -26,6 +29,8 @@ from app.models.ban import (
BansByJailResponse,
BanTrendBucket,
BanTrendResponse,
ActiveBan,
ActiveBanListResponse,
DashboardBanItem,
DashboardBanListResponse,
TimeRange,
@@ -42,6 +47,10 @@ from app.repositories.history_archive_repo import (
)
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso
from app.utils.fail2ban_client import (
Fail2BanClient,
Fail2BanResponse,
)
if TYPE_CHECKING:
import aiohttp
@@ -70,6 +79,114 @@ _SOCKET_TIMEOUT: float = 5.0
# ---------------------------------------------------------------------------
def _ok(response: object) -> object:
"""Extract the payload from a fail2ban ``(return_code, data)`` response.
Args:
response: Raw value returned by :meth:`~Fail2BanClient.send`.
Returns:
The payload ``data`` portion of the response.
Raises:
ValueError: If the response indicates an error (return code ≠ 0).
"""
try:
code, data = response # type: ignore[assignment]
except (TypeError, ValueError) as exc:
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
if code != 0:
raise ValueError(f"fail2ban returned error code {code}: {data!r}")
return data
def _to_dict(pairs: object) -> dict[str, object]:
"""Convert a list of ``(key, value)`` pairs to a plain dict.
Args:
pairs: A list of ``(key, value)`` pairs (or any iterable thereof).
Returns:
A :class:`dict` with the keys and values from *pairs*.
"""
if not isinstance(pairs, (list, tuple)):
return {}
result: dict[str, object] = {}
for item in pairs:
try:
k, v = item
result[str(k)] = v
except (TypeError, ValueError):
pass
return result
def _ensure_list(value: object | None) -> list[str]:
"""Coerce a fail2ban response value to a list of strings."""
if value is None:
return []
if isinstance(value, str):
return [value] if value.strip() else []
if isinstance(value, (list, tuple)):
return [str(v) for v in value if v is not None]
return [str(value)]
def _is_not_found_error(exc: Exception) -> bool:
"""Return ``True`` if *exc* indicates a jail does not exist."""
msg = str(exc).lower()
return any(
phrase in msg
for phrase in (
"unknown jail",
"unknownjail",
"no jail",
"does not exist",
"not found",
)
)
async def ban_ip(socket_path: str, jail: str, ip: str) -> None:
"""Ban an IP address in the specified jail."""
try:
ipaddress.ip_address(ip)
except ValueError as exc:
raise ValueError(f"Invalid IP address: {ip!r}") from exc
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
try:
_ok(await client.send(["set", jail, "banip", ip]))
except ValueError as exc:
if _is_not_found_error(exc):
raise JailNotFoundError(jail) from exc
raise JailOperationError(str(exc)) from exc
async def unban_ip(socket_path: str, ip: str, jail: str | None = None) -> None:
"""Unban an IP address from a specific jail or all jails."""
try:
ipaddress.ip_address(ip)
except ValueError as exc:
raise ValueError(f"Invalid IP address: {ip!r}") from exc
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
if jail is None:
_ok(await client.send(["unban", ip]))
return
try:
_ok(await client.send(["set", jail, "unbanip", ip]))
except ValueError as exc:
if _is_not_found_error(exc):
raise JailNotFoundError(jail) from exc
raise JailOperationError(str(exc)) from exc
def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
"""Return a SQL fragment and its parameters for the origin filter.
@@ -88,6 +205,191 @@ def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
return "", ()
def _parse_ban_entry(entry: str, jail: str) -> ActiveBan | None:
"""Parse a ban entry from ``get <jail> banip --with-time`` output."""
from datetime import UTC, datetime
try:
parts = entry.split("\t", 1)
ip = parts[0].strip()
ipaddress.ip_address(ip)
if len(parts) < 2:
return ActiveBan(
ip=ip,
jail=jail,
banned_at=None,
expires_at=None,
ban_count=1,
country=None,
)
time_part = parts[1].strip()
plus_idx = time_part.find(" + ")
if plus_idx == -1:
banned_at_str = time_part.strip()
expires_at_str: str | None = None
else:
banned_at_str = time_part[:plus_idx].strip()
remainder = time_part[plus_idx + 3 :]
eq_idx = remainder.find(" = ")
expires_at_str = remainder[eq_idx + 3 :].strip() if eq_idx != -1 else None
_date_fmt = "%Y-%m-%d %H:%M:%S"
def _to_iso(ts: str) -> str:
dt = datetime.strptime(ts, _date_fmt).replace(tzinfo=UTC)
return dt.isoformat()
banned_at_iso: str | None = None
expires_at_iso: str | None = None
with contextlib.suppress(ValueError):
banned_at_iso = _to_iso(banned_at_str)
with contextlib.suppress(ValueError):
if expires_at_str:
expires_at_iso = _to_iso(expires_at_str)
return ActiveBan(
ip=ip,
jail=jail,
banned_at=banned_at_iso,
expires_at=expires_at_iso,
ban_count=1,
country=None,
)
except (ValueError, IndexError, AttributeError) as exc:
log.debug("ban_entry_parse_error", entry=entry, jail=jail, error=str(exc))
return None
async def _enrich_bans(
bans: list[ActiveBan],
geo_enricher: GeoEnricher,
) -> list[ActiveBan]:
"""Enrich ban records with geo data asynchronously."""
geo_results: list[object | Exception] = await asyncio.gather(
*[cast("Awaitable[object]", geo_enricher(ban.ip)) for ban in bans],
return_exceptions=True,
)
enriched: list[ActiveBan] = []
for ban, geo in zip(bans, geo_results, strict=False):
if geo is not None and not isinstance(geo, Exception):
geo_info = cast("GeoInfo", geo)
enriched.append(ban.model_copy(update={"country": geo_info.country_code}))
else:
enriched.append(ban)
return enriched
async def get_active_bans(
socket_path: str,
geo_batch_lookup: GeoBatchLookup | None = None,
geo_enricher: GeoEnricher | None = None,
http_session: aiohttp.ClientSession | None = None,
app_db: aiosqlite.Connection | None = None,
) -> ActiveBanListResponse:
"""Return all currently banned IPs across every jail.
For each jail the ``get <jail> banip --with-time`` command is used
to retrieve ban start and expiry times alongside the IP address.
Geo enrichment strategy (highest priority first):
1. When *http_session* is provided the entire set of banned IPs is resolved
in a single :func:`~app.services.geo_service.lookup_batch` call (up to
100 IPs per HTTP request). This is far more efficient than concurrent
per-IP lookups and stays within ip-api.com rate limits.
2. When only *geo_enricher* is provided (legacy / test path) each IP is
resolved individually via the supplied async callable.
Args:
socket_path: Path to the fail2ban Unix domain socket.
geo_enricher: Optional async callable ``(ip) -> GeoInfo | None``.
Used to enrich each ban entry with country and ASN data.
Ignored when *http_session* is provided.
http_session: Optional shared :class:`aiohttp.ClientSession`. When
provided, :func:`~app.services.geo_service.lookup_batch` is used
for efficient bulk geo resolution.
app_db: Optional BanGUI application database connection used to
persist newly resolved geo entries across restarts. Only
meaningful when *http_session* is provided.
Returns:
:class:`~app.models.ban.ActiveBanListResponse` with all active bans.
Raises:
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
cannot be reached.
"""
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
global_status = _to_dict(_ok(await client.send(["status"])))
jail_list_raw: str = str(global_status.get("Jail list", "") or "").strip()
jail_names: list[str] = (
[j.strip() for j in jail_list_raw.split(",") if j.strip()]
if jail_list_raw
else []
)
if not jail_names:
return ActiveBanListResponse(bans=[], total=0)
results: list[object | Exception] = await asyncio.gather(
*[client.send(["get", jn, "banip", "--with-time"]) for jn in jail_names],
return_exceptions=True,
)
bans: list[ActiveBan] = []
for jail_name, raw_result in zip(jail_names, results, strict=False):
if isinstance(raw_result, Exception):
log.warning(
"active_bans_fetch_error",
jail=jail_name,
error=str(raw_result),
)
continue
try:
ban_list: list[str] = cast("list[str]", _ok(raw_result)) or []
except (TypeError, ValueError) as exc:
log.warning(
"active_bans_parse_error",
jail=jail_name,
error=str(exc),
)
continue
for entry in ban_list:
ban = _parse_ban_entry(str(entry), jail_name)
if ban is not None:
bans.append(ban)
if http_session is not None and bans and geo_batch_lookup is not None:
all_ips: list[str] = [ban.ip for ban in bans]
try:
geo_map = await geo_batch_lookup(all_ips, http_session, db=app_db)
except Exception: # noqa: BLE001
log.warning("active_bans_batch_geo_failed")
geo_map = {}
enriched: list[ActiveBan] = []
for ban in bans:
geo = geo_map.get(ban.ip)
if geo is not None:
enriched.append(ban.model_copy(update={"country": geo.country_code}))
else:
enriched.append(ban)
bans = enriched
elif geo_enricher is not None:
bans = await _enrich_bans(bans, geo_enricher)
log.info("active_bans_fetched", total=len(bans))
return ActiveBanListResponse(bans=bans, total=len(bans))
_TIME_RANGE_SLACK_SECONDS: int = 60
@@ -502,7 +804,7 @@ async def bans_by_country(
country_names[cc] = cn
# Build companion table from recent rows (geo already cached from batch step).
bans: list[DashboardBanItem] = []
bans: list[ActiveBan] = []
for companion_row in companion_rows:
if source == "archive":
ip = companion_row["ip"]

View File

@@ -802,194 +802,6 @@ async def restart_daemon(
)
# ---------------------------------------------------------------------------
# Public API — Ban / Unban
# ---------------------------------------------------------------------------
async def ban_ip(socket_path: str, jail: str, ip: str) -> None:
"""Ban an IP address in a specific fail2ban jail.
The IP address is validated with :mod:`ipaddress` before the command
is sent to fail2ban.
Args:
socket_path: Path to the fail2ban Unix domain socket.
jail: Jail in which to apply the ban.
ip: IP address to ban (IPv4 or IPv6).
Raises:
ValueError: If *ip* is not a valid IP address.
JailNotFoundError: If *jail* is not a known jail.
JailOperationError: If fail2ban reports the operation failed.
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
cannot be reached.
"""
# Validate the IP address before sending to avoid injection.
try:
ipaddress.ip_address(ip)
except ValueError as exc:
raise ValueError(f"Invalid IP address: {ip!r}") from exc
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
try:
_ok(await client.send(["set", jail, "banip", ip]))
log.info("ip_banned", ip=ip, jail=jail)
except ValueError as exc:
if _is_not_found_error(exc):
raise JailNotFoundError(jail) from exc
raise JailOperationError(str(exc)) from exc
async def unban_ip(
socket_path: str,
ip: str,
jail: str | None = None,
) -> None:
"""Unban an IP address from one or all fail2ban jails.
If *jail* is ``None``, the IP is unbanned from every jail using the
global ``unban`` command. Otherwise only the specified jail is
targeted.
Args:
socket_path: Path to the fail2ban Unix domain socket.
ip: IP address to unban.
jail: Jail to unban from. ``None`` means all jails.
Raises:
ValueError: If *ip* is not a valid IP address.
JailNotFoundError: If *jail* is specified but does not exist.
JailOperationError: If fail2ban reports the operation failed.
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
cannot be reached.
"""
try:
ipaddress.ip_address(ip)
except ValueError as exc:
raise ValueError(f"Invalid IP address: {ip!r}") from exc
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
try:
if jail is None:
_ok(await client.send(["unban", ip]))
log.info("ip_unbanned_all_jails", ip=ip)
else:
_ok(await client.send(["set", jail, "unbanip", ip]))
log.info("ip_unbanned", ip=ip, jail=jail)
except ValueError as exc:
if _is_not_found_error(exc):
raise JailNotFoundError(jail or "") from exc
raise JailOperationError(str(exc)) from exc
async def get_active_bans(
socket_path: str,
geo_batch_lookup: GeoBatchLookup | None = None,
geo_enricher: GeoEnricher | None = None,
http_session: aiohttp.ClientSession | None = None,
app_db: aiosqlite.Connection | None = None,
) -> ActiveBanListResponse:
"""Return all currently banned IPs across every jail.
For each jail the ``get <jail> banip --with-time`` command is used
to retrieve ban start and expiry times alongside the IP address.
Geo enrichment strategy (highest priority first):
1. When *http_session* is provided the entire set of banned IPs is resolved
in a single :func:`~app.services.geo_service.lookup_batch` call (up to
100 IPs per HTTP request). This is far more efficient than concurrent
per-IP lookups and stays within ip-api.com rate limits.
2. When only *geo_enricher* is provided (legacy / test path) each IP is
resolved individually via the supplied async callable.
Args:
socket_path: Path to the fail2ban Unix domain socket.
geo_enricher: Optional async callable ``(ip) → GeoInfo | None``
used to enrich each ban entry with country and ASN data.
Ignored when *http_session* is provided.
http_session: Optional shared :class:`aiohttp.ClientSession`. When
provided, :func:`~app.services.geo_service.lookup_batch` is used
for efficient bulk geo resolution.
app_db: Optional BanGUI application database connection used to
persist newly resolved geo entries across restarts. Only
meaningful when *http_session* is provided.
Returns:
:class:`~app.models.ban.ActiveBanListResponse` with all active bans.
Raises:
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
cannot be reached.
"""
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
# Fetch jail names.
global_status = _to_dict(_ok(await client.send(["status"])))
jail_list_raw: str = str(global_status.get("Jail list", "") or "").strip()
jail_names: list[str] = (
[j.strip() for j in jail_list_raw.split(",") if j.strip()]
if jail_list_raw
else []
)
if not jail_names:
return ActiveBanListResponse(bans=[], total=0)
# For each jail, fetch the ban list with time info in parallel.
results: list[object | Exception] = await asyncio.gather(
*[client.send(["get", jn, "banip", "--with-time"]) for jn in jail_names],
return_exceptions=True,
)
bans: list[ActiveBan] = []
for jail_name, raw_result in zip(jail_names, results, strict=False):
if isinstance(raw_result, Exception):
log.warning(
"active_bans_fetch_error",
jail=jail_name,
error=str(raw_result),
)
continue
try:
ban_list: list[str] = cast("list[str]", _ok(raw_result)) or []
except (TypeError, ValueError) as exc:
log.warning(
"active_bans_parse_error",
jail=jail_name,
error=str(exc),
)
continue
for entry in ban_list:
ban = _parse_ban_entry(str(entry), jail_name)
if ban is not None:
bans.append(ban)
# Enrich with geo data — prefer batch lookup over per-IP enricher.
if http_session is not None and bans and geo_batch_lookup is not None:
all_ips: list[str] = [ban.ip for ban in bans]
try:
geo_map = await geo_batch_lookup(all_ips, http_session, db=app_db)
except Exception: # noqa: BLE001
log.warning("active_bans_batch_geo_failed")
geo_map = {}
enriched: list[ActiveBan] = []
for ban in bans:
geo = geo_map.get(ban.ip)
if geo is not None:
enriched.append(ban.model_copy(update={"country": geo.country_code}))
else:
enriched.append(ban)
bans = enriched
elif geo_enricher is not None:
bans = await _enrich_bans(bans, geo_enricher)
log.info("active_bans_fetched", total=len(bans))
return ActiveBanListResponse(bans=bans, total=len(bans))
def _parse_ban_entry(entry: str, jail: str) -> ActiveBan | None:

View File

@@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Any
import structlog
from app.db import open_db
from app.services import blocklist_service, jail_service
from app.services import ban_service, blocklist_service
from app.utils.runtime_state import get_effective_settings
if TYPE_CHECKING:
@@ -55,7 +55,7 @@ async def _run_import_with_resources(settings: Settings, http_session: ClientSes
db,
http_session,
socket_path,
ban_ip=jail_service.ban_ip,
ban_ip=ban_service.ban_ip,
)
log.info(
"blocklist_import_finished",

View File

@@ -84,7 +84,7 @@ class TestGetActiveBans:
total=1,
)
with patch(
"app.routers.bans.jail_service.get_active_bans",
"app.routers.bans.ban_service.get_active_bans",
AsyncMock(return_value=mock_response),
):
resp = await bans_client.get("/api/bans/active")
@@ -107,7 +107,7 @@ class TestGetActiveBans:
"""GET /api/bans/active returns empty list when no bans are active."""
mock_response = ActiveBanListResponse(bans=[], total=0)
with patch(
"app.routers.bans.jail_service.get_active_bans",
"app.routers.bans.ban_service.get_active_bans",
AsyncMock(return_value=mock_response),
):
resp = await bans_client.get("/api/bans/active")
@@ -132,7 +132,7 @@ class TestGetActiveBans:
total=1,
)
with patch(
"app.routers.bans.jail_service.get_active_bans",
"app.routers.bans.ban_service.get_active_bans",
AsyncMock(return_value=mock_response),
):
resp = await bans_client.get("/api/bans/active")
@@ -156,7 +156,7 @@ class TestBanIp:
async def test_201_on_success(self, bans_client: AsyncClient) -> None:
"""POST /api/bans returns 201 when the IP is banned."""
with patch(
"app.routers.bans.jail_service.ban_ip",
"app.routers.bans.ban_service.ban_ip",
AsyncMock(return_value=None),
):
resp = await bans_client.post(
@@ -170,7 +170,7 @@ class TestBanIp:
async def test_400_for_invalid_ip(self, bans_client: AsyncClient) -> None:
"""POST /api/bans returns 400 for an invalid IP address."""
with patch(
"app.routers.bans.jail_service.ban_ip",
"app.routers.bans.ban_service.ban_ip",
AsyncMock(side_effect=ValueError("Invalid IP address: 'bad'")),
):
resp = await bans_client.post(
@@ -185,7 +185,7 @@ class TestBanIp:
from app.services.jail_service import JailNotFoundError
with patch(
"app.routers.bans.jail_service.ban_ip",
"app.routers.bans.ban_service.ban_ip",
AsyncMock(side_effect=JailNotFoundError("ghost")),
):
resp = await bans_client.post(
@@ -215,7 +215,7 @@ class TestUnbanIp:
async def test_200_unban_from_all(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans with unban_all=true unbans from all jails."""
with patch(
"app.routers.bans.jail_service.unban_ip",
"app.routers.bans.ban_service.unban_ip",
AsyncMock(return_value=None),
):
resp = await bans_client.request(
@@ -230,7 +230,7 @@ class TestUnbanIp:
async def test_200_unban_from_specific_jail(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans with a jail unbans from that jail only."""
with patch(
"app.routers.bans.jail_service.unban_ip",
"app.routers.bans.ban_service.unban_ip",
AsyncMock(return_value=None),
):
resp = await bans_client.request(
@@ -245,7 +245,7 @@ class TestUnbanIp:
async def test_400_for_invalid_ip(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans returns 400 for an invalid IP."""
with patch(
"app.routers.bans.jail_service.unban_ip",
"app.routers.bans.ban_service.unban_ip",
AsyncMock(side_effect=ValueError("Invalid IP address: 'bad'")),
):
resp = await bans_client.request(
@@ -261,7 +261,7 @@ class TestUnbanIp:
from app.services.jail_service import JailNotFoundError
with patch(
"app.routers.bans.jail_service.unban_ip",
"app.routers.bans.ban_service.unban_ip",
AsyncMock(side_effect=JailNotFoundError("ghost")),
):
resp = await bans_client.request(

View File

@@ -174,17 +174,17 @@ class TestImport:
source = await blocklist_service.create_source(db, "Import Test", "https://t.test/")
from app.services import jail_service
from app.services import ban_service
with patch(
"app.services.jail_service.ban_ip", new_callable=AsyncMock
"app.services.ban_service.ban_ip", new_callable=AsyncMock
) as mock_ban:
result = await blocklist_service.import_source(
source,
session,
"/tmp/fake.sock",
db,
ban_ip=jail_service.ban_ip,
ban_ip=ban_service.ban_ip,
)
assert result.ips_imported == 2
@@ -198,15 +198,15 @@ class TestImport:
session = _make_session(content)
source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/")
from app.services import jail_service
from app.services import ban_service
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
with patch("app.services.ban_service.ban_ip", new_callable=AsyncMock):
result = await blocklist_service.import_source(
source,
session,
"/tmp/fake.sock",
db,
ban_ip=jail_service.ban_ip,
ban_ip=ban_service.ban_ip,
)
assert result.ips_imported == 1
@@ -217,14 +217,14 @@ class TestImport:
session = _make_session("", status=503)
source = await blocklist_service.create_source(db, "Err Source", "https://err.test/")
from app.services import jail_service
from app.services import ban_service
result = await blocklist_service.import_source(
source,
session,
"/tmp/fake.sock",
db,
ban_ip=jail_service.ban_ip,
ban_ip=ban_service.ban_ip,
)
assert result.ips_imported == 0
@@ -234,6 +234,7 @@ class TestImport:
"""import_source aborts immediately and records an error when the target jail
does not exist in fail2ban instead of silently skipping every IP."""
from app.services.jail_service import JailNotFoundError
from app.services import ban_service
content = "\n".join(f"1.2.3.{i}" for i in range(100))
session = _make_session(content)
@@ -246,15 +247,13 @@ class TestImport:
call_count += 1
raise JailNotFoundError(jail)
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found):
from app.services import jail_service
with patch("app.services.ban_service.ban_ip", side_effect=_raise_jail_not_found):
result = await blocklist_service.import_source(
source,
session,
"/tmp/fake.sock",
db,
ban_ip=jail_service.ban_ip,
ban_ip=ban_service.ban_ip,
)
# Must abort after the first JailNotFoundError — only one ban attempt.
@@ -273,15 +272,15 @@ class TestImport:
session = _make_session(content)
with patch(
"app.services.jail_service.ban_ip", new_callable=AsyncMock
"app.services.ban_service.ban_ip", new_callable=AsyncMock
):
from app.services import jail_service
from app.services import ban_service
result = await blocklist_service.import_all(
db,
session,
"/tmp/fake.sock",
ban_ip=jail_service.ban_ip,
ban_ip=ban_service.ban_ip,
)
# Only S1 is enabled, S2 is disabled.
@@ -415,16 +414,16 @@ class TestGeoPrewarmCacheFilter:
def _mock_is_cached(ip: str) -> bool:
return ip == "1.2.3.4"
from app.services import jail_service
from app.services import ban_service
mock_batch = AsyncMock(return_value={})
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
with patch("app.services.ban_service.ban_ip", new_callable=AsyncMock):
result = await blocklist_service.import_source(
source,
session,
"/tmp/fake.sock",
db,
ban_ip=jail_service.ban_ip,
ban_ip=ban_service.ban_ip,
geo_is_cached=_mock_is_cached,
geo_batch_lookup=mock_batch,
)

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import asyncio
import contextlib
from typing import Any
from unittest.mock import AsyncMock, patch
@@ -12,7 +13,7 @@ from app.exceptions import Fail2BanConnectionError
from app.models.ban import ActiveBanListResponse, JailBannedIpsResponse
from app.models.geo import GeoDetail, GeoInfo
from app.models.jail import JailDetailResponse, JailListResponse
from app.services import jail_service
from app.services import ban_service, jail_service
from app.services.jail_service import JailNotFoundError, JailOperationError
# ---------------------------------------------------------------------------
@@ -71,7 +72,10 @@ def _patch_client(responses: dict[str, Any]) -> Any:
def __init__(self, **_kw: Any) -> None:
self.send = mock_send
return patch("app.services.jail_service.Fail2BanClient", _FakeClient)
stack = contextlib.ExitStack()
stack.enter_context(patch("app.services.jail_service.Fail2BanClient", _FakeClient))
stack.enter_context(patch("app.services.ban_service.Fail2BanClient", _FakeClient))
return stack
# ---------------------------------------------------------------------------
@@ -555,19 +559,19 @@ class TestJailControls:
class TestBanUnban:
"""Unit tests for :func:`~app.services.jail_service.ban_ip` and
:func:`~app.services.jail_service.unban_ip`.
"""Unit tests for :func:`~app.services.ban_service.ban_ip` and
:func:`~app.services.ban_service.unban_ip`.
"""
async def test_ban_ip_success(self) -> None:
"""ban_ip sends the banip command for a valid IP."""
with _patch_client({"set|sshd|banip|1.2.3.4": (0, 1)}):
await jail_service.ban_ip(_SOCKET, "sshd", "1.2.3.4") # should not raise
await ban_service.ban_ip(_SOCKET, "sshd", "1.2.3.4") # should not raise
async def test_ban_ip_invalid_raises(self) -> None:
"""ban_ip raises ValueError for a non-IP value."""
with pytest.raises(ValueError, match="Invalid IP"):
await jail_service.ban_ip(_SOCKET, "sshd", "not-an-ip")
await ban_service.ban_ip(_SOCKET, "sshd", "not-an-ip")
async def test_ban_ip_unknown_jail_exception_raises_jail_not_found(self) -> None:
"""ban_ip raises JailNotFoundError when fail2ban returns UnknownJailException.
@@ -581,27 +585,27 @@ class TestBanUnban:
_patch_client({"set|missing-jail|banip|1.2.3.4": response}),
pytest.raises(JailNotFoundError, match="missing-jail"),
):
await jail_service.ban_ip(_SOCKET, "missing-jail", "1.2.3.4")
await ban_service.ban_ip(_SOCKET, "missing-jail", "1.2.3.4")
async def test_ban_ipv6_success(self) -> None:
"""ban_ip accepts an IPv6 address."""
with _patch_client({"set|sshd|banip|::1": (0, 1)}):
await jail_service.ban_ip(_SOCKET, "sshd", "::1") # should not raise
await ban_service.ban_ip(_SOCKET, "sshd", "::1") # should not raise
async def test_unban_ip_all_jails(self) -> None:
"""unban_ip with jail=None uses the global unban command."""
with _patch_client({"unban|1.2.3.4": (0, 1)}):
await jail_service.unban_ip(_SOCKET, "1.2.3.4") # should not raise
await ban_service.unban_ip(_SOCKET, "1.2.3.4") # should not raise
async def test_unban_ip_specific_jail(self) -> None:
"""unban_ip with a jail sends the set unbanip command."""
with _patch_client({"set|sshd|unbanip|1.2.3.4": (0, 1)}):
await jail_service.unban_ip(_SOCKET, "1.2.3.4", jail="sshd") # should not raise
await ban_service.unban_ip(_SOCKET, "1.2.3.4", jail="sshd") # should not raise
async def test_unban_invalid_ip_raises(self) -> None:
"""unban_ip raises ValueError for an invalid IP."""
with pytest.raises(ValueError, match="Invalid IP"):
await jail_service.unban_ip(_SOCKET, "bad-ip")
await ban_service.unban_ip(_SOCKET, "bad-ip")
# ---------------------------------------------------------------------------
@@ -610,7 +614,7 @@ class TestBanUnban:
class TestGetActiveBans:
"""Unit tests for :func:`~app.services.jail_service.get_active_bans`."""
"""Unit tests for :func:`~app.services.ban_service.get_active_bans`."""
async def test_returns_active_ban_list_response(self) -> None:
"""get_active_bans returns an ActiveBanListResponse."""
@@ -622,7 +626,7 @@ class TestGetActiveBans:
),
}
with _patch_client(responses):
result = await jail_service.get_active_bans(_SOCKET)
result = await ban_service.get_active_bans(_SOCKET)
assert isinstance(result, ActiveBanListResponse)
assert result.total == 1
@@ -633,7 +637,7 @@ class TestGetActiveBans:
"""get_active_bans returns empty list when no jails are active."""
responses = {"status": (0, [("Number of jail", 0), ("Jail list", "")])}
with _patch_client(responses):
result = await jail_service.get_active_bans(_SOCKET)
result = await ban_service.get_active_bans(_SOCKET)
assert result.total == 0
assert result.bans == []
@@ -645,7 +649,7 @@ class TestGetActiveBans:
"get|sshd|banip|--with-time": (0, []),
}
with _patch_client(responses):
result = await jail_service.get_active_bans(_SOCKET)
result = await ban_service.get_active_bans(_SOCKET)
assert result.total == 0
@@ -659,7 +663,7 @@ class TestGetActiveBans:
),
}
with _patch_client(responses):
result = await jail_service.get_active_bans(_SOCKET)
result = await ban_service.get_active_bans(_SOCKET)
ban = result.bans[0]
assert ban.banned_at is not None
@@ -691,8 +695,8 @@ class TestGetActiveBans:
def __init__(self, **_kw: Any) -> None:
self.send = AsyncMock(side_effect=_side)
with patch("app.services.jail_service.Fail2BanClient", _FakeClientPartial):
result = await jail_service.get_active_bans(_SOCKET)
with patch("app.services.ban_service.Fail2BanClient", _FakeClientPartial):
result = await ban_service.get_active_bans(_SOCKET)
# Only sshd ban returned (nginx silently skipped)
assert result.total == 1
@@ -714,7 +718,7 @@ class TestGetActiveBans:
with _patch_client(responses):
mock_session = AsyncMock()
result = await jail_service.get_active_bans(
result = await ban_service.get_active_bans(
_SOCKET,
http_session=mock_session,
geo_batch_lookup=mock_batch,
@@ -738,7 +742,7 @@ class TestGetActiveBans:
with _patch_client(responses):
mock_session = AsyncMock()
result = await jail_service.get_active_bans(
result = await ban_service.get_active_bans(
_SOCKET,
http_session=mock_session,
geo_batch_lookup=failing_batch,
@@ -763,7 +767,7 @@ class TestGetActiveBans:
return GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None)
with _patch_client(responses):
result = await jail_service.get_active_bans(
result = await ban_service.get_active_bans(
_SOCKET, geo_enricher=_enricher
)