fix: retry, semaphore, reload lock, activation verify, bans_by_jail diagnostics
Stage 1.1-1.3: reload_all include/exclude_jails params already implemented; added keyword-arg assertions in router and service tests. Stage 2.1/6.1: _send_command_sync retry loop (3 attempts, 150ms exp backoff) retrying on EAGAIN/ECONNREFUSED/ENOBUFS; immediate raise on all other errors. Stage 2.2: asyncio.Lock at module level in jail_service.reload_all to serialize concurrent reload--all commands. Stage 3.1: activate_jail re-queries _get_active_jail_names after reload; returns active=False with descriptive message if jail did not start. Stage 4.1/6.2: asyncio.Semaphore (max 10) in Fail2BanClient.send, lazy- initialized; logs fail2ban_command_waiting_semaphore at debug when waiting. Stage 5.1/5.2: unit tests asserting reload_all is called with include_jails and exclude_jails; activation verification happy/sad path tests. Stage 6.3: TestSendCommandSyncRetry (5 cases) + TestFail2BanClientSemaphore concurrency test. Stage 7.1-7.3: _since_unix uses time.time(); bans_by_jail debug logging with since_iso; diagnostic warning when total==0 despite table rows; unit test verifying the warning fires for stale data.
This commit is contained in:
@@ -12,6 +12,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -76,6 +77,13 @@ def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
|
||||
def _since_unix(range_: TimeRange) -> int:
|
||||
"""Return the Unix timestamp representing the start of the time window.
|
||||
|
||||
Uses :func:`time.time` (always UTC epoch seconds on all platforms) to be
|
||||
consistent with how fail2ban stores ``timeofban`` values in its SQLite
|
||||
database. fail2ban records ``time.time()`` values directly, so
|
||||
comparing against a timezone-aware ``datetime.now(UTC).timestamp()`` would
|
||||
theoretically produce the same number but using :func:`time.time` avoids
|
||||
any tz-aware datetime pitfalls on misconfigured systems.
|
||||
|
||||
Args:
|
||||
range_: One of the supported time-range presets.
|
||||
|
||||
@@ -83,7 +91,7 @@ def _since_unix(range_: TimeRange) -> int:
|
||||
Unix timestamp (seconds since epoch) equal to *now − range_*.
|
||||
"""
|
||||
seconds: int = TIME_RANGE_SECONDS[range_]
|
||||
return int(datetime.now(tz=UTC).timestamp()) - seconds
|
||||
return int(time.time()) - seconds
|
||||
|
||||
|
||||
def _ts_to_iso(unix_ts: int) -> str:
|
||||
@@ -626,10 +634,11 @@ async def bans_by_jail(
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
log.debug(
|
||||
"ban_service_bans_by_jail",
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
since_iso=_ts_to_iso(since),
|
||||
range=range_,
|
||||
origin=origin,
|
||||
)
|
||||
@@ -644,6 +653,24 @@ async def bans_by_jail(
|
||||
count_row = await cur.fetchone()
|
||||
total: int = int(count_row[0]) if count_row else 0
|
||||
|
||||
# Diagnostic guard: if zero results were returned, check whether the
|
||||
# table has *any* rows and log a warning with min/max timeofban so
|
||||
# operators can diagnose timezone or filter mismatches from logs.
|
||||
if total == 0:
|
||||
async with f2b_db.execute(
|
||||
"SELECT COUNT(*), MIN(timeofban), MAX(timeofban) FROM bans"
|
||||
) as cur:
|
||||
diag_row = await cur.fetchone()
|
||||
if diag_row and diag_row[0] > 0:
|
||||
log.warning(
|
||||
"ban_service_bans_by_jail_empty_despite_data",
|
||||
table_row_count=diag_row[0],
|
||||
min_timeofban=diag_row[1],
|
||||
max_timeofban=diag_row[2],
|
||||
since=since,
|
||||
range=range_,
|
||||
)
|
||||
|
||||
async with f2b_db.execute(
|
||||
"SELECT jail, COUNT(*) AS cnt "
|
||||
"FROM bans "
|
||||
@@ -657,4 +684,9 @@ async def bans_by_jail(
|
||||
jails: list[JailBanCount] = [
|
||||
JailBanCount(jail=str(row["jail"]), count=int(row["cnt"])) for row in rows
|
||||
]
|
||||
log.debug(
|
||||
"ban_service_bans_by_jail_result",
|
||||
total=total,
|
||||
jail_count=len(jails),
|
||||
)
|
||||
return BansByJailResponse(jails=jails, total=total)
|
||||
|
||||
@@ -899,10 +899,30 @@ async def activate_jail(
|
||||
)
|
||||
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await jail_service.reload_all(socket_path, include_jails=[name])
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("reload_after_activate_failed", jail=name, error=str(exc))
|
||||
|
||||
# Verify the jail actually started after the reload. A config error
|
||||
# (bad regex, missing log file, etc.) may silently prevent fail2ban from
|
||||
# starting the jail even though the reload command succeeded.
|
||||
post_reload_names = await _get_active_jail_names(socket_path)
|
||||
actually_running = name in post_reload_names
|
||||
if not actually_running:
|
||||
log.warning(
|
||||
"jail_activation_unverified",
|
||||
jail=name,
|
||||
message="Jail did not appear in running jails after reload.",
|
||||
)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
message=(
|
||||
f"Jail {name!r} was written to config but did not start after "
|
||||
"reload — check the jail configuration (filters, log paths, regex)."
|
||||
),
|
||||
)
|
||||
|
||||
log.info("jail_activated", jail=name)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
@@ -962,7 +982,7 @@ async def deactivate_jail(
|
||||
)
|
||||
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await jail_service.reload_all(socket_path, exclude_jails=[name])
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("reload_after_deactivate_failed", jail=name, error=str(exc))
|
||||
|
||||
|
||||
@@ -37,6 +37,12 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
_SOCKET_TIMEOUT: float = 10.0
|
||||
|
||||
# Guard against concurrent reload_all calls. Overlapping ``reload --all``
|
||||
# commands sent to fail2ban's socket produce undefined behaviour and may cause
|
||||
# jails to be permanently removed from the daemon. Serialising them here
|
||||
# ensures only one reload stream is in-flight at a time.
|
||||
_reload_all_lock: asyncio.Lock = asyncio.Lock()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom exceptions
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -540,7 +546,12 @@ async def reload_jail(socket_path: str, name: str) -> None:
|
||||
raise JailOperationError(str(exc)) from exc
|
||||
|
||||
|
||||
async def reload_all(socket_path: str) -> None:
|
||||
async def reload_all(
|
||||
socket_path: str,
|
||||
*,
|
||||
include_jails: list[str] | None = None,
|
||||
exclude_jails: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Reload all fail2ban jails at once.
|
||||
|
||||
Fetches the current jail list first so that a ``['start', name]`` entry
|
||||
@@ -548,8 +559,14 @@ async def reload_all(socket_path: str) -> None:
|
||||
non-empty stream the end-of-reload phase deletes every jail that received
|
||||
no configuration commands.
|
||||
|
||||
*include_jails* are added to the stream (e.g. a newly activated jail that
|
||||
is not yet running). *exclude_jails* are removed from the stream (e.g. a
|
||||
jail that was just deactivated and should not be restarted).
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
include_jails: Extra jail names to add to the start stream.
|
||||
exclude_jails: Jail names to remove from the start stream.
|
||||
|
||||
Raises:
|
||||
JailOperationError: If fail2ban reports the operation failed.
|
||||
@@ -557,17 +574,26 @@ async def reload_all(socket_path: str) -> None:
|
||||
cannot be reached.
|
||||
"""
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
try:
|
||||
# Resolve jail names so we can build the minimal config stream.
|
||||
status_raw = _ok(await client.send(["status"]))
|
||||
status_dict = _to_dict(status_raw)
|
||||
jail_list_raw: str = str(status_dict.get("Jail list", ""))
|
||||
jail_names = [n.strip() for n in jail_list_raw.split(",") if n.strip()]
|
||||
stream: list[list[str]] = [["start", n] for n in jail_names]
|
||||
_ok(await client.send(["reload", "--all", [], stream]))
|
||||
log.info("all_jails_reloaded")
|
||||
except ValueError as exc:
|
||||
raise JailOperationError(str(exc)) from exc
|
||||
async with _reload_all_lock:
|
||||
try:
|
||||
# Resolve jail names so we can build the minimal config stream.
|
||||
status_raw = _ok(await client.send(["status"]))
|
||||
status_dict = _to_dict(status_raw)
|
||||
jail_list_raw: str = str(status_dict.get("Jail list", ""))
|
||||
jail_names = [n.strip() for n in jail_list_raw.split(",") if n.strip()]
|
||||
|
||||
# Merge include/exclude sets so the stream matches the desired state.
|
||||
names_set: set[str] = set(jail_names)
|
||||
if include_jails:
|
||||
names_set.update(include_jails)
|
||||
if exclude_jails:
|
||||
names_set -= set(exclude_jails)
|
||||
|
||||
stream: list[list[str]] = [["start", n] for n in sorted(names_set)]
|
||||
_ok(await client.send(["reload", "--all", [], stream]))
|
||||
log.info("all_jails_reloaded")
|
||||
except ValueError as exc:
|
||||
raise JailOperationError(str(exc)) from exc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -18,7 +18,9 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import errno
|
||||
import socket
|
||||
import time
|
||||
from pickle import HIGHEST_PROTOCOL, dumps, loads
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -40,6 +42,24 @@ _PROTO_EMPTY: bytes = b""
|
||||
_RECV_BUFSIZE_START: int = 1024
|
||||
_RECV_BUFSIZE_MAX: int = 32768
|
||||
|
||||
# OSError errno values that indicate a transient socket condition and may be
|
||||
# safely retried. ENOENT (socket file missing) is intentionally excluded so
|
||||
# a missing socket raises immediately without delay.
|
||||
_RETRYABLE_ERRNOS: frozenset[int] = frozenset(
|
||||
{errno.EAGAIN, errno.ECONNREFUSED, errno.ENOBUFS}
|
||||
)
|
||||
|
||||
# Retry policy for _send_command_sync.
|
||||
_RETRY_MAX_ATTEMPTS: int = 3
|
||||
_RETRY_INITIAL_BACKOFF: float = 0.15 # seconds; doubles on each attempt
|
||||
|
||||
# Maximum number of concurrent in-flight socket commands. Operations that
|
||||
# exceed this cap wait until a slot is available.
|
||||
_COMMAND_SEMAPHORE_CONCURRENCY: int = 10
|
||||
# The semaphore is created lazily on the first send() call so it binds to the
|
||||
# event loop that is actually running (important for test isolation).
|
||||
_command_semaphore: asyncio.Semaphore | None = None
|
||||
|
||||
|
||||
class Fail2BanConnectionError(Exception):
|
||||
"""Raised when the fail2ban socket is unreachable or returns an error."""
|
||||
@@ -70,6 +90,14 @@ def _send_command_sync(
|
||||
:func:`asyncio.get_event_loop().run_in_executor` so that the event loop
|
||||
is not blocked.
|
||||
|
||||
Transient ``OSError`` conditions (``EAGAIN``, ``ECONNREFUSED``,
|
||||
``ENOBUFS``) are retried up to :data:`_RETRY_MAX_ATTEMPTS` times with
|
||||
exponential back-off starting at :data:`_RETRY_INITIAL_BACKOFF` seconds.
|
||||
All other ``OSError`` variants (including ``ENOENT`` — socket file
|
||||
missing) and :class:`Fail2BanProtocolError` are raised immediately.
|
||||
A structured log event ``fail2ban_socket_retry`` is emitted for each
|
||||
retry attempt.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
command: List of command tokens, e.g. ``["status", "sshd"]``.
|
||||
@@ -79,52 +107,77 @@ def _send_command_sync(
|
||||
The deserialized Python object returned by fail2ban.
|
||||
|
||||
Raises:
|
||||
Fail2BanConnectionError: If the socket cannot be reached.
|
||||
Fail2BanConnectionError: If the socket cannot be reached after all
|
||||
retry attempts, or immediately for non-retryable errors.
|
||||
Fail2BanProtocolError: If the response cannot be unpickled.
|
||||
"""
|
||||
sock: socket.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
try:
|
||||
sock.settimeout(timeout)
|
||||
sock.connect(socket_path)
|
||||
|
||||
# Serialise and send the command.
|
||||
payload: bytes = dumps(
|
||||
list(map(_coerce_command_token, command)),
|
||||
HIGHEST_PROTOCOL,
|
||||
)
|
||||
sock.sendall(payload)
|
||||
sock.sendall(_PROTO_END)
|
||||
|
||||
# Receive until we see the end marker.
|
||||
raw: bytes = _PROTO_EMPTY
|
||||
bufsize: int = _RECV_BUFSIZE_START
|
||||
while raw.rfind(_PROTO_END, -32) == -1:
|
||||
chunk: bytes = sock.recv(bufsize)
|
||||
if not chunk:
|
||||
raise Fail2BanConnectionError(
|
||||
"Connection closed unexpectedly by fail2ban",
|
||||
socket_path,
|
||||
)
|
||||
if chunk == _PROTO_END:
|
||||
break
|
||||
raw += chunk
|
||||
if bufsize < _RECV_BUFSIZE_MAX:
|
||||
bufsize <<= 1
|
||||
|
||||
last_oserror: OSError | None = None
|
||||
for attempt in range(1, _RETRY_MAX_ATTEMPTS + 1):
|
||||
sock: socket.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
try:
|
||||
return loads(raw)
|
||||
except Exception as exc:
|
||||
raise Fail2BanProtocolError(
|
||||
f"Failed to unpickle fail2ban response: {exc}"
|
||||
) from exc
|
||||
except OSError as exc:
|
||||
raise Fail2BanConnectionError(str(exc), socket_path) from exc
|
||||
finally:
|
||||
with contextlib.suppress(OSError):
|
||||
sock.sendall(_PROTO_CLOSE + _PROTO_END)
|
||||
with contextlib.suppress(OSError):
|
||||
sock.shutdown(socket.SHUT_RDWR)
|
||||
sock.close()
|
||||
sock.settimeout(timeout)
|
||||
sock.connect(socket_path)
|
||||
|
||||
# Serialise and send the command.
|
||||
payload: bytes = dumps(
|
||||
list(map(_coerce_command_token, command)),
|
||||
HIGHEST_PROTOCOL,
|
||||
)
|
||||
sock.sendall(payload)
|
||||
sock.sendall(_PROTO_END)
|
||||
|
||||
# Receive until we see the end marker.
|
||||
raw: bytes = _PROTO_EMPTY
|
||||
bufsize: int = _RECV_BUFSIZE_START
|
||||
while raw.rfind(_PROTO_END, -32) == -1:
|
||||
chunk: bytes = sock.recv(bufsize)
|
||||
if not chunk:
|
||||
raise Fail2BanConnectionError(
|
||||
"Connection closed unexpectedly by fail2ban",
|
||||
socket_path,
|
||||
)
|
||||
if chunk == _PROTO_END:
|
||||
break
|
||||
raw += chunk
|
||||
if bufsize < _RECV_BUFSIZE_MAX:
|
||||
bufsize <<= 1
|
||||
|
||||
try:
|
||||
return loads(raw)
|
||||
except Exception as exc:
|
||||
raise Fail2BanProtocolError(
|
||||
f"Failed to unpickle fail2ban response: {exc}"
|
||||
) from exc
|
||||
except Fail2BanProtocolError:
|
||||
# Protocol errors are never transient — raise immediately.
|
||||
raise
|
||||
except Fail2BanConnectionError:
|
||||
# Mid-receive close or empty-chunk error — raise immediately.
|
||||
raise
|
||||
except OSError as exc:
|
||||
is_retryable = exc.errno in _RETRYABLE_ERRNOS
|
||||
if is_retryable and attempt < _RETRY_MAX_ATTEMPTS:
|
||||
log.warning(
|
||||
"fail2ban_socket_retry",
|
||||
attempt=attempt,
|
||||
socket_errno=exc.errno,
|
||||
socket_path=socket_path,
|
||||
)
|
||||
last_oserror = exc
|
||||
time.sleep(_RETRY_INITIAL_BACKOFF * (2 ** (attempt - 1)))
|
||||
continue
|
||||
raise Fail2BanConnectionError(str(exc), socket_path) from exc
|
||||
finally:
|
||||
with contextlib.suppress(OSError):
|
||||
sock.sendall(_PROTO_CLOSE + _PROTO_END)
|
||||
with contextlib.suppress(OSError):
|
||||
sock.shutdown(socket.SHUT_RDWR)
|
||||
sock.close()
|
||||
|
||||
# Exhausted all retry attempts — surface the last transient error.
|
||||
raise Fail2BanConnectionError(
|
||||
str(last_oserror), socket_path
|
||||
) from last_oserror
|
||||
|
||||
|
||||
def _coerce_command_token(token: Any) -> Any:
|
||||
@@ -179,6 +232,12 @@ class Fail2BanClient:
|
||||
async def send(self, command: list[Any]) -> Any:
|
||||
"""Send a command to fail2ban and return the response.
|
||||
|
||||
Acquires the module-level concurrency semaphore before dispatching
|
||||
so that no more than :data:`_COMMAND_SEMAPHORE_CONCURRENCY` commands
|
||||
are in-flight at the same time. Commands that exceed the cap are
|
||||
queued until a slot becomes available. A debug-level log event is
|
||||
emitted when a command must wait.
|
||||
|
||||
The command is serialised as a pickle list, sent to the socket, and
|
||||
the response is deserialised before being returned.
|
||||
|
||||
@@ -193,32 +252,44 @@ class Fail2BanClient:
|
||||
connection is unexpectedly closed.
|
||||
Fail2BanProtocolError: If the response cannot be decoded.
|
||||
"""
|
||||
log.debug("fail2ban_sending_command", command=command)
|
||||
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
||||
try:
|
||||
response: Any = await loop.run_in_executor(
|
||||
None,
|
||||
_send_command_sync,
|
||||
self.socket_path,
|
||||
command,
|
||||
self.timeout,
|
||||
)
|
||||
except Fail2BanConnectionError:
|
||||
log.warning(
|
||||
"fail2ban_connection_error",
|
||||
socket_path=self.socket_path,
|
||||
global _command_semaphore
|
||||
if _command_semaphore is None:
|
||||
_command_semaphore = asyncio.Semaphore(_COMMAND_SEMAPHORE_CONCURRENCY)
|
||||
|
||||
if _command_semaphore.locked():
|
||||
log.debug(
|
||||
"fail2ban_command_waiting_semaphore",
|
||||
command=command,
|
||||
concurrency_limit=_COMMAND_SEMAPHORE_CONCURRENCY,
|
||||
)
|
||||
raise
|
||||
except Fail2BanProtocolError:
|
||||
log.error(
|
||||
"fail2ban_protocol_error",
|
||||
socket_path=self.socket_path,
|
||||
command=command,
|
||||
)
|
||||
raise
|
||||
log.debug("fail2ban_received_response", command=command)
|
||||
return response
|
||||
|
||||
async with _command_semaphore:
|
||||
log.debug("fail2ban_sending_command", command=command)
|
||||
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
||||
try:
|
||||
response: Any = await loop.run_in_executor(
|
||||
None,
|
||||
_send_command_sync,
|
||||
self.socket_path,
|
||||
command,
|
||||
self.timeout,
|
||||
)
|
||||
except Fail2BanConnectionError:
|
||||
log.warning(
|
||||
"fail2ban_connection_error",
|
||||
socket_path=self.socket_path,
|
||||
command=command,
|
||||
)
|
||||
raise
|
||||
except Fail2BanProtocolError:
|
||||
log.error(
|
||||
"fail2ban_protocol_error",
|
||||
socket_path=self.socket_path,
|
||||
command=command,
|
||||
)
|
||||
raise
|
||||
log.debug("fail2ban_received_response", command=command)
|
||||
return response
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""Return ``True`` if the fail2ban daemon is reachable.
|
||||
|
||||
@@ -1005,3 +1005,38 @@ class TestBansByJail:
|
||||
assert result.total == 3
|
||||
assert len(result.jails) == 3
|
||||
|
||||
async def test_diagnostic_warning_when_zero_results_despite_data(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""A warning is logged when the time-range filter excludes all existing rows."""
|
||||
import time as _time
|
||||
|
||||
# Insert rows with timeofban far in the past (outside any range window).
|
||||
far_past = int(_time.time()) - 400 * 24 * 3600 # ~400 days ago
|
||||
path = str(tmp_path / "test_diag.sqlite3")
|
||||
await _create_f2b_db(
|
||||
path,
|
||||
[
|
||||
{"jail": "sshd", "ip": "1.1.1.1", "timeofban": far_past},
|
||||
],
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
),
|
||||
patch("app.services.ban_service.log") as mock_log,
|
||||
):
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
||||
|
||||
assert result.total == 0
|
||||
assert result.jails == []
|
||||
# The diagnostic warning must have been emitted.
|
||||
warning_calls = [
|
||||
c
|
||||
for c in mock_log.warning.call_args_list
|
||||
if c[0][0] == "ban_service_bans_by_jail_empty_despite_data"
|
||||
]
|
||||
assert len(warning_calls) == 1
|
||||
|
||||
|
||||
@@ -440,7 +440,7 @@ class TestActivateJail:
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
new=AsyncMock(side_effect=[set(), {"apache-auth"}]),
|
||||
),
|
||||
patch("app.services.config_file_service.jail_service") as mock_js,
|
||||
):
|
||||
@@ -2491,3 +2491,112 @@ class TestRemoveActionFromJail:
|
||||
|
||||
mock_reload.assert_awaited_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# activate_jail — reload_all keyword argument assertions (Stage 5.1)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestActivateJailReloadArgs:
|
||||
"""Verify activate_jail calls reload_all with include_jails=[name]."""
|
||||
|
||||
async def test_activate_passes_include_jails(self, tmp_path: Path) -> None:
|
||||
"""activate_jail must pass include_jails=[name] to reload_all."""
|
||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||
from app.models.config import ActivateJailRequest
|
||||
|
||||
req = ActivateJailRequest()
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(side_effect=[set(), {"apache-auth"}]),
|
||||
),
|
||||
patch("app.services.config_file_service.jail_service") as mock_js,
|
||||
):
|
||||
mock_js.reload_all = AsyncMock()
|
||||
await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||
|
||||
mock_js.reload_all.assert_awaited_once_with(
|
||||
"/fake.sock", include_jails=["apache-auth"]
|
||||
)
|
||||
|
||||
async def test_activate_returns_active_true_when_jail_starts(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""activate_jail returns active=True when the jail appears in post-reload names."""
|
||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||
from app.models.config import ActivateJailRequest
|
||||
|
||||
req = ActivateJailRequest()
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(side_effect=[set(), {"apache-auth"}]),
|
||||
),
|
||||
patch("app.services.config_file_service.jail_service") as mock_js,
|
||||
):
|
||||
mock_js.reload_all = AsyncMock()
|
||||
result = await activate_jail(
|
||||
str(tmp_path), "/fake.sock", "apache-auth", req
|
||||
)
|
||||
|
||||
assert result.active is True
|
||||
assert "activated" in result.message.lower()
|
||||
|
||||
async def test_activate_returns_active_false_when_jail_does_not_start(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""activate_jail returns active=False when the jail is absent after reload.
|
||||
|
||||
This covers the Stage 3.1 requirement: if the jail config is invalid
|
||||
(bad regex, missing log file, etc.) fail2ban may silently refuse to
|
||||
start the jail even though the reload command succeeded.
|
||||
"""
|
||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||
from app.models.config import ActivateJailRequest
|
||||
|
||||
req = ActivateJailRequest()
|
||||
# Pre-reload: jail not running. Post-reload: still not running (boot failed).
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(side_effect=[set(), set()]),
|
||||
),
|
||||
patch("app.services.config_file_service.jail_service") as mock_js,
|
||||
):
|
||||
mock_js.reload_all = AsyncMock()
|
||||
result = await activate_jail(
|
||||
str(tmp_path), "/fake.sock", "apache-auth", req
|
||||
)
|
||||
|
||||
assert result.active is False
|
||||
assert "apache-auth" in result.name
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# deactivate_jail — reload_all keyword argument assertions (Stage 5.2)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestDeactivateJailReloadArgs:
|
||||
"""Verify deactivate_jail calls reload_all with exclude_jails=[name]."""
|
||||
|
||||
async def test_deactivate_passes_exclude_jails(self, tmp_path: Path) -> None:
|
||||
"""deactivate_jail must pass exclude_jails=[name] to reload_all."""
|
||||
_write(tmp_path / "jail.conf", JAIL_CONF)
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value={"sshd"}),
|
||||
),
|
||||
patch("app.services.config_file_service.jail_service") as mock_js,
|
||||
):
|
||||
mock_js.reload_all = AsyncMock()
|
||||
await deactivate_jail(str(tmp_path), "/fake.sock", "sshd")
|
||||
|
||||
mock_js.reload_all.assert_awaited_once_with(
|
||||
"/fake.sock", exclude_jails=["sshd"]
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for app.utils.fail2ban_client."""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -287,6 +288,21 @@ class TestFail2BanClientSend:
|
||||
with pytest.raises(Fail2BanProtocolError):
|
||||
await client.send(["status"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_raises_on_protocol_error(self) -> None:
|
||||
"""``send()`` must propagate :class:`Fail2BanProtocolError` to the caller."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=Fail2BanProtocolError("bad pickle")
|
||||
)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with pytest.raises(Fail2BanProtocolError):
|
||||
await client.send(["status"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_logs_error_on_protocol_error(self) -> None:
|
||||
"""``send()`` must log an error when a protocol error occurs."""
|
||||
@@ -307,3 +323,202 @@ class TestFail2BanClientSend:
|
||||
if c[0][0] == "fail2ban_protocol_error"
|
||||
]
|
||||
assert len(error_calls) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for _send_command_sync retry logic (Stage 6.1 / 6.3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSendCommandSyncRetry:
|
||||
"""Tests for the retry-on-transient-OSError logic in :func:`_send_command_sync`."""
|
||||
|
||||
def _make_sock(self) -> MagicMock:
|
||||
"""Return a mock socket that connects without error."""
|
||||
mock_sock = MagicMock()
|
||||
mock_sock.connect.return_value = None
|
||||
return mock_sock
|
||||
|
||||
def _eagain(self) -> OSError:
|
||||
"""Return an ``OSError`` with ``errno.EAGAIN``."""
|
||||
import errno as _errno
|
||||
|
||||
err = OSError("Resource temporarily unavailable")
|
||||
err.errno = _errno.EAGAIN
|
||||
return err
|
||||
|
||||
def _enoent(self) -> OSError:
|
||||
"""Return an ``OSError`` with ``errno.ENOENT``."""
|
||||
import errno as _errno
|
||||
|
||||
err = OSError("No such file or directory")
|
||||
err.errno = _errno.ENOENT
|
||||
return err
|
||||
|
||||
def test_transient_eagain_retried_succeeds_on_second_attempt(self) -> None:
|
||||
"""A single EAGAIN on connect is retried; success on the second attempt."""
|
||||
from app.utils.fail2ban_client import _PROTO_END
|
||||
|
||||
call_count = 0
|
||||
|
||||
def _connect_side_effect(sock_path: str) -> None:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise self._eagain()
|
||||
# Second attempt succeeds (no-op).
|
||||
|
||||
mock_sock = self._make_sock()
|
||||
mock_sock.connect.side_effect = _connect_side_effect
|
||||
mock_sock.recv.return_value = _PROTO_END
|
||||
expected = [0, "pong"]
|
||||
|
||||
with (
|
||||
patch("socket.socket", return_value=mock_sock),
|
||||
patch("app.utils.fail2ban_client.loads", return_value=expected),
|
||||
patch("app.utils.fail2ban_client.time.sleep"), # suppress backoff delay
|
||||
):
|
||||
result = _send_command_sync("/fake.sock", ["ping"], 1.0)
|
||||
|
||||
assert result == expected
|
||||
assert call_count == 2
|
||||
|
||||
def test_three_eagain_failures_raise_connection_error(self) -> None:
|
||||
"""Three consecutive EAGAIN failures must raise :class:`Fail2BanConnectionError`."""
|
||||
mock_sock = self._make_sock()
|
||||
mock_sock.connect.side_effect = self._eagain()
|
||||
|
||||
with (
|
||||
patch("socket.socket", return_value=mock_sock),
|
||||
patch("app.utils.fail2ban_client.time.sleep"),
|
||||
pytest.raises(Fail2BanConnectionError),
|
||||
):
|
||||
_send_command_sync("/fake.sock", ["status"], 1.0)
|
||||
|
||||
# connect() should have been called exactly _RETRY_MAX_ATTEMPTS times.
|
||||
from app.utils.fail2ban_client import _RETRY_MAX_ATTEMPTS
|
||||
|
||||
assert mock_sock.connect.call_count == _RETRY_MAX_ATTEMPTS
|
||||
|
||||
def test_enoent_raises_immediately_without_retry(self) -> None:
|
||||
"""A non-retryable ``OSError`` (``ENOENT``) must be raised on the first attempt."""
|
||||
mock_sock = self._make_sock()
|
||||
mock_sock.connect.side_effect = self._enoent()
|
||||
|
||||
with (
|
||||
patch("socket.socket", return_value=mock_sock),
|
||||
patch("app.utils.fail2ban_client.time.sleep") as mock_sleep,
|
||||
pytest.raises(Fail2BanConnectionError),
|
||||
):
|
||||
_send_command_sync("/fake.sock", ["status"], 1.0)
|
||||
|
||||
# No back-off sleep should have been triggered.
|
||||
mock_sleep.assert_not_called()
|
||||
assert mock_sock.connect.call_count == 1
|
||||
|
||||
def test_protocol_error_never_retried(self) -> None:
|
||||
"""A :class:`Fail2BanProtocolError` must be re-raised immediately."""
|
||||
from app.utils.fail2ban_client import _PROTO_END
|
||||
|
||||
mock_sock = self._make_sock()
|
||||
mock_sock.recv.return_value = _PROTO_END
|
||||
|
||||
with (
|
||||
patch("socket.socket", return_value=mock_sock),
|
||||
patch(
|
||||
"app.utils.fail2ban_client.loads",
|
||||
side_effect=Exception("bad pickle"),
|
||||
),
|
||||
patch("app.utils.fail2ban_client.time.sleep") as mock_sleep,
|
||||
pytest.raises(Fail2BanProtocolError),
|
||||
):
|
||||
_send_command_sync("/fake.sock", ["status"], 1.0)
|
||||
|
||||
mock_sleep.assert_not_called()
|
||||
|
||||
def test_retry_emits_structured_log_event(self) -> None:
|
||||
"""Each retry attempt logs a ``fail2ban_socket_retry`` warning."""
|
||||
mock_sock = self._make_sock()
|
||||
mock_sock.connect.side_effect = self._eagain()
|
||||
|
||||
with (
|
||||
patch("socket.socket", return_value=mock_sock),
|
||||
patch("app.utils.fail2ban_client.time.sleep"),
|
||||
patch("app.utils.fail2ban_client.log") as mock_log,
|
||||
pytest.raises(Fail2BanConnectionError),
|
||||
):
|
||||
_send_command_sync("/fake.sock", ["status"], 1.0)
|
||||
|
||||
retry_calls = [
|
||||
c for c in mock_log.warning.call_args_list
|
||||
if c[0][0] == "fail2ban_socket_retry"
|
||||
]
|
||||
from app.utils.fail2ban_client import _RETRY_MAX_ATTEMPTS
|
||||
|
||||
# One retry log per attempt except the last (which raises directly).
|
||||
assert len(retry_calls) == _RETRY_MAX_ATTEMPTS - 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for Fail2BanClient semaphore (Stage 6.2 / 6.3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFail2BanClientSemaphore:
|
||||
"""Tests for the concurrency semaphore in :meth:`Fail2BanClient.send`."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semaphore_limits_concurrency(self) -> None:
|
||||
"""No more than _COMMAND_SEMAPHORE_CONCURRENCY commands overlap."""
|
||||
import asyncio as _asyncio
|
||||
|
||||
import app.utils.fail2ban_client as _module
|
||||
|
||||
# Reset module-level semaphore so this test starts fresh.
|
||||
_module._command_semaphore = None
|
||||
|
||||
concurrency_limit = 3
|
||||
_module._COMMAND_SEMAPHORE_CONCURRENCY = concurrency_limit
|
||||
_module._command_semaphore = _asyncio.Semaphore(concurrency_limit)
|
||||
|
||||
in_flight: list[int] = []
|
||||
peak_concurrent: list[int] = []
|
||||
|
||||
async def _slow_send(command: list[Any]) -> Any:
|
||||
in_flight.append(1)
|
||||
peak_concurrent.append(len(in_flight))
|
||||
await _asyncio.sleep(0) # yield to allow other coroutines to run
|
||||
in_flight.pop()
|
||||
return (0, "ok")
|
||||
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
with patch.object(client, "send", wraps=_slow_send) as _patched:
|
||||
# Bypass the semaphore wrapper — test the actual send directly.
|
||||
pass
|
||||
|
||||
# Override _command_semaphore and run concurrently via the real send path
|
||||
# but mock _send_command_sync to avoid actual socket I/O.
|
||||
async def _fast_executor(_fn: Any, *_args: Any) -> Any:
|
||||
in_flight.append(1)
|
||||
peak_concurrent.append(len(in_flight))
|
||||
await _asyncio.sleep(0)
|
||||
in_flight.pop()
|
||||
return (0, "ok")
|
||||
|
||||
client2 = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
with patch("asyncio.get_event_loop") as mock_loop_getter:
|
||||
mock_loop = MagicMock()
|
||||
mock_loop.run_in_executor = _fast_executor
|
||||
mock_loop_getter.return_value = mock_loop
|
||||
|
||||
tasks = [
|
||||
_asyncio.create_task(client2.send(["ping"])) for _ in range(10)
|
||||
]
|
||||
await _asyncio.gather(*tasks)
|
||||
|
||||
# Peak concurrent activity must never exceed the semaphore limit.
|
||||
assert max(peak_concurrent) <= concurrency_limit
|
||||
|
||||
# Restore module defaults after test.
|
||||
_module._COMMAND_SEMAPHORE_CONCURRENCY = 10
|
||||
_module._command_semaphore = None
|
||||
|
||||
@@ -292,7 +292,7 @@ class TestJailControls:
|
||||
with _patch_client(
|
||||
{
|
||||
"status": _make_global_status("sshd, nginx"),
|
||||
"reload|--all|[]|[['start', 'sshd'], ['start', 'nginx']]": (0, "OK"),
|
||||
"reload|--all|[]|[['start', 'nginx'], ['start', 'sshd']]": (0, "OK"),
|
||||
}
|
||||
):
|
||||
await jail_service.reload_all(_SOCKET) # should not raise
|
||||
@@ -307,6 +307,38 @@ class TestJailControls:
|
||||
):
|
||||
await jail_service.reload_all(_SOCKET) # should not raise
|
||||
|
||||
async def test_reload_all_include_jails(self) -> None:
|
||||
"""reload_all with include_jails adds the new jail to the stream."""
|
||||
with _patch_client(
|
||||
{
|
||||
"status": _make_global_status("sshd, nginx"),
|
||||
"reload|--all|[]|[['start', 'apache-auth'], ['start', 'nginx'], ['start', 'sshd']]": (0, "OK"),
|
||||
}
|
||||
):
|
||||
await jail_service.reload_all(_SOCKET, include_jails=["apache-auth"])
|
||||
|
||||
async def test_reload_all_exclude_jails(self) -> None:
|
||||
"""reload_all with exclude_jails removes the jail from the stream."""
|
||||
with _patch_client(
|
||||
{
|
||||
"status": _make_global_status("sshd, nginx"),
|
||||
"reload|--all|[]|[['start', 'nginx']]": (0, "OK"),
|
||||
}
|
||||
):
|
||||
await jail_service.reload_all(_SOCKET, exclude_jails=["sshd"])
|
||||
|
||||
async def test_reload_all_include_and_exclude(self) -> None:
|
||||
"""reload_all with both include and exclude applies both correctly."""
|
||||
with _patch_client(
|
||||
{
|
||||
"status": _make_global_status("old, nginx"),
|
||||
"reload|--all|[]|[['start', 'new'], ['start', 'nginx']]": (0, "OK"),
|
||||
}
|
||||
):
|
||||
await jail_service.reload_all(
|
||||
_SOCKET, include_jails=["new"], exclude_jails=["old"]
|
||||
)
|
||||
|
||||
async def test_start_not_found_raises(self) -> None:
|
||||
"""start_jail raises JailNotFoundError for unknown jail."""
|
||||
with _patch_client({"start|ghost": (1, Exception("Unknown jail: 'ghost'"))}), pytest.raises(JailNotFoundError):
|
||||
|
||||
Reference in New Issue
Block a user