Fix module-level asyncio locks in jail_service
Initialize jail_service locks lazily to avoid import-time event loop binding and add regression tests for lock creation.
This commit is contained in:
@@ -313,7 +313,9 @@ Multi-step orchestration in the router violates the zero-business-logic rule. It
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### TASK-10 — Move `GeoInfo → GeoDetail` translation out of the router 🟡
|
### TASK-10 — Move `GeoInfo → GeoDetail` translation out of the router ✅
|
||||||
|
|
||||||
|
**Status:** Completed ✅
|
||||||
|
|
||||||
**Where:**
|
**Where:**
|
||||||
`backend/app/routers/geo.py` — `async def lookup_ip()`, lines ~85–93:
|
`backend/app/routers/geo.py` — `async def lookup_ip()`, lines ~85–93:
|
||||||
@@ -344,7 +346,9 @@ Schema translation in the router adds fragility: if either model changes, the ma
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
### TASK-11 — Fix `asyncio.Lock()` created at module import time in `jail_service.py` 🟡
|
### TASK-11 — Fix `asyncio.Lock()` created at module import time in `jail_service.py` ✅
|
||||||
|
|
||||||
|
**Status:** Completed ✅
|
||||||
|
|
||||||
**Where:**
|
**Where:**
|
||||||
`backend/app/services/jail_service.py` — lines 71 and 78:
|
`backend/app/services/jail_service.py` — lines 71 and 78:
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ import structlog
|
|||||||
from app.exceptions import JailNotFoundError, JailOperationError
|
from app.exceptions import JailNotFoundError, JailOperationError
|
||||||
from app.models.ban import ActiveBan, ActiveBanListResponse, JailBannedIpsResponse
|
from app.models.ban import ActiveBan, ActiveBanListResponse, JailBannedIpsResponse
|
||||||
from app.models.config import BantimeEscalation
|
from app.models.config import BantimeEscalation
|
||||||
|
from app.models.geo import GeoDetail
|
||||||
from app.models.jail import (
|
from app.models.jail import (
|
||||||
Jail,
|
Jail,
|
||||||
JailDetailResponse,
|
JailDetailResponse,
|
||||||
@@ -55,7 +56,7 @@ class IpLookupResult(TypedDict):
|
|||||||
|
|
||||||
ip: str
|
ip: str
|
||||||
currently_banned_in: list[str]
|
currently_banned_in: list[str]
|
||||||
geo: GeoInfo | None
|
geo: GeoDetail | None
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -68,14 +69,39 @@ _SOCKET_TIMEOUT: float = 10.0
|
|||||||
# commands sent to fail2ban's socket produce undefined behaviour and may cause
|
# commands sent to fail2ban's socket produce undefined behaviour and may cause
|
||||||
# jails to be permanently removed from the daemon. Serialising them here
|
# jails to be permanently removed from the daemon. Serialising them here
|
||||||
# ensures only one reload stream is in-flight at a time.
|
# ensures only one reload stream is in-flight at a time.
|
||||||
_reload_all_lock: asyncio.Lock = asyncio.Lock()
|
_reload_all_lock: asyncio.Lock | None = None
|
||||||
|
|
||||||
# Capability detection for optional fail2ban transmitter commands (backend, idle).
|
# Capability detection for optional fail2ban transmitter commands (backend, idle).
|
||||||
# These commands are not supported in all fail2ban versions. Caching the result
|
# These commands are not supported in all fail2ban versions. Caching the result
|
||||||
# avoids sending unsupported commands every polling cycle and spamming the
|
# avoids sending unsupported commands every polling cycle and spamming the
|
||||||
# fail2ban log with "Invalid command" errors.
|
# fail2ban log with "Invalid command" errors.
|
||||||
_backend_cmd_supported: bool | None = None
|
_backend_cmd_supported: bool | None = None
|
||||||
_backend_cmd_lock: asyncio.Lock = asyncio.Lock()
|
_backend_cmd_lock: asyncio.Lock | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_reload_all_lock() -> asyncio.Lock:
|
||||||
|
"""Return the shared reload-all lock, initialising it lazily.
|
||||||
|
|
||||||
|
Asyncio primitives must be created inside an active event loop in test
|
||||||
|
environments that create new loops per test. Lazily initialising the lock
|
||||||
|
avoids binding it to the import-time loop.
|
||||||
|
"""
|
||||||
|
global _reload_all_lock
|
||||||
|
if _reload_all_lock is None:
|
||||||
|
_reload_all_lock = asyncio.Lock()
|
||||||
|
return _reload_all_lock
|
||||||
|
|
||||||
|
|
||||||
|
def _get_backend_cmd_lock() -> asyncio.Lock:
|
||||||
|
"""Return the shared backend capability probe lock, initialising it lazily.
|
||||||
|
|
||||||
|
The caller must already be running inside the event loop when the lock is
|
||||||
|
created, which is true for all service entry points in this module.
|
||||||
|
"""
|
||||||
|
global _backend_cmd_lock
|
||||||
|
if _backend_cmd_lock is None:
|
||||||
|
_backend_cmd_lock = asyncio.Lock()
|
||||||
|
return _backend_cmd_lock
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Custom exceptions
|
# Custom exceptions
|
||||||
@@ -231,7 +257,7 @@ async def _check_backend_cmd_supported(
|
|||||||
return _backend_cmd_supported
|
return _backend_cmd_supported
|
||||||
|
|
||||||
# Slow path: acquire lock and probe the command once.
|
# Slow path: acquire lock and probe the command once.
|
||||||
async with _backend_cmd_lock:
|
async with _get_backend_cmd_lock():
|
||||||
# Double-check idiom: another coroutine may have probed while we waited.
|
# Double-check idiom: another coroutine may have probed while we waited.
|
||||||
if _backend_cmd_supported is not None:
|
if _backend_cmd_supported is not None:
|
||||||
return _backend_cmd_supported
|
return _backend_cmd_supported
|
||||||
@@ -256,7 +282,7 @@ async def _reset_backend_capability_cache() -> None:
|
|||||||
"""
|
"""
|
||||||
global _backend_cmd_supported
|
global _backend_cmd_supported
|
||||||
|
|
||||||
async with _backend_cmd_lock:
|
async with _get_backend_cmd_lock():
|
||||||
_backend_cmd_supported = None
|
_backend_cmd_supported = None
|
||||||
|
|
||||||
|
|
||||||
@@ -677,7 +703,7 @@ async def reload_all(
|
|||||||
cannot be reached.
|
cannot be reached.
|
||||||
"""
|
"""
|
||||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||||
async with _reload_all_lock:
|
async with _get_reload_all_lock():
|
||||||
try:
|
try:
|
||||||
# Resolve jail names so we can build the minimal config stream.
|
# Resolve jail names so we can build the minimal config stream.
|
||||||
status_raw = _ok(await client.send(["status"]))
|
status_raw = _ok(await client.send(["status"]))
|
||||||
@@ -1341,10 +1367,17 @@ async def lookup_ip(
|
|||||||
except (ValueError, TypeError):
|
except (ValueError, TypeError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
geo = None
|
geo: GeoDetail | None = None
|
||||||
if geo_enricher is not None:
|
if geo_enricher is not None:
|
||||||
with contextlib.suppress(Exception): # noqa: BLE001
|
with contextlib.suppress(Exception): # noqa: BLE001
|
||||||
geo = await geo_enricher(ip)
|
raw_geo = await geo_enricher(ip)
|
||||||
|
if raw_geo is not None:
|
||||||
|
geo = GeoDetail(
|
||||||
|
country_code=raw_geo.country_code,
|
||||||
|
country_name=raw_geo.country_name,
|
||||||
|
asn=raw_geo.asn,
|
||||||
|
org=raw_geo.org,
|
||||||
|
)
|
||||||
|
|
||||||
log.info("ip_lookup_completed", ip=ip, banned_in_jails=currently_banned_in)
|
log.info("ip_lookup_completed", ip=ip, banned_in_jails=currently_banned_in)
|
||||||
|
|
||||||
|
|||||||
@@ -2,12 +2,14 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.models.ban import ActiveBanListResponse, JailBannedIpsResponse
|
from app.models.ban import ActiveBanListResponse, JailBannedIpsResponse
|
||||||
|
from app.models.geo import GeoDetail, GeoInfo
|
||||||
from app.models.jail import JailDetailResponse, JailListResponse
|
from app.models.jail import JailDetailResponse, JailListResponse
|
||||||
from app.services import jail_service
|
from app.services import jail_service
|
||||||
from app.services.jail_service import JailNotFoundError, JailOperationError
|
from app.services.jail_service import JailNotFoundError, JailOperationError
|
||||||
@@ -270,6 +272,28 @@ class TestListJails:
|
|||||||
assert jail.idle is False
|
assert jail.idle is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestLockInitialization:
|
||||||
|
"""Regression tests for asyncio lock creation in jail_service."""
|
||||||
|
|
||||||
|
async def test_reload_all_lock_is_lazy_initialised(self) -> None:
|
||||||
|
"""The reload-all lock should be created lazily on first use."""
|
||||||
|
jail_service._reload_all_lock = None
|
||||||
|
|
||||||
|
lock = _ = jail_service._get_reload_all_lock()
|
||||||
|
|
||||||
|
assert isinstance(lock, asyncio.Lock)
|
||||||
|
assert jail_service._reload_all_lock is lock
|
||||||
|
|
||||||
|
async def test_backend_cmd_lock_is_lazy_initialised(self) -> None:
|
||||||
|
"""The backend capability probe lock should be created lazily on first use."""
|
||||||
|
jail_service._backend_cmd_lock = None
|
||||||
|
|
||||||
|
lock = _ = jail_service._get_backend_cmd_lock()
|
||||||
|
|
||||||
|
assert isinstance(lock, asyncio.Lock)
|
||||||
|
assert jail_service._backend_cmd_lock is lock
|
||||||
|
|
||||||
|
|
||||||
class TestGetJail:
|
class TestGetJail:
|
||||||
"""Unit tests for :func:`~app.services.jail_service.get_jail`."""
|
"""Unit tests for :func:`~app.services.jail_service.get_jail`."""
|
||||||
|
|
||||||
@@ -771,6 +795,30 @@ class TestLookupIp:
|
|||||||
assert result["ip"] == "1.2.3.4"
|
assert result["ip"] == "1.2.3.4"
|
||||||
assert "sshd" in result["currently_banned_in"]
|
assert "sshd" in result["currently_banned_in"]
|
||||||
|
|
||||||
|
async def test_geo_enricher_returns_geo_detail(self) -> None:
|
||||||
|
"""lookup_ip converts GeoInfo from the enricher into GeoDetail."""
|
||||||
|
responses = {
|
||||||
|
"get|--all|banned|1.2.3.4": (0, []),
|
||||||
|
"status": _make_global_status("sshd"),
|
||||||
|
"get|sshd|banip": (0, ["1.2.3.4", "5.6.7.8"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _enricher(ip: str) -> GeoInfo:
|
||||||
|
return GeoInfo(country_code="DE", country_name="Germany", asn="AS123", org="Acme")
|
||||||
|
|
||||||
|
with _patch_client(responses):
|
||||||
|
result = await jail_service.lookup_ip(
|
||||||
|
_SOCKET,
|
||||||
|
"1.2.3.4",
|
||||||
|
geo_enricher=_enricher,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result["geo"], GeoDetail)
|
||||||
|
assert result["geo"].country_code == "DE"
|
||||||
|
assert result["geo"].country_name == "Germany"
|
||||||
|
assert result["geo"].asn == "AS123"
|
||||||
|
assert result["geo"].org == "Acme"
|
||||||
|
|
||||||
async def test_invalid_ip_raises(self) -> None:
|
async def test_invalid_ip_raises(self) -> None:
|
||||||
"""lookup_ip raises ValueError for invalid IP."""
|
"""lookup_ip raises ValueError for invalid IP."""
|
||||||
with pytest.raises(ValueError, match="Invalid IP"):
|
with pytest.raises(ValueError, match="Invalid IP"):
|
||||||
|
|||||||
Reference in New Issue
Block a user