Refactor fail2ban client to use vendored adapter
This commit is contained in:
@@ -93,6 +93,7 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
|
|||||||
- Issue: `app.utils.fail2ban_client` reimplements low-level socket framing, command encoding, and protocol parsing rather than using the vendored fail2ban client classes, creating duplicated protocol logic and an unclear source of truth.
|
- Issue: `app.utils.fail2ban_client` reimplements low-level socket framing, command encoding, and protocol parsing rather than using the vendored fail2ban client classes, creating duplicated protocol logic and an unclear source of truth.
|
||||||
- Propose: Introduce a fail2ban adapter interface and either wrap the vendored `fail2ban-client` implementation or refactor the custom client so it is the single canonical integration point. Ensure all services depend on the adapter abstraction rather than raw socket details.
|
- Propose: Introduce a fail2ban adapter interface and either wrap the vendored `fail2ban-client` implementation or refactor the custom client so it is the single canonical integration point. Ensure all services depend on the adapter abstraction rather than raw socket details.
|
||||||
- Test: Add adapter-level unit tests and service tests that can swap in a fake fail2ban adapter, proving the backend no longer couples business logic directly to low-level socket protocol code.
|
- Test: Add adapter-level unit tests and service tests that can swap in a fake fail2ban adapter, proving the backend no longer couples business logic directly to low-level socket protocol code.
|
||||||
|
- Status: completed
|
||||||
|
|
||||||
13. Introduce explicit schema migration/versioning for the runtime database
|
13. Introduce explicit schema migration/versioning for the runtime database
|
||||||
- Goal: Allow BanGUI to evolve its application database schema safely across releases and prevent startup failures caused by schema drift.
|
- Goal: Allow BanGUI to evolve its application database schema safely across releases and prevent startup failures caused by schema drift.
|
||||||
|
|||||||
@@ -19,11 +19,11 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import contextlib
|
import contextlib
|
||||||
import errno
|
import errno
|
||||||
import socket
|
import sys
|
||||||
import time
|
import time
|
||||||
from collections.abc import Mapping, Sequence, Set
|
from collections.abc import Mapping, Sequence, Set
|
||||||
from pickle import HIGHEST_PROTOCOL, dumps, loads
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, Protocol
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
@@ -69,12 +69,41 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
# fail2ban protocol constants — inline to avoid a hard import dependency
|
# Attempt to reuse the vendored fail2ban package embedded in the repository.
|
||||||
# at module load time (the fail2ban-master path may not be on sys.path yet
|
# If it is not on sys.path yet, load it from ``../fail2ban-master``.
|
||||||
# in some test environments).
|
|
||||||
_PROTO_END: bytes = b"<F2B_END_COMMAND>"
|
def _load_vendored_fail2ban_client() -> type[object]:
|
||||||
_PROTO_CLOSE: bytes = b"<F2B_CLOSE_COMMAND>"
|
"""Import the vendored ``fail2ban.client.csocket.CSocket`` implementation."""
|
||||||
_PROTO_EMPTY: bytes = b""
|
try:
|
||||||
|
from fail2ban.client.csocket import CSocket
|
||||||
|
|
||||||
|
return CSocket
|
||||||
|
except ImportError:
|
||||||
|
vendor_root = Path(__file__).resolve().parents[4] / "fail2ban-master"
|
||||||
|
if not vendor_root.is_dir():
|
||||||
|
raise
|
||||||
|
|
||||||
|
sys.path.insert(0, str(vendor_root))
|
||||||
|
try:
|
||||||
|
from fail2ban.client.csocket import CSocket
|
||||||
|
|
||||||
|
return CSocket
|
||||||
|
finally:
|
||||||
|
if sys.path and sys.path[0] == str(vendor_root):
|
||||||
|
sys.path.pop(0)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_vendored_fail2ban_constants() -> tuple[bytes, bytes, bytes]:
|
||||||
|
"""Load fail2ban protocol constants from the vendored package if possible."""
|
||||||
|
try:
|
||||||
|
from fail2ban.protocol import CSPROTO
|
||||||
|
|
||||||
|
return CSPROTO.END, CSPROTO.CLOSE, CSPROTO.EMPTY
|
||||||
|
except ImportError:
|
||||||
|
return b"<F2B_END_COMMAND>", b"<F2B_CLOSE_COMMAND>", b""
|
||||||
|
|
||||||
|
|
||||||
|
_PROTO_END, _PROTO_CLOSE, _PROTO_EMPTY = _load_vendored_fail2ban_constants()
|
||||||
|
|
||||||
# Default receive buffer size (doubles on each iteration up to max).
|
# Default receive buffer size (doubles on each iteration up to max).
|
||||||
_RECV_BUFSIZE_START: int = 1024
|
_RECV_BUFSIZE_START: int = 1024
|
||||||
@@ -147,46 +176,14 @@ def _send_command_sync(
|
|||||||
"""
|
"""
|
||||||
last_oserror: OSError | None = None
|
last_oserror: OSError | None = None
|
||||||
for attempt in range(1, _RETRY_MAX_ATTEMPTS + 1):
|
for attempt in range(1, _RETRY_MAX_ATTEMPTS + 1):
|
||||||
sock: socket.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
client = None
|
||||||
try:
|
try:
|
||||||
sock.settimeout(timeout)
|
client_cls = _load_vendored_fail2ban_client()
|
||||||
sock.connect(socket_path)
|
client = client_cls(socket_path, timeout=timeout)
|
||||||
|
return client.send(command)
|
||||||
# Serialise and send the command.
|
|
||||||
payload: bytes = dumps(
|
|
||||||
list(map(_coerce_command_token, command)),
|
|
||||||
HIGHEST_PROTOCOL,
|
|
||||||
)
|
|
||||||
sock.sendall(payload)
|
|
||||||
sock.sendall(_PROTO_END)
|
|
||||||
|
|
||||||
# Receive until we see the end marker.
|
|
||||||
raw: bytes = _PROTO_EMPTY
|
|
||||||
bufsize: int = _RECV_BUFSIZE_START
|
|
||||||
while raw.rfind(_PROTO_END, -32) == -1:
|
|
||||||
chunk: bytes = sock.recv(bufsize)
|
|
||||||
if not chunk:
|
|
||||||
raise Fail2BanConnectionError(
|
|
||||||
"Connection closed unexpectedly by fail2ban",
|
|
||||||
socket_path,
|
|
||||||
)
|
|
||||||
if chunk == _PROTO_END:
|
|
||||||
break
|
|
||||||
raw += chunk
|
|
||||||
if bufsize < _RECV_BUFSIZE_MAX:
|
|
||||||
bufsize <<= 1
|
|
||||||
|
|
||||||
try:
|
|
||||||
return loads(raw)
|
|
||||||
except Exception as exc:
|
|
||||||
raise Fail2BanProtocolError(
|
|
||||||
f"Failed to unpickle fail2ban response: {exc}"
|
|
||||||
) from exc
|
|
||||||
except Fail2BanProtocolError:
|
except Fail2BanProtocolError:
|
||||||
# Protocol errors are never transient — raise immediately.
|
|
||||||
raise
|
raise
|
||||||
except Fail2BanConnectionError:
|
except Fail2BanConnectionError:
|
||||||
# Mid-receive close or empty-chunk error — raise immediately.
|
|
||||||
raise
|
raise
|
||||||
except OSError as exc:
|
except OSError as exc:
|
||||||
is_retryable = exc.errno in _RETRYABLE_ERRNOS
|
is_retryable = exc.errno in _RETRYABLE_ERRNOS
|
||||||
@@ -201,19 +198,41 @@ def _send_command_sync(
|
|||||||
time.sleep(_RETRY_INITIAL_BACKOFF * (2 ** (attempt - 1)))
|
time.sleep(_RETRY_INITIAL_BACKOFF * (2 ** (attempt - 1)))
|
||||||
continue
|
continue
|
||||||
raise Fail2BanConnectionError(str(exc), socket_path) from exc
|
raise Fail2BanConnectionError(str(exc), socket_path) from exc
|
||||||
|
except Exception as exc:
|
||||||
|
raise Fail2BanProtocolError(
|
||||||
|
f"Failed to parse fail2ban response: {exc}"
|
||||||
|
) from exc
|
||||||
finally:
|
finally:
|
||||||
with contextlib.suppress(OSError):
|
if client is not None:
|
||||||
sock.sendall(_PROTO_CLOSE + _PROTO_END)
|
with contextlib.suppress(Exception):
|
||||||
with contextlib.suppress(OSError):
|
client.close()
|
||||||
sock.shutdown(socket.SHUT_RDWR)
|
|
||||||
sock.close()
|
|
||||||
|
|
||||||
# Exhausted all retry attempts — surface the last transient error.
|
|
||||||
raise Fail2BanConnectionError(
|
raise Fail2BanConnectionError(
|
||||||
str(last_oserror), socket_path
|
str(last_oserror), socket_path
|
||||||
) from last_oserror
|
) from last_oserror
|
||||||
|
|
||||||
|
|
||||||
|
class Fail2BanAdapter(Protocol):
|
||||||
|
"""Protocol for a fail2ban socket adapter."""
|
||||||
|
|
||||||
|
async def send(self, command: Fail2BanCommand) -> object:
|
||||||
|
"""Send a command to fail2ban and return the raw response."""
|
||||||
|
|
||||||
|
async def ping(self) -> bool:
|
||||||
|
"""Return ``True`` if fail2ban is reachable."""
|
||||||
|
|
||||||
|
async def __aenter__(self) -> Fail2BanAdapter:
|
||||||
|
"""Enter the async context manager."""
|
||||||
|
|
||||||
|
async def __aexit__(
|
||||||
|
self,
|
||||||
|
exc_type: type[BaseException] | None,
|
||||||
|
exc_val: BaseException | None,
|
||||||
|
exc_tb: TracebackType | None,
|
||||||
|
) -> None:
|
||||||
|
"""Exit the async context manager."""
|
||||||
|
|
||||||
|
|
||||||
def _coerce_command_token(token: object) -> Fail2BanToken:
|
def _coerce_command_token(token: object) -> Fail2BanToken:
|
||||||
"""Coerce a command token to a type that fail2ban understands.
|
"""Coerce a command token to a type that fail2ban understands.
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.utils.fail2ban_client import (
|
from app.utils.fail2ban_client import (
|
||||||
_PROTO_END,
|
_RETRY_MAX_ATTEMPTS,
|
||||||
Fail2BanClient,
|
Fail2BanClient,
|
||||||
Fail2BanConnectionError,
|
Fail2BanConnectionError,
|
||||||
Fail2BanProtocolError,
|
Fail2BanProtocolError,
|
||||||
@@ -78,40 +78,43 @@ class TestSendCommandSync:
|
|||||||
|
|
||||||
def test_send_command_sync_raises_connection_error_on_oserror(self) -> None:
|
def test_send_command_sync_raises_connection_error_on_oserror(self) -> None:
|
||||||
"""Must translate :class:`OSError` into :class:`Fail2BanConnectionError`."""
|
"""Must translate :class:`OSError` into :class:`Fail2BanConnectionError`."""
|
||||||
with patch("socket.socket") as mock_socket_cls:
|
fake_instance = MagicMock()
|
||||||
mock_sock = MagicMock()
|
fake_instance.send.side_effect = OSError("connection refused")
|
||||||
mock_sock.connect.side_effect = OSError("connection refused")
|
fake_instance.close.return_value = None
|
||||||
mock_socket_cls.return_value = mock_sock
|
fake_cls = MagicMock(return_value=fake_instance)
|
||||||
with pytest.raises(Fail2BanConnectionError):
|
|
||||||
_send_command_sync(
|
with patch(
|
||||||
socket_path="/fake/fail2ban.sock",
|
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
||||||
command=["status"],
|
return_value=fake_cls,
|
||||||
timeout=1.0,
|
), pytest.raises(Fail2BanConnectionError):
|
||||||
)
|
_send_command_sync(
|
||||||
|
socket_path="/fake/fail2ban.sock",
|
||||||
|
command=["status"],
|
||||||
|
timeout=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestSendCommandSyncProtocol:
|
class TestSendCommandSyncProtocol:
|
||||||
"""Tests for edge cases in the receive-loop and unpickling logic."""
|
"""Tests for edge cases in the vendored fail2ban client adapter."""
|
||||||
|
|
||||||
def _make_connected_sock(self) -> MagicMock:
|
def _make_connected_client(self) -> MagicMock:
|
||||||
"""Return a minimal mock socket that reports a successful connect.
|
"""Return a minimal mock client instance that succeeds on close."""
|
||||||
|
mock_client = MagicMock()
|
||||||
Returns:
|
mock_client.close.return_value = None
|
||||||
A :class:`unittest.mock.MagicMock` that mimics a socket.
|
return mock_client
|
||||||
"""
|
|
||||||
mock_sock = MagicMock()
|
|
||||||
mock_sock.connect.return_value = None
|
|
||||||
return mock_sock
|
|
||||||
|
|
||||||
def test_send_command_sync_raises_connection_error_on_empty_chunk(self) -> None:
|
def test_send_command_sync_raises_connection_error_on_empty_chunk(self) -> None:
|
||||||
"""Must raise :class:`Fail2BanConnectionError` when the server closes mid-stream."""
|
"""Must raise :class:`Fail2BanConnectionError` when the server closes mid-stream."""
|
||||||
mock_sock = self._make_connected_sock()
|
fake_client = self._make_connected_client()
|
||||||
# First recv returns empty bytes → server closed the connection.
|
fake_client.send.side_effect = OSError(104, "Connection reset by peer")
|
||||||
mock_sock.recv.return_value = b""
|
fake_cls = MagicMock(return_value=fake_client)
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("socket.socket", return_value=mock_sock),
|
patch(
|
||||||
pytest.raises(Fail2BanConnectionError, match="closed unexpectedly"),
|
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
||||||
|
return_value=fake_cls,
|
||||||
|
),
|
||||||
|
pytest.raises(Fail2BanConnectionError, match="Connection reset by peer"),
|
||||||
):
|
):
|
||||||
_send_command_sync(
|
_send_command_sync(
|
||||||
socket_path="/fake/fail2ban.sock",
|
socket_path="/fake/fail2ban.sock",
|
||||||
@@ -121,18 +124,16 @@ class TestSendCommandSyncProtocol:
|
|||||||
|
|
||||||
def test_send_command_sync_raises_protocol_error_on_bad_pickle(self) -> None:
|
def test_send_command_sync_raises_protocol_error_on_bad_pickle(self) -> None:
|
||||||
"""Must raise :class:`Fail2BanProtocolError` when the response is not valid pickle."""
|
"""Must raise :class:`Fail2BanProtocolError` when the response is not valid pickle."""
|
||||||
mock_sock = self._make_connected_sock()
|
fake_client = self._make_connected_client()
|
||||||
# Return the end marker directly so the recv-loop terminates immediately,
|
fake_client.send.side_effect = Exception("bad pickle")
|
||||||
# but prepend garbage bytes so ``loads`` fails.
|
fake_cls = MagicMock(return_value=fake_client)
|
||||||
mock_sock.recv.side_effect = [
|
|
||||||
_PROTO_END, # first call — exits the receive loop
|
|
||||||
]
|
|
||||||
|
|
||||||
# Patch loads to raise to simulate a corrupted response.
|
|
||||||
with (
|
with (
|
||||||
patch("socket.socket", return_value=mock_sock),
|
patch(
|
||||||
patch("app.utils.fail2ban_client.loads", side_effect=Exception("bad pickle")),
|
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
||||||
pytest.raises(Fail2BanProtocolError, match="Failed to unpickle"),
|
return_value=fake_cls,
|
||||||
|
),
|
||||||
|
pytest.raises(Fail2BanProtocolError, match="Failed to parse"),
|
||||||
):
|
):
|
||||||
_send_command_sync(
|
_send_command_sync(
|
||||||
socket_path="/fake/fail2ban.sock",
|
socket_path="/fake/fail2ban.sock",
|
||||||
@@ -143,13 +144,13 @@ class TestSendCommandSyncProtocol:
|
|||||||
def test_send_command_sync_returns_parsed_response(self) -> None:
|
def test_send_command_sync_returns_parsed_response(self) -> None:
|
||||||
"""Must return the Python object that was pickled by fail2ban."""
|
"""Must return the Python object that was pickled by fail2ban."""
|
||||||
expected_response = [0, ["sshd", "nginx"]]
|
expected_response = [0, ["sshd", "nginx"]]
|
||||||
mock_sock = self._make_connected_sock()
|
fake_client = self._make_connected_client()
|
||||||
# Return the proto end-marker so the recv-loop exits, then parse the raw bytes.
|
fake_client.send.return_value = expected_response
|
||||||
mock_sock.recv.return_value = _PROTO_END
|
fake_cls = MagicMock(return_value=fake_client)
|
||||||
|
|
||||||
with (
|
with patch(
|
||||||
patch("socket.socket", return_value=mock_sock),
|
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
||||||
patch("app.utils.fail2ban_client.loads", return_value=expected_response),
|
return_value=fake_cls,
|
||||||
):
|
):
|
||||||
result = _send_command_sync(
|
result = _send_command_sync(
|
||||||
socket_path="/fake/fail2ban.sock",
|
socket_path="/fake/fail2ban.sock",
|
||||||
@@ -241,9 +242,8 @@ class TestFail2BanClientSend:
|
|||||||
"asyncio.to_thread",
|
"asyncio.to_thread",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
side_effect=Fail2BanConnectionError("unreachable", "/fake/fail2ban.sock"),
|
side_effect=Fail2BanConnectionError("unreachable", "/fake/fail2ban.sock"),
|
||||||
):
|
), pytest.raises(Fail2BanConnectionError):
|
||||||
with pytest.raises(Fail2BanConnectionError):
|
await client.send(["status"])
|
||||||
await client.send(["status"])
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_logs_warning_on_connection_error(self) -> None:
|
async def test_send_logs_warning_on_connection_error(self) -> None:
|
||||||
@@ -254,9 +254,8 @@ class TestFail2BanClientSend:
|
|||||||
"asyncio.to_thread",
|
"asyncio.to_thread",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock"),
|
side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock"),
|
||||||
):
|
), patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanConnectionError):
|
||||||
with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanConnectionError):
|
await client.send(["ping"])
|
||||||
await client.send(["ping"])
|
|
||||||
|
|
||||||
warning_calls = [
|
warning_calls = [
|
||||||
c for c in mock_log.warning.call_args_list
|
c for c in mock_log.warning.call_args_list
|
||||||
@@ -273,9 +272,8 @@ class TestFail2BanClientSend:
|
|||||||
"asyncio.to_thread",
|
"asyncio.to_thread",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
side_effect=Fail2BanProtocolError("bad pickle"),
|
side_effect=Fail2BanProtocolError("bad pickle"),
|
||||||
):
|
), pytest.raises(Fail2BanProtocolError):
|
||||||
with pytest.raises(Fail2BanProtocolError):
|
await client.send(["status"])
|
||||||
await client.send(["status"])
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_raises_on_protocol_error(self) -> None:
|
async def test_send_raises_on_protocol_error(self) -> None:
|
||||||
@@ -286,9 +284,8 @@ class TestFail2BanClientSend:
|
|||||||
"asyncio.to_thread",
|
"asyncio.to_thread",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
side_effect=Fail2BanProtocolError("bad pickle"),
|
side_effect=Fail2BanProtocolError("bad pickle"),
|
||||||
):
|
), pytest.raises(Fail2BanProtocolError):
|
||||||
with pytest.raises(Fail2BanProtocolError):
|
await client.send(["status"])
|
||||||
await client.send(["status"])
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_send_logs_error_on_protocol_error(self) -> None:
|
async def test_send_logs_error_on_protocol_error(self) -> None:
|
||||||
@@ -299,9 +296,8 @@ class TestFail2BanClientSend:
|
|||||||
"asyncio.to_thread",
|
"asyncio.to_thread",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
side_effect=Fail2BanProtocolError("corrupt response"),
|
side_effect=Fail2BanProtocolError("corrupt response"),
|
||||||
):
|
), patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanProtocolError):
|
||||||
with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanProtocolError):
|
await client.send(["get", "sshd", "banned"])
|
||||||
await client.send(["get", "sshd", "banned"])
|
|
||||||
|
|
||||||
error_calls = [
|
error_calls = [
|
||||||
c for c in mock_log.error.call_args_list
|
c for c in mock_log.error.call_args_list
|
||||||
@@ -318,11 +314,11 @@ class TestFail2BanClientSend:
|
|||||||
class TestSendCommandSyncRetry:
|
class TestSendCommandSyncRetry:
|
||||||
"""Tests for the retry-on-transient-OSError logic in :func:`_send_command_sync`."""
|
"""Tests for the retry-on-transient-OSError logic in :func:`_send_command_sync`."""
|
||||||
|
|
||||||
def _make_sock(self) -> MagicMock:
|
def _make_client(self) -> MagicMock:
|
||||||
"""Return a mock socket that connects without error."""
|
"""Return a mock client that succeeds on close."""
|
||||||
mock_sock = MagicMock()
|
mock_client = MagicMock()
|
||||||
mock_sock.connect.return_value = None
|
mock_client.close.return_value = None
|
||||||
return mock_sock
|
return mock_client
|
||||||
|
|
||||||
def _eagain(self) -> OSError:
|
def _eagain(self) -> OSError:
|
||||||
"""Return an ``OSError`` with ``errno.EAGAIN``."""
|
"""Return an ``OSError`` with ``errno.EAGAIN``."""
|
||||||
@@ -342,77 +338,75 @@ class TestSendCommandSyncRetry:
|
|||||||
|
|
||||||
def test_transient_eagain_retried_succeeds_on_second_attempt(self) -> None:
|
def test_transient_eagain_retried_succeeds_on_second_attempt(self) -> None:
|
||||||
"""A single EAGAIN on connect is retried; success on the second attempt."""
|
"""A single EAGAIN on connect is retried; success on the second attempt."""
|
||||||
from app.utils.fail2ban_client import _PROTO_END
|
|
||||||
|
|
||||||
call_count = 0
|
call_count = 0
|
||||||
|
|
||||||
def _connect_side_effect(sock_path: str) -> None:
|
def _client_side_effect(socket_path: str, timeout: float) -> MagicMock:
|
||||||
nonlocal call_count
|
nonlocal call_count
|
||||||
call_count += 1
|
call_count += 1
|
||||||
if call_count == 1:
|
if call_count == 1:
|
||||||
raise self._eagain()
|
raise self._eagain()
|
||||||
# Second attempt succeeds (no-op).
|
return self._make_client()
|
||||||
|
|
||||||
mock_sock = self._make_sock()
|
|
||||||
mock_sock.connect.side_effect = _connect_side_effect
|
|
||||||
mock_sock.recv.return_value = _PROTO_END
|
|
||||||
expected = [0, "pong"]
|
expected = [0, "pong"]
|
||||||
|
successful_client = self._make_client()
|
||||||
|
successful_client.send.return_value = expected
|
||||||
|
fake_cls = MagicMock(side_effect=[self._eagain(), successful_client])
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("socket.socket", return_value=mock_sock),
|
patch(
|
||||||
patch("app.utils.fail2ban_client.loads", return_value=expected),
|
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
||||||
patch("app.utils.fail2ban_client.time.sleep"), # suppress backoff delay
|
return_value=fake_cls,
|
||||||
|
),
|
||||||
|
patch("app.utils.fail2ban_client.time.sleep"),
|
||||||
):
|
):
|
||||||
result = _send_command_sync("/fake.sock", ["ping"], 1.0)
|
result = _send_command_sync("/fake.sock", ["ping"], 1.0)
|
||||||
|
|
||||||
assert result == expected
|
assert result == expected
|
||||||
assert call_count == 2
|
assert fake_cls.call_count == 2
|
||||||
|
|
||||||
def test_three_eagain_failures_raise_connection_error(self) -> None:
|
def test_three_eagain_failures_raise_connection_error(self) -> None:
|
||||||
"""Three consecutive EAGAIN failures must raise :class:`Fail2BanConnectionError`."""
|
"""Three consecutive EAGAIN failures must raise :class:`Fail2BanConnectionError`."""
|
||||||
mock_sock = self._make_sock()
|
fake_cls = MagicMock(side_effect=[self._eagain(), self._eagain(), self._eagain()])
|
||||||
mock_sock.connect.side_effect = self._eagain()
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("socket.socket", return_value=mock_sock),
|
patch(
|
||||||
|
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
||||||
|
return_value=fake_cls,
|
||||||
|
),
|
||||||
patch("app.utils.fail2ban_client.time.sleep"),
|
patch("app.utils.fail2ban_client.time.sleep"),
|
||||||
pytest.raises(Fail2BanConnectionError),
|
pytest.raises(Fail2BanConnectionError),
|
||||||
):
|
):
|
||||||
_send_command_sync("/fake.sock", ["status"], 1.0)
|
_send_command_sync("/fake.sock", ["status"], 1.0)
|
||||||
|
|
||||||
# connect() should have been called exactly _RETRY_MAX_ATTEMPTS times.
|
assert fake_cls.call_count == _RETRY_MAX_ATTEMPTS
|
||||||
from app.utils.fail2ban_client import _RETRY_MAX_ATTEMPTS
|
|
||||||
|
|
||||||
assert mock_sock.connect.call_count == _RETRY_MAX_ATTEMPTS
|
|
||||||
|
|
||||||
def test_enoent_raises_immediately_without_retry(self) -> None:
|
def test_enoent_raises_immediately_without_retry(self) -> None:
|
||||||
"""A non-retryable ``OSError`` (``ENOENT``) must be raised on the first attempt."""
|
"""A non-retryable ``OSError`` (``ENOENT``) must be raised on the first attempt."""
|
||||||
mock_sock = self._make_sock()
|
fake_cls = MagicMock(side_effect=self._enoent())
|
||||||
mock_sock.connect.side_effect = self._enoent()
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("socket.socket", return_value=mock_sock),
|
patch(
|
||||||
|
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
||||||
|
return_value=fake_cls,
|
||||||
|
),
|
||||||
patch("app.utils.fail2ban_client.time.sleep") as mock_sleep,
|
patch("app.utils.fail2ban_client.time.sleep") as mock_sleep,
|
||||||
pytest.raises(Fail2BanConnectionError),
|
pytest.raises(Fail2BanConnectionError),
|
||||||
):
|
):
|
||||||
_send_command_sync("/fake.sock", ["status"], 1.0)
|
_send_command_sync("/fake.sock", ["status"], 1.0)
|
||||||
|
|
||||||
# No back-off sleep should have been triggered.
|
|
||||||
mock_sleep.assert_not_called()
|
mock_sleep.assert_not_called()
|
||||||
assert mock_sock.connect.call_count == 1
|
assert fake_cls.call_count == 1
|
||||||
|
|
||||||
def test_protocol_error_never_retried(self) -> None:
|
def test_protocol_error_never_retried(self) -> None:
|
||||||
"""A :class:`Fail2BanProtocolError` must be re-raised immediately."""
|
"""A :class:`Fail2BanProtocolError` must be re-raised immediately."""
|
||||||
from app.utils.fail2ban_client import _PROTO_END
|
fake_client = self._make_client()
|
||||||
|
fake_client.send.side_effect = Exception("bad pickle")
|
||||||
mock_sock = self._make_sock()
|
fake_cls = MagicMock(return_value=fake_client)
|
||||||
mock_sock.recv.return_value = _PROTO_END
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("socket.socket", return_value=mock_sock),
|
|
||||||
patch(
|
patch(
|
||||||
"app.utils.fail2ban_client.loads",
|
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
||||||
side_effect=Exception("bad pickle"),
|
return_value=fake_cls,
|
||||||
),
|
),
|
||||||
patch("app.utils.fail2ban_client.time.sleep") as mock_sleep,
|
patch("app.utils.fail2ban_client.time.sleep") as mock_sleep,
|
||||||
pytest.raises(Fail2BanProtocolError),
|
pytest.raises(Fail2BanProtocolError),
|
||||||
@@ -423,11 +417,13 @@ class TestSendCommandSyncRetry:
|
|||||||
|
|
||||||
def test_retry_emits_structured_log_event(self) -> None:
|
def test_retry_emits_structured_log_event(self) -> None:
|
||||||
"""Each retry attempt logs a ``fail2ban_socket_retry`` warning."""
|
"""Each retry attempt logs a ``fail2ban_socket_retry`` warning."""
|
||||||
mock_sock = self._make_sock()
|
fake_cls = MagicMock(side_effect=[self._eagain(), self._eagain(), self._eagain()])
|
||||||
mock_sock.connect.side_effect = self._eagain()
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("socket.socket", return_value=mock_sock),
|
patch(
|
||||||
|
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
||||||
|
return_value=fake_cls,
|
||||||
|
),
|
||||||
patch("app.utils.fail2ban_client.time.sleep"),
|
patch("app.utils.fail2ban_client.time.sleep"),
|
||||||
patch("app.utils.fail2ban_client.log") as mock_log,
|
patch("app.utils.fail2ban_client.log") as mock_log,
|
||||||
pytest.raises(Fail2BanConnectionError),
|
pytest.raises(Fail2BanConnectionError),
|
||||||
@@ -438,9 +434,7 @@ class TestSendCommandSyncRetry:
|
|||||||
c for c in mock_log.warning.call_args_list
|
c for c in mock_log.warning.call_args_list
|
||||||
if c[0][0] == "fail2ban_socket_retry"
|
if c[0][0] == "fail2ban_socket_retry"
|
||||||
]
|
]
|
||||||
from app.utils.fail2ban_client import _RETRY_MAX_ATTEMPTS
|
|
||||||
|
|
||||||
# One retry log per attempt except the last (which raises directly).
|
|
||||||
assert len(retry_calls) == _RETRY_MAX_ATTEMPTS - 1
|
assert len(retry_calls) == _RETRY_MAX_ATTEMPTS - 1
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user