- Add TYPE_CHECKING guards for runtime-expensive imports (aiohttp, aiosqlite) - Reorganize imports to follow PEP 8 conventions - Convert TypeAlias to modern PEP 695 type syntax (where appropriate) - Use Sequence/Mapping from collections.abc for type hints (covariant) - Replace string literals with cast() for improved type inference - Fix casting of Fail2BanResponse and TypedDict patterns - Add IpLookupResult TypedDict for precise return type annotation - Reformat overlong lines for readability (120 char limit) - Add asyncio_mode and filterwarnings to pytest config - Update test fixtures with improved type hints This improves mypy type checking and makes type relationships explicit.
357 lines
12 KiB
Python
357 lines
12 KiB
Python
"""Async wrapper around the fail2ban Unix domain socket protocol.
|
|
|
|
fail2ban uses a proprietary binary protocol over a Unix domain socket:
|
|
commands are transmitted as pickle-serialised Python lists and responses
|
|
are returned the same way. The protocol constants (``END``, ``CLOSE``)
|
|
come from ``fail2ban.protocol.CSPROTO``.
|
|
|
|
Because the underlying socket is blocking, all I/O is dispatched to a
|
|
thread-pool executor so the FastAPI event loop is never blocked.
|
|
|
|
Usage::
|
|
|
|
async with Fail2BanClient(socket_path="/var/run/fail2ban/fail2ban.sock") as client:
|
|
status = await client.send(["status"])
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import contextlib
|
|
import errno
|
|
import socket
|
|
import time
|
|
from collections.abc import Mapping, Sequence, Set
|
|
from pickle import HIGHEST_PROTOCOL, dumps, loads
|
|
from typing import TYPE_CHECKING
|
|
|
|
import structlog
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Types
|
|
# ---------------------------------------------------------------------------
|
|
|
|
# Use covariant container types so callers can pass ``list[int]`` / ``dict[str, str]``
|
|
# without needing to cast. At runtime we only accept the basic built-in
|
|
# containers supported by fail2ban's protocol (list/dict/set) and stringify
|
|
# anything else.
|
|
#
|
|
# NOTE: ``Sequence`` will also accept tuples, but tuples are stringified at
|
|
# runtime because fail2ban only understands lists.
|
|
|
|
type Fail2BanToken = (
|
|
str
|
|
| int
|
|
| float
|
|
| bool
|
|
| None
|
|
| Mapping[str, object]
|
|
| Sequence[object]
|
|
| Set[object]
|
|
)
|
|
"""A single token in a fail2ban command.
|
|
|
|
Fail2ban accepts simple types (str/int/float/bool) plus compound types
|
|
(list/dict/set). Complex objects are stringified before being sent.
|
|
"""
|
|
|
|
type Fail2BanCommand = Sequence[Fail2BanToken]
|
|
"""A command sent to fail2ban over the socket.
|
|
|
|
Commands are pickle serialised sequences of tokens.
|
|
"""
|
|
|
|
type Fail2BanResponse = tuple[int, object]
|
|
"""A typical fail2ban response containing a status code and payload."""
|
|
|
|
if TYPE_CHECKING:
|
|
from types import TracebackType
|
|
|
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
|
|
|
# fail2ban protocol constants — inline to avoid a hard import dependency
|
|
# at module load time (the fail2ban-master path may not be on sys.path yet
|
|
# in some test environments).
|
|
_PROTO_END: bytes = b"<F2B_END_COMMAND>"
|
|
_PROTO_CLOSE: bytes = b"<F2B_CLOSE_COMMAND>"
|
|
_PROTO_EMPTY: bytes = b""
|
|
|
|
# Default receive buffer size (doubles on each iteration up to max).
|
|
_RECV_BUFSIZE_START: int = 1024
|
|
_RECV_BUFSIZE_MAX: int = 32768
|
|
|
|
# OSError errno values that indicate a transient socket condition and may be
|
|
# safely retried. ENOENT (socket file missing) is intentionally excluded so
|
|
# a missing socket raises immediately without delay.
|
|
_RETRYABLE_ERRNOS: frozenset[int] = frozenset(
|
|
{errno.EAGAIN, errno.ECONNREFUSED, errno.ENOBUFS}
|
|
)
|
|
|
|
# Retry policy for _send_command_sync.
|
|
_RETRY_MAX_ATTEMPTS: int = 3
|
|
_RETRY_INITIAL_BACKOFF: float = 0.15 # seconds; doubles on each attempt
|
|
|
|
# Maximum number of concurrent in-flight socket commands. Operations that
|
|
# exceed this cap wait until a slot is available.
|
|
_COMMAND_SEMAPHORE_CONCURRENCY: int = 10
|
|
# The semaphore is created lazily on the first send() call so it binds to the
|
|
# event loop that is actually running (important for test isolation).
|
|
_command_semaphore: asyncio.Semaphore | None = None
|
|
|
|
|
|
class Fail2BanConnectionError(Exception):
|
|
"""Raised when the fail2ban socket is unreachable or returns an error."""
|
|
|
|
def __init__(self, message: str, socket_path: str) -> None:
|
|
"""Initialise with a human-readable message and the socket path.
|
|
|
|
Args:
|
|
message: Description of the connection problem.
|
|
socket_path: The fail2ban socket path that was targeted.
|
|
"""
|
|
self.socket_path: str = socket_path
|
|
super().__init__(f"{message} (socket: {socket_path})")
|
|
|
|
|
|
class Fail2BanProtocolError(Exception):
|
|
"""Raised when the response from fail2ban cannot be parsed."""
|
|
|
|
|
|
def _send_command_sync(
|
|
socket_path: str,
|
|
command: Fail2BanCommand,
|
|
timeout: float,
|
|
) -> object:
|
|
"""Send a command to fail2ban and return the parsed response.
|
|
|
|
This is a **synchronous** function intended to be called from within
|
|
:func:`asyncio.get_event_loop().run_in_executor` so that the event loop
|
|
is not blocked.
|
|
|
|
Transient ``OSError`` conditions (``EAGAIN``, ``ECONNREFUSED``,
|
|
``ENOBUFS``) are retried up to :data:`_RETRY_MAX_ATTEMPTS` times with
|
|
exponential back-off starting at :data:`_RETRY_INITIAL_BACKOFF` seconds.
|
|
All other ``OSError`` variants (including ``ENOENT`` — socket file
|
|
missing) and :class:`Fail2BanProtocolError` are raised immediately.
|
|
A structured log event ``fail2ban_socket_retry`` is emitted for each
|
|
retry attempt.
|
|
|
|
Args:
|
|
socket_path: Path to the fail2ban Unix domain socket.
|
|
command: List of command tokens, e.g. ``["status", "sshd"]``.
|
|
timeout: Socket timeout in seconds.
|
|
|
|
Returns:
|
|
The deserialized Python object returned by fail2ban.
|
|
|
|
Raises:
|
|
Fail2BanConnectionError: If the socket cannot be reached after all
|
|
retry attempts, or immediately for non-retryable errors.
|
|
Fail2BanProtocolError: If the response cannot be unpickled.
|
|
"""
|
|
last_oserror: OSError | None = None
|
|
for attempt in range(1, _RETRY_MAX_ATTEMPTS + 1):
|
|
sock: socket.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
try:
|
|
sock.settimeout(timeout)
|
|
sock.connect(socket_path)
|
|
|
|
# Serialise and send the command.
|
|
payload: bytes = dumps(
|
|
list(map(_coerce_command_token, command)),
|
|
HIGHEST_PROTOCOL,
|
|
)
|
|
sock.sendall(payload)
|
|
sock.sendall(_PROTO_END)
|
|
|
|
# Receive until we see the end marker.
|
|
raw: bytes = _PROTO_EMPTY
|
|
bufsize: int = _RECV_BUFSIZE_START
|
|
while raw.rfind(_PROTO_END, -32) == -1:
|
|
chunk: bytes = sock.recv(bufsize)
|
|
if not chunk:
|
|
raise Fail2BanConnectionError(
|
|
"Connection closed unexpectedly by fail2ban",
|
|
socket_path,
|
|
)
|
|
if chunk == _PROTO_END:
|
|
break
|
|
raw += chunk
|
|
if bufsize < _RECV_BUFSIZE_MAX:
|
|
bufsize <<= 1
|
|
|
|
try:
|
|
return loads(raw)
|
|
except Exception as exc:
|
|
raise Fail2BanProtocolError(
|
|
f"Failed to unpickle fail2ban response: {exc}"
|
|
) from exc
|
|
except Fail2BanProtocolError:
|
|
# Protocol errors are never transient — raise immediately.
|
|
raise
|
|
except Fail2BanConnectionError:
|
|
# Mid-receive close or empty-chunk error — raise immediately.
|
|
raise
|
|
except OSError as exc:
|
|
is_retryable = exc.errno in _RETRYABLE_ERRNOS
|
|
if is_retryable and attempt < _RETRY_MAX_ATTEMPTS:
|
|
log.warning(
|
|
"fail2ban_socket_retry",
|
|
attempt=attempt,
|
|
socket_errno=exc.errno,
|
|
socket_path=socket_path,
|
|
)
|
|
last_oserror = exc
|
|
time.sleep(_RETRY_INITIAL_BACKOFF * (2 ** (attempt - 1)))
|
|
continue
|
|
raise Fail2BanConnectionError(str(exc), socket_path) from exc
|
|
finally:
|
|
with contextlib.suppress(OSError):
|
|
sock.sendall(_PROTO_CLOSE + _PROTO_END)
|
|
with contextlib.suppress(OSError):
|
|
sock.shutdown(socket.SHUT_RDWR)
|
|
sock.close()
|
|
|
|
# Exhausted all retry attempts — surface the last transient error.
|
|
raise Fail2BanConnectionError(
|
|
str(last_oserror), socket_path
|
|
) from last_oserror
|
|
|
|
|
|
def _coerce_command_token(token: object) -> Fail2BanToken:
|
|
"""Coerce a command token to a type that fail2ban understands.
|
|
|
|
fail2ban's ``CSocket.convert`` accepts ``str``, ``bool``, ``int``,
|
|
``float``, ``list``, ``dict``, and ``set``. Any other type is
|
|
stringified.
|
|
|
|
Args:
|
|
token: A single token from the command list.
|
|
|
|
Returns:
|
|
The token in a type safe for pickle transmission to fail2ban.
|
|
"""
|
|
if isinstance(token, (str, bool, int, float, list, dict, set)):
|
|
return token
|
|
return str(token)
|
|
|
|
|
|
class Fail2BanClient:
|
|
"""Async client for communicating with the fail2ban daemon via its socket.
|
|
|
|
All blocking socket I/O is offloaded to the default thread-pool executor
|
|
so the asyncio event loop remains unblocked.
|
|
|
|
The client can be used as an async context manager::
|
|
|
|
async with Fail2BanClient(socket_path) as client:
|
|
result = await client.send(["status"])
|
|
|
|
Or instantiated directly and closed manually::
|
|
|
|
client = Fail2BanClient(socket_path)
|
|
result = await client.send(["status"])
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
socket_path: str,
|
|
timeout: float = 5.0,
|
|
) -> None:
|
|
"""Initialise the client.
|
|
|
|
Args:
|
|
socket_path: Path to the fail2ban Unix domain socket.
|
|
timeout: Socket I/O timeout in seconds.
|
|
"""
|
|
self.socket_path: str = socket_path
|
|
self.timeout: float = timeout
|
|
|
|
async def send(self, command: Fail2BanCommand) -> object:
|
|
"""Send a command to fail2ban and return the response.
|
|
|
|
Acquires the module-level concurrency semaphore before dispatching
|
|
so that no more than :data:`_COMMAND_SEMAPHORE_CONCURRENCY` commands
|
|
are in-flight at the same time. Commands that exceed the cap are
|
|
queued until a slot becomes available. A debug-level log event is
|
|
emitted when a command must wait.
|
|
|
|
The command is serialised as a pickle list, sent to the socket, and
|
|
the response is deserialised before being returned.
|
|
|
|
Args:
|
|
command: A list of command tokens, e.g. ``["status", "sshd"]``.
|
|
|
|
Returns:
|
|
The Python object returned by fail2ban (typically a list or dict).
|
|
|
|
Raises:
|
|
Fail2BanConnectionError: If the socket cannot be reached or the
|
|
connection is unexpectedly closed.
|
|
Fail2BanProtocolError: If the response cannot be decoded.
|
|
"""
|
|
global _command_semaphore
|
|
if _command_semaphore is None:
|
|
_command_semaphore = asyncio.Semaphore(_COMMAND_SEMAPHORE_CONCURRENCY)
|
|
|
|
if _command_semaphore.locked():
|
|
log.debug(
|
|
"fail2ban_command_waiting_semaphore",
|
|
command=command,
|
|
concurrency_limit=_COMMAND_SEMAPHORE_CONCURRENCY,
|
|
)
|
|
|
|
async with _command_semaphore:
|
|
log.debug("fail2ban_sending_command", command=command)
|
|
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
|
try:
|
|
response: object = await loop.run_in_executor(
|
|
None,
|
|
_send_command_sync,
|
|
self.socket_path,
|
|
command,
|
|
self.timeout,
|
|
)
|
|
except Fail2BanConnectionError:
|
|
log.warning(
|
|
"fail2ban_connection_error",
|
|
socket_path=self.socket_path,
|
|
command=command,
|
|
)
|
|
raise
|
|
except Fail2BanProtocolError:
|
|
log.error(
|
|
"fail2ban_protocol_error",
|
|
socket_path=self.socket_path,
|
|
command=command,
|
|
)
|
|
raise
|
|
log.debug("fail2ban_received_response", command=command)
|
|
return response
|
|
|
|
async def ping(self) -> bool:
|
|
"""Return ``True`` if the fail2ban daemon is reachable.
|
|
|
|
Sends a ``ping`` command and checks for a ``pong`` response.
|
|
|
|
Returns:
|
|
``True`` when the daemon responds correctly, ``False`` otherwise.
|
|
"""
|
|
try:
|
|
response: object = await self.send(["ping"])
|
|
return bool(response == 1) # fail2ban returns 1 on successful ping
|
|
except (Fail2BanConnectionError, Fail2BanProtocolError):
|
|
return False
|
|
|
|
async def __aenter__(self) -> Fail2BanClient:
|
|
"""Return self when used as an async context manager."""
|
|
return self
|
|
|
|
async def __aexit__(
|
|
self,
|
|
exc_type: type[BaseException] | None,
|
|
exc_val: BaseException | None,
|
|
exc_tb: TracebackType | None,
|
|
) -> None:
|
|
"""No-op exit — each command opens and closes its own socket."""
|