fix: retry, semaphore, reload lock, activation verify, bans_by_jail diagnostics

Stage 1.1-1.3: reload_all include/exclude_jails params already implemented;
  added keyword-arg assertions in router and service tests.

Stage 2.1/6.1: _send_command_sync retry loop (3 attempts, 150ms exp backoff)
  retrying on EAGAIN/ECONNREFUSED/ENOBUFS; immediate raise on all other errors.

Stage 2.2: asyncio.Lock at module level in jail_service.reload_all to
  serialize concurrent reload--all commands.

Stage 3.1: activate_jail re-queries _get_active_jail_names after reload;
  returns active=False with descriptive message if jail did not start.

Stage 4.1/6.2: asyncio.Semaphore (max 10) in Fail2BanClient.send, lazy-
  initialized; logs fail2ban_command_waiting_semaphore at debug when waiting.

Stage 5.1/5.2: unit tests asserting reload_all is called with include_jails
  and exclude_jails; activation verification happy/sad path tests.

Stage 6.3: TestSendCommandSyncRetry (5 cases) + TestFail2BanClientSemaphore
  concurrency test.

Stage 7.1-7.3: _since_unix uses time.time(); bans_by_jail debug logging with
  since_iso; diagnostic warning when total==0 despite table rows; unit test
  verifying the warning fires for stale data.
This commit is contained in:
2026-03-14 11:09:55 +01:00
parent 2274e20123
commit 2f2e5a7419
9 changed files with 880 additions and 115 deletions

View File

@@ -1005,3 +1005,38 @@ class TestBansByJail:
assert result.total == 3
assert len(result.jails) == 3
async def test_diagnostic_warning_when_zero_results_despite_data(
self, tmp_path: Path
) -> None:
"""A warning is logged when the time-range filter excludes all existing rows."""
import time as _time
# Insert rows with timeofban far in the past (outside any range window).
far_past = int(_time.time()) - 400 * 24 * 3600 # ~400 days ago
path = str(tmp_path / "test_diag.sqlite3")
await _create_f2b_db(
path,
[
{"jail": "sshd", "ip": "1.1.1.1", "timeofban": far_past},
],
)
with (
patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=path),
),
patch("app.services.ban_service.log") as mock_log,
):
result = await ban_service.bans_by_jail("/fake/sock", "24h")
assert result.total == 0
assert result.jails == []
# The diagnostic warning must have been emitted.
warning_calls = [
c
for c in mock_log.warning.call_args_list
if c[0][0] == "ban_service_bans_by_jail_empty_despite_data"
]
assert len(warning_calls) == 1

View File

