Mark async socket handling task done and implement startup cleanup

This commit is contained in:
2026-04-09 22:13:22 +02:00
parent 148756fb79
commit 6b177f1881
5 changed files with 99 additions and 66 deletions

View File

@@ -56,6 +56,7 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
### 6. Update async socket handling ### 6. Update async socket handling
- Where found: `backend/app/utils/fail2ban_client.py`, `backend/app/startup.py` - Where found: `backend/app/utils/fail2ban_client.py`, `backend/app/startup.py`
- Goal: use modern asyncio APIs (`get_running_loop()`), avoid blocking operations on the event loop, and ensure startup resources are cleaned up if initialization fails. - Goal: use modern asyncio APIs (`get_running_loop()`), avoid blocking operations on the event loop, and ensure startup resources are cleaned up if initialization fails.
- Status: completed — switched fail2ban socket I/O to `asyncio.to_thread` and added startup cleanup for failed resource initialization.
- Possible traps and issues: - Possible traps and issues:
- `asyncio.get_event_loop()` behavior changed in newer Python versions; this can cause runtime warnings or errors. - `asyncio.get_event_loop()` behavior changed in newer Python versions; this can cause runtime warnings or errors.
- Resource leaks can occur if `startup_shared_resources()` fails before the lifespan shutdown path is reached. - Resource leaks can occur if `startup_shared_resources()` fails before the lifespan shutdown path is reached.

View File

