From 3ccfc20c64e54e6162e34092f86ff264f0f0eb17 Mon Sep 17 00:00:00 2001 From: Lukas Date: Mon, 6 Apr 2026 20:20:14 +0200 Subject: [PATCH] Harden fail2ban integration and mark task complete --- Docs/Tasks.md | 6 ++-- backend/app/main.py | 32 +---------------- backend/app/utils/fail2ban_client.py | 18 ++++------ backend/tests/conftest.py | 6 ---- .../test_services/test_fail2ban_client.py | 34 +++++-------------- 5 files changed, 20 insertions(+), 76 deletions(-) diff --git a/Docs/Tasks.md b/Docs/Tasks.md index 166b269..d536825 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -22,7 +22,7 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue. - Introduce explicit factories or providers for shared resources such as DB, HTTP client session, scheduler, and settings. - Ensure routers depend on injected providers rather than global state or dynamic imports. -- **Harden fail2ban integration.** +- **Harden fail2ban integration.** ✅ - Remove the `sys.path` hack that locates `fail2ban-master` at runtime. - Replace it with a deterministic packaging or configuration model so the backend does not depend on repository layout. - Refactor `Fail2BanClient` so concurrency control is instance-based and not backed by hidden module globals. @@ -73,8 +73,8 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue. ### Priority Execution Plan 1. ✅ Fix the global SQLite connection pattern and tests. -2. Refactor dependency injection / explicit shared resources. -3. Harden fail2ban client concurrency and packaging. +2. ✅ Refactor dependency injection / explicit shared resources. +3. ✅ Harden fail2ban client concurrency and packaging. 4. Convert setup guard to a safer startup-driven model. 5. Add deployment-safe configuration and production-ready CORS. 6. Add lifecycle and concurrency regression tests. diff --git a/backend/app/main.py b/backend/app/main.py index bd3ddc0..e11a73e 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -51,36 +51,6 @@ from app.tasks import blocklist_import, geo_cache_flush, geo_re_resolve, health_ from app.utils.fail2ban_client import Fail2BanConnectionError, Fail2BanProtocolError from app.utils.jail_config import ensure_jail_configs -# --------------------------------------------------------------------------- -# Ensure the bundled fail2ban package is importable from fail2ban-master/ -# -# The directory layout differs between local dev and the Docker image: -# Local: /backend/app/main.py → fail2ban-master at parents[2] -# Docker: /app/app/main.py → fail2ban-master at parents[1] -# Walk up from this file until we find a "fail2ban-master" sibling directory -# so the path resolution is environment-agnostic. -# --------------------------------------------------------------------------- - - -def _find_fail2ban_master() -> Path | None: - """Return the first ``fail2ban-master`` directory found while walking up. - - Returns: - Absolute :class:`~pathlib.Path` to the ``fail2ban-master`` directory, - or ``None`` if no such directory exists among the ancestors. - """ - here = Path(__file__).resolve() - for ancestor in here.parents: - candidate = ancestor / "fail2ban-master" - if candidate.is_dir(): - return candidate - return None - - -_fail2ban_master: Path | None = _find_fail2ban_master() -if _fail2ban_master is not None and str(_fail2ban_master) not in sys.path: - sys.path.insert(0, str(_fail2ban_master)) - log: structlog.stdlib.BoundLogger = structlog.get_logger() @@ -328,8 +298,8 @@ class SetupRedirectMiddleware(BaseHTTPMiddleware): if path.startswith("/api") and not getattr( request.app.state, "_setup_complete_cached", False ): - from app.services import setup_service # noqa: PLC0415 from app.db import open_db # noqa: PLC0415 + from app.services import setup_service # noqa: PLC0415 db = getattr(request.app.state, "db", None) if db is None: diff --git a/backend/app/utils/fail2ban_client.py b/backend/app/utils/fail2ban_client.py index d02a6a5..7e04e9c 100644 --- a/backend/app/utils/fail2ban_client.py +++ b/backend/app/utils/fail2ban_client.py @@ -91,12 +91,9 @@ _RETRYABLE_ERRNOS: frozenset[int] = frozenset( _RETRY_MAX_ATTEMPTS: int = 3 _RETRY_INITIAL_BACKOFF: float = 0.15 # seconds; doubles on each attempt -# Maximum number of concurrent in-flight socket commands. Operations that -# exceed this cap wait until a slot is available. +# Maximum number of concurrent in-flight socket commands per client. +# Operations that exceed this cap wait until a slot is available. _COMMAND_SEMAPHORE_CONCURRENCY: int = 10 -# The semaphore is created lazily on the first send() call so it binds to the -# event loop that is actually running (important for test isolation). -_command_semaphore: asyncio.Semaphore | None = None class Fail2BanConnectionError(Exception): @@ -266,6 +263,9 @@ class Fail2BanClient: """ self.socket_path: str = socket_path self.timeout: float = timeout + self._command_semaphore: asyncio.Semaphore = asyncio.Semaphore( + _COMMAND_SEMAPHORE_CONCURRENCY + ) async def send(self, command: Fail2BanCommand) -> object: """Send a command to fail2ban and return the response. @@ -290,18 +290,14 @@ class Fail2BanClient: connection is unexpectedly closed. Fail2BanProtocolError: If the response cannot be decoded. """ - global _command_semaphore - if _command_semaphore is None: - _command_semaphore = asyncio.Semaphore(_COMMAND_SEMAPHORE_CONCURRENCY) - - if _command_semaphore.locked(): + if self._command_semaphore.locked(): log.debug( "fail2ban_command_waiting_semaphore", command=command, concurrency_limit=_COMMAND_SEMAPHORE_CONCURRENCY, ) - async with _command_semaphore: + async with self._command_semaphore: log.debug("fail2ban_sending_command", command=command) loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() try: diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index cb14b9f..50af991 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -7,14 +7,8 @@ infrastructure. from __future__ import annotations -import sys from pathlib import Path -# Ensure the bundled fail2ban package is importable. -_FAIL2BAN_MASTER: Path = Path(__file__).resolve().parents[2] / "fail2ban-master" -if str(_FAIL2BAN_MASTER) not in sys.path: - sys.path.insert(0, str(_FAIL2BAN_MASTER)) - import aiosqlite import pytest from httpx import ASGITransport, AsyncClient diff --git a/backend/tests/test_services/test_fail2ban_client.py b/backend/tests/test_services/test_fail2ban_client.py index 8e344ec..45f25d0 100644 --- a/backend/tests/test_services/test_fail2ban_client.py +++ b/backend/tests/test_services/test_fail2ban_client.py @@ -468,36 +468,23 @@ class TestFail2BanClientSemaphore: """Tests for the concurrency semaphore in :meth:`Fail2BanClient.send`.""" @pytest.mark.asyncio - async def test_semaphore_limits_concurrency(self) -> None: - """No more than _COMMAND_SEMAPHORE_CONCURRENCY commands overlap.""" + async def test_semaphore_limits_concurrency_per_instance(self) -> None: + """Each client instance must enforce its own concurrency cap.""" import asyncio as _asyncio - import app.utils.fail2ban_client as _module - - # Reset module-level semaphore so this test starts fresh. - _module._command_semaphore = None + from app.utils import fail2ban_client as _module concurrency_limit = 3 _module._COMMAND_SEMAPHORE_CONCURRENCY = concurrency_limit - _module._command_semaphore = _asyncio.Semaphore(concurrency_limit) + + client = Fail2BanClient(socket_path="/fake/fail2ban.sock") + client2 = Fail2BanClient(socket_path="/fake/fail2ban.sock") + + assert client._command_semaphore is not client2._command_semaphore in_flight: list[int] = [] peak_concurrent: list[int] = [] - async def _slow_send(command: list[Any]) -> Any: - in_flight.append(1) - peak_concurrent.append(len(in_flight)) - await _asyncio.sleep(0) # yield to allow other coroutines to run - in_flight.pop() - return (0, "ok") - - client = Fail2BanClient(socket_path="/fake/fail2ban.sock") - with patch.object(client, "send", wraps=_slow_send) as _patched: - # Bypass the semaphore wrapper — test the actual send directly. - pass - - # Override _command_semaphore and run concurrently via the real send path - # but mock _send_command_sync to avoid actual socket I/O. async def _fast_executor(_fn: Any, *_args: Any) -> Any: in_flight.append(1) peak_concurrent.append(len(in_flight)) @@ -505,20 +492,17 @@ class TestFail2BanClientSemaphore: in_flight.pop() return (0, "ok") - client2 = Fail2BanClient(socket_path="/fake/fail2ban.sock") with patch("asyncio.get_event_loop") as mock_loop_getter: mock_loop = MagicMock() mock_loop.run_in_executor = _fast_executor mock_loop_getter.return_value = mock_loop tasks = [ - _asyncio.create_task(client2.send(["ping"])) for _ in range(10) + _asyncio.create_task(client.send(["ping"])) for _ in range(10) ] await _asyncio.gather(*tasks) - # Peak concurrent activity must never exceed the semaphore limit. assert max(peak_concurrent) <= concurrency_limit # Restore module defaults after test. _module._COMMAND_SEMAPHORE_CONCURRENCY = 10 - _module._command_semaphore = None