@@ -440,7 +440,7 @@ class TestActivateJail:
with (
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
new=AsyncMock(side_effect=[set(), {"apache-auth"}]),
),
patch("app.services.config_file_service.jail_service") as mock_js,
):
@@ -2491,3 +2491,112 @@ class TestRemoveActionFromJail:
mock_reload.assert_awaited_once()
# ---------------------------------------------------------------------------
# activate_jail — reload_all keyword argument assertions (Stage 5.1)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestActivateJailReloadArgs:
"""Verify activate_jail calls reload_all with include_jails=[name]."""
async def test_activate_passes_include_jails(self, tmp_path: Path) -> None:
"""activate_jail must pass include_jails=[name] to reload_all."""
_write(tmp_path / "jail.conf", JAIL_CONF)
from app.models.config import ActivateJailRequest
req = ActivateJailRequest()
with (
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(side_effect=[set(), {"apache-auth"}]),
),
patch("app.services.config_file_service.jail_service") as mock_js,
):
mock_js.reload_all = AsyncMock()
await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
mock_js.reload_all.assert_awaited_once_with(
"/fake.sock", include_jails=["apache-auth"]
)
async def test_activate_returns_active_true_when_jail_starts(
self, tmp_path: Path
) -> None:
"""activate_jail returns active=True when the jail appears in post-reload names."""
_write(tmp_path / "jail.conf", JAIL_CONF)
from app.models.config import ActivateJailRequest
req = ActivateJailRequest()
with (
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(side_effect=[set(), {"apache-auth"}]),
),
patch("app.services.config_file_service.jail_service") as mock_js,
):
mock_js.reload_all = AsyncMock()
result = await activate_jail(
str(tmp_path), "/fake.sock", "apache-auth", req
)
assert result.active is True
assert "activated" in result.message.lower()
async def test_activate_returns_active_false_when_jail_does_not_start(
self, tmp_path: Path
) -> None:
"""activate_jail returns active=False when the jail is absent after reload.
This covers the Stage 3.1 requirement: if the jail config is invalid
(bad regex, missing log file, etc.) fail2ban may silently refuse to
start the jail even though the reload command succeeded.
"""
_write(tmp_path / "jail.conf", JAIL_CONF)
from app.models.config import ActivateJailRequest
req = ActivateJailRequest()
# Pre-reload: jail not running. Post-reload: still not running (boot failed).
with (
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(side_effect=[set(), set()]),
),
patch("app.services.config_file_service.jail_service") as mock_js,
):
mock_js.reload_all = AsyncMock()
result = await activate_jail(
str(tmp_path), "/fake.sock", "apache-auth", req
)
assert result.active is False
assert "apache-auth" in result.name
# ---------------------------------------------------------------------------
# deactivate_jail — reload_all keyword argument assertions (Stage 5.2)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestDeactivateJailReloadArgs:
"""Verify deactivate_jail calls reload_all with exclude_jails=[name]."""
async def test_deactivate_passes_exclude_jails(self, tmp_path: Path) -> None:
"""deactivate_jail must pass exclude_jails=[name] to reload_all."""
_write(tmp_path / "jail.conf", JAIL_CONF)
with (
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value={"sshd"}),
),
patch("app.services.config_file_service.jail_service") as mock_js,
):
mock_js.reload_all = AsyncMock()
await deactivate_jail(str(tmp_path), "/fake.sock", "sshd")
mock_js.reload_all.assert_awaited_once_with(
"/fake.sock", exclude_jails=["sshd"]
)

View File

