Harden fail2ban integration and mark task complete
This commit is contained in:
@@ -7,14 +7,8 @@ infrastructure.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure the bundled fail2ban package is importable.
|
||||
_FAIL2BAN_MASTER: Path = Path(__file__).resolve().parents[2] / "fail2ban-master"
|
||||
if str(_FAIL2BAN_MASTER) not in sys.path:
|
||||
sys.path.insert(0, str(_FAIL2BAN_MASTER))
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
@@ -468,36 +468,23 @@ class TestFail2BanClientSemaphore:
|
||||
"""Tests for the concurrency semaphore in :meth:`Fail2BanClient.send`."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_semaphore_limits_concurrency(self) -> None:
|
||||
"""No more than _COMMAND_SEMAPHORE_CONCURRENCY commands overlap."""
|
||||
async def test_semaphore_limits_concurrency_per_instance(self) -> None:
|
||||
"""Each client instance must enforce its own concurrency cap."""
|
||||
import asyncio as _asyncio
|
||||
|
||||
import app.utils.fail2ban_client as _module
|
||||
|
||||
# Reset module-level semaphore so this test starts fresh.
|
||||
_module._command_semaphore = None
|
||||
from app.utils import fail2ban_client as _module
|
||||
|
||||
concurrency_limit = 3
|
||||
_module._COMMAND_SEMAPHORE_CONCURRENCY = concurrency_limit
|
||||
_module._command_semaphore = _asyncio.Semaphore(concurrency_limit)
|
||||
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
client2 = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
|
||||
assert client._command_semaphore is not client2._command_semaphore
|
||||
|
||||
in_flight: list[int] = []
|
||||
peak_concurrent: list[int] = []
|
||||
|
||||
async def _slow_send(command: list[Any]) -> Any:
|
||||
in_flight.append(1)
|
||||
peak_concurrent.append(len(in_flight))
|
||||
await _asyncio.sleep(0) # yield to allow other coroutines to run
|
||||
in_flight.pop()
|
||||
return (0, "ok")
|
||||
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
with patch.object(client, "send", wraps=_slow_send) as _patched:
|
||||
# Bypass the semaphore wrapper — test the actual send directly.
|
||||
pass
|
||||
|
||||
# Override _command_semaphore and run concurrently via the real send path
|
||||
# but mock _send_command_sync to avoid actual socket I/O.
|
||||
async def _fast_executor(_fn: Any, *_args: Any) -> Any:
|
||||
in_flight.append(1)
|
||||
peak_concurrent.append(len(in_flight))
|
||||
@@ -505,20 +492,17 @@ class TestFail2BanClientSemaphore:
|
||||
in_flight.pop()
|
||||
return (0, "ok")
|
||||
|
||||
client2 = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
with patch("asyncio.get_event_loop") as mock_loop_getter:
|
||||
mock_loop = MagicMock()
|
||||
mock_loop.run_in_executor = _fast_executor
|
||||
mock_loop_getter.return_value = mock_loop
|
||||
|
||||
tasks = [
|
||||
_asyncio.create_task(client2.send(["ping"])) for _ in range(10)
|
||||
_asyncio.create_task(client.send(["ping"])) for _ in range(10)
|
||||
]
|
||||
await _asyncio.gather(*tasks)
|
||||
|
||||
# Peak concurrent activity must never exceed the semaphore limit.
|
||||
assert max(peak_concurrent) <= concurrency_limit
|
||||
|
||||
# Restore module defaults after test.
|
||||
_module._COMMAND_SEMAPHORE_CONCURRENCY = 10
|
||||
_module._command_semaphore = None
|
||||
|
||||
Reference in New Issue
Block a user