Files
BanGUI/backend/app/utils/jail_socket.py
Lukas 83452ffc23 Refactor backend services and jail configuration
- Refactor action_config_service, filter_config_service, jail_config_service, and jail_service
- Add jail_socket utility module for socket communication
- Update test_jail_service with new test cases
- Update architecture and task documentation

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-04-25 18:34:03 +02:00

109 lines
4.2 KiB
Python

"""Low-level socket operations for jail management.
Provides shared socket utilities for reloading jails across all services.
These operations are extracted to a utility layer to avoid circular dependencies
between sibling services (jail_service, jail_config_service, action_config_service,
filter_config_service).
"""
from __future__ import annotations
import asyncio
from typing import cast
import structlog
from app.exceptions import JailNotFoundError, JailOperationError
from app.utils.fail2ban_client import (
Fail2BanClient,
Fail2BanToken,
)
from app.utils.fail2ban_response import (
is_not_found_error,
ok,
to_dict,
)
log: structlog.stdlib.BoundLogger = structlog.get_logger()
# Socket communication timeout in seconds.
SOCKET_TIMEOUT: float = 10.0
# Guard against concurrent reload_all calls. Overlapping ``reload --all``
# commands sent to fail2ban's socket produce undefined behaviour and may cause
# jails to be permanently removed from the daemon. Serialising them here
# ensures only one reload stream is in-flight at a time.
_reload_all_lock: asyncio.Lock | None = None
def _get_reload_all_lock() -> asyncio.Lock:
"""Return the shared reload-all lock, initialising it lazily.
Asyncio primitives must be created inside an active event loop in test
environments that create new loops per test. Lazily initialising the lock
avoids binding it to the import-time loop.
"""
global _reload_all_lock
if _reload_all_lock is None:
_reload_all_lock = asyncio.Lock()
return _reload_all_lock
async def reload_all(
socket_path: str,
*,
include_jails: list[str] | None = None,
exclude_jails: list[str] | None = None,
) -> None:
"""Reload all fail2ban jails at once.
Fetches the current jail list first so that a ``['start', name]`` entry
can be included in the config stream for every active jail. Without a
non-empty stream the end-of-reload phase deletes every jail that received
no configuration commands.
*include_jails* are added to the stream (e.g. a newly activated jail that
is not yet running). *exclude_jails* are removed from the stream (e.g. a
jail that was just deactivated and should not be restarted).
Args:
socket_path: Path to the fail2ban Unix domain socket.
include_jails: Extra jail names to add to the start stream.
exclude_jails: Jail names to remove from the start stream.
Raises:
JailNotFoundError: If a jail in *include_jails* does not exist or
its configuration is invalid (e.g. missing logpath).
JailOperationError: If fail2ban reports the operation failed for
a different reason.
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
cannot be reached.
"""
client = Fail2BanClient(socket_path=socket_path, timeout=SOCKET_TIMEOUT)
async with _get_reload_all_lock():
try:
# Resolve jail names so we can build the minimal config stream.
status_raw = ok(await client.send(["status"]))
status_dict = to_dict(status_raw)
jail_list_raw: str = str(status_dict.get("Jail list", ""))
jail_names = [n.strip() for n in jail_list_raw.split(",") if n.strip()]
# Merge include/exclude sets so the stream matches the desired state.
names_set: set[str] = set(jail_names)
if include_jails:
names_set.update(include_jails)
if exclude_jails:
names_set -= set(exclude_jails)
stream: list[list[object]] = [["start", n] for n in sorted(names_set)]
ok(await client.send(["reload", "--all", [], cast("Fail2BanToken", stream)]))
log.info("all_jails_reloaded")
except ValueError as exc:
# Detect UnknownJailException (missing or invalid jail configuration)
# and re-raise as JailNotFoundError for better error specificity.
if is_not_found_error(exc):
# Extract the jail name from include_jails if available.
jail_name = include_jails[0] if include_jails else "unknown"
raise JailNotFoundError(jail_name) from exc
raise JailOperationError(str(exc)) from exc