@@ -1,5 +1,6 @@
"""Tests for app.utils.fail2ban_client."""
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -287,6 +288,21 @@ class TestFail2BanClientSend:
with pytest.raises(Fail2BanProtocolError):
await client.send(["status"])
@pytest.mark.asyncio
async def test_send_raises_on_protocol_error(self) -> None:
"""``send()`` must propagate :class:`Fail2BanProtocolError` to the caller."""
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
with patch("asyncio.get_event_loop") as mock_get_loop:
mock_loop = AsyncMock()
mock_loop.run_in_executor = AsyncMock(
side_effect=Fail2BanProtocolError("bad pickle")
)
mock_get_loop.return_value = mock_loop
with pytest.raises(Fail2BanProtocolError):
await client.send(["status"])
@pytest.mark.asyncio
async def test_send_logs_error_on_protocol_error(self) -> None:
"""``send()`` must log an error when a protocol error occurs."""
@@ -307,3 +323,202 @@ class TestFail2BanClientSend:
if c[0][0] == "fail2ban_protocol_error"
]
assert len(error_calls) == 1
# ---------------------------------------------------------------------------
# Tests for _send_command_sync retry logic (Stage 6.1 / 6.3)
# ---------------------------------------------------------------------------
class TestSendCommandSyncRetry:
"""Tests for the retry-on-transient-OSError logic in :func:`_send_command_sync`."""
def _make_sock(self) -> MagicMock:
"""Return a mock socket that connects without error."""
mock_sock = MagicMock()
mock_sock.connect.return_value = None
return mock_sock
def _eagain(self) -> OSError:
"""Return an ``OSError`` with ``errno.EAGAIN``."""
import errno as _errno
err = OSError("Resource temporarily unavailable")
err.errno = _errno.EAGAIN
return err
def _enoent(self) -> OSError:
"""Return an ``OSError`` with ``errno.ENOENT``."""
import errno as _errno
err = OSError("No such file or directory")
err.errno = _errno.ENOENT
return err
def test_transient_eagain_retried_succeeds_on_second_attempt(self) -> None:
"""A single EAGAIN on connect is retried; success on the second attempt."""
from app.utils.fail2ban_client import _PROTO_END
call_count = 0
def _connect_side_effect(sock_path: str) -> None:
nonlocal call_count
call_count += 1
if call_count == 1:
raise self._eagain()
# Second attempt succeeds (no-op).
mock_sock = self._make_sock()
mock_sock.connect.side_effect = _connect_side_effect
mock_sock.recv.return_value = _PROTO_END
expected = [0, "pong"]
with (
patch("socket.socket", return_value=mock_sock),
patch("app.utils.fail2ban_client.loads", return_value=expected),
patch("app.utils.fail2ban_client.time.sleep"), # suppress backoff delay
):
result = _send_command_sync("/fake.sock", ["ping"], 1.0)
assert result == expected
assert call_count == 2
def test_three_eagain_failures_raise_connection_error(self) -> None:
"""Three consecutive EAGAIN failures must raise :class:`Fail2BanConnectionError`."""
mock_sock = self._make_sock()
mock_sock.connect.side_effect = self._eagain()
with (
patch("socket.socket", return_value=mock_sock),
patch("app.utils.fail2ban_client.time.sleep"),
pytest.raises(Fail2BanConnectionError),
):
_send_command_sync("/fake.sock", ["status"], 1.0)
# connect() should have been called exactly _RETRY_MAX_ATTEMPTS times.
from app.utils.fail2ban_client import _RETRY_MAX_ATTEMPTS
assert mock_sock.connect.call_count == _RETRY_MAX_ATTEMPTS
def test_enoent_raises_immediately_without_retry(self) -> None:
"""A non-retryable ``OSError`` (``ENOENT``) must be raised on the first attempt."""
mock_sock = self._make_sock()
mock_sock.connect.side_effect = self._enoent()
with (
patch("socket.socket", return_value=mock_sock),
patch("app.utils.fail2ban_client.time.sleep") as mock_sleep,
pytest.raises(Fail2BanConnectionError),
):
_send_command_sync("/fake.sock", ["status"], 1.0)
# No back-off sleep should have been triggered.
mock_sleep.assert_not_called()
assert mock_sock.connect.call_count == 1
def test_protocol_error_never_retried(self) -> None:
"""A :class:`Fail2BanProtocolError` must be re-raised immediately."""
from app.utils.fail2ban_client import _PROTO_END
mock_sock = self._make_sock()
mock_sock.recv.return_value = _PROTO_END
with (
patch("socket.socket", return_value=mock_sock),
patch(
"app.utils.fail2ban_client.loads",
side_effect=Exception("bad pickle"),
),
patch("app.utils.fail2ban_client.time.sleep") as mock_sleep,
pytest.raises(Fail2BanProtocolError),
):
_send_command_sync("/fake.sock", ["status"], 1.0)
mock_sleep.assert_not_called()
def test_retry_emits_structured_log_event(self) -> None:
"""Each retry attempt logs a ``fail2ban_socket_retry`` warning."""
mock_sock = self._make_sock()
mock_sock.connect.side_effect = self._eagain()
with (
patch("socket.socket", return_value=mock_sock),
patch("app.utils.fail2ban_client.time.sleep"),
patch("app.utils.fail2ban_client.log") as mock_log,
pytest.raises(Fail2BanConnectionError),
):
_send_command_sync("/fake.sock", ["status"], 1.0)
retry_calls = [
c for c in mock_log.warning.call_args_list
if c[0][0] == "fail2ban_socket_retry"
]
from app.utils.fail2ban_client import _RETRY_MAX_ATTEMPTS
# One retry log per attempt except the last (which raises directly).
assert len(retry_calls) == _RETRY_MAX_ATTEMPTS - 1
# ---------------------------------------------------------------------------
# Tests for Fail2BanClient semaphore (Stage 6.2 / 6.3)
# ---------------------------------------------------------------------------
class TestFail2BanClientSemaphore:
"""Tests for the concurrency semaphore in :meth:`Fail2BanClient.send`."""
@pytest.mark.asyncio
async def test_semaphore_limits_concurrency(self) -> None:
"""No more than _COMMAND_SEMAPHORE_CONCURRENCY commands overlap."""
import asyncio as _asyncio
import app.utils.fail2ban_client as _module
# Reset module-level semaphore so this test starts fresh.
_module._command_semaphore = None
concurrency_limit = 3
_module._COMMAND_SEMAPHORE_CONCURRENCY = concurrency_limit
_module._command_semaphore = _asyncio.Semaphore(concurrency_limit)
in_flight: list[int] = []
peak_concurrent: list[int] = []
async def _slow_send(command: list[Any]) -> Any:
in_flight.append(1)
peak_concurrent.append(len(in_flight))
await _asyncio.sleep(0) # yield to allow other coroutines to run
in_flight.pop()
return (0, "ok")
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
with patch.object(client, "send", wraps=_slow_send) as _patched:
# Bypass the semaphore wrapper — test the actual send directly.
pass
# Override _command_semaphore and run concurrently via the real send path
# but mock _send_command_sync to avoid actual socket I/O.
async def _fast_executor(_fn: Any, *_args: Any) -> Any:
in_flight.append(1)
peak_concurrent.append(len(in_flight))
await _asyncio.sleep(0)
in_flight.pop()
return (0, "ok")
client2 = Fail2BanClient(socket_path="/fake/fail2ban.sock")
with patch("asyncio.get_event_loop") as mock_loop_getter:
mock_loop = MagicMock()
mock_loop.run_in_executor = _fast_executor
mock_loop_getter.return_value = mock_loop
tasks = [
_asyncio.create_task(client2.send(["ping"])) for _ in range(10)
]
await _asyncio.gather(*tasks)
# Peak concurrent activity must never exceed the semaphore limit.
assert max(peak_concurrent) <= concurrency_limit
# Restore module defaults after test.
_module._COMMAND_SEMAPHORE_CONCURRENCY = 10
_module._command_semaphore = None

