Mark async socket handling task done and implement startup cleanup

This commit is contained in:
2026-04-09 22:13:22 +02:00
parent 148756fb79
commit 6b177f1881
5 changed files with 99 additions and 66 deletions

View File

@@ -4,6 +4,7 @@ import asyncio
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import aiosqlite
from httpx import ASGITransport, AsyncClient
@@ -123,6 +124,49 @@ async def test_lifespan_initialises_and_cleans_up_shared_resources(tmp_path: Pat
mock_scheduler.shutdown.assert_called_once_with(wait=False)
async def test_lifespan_cleans_up_resources_when_startup_fails(tmp_path: Path) -> None:
"""The lifespan must close resources if shared startup registration fails."""
settings = Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-lifespan-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with (
patch("app.startup.ensure_jail_configs"),
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session),
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.startup.init_db", new=AsyncMock()),
patch("app.services.geo_service.init_geoip"),
patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)),
patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)),
patch("app.tasks.health_check.register", side_effect=RuntimeError("startup failed")),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.tasks.history_sync.register"),
):
with pytest.raises(RuntimeError, match="startup failed"):
async with _lifespan(app):
pass
mock_http_session.close.assert_awaited_once()
mock_scheduler.shutdown.assert_called_once_with(wait=False)
async def test_http_session_is_created_with_configured_timeouts_and_limits(tmp_path: Path) -> None:
"""The shared HTTP client session is created with the configured limits."""
settings = Settings(

View File

@@ -226,15 +226,10 @@ class TestFail2BanClientSend:
"""``send()`` must return the response from the executor."""
expected = [0, "OK"]
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
# asyncio.get_event_loop().run_in_executor is called inside send().
# We patch it on the loop object returned by asyncio.get_event_loop().
with patch("asyncio.get_event_loop") as mock_get_loop:
mock_loop = AsyncMock()
mock_loop.run_in_executor = AsyncMock(return_value=expected)
mock_get_loop.return_value = mock_loop
with patch("asyncio.to_thread", new_callable=AsyncMock, return_value=expected) as mock_to_thread:
result = await client.send(["status"])
mock_to_thread.assert_awaited_once()
assert result == expected
@pytest.mark.asyncio
@@ -242,13 +237,11 @@ class TestFail2BanClientSend:
"""``send()`` must re-raise :class:`Fail2BanConnectionError`."""
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
with patch("asyncio.get_event_loop") as mock_get_loop:
mock_loop = AsyncMock()
mock_loop.run_in_executor = AsyncMock(
side_effect=Fail2BanConnectionError("unreachable", "/fake/fail2ban.sock")
)
mock_get_loop.return_value = mock_loop
with patch(
"asyncio.to_thread",
new_callable=AsyncMock,
side_effect=Fail2BanConnectionError("unreachable", "/fake/fail2ban.sock"),
):
with pytest.raises(Fail2BanConnectionError):
await client.send(["status"])
@@ -257,13 +250,11 @@ class TestFail2BanClientSend:
"""``send()`` must log a warning when a connection error occurs."""
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
with patch("asyncio.get_event_loop") as mock_get_loop:
mock_loop = AsyncMock()
mock_loop.run_in_executor = AsyncMock(
side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock")
)
mock_get_loop.return_value = mock_loop
with patch(
"asyncio.to_thread",
new_callable=AsyncMock,
side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock"),
):
with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanConnectionError):
await client.send(["ping"])
@@ -278,13 +269,11 @@ class TestFail2BanClientSend:
"""``send()`` must re-raise :class:`Fail2BanProtocolError`."""
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
with patch("asyncio.get_event_loop") as mock_get_loop:
mock_loop = AsyncMock()
mock_loop.run_in_executor = AsyncMock(
side_effect=Fail2BanProtocolError("bad pickle")
)
mock_get_loop.return_value = mock_loop
with patch(
"asyncio.to_thread",
new_callable=AsyncMock,
side_effect=Fail2BanProtocolError("bad pickle"),
):
with pytest.raises(Fail2BanProtocolError):
await client.send(["status"])
@@ -293,13 +282,11 @@ class TestFail2BanClientSend:
"""``send()`` must propagate :class:`Fail2BanProtocolError` to the caller."""
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
with patch("asyncio.get_event_loop") as mock_get_loop:
mock_loop = AsyncMock()
mock_loop.run_in_executor = AsyncMock(
side_effect=Fail2BanProtocolError("bad pickle")
)
mock_get_loop.return_value = mock_loop
with patch(
"asyncio.to_thread",
new_callable=AsyncMock,
side_effect=Fail2BanProtocolError("bad pickle"),
):
with pytest.raises(Fail2BanProtocolError):
await client.send(["status"])
@@ -308,13 +295,11 @@ class TestFail2BanClientSend:
"""``send()`` must log an error when a protocol error occurs."""
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
with patch("asyncio.get_event_loop") as mock_get_loop:
mock_loop = AsyncMock()
mock_loop.run_in_executor = AsyncMock(
side_effect=Fail2BanProtocolError("corrupt response")
)
mock_get_loop.return_value = mock_loop
with patch(
"asyncio.to_thread",
new_callable=AsyncMock,
side_effect=Fail2BanProtocolError("corrupt response"),
):
with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanProtocolError):
await client.send(["get", "sshd", "banned"])
@@ -492,11 +477,7 @@ class TestFail2BanClientSemaphore:
in_flight.pop()
return (0, "ok")
with patch("asyncio.get_event_loop") as mock_loop_getter:
mock_loop = MagicMock()
mock_loop.run_in_executor = _fast_executor
mock_loop_getter.return_value = mock_loop
with patch("asyncio.to_thread", new=_fast_executor):
tasks = [
_asyncio.create_task(client.send(["ping"])) for _ in range(10)
]