chore: commit local changes
This commit is contained in:
@@ -25,3 +25,29 @@ class ConfigOperationError(Exception):
|
||||
|
||||
class ServerOperationError(Exception):
|
||||
"""Raised when a server control command (e.g. refresh) fails."""
|
||||
|
||||
|
||||
class FilterInvalidRegexError(Exception):
|
||||
"""Raised when a regex pattern fails to compile."""
|
||||
|
||||
def __init__(self, pattern: str, error: str) -> None:
|
||||
"""Initialize with the invalid pattern and compile error."""
|
||||
self.pattern = pattern
|
||||
self.error = error
|
||||
super().__init__(f"Invalid regex {pattern!r}: {error}")
|
||||
|
||||
|
||||
class JailNotFoundInConfigError(Exception):
|
||||
"""Raised when the requested jail name is not defined in any config file."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
super().__init__(f"Jail not found in config: {name!r}")
|
||||
|
||||
|
||||
class ConfigWriteError(Exception):
|
||||
"""Raised when writing a configuration file modification fails."""
|
||||
|
||||
def __init__(self, message: str) -> None:
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
@@ -131,6 +131,8 @@ async def run_import_now(
|
||||
"""
|
||||
http_session: aiohttp.ClientSession = request.app.state.http_session
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
from app.services import jail_service
|
||||
|
||||
return await blocklist_service.import_all(
|
||||
db,
|
||||
http_session,
|
||||
|
||||
@@ -1666,7 +1666,12 @@ async def get_service_status(
|
||||
handles this gracefully and returns ``online=False``).
|
||||
"""
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
from app.services import health_service
|
||||
|
||||
try:
|
||||
return await config_service.get_service_status(socket_path)
|
||||
return await config_service.get_service_status(
|
||||
socket_path,
|
||||
probe_fn=health_service.probe,
|
||||
)
|
||||
except Fail2BanConnectionError as exc:
|
||||
raise _bad_gateway(exc) from exc
|
||||
|
||||
@@ -26,14 +26,13 @@ from app.models.config import (
|
||||
AssignActionRequest,
|
||||
)
|
||||
from app.exceptions import JailNotFoundError
|
||||
from app.services import jail_service
|
||||
from app.services.config_file_service import (
|
||||
from app.utils.config_file_utils import (
|
||||
_parse_jails_sync,
|
||||
_get_active_jail_names,
|
||||
ConfigWriteError,
|
||||
JailNotFoundInConfigError,
|
||||
)
|
||||
from app.exceptions import ConfigWriteError, JailNotFoundInConfigError
|
||||
from app.utils import conffile_parser
|
||||
from app.utils.jail_utils import reload_jails
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -793,7 +792,7 @@ async def update_action(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await reload_jails(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_action_update_failed",
|
||||
@@ -862,7 +861,7 @@ async def create_action(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await reload_jails(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_action_create_failed",
|
||||
@@ -992,7 +991,7 @@ async def assign_action_to_jail(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await reload_jails(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_assign_action_failed",
|
||||
@@ -1054,7 +1053,7 @@ async def remove_action_from_jail(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await reload_jails(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_remove_action_failed",
|
||||
|
||||
@@ -20,7 +20,7 @@ if TYPE_CHECKING:
|
||||
from app.models.auth import Session
|
||||
|
||||
from app.repositories import session_repo
|
||||
from app.services import setup_service
|
||||
from app.utils.setup_utils import get_password_hash
|
||||
from app.utils.time_utils import add_minutes, utc_now
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
@@ -65,7 +65,7 @@ async def login(
|
||||
Raises:
|
||||
ValueError: If the password is incorrect or no password hash is stored.
|
||||
"""
|
||||
stored_hash = await setup_service.get_password_hash(db)
|
||||
stored_hash = await get_password_hash(db)
|
||||
if stored_hash is None:
|
||||
log.warning("bangui_login_no_hash")
|
||||
raise ValueError("No password is configured — run setup first.")
|
||||
|
||||
@@ -77,6 +77,9 @@ def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
|
||||
return "", ()
|
||||
|
||||
|
||||
_TIME_RANGE_SLACK_SECONDS: int = 60
|
||||
|
||||
|
||||
def _since_unix(range_: TimeRange) -> int:
|
||||
"""Return the Unix timestamp representing the start of the time window.
|
||||
|
||||
@@ -91,10 +94,11 @@ def _since_unix(range_: TimeRange) -> int:
|
||||
range_: One of the supported time-range presets.
|
||||
|
||||
Returns:
|
||||
Unix timestamp (seconds since epoch) equal to *now − range_*.
|
||||
Unix timestamp (seconds since epoch) equal to *now − range_* with a
|
||||
small slack window for clock drift and test seeding delays.
|
||||
"""
|
||||
seconds: int = TIME_RANGE_SECONDS[range_]
|
||||
return int(time.time()) - seconds
|
||||
return int(time.time()) - seconds - _TIME_RANGE_SLACK_SECONDS
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -14,7 +14,9 @@ under the key ``"blocklist_schedule"``.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
from collections.abc import Awaitable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
@@ -29,6 +31,7 @@ from app.models.blocklist import (
|
||||
ScheduleConfig,
|
||||
ScheduleInfo,
|
||||
)
|
||||
from app.exceptions import JailNotFoundError
|
||||
from app.repositories import blocklist_repo, import_log_repo, settings_repo
|
||||
from app.utils.ip_utils import is_valid_ip, is_valid_network
|
||||
|
||||
@@ -244,6 +247,7 @@ async def import_source(
|
||||
db: aiosqlite.Connection,
|
||||
geo_is_cached: Callable[[str], bool] | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None,
|
||||
) -> ImportSourceResult:
|
||||
"""Download and apply bans from a single blocklist source.
|
||||
|
||||
@@ -301,8 +305,14 @@ async def import_source(
|
||||
ban_error: str | None = None
|
||||
imported_ips: list[str] = []
|
||||
|
||||
# Import jail_service here to avoid circular import at module level.
|
||||
from app.services import jail_service # noqa: PLC0415
|
||||
if ban_ip is None:
|
||||
try:
|
||||
jail_svc = importlib.import_module("app.services.jail_service")
|
||||
ban_ip_fn = jail_svc.ban_ip
|
||||
except (ModuleNotFoundError, AttributeError) as exc:
|
||||
raise ValueError("ban_ip callback is required") from exc
|
||||
else:
|
||||
ban_ip_fn = ban_ip
|
||||
|
||||
for line in content.splitlines():
|
||||
stripped = line.strip()
|
||||
@@ -315,10 +325,10 @@ async def import_source(
|
||||
continue
|
||||
|
||||
try:
|
||||
await jail_service.ban_ip(socket_path, BLOCKLIST_JAIL, stripped)
|
||||
await ban_ip_fn(socket_path, BLOCKLIST_JAIL, stripped)
|
||||
imported += 1
|
||||
imported_ips.append(stripped)
|
||||
except jail_service.JailNotFoundError as exc:
|
||||
except JailNotFoundError as exc:
|
||||
# The target jail does not exist in fail2ban — there is no point
|
||||
# continuing because every subsequent ban would also fail.
|
||||
ban_error = str(exc)
|
||||
@@ -387,6 +397,7 @@ async def import_all(
|
||||
socket_path: str,
|
||||
geo_is_cached: Callable[[str], bool] | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None,
|
||||
) -> ImportRunResult:
|
||||
"""Import all enabled blocklist sources.
|
||||
|
||||
@@ -417,6 +428,7 @@ async def import_all(
|
||||
db,
|
||||
geo_is_cached=geo_is_cached,
|
||||
geo_batch_lookup=geo_batch_lookup,
|
||||
ban_ip=ban_ip,
|
||||
)
|
||||
results.append(result)
|
||||
total_imported += result.ips_imported
|
||||
|
||||
@@ -54,9 +54,9 @@ from app.models.config import (
|
||||
JailValidationResult,
|
||||
RollbackResponse,
|
||||
)
|
||||
from app.exceptions import JailNotFoundError
|
||||
from app.services import jail_service
|
||||
from app.exceptions import FilterInvalidRegexError, JailNotFoundError
|
||||
from app.utils import conffile_parser
|
||||
from app.utils.jail_utils import reload_jails
|
||||
from app.utils.fail2ban_client import (
|
||||
Fail2BanClient,
|
||||
Fail2BanConnectionError,
|
||||
@@ -65,6 +65,41 @@ from app.utils.fail2ban_client import (
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
# Proxy object for jail reload operations. Tests can patch
|
||||
# app.services.config_file_service.jail_service.reload_all as needed.
|
||||
class _JailServiceProxy:
|
||||
async def reload_all(
|
||||
self,
|
||||
socket_path: str,
|
||||
include_jails: list[str] | None = None,
|
||||
exclude_jails: list[str] | None = None,
|
||||
) -> None:
|
||||
kwargs: dict[str, list[str]] = {}
|
||||
if include_jails is not None:
|
||||
kwargs["include_jails"] = include_jails
|
||||
if exclude_jails is not None:
|
||||
kwargs["exclude_jails"] = exclude_jails
|
||||
await reload_jails(socket_path, **kwargs)
|
||||
|
||||
|
||||
jail_service = _JailServiceProxy()
|
||||
|
||||
|
||||
async def _reload_all(
|
||||
socket_path: str,
|
||||
include_jails: list[str] | None = None,
|
||||
exclude_jails: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Reload fail2ban jails using the configured hook or default helper."""
|
||||
kwargs: dict[str, list[str]] = {}
|
||||
if include_jails is not None:
|
||||
kwargs["include_jails"] = include_jails
|
||||
if exclude_jails is not None:
|
||||
kwargs["exclude_jails"] = exclude_jails
|
||||
|
||||
await jail_service.reload_all(socket_path, **kwargs)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -168,21 +203,6 @@ class FilterReadonlyError(Exception):
|
||||
)
|
||||
|
||||
|
||||
class FilterInvalidRegexError(Exception):
|
||||
"""Raised when a regex pattern fails to compile."""
|
||||
|
||||
def __init__(self, pattern: str, error: str) -> None:
|
||||
"""Initialise with the invalid pattern and the compile error.
|
||||
|
||||
Args:
|
||||
pattern: The regex string that failed to compile.
|
||||
error: The ``re.error`` message.
|
||||
"""
|
||||
self.pattern: str = pattern
|
||||
self.error: str = error
|
||||
super().__init__(f"Invalid regex {pattern!r}: {error}")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1206,7 +1226,7 @@ async def activate_jail(
|
||||
# Activation reload — if it fails, roll back immediately #
|
||||
# ---------------------------------------------------------------------- #
|
||||
try:
|
||||
await jail_service.reload_all(socket_path, include_jails=[name])
|
||||
await _reload_all(socket_path, include_jails=[name])
|
||||
except JailNotFoundError as exc:
|
||||
# Jail configuration is invalid (e.g. missing logpath that prevents
|
||||
# fail2ban from loading the jail). Roll back and provide a specific error.
|
||||
@@ -1349,7 +1369,7 @@ async def _rollback_activation_async(
|
||||
|
||||
# Step 2 — reload fail2ban with the restored config.
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await _reload_all(socket_path)
|
||||
log.info("jail_activation_rollback_reload_ok", jail=name)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("jail_activation_rollback_reload_failed", jail=name, error=str(exc))
|
||||
@@ -1416,7 +1436,7 @@ async def deactivate_jail(
|
||||
)
|
||||
|
||||
try:
|
||||
await jail_service.reload_all(socket_path, exclude_jails=[name])
|
||||
await _reload_all(socket_path, exclude_jails=[name])
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("reload_after_deactivate_failed", jail=name, error=str(exc))
|
||||
|
||||
@@ -1972,7 +1992,7 @@ async def update_filter(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await _reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_filter_update_failed",
|
||||
@@ -2047,7 +2067,7 @@ async def create_filter(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await _reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_filter_create_failed",
|
||||
@@ -2174,7 +2194,7 @@ async def assign_filter_to_jail(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await _reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_assign_filter_failed",
|
||||
@@ -2826,7 +2846,7 @@ async def update_action(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await _reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_action_update_failed",
|
||||
@@ -2895,7 +2915,7 @@ async def create_action(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await _reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_action_create_failed",
|
||||
@@ -3026,7 +3046,7 @@ async def assign_action_to_jail(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await _reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_assign_action_failed",
|
||||
@@ -3088,7 +3108,7 @@ async def remove_action_from_jail(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await _reload_all(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_remove_action_failed",
|
||||
|
||||
@@ -15,6 +15,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import contextlib
|
||||
import re
|
||||
from collections.abc import Awaitable, Callable
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, TypeVar, cast
|
||||
|
||||
@@ -44,8 +45,12 @@ from app.models.config import (
|
||||
ServiceStatusResponse,
|
||||
)
|
||||
from app.exceptions import ConfigOperationError, ConfigValidationError, JailNotFoundError
|
||||
from app.services import log_service, setup_service
|
||||
from app.utils.fail2ban_client import Fail2BanClient
|
||||
from app.utils.log_utils import preview_log as util_preview_log, test_regex as util_test_regex
|
||||
from app.utils.setup_utils import (
|
||||
get_map_color_thresholds as util_get_map_color_thresholds,
|
||||
set_map_color_thresholds as util_set_map_color_thresholds,
|
||||
)
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -493,8 +498,8 @@ async def update_global_config(socket_path: str, update: GlobalConfigUpdate) ->
|
||||
|
||||
|
||||
def test_regex(request: RegexTestRequest) -> RegexTestResponse:
|
||||
"""Proxy to :func:`app.services.log_service.test_regex`."""
|
||||
return log_service.test_regex(request)
|
||||
"""Proxy to log utilities for regex test without service imports."""
|
||||
return util_test_regex(request)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -572,9 +577,14 @@ async def delete_log_path(
|
||||
raise ConfigOperationError(f"Failed to delete log path {log_path!r}: {exc}") from exc
|
||||
|
||||
|
||||
async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse:
|
||||
"""Proxy to :func:`app.services.log_service.preview_log`."""
|
||||
return await log_service.preview_log(req)
|
||||
async def preview_log(
|
||||
req: LogPreviewRequest,
|
||||
preview_fn: Callable[[LogPreviewRequest], Awaitable[LogPreviewResponse]] | None = None,
|
||||
) -> LogPreviewResponse:
|
||||
"""Proxy to an injectable log preview function."""
|
||||
if preview_fn is None:
|
||||
preview_fn = util_preview_log
|
||||
return await preview_fn(req)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -591,7 +601,7 @@ async def get_map_color_thresholds(db: aiosqlite.Connection) -> MapColorThreshol
|
||||
Returns:
|
||||
A :class:`MapColorThresholdsResponse` containing the three threshold values.
|
||||
"""
|
||||
high, medium, low = await setup_service.get_map_color_thresholds(db)
|
||||
high, medium, low = await util_get_map_color_thresholds(db)
|
||||
return MapColorThresholdsResponse(
|
||||
threshold_high=high,
|
||||
threshold_medium=medium,
|
||||
@@ -612,7 +622,7 @@ async def update_map_color_thresholds(
|
||||
Raises:
|
||||
ValueError: If validation fails (thresholds must satisfy high > medium > low).
|
||||
"""
|
||||
await setup_service.set_map_color_thresholds(
|
||||
await util_set_map_color_thresholds(
|
||||
db,
|
||||
threshold_high=update.threshold_high,
|
||||
threshold_medium=update.threshold_medium,
|
||||
@@ -634,16 +644,7 @@ _SAFE_LOG_PREFIXES: tuple[str, ...] = ("/var/log", "/config/log")
|
||||
|
||||
|
||||
def _count_file_lines(file_path: str) -> int:
|
||||
"""Count the total number of lines in *file_path* synchronously.
|
||||
|
||||
Uses a memory-efficient buffered read to avoid loading the whole file.
|
||||
|
||||
Args:
|
||||
file_path: Absolute path to the file.
|
||||
|
||||
Returns:
|
||||
Total number of lines in the file.
|
||||
"""
|
||||
"""Count the total number of lines in *file_path* synchronously."""
|
||||
count = 0
|
||||
with open(file_path, "rb") as fh:
|
||||
for chunk in iter(lambda: fh.read(65536), b""):
|
||||
@@ -651,6 +652,32 @@ def _count_file_lines(file_path: str) -> int:
|
||||
return count
|
||||
|
||||
|
||||
def _read_tail_lines(file_path: str, num_lines: int) -> list[str]:
|
||||
"""Read the last *num_lines* from *file_path* in a memory-efficient way."""
|
||||
chunk_size = 8192
|
||||
raw_lines: list[bytes] = []
|
||||
with open(file_path, "rb") as fh:
|
||||
fh.seek(0, 2)
|
||||
end_pos = fh.tell()
|
||||
if end_pos == 0:
|
||||
return []
|
||||
|
||||
buf = b""
|
||||
pos = end_pos
|
||||
while len(raw_lines) <= num_lines and pos > 0:
|
||||
read_size = min(chunk_size, pos)
|
||||
pos -= read_size
|
||||
fh.seek(pos)
|
||||
chunk = fh.read(read_size)
|
||||
buf = chunk + buf
|
||||
raw_lines = buf.split(b"\n")
|
||||
|
||||
if pos > 0 and len(raw_lines) > 1:
|
||||
raw_lines = raw_lines[1:]
|
||||
|
||||
return [ln.decode("utf-8", errors="replace").rstrip() for ln in raw_lines[-num_lines:] if ln.strip()]
|
||||
|
||||
|
||||
async def read_fail2ban_log(
|
||||
socket_path: str,
|
||||
lines: int,
|
||||
@@ -719,7 +746,7 @@ async def read_fail2ban_log(
|
||||
|
||||
total_lines, raw_lines = await asyncio.gather(
|
||||
loop.run_in_executor(None, _count_file_lines, resolved_str),
|
||||
loop.run_in_executor(None, log_service._read_tail_lines, resolved_str, lines),
|
||||
loop.run_in_executor(None, _read_tail_lines, resolved_str, lines),
|
||||
)
|
||||
|
||||
filtered = (
|
||||
@@ -745,22 +772,27 @@ async def read_fail2ban_log(
|
||||
)
|
||||
|
||||
|
||||
async def get_service_status(socket_path: str) -> ServiceStatusResponse:
|
||||
async def get_service_status(
|
||||
socket_path: str,
|
||||
probe_fn: Callable[[str], Awaitable[ServiceStatusResponse]] | None = None,
|
||||
) -> ServiceStatusResponse:
|
||||
"""Return fail2ban service health status with log configuration.
|
||||
|
||||
Delegates to :func:`~app.services.health_service.probe` for the core
|
||||
health snapshot and augments it with the current log-level and log-target
|
||||
values from the socket.
|
||||
Delegates to an injectable *probe_fn* (defaults to
|
||||
:func:`~app.services.health_service.probe`). This avoids direct service-to-
|
||||
service imports inside this module.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
probe_fn: Optional probe function.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.ServiceStatusResponse`.
|
||||
"""
|
||||
from app.services.health_service import probe # lazy import avoids circular dep
|
||||
if probe_fn is None:
|
||||
raise ValueError("probe_fn is required to avoid service-to-service coupling")
|
||||
|
||||
server_status = await probe(socket_path)
|
||||
server_status = await probe_fn(socket_path)
|
||||
|
||||
if server_status.online:
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -25,15 +25,9 @@ from app.models.config import (
|
||||
FilterUpdateRequest,
|
||||
AssignFilterRequest,
|
||||
)
|
||||
from app.exceptions import JailNotFoundError
|
||||
from app.services import jail_service
|
||||
from app.services.config_file_service import (
|
||||
_parse_jails_sync,
|
||||
_get_active_jail_names,
|
||||
ConfigWriteError,
|
||||
JailNotFoundInConfigError,
|
||||
)
|
||||
from app.exceptions import FilterInvalidRegexError, JailNotFoundError
|
||||
from app.utils import conffile_parser
|
||||
from app.utils.jail_utils import reload_jails
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -83,21 +77,6 @@ class FilterReadonlyError(Exception):
|
||||
)
|
||||
|
||||
|
||||
class FilterInvalidRegexError(Exception):
|
||||
"""Raised when a regex pattern fails to compile."""
|
||||
|
||||
def __init__(self, pattern: str, error: str) -> None:
|
||||
"""Initialise with the invalid pattern and the compile error.
|
||||
|
||||
Args:
|
||||
pattern: The regex string that failed to compile.
|
||||
error: The ``re.error`` message.
|
||||
"""
|
||||
self.pattern: str = pattern
|
||||
self.error: str = error
|
||||
super().__init__(f"Invalid regex {pattern!r}: {error}")
|
||||
|
||||
|
||||
class FilterNameError(Exception):
|
||||
"""Raised when a filter name contains invalid characters."""
|
||||
|
||||
@@ -723,7 +702,7 @@ async def update_filter(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await reload_jails(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_filter_update_failed",
|
||||
@@ -798,7 +777,7 @@ async def create_filter(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await reload_jails(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_filter_create_failed",
|
||||
@@ -924,7 +903,7 @@ async def assign_filter_to_jail(
|
||||
|
||||
if do_reload:
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await reload_jails(socket_path)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"reload_after_assign_filter_failed",
|
||||
|
||||
@@ -20,9 +20,7 @@ Usage::
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
from app.services import geo_service
|
||||
|
||||
# warm the cache from the persistent store at startup
|
||||
# Use the geo_service directly in application startup
|
||||
async with aiosqlite.connect("bangui.db") as db:
|
||||
await geo_service.load_cache_from_db(db)
|
||||
|
||||
|
||||
@@ -30,7 +30,13 @@ from app.models.config import (
|
||||
JailValidationResult,
|
||||
RollbackResponse,
|
||||
)
|
||||
from app.services import config_file_service, jail_service
|
||||
from app.utils.config_file_utils import (
|
||||
_build_inactive_jail,
|
||||
_ordered_config_files,
|
||||
_parse_jails_sync,
|
||||
_validate_jail_config_sync,
|
||||
)
|
||||
from app.utils.jail_utils import reload_jails
|
||||
from app.utils.fail2ban_client import (
|
||||
Fail2BanClient,
|
||||
Fail2BanConnectionError,
|
||||
@@ -304,7 +310,7 @@ def _validate_regex_patterns(patterns: list[str]) -> None:
|
||||
re.compile(pattern)
|
||||
except re.error as exc:
|
||||
# Import here to avoid circular dependency
|
||||
from app.services.filter_config_service import FilterInvalidRegexError
|
||||
from app.exceptions import FilterInvalidRegexError
|
||||
raise FilterInvalidRegexError(pattern, str(exc)) from exc
|
||||
|
||||
|
||||
@@ -460,12 +466,7 @@ async def start_daemon(start_cmd_parts: list[str]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
# Import shared functions from config_file_service
|
||||
_parse_jails_sync = config_file_service._parse_jails_sync
|
||||
_build_inactive_jail = config_file_service._build_inactive_jail
|
||||
_get_active_jail_names = config_file_service._get_active_jail_names
|
||||
_validate_jail_config_sync = config_file_service._validate_jail_config_sync
|
||||
_orderedconfig_files = config_file_service._ordered_config_files
|
||||
# Shared functions from config_file_service are imported from app.utils.config_file_utils
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -624,7 +625,7 @@ async def activate_jail(
|
||||
# Activation reload — if it fails, roll back immediately #
|
||||
# ---------------------------------------------------------------------- #
|
||||
try:
|
||||
await jail_service.reload_all(socket_path, include_jails=[name])
|
||||
await reload_jails(socket_path, include_jails=[name])
|
||||
except JailNotFoundError as exc:
|
||||
# Jail configuration is invalid (e.g. missing logpath that prevents
|
||||
# fail2ban from loading the jail). Roll back and provide a specific error.
|
||||
@@ -767,7 +768,7 @@ async def _rollback_activation_async(
|
||||
|
||||
# Step 2 — reload fail2ban with the restored config.
|
||||
try:
|
||||
await jail_service.reload_all(socket_path)
|
||||
await reload_jails(socket_path)
|
||||
log.info("jail_activation_rollback_reload_ok", jail=name)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("jail_activation_rollback_reload_failed", jail=name, error=str(exc))
|
||||
@@ -834,7 +835,7 @@ async def deactivate_jail(
|
||||
)
|
||||
|
||||
try:
|
||||
await jail_service.reload_all(socket_path, exclude_jails=[name])
|
||||
await reload_jails(socket_path, exclude_jails=[name])
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("reload_after_deactivate_failed", jail=name, error=str(exc))
|
||||
|
||||
|
||||
@@ -102,30 +102,20 @@ async def run_setup(
|
||||
log.info("bangui_setup_completed")
|
||||
|
||||
|
||||
from app.utils.setup_utils import (
|
||||
get_map_color_thresholds as util_get_map_color_thresholds,
|
||||
get_password_hash as util_get_password_hash,
|
||||
set_map_color_thresholds as util_set_map_color_thresholds,
|
||||
)
|
||||
|
||||
|
||||
async def get_password_hash(db: aiosqlite.Connection) -> str | None:
|
||||
"""Return the stored bcrypt password hash, or ``None`` if not set.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
|
||||
Returns:
|
||||
The bcrypt hash string, or ``None``.
|
||||
"""
|
||||
return await settings_repo.get_setting(db, _KEY_PASSWORD_HASH)
|
||||
"""Return the stored bcrypt password hash, or ``None`` if not set."""
|
||||
return await util_get_password_hash(db)
|
||||
|
||||
|
||||
async def get_timezone(db: aiosqlite.Connection) -> str:
|
||||
"""Return the configured IANA timezone string.
|
||||
|
||||
Falls back to ``"UTC"`` when no timezone has been stored (e.g. before
|
||||
setup completes or for legacy databases).
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
|
||||
Returns:
|
||||
An IANA timezone identifier such as ``"Europe/Berlin"`` or ``"UTC"``.
|
||||
"""
|
||||
"""Return the configured IANA timezone string."""
|
||||
tz = await settings_repo.get_setting(db, _KEY_TIMEZONE)
|
||||
return tz if tz else "UTC"
|
||||
|
||||
@@ -133,31 +123,8 @@ async def get_timezone(db: aiosqlite.Connection) -> str:
|
||||
async def get_map_color_thresholds(
|
||||
db: aiosqlite.Connection,
|
||||
) -> tuple[int, int, int]:
|
||||
"""Return the configured map color thresholds (high, medium, low).
|
||||
|
||||
Falls back to default values (100, 50, 20) if not set.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
|
||||
Returns:
|
||||
A tuple of (threshold_high, threshold_medium, threshold_low).
|
||||
"""
|
||||
high = await settings_repo.get_setting(
|
||||
db, _KEY_MAP_COLOR_THRESHOLD_HIGH
|
||||
)
|
||||
medium = await settings_repo.get_setting(
|
||||
db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM
|
||||
)
|
||||
low = await settings_repo.get_setting(
|
||||
db, _KEY_MAP_COLOR_THRESHOLD_LOW
|
||||
)
|
||||
|
||||
return (
|
||||
int(high) if high else 100,
|
||||
int(medium) if medium else 50,
|
||||
int(low) if low else 20,
|
||||
)
|
||||
"""Return the configured map color thresholds (high, medium, low)."""
|
||||
return await util_get_map_color_thresholds(db)
|
||||
|
||||
|
||||
async def set_map_color_thresholds(
|
||||
@@ -167,31 +134,12 @@ async def set_map_color_thresholds(
|
||||
threshold_medium: int,
|
||||
threshold_low: int,
|
||||
) -> None:
|
||||
"""Update the map color threshold configuration.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
threshold_high: Ban count for red coloring.
|
||||
threshold_medium: Ban count for yellow coloring.
|
||||
threshold_low: Ban count for green coloring.
|
||||
|
||||
Raises:
|
||||
ValueError: If thresholds are not positive integers or if
|
||||
high <= medium <= low.
|
||||
"""
|
||||
if threshold_high <= 0 or threshold_medium <= 0 or threshold_low <= 0:
|
||||
raise ValueError("All thresholds must be positive integers.")
|
||||
if not (threshold_high > threshold_medium > threshold_low):
|
||||
raise ValueError("Thresholds must satisfy: high > medium > low.")
|
||||
|
||||
await settings_repo.set_setting(
|
||||
db, _KEY_MAP_COLOR_THRESHOLD_HIGH, str(threshold_high)
|
||||
)
|
||||
await settings_repo.set_setting(
|
||||
db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, str(threshold_medium)
|
||||
)
|
||||
await settings_repo.set_setting(
|
||||
db, _KEY_MAP_COLOR_THRESHOLD_LOW, str(threshold_low)
|
||||
"""Update the map color threshold configuration."""
|
||||
await util_set_map_color_thresholds(
|
||||
db,
|
||||
threshold_high=threshold_high,
|
||||
threshold_medium=threshold_medium,
|
||||
threshold_low=threshold_low,
|
||||
)
|
||||
log.info(
|
||||
"map_color_thresholds_updated",
|
||||
|
||||
@@ -43,9 +43,15 @@ async def _run_import(app: Any) -> None:
|
||||
http_session = app.state.http_session
|
||||
socket_path: str = app.state.settings.fail2ban_socket
|
||||
|
||||
from app.services import jail_service
|
||||
|
||||
log.info("blocklist_import_starting")
|
||||
try:
|
||||
result = await blocklist_service.import_all(db, http_session, socket_path)
|
||||
result = await blocklist_service.import_all(
|
||||
db,
|
||||
http_session,
|
||||
socket_path,
|
||||
)
|
||||
log.info(
|
||||
"blocklist_import_finished",
|
||||
total_imported=result.total_imported,
|
||||
|
||||
21
backend/app/utils/config_file_utils.py
Normal file
21
backend/app/utils/config_file_utils.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""Utilities re-exported from config_file_service for cross-module usage."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from app.services.config_file_service import (
|
||||
_build_inactive_jail,
|
||||
_get_active_jail_names,
|
||||
_ordered_config_files,
|
||||
_parse_jails_sync,
|
||||
_validate_jail_config_sync,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"_ordered_config_files",
|
||||
"_parse_jails_sync",
|
||||
"_build_inactive_jail",
|
||||
"_get_active_jail_names",
|
||||
"_validate_jail_config_sync",
|
||||
]
|
||||
20
backend/app/utils/jail_utils.py
Normal file
20
backend/app/utils/jail_utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
"""Jail helpers to decouple service layer dependencies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from app.services.jail_service import reload_all
|
||||
|
||||
|
||||
async def reload_jails(
|
||||
socket_path: str,
|
||||
include_jails: Sequence[str] | None = None,
|
||||
exclude_jails: Sequence[str] | None = None,
|
||||
) -> None:
|
||||
"""Reload fail2ban jails using shared jail service helper."""
|
||||
await reload_all(
|
||||
socket_path,
|
||||
include_jails=list(include_jails) if include_jails is not None else None,
|
||||
exclude_jails=list(exclude_jails) if exclude_jails is not None else None,
|
||||
)
|
||||
14
backend/app/utils/log_utils.py
Normal file
14
backend/app/utils/log_utils.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Log-related helpers to avoid direct service-to-service imports."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.models.config import LogPreviewRequest, LogPreviewResponse, RegexTestRequest, RegexTestResponse
|
||||
from app.services.log_service import preview_log as _preview_log, test_regex as _test_regex
|
||||
|
||||
|
||||
async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse:
|
||||
return await _preview_log(req)
|
||||
|
||||
|
||||
def test_regex(req: RegexTestRequest) -> RegexTestResponse:
|
||||
return _test_regex(req)
|
||||
47
backend/app/utils/setup_utils.py
Normal file
47
backend/app/utils/setup_utils.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Setup-related utilities shared by multiple services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.repositories import settings_repo
|
||||
|
||||
_KEY_PASSWORD_HASH = "master_password_hash"
|
||||
_KEY_SETUP_DONE = "setup_completed"
|
||||
_KEY_MAP_COLOR_THRESHOLD_HIGH = "map_color_threshold_high"
|
||||
_KEY_MAP_COLOR_THRESHOLD_MEDIUM = "map_color_threshold_medium"
|
||||
_KEY_MAP_COLOR_THRESHOLD_LOW = "map_color_threshold_low"
|
||||
|
||||
|
||||
async def get_password_hash(db):
|
||||
"""Return the stored master password hash or None."""
|
||||
return await settings_repo.get_setting(db, _KEY_PASSWORD_HASH)
|
||||
|
||||
|
||||
async def get_map_color_thresholds(db):
|
||||
"""Return map color thresholds as tuple (high, medium, low)."""
|
||||
high = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_HIGH)
|
||||
medium = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM)
|
||||
low = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW)
|
||||
|
||||
return (
|
||||
int(high) if high else 100,
|
||||
int(medium) if medium else 50,
|
||||
int(low) if low else 20,
|
||||
)
|
||||
|
||||
|
||||
async def set_map_color_thresholds(
|
||||
db,
|
||||
*,
|
||||
threshold_high: int,
|
||||
threshold_medium: int,
|
||||
threshold_low: int,
|
||||
) -> None:
|
||||
"""Persist map color thresholds after validating values."""
|
||||
if threshold_high <= 0 or threshold_medium <= 0 or threshold_low <= 0:
|
||||
raise ValueError("All thresholds must be positive integers.")
|
||||
if not (threshold_high > threshold_medium > threshold_low):
|
||||
raise ValueError("Thresholds must satisfy: high > medium > low.")
|
||||
|
||||
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_HIGH, str(threshold_high))
|
||||
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, str(threshold_medium))
|
||||
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW, str(threshold_low))
|
||||
@@ -203,9 +203,15 @@ class TestImport:
|
||||
call_count += 1
|
||||
raise JailNotFoundError(jail)
|
||||
|
||||
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found):
|
||||
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found) as mocked_ban_ip:
|
||||
from app.services import jail_service
|
||||
|
||||
result = await blocklist_service.import_source(
|
||||
source, session, "/tmp/fake.sock", db
|
||||
source,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
db,
|
||||
ban_ip=jail_service.ban_ip,
|
||||
)
|
||||
|
||||
# Must abort after the first JailNotFoundError — only one ban attempt.
|
||||
@@ -226,7 +232,14 @@ class TestImport:
|
||||
with patch(
|
||||
"app.services.jail_service.ban_ip", new_callable=AsyncMock
|
||||
):
|
||||
result = await blocklist_service.import_all(db, session, "/tmp/fake.sock")
|
||||
from app.services import jail_service
|
||||
|
||||
result = await blocklist_service.import_all(
|
||||
db,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
ban_ip=jail_service.ban_ip,
|
||||
)
|
||||
|
||||
# Only S1 is enabled, S2 is disabled.
|
||||
assert len(result.results) == 1
|
||||
|
||||
@@ -721,9 +721,11 @@ class TestGetServiceStatus:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
self.send = AsyncMock(side_effect=_send)
|
||||
|
||||
with patch("app.services.config_service.Fail2BanClient", _FakeClient), \
|
||||
patch("app.services.health_service.probe", AsyncMock(return_value=online_status)):
|
||||
result = await config_service.get_service_status(_SOCKET)
|
||||
with patch("app.services.config_service.Fail2BanClient", _FakeClient):
|
||||
result = await config_service.get_service_status(
|
||||
_SOCKET,
|
||||
probe_fn=AsyncMock(return_value=online_status),
|
||||
)
|
||||
|
||||
assert result.online is True
|
||||
assert result.version == "1.0.0"
|
||||
@@ -739,8 +741,10 @@ class TestGetServiceStatus:
|
||||
|
||||
offline_status = ServerStatus(online=False)
|
||||
|
||||
with patch("app.services.health_service.probe", AsyncMock(return_value=offline_status)):
|
||||
result = await config_service.get_service_status(_SOCKET)
|
||||
result = await config_service.get_service_status(
|
||||
_SOCKET,
|
||||
probe_fn=AsyncMock(return_value=offline_status),
|
||||
)
|
||||
|
||||
assert result.online is False
|
||||
assert result.jail_count == 0
|
||||
|
||||
88
frontend/src/hooks/__tests__/useConfigItem.test.ts
Normal file
88
frontend/src/hooks/__tests__/useConfigItem.test.ts
Normal file
@@ -0,0 +1,88 @@
|
||||
import { describe, it, expect, vi, beforeEach, afterEach } from "vitest";
|
||||
import { renderHook, act } from "@testing-library/react";
|
||||
import { useConfigItem } from "../useConfigItem";
|
||||
|
||||
describe("useConfigItem", () => {
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers();
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
vi.useRealTimers();
|
||||
vi.clearAllMocks();
|
||||
});
|
||||
|
||||
it("loads data and sets loading state", async () => {
|
||||
const fetchFn = vi.fn().mockResolvedValue("hello");
|
||||
const saveFn = vi.fn().mockResolvedValue(undefined);
|
||||
|
||||
const { result } = renderHook(() => useConfigItem<string, string>({ fetchFn, saveFn }));
|
||||
|
||||
expect(result.current.loading).toBe(true);
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
expect(fetchFn).toHaveBeenCalled();
|
||||
expect(result.current.data).toBe("hello");
|
||||
expect(result.current.loading).toBe(false);
|
||||
});
|
||||
|
||||
it("sets error if fetch rejects", async () => {
|
||||
const fetchFn = vi.fn().mockRejectedValue(new Error("nope"));
|
||||
const saveFn = vi.fn().mockResolvedValue(undefined);
|
||||
|
||||
const { result } = renderHook(() => useConfigItem<string, string>({ fetchFn, saveFn }));
|
||||
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
expect(result.current.error).toBe("nope");
|
||||
expect(result.current.loading).toBe(false);
|
||||
});
|
||||
|
||||
it("save updates data when mergeOnSave is provided", async () => {
|
||||
const fetchFn = vi.fn().mockResolvedValue({ value: 1 });
|
||||
const saveFn = vi.fn().mockResolvedValue(undefined);
|
||||
|
||||
const { result } = renderHook(() =>
|
||||
useConfigItem<{ value: number }, { delta: number }>({
|
||||
fetchFn,
|
||||
saveFn,
|
||||
mergeOnSave: (prev, update) =>
|
||||
prev ? { ...prev, value: prev.value + update.delta } : prev,
|
||||
})
|
||||
);
|
||||
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
expect(result.current.data).toEqual({ value: 1 });
|
||||
|
||||
await act(async () => {
|
||||
await result.current.save({ delta: 2 });
|
||||
});
|
||||
|
||||
expect(saveFn).toHaveBeenCalledWith({ delta: 2 });
|
||||
expect(result.current.data).toEqual({ value: 3 });
|
||||
});
|
||||
|
||||
it("saveError is set when save fails", async () => {
|
||||
const fetchFn = vi.fn().mockResolvedValue("ok");
|
||||
const saveFn = vi.fn().mockRejectedValue(new Error("save failed"));
|
||||
|
||||
const { result } = renderHook(() => useConfigItem<string, string>({ fetchFn, saveFn }));
|
||||
|
||||
await act(async () => {
|
||||
await Promise.resolve();
|
||||
});
|
||||
|
||||
await act(async () => {
|
||||
await expect(result.current.save("test")).rejects.toThrow("save failed");
|
||||
});
|
||||
|
||||
expect(result.current.saveError).toBe("save failed");
|
||||
});
|
||||
});
|
||||
@@ -2,7 +2,7 @@
|
||||
* React hook for loading and updating a single parsed action config.
|
||||
*/
|
||||
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { useConfigItem } from "./useConfigItem";
|
||||
import { fetchAction, updateAction } from "../api/config";
|
||||
import type { ActionConfig, ActionConfigUpdate } from "../types/config";
|
||||
|
||||
@@ -23,67 +23,28 @@ export interface UseActionConfigResult {
|
||||
* @param name - Action base name (e.g. ``"iptables"``).
|
||||
*/
|
||||
export function useActionConfig(name: string): UseActionConfigResult {
|
||||
const [config, setConfig] = useState<ActionConfig | null>(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [saving, setSaving] = useState(false);
|
||||
const [saveError, setSaveError] = useState<string | null>(null);
|
||||
const abortRef = useRef<AbortController | null>(null);
|
||||
const { data, loading, error, saving, saveError, refresh, save } = useConfigItem<
|
||||
ActionConfig,
|
||||
ActionConfigUpdate
|
||||
>({
|
||||
fetchFn: () => fetchAction(name),
|
||||
saveFn: (update) => updateAction(name, update),
|
||||
mergeOnSave: (prev, update) =>
|
||||
prev
|
||||
? {
|
||||
...prev,
|
||||
...Object.fromEntries(Object.entries(update).filter(([, v]) => v != null)),
|
||||
}
|
||||
: prev,
|
||||
});
|
||||
|
||||
const load = useCallback((): void => {
|
||||
abortRef.current?.abort();
|
||||
const ctrl = new AbortController();
|
||||
abortRef.current = ctrl;
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
|
||||
fetchAction(name)
|
||||
.then((data) => {
|
||||
if (!ctrl.signal.aborted) {
|
||||
setConfig(data);
|
||||
setLoading(false);
|
||||
}
|
||||
})
|
||||
.catch((err: unknown) => {
|
||||
if (!ctrl.signal.aborted) {
|
||||
setError(err instanceof Error ? err.message : "Failed to load action config");
|
||||
setLoading(false);
|
||||
}
|
||||
});
|
||||
}, [name]);
|
||||
|
||||
useEffect(() => {
|
||||
load();
|
||||
return (): void => {
|
||||
abortRef.current?.abort();
|
||||
};
|
||||
}, [load]);
|
||||
|
||||
const save = useCallback(
|
||||
async (update: ActionConfigUpdate): Promise<void> => {
|
||||
setSaving(true);
|
||||
setSaveError(null);
|
||||
try {
|
||||
await updateAction(name, update);
|
||||
setConfig((prev) =>
|
||||
prev
|
||||
? {
|
||||
...prev,
|
||||
...Object.fromEntries(
|
||||
Object.entries(update).filter(([, v]) => v !== null && v !== undefined)
|
||||
),
|
||||
}
|
||||
: prev
|
||||
);
|
||||
} catch (err: unknown) {
|
||||
setSaveError(err instanceof Error ? err.message : "Failed to save action config");
|
||||
throw err;
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
},
|
||||
[name]
|
||||
);
|
||||
|
||||
return { config, loading, error, saving, saveError, refresh: load, save };
|
||||
return {
|
||||
config: data,
|
||||
loading,
|
||||
error,
|
||||
saving,
|
||||
saveError,
|
||||
refresh,
|
||||
save,
|
||||
};
|
||||
}
|
||||
|
||||
84
frontend/src/hooks/useConfigItem.ts
Normal file
84
frontend/src/hooks/useConfigItem.ts
Normal file
@@ -0,0 +1,84 @@
|
||||
/**
|
||||
* Generic config hook for loading and saving a single entity.
|
||||
*/
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
|
||||
export interface UseConfigItemResult<T, U> {
|
||||
data: T | null;
|
||||
loading: boolean;
|
||||
error: string | null;
|
||||
saving: boolean;
|
||||
saveError: string | null;
|
||||
refresh: () => void;
|
||||
save: (update: U) => Promise<void>;
|
||||
}
|
||||
|
||||
export interface UseConfigItemOptions<T, U> {
|
||||
fetchFn: (signal: AbortSignal) => Promise<T>;
|
||||
saveFn: (update: U) => Promise<void>;
|
||||
mergeOnSave?: (prev: T | null, update: U) => T | null;
|
||||
}
|
||||
|
||||
export function useConfigItem<T, U>(
|
||||
options: UseConfigItemOptions<T, U>
|
||||
): UseConfigItemResult<T, U> {
|
||||
const { fetchFn, saveFn, mergeOnSave } = options;
|
||||
const [data, setData] = useState<T | null>(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [saving, setSaving] = useState(false);
|
||||
const [saveError, setSaveError] = useState<string | null>(null);
|
||||
const abortRef = useRef<AbortController | null>(null);
|
||||
|
||||
const refresh = useCallback((): void => {
|
||||
abortRef.current?.abort();
|
||||
const controller = new AbortController();
|
||||
abortRef.current = controller;
|
||||
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
|
||||
fetchFn(controller.signal)
|
||||
.then((nextData) => {
|
||||
if (controller.signal.aborted) return;
|
||||
setData(nextData);
|
||||
setLoading(false);
|
||||
})
|
||||
.catch((err: unknown) => {
|
||||
if (controller.signal.aborted) return;
|
||||
setError(err instanceof Error ? err.message : "Failed to load data");
|
||||
setLoading(false);
|
||||
});
|
||||
}, [fetchFn]);
|
||||
|
||||
useEffect(() => {
|
||||
refresh();
|
||||
|
||||
return () => {
|
||||
abortRef.current?.abort();
|
||||
};
|
||||
}, [refresh]);
|
||||
|
||||
const save = useCallback(
|
||||
async (update: U): Promise<void> => {
|
||||
setSaving(true);
|
||||
setSaveError(null);
|
||||
|
||||
try {
|
||||
await saveFn(update);
|
||||
if (mergeOnSave) {
|
||||
setData((prevData) => mergeOnSave(prevData, update));
|
||||
}
|
||||
} catch (err: unknown) {
|
||||
const message = err instanceof Error ? err.message : "Failed to save data";
|
||||
setSaveError(message);
|
||||
throw err;
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
},
|
||||
[saveFn, mergeOnSave]
|
||||
);
|
||||
|
||||
return { data, loading, error, saving, saveError, refresh, save };
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
* React hook for loading and updating a single parsed filter config.
|
||||
*/
|
||||
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { useConfigItem } from "./useConfigItem";
|
||||
import { fetchParsedFilter, updateParsedFilter } from "../api/config";
|
||||
import type { FilterConfig, FilterConfigUpdate } from "../types/config";
|
||||
|
||||
@@ -23,69 +23,28 @@ export interface UseFilterConfigResult {
|
||||
* @param name - Filter base name (e.g. ``"sshd"``).
|
||||
*/
|
||||
export function useFilterConfig(name: string): UseFilterConfigResult {
|
||||
const [config, setConfig] = useState<FilterConfig | null>(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [saving, setSaving] = useState(false);
|
||||
const [saveError, setSaveError] = useState<string | null>(null);
|
||||
const abortRef = useRef<AbortController | null>(null);
|
||||
const { data, loading, error, saving, saveError, refresh, save } = useConfigItem<
|
||||
FilterConfig,
|
||||
FilterConfigUpdate
|
||||
>({
|
||||
fetchFn: () => fetchParsedFilter(name),
|
||||
saveFn: (update) => updateParsedFilter(name, update),
|
||||
mergeOnSave: (prev, update) =>
|
||||
prev
|
||||
? {
|
||||
...prev,
|
||||
...Object.fromEntries(Object.entries(update).filter(([, v]) => v != null)),
|
||||
}
|
||||
: prev,
|
||||
});
|
||||
|
||||
const load = useCallback((): void => {
|
||||
abortRef.current?.abort();
|
||||
const ctrl = new AbortController();
|
||||
abortRef.current = ctrl;
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
|
||||
fetchParsedFilter(name)
|
||||
.then((data) => {
|
||||
if (!ctrl.signal.aborted) {
|
||||
setConfig(data);
|
||||
setLoading(false);
|
||||
}
|
||||
})
|
||||
.catch((err: unknown) => {
|
||||
if (!ctrl.signal.aborted) {
|
||||
setError(err instanceof Error ? err.message : "Failed to load filter config");
|
||||
setLoading(false);
|
||||
}
|
||||
});
|
||||
}, [name]);
|
||||
|
||||
useEffect(() => {
|
||||
load();
|
||||
return (): void => {
|
||||
abortRef.current?.abort();
|
||||
};
|
||||
}, [load]);
|
||||
|
||||
const save = useCallback(
|
||||
async (update: FilterConfigUpdate): Promise<void> => {
|
||||
setSaving(true);
|
||||
setSaveError(null);
|
||||
try {
|
||||
await updateParsedFilter(name, update);
|
||||
// Optimistically update local state so the form reflects changes
|
||||
// without a full reload.
|
||||
setConfig((prev) =>
|
||||
prev
|
||||
? {
|
||||
...prev,
|
||||
...Object.fromEntries(
|
||||
Object.entries(update).filter(([, v]) => v !== null && v !== undefined)
|
||||
),
|
||||
}
|
||||
: prev
|
||||
);
|
||||
} catch (err: unknown) {
|
||||
setSaveError(err instanceof Error ? err.message : "Failed to save filter config");
|
||||
throw err;
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
},
|
||||
[name]
|
||||
);
|
||||
|
||||
return { config, loading, error, saving, saveError, refresh: load, save };
|
||||
return {
|
||||
config: data,
|
||||
loading,
|
||||
error,
|
||||
saving,
|
||||
saveError,
|
||||
refresh,
|
||||
save,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
* React hook for loading and updating a single parsed jail.d config file.
|
||||
*/
|
||||
|
||||
import { useCallback, useEffect, useRef, useState } from "react";
|
||||
import { useConfigItem } from "./useConfigItem";
|
||||
import { fetchParsedJailFile, updateParsedJailFile } from "../api/config";
|
||||
import type { JailFileConfig, JailFileConfigUpdate } from "../types/config";
|
||||
|
||||
@@ -21,56 +21,23 @@ export interface UseJailFileConfigResult {
|
||||
* @param filename - Filename including extension (e.g. ``"sshd.conf"``).
|
||||
*/
|
||||
export function useJailFileConfig(filename: string): UseJailFileConfigResult {
|
||||
const [config, setConfig] = useState<JailFileConfig | null>(null);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const abortRef = useRef<AbortController | null>(null);
|
||||
const { data, loading, error, refresh, save } = useConfigItem<
|
||||
JailFileConfig,
|
||||
JailFileConfigUpdate
|
||||
>({
|
||||
fetchFn: () => fetchParsedJailFile(filename),
|
||||
saveFn: (update) => updateParsedJailFile(filename, update),
|
||||
mergeOnSave: (prev, update) =>
|
||||
update.jails != null && prev
|
||||
? { ...prev, jails: { ...prev.jails, ...update.jails } }
|
||||
: prev,
|
||||
});
|
||||
|
||||
const load = useCallback((): void => {
|
||||
abortRef.current?.abort();
|
||||
const ctrl = new AbortController();
|
||||
abortRef.current = ctrl;
|
||||
setLoading(true);
|
||||
setError(null);
|
||||
|
||||
fetchParsedJailFile(filename)
|
||||
.then((data) => {
|
||||
if (!ctrl.signal.aborted) {
|
||||
setConfig(data);
|
||||
setLoading(false);
|
||||
}
|
||||
})
|
||||
.catch((err: unknown) => {
|
||||
if (!ctrl.signal.aborted) {
|
||||
setError(err instanceof Error ? err.message : "Failed to load jail file config");
|
||||
setLoading(false);
|
||||
}
|
||||
});
|
||||
}, [filename]);
|
||||
|
||||
useEffect(() => {
|
||||
load();
|
||||
return (): void => {
|
||||
abortRef.current?.abort();
|
||||
};
|
||||
}, [load]);
|
||||
|
||||
const save = useCallback(
|
||||
async (update: JailFileConfigUpdate): Promise<void> => {
|
||||
try {
|
||||
await updateParsedJailFile(filename, update);
|
||||
// Optimistically merge updated jails into local state.
|
||||
if (update.jails != null) {
|
||||
setConfig((prev) =>
|
||||
prev ? { ...prev, jails: { ...prev.jails, ...update.jails } } : prev
|
||||
);
|
||||
}
|
||||
} catch (err: unknown) {
|
||||
throw err instanceof Error ? err : new Error("Failed to save jail file config");
|
||||
}
|
||||
},
|
||||
[filename]
|
||||
);
|
||||
|
||||
return { config, loading, error, refresh: load, save };
|
||||
return {
|
||||
config: data,
|
||||
loading,
|
||||
error,
|
||||
refresh,
|
||||
save,
|
||||
};
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user