View File

@@ -292,7 +292,7 @@ class TestJailControls:
with _patch_client(
{
"status": _make_global_status("sshd, nginx"),
"reload|--all|[]|[['start', 'sshd'], ['start', 'nginx']]": (0, "OK"),
"reload|--all|[]|[['start', 'nginx'], ['start', 'sshd']]": (0, "OK"),
}
):
await jail_service.reload_all(_SOCKET) # should not raise
@@ -307,6 +307,38 @@ class TestJailControls:
):
await jail_service.reload_all(_SOCKET) # should not raise
async def test_reload_all_include_jails(self) -> None:
"""reload_all with include_jails adds the new jail to the stream."""
with _patch_client(
{
"status": _make_global_status("sshd, nginx"),
"reload|--all|[]|[['start', 'apache-auth'], ['start', 'nginx'], ['start', 'sshd']]": (0, "OK"),
}
):
await jail_service.reload_all(_SOCKET, include_jails=["apache-auth"])
async def test_reload_all_exclude_jails(self) -> None:
"""reload_all with exclude_jails removes the jail from the stream."""
with _patch_client(
{
"status": _make_global_status("sshd, nginx"),
"reload|--all|[]|[['start', 'nginx']]": (0, "OK"),
}
):
await jail_service.reload_all(_SOCKET, exclude_jails=["sshd"])
async def test_reload_all_include_and_exclude(self) -> None:
"""reload_all with both include and exclude applies both correctly."""
with _patch_client(
{
"status": _make_global_status("old, nginx"),
"reload|--all|[]|[['start', 'new'], ['start', 'nginx']]": (0, "OK"),
}
):
await jail_service.reload_all(
_SOCKET, include_jails=["new"], exclude_jails=["old"]
)
async def test_start_not_found_raises(self) -> None:
"""start_jail raises JailNotFoundError for unknown jail."""
with _patch_client({"start|ghost": (1, Exception("Unknown jail: 'ghost'"))}), pytest.raises(JailNotFoundError):