refactoring-backend #3
@@ -56,6 +56,7 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
|
||||
### 6. Update async socket handling
|
||||
- Where found: `backend/app/utils/fail2ban_client.py`, `backend/app/startup.py`
|
||||
- Goal: use modern asyncio APIs (`get_running_loop()`), avoid blocking operations on the event loop, and ensure startup resources are cleaned up if initialization fails.
|
||||
- Status: completed — switched fail2ban socket I/O to `asyncio.to_thread` and added startup cleanup for failed resource initialization.
|
||||
- Possible traps and issues:
|
||||
- `asyncio.get_event_loop()` behavior changed in newer Python versions; this can cause runtime warnings or errors.
|
||||
- Resource leaks can occur if `startup_shared_resources()` fails before the lifespan shutdown path is reached.
|
||||
|
||||
@@ -7,6 +7,7 @@ in ``app.main`` delegates resource creation and task registration here.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import suppress
|
||||
from pathlib import Path
|
||||
|
||||
import aiohttp
|
||||
@@ -101,13 +102,22 @@ async def startup_shared_resources(
|
||||
http_session: aiohttp.ClientSession = _create_http_session(settings)
|
||||
geo_service.init_geoip(settings.geoip_db_path)
|
||||
|
||||
scheduler: AsyncIOScheduler = AsyncIOScheduler(timezone="UTC")
|
||||
scheduler.start()
|
||||
scheduler: AsyncIOScheduler | None = None
|
||||
try:
|
||||
scheduler = AsyncIOScheduler(timezone="UTC")
|
||||
scheduler.start()
|
||||
|
||||
health_check.register(app)
|
||||
blocklist_import.register(app)
|
||||
geo_cache_flush.register(app)
|
||||
geo_re_resolve.register(app)
|
||||
history_sync.register(app)
|
||||
health_check.register(app)
|
||||
blocklist_import.register(app)
|
||||
geo_cache_flush.register(app)
|
||||
geo_re_resolve.register(app)
|
||||
history_sync.register(app)
|
||||
|
||||
return http_session, scheduler
|
||||
return http_session, scheduler
|
||||
except Exception:
|
||||
with suppress(Exception):
|
||||
await http_session.close()
|
||||
if scheduler is not None:
|
||||
with suppress(Exception):
|
||||
scheduler.shutdown(wait=False)
|
||||
raise
|
||||
|
||||
@@ -121,9 +121,8 @@ def _send_command_sync(
|
||||
) -> object:
|
||||
"""Send a command to fail2ban and return the parsed response.
|
||||
|
||||
This is a **synchronous** function intended to be called from within
|
||||
:func:`asyncio.get_event_loop().run_in_executor` so that the event loop
|
||||
is not blocked.
|
||||
This is a **synchronous** function intended to be executed via
|
||||
:func:`asyncio.to_thread` so that the event loop is not blocked.
|
||||
|
||||
Transient ``OSError`` conditions (``EAGAIN``, ``ECONNREFUSED``,
|
||||
``ENOBUFS``) are retried up to :data:`_RETRY_MAX_ATTEMPTS` times with
|
||||
@@ -299,15 +298,13 @@ class Fail2BanClient:
|
||||
|
||||
async with self._command_semaphore:
|
||||
log.debug("fail2ban_sending_command", command=command)
|
||||
loop: asyncio.AbstractEventLoop = asyncio.get_event_loop()
|
||||
try:
|
||||
response: object = await loop.run_in_executor(
|
||||
None,
|
||||
_send_command_sync,
|
||||
self.socket_path,
|
||||
command,
|
||||
self.timeout,
|
||||
)
|
||||
response: object = await asyncio.to_thread(
|
||||
_send_command_sync,
|
||||
self.socket_path,
|
||||
command,
|
||||
self.timeout,
|
||||
)
|
||||
except Fail2BanConnectionError:
|
||||
log.warning(
|
||||
"fail2ban_connection_error",
|
||||
|
||||
@@ -4,6 +4,7 @@ import asyncio
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import aiosqlite
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
@@ -123,6 +124,49 @@ async def test_lifespan_initialises_and_cleans_up_shared_resources(tmp_path: Pat
|
||||
mock_scheduler.shutdown.assert_called_once_with(wait=False)
|
||||
|
||||
|
||||
async def test_lifespan_cleans_up_resources_when_startup_fails(tmp_path: Path) -> None:
|
||||
"""The lifespan must close resources if shared startup registration fails."""
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "bangui.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
fail2ban_config_dir=str(tmp_path / "fail2ban"),
|
||||
session_secret="test-lifespan-secret",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
mock_scheduler = MagicMock()
|
||||
mock_scheduler.start = MagicMock()
|
||||
mock_scheduler.shutdown = MagicMock()
|
||||
|
||||
mock_http_session = MagicMock()
|
||||
mock_http_session.close = AsyncMock()
|
||||
|
||||
with (
|
||||
patch("app.startup.ensure_jail_configs"),
|
||||
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session),
|
||||
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler),
|
||||
patch("app.startup.init_db", new=AsyncMock()),
|
||||
patch("app.services.geo_service.init_geoip"),
|
||||
patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)),
|
||||
patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)),
|
||||
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)),
|
||||
patch("app.tasks.health_check.register", side_effect=RuntimeError("startup failed")),
|
||||
patch("app.tasks.blocklist_import.register"),
|
||||
patch("app.tasks.geo_cache_flush.register"),
|
||||
patch("app.tasks.geo_re_resolve.register"),
|
||||
patch("app.tasks.history_sync.register"),
|
||||
):
|
||||
with pytest.raises(RuntimeError, match="startup failed"):
|
||||
async with _lifespan(app):
|
||||
pass
|
||||
|
||||
mock_http_session.close.assert_awaited_once()
|
||||
mock_scheduler.shutdown.assert_called_once_with(wait=False)
|
||||
|
||||
|
||||
async def test_http_session_is_created_with_configured_timeouts_and_limits(tmp_path: Path) -> None:
|
||||
"""The shared HTTP client session is created with the configured limits."""
|
||||
settings = Settings(
|
||||
|
||||
@@ -226,15 +226,10 @@ class TestFail2BanClientSend:
|
||||
"""``send()`` must return the response from the executor."""
|
||||
expected = [0, "OK"]
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
# asyncio.get_event_loop().run_in_executor is called inside send().
|
||||
# We patch it on the loop object returned by asyncio.get_event_loop().
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(return_value=expected)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with patch("asyncio.to_thread", new_callable=AsyncMock, return_value=expected) as mock_to_thread:
|
||||
result = await client.send(["status"])
|
||||
|
||||
mock_to_thread.assert_awaited_once()
|
||||
assert result == expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -242,13 +237,11 @@ class TestFail2BanClientSend:
|
||||
"""``send()`` must re-raise :class:`Fail2BanConnectionError`."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=Fail2BanConnectionError("unreachable", "/fake/fail2ban.sock")
|
||||
)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with patch(
|
||||
"asyncio.to_thread",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Fail2BanConnectionError("unreachable", "/fake/fail2ban.sock"),
|
||||
):
|
||||
with pytest.raises(Fail2BanConnectionError):
|
||||
await client.send(["status"])
|
||||
|
||||
@@ -257,13 +250,11 @@ class TestFail2BanClientSend:
|
||||
"""``send()`` must log a warning when a connection error occurs."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock")
|
||||
)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with patch(
|
||||
"asyncio.to_thread",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock"),
|
||||
):
|
||||
with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanConnectionError):
|
||||
await client.send(["ping"])
|
||||
|
||||
@@ -278,13 +269,11 @@ class TestFail2BanClientSend:
|
||||
"""``send()`` must re-raise :class:`Fail2BanProtocolError`."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=Fail2BanProtocolError("bad pickle")
|
||||
)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with patch(
|
||||
"asyncio.to_thread",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Fail2BanProtocolError("bad pickle"),
|
||||
):
|
||||
with pytest.raises(Fail2BanProtocolError):
|
||||
await client.send(["status"])
|
||||
|
||||
@@ -293,13 +282,11 @@ class TestFail2BanClientSend:
|
||||
"""``send()`` must propagate :class:`Fail2BanProtocolError` to the caller."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=Fail2BanProtocolError("bad pickle")
|
||||
)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with patch(
|
||||
"asyncio.to_thread",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Fail2BanProtocolError("bad pickle"),
|
||||
):
|
||||
with pytest.raises(Fail2BanProtocolError):
|
||||
await client.send(["status"])
|
||||
|
||||
@@ -308,13 +295,11 @@ class TestFail2BanClientSend:
|
||||
"""``send()`` must log an error when a protocol error occurs."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=Fail2BanProtocolError("corrupt response")
|
||||
)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with patch(
|
||||
"asyncio.to_thread",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=Fail2BanProtocolError("corrupt response"),
|
||||
):
|
||||
with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanProtocolError):
|
||||
await client.send(["get", "sshd", "banned"])
|
||||
|
||||
@@ -492,11 +477,7 @@ class TestFail2BanClientSemaphore:
|
||||
in_flight.pop()
|
||||
return (0, "ok")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_loop_getter:
|
||||
mock_loop = MagicMock()
|
||||
mock_loop.run_in_executor = _fast_executor
|
||||
mock_loop_getter.return_value = mock_loop
|
||||
|
||||
with patch("asyncio.to_thread", new=_fast_executor):
|
||||
tasks = [
|
||||
_asyncio.create_task(client.send(["ping"])) for _ in range(10)
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user