"""Tests for app.utils.fail2ban_client.""" from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.utils.fail2ban_client import ( _PROTO_END, Fail2BanClient, Fail2BanConnectionError, Fail2BanProtocolError, _coerce_command_token, _send_command_sync, ) class TestFail2BanClientPing: """Tests for :meth:`Fail2BanClient.ping`.""" @pytest.mark.asyncio async def test_ping_returns_true_when_daemon_responds(self) -> None: """``ping()`` must return ``True`` when fail2ban responds with 1.""" client = Fail2BanClient(socket_path="/fake/fail2ban.sock") with patch.object(client, "send", new_callable=AsyncMock, return_value=1): result = await client.ping() assert result is True @pytest.mark.asyncio async def test_ping_returns_false_on_connection_error(self) -> None: """``ping()`` must return ``False`` when the daemon is unreachable.""" client = Fail2BanClient(socket_path="/fake/fail2ban.sock") with patch.object( client, "send", new_callable=AsyncMock, side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock"), ): result = await client.ping() assert result is False @pytest.mark.asyncio async def test_ping_returns_false_on_protocol_error(self) -> None: """``ping()`` must return ``False`` if the response cannot be parsed.""" client = Fail2BanClient(socket_path="/fake/fail2ban.sock") with patch.object( client, "send", new_callable=AsyncMock, side_effect=Fail2BanProtocolError("bad pickle"), ): result = await client.ping() assert result is False class TestFail2BanClientContextManager: """Tests for the async context manager protocol.""" @pytest.mark.asyncio async def test_context_manager_returns_self(self) -> None: """``async with Fail2BanClient(...)`` must yield the client itself.""" client = Fail2BanClient(socket_path="/fake/fail2ban.sock") async with client as ctx: assert ctx is client class TestSendCommandSync: """Tests for the synchronous :func:`_send_command_sync` helper.""" def test_send_command_sync_raises_connection_error_when_socket_absent(self) -> None: """Must raise :class:`Fail2BanConnectionError` if the socket does not exist.""" with pytest.raises(Fail2BanConnectionError): _send_command_sync( socket_path="/nonexistent/fail2ban.sock", command=["ping"], timeout=1.0, ) def test_send_command_sync_raises_connection_error_on_oserror(self) -> None: """Must translate :class:`OSError` into :class:`Fail2BanConnectionError`.""" with patch("socket.socket") as mock_socket_cls: mock_sock = MagicMock() mock_sock.connect.side_effect = OSError("connection refused") mock_socket_cls.return_value = mock_sock with pytest.raises(Fail2BanConnectionError): _send_command_sync( socket_path="/fake/fail2ban.sock", command=["status"], timeout=1.0, ) class TestSendCommandSyncProtocol: """Tests for edge cases in the receive-loop and unpickling logic.""" def _make_connected_sock(self) -> MagicMock: """Return a minimal mock socket that reports a successful connect. Returns: A :class:`unittest.mock.MagicMock` that mimics a socket. """ mock_sock = MagicMock() mock_sock.connect.return_value = None return mock_sock def test_send_command_sync_raises_connection_error_on_empty_chunk(self) -> None: """Must raise :class:`Fail2BanConnectionError` when the server closes mid-stream.""" mock_sock = self._make_connected_sock() # First recv returns empty bytes → server closed the connection. mock_sock.recv.return_value = b"" with ( patch("socket.socket", return_value=mock_sock), pytest.raises(Fail2BanConnectionError, match="closed unexpectedly"), ): _send_command_sync( socket_path="/fake/fail2ban.sock", command=["ping"], timeout=1.0, ) def test_send_command_sync_raises_protocol_error_on_bad_pickle(self) -> None: """Must raise :class:`Fail2BanProtocolError` when the response is not valid pickle.""" mock_sock = self._make_connected_sock() # Return the end marker directly so the recv-loop terminates immediately, # but prepend garbage bytes so ``loads`` fails. mock_sock.recv.side_effect = [ _PROTO_END, # first call — exits the receive loop ] # Patch loads to raise to simulate a corrupted response. with ( patch("socket.socket", return_value=mock_sock), patch("app.utils.fail2ban_client.loads", side_effect=Exception("bad pickle")), pytest.raises(Fail2BanProtocolError, match="Failed to unpickle"), ): _send_command_sync( socket_path="/fake/fail2ban.sock", command=["status"], timeout=1.0, ) def test_send_command_sync_returns_parsed_response(self) -> None: """Must return the Python object that was pickled by fail2ban.""" expected_response = [0, ["sshd", "nginx"]] mock_sock = self._make_connected_sock() # Return the proto end-marker so the recv-loop exits, then parse the raw bytes. mock_sock.recv.return_value = _PROTO_END with ( patch("socket.socket", return_value=mock_sock), patch("app.utils.fail2ban_client.loads", return_value=expected_response), ): result = _send_command_sync( socket_path="/fake/fail2ban.sock", command=["status"], timeout=1.0, ) assert result == expected_response # --------------------------------------------------------------------------- # Tests for _coerce_command_token # --------------------------------------------------------------------------- class TestCoerceCommandToken: """Tests for :func:`~app.utils.fail2ban_client._coerce_command_token`.""" def test_coerce_str_unchanged(self) -> None: """``str`` tokens must pass through unchanged.""" assert _coerce_command_token("sshd") == "sshd" def test_coerce_bool_unchanged(self) -> None: """``bool`` tokens must pass through unchanged.""" assert _coerce_command_token(True) is True # noqa: FBT003 def test_coerce_int_unchanged(self) -> None: """``int`` tokens must pass through unchanged.""" assert _coerce_command_token(42) == 42 def test_coerce_float_unchanged(self) -> None: """``float`` tokens must pass through unchanged.""" assert _coerce_command_token(1.5) == 1.5 def test_coerce_list_unchanged(self) -> None: """``list`` tokens must pass through unchanged.""" token: list[int] = [1, 2] assert _coerce_command_token(token) is token def test_coerce_dict_unchanged(self) -> None: """``dict`` tokens must pass through unchanged.""" token: dict[str, str] = {"key": "value"} assert _coerce_command_token(token) is token def test_coerce_set_unchanged(self) -> None: """``set`` tokens must pass through unchanged.""" token: set[str] = {"a", "b"} assert _coerce_command_token(token) is token def test_coerce_unknown_type_stringified(self) -> None: """Any other type must be converted to its ``str()`` representation.""" class CustomObj: def __str__(self) -> str: return "custom_repr" assert _coerce_command_token(CustomObj()) == "custom_repr" def test_coerce_none_stringified(self) -> None: """``None`` must be stringified to ``"None"``.""" assert _coerce_command_token(None) == "None" # --------------------------------------------------------------------------- # Extended tests for Fail2BanClient.send # --------------------------------------------------------------------------- class TestFail2BanClientSend: """Tests for :meth:`Fail2BanClient.send`.""" @pytest.mark.asyncio async def test_send_returns_response_on_success(self) -> None: """``send()`` must return the response from the executor.""" expected = [0, "OK"] client = Fail2BanClient(socket_path="/fake/fail2ban.sock") # asyncio.get_event_loop().run_in_executor is called inside send(). # We patch it on the loop object returned by asyncio.get_event_loop(). with patch("asyncio.get_event_loop") as mock_get_loop: mock_loop = AsyncMock() mock_loop.run_in_executor = AsyncMock(return_value=expected) mock_get_loop.return_value = mock_loop result = await client.send(["status"]) assert result == expected @pytest.mark.asyncio async def test_send_reraises_connection_error(self) -> None: """``send()`` must re-raise :class:`Fail2BanConnectionError`.""" client = Fail2BanClient(socket_path="/fake/fail2ban.sock") with patch("asyncio.get_event_loop") as mock_get_loop: mock_loop = AsyncMock() mock_loop.run_in_executor = AsyncMock( side_effect=Fail2BanConnectionError("unreachable", "/fake/fail2ban.sock") ) mock_get_loop.return_value = mock_loop with pytest.raises(Fail2BanConnectionError): await client.send(["status"]) @pytest.mark.asyncio async def test_send_logs_warning_on_connection_error(self) -> None: """``send()`` must log a warning when a connection error occurs.""" client = Fail2BanClient(socket_path="/fake/fail2ban.sock") with patch("asyncio.get_event_loop") as mock_get_loop: mock_loop = AsyncMock() mock_loop.run_in_executor = AsyncMock( side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock") ) mock_get_loop.return_value = mock_loop with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanConnectionError): await client.send(["ping"]) warning_calls = [ c for c in mock_log.warning.call_args_list if c[0][0] == "fail2ban_connection_error" ] assert len(warning_calls) == 1 @pytest.mark.asyncio async def test_send_reraises_protocol_error(self) -> None: """``send()`` must re-raise :class:`Fail2BanProtocolError`.""" client = Fail2BanClient(socket_path="/fake/fail2ban.sock") with patch("asyncio.get_event_loop") as mock_get_loop: mock_loop = AsyncMock() mock_loop.run_in_executor = AsyncMock( side_effect=Fail2BanProtocolError("bad pickle") ) mock_get_loop.return_value = mock_loop with pytest.raises(Fail2BanProtocolError): await client.send(["status"]) @pytest.mark.asyncio async def test_send_raises_on_protocol_error(self) -> None: """``send()`` must propagate :class:`Fail2BanProtocolError` to the caller.""" client = Fail2BanClient(socket_path="/fake/fail2ban.sock") with patch("asyncio.get_event_loop") as mock_get_loop: mock_loop = AsyncMock() mock_loop.run_in_executor = AsyncMock( side_effect=Fail2BanProtocolError("bad pickle") ) mock_get_loop.return_value = mock_loop with pytest.raises(Fail2BanProtocolError): await client.send(["status"]) @pytest.mark.asyncio async def test_send_logs_error_on_protocol_error(self) -> None: """``send()`` must log an error when a protocol error occurs.""" client = Fail2BanClient(socket_path="/fake/fail2ban.sock") with patch("asyncio.get_event_loop") as mock_get_loop: mock_loop = AsyncMock() mock_loop.run_in_executor = AsyncMock( side_effect=Fail2BanProtocolError("corrupt response") ) mock_get_loop.return_value = mock_loop with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanProtocolError): await client.send(["get", "sshd", "banned"]) error_calls = [ c for c in mock_log.error.call_args_list if c[0][0] == "fail2ban_protocol_error" ] assert len(error_calls) == 1 # --------------------------------------------------------------------------- # Tests for _send_command_sync retry logic (Stage 6.1 / 6.3) # --------------------------------------------------------------------------- class TestSendCommandSyncRetry: """Tests for the retry-on-transient-OSError logic in :func:`_send_command_sync`.""" def _make_sock(self) -> MagicMock: """Return a mock socket that connects without error.""" mock_sock = MagicMock() mock_sock.connect.return_value = None return mock_sock def _eagain(self) -> OSError: """Return an ``OSError`` with ``errno.EAGAIN``.""" import errno as _errno err = OSError("Resource temporarily unavailable") err.errno = _errno.EAGAIN return err def _enoent(self) -> OSError: """Return an ``OSError`` with ``errno.ENOENT``.""" import errno as _errno err = OSError("No such file or directory") err.errno = _errno.ENOENT return err def test_transient_eagain_retried_succeeds_on_second_attempt(self) -> None: """A single EAGAIN on connect is retried; success on the second attempt.""" from app.utils.fail2ban_client import _PROTO_END call_count = 0 def _connect_side_effect(sock_path: str) -> None: nonlocal call_count call_count += 1 if call_count == 1: raise self._eagain() # Second attempt succeeds (no-op). mock_sock = self._make_sock() mock_sock.connect.side_effect = _connect_side_effect mock_sock.recv.return_value = _PROTO_END expected = [0, "pong"] with ( patch("socket.socket", return_value=mock_sock), patch("app.utils.fail2ban_client.loads", return_value=expected), patch("app.utils.fail2ban_client.time.sleep"), # suppress backoff delay ): result = _send_command_sync("/fake.sock", ["ping"], 1.0) assert result == expected assert call_count == 2 def test_three_eagain_failures_raise_connection_error(self) -> None: """Three consecutive EAGAIN failures must raise :class:`Fail2BanConnectionError`.""" mock_sock = self._make_sock() mock_sock.connect.side_effect = self._eagain() with ( patch("socket.socket", return_value=mock_sock), patch("app.utils.fail2ban_client.time.sleep"), pytest.raises(Fail2BanConnectionError), ): _send_command_sync("/fake.sock", ["status"], 1.0) # connect() should have been called exactly _RETRY_MAX_ATTEMPTS times. from app.utils.fail2ban_client import _RETRY_MAX_ATTEMPTS assert mock_sock.connect.call_count == _RETRY_MAX_ATTEMPTS def test_enoent_raises_immediately_without_retry(self) -> None: """A non-retryable ``OSError`` (``ENOENT``) must be raised on the first attempt.""" mock_sock = self._make_sock() mock_sock.connect.side_effect = self._enoent() with ( patch("socket.socket", return_value=mock_sock), patch("app.utils.fail2ban_client.time.sleep") as mock_sleep, pytest.raises(Fail2BanConnectionError), ): _send_command_sync("/fake.sock", ["status"], 1.0) # No back-off sleep should have been triggered. mock_sleep.assert_not_called() assert mock_sock.connect.call_count == 1 def test_protocol_error_never_retried(self) -> None: """A :class:`Fail2BanProtocolError` must be re-raised immediately.""" from app.utils.fail2ban_client import _PROTO_END mock_sock = self._make_sock() mock_sock.recv.return_value = _PROTO_END with ( patch("socket.socket", return_value=mock_sock), patch( "app.utils.fail2ban_client.loads", side_effect=Exception("bad pickle"), ), patch("app.utils.fail2ban_client.time.sleep") as mock_sleep, pytest.raises(Fail2BanProtocolError), ): _send_command_sync("/fake.sock", ["status"], 1.0) mock_sleep.assert_not_called() def test_retry_emits_structured_log_event(self) -> None: """Each retry attempt logs a ``fail2ban_socket_retry`` warning.""" mock_sock = self._make_sock() mock_sock.connect.side_effect = self._eagain() with ( patch("socket.socket", return_value=mock_sock), patch("app.utils.fail2ban_client.time.sleep"), patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanConnectionError), ): _send_command_sync("/fake.sock", ["status"], 1.0) retry_calls = [ c for c in mock_log.warning.call_args_list if c[0][0] == "fail2ban_socket_retry" ] from app.utils.fail2ban_client import _RETRY_MAX_ATTEMPTS # One retry log per attempt except the last (which raises directly). assert len(retry_calls) == _RETRY_MAX_ATTEMPTS - 1 # --------------------------------------------------------------------------- # Tests for Fail2BanClient semaphore (Stage 6.2 / 6.3) # --------------------------------------------------------------------------- class TestFail2BanClientSemaphore: """Tests for the concurrency semaphore in :meth:`Fail2BanClient.send`.""" @pytest.mark.asyncio async def test_semaphore_limits_concurrency(self) -> None: """No more than _COMMAND_SEMAPHORE_CONCURRENCY commands overlap.""" import asyncio as _asyncio import app.utils.fail2ban_client as _module # Reset module-level semaphore so this test starts fresh. _module._command_semaphore = None concurrency_limit = 3 _module._COMMAND_SEMAPHORE_CONCURRENCY = concurrency_limit _module._command_semaphore = _asyncio.Semaphore(concurrency_limit) in_flight: list[int] = [] peak_concurrent: list[int] = [] async def _slow_send(command: list[Any]) -> Any: in_flight.append(1) peak_concurrent.append(len(in_flight)) await _asyncio.sleep(0) # yield to allow other coroutines to run in_flight.pop() return (0, "ok") client = Fail2BanClient(socket_path="/fake/fail2ban.sock") with patch.object(client, "send", wraps=_slow_send) as _patched: # Bypass the semaphore wrapper — test the actual send directly. pass # Override _command_semaphore and run concurrently via the real send path # but mock _send_command_sync to avoid actual socket I/O. async def _fast_executor(_fn: Any, *_args: Any) -> Any: in_flight.append(1) peak_concurrent.append(len(in_flight)) await _asyncio.sleep(0) in_flight.pop() return (0, "ok") client2 = Fail2BanClient(socket_path="/fake/fail2ban.sock") with patch("asyncio.get_event_loop") as mock_loop_getter: mock_loop = MagicMock() mock_loop.run_in_executor = _fast_executor mock_loop_getter.return_value = mock_loop tasks = [ _asyncio.create_task(client2.send(["ping"])) for _ in range(10) ] await _asyncio.gather(*tasks) # Peak concurrent activity must never exceed the semaphore limit. assert max(peak_concurrent) <= concurrency_limit # Restore module defaults after test. _module._COMMAND_SEMAPHORE_CONCURRENCY = 10 _module._command_semaphore = None