Refactor fail2ban client to use vendored adapter
This commit is contained in:
@@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from app.utils.fail2ban_client import (
|
||||
_PROTO_END,
|
||||
_RETRY_MAX_ATTEMPTS,
|
||||
Fail2BanClient,
|
||||
Fail2BanConnectionError,
|
||||
Fail2BanProtocolError,
|
||||
@@ -78,40 +78,43 @@ class TestSendCommandSync:
|
||||
|
||||
def test_send_command_sync_raises_connection_error_on_oserror(self) -> None:
|
||||
"""Must translate :class:`OSError` into :class:`Fail2BanConnectionError`."""
|
||||
with patch("socket.socket") as mock_socket_cls:
|
||||
mock_sock = MagicMock()
|
||||
mock_sock.connect.side_effect = OSError("connection refused")
|
||||
mock_socket_cls.return_value = mock_sock
|
||||
with pytest.raises(Fail2BanConnectionError):
|
||||
_send_command_sync(
|
||||
socket_path="/fake/fail2ban.sock",
|
||||
command=["status"],
|
||||
timeout=1.0,
|
||||
)
|
||||
fake_instance = MagicMock()
|
||||
fake_instance.send.side_effect = OSError("connection refused")
|
||||
fake_instance.close.return_value = None
|
||||
fake_cls = MagicMock(return_value=fake_instance)
|
||||
|
||||
with patch(
|
||||
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
||||
return_value=fake_cls,
|
||||
), pytest.raises(Fail2BanConnectionError):
|
||||
_send_command_sync(
|
||||
socket_path="/fake/fail2ban.sock",
|
||||
command=["status"],
|
||||
timeout=1.0,
|
||||
)
|
||||
|
||||
|
||||
class TestSendCommandSyncProtocol:
|
||||
"""Tests for edge cases in the receive-loop and unpickling logic."""
|
||||
"""Tests for edge cases in the vendored fail2ban client adapter."""
|
||||
|
||||
def _make_connected_sock(self) -> MagicMock:
|
||||
"""Return a minimal mock socket that reports a successful connect.
|
||||
|
||||
Returns:
|
||||
A :class:`unittest.mock.MagicMock` that mimics a socket.
|
||||
"""
|
||||
mock_sock = MagicMock()
|
||||
mock_sock.connect.return_value = None
|
||||
return mock_sock
|
||||
def _make_connected_client(self) -> MagicMock:
|
||||
"""Return a minimal mock client instance that succeeds on close."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.close.return_value = None
|
||||
return mock_client
|
||||
|
||||
def test_send_command_sync_raises_connection_error_on_empty_chunk(self) -> None:
|
||||
"""Must raise :class:`Fail2BanConnectionError` when the server closes mid-stream."""
|
||||
mock_sock = self._make_connected_sock()
|
||||
# First recv returns empty bytes → server closed the connection.
|
||||
mock_sock.recv.return_value = b""
|
||||
fake_client = self._make_connected_client()
|
||||
fake_client.send.side_effect = OSError(104, "Connection reset by peer")
|
||||
fake_cls = MagicMock(return_value=fake_client)
|
||||
|
||||
with (
|
||||
patch("socket.socket", return_value=mock_sock),
|
||||
pytest.raises(Fail2BanConnectionError, match="closed unexpectedly"),
|
||||
patch(
|
||||
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
||||
return_value=fake_cls,
|
||||
),
|
||||
pytest.raises(Fail2BanConnectionError, match="Connection reset by peer"),
|
||||
):
|
||||
_send_command_sync(
|
||||
socket_path="/fake/fail2ban.sock",
|
||||
@@ -121,18 +124,16 @@ class TestSendCommandSyncProtocol:
|
||||
|
||||
def test_send_command_sync_raises_protocol_error_on_bad_pickle(self) -> None:
|
||||
"""Must raise :class:`Fail2BanProtocolError` when the response is not valid pickle."""
|
||||
mock_sock = self._make_connected_sock()
|
||||
# Return the end marker directly so the recv-loop terminates immediately,
|
||||
# but prepend garbage bytes so ``loads`` fails.
|
||||
mock_sock.recv.side_effect = [
|
||||
_PROTO_END, # first call — exits the receive loop
|
||||
]
|
||||
fake_client = self._make_connected_client()
|
||||
fake_client.send.side_effect = Exception("bad pickle")
|
||||
fake_cls = MagicMock(return_value=fake_client)
|
||||
|
||||
# Patch loads to raise to simulate a corrupted response.
|
||||
with (
|
||||
patch("socket.socket", return_value=mock_sock),
|
||||
patch("app.utils.fail2ban_client.loads", side_effect=Exception("bad pickle")),
|
||||
pytest.raises(Fail2BanProtocolError, match="Failed to unpickle"),
|
||||
patch(
|
||||
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
||||
return_value=fake_cls,
|
||||
),
|
||||
pytest.raises(Fail2BanProtocolError, match="Failed to parse"),
|
||||
):
|
||||
_send_command_sync(
|
||||
socket_path="/fake/fail2ban.sock",
|
||||
@@ -143,13 +144,13 @@ class TestSendCommandSyncProtocol:
|
||||
def test_send_command_sync_returns_parsed_response(self) -> None:
|
||||
"""Must return the Python object that was pickled by fail2ban."""
|
||||
expected_response = [0, ["sshd", "nginx"]]
|
||||
mock_sock = self._make_connected_sock()
|
||||
# Return the proto end-marker so the recv-loop exits, then parse the raw bytes.
|
||||
mock_sock.recv.return_value = _PROTO_END
|
||||
fake_client = self._make_connected_client()
|
||||
fake_client.send.return_value = expected_response
|
||||
fake_cls = MagicMock(return_value=fake_client)
|
||||
|
||||
with (
|
||||
patch("socket.socket", return_value=mock_sock),
|
||||
patch("app.utils.fail2ban_client.loads", return_value=expected_response),
|
||||
with patch(
|
||||
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
||||
return_value=fake_cls,
|
||||
):
|
||||
result = _send_command_sync(
|
||||
socket_path="/fake/fail2ban.sock",
|
||||
@@ -241,9 +242,8 @@ class TestFail2BanClientSend:
|
||||
"asyncio.to_thread",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Fail2BanConnectionError("unreachable", "/fake/fail2ban.sock"),
|
||||
):
|
||||
with pytest.raises(Fail2BanConnectionError):
|
||||
await client.send(["status"])
|
||||
), pytest.raises(Fail2BanConnectionError):
|
||||
await client.send(["status"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_logs_warning_on_connection_error(self) -> None:
|
||||
@@ -254,9 +254,8 @@ class TestFail2BanClientSend:
|
||||
"asyncio.to_thread",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock"),
|
||||
):
|
||||
with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanConnectionError):
|
||||
await client.send(["ping"])
|
||||
), patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanConnectionError):
|
||||
await client.send(["ping"])
|
||||
|
||||
warning_calls = [
|
||||
c for c in mock_log.warning.call_args_list
|
||||
@@ -273,9 +272,8 @@ class TestFail2BanClientSend:
|
||||
"asyncio.to_thread",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Fail2BanProtocolError("bad pickle"),
|
||||
):
|
||||
with pytest.raises(Fail2BanProtocolError):
|
||||
await client.send(["status"])
|
||||
), pytest.raises(Fail2BanProtocolError):
|
||||
await client.send(["status"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_raises_on_protocol_error(self) -> None:
|
||||
@@ -286,9 +284,8 @@ class TestFail2BanClientSend:
|
||||
"asyncio.to_thread",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Fail2BanProtocolError("bad pickle"),
|
||||
):
|
||||
with pytest.raises(Fail2BanProtocolError):
|
||||
await client.send(["status"])
|
||||
), pytest.raises(Fail2BanProtocolError):
|
||||
await client.send(["status"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_logs_error_on_protocol_error(self) -> None:
|
||||
@@ -299,9 +296,8 @@ class TestFail2BanClientSend:
|
||||
"asyncio.to_thread",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Fail2BanProtocolError("corrupt response"),
|
||||
):
|
||||
with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanProtocolError):
|
||||
await client.send(["get", "sshd", "banned"])
|
||||
), patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanProtocolError):
|
||||
await client.send(["get", "sshd", "banned"])
|
||||
|
||||
error_calls = [
|
||||
c for c in mock_log.error.call_args_list
|
||||
@@ -318,11 +314,11 @@ class TestFail2BanClientSend:
|
||||
class TestSendCommandSyncRetry:
|
||||
"""Tests for the retry-on-transient-OSError logic in :func:`_send_command_sync`."""
|
||||
|
||||
def _make_sock(self) -> MagicMock:
|
||||
"""Return a mock socket that connects without error."""
|
||||
mock_sock = MagicMock()
|
||||
mock_sock.connect.return_value = None
|
||||
return mock_sock
|
||||
def _make_client(self) -> MagicMock:
|
||||
"""Return a mock client that succeeds on close."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.close.return_value = None
|
||||
return mock_client
|
||||
|
||||
def _eagain(self) -> OSError:
|
||||
"""Return an ``OSError`` with ``errno.EAGAIN``."""
|
||||
@@ -342,77 +338,75 @@ class TestSendCommandSyncRetry:
|
||||
|
||||
def test_transient_eagain_retried_succeeds_on_second_attempt(self) -> None:
|
||||
"""A single EAGAIN on connect is retried; success on the second attempt."""
|
||||
from app.utils.fail2ban_client import _PROTO_END
|
||||
|
||||
call_count = 0
|
||||
|
||||
def _connect_side_effect(sock_path: str) -> None:
|
||||
def _client_side_effect(socket_path: str, timeout: float) -> MagicMock:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise self._eagain()
|
||||
# Second attempt succeeds (no-op).
|
||||
return self._make_client()
|
||||
|
||||
mock_sock = self._make_sock()
|
||||
mock_sock.connect.side_effect = _connect_side_effect
|
||||
mock_sock.recv.return_value = _PROTO_END
|
||||
expected = [0, "pong"]
|
||||
successful_client = self._make_client()
|
||||
successful_client.send.return_value = expected
|
||||
fake_cls = MagicMock(side_effect=[self._eagain(), successful_client])
|
||||
|
||||
with (
|
||||
patch("socket.socket", return_value=mock_sock),
|
||||
patch("app.utils.fail2ban_client.loads", return_value=expected),
|
||||
patch("app.utils.fail2ban_client.time.sleep"), # suppress backoff delay
|
||||
patch(
|
||||
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
||||
return_value=fake_cls,
|
||||
),
|
||||
patch("app.utils.fail2ban_client.time.sleep"),
|
||||
):
|
||||
result = _send_command_sync("/fake.sock", ["ping"], 1.0)
|
||||
|
||||
assert result == expected
|
||||
assert call_count == 2
|
||||
assert fake_cls.call_count == 2
|
||||
|
||||
def test_three_eagain_failures_raise_connection_error(self) -> None:
|
||||
"""Three consecutive EAGAIN failures must raise :class:`Fail2BanConnectionError`."""
|
||||
mock_sock = self._make_sock()
|
||||
mock_sock.connect.side_effect = self._eagain()
|
||||
fake_cls = MagicMock(side_effect=[self._eagain(), self._eagain(), self._eagain()])
|
||||
|
||||
with (
|
||||
patch("socket.socket", return_value=mock_sock),
|
||||
patch(
|
||||
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
||||
return_value=fake_cls,
|
||||
),
|
||||
patch("app.utils.fail2ban_client.time.sleep"),
|
||||
pytest.raises(Fail2BanConnectionError),
|
||||
):
|
||||
_send_command_sync("/fake.sock", ["status"], 1.0)
|
||||
|
||||
# connect() should have been called exactly _RETRY_MAX_ATTEMPTS times.
|
||||
from app.utils.fail2ban_client import _RETRY_MAX_ATTEMPTS
|
||||
|
||||
assert mock_sock.connect.call_count == _RETRY_MAX_ATTEMPTS
|
||||
assert fake_cls.call_count == _RETRY_MAX_ATTEMPTS
|
||||
|
||||
def test_enoent_raises_immediately_without_retry(self) -> None:
|
||||
"""A non-retryable ``OSError`` (``ENOENT``) must be raised on the first attempt."""
|
||||
mock_sock = self._make_sock()
|
||||
mock_sock.connect.side_effect = self._enoent()
|
||||
fake_cls = MagicMock(side_effect=self._enoent())
|
||||
|
||||
with (
|
||||
patch("socket.socket", return_value=mock_sock),
|
||||
patch(
|
||||
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
||||
return_value=fake_cls,
|
||||
),
|
||||
patch("app.utils.fail2ban_client.time.sleep") as mock_sleep,
|
||||
pytest.raises(Fail2BanConnectionError),
|
||||
):
|
||||
_send_command_sync("/fake.sock", ["status"], 1.0)
|
||||
|
||||
# No back-off sleep should have been triggered.
|
||||
mock_sleep.assert_not_called()
|
||||
assert mock_sock.connect.call_count == 1
|
||||
assert fake_cls.call_count == 1
|
||||
|
||||
def test_protocol_error_never_retried(self) -> None:
|
||||
"""A :class:`Fail2BanProtocolError` must be re-raised immediately."""
|
||||
from app.utils.fail2ban_client import _PROTO_END
|
||||
|
||||
mock_sock = self._make_sock()
|
||||
mock_sock.recv.return_value = _PROTO_END
|
||||
fake_client = self._make_client()
|
||||
fake_client.send.side_effect = Exception("bad pickle")
|
||||
fake_cls = MagicMock(return_value=fake_client)
|
||||
|
||||
with (
|
||||
patch("socket.socket", return_value=mock_sock),
|
||||
patch(
|
||||
"app.utils.fail2ban_client.loads",
|
||||
side_effect=Exception("bad pickle"),
|
||||
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
||||
return_value=fake_cls,
|
||||
),
|
||||
patch("app.utils.fail2ban_client.time.sleep") as mock_sleep,
|
||||
pytest.raises(Fail2BanProtocolError),
|
||||
@@ -423,11 +417,13 @@ class TestSendCommandSyncRetry:
|
||||
|
||||
def test_retry_emits_structured_log_event(self) -> None:
|
||||
"""Each retry attempt logs a ``fail2ban_socket_retry`` warning."""
|
||||
mock_sock = self._make_sock()
|
||||
mock_sock.connect.side_effect = self._eagain()
|
||||
fake_cls = MagicMock(side_effect=[self._eagain(), self._eagain(), self._eagain()])
|
||||
|
||||
with (
|
||||
patch("socket.socket", return_value=mock_sock),
|
||||
patch(
|
||||
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
||||
return_value=fake_cls,
|
||||
),
|
||||
patch("app.utils.fail2ban_client.time.sleep"),
|
||||
patch("app.utils.fail2ban_client.log") as mock_log,
|
||||
pytest.raises(Fail2BanConnectionError),
|
||||
@@ -438,9 +434,7 @@ class TestSendCommandSyncRetry:
|
||||
c for c in mock_log.warning.call_args_list
|
||||
if c[0][0] == "fail2ban_socket_retry"
|
||||
]
|
||||
from app.utils.fail2ban_client import _RETRY_MAX_ATTEMPTS
|
||||
|
||||
# One retry log per attempt except the last (which raises directly).
|
||||
assert len(retry_calls) == _RETRY_MAX_ATTEMPTS - 1
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user