"""Async wrapper around the fail2ban Unix domain socket protocol. fail2ban uses a proprietary binary protocol over a Unix domain socket: commands are transmitted as pickle-serialised Python lists and responses are returned the same way. The protocol constants (``END``, ``CLOSE``) come from ``fail2ban.protocol.CSPROTO``. Because the underlying socket is blocking, all I/O is dispatched to a thread-pool executor so the FastAPI event loop is never blocked. Usage:: async with Fail2BanClient(socket_path="/var/run/fail2ban/fail2ban.sock") as client: status = await client.send(["status"]) """ from __future__ import annotations import asyncio import contextlib import errno import socket import time from pickle import HIGHEST_PROTOCOL, dumps, loads from typing import TYPE_CHECKING, Any if TYPE_CHECKING: from types import TracebackType import structlog log: structlog.stdlib.BoundLogger = structlog.get_logger() # fail2ban protocol constants — inline to avoid a hard import dependency # at module load time (the fail2ban-master path may not be on sys.path yet # in some test environments). _PROTO_END: bytes = b"" _PROTO_CLOSE: bytes = b"" _PROTO_EMPTY: bytes = b"" # Default receive buffer size (doubles on each iteration up to max). _RECV_BUFSIZE_START: int = 1024 _RECV_BUFSIZE_MAX: int = 32768 # OSError errno values that indicate a transient socket condition and may be # safely retried. ENOENT (socket file missing) is intentionally excluded so # a missing socket raises immediately without delay. _RETRYABLE_ERRNOS: frozenset[int] = frozenset( {errno.EAGAIN, errno.ECONNREFUSED, errno.ENOBUFS} ) # Retry policy for _send_command_sync. _RETRY_MAX_ATTEMPTS: int = 3 _RETRY_INITIAL_BACKOFF: float = 0.15 # seconds; doubles on each attempt # Maximum number of concurrent in-flight socket commands. Operations that # exceed this cap wait until a slot is available. _COMMAND_SEMAPHORE_CONCURRENCY: int = 10 # The semaphore is created lazily on the first send() call so it binds to the # event loop that is actually running (important for test isolation). _command_semaphore: asyncio.Semaphore | None = None class Fail2BanConnectionError(Exception): """Raised when the fail2ban socket is unreachable or returns an error.""" def __init__(self, message: str, socket_path: str) -> None: """Initialise with a human-readable message and the socket path. Args: message: Description of the connection problem. socket_path: The fail2ban socket path that was targeted. """ self.socket_path: str = socket_path super().__init__(f"{message} (socket: {socket_path})") class Fail2BanProtocolError(Exception): """Raised when the response from fail2ban cannot be parsed.""" def _send_command_sync( socket_path: str, command: list[Any], timeout: float, ) -> Any: """Send a command to fail2ban and return the parsed response. This is a **synchronous** function intended to be called from within :func:`asyncio.get_event_loop().run_in_executor` so that the event loop is not blocked. Transient ``OSError`` conditions (``EAGAIN``, ``ECONNREFUSED``, ``ENOBUFS``) are retried up to :data:`_RETRY_MAX_ATTEMPTS` times with exponential back-off starting at :data:`_RETRY_INITIAL_BACKOFF` seconds. All other ``OSError`` variants (including ``ENOENT`` — socket file missing) and :class:`Fail2BanProtocolError` are raised immediately. A structured log event ``fail2ban_socket_retry`` is emitted for each retry attempt. Args: socket_path: Path to the fail2ban Unix domain socket. command: List of command tokens, e.g. ``["status", "sshd"]``. timeout: Socket timeout in seconds. Returns: The deserialized Python object returned by fail2ban. Raises: Fail2BanConnectionError: If the socket cannot be reached after all retry attempts, or immediately for non-retryable errors. Fail2BanProtocolError: If the response cannot be unpickled. """ last_oserror: OSError | None = None for attempt in range(1, _RETRY_MAX_ATTEMPTS + 1): sock: socket.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) try: sock.settimeout(timeout) sock.connect(socket_path) # Serialise and send the command. payload: bytes = dumps( list(map(_coerce_command_token, command)), HIGHEST_PROTOCOL, ) sock.sendall(payload) sock.sendall(_PROTO_END) # Receive until we see the end marker. raw: bytes = _PROTO_EMPTY bufsize: int = _RECV_BUFSIZE_START while raw.rfind(_PROTO_END, -32) == -1: chunk: bytes = sock.recv(bufsize) if not chunk: raise Fail2BanConnectionError( "Connection closed unexpectedly by fail2ban", socket_path, ) if chunk == _PROTO_END: break raw += chunk if bufsize < _RECV_BUFSIZE_MAX: bufsize <<= 1 try: return loads(raw) except Exception as exc: raise Fail2BanProtocolError( f"Failed to unpickle fail2ban response: {exc}" ) from exc except Fail2BanProtocolError: # Protocol errors are never transient — raise immediately. raise except Fail2BanConnectionError: # Mid-receive close or empty-chunk error — raise immediately. raise except OSError as exc: is_retryable = exc.errno in _RETRYABLE_ERRNOS if is_retryable and attempt < _RETRY_MAX_ATTEMPTS: log.warning( "fail2ban_socket_retry", attempt=attempt, socket_errno=exc.errno, socket_path=socket_path, ) last_oserror = exc time.sleep(_RETRY_INITIAL_BACKOFF * (2 ** (attempt - 1))) continue raise Fail2BanConnectionError(str(exc), socket_path) from exc finally: with contextlib.suppress(OSError): sock.sendall(_PROTO_CLOSE + _PROTO_END) with contextlib.suppress(OSError): sock.shutdown(socket.SHUT_RDWR) sock.close() # Exhausted all retry attempts — surface the last transient error. raise Fail2BanConnectionError( str(last_oserror), socket_path ) from last_oserror def _coerce_command_token(token: Any) -> Any: """Coerce a command token to a type that fail2ban understands. fail2ban's ``CSocket.convert`` accepts ``str``, ``bool``, ``int``, ``float``, ``list``, ``dict``, and ``set``. Any other type is stringified. Args: token: A single token from the command list. Returns: The token in a type safe for pickle transmission to fail2ban. """ if isinstance(token, (str, bool, int, float, list, dict, set)): return token return str(token) class Fail2BanClient: """Async client for communicating with the fail2ban daemon via its socket. All blocking socket I/O is offloaded to the default thread-pool executor so the asyncio event loop remains unblocked. The client can be used as an async context manager:: async with Fail2BanClient(socket_path) as client: result = await client.send(["status"]) Or instantiated directly and closed manually:: client = Fail2BanClient(socket_path) result = await client.send(["status"]) """ def __init__( self, socket_path: str, timeout: float = 5.0, ) -> None: """Initialise the client. Args: socket_path: Path to the fail2ban Unix domain socket. timeout: Socket I/O timeout in seconds. """ self.socket_path: str = socket_path self.timeout: float = timeout async def send(self, command: list[Any]) -> Any: """Send a command to fail2ban and return the response. Acquires the module-level concurrency semaphore before dispatching so that no more than :data:`_COMMAND_SEMAPHORE_CONCURRENCY` commands are in-flight at the same time. Commands that exceed the cap are queued until a slot becomes available. A debug-level log event is emitted when a command must wait. The command is serialised as a pickle list, sent to the socket, and the response is deserialised before being returned. Args: command: A list of command tokens, e.g. ``["status", "sshd"]``. Returns: The Python object returned by fail2ban (typically a list or dict). Raises: Fail2BanConnectionError: If the socket cannot be reached or the connection is unexpectedly closed. Fail2BanProtocolError: If the response cannot be decoded. """ global _command_semaphore if _command_semaphore is None: _command_semaphore = asyncio.Semaphore(_COMMAND_SEMAPHORE_CONCURRENCY) if _command_semaphore.locked(): log.debug( "fail2ban_command_waiting_semaphore", command=command, concurrency_limit=_COMMAND_SEMAPHORE_CONCURRENCY, ) async with _command_semaphore: log.debug("fail2ban_sending_command", command=command) loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() try: response: Any = await loop.run_in_executor( None, _send_command_sync, self.socket_path, command, self.timeout, ) except Fail2BanConnectionError: log.warning( "fail2ban_connection_error", socket_path=self.socket_path, command=command, ) raise except Fail2BanProtocolError: log.error( "fail2ban_protocol_error", socket_path=self.socket_path, command=command, ) raise log.debug("fail2ban_received_response", command=command) return response async def ping(self) -> bool: """Return ``True`` if the fail2ban daemon is reachable. Sends a ``ping`` command and checks for a ``pong`` response. Returns: ``True`` when the daemon responds correctly, ``False`` otherwise. """ try: response: Any = await self.send(["ping"]) return bool(response == 1) # fail2ban returns 1 on successful ping except (Fail2BanConnectionError, Fail2BanProtocolError): return False async def __aenter__(self) -> Fail2BanClient: """Return self when used as an async context manager.""" return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: """No-op exit — each command opens and closes its own socket."""