"""Async wrapper around the fail2ban Unix domain socket protocol. fail2ban uses a proprietary binary protocol over a Unix domain socket: commands are transmitted as pickle-serialised Python lists and responses are returned the same way. The protocol constants (``END``, ``CLOSE``) come from ``fail2ban.protocol.CSPROTO``. Because the underlying socket is blocking, all I/O is dispatched to a thread-pool executor so the FastAPI event loop is never blocked. Usage:: async with Fail2BanClient(socket_path="/var/run/fail2ban/fail2ban.sock") as client: status = await client.send(["status"]) """ from __future__ import annotations import asyncio import contextlib import errno import sys import time from collections.abc import Mapping, Sequence, Set from pathlib import Path from typing import TYPE_CHECKING, Protocol import structlog from app.exceptions import Fail2BanConnectionError, Fail2BanProtocolError # --------------------------------------------------------------------------- # Types # --------------------------------------------------------------------------- # Use covariant container types so callers can pass ``list[int]`` / ``dict[str, str]`` # without needing to cast. At runtime we only accept the basic built-in # containers supported by fail2ban's protocol (list/dict/set) and stringify # anything else. # # NOTE: ``Sequence`` will also accept tuples, but tuples are stringified at # runtime because fail2ban only understands lists. type Fail2BanToken = ( str | int | float | bool | None | Mapping[str, object] | Sequence[object] | Set[object] ) """A single token in a fail2ban command. Fail2ban accepts simple types (str/int/float/bool) plus compound types (list/dict/set). Complex objects are stringified before being sent. """ type Fail2BanCommand = Sequence[Fail2BanToken] """A command sent to fail2ban over the socket. Commands are pickle serialised sequences of tokens. """ type Fail2BanResponse = tuple[int, object] """A typical fail2ban response containing a status code and payload.""" if TYPE_CHECKING: from types import TracebackType log: structlog.stdlib.BoundLogger = structlog.get_logger() # Attempt to reuse the vendored fail2ban package embedded in the repository. # If it is not on sys.path yet, load it from ``../fail2ban-master``. def _load_vendored_fail2ban_client() -> type[object]: """Import the vendored ``fail2ban.client.csocket.CSocket`` implementation.""" try: from fail2ban.client.csocket import CSocket return CSocket except ImportError: vendor_root = Path(__file__).resolve().parents[4] / "fail2ban-master" if not vendor_root.is_dir(): raise sys.path.insert(0, str(vendor_root)) try: from fail2ban.client.csocket import CSocket return CSocket finally: if sys.path and sys.path[0] == str(vendor_root): sys.path.pop(0) def _load_vendored_fail2ban_constants() -> tuple[bytes, bytes, bytes]: """Load fail2ban protocol constants from the vendored package if possible.""" try: from fail2ban.protocol import CSPROTO return CSPROTO.END, CSPROTO.CLOSE, CSPROTO.EMPTY except ImportError: return b"", b"", b"" _PROTO_END, _PROTO_CLOSE, _PROTO_EMPTY = _load_vendored_fail2ban_constants() # Default receive buffer size (doubles on each iteration up to max). _RECV_BUFSIZE_START: int = 1024 _RECV_BUFSIZE_MAX: int = 32768 # OSError errno values that indicate a transient socket condition and may be # safely retried. ENOENT (socket file missing) is intentionally excluded so # a missing socket raises immediately without delay. _RETRYABLE_ERRNOS: frozenset[int] = frozenset( {errno.EAGAIN, errno.ECONNREFUSED, errno.ENOBUFS} ) # Retry policy for _send_command_sync. _RETRY_MAX_ATTEMPTS: int = 3 _RETRY_INITIAL_BACKOFF: float = 0.15 # seconds; doubles on each attempt # Maximum number of concurrent in-flight socket commands per client. # Operations that exceed this cap wait until a slot is available. _COMMAND_SEMAPHORE_CONCURRENCY: int = 10 def _send_command_sync( socket_path: str, command: Fail2BanCommand, timeout: float, ) -> object: """Send a command to fail2ban and return the parsed response. This is a **synchronous** function intended to be executed via :func:`asyncio.to_thread` so that the event loop is not blocked. Transient ``OSError`` conditions (``EAGAIN``, ``ECONNREFUSED``, ``ENOBUFS``) are retried up to :data:`_RETRY_MAX_ATTEMPTS` times with exponential back-off starting at :data:`_RETRY_INITIAL_BACKOFF` seconds. All other ``OSError`` variants (including ``ENOENT`` — socket file missing) and :class:`Fail2BanProtocolError` are raised immediately. A structured log event ``fail2ban_socket_retry`` is emitted for each retry attempt. Args: socket_path: Path to the fail2ban Unix domain socket. command: List of command tokens, e.g. ``["status", "sshd"]``. timeout: Socket timeout in seconds. Returns: The deserialized Python object returned by fail2ban. Raises: Fail2BanConnectionError: If the socket cannot be reached after all retry attempts, or immediately for non-retryable errors. Fail2BanProtocolError: If the response cannot be unpickled. """ last_oserror: OSError | None = None for attempt in range(1, _RETRY_MAX_ATTEMPTS + 1): client = None try: client_cls = _load_vendored_fail2ban_client() client = client_cls(socket_path, timeout=timeout) return client.send(command) except Fail2BanProtocolError: raise except Fail2BanConnectionError: raise except OSError as exc: is_retryable = exc.errno in _RETRYABLE_ERRNOS if is_retryable and attempt < _RETRY_MAX_ATTEMPTS: log.warning( "fail2ban_socket_retry", attempt=attempt, socket_errno=exc.errno, socket_path=socket_path, ) last_oserror = exc time.sleep(_RETRY_INITIAL_BACKOFF * (2 ** (attempt - 1))) continue raise Fail2BanConnectionError(str(exc), socket_path) from exc except Exception as exc: raise Fail2BanProtocolError( f"Failed to parse fail2ban response: {exc}" ) from exc finally: if client is not None: with contextlib.suppress(Exception): client.close() raise Fail2BanConnectionError( str(last_oserror), socket_path ) from last_oserror class Fail2BanAdapter(Protocol): """Protocol for a fail2ban socket adapter.""" async def send(self, command: Fail2BanCommand) -> object: """Send a command to fail2ban and return the raw response.""" async def ping(self) -> bool: """Return ``True`` if fail2ban is reachable.""" async def __aenter__(self) -> Fail2BanAdapter: """Enter the async context manager.""" async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: """Exit the async context manager.""" def _coerce_command_token(token: object) -> Fail2BanToken: """Coerce a command token to a type that fail2ban understands. fail2ban's ``CSocket.convert`` accepts ``str``, ``bool``, ``int``, ``float``, ``list``, ``dict``, and ``set``. Any other type is stringified. Args: token: A single token from the command list. Returns: The token in a type safe for pickle transmission to fail2ban. """ if isinstance(token, (str, bool, int, float, list, dict, set)): return token return str(token) class Fail2BanClient: """Async client for communicating with the fail2ban daemon via its socket. All blocking socket I/O is offloaded to the default thread-pool executor so the asyncio event loop remains unblocked. The client can be used as an async context manager:: async with Fail2BanClient(socket_path) as client: result = await client.send(["status"]) Or instantiated directly and closed manually:: client = Fail2BanClient(socket_path) result = await client.send(["status"]) """ def __init__( self, socket_path: str, timeout: float = 5.0, ) -> None: """Initialise the client. Args: socket_path: Path to the fail2ban Unix domain socket. timeout: Socket I/O timeout in seconds. """ self.socket_path: str = socket_path self.timeout: float = timeout self._command_semaphore: asyncio.Semaphore = asyncio.Semaphore( _COMMAND_SEMAPHORE_CONCURRENCY ) async def send(self, command: Fail2BanCommand) -> object: """Send a command to fail2ban and return the response. Acquires the module-level concurrency semaphore before dispatching so that no more than :data:`_COMMAND_SEMAPHORE_CONCURRENCY` commands are in-flight at the same time. Commands that exceed the cap are queued until a slot becomes available. A debug-level log event is emitted when a command must wait. The command is serialised as a pickle list, sent to the socket, and the response is deserialised before being returned. Args: command: A list of command tokens, e.g. ``["status", "sshd"]``. Returns: The Python object returned by fail2ban (typically a list or dict). Raises: Fail2BanConnectionError: If the socket cannot be reached or the connection is unexpectedly closed. Fail2BanProtocolError: If the response cannot be decoded. """ if self._command_semaphore.locked(): log.debug( "fail2ban_command_waiting_semaphore", command=command, concurrency_limit=_COMMAND_SEMAPHORE_CONCURRENCY, ) async with self._command_semaphore: log.debug("fail2ban_sending_command", command=command) try: response: object = await asyncio.to_thread( _send_command_sync, self.socket_path, command, self.timeout, ) except Fail2BanConnectionError: log.warning( "fail2ban_connection_error", socket_path=self.socket_path, command=command, ) raise except Fail2BanProtocolError: log.error( "fail2ban_protocol_error", socket_path=self.socket_path, command=command, ) raise log.debug("fail2ban_received_response", command=command) return response async def ping(self) -> bool: """Return ``True`` if the fail2ban daemon is reachable. Sends a ``ping`` command and checks for a ``pong`` response. Returns: ``True`` when the daemon responds correctly, ``False`` otherwise. """ try: response: object = await self.send(["ping"]) return bool(response == 1) # fail2ban returns 1 on successful ping except (Fail2BanConnectionError, Fail2BanProtocolError): return False async def __aenter__(self) -> Fail2BanClient: """Return self when used as an async context manager.""" return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: """No-op exit — each command opens and closes its own socket."""