From 6b177f1881ec0da092896f7c9b7df8c3f7398270 Mon Sep 17 00:00:00 2001 From: Lukas Date: Thu, 9 Apr 2026 22:13:22 +0200 Subject: [PATCH] Mark async socket handling task done and implement startup cleanup --- Docs/Tasks.md | 1 + backend/app/startup.py | 26 +++++-- backend/app/utils/fail2ban_client.py | 19 ++--- backend/tests/test_main.py | 44 +++++++++++ .../test_services/test_fail2ban_client.py | 75 +++++++------------ 5 files changed, 99 insertions(+), 66 deletions(-) diff --git a/Docs/Tasks.md b/Docs/Tasks.md index 11bd672..8fcb370 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -56,6 +56,7 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue. ### 6. Update async socket handling - Where found: `backend/app/utils/fail2ban_client.py`, `backend/app/startup.py` - Goal: use modern asyncio APIs (`get_running_loop()`), avoid blocking operations on the event loop, and ensure startup resources are cleaned up if initialization fails. +- Status: completed — switched fail2ban socket I/O to `asyncio.to_thread` and added startup cleanup for failed resource initialization. - Possible traps and issues: - `asyncio.get_event_loop()` behavior changed in newer Python versions; this can cause runtime warnings or errors. - Resource leaks can occur if `startup_shared_resources()` fails before the lifespan shutdown path is reached. diff --git a/backend/app/startup.py b/backend/app/startup.py index ebf908a..ba8d405 100644 --- a/backend/app/startup.py +++ b/backend/app/startup.py @@ -7,6 +7,7 @@ in ``app.main`` delegates resource creation and task registration here. from __future__ import annotations +from contextlib import suppress from pathlib import Path import aiohttp @@ -101,13 +102,22 @@ async def startup_shared_resources( http_session: aiohttp.ClientSession = _create_http_session(settings) geo_service.init_geoip(settings.geoip_db_path) - scheduler: AsyncIOScheduler = AsyncIOScheduler(timezone="UTC") - scheduler.start() + scheduler: AsyncIOScheduler | None = None + try: + scheduler = AsyncIOScheduler(timezone="UTC") + scheduler.start() - health_check.register(app) - blocklist_import.register(app) - geo_cache_flush.register(app) - geo_re_resolve.register(app) - history_sync.register(app) + health_check.register(app) + blocklist_import.register(app) + geo_cache_flush.register(app) + geo_re_resolve.register(app) + history_sync.register(app) - return http_session, scheduler + return http_session, scheduler + except Exception: + with suppress(Exception): + await http_session.close() + if scheduler is not None: + with suppress(Exception): + scheduler.shutdown(wait=False) + raise diff --git a/backend/app/utils/fail2ban_client.py b/backend/app/utils/fail2ban_client.py index 7e04e9c..82ddfec 100644 --- a/backend/app/utils/fail2ban_client.py +++ b/backend/app/utils/fail2ban_client.py @@ -121,9 +121,8 @@ def _send_command_sync( ) -> object: """Send a command to fail2ban and return the parsed response. - This is a **synchronous** function intended to be called from within - :func:`asyncio.get_event_loop().run_in_executor` so that the event loop - is not blocked. + This is a **synchronous** function intended to be executed via + :func:`asyncio.to_thread` so that the event loop is not blocked. Transient ``OSError`` conditions (``EAGAIN``, ``ECONNREFUSED``, ``ENOBUFS``) are retried up to :data:`_RETRY_MAX_ATTEMPTS` times with @@ -299,15 +298,13 @@ class Fail2BanClient: async with self._command_semaphore: log.debug("fail2ban_sending_command", command=command) - loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() try: - response: object = await loop.run_in_executor( - None, - _send_command_sync, - self.socket_path, - command, - self.timeout, - ) + response: object = await asyncio.to_thread( + _send_command_sync, + self.socket_path, + command, + self.timeout, + ) except Fail2BanConnectionError: log.warning( "fail2ban_connection_error", diff --git a/backend/tests/test_main.py b/backend/tests/test_main.py index 95b96d8..c5ce8e7 100644 --- a/backend/tests/test_main.py +++ b/backend/tests/test_main.py @@ -4,6 +4,7 @@ import asyncio from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch +import pytest import aiosqlite from httpx import ASGITransport, AsyncClient @@ -123,6 +124,49 @@ async def test_lifespan_initialises_and_cleans_up_shared_resources(tmp_path: Pat mock_scheduler.shutdown.assert_called_once_with(wait=False) +async def test_lifespan_cleans_up_resources_when_startup_fails(tmp_path: Path) -> None: + """The lifespan must close resources if shared startup registration fails.""" + settings = Settings( + database_path=str(tmp_path / "bangui.db"), + fail2ban_socket="/tmp/fake_fail2ban.sock", + fail2ban_config_dir=str(tmp_path / "fail2ban"), + session_secret="test-lifespan-secret", + session_duration_minutes=60, + timezone="UTC", + log_level="debug", + ) + app = create_app(settings=settings) + + mock_scheduler = MagicMock() + mock_scheduler.start = MagicMock() + mock_scheduler.shutdown = MagicMock() + + mock_http_session = MagicMock() + mock_http_session.close = AsyncMock() + + with ( + patch("app.startup.ensure_jail_configs"), + patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session), + patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), + patch("app.startup.init_db", new=AsyncMock()), + patch("app.services.geo_service.init_geoip"), + patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)), + patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)), + patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)), + patch("app.tasks.health_check.register", side_effect=RuntimeError("startup failed")), + patch("app.tasks.blocklist_import.register"), + patch("app.tasks.geo_cache_flush.register"), + patch("app.tasks.geo_re_resolve.register"), + patch("app.tasks.history_sync.register"), + ): + with pytest.raises(RuntimeError, match="startup failed"): + async with _lifespan(app): + pass + + mock_http_session.close.assert_awaited_once() + mock_scheduler.shutdown.assert_called_once_with(wait=False) + + async def test_http_session_is_created_with_configured_timeouts_and_limits(tmp_path: Path) -> None: """The shared HTTP client session is created with the configured limits.""" settings = Settings( diff --git a/backend/tests/test_services/test_fail2ban_client.py b/backend/tests/test_services/test_fail2ban_client.py index 45f25d0..5b837ab 100644 --- a/backend/tests/test_services/test_fail2ban_client.py +++ b/backend/tests/test_services/test_fail2ban_client.py @@ -226,15 +226,10 @@ class TestFail2BanClientSend: """``send()`` must return the response from the executor.""" expected = [0, "OK"] client = Fail2BanClient(socket_path="/fake/fail2ban.sock") - # asyncio.get_event_loop().run_in_executor is called inside send(). - # We patch it on the loop object returned by asyncio.get_event_loop(). - with patch("asyncio.get_event_loop") as mock_get_loop: - mock_loop = AsyncMock() - mock_loop.run_in_executor = AsyncMock(return_value=expected) - mock_get_loop.return_value = mock_loop - + with patch("asyncio.to_thread", new_callable=AsyncMock, return_value=expected) as mock_to_thread: result = await client.send(["status"]) + mock_to_thread.assert_awaited_once() assert result == expected @pytest.mark.asyncio @@ -242,13 +237,11 @@ class TestFail2BanClientSend: """``send()`` must re-raise :class:`Fail2BanConnectionError`.""" client = Fail2BanClient(socket_path="/fake/fail2ban.sock") - with patch("asyncio.get_event_loop") as mock_get_loop: - mock_loop = AsyncMock() - mock_loop.run_in_executor = AsyncMock( - side_effect=Fail2BanConnectionError("unreachable", "/fake/fail2ban.sock") - ) - mock_get_loop.return_value = mock_loop - + with patch( + "asyncio.to_thread", + new_callable=AsyncMock, + side_effect=Fail2BanConnectionError("unreachable", "/fake/fail2ban.sock"), + ): with pytest.raises(Fail2BanConnectionError): await client.send(["status"]) @@ -257,13 +250,11 @@ class TestFail2BanClientSend: """``send()`` must log a warning when a connection error occurs.""" client = Fail2BanClient(socket_path="/fake/fail2ban.sock") - with patch("asyncio.get_event_loop") as mock_get_loop: - mock_loop = AsyncMock() - mock_loop.run_in_executor = AsyncMock( - side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock") - ) - mock_get_loop.return_value = mock_loop - + with patch( + "asyncio.to_thread", + new_callable=AsyncMock, + side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock"), + ): with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanConnectionError): await client.send(["ping"]) @@ -278,13 +269,11 @@ class TestFail2BanClientSend: """``send()`` must re-raise :class:`Fail2BanProtocolError`.""" client = Fail2BanClient(socket_path="/fake/fail2ban.sock") - with patch("asyncio.get_event_loop") as mock_get_loop: - mock_loop = AsyncMock() - mock_loop.run_in_executor = AsyncMock( - side_effect=Fail2BanProtocolError("bad pickle") - ) - mock_get_loop.return_value = mock_loop - + with patch( + "asyncio.to_thread", + new_callable=AsyncMock, + side_effect=Fail2BanProtocolError("bad pickle"), + ): with pytest.raises(Fail2BanProtocolError): await client.send(["status"]) @@ -293,13 +282,11 @@ class TestFail2BanClientSend: """``send()`` must propagate :class:`Fail2BanProtocolError` to the caller.""" client = Fail2BanClient(socket_path="/fake/fail2ban.sock") - with patch("asyncio.get_event_loop") as mock_get_loop: - mock_loop = AsyncMock() - mock_loop.run_in_executor = AsyncMock( - side_effect=Fail2BanProtocolError("bad pickle") - ) - mock_get_loop.return_value = mock_loop - + with patch( + "asyncio.to_thread", + new_callable=AsyncMock, + side_effect=Fail2BanProtocolError("bad pickle"), + ): with pytest.raises(Fail2BanProtocolError): await client.send(["status"]) @@ -308,13 +295,11 @@ class TestFail2BanClientSend: """``send()`` must log an error when a protocol error occurs.""" client = Fail2BanClient(socket_path="/fake/fail2ban.sock") - with patch("asyncio.get_event_loop") as mock_get_loop: - mock_loop = AsyncMock() - mock_loop.run_in_executor = AsyncMock( - side_effect=Fail2BanProtocolError("corrupt response") - ) - mock_get_loop.return_value = mock_loop - + with patch( + "asyncio.to_thread", + new_callable=AsyncMock, + side_effect=Fail2BanProtocolError("corrupt response"), + ): with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanProtocolError): await client.send(["get", "sshd", "banned"]) @@ -492,11 +477,7 @@ class TestFail2BanClientSemaphore: in_flight.pop() return (0, "ok") - with patch("asyncio.get_event_loop") as mock_loop_getter: - mock_loop = MagicMock() - mock_loop.run_in_executor = _fast_executor - mock_loop_getter.return_value = mock_loop - + with patch("asyncio.to_thread", new=_fast_executor): tasks = [ _asyncio.create_task(client.send(["ping"])) for _ in range(10) ]