From c21cf82e9e3e8ffc9c26e9f9ae530aa6ab00c094 Mon Sep 17 00:00:00 2001 From: Lukas Date: Fri, 17 Apr 2026 15:13:07 +0200 Subject: [PATCH] Refactor map color threshold storage into dedicated settings service --- Docs/Tasks.md | 2 + backend/app/routers/config_misc.py | 132 ++++++------ backend/app/services/config_service.py | 193 +----------------- backend/app/services/health_service.py | 128 ++++++++++-- backend/app/services/log_service.py | 169 ++++++++++++++- backend/app/services/protocols.py | 15 -- backend/app/services/settings_service.py | 89 ++++++++ backend/app/services/setup_service.py | 10 +- backend/app/utils/setup_utils.py | 34 --- backend/tests/test_regression_500s.py | 10 +- .../test_services/test_config_service.py | 34 +-- 11 files changed, 467 insertions(+), 349 deletions(-) create mode 100644 backend/app/services/settings_service.py diff --git a/Docs/Tasks.md b/Docs/Tasks.md index 2469719..fcd2a14 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -138,6 +138,8 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue. **Why this is needed:** Triplicated implementation violates DRY and means a change to the threshold schema must be made in three places. Using `setup_service` for ongoing runtime settings is conceptually wrong and misleads maintainers. +**Status:** Completed ✅ + --- ### Task 8 — Move the `activate_jail` 3-step restart workflow out of `config_misc.py` router diff --git a/backend/app/routers/config_misc.py b/backend/app/routers/config_misc.py index 167d9ca..a7a7d74 100644 --- a/backend/app/routers/config_misc.py +++ b/backend/app/routers/config_misc.py @@ -5,8 +5,17 @@ from typing import Annotated import structlog from fastapi import APIRouter, HTTPException, Query, Request, status -from app.dependencies import AuthDep, DbDep, Fail2BanSocketDep, Fail2BanStartCommandDep -from app.exceptions import ConfigOperationError, JailOperationError +from app.dependencies import ( + AuthDep, + DbDep, + Fail2BanSocketDep, + Fail2BanStartCommandDep, +) +from app.exceptions import ( + ConfigOperationError, + Fail2BanConnectionError, + JailOperationError, +) from app.models.config import ( Fail2BanLogResponse, GlobalConfigResponse, @@ -19,9 +28,12 @@ from app.models.config import ( RegexTestResponse, ServiceStatusResponse, ) -from app.services import config_service, jail_service, log_service, setup_service +from app.services import ( + config_service, + jail_service, + log_service, +) from app.utils.config_file_utils import start_daemon, wait_for_fail2ban -from app.exceptions import Fail2BanConnectionError log: structlog.stdlib.BoundLogger = structlog.get_logger() @@ -41,17 +53,20 @@ def _bad_request(message: str) -> HTTPException: detail=message, ) + @router.get( "/global", response_model=GlobalConfigResponse, summary="Return global fail2ban settings", ) async def get_global_config( - request: Request, + _request: Request, _auth: AuthDep, socket_path: Fail2BanSocketDep, ) -> GlobalConfigResponse: - """Return global fail2ban settings (log level, log target, database config). + """Return global fail2ban settings. + + Includes log level, log target, and database configuration. Args: request: Incoming request. @@ -69,15 +84,13 @@ async def get_global_config( raise _bad_gateway(exc) from exc - - @router.put( "/global", status_code=status.HTTP_204_NO_CONTENT, summary="Update global fail2ban settings", ) async def update_global_config( - request: Request, + _request: Request, _auth: AuthDep, socket_path: Fail2BanSocketDep, body: GlobalConfigUpdate, @@ -105,16 +118,13 @@ async def update_global_config( # Reload endpoint # --------------------------------------------------------------------------- - - - @router.post( "/reload", status_code=status.HTTP_204_NO_CONTENT, summary="Reload fail2ban to apply configuration changes", ) async def reload_fail2ban( - request: Request, + _request: Request, _auth: AuthDep, socket_path: Fail2BanSocketDep, ) -> None: @@ -144,16 +154,13 @@ async def reload_fail2ban( # Restart endpoint # --------------------------------------------------------------------------- - - - @router.post( "/restart", status_code=status.HTTP_204_NO_CONTENT, summary="Restart the fail2ban service", ) async def restart_fail2ban( - request: Request, + _request: Request, _auth: AuthDep, socket_path: Fail2BanSocketDep, start_cmd: Fail2BanStartCommandDep, @@ -175,8 +182,8 @@ async def restart_fail2ban( HTTPException: 503 when fail2ban does not come back online within 10 seconds after being started. Check the fail2ban log for initialisation errors. Use - ``POST /api/config/jails/{name}/rollback`` if a specific jail - is suspect. + ``POST /api/config/jails/{name}/rollback`` + if a specific jail is suspect. """ start_cmd_parts: list[str] = start_cmd.split() @@ -194,15 +201,21 @@ async def restart_fail2ban( # Step 2: start the daemon via subprocess. await start_daemon(start_cmd_parts) - # Step 3: probe the socket until fail2ban is responsive or the budget expires. - fail2ban_running: bool = await wait_for_fail2ban(socket_path, max_wait_seconds=10.0) + # Step 3: probe the socket until fail2ban is responsive or the budget + # expires. + fail2ban_running: bool = await wait_for_fail2ban( + socket_path, + max_wait_seconds=10.0, + ) if not fail2ban_running: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail=( - "fail2ban was stopped but did not come back online within 10 seconds. " + "fail2ban was stopped but did not come back " + "online within 10 seconds. " "Check the fail2ban log for initialisation errors. " - "Use POST /api/config/jails/{name}/rollback if a specific jail is suspect." + "Use POST /api/config/jails/{name}/rollback if a " + "specific jail is suspect." ), ) log.info("fail2ban_restarted") @@ -212,9 +225,6 @@ async def restart_fail2ban( # Regex tester (stateless) # --------------------------------------------------------------------------- - - - @router.post( "/regex-test", response_model=RegexTestResponse, @@ -234,7 +244,8 @@ async def regex_test( body: Sample log line and regex pattern. Returns: - :class:`~app.models.config.RegexTestResponse` with match result and groups. + :class:`~app.models.config.RegexTestResponse` with match result and + groups. """ return log_service.test_regex(body) @@ -243,9 +254,6 @@ async def regex_test( # Log path management # --------------------------------------------------------------------------- - - - @router.post( "/preview-log", response_model=LogPreviewResponse, @@ -275,16 +283,13 @@ async def preview_log( # Map color thresholds # --------------------------------------------------------------------------- - - - @router.get( "/map-color-thresholds", response_model=MapColorThresholdsResponse, summary="Get map color threshold configuration", ) async def get_map_color_thresholds( - request: Request, + _request: Request, _auth: AuthDep, db: DbDep, ) -> MapColorThresholdsResponse: @@ -298,16 +303,7 @@ async def get_map_color_thresholds( :class:`~app.models.config.MapColorThresholdsResponse` with current thresholds. """ - high, medium, low = await setup_service.get_map_color_thresholds(db) - return MapColorThresholdsResponse( - threshold_high=high, - threshold_medium=medium, - threshold_low=low, - ) - - - - + return await config_service.get_map_color_thresholds(db) @router.put( "/map-color-thresholds", @@ -315,7 +311,7 @@ async def get_map_color_thresholds( summary="Update map color threshold configuration", ) async def update_map_color_thresholds( - request: Request, + _request: Request, _auth: AuthDep, db: DbDep, body: MapColorThresholdsUpdate, @@ -336,22 +332,11 @@ async def update_map_color_thresholds( properly ordered). """ try: - await setup_service.set_map_color_thresholds( - db, - threshold_high=body.threshold_high, - threshold_medium=body.threshold_medium, - threshold_low=body.threshold_low, - ) + await config_service.update_map_color_thresholds(db, body) except ValueError as exc: raise _bad_request(str(exc)) from exc - return MapColorThresholdsResponse( - threshold_high=body.threshold_high, - threshold_medium=body.threshold_medium, - threshold_low=body.threshold_low, - ) - - + return await config_service.get_map_color_thresholds(db) @router.get( @@ -360,13 +345,26 @@ async def update_map_color_thresholds( summary="Read the tail of the fail2ban daemon log file", ) async def get_fail2ban_log( - request: Request, + _request: Request, _auth: AuthDep, socket_path: Fail2BanSocketDep, - lines: Annotated[int, Query(ge=1, le=2000, description="Number of lines to return from the tail.")] = 200, - filter: Annotated[ # noqa: A002 + lines: Annotated[ + int, + Query( + ge=1, + le=2000, + description="Number of lines to return from the tail.", + ), + ] = 200, + filter_: Annotated[ # noqa: A002 str | None, - Query(description="Plain-text substring filter; only matching lines are returned."), + Query( + alias="filter", + description=( + "Plain-text substring filter; " + "only matching lines are returned." + ), + ), ] = None, ) -> Fail2BanLogResponse: """Return the tail of the fail2ban daemon log file. @@ -390,22 +388,20 @@ async def get_fail2ban_log( HTTPException: 502 when fail2ban is unreachable. """ try: - return await config_service.read_fail2ban_log(socket_path, lines, filter) + return await log_service.read_fail2ban_log(socket_path, lines, filter_) except ConfigOperationError as exc: raise _bad_request(str(exc)) from exc except Fail2BanConnectionError as exc: raise _bad_gateway(exc) from exc - - @router.get( "/service-status", response_model=ServiceStatusResponse, summary="Return fail2ban service health status with log configuration", ) async def get_service_status( - request: Request, + _request: Request, _auth: AuthDep, socket_path: Fail2BanSocketDep, ) -> ServiceStatusResponse: @@ -428,11 +424,9 @@ async def get_service_status( from app.services import health_service try: - return await config_service.get_service_status( + return await health_service.get_service_status( socket_path, probe_fn=health_service.probe, ) except Fail2BanConnectionError as exc: raise _bad_gateway(exc) from exc - - diff --git a/backend/app/services/config_service.py b/backend/app/services/config_service.py index a11965b..9b4ddd7 100644 --- a/backend/app/services/config_service.py +++ b/backend/app/services/config_service.py @@ -13,10 +13,8 @@ routers can serialise them directly. from __future__ import annotations import asyncio -from app.utils.async_utils import run_blocking import contextlib import re -from pathlib import Path from typing import TYPE_CHECKING, TypeVar, cast import structlog @@ -28,14 +26,10 @@ if TYPE_CHECKING: import aiosqlite -from app import __version__ from app.exceptions import ConfigOperationError, ConfigValidationError, JailNotFoundError -from app.services.log_service import preview_log as util_preview_log -from app.services.log_service import test_regex as util_test_regex from app.models.config import ( AddLogPathRequest, BantimeEscalation, - Fail2BanLogResponse, GlobalConfigResponse, GlobalConfigUpdate, JailConfig, @@ -48,15 +42,16 @@ from app.models.config import ( MapColorThresholdsUpdate, RegexTestRequest, RegexTestResponse, - ServiceStatusResponse, ) -from app.utils.fail2ban_client import Fail2BanClient -from app.utils.setup_utils import ( +from app.services.log_service import preview_log as util_preview_log +from app.services.log_service import test_regex as util_test_regex +from app.services.settings_service import ( get_map_color_thresholds as util_get_map_color_thresholds, ) -from app.utils.setup_utils import ( +from app.services.settings_service import ( set_map_color_thresholds as util_set_map_color_thresholds, ) +from app.utils.fail2ban_client import Fail2BanClient log: structlog.stdlib.BoundLogger = structlog.get_logger() @@ -649,181 +644,3 @@ _NON_FILE_LOG_TARGETS: frozenset[str] = frozenset( _SAFE_LOG_PREFIXES: tuple[str, ...] = ("/var/log", "/config/log") -def _count_file_lines(file_path: str) -> int: - """Count the total number of lines in *file_path* synchronously.""" - count = 0 - with open(file_path, "rb") as fh: - for chunk in iter(lambda: fh.read(65536), b""): - count += chunk.count(b"\n") - return count - - -def _read_tail_lines(file_path: str, num_lines: int) -> list[str]: - """Read the last *num_lines* from *file_path* in a memory-efficient way.""" - chunk_size = 8192 - raw_lines: list[bytes] = [] - with open(file_path, "rb") as fh: - fh.seek(0, 2) - end_pos = fh.tell() - if end_pos == 0: - return [] - - buf = b"" - pos = end_pos - while len(raw_lines) <= num_lines and pos > 0: - read_size = min(chunk_size, pos) - pos -= read_size - fh.seek(pos) - chunk = fh.read(read_size) - buf = chunk + buf - raw_lines = buf.split(b"\n") - - if pos > 0 and len(raw_lines) > 1: - raw_lines = raw_lines[1:] - - return [ln.decode("utf-8", errors="replace").rstrip() for ln in raw_lines[-num_lines:] if ln.strip()] - - -async def read_fail2ban_log( - socket_path: str, - lines: int, - filter_text: str | None = None, -) -> Fail2BanLogResponse: - """Read the tail of the fail2ban daemon log file. - - Queries the fail2ban socket for the current log target and log level, - validates that the target is a readable file, then returns the last - *lines* entries optionally filtered by *filter_text*. - - Security: the resolved log path is rejected unless it starts with one of - the paths in :data:`_SAFE_LOG_PREFIXES`, preventing path traversal. - - Args: - socket_path: Path to the fail2ban Unix domain socket. - lines: Number of lines to return from the tail of the file (1–2000). - filter_text: Optional plain-text substring — only matching lines are - returned. Applied server-side; does not affect *total_lines*. - - Returns: - :class:`~app.models.config.Fail2BanLogResponse`. - - Raises: - ConfigOperationError: When the log target is not a file, when the - resolved path is outside the allowed directories, or when the - file cannot be read. - ~app.utils.fail2ban_client.Fail2BanConnectionError: Socket unreachable. - """ - client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) - - log_level_raw, log_target_raw = await asyncio.gather( - _safe_get_typed(client, ["get", "loglevel"], "INFO"), - _safe_get_typed(client, ["get", "logtarget"], "STDOUT"), - ) - - log_level = str(log_level_raw or "INFO").upper() - log_target = str(log_target_raw or "STDOUT") - - # Reject non-file targets up front. - if log_target.upper() in _NON_FILE_LOG_TARGETS: - raise ConfigOperationError( - f"fail2ban is logging to {log_target!r}. " - "File-based log viewing is only available when fail2ban logs to a file path." - ) - - # Resolve and validate (security: no path traversal outside safe dirs). - try: - resolved = Path(log_target).resolve() - except (ValueError, OSError) as exc: - raise ConfigOperationError( - f"Cannot resolve log target path {log_target!r}: {exc}" - ) from exc - - resolved_str = str(resolved) - if not any(resolved_str.startswith(safe) for safe in _SAFE_LOG_PREFIXES): - raise ConfigOperationError( - f"Log path {resolved_str!r} is outside the allowed directory. " - "Only paths under /var/log or /config/log are permitted." - ) - - if not resolved.is_file(): - raise ConfigOperationError(f"Log file not found: {resolved_str!r}") - - loop = asyncio.get_event_loop() - - total_lines, raw_lines = await asyncio.gather( - run_blocking( _count_file_lines, resolved_str), - run_blocking( _read_tail_lines, resolved_str, lines), - ) - - filtered = ( - [ln for ln in raw_lines if filter_text in ln] - if filter_text - else raw_lines - ) - - log.info( - "fail2ban_log_read", - log_path=resolved_str, - lines_requested=lines, - lines_returned=len(filtered), - filter_active=filter_text is not None, - ) - - return Fail2BanLogResponse( - log_path=resolved_str, - lines=filtered, - total_lines=total_lines, - log_level=log_level, - log_target=log_target, - ) - - -async def get_service_status( - socket_path: str, - probe_fn: Callable[[str], Awaitable[ServiceStatusResponse]] | None = None, -) -> ServiceStatusResponse: - """Return fail2ban service health status with log configuration. - - Delegates to an injectable *probe_fn* (defaults to - :func:`~app.services.health_service.probe`). This avoids direct service-to- - service imports inside this module. - - Args: - socket_path: Path to the fail2ban Unix domain socket. - probe_fn: Optional probe function. - - Returns: - :class:`~app.models.config.ServiceStatusResponse`. - """ - if probe_fn is None: - raise ValueError("probe_fn is required to avoid service-to-service coupling") - - server_status = await probe_fn(socket_path) - - if server_status.online: - client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) - log_level_raw, log_target_raw = await asyncio.gather( - _safe_get_typed(client, ["get", "loglevel"], "INFO"), - _safe_get_typed(client, ["get", "logtarget"], "STDOUT"), - ) - log_level = str(log_level_raw or "INFO").upper() - log_target = str(log_target_raw or "STDOUT") - else: - log_level = "UNKNOWN" - log_target = "UNKNOWN" - - log.info( - "service_status_fetched", - online=server_status.online, - jail_count=server_status.active_jails, - ) - - return ServiceStatusResponse( - online=server_status.online, - version=__version__, - jail_count=server_status.active_jails, - total_bans=server_status.total_bans, - total_failures=server_status.total_failures, - log_level=log_level, - log_target=log_target, - ) diff --git a/backend/app/services/health_service.py b/backend/app/services/health_service.py index b342f5f..5f46527 100644 --- a/backend/app/services/health_service.py +++ b/backend/app/services/health_service.py @@ -9,13 +9,18 @@ seconds by the background health-check task, not on every HTTP request. from __future__ import annotations -from typing import cast +import asyncio +from collections.abc import Awaitable, Callable +from typing import TypeVar, cast import structlog +from app import __version__ +from app.models.config import ServiceStatusResponse from app.models.server import ServerStatus from app.utils.fail2ban_client import ( Fail2BanClient, + Fail2BanCommand, Fail2BanConnectionError, Fail2BanProtocolError, Fail2BanResponse, @@ -49,7 +54,9 @@ def _ok(response: object) -> object: try: code, data = cast("Fail2BanResponse", response) except (TypeError, ValueError) as exc: - raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc + raise ValueError( + f"Unexpected fail2ban response shape: {response!r}" + ) from exc if code != 0: raise ValueError(f"fail2ban returned error code {code}: {data!r}") @@ -81,13 +88,101 @@ def _to_dict(pairs: object) -> dict[str, object]: return result +T = TypeVar("T") + + +async def _safe_get( + client: Fail2BanClient, + command: Fail2BanCommand, + default: object | None = None, +) -> object | None: + """Send a command and return *default* if it fails.""" + try: + return _ok(await client.send(command)) + except ( + Fail2BanConnectionError, + Fail2BanProtocolError, + ValueError, + OSError, + ): + return default + + +async def _safe_get_typed( + client: Fail2BanClient, + command: Fail2BanCommand, + default: T, +) -> T: + """Send a command and return the result typed as ``default``'s type.""" + return cast("T", await _safe_get(client, command, default)) + + +async def get_service_status( + socket_path: str, + probe_fn: Callable[[str], Awaitable[ServerStatus]] | None = None, +) -> ServiceStatusResponse: + """Return fail2ban service health status with log configuration. + + Delegates to an injectable *probe_fn* (defaults to + :func:`~app.services.health_service.probe`). + + Args: + socket_path: Path to the fail2ban Unix domain socket. + probe_fn: Optional probe function. + + Returns: + :class:`~app.models.config.ServiceStatusResponse`. + """ + if probe_fn is None: + raise ValueError( + "probe_fn is required to avoid service-to-service coupling" + ) + + server_status = await probe_fn(socket_path) + + if server_status.online: + client = Fail2BanClient( + socket_path=socket_path, + timeout=_SOCKET_TIMEOUT, + ) + log_level_raw, log_target_raw = await asyncio.gather( + _safe_get_typed(client, ["get", "loglevel"], "INFO"), + _safe_get_typed(client, ["get", "logtarget"], "STDOUT"), + ) + log_level = str(log_level_raw or "INFO").upper() + log_target = str(log_target_raw or "STDOUT") + else: + log_level = "UNKNOWN" + log_target = "UNKNOWN" + + log.info( + "service_status_fetched", + online=server_status.online, + jail_count=server_status.active_jails, + ) + + return ServiceStatusResponse( + online=server_status.online, + version=__version__, + jail_count=server_status.active_jails, + total_bans=server_status.total_bans, + total_failures=server_status.total_failures, + log_level=log_level, + log_target=log_target, + ) + + # --------------------------------------------------------------------------- # Public interface # --------------------------------------------------------------------------- -async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerStatus: - """Probe the fail2ban daemon and return a :class:`~app.models.server.ServerStatus`. +async def probe( + socket_path: str, + timeout: float = _SOCKET_TIMEOUT, +) -> ServerStatus: + """Probe the fail2ban daemon and return a + :class:`~app.models.server.ServerStatus`. Sends ``ping``, ``version``, ``status``, and per-jail ``status `` commands. Any socket or protocol error is caught and results in an @@ -109,11 +204,14 @@ async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerSta # ------------------------------------------------------------------ # ping_data = _ok(await client.send(["ping"])) if ping_data != "pong": - log.warning("fail2ban_unexpected_ping_response", response=ping_data) + log.warning( + "fail2ban_unexpected_ping_response", + response=ping_data, + ) return ServerStatus(online=False) # ------------------------------------------------------------------ # - # 2. Version # + # 2. Version # ------------------------------------------------------------------ # try: version: str | None = str(_ok(await client.send(["version"]))) @@ -125,7 +223,9 @@ async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerSta # ------------------------------------------------------------------ # status_data = _to_dict(_ok(await client.send(["status"]))) active_jails: int = int(str(status_data.get("Number of jail", 0) or 0)) - jail_list_raw: str = str(status_data.get("Jail list", "") or "").strip() + jail_list_raw: str = str( + status_data.get("Jail list", "") or "" + ).strip() jail_names: list[str] = ( [j.strip() for j in jail_list_raw.split(",") if j.strip()] if jail_list_raw @@ -140,11 +240,17 @@ async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerSta for jail_name in jail_names: try: - jail_resp = _to_dict(_ok(await client.send(["status", jail_name]))) + jail_resp = _to_dict( + _ok(await client.send(["status", jail_name])) + ) filter_stats = _to_dict(jail_resp.get("Filter") or []) action_stats = _to_dict(jail_resp.get("Actions") or []) - total_failures += int(str(filter_stats.get("Currently failed", 0) or 0)) - total_bans += int(str(action_stats.get("Currently banned", 0) or 0)) + total_failures += int( + str(filter_stats.get("Currently failed", 0) or 0) + ) + total_bans += int( + str(action_stats.get("Currently banned", 0) or 0) + ) except (ValueError, TypeError, KeyError) as exc: log.warning( "fail2ban_jail_status_parse_error", @@ -174,5 +280,3 @@ async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerSta except ValueError as exc: log.error("fail2ban_probe_parse_error", error=str(exc)) return ServerStatus(online=False) - - diff --git a/backend/app/services/log_service.py b/backend/app/services/log_service.py index 7006e0e..be80de4 100644 --- a/backend/app/services/log_service.py +++ b/backend/app/services/log_service.py @@ -1,22 +1,168 @@ """Log helper service. -Contains regex test and log preview helpers that are independent of -fail2ban socket operations. +Contains regex test, log preview, and fail2ban log reading helpers. """ from __future__ import annotations -from app.utils.async_utils import run_blocking +import asyncio import re +import structlog from pathlib import Path +from typing import TypeVar, cast +from app.exceptions import ConfigOperationError from app.models.config import ( + Fail2BanLogResponse, LogPreviewLine, LogPreviewRequest, LogPreviewResponse, RegexTestRequest, RegexTestResponse, ) +from app.utils.async_utils import run_blocking +from app.utils.fail2ban_client import ( + Fail2BanClient, + Fail2BanConnectionError, + Fail2BanProtocolError, + Fail2BanResponse, +) + +log: structlog.stdlib.BoundLogger = structlog.get_logger() + +_SOCKET_TIMEOUT: float = 10.0 + +_NON_FILE_LOG_TARGETS: frozenset[str] = frozenset( + {"STDOUT", "STDERR", "SYSLOG", "SYSTEMD-JOURNAL"} +) + +_SAFE_LOG_PREFIXES: tuple[str, ...] = ("/var/log", "/config/log") + + +def _ok(response: object) -> object: + """Extract the payload from a fail2ban ``(return_code, data)`` response.""" + try: + code, data = cast(Fail2BanResponse, response) + except (TypeError, ValueError) as exc: + raise ValueError( + f"Unexpected fail2ban response shape: {response!r}" + ) from exc + + if code != 0: + raise ValueError(f"fail2ban returned error code {code}: {data!r}") + + return data + + +def _count_file_lines(file_path: str) -> int: + """Count the total number of lines in *file_path* synchronously.""" + count = 0 + with open(file_path, "rb") as fh: + for chunk in iter(lambda: fh.read(65536), b""): + count += chunk.count(b"\n") + return count + + +async def _safe_get( + client: Fail2BanClient, + command: list[str], + default: object | None = None, +) -> object | None: + """Send a command and return *default* if it fails.""" + try: + return _ok(await client.send(command)) + except ( + Fail2BanConnectionError, + Fail2BanProtocolError, + OSError, + ValueError, + ): + return default + + +T = TypeVar("T") + + +async def _safe_get_typed( + client: Fail2BanClient, + command: list[str], + default: T, +) -> T: + """Send a command and return the result typed as ``default``'s type.""" + return cast("T", await _safe_get(client, command, default)) + + +async def read_fail2ban_log( + socket_path: str, + lines: int, + filter_text: str | None = None, +) -> Fail2BanLogResponse: + """Read the tail of the fail2ban daemon log file. + + Queries the fail2ban socket for the current log target and log level, + validates that the target is a readable file, then returns the last + *lines* entries optionally filtered by *filter_text*. + """ + client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) + + log_level_raw, log_target_raw = await asyncio.gather( + _safe_get_typed(client, ["get", "loglevel"], "INFO"), + _safe_get_typed(client, ["get", "logtarget"], "STDOUT"), + ) + + log_level = str(log_level_raw or "INFO").upper() + log_target = str(log_target_raw or "STDOUT") + + if log_target.upper() in _NON_FILE_LOG_TARGETS: + raise ConfigOperationError( + f"fail2ban is logging to {log_target!r}. " + "File-based log viewing is only available when fail2ban logs " + "to a file path." + ) + + try: + resolved = Path(log_target).resolve() + except (ValueError, OSError) as exc: + raise ConfigOperationError( + f"Cannot resolve log target path {log_target!r}: {exc}" + ) from exc + + resolved_str = str(resolved) + if not any(resolved_str.startswith(safe) for safe in _SAFE_LOG_PREFIXES): + raise ConfigOperationError( + f"Log path {resolved_str!r} is outside the allowed directory. " + "Only paths under /var/log or /config/log are permitted." + ) + + if not resolved.is_file(): + raise ConfigOperationError(f"Log file not found: {resolved_str!r}") + + total_lines, raw_lines = await asyncio.gather( + run_blocking(_count_file_lines, resolved_str), + run_blocking(_read_tail_lines, resolved_str, lines), + ) + + filtered = ( + [ln for ln in raw_lines if filter_text in ln] + if filter_text + else raw_lines + ) + + log.info( + "fail2ban_log_read", + log_path=resolved_str, + lines_requested=lines, + lines_returned=len(filtered), + filter_active=filter_text is not None, + ) + + return Fail2BanLogResponse( + log_path=resolved_str, + lines=filtered, + total_lines=total_lines, + log_level=log_level, + log_target=log_target, + ) def test_regex(request: RegexTestRequest) -> RegexTestResponse: @@ -38,7 +184,10 @@ def test_regex(request: RegexTestRequest) -> RegexTestResponse: return RegexTestResponse(matched=False) groups: list[str] = list(match.groups() or []) - return RegexTestResponse(matched=True, groups=[str(g) for g in groups if g is not None]) + return RegexTestResponse( + matched=True, + groups=[str(g) for g in groups if g is not None], + ) async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse: @@ -87,7 +236,11 @@ async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse: matched_count = 0 for line in raw_lines: m = compiled.search(line) - groups = [str(g) for g in (m.groups() or []) if g is not None] if m else [] + groups = [ + str(g) + for g in (m.groups() or []) + if g is not None + ] if m else [] result_lines.append( LogPreviewLine(line=line, matched=(m is not None), groups=groups), ) @@ -124,4 +277,8 @@ def _read_tail_lines(file_path: str, num_lines: int) -> list[str]: if pos > 0 and len(raw_lines) > 1: raw_lines = raw_lines[1:] - return [ln.decode("utf-8", errors="replace").rstrip() for ln in raw_lines[-num_lines:] if ln.strip()] + return [ + ln.decode("utf-8", errors="replace").rstrip() + for ln in raw_lines[-num_lines:] + if ln.strip() + ] diff --git a/backend/app/services/protocols.py b/backend/app/services/protocols.py index 0eac734..dd65795 100644 --- a/backend/app/services/protocols.py +++ b/backend/app/services/protocols.py @@ -284,21 +284,6 @@ class ConfigService(Protocol): ) -> None: ... - async def read_fail2ban_log( - self, - socket_path: str, - lines: int, - filter_text: str | None = None, - ) -> Fail2BanLogResponse: - ... - - async def get_service_status( - self, - socket_path: str, - probe_fn: Callable[[str], Awaitable[ServiceStatusResponse]] | None = None, - ) -> ServiceStatusResponse: - ... - @runtime_checkable class HistoryService(Protocol): diff --git a/backend/app/services/settings_service.py b/backend/app/services/settings_service.py new file mode 100644 index 0000000..8721c87 --- /dev/null +++ b/backend/app/services/settings_service.py @@ -0,0 +1,89 @@ +"""Shared settings persistence helpers. + +This service centralises storage and validation for application settings that are +shared between setup and runtime configuration workflows. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import structlog + +from app.repositories import settings_repo + +if TYPE_CHECKING: # pragma: no cover + import aiosqlite + +log: structlog.stdlib.BoundLogger = structlog.get_logger() + +_KEY_MAP_COLOR_THRESHOLD_HIGH = "map_color_threshold_high" +_KEY_MAP_COLOR_THRESHOLD_MEDIUM = "map_color_threshold_medium" +_KEY_MAP_COLOR_THRESHOLD_LOW = "map_color_threshold_low" + +_DEFAULT_MAP_COLOR_THRESHOLD_HIGH = 100 +_DEFAULT_MAP_COLOR_THRESHOLD_MEDIUM = 50 +_DEFAULT_MAP_COLOR_THRESHOLD_LOW = 20 + + +async def get_map_color_thresholds( + db: aiosqlite.Connection, +) -> tuple[int, int, int]: + """Return map color thresholds from stored settings. + + Args: + db: Active aiosqlite connection. + + Returns: + A tuple of ``(high, medium, low)`` thresholds. + """ + high = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_HIGH) + medium = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM) + low = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW) + + return ( + int(high) if high else _DEFAULT_MAP_COLOR_THRESHOLD_HIGH, + int(medium) if medium else _DEFAULT_MAP_COLOR_THRESHOLD_MEDIUM, + int(low) if low else _DEFAULT_MAP_COLOR_THRESHOLD_LOW, + ) + + +async def set_map_color_thresholds( + db: aiosqlite.Connection, + *, + threshold_high: int, + threshold_medium: int, + threshold_low: int, +) -> None: + """Persist validated map color thresholds. + + Args: + db: Active aiosqlite connection. + threshold_high: High threshold value. + threshold_medium: Medium threshold value. + threshold_low: Low threshold value. + + Raises: + ValueError: If thresholds are non-positive or misordered. + """ + if threshold_high <= 0 or threshold_medium <= 0 or threshold_low <= 0: + raise ValueError("All thresholds must be positive integers.") + if not (threshold_high > threshold_medium > threshold_low): + raise ValueError("Thresholds must satisfy: high > medium > low.") + + await settings_repo.set_setting( + db, _KEY_MAP_COLOR_THRESHOLD_HIGH, str(threshold_high) + ) + await settings_repo.set_setting( + db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, str(threshold_medium) + ) + await settings_repo.set_setting( + db, _KEY_MAP_COLOR_THRESHOLD_LOW, str(threshold_low) + ) + + log.info( + "map_color_thresholds_persisted", + high=threshold_high, + medium=threshold_medium, + low=threshold_low, + ) diff --git a/backend/app/services/setup_service.py b/backend/app/services/setup_service.py index 230c715..f529c56 100644 --- a/backend/app/services/setup_service.py +++ b/backend/app/services/setup_service.py @@ -15,16 +15,16 @@ import structlog from app.db import init_db, open_db from app.repositories import settings_repo -from app.utils.async_utils import run_blocking -from app.utils.setup_utils import ( +from app.services.settings_service import ( get_map_color_thresholds as util_get_map_color_thresholds, ) +from app.services.settings_service import ( + set_map_color_thresholds as util_set_map_color_thresholds, +) +from app.utils.async_utils import run_blocking from app.utils.setup_utils import ( get_password_hash as util_get_password_hash, ) -from app.utils.setup_utils import ( - set_map_color_thresholds as util_set_map_color_thresholds, -) if TYPE_CHECKING: import aiosqlite diff --git a/backend/app/utils/setup_utils.py b/backend/app/utils/setup_utils.py index 9fa6db3..e3b79c4 100644 --- a/backend/app/utils/setup_utils.py +++ b/backend/app/utils/setup_utils.py @@ -6,42 +6,8 @@ from app.repositories import settings_repo _KEY_PASSWORD_HASH = "master_password_hash" _KEY_SETUP_DONE = "setup_completed" -_KEY_MAP_COLOR_THRESHOLD_HIGH = "map_color_threshold_high" -_KEY_MAP_COLOR_THRESHOLD_MEDIUM = "map_color_threshold_medium" -_KEY_MAP_COLOR_THRESHOLD_LOW = "map_color_threshold_low" async def get_password_hash(db): """Return the stored master password hash or None.""" return await settings_repo.get_setting(db, _KEY_PASSWORD_HASH) - - -async def get_map_color_thresholds(db): - """Return map color thresholds as tuple (high, medium, low).""" - high = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_HIGH) - medium = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM) - low = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW) - - return ( - int(high) if high else 100, - int(medium) if medium else 50, - int(low) if low else 20, - ) - - -async def set_map_color_thresholds( - db, - *, - threshold_high: int, - threshold_medium: int, - threshold_low: int, -) -> None: - """Persist map color thresholds after validating values.""" - if threshold_high <= 0 or threshold_medium <= 0 or threshold_low <= 0: - raise ValueError("All thresholds must be positive integers.") - if not (threshold_high > threshold_medium > threshold_low): - raise ValueError("Thresholds must satisfy: high > medium > low.") - - await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_HIGH, str(threshold_high)) - await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, str(threshold_medium)) - await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW, str(threshold_low)) diff --git a/backend/tests/test_regression_500s.py b/backend/tests/test_regression_500s.py index 079e7fe..efa82aa 100644 --- a/backend/tests/test_regression_500s.py +++ b/backend/tests/test_regression_500s.py @@ -227,7 +227,7 @@ class TestServiceStatusBanguiVersion: async def test_online_response_contains_bangui_version(self) -> None: """The returned model must contain the ``bangui_version`` field.""" from app.models.server import ServerStatus - from app.services import config_service + from app.services import health_service import app online_status = ServerStatus( @@ -250,8 +250,8 @@ class TestServiceStatusBanguiVersion: def __init__(self, **_kw: Any) -> None: self.send = AsyncMock(side_effect=_send) - with patch("app.services.config_service.Fail2BanClient", _FakeClient): - result = await config_service.get_service_status( + with patch("app.services.health_service.Fail2BanClient", _FakeClient): + result = await health_service.get_service_status( "/fake/socket", probe_fn=AsyncMock(return_value=online_status), ) @@ -263,12 +263,12 @@ class TestServiceStatusBanguiVersion: async def test_offline_response_contains_bangui_version(self) -> None: """Even when fail2ban is offline, ``bangui_version`` must be present.""" from app.models.server import ServerStatus - from app.services import config_service + from app.services import health_service import app offline_status = ServerStatus(online=False) - result = await config_service.get_service_status( + result = await health_service.get_service_status( "/fake/socket", probe_fn=AsyncMock(return_value=offline_status), ) diff --git a/backend/tests/test_services/test_config_service.py b/backend/tests/test_services/test_config_service.py index 521bf47..9af33f7 100644 --- a/backend/tests/test_services/test_config_service.py +++ b/backend/tests/test_services/test_config_service.py @@ -2,6 +2,7 @@ from __future__ import annotations +from contextlib import ExitStack from pathlib import Path from typing import Any from unittest.mock import AsyncMock, patch @@ -15,7 +16,7 @@ from app.models.config import ( LogPreviewRequest, RegexTestRequest, ) -from app.services import config_service +from app.services import config_service, health_service, log_service from app.services.config_service import ( ConfigValidationError, JailNotFoundError, @@ -650,7 +651,10 @@ class TestReadFail2BanLog: def __init__(self, **_kw: Any) -> None: self.send = AsyncMock(side_effect=_send) - return patch("app.services.config_service.Fail2BanClient", _FakeClient) + stack = ExitStack() + stack.enter_context(patch("app.services.config_service.Fail2BanClient", _FakeClient)) + stack.enter_context(patch("app.services.log_service.Fail2BanClient", _FakeClient)) + return stack async def test_returns_log_lines_from_file(self, tmp_path: Any) -> None: """read_fail2ban_log returns lines from the file and counts totals.""" @@ -660,8 +664,8 @@ class TestReadFail2BanLog: # Patch _SAFE_LOG_PREFIXES to allow tmp_path with self._patch_client(log_target=str(log_file)), \ - patch("app.services.config_service._SAFE_LOG_PREFIXES", (log_dir,)): - result = await config_service.read_fail2ban_log(_SOCKET, 200) + patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)): + result = await log_service.read_fail2ban_log(_SOCKET, 200) assert result.log_path == str(log_file.resolve()) assert result.total_lines >= 3 @@ -675,8 +679,8 @@ class TestReadFail2BanLog: log_dir = str(tmp_path) with self._patch_client(log_target=str(log_file)), \ - patch("app.services.config_service._SAFE_LOG_PREFIXES", (log_dir,)): - result = await config_service.read_fail2ban_log(_SOCKET, 200, "Found") + patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)): + result = await log_service.read_fail2ban_log(_SOCKET, 200, "Found") assert all("Found" in ln for ln in result.lines) assert result.total_lines >= 3 # total is unfiltered @@ -685,13 +689,13 @@ class TestReadFail2BanLog: """read_fail2ban_log raises ConfigOperationError for STDOUT target.""" with self._patch_client(log_target="STDOUT"), \ pytest.raises(config_service.ConfigOperationError, match="STDOUT"): - await config_service.read_fail2ban_log(_SOCKET, 200) + await log_service.read_fail2ban_log(_SOCKET, 200) async def test_syslog_target_raises_operation_error(self) -> None: """read_fail2ban_log raises ConfigOperationError for SYSLOG target.""" with self._patch_client(log_target="SYSLOG"), \ pytest.raises(config_service.ConfigOperationError, match="SYSLOG"): - await config_service.read_fail2ban_log(_SOCKET, 200) + await log_service.read_fail2ban_log(_SOCKET, 200) async def test_path_outside_safe_dir_raises_operation_error(self, tmp_path: Any) -> None: """read_fail2ban_log rejects a log_target outside allowed directories.""" @@ -700,9 +704,9 @@ class TestReadFail2BanLog: # Allow only /var/log — tmp_path is deliberately not in the safe list. with self._patch_client(log_target=str(log_file)), \ - patch("app.services.config_service._SAFE_LOG_PREFIXES", ("/var/log",)), \ + patch("app.services.log_service._SAFE_LOG_PREFIXES", ("/var/log",)), \ pytest.raises(config_service.ConfigOperationError, match="outside the allowed"): - await config_service.read_fail2ban_log(_SOCKET, 200) + await log_service.read_fail2ban_log(_SOCKET, 200) async def test_missing_log_file_raises_operation_error(self, tmp_path: Any) -> None: """read_fail2ban_log raises ConfigOperationError when the file does not exist.""" @@ -710,9 +714,9 @@ class TestReadFail2BanLog: log_dir = str(tmp_path) with self._patch_client(log_target=missing), \ - patch("app.services.config_service._SAFE_LOG_PREFIXES", (log_dir,)), \ + patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)), \ pytest.raises(config_service.ConfigOperationError, match="not found"): - await config_service.read_fail2ban_log(_SOCKET, 200) + await log_service.read_fail2ban_log(_SOCKET, 200) # --------------------------------------------------------------------------- @@ -743,8 +747,8 @@ class TestGetServiceStatus: def __init__(self, **_kw: Any) -> None: self.send = AsyncMock(side_effect=_send) - with patch("app.services.config_service.Fail2BanClient", _FakeClient): - result = await config_service.get_service_status( + with patch("app.services.health_service.Fail2BanClient", _FakeClient): + result = await health_service.get_service_status( _SOCKET, probe_fn=AsyncMock(return_value=online_status), ) @@ -765,7 +769,7 @@ class TestGetServiceStatus: offline_status = ServerStatus(online=False) - result = await config_service.get_service_status( + result = await health_service.get_service_status( _SOCKET, probe_fn=AsyncMock(return_value=offline_status), )