Files
BanGUI/backend/app/utils/fail2ban_client.py
Lukas 1c0bac1353 refactor: improve backend type safety and import organization
- Add TYPE_CHECKING guards for runtime-expensive imports (aiohttp, aiosqlite)
- Reorganize imports to follow PEP 8 conventions
- Convert TypeAlias to modern PEP 695 type syntax (where appropriate)
- Use Sequence/Mapping from collections.abc for type hints (covariant)
- Replace string literals with cast() for improved type inference
- Fix casting of Fail2BanResponse and TypedDict patterns
- Add IpLookupResult TypedDict for precise return type annotation
- Reformat overlong lines for readability (120 char limit)
- Add asyncio_mode and filterwarnings to pytest config
- Update test fixtures with improved type hints

This improves mypy type checking and makes type relationships explicit.
2026-03-22 14:24:24 +01:00

357 lines
12 KiB
Python

"""Async wrapper around the fail2ban Unix domain socket protocol.
fail2ban uses a proprietary binary protocol over a Unix domain socket:
commands are transmitted as pickle-serialised Python lists and responses
are returned the same way. The protocol constants (``END``, ``CLOSE``)
come from ``fail2ban.protocol.CSPROTO``.
Because the underlying socket is blocking, all I/O is dispatched to a
thread-pool executor so the FastAPI event loop is never blocked.
Usage::
async with Fail2BanClient(socket_path="/var/run/fail2ban/fail2ban.sock") as client:
status = await client.send(["status"])
"""
from __future__ import annotations
import asyncio
import contextlib
import errno
import socket
import time
from collections.abc import Mapping, Sequence, Set
from pickle import HIGHEST_PROTOCOL, dumps, loads
from typing import TYPE_CHECKING
import structlog
# ---------------------------------------------------------------------------
# Types
# ---------------------------------------------------------------------------
# Use covariant container types so callers can pass ``list[int]`` / ``dict[str, str]``
# without needing to cast. At runtime we only accept the basic built-in
# containers supported by fail2ban's protocol (list/dict/set) and stringify
# anything else.
#
# NOTE: ``Sequence`` will also accept tuples, but tuples are stringified at
# runtime because fail2ban only understands lists.
type Fail2BanToken = (
str
| int
| float
| bool
| None
| Mapping[str, object]
| Sequence[object]
| Set[object]
)
"""A single token in a fail2ban command.
Fail2ban accepts simple types (str/int/float/bool) plus compound types
(list/dict/set). Complex objects are stringified before being sent.
"""
type Fail2BanCommand = Sequence[Fail2BanToken]
"""A command sent to fail2ban over the socket.
Commands are pickle serialised sequences of tokens.
"""
type Fail2BanResponse = tuple[int, object]
"""A typical fail2ban response containing a status code and payload."""
if TYPE_CHECKING:
from types import TracebackType
log: structlog.stdlib.BoundLogger = structlog.get_logger()
# fail2ban protocol constants — inline to avoid a hard import dependency
# at module load time (the fail2ban-master path may not be on sys.path yet
# in some test environments).
_PROTO_END: bytes = b"<F2B_END_COMMAND>"
_PROTO_CLOSE: bytes = b"<F2B_CLOSE_COMMAND>"
_PROTO_EMPTY: bytes = b""
# Default receive buffer size (doubles on each iteration up to max).
_RECV_BUFSIZE_START: int = 1024
_RECV_BUFSIZE_MAX: int = 32768
# OSError errno values that indicate a transient socket condition and may be
# safely retried. ENOENT (socket file missing) is intentionally excluded so
# a missing socket raises immediately without delay.
_RETRYABLE_ERRNOS: frozenset[int] = frozenset(
{errno.EAGAIN, errno.ECONNREFUSED, errno.ENOBUFS}
)
# Retry policy for _send_command_sync.
_RETRY_MAX_ATTEMPTS: int = 3
_RETRY_INITIAL_BACKOFF: float = 0.15 # seconds; doubles on each attempt
# Maximum number of concurrent in-flight socket commands. Operations that
# exceed this cap wait until a slot is available.
_COMMAND_SEMAPHORE_CONCURRENCY: int = 10
# The semaphore is created lazily on the first send() call so it binds to the
# event loop that is actually running (important for test isolation).
_command_semaphore: asyncio.Semaphore | None = None
class Fail2BanConnectionError(Exception):
"""Raised when the fail2ban socket is unreachable or returns an error."""
def __init__(self, message: str, socket_path: str) -> None:
"""Initialise with a human-readable message and the socket path.
Args:
message: Description of the connection problem.
socket_path: The fail2ban socket path that was targeted.
"""
self.socket_path: str = socket_path
super().__init__(f"{message} (socket: {socket_path})")
class Fail2BanProtocolError(Exception):
"""Raised when the response from fail2ban cannot be parsed."""
def _send_command_sync(
socket_path: str,
command: Fail2BanCommand,
timeout: float,
) -> object:
"""Send a command to fail2ban and return the parsed response.
This is a **synchronous** function intended to be called from within
:func:`asyncio.get_event_loop().run_in_executor` so that the event loop
is not blocked.
Transient ``OSError`` conditions (``EAGAIN``, ``ECONNREFUSED``,
``ENOBUFS``) are retried up to :data:`_RETRY_MAX_ATTEMPTS` times with
exponential back-off starting at :data:`_RETRY_INITIAL_BACKOFF` seconds.
All other ``OSError`` variants (including ``ENOENT`` — socket file
missing) and :class:`Fail2BanProtocolError` are raised immediately.
A structured log event ``fail2ban_socket_retry`` is emitted for each
retry attempt.
Args:
socket_path: Path to the fail2ban Unix domain socket.
command: List of command tokens, e.g. ``["status", "sshd"]``.
timeout: Socket timeout in seconds.
Returns:
The deserialized Python object returned by fail2ban.
Raises:
Fail2BanConnectionError: If the socket cannot be reached after all
retry attempts, or immediately for non-retryable errors.
Fail2BanProtocolError: If the response cannot be unpickled.
"""
last_oserror: OSError | None = None
for attempt in range(1, _RETRY_MAX_ATTEMPTS + 1):
sock: socket.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
try:
sock.settimeout(timeout)
sock.connect(socket_path)
# Serialise and send the command.
payload: bytes = dumps(
list(map(_coerce_command_token, command)),
HIGHEST_PROTOCOL,
)
sock.sendall(payload)
sock.sendall(_PROTO_END)
# Receive until we see the end marker.
raw: bytes = _PROTO_EMPTY
bufsize: int = _RECV_BUFSIZE_START
while raw.rfind(_PROTO_END, -32) == -1:
chunk: bytes = sock.recv(bufsize)
if not chunk:
raise Fail2BanConnectionError(
"Connection closed unexpectedly by fail2ban",
socket_path,
)
if chunk == _PROTO_END:
break
raw += chunk
if bufsize < _RECV_BUFSIZE_MAX:
bufsize <<= 1
try:
return loads(raw)
except Exception as exc:
raise Fail2BanProtocolError(
f"Failed to unpickle fail2ban response: {exc}"
) from exc
except Fail2BanProtocolError:
# Protocol errors are never transient — raise immediately.
raise
except Fail2BanConnectionError:
# Mid-receive close or empty-chunk error — raise immediately.
raise
except OSError as exc:
is_retryable = exc.errno in _RETRYABLE_ERRNOS
if is_retryable and attempt < _RETRY_MAX_ATTEMPTS:
log.warning(
"fail2ban_socket_retry",
attempt=attempt,
socket_errno=exc.errno,
socket_path=socket_path,
)
last_oserror = exc
time.sleep(_RETRY_INITIAL_BACKOFF * (2 ** (attempt - 1)))
continue
raise Fail2BanConnectionError(str(exc), socket_path) from exc
finally:
with contextlib.suppress(OSError):
sock.sendall(_PROTO_CLOSE + _PROTO_END)
with contextlib.suppress(OSError):
sock.shutdown(socket.SHUT_RDWR)
sock.close()
# Exhausted all retry attempts — surface the last transient error.
raise Fail2BanConnectionError(
str(last_oserror), socket_path
) from last_oserror
def _coerce_command_token(token: object) -> Fail2BanToken:
"""Coerce a command token to a type that fail2ban understands.
fail2ban's ``CSocket.convert`` accepts ``str``, ``bool``, ``int``,
``float``, ``list``, ``dict``, and ``set``. Any other type is
stringified.
Args:
token: A single token from the command list.
Returns:
The token in a type safe for pickle transmission to fail2ban.
"""
if isinstance(token, (str, bool, int, float, list, dict, set)):
return token
return str(token)
class Fail2BanClient:
"""Async client for communicating with the fail2ban daemon via its socket.
All blocking socket I/O is offloaded to the default thread-pool executor
so the asyncio event loop remains unblocked.
The client can be used as an async context manager::
async with Fail2BanClient(socket_path) as client:
result = await client.send(["status"])
Or instantiated directly and closed manually::
client = Fail2BanClient(socket_path)
result = await client.send(["status"])
"""
def __init__(
self,
socket_path: str,
timeout: float = 5.0,
) -> None:
"""Initialise the client.
Args:
socket_path: Path to the fail2ban Unix domain socket.
timeout: Socket I/O timeout in seconds.
"""
self.socket_path: str = socket_path
self.timeout: float = timeout
async def send(self, command: Fail2BanCommand) -> object:
"""Send a command to fail2ban and return the response.
Acquires the module-level concurrency semaphore before dispatching
so that no more than :data:`_COMMAND_SEMAPHORE_CONCURRENCY` commands
are in-flight at the same time. Commands that exceed the cap are
queued until a slot becomes available. A debug-level log event is
emitted when a command must wait.
The command is serialised as a pickle list, sent to the socket, and
the response is deserialised before being returned.
Args:
command: A list of command tokens, e.g. ``["status", "sshd"]``.
Returns:
The Python object returned by fail2ban (typically a list or dict).
Raises:
Fail2BanConnectionError: If the socket cannot be reached or the
connection is unexpectedly closed.
Fail2BanProtocolError: If the response cannot be decoded.
"""
global _command_semaphore
if _command_semaphore is None:
_command_semaphore = asyncio.Semaphore(_COMMAND_SEMAPHORE_CONCURRENCY)
if _command_semaphore.locked():
log.debug(
"fail2ban_command_waiting_semaphore",
command=command,
concurrency_limit=_COMMAND_SEMAPHORE_CONCURRENCY,
)
async with _command_semaphore:
log.debug("fail2ban_sending_command", command=command)
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
try:
response: object = await loop.run_in_executor(
None,
_send_command_sync,
self.socket_path,
command,
self.timeout,
)
except Fail2BanConnectionError:
log.warning(
"fail2ban_connection_error",
socket_path=self.socket_path,
command=command,
)
raise
except Fail2BanProtocolError:
log.error(
"fail2ban_protocol_error",
socket_path=self.socket_path,
command=command,
)
raise
log.debug("fail2ban_received_response", command=command)
return response
async def ping(self) -> bool:
"""Return ``True`` if the fail2ban daemon is reachable.
Sends a ``ping`` command and checks for a ``pong`` response.
Returns:
``True`` when the daemon responds correctly, ``False`` otherwise.
"""
try:
response: object = await self.send(["ping"])
return bool(response == 1) # fail2ban returns 1 on successful ping
except (Fail2BanConnectionError, Fail2BanProtocolError):
return False
async def __aenter__(self) -> Fail2BanClient:
"""Return self when used as an async context manager."""
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""No-op exit — each command opens and closes its own socket."""