Refactor fail2ban client to use vendored adapter
This commit is contained in:
@@ -19,11 +19,11 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import contextlib
|
||||
import errno
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence, Set
|
||||
from pickle import HIGHEST_PROTOCOL, dumps, loads
|
||||
from typing import TYPE_CHECKING
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
import structlog
|
||||
|
||||
@@ -69,12 +69,41 @@ if TYPE_CHECKING:
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
# fail2ban protocol constants — inline to avoid a hard import dependency
|
||||
# at module load time (the fail2ban-master path may not be on sys.path yet
|
||||
# in some test environments).
|
||||
_PROTO_END: bytes = b"<F2B_END_COMMAND>"
|
||||
_PROTO_CLOSE: bytes = b"<F2B_CLOSE_COMMAND>"
|
||||
_PROTO_EMPTY: bytes = b""
|
||||
# Attempt to reuse the vendored fail2ban package embedded in the repository.
|
||||
# If it is not on sys.path yet, load it from ``../fail2ban-master``.
|
||||
|
||||
def _load_vendored_fail2ban_client() -> type[object]:
|
||||
"""Import the vendored ``fail2ban.client.csocket.CSocket`` implementation."""
|
||||
try:
|
||||
from fail2ban.client.csocket import CSocket
|
||||
|
||||
return CSocket
|
||||
except ImportError:
|
||||
vendor_root = Path(__file__).resolve().parents[4] / "fail2ban-master"
|
||||
if not vendor_root.is_dir():
|
||||
raise
|
||||
|
||||
sys.path.insert(0, str(vendor_root))
|
||||
try:
|
||||
from fail2ban.client.csocket import CSocket
|
||||
|
||||
return CSocket
|
||||
finally:
|
||||
if sys.path and sys.path[0] == str(vendor_root):
|
||||
sys.path.pop(0)
|
||||
|
||||
|
||||
def _load_vendored_fail2ban_constants() -> tuple[bytes, bytes, bytes]:
|
||||
"""Load fail2ban protocol constants from the vendored package if possible."""
|
||||
try:
|
||||
from fail2ban.protocol import CSPROTO
|
||||
|
||||
return CSPROTO.END, CSPROTO.CLOSE, CSPROTO.EMPTY
|
||||
except ImportError:
|
||||
return b"<F2B_END_COMMAND>", b"<F2B_CLOSE_COMMAND>", b""
|
||||
|
||||
|
||||
_PROTO_END, _PROTO_CLOSE, _PROTO_EMPTY = _load_vendored_fail2ban_constants()
|
||||
|
||||
# Default receive buffer size (doubles on each iteration up to max).
|
||||
_RECV_BUFSIZE_START: int = 1024
|
||||
@@ -147,46 +176,14 @@ def _send_command_sync(
|
||||
"""
|
||||
last_oserror: OSError | None = None
|
||||
for attempt in range(1, _RETRY_MAX_ATTEMPTS + 1):
|
||||
sock: socket.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
client = None
|
||||
try:
|
||||
sock.settimeout(timeout)
|
||||
sock.connect(socket_path)
|
||||
|
||||
# Serialise and send the command.
|
||||
payload: bytes = dumps(
|
||||
list(map(_coerce_command_token, command)),
|
||||
HIGHEST_PROTOCOL,
|
||||
)
|
||||
sock.sendall(payload)
|
||||
sock.sendall(_PROTO_END)
|
||||
|
||||
# Receive until we see the end marker.
|
||||
raw: bytes = _PROTO_EMPTY
|
||||
bufsize: int = _RECV_BUFSIZE_START
|
||||
while raw.rfind(_PROTO_END, -32) == -1:
|
||||
chunk: bytes = sock.recv(bufsize)
|
||||
if not chunk:
|
||||
raise Fail2BanConnectionError(
|
||||
"Connection closed unexpectedly by fail2ban",
|
||||
socket_path,
|
||||
)
|
||||
if chunk == _PROTO_END:
|
||||
break
|
||||
raw += chunk
|
||||
if bufsize < _RECV_BUFSIZE_MAX:
|
||||
bufsize <<= 1
|
||||
|
||||
try:
|
||||
return loads(raw)
|
||||
except Exception as exc:
|
||||
raise Fail2BanProtocolError(
|
||||
f"Failed to unpickle fail2ban response: {exc}"
|
||||
) from exc
|
||||
client_cls = _load_vendored_fail2ban_client()
|
||||
client = client_cls(socket_path, timeout=timeout)
|
||||
return client.send(command)
|
||||
except Fail2BanProtocolError:
|
||||
# Protocol errors are never transient — raise immediately.
|
||||
raise
|
||||
except Fail2BanConnectionError:
|
||||
# Mid-receive close or empty-chunk error — raise immediately.
|
||||
raise
|
||||
except OSError as exc:
|
||||
is_retryable = exc.errno in _RETRYABLE_ERRNOS
|
||||
@@ -201,19 +198,41 @@ def _send_command_sync(
|
||||
time.sleep(_RETRY_INITIAL_BACKOFF * (2 ** (attempt - 1)))
|
||||
continue
|
||||
raise Fail2BanConnectionError(str(exc), socket_path) from exc
|
||||
except Exception as exc:
|
||||
raise Fail2BanProtocolError(
|
||||
f"Failed to parse fail2ban response: {exc}"
|
||||
) from exc
|
||||
finally:
|
||||
with contextlib.suppress(OSError):
|
||||
sock.sendall(_PROTO_CLOSE + _PROTO_END)
|
||||
with contextlib.suppress(OSError):
|
||||
sock.shutdown(socket.SHUT_RDWR)
|
||||
sock.close()
|
||||
if client is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
client.close()
|
||||
|
||||
# Exhausted all retry attempts — surface the last transient error.
|
||||
raise Fail2BanConnectionError(
|
||||
str(last_oserror), socket_path
|
||||
) from last_oserror
|
||||
|
||||
|
||||
class Fail2BanAdapter(Protocol):
|
||||
"""Protocol for a fail2ban socket adapter."""
|
||||
|
||||
async def send(self, command: Fail2BanCommand) -> object:
|
||||
"""Send a command to fail2ban and return the raw response."""
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""Return ``True`` if fail2ban is reachable."""
|
||||
|
||||
async def __aenter__(self) -> Fail2BanAdapter:
|
||||
"""Enter the async context manager."""
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
"""Exit the async context manager."""
|
||||
|
||||
|
||||
def _coerce_command_token(token: object) -> Fail2BanToken:
|
||||
"""Coerce a command token to a type that fail2ban understands.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user