feat: Stage 4 — fail2ban connection and server status
This commit is contained in:
@@ -33,7 +33,8 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
from app.config import Settings, get_settings
|
||||
from app.db import init_db
|
||||
from app.routers import auth, health, setup
|
||||
from app.routers import auth, dashboard, health, setup
|
||||
from app.tasks import health_check
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ensure the bundled fail2ban package is importable from fail2ban-master/
|
||||
@@ -114,6 +115,9 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
scheduler.start()
|
||||
app.state.scheduler = scheduler
|
||||
|
||||
# --- Health-check background probe ---
|
||||
health_check.register(app)
|
||||
|
||||
log.info("bangui_started")
|
||||
|
||||
try:
|
||||
@@ -268,5 +272,6 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
||||
app.include_router(health.router)
|
||||
app.include_router(setup.router)
|
||||
app.include_router(auth.router)
|
||||
app.include_router(dashboard.router)
|
||||
|
||||
return app
|
||||
|
||||
46
backend/app/routers/dashboard.py
Normal file
46
backend/app/routers/dashboard.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Dashboard router.
|
||||
|
||||
Provides the ``GET /api/dashboard/status`` endpoint that returns the cached
|
||||
fail2ban server health snapshot. The snapshot is maintained by the
|
||||
background health-check task and refreshed every 30 seconds.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Request
|
||||
|
||||
from app.dependencies import AuthDep
|
||||
from app.models.server import ServerStatus, ServerStatusResponse
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
|
||||
|
||||
|
||||
@router.get(
|
||||
"/status",
|
||||
response_model=ServerStatusResponse,
|
||||
summary="Return the cached fail2ban server status",
|
||||
)
|
||||
async def get_server_status(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
) -> ServerStatusResponse:
|
||||
"""Return the most recent fail2ban health snapshot.
|
||||
|
||||
The snapshot is populated by a background task that runs every 30 seconds.
|
||||
If the task has not yet executed a placeholder ``online=False`` status is
|
||||
returned so the response is always well-formed.
|
||||
|
||||
Args:
|
||||
request: The incoming request (used to access ``app.state``).
|
||||
_auth: Validated session — enforces authentication on this endpoint.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.server.ServerStatusResponse` containing the
|
||||
current health snapshot.
|
||||
"""
|
||||
cached: ServerStatus = getattr(
|
||||
request.app.state,
|
||||
"server_status",
|
||||
ServerStatus(online=False),
|
||||
)
|
||||
return ServerStatusResponse(status=cached)
|
||||
171
backend/app/services/health_service.py
Normal file
171
backend/app/services/health_service.py
Normal file
@@ -0,0 +1,171 @@
|
||||
"""Health service.
|
||||
|
||||
Probes the fail2ban socket to determine whether the daemon is reachable and
|
||||
collects aggregated server statistics (version, jail count, ban counts).
|
||||
|
||||
The probe is intentionally lightweight — it is meant to be called every 30
|
||||
seconds by the background health-check task, not on every HTTP request.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
|
||||
from app.models.server import ServerStatus
|
||||
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanConnectionError, Fail2BanProtocolError
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SOCKET_TIMEOUT: float = 5.0
|
||||
|
||||
|
||||
def _ok(response: Any) -> Any:
|
||||
"""Extract the payload from a fail2ban ``(return_code, data)`` response.
|
||||
|
||||
fail2ban wraps every response in a ``(0, data)`` success tuple or
|
||||
a ``(1, exception)`` error tuple. This helper returns ``data`` for
|
||||
successful responses or raises :class:`ValueError` for error responses.
|
||||
|
||||
Args:
|
||||
response: Raw value returned by :meth:`~Fail2BanClient.send`.
|
||||
|
||||
Returns:
|
||||
The payload ``data`` portion of the response.
|
||||
|
||||
Raises:
|
||||
ValueError: If the response indicates an error (return code ≠ 0).
|
||||
"""
|
||||
try:
|
||||
code, data = response
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||
|
||||
if code != 0:
|
||||
raise ValueError(f"fail2ban returned error code {code}: {data!r}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
"""Convert a list of ``(key, value)`` pairs to a plain dict.
|
||||
|
||||
fail2ban returns structured data as lists of 2-tuples rather than dicts.
|
||||
This helper converts them safely, ignoring non-pair items.
|
||||
|
||||
Args:
|
||||
pairs: A list of ``(key, value)`` pairs (or any iterable thereof).
|
||||
|
||||
Returns:
|
||||
A :class:`dict` with the keys and values from *pairs*.
|
||||
"""
|
||||
if not isinstance(pairs, (list, tuple)):
|
||||
return {}
|
||||
result: dict[str, Any] = {}
|
||||
for item in pairs:
|
||||
try:
|
||||
k, v = item
|
||||
result[str(k)] = v
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public interface
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerStatus:
|
||||
"""Probe the fail2ban daemon and return a :class:`~app.models.server.ServerStatus`.
|
||||
|
||||
Sends ``ping``, ``version``, ``status``, and per-jail ``status <jail>``
|
||||
commands. Any socket or protocol error is caught and results in an
|
||||
``online=False`` status so the dashboard can always return a safe default.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
timeout: Per-command socket timeout in seconds.
|
||||
|
||||
Returns:
|
||||
A :class:`~app.models.server.ServerStatus` snapshot. ``online`` is
|
||||
``True`` when the daemon is reachable, ``False`` otherwise.
|
||||
"""
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=timeout)
|
||||
|
||||
try:
|
||||
# ------------------------------------------------------------------ #
|
||||
# 1. Connectivity check #
|
||||
# ------------------------------------------------------------------ #
|
||||
ping_data = _ok(await client.send(["ping"]))
|
||||
if ping_data != "pong":
|
||||
log.warning("fail2ban_unexpected_ping_response", response=ping_data)
|
||||
return ServerStatus(online=False)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# 2. Version #
|
||||
# ------------------------------------------------------------------ #
|
||||
try:
|
||||
version: str | None = str(_ok(await client.send(["version"])))
|
||||
except (ValueError, TypeError):
|
||||
version = None
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# 3. Global status — jail count and names #
|
||||
# ------------------------------------------------------------------ #
|
||||
status_data = _to_dict(_ok(await client.send(["status"])))
|
||||
active_jails: int = int(status_data.get("Number of jail", 0) or 0)
|
||||
jail_list_raw: str = str(status_data.get("Jail list", "") or "").strip()
|
||||
jail_names: list[str] = (
|
||||
[j.strip() for j in jail_list_raw.split(",") if j.strip()]
|
||||
if jail_list_raw
|
||||
else []
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------ #
|
||||
# 4. Per-jail aggregation #
|
||||
# ------------------------------------------------------------------ #
|
||||
total_bans: int = 0
|
||||
total_failures: int = 0
|
||||
|
||||
for jail_name in jail_names:
|
||||
try:
|
||||
jail_resp = _to_dict(_ok(await client.send(["status", jail_name])))
|
||||
filter_stats = _to_dict(jail_resp.get("Filter") or [])
|
||||
action_stats = _to_dict(jail_resp.get("Actions") or [])
|
||||
total_failures += int(filter_stats.get("Currently failed", 0) or 0)
|
||||
total_bans += int(action_stats.get("Currently banned", 0) or 0)
|
||||
except (ValueError, TypeError, KeyError) as exc:
|
||||
log.warning(
|
||||
"fail2ban_jail_status_parse_error",
|
||||
jail=jail_name,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
log.debug(
|
||||
"fail2ban_probe_ok",
|
||||
version=version,
|
||||
active_jails=active_jails,
|
||||
total_bans=total_bans,
|
||||
total_failures=total_failures,
|
||||
)
|
||||
|
||||
return ServerStatus(
|
||||
online=True,
|
||||
version=version,
|
||||
active_jails=active_jails,
|
||||
total_bans=total_bans,
|
||||
total_failures=total_failures,
|
||||
)
|
||||
|
||||
except (Fail2BanConnectionError, Fail2BanProtocolError) as exc:
|
||||
log.warning("fail2ban_probe_failed", error=str(exc))
|
||||
return ServerStatus(online=False)
|
||||
except ValueError as exc:
|
||||
log.error("fail2ban_probe_parse_error", error=str(exc))
|
||||
return ServerStatus(online=False)
|
||||
79
backend/app/tasks/health_check.py
Normal file
79
backend/app/tasks/health_check.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Health-check background task.
|
||||
|
||||
Registers an APScheduler job that probes the fail2ban socket every 30 seconds
|
||||
and stores the result on ``app.state.server_status``. The dashboard endpoint
|
||||
reads from this cache, keeping HTTP responses fast and the daemon connection
|
||||
decoupled from user-facing requests.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import structlog
|
||||
|
||||
from app.models.server import ServerStatus
|
||||
from app.services import health_service
|
||||
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from fastapi import FastAPI
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
#: How often the probe fires (seconds).
|
||||
HEALTH_CHECK_INTERVAL: int = 30
|
||||
|
||||
|
||||
async def _run_probe(app: Any) -> None:
|
||||
"""Probe fail2ban and cache the result on *app.state*.
|
||||
|
||||
This is the APScheduler job callback. It reads ``fail2ban_socket`` from
|
||||
``app.state.settings``, runs the health probe, and writes the result to
|
||||
``app.state.server_status``.
|
||||
|
||||
Args:
|
||||
app: The :class:`fastapi.FastAPI` application instance passed by the
|
||||
scheduler via the ``kwargs`` mechanism.
|
||||
"""
|
||||
socket_path: str = app.state.settings.fail2ban_socket
|
||||
status: ServerStatus = await health_service.probe(socket_path)
|
||||
app.state.server_status = status
|
||||
log.debug(
|
||||
"health_check_complete",
|
||||
online=status.online,
|
||||
version=status.version,
|
||||
active_jails=status.active_jails,
|
||||
)
|
||||
|
||||
|
||||
def register(app: FastAPI) -> None:
|
||||
"""Add the health-check job to the application scheduler.
|
||||
|
||||
Must be called after the scheduler has been started (i.e., inside the
|
||||
lifespan handler, after ``scheduler.start()``).
|
||||
|
||||
Args:
|
||||
app: The :class:`fastapi.FastAPI` application instance whose
|
||||
``app.state.scheduler`` will receive the job.
|
||||
"""
|
||||
# Initialise the cache with an offline placeholder so the dashboard
|
||||
# endpoint is always able to return a valid response even before the
|
||||
# first probe fires.
|
||||
app.state.server_status = ServerStatus(online=False)
|
||||
|
||||
app.state.scheduler.add_job(
|
||||
_run_probe,
|
||||
trigger="interval",
|
||||
seconds=HEALTH_CHECK_INTERVAL,
|
||||
kwargs={"app": app},
|
||||
id="health_check",
|
||||
replace_existing=True,
|
||||
# Fire immediately on startup too, so the UI isn't dark for 30 s.
|
||||
next_run_time=__import__("datetime").datetime.now(
|
||||
tz=__import__("datetime").timezone.utc
|
||||
),
|
||||
)
|
||||
log.info(
|
||||
"health_check_scheduled",
|
||||
interval_seconds=HEALTH_CHECK_INTERVAL,
|
||||
)
|
||||
@@ -43,8 +43,8 @@ ignore = ["B008"] # FastAPI uses function calls in default arguments (Depends)
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
# sys.path manipulation before stdlib imports is intentional in test helpers
|
||||
# pytest evaluates fixture type annotations at runtime, so TC002/TC003 are false-positives
|
||||
"tests/**" = ["E402", "TC002", "TC003"]
|
||||
# pytest evaluates fixture type annotations at runtime, so TC001/TC002/TC003 are false-positives
|
||||
"tests/**" = ["E402", "TC001", "TC002", "TC003"]
|
||||
"app/routers/**" = ["TC001"] # FastAPI evaluates Depends() type aliases at runtime via get_type_hints()
|
||||
|
||||
[tool.ruff.format]
|
||||
|
||||
194
backend/tests/test_routers/test_dashboard.py
Normal file
194
backend/tests/test_routers/test_dashboard.py
Normal file
@@ -0,0 +1,194 @@
|
||||
"""Tests for the dashboard router (GET /api/dashboard/status)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.models.server import ServerStatus
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SETUP_PAYLOAD = {
|
||||
"master_password": "testpassword1",
|
||||
"database_path": "bangui.db",
|
||||
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
||||
"timezone": "UTC",
|
||||
"session_duration_minutes": 60,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
"""Provide an authenticated ``AsyncClient`` with a pre-seeded server status.
|
||||
|
||||
Unlike the shared ``client`` fixture this one also exposes access to
|
||||
``app.state`` via the app instance so we can seed the status cache.
|
||||
"""
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "dashboard_test.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
session_secret="test-dashboard-secret",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
|
||||
db.row_factory = aiosqlite.Row
|
||||
await init_db(db)
|
||||
app.state.db = db
|
||||
|
||||
# Pre-seed a server status so the endpoint has something to return.
|
||||
app.state.server_status = ServerStatus(
|
||||
online=True,
|
||||
version="1.0.2",
|
||||
active_jails=2,
|
||||
total_bans=10,
|
||||
total_failures=5,
|
||||
)
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
# Complete setup so the middleware doesn't redirect.
|
||||
resp = await ac.post("/api/setup", json=_SETUP_PAYLOAD)
|
||||
assert resp.status_code == 201
|
||||
|
||||
# Login to get a session cookie.
|
||||
login_resp = await ac.post(
|
||||
"/api/auth/login",
|
||||
json={"password": _SETUP_PAYLOAD["master_password"]},
|
||||
)
|
||||
assert login_resp.status_code == 200
|
||||
|
||||
yield ac
|
||||
|
||||
await db.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def offline_dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
"""Like ``dashboard_client`` but with an offline server status."""
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "dashboard_offline_test.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
session_secret="test-dashboard-offline-secret",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
|
||||
db.row_factory = aiosqlite.Row
|
||||
await init_db(db)
|
||||
app.state.db = db
|
||||
|
||||
app.state.server_status = ServerStatus(online=False)
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
resp = await ac.post("/api/setup", json=_SETUP_PAYLOAD)
|
||||
assert resp.status_code == 201
|
||||
|
||||
login_resp = await ac.post(
|
||||
"/api/auth/login",
|
||||
json={"password": _SETUP_PAYLOAD["master_password"]},
|
||||
)
|
||||
assert login_resp.status_code == 200
|
||||
|
||||
yield ac
|
||||
|
||||
await db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDashboardStatus:
|
||||
"""GET /api/dashboard/status."""
|
||||
|
||||
async def test_returns_200_when_authenticated(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""Authenticated request returns HTTP 200."""
|
||||
response = await dashboard_client.get("/api/dashboard/status")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_returns_401_when_unauthenticated(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
"""Unauthenticated request returns HTTP 401."""
|
||||
# Complete setup so the middleware allows the request through.
|
||||
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
||||
response = await client.get("/api/dashboard/status")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_response_shape_when_online(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""Response contains the expected ``status`` object shape."""
|
||||
response = await dashboard_client.get("/api/dashboard/status")
|
||||
body = response.json()
|
||||
|
||||
assert "status" in body
|
||||
status = body["status"]
|
||||
assert "online" in status
|
||||
assert "version" in status
|
||||
assert "active_jails" in status
|
||||
assert "total_bans" in status
|
||||
assert "total_failures" in status
|
||||
|
||||
async def test_cached_values_returned_when_online(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""Endpoint returns the exact values from ``app.state.server_status``."""
|
||||
response = await dashboard_client.get("/api/dashboard/status")
|
||||
status = response.json()["status"]
|
||||
|
||||
assert status["online"] is True
|
||||
assert status["version"] == "1.0.2"
|
||||
assert status["active_jails"] == 2
|
||||
assert status["total_bans"] == 10
|
||||
assert status["total_failures"] == 5
|
||||
|
||||
async def test_offline_status_returned_correctly(
|
||||
self, offline_dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""Endpoint returns online=False when the cache holds an offline snapshot."""
|
||||
response = await offline_dashboard_client.get("/api/dashboard/status")
|
||||
assert response.status_code == 200
|
||||
status = response.json()["status"]
|
||||
|
||||
assert status["online"] is False
|
||||
assert status["version"] is None
|
||||
assert status["active_jails"] == 0
|
||||
assert status["total_bans"] == 0
|
||||
assert status["total_failures"] == 0
|
||||
|
||||
async def test_returns_offline_when_state_not_initialised(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
"""Endpoint returns online=False as a safe default if the cache is absent."""
|
||||
# Setup + login so the endpoint is reachable.
|
||||
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
||||
await client.post(
|
||||
"/api/auth/login",
|
||||
json={"password": _SETUP_PAYLOAD["master_password"]},
|
||||
)
|
||||
# server_status is not set on app.state in the shared `client` fixture.
|
||||
response = await client.get("/api/dashboard/status")
|
||||
assert response.status_code == 200
|
||||
status = response.json()["status"]
|
||||
assert status["online"] is False
|
||||
263
backend/tests/test_services/test_health_service.py
Normal file
263
backend/tests/test_services/test_health_service.py
Normal file
@@ -0,0 +1,263 @@
|
||||
"""Tests for health_service.probe()."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.server import ServerStatus
|
||||
from app.services import health_service
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError, Fail2BanProtocolError
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SOCKET = "/fake/fail2ban.sock"
|
||||
|
||||
|
||||
def _make_send(responses: dict[str, Any]) -> AsyncMock:
|
||||
"""Build an ``AsyncMock`` for ``Fail2BanClient.send`` keyed by command[0].
|
||||
|
||||
For the ``["status", jail_name]`` command the key is
|
||||
``"status:<jail_name>"``.
|
||||
"""
|
||||
|
||||
async def _side_effect(command: list[str]) -> Any:
|
||||
key = f"status:{command[1]}" if len(command) >= 2 and command[0] == "status" else command[0]
|
||||
if key not in responses:
|
||||
raise KeyError(f"Unexpected command key {key!r} in mock")
|
||||
return responses[key]
|
||||
|
||||
mock = AsyncMock(side_effect=_side_effect)
|
||||
return mock
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProbeOnline:
|
||||
"""Verify probe() correctly parses a healthy fail2ban response."""
|
||||
|
||||
async def test_online_flag_is_true(self) -> None:
|
||||
"""status.online is True when ping succeeds."""
|
||||
send = _make_send(
|
||||
{
|
||||
"ping": (0, "pong"),
|
||||
"version": (0, "1.0.2"),
|
||||
"status": (0, [("Number of jail", 0), ("Jail list", "")]),
|
||||
}
|
||||
)
|
||||
with patch("app.services.health_service.Fail2BanClient") as mock_client:
|
||||
mock_client.return_value.send = send
|
||||
result: ServerStatus = await health_service.probe(_SOCKET)
|
||||
|
||||
assert result.online is True
|
||||
|
||||
async def test_version_parsed(self) -> None:
|
||||
"""status.version contains the version string returned by fail2ban."""
|
||||
send = _make_send(
|
||||
{
|
||||
"ping": (0, "pong"),
|
||||
"version": (0, "1.1.0"),
|
||||
"status": (0, [("Number of jail", 0), ("Jail list", "")]),
|
||||
}
|
||||
)
|
||||
with patch("app.services.health_service.Fail2BanClient") as mock_client:
|
||||
mock_client.return_value.send = send
|
||||
result = await health_service.probe(_SOCKET)
|
||||
|
||||
assert result.version == "1.1.0"
|
||||
|
||||
async def test_active_jails_count(self) -> None:
|
||||
"""status.active_jails reflects the jail count from the status command."""
|
||||
send = _make_send(
|
||||
{
|
||||
"ping": (0, "pong"),
|
||||
"version": (0, "1.0.2"),
|
||||
"status": (0, [("Number of jail", 2), ("Jail list", "sshd, nginx")]),
|
||||
"status:sshd": (
|
||||
0,
|
||||
[
|
||||
("Filter", [("Currently failed", 3), ("Total failed", 100)]),
|
||||
("Actions", [("Currently banned", 1), ("Total banned", 50)]),
|
||||
],
|
||||
),
|
||||
"status:nginx": (
|
||||
0,
|
||||
[
|
||||
("Filter", [("Currently failed", 2), ("Total failed", 50)]),
|
||||
("Actions", [("Currently banned", 0), ("Total banned", 10)]),
|
||||
],
|
||||
),
|
||||
}
|
||||
)
|
||||
with patch("app.services.health_service.Fail2BanClient") as mock_client:
|
||||
mock_client.return_value.send = send
|
||||
result = await health_service.probe(_SOCKET)
|
||||
|
||||
assert result.active_jails == 2
|
||||
|
||||
async def test_total_bans_aggregated(self) -> None:
|
||||
"""status.total_bans sums 'Currently banned' across all jails."""
|
||||
send = _make_send(
|
||||
{
|
||||
"ping": (0, "pong"),
|
||||
"version": (0, "1.0.2"),
|
||||
"status": (0, [("Number of jail", 2), ("Jail list", "sshd, nginx")]),
|
||||
"status:sshd": (
|
||||
0,
|
||||
[
|
||||
("Filter", [("Currently failed", 3), ("Total failed", 100)]),
|
||||
("Actions", [("Currently banned", 4), ("Total banned", 50)]),
|
||||
],
|
||||
),
|
||||
"status:nginx": (
|
||||
0,
|
||||
[
|
||||
("Filter", [("Currently failed", 1), ("Total failed", 20)]),
|
||||
("Actions", [("Currently banned", 2), ("Total banned", 15)]),
|
||||
],
|
||||
),
|
||||
}
|
||||
)
|
||||
with patch("app.services.health_service.Fail2BanClient") as mock_client:
|
||||
mock_client.return_value.send = send
|
||||
result = await health_service.probe(_SOCKET)
|
||||
|
||||
assert result.total_bans == 6 # 4 + 2
|
||||
|
||||
async def test_total_failures_aggregated(self) -> None:
|
||||
"""status.total_failures sums 'Currently failed' across all jails."""
|
||||
send = _make_send(
|
||||
{
|
||||
"ping": (0, "pong"),
|
||||
"version": (0, "1.0.2"),
|
||||
"status": (0, [("Number of jail", 2), ("Jail list", "sshd, nginx")]),
|
||||
"status:sshd": (
|
||||
0,
|
||||
[
|
||||
("Filter", [("Currently failed", 3), ("Total failed", 100)]),
|
||||
("Actions", [("Currently banned", 1), ("Total banned", 50)]),
|
||||
],
|
||||
),
|
||||
"status:nginx": (
|
||||
0,
|
||||
[
|
||||
("Filter", [("Currently failed", 2), ("Total failed", 20)]),
|
||||
("Actions", [("Currently banned", 0), ("Total banned", 10)]),
|
||||
],
|
||||
),
|
||||
}
|
||||
)
|
||||
with patch("app.services.health_service.Fail2BanClient") as mock_client:
|
||||
mock_client.return_value.send = send
|
||||
result = await health_service.probe(_SOCKET)
|
||||
|
||||
assert result.total_failures == 5 # 3 + 2
|
||||
|
||||
async def test_empty_jail_list(self) -> None:
|
||||
"""Probe succeeds with zero jails — no per-jail queries are made."""
|
||||
send = _make_send(
|
||||
{
|
||||
"ping": (0, "pong"),
|
||||
"version": (0, "1.0.2"),
|
||||
"status": (0, [("Number of jail", 0), ("Jail list", "")]),
|
||||
}
|
||||
)
|
||||
with patch("app.services.health_service.Fail2BanClient") as mock_client:
|
||||
mock_client.return_value.send = send
|
||||
result = await health_service.probe(_SOCKET)
|
||||
|
||||
assert result.online is True
|
||||
assert result.active_jails == 0
|
||||
assert result.total_bans == 0
|
||||
assert result.total_failures == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Error handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProbeOffline:
|
||||
"""Verify probe() returns online=False when the daemon is unreachable."""
|
||||
|
||||
async def test_connection_error_returns_offline(self) -> None:
|
||||
"""Fail2BanConnectionError → online=False."""
|
||||
with patch("app.services.health_service.Fail2BanClient") as mock_client:
|
||||
mock_client.return_value.send = AsyncMock(
|
||||
side_effect=Fail2BanConnectionError("socket not found", _SOCKET)
|
||||
)
|
||||
result = await health_service.probe(_SOCKET)
|
||||
|
||||
assert result.online is False
|
||||
assert result.version is None
|
||||
|
||||
async def test_protocol_error_returns_offline(self) -> None:
|
||||
"""Fail2BanProtocolError → online=False."""
|
||||
with patch("app.services.health_service.Fail2BanClient") as mock_client:
|
||||
mock_client.return_value.send = AsyncMock(
|
||||
side_effect=Fail2BanProtocolError("bad pickle")
|
||||
)
|
||||
result = await health_service.probe(_SOCKET)
|
||||
|
||||
assert result.online is False
|
||||
|
||||
async def test_bad_ping_response_returns_offline(self) -> None:
|
||||
"""An unexpected ping response → online=False (defensive guard)."""
|
||||
send = _make_send({"ping": (0, "NOTPONG")})
|
||||
with patch("app.services.health_service.Fail2BanClient") as mock_client:
|
||||
mock_client.return_value.send = send
|
||||
result = await health_service.probe(_SOCKET)
|
||||
|
||||
assert result.online is False
|
||||
|
||||
async def test_error_code_in_ping_returns_offline(self) -> None:
|
||||
"""An error return code in the ping response → online=False."""
|
||||
send = _make_send({"ping": (1, "ERROR")})
|
||||
with patch("app.services.health_service.Fail2BanClient") as mock_client:
|
||||
mock_client.return_value.send = send
|
||||
result = await health_service.probe(_SOCKET)
|
||||
|
||||
assert result.online is False
|
||||
|
||||
async def test_per_jail_error_is_tolerated(self) -> None:
|
||||
"""A parse error on an individual jail's status does not break the probe."""
|
||||
send = _make_send(
|
||||
{
|
||||
"ping": (0, "pong"),
|
||||
"version": (0, "1.0.2"),
|
||||
"status": (0, [("Number of jail", 1), ("Jail list", "sshd")]),
|
||||
# Return garbage to trigger parse tolerance.
|
||||
"status:sshd": (0, "INVALID"),
|
||||
}
|
||||
)
|
||||
with patch("app.services.health_service.Fail2BanClient") as mock_client:
|
||||
mock_client.return_value.send = send
|
||||
result = await health_service.probe(_SOCKET)
|
||||
|
||||
# The service should still be online even if per-jail parsing fails.
|
||||
assert result.online is True
|
||||
assert result.total_bans == 0
|
||||
assert result.total_failures == 0
|
||||
|
||||
@pytest.mark.parametrize("version_return", [(1, "ERROR"), (0, None)])
|
||||
async def test_version_failure_is_tolerated(self, version_return: tuple[int, Any]) -> None:
|
||||
"""A failed or null version response does not prevent a successful probe."""
|
||||
send = _make_send(
|
||||
{
|
||||
"ping": (0, "pong"),
|
||||
"version": version_return,
|
||||
"status": (0, [("Number of jail", 0), ("Jail list", "")]),
|
||||
}
|
||||
)
|
||||
with patch("app.services.health_service.Fail2BanClient") as mock_client:
|
||||
mock_client.return_value.send = send
|
||||
result = await health_service.probe(_SOCKET)
|
||||
|
||||
assert result.online is True
|
||||
Reference in New Issue
Block a user