fix: retry, semaphore, reload lock, activation verify, bans_by_jail diagnostics

Stage 1.1-1.3: reload_all include/exclude_jails params already implemented;
  added keyword-arg assertions in router and service tests.

Stage 2.1/6.1: _send_command_sync retry loop (3 attempts, 150ms exp backoff)
  retrying on EAGAIN/ECONNREFUSED/ENOBUFS; immediate raise on all other errors.

Stage 2.2: asyncio.Lock at module level in jail_service.reload_all to
  serialize concurrent reload--all commands.

Stage 3.1: activate_jail re-queries _get_active_jail_names after reload;
  returns active=False with descriptive message if jail did not start.

Stage 4.1/6.2: asyncio.Semaphore (max 10) in Fail2BanClient.send, lazy-
  initialized; logs fail2ban_command_waiting_semaphore at debug when waiting.

Stage 5.1/5.2: unit tests asserting reload_all is called with include_jails
  and exclude_jails; activation verification happy/sad path tests.

Stage 6.3: TestSendCommandSyncRetry (5 cases) + TestFail2BanClientSemaphore
  concurrency test.

Stage 7.1-7.3: _since_unix uses time.time(); bans_by_jail debug logging with
  since_iso; diagnostic warning when total==0 despite table rows; unit test
  verifying the warning fires for stale data.
This commit is contained in:
2026-03-14 11:09:55 +01:00
parent 2274e20123
commit 2f2e5a7419
9 changed files with 880 additions and 115 deletions

View File

@@ -12,6 +12,7 @@ from __future__ import annotations
import asyncio
import json
import time
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
@@ -76,6 +77,13 @@ def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
def _since_unix(range_: TimeRange) -> int:
"""Return the Unix timestamp representing the start of the time window.
Uses :func:`time.time` (always UTC epoch seconds on all platforms) to be
consistent with how fail2ban stores ``timeofban`` values in its SQLite
database. fail2ban records ``time.time()`` values directly, so
comparing against a timezone-aware ``datetime.now(UTC).timestamp()`` would
theoretically produce the same number but using :func:`time.time` avoids
any tz-aware datetime pitfalls on misconfigured systems.
Args:
range_: One of the supported time-range presets.
@@ -83,7 +91,7 @@ def _since_unix(range_: TimeRange) -> int:
Unix timestamp (seconds since epoch) equal to *now range_*.
"""
seconds: int = TIME_RANGE_SECONDS[range_]
return int(datetime.now(tz=UTC).timestamp()) - seconds
return int(time.time()) - seconds
def _ts_to_iso(unix_ts: int) -> str:
@@ -626,10 +634,11 @@ async def bans_by_jail(
origin_clause, origin_params = _origin_sql_filter(origin)
db_path: str = await _get_fail2ban_db_path(socket_path)
log.info(
log.debug(
"ban_service_bans_by_jail",
db_path=db_path,
since=since,
since_iso=_ts_to_iso(since),
range=range_,
origin=origin,
)
@@ -644,6 +653,24 @@ async def bans_by_jail(
count_row = await cur.fetchone()
total: int = int(count_row[0]) if count_row else 0
# Diagnostic guard: if zero results were returned, check whether the
# table has *any* rows and log a warning with min/max timeofban so
# operators can diagnose timezone or filter mismatches from logs.
if total == 0:
async with f2b_db.execute(
"SELECT COUNT(*), MIN(timeofban), MAX(timeofban) FROM bans"
) as cur:
diag_row = await cur.fetchone()
if diag_row and diag_row[0] > 0:
log.warning(
"ban_service_bans_by_jail_empty_despite_data",
table_row_count=diag_row[0],
min_timeofban=diag_row[1],
max_timeofban=diag_row[2],
since=since,
range=range_,
)
async with f2b_db.execute(
"SELECT jail, COUNT(*) AS cnt "
"FROM bans "
@@ -657,4 +684,9 @@ async def bans_by_jail(
jails: list[JailBanCount] = [
JailBanCount(jail=str(row["jail"]), count=int(row["cnt"])) for row in rows
]
log.debug(
"ban_service_bans_by_jail_result",
total=total,
jail_count=len(jails),
)
return BansByJailResponse(jails=jails, total=total)

View File

@@ -899,10 +899,30 @@ async def activate_jail(
)
try:
await jail_service.reload_all(socket_path)
await jail_service.reload_all(socket_path, include_jails=[name])
except Exception as exc: # noqa: BLE001
log.warning("reload_after_activate_failed", jail=name, error=str(exc))
# Verify the jail actually started after the reload. A config error
# (bad regex, missing log file, etc.) may silently prevent fail2ban from
# starting the jail even though the reload command succeeded.
post_reload_names = await _get_active_jail_names(socket_path)
actually_running = name in post_reload_names
if not actually_running:
log.warning(
"jail_activation_unverified",
jail=name,
message="Jail did not appear in running jails after reload.",
)
return JailActivationResponse(
name=name,
active=False,
message=(
f"Jail {name!r} was written to config but did not start after "
"reload — check the jail configuration (filters, log paths, regex)."
),
)
log.info("jail_activated", jail=name)
return JailActivationResponse(
name=name,
@@ -962,7 +982,7 @@ async def deactivate_jail(
)
try:
await jail_service.reload_all(socket_path)
await jail_service.reload_all(socket_path, exclude_jails=[name])
except Exception as exc: # noqa: BLE001
log.warning("reload_after_deactivate_failed", jail=name, error=str(exc))

View File

@@ -37,6 +37,12 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger()
_SOCKET_TIMEOUT: float = 10.0
# Guard against concurrent reload_all calls. Overlapping ``reload --all``
# commands sent to fail2ban's socket produce undefined behaviour and may cause
# jails to be permanently removed from the daemon. Serialising them here
# ensures only one reload stream is in-flight at a time.
_reload_all_lock: asyncio.Lock = asyncio.Lock()
# ---------------------------------------------------------------------------
# Custom exceptions
# ---------------------------------------------------------------------------
@@ -540,7 +546,12 @@ async def reload_jail(socket_path: str, name: str) -> None:
raise JailOperationError(str(exc)) from exc
async def reload_all(socket_path: str) -> None:
async def reload_all(
socket_path: str,
*,
include_jails: list[str] | None = None,
exclude_jails: list[str] | None = None,
) -> None:
"""Reload all fail2ban jails at once.
Fetches the current jail list first so that a ``['start', name]`` entry
@@ -548,8 +559,14 @@ async def reload_all(socket_path: str) -> None:
non-empty stream the end-of-reload phase deletes every jail that received
no configuration commands.
*include_jails* are added to the stream (e.g. a newly activated jail that
is not yet running). *exclude_jails* are removed from the stream (e.g. a
jail that was just deactivated and should not be restarted).
Args:
socket_path: Path to the fail2ban Unix domain socket.
include_jails: Extra jail names to add to the start stream.
exclude_jails: Jail names to remove from the start stream.
Raises:
JailOperationError: If fail2ban reports the operation failed.
@@ -557,17 +574,26 @@ async def reload_all(socket_path: str) -> None:
cannot be reached.
"""
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
try:
# Resolve jail names so we can build the minimal config stream.
status_raw = _ok(await client.send(["status"]))
status_dict = _to_dict(status_raw)
jail_list_raw: str = str(status_dict.get("Jail list", ""))
jail_names = [n.strip() for n in jail_list_raw.split(",") if n.strip()]
stream: list[list[str]] = [["start", n] for n in jail_names]
_ok(await client.send(["reload", "--all", [], stream]))
log.info("all_jails_reloaded")
except ValueError as exc:
raise JailOperationError(str(exc)) from exc
async with _reload_all_lock:
try:
# Resolve jail names so we can build the minimal config stream.
status_raw = _ok(await client.send(["status"]))
status_dict = _to_dict(status_raw)
jail_list_raw: str = str(status_dict.get("Jail list", ""))
jail_names = [n.strip() for n in jail_list_raw.split(",") if n.strip()]
# Merge include/exclude sets so the stream matches the desired state.
names_set: set[str] = set(jail_names)
if include_jails:
names_set.update(include_jails)
if exclude_jails:
names_set -= set(exclude_jails)
stream: list[list[str]] = [["start", n] for n in sorted(names_set)]
_ok(await client.send(["reload", "--all", [], stream]))
log.info("all_jails_reloaded")
except ValueError as exc:
raise JailOperationError(str(exc)) from exc
# ---------------------------------------------------------------------------

View File

@@ -18,7 +18,9 @@ from __future__ import annotations
import asyncio
import contextlib
import errno
import socket
import time
from pickle import HIGHEST_PROTOCOL, dumps, loads
from typing import TYPE_CHECKING, Any
@@ -40,6 +42,24 @@ _PROTO_EMPTY: bytes = b""
_RECV_BUFSIZE_START: int = 1024
_RECV_BUFSIZE_MAX: int = 32768
# OSError errno values that indicate a transient socket condition and may be
# safely retried. ENOENT (socket file missing) is intentionally excluded so
# a missing socket raises immediately without delay.
_RETRYABLE_ERRNOS: frozenset[int] = frozenset(
{errno.EAGAIN, errno.ECONNREFUSED, errno.ENOBUFS}
)
# Retry policy for _send_command_sync.
_RETRY_MAX_ATTEMPTS: int = 3
_RETRY_INITIAL_BACKOFF: float = 0.15 # seconds; doubles on each attempt
# Maximum number of concurrent in-flight socket commands. Operations that
# exceed this cap wait until a slot is available.
_COMMAND_SEMAPHORE_CONCURRENCY: int = 10
# The semaphore is created lazily on the first send() call so it binds to the
# event loop that is actually running (important for test isolation).
_command_semaphore: asyncio.Semaphore | None = None
class Fail2BanConnectionError(Exception):
"""Raised when the fail2ban socket is unreachable or returns an error."""
@@ -70,6 +90,14 @@ def _send_command_sync(
:func:`asyncio.get_event_loop().run_in_executor` so that the event loop
is not blocked.
Transient ``OSError`` conditions (``EAGAIN``, ``ECONNREFUSED``,
``ENOBUFS``) are retried up to :data:`_RETRY_MAX_ATTEMPTS` times with
exponential back-off starting at :data:`_RETRY_INITIAL_BACKOFF` seconds.
All other ``OSError`` variants (including ``ENOENT`` — socket file
missing) and :class:`Fail2BanProtocolError` are raised immediately.
A structured log event ``fail2ban_socket_retry`` is emitted for each
retry attempt.
Args:
socket_path: Path to the fail2ban Unix domain socket.
command: List of command tokens, e.g. ``["status", "sshd"]``.
@@ -79,52 +107,77 @@ def _send_command_sync(
The deserialized Python object returned by fail2ban.
Raises:
Fail2BanConnectionError: If the socket cannot be reached.
Fail2BanConnectionError: If the socket cannot be reached after all
retry attempts, or immediately for non-retryable errors.
Fail2BanProtocolError: If the response cannot be unpickled.
"""
sock: socket.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
sock.settimeout(timeout)
sock.connect(socket_path)
# Serialise and send the command.
payload: bytes = dumps(
list(map(_coerce_command_token, command)),
HIGHEST_PROTOCOL,
)
sock.sendall(payload)
sock.sendall(_PROTO_END)
# Receive until we see the end marker.
raw: bytes = _PROTO_EMPTY
bufsize: int = _RECV_BUFSIZE_START
while raw.rfind(_PROTO_END, -32) == -1:
chunk: bytes = sock.recv(bufsize)
if not chunk:
raise Fail2BanConnectionError(
"Connection closed unexpectedly by fail2ban",
socket_path,
)
if chunk == _PROTO_END:
break
raw += chunk
if bufsize < _RECV_BUFSIZE_MAX:
bufsize <<= 1
last_oserror: OSError | None = None
for attempt in range(1, _RETRY_MAX_ATTEMPTS + 1):
sock: socket.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
return loads(raw)
except Exception as exc:
raise Fail2BanProtocolError(
f"Failed to unpickle fail2ban response: {exc}"
) from exc
except OSError as exc:
raise Fail2BanConnectionError(str(exc), socket_path) from exc
finally:
with contextlib.suppress(OSError):
sock.sendall(_PROTO_CLOSE + _PROTO_END)
with contextlib.suppress(OSError):
sock.shutdown(socket.SHUT_RDWR)
sock.close()
sock.settimeout(timeout)
sock.connect(socket_path)
# Serialise and send the command.
payload: bytes = dumps(
list(map(_coerce_command_token, command)),
HIGHEST_PROTOCOL,
)
sock.sendall(payload)
sock.sendall(_PROTO_END)
# Receive until we see the end marker.
raw: bytes = _PROTO_EMPTY
bufsize: int = _RECV_BUFSIZE_START
while raw.rfind(_PROTO_END, -32) == -1:
chunk: bytes = sock.recv(bufsize)
if not chunk:
raise Fail2BanConnectionError(
"Connection closed unexpectedly by fail2ban",
socket_path,
)
if chunk == _PROTO_END:
break
raw += chunk
if bufsize < _RECV_BUFSIZE_MAX:
bufsize <<= 1
try:
return loads(raw)
except Exception as exc:
raise Fail2BanProtocolError(
f"Failed to unpickle fail2ban response: {exc}"
) from exc
except Fail2BanProtocolError:
# Protocol errors are never transient — raise immediately.
raise
except Fail2BanConnectionError:
# Mid-receive close or empty-chunk error — raise immediately.
raise
except OSError as exc:
is_retryable = exc.errno in _RETRYABLE_ERRNOS
if is_retryable and attempt < _RETRY_MAX_ATTEMPTS:
log.warning(
"fail2ban_socket_retry",
attempt=attempt,
socket_errno=exc.errno,
socket_path=socket_path,
)
last_oserror = exc
time.sleep(_RETRY_INITIAL_BACKOFF * (2 ** (attempt - 1)))
continue
raise Fail2BanConnectionError(str(exc), socket_path) from exc
finally:
with contextlib.suppress(OSError):
sock.sendall(_PROTO_CLOSE + _PROTO_END)
with contextlib.suppress(OSError):
sock.shutdown(socket.SHUT_RDWR)
sock.close()
# Exhausted all retry attempts — surface the last transient error.
raise Fail2BanConnectionError(
str(last_oserror), socket_path
) from last_oserror
def _coerce_command_token(token: Any) -> Any:
@@ -179,6 +232,12 @@ class Fail2BanClient:
async def send(self, command: list[Any]) -> Any:
"""Send a command to fail2ban and return the response.
Acquires the module-level concurrency semaphore before dispatching
so that no more than :data:`_COMMAND_SEMAPHORE_CONCURRENCY` commands
are in-flight at the same time. Commands that exceed the cap are
queued until a slot becomes available. A debug-level log event is
emitted when a command must wait.
The command is serialised as a pickle list, sent to the socket, and
the response is deserialised before being returned.
@@ -193,32 +252,44 @@ class Fail2BanClient:
connection is unexpectedly closed.
Fail2BanProtocolError: If the response cannot be decoded.
"""
log.debug("fail2ban_sending_command", command=command)
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
try:
response: Any = await loop.run_in_executor(
None,
_send_command_sync,
self.socket_path,
command,
self.timeout,
)
except Fail2BanConnectionError:
log.warning(
"fail2ban_connection_error",
socket_path=self.socket_path,
global _command_semaphore
if _command_semaphore is None:
_command_semaphore = asyncio.Semaphore(_COMMAND_SEMAPHORE_CONCURRENCY)
if _command_semaphore.locked():
log.debug(
"fail2ban_command_waiting_semaphore",
command=command,
concurrency_limit=_COMMAND_SEMAPHORE_CONCURRENCY,
)
raise
except Fail2BanProtocolError:
log.error(
"fail2ban_protocol_error",
socket_path=self.socket_path,
command=command,
)
raise
log.debug("fail2ban_received_response", command=command)
return response
async with _command_semaphore:
log.debug("fail2ban_sending_command", command=command)
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
try:
response: Any = await loop.run_in_executor(
None,
_send_command_sync,
self.socket_path,
command,
self.timeout,
)
except Fail2BanConnectionError:
log.warning(
"fail2ban_connection_error",
socket_path=self.socket_path,
command=command,
)
raise
except Fail2BanProtocolError:
log.error(
"fail2ban_protocol_error",
socket_path=self.socket_path,
command=command,
)
raise
log.debug("fail2ban_received_response", command=command)
return response
async def ping(self) -> bool:
"""Return ``True`` if the fail2ban daemon is reachable.