Mark async socket handling task done and implement startup cleanup
This commit is contained in:
@@ -226,15 +226,10 @@ class TestFail2BanClientSend:
|
||||
"""``send()`` must return the response from the executor."""
|
||||
expected = [0, "OK"]
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
# asyncio.get_event_loop().run_in_executor is called inside send().
|
||||
# We patch it on the loop object returned by asyncio.get_event_loop().
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(return_value=expected)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with patch("asyncio.to_thread", new_callable=AsyncMock, return_value=expected) as mock_to_thread:
|
||||
result = await client.send(["status"])
|
||||
|
||||
mock_to_thread.assert_awaited_once()
|
||||
assert result == expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -242,13 +237,11 @@ class TestFail2BanClientSend:
|
||||
"""``send()`` must re-raise :class:`Fail2BanConnectionError`."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=Fail2BanConnectionError("unreachable", "/fake/fail2ban.sock")
|
||||
)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with patch(
|
||||
"asyncio.to_thread",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Fail2BanConnectionError("unreachable", "/fake/fail2ban.sock"),
|
||||
):
|
||||
with pytest.raises(Fail2BanConnectionError):
|
||||
await client.send(["status"])
|
||||
|
||||
@@ -257,13 +250,11 @@ class TestFail2BanClientSend:
|
||||
"""``send()`` must log a warning when a connection error occurs."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock")
|
||||
)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with patch(
|
||||
"asyncio.to_thread",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock"),
|
||||
):
|
||||
with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanConnectionError):
|
||||
await client.send(["ping"])
|
||||
|
||||
@@ -278,13 +269,11 @@ class TestFail2BanClientSend:
|
||||
"""``send()`` must re-raise :class:`Fail2BanProtocolError`."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=Fail2BanProtocolError("bad pickle")
|
||||
)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with patch(
|
||||
"asyncio.to_thread",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Fail2BanProtocolError("bad pickle"),
|
||||
):
|
||||
with pytest.raises(Fail2BanProtocolError):
|
||||
await client.send(["status"])
|
||||
|
||||
@@ -293,13 +282,11 @@ class TestFail2BanClientSend:
|
||||
"""``send()`` must propagate :class:`Fail2BanProtocolError` to the caller."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=Fail2BanProtocolError("bad pickle")
|
||||
)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with patch(
|
||||
"asyncio.to_thread",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Fail2BanProtocolError("bad pickle"),
|
||||
):
|
||||
with pytest.raises(Fail2BanProtocolError):
|
||||
await client.send(["status"])
|
||||
|
||||
@@ -308,13 +295,11 @@ class TestFail2BanClientSend:
|
||||
"""``send()`` must log an error when a protocol error occurs."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=Fail2BanProtocolError("corrupt response")
|
||||
)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with patch(
|
||||
"asyncio.to_thread",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Fail2BanProtocolError("corrupt response"),
|
||||
):
|
||||
with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanProtocolError):
|
||||
await client.send(["get", "sshd", "banned"])
|
||||
|
||||
@@ -492,11 +477,7 @@ class TestFail2BanClientSemaphore:
|
||||
in_flight.pop()
|
||||
return (0, "ok")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_loop_getter:
|
||||
mock_loop = MagicMock()
|
||||
mock_loop.run_in_executor = _fast_executor
|
||||
mock_loop_getter.return_value = mock_loop
|
||||
|
||||
with patch("asyncio.to_thread", new=_fast_executor):
|
||||
tasks = [
|
||||
_asyncio.create_task(client.send(["ping"])) for _ in range(10)
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user