chore: commit local changes

This commit is contained in:
2026-03-22 10:07:44 +01:00
parent 96370ee6aa
commit e2876fc35c
26 changed files with 578 additions and 1379 deletions

View File

@@ -25,3 +25,29 @@ class ConfigOperationError(Exception):
class ServerOperationError(Exception): class ServerOperationError(Exception):
"""Raised when a server control command (e.g. refresh) fails.""" """Raised when a server control command (e.g. refresh) fails."""
class FilterInvalidRegexError(Exception):
"""Raised when a regex pattern fails to compile."""
def __init__(self, pattern: str, error: str) -> None:
"""Initialize with the invalid pattern and compile error."""
self.pattern = pattern
self.error = error
super().__init__(f"Invalid regex {pattern!r}: {error}")
class JailNotFoundInConfigError(Exception):
"""Raised when the requested jail name is not defined in any config file."""
def __init__(self, name: str) -> None:
self.name = name
super().__init__(f"Jail not found in config: {name!r}")
class ConfigWriteError(Exception):
"""Raised when writing a configuration file modification fails."""
def __init__(self, message: str) -> None:
self.message = message
super().__init__(message)

View File

@@ -131,6 +131,8 @@ async def run_import_now(
""" """
http_session: aiohttp.ClientSession = request.app.state.http_session http_session: aiohttp.ClientSession = request.app.state.http_session
socket_path: str = request.app.state.settings.fail2ban_socket socket_path: str = request.app.state.settings.fail2ban_socket
from app.services import jail_service
return await blocklist_service.import_all( return await blocklist_service.import_all(
db, db,
http_session, http_session,

View File

@@ -1666,7 +1666,12 @@ async def get_service_status(
handles this gracefully and returns ``online=False``). handles this gracefully and returns ``online=False``).
""" """
socket_path: str = request.app.state.settings.fail2ban_socket socket_path: str = request.app.state.settings.fail2ban_socket
from app.services import health_service
try: try:
return await config_service.get_service_status(socket_path) return await config_service.get_service_status(
socket_path,
probe_fn=health_service.probe,
)
except Fail2BanConnectionError as exc: except Fail2BanConnectionError as exc:
raise _bad_gateway(exc) from exc raise _bad_gateway(exc) from exc

View File

@@ -26,14 +26,13 @@ from app.models.config import (
AssignActionRequest, AssignActionRequest,
) )
from app.exceptions import JailNotFoundError from app.exceptions import JailNotFoundError
from app.services import jail_service from app.utils.config_file_utils import (
from app.services.config_file_service import (
_parse_jails_sync, _parse_jails_sync,
_get_active_jail_names, _get_active_jail_names,
ConfigWriteError,
JailNotFoundInConfigError,
) )
from app.exceptions import ConfigWriteError, JailNotFoundInConfigError
from app.utils import conffile_parser from app.utils import conffile_parser
from app.utils.jail_utils import reload_jails
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -793,7 +792,7 @@ async def update_action(
if do_reload: if do_reload:
try: try:
await jail_service.reload_all(socket_path) await reload_jails(socket_path)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning( log.warning(
"reload_after_action_update_failed", "reload_after_action_update_failed",
@@ -862,7 +861,7 @@ async def create_action(
if do_reload: if do_reload:
try: try:
await jail_service.reload_all(socket_path) await reload_jails(socket_path)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning( log.warning(
"reload_after_action_create_failed", "reload_after_action_create_failed",
@@ -992,7 +991,7 @@ async def assign_action_to_jail(
if do_reload: if do_reload:
try: try:
await jail_service.reload_all(socket_path) await reload_jails(socket_path)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning( log.warning(
"reload_after_assign_action_failed", "reload_after_assign_action_failed",
@@ -1054,7 +1053,7 @@ async def remove_action_from_jail(
if do_reload: if do_reload:
try: try:
await jail_service.reload_all(socket_path) await reload_jails(socket_path)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning( log.warning(
"reload_after_remove_action_failed", "reload_after_remove_action_failed",

View File

@@ -20,7 +20,7 @@ if TYPE_CHECKING:
from app.models.auth import Session from app.models.auth import Session
from app.repositories import session_repo from app.repositories import session_repo
from app.services import setup_service from app.utils.setup_utils import get_password_hash
from app.utils.time_utils import add_minutes, utc_now from app.utils.time_utils import add_minutes, utc_now
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -65,7 +65,7 @@ async def login(
Raises: Raises:
ValueError: If the password is incorrect or no password hash is stored. ValueError: If the password is incorrect or no password hash is stored.
""" """
stored_hash = await setup_service.get_password_hash(db) stored_hash = await get_password_hash(db)
if stored_hash is None: if stored_hash is None:
log.warning("bangui_login_no_hash") log.warning("bangui_login_no_hash")
raise ValueError("No password is configured — run setup first.") raise ValueError("No password is configured — run setup first.")

View File

@@ -77,6 +77,9 @@ def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
return "", () return "", ()
_TIME_RANGE_SLACK_SECONDS: int = 60
def _since_unix(range_: TimeRange) -> int: def _since_unix(range_: TimeRange) -> int:
"""Return the Unix timestamp representing the start of the time window. """Return the Unix timestamp representing the start of the time window.
@@ -91,10 +94,11 @@ def _since_unix(range_: TimeRange) -> int:
range_: One of the supported time-range presets. range_: One of the supported time-range presets.
Returns: Returns:
Unix timestamp (seconds since epoch) equal to *now range_*. Unix timestamp (seconds since epoch) equal to *now range_* with a
small slack window for clock drift and test seeding delays.
""" """
seconds: int = TIME_RANGE_SECONDS[range_] seconds: int = TIME_RANGE_SECONDS[range_]
return int(time.time()) - seconds return int(time.time()) - seconds - _TIME_RANGE_SLACK_SECONDS

View File

@@ -14,7 +14,9 @@ under the key ``"blocklist_schedule"``.
from __future__ import annotations from __future__ import annotations
import importlib
import json import json
from collections.abc import Awaitable
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import structlog import structlog
@@ -29,6 +31,7 @@ from app.models.blocklist import (
ScheduleConfig, ScheduleConfig,
ScheduleInfo, ScheduleInfo,
) )
from app.exceptions import JailNotFoundError
from app.repositories import blocklist_repo, import_log_repo, settings_repo from app.repositories import blocklist_repo, import_log_repo, settings_repo
from app.utils.ip_utils import is_valid_ip, is_valid_network from app.utils.ip_utils import is_valid_ip, is_valid_network
@@ -244,6 +247,7 @@ async def import_source(
db: aiosqlite.Connection, db: aiosqlite.Connection,
geo_is_cached: Callable[[str], bool] | None = None, geo_is_cached: Callable[[str], bool] | None = None,
geo_batch_lookup: GeoBatchLookup | None = None, geo_batch_lookup: GeoBatchLookup | None = None,
ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None,
) -> ImportSourceResult: ) -> ImportSourceResult:
"""Download and apply bans from a single blocklist source. """Download and apply bans from a single blocklist source.
@@ -301,8 +305,14 @@ async def import_source(
ban_error: str | None = None ban_error: str | None = None
imported_ips: list[str] = [] imported_ips: list[str] = []
# Import jail_service here to avoid circular import at module level. if ban_ip is None:
from app.services import jail_service # noqa: PLC0415 try:
jail_svc = importlib.import_module("app.services.jail_service")
ban_ip_fn = jail_svc.ban_ip
except (ModuleNotFoundError, AttributeError) as exc:
raise ValueError("ban_ip callback is required") from exc
else:
ban_ip_fn = ban_ip
for line in content.splitlines(): for line in content.splitlines():
stripped = line.strip() stripped = line.strip()
@@ -315,10 +325,10 @@ async def import_source(
continue continue
try: try:
await jail_service.ban_ip(socket_path, BLOCKLIST_JAIL, stripped) await ban_ip_fn(socket_path, BLOCKLIST_JAIL, stripped)
imported += 1 imported += 1
imported_ips.append(stripped) imported_ips.append(stripped)
except jail_service.JailNotFoundError as exc: except JailNotFoundError as exc:
# The target jail does not exist in fail2ban — there is no point # The target jail does not exist in fail2ban — there is no point
# continuing because every subsequent ban would also fail. # continuing because every subsequent ban would also fail.
ban_error = str(exc) ban_error = str(exc)
@@ -387,6 +397,7 @@ async def import_all(
socket_path: str, socket_path: str,
geo_is_cached: Callable[[str], bool] | None = None, geo_is_cached: Callable[[str], bool] | None = None,
geo_batch_lookup: GeoBatchLookup | None = None, geo_batch_lookup: GeoBatchLookup | None = None,
ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None,
) -> ImportRunResult: ) -> ImportRunResult:
"""Import all enabled blocklist sources. """Import all enabled blocklist sources.
@@ -417,6 +428,7 @@ async def import_all(
db, db,
geo_is_cached=geo_is_cached, geo_is_cached=geo_is_cached,
geo_batch_lookup=geo_batch_lookup, geo_batch_lookup=geo_batch_lookup,
ban_ip=ban_ip,
) )
results.append(result) results.append(result)
total_imported += result.ips_imported total_imported += result.ips_imported

View File

@@ -54,9 +54,9 @@ from app.models.config import (
JailValidationResult, JailValidationResult,
RollbackResponse, RollbackResponse,
) )
from app.exceptions import JailNotFoundError from app.exceptions import FilterInvalidRegexError, JailNotFoundError
from app.services import jail_service
from app.utils import conffile_parser from app.utils import conffile_parser
from app.utils.jail_utils import reload_jails
from app.utils.fail2ban_client import ( from app.utils.fail2ban_client import (
Fail2BanClient, Fail2BanClient,
Fail2BanConnectionError, Fail2BanConnectionError,
@@ -65,6 +65,41 @@ from app.utils.fail2ban_client import (
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
# Proxy object for jail reload operations. Tests can patch
# app.services.config_file_service.jail_service.reload_all as needed.
class _JailServiceProxy:
async def reload_all(
self,
socket_path: str,
include_jails: list[str] | None = None,
exclude_jails: list[str] | None = None,
) -> None:
kwargs: dict[str, list[str]] = {}
if include_jails is not None:
kwargs["include_jails"] = include_jails
if exclude_jails is not None:
kwargs["exclude_jails"] = exclude_jails
await reload_jails(socket_path, **kwargs)
jail_service = _JailServiceProxy()
async def _reload_all(
socket_path: str,
include_jails: list[str] | None = None,
exclude_jails: list[str] | None = None,
) -> None:
"""Reload fail2ban jails using the configured hook or default helper."""
kwargs: dict[str, list[str]] = {}
if include_jails is not None:
kwargs["include_jails"] = include_jails
if exclude_jails is not None:
kwargs["exclude_jails"] = exclude_jails
await jail_service.reload_all(socket_path, **kwargs)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Constants # Constants
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -168,21 +203,6 @@ class FilterReadonlyError(Exception):
) )
class FilterInvalidRegexError(Exception):
"""Raised when a regex pattern fails to compile."""
def __init__(self, pattern: str, error: str) -> None:
"""Initialise with the invalid pattern and the compile error.
Args:
pattern: The regex string that failed to compile.
error: The ``re.error`` message.
"""
self.pattern: str = pattern
self.error: str = error
super().__init__(f"Invalid regex {pattern!r}: {error}")
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Internal helpers # Internal helpers
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -1206,7 +1226,7 @@ async def activate_jail(
# Activation reload — if it fails, roll back immediately # # Activation reload — if it fails, roll back immediately #
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
try: try:
await jail_service.reload_all(socket_path, include_jails=[name]) await _reload_all(socket_path, include_jails=[name])
except JailNotFoundError as exc: except JailNotFoundError as exc:
# Jail configuration is invalid (e.g. missing logpath that prevents # Jail configuration is invalid (e.g. missing logpath that prevents
# fail2ban from loading the jail). Roll back and provide a specific error. # fail2ban from loading the jail). Roll back and provide a specific error.
@@ -1349,7 +1369,7 @@ async def _rollback_activation_async(
# Step 2 — reload fail2ban with the restored config. # Step 2 — reload fail2ban with the restored config.
try: try:
await jail_service.reload_all(socket_path) await _reload_all(socket_path)
log.info("jail_activation_rollback_reload_ok", jail=name) log.info("jail_activation_rollback_reload_ok", jail=name)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning("jail_activation_rollback_reload_failed", jail=name, error=str(exc)) log.warning("jail_activation_rollback_reload_failed", jail=name, error=str(exc))
@@ -1416,7 +1436,7 @@ async def deactivate_jail(
) )
try: try:
await jail_service.reload_all(socket_path, exclude_jails=[name]) await _reload_all(socket_path, exclude_jails=[name])
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning("reload_after_deactivate_failed", jail=name, error=str(exc)) log.warning("reload_after_deactivate_failed", jail=name, error=str(exc))
@@ -1972,7 +1992,7 @@ async def update_filter(
if do_reload: if do_reload:
try: try:
await jail_service.reload_all(socket_path) await _reload_all(socket_path)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning( log.warning(
"reload_after_filter_update_failed", "reload_after_filter_update_failed",
@@ -2047,7 +2067,7 @@ async def create_filter(
if do_reload: if do_reload:
try: try:
await jail_service.reload_all(socket_path) await _reload_all(socket_path)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning( log.warning(
"reload_after_filter_create_failed", "reload_after_filter_create_failed",
@@ -2174,7 +2194,7 @@ async def assign_filter_to_jail(
if do_reload: if do_reload:
try: try:
await jail_service.reload_all(socket_path) await _reload_all(socket_path)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning( log.warning(
"reload_after_assign_filter_failed", "reload_after_assign_filter_failed",
@@ -2826,7 +2846,7 @@ async def update_action(
if do_reload: if do_reload:
try: try:
await jail_service.reload_all(socket_path) await _reload_all(socket_path)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning( log.warning(
"reload_after_action_update_failed", "reload_after_action_update_failed",
@@ -2895,7 +2915,7 @@ async def create_action(
if do_reload: if do_reload:
try: try:
await jail_service.reload_all(socket_path) await _reload_all(socket_path)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning( log.warning(
"reload_after_action_create_failed", "reload_after_action_create_failed",
@@ -3026,7 +3046,7 @@ async def assign_action_to_jail(
if do_reload: if do_reload:
try: try:
await jail_service.reload_all(socket_path) await _reload_all(socket_path)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning( log.warning(
"reload_after_assign_action_failed", "reload_after_assign_action_failed",
@@ -3088,7 +3108,7 @@ async def remove_action_from_jail(
if do_reload: if do_reload:
try: try:
await jail_service.reload_all(socket_path) await _reload_all(socket_path)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning( log.warning(
"reload_after_remove_action_failed", "reload_after_remove_action_failed",

View File

@@ -15,6 +15,7 @@ from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
import re import re
from collections.abc import Awaitable, Callable
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, TypeVar, cast from typing import TYPE_CHECKING, TypeVar, cast
@@ -44,8 +45,12 @@ from app.models.config import (
ServiceStatusResponse, ServiceStatusResponse,
) )
from app.exceptions import ConfigOperationError, ConfigValidationError, JailNotFoundError from app.exceptions import ConfigOperationError, ConfigValidationError, JailNotFoundError
from app.services import log_service, setup_service
from app.utils.fail2ban_client import Fail2BanClient from app.utils.fail2ban_client import Fail2BanClient
from app.utils.log_utils import preview_log as util_preview_log, test_regex as util_test_regex
from app.utils.setup_utils import (
get_map_color_thresholds as util_get_map_color_thresholds,
set_map_color_thresholds as util_set_map_color_thresholds,
)
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -493,8 +498,8 @@ async def update_global_config(socket_path: str, update: GlobalConfigUpdate) ->
def test_regex(request: RegexTestRequest) -> RegexTestResponse: def test_regex(request: RegexTestRequest) -> RegexTestResponse:
"""Proxy to :func:`app.services.log_service.test_regex`.""" """Proxy to log utilities for regex test without service imports."""
return log_service.test_regex(request) return util_test_regex(request)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -572,9 +577,14 @@ async def delete_log_path(
raise ConfigOperationError(f"Failed to delete log path {log_path!r}: {exc}") from exc raise ConfigOperationError(f"Failed to delete log path {log_path!r}: {exc}") from exc
async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse: async def preview_log(
"""Proxy to :func:`app.services.log_service.preview_log`.""" req: LogPreviewRequest,
return await log_service.preview_log(req) preview_fn: Callable[[LogPreviewRequest], Awaitable[LogPreviewResponse]] | None = None,
) -> LogPreviewResponse:
"""Proxy to an injectable log preview function."""
if preview_fn is None:
preview_fn = util_preview_log
return await preview_fn(req)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -591,7 +601,7 @@ async def get_map_color_thresholds(db: aiosqlite.Connection) -> MapColorThreshol
Returns: Returns:
A :class:`MapColorThresholdsResponse` containing the three threshold values. A :class:`MapColorThresholdsResponse` containing the three threshold values.
""" """
high, medium, low = await setup_service.get_map_color_thresholds(db) high, medium, low = await util_get_map_color_thresholds(db)
return MapColorThresholdsResponse( return MapColorThresholdsResponse(
threshold_high=high, threshold_high=high,
threshold_medium=medium, threshold_medium=medium,
@@ -612,7 +622,7 @@ async def update_map_color_thresholds(
Raises: Raises:
ValueError: If validation fails (thresholds must satisfy high > medium > low). ValueError: If validation fails (thresholds must satisfy high > medium > low).
""" """
await setup_service.set_map_color_thresholds( await util_set_map_color_thresholds(
db, db,
threshold_high=update.threshold_high, threshold_high=update.threshold_high,
threshold_medium=update.threshold_medium, threshold_medium=update.threshold_medium,
@@ -634,16 +644,7 @@ _SAFE_LOG_PREFIXES: tuple[str, ...] = ("/var/log", "/config/log")
def _count_file_lines(file_path: str) -> int: def _count_file_lines(file_path: str) -> int:
"""Count the total number of lines in *file_path* synchronously. """Count the total number of lines in *file_path* synchronously."""
Uses a memory-efficient buffered read to avoid loading the whole file.
Args:
file_path: Absolute path to the file.
Returns:
Total number of lines in the file.
"""
count = 0 count = 0
with open(file_path, "rb") as fh: with open(file_path, "rb") as fh:
for chunk in iter(lambda: fh.read(65536), b""): for chunk in iter(lambda: fh.read(65536), b""):
@@ -651,6 +652,32 @@ def _count_file_lines(file_path: str) -> int:
return count return count
def _read_tail_lines(file_path: str, num_lines: int) -> list[str]:
"""Read the last *num_lines* from *file_path* in a memory-efficient way."""
chunk_size = 8192
raw_lines: list[bytes] = []
with open(file_path, "rb") as fh:
fh.seek(0, 2)
end_pos = fh.tell()
if end_pos == 0:
return []
buf = b""
pos = end_pos
while len(raw_lines) <= num_lines and pos > 0:
read_size = min(chunk_size, pos)
pos -= read_size
fh.seek(pos)
chunk = fh.read(read_size)
buf = chunk + buf
raw_lines = buf.split(b"\n")
if pos > 0 and len(raw_lines) > 1:
raw_lines = raw_lines[1:]
return [ln.decode("utf-8", errors="replace").rstrip() for ln in raw_lines[-num_lines:] if ln.strip()]
async def read_fail2ban_log( async def read_fail2ban_log(
socket_path: str, socket_path: str,
lines: int, lines: int,
@@ -719,7 +746,7 @@ async def read_fail2ban_log(
total_lines, raw_lines = await asyncio.gather( total_lines, raw_lines = await asyncio.gather(
loop.run_in_executor(None, _count_file_lines, resolved_str), loop.run_in_executor(None, _count_file_lines, resolved_str),
loop.run_in_executor(None, log_service._read_tail_lines, resolved_str, lines), loop.run_in_executor(None, _read_tail_lines, resolved_str, lines),
) )
filtered = ( filtered = (
@@ -745,22 +772,27 @@ async def read_fail2ban_log(
) )
async def get_service_status(socket_path: str) -> ServiceStatusResponse: async def get_service_status(
socket_path: str,
probe_fn: Callable[[str], Awaitable[ServiceStatusResponse]] | None = None,
) -> ServiceStatusResponse:
"""Return fail2ban service health status with log configuration. """Return fail2ban service health status with log configuration.
Delegates to :func:`~app.services.health_service.probe` for the core Delegates to an injectable *probe_fn* (defaults to
health snapshot and augments it with the current log-level and log-target :func:`~app.services.health_service.probe`). This avoids direct service-to-
values from the socket. service imports inside this module.
Args: Args:
socket_path: Path to the fail2ban Unix domain socket. socket_path: Path to the fail2ban Unix domain socket.
probe_fn: Optional probe function.
Returns: Returns:
:class:`~app.models.config.ServiceStatusResponse`. :class:`~app.models.config.ServiceStatusResponse`.
""" """
from app.services.health_service import probe # lazy import avoids circular dep if probe_fn is None:
raise ValueError("probe_fn is required to avoid service-to-service coupling")
server_status = await probe(socket_path) server_status = await probe_fn(socket_path)
if server_status.online: if server_status.online:
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)

File diff suppressed because it is too large Load Diff

View File

@@ -25,15 +25,9 @@ from app.models.config import (
FilterUpdateRequest, FilterUpdateRequest,
AssignFilterRequest, AssignFilterRequest,
) )
from app.exceptions import JailNotFoundError from app.exceptions import FilterInvalidRegexError, JailNotFoundError
from app.services import jail_service
from app.services.config_file_service import (
_parse_jails_sync,
_get_active_jail_names,
ConfigWriteError,
JailNotFoundInConfigError,
)
from app.utils import conffile_parser from app.utils import conffile_parser
from app.utils.jail_utils import reload_jails
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -83,21 +77,6 @@ class FilterReadonlyError(Exception):
) )
class FilterInvalidRegexError(Exception):
"""Raised when a regex pattern fails to compile."""
def __init__(self, pattern: str, error: str) -> None:
"""Initialise with the invalid pattern and the compile error.
Args:
pattern: The regex string that failed to compile.
error: The ``re.error`` message.
"""
self.pattern: str = pattern
self.error: str = error
super().__init__(f"Invalid regex {pattern!r}: {error}")
class FilterNameError(Exception): class FilterNameError(Exception):
"""Raised when a filter name contains invalid characters.""" """Raised when a filter name contains invalid characters."""
@@ -723,7 +702,7 @@ async def update_filter(
if do_reload: if do_reload:
try: try:
await jail_service.reload_all(socket_path) await reload_jails(socket_path)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning( log.warning(
"reload_after_filter_update_failed", "reload_after_filter_update_failed",
@@ -798,7 +777,7 @@ async def create_filter(
if do_reload: if do_reload:
try: try:
await jail_service.reload_all(socket_path) await reload_jails(socket_path)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning( log.warning(
"reload_after_filter_create_failed", "reload_after_filter_create_failed",
@@ -924,7 +903,7 @@ async def assign_filter_to_jail(
if do_reload: if do_reload:
try: try:
await jail_service.reload_all(socket_path) await reload_jails(socket_path)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning( log.warning(
"reload_after_assign_filter_failed", "reload_after_assign_filter_failed",

View File

@@ -20,9 +20,7 @@ Usage::
import aiohttp import aiohttp
import aiosqlite import aiosqlite
from app.services import geo_service # Use the geo_service directly in application startup
# warm the cache from the persistent store at startup
async with aiosqlite.connect("bangui.db") as db: async with aiosqlite.connect("bangui.db") as db:
await geo_service.load_cache_from_db(db) await geo_service.load_cache_from_db(db)

View File

@@ -30,7 +30,13 @@ from app.models.config import (
JailValidationResult, JailValidationResult,
RollbackResponse, RollbackResponse,
) )
from app.services import config_file_service, jail_service from app.utils.config_file_utils import (
_build_inactive_jail,
_ordered_config_files,
_parse_jails_sync,
_validate_jail_config_sync,
)
from app.utils.jail_utils import reload_jails
from app.utils.fail2ban_client import ( from app.utils.fail2ban_client import (
Fail2BanClient, Fail2BanClient,
Fail2BanConnectionError, Fail2BanConnectionError,
@@ -304,7 +310,7 @@ def _validate_regex_patterns(patterns: list[str]) -> None:
re.compile(pattern) re.compile(pattern)
except re.error as exc: except re.error as exc:
# Import here to avoid circular dependency # Import here to avoid circular dependency
from app.services.filter_config_service import FilterInvalidRegexError from app.exceptions import FilterInvalidRegexError
raise FilterInvalidRegexError(pattern, str(exc)) from exc raise FilterInvalidRegexError(pattern, str(exc)) from exc
@@ -460,12 +466,7 @@ async def start_daemon(start_cmd_parts: list[str]) -> bool:
return False return False
# Import shared functions from config_file_service # Shared functions from config_file_service are imported from app.utils.config_file_utils
_parse_jails_sync = config_file_service._parse_jails_sync
_build_inactive_jail = config_file_service._build_inactive_jail
_get_active_jail_names = config_file_service._get_active_jail_names
_validate_jail_config_sync = config_file_service._validate_jail_config_sync
_orderedconfig_files = config_file_service._ordered_config_files
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -624,7 +625,7 @@ async def activate_jail(
# Activation reload — if it fails, roll back immediately # # Activation reload — if it fails, roll back immediately #
# ---------------------------------------------------------------------- # # ---------------------------------------------------------------------- #
try: try:
await jail_service.reload_all(socket_path, include_jails=[name]) await reload_jails(socket_path, include_jails=[name])
except JailNotFoundError as exc: except JailNotFoundError as exc:
# Jail configuration is invalid (e.g. missing logpath that prevents # Jail configuration is invalid (e.g. missing logpath that prevents
# fail2ban from loading the jail). Roll back and provide a specific error. # fail2ban from loading the jail). Roll back and provide a specific error.
@@ -767,7 +768,7 @@ async def _rollback_activation_async(
# Step 2 — reload fail2ban with the restored config. # Step 2 — reload fail2ban with the restored config.
try: try:
await jail_service.reload_all(socket_path) await reload_jails(socket_path)
log.info("jail_activation_rollback_reload_ok", jail=name) log.info("jail_activation_rollback_reload_ok", jail=name)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning("jail_activation_rollback_reload_failed", jail=name, error=str(exc)) log.warning("jail_activation_rollback_reload_failed", jail=name, error=str(exc))
@@ -834,7 +835,7 @@ async def deactivate_jail(
) )
try: try:
await jail_service.reload_all(socket_path, exclude_jails=[name]) await reload_jails(socket_path, exclude_jails=[name])
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning("reload_after_deactivate_failed", jail=name, error=str(exc)) log.warning("reload_after_deactivate_failed", jail=name, error=str(exc))

View File

@@ -102,30 +102,20 @@ async def run_setup(
log.info("bangui_setup_completed") log.info("bangui_setup_completed")
from app.utils.setup_utils import (
get_map_color_thresholds as util_get_map_color_thresholds,
get_password_hash as util_get_password_hash,
set_map_color_thresholds as util_set_map_color_thresholds,
)
async def get_password_hash(db: aiosqlite.Connection) -> str | None: async def get_password_hash(db: aiosqlite.Connection) -> str | None:
"""Return the stored bcrypt password hash, or ``None`` if not set. """Return the stored bcrypt password hash, or ``None`` if not set."""
return await util_get_password_hash(db)
Args:
db: Active aiosqlite connection.
Returns:
The bcrypt hash string, or ``None``.
"""
return await settings_repo.get_setting(db, _KEY_PASSWORD_HASH)
async def get_timezone(db: aiosqlite.Connection) -> str: async def get_timezone(db: aiosqlite.Connection) -> str:
"""Return the configured IANA timezone string. """Return the configured IANA timezone string."""
Falls back to ``"UTC"`` when no timezone has been stored (e.g. before
setup completes or for legacy databases).
Args:
db: Active aiosqlite connection.
Returns:
An IANA timezone identifier such as ``"Europe/Berlin"`` or ``"UTC"``.
"""
tz = await settings_repo.get_setting(db, _KEY_TIMEZONE) tz = await settings_repo.get_setting(db, _KEY_TIMEZONE)
return tz if tz else "UTC" return tz if tz else "UTC"
@@ -133,31 +123,8 @@ async def get_timezone(db: aiosqlite.Connection) -> str:
async def get_map_color_thresholds( async def get_map_color_thresholds(
db: aiosqlite.Connection, db: aiosqlite.Connection,
) -> tuple[int, int, int]: ) -> tuple[int, int, int]:
"""Return the configured map color thresholds (high, medium, low). """Return the configured map color thresholds (high, medium, low)."""
return await util_get_map_color_thresholds(db)
Falls back to default values (100, 50, 20) if not set.
Args:
db: Active aiosqlite connection.
Returns:
A tuple of (threshold_high, threshold_medium, threshold_low).
"""
high = await settings_repo.get_setting(
db, _KEY_MAP_COLOR_THRESHOLD_HIGH
)
medium = await settings_repo.get_setting(
db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM
)
low = await settings_repo.get_setting(
db, _KEY_MAP_COLOR_THRESHOLD_LOW
)
return (
int(high) if high else 100,
int(medium) if medium else 50,
int(low) if low else 20,
)
async def set_map_color_thresholds( async def set_map_color_thresholds(
@@ -167,31 +134,12 @@ async def set_map_color_thresholds(
threshold_medium: int, threshold_medium: int,
threshold_low: int, threshold_low: int,
) -> None: ) -> None:
"""Update the map color threshold configuration. """Update the map color threshold configuration."""
await util_set_map_color_thresholds(
Args: db,
db: Active aiosqlite connection. threshold_high=threshold_high,
threshold_high: Ban count for red coloring. threshold_medium=threshold_medium,
threshold_medium: Ban count for yellow coloring. threshold_low=threshold_low,
threshold_low: Ban count for green coloring.
Raises:
ValueError: If thresholds are not positive integers or if
high <= medium <= low.
"""
if threshold_high <= 0 or threshold_medium <= 0 or threshold_low <= 0:
raise ValueError("All thresholds must be positive integers.")
if not (threshold_high > threshold_medium > threshold_low):
raise ValueError("Thresholds must satisfy: high > medium > low.")
await settings_repo.set_setting(
db, _KEY_MAP_COLOR_THRESHOLD_HIGH, str(threshold_high)
)
await settings_repo.set_setting(
db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, str(threshold_medium)
)
await settings_repo.set_setting(
db, _KEY_MAP_COLOR_THRESHOLD_LOW, str(threshold_low)
) )
log.info( log.info(
"map_color_thresholds_updated", "map_color_thresholds_updated",

View File

@@ -43,9 +43,15 @@ async def _run_import(app: Any) -> None:
http_session = app.state.http_session http_session = app.state.http_session
socket_path: str = app.state.settings.fail2ban_socket socket_path: str = app.state.settings.fail2ban_socket
from app.services import jail_service
log.info("blocklist_import_starting") log.info("blocklist_import_starting")
try: try:
result = await blocklist_service.import_all(db, http_session, socket_path) result = await blocklist_service.import_all(
db,
http_session,
socket_path,
)
log.info( log.info(
"blocklist_import_finished", "blocklist_import_finished",
total_imported=result.total_imported, total_imported=result.total_imported,

View File

@@ -0,0 +1,21 @@
"""Utilities re-exported from config_file_service for cross-module usage."""
from __future__ import annotations
from pathlib import Path
from app.services.config_file_service import (
_build_inactive_jail,
_get_active_jail_names,
_ordered_config_files,
_parse_jails_sync,
_validate_jail_config_sync,
)
__all__ = [
"_ordered_config_files",
"_parse_jails_sync",
"_build_inactive_jail",
"_get_active_jail_names",
"_validate_jail_config_sync",
]

View File

@@ -0,0 +1,20 @@
"""Jail helpers to decouple service layer dependencies."""
from __future__ import annotations
from collections.abc import Sequence
from app.services.jail_service import reload_all
async def reload_jails(
socket_path: str,
include_jails: Sequence[str] | None = None,
exclude_jails: Sequence[str] | None = None,
) -> None:
"""Reload fail2ban jails using shared jail service helper."""
await reload_all(
socket_path,
include_jails=list(include_jails) if include_jails is not None else None,
exclude_jails=list(exclude_jails) if exclude_jails is not None else None,
)

View File

@@ -0,0 +1,14 @@
"""Log-related helpers to avoid direct service-to-service imports."""
from __future__ import annotations
from app.models.config import LogPreviewRequest, LogPreviewResponse, RegexTestRequest, RegexTestResponse
from app.services.log_service import preview_log as _preview_log, test_regex as _test_regex
async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse:
return await _preview_log(req)
def test_regex(req: RegexTestRequest) -> RegexTestResponse:
return _test_regex(req)

View File

@@ -0,0 +1,47 @@
"""Setup-related utilities shared by multiple services."""
from __future__ import annotations
from app.repositories import settings_repo
_KEY_PASSWORD_HASH = "master_password_hash"
_KEY_SETUP_DONE = "setup_completed"
_KEY_MAP_COLOR_THRESHOLD_HIGH = "map_color_threshold_high"
_KEY_MAP_COLOR_THRESHOLD_MEDIUM = "map_color_threshold_medium"
_KEY_MAP_COLOR_THRESHOLD_LOW = "map_color_threshold_low"
async def get_password_hash(db):
"""Return the stored master password hash or None."""
return await settings_repo.get_setting(db, _KEY_PASSWORD_HASH)
async def get_map_color_thresholds(db):
"""Return map color thresholds as tuple (high, medium, low)."""
high = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_HIGH)
medium = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM)
low = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW)
return (
int(high) if high else 100,
int(medium) if medium else 50,
int(low) if low else 20,
)
async def set_map_color_thresholds(
db,
*,
threshold_high: int,
threshold_medium: int,
threshold_low: int,
) -> None:
"""Persist map color thresholds after validating values."""
if threshold_high <= 0 or threshold_medium <= 0 or threshold_low <= 0:
raise ValueError("All thresholds must be positive integers.")
if not (threshold_high > threshold_medium > threshold_low):
raise ValueError("Thresholds must satisfy: high > medium > low.")
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_HIGH, str(threshold_high))
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, str(threshold_medium))
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW, str(threshold_low))

View File

@@ -203,9 +203,15 @@ class TestImport:
call_count += 1 call_count += 1
raise JailNotFoundError(jail) raise JailNotFoundError(jail)
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found): with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found) as mocked_ban_ip:
from app.services import jail_service
result = await blocklist_service.import_source( result = await blocklist_service.import_source(
source, session, "/tmp/fake.sock", db source,
session,
"/tmp/fake.sock",
db,
ban_ip=jail_service.ban_ip,
) )
# Must abort after the first JailNotFoundError — only one ban attempt. # Must abort after the first JailNotFoundError — only one ban attempt.
@@ -226,7 +232,14 @@ class TestImport:
with patch( with patch(
"app.services.jail_service.ban_ip", new_callable=AsyncMock "app.services.jail_service.ban_ip", new_callable=AsyncMock
): ):
result = await blocklist_service.import_all(db, session, "/tmp/fake.sock") from app.services import jail_service
result = await blocklist_service.import_all(
db,
session,
"/tmp/fake.sock",
ban_ip=jail_service.ban_ip,
)
# Only S1 is enabled, S2 is disabled. # Only S1 is enabled, S2 is disabled.
assert len(result.results) == 1 assert len(result.results) == 1

View File

@@ -721,9 +721,11 @@ class TestGetServiceStatus:
def __init__(self, **_kw: Any) -> None: def __init__(self, **_kw: Any) -> None:
self.send = AsyncMock(side_effect=_send) self.send = AsyncMock(side_effect=_send)
with patch("app.services.config_service.Fail2BanClient", _FakeClient), \ with patch("app.services.config_service.Fail2BanClient", _FakeClient):
patch("app.services.health_service.probe", AsyncMock(return_value=online_status)): result = await config_service.get_service_status(
result = await config_service.get_service_status(_SOCKET) _SOCKET,
probe_fn=AsyncMock(return_value=online_status),
)
assert result.online is True assert result.online is True
assert result.version == "1.0.0" assert result.version == "1.0.0"
@@ -739,8 +741,10 @@ class TestGetServiceStatus:
offline_status = ServerStatus(online=False) offline_status = ServerStatus(online=False)
with patch("app.services.health_service.probe", AsyncMock(return_value=offline_status)): result = await config_service.get_service_status(
result = await config_service.get_service_status(_SOCKET) _SOCKET,
probe_fn=AsyncMock(return_value=offline_status),
)
assert result.online is False assert result.online is False
assert result.jail_count == 0 assert result.jail_count == 0

View File

@@ -0,0 +1,88 @@
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
import { renderHook, act } from "@testing-library/react";
import { useConfigItem } from "../useConfigItem";
describe("useConfigItem", () => {
beforeEach(() => {
vi.useFakeTimers();
});
afterEach(() => {
vi.useRealTimers();
vi.clearAllMocks();
});
it("loads data and sets loading state", async () => {
const fetchFn = vi.fn().mockResolvedValue("hello");
const saveFn = vi.fn().mockResolvedValue(undefined);
const { result } = renderHook(() => useConfigItem<string, string>({ fetchFn, saveFn }));
expect(result.current.loading).toBe(true);
await act(async () => {
await Promise.resolve();
});
expect(fetchFn).toHaveBeenCalled();
expect(result.current.data).toBe("hello");
expect(result.current.loading).toBe(false);
});
it("sets error if fetch rejects", async () => {
const fetchFn = vi.fn().mockRejectedValue(new Error("nope"));
const saveFn = vi.fn().mockResolvedValue(undefined);
const { result } = renderHook(() => useConfigItem<string, string>({ fetchFn, saveFn }));
await act(async () => {
await Promise.resolve();
});
expect(result.current.error).toBe("nope");
expect(result.current.loading).toBe(false);
});
it("save updates data when mergeOnSave is provided", async () => {
const fetchFn = vi.fn().mockResolvedValue({ value: 1 });
const saveFn = vi.fn().mockResolvedValue(undefined);
const { result } = renderHook(() =>
useConfigItem<{ value: number }, { delta: number }>({
fetchFn,
saveFn,
mergeOnSave: (prev, update) =>
prev ? { ...prev, value: prev.value + update.delta } : prev,
})
);
await act(async () => {
await Promise.resolve();
});
expect(result.current.data).toEqual({ value: 1 });
await act(async () => {
await result.current.save({ delta: 2 });
});
expect(saveFn).toHaveBeenCalledWith({ delta: 2 });
expect(result.current.data).toEqual({ value: 3 });
});
it("saveError is set when save fails", async () => {
const fetchFn = vi.fn().mockResolvedValue("ok");
const saveFn = vi.fn().mockRejectedValue(new Error("save failed"));
const { result } = renderHook(() => useConfigItem<string, string>({ fetchFn, saveFn }));
await act(async () => {
await Promise.resolve();
});
await act(async () => {
await expect(result.current.save("test")).rejects.toThrow("save failed");
});
expect(result.current.saveError).toBe("save failed");
});
});

View File

@@ -2,7 +2,7 @@
* React hook for loading and updating a single parsed action config. * React hook for loading and updating a single parsed action config.
*/ */
import { useCallback, useEffect, useRef, useState } from "react"; import { useConfigItem } from "./useConfigItem";
import { fetchAction, updateAction } from "../api/config"; import { fetchAction, updateAction } from "../api/config";
import type { ActionConfig, ActionConfigUpdate } from "../types/config"; import type { ActionConfig, ActionConfigUpdate } from "../types/config";
@@ -23,67 +23,28 @@ export interface UseActionConfigResult {
* @param name - Action base name (e.g. ``"iptables"``). * @param name - Action base name (e.g. ``"iptables"``).
*/ */
export function useActionConfig(name: string): UseActionConfigResult { export function useActionConfig(name: string): UseActionConfigResult {
const [config, setConfig] = useState<ActionConfig | null>(null); const { data, loading, error, saving, saveError, refresh, save } = useConfigItem<
const [loading, setLoading] = useState(true); ActionConfig,
const [error, setError] = useState<string | null>(null); ActionConfigUpdate
const [saving, setSaving] = useState(false); >({
const [saveError, setSaveError] = useState<string | null>(null); fetchFn: () => fetchAction(name),
const abortRef = useRef<AbortController | null>(null); saveFn: (update) => updateAction(name, update),
mergeOnSave: (prev, update) =>
const load = useCallback((): void => {
abortRef.current?.abort();
const ctrl = new AbortController();
abortRef.current = ctrl;
setLoading(true);
setError(null);
fetchAction(name)
.then((data) => {
if (!ctrl.signal.aborted) {
setConfig(data);
setLoading(false);
}
})
.catch((err: unknown) => {
if (!ctrl.signal.aborted) {
setError(err instanceof Error ? err.message : "Failed to load action config");
setLoading(false);
}
});
}, [name]);
useEffect(() => {
load();
return (): void => {
abortRef.current?.abort();
};
}, [load]);
const save = useCallback(
async (update: ActionConfigUpdate): Promise<void> => {
setSaving(true);
setSaveError(null);
try {
await updateAction(name, update);
setConfig((prev) =>
prev prev
? { ? {
...prev, ...prev,
...Object.fromEntries( ...Object.fromEntries(Object.entries(update).filter(([, v]) => v != null)),
Object.entries(update).filter(([, v]) => v !== null && v !== undefined)
),
} }
: prev : prev,
); });
} catch (err: unknown) {
setSaveError(err instanceof Error ? err.message : "Failed to save action config");
throw err;
} finally {
setSaving(false);
}
},
[name]
);
return { config, loading, error, saving, saveError, refresh: load, save }; return {
config: data,
loading,
error,
saving,
saveError,
refresh,
save,
};
} }

View File

@@ -0,0 +1,84 @@
/**
* Generic config hook for loading and saving a single entity.
*/
import { useCallback, useEffect, useRef, useState } from "react";
export interface UseConfigItemResult<T, U> {
data: T | null;
loading: boolean;
error: string | null;
saving: boolean;
saveError: string | null;
refresh: () => void;
save: (update: U) => Promise<void>;
}
export interface UseConfigItemOptions<T, U> {
fetchFn: (signal: AbortSignal) => Promise<T>;
saveFn: (update: U) => Promise<void>;
mergeOnSave?: (prev: T | null, update: U) => T | null;
}
export function useConfigItem<T, U>(
options: UseConfigItemOptions<T, U>
): UseConfigItemResult<T, U> {
const { fetchFn, saveFn, mergeOnSave } = options;
const [data, setData] = useState<T | null>(null);
const [loading, setLoading] = useState(true);
const [error, setError] = useState<string | null>(null);
const [saving, setSaving] = useState(false);
const [saveError, setSaveError] = useState<string | null>(null);
const abortRef = useRef<AbortController | null>(null);
const refresh = useCallback((): void => {
abortRef.current?.abort();
const controller = new AbortController();
abortRef.current = controller;
setLoading(true);
setError(null);
fetchFn(controller.signal)
.then((nextData) => {
if (controller.signal.aborted) return;
setData(nextData);
setLoading(false);
})
.catch((err: unknown) => {
if (controller.signal.aborted) return;
setError(err instanceof Error ? err.message : "Failed to load data");
setLoading(false);
});
}, [fetchFn]);
useEffect(() => {
refresh();
return () => {
abortRef.current?.abort();
};
}, [refresh]);
const save = useCallback(
async (update: U): Promise<void> => {
setSaving(true);
setSaveError(null);
try {
await saveFn(update);
if (mergeOnSave) {
setData((prevData) => mergeOnSave(prevData, update));
}
} catch (err: unknown) {
const message = err instanceof Error ? err.message : "Failed to save data";
setSaveError(message);
throw err;
} finally {
setSaving(false);
}
},
[saveFn, mergeOnSave]
);
return { data, loading, error, saving, saveError, refresh, save };
}

View File

@@ -2,7 +2,7 @@
* React hook for loading and updating a single parsed filter config. * React hook for loading and updating a single parsed filter config.
*/ */
import { useCallback, useEffect, useRef, useState } from "react"; import { useConfigItem } from "./useConfigItem";
import { fetchParsedFilter, updateParsedFilter } from "../api/config"; import { fetchParsedFilter, updateParsedFilter } from "../api/config";
import type { FilterConfig, FilterConfigUpdate } from "../types/config"; import type { FilterConfig, FilterConfigUpdate } from "../types/config";
@@ -23,69 +23,28 @@ export interface UseFilterConfigResult {
* @param name - Filter base name (e.g. ``"sshd"``). * @param name - Filter base name (e.g. ``"sshd"``).
*/ */
export function useFilterConfig(name: string): UseFilterConfigResult { export function useFilterConfig(name: string): UseFilterConfigResult {
const [config, setConfig] = useState<FilterConfig | null>(null); const { data, loading, error, saving, saveError, refresh, save } = useConfigItem<
const [loading, setLoading] = useState(true); FilterConfig,
const [error, setError] = useState<string | null>(null); FilterConfigUpdate
const [saving, setSaving] = useState(false); >({
const [saveError, setSaveError] = useState<string | null>(null); fetchFn: () => fetchParsedFilter(name),
const abortRef = useRef<AbortController | null>(null); saveFn: (update) => updateParsedFilter(name, update),
mergeOnSave: (prev, update) =>
const load = useCallback((): void => {
abortRef.current?.abort();
const ctrl = new AbortController();
abortRef.current = ctrl;
setLoading(true);
setError(null);
fetchParsedFilter(name)
.then((data) => {
if (!ctrl.signal.aborted) {
setConfig(data);
setLoading(false);
}
})
.catch((err: unknown) => {
if (!ctrl.signal.aborted) {
setError(err instanceof Error ? err.message : "Failed to load filter config");
setLoading(false);
}
});
}, [name]);
useEffect(() => {
load();
return (): void => {
abortRef.current?.abort();
};
}, [load]);
const save = useCallback(
async (update: FilterConfigUpdate): Promise<void> => {
setSaving(true);
setSaveError(null);
try {
await updateParsedFilter(name, update);
// Optimistically update local state so the form reflects changes
// without a full reload.
setConfig((prev) =>
prev prev
? { ? {
...prev, ...prev,
...Object.fromEntries( ...Object.fromEntries(Object.entries(update).filter(([, v]) => v != null)),
Object.entries(update).filter(([, v]) => v !== null && v !== undefined)
),
} }
: prev : prev,
); });
} catch (err: unknown) {
setSaveError(err instanceof Error ? err.message : "Failed to save filter config");
throw err;
} finally {
setSaving(false);
}
},
[name]
);
return { config, loading, error, saving, saveError, refresh: load, save }; return {
config: data,
loading,
error,
saving,
saveError,
refresh,
save,
};
} }

View File

@@ -2,7 +2,7 @@
* React hook for loading and updating a single parsed jail.d config file. * React hook for loading and updating a single parsed jail.d config file.
*/ */
import { useCallback, useEffect, useRef, useState } from "react"; import { useConfigItem } from "./useConfigItem";
import { fetchParsedJailFile, updateParsedJailFile } from "../api/config"; import { fetchParsedJailFile, updateParsedJailFile } from "../api/config";
import type { JailFileConfig, JailFileConfigUpdate } from "../types/config"; import type { JailFileConfig, JailFileConfigUpdate } from "../types/config";
@@ -21,56 +21,23 @@ export interface UseJailFileConfigResult {
* @param filename - Filename including extension (e.g. ``"sshd.conf"``). * @param filename - Filename including extension (e.g. ``"sshd.conf"``).
*/ */
export function useJailFileConfig(filename: string): UseJailFileConfigResult { export function useJailFileConfig(filename: string): UseJailFileConfigResult {
const [config, setConfig] = useState<JailFileConfig | null>(null); const { data, loading, error, refresh, save } = useConfigItem<
const [loading, setLoading] = useState(true); JailFileConfig,
const [error, setError] = useState<string | null>(null); JailFileConfigUpdate
const abortRef = useRef<AbortController | null>(null); >({
fetchFn: () => fetchParsedJailFile(filename),
const load = useCallback((): void => { saveFn: (update) => updateParsedJailFile(filename, update),
abortRef.current?.abort(); mergeOnSave: (prev, update) =>
const ctrl = new AbortController(); update.jails != null && prev
abortRef.current = ctrl; ? { ...prev, jails: { ...prev.jails, ...update.jails } }
setLoading(true); : prev,
setError(null);
fetchParsedJailFile(filename)
.then((data) => {
if (!ctrl.signal.aborted) {
setConfig(data);
setLoading(false);
}
})
.catch((err: unknown) => {
if (!ctrl.signal.aborted) {
setError(err instanceof Error ? err.message : "Failed to load jail file config");
setLoading(false);
}
}); });
}, [filename]);
useEffect(() => { return {
load(); config: data,
return (): void => { loading,
abortRef.current?.abort(); error,
refresh,
save,
}; };
}, [load]);
const save = useCallback(
async (update: JailFileConfigUpdate): Promise<void> => {
try {
await updateParsedJailFile(filename, update);
// Optimistically merge updated jails into local state.
if (update.jails != null) {
setConfig((prev) =>
prev ? { ...prev, jails: { ...prev.jails, ...update.jails } } : prev
);
}
} catch (err: unknown) {
throw err instanceof Error ? err : new Error("Failed to save jail file config");
}
},
[filename]
);
return { config, loading, error, refresh: load, save };
} }