Refactor fail2ban client to use vendored adapter

This commit is contained in:
2026-04-12 19:25:56 +02:00
parent 21b38365c4
commit e271207795
3 changed files with 162 additions and 148 deletions

View File

@@ -19,11 +19,11 @@ from __future__ import annotations
import asyncio
import contextlib
import errno
import socket
import sys
import time
from collections.abc import Mapping, Sequence, Set
from pickle import HIGHEST_PROTOCOL, dumps, loads
from typing import TYPE_CHECKING
from pathlib import Path
from typing import TYPE_CHECKING, Protocol
import structlog
@@ -69,12 +69,41 @@ if TYPE_CHECKING:
log: structlog.stdlib.BoundLogger = structlog.get_logger()
# fail2ban protocol constants — inline to avoid a hard import dependency
# at module load time (the fail2ban-master path may not be on sys.path yet
# in some test environments).
_PROTO_END: bytes = b"<F2B_END_COMMAND>"
_PROTO_CLOSE: bytes = b"<F2B_CLOSE_COMMAND>"
_PROTO_EMPTY: bytes = b""
# Attempt to reuse the vendored fail2ban package embedded in the repository.
# If it is not on sys.path yet, load it from ``../fail2ban-master``.
def _load_vendored_fail2ban_client() -> type[object]:
"""Import the vendored ``fail2ban.client.csocket.CSocket`` implementation."""
try:
from fail2ban.client.csocket import CSocket
return CSocket
except ImportError:
vendor_root = Path(__file__).resolve().parents[4] / "fail2ban-master"
if not vendor_root.is_dir():
raise
sys.path.insert(0, str(vendor_root))
try:
from fail2ban.client.csocket import CSocket
return CSocket
finally:
if sys.path and sys.path[0] == str(vendor_root):
sys.path.pop(0)
def _load_vendored_fail2ban_constants() -> tuple[bytes, bytes, bytes]:
"""Load fail2ban protocol constants from the vendored package if possible."""
try:
from fail2ban.protocol import CSPROTO
return CSPROTO.END, CSPROTO.CLOSE, CSPROTO.EMPTY
except ImportError:
return b"<F2B_END_COMMAND>", b"<F2B_CLOSE_COMMAND>", b""
_PROTO_END, _PROTO_CLOSE, _PROTO_EMPTY = _load_vendored_fail2ban_constants()
# Default receive buffer size (doubles on each iteration up to max).
_RECV_BUFSIZE_START: int = 1024
@@ -147,46 +176,14 @@ def _send_command_sync(
"""
last_oserror: OSError | None = None
for attempt in range(1, _RETRY_MAX_ATTEMPTS + 1):
sock: socket.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
client = None
try:
sock.settimeout(timeout)
sock.connect(socket_path)
# Serialise and send the command.
payload: bytes = dumps(
list(map(_coerce_command_token, command)),
HIGHEST_PROTOCOL,
)
sock.sendall(payload)
sock.sendall(_PROTO_END)
# Receive until we see the end marker.
raw: bytes = _PROTO_EMPTY
bufsize: int = _RECV_BUFSIZE_START
while raw.rfind(_PROTO_END, -32) == -1:
chunk: bytes = sock.recv(bufsize)
if not chunk:
raise Fail2BanConnectionError(
"Connection closed unexpectedly by fail2ban",
socket_path,
)
if chunk == _PROTO_END:
break
raw += chunk
if bufsize < _RECV_BUFSIZE_MAX:
bufsize <<= 1
try:
return loads(raw)
except Exception as exc:
raise Fail2BanProtocolError(
f"Failed to unpickle fail2ban response: {exc}"
) from exc
client_cls = _load_vendored_fail2ban_client()
client = client_cls(socket_path, timeout=timeout)
return client.send(command)
except Fail2BanProtocolError:
# Protocol errors are never transient — raise immediately.
raise
except Fail2BanConnectionError:
# Mid-receive close or empty-chunk error — raise immediately.
raise
except OSError as exc:
is_retryable = exc.errno in _RETRYABLE_ERRNOS
@@ -201,19 +198,41 @@ def _send_command_sync(
time.sleep(_RETRY_INITIAL_BACKOFF * (2 ** (attempt - 1)))
continue
raise Fail2BanConnectionError(str(exc), socket_path) from exc
except Exception as exc:
raise Fail2BanProtocolError(
f"Failed to parse fail2ban response: {exc}"
) from exc
finally:
with contextlib.suppress(OSError):
sock.sendall(_PROTO_CLOSE + _PROTO_END)
with contextlib.suppress(OSError):
sock.shutdown(socket.SHUT_RDWR)
sock.close()
if client is not None:
with contextlib.suppress(Exception):
client.close()
# Exhausted all retry attempts — surface the last transient error.
raise Fail2BanConnectionError(
str(last_oserror), socket_path
) from last_oserror
class Fail2BanAdapter(Protocol):
"""Protocol for a fail2ban socket adapter."""
async def send(self, command: Fail2BanCommand) -> object:
"""Send a command to fail2ban and return the raw response."""
async def ping(self) -> bool:
"""Return ``True`` if fail2ban is reachable."""
async def __aenter__(self) -> Fail2BanAdapter:
"""Enter the async context manager."""
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit the async context manager."""
def _coerce_command_token(token: object) -> Fail2BanToken:
"""Coerce a command token to a type that fail2ban understands.