diff --git a/Docs/Tasks.md b/Docs/Tasks.md index d10117e..e5a8f29 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -93,6 +93,7 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue. - Issue: `app.utils.fail2ban_client` reimplements low-level socket framing, command encoding, and protocol parsing rather than using the vendored fail2ban client classes, creating duplicated protocol logic and an unclear source of truth. - Propose: Introduce a fail2ban adapter interface and either wrap the vendored `fail2ban-client` implementation or refactor the custom client so it is the single canonical integration point. Ensure all services depend on the adapter abstraction rather than raw socket details. - Test: Add adapter-level unit tests and service tests that can swap in a fake fail2ban adapter, proving the backend no longer couples business logic directly to low-level socket protocol code. + - Status: completed 13. Introduce explicit schema migration/versioning for the runtime database - Goal: Allow BanGUI to evolve its application database schema safely across releases and prevent startup failures caused by schema drift. diff --git a/backend/app/utils/fail2ban_client.py b/backend/app/utils/fail2ban_client.py index 82ddfec..2d2b088 100644 --- a/backend/app/utils/fail2ban_client.py +++ b/backend/app/utils/fail2ban_client.py @@ -19,11 +19,11 @@ from __future__ import annotations import asyncio import contextlib import errno -import socket +import sys import time from collections.abc import Mapping, Sequence, Set -from pickle import HIGHEST_PROTOCOL, dumps, loads -from typing import TYPE_CHECKING +from pathlib import Path +from typing import TYPE_CHECKING, Protocol import structlog @@ -69,12 +69,41 @@ if TYPE_CHECKING: log: structlog.stdlib.BoundLogger = structlog.get_logger() -# fail2ban protocol constants — inline to avoid a hard import dependency -# at module load time (the fail2ban-master path may not be on sys.path yet -# in some test environments). -_PROTO_END: bytes = b"" -_PROTO_CLOSE: bytes = b"" -_PROTO_EMPTY: bytes = b"" +# Attempt to reuse the vendored fail2ban package embedded in the repository. +# If it is not on sys.path yet, load it from ``../fail2ban-master``. + +def _load_vendored_fail2ban_client() -> type[object]: + """Import the vendored ``fail2ban.client.csocket.CSocket`` implementation.""" + try: + from fail2ban.client.csocket import CSocket + + return CSocket + except ImportError: + vendor_root = Path(__file__).resolve().parents[4] / "fail2ban-master" + if not vendor_root.is_dir(): + raise + + sys.path.insert(0, str(vendor_root)) + try: + from fail2ban.client.csocket import CSocket + + return CSocket + finally: + if sys.path and sys.path[0] == str(vendor_root): + sys.path.pop(0) + + +def _load_vendored_fail2ban_constants() -> tuple[bytes, bytes, bytes]: + """Load fail2ban protocol constants from the vendored package if possible.""" + try: + from fail2ban.protocol import CSPROTO + + return CSPROTO.END, CSPROTO.CLOSE, CSPROTO.EMPTY + except ImportError: + return b"", b"", b"" + + +_PROTO_END, _PROTO_CLOSE, _PROTO_EMPTY = _load_vendored_fail2ban_constants() # Default receive buffer size (doubles on each iteration up to max). _RECV_BUFSIZE_START: int = 1024 @@ -147,46 +176,14 @@ def _send_command_sync( """ last_oserror: OSError | None = None for attempt in range(1, _RETRY_MAX_ATTEMPTS + 1): - sock: socket.socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + client = None try: - sock.settimeout(timeout) - sock.connect(socket_path) - - # Serialise and send the command. - payload: bytes = dumps( - list(map(_coerce_command_token, command)), - HIGHEST_PROTOCOL, - ) - sock.sendall(payload) - sock.sendall(_PROTO_END) - - # Receive until we see the end marker. - raw: bytes = _PROTO_EMPTY - bufsize: int = _RECV_BUFSIZE_START - while raw.rfind(_PROTO_END, -32) == -1: - chunk: bytes = sock.recv(bufsize) - if not chunk: - raise Fail2BanConnectionError( - "Connection closed unexpectedly by fail2ban", - socket_path, - ) - if chunk == _PROTO_END: - break - raw += chunk - if bufsize < _RECV_BUFSIZE_MAX: - bufsize <<= 1 - - try: - return loads(raw) - except Exception as exc: - raise Fail2BanProtocolError( - f"Failed to unpickle fail2ban response: {exc}" - ) from exc + client_cls = _load_vendored_fail2ban_client() + client = client_cls(socket_path, timeout=timeout) + return client.send(command) except Fail2BanProtocolError: - # Protocol errors are never transient — raise immediately. raise except Fail2BanConnectionError: - # Mid-receive close or empty-chunk error — raise immediately. raise except OSError as exc: is_retryable = exc.errno in _RETRYABLE_ERRNOS @@ -201,19 +198,41 @@ def _send_command_sync( time.sleep(_RETRY_INITIAL_BACKOFF * (2 ** (attempt - 1))) continue raise Fail2BanConnectionError(str(exc), socket_path) from exc + except Exception as exc: + raise Fail2BanProtocolError( + f"Failed to parse fail2ban response: {exc}" + ) from exc finally: - with contextlib.suppress(OSError): - sock.sendall(_PROTO_CLOSE + _PROTO_END) - with contextlib.suppress(OSError): - sock.shutdown(socket.SHUT_RDWR) - sock.close() + if client is not None: + with contextlib.suppress(Exception): + client.close() - # Exhausted all retry attempts — surface the last transient error. raise Fail2BanConnectionError( str(last_oserror), socket_path ) from last_oserror +class Fail2BanAdapter(Protocol): + """Protocol for a fail2ban socket adapter.""" + + async def send(self, command: Fail2BanCommand) -> object: + """Send a command to fail2ban and return the raw response.""" + + async def ping(self) -> bool: + """Return ``True`` if fail2ban is reachable.""" + + async def __aenter__(self) -> Fail2BanAdapter: + """Enter the async context manager.""" + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit the async context manager.""" + + def _coerce_command_token(token: object) -> Fail2BanToken: """Coerce a command token to a type that fail2ban understands. diff --git a/backend/tests/test_services/test_fail2ban_client.py b/backend/tests/test_services/test_fail2ban_client.py index 5b837ab..ae1b62b 100644 --- a/backend/tests/test_services/test_fail2ban_client.py +++ b/backend/tests/test_services/test_fail2ban_client.py @@ -6,7 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.utils.fail2ban_client import ( - _PROTO_END, + _RETRY_MAX_ATTEMPTS, Fail2BanClient, Fail2BanConnectionError, Fail2BanProtocolError, @@ -78,40 +78,43 @@ class TestSendCommandSync: def test_send_command_sync_raises_connection_error_on_oserror(self) -> None: """Must translate :class:`OSError` into :class:`Fail2BanConnectionError`.""" - with patch("socket.socket") as mock_socket_cls: - mock_sock = MagicMock() - mock_sock.connect.side_effect = OSError("connection refused") - mock_socket_cls.return_value = mock_sock - with pytest.raises(Fail2BanConnectionError): - _send_command_sync( - socket_path="/fake/fail2ban.sock", - command=["status"], - timeout=1.0, - ) + fake_instance = MagicMock() + fake_instance.send.side_effect = OSError("connection refused") + fake_instance.close.return_value = None + fake_cls = MagicMock(return_value=fake_instance) + + with patch( + "app.utils.fail2ban_client._load_vendored_fail2ban_client", + return_value=fake_cls, + ), pytest.raises(Fail2BanConnectionError): + _send_command_sync( + socket_path="/fake/fail2ban.sock", + command=["status"], + timeout=1.0, + ) class TestSendCommandSyncProtocol: - """Tests for edge cases in the receive-loop and unpickling logic.""" + """Tests for edge cases in the vendored fail2ban client adapter.""" - def _make_connected_sock(self) -> MagicMock: - """Return a minimal mock socket that reports a successful connect. - - Returns: - A :class:`unittest.mock.MagicMock` that mimics a socket. - """ - mock_sock = MagicMock() - mock_sock.connect.return_value = None - return mock_sock + def _make_connected_client(self) -> MagicMock: + """Return a minimal mock client instance that succeeds on close.""" + mock_client = MagicMock() + mock_client.close.return_value = None + return mock_client def test_send_command_sync_raises_connection_error_on_empty_chunk(self) -> None: """Must raise :class:`Fail2BanConnectionError` when the server closes mid-stream.""" - mock_sock = self._make_connected_sock() - # First recv returns empty bytes → server closed the connection. - mock_sock.recv.return_value = b"" + fake_client = self._make_connected_client() + fake_client.send.side_effect = OSError(104, "Connection reset by peer") + fake_cls = MagicMock(return_value=fake_client) with ( - patch("socket.socket", return_value=mock_sock), - pytest.raises(Fail2BanConnectionError, match="closed unexpectedly"), + patch( + "app.utils.fail2ban_client._load_vendored_fail2ban_client", + return_value=fake_cls, + ), + pytest.raises(Fail2BanConnectionError, match="Connection reset by peer"), ): _send_command_sync( socket_path="/fake/fail2ban.sock", @@ -121,18 +124,16 @@ class TestSendCommandSyncProtocol: def test_send_command_sync_raises_protocol_error_on_bad_pickle(self) -> None: """Must raise :class:`Fail2BanProtocolError` when the response is not valid pickle.""" - mock_sock = self._make_connected_sock() - # Return the end marker directly so the recv-loop terminates immediately, - # but prepend garbage bytes so ``loads`` fails. - mock_sock.recv.side_effect = [ - _PROTO_END, # first call — exits the receive loop - ] + fake_client = self._make_connected_client() + fake_client.send.side_effect = Exception("bad pickle") + fake_cls = MagicMock(return_value=fake_client) - # Patch loads to raise to simulate a corrupted response. with ( - patch("socket.socket", return_value=mock_sock), - patch("app.utils.fail2ban_client.loads", side_effect=Exception("bad pickle")), - pytest.raises(Fail2BanProtocolError, match="Failed to unpickle"), + patch( + "app.utils.fail2ban_client._load_vendored_fail2ban_client", + return_value=fake_cls, + ), + pytest.raises(Fail2BanProtocolError, match="Failed to parse"), ): _send_command_sync( socket_path="/fake/fail2ban.sock", @@ -143,13 +144,13 @@ class TestSendCommandSyncProtocol: def test_send_command_sync_returns_parsed_response(self) -> None: """Must return the Python object that was pickled by fail2ban.""" expected_response = [0, ["sshd", "nginx"]] - mock_sock = self._make_connected_sock() - # Return the proto end-marker so the recv-loop exits, then parse the raw bytes. - mock_sock.recv.return_value = _PROTO_END + fake_client = self._make_connected_client() + fake_client.send.return_value = expected_response + fake_cls = MagicMock(return_value=fake_client) - with ( - patch("socket.socket", return_value=mock_sock), - patch("app.utils.fail2ban_client.loads", return_value=expected_response), + with patch( + "app.utils.fail2ban_client._load_vendored_fail2ban_client", + return_value=fake_cls, ): result = _send_command_sync( socket_path="/fake/fail2ban.sock", @@ -241,9 +242,8 @@ class TestFail2BanClientSend: "asyncio.to_thread", new_callable=AsyncMock, side_effect=Fail2BanConnectionError("unreachable", "/fake/fail2ban.sock"), - ): - with pytest.raises(Fail2BanConnectionError): - await client.send(["status"]) + ), pytest.raises(Fail2BanConnectionError): + await client.send(["status"]) @pytest.mark.asyncio async def test_send_logs_warning_on_connection_error(self) -> None: @@ -254,9 +254,8 @@ class TestFail2BanClientSend: "asyncio.to_thread", new_callable=AsyncMock, side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock"), - ): - with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanConnectionError): - await client.send(["ping"]) + ), patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanConnectionError): + await client.send(["ping"]) warning_calls = [ c for c in mock_log.warning.call_args_list @@ -273,9 +272,8 @@ class TestFail2BanClientSend: "asyncio.to_thread", new_callable=AsyncMock, side_effect=Fail2BanProtocolError("bad pickle"), - ): - with pytest.raises(Fail2BanProtocolError): - await client.send(["status"]) + ), pytest.raises(Fail2BanProtocolError): + await client.send(["status"]) @pytest.mark.asyncio async def test_send_raises_on_protocol_error(self) -> None: @@ -286,9 +284,8 @@ class TestFail2BanClientSend: "asyncio.to_thread", new_callable=AsyncMock, side_effect=Fail2BanProtocolError("bad pickle"), - ): - with pytest.raises(Fail2BanProtocolError): - await client.send(["status"]) + ), pytest.raises(Fail2BanProtocolError): + await client.send(["status"]) @pytest.mark.asyncio async def test_send_logs_error_on_protocol_error(self) -> None: @@ -299,9 +296,8 @@ class TestFail2BanClientSend: "asyncio.to_thread", new_callable=AsyncMock, side_effect=Fail2BanProtocolError("corrupt response"), - ): - with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanProtocolError): - await client.send(["get", "sshd", "banned"]) + ), patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanProtocolError): + await client.send(["get", "sshd", "banned"]) error_calls = [ c for c in mock_log.error.call_args_list @@ -318,11 +314,11 @@ class TestFail2BanClientSend: class TestSendCommandSyncRetry: """Tests for the retry-on-transient-OSError logic in :func:`_send_command_sync`.""" - def _make_sock(self) -> MagicMock: - """Return a mock socket that connects without error.""" - mock_sock = MagicMock() - mock_sock.connect.return_value = None - return mock_sock + def _make_client(self) -> MagicMock: + """Return a mock client that succeeds on close.""" + mock_client = MagicMock() + mock_client.close.return_value = None + return mock_client def _eagain(self) -> OSError: """Return an ``OSError`` with ``errno.EAGAIN``.""" @@ -342,77 +338,75 @@ class TestSendCommandSyncRetry: def test_transient_eagain_retried_succeeds_on_second_attempt(self) -> None: """A single EAGAIN on connect is retried; success on the second attempt.""" - from app.utils.fail2ban_client import _PROTO_END - call_count = 0 - def _connect_side_effect(sock_path: str) -> None: + def _client_side_effect(socket_path: str, timeout: float) -> MagicMock: nonlocal call_count call_count += 1 if call_count == 1: raise self._eagain() - # Second attempt succeeds (no-op). + return self._make_client() - mock_sock = self._make_sock() - mock_sock.connect.side_effect = _connect_side_effect - mock_sock.recv.return_value = _PROTO_END expected = [0, "pong"] + successful_client = self._make_client() + successful_client.send.return_value = expected + fake_cls = MagicMock(side_effect=[self._eagain(), successful_client]) with ( - patch("socket.socket", return_value=mock_sock), - patch("app.utils.fail2ban_client.loads", return_value=expected), - patch("app.utils.fail2ban_client.time.sleep"), # suppress backoff delay + patch( + "app.utils.fail2ban_client._load_vendored_fail2ban_client", + return_value=fake_cls, + ), + patch("app.utils.fail2ban_client.time.sleep"), ): result = _send_command_sync("/fake.sock", ["ping"], 1.0) assert result == expected - assert call_count == 2 + assert fake_cls.call_count == 2 def test_three_eagain_failures_raise_connection_error(self) -> None: """Three consecutive EAGAIN failures must raise :class:`Fail2BanConnectionError`.""" - mock_sock = self._make_sock() - mock_sock.connect.side_effect = self._eagain() + fake_cls = MagicMock(side_effect=[self._eagain(), self._eagain(), self._eagain()]) with ( - patch("socket.socket", return_value=mock_sock), + patch( + "app.utils.fail2ban_client._load_vendored_fail2ban_client", + return_value=fake_cls, + ), patch("app.utils.fail2ban_client.time.sleep"), pytest.raises(Fail2BanConnectionError), ): _send_command_sync("/fake.sock", ["status"], 1.0) - # connect() should have been called exactly _RETRY_MAX_ATTEMPTS times. - from app.utils.fail2ban_client import _RETRY_MAX_ATTEMPTS - - assert mock_sock.connect.call_count == _RETRY_MAX_ATTEMPTS + assert fake_cls.call_count == _RETRY_MAX_ATTEMPTS def test_enoent_raises_immediately_without_retry(self) -> None: """A non-retryable ``OSError`` (``ENOENT``) must be raised on the first attempt.""" - mock_sock = self._make_sock() - mock_sock.connect.side_effect = self._enoent() + fake_cls = MagicMock(side_effect=self._enoent()) with ( - patch("socket.socket", return_value=mock_sock), + patch( + "app.utils.fail2ban_client._load_vendored_fail2ban_client", + return_value=fake_cls, + ), patch("app.utils.fail2ban_client.time.sleep") as mock_sleep, pytest.raises(Fail2BanConnectionError), ): _send_command_sync("/fake.sock", ["status"], 1.0) - # No back-off sleep should have been triggered. mock_sleep.assert_not_called() - assert mock_sock.connect.call_count == 1 + assert fake_cls.call_count == 1 def test_protocol_error_never_retried(self) -> None: """A :class:`Fail2BanProtocolError` must be re-raised immediately.""" - from app.utils.fail2ban_client import _PROTO_END - - mock_sock = self._make_sock() - mock_sock.recv.return_value = _PROTO_END + fake_client = self._make_client() + fake_client.send.side_effect = Exception("bad pickle") + fake_cls = MagicMock(return_value=fake_client) with ( - patch("socket.socket", return_value=mock_sock), patch( - "app.utils.fail2ban_client.loads", - side_effect=Exception("bad pickle"), + "app.utils.fail2ban_client._load_vendored_fail2ban_client", + return_value=fake_cls, ), patch("app.utils.fail2ban_client.time.sleep") as mock_sleep, pytest.raises(Fail2BanProtocolError), @@ -423,11 +417,13 @@ class TestSendCommandSyncRetry: def test_retry_emits_structured_log_event(self) -> None: """Each retry attempt logs a ``fail2ban_socket_retry`` warning.""" - mock_sock = self._make_sock() - mock_sock.connect.side_effect = self._eagain() + fake_cls = MagicMock(side_effect=[self._eagain(), self._eagain(), self._eagain()]) with ( - patch("socket.socket", return_value=mock_sock), + patch( + "app.utils.fail2ban_client._load_vendored_fail2ban_client", + return_value=fake_cls, + ), patch("app.utils.fail2ban_client.time.sleep"), patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanConnectionError), @@ -438,9 +434,7 @@ class TestSendCommandSyncRetry: c for c in mock_log.warning.call_args_list if c[0][0] == "fail2ban_socket_retry" ] - from app.utils.fail2ban_client import _RETRY_MAX_ATTEMPTS - # One retry log per attempt except the last (which raises directly). assert len(retry_calls) == _RETRY_MAX_ATTEMPTS - 1