484 lines
18 KiB
Python
484 lines
18 KiB
Python
"""Tests for app.utils.fail2ban_client."""
|
|
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.utils.fail2ban_client import (
|
|
_RETRY_MAX_ATTEMPTS,
|
|
Fail2BanClient,
|
|
Fail2BanConnectionError,
|
|
Fail2BanProtocolError,
|
|
_coerce_command_token,
|
|
_send_command_sync,
|
|
)
|
|
|
|
|
|
class TestFail2BanClientPing:
|
|
"""Tests for :meth:`Fail2BanClient.ping`."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ping_returns_true_when_daemon_responds(self) -> None:
|
|
"""``ping()`` must return ``True`` when fail2ban responds with 1."""
|
|
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
|
with patch.object(client, "send", new_callable=AsyncMock, return_value=1):
|
|
result = await client.ping()
|
|
assert result is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ping_returns_false_on_connection_error(self) -> None:
|
|
"""``ping()`` must return ``False`` when the daemon is unreachable."""
|
|
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
|
with patch.object(
|
|
client,
|
|
"send",
|
|
new_callable=AsyncMock,
|
|
side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock"),
|
|
):
|
|
result = await client.ping()
|
|
assert result is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ping_returns_false_on_protocol_error(self) -> None:
|
|
"""``ping()`` must return ``False`` if the response cannot be parsed."""
|
|
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
|
with patch.object(
|
|
client,
|
|
"send",
|
|
new_callable=AsyncMock,
|
|
side_effect=Fail2BanProtocolError("bad pickle"),
|
|
):
|
|
result = await client.ping()
|
|
assert result is False
|
|
|
|
|
|
class TestFail2BanClientContextManager:
|
|
"""Tests for the async context manager protocol."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_context_manager_returns_self(self) -> None:
|
|
"""``async with Fail2BanClient(...)`` must yield the client itself."""
|
|
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
|
async with client as ctx:
|
|
assert ctx is client
|
|
|
|
|
|
class TestSendCommandSync:
|
|
"""Tests for the synchronous :func:`_send_command_sync` helper."""
|
|
|
|
def test_send_command_sync_raises_connection_error_when_socket_absent(self) -> None:
|
|
"""Must raise :class:`Fail2BanConnectionError` if the socket does not exist."""
|
|
with pytest.raises(Fail2BanConnectionError):
|
|
_send_command_sync(
|
|
socket_path="/nonexistent/fail2ban.sock",
|
|
command=["ping"],
|
|
timeout=1.0,
|
|
)
|
|
|
|
def test_send_command_sync_raises_connection_error_on_oserror(self) -> None:
|
|
"""Must translate :class:`OSError` into :class:`Fail2BanConnectionError`."""
|
|
fake_instance = MagicMock()
|
|
fake_instance.send.side_effect = OSError("connection refused")
|
|
fake_instance.close.return_value = None
|
|
fake_cls = MagicMock(return_value=fake_instance)
|
|
|
|
with patch(
|
|
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
|
return_value=fake_cls,
|
|
), pytest.raises(Fail2BanConnectionError):
|
|
_send_command_sync(
|
|
socket_path="/fake/fail2ban.sock",
|
|
command=["status"],
|
|
timeout=1.0,
|
|
)
|
|
|
|
|
|
class TestSendCommandSyncProtocol:
|
|
"""Tests for edge cases in the vendored fail2ban client adapter."""
|
|
|
|
def _make_connected_client(self) -> MagicMock:
|
|
"""Return a minimal mock client instance that succeeds on close."""
|
|
mock_client = MagicMock()
|
|
mock_client.close.return_value = None
|
|
return mock_client
|
|
|
|
def test_send_command_sync_raises_connection_error_on_empty_chunk(self) -> None:
|
|
"""Must raise :class:`Fail2BanConnectionError` when the server closes mid-stream."""
|
|
fake_client = self._make_connected_client()
|
|
fake_client.send.side_effect = OSError(104, "Connection reset by peer")
|
|
fake_cls = MagicMock(return_value=fake_client)
|
|
|
|
with (
|
|
patch(
|
|
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
|
return_value=fake_cls,
|
|
),
|
|
pytest.raises(Fail2BanConnectionError, match="Connection reset by peer"),
|
|
):
|
|
_send_command_sync(
|
|
socket_path="/fake/fail2ban.sock",
|
|
command=["ping"],
|
|
timeout=1.0,
|
|
)
|
|
|
|
def test_send_command_sync_raises_protocol_error_on_bad_pickle(self) -> None:
|
|
"""Must raise :class:`Fail2BanProtocolError` when the response is not valid pickle."""
|
|
fake_client = self._make_connected_client()
|
|
fake_client.send.side_effect = Exception("bad pickle")
|
|
fake_cls = MagicMock(return_value=fake_client)
|
|
|
|
with (
|
|
patch(
|
|
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
|
return_value=fake_cls,
|
|
),
|
|
pytest.raises(Fail2BanProtocolError, match="Failed to parse"),
|
|
):
|
|
_send_command_sync(
|
|
socket_path="/fake/fail2ban.sock",
|
|
command=["status"],
|
|
timeout=1.0,
|
|
)
|
|
|
|
def test_send_command_sync_returns_parsed_response(self) -> None:
|
|
"""Must return the Python object that was pickled by fail2ban."""
|
|
expected_response = [0, ["sshd", "nginx"]]
|
|
fake_client = self._make_connected_client()
|
|
fake_client.send.return_value = expected_response
|
|
fake_cls = MagicMock(return_value=fake_client)
|
|
|
|
with patch(
|
|
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
|
return_value=fake_cls,
|
|
):
|
|
result = _send_command_sync(
|
|
socket_path="/fake/fail2ban.sock",
|
|
command=["status"],
|
|
timeout=1.0,
|
|
)
|
|
|
|
assert result == expected_response
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests for _coerce_command_token
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestCoerceCommandToken:
|
|
"""Tests for :func:`~app.utils.fail2ban_client._coerce_command_token`."""
|
|
|
|
def test_coerce_str_unchanged(self) -> None:
|
|
"""``str`` tokens must pass through unchanged."""
|
|
assert _coerce_command_token("sshd") == "sshd"
|
|
|
|
def test_coerce_bool_unchanged(self) -> None:
|
|
"""``bool`` tokens must pass through unchanged."""
|
|
assert _coerce_command_token(True) is True # noqa: FBT003
|
|
|
|
def test_coerce_int_unchanged(self) -> None:
|
|
"""``int`` tokens must pass through unchanged."""
|
|
assert _coerce_command_token(42) == 42
|
|
|
|
def test_coerce_float_unchanged(self) -> None:
|
|
"""``float`` tokens must pass through unchanged."""
|
|
assert _coerce_command_token(1.5) == 1.5
|
|
|
|
def test_coerce_list_unchanged(self) -> None:
|
|
"""``list`` tokens must pass through unchanged."""
|
|
token: list[int] = [1, 2]
|
|
assert _coerce_command_token(token) is token
|
|
|
|
def test_coerce_dict_unchanged(self) -> None:
|
|
"""``dict`` tokens must pass through unchanged."""
|
|
token: dict[str, str] = {"key": "value"}
|
|
assert _coerce_command_token(token) is token
|
|
|
|
def test_coerce_set_unchanged(self) -> None:
|
|
"""``set`` tokens must pass through unchanged."""
|
|
token: set[str] = {"a", "b"}
|
|
assert _coerce_command_token(token) is token
|
|
|
|
def test_coerce_unknown_type_stringified(self) -> None:
|
|
"""Any other type must be converted to its ``str()`` representation."""
|
|
|
|
class CustomObj:
|
|
def __str__(self) -> str:
|
|
return "custom_repr"
|
|
|
|
assert _coerce_command_token(CustomObj()) == "custom_repr"
|
|
|
|
def test_coerce_none_stringified(self) -> None:
|
|
"""``None`` must be stringified to ``"None"``."""
|
|
assert _coerce_command_token(None) == "None"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Extended tests for Fail2BanClient.send
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestFail2BanClientSend:
|
|
"""Tests for :meth:`Fail2BanClient.send`."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_returns_response_on_success(self) -> None:
|
|
"""``send()`` must return the response from the executor."""
|
|
expected = [0, "OK"]
|
|
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
|
with patch("asyncio.to_thread", new_callable=AsyncMock, return_value=expected) as mock_to_thread:
|
|
result = await client.send(["status"])
|
|
|
|
mock_to_thread.assert_awaited_once()
|
|
assert result == expected
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_reraises_connection_error(self) -> None:
|
|
"""``send()`` must re-raise :class:`Fail2BanConnectionError`."""
|
|
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
|
|
|
with patch(
|
|
"asyncio.to_thread",
|
|
new_callable=AsyncMock,
|
|
side_effect=Fail2BanConnectionError("unreachable", "/fake/fail2ban.sock"),
|
|
), pytest.raises(Fail2BanConnectionError):
|
|
await client.send(["status"])
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_logs_warning_on_connection_error(self) -> None:
|
|
"""``send()`` must log a warning when a connection error occurs."""
|
|
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
|
|
|
with patch(
|
|
"asyncio.to_thread",
|
|
new_callable=AsyncMock,
|
|
side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock"),
|
|
), patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanConnectionError):
|
|
await client.send(["ping"])
|
|
|
|
warning_calls = [
|
|
c for c in mock_log.warning.call_args_list
|
|
if c[0][0] == "fail2ban_connection_error"
|
|
]
|
|
assert len(warning_calls) == 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_reraises_protocol_error(self) -> None:
|
|
"""``send()`` must re-raise :class:`Fail2BanProtocolError`."""
|
|
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
|
|
|
with patch(
|
|
"asyncio.to_thread",
|
|
new_callable=AsyncMock,
|
|
side_effect=Fail2BanProtocolError("bad pickle"),
|
|
), pytest.raises(Fail2BanProtocolError):
|
|
await client.send(["status"])
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_raises_on_protocol_error(self) -> None:
|
|
"""``send()`` must propagate :class:`Fail2BanProtocolError` to the caller."""
|
|
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
|
|
|
with patch(
|
|
"asyncio.to_thread",
|
|
new_callable=AsyncMock,
|
|
side_effect=Fail2BanProtocolError("bad pickle"),
|
|
), pytest.raises(Fail2BanProtocolError):
|
|
await client.send(["status"])
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_logs_error_on_protocol_error(self) -> None:
|
|
"""``send()`` must log an error when a protocol error occurs."""
|
|
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
|
|
|
with patch(
|
|
"asyncio.to_thread",
|
|
new_callable=AsyncMock,
|
|
side_effect=Fail2BanProtocolError("corrupt response"),
|
|
), patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanProtocolError):
|
|
await client.send(["get", "sshd", "banned"])
|
|
|
|
error_calls = [
|
|
c for c in mock_log.error.call_args_list
|
|
if c[0][0] == "fail2ban_protocol_error"
|
|
]
|
|
assert len(error_calls) == 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests for _send_command_sync retry logic (Stage 6.1 / 6.3)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSendCommandSyncRetry:
|
|
"""Tests for the retry-on-transient-OSError logic in :func:`_send_command_sync`."""
|
|
|
|
def _make_client(self) -> MagicMock:
|
|
"""Return a mock client that succeeds on close."""
|
|
mock_client = MagicMock()
|
|
mock_client.close.return_value = None
|
|
return mock_client
|
|
|
|
def _eagain(self) -> OSError:
|
|
"""Return an ``OSError`` with ``errno.EAGAIN``."""
|
|
import errno as _errno
|
|
|
|
err = OSError("Resource temporarily unavailable")
|
|
err.errno = _errno.EAGAIN
|
|
return err
|
|
|
|
def _enoent(self) -> OSError:
|
|
"""Return an ``OSError`` with ``errno.ENOENT``."""
|
|
import errno as _errno
|
|
|
|
err = OSError("No such file or directory")
|
|
err.errno = _errno.ENOENT
|
|
return err
|
|
|
|
def test_transient_eagain_retried_succeeds_on_second_attempt(self) -> None:
|
|
"""A single EAGAIN on connect is retried; success on the second attempt."""
|
|
call_count = 0
|
|
|
|
def _client_side_effect(socket_path: str, timeout: float) -> MagicMock:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
if call_count == 1:
|
|
raise self._eagain()
|
|
return self._make_client()
|
|
|
|
expected = [0, "pong"]
|
|
successful_client = self._make_client()
|
|
successful_client.send.return_value = expected
|
|
fake_cls = MagicMock(side_effect=[self._eagain(), successful_client])
|
|
|
|
with (
|
|
patch(
|
|
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
|
return_value=fake_cls,
|
|
),
|
|
patch("app.utils.fail2ban_client.time.sleep"),
|
|
):
|
|
result = _send_command_sync("/fake.sock", ["ping"], 1.0)
|
|
|
|
assert result == expected
|
|
assert fake_cls.call_count == 2
|
|
|
|
def test_three_eagain_failures_raise_connection_error(self) -> None:
|
|
"""Three consecutive EAGAIN failures must raise :class:`Fail2BanConnectionError`."""
|
|
fake_cls = MagicMock(side_effect=[self._eagain(), self._eagain(), self._eagain()])
|
|
|
|
with (
|
|
patch(
|
|
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
|
return_value=fake_cls,
|
|
),
|
|
patch("app.utils.fail2ban_client.time.sleep"),
|
|
pytest.raises(Fail2BanConnectionError),
|
|
):
|
|
_send_command_sync("/fake.sock", ["status"], 1.0)
|
|
|
|
assert fake_cls.call_count == _RETRY_MAX_ATTEMPTS
|
|
|
|
def test_enoent_raises_immediately_without_retry(self) -> None:
|
|
"""A non-retryable ``OSError`` (``ENOENT``) must be raised on the first attempt."""
|
|
fake_cls = MagicMock(side_effect=self._enoent())
|
|
|
|
with (
|
|
patch(
|
|
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
|
return_value=fake_cls,
|
|
),
|
|
patch("app.utils.fail2ban_client.time.sleep") as mock_sleep,
|
|
pytest.raises(Fail2BanConnectionError),
|
|
):
|
|
_send_command_sync("/fake.sock", ["status"], 1.0)
|
|
|
|
mock_sleep.assert_not_called()
|
|
assert fake_cls.call_count == 1
|
|
|
|
def test_protocol_error_never_retried(self) -> None:
|
|
"""A :class:`Fail2BanProtocolError` must be re-raised immediately."""
|
|
fake_client = self._make_client()
|
|
fake_client.send.side_effect = Exception("bad pickle")
|
|
fake_cls = MagicMock(return_value=fake_client)
|
|
|
|
with (
|
|
patch(
|
|
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
|
return_value=fake_cls,
|
|
),
|
|
patch("app.utils.fail2ban_client.time.sleep") as mock_sleep,
|
|
pytest.raises(Fail2BanProtocolError),
|
|
):
|
|
_send_command_sync("/fake.sock", ["status"], 1.0)
|
|
|
|
mock_sleep.assert_not_called()
|
|
|
|
def test_retry_emits_structured_log_event(self) -> None:
|
|
"""Each retry attempt logs a ``fail2ban_socket_retry`` warning."""
|
|
fake_cls = MagicMock(side_effect=[self._eagain(), self._eagain(), self._eagain()])
|
|
|
|
with (
|
|
patch(
|
|
"app.utils.fail2ban_client._load_vendored_fail2ban_client",
|
|
return_value=fake_cls,
|
|
),
|
|
patch("app.utils.fail2ban_client.time.sleep"),
|
|
patch("app.utils.fail2ban_client.log") as mock_log,
|
|
pytest.raises(Fail2BanConnectionError),
|
|
):
|
|
_send_command_sync("/fake.sock", ["status"], 1.0)
|
|
|
|
retry_calls = [
|
|
c for c in mock_log.warning.call_args_list
|
|
if c[0][0] == "fail2ban_socket_retry"
|
|
]
|
|
|
|
assert len(retry_calls) == _RETRY_MAX_ATTEMPTS - 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests for Fail2BanClient semaphore (Stage 6.2 / 6.3)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestFail2BanClientSemaphore:
|
|
"""Tests for the concurrency semaphore in :meth:`Fail2BanClient.send`."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_semaphore_limits_concurrency_per_instance(self) -> None:
|
|
"""Each client instance must enforce its own concurrency cap."""
|
|
import asyncio as _asyncio
|
|
|
|
from app.utils import fail2ban_client as _module
|
|
|
|
concurrency_limit = 3
|
|
_module._COMMAND_SEMAPHORE_CONCURRENCY = concurrency_limit
|
|
|
|
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
|
client2 = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
|
|
|
assert client._command_semaphore is not client2._command_semaphore
|
|
|
|
in_flight: list[int] = []
|
|
peak_concurrent: list[int] = []
|
|
|
|
async def _fast_executor(_fn: Any, *_args: Any) -> Any:
|
|
in_flight.append(1)
|
|
peak_concurrent.append(len(in_flight))
|
|
await _asyncio.sleep(0)
|
|
in_flight.pop()
|
|
return (0, "ok")
|
|
|
|
with patch("asyncio.to_thread", new=_fast_executor):
|
|
tasks = [
|
|
_asyncio.create_task(client.send(["ping"])) for _ in range(10)
|
|
]
|
|
await _asyncio.gather(*tasks)
|
|
|
|
assert max(peak_concurrent) <= concurrency_limit
|
|
|
|
# Restore module defaults after test.
|
|
_module._COMMAND_SEMAPHORE_CONCURRENCY = 10
|