"""Low-level socket operations for jail management. Provides shared socket utilities for reloading jails across all services. These operations are extracted to a utility layer to avoid circular dependencies between sibling services (jail_service, jail_config_service, action_config_service, filter_config_service). """ from __future__ import annotations import asyncio from typing import cast import structlog from app.exceptions import JailNotFoundError, JailOperationError from app.utils.fail2ban_client import ( Fail2BanClient, Fail2BanToken, ) from app.utils.fail2ban_response import ( is_not_found_error, ok, to_dict, ) log: structlog.stdlib.BoundLogger = structlog.get_logger() # Socket communication timeout in seconds. SOCKET_TIMEOUT: float = 10.0 # Guard against concurrent reload_all calls. Overlapping ``reload --all`` # commands sent to fail2ban's socket produce undefined behaviour and may cause # jails to be permanently removed from the daemon. Serialising them here # ensures only one reload stream is in-flight at a time. _reload_all_lock: asyncio.Lock | None = None def _get_reload_all_lock() -> asyncio.Lock: """Return the shared reload-all lock, initialising it lazily. Asyncio primitives must be created inside an active event loop in test environments that create new loops per test. Lazily initialising the lock avoids binding it to the import-time loop. """ global _reload_all_lock if _reload_all_lock is None: _reload_all_lock = asyncio.Lock() return _reload_all_lock async def reload_all( socket_path: str, *, include_jails: list[str] | None = None, exclude_jails: list[str] | None = None, ) -> None: """Reload all fail2ban jails at once. Fetches the current jail list first so that a ``['start', name]`` entry can be included in the config stream for every active jail. Without a non-empty stream the end-of-reload phase deletes every jail that received no configuration commands. *include_jails* are added to the stream (e.g. a newly activated jail that is not yet running). *exclude_jails* are removed from the stream (e.g. a jail that was just deactivated and should not be restarted). Args: socket_path: Path to the fail2ban Unix domain socket. include_jails: Extra jail names to add to the start stream. exclude_jails: Jail names to remove from the start stream. Raises: JailNotFoundError: If a jail in *include_jails* does not exist or its configuration is invalid (e.g. missing logpath). JailOperationError: If fail2ban reports the operation failed for a different reason. ~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket cannot be reached. """ client = Fail2BanClient(socket_path=socket_path, timeout=SOCKET_TIMEOUT) async with _get_reload_all_lock(): try: # Resolve jail names so we can build the minimal config stream. status_raw = ok(await client.send(["status"])) status_dict = to_dict(status_raw) jail_list_raw: str = str(status_dict.get("Jail list", "")) jail_names = [n.strip() for n in jail_list_raw.split(",") if n.strip()] # Merge include/exclude sets so the stream matches the desired state. names_set: set[str] = set(jail_names) if include_jails: names_set.update(include_jails) if exclude_jails: names_set -= set(exclude_jails) stream: list[list[object]] = [["start", n] for n in sorted(names_set)] ok(await client.send(["reload", "--all", [], cast("Fail2BanToken", stream)])) log.info("all_jails_reloaded") except ValueError as exc: # Detect UnknownJailException (missing or invalid jail configuration) # and re-raise as JailNotFoundError for better error specificity. if is_not_found_error(exc): # Extract the jail name from include_jails if available. jail_name = include_jails[0] if include_jails else "unknown" raise JailNotFoundError(jail_name) from exc raise JailOperationError(str(exc)) from exc