@@ -7,6 +7,7 @@ in ``app.main`` delegates resource creation and task registration here.
from __future__ import annotations from __future__ import annotations
from contextlib import suppress
from pathlib import Path from pathlib import Path
import aiohttp import aiohttp
@@ -101,13 +102,22 @@ async def startup_shared_resources(
http_session: aiohttp.ClientSession = _create_http_session(settings) http_session: aiohttp.ClientSession = _create_http_session(settings)
geo_service.init_geoip(settings.geoip_db_path) geo_service.init_geoip(settings.geoip_db_path)
scheduler: AsyncIOScheduler = AsyncIOScheduler(timezone="UTC") scheduler: AsyncIOScheduler | None = None
scheduler.start() try:
scheduler = AsyncIOScheduler(timezone="UTC")
scheduler.start()
health_check.register(app) health_check.register(app)
blocklist_import.register(app) blocklist_import.register(app)
geo_cache_flush.register(app) geo_cache_flush.register(app)
geo_re_resolve.register(app) geo_re_resolve.register(app)
history_sync.register(app) history_sync.register(app)
return http_session, scheduler return http_session, scheduler
except Exception:
with suppress(Exception):
await http_session.close()
if scheduler is not None:
with suppress(Exception):
scheduler.shutdown(wait=False)
raise

View File

@@ -121,9 +121,8 @@ def _send_command_sync(
) -> object: ) -> object:
"""Send a command to fail2ban and return the parsed response. """Send a command to fail2ban and return the parsed response.
This is a **synchronous** function intended to be called from within This is a **synchronous** function intended to be executed via
:func:`asyncio.get_event_loop().run_in_executor` so that the event loop :func:`asyncio.to_thread` so that the event loop is not blocked.
is not blocked.
Transient ``OSError`` conditions (``EAGAIN``, ``ECONNREFUSED``, Transient ``OSError`` conditions (``EAGAIN``, ``ECONNREFUSED``,
``ENOBUFS``) are retried up to :data:`_RETRY_MAX_ATTEMPTS` times with ``ENOBUFS``) are retried up to :data:`_RETRY_MAX_ATTEMPTS` times with
@@ -299,15 +298,13 @@ class Fail2BanClient:
async with self._command_semaphore: async with self._command_semaphore:
log.debug("fail2ban_sending_command", command=command) log.debug("fail2ban_sending_command", command=command)
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
try: try:
response: object = await loop.run_in_executor( response: object = await asyncio.to_thread(
None, _send_command_sync,
_send_command_sync, self.socket_path,
self.socket_path, command,
command, self.timeout,
self.timeout, )
)
except Fail2BanConnectionError: except Fail2BanConnectionError:
log.warning( log.warning(
"fail2ban_connection_error", "fail2ban_connection_error",

View File

@@ -4,6 +4,7 @@ import asyncio
from pathlib import Path from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import aiosqlite import aiosqlite
from httpx import ASGITransport, AsyncClient from httpx import ASGITransport, AsyncClient
@@ -123,6 +124,49 @@ async def test_lifespan_initialises_and_cleans_up_shared_resources(tmp_path: Pat
mock_scheduler.shutdown.assert_called_once_with(wait=False) mock_scheduler.shutdown.assert_called_once_with(wait=False)
async def test_lifespan_cleans_up_resources_when_startup_fails(tmp_path: Path) -> None:
"""The lifespan must close resources if shared startup registration fails."""
settings = Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-lifespan-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
mock_scheduler = MagicMock()
mock_scheduler.start = MagicMock()
mock_scheduler.shutdown = MagicMock()
mock_http_session = MagicMock()
mock_http_session.close = AsyncMock()
with (
patch("app.startup.ensure_jail_configs"),
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session),
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler),
patch("app.startup.init_db", new=AsyncMock()),
patch("app.services.geo_service.init_geoip"),
patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)),
patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)),
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)),
patch("app.tasks.health_check.register", side_effect=RuntimeError("startup failed")),
patch("app.tasks.blocklist_import.register"),
patch("app.tasks.geo_cache_flush.register"),
patch("app.tasks.geo_re_resolve.register"),
patch("app.tasks.history_sync.register"),
):
with pytest.raises(RuntimeError, match="startup failed"):
async with _lifespan(app):
pass
mock_http_session.close.assert_awaited_once()
mock_scheduler.shutdown.assert_called_once_with(wait=False)
async def test_http_session_is_created_with_configured_timeouts_and_limits(tmp_path: Path) -> None: async def test_http_session_is_created_with_configured_timeouts_and_limits(tmp_path: Path) -> None:
"""The shared HTTP client session is created with the configured limits.""" """The shared HTTP client session is created with the configured limits."""
settings = Settings( settings = Settings(

View File

@@ -226,15 +226,10 @@ class TestFail2BanClientSend:
"""``send()`` must return the response from the executor.""" """``send()`` must return the response from the executor."""
expected = [0, "OK"] expected = [0, "OK"]
client = Fail2BanClient(socket_path="/fake/fail2ban.sock") client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
# asyncio.get_event_loop().run_in_executor is called inside send(). with patch("asyncio.to_thread", new_callable=AsyncMock, return_value=expected) as mock_to_thread:
# We patch it on the loop object returned by asyncio.get_event_loop().
with patch("asyncio.get_event_loop") as mock_get_loop:
mock_loop = AsyncMock()
mock_loop.run_in_executor = AsyncMock(return_value=expected)
mock_get_loop.return_value = mock_loop
result = await client.send(["status"]) result = await client.send(["status"])
mock_to_thread.assert_awaited_once()
assert result == expected assert result == expected
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -242,13 +237,11 @@ class TestFail2BanClientSend:
"""``send()`` must re-raise :class:`Fail2BanConnectionError`.""" """``send()`` must re-raise :class:`Fail2BanConnectionError`."""
client = Fail2BanClient(socket_path="/fake/fail2ban.sock") client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
with patch("asyncio.get_event_loop") as mock_get_loop: with patch(
mock_loop = AsyncMock() "asyncio.to_thread",
mock_loop.run_in_executor = AsyncMock( new_callable=AsyncMock,
side_effect=Fail2BanConnectionError("unreachable", "/fake/fail2ban.sock") side_effect=Fail2BanConnectionError("unreachable", "/fake/fail2ban.sock"),
) ):
mock_get_loop.return_value = mock_loop
with pytest.raises(Fail2BanConnectionError): with pytest.raises(Fail2BanConnectionError):
await client.send(["status"]) await client.send(["status"])
@@ -257,13 +250,11 @@ class TestFail2BanClientSend:
"""``send()`` must log a warning when a connection error occurs.""" """``send()`` must log a warning when a connection error occurs."""
client = Fail2BanClient(socket_path="/fake/fail2ban.sock") client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
with patch("asyncio.get_event_loop") as mock_get_loop: with patch(
mock_loop = AsyncMock() "asyncio.to_thread",
mock_loop.run_in_executor = AsyncMock( new_callable=AsyncMock,
side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock") side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock"),
) ):
mock_get_loop.return_value = mock_loop
with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanConnectionError): with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanConnectionError):
await client.send(["ping"]) await client.send(["ping"])
@@ -278,13 +269,11 @@ class TestFail2BanClientSend:
"""``send()`` must re-raise :class:`Fail2BanProtocolError`.""" """``send()`` must re-raise :class:`Fail2BanProtocolError`."""
client = Fail2BanClient(socket_path="/fake/fail2ban.sock") client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
with patch("asyncio.get_event_loop") as mock_get_loop: with patch(
mock_loop = AsyncMock() "asyncio.to_thread",
mock_loop.run_in_executor = AsyncMock( new_callable=AsyncMock,
side_effect=Fail2BanProtocolError("bad pickle") side_effect=Fail2BanProtocolError("bad pickle"),
) ):
mock_get_loop.return_value = mock_loop
with pytest.raises(Fail2BanProtocolError): with pytest.raises(Fail2BanProtocolError):
await client.send(["status"]) await client.send(["status"])
@@ -293,13 +282,11 @@ class TestFail2BanClientSend:
"""``send()`` must propagate :class:`Fail2BanProtocolError` to the caller.""" """``send()`` must propagate :class:`Fail2BanProtocolError` to the caller."""
client = Fail2BanClient(socket_path="/fake/fail2ban.sock") client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
with patch("asyncio.get_event_loop") as mock_get_loop: with patch(
mock_loop = AsyncMock() "asyncio.to_thread",
mock_loop.run_in_executor = AsyncMock( new_callable=AsyncMock,
side_effect=Fail2BanProtocolError("bad pickle") side_effect=Fail2BanProtocolError("bad pickle"),
) ):
mock_get_loop.return_value = mock_loop
with pytest.raises(Fail2BanProtocolError): with pytest.raises(Fail2BanProtocolError):
await client.send(["status"]) await client.send(["status"])
@@ -308,13 +295,11 @@ class TestFail2BanClientSend:
"""``send()`` must log an error when a protocol error occurs.""" """``send()`` must log an error when a protocol error occurs."""
client = Fail2BanClient(socket_path="/fake/fail2ban.sock") client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
with patch("asyncio.get_event_loop") as mock_get_loop: with patch(
mock_loop = AsyncMock() "asyncio.to_thread",
mock_loop.run_in_executor = AsyncMock( new_callable=AsyncMock,
side_effect=Fail2BanProtocolError("corrupt response") side_effect=Fail2BanProtocolError("corrupt response"),
) ):
mock_get_loop.return_value = mock_loop
with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanProtocolError): with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanProtocolError):
await client.send(["get", "sshd", "banned"]) await client.send(["get", "sshd", "banned"])
@@ -492,11 +477,7 @@ class TestFail2BanClientSemaphore:
in_flight.pop() in_flight.pop()
return (0, "ok") return (0, "ok")
with patch("asyncio.get_event_loop") as mock_loop_getter: with patch("asyncio.to_thread", new=_fast_executor):
mock_loop = MagicMock()
mock_loop.run_in_executor = _fast_executor
mock_loop_getter.return_value = mock_loop
tasks = [ tasks = [
_asyncio.create_task(client.send(["ping"])) for _ in range(10) _asyncio.create_task(client.send(["ping"])) for _ in range(10)
] ]