- Refactor action_config_service, filter_config_service, jail_config_service, and jail_service - Add jail_socket utility module for socket communication - Update test_jail_service with new test cases - Update architecture and task documentation Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
109 lines
4.2 KiB
Python
109 lines
4.2 KiB
Python
"""Low-level socket operations for jail management.
|
|
|
|
Provides shared socket utilities for reloading jails across all services.
|
|
These operations are extracted to a utility layer to avoid circular dependencies
|
|
between sibling services (jail_service, jail_config_service, action_config_service,
|
|
filter_config_service).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
from typing import cast
|
|
|
|
import structlog
|
|
|
|
from app.exceptions import JailNotFoundError, JailOperationError
|
|
from app.utils.fail2ban_client import (
|
|
Fail2BanClient,
|
|
Fail2BanToken,
|
|
)
|
|
from app.utils.fail2ban_response import (
|
|
is_not_found_error,
|
|
ok,
|
|
to_dict,
|
|
)
|
|
|
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
|
|
|
# Socket communication timeout in seconds.
|
|
SOCKET_TIMEOUT: float = 10.0
|
|
|
|
# Guard against concurrent reload_all calls. Overlapping ``reload --all``
|
|
# commands sent to fail2ban's socket produce undefined behaviour and may cause
|
|
# jails to be permanently removed from the daemon. Serialising them here
|
|
# ensures only one reload stream is in-flight at a time.
|
|
_reload_all_lock: asyncio.Lock | None = None
|
|
|
|
|
|
def _get_reload_all_lock() -> asyncio.Lock:
|
|
"""Return the shared reload-all lock, initialising it lazily.
|
|
|
|
Asyncio primitives must be created inside an active event loop in test
|
|
environments that create new loops per test. Lazily initialising the lock
|
|
avoids binding it to the import-time loop.
|
|
"""
|
|
global _reload_all_lock
|
|
if _reload_all_lock is None:
|
|
_reload_all_lock = asyncio.Lock()
|
|
return _reload_all_lock
|
|
|
|
|
|
async def reload_all(
|
|
socket_path: str,
|
|
*,
|
|
include_jails: list[str] | None = None,
|
|
exclude_jails: list[str] | None = None,
|
|
) -> None:
|
|
"""Reload all fail2ban jails at once.
|
|
|
|
Fetches the current jail list first so that a ``['start', name]`` entry
|
|
can be included in the config stream for every active jail. Without a
|
|
non-empty stream the end-of-reload phase deletes every jail that received
|
|
no configuration commands.
|
|
|
|
*include_jails* are added to the stream (e.g. a newly activated jail that
|
|
is not yet running). *exclude_jails* are removed from the stream (e.g. a
|
|
jail that was just deactivated and should not be restarted).
|
|
|
|
Args:
|
|
socket_path: Path to the fail2ban Unix domain socket.
|
|
include_jails: Extra jail names to add to the start stream.
|
|
exclude_jails: Jail names to remove from the start stream.
|
|
|
|
Raises:
|
|
JailNotFoundError: If a jail in *include_jails* does not exist or
|
|
its configuration is invalid (e.g. missing logpath).
|
|
JailOperationError: If fail2ban reports the operation failed for
|
|
a different reason.
|
|
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
|
cannot be reached.
|
|
"""
|
|
client = Fail2BanClient(socket_path=socket_path, timeout=SOCKET_TIMEOUT)
|
|
async with _get_reload_all_lock():
|
|
try:
|
|
# Resolve jail names so we can build the minimal config stream.
|
|
status_raw = ok(await client.send(["status"]))
|
|
status_dict = to_dict(status_raw)
|
|
jail_list_raw: str = str(status_dict.get("Jail list", ""))
|
|
jail_names = [n.strip() for n in jail_list_raw.split(",") if n.strip()]
|
|
|
|
# Merge include/exclude sets so the stream matches the desired state.
|
|
names_set: set[str] = set(jail_names)
|
|
if include_jails:
|
|
names_set.update(include_jails)
|
|
if exclude_jails:
|
|
names_set -= set(exclude_jails)
|
|
|
|
stream: list[list[object]] = [["start", n] for n in sorted(names_set)]
|
|
ok(await client.send(["reload", "--all", [], cast("Fail2BanToken", stream)]))
|
|
log.info("all_jails_reloaded")
|
|
except ValueError as exc:
|
|
# Detect UnknownJailException (missing or invalid jail configuration)
|
|
# and re-raise as JailNotFoundError for better error specificity.
|
|
if is_not_found_error(exc):
|
|
# Extract the jail name from include_jails if available.
|
|
jail_name = include_jails[0] if include_jails else "unknown"
|
|
raise JailNotFoundError(jail_name) from exc
|
|
raise JailOperationError(str(exc)) from exc
|