Refactor backend services and jail configuration
- Refactor action_config_service, filter_config_service, jail_config_service, and jail_service - Add jail_socket utility module for socket communication - Update test_jail_service with new test cases - Update architecture and task documentation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -25,13 +25,6 @@ from app.exceptions import (
|
||||
ConfigWriteError,
|
||||
JailNotFoundInConfigError,
|
||||
)
|
||||
import app.services.jail_service as jail_service
|
||||
from app.utils.config_file_utils import (
|
||||
_get_active_jail_names as _config_file_get_active_jail_names,
|
||||
_parse_jails_sync as _config_file_parse_jails_sync,
|
||||
_safe_jail_name,
|
||||
build_parser,
|
||||
)
|
||||
from app.models.config import (
|
||||
ActionConfig,
|
||||
ActionConfigUpdate,
|
||||
@@ -42,6 +35,17 @@ from app.models.config import (
|
||||
)
|
||||
from app.utils import conffile_parser
|
||||
from app.utils.async_utils import run_blocking
|
||||
from app.utils.config_file_utils import (
|
||||
_get_active_jail_names as _config_file_get_active_jail_names,
|
||||
)
|
||||
from app.utils.config_file_utils import (
|
||||
_parse_jails_sync as _config_file_parse_jails_sync,
|
||||
)
|
||||
from app.utils.config_file_utils import (
|
||||
_safe_jail_name,
|
||||
build_parser,
|
||||
)
|
||||
from app.utils.jail_socket import reload_all
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -681,7 +685,7 @@ async def update_action(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_action_update_failed",
|
||||
@@ -749,7 +753,7 @@ async def create_action(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_action_create_failed",
|
||||
@@ -874,7 +878,7 @@ async def assign_action_to_jail(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_assign_action_failed",
|
||||
@@ -932,7 +936,7 @@ async def remove_action_from_jail(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_remove_action_failed",
|
||||
|
||||
@@ -23,14 +23,6 @@ from app.exceptions import (
|
||||
FilterReadonlyError,
|
||||
JailNotFoundInConfigError,
|
||||
)
|
||||
import app.services.jail_service as jail_service
|
||||
from app.utils.config_file_utils import (
|
||||
_get_active_jail_names as _config_file_get_active_jail_names,
|
||||
_parse_jails_sync as _config_file_parse_jails_sync,
|
||||
_safe_filter_name,
|
||||
_safe_jail_name,
|
||||
set_jail_local_key_sync,
|
||||
)
|
||||
from app.models.config import (
|
||||
AssignFilterRequest,
|
||||
FilterConfig,
|
||||
@@ -41,6 +33,18 @@ from app.models.config import (
|
||||
)
|
||||
from app.utils import conffile_parser
|
||||
from app.utils.async_utils import run_blocking
|
||||
from app.utils.config_file_utils import (
|
||||
_get_active_jail_names as _config_file_get_active_jail_names,
|
||||
)
|
||||
from app.utils.config_file_utils import (
|
||||
_parse_jails_sync as _config_file_parse_jails_sync,
|
||||
)
|
||||
from app.utils.config_file_utils import (
|
||||
_safe_filter_name,
|
||||
_safe_jail_name,
|
||||
set_jail_local_key_sync,
|
||||
)
|
||||
from app.utils.jail_socket import reload_all
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -508,7 +512,7 @@ async def update_filter(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_filter_update_failed",
|
||||
@@ -582,7 +586,7 @@ async def create_filter(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_filter_create_failed",
|
||||
@@ -704,7 +708,7 @@ async def assign_filter_to_jail(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_assign_filter_failed",
|
||||
|
||||
@@ -25,17 +25,6 @@ from app.exceptions import (
|
||||
JailNotFoundError,
|
||||
JailNotFoundInConfigError,
|
||||
)
|
||||
import app.services.jail_service as jail_service
|
||||
from app.utils.config_file_utils import (
|
||||
_build_inactive_jail,
|
||||
_parse_jails_sync as _config_file_parse_jails_sync,
|
||||
_get_active_jail_names as _config_file_get_active_jail_names,
|
||||
_probe_fail2ban_running,
|
||||
_safe_jail_name,
|
||||
_validate_jail_config_sync as _config_file_validate_jail_config_sync,
|
||||
start_daemon,
|
||||
wait_for_fail2ban,
|
||||
)
|
||||
from app.models.config import (
|
||||
ActivateJailRequest,
|
||||
InactiveJail,
|
||||
@@ -46,7 +35,23 @@ from app.models.config import (
|
||||
)
|
||||
from app.services import health_service
|
||||
from app.utils.async_utils import run_blocking
|
||||
from app.utils.fail2ban_client import Fail2BanClient
|
||||
from app.utils.config_file_utils import (
|
||||
_build_inactive_jail,
|
||||
_probe_fail2ban_running,
|
||||
_safe_jail_name,
|
||||
start_daemon,
|
||||
wait_for_fail2ban,
|
||||
)
|
||||
from app.utils.config_file_utils import (
|
||||
_get_active_jail_names as _config_file_get_active_jail_names,
|
||||
)
|
||||
from app.utils.config_file_utils import (
|
||||
_parse_jails_sync as _config_file_parse_jails_sync,
|
||||
)
|
||||
from app.utils.config_file_utils import (
|
||||
_validate_jail_config_sync as _config_file_validate_jail_config_sync,
|
||||
)
|
||||
from app.utils.jail_socket import reload_all
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -404,7 +409,7 @@ async def _activate_jail(
|
||||
# Activation reload — if it fails, roll back immediately #
|
||||
# ---------------------------------------------------------------------- #
|
||||
try:
|
||||
await jail_service.reload_all(socket_path, include_jails=[name])
|
||||
await reload_all(socket_path, include_jails=[name])
|
||||
except JailNotFoundError as exc:
|
||||
# Jail configuration is invalid (e.g. missing logpath that prevents
|
||||
# fail2ban from loading the jail). Roll back and provide a specific error.
|
||||
@@ -546,7 +551,7 @@ async def _rollback_activation_async(
|
||||
|
||||
# Step 2 — reload fail2ban with the restored config.
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await reload_all(socket_path)
|
||||
log.info("jail_activation_rollback_reload_ok", jail=name)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("jail_activation_rollback_reload_failed", jail=name, error=str(exc))
|
||||
@@ -626,7 +631,7 @@ async def _deactivate_jail(
|
||||
)
|
||||
|
||||
try:
|
||||
await jail_service.reload_all(socket_path, exclude_jails=[name])
|
||||
await reload_all(socket_path, exclude_jails=[name])
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("reload_after_deactivate_failed", jail=name, error=str(exc))
|
||||
|
||||
|
||||
@@ -36,7 +36,6 @@ from app.utils.fail2ban_client import (
|
||||
Fail2BanCommand,
|
||||
Fail2BanConnectionError,
|
||||
Fail2BanResponse,
|
||||
Fail2BanToken,
|
||||
)
|
||||
from app.utils.fail2ban_response import (
|
||||
ensure_list,
|
||||
@@ -44,6 +43,7 @@ from app.utils.fail2ban_response import (
|
||||
ok,
|
||||
to_dict,
|
||||
)
|
||||
from app.utils.jail_socket import reload_all
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable
|
||||
@@ -55,6 +55,8 @@ if TYPE_CHECKING:
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
__all__ = ["reload_all"]
|
||||
|
||||
class IpLookupResult(TypedDict):
|
||||
"""Result returned by :func:`lookup_ip`.
|
||||
|
||||
@@ -73,12 +75,6 @@ class IpLookupResult(TypedDict):
|
||||
|
||||
_SOCKET_TIMEOUT: float = 10.0
|
||||
|
||||
# Guard against concurrent reload_all calls. Overlapping ``reload --all``
|
||||
# commands sent to fail2ban's socket produce undefined behaviour and may cause
|
||||
# jails to be permanently removed from the daemon. Serialising them here
|
||||
# ensures only one reload stream is in-flight at a time.
|
||||
_reload_all_lock: asyncio.Lock | None = None
|
||||
|
||||
# Capability detection for optional fail2ban transmitter commands (backend, idle).
|
||||
# These commands are not supported in all fail2ban versions. Caching the result
|
||||
# avoids sending unsupported commands every polling cycle and spamming the
|
||||
@@ -87,19 +83,6 @@ _backend_cmd_supported: bool | None = None
|
||||
_backend_cmd_lock: asyncio.Lock | None = None
|
||||
|
||||
|
||||
def _get_reload_all_lock() -> asyncio.Lock:
|
||||
"""Return the shared reload-all lock, initialising it lazily.
|
||||
|
||||
Asyncio primitives must be created inside an active event loop in test
|
||||
environments that create new loops per test. Lazily initialising the lock
|
||||
avoids binding it to the import-time loop.
|
||||
"""
|
||||
global _reload_all_lock
|
||||
if _reload_all_lock is None:
|
||||
_reload_all_lock = asyncio.Lock()
|
||||
return _reload_all_lock
|
||||
|
||||
|
||||
def _get_backend_cmd_lock() -> asyncio.Lock:
|
||||
"""Return the shared backend capability probe lock, initialising it lazily.
|
||||
|
||||
@@ -605,65 +588,6 @@ async def reload_jail(socket_path: str, name: str) -> None:
|
||||
raise JailOperationError(str(exc)) from exc
|
||||
|
||||
|
||||
async def reload_all(
|
||||
socket_path: str,
|
||||
*,
|
||||
include_jails: list[str] | None = None,
|
||||
exclude_jails: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Reload all fail2ban jails at once.
|
||||
|
||||
Fetches the current jail list first so that a ``['start', name]`` entry
|
||||
can be included in the config stream for every active jail. Without a
|
||||
non-empty stream the end-of-reload phase deletes every jail that received
|
||||
no configuration commands.
|
||||
|
||||
*include_jails* are added to the stream (e.g. a newly activated jail that
|
||||
is not yet running). *exclude_jails* are removed from the stream (e.g. a
|
||||
jail that was just deactivated and should not be restarted).
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
include_jails: Extra jail names to add to the start stream.
|
||||
exclude_jails: Jail names to remove from the start stream.
|
||||
|
||||
Raises:
|
||||
JailNotFoundError: If a jail in *include_jails* does not exist or
|
||||
its configuration is invalid (e.g. missing logpath).
|
||||
JailOperationError: If fail2ban reports the operation failed for
|
||||
a different reason.
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
||||
cannot be reached.
|
||||
"""
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
async with _get_reload_all_lock():
|
||||
try:
|
||||
# Resolve jail names so we can build the minimal config stream.
|
||||
status_raw = ok(await client.send(["status"]))
|
||||
status_dict = to_dict(status_raw)
|
||||
jail_list_raw: str = str(status_dict.get("Jail list", ""))
|
||||
jail_names = [n.strip() for n in jail_list_raw.split(",") if n.strip()]
|
||||
|
||||
# Merge include/exclude sets so the stream matches the desired state.
|
||||
names_set: set[str] = set(jail_names)
|
||||
if include_jails:
|
||||
names_set.update(include_jails)
|
||||
if exclude_jails:
|
||||
names_set -= set(exclude_jails)
|
||||
|
||||
stream: list[list[object]] = [["start", n] for n in sorted(names_set)]
|
||||
ok(await client.send(["reload", "--all", [], cast("Fail2BanToken", stream)]))
|
||||
log.info("all_jails_reloaded")
|
||||
except ValueError as exc:
|
||||
# Detect UnknownJailException (missing or invalid jail configuration)
|
||||
# and re-raise as JailNotFoundError for better error specificity.
|
||||
if is_not_found_error(exc):
|
||||
# Extract the jail name from include_jails if available.
|
||||
jail_name = include_jails[0] if include_jails else "unknown"
|
||||
raise JailNotFoundError(jail_name) from exc
|
||||
raise JailOperationError(str(exc)) from exc
|
||||
|
||||
|
||||
async def restart(socket_path: str) -> None:
|
||||
"""Stop the fail2ban daemon via the Unix socket.
|
||||
|
||||
|
||||
108
backend/app/utils/jail_socket.py
Normal file
108
backend/app/utils/jail_socket.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Low-level socket operations for jail management.
|
||||
|
||||
Provides shared socket utilities for reloading jails across all services.
|
||||
These operations are extracted to a utility layer to avoid circular dependencies
|
||||
between sibling services (jail_service, jail_config_service, action_config_service,
|
||||
filter_config_service).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import cast
|
||||
|
||||
import structlog
|
||||
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
from app.utils.fail2ban_client import (
|
||||
Fail2BanClient,
|
||||
Fail2BanToken,
|
||||
)
|
||||
from app.utils.fail2ban_response import (
|
||||
is_not_found_error,
|
||||
ok,
|
||||
to_dict,
|
||||
)
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
# Socket communication timeout in seconds.
|
||||
SOCKET_TIMEOUT: float = 10.0
|
||||
|
||||
# Guard against concurrent reload_all calls. Overlapping ``reload --all``
|
||||
# commands sent to fail2ban's socket produce undefined behaviour and may cause
|
||||
# jails to be permanently removed from the daemon. Serialising them here
|
||||
# ensures only one reload stream is in-flight at a time.
|
||||
_reload_all_lock: asyncio.Lock | None = None
|
||||
|
||||
|
||||
def _get_reload_all_lock() -> asyncio.Lock:
|
||||
"""Return the shared reload-all lock, initialising it lazily.
|
||||
|
||||
Asyncio primitives must be created inside an active event loop in test
|
||||
environments that create new loops per test. Lazily initialising the lock
|
||||
avoids binding it to the import-time loop.
|
||||
"""
|
||||
global _reload_all_lock
|
||||
if _reload_all_lock is None:
|
||||
_reload_all_lock = asyncio.Lock()
|
||||
return _reload_all_lock
|
||||
|
||||
|
||||
async def reload_all(
|
||||
socket_path: str,
|
||||
*,
|
||||
include_jails: list[str] | None = None,
|
||||
exclude_jails: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Reload all fail2ban jails at once.
|
||||
|
||||
Fetches the current jail list first so that a ``['start', name]`` entry
|
||||
can be included in the config stream for every active jail. Without a
|
||||
non-empty stream the end-of-reload phase deletes every jail that received
|
||||
no configuration commands.
|
||||
|
||||
*include_jails* are added to the stream (e.g. a newly activated jail that
|
||||
is not yet running). *exclude_jails* are removed from the stream (e.g. a
|
||||
jail that was just deactivated and should not be restarted).
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
include_jails: Extra jail names to add to the start stream.
|
||||
exclude_jails: Jail names to remove from the start stream.
|
||||
|
||||
Raises:
|
||||
JailNotFoundError: If a jail in *include_jails* does not exist or
|
||||
its configuration is invalid (e.g. missing logpath).
|
||||
JailOperationError: If fail2ban reports the operation failed for
|
||||
a different reason.
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
||||
cannot be reached.
|
||||
"""
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=SOCKET_TIMEOUT)
|
||||
async with _get_reload_all_lock():
|
||||
try:
|
||||
# Resolve jail names so we can build the minimal config stream.
|
||||
status_raw = ok(await client.send(["status"]))
|
||||
status_dict = to_dict(status_raw)
|
||||
jail_list_raw: str = str(status_dict.get("Jail list", ""))
|
||||
jail_names = [n.strip() for n in jail_list_raw.split(",") if n.strip()]
|
||||
|
||||
# Merge include/exclude sets so the stream matches the desired state.
|
||||
names_set: set[str] = set(jail_names)
|
||||
if include_jails:
|
||||
names_set.update(include_jails)
|
||||
if exclude_jails:
|
||||
names_set -= set(exclude_jails)
|
||||
|
||||
stream: list[list[object]] = [["start", n] for n in sorted(names_set)]
|
||||
ok(await client.send(["reload", "--all", [], cast("Fail2BanToken", stream)]))
|
||||
log.info("all_jails_reloaded")
|
||||
except ValueError as exc:
|
||||
# Detect UnknownJailException (missing or invalid jail configuration)
|
||||
# and re-raise as JailNotFoundError for better error specificity.
|
||||
if is_not_found_error(exc):
|
||||
# Extract the jail name from include_jails if available.
|
||||
jail_name = include_jails[0] if include_jails else "unknown"
|
||||
raise JailNotFoundError(jail_name) from exc
|
||||
raise JailOperationError(str(exc)) from exc
|
||||
@@ -15,6 +15,7 @@ from app.models.geo import GeoDetail, GeoInfo
|
||||
from app.models.jail import JailDetailResponse, JailListResponse
|
||||
from app.services import ban_service, jail_service
|
||||
from app.services.jail_service import JailNotFoundError, JailOperationError
|
||||
from app.utils import jail_socket
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -75,6 +76,7 @@ def _patch_client(responses: dict[str, Any]) -> Any:
|
||||
stack = contextlib.ExitStack()
|
||||
stack.enter_context(patch("app.services.jail_service.Fail2BanClient", _FakeClient))
|
||||
stack.enter_context(patch("app.services.ban_service.Fail2BanClient", _FakeClient))
|
||||
stack.enter_context(patch("app.utils.jail_socket.Fail2BanClient", _FakeClient))
|
||||
return stack
|
||||
|
||||
|
||||
@@ -281,12 +283,12 @@ class TestLockInitialization:
|
||||
|
||||
async def test_reload_all_lock_is_lazy_initialised(self) -> None:
|
||||
"""The reload-all lock should be created lazily on first use."""
|
||||
jail_service._reload_all_lock = None
|
||||
jail_socket._reload_all_lock = None
|
||||
|
||||
lock = _ = jail_service._get_reload_all_lock()
|
||||
lock = _ = jail_socket._get_reload_all_lock()
|
||||
|
||||
assert isinstance(lock, asyncio.Lock)
|
||||
assert jail_service._reload_all_lock is lock
|
||||
assert jail_socket._reload_all_lock is lock
|
||||
|
||||
async def test_backend_cmd_lock_is_lazy_initialised(self) -> None:
|
||||
"""The backend capability probe lock should be created lazily on first use."""
|
||||
|
||||
Reference in New Issue
Block a user