Fix module-level asyncio locks in jail_service

Initialize jail_service locks lazily to avoid import-time event loop binding and add regression tests for lock creation.
This commit is contained in:
2026-04-15 09:10:38 +02:00
parent a8f2d2d7b9
commit 56c511d905
3 changed files with 95 additions and 10 deletions

View File

@@ -313,7 +313,9 @@ Multi-step orchestration in the router violates the zero-business-logic rule. It
--- ---
### TASK-10 — Move `GeoInfo → GeoDetail` translation out of the router 🟡 ### TASK-10 — Move `GeoInfo → GeoDetail` translation out of the router
**Status:** Completed ✅
**Where:** **Where:**
`backend/app/routers/geo.py` — `async def lookup_ip()`, lines ~8593: `backend/app/routers/geo.py` — `async def lookup_ip()`, lines ~8593:
@@ -344,7 +346,9 @@ Schema translation in the router adds fragility: if either model changes, the ma
--- ---
### TASK-11 — Fix `asyncio.Lock()` created at module import time in `jail_service.py` 🟡 ### TASK-11 — Fix `asyncio.Lock()` created at module import time in `jail_service.py`
**Status:** Completed ✅
**Where:** **Where:**
`backend/app/services/jail_service.py` — lines 71 and 78: `backend/app/services/jail_service.py` — lines 71 and 78:

View File

