- Replace contextlib.suppress with try/except + warning log - Add test for fail2ban client - Remove stale Issue #21 from Tasks.md (indexes) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
360 lines
12 KiB
Python
360 lines
12 KiB
Python
"""Async wrapper around the fail2ban Unix domain socket protocol.
|
|
|
|
fail2ban uses a proprietary binary protocol over a Unix domain socket:
|
|
commands are transmitted as pickle-serialised Python lists and responses
|
|
are returned the same way. The protocol constants (``END``, ``CLOSE``)
|
|
come from ``fail2ban.protocol.CSPROTO``.
|
|
|
|
Because the underlying socket is blocking, all I/O is dispatched to a
|
|
thread-pool executor so the FastAPI event loop is never blocked.
|
|
|
|
Usage::
|
|
|
|
async with Fail2BanClient(socket_path="/var/run/fail2ban/fail2ban.sock") as client:
|
|
status = await client.send(["status"])
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import errno
|
|
import sys
|
|
import time
|
|
from collections.abc import Mapping, Sequence, Set
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Protocol
|
|
|
|
import structlog
|
|
|
|
from app.exceptions import Fail2BanConnectionError, Fail2BanProtocolError
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Types
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Use covariant container types so callers can pass ``list[int]`` / ``dict[str, str]``
|
|
# without needing to cast. At runtime we only accept the basic built-in
|
|
# containers supported by fail2ban's protocol (list/dict/set) and stringify
|
|
# anything else.
|
|
#
|
|
# NOTE: ``Sequence`` will also accept tuples, but tuples are stringified at
|
|
# runtime because fail2ban only understands lists.
|
|
|
|
type Fail2BanToken = (
|
|
str
|
|
| int
|
|
| float
|
|
| bool
|
|
| None
|
|
| Mapping[str, object]
|
|
| Sequence[object]
|
|
| Set[object]
|
|
)
|
|
"""A single token in a fail2ban command.
|
|
|
|
Fail2ban accepts simple types (str/int/float/bool) plus compound types
|
|
(list/dict/set). Complex objects are stringified before being sent.
|
|
"""
|
|
|
|
type Fail2BanCommand = Sequence[Fail2BanToken]
|
|
"""A command sent to fail2ban over the socket.
|
|
|
|
Commands are pickle serialised sequences of tokens.
|
|
"""
|
|
|
|
type Fail2BanResponse = tuple[int, object]
|
|
"""A typical fail2ban response containing a status code and payload."""
|
|
|
|
if TYPE_CHECKING:
|
|
from types import TracebackType
|
|
|
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
|
|
|
# Attempt to reuse the vendored fail2ban package embedded in the repository.
|
|
# If it is not on sys.path yet, load it from ``../fail2ban-master``.
|
|
|
|
def _load_vendored_fail2ban_client() -> type[object]:
|
|
"""Import the vendored ``fail2ban.client.csocket.CSocket`` implementation."""
|
|
try:
|
|
from fail2ban.client.csocket import CSocket
|
|
|
|
return CSocket
|
|
except ImportError:
|
|
vendor_root = Path(__file__).resolve().parents[4] / "fail2ban-master"
|
|
if not vendor_root.is_dir():
|
|
raise
|
|
|
|
sys.path.insert(0, str(vendor_root))
|
|
try:
|
|
from fail2ban.client.csocket import CSocket
|
|
|
|
return CSocket
|
|
finally:
|
|
if sys.path and sys.path[0] == str(vendor_root):
|
|
sys.path.pop(0)
|
|
|
|
|
|
def _load_vendored_fail2ban_constants() -> tuple[bytes, bytes, bytes]:
|
|
"""Load fail2ban protocol constants from the vendored package if possible."""
|
|
try:
|
|
from fail2ban.protocol import CSPROTO
|
|
|
|
return CSPROTO.END, CSPROTO.CLOSE, CSPROTO.EMPTY
|
|
except ImportError:
|
|
return b"<F2B_END_COMMAND>", b"<F2B_CLOSE_COMMAND>", b""
|
|
|
|
|
|
_PROTO_END, _PROTO_CLOSE, _PROTO_EMPTY = _load_vendored_fail2ban_constants()
|
|
|
|
# Default receive buffer size (doubles on each iteration up to max).
|
|
_RECV_BUFSIZE_START: int = 1024
|
|
_RECV_BUFSIZE_MAX: int = 32768
|
|
|
|
# OSError errno values that indicate a transient socket condition and may be
|
|
# safely retried. ENOENT (socket file missing) is intentionally excluded so
|
|
# a missing socket raises immediately without delay.
|
|
_RETRYABLE_ERRNOS: frozenset[int] = frozenset(
|
|
{errno.EAGAIN, errno.ECONNREFUSED, errno.ENOBUFS}
|
|
)
|
|
|
|
# Retry policy for _send_command_sync.
|
|
_RETRY_MAX_ATTEMPTS: int = 3
|
|
_RETRY_INITIAL_BACKOFF: float = 0.15 # seconds; doubles on each attempt
|
|
|
|
# Maximum number of concurrent in-flight socket commands per client.
|
|
# Operations that exceed this cap wait until a slot is available.
|
|
_COMMAND_SEMAPHORE_CONCURRENCY: int = 10
|
|
|
|
|
|
|
|
def _send_command_sync(
|
|
socket_path: str,
|
|
command: Fail2BanCommand,
|
|
timeout: float,
|
|
) -> object:
|
|
"""Send a command to fail2ban and return the parsed response.
|
|
|
|
This is a **synchronous** function intended to be executed via
|
|
:func:`asyncio.to_thread` so that the event loop is not blocked.
|
|
|
|
Transient ``OSError`` conditions (``EAGAIN``, ``ECONNREFUSED``,
|
|
``ENOBUFS``) are retried up to :data:`_RETRY_MAX_ATTEMPTS` times with
|
|
exponential back-off starting at :data:`_RETRY_INITIAL_BACKOFF` seconds.
|
|
All other ``OSError`` variants (including ``ENOENT`` — socket file
|
|
missing) and :class:`Fail2BanProtocolError` are raised immediately.
|
|
A structured log event ``fail2ban_socket_retry`` is emitted for each
|
|
retry attempt.
|
|
|
|
Args:
|
|
socket_path: Path to the fail2ban Unix domain socket.
|
|
command: List of command tokens, e.g. ``["status", "sshd"]``.
|
|
timeout: Socket timeout in seconds.
|
|
|
|
Returns:
|
|
The deserialized Python object returned by fail2ban.
|
|
|
|
Raises:
|
|
Fail2BanConnectionError: If the socket cannot be reached after all
|
|
retry attempts, or immediately for non-retryable errors.
|
|
Fail2BanProtocolError: If the response cannot be unpickled.
|
|
"""
|
|
last_oserror: OSError | None = None
|
|
for attempt in range(1, _RETRY_MAX_ATTEMPTS + 1):
|
|
client = None
|
|
try:
|
|
client_cls = _load_vendored_fail2ban_client()
|
|
client = client_cls(socket_path, timeout=timeout)
|
|
return client.send(command)
|
|
except Fail2BanProtocolError:
|
|
raise
|
|
except Fail2BanConnectionError:
|
|
raise
|
|
except OSError as exc:
|
|
is_retryable = exc.errno in _RETRYABLE_ERRNOS
|
|
if is_retryable and attempt < _RETRY_MAX_ATTEMPTS:
|
|
log.warning(
|
|
"fail2ban_socket_retry",
|
|
attempt=attempt,
|
|
socket_errno=exc.errno,
|
|
socket_path=socket_path,
|
|
)
|
|
last_oserror = exc
|
|
time.sleep(_RETRY_INITIAL_BACKOFF * (2 ** (attempt - 1)))
|
|
continue
|
|
raise Fail2BanConnectionError(str(exc), socket_path) from exc
|
|
except Exception as exc:
|
|
raise Fail2BanProtocolError(
|
|
f"Failed to parse fail2ban response: {exc}"
|
|
) from exc
|
|
finally:
|
|
if client is not None:
|
|
try:
|
|
client.close()
|
|
except Exception as e:
|
|
log.warning(
|
|
"fail2ban_socket_close_error",
|
|
socket_path=socket_path,
|
|
error=str(e),
|
|
exc_info=True,
|
|
)
|
|
|
|
raise Fail2BanConnectionError(
|
|
str(last_oserror), socket_path
|
|
) from last_oserror
|
|
|
|
|
|
class Fail2BanAdapter(Protocol):
|
|
"""Protocol for a fail2ban socket adapter."""
|
|
|
|
async def send(self, command: Fail2BanCommand) -> object:
|
|
"""Send a command to fail2ban and return the raw response."""
|
|
|
|
async def ping(self) -> bool:
|
|
"""Return ``True`` if fail2ban is reachable."""
|
|
|
|
async def __aenter__(self) -> Fail2BanAdapter:
|
|
"""Enter the async context manager."""
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> None:
|
|
"""Exit the async context manager."""
|
|
|
|
|
|
def _coerce_command_token(token: object) -> Fail2BanToken:
|
|
"""Coerce a command token to a type that fail2ban understands.
|
|
|
|
fail2ban's ``CSocket.convert`` accepts ``str``, ``bool``, ``int``,
|
|
``float``, ``list``, ``dict``, and ``set``. Any other type is
|
|
stringified.
|
|
|
|
Args:
|
|
token: A single token from the command list.
|
|
|
|
Returns:
|
|
The token in a type safe for pickle transmission to fail2ban.
|
|
"""
|
|
if isinstance(token, (str, bool, int, float, list, dict, set)):
|
|
return token
|
|
return str(token)
|
|
|
|
|
|
class Fail2BanClient:
|
|
"""Async client for communicating with the fail2ban daemon via its socket.
|
|
|
|
All blocking socket I/O is offloaded to the default thread-pool executor
|
|
so the asyncio event loop remains unblocked.
|
|
|
|
The client can be used as an async context manager::
|
|
|
|
async with Fail2BanClient(socket_path) as client:
|
|
result = await client.send(["status"])
|
|
|
|
Or instantiated directly and closed manually::
|
|
|
|
client = Fail2BanClient(socket_path)
|
|
result = await client.send(["status"])
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
socket_path: str,
|
|
timeout: float = 5.0,
|
|
) -> None:
|
|
"""Initialise the client.
|
|
|
|
Args:
|
|
socket_path: Path to the fail2ban Unix domain socket.
|
|
timeout: Socket I/O timeout in seconds.
|
|
"""
|
|
self.socket_path: str = socket_path
|
|
self.timeout: float = timeout
|
|
self._command_semaphore: asyncio.Semaphore = asyncio.Semaphore(
|
|
_COMMAND_SEMAPHORE_CONCURRENCY
|
|
)
|
|
|
|
async def send(self, command: Fail2BanCommand) -> object:
|
|
"""Send a command to fail2ban and return the response.
|
|
|
|
Acquires the module-level concurrency semaphore before dispatching
|
|
so that no more than :data:`_COMMAND_SEMAPHORE_CONCURRENCY` commands
|
|
are in-flight at the same time. Commands that exceed the cap are
|
|
queued until a slot becomes available. A debug-level log event is
|
|
emitted when a command must wait.
|
|
|
|
The command is serialised as a pickle list, sent to the socket, and
|
|
the response is deserialised before being returned.
|
|
|
|
Args:
|
|
command: A list of command tokens, e.g. ``["status", "sshd"]``.
|
|
|
|
Returns:
|
|
The Python object returned by fail2ban (typically a list or dict).
|
|
|
|
Raises:
|
|
Fail2BanConnectionError: If the socket cannot be reached or the
|
|
connection is unexpectedly closed.
|
|
Fail2BanProtocolError: If the response cannot be decoded.
|
|
"""
|
|
if self._command_semaphore.locked():
|
|
log.debug(
|
|
"fail2ban_command_waiting_semaphore",
|
|
command=command,
|
|
concurrency_limit=_COMMAND_SEMAPHORE_CONCURRENCY,
|
|
)
|
|
|
|
async with self._command_semaphore:
|
|
log.debug("fail2ban_sending_command", command=command)
|
|
try:
|
|
response: object = await asyncio.to_thread(
|
|
_send_command_sync,
|
|
self.socket_path,
|
|
command,
|
|
self.timeout,
|
|
)
|
|
except Fail2BanConnectionError:
|
|
log.warning(
|
|
"fail2ban_connection_error",
|
|
socket_path=self.socket_path,
|
|
command=command,
|
|
)
|
|
raise
|
|
except Fail2BanProtocolError:
|
|
log.error(
|
|
"fail2ban_protocol_error",
|
|
socket_path=self.socket_path,
|
|
command=command,
|
|
)
|
|
raise
|
|
log.debug("fail2ban_received_response", command=command)
|
|
return response
|
|
|
|
async def ping(self) -> bool:
|
|
"""Return ``True`` if the fail2ban daemon is reachable.
|
|
|
|
Sends a ``ping`` command and checks for a ``pong`` response.
|
|
|
|
Returns:
|
|
``True`` when the daemon responds correctly, ``False`` otherwise.
|
|
"""
|
|
try:
|
|
response: object = await self.send(["ping"])
|
|
return bool(response == 1) # fail2ban returns 1 on successful ping
|
|
except (Fail2BanConnectionError, Fail2BanProtocolError):
|
|
return False
|
|
|
|
async def __aenter__(self) -> Fail2BanClient:
|
|
"""Return self when used as an async context manager."""
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> None:
|
|
"""No-op exit — each command opens and closes its own socket."""
|