fix: retry, semaphore, reload lock, activation verify, bans_by_jail diagnostics
Stage 1.1-1.3: reload_all include/exclude_jails params already implemented; added keyword-arg assertions in router and service tests. Stage 2.1/6.1: _send_command_sync retry loop (3 attempts, 150ms exp backoff) retrying on EAGAIN/ECONNREFUSED/ENOBUFS; immediate raise on all other errors. Stage 2.2: asyncio.Lock at module level in jail_service.reload_all to serialize concurrent reload--all commands. Stage 3.1: activate_jail re-queries _get_active_jail_names after reload; returns active=False with descriptive message if jail did not start. Stage 4.1/6.2: asyncio.Semaphore (max 10) in Fail2BanClient.send, lazy- initialized; logs fail2ban_command_waiting_semaphore at debug when waiting. Stage 5.1/5.2: unit tests asserting reload_all is called with include_jails and exclude_jails; activation verification happy/sad path tests. Stage 6.3: TestSendCommandSyncRetry (5 cases) + TestFail2BanClientSemaphore concurrency test. Stage 7.1-7.3: _since_unix uses time.time(); bans_by_jail debug logging with since_iso; diagnostic warning when total==0 despite table rows; unit test verifying the warning fires for stale data.
This commit is contained in:
@@ -18,7 +18,9 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import errno
|
||||
import socket
|
||||
import time
|
||||
from pickle import HIGHEST_PROTOCOL, dumps, loads
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
@@ -40,6 +42,24 @@ _PROTO_EMPTY: bytes = b""
|
||||
_RECV_BUFSIZE_START: int = 1024
|
||||
_RECV_BUFSIZE_MAX: int = 32768
|
||||
|
||||
# OSError errno values that indicate a transient socket condition and may be
|
||||
# safely retried. ENOENT (socket file missing) is intentionally excluded so
|
||||
# a missing socket raises immediately without delay.
|
||||
_RETRYABLE_ERRNOS: frozenset[int] = frozenset(
|
||||
{errno.EAGAIN, errno.ECONNREFUSED, errno.ENOBUFS}
|
||||
)
|
||||
|
||||
# Retry policy for _send_command_sync.
|
||||
_RETRY_MAX_ATTEMPTS: int = 3
|
||||
_RETRY_INITIAL_BACKOFF: float = 0.15 # seconds; doubles on each attempt
|
||||
|
||||
# Maximum number of concurrent in-flight socket commands. Operations that
|
||||
# exceed this cap wait until a slot is available.
|
||||
_COMMAND_SEMAPHORE_CONCURRENCY: int = 10
|
||||
# The semaphore is created lazily on the first send() call so it binds to the
|
||||
# event loop that is actually running (important for test isolation).
|
||||
_command_semaphore: asyncio.Semaphore | None = None
|
||||
|
||||
|
||||
class Fail2BanConnectionError(Exception):
|
||||
"""Raised when the fail2ban socket is unreachable or returns an error."""
|
||||
@@ -70,6 +90,14 @@ def _send_command_sync(
|
||||
:func:`asyncio.get_event_loop().run_in_executor` so that the event loop
|
||||
is not blocked.
|
||||
|
||||
Transient ``OSError`` conditions (``EAGAIN``, ``ECONNREFUSED``,
|
||||
``ENOBUFS``) are retried up to :data:`_RETRY_MAX_ATTEMPTS` times with
|
||||
exponential back-off starting at :data:`_RETRY_INITIAL_BACKOFF` seconds.
|
||||
All other ``OSError`` variants (including ``ENOENT`` — socket file
|
||||
missing) and :class:`Fail2BanProtocolError` are raised immediately.
|
||||
A structured log event ``fail2ban_socket_retry`` is emitted for each
|
||||
retry attempt.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
command: List of command tokens, e.g. ``["status", "sshd"]``.
|
||||
@@ -79,52 +107,77 @@ def _send_command_sync(
|
||||
The deserialized Python object returned by fail2ban.
|
||||
|
||||
Raises:
|
||||
Fail2BanConnectionError: If the socket cannot be reached.
|
||||
Fail2BanConnectionError: If the socket cannot be reached after all
|
||||
retry attempts, or immediately for non-retryable errors.
|
||||
Fail2BanProtocolError: If the response cannot be unpickled.
|
||||
"""
|
||||
sock: socket.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
try:
|
||||
sock.settimeout(timeout)
|
||||
sock.connect(socket_path)
|
||||
|
||||
# Serialise and send the command.
|
||||
payload: bytes = dumps(
|
||||
list(map(_coerce_command_token, command)),
|
||||
HIGHEST_PROTOCOL,
|
||||
)
|
||||
sock.sendall(payload)
|
||||
sock.sendall(_PROTO_END)
|
||||
|
||||
# Receive until we see the end marker.
|
||||
raw: bytes = _PROTO_EMPTY
|
||||
bufsize: int = _RECV_BUFSIZE_START
|
||||
while raw.rfind(_PROTO_END, -32) == -1:
|
||||
chunk: bytes = sock.recv(bufsize)
|
||||
if not chunk:
|
||||
raise Fail2BanConnectionError(
|
||||
"Connection closed unexpectedly by fail2ban",
|
||||
socket_path,
|
||||
)
|
||||
if chunk == _PROTO_END:
|
||||
break
|
||||
raw += chunk
|
||||
if bufsize < _RECV_BUFSIZE_MAX:
|
||||
bufsize <<= 1
|
||||
|
||||
last_oserror: OSError | None = None
|
||||
for attempt in range(1, _RETRY_MAX_ATTEMPTS + 1):
|
||||
sock: socket.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
try:
|
||||
return loads(raw)
|
||||
except Exception as exc:
|
||||
raise Fail2BanProtocolError(
|
||||
f"Failed to unpickle fail2ban response: {exc}"
|
||||
) from exc
|
||||
except OSError as exc:
|
||||
raise Fail2BanConnectionError(str(exc), socket_path) from exc
|
||||
finally:
|
||||
with contextlib.suppress(OSError):
|
||||
sock.sendall(_PROTO_CLOSE + _PROTO_END)
|
||||
with contextlib.suppress(OSError):
|
||||
sock.shutdown(socket.SHUT_RDWR)
|
||||
sock.close()
|
||||
sock.settimeout(timeout)
|
||||
sock.connect(socket_path)
|
||||
|
||||
# Serialise and send the command.
|
||||
payload: bytes = dumps(
|
||||
list(map(_coerce_command_token, command)),
|
||||
HIGHEST_PROTOCOL,
|
||||
)
|
||||
sock.sendall(payload)
|
||||
sock.sendall(_PROTO_END)
|
||||
|
||||
# Receive until we see the end marker.
|
||||
raw: bytes = _PROTO_EMPTY
|
||||
bufsize: int = _RECV_BUFSIZE_START
|
||||
while raw.rfind(_PROTO_END, -32) == -1:
|
||||
chunk: bytes = sock.recv(bufsize)
|
||||
if not chunk:
|
||||
raise Fail2BanConnectionError(
|
||||
"Connection closed unexpectedly by fail2ban",
|
||||
socket_path,
|
||||
)
|
||||
if chunk == _PROTO_END:
|
||||
break
|
||||
raw += chunk
|
||||
if bufsize < _RECV_BUFSIZE_MAX:
|
||||
bufsize <<= 1
|
||||
|
||||
try:
|
||||
return loads(raw)
|
||||
except Exception as exc:
|
||||
raise Fail2BanProtocolError(
|
||||
f"Failed to unpickle fail2ban response: {exc}"
|
||||
) from exc
|
||||
except Fail2BanProtocolError:
|
||||
# Protocol errors are never transient — raise immediately.
|
||||
raise
|
||||
except Fail2BanConnectionError:
|
||||
# Mid-receive close or empty-chunk error — raise immediately.
|
||||
raise
|
||||
except OSError as exc:
|
||||
is_retryable = exc.errno in _RETRYABLE_ERRNOS
|
||||
if is_retryable and attempt < _RETRY_MAX_ATTEMPTS:
|
||||
log.warning(
|
||||
"fail2ban_socket_retry",
|
||||
attempt=attempt,
|
||||
socket_errno=exc.errno,
|
||||
socket_path=socket_path,
|
||||
)
|
||||
last_oserror = exc
|
||||
time.sleep(_RETRY_INITIAL_BACKOFF * (2 ** (attempt - 1)))
|
||||
continue
|
||||
raise Fail2BanConnectionError(str(exc), socket_path) from exc
|
||||
finally:
|
||||
with contextlib.suppress(OSError):
|
||||
sock.sendall(_PROTO_CLOSE + _PROTO_END)
|
||||
with contextlib.suppress(OSError):
|
||||
sock.shutdown(socket.SHUT_RDWR)
|
||||
sock.close()
|
||||
|
||||
# Exhausted all retry attempts — surface the last transient error.
|
||||
raise Fail2BanConnectionError(
|
||||
str(last_oserror), socket_path
|
||||
) from last_oserror
|
||||
|
||||
|
||||
def _coerce_command_token(token: Any) -> Any:
|
||||
@@ -179,6 +232,12 @@ class Fail2BanClient:
|
||||
async def send(self, command: list[Any]) -> Any:
|
||||
"""Send a command to fail2ban and return the response.
|
||||
|
||||
Acquires the module-level concurrency semaphore before dispatching
|
||||
so that no more than :data:`_COMMAND_SEMAPHORE_CONCURRENCY` commands
|
||||
are in-flight at the same time. Commands that exceed the cap are
|
||||
queued until a slot becomes available. A debug-level log event is
|
||||
emitted when a command must wait.
|
||||
|
||||
The command is serialised as a pickle list, sent to the socket, and
|
||||
the response is deserialised before being returned.
|
||||
|
||||
@@ -193,32 +252,44 @@ class Fail2BanClient:
|
||||
connection is unexpectedly closed.
|
||||
Fail2BanProtocolError: If the response cannot be decoded.
|
||||
"""
|
||||
log.debug("fail2ban_sending_command", command=command)
|
||||
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
||||
try:
|
||||
response: Any = await loop.run_in_executor(
|
||||
None,
|
||||
_send_command_sync,
|
||||
self.socket_path,
|
||||
command,
|
||||
self.timeout,
|
||||
)
|
||||
except Fail2BanConnectionError:
|
||||
log.warning(
|
||||
"fail2ban_connection_error",
|
||||
socket_path=self.socket_path,
|
||||
global _command_semaphore
|
||||
if _command_semaphore is None:
|
||||
_command_semaphore = asyncio.Semaphore(_COMMAND_SEMAPHORE_CONCURRENCY)
|
||||
|
||||
if _command_semaphore.locked():
|
||||
log.debug(
|
||||
"fail2ban_command_waiting_semaphore",
|
||||
command=command,
|
||||
concurrency_limit=_COMMAND_SEMAPHORE_CONCURRENCY,
|
||||
)
|
||||
raise
|
||||
except Fail2BanProtocolError:
|
||||
log.error(
|
||||
"fail2ban_protocol_error",
|
||||
socket_path=self.socket_path,
|
||||
command=command,
|
||||
)
|
||||
raise
|
||||
log.debug("fail2ban_received_response", command=command)
|
||||
return response
|
||||
|
||||
async with _command_semaphore:
|
||||
log.debug("fail2ban_sending_command", command=command)
|
||||
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
||||
try:
|
||||
response: Any = await loop.run_in_executor(
|
||||
None,
|
||||
_send_command_sync,
|
||||
self.socket_path,
|
||||
command,
|
||||
self.timeout,
|
||||
)
|
||||
except Fail2BanConnectionError:
|
||||
log.warning(
|
||||
"fail2ban_connection_error",
|
||||
socket_path=self.socket_path,
|
||||
command=command,
|
||||
)
|
||||
raise
|
||||
except Fail2BanProtocolError:
|
||||
log.error(
|
||||
"fail2ban_protocol_error",
|
||||
socket_path=self.socket_path,
|
||||
command=command,
|
||||
)
|
||||
raise
|
||||
log.debug("fail2ban_received_response", command=command)
|
||||
return response
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""Return ``True`` if the fail2ban daemon is reachable.
|
||||
|
||||
Reference in New Issue
Block a user