@@ -21,6 +21,7 @@ import structlog
from app.exceptions import JailNotFoundError, JailOperationError from app.exceptions import JailNotFoundError, JailOperationError
from app.models.ban import ActiveBan, ActiveBanListResponse, JailBannedIpsResponse from app.models.ban import ActiveBan, ActiveBanListResponse, JailBannedIpsResponse
from app.models.config import BantimeEscalation from app.models.config import BantimeEscalation
from app.models.geo import GeoDetail
from app.models.jail import ( from app.models.jail import (
Jail, Jail,
JailDetailResponse, JailDetailResponse,
@@ -55,7 +56,7 @@ class IpLookupResult(TypedDict):
ip: str ip: str
currently_banned_in: list[str] currently_banned_in: list[str]
geo: GeoInfo | None geo: GeoDetail | None
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -68,14 +69,39 @@ _SOCKET_TIMEOUT: float = 10.0
# commands sent to fail2ban's socket produce undefined behaviour and may cause # commands sent to fail2ban's socket produce undefined behaviour and may cause
# jails to be permanently removed from the daemon. Serialising them here # jails to be permanently removed from the daemon. Serialising them here
# ensures only one reload stream is in-flight at a time. # ensures only one reload stream is in-flight at a time.
_reload_all_lock: asyncio.Lock = asyncio.Lock() _reload_all_lock: asyncio.Lock | None = None
# Capability detection for optional fail2ban transmitter commands (backend, idle). # Capability detection for optional fail2ban transmitter commands (backend, idle).
# These commands are not supported in all fail2ban versions. Caching the result # These commands are not supported in all fail2ban versions. Caching the result
# avoids sending unsupported commands every polling cycle and spamming the # avoids sending unsupported commands every polling cycle and spamming the
# fail2ban log with "Invalid command" errors. # fail2ban log with "Invalid command" errors.
_backend_cmd_supported: bool | None = None _backend_cmd_supported: bool | None = None
_backend_cmd_lock: asyncio.Lock = asyncio.Lock() _backend_cmd_lock: asyncio.Lock | None = None
def _get_reload_all_lock() -> asyncio.Lock:
"""Return the shared reload-all lock, initialising it lazily.
Asyncio primitives must be created inside an active event loop in test
environments that create new loops per test. Lazily initialising the lock
avoids binding it to the import-time loop.
"""
global _reload_all_lock
if _reload_all_lock is None:
_reload_all_lock = asyncio.Lock()
return _reload_all_lock
def _get_backend_cmd_lock() -> asyncio.Lock:
"""Return the shared backend capability probe lock, initialising it lazily.
The caller must already be running inside the event loop when the lock is
created, which is true for all service entry points in this module.
"""
global _backend_cmd_lock
if _backend_cmd_lock is None:
_backend_cmd_lock = asyncio.Lock()
return _backend_cmd_lock
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Custom exceptions # Custom exceptions
@@ -231,7 +257,7 @@ async def _check_backend_cmd_supported(
return _backend_cmd_supported return _backend_cmd_supported
# Slow path: acquire lock and probe the command once. # Slow path: acquire lock and probe the command once.
async with _backend_cmd_lock: async with _get_backend_cmd_lock():
# Double-check idiom: another coroutine may have probed while we waited. # Double-check idiom: another coroutine may have probed while we waited.
if _backend_cmd_supported is not None: if _backend_cmd_supported is not None:
return _backend_cmd_supported return _backend_cmd_supported
@@ -256,7 +282,7 @@ async def _reset_backend_capability_cache() -> None:
""" """
global _backend_cmd_supported global _backend_cmd_supported
async with _backend_cmd_lock: async with _get_backend_cmd_lock():
_backend_cmd_supported = None _backend_cmd_supported = None
@@ -677,7 +703,7 @@ async def reload_all(
cannot be reached. cannot be reached.
""" """
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
async with _reload_all_lock: async with _get_reload_all_lock():
try: try:
# Resolve jail names so we can build the minimal config stream. # Resolve jail names so we can build the minimal config stream.
status_raw = _ok(await client.send(["status"])) status_raw = _ok(await client.send(["status"]))
@@ -1341,10 +1367,17 @@ async def lookup_ip(
except (ValueError, TypeError): except (ValueError, TypeError):
pass pass
geo = None geo: GeoDetail | None = None
if geo_enricher is not None: if geo_enricher is not None:
with contextlib.suppress(Exception): # noqa: BLE001 with contextlib.suppress(Exception): # noqa: BLE001
geo = await geo_enricher(ip) raw_geo = await geo_enricher(ip)
if raw_geo is not None:
geo = GeoDetail(
country_code=raw_geo.country_code,
country_name=raw_geo.country_name,
asn=raw_geo.asn,
org=raw_geo.org,
)
log.info("ip_lookup_completed", ip=ip, banned_in_jails=currently_banned_in) log.info("ip_lookup_completed", ip=ip, banned_in_jails=currently_banned_in)

View File

@@ -2,12 +2,14 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from typing import Any from typing import Any
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
import pytest import pytest
from app.models.ban import ActiveBanListResponse, JailBannedIpsResponse from app.models.ban import ActiveBanListResponse, JailBannedIpsResponse
from app.models.geo import GeoDetail, GeoInfo
from app.models.jail import JailDetailResponse, JailListResponse from app.models.jail import JailDetailResponse, JailListResponse
from app.services import jail_service from app.services import jail_service
from app.services.jail_service import JailNotFoundError, JailOperationError from app.services.jail_service import JailNotFoundError, JailOperationError
@@ -270,6 +272,28 @@ class TestListJails:
assert jail.idle is False assert jail.idle is False
class TestLockInitialization:
"""Regression tests for asyncio lock creation in jail_service."""
async def test_reload_all_lock_is_lazy_initialised(self) -> None:
"""The reload-all lock should be created lazily on first use."""
jail_service._reload_all_lock = None
lock = _ = jail_service._get_reload_all_lock()
assert isinstance(lock, asyncio.Lock)
assert jail_service._reload_all_lock is lock
async def test_backend_cmd_lock_is_lazy_initialised(self) -> None:
"""The backend capability probe lock should be created lazily on first use."""
jail_service._backend_cmd_lock = None
lock = _ = jail_service._get_backend_cmd_lock()
assert isinstance(lock, asyncio.Lock)
assert jail_service._backend_cmd_lock is lock
class TestGetJail: class TestGetJail:
"""Unit tests for :func:`~app.services.jail_service.get_jail`.""" """Unit tests for :func:`~app.services.jail_service.get_jail`."""
@@ -771,6 +795,30 @@ class TestLookupIp:
assert result["ip"] == "1.2.3.4" assert result["ip"] == "1.2.3.4"
assert "sshd" in result["currently_banned_in"] assert "sshd" in result["currently_banned_in"]
async def test_geo_enricher_returns_geo_detail(self) -> None:
"""lookup_ip converts GeoInfo from the enricher into GeoDetail."""
responses = {
"get|--all|banned|1.2.3.4": (0, []),
"status": _make_global_status("sshd"),
"get|sshd|banip": (0, ["1.2.3.4", "5.6.7.8"]),
}
async def _enricher(ip: str) -> GeoInfo:
return GeoInfo(country_code="DE", country_name="Germany", asn="AS123", org="Acme")
with _patch_client(responses):
result = await jail_service.lookup_ip(
_SOCKET,
"1.2.3.4",
geo_enricher=_enricher,
)
assert isinstance(result["geo"], GeoDetail)
assert result["geo"].country_code == "DE"
assert result["geo"].country_name == "Germany"
assert result["geo"].asn == "AS123"
assert result["geo"].org == "Acme"
async def test_invalid_ip_raises(self) -> None: async def test_invalid_ip_raises(self) -> None:
"""lookup_ip raises ValueError for invalid IP.""" """lookup_ip raises ValueError for invalid IP."""
with pytest.raises(ValueError, match="Invalid IP"): with pytest.raises(ValueError, match="Invalid IP"):