Add tests for background tasks and fail2ban client utility
- tests/test_tasks/test_blocklist_import.py: 14 tests, 96% coverage - tests/test_tasks/test_health_check.py: 12 tests, 100% coverage - tests/test_tasks/test_geo_cache_flush.py: 8 tests, 100% coverage - tests/test_services/test_fail2ban_client.py: 24 new tests, 96% coverage Total: 50 new tests (628 → 678 passing). Overall coverage 85% → 87%. ruff, mypy --strict, tsc, and eslint all clean.
This commit is contained in:
@@ -5,9 +5,11 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from app.utils.fail2ban_client import (
|
||||
_PROTO_END,
|
||||
Fail2BanClient,
|
||||
Fail2BanConnectionError,
|
||||
Fail2BanProtocolError,
|
||||
_coerce_command_token,
|
||||
_send_command_sync,
|
||||
)
|
||||
|
||||
@@ -85,3 +87,223 @@ class TestSendCommandSync:
|
||||
command=["status"],
|
||||
timeout=1.0,
|
||||
)
|
||||
|
||||
|
||||
class TestSendCommandSyncProtocol:
|
||||
"""Tests for edge cases in the receive-loop and unpickling logic."""
|
||||
|
||||
def _make_connected_sock(self) -> MagicMock:
|
||||
"""Return a minimal mock socket that reports a successful connect.
|
||||
|
||||
Returns:
|
||||
A :class:`unittest.mock.MagicMock` that mimics a socket.
|
||||
"""
|
||||
mock_sock = MagicMock()
|
||||
mock_sock.connect.return_value = None
|
||||
return mock_sock
|
||||
|
||||
def test_send_command_sync_raises_connection_error_on_empty_chunk(self) -> None:
|
||||
"""Must raise :class:`Fail2BanConnectionError` when the server closes mid-stream."""
|
||||
mock_sock = self._make_connected_sock()
|
||||
# First recv returns empty bytes → server closed the connection.
|
||||
mock_sock.recv.return_value = b""
|
||||
|
||||
with (
|
||||
patch("socket.socket", return_value=mock_sock),
|
||||
pytest.raises(Fail2BanConnectionError, match="closed unexpectedly"),
|
||||
):
|
||||
_send_command_sync(
|
||||
socket_path="/fake/fail2ban.sock",
|
||||
command=["ping"],
|
||||
timeout=1.0,
|
||||
)
|
||||
|
||||
def test_send_command_sync_raises_protocol_error_on_bad_pickle(self) -> None:
|
||||
"""Must raise :class:`Fail2BanProtocolError` when the response is not valid pickle."""
|
||||
mock_sock = self._make_connected_sock()
|
||||
# Return the end marker directly so the recv-loop terminates immediately,
|
||||
# but prepend garbage bytes so ``loads`` fails.
|
||||
mock_sock.recv.side_effect = [
|
||||
_PROTO_END, # first call — exits the receive loop
|
||||
]
|
||||
|
||||
# Patch loads to raise to simulate a corrupted response.
|
||||
with (
|
||||
patch("socket.socket", return_value=mock_sock),
|
||||
patch("app.utils.fail2ban_client.loads", side_effect=Exception("bad pickle")),
|
||||
pytest.raises(Fail2BanProtocolError, match="Failed to unpickle"),
|
||||
):
|
||||
_send_command_sync(
|
||||
socket_path="/fake/fail2ban.sock",
|
||||
command=["status"],
|
||||
timeout=1.0,
|
||||
)
|
||||
|
||||
def test_send_command_sync_returns_parsed_response(self) -> None:
|
||||
"""Must return the Python object that was pickled by fail2ban."""
|
||||
expected_response = [0, ["sshd", "nginx"]]
|
||||
mock_sock = self._make_connected_sock()
|
||||
# Return the proto end-marker so the recv-loop exits, then parse the raw bytes.
|
||||
mock_sock.recv.return_value = _PROTO_END
|
||||
|
||||
with (
|
||||
patch("socket.socket", return_value=mock_sock),
|
||||
patch("app.utils.fail2ban_client.loads", return_value=expected_response),
|
||||
):
|
||||
result = _send_command_sync(
|
||||
socket_path="/fake/fail2ban.sock",
|
||||
command=["status"],
|
||||
timeout=1.0,
|
||||
)
|
||||
|
||||
assert result == expected_response
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for _coerce_command_token
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCoerceCommandToken:
|
||||
"""Tests for :func:`~app.utils.fail2ban_client._coerce_command_token`."""
|
||||
|
||||
def test_coerce_str_unchanged(self) -> None:
|
||||
"""``str`` tokens must pass through unchanged."""
|
||||
assert _coerce_command_token("sshd") == "sshd"
|
||||
|
||||
def test_coerce_bool_unchanged(self) -> None:
|
||||
"""``bool`` tokens must pass through unchanged."""
|
||||
assert _coerce_command_token(True) is True # noqa: FBT003
|
||||
|
||||
def test_coerce_int_unchanged(self) -> None:
|
||||
"""``int`` tokens must pass through unchanged."""
|
||||
assert _coerce_command_token(42) == 42
|
||||
|
||||
def test_coerce_float_unchanged(self) -> None:
|
||||
"""``float`` tokens must pass through unchanged."""
|
||||
assert _coerce_command_token(1.5) == 1.5
|
||||
|
||||
def test_coerce_list_unchanged(self) -> None:
|
||||
"""``list`` tokens must pass through unchanged."""
|
||||
token: list[int] = [1, 2]
|
||||
assert _coerce_command_token(token) is token
|
||||
|
||||
def test_coerce_dict_unchanged(self) -> None:
|
||||
"""``dict`` tokens must pass through unchanged."""
|
||||
token: dict[str, str] = {"key": "value"}
|
||||
assert _coerce_command_token(token) is token
|
||||
|
||||
def test_coerce_set_unchanged(self) -> None:
|
||||
"""``set`` tokens must pass through unchanged."""
|
||||
token: set[str] = {"a", "b"}
|
||||
assert _coerce_command_token(token) is token
|
||||
|
||||
def test_coerce_unknown_type_stringified(self) -> None:
|
||||
"""Any other type must be converted to its ``str()`` representation."""
|
||||
|
||||
class CustomObj:
|
||||
def __str__(self) -> str:
|
||||
return "custom_repr"
|
||||
|
||||
assert _coerce_command_token(CustomObj()) == "custom_repr"
|
||||
|
||||
def test_coerce_none_stringified(self) -> None:
|
||||
"""``None`` must be stringified to ``"None"``."""
|
||||
assert _coerce_command_token(None) == "None"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extended tests for Fail2BanClient.send
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFail2BanClientSend:
|
||||
"""Tests for :meth:`Fail2BanClient.send`."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_returns_response_on_success(self) -> None:
|
||||
"""``send()`` must return the response from the executor."""
|
||||
expected = [0, "OK"]
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
# asyncio.get_event_loop().run_in_executor is called inside send().
|
||||
# We patch it on the loop object returned by asyncio.get_event_loop().
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(return_value=expected)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
result = await client.send(["status"])
|
||||
|
||||
assert result == expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reraises_connection_error(self) -> None:
|
||||
"""``send()`` must re-raise :class:`Fail2BanConnectionError`."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=Fail2BanConnectionError("unreachable", "/fake/fail2ban.sock")
|
||||
)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with pytest.raises(Fail2BanConnectionError):
|
||||
await client.send(["status"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_logs_warning_on_connection_error(self) -> None:
|
||||
"""``send()`` must log a warning when a connection error occurs."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock")
|
||||
)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanConnectionError):
|
||||
await client.send(["ping"])
|
||||
|
||||
warning_calls = [
|
||||
c for c in mock_log.warning.call_args_list
|
||||
if c[0][0] == "fail2ban_connection_error"
|
||||
]
|
||||
assert len(warning_calls) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reraises_protocol_error(self) -> None:
|
||||
"""``send()`` must re-raise :class:`Fail2BanProtocolError`."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=Fail2BanProtocolError("bad pickle")
|
||||
)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with pytest.raises(Fail2BanProtocolError):
|
||||
await client.send(["status"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_logs_error_on_protocol_error(self) -> None:
|
||||
"""``send()`` must log an error when a protocol error occurs."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=Fail2BanProtocolError("corrupt response")
|
||||
)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanProtocolError):
|
||||
await client.send(["get", "sshd", "banned"])
|
||||
|
||||
error_calls = [
|
||||
c for c in mock_log.error.call_args_list
|
||||
if c[0][0] == "fail2ban_protocol_error"
|
||||
]
|
||||
assert len(error_calls) == 1
|
||||
|
||||
Reference in New Issue
Block a user