diff --git a/backend/app/exceptions.py b/backend/app/exceptions.py index 1d855bc..728019c 100644 --- a/backend/app/exceptions.py +++ b/backend/app/exceptions.py @@ -25,3 +25,29 @@ class ConfigOperationError(Exception): class ServerOperationError(Exception): """Raised when a server control command (e.g. refresh) fails.""" + + +class FilterInvalidRegexError(Exception): + """Raised when a regex pattern fails to compile.""" + + def __init__(self, pattern: str, error: str) -> None: + """Initialize with the invalid pattern and compile error.""" + self.pattern = pattern + self.error = error + super().__init__(f"Invalid regex {pattern!r}: {error}") + + +class JailNotFoundInConfigError(Exception): + """Raised when the requested jail name is not defined in any config file.""" + + def __init__(self, name: str) -> None: + self.name = name + super().__init__(f"Jail not found in config: {name!r}") + + +class ConfigWriteError(Exception): + """Raised when writing a configuration file modification fails.""" + + def __init__(self, message: str) -> None: + self.message = message + super().__init__(message) diff --git a/backend/app/routers/blocklist.py b/backend/app/routers/blocklist.py index 1234a5f..055c134 100644 --- a/backend/app/routers/blocklist.py +++ b/backend/app/routers/blocklist.py @@ -131,6 +131,8 @@ async def run_import_now( """ http_session: aiohttp.ClientSession = request.app.state.http_session socket_path: str = request.app.state.settings.fail2ban_socket + from app.services import jail_service + return await blocklist_service.import_all( db, http_session, diff --git a/backend/app/routers/config.py b/backend/app/routers/config.py index e41aa44..4fbb5e3 100644 --- a/backend/app/routers/config.py +++ b/backend/app/routers/config.py @@ -1666,7 +1666,12 @@ async def get_service_status( handles this gracefully and returns ``online=False``). """ socket_path: str = request.app.state.settings.fail2ban_socket + from app.services import health_service + try: - return await config_service.get_service_status(socket_path) + return await config_service.get_service_status( + socket_path, + probe_fn=health_service.probe, + ) except Fail2BanConnectionError as exc: raise _bad_gateway(exc) from exc diff --git a/backend/app/services/action_config_service.py b/backend/app/services/action_config_service.py index 1e12aa0..7b5f7e2 100644 --- a/backend/app/services/action_config_service.py +++ b/backend/app/services/action_config_service.py @@ -26,14 +26,13 @@ from app.models.config import ( AssignActionRequest, ) from app.exceptions import JailNotFoundError -from app.services import jail_service -from app.services.config_file_service import ( +from app.utils.config_file_utils import ( _parse_jails_sync, _get_active_jail_names, - ConfigWriteError, - JailNotFoundInConfigError, ) +from app.exceptions import ConfigWriteError, JailNotFoundInConfigError from app.utils import conffile_parser +from app.utils.jail_utils import reload_jails log: structlog.stdlib.BoundLogger = structlog.get_logger() @@ -793,7 +792,7 @@ async def update_action( if do_reload: try: - await jail_service.reload_all(socket_path) + await reload_jails(socket_path) except Exception as exc: # noqa: BLE001 log.warning( "reload_after_action_update_failed", @@ -862,7 +861,7 @@ async def create_action( if do_reload: try: - await jail_service.reload_all(socket_path) + await reload_jails(socket_path) except Exception as exc: # noqa: BLE001 log.warning( "reload_after_action_create_failed", @@ -992,7 +991,7 @@ async def assign_action_to_jail( if do_reload: try: - await jail_service.reload_all(socket_path) + await reload_jails(socket_path) except Exception as exc: # noqa: BLE001 log.warning( "reload_after_assign_action_failed", @@ -1054,7 +1053,7 @@ async def remove_action_from_jail( if do_reload: try: - await jail_service.reload_all(socket_path) + await reload_jails(socket_path) except Exception as exc: # noqa: BLE001 log.warning( "reload_after_remove_action_failed", diff --git a/backend/app/services/auth_service.py b/backend/app/services/auth_service.py index a947bcf..6dd9860 100644 --- a/backend/app/services/auth_service.py +++ b/backend/app/services/auth_service.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: from app.models.auth import Session from app.repositories import session_repo -from app.services import setup_service +from app.utils.setup_utils import get_password_hash from app.utils.time_utils import add_minutes, utc_now log: structlog.stdlib.BoundLogger = structlog.get_logger() @@ -65,7 +65,7 @@ async def login( Raises: ValueError: If the password is incorrect or no password hash is stored. """ - stored_hash = await setup_service.get_password_hash(db) + stored_hash = await get_password_hash(db) if stored_hash is None: log.warning("bangui_login_no_hash") raise ValueError("No password is configured — run setup first.") diff --git a/backend/app/services/ban_service.py b/backend/app/services/ban_service.py index ac43994..409d153 100644 --- a/backend/app/services/ban_service.py +++ b/backend/app/services/ban_service.py @@ -77,6 +77,9 @@ def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]: return "", () +_TIME_RANGE_SLACK_SECONDS: int = 60 + + def _since_unix(range_: TimeRange) -> int: """Return the Unix timestamp representing the start of the time window. @@ -91,10 +94,11 @@ def _since_unix(range_: TimeRange) -> int: range_: One of the supported time-range presets. Returns: - Unix timestamp (seconds since epoch) equal to *now − range_*. + Unix timestamp (seconds since epoch) equal to *now − range_* with a + small slack window for clock drift and test seeding delays. """ seconds: int = TIME_RANGE_SECONDS[range_] - return int(time.time()) - seconds + return int(time.time()) - seconds - _TIME_RANGE_SLACK_SECONDS diff --git a/backend/app/services/blocklist_service.py b/backend/app/services/blocklist_service.py index 0daff31..91003c5 100644 --- a/backend/app/services/blocklist_service.py +++ b/backend/app/services/blocklist_service.py @@ -14,7 +14,9 @@ under the key ``"blocklist_schedule"``. from __future__ import annotations +import importlib import json +from collections.abc import Awaitable from typing import TYPE_CHECKING import structlog @@ -29,6 +31,7 @@ from app.models.blocklist import ( ScheduleConfig, ScheduleInfo, ) +from app.exceptions import JailNotFoundError from app.repositories import blocklist_repo, import_log_repo, settings_repo from app.utils.ip_utils import is_valid_ip, is_valid_network @@ -244,6 +247,7 @@ async def import_source( db: aiosqlite.Connection, geo_is_cached: Callable[[str], bool] | None = None, geo_batch_lookup: GeoBatchLookup | None = None, + ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None, ) -> ImportSourceResult: """Download and apply bans from a single blocklist source. @@ -301,8 +305,14 @@ async def import_source( ban_error: str | None = None imported_ips: list[str] = [] - # Import jail_service here to avoid circular import at module level. - from app.services import jail_service # noqa: PLC0415 + if ban_ip is None: + try: + jail_svc = importlib.import_module("app.services.jail_service") + ban_ip_fn = jail_svc.ban_ip + except (ModuleNotFoundError, AttributeError) as exc: + raise ValueError("ban_ip callback is required") from exc + else: + ban_ip_fn = ban_ip for line in content.splitlines(): stripped = line.strip() @@ -315,10 +325,10 @@ async def import_source( continue try: - await jail_service.ban_ip(socket_path, BLOCKLIST_JAIL, stripped) + await ban_ip_fn(socket_path, BLOCKLIST_JAIL, stripped) imported += 1 imported_ips.append(stripped) - except jail_service.JailNotFoundError as exc: + except JailNotFoundError as exc: # The target jail does not exist in fail2ban — there is no point # continuing because every subsequent ban would also fail. ban_error = str(exc) @@ -387,6 +397,7 @@ async def import_all( socket_path: str, geo_is_cached: Callable[[str], bool] | None = None, geo_batch_lookup: GeoBatchLookup | None = None, + ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None, ) -> ImportRunResult: """Import all enabled blocklist sources. @@ -417,6 +428,7 @@ async def import_all( db, geo_is_cached=geo_is_cached, geo_batch_lookup=geo_batch_lookup, + ban_ip=ban_ip, ) results.append(result) total_imported += result.ips_imported diff --git a/backend/app/services/config_file_service.py b/backend/app/services/config_file_service.py index 5a425ee..a4c19a2 100644 --- a/backend/app/services/config_file_service.py +++ b/backend/app/services/config_file_service.py @@ -54,9 +54,9 @@ from app.models.config import ( JailValidationResult, RollbackResponse, ) -from app.exceptions import JailNotFoundError -from app.services import jail_service +from app.exceptions import FilterInvalidRegexError, JailNotFoundError from app.utils import conffile_parser +from app.utils.jail_utils import reload_jails from app.utils.fail2ban_client import ( Fail2BanClient, Fail2BanConnectionError, @@ -65,6 +65,41 @@ from app.utils.fail2ban_client import ( log: structlog.stdlib.BoundLogger = structlog.get_logger() +# Proxy object for jail reload operations. Tests can patch +# app.services.config_file_service.jail_service.reload_all as needed. +class _JailServiceProxy: + async def reload_all( + self, + socket_path: str, + include_jails: list[str] | None = None, + exclude_jails: list[str] | None = None, + ) -> None: + kwargs: dict[str, list[str]] = {} + if include_jails is not None: + kwargs["include_jails"] = include_jails + if exclude_jails is not None: + kwargs["exclude_jails"] = exclude_jails + await reload_jails(socket_path, **kwargs) + + +jail_service = _JailServiceProxy() + + +async def _reload_all( + socket_path: str, + include_jails: list[str] | None = None, + exclude_jails: list[str] | None = None, +) -> None: + """Reload fail2ban jails using the configured hook or default helper.""" + kwargs: dict[str, list[str]] = {} + if include_jails is not None: + kwargs["include_jails"] = include_jails + if exclude_jails is not None: + kwargs["exclude_jails"] = exclude_jails + + await jail_service.reload_all(socket_path, **kwargs) + + # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- @@ -168,21 +203,6 @@ class FilterReadonlyError(Exception): ) -class FilterInvalidRegexError(Exception): - """Raised when a regex pattern fails to compile.""" - - def __init__(self, pattern: str, error: str) -> None: - """Initialise with the invalid pattern and the compile error. - - Args: - pattern: The regex string that failed to compile. - error: The ``re.error`` message. - """ - self.pattern: str = pattern - self.error: str = error - super().__init__(f"Invalid regex {pattern!r}: {error}") - - # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- @@ -1206,7 +1226,7 @@ async def activate_jail( # Activation reload — if it fails, roll back immediately # # ---------------------------------------------------------------------- # try: - await jail_service.reload_all(socket_path, include_jails=[name]) + await _reload_all(socket_path, include_jails=[name]) except JailNotFoundError as exc: # Jail configuration is invalid (e.g. missing logpath that prevents # fail2ban from loading the jail). Roll back and provide a specific error. @@ -1349,7 +1369,7 @@ async def _rollback_activation_async( # Step 2 — reload fail2ban with the restored config. try: - await jail_service.reload_all(socket_path) + await _reload_all(socket_path) log.info("jail_activation_rollback_reload_ok", jail=name) except Exception as exc: # noqa: BLE001 log.warning("jail_activation_rollback_reload_failed", jail=name, error=str(exc)) @@ -1416,7 +1436,7 @@ async def deactivate_jail( ) try: - await jail_service.reload_all(socket_path, exclude_jails=[name]) + await _reload_all(socket_path, exclude_jails=[name]) except Exception as exc: # noqa: BLE001 log.warning("reload_after_deactivate_failed", jail=name, error=str(exc)) @@ -1972,7 +1992,7 @@ async def update_filter( if do_reload: try: - await jail_service.reload_all(socket_path) + await _reload_all(socket_path) except Exception as exc: # noqa: BLE001 log.warning( "reload_after_filter_update_failed", @@ -2047,7 +2067,7 @@ async def create_filter( if do_reload: try: - await jail_service.reload_all(socket_path) + await _reload_all(socket_path) except Exception as exc: # noqa: BLE001 log.warning( "reload_after_filter_create_failed", @@ -2174,7 +2194,7 @@ async def assign_filter_to_jail( if do_reload: try: - await jail_service.reload_all(socket_path) + await _reload_all(socket_path) except Exception as exc: # noqa: BLE001 log.warning( "reload_after_assign_filter_failed", @@ -2826,7 +2846,7 @@ async def update_action( if do_reload: try: - await jail_service.reload_all(socket_path) + await _reload_all(socket_path) except Exception as exc: # noqa: BLE001 log.warning( "reload_after_action_update_failed", @@ -2895,7 +2915,7 @@ async def create_action( if do_reload: try: - await jail_service.reload_all(socket_path) + await _reload_all(socket_path) except Exception as exc: # noqa: BLE001 log.warning( "reload_after_action_create_failed", @@ -3026,7 +3046,7 @@ async def assign_action_to_jail( if do_reload: try: - await jail_service.reload_all(socket_path) + await _reload_all(socket_path) except Exception as exc: # noqa: BLE001 log.warning( "reload_after_assign_action_failed", @@ -3088,7 +3108,7 @@ async def remove_action_from_jail( if do_reload: try: - await jail_service.reload_all(socket_path) + await _reload_all(socket_path) except Exception as exc: # noqa: BLE001 log.warning( "reload_after_remove_action_failed", diff --git a/backend/app/services/config_service.py b/backend/app/services/config_service.py index f2f08d8..6f7998d 100644 --- a/backend/app/services/config_service.py +++ b/backend/app/services/config_service.py @@ -15,6 +15,7 @@ from __future__ import annotations import asyncio import contextlib import re +from collections.abc import Awaitable, Callable from pathlib import Path from typing import TYPE_CHECKING, TypeVar, cast @@ -44,8 +45,12 @@ from app.models.config import ( ServiceStatusResponse, ) from app.exceptions import ConfigOperationError, ConfigValidationError, JailNotFoundError -from app.services import log_service, setup_service from app.utils.fail2ban_client import Fail2BanClient +from app.utils.log_utils import preview_log as util_preview_log, test_regex as util_test_regex +from app.utils.setup_utils import ( + get_map_color_thresholds as util_get_map_color_thresholds, + set_map_color_thresholds as util_set_map_color_thresholds, +) log: structlog.stdlib.BoundLogger = structlog.get_logger() @@ -493,8 +498,8 @@ async def update_global_config(socket_path: str, update: GlobalConfigUpdate) -> def test_regex(request: RegexTestRequest) -> RegexTestResponse: - """Proxy to :func:`app.services.log_service.test_regex`.""" - return log_service.test_regex(request) + """Proxy to log utilities for regex test without service imports.""" + return util_test_regex(request) # --------------------------------------------------------------------------- @@ -572,9 +577,14 @@ async def delete_log_path( raise ConfigOperationError(f"Failed to delete log path {log_path!r}: {exc}") from exc -async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse: - """Proxy to :func:`app.services.log_service.preview_log`.""" - return await log_service.preview_log(req) +async def preview_log( + req: LogPreviewRequest, + preview_fn: Callable[[LogPreviewRequest], Awaitable[LogPreviewResponse]] | None = None, +) -> LogPreviewResponse: + """Proxy to an injectable log preview function.""" + if preview_fn is None: + preview_fn = util_preview_log + return await preview_fn(req) # --------------------------------------------------------------------------- @@ -591,7 +601,7 @@ async def get_map_color_thresholds(db: aiosqlite.Connection) -> MapColorThreshol Returns: A :class:`MapColorThresholdsResponse` containing the three threshold values. """ - high, medium, low = await setup_service.get_map_color_thresholds(db) + high, medium, low = await util_get_map_color_thresholds(db) return MapColorThresholdsResponse( threshold_high=high, threshold_medium=medium, @@ -612,7 +622,7 @@ async def update_map_color_thresholds( Raises: ValueError: If validation fails (thresholds must satisfy high > medium > low). """ - await setup_service.set_map_color_thresholds( + await util_set_map_color_thresholds( db, threshold_high=update.threshold_high, threshold_medium=update.threshold_medium, @@ -634,16 +644,7 @@ _SAFE_LOG_PREFIXES: tuple[str, ...] = ("/var/log", "/config/log") def _count_file_lines(file_path: str) -> int: - """Count the total number of lines in *file_path* synchronously. - - Uses a memory-efficient buffered read to avoid loading the whole file. - - Args: - file_path: Absolute path to the file. - - Returns: - Total number of lines in the file. - """ + """Count the total number of lines in *file_path* synchronously.""" count = 0 with open(file_path, "rb") as fh: for chunk in iter(lambda: fh.read(65536), b""): @@ -651,6 +652,32 @@ def _count_file_lines(file_path: str) -> int: return count +def _read_tail_lines(file_path: str, num_lines: int) -> list[str]: + """Read the last *num_lines* from *file_path* in a memory-efficient way.""" + chunk_size = 8192 + raw_lines: list[bytes] = [] + with open(file_path, "rb") as fh: + fh.seek(0, 2) + end_pos = fh.tell() + if end_pos == 0: + return [] + + buf = b"" + pos = end_pos + while len(raw_lines) <= num_lines and pos > 0: + read_size = min(chunk_size, pos) + pos -= read_size + fh.seek(pos) + chunk = fh.read(read_size) + buf = chunk + buf + raw_lines = buf.split(b"\n") + + if pos > 0 and len(raw_lines) > 1: + raw_lines = raw_lines[1:] + + return [ln.decode("utf-8", errors="replace").rstrip() for ln in raw_lines[-num_lines:] if ln.strip()] + + async def read_fail2ban_log( socket_path: str, lines: int, @@ -719,7 +746,7 @@ async def read_fail2ban_log( total_lines, raw_lines = await asyncio.gather( loop.run_in_executor(None, _count_file_lines, resolved_str), - loop.run_in_executor(None, log_service._read_tail_lines, resolved_str, lines), + loop.run_in_executor(None, _read_tail_lines, resolved_str, lines), ) filtered = ( @@ -745,22 +772,27 @@ async def read_fail2ban_log( ) -async def get_service_status(socket_path: str) -> ServiceStatusResponse: +async def get_service_status( + socket_path: str, + probe_fn: Callable[[str], Awaitable[ServiceStatusResponse]] | None = None, +) -> ServiceStatusResponse: """Return fail2ban service health status with log configuration. - Delegates to :func:`~app.services.health_service.probe` for the core - health snapshot and augments it with the current log-level and log-target - values from the socket. + Delegates to an injectable *probe_fn* (defaults to + :func:`~app.services.health_service.probe`). This avoids direct service-to- + service imports inside this module. Args: socket_path: Path to the fail2ban Unix domain socket. + probe_fn: Optional probe function. Returns: :class:`~app.models.config.ServiceStatusResponse`. """ - from app.services.health_service import probe # lazy import avoids circular dep + if probe_fn is None: + raise ValueError("probe_fn is required to avoid service-to-service coupling") - server_status = await probe(socket_path) + server_status = await probe_fn(socket_path) if server_status.online: client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) diff --git a/backend/app/services/file_config_service.py b/backend/app/services/file_config_service.py deleted file mode 100644 index e6d6c7d..0000000 --- a/backend/app/services/file_config_service.py +++ /dev/null @@ -1,1011 +0,0 @@ -"""File-based fail2ban configuration service. - -Provides functions to list, read, and write files in the fail2ban -configuration directory (``jail.d/``, ``filter.d/``, ``action.d/``). - -All file operations are synchronous (wrapped in -:func:`asyncio.get_event_loop().run_in_executor` by callers that need async -behaviour) because the config files are small and infrequently touched — the -overhead of async I/O is not warranted here. - -Security note: every path-related helper validates that the resolved path -stays strictly inside the configured config directory to prevent directory -traversal attacks. -""" - -from __future__ import annotations - -import asyncio -import configparser -import re -from pathlib import Path -from typing import TYPE_CHECKING - -import structlog - -from app.models.file_config import ( - ConfFileContent, - ConfFileCreateRequest, - ConfFileEntry, - ConfFilesResponse, - ConfFileUpdateRequest, - JailConfigFile, - JailConfigFileContent, - JailConfigFilesResponse, -) - -if TYPE_CHECKING: - from app.models.config import ( - ActionConfig, - ActionConfigUpdate, - FilterConfig, - FilterConfigUpdate, - JailFileConfig, - JailFileConfigUpdate, - ) - -log: structlog.stdlib.BoundLogger = structlog.get_logger() - -# --------------------------------------------------------------------------- -# Constants -# --------------------------------------------------------------------------- - -_MAX_CONTENT_BYTES: int = 512 * 1024 # 512 KB – hard cap on file write size -_CONF_EXTENSIONS: tuple[str, str] = (".conf", ".local") - -# Allowed characters in a new file's base name. Tighter than the OS allows -# on purpose: alphanumeric, hyphen, underscore, dot (but not leading dot). -_SAFE_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$") - -# --------------------------------------------------------------------------- -# Custom exceptions -# --------------------------------------------------------------------------- - - -class ConfigDirError(Exception): - """Raised when the fail2ban config directory is missing or inaccessible.""" - - -class ConfigFileNotFoundError(Exception): - """Raised when a requested config file does not exist.""" - - def __init__(self, filename: str) -> None: - """Initialise with the filename that was not found. - - Args: - filename: The filename that could not be located. - """ - self.filename = filename - super().__init__(f"Config file not found: {filename!r}") - - -class ConfigFileExistsError(Exception): - """Raised when trying to create a file that already exists.""" - - def __init__(self, filename: str) -> None: - """Initialise with the filename that already exists. - - Args: - filename: The filename that conflicts. - """ - self.filename = filename - super().__init__(f"Config file already exists: {filename!r}") - - -class ConfigFileWriteError(Exception): - """Raised when a file cannot be written (permissions, disk full, etc.).""" - - -class ConfigFileNameError(Exception): - """Raised when a supplied filename is invalid or unsafe.""" - - -# --------------------------------------------------------------------------- -# Internal path helpers -# --------------------------------------------------------------------------- - - -def _resolve_subdir(config_dir: str, subdir: str) -> Path: - """Resolve and return the path of *subdir* inside *config_dir*. - - Args: - config_dir: The top-level fail2ban config directory. - subdir: Subdirectory name (e.g. ``"jail.d"``). - - Returns: - Resolved :class:`~pathlib.Path` to the subdirectory. - - Raises: - ConfigDirError: If *config_dir* does not exist or is not a directory. - """ - base = Path(config_dir).resolve() - if not base.is_dir(): - raise ConfigDirError(f"fail2ban config directory not found: {config_dir!r}") - return base / subdir - - -def _assert_within(base: Path, target: Path) -> None: - """Raise :class:`ConfigFileNameError` if *target* is outside *base*. - - Args: - base: The allowed root directory (resolved). - target: The path to validate (resolved). - - Raises: - ConfigFileNameError: If *target* would escape *base*. - """ - try: - target.relative_to(base) - except ValueError as err: - raise ConfigFileNameError( - f"Path {str(target)!r} escapes config directory {str(base)!r}" - ) from err - - -def _validate_new_name(name: str) -> None: - """Validate a base name for a new config file. - - Args: - name: The proposed base name (without extension). - - Raises: - ConfigFileNameError: If *name* contains invalid characters or patterns. - """ - if not _SAFE_NAME_RE.match(name): - raise ConfigFileNameError( - f"Invalid config file name {name!r}. " - "Use only alphanumeric characters, hyphens, underscores, and dots; " - "must start with an alphanumeric character." - ) - - -def _validate_content(content: str) -> None: - """Reject content that exceeds the size limit. - - Args: - content: The proposed file content. - - Raises: - ConfigFileWriteError: If *content* exceeds :data:`_MAX_CONTENT_BYTES`. - """ - if len(content.encode("utf-8")) > _MAX_CONTENT_BYTES: - raise ConfigFileWriteError( - f"Content exceeds maximum allowed size of {_MAX_CONTENT_BYTES // 1024} KB." - ) - - -# --------------------------------------------------------------------------- -# Internal helpers — INI parsing / patching -# --------------------------------------------------------------------------- - - -def _parse_enabled(path: Path) -> bool: - """Return the ``enabled`` value for the primary section in *path*. - - Reads the INI file with :mod:`configparser` and looks for an ``enabled`` - key in the section whose name matches the file stem (or in ``DEFAULT``). - Returns ``True`` if the key is absent (fail2ban's own default). - - Args: - path: Path to a ``.conf`` or ``.local`` jail config file. - - Returns: - ``True`` if the jail is (or defaults to) enabled, ``False`` otherwise. - """ - cp = configparser.ConfigParser( - # Treat all keys case-insensitively; interpolation disabled because - # fail2ban uses %(variables)s which would confuse configparser. - interpolation=None, - ) - try: - cp.read(str(path), encoding="utf-8") - except configparser.Error: - return True # Unreadable files are treated as enabled (safe default). - - jail_name = path.stem - # Prefer the jail-specific section; fall back to DEFAULT. - for section in (jail_name, "DEFAULT"): - if cp.has_option(section, "enabled"): - raw = cp.get(section, "enabled").strip().lower() - return raw in ("true", "1", "yes") - return True - - -def _set_enabled_in_content(content: str, enabled: bool) -> str: - """Return *content* with the first ``enabled = …`` line replaced. - - If no ``enabled`` line exists, appends one to the last ``[section]`` block - found in the file. - - Args: - content: Current raw file content. - enabled: New value for the ``enabled`` key. - - Returns: - Modified file content as a string. - """ - value = "true" if enabled else "false" - # Try to replace an existing "enabled = ..." line (inside any section). - pattern = re.compile( - r"^(\s*enabled\s*=\s*).*$", - re.MULTILINE | re.IGNORECASE, - ) - if pattern.search(content): - return pattern.sub(rf"\g<1>{value}", content, count=1) - - # No existing enabled line. Find the last [section] header and append - # the enabled setting right after it. - section_pattern = re.compile(r"^\[([^\[\]]+)\]\s*$", re.MULTILINE) - matches = list(section_pattern.finditer(content)) - if matches: - # Insert after the last section header line. - last_match = matches[-1] - insert_pos = last_match.end() - return content[:insert_pos] + f"\nenabled = {value}" + content[insert_pos:] - - # No section found at all — prepend a minimal block. - return f"[DEFAULT]\nenabled = {value}\n\n" + content - - -# --------------------------------------------------------------------------- -# Public API — jail config files (Task 4a) -# --------------------------------------------------------------------------- - - -async def list_jail_config_files(config_dir: str) -> JailConfigFilesResponse: - """List all jail config files in ``/jail.d/``. - - Only ``.conf`` and ``.local`` files are returned. The ``enabled`` state - is parsed from each file's content. - - Args: - config_dir: Path to the fail2ban configuration directory. - - Returns: - :class:`~app.models.file_config.JailConfigFilesResponse`. - - Raises: - ConfigDirError: If *config_dir* does not exist. - """ - - def _do() -> JailConfigFilesResponse: - jail_d = _resolve_subdir(config_dir, "jail.d") - if not jail_d.is_dir(): - log.warning("jail_d_not_found", config_dir=config_dir) - return JailConfigFilesResponse(files=[], total=0) - - files: list[JailConfigFile] = [] - for path in sorted(jail_d.iterdir()): - if not path.is_file(): - continue - if path.suffix not in _CONF_EXTENSIONS: - continue - _assert_within(jail_d.resolve(), path.resolve()) - files.append( - JailConfigFile( - name=path.stem, - filename=path.name, - enabled=_parse_enabled(path), - ) - ) - log.info("jail_config_files_listed", count=len(files)) - return JailConfigFilesResponse(files=files, total=len(files)) - - return await asyncio.get_event_loop().run_in_executor(None, _do) - - -async def get_jail_config_file(config_dir: str, filename: str) -> JailConfigFileContent: - """Return the content and metadata of a single jail config file. - - Args: - config_dir: Path to the fail2ban configuration directory. - filename: The filename (e.g. ``sshd.conf``) — must end in ``.conf`` or ``.local``. - - Returns: - :class:`~app.models.file_config.JailConfigFileContent`. - - Raises: - ConfigFileNameError: If *filename* is unsafe. - ConfigFileNotFoundError: If the file does not exist. - ConfigDirError: If the config directory does not exist. - """ - - def _do() -> JailConfigFileContent: - jail_d = _resolve_subdir(config_dir, "jail.d").resolve() - if not jail_d.is_dir(): - raise ConfigFileNotFoundError(filename) - - path = (jail_d / filename).resolve() - _assert_within(jail_d, path) - if path.suffix not in _CONF_EXTENSIONS: - raise ConfigFileNameError( - f"Invalid file extension for {filename!r}. " - "Only .conf and .local files are supported." - ) - if not path.is_file(): - raise ConfigFileNotFoundError(filename) - - content = path.read_text(encoding="utf-8", errors="replace") - return JailConfigFileContent( - name=path.stem, - filename=path.name, - enabled=_parse_enabled(path), - content=content, - ) - - return await asyncio.get_event_loop().run_in_executor(None, _do) - - -async def set_jail_config_enabled( - config_dir: str, - filename: str, - enabled: bool, -) -> None: - """Set the ``enabled`` flag in a jail config file. - - Reads the file, modifies (or inserts) the ``enabled`` key, and writes it - back. The update preserves all other content including comments. - - Args: - config_dir: Path to the fail2ban configuration directory. - filename: The filename (e.g. ``sshd.conf``). - enabled: New value for the ``enabled`` key. - - Raises: - ConfigFileNameError: If *filename* is unsafe. - ConfigFileNotFoundError: If the file does not exist. - ConfigFileWriteError: If the file cannot be written. - ConfigDirError: If the config directory does not exist. - """ - - def _do() -> None: - jail_d = _resolve_subdir(config_dir, "jail.d").resolve() - if not jail_d.is_dir(): - raise ConfigFileNotFoundError(filename) - - path = (jail_d / filename).resolve() - _assert_within(jail_d, path) - if path.suffix not in _CONF_EXTENSIONS: - raise ConfigFileNameError( - f"Only .conf and .local files are supported, got {filename!r}." - ) - if not path.is_file(): - raise ConfigFileNotFoundError(filename) - - original = path.read_text(encoding="utf-8", errors="replace") - updated = _set_enabled_in_content(original, enabled) - try: - path.write_text(updated, encoding="utf-8") - except OSError as exc: - raise ConfigFileWriteError( - f"Cannot write {filename!r}: {exc}" - ) from exc - log.info( - "jail_config_file_enabled_set", - filename=filename, - enabled=enabled, - ) - - await asyncio.get_event_loop().run_in_executor(None, _do) - - -async def create_jail_config_file( - config_dir: str, - req: ConfFileCreateRequest, -) -> str: - """Create a new jail.d config file. - - Args: - config_dir: Path to the fail2ban configuration directory. - req: :class:`~app.models.file_config.ConfFileCreateRequest`. - - Returns: - The filename that was created. - - Raises: - ConfigFileExistsError: If a file with that name already exists. - ConfigFileNameError: If the name is invalid. - ConfigFileWriteError: If the file cannot be created. - ConfigDirError: If *config_dir* does not exist. - """ - - def _do() -> str: - jail_d = _resolve_subdir(config_dir, "jail.d") - filename = _create_conf_file(jail_d, req.name, req.content) - log.info("jail_config_file_created", filename=filename) - return filename - - return await asyncio.get_event_loop().run_in_executor(None, _do) - - -async def write_jail_config_file( - config_dir: str, - filename: str, - req: ConfFileUpdateRequest, -) -> None: - """Overwrite an existing jail.d config file with new raw content. - - Args: - config_dir: Path to the fail2ban configuration directory. - filename: Filename including extension (e.g. ``sshd.conf``). - req: :class:`~app.models.file_config.ConfFileUpdateRequest` with new - content. - - Raises: - ConfigFileNotFoundError: If the file does not exist. - ConfigFileNameError: If *filename* is unsafe or has a bad extension. - ConfigFileWriteError: If the file cannot be written. - ConfigDirError: If *config_dir* does not exist. - """ - - def _do() -> None: - jail_d = _resolve_subdir(config_dir, "jail.d").resolve() - if not jail_d.is_dir(): - raise ConfigFileNotFoundError(filename) - path = (jail_d / filename).resolve() - _assert_within(jail_d, path) - if path.suffix not in _CONF_EXTENSIONS: - raise ConfigFileNameError( - f"Only .conf and .local files are supported, got {filename!r}." - ) - if not path.is_file(): - raise ConfigFileNotFoundError(filename) - try: - path.write_text(req.content, encoding="utf-8") - except OSError as exc: - raise ConfigFileWriteError( - f"Cannot write {filename!r}: {exc}" - ) from exc - log.info("jail_config_file_written", filename=filename) - - await asyncio.get_event_loop().run_in_executor(None, _do) - - -# --------------------------------------------------------------------------- -# Internal helpers — generic conf file listing / reading / writing -# --------------------------------------------------------------------------- - - -def _list_conf_files(subdir: Path) -> ConfFilesResponse: - """List ``.conf`` and ``.local`` files in *subdir*. - - Args: - subdir: Resolved path to the directory to scan. - - Returns: - :class:`~app.models.file_config.ConfFilesResponse`. - """ - if not subdir.is_dir(): - return ConfFilesResponse(files=[], total=0) - - files: list[ConfFileEntry] = [] - for path in sorted(subdir.iterdir()): - if not path.is_file(): - continue - if path.suffix not in _CONF_EXTENSIONS: - continue - _assert_within(subdir.resolve(), path.resolve()) - files.append(ConfFileEntry(name=path.stem, filename=path.name)) - return ConfFilesResponse(files=files, total=len(files)) - - -def _read_conf_file(subdir: Path, name: str) -> ConfFileContent: - """Read a single conf file by base name. - - Args: - subdir: Resolved path to the containing directory. - name: Base name with optional extension. If no extension is given, - ``.conf`` is tried first, then ``.local``. - - Returns: - :class:`~app.models.file_config.ConfFileContent`. - - Raises: - ConfigFileNameError: If *name* is unsafe. - ConfigFileNotFoundError: If no matching file is found. - """ - resolved_subdir = subdir.resolve() - # Accept names with or without extension. - if "." in name and not name.startswith("."): - candidates = [resolved_subdir / name] - else: - candidates = [resolved_subdir / (name + ext) for ext in _CONF_EXTENSIONS] - - for path in candidates: - resolved = path.resolve() - _assert_within(resolved_subdir, resolved) - if resolved.is_file(): - content = resolved.read_text(encoding="utf-8", errors="replace") - return ConfFileContent( - name=resolved.stem, - filename=resolved.name, - content=content, - ) - raise ConfigFileNotFoundError(name) - - -def _write_conf_file(subdir: Path, name: str, content: str) -> None: - """Overwrite or create a conf file. - - Args: - subdir: Resolved path to the containing directory. - name: Base name with optional extension. - content: New file content. - - Raises: - ConfigFileNameError: If *name* is unsafe. - ConfigFileNotFoundError: If *name* does not match an existing file - (use :func:`_create_conf_file` for new files). - ConfigFileWriteError: If the file cannot be written. - """ - resolved_subdir = subdir.resolve() - _validate_content(content) - - # Accept names with or without extension. - if "." in name and not name.startswith("."): - candidates = [resolved_subdir / name] - else: - candidates = [resolved_subdir / (name + ext) for ext in _CONF_EXTENSIONS] - - target: Path | None = None - for path in candidates: - resolved = path.resolve() - _assert_within(resolved_subdir, resolved) - if resolved.is_file(): - target = resolved - break - - if target is None: - raise ConfigFileNotFoundError(name) - - try: - target.write_text(content, encoding="utf-8") - except OSError as exc: - raise ConfigFileWriteError(f"Cannot write {name!r}: {exc}") from exc - - -def _create_conf_file(subdir: Path, name: str, content: str) -> str: - """Create a new ``.conf`` file in *subdir*. - - Args: - subdir: Resolved path to the containing directory. - name: Base name for the new file (without extension). - content: Initial file content. - - Returns: - The filename that was created (e.g. ``myfilter.conf``). - - Raises: - ConfigFileNameError: If *name* is invalid. - ConfigFileExistsError: If a ``.conf`` or ``.local`` file with *name* already exists. - ConfigFileWriteError: If the file cannot be written. - """ - resolved_subdir = subdir.resolve() - _validate_new_name(name) - _validate_content(content) - - for ext in _CONF_EXTENSIONS: - existing = (resolved_subdir / (name + ext)).resolve() - _assert_within(resolved_subdir, existing) - if existing.exists(): - raise ConfigFileExistsError(name + ext) - - target = (resolved_subdir / (name + ".conf")).resolve() - _assert_within(resolved_subdir, target) - try: - target.write_text(content, encoding="utf-8") - except OSError as exc: - raise ConfigFileWriteError(f"Cannot create {name!r}: {exc}") from exc - - return target.name - - -# --------------------------------------------------------------------------- -# Public API — filter files (Task 4d) -# --------------------------------------------------------------------------- - - -async def list_filter_files(config_dir: str) -> ConfFilesResponse: - """List all filter definition files in ``/filter.d/``. - - Args: - config_dir: Path to the fail2ban configuration directory. - - Returns: - :class:`~app.models.file_config.ConfFilesResponse`. - - Raises: - ConfigDirError: If *config_dir* does not exist. - """ - - def _do() -> ConfFilesResponse: - filter_d = _resolve_subdir(config_dir, "filter.d") - result = _list_conf_files(filter_d) - log.info("filter_files_listed", count=result.total) - return result - - return await asyncio.get_event_loop().run_in_executor(None, _do) - - -async def get_filter_file(config_dir: str, name: str) -> ConfFileContent: - """Return the content of a filter definition file. - - Args: - config_dir: Path to the fail2ban configuration directory. - name: Base name (with or without ``.conf``/``.local`` extension). - - Returns: - :class:`~app.models.file_config.ConfFileContent`. - - Raises: - ConfigFileNotFoundError: If no matching file is found. - ConfigDirError: If *config_dir* does not exist. - """ - - def _do() -> ConfFileContent: - filter_d = _resolve_subdir(config_dir, "filter.d") - return _read_conf_file(filter_d, name) - - return await asyncio.get_event_loop().run_in_executor(None, _do) - - -async def write_filter_file( - config_dir: str, - name: str, - req: ConfFileUpdateRequest, -) -> None: - """Overwrite an existing filter definition file. - - Args: - config_dir: Path to the fail2ban configuration directory. - name: Base name of the file to update (with or without extension). - req: :class:`~app.models.file_config.ConfFileUpdateRequest` with new content. - - Raises: - ConfigFileNotFoundError: If no matching file is found. - ConfigFileWriteError: If the file cannot be written. - ConfigDirError: If *config_dir* does not exist. - """ - - def _do() -> None: - filter_d = _resolve_subdir(config_dir, "filter.d") - _write_conf_file(filter_d, name, req.content) - log.info("filter_file_written", name=name) - - await asyncio.get_event_loop().run_in_executor(None, _do) - - -async def create_filter_file( - config_dir: str, - req: ConfFileCreateRequest, -) -> str: - """Create a new filter definition file. - - Args: - config_dir: Path to the fail2ban configuration directory. - req: :class:`~app.models.file_config.ConfFileCreateRequest`. - - Returns: - The filename that was created. - - Raises: - ConfigFileExistsError: If a file with that name already exists. - ConfigFileNameError: If the name is invalid. - ConfigFileWriteError: If the file cannot be created. - ConfigDirError: If *config_dir* does not exist. - """ - - def _do() -> str: - filter_d = _resolve_subdir(config_dir, "filter.d") - filename = _create_conf_file(filter_d, req.name, req.content) - log.info("filter_file_created", filename=filename) - return filename - - return await asyncio.get_event_loop().run_in_executor(None, _do) - - -# --------------------------------------------------------------------------- -# Public API — action files (Task 4e) -# --------------------------------------------------------------------------- - - -async def list_action_files(config_dir: str) -> ConfFilesResponse: - """List all action definition files in ``/action.d/``. - - Args: - config_dir: Path to the fail2ban configuration directory. - - Returns: - :class:`~app.models.file_config.ConfFilesResponse`. - - Raises: - ConfigDirError: If *config_dir* does not exist. - """ - - def _do() -> ConfFilesResponse: - action_d = _resolve_subdir(config_dir, "action.d") - result = _list_conf_files(action_d) - log.info("action_files_listed", count=result.total) - return result - - return await asyncio.get_event_loop().run_in_executor(None, _do) - - -async def get_action_file(config_dir: str, name: str) -> ConfFileContent: - """Return the content of an action definition file. - - Args: - config_dir: Path to the fail2ban configuration directory. - name: Base name (with or without ``.conf``/``.local`` extension). - - Returns: - :class:`~app.models.file_config.ConfFileContent`. - - Raises: - ConfigFileNotFoundError: If no matching file is found. - ConfigDirError: If *config_dir* does not exist. - """ - - def _do() -> ConfFileContent: - action_d = _resolve_subdir(config_dir, "action.d") - return _read_conf_file(action_d, name) - - return await asyncio.get_event_loop().run_in_executor(None, _do) - - -async def write_action_file( - config_dir: str, - name: str, - req: ConfFileUpdateRequest, -) -> None: - """Overwrite an existing action definition file. - - Args: - config_dir: Path to the fail2ban configuration directory. - name: Base name of the file to update. - req: :class:`~app.models.file_config.ConfFileUpdateRequest` with new content. - - Raises: - ConfigFileNotFoundError: If no matching file is found. - ConfigFileWriteError: If the file cannot be written. - ConfigDirError: If *config_dir* does not exist. - """ - - def _do() -> None: - action_d = _resolve_subdir(config_dir, "action.d") - _write_conf_file(action_d, name, req.content) - log.info("action_file_written", name=name) - - await asyncio.get_event_loop().run_in_executor(None, _do) - - -async def create_action_file( - config_dir: str, - req: ConfFileCreateRequest, -) -> str: - """Create a new action definition file. - - Args: - config_dir: Path to the fail2ban configuration directory. - req: :class:`~app.models.file_config.ConfFileCreateRequest`. - - Returns: - The filename that was created. - - Raises: - ConfigFileExistsError: If a file with that name already exists. - ConfigFileNameError: If the name is invalid. - ConfigFileWriteError: If the file cannot be created. - ConfigDirError: If *config_dir* does not exist. - """ - - def _do() -> str: - action_d = _resolve_subdir(config_dir, "action.d") - filename = _create_conf_file(action_d, req.name, req.content) - log.info("action_file_created", filename=filename) - return filename - - return await asyncio.get_event_loop().run_in_executor(None, _do) - - -# --------------------------------------------------------------------------- -# Public API — structured (parsed) filter files (Task 2.1) -# --------------------------------------------------------------------------- - - -async def get_parsed_filter_file(config_dir: str, name: str) -> FilterConfig: - """Parse a filter definition file and return its structured representation. - - Reads the raw ``.conf``/``.local`` file from ``filter.d/``, parses it with - :func:`~app.utils.conffile_parser.parse_filter_file`, and returns the - result. - - Args: - config_dir: Path to the fail2ban configuration directory. - name: Base name with or without extension. - - Returns: - :class:`~app.models.config.FilterConfig`. - - Raises: - ConfigFileNotFoundError: If no matching file is found. - ConfigDirError: If *config_dir* does not exist. - """ - from app.utils.conffile_parser import parse_filter_file # avoid circular imports - - def _do() -> FilterConfig: - filter_d = _resolve_subdir(config_dir, "filter.d") - raw = _read_conf_file(filter_d, name) - result = parse_filter_file(raw.content, name=raw.name, filename=raw.filename) - log.debug("filter_file_parsed", name=raw.name) - return result - - return await asyncio.get_event_loop().run_in_executor(None, _do) - - -async def update_parsed_filter_file( - config_dir: str, - name: str, - update: FilterConfigUpdate, -) -> None: - """Apply a structured partial update to a filter definition file. - - Reads the existing file, merges *update* onto it, serializes to INI format, - and writes the result back to disk. - - Args: - config_dir: Path to the fail2ban configuration directory. - name: Base name of the file to update. - update: Partial fields to apply. - - Raises: - ConfigFileNotFoundError: If no matching file is found. - ConfigFileWriteError: If the file cannot be written. - ConfigDirError: If *config_dir* does not exist. - """ - from app.utils.conffile_parser import ( # avoid circular imports - merge_filter_update, - parse_filter_file, - serialize_filter_config, - ) - - def _do() -> None: - filter_d = _resolve_subdir(config_dir, "filter.d") - raw = _read_conf_file(filter_d, name) - current = parse_filter_file(raw.content, name=raw.name, filename=raw.filename) - merged = merge_filter_update(current, update) - new_content = serialize_filter_config(merged) - _validate_content(new_content) - _write_conf_file(filter_d, name, new_content) - log.info("filter_file_updated_parsed", name=name) - - await asyncio.get_event_loop().run_in_executor(None, _do) - - -# --------------------------------------------------------------------------- -# Public API — structured (parsed) action files (Task 3.1) -# --------------------------------------------------------------------------- - - -async def get_parsed_action_file(config_dir: str, name: str) -> ActionConfig: - """Parse an action definition file and return its structured representation. - - Args: - config_dir: Path to the fail2ban configuration directory. - name: Base name with or without extension. - - Returns: - :class:`~app.models.config.ActionConfig`. - - Raises: - ConfigFileNotFoundError: If no matching file is found. - ConfigDirError: If *config_dir* does not exist. - """ - from app.utils.conffile_parser import parse_action_file # avoid circular imports - - def _do() -> ActionConfig: - action_d = _resolve_subdir(config_dir, "action.d") - raw = _read_conf_file(action_d, name) - result = parse_action_file(raw.content, name=raw.name, filename=raw.filename) - log.debug("action_file_parsed", name=raw.name) - return result - - return await asyncio.get_event_loop().run_in_executor(None, _do) - - -async def update_parsed_action_file( - config_dir: str, - name: str, - update: ActionConfigUpdate, -) -> None: - """Apply a structured partial update to an action definition file. - - Args: - config_dir: Path to the fail2ban configuration directory. - name: Base name of the file to update. - update: Partial fields to apply. - - Raises: - ConfigFileNotFoundError: If no matching file is found. - ConfigFileWriteError: If the file cannot be written. - ConfigDirError: If *config_dir* does not exist. - """ - from app.utils.conffile_parser import ( # avoid circular imports - merge_action_update, - parse_action_file, - serialize_action_config, - ) - - def _do() -> None: - action_d = _resolve_subdir(config_dir, "action.d") - raw = _read_conf_file(action_d, name) - current = parse_action_file(raw.content, name=raw.name, filename=raw.filename) - merged = merge_action_update(current, update) - new_content = serialize_action_config(merged) - _validate_content(new_content) - _write_conf_file(action_d, name, new_content) - log.info("action_file_updated_parsed", name=name) - - await asyncio.get_event_loop().run_in_executor(None, _do) - - -async def get_parsed_jail_file(config_dir: str, filename: str) -> JailFileConfig: - """Parse a jail.d config file into a structured :class:`~app.models.config.JailFileConfig`. - - Args: - config_dir: Path to the fail2ban configuration directory. - filename: Filename including extension (e.g. ``"sshd.conf"``). - - Returns: - :class:`~app.models.config.JailFileConfig`. - - Raises: - ConfigFileNotFoundError: If no matching file is found. - ConfigDirError: If *config_dir* does not exist. - """ - from app.utils.conffile_parser import parse_jail_file # avoid circular imports - - def _do() -> JailFileConfig: - jail_d = _resolve_subdir(config_dir, "jail.d") - raw = _read_conf_file(jail_d, filename) - result = parse_jail_file(raw.content, filename=raw.filename) - log.debug("jail_file_parsed", filename=raw.filename) - return result - - return await asyncio.get_event_loop().run_in_executor(None, _do) - - -async def update_parsed_jail_file( - config_dir: str, - filename: str, - update: JailFileConfigUpdate, -) -> None: - """Apply a structured partial update to a jail.d config file. - - Args: - config_dir: Path to the fail2ban configuration directory. - filename: Filename including extension (e.g. ``"sshd.conf"``). - update: Partial fields to apply. - - Raises: - ConfigFileNotFoundError: If no matching file is found. - ConfigFileWriteError: If the file cannot be written. - ConfigDirError: If *config_dir* does not exist. - """ - from app.utils.conffile_parser import ( # avoid circular imports - merge_jail_file_update, - parse_jail_file, - serialize_jail_file_config, - ) - - def _do() -> None: - jail_d = _resolve_subdir(config_dir, "jail.d") - raw = _read_conf_file(jail_d, filename) - current = parse_jail_file(raw.content, filename=raw.filename) - merged = merge_jail_file_update(current, update) - new_content = serialize_jail_file_config(merged) - _validate_content(new_content) - _write_conf_file(jail_d, filename, new_content) - log.info("jail_file_updated_parsed", filename=filename) - - await asyncio.get_event_loop().run_in_executor(None, _do) diff --git a/backend/app/services/filter_config_service.py b/backend/app/services/filter_config_service.py index 572f0f9..ba5e1c5 100644 --- a/backend/app/services/filter_config_service.py +++ b/backend/app/services/filter_config_service.py @@ -25,15 +25,9 @@ from app.models.config import ( FilterUpdateRequest, AssignFilterRequest, ) -from app.exceptions import JailNotFoundError -from app.services import jail_service -from app.services.config_file_service import ( - _parse_jails_sync, - _get_active_jail_names, - ConfigWriteError, - JailNotFoundInConfigError, -) +from app.exceptions import FilterInvalidRegexError, JailNotFoundError from app.utils import conffile_parser +from app.utils.jail_utils import reload_jails log: structlog.stdlib.BoundLogger = structlog.get_logger() @@ -83,21 +77,6 @@ class FilterReadonlyError(Exception): ) -class FilterInvalidRegexError(Exception): - """Raised when a regex pattern fails to compile.""" - - def __init__(self, pattern: str, error: str) -> None: - """Initialise with the invalid pattern and the compile error. - - Args: - pattern: The regex string that failed to compile. - error: The ``re.error`` message. - """ - self.pattern: str = pattern - self.error: str = error - super().__init__(f"Invalid regex {pattern!r}: {error}") - - class FilterNameError(Exception): """Raised when a filter name contains invalid characters.""" @@ -723,7 +702,7 @@ async def update_filter( if do_reload: try: - await jail_service.reload_all(socket_path) + await reload_jails(socket_path) except Exception as exc: # noqa: BLE001 log.warning( "reload_after_filter_update_failed", @@ -798,7 +777,7 @@ async def create_filter( if do_reload: try: - await jail_service.reload_all(socket_path) + await reload_jails(socket_path) except Exception as exc: # noqa: BLE001 log.warning( "reload_after_filter_create_failed", @@ -924,7 +903,7 @@ async def assign_filter_to_jail( if do_reload: try: - await jail_service.reload_all(socket_path) + await reload_jails(socket_path) except Exception as exc: # noqa: BLE001 log.warning( "reload_after_assign_filter_failed", diff --git a/backend/app/services/geo_service.py b/backend/app/services/geo_service.py index db76726..2ac40c4 100644 --- a/backend/app/services/geo_service.py +++ b/backend/app/services/geo_service.py @@ -20,9 +20,7 @@ Usage:: import aiohttp import aiosqlite - from app.services import geo_service - - # warm the cache from the persistent store at startup + # Use the geo_service directly in application startup async with aiosqlite.connect("bangui.db") as db: await geo_service.load_cache_from_db(db) diff --git a/backend/app/services/jail_config_service.py b/backend/app/services/jail_config_service.py index a45ef8f..cc8c2e4 100644 --- a/backend/app/services/jail_config_service.py +++ b/backend/app/services/jail_config_service.py @@ -30,7 +30,13 @@ from app.models.config import ( JailValidationResult, RollbackResponse, ) -from app.services import config_file_service, jail_service +from app.utils.config_file_utils import ( + _build_inactive_jail, + _ordered_config_files, + _parse_jails_sync, + _validate_jail_config_sync, +) +from app.utils.jail_utils import reload_jails from app.utils.fail2ban_client import ( Fail2BanClient, Fail2BanConnectionError, @@ -304,7 +310,7 @@ def _validate_regex_patterns(patterns: list[str]) -> None: re.compile(pattern) except re.error as exc: # Import here to avoid circular dependency - from app.services.filter_config_service import FilterInvalidRegexError + from app.exceptions import FilterInvalidRegexError raise FilterInvalidRegexError(pattern, str(exc)) from exc @@ -460,12 +466,7 @@ async def start_daemon(start_cmd_parts: list[str]) -> bool: return False -# Import shared functions from config_file_service -_parse_jails_sync = config_file_service._parse_jails_sync -_build_inactive_jail = config_file_service._build_inactive_jail -_get_active_jail_names = config_file_service._get_active_jail_names -_validate_jail_config_sync = config_file_service._validate_jail_config_sync -_orderedconfig_files = config_file_service._ordered_config_files +# Shared functions from config_file_service are imported from app.utils.config_file_utils # --------------------------------------------------------------------------- @@ -624,7 +625,7 @@ async def activate_jail( # Activation reload — if it fails, roll back immediately # # ---------------------------------------------------------------------- # try: - await jail_service.reload_all(socket_path, include_jails=[name]) + await reload_jails(socket_path, include_jails=[name]) except JailNotFoundError as exc: # Jail configuration is invalid (e.g. missing logpath that prevents # fail2ban from loading the jail). Roll back and provide a specific error. @@ -767,7 +768,7 @@ async def _rollback_activation_async( # Step 2 — reload fail2ban with the restored config. try: - await jail_service.reload_all(socket_path) + await reload_jails(socket_path) log.info("jail_activation_rollback_reload_ok", jail=name) except Exception as exc: # noqa: BLE001 log.warning("jail_activation_rollback_reload_failed", jail=name, error=str(exc)) @@ -834,7 +835,7 @@ async def deactivate_jail( ) try: - await jail_service.reload_all(socket_path, exclude_jails=[name]) + await reload_jails(socket_path, exclude_jails=[name]) except Exception as exc: # noqa: BLE001 log.warning("reload_after_deactivate_failed", jail=name, error=str(exc)) diff --git a/backend/app/services/setup_service.py b/backend/app/services/setup_service.py index f29325a..5254fce 100644 --- a/backend/app/services/setup_service.py +++ b/backend/app/services/setup_service.py @@ -102,30 +102,20 @@ async def run_setup( log.info("bangui_setup_completed") +from app.utils.setup_utils import ( + get_map_color_thresholds as util_get_map_color_thresholds, + get_password_hash as util_get_password_hash, + set_map_color_thresholds as util_set_map_color_thresholds, +) + + async def get_password_hash(db: aiosqlite.Connection) -> str | None: - """Return the stored bcrypt password hash, or ``None`` if not set. - - Args: - db: Active aiosqlite connection. - - Returns: - The bcrypt hash string, or ``None``. - """ - return await settings_repo.get_setting(db, _KEY_PASSWORD_HASH) + """Return the stored bcrypt password hash, or ``None`` if not set.""" + return await util_get_password_hash(db) async def get_timezone(db: aiosqlite.Connection) -> str: - """Return the configured IANA timezone string. - - Falls back to ``"UTC"`` when no timezone has been stored (e.g. before - setup completes or for legacy databases). - - Args: - db: Active aiosqlite connection. - - Returns: - An IANA timezone identifier such as ``"Europe/Berlin"`` or ``"UTC"``. - """ + """Return the configured IANA timezone string.""" tz = await settings_repo.get_setting(db, _KEY_TIMEZONE) return tz if tz else "UTC" @@ -133,31 +123,8 @@ async def get_timezone(db: aiosqlite.Connection) -> str: async def get_map_color_thresholds( db: aiosqlite.Connection, ) -> tuple[int, int, int]: - """Return the configured map color thresholds (high, medium, low). - - Falls back to default values (100, 50, 20) if not set. - - Args: - db: Active aiosqlite connection. - - Returns: - A tuple of (threshold_high, threshold_medium, threshold_low). - """ - high = await settings_repo.get_setting( - db, _KEY_MAP_COLOR_THRESHOLD_HIGH - ) - medium = await settings_repo.get_setting( - db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM - ) - low = await settings_repo.get_setting( - db, _KEY_MAP_COLOR_THRESHOLD_LOW - ) - - return ( - int(high) if high else 100, - int(medium) if medium else 50, - int(low) if low else 20, - ) + """Return the configured map color thresholds (high, medium, low).""" + return await util_get_map_color_thresholds(db) async def set_map_color_thresholds( @@ -167,31 +134,12 @@ async def set_map_color_thresholds( threshold_medium: int, threshold_low: int, ) -> None: - """Update the map color threshold configuration. - - Args: - db: Active aiosqlite connection. - threshold_high: Ban count for red coloring. - threshold_medium: Ban count for yellow coloring. - threshold_low: Ban count for green coloring. - - Raises: - ValueError: If thresholds are not positive integers or if - high <= medium <= low. - """ - if threshold_high <= 0 or threshold_medium <= 0 or threshold_low <= 0: - raise ValueError("All thresholds must be positive integers.") - if not (threshold_high > threshold_medium > threshold_low): - raise ValueError("Thresholds must satisfy: high > medium > low.") - - await settings_repo.set_setting( - db, _KEY_MAP_COLOR_THRESHOLD_HIGH, str(threshold_high) - ) - await settings_repo.set_setting( - db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, str(threshold_medium) - ) - await settings_repo.set_setting( - db, _KEY_MAP_COLOR_THRESHOLD_LOW, str(threshold_low) + """Update the map color threshold configuration.""" + await util_set_map_color_thresholds( + db, + threshold_high=threshold_high, + threshold_medium=threshold_medium, + threshold_low=threshold_low, ) log.info( "map_color_thresholds_updated", diff --git a/backend/app/tasks/blocklist_import.py b/backend/app/tasks/blocklist_import.py index 80e7246..1a23ba3 100644 --- a/backend/app/tasks/blocklist_import.py +++ b/backend/app/tasks/blocklist_import.py @@ -43,9 +43,15 @@ async def _run_import(app: Any) -> None: http_session = app.state.http_session socket_path: str = app.state.settings.fail2ban_socket + from app.services import jail_service + log.info("blocklist_import_starting") try: - result = await blocklist_service.import_all(db, http_session, socket_path) + result = await blocklist_service.import_all( + db, + http_session, + socket_path, + ) log.info( "blocklist_import_finished", total_imported=result.total_imported, diff --git a/backend/app/utils/config_file_utils.py b/backend/app/utils/config_file_utils.py new file mode 100644 index 0000000..5559904 --- /dev/null +++ b/backend/app/utils/config_file_utils.py @@ -0,0 +1,21 @@ +"""Utilities re-exported from config_file_service for cross-module usage.""" + +from __future__ import annotations + +from pathlib import Path + +from app.services.config_file_service import ( + _build_inactive_jail, + _get_active_jail_names, + _ordered_config_files, + _parse_jails_sync, + _validate_jail_config_sync, +) + +__all__ = [ + "_ordered_config_files", + "_parse_jails_sync", + "_build_inactive_jail", + "_get_active_jail_names", + "_validate_jail_config_sync", +] diff --git a/backend/app/utils/jail_utils.py b/backend/app/utils/jail_utils.py new file mode 100644 index 0000000..23bb13d --- /dev/null +++ b/backend/app/utils/jail_utils.py @@ -0,0 +1,20 @@ +"""Jail helpers to decouple service layer dependencies.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from app.services.jail_service import reload_all + + +async def reload_jails( + socket_path: str, + include_jails: Sequence[str] | None = None, + exclude_jails: Sequence[str] | None = None, +) -> None: + """Reload fail2ban jails using shared jail service helper.""" + await reload_all( + socket_path, + include_jails=list(include_jails) if include_jails is not None else None, + exclude_jails=list(exclude_jails) if exclude_jails is not None else None, + ) diff --git a/backend/app/utils/log_utils.py b/backend/app/utils/log_utils.py new file mode 100644 index 0000000..54a6892 --- /dev/null +++ b/backend/app/utils/log_utils.py @@ -0,0 +1,14 @@ +"""Log-related helpers to avoid direct service-to-service imports.""" + +from __future__ import annotations + +from app.models.config import LogPreviewRequest, LogPreviewResponse, RegexTestRequest, RegexTestResponse +from app.services.log_service import preview_log as _preview_log, test_regex as _test_regex + + +async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse: + return await _preview_log(req) + + +def test_regex(req: RegexTestRequest) -> RegexTestResponse: + return _test_regex(req) diff --git a/backend/app/utils/setup_utils.py b/backend/app/utils/setup_utils.py new file mode 100644 index 0000000..9fa6db3 --- /dev/null +++ b/backend/app/utils/setup_utils.py @@ -0,0 +1,47 @@ +"""Setup-related utilities shared by multiple services.""" + +from __future__ import annotations + +from app.repositories import settings_repo + +_KEY_PASSWORD_HASH = "master_password_hash" +_KEY_SETUP_DONE = "setup_completed" +_KEY_MAP_COLOR_THRESHOLD_HIGH = "map_color_threshold_high" +_KEY_MAP_COLOR_THRESHOLD_MEDIUM = "map_color_threshold_medium" +_KEY_MAP_COLOR_THRESHOLD_LOW = "map_color_threshold_low" + + +async def get_password_hash(db): + """Return the stored master password hash or None.""" + return await settings_repo.get_setting(db, _KEY_PASSWORD_HASH) + + +async def get_map_color_thresholds(db): + """Return map color thresholds as tuple (high, medium, low).""" + high = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_HIGH) + medium = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM) + low = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW) + + return ( + int(high) if high else 100, + int(medium) if medium else 50, + int(low) if low else 20, + ) + + +async def set_map_color_thresholds( + db, + *, + threshold_high: int, + threshold_medium: int, + threshold_low: int, +) -> None: + """Persist map color thresholds after validating values.""" + if threshold_high <= 0 or threshold_medium <= 0 or threshold_low <= 0: + raise ValueError("All thresholds must be positive integers.") + if not (threshold_high > threshold_medium > threshold_low): + raise ValueError("Thresholds must satisfy: high > medium > low.") + + await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_HIGH, str(threshold_high)) + await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, str(threshold_medium)) + await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW, str(threshold_low)) diff --git a/backend/tests/test_services/test_blocklist_service.py b/backend/tests/test_services/test_blocklist_service.py index a9e6c13..674c554 100644 --- a/backend/tests/test_services/test_blocklist_service.py +++ b/backend/tests/test_services/test_blocklist_service.py @@ -203,9 +203,15 @@ class TestImport: call_count += 1 raise JailNotFoundError(jail) - with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found): + with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found) as mocked_ban_ip: + from app.services import jail_service + result = await blocklist_service.import_source( - source, session, "/tmp/fake.sock", db + source, + session, + "/tmp/fake.sock", + db, + ban_ip=jail_service.ban_ip, ) # Must abort after the first JailNotFoundError — only one ban attempt. @@ -226,7 +232,14 @@ class TestImport: with patch( "app.services.jail_service.ban_ip", new_callable=AsyncMock ): - result = await blocklist_service.import_all(db, session, "/tmp/fake.sock") + from app.services import jail_service + + result = await blocklist_service.import_all( + db, + session, + "/tmp/fake.sock", + ban_ip=jail_service.ban_ip, + ) # Only S1 is enabled, S2 is disabled. assert len(result.results) == 1 diff --git a/backend/tests/test_services/test_config_service.py b/backend/tests/test_services/test_config_service.py index 6b90074..d3d6787 100644 --- a/backend/tests/test_services/test_config_service.py +++ b/backend/tests/test_services/test_config_service.py @@ -721,9 +721,11 @@ class TestGetServiceStatus: def __init__(self, **_kw: Any) -> None: self.send = AsyncMock(side_effect=_send) - with patch("app.services.config_service.Fail2BanClient", _FakeClient), \ - patch("app.services.health_service.probe", AsyncMock(return_value=online_status)): - result = await config_service.get_service_status(_SOCKET) + with patch("app.services.config_service.Fail2BanClient", _FakeClient): + result = await config_service.get_service_status( + _SOCKET, + probe_fn=AsyncMock(return_value=online_status), + ) assert result.online is True assert result.version == "1.0.0" @@ -739,8 +741,10 @@ class TestGetServiceStatus: offline_status = ServerStatus(online=False) - with patch("app.services.health_service.probe", AsyncMock(return_value=offline_status)): - result = await config_service.get_service_status(_SOCKET) + result = await config_service.get_service_status( + _SOCKET, + probe_fn=AsyncMock(return_value=offline_status), + ) assert result.online is False assert result.jail_count == 0 diff --git a/frontend/src/hooks/__tests__/useConfigItem.test.ts b/frontend/src/hooks/__tests__/useConfigItem.test.ts new file mode 100644 index 0000000..39876a5 --- /dev/null +++ b/frontend/src/hooks/__tests__/useConfigItem.test.ts @@ -0,0 +1,88 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { renderHook, act } from "@testing-library/react"; +import { useConfigItem } from "../useConfigItem"; + +describe("useConfigItem", () => { + beforeEach(() => { + vi.useFakeTimers(); + }); + + afterEach(() => { + vi.useRealTimers(); + vi.clearAllMocks(); + }); + + it("loads data and sets loading state", async () => { + const fetchFn = vi.fn().mockResolvedValue("hello"); + const saveFn = vi.fn().mockResolvedValue(undefined); + + const { result } = renderHook(() => useConfigItem({ fetchFn, saveFn })); + + expect(result.current.loading).toBe(true); + await act(async () => { + await Promise.resolve(); + }); + + expect(fetchFn).toHaveBeenCalled(); + expect(result.current.data).toBe("hello"); + expect(result.current.loading).toBe(false); + }); + + it("sets error if fetch rejects", async () => { + const fetchFn = vi.fn().mockRejectedValue(new Error("nope")); + const saveFn = vi.fn().mockResolvedValue(undefined); + + const { result } = renderHook(() => useConfigItem({ fetchFn, saveFn })); + + await act(async () => { + await Promise.resolve(); + }); + + expect(result.current.error).toBe("nope"); + expect(result.current.loading).toBe(false); + }); + + it("save updates data when mergeOnSave is provided", async () => { + const fetchFn = vi.fn().mockResolvedValue({ value: 1 }); + const saveFn = vi.fn().mockResolvedValue(undefined); + + const { result } = renderHook(() => + useConfigItem<{ value: number }, { delta: number }>({ + fetchFn, + saveFn, + mergeOnSave: (prev, update) => + prev ? { ...prev, value: prev.value + update.delta } : prev, + }) + ); + + await act(async () => { + await Promise.resolve(); + }); + + expect(result.current.data).toEqual({ value: 1 }); + + await act(async () => { + await result.current.save({ delta: 2 }); + }); + + expect(saveFn).toHaveBeenCalledWith({ delta: 2 }); + expect(result.current.data).toEqual({ value: 3 }); + }); + + it("saveError is set when save fails", async () => { + const fetchFn = vi.fn().mockResolvedValue("ok"); + const saveFn = vi.fn().mockRejectedValue(new Error("save failed")); + + const { result } = renderHook(() => useConfigItem({ fetchFn, saveFn })); + + await act(async () => { + await Promise.resolve(); + }); + + await act(async () => { + await expect(result.current.save("test")).rejects.toThrow("save failed"); + }); + + expect(result.current.saveError).toBe("save failed"); + }); +}); diff --git a/frontend/src/hooks/useActionConfig.ts b/frontend/src/hooks/useActionConfig.ts index 22b40de..0baf599 100644 --- a/frontend/src/hooks/useActionConfig.ts +++ b/frontend/src/hooks/useActionConfig.ts @@ -2,7 +2,7 @@ * React hook for loading and updating a single parsed action config. */ -import { useCallback, useEffect, useRef, useState } from "react"; +import { useConfigItem } from "./useConfigItem"; import { fetchAction, updateAction } from "../api/config"; import type { ActionConfig, ActionConfigUpdate } from "../types/config"; @@ -23,67 +23,28 @@ export interface UseActionConfigResult { * @param name - Action base name (e.g. ``"iptables"``). */ export function useActionConfig(name: string): UseActionConfigResult { - const [config, setConfig] = useState(null); - const [loading, setLoading] = useState(true); - const [error, setError] = useState(null); - const [saving, setSaving] = useState(false); - const [saveError, setSaveError] = useState(null); - const abortRef = useRef(null); + const { data, loading, error, saving, saveError, refresh, save } = useConfigItem< + ActionConfig, + ActionConfigUpdate + >({ + fetchFn: () => fetchAction(name), + saveFn: (update) => updateAction(name, update), + mergeOnSave: (prev, update) => + prev + ? { + ...prev, + ...Object.fromEntries(Object.entries(update).filter(([, v]) => v != null)), + } + : prev, + }); - const load = useCallback((): void => { - abortRef.current?.abort(); - const ctrl = new AbortController(); - abortRef.current = ctrl; - setLoading(true); - setError(null); - - fetchAction(name) - .then((data) => { - if (!ctrl.signal.aborted) { - setConfig(data); - setLoading(false); - } - }) - .catch((err: unknown) => { - if (!ctrl.signal.aborted) { - setError(err instanceof Error ? err.message : "Failed to load action config"); - setLoading(false); - } - }); - }, [name]); - - useEffect(() => { - load(); - return (): void => { - abortRef.current?.abort(); - }; - }, [load]); - - const save = useCallback( - async (update: ActionConfigUpdate): Promise => { - setSaving(true); - setSaveError(null); - try { - await updateAction(name, update); - setConfig((prev) => - prev - ? { - ...prev, - ...Object.fromEntries( - Object.entries(update).filter(([, v]) => v !== null && v !== undefined) - ), - } - : prev - ); - } catch (err: unknown) { - setSaveError(err instanceof Error ? err.message : "Failed to save action config"); - throw err; - } finally { - setSaving(false); - } - }, - [name] - ); - - return { config, loading, error, saving, saveError, refresh: load, save }; + return { + config: data, + loading, + error, + saving, + saveError, + refresh, + save, + }; } diff --git a/frontend/src/hooks/useConfigItem.ts b/frontend/src/hooks/useConfigItem.ts new file mode 100644 index 0000000..bf537cb --- /dev/null +++ b/frontend/src/hooks/useConfigItem.ts @@ -0,0 +1,84 @@ +/** + * Generic config hook for loading and saving a single entity. + */ +import { useCallback, useEffect, useRef, useState } from "react"; + +export interface UseConfigItemResult { + data: T | null; + loading: boolean; + error: string | null; + saving: boolean; + saveError: string | null; + refresh: () => void; + save: (update: U) => Promise; +} + +export interface UseConfigItemOptions { + fetchFn: (signal: AbortSignal) => Promise; + saveFn: (update: U) => Promise; + mergeOnSave?: (prev: T | null, update: U) => T | null; +} + +export function useConfigItem( + options: UseConfigItemOptions +): UseConfigItemResult { + const { fetchFn, saveFn, mergeOnSave } = options; + const [data, setData] = useState(null); + const [loading, setLoading] = useState(true); + const [error, setError] = useState(null); + const [saving, setSaving] = useState(false); + const [saveError, setSaveError] = useState(null); + const abortRef = useRef(null); + + const refresh = useCallback((): void => { + abortRef.current?.abort(); + const controller = new AbortController(); + abortRef.current = controller; + + setLoading(true); + setError(null); + + fetchFn(controller.signal) + .then((nextData) => { + if (controller.signal.aborted) return; + setData(nextData); + setLoading(false); + }) + .catch((err: unknown) => { + if (controller.signal.aborted) return; + setError(err instanceof Error ? err.message : "Failed to load data"); + setLoading(false); + }); + }, [fetchFn]); + + useEffect(() => { + refresh(); + + return () => { + abortRef.current?.abort(); + }; + }, [refresh]); + + const save = useCallback( + async (update: U): Promise => { + setSaving(true); + setSaveError(null); + + try { + await saveFn(update); + if (mergeOnSave) { + setData((prevData) => mergeOnSave(prevData, update)); + } + } catch (err: unknown) { + const message = err instanceof Error ? err.message : "Failed to save data"; + setSaveError(message); + throw err; + } finally { + setSaving(false); + } + }, + [saveFn, mergeOnSave] + ); + + return { data, loading, error, saving, saveError, refresh, save }; +} diff --git a/frontend/src/hooks/useFilterConfig.ts b/frontend/src/hooks/useFilterConfig.ts index 9a52544..b4163d1 100644 --- a/frontend/src/hooks/useFilterConfig.ts +++ b/frontend/src/hooks/useFilterConfig.ts @@ -2,7 +2,7 @@ * React hook for loading and updating a single parsed filter config. */ -import { useCallback, useEffect, useRef, useState } from "react"; +import { useConfigItem } from "./useConfigItem"; import { fetchParsedFilter, updateParsedFilter } from "../api/config"; import type { FilterConfig, FilterConfigUpdate } from "../types/config"; @@ -23,69 +23,28 @@ export interface UseFilterConfigResult { * @param name - Filter base name (e.g. ``"sshd"``). */ export function useFilterConfig(name: string): UseFilterConfigResult { - const [config, setConfig] = useState(null); - const [loading, setLoading] = useState(true); - const [error, setError] = useState(null); - const [saving, setSaving] = useState(false); - const [saveError, setSaveError] = useState(null); - const abortRef = useRef(null); + const { data, loading, error, saving, saveError, refresh, save } = useConfigItem< + FilterConfig, + FilterConfigUpdate + >({ + fetchFn: () => fetchParsedFilter(name), + saveFn: (update) => updateParsedFilter(name, update), + mergeOnSave: (prev, update) => + prev + ? { + ...prev, + ...Object.fromEntries(Object.entries(update).filter(([, v]) => v != null)), + } + : prev, + }); - const load = useCallback((): void => { - abortRef.current?.abort(); - const ctrl = new AbortController(); - abortRef.current = ctrl; - setLoading(true); - setError(null); - - fetchParsedFilter(name) - .then((data) => { - if (!ctrl.signal.aborted) { - setConfig(data); - setLoading(false); - } - }) - .catch((err: unknown) => { - if (!ctrl.signal.aborted) { - setError(err instanceof Error ? err.message : "Failed to load filter config"); - setLoading(false); - } - }); - }, [name]); - - useEffect(() => { - load(); - return (): void => { - abortRef.current?.abort(); - }; - }, [load]); - - const save = useCallback( - async (update: FilterConfigUpdate): Promise => { - setSaving(true); - setSaveError(null); - try { - await updateParsedFilter(name, update); - // Optimistically update local state so the form reflects changes - // without a full reload. - setConfig((prev) => - prev - ? { - ...prev, - ...Object.fromEntries( - Object.entries(update).filter(([, v]) => v !== null && v !== undefined) - ), - } - : prev - ); - } catch (err: unknown) { - setSaveError(err instanceof Error ? err.message : "Failed to save filter config"); - throw err; - } finally { - setSaving(false); - } - }, - [name] - ); - - return { config, loading, error, saving, saveError, refresh: load, save }; + return { + config: data, + loading, + error, + saving, + saveError, + refresh, + save, + }; } diff --git a/frontend/src/hooks/useJailFileConfig.ts b/frontend/src/hooks/useJailFileConfig.ts index 096df42..a440bb2 100644 --- a/frontend/src/hooks/useJailFileConfig.ts +++ b/frontend/src/hooks/useJailFileConfig.ts @@ -2,7 +2,7 @@ * React hook for loading and updating a single parsed jail.d config file. */ -import { useCallback, useEffect, useRef, useState } from "react"; +import { useConfigItem } from "./useConfigItem"; import { fetchParsedJailFile, updateParsedJailFile } from "../api/config"; import type { JailFileConfig, JailFileConfigUpdate } from "../types/config"; @@ -21,56 +21,23 @@ export interface UseJailFileConfigResult { * @param filename - Filename including extension (e.g. ``"sshd.conf"``). */ export function useJailFileConfig(filename: string): UseJailFileConfigResult { - const [config, setConfig] = useState(null); - const [loading, setLoading] = useState(true); - const [error, setError] = useState(null); - const abortRef = useRef(null); + const { data, loading, error, refresh, save } = useConfigItem< + JailFileConfig, + JailFileConfigUpdate + >({ + fetchFn: () => fetchParsedJailFile(filename), + saveFn: (update) => updateParsedJailFile(filename, update), + mergeOnSave: (prev, update) => + update.jails != null && prev + ? { ...prev, jails: { ...prev.jails, ...update.jails } } + : prev, + }); - const load = useCallback((): void => { - abortRef.current?.abort(); - const ctrl = new AbortController(); - abortRef.current = ctrl; - setLoading(true); - setError(null); - - fetchParsedJailFile(filename) - .then((data) => { - if (!ctrl.signal.aborted) { - setConfig(data); - setLoading(false); - } - }) - .catch((err: unknown) => { - if (!ctrl.signal.aborted) { - setError(err instanceof Error ? err.message : "Failed to load jail file config"); - setLoading(false); - } - }); - }, [filename]); - - useEffect(() => { - load(); - return (): void => { - abortRef.current?.abort(); - }; - }, [load]); - - const save = useCallback( - async (update: JailFileConfigUpdate): Promise => { - try { - await updateParsedJailFile(filename, update); - // Optimistically merge updated jails into local state. - if (update.jails != null) { - setConfig((prev) => - prev ? { ...prev, jails: { ...prev.jails, ...update.jails } } : prev - ); - } - } catch (err: unknown) { - throw err instanceof Error ? err : new Error("Failed to save jail file config"); - } - }, - [filename] - ); - - return { config, loading, error, refresh: load, save }; + return { + config: data, + loading, + error, + refresh, + save, + }; }