Refactor map color threshold storage into dedicated settings service

This commit is contained in:
2026-04-17 15:13:07 +02:00
parent 13b3fde274
commit c21cf82e9e
11 changed files with 467 additions and 349 deletions

View File

@@ -138,6 +138,8 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
**Why this is needed:** Triplicated implementation violates DRY and means a change to the threshold schema must be made in three places. Using `setup_service` for ongoing runtime settings is conceptually wrong and misleads maintainers. **Why this is needed:** Triplicated implementation violates DRY and means a change to the threshold schema must be made in three places. Using `setup_service` for ongoing runtime settings is conceptually wrong and misleads maintainers.
**Status:** Completed ✅
--- ---
### Task 8 — Move the `activate_jail` 3-step restart workflow out of `config_misc.py` router ### Task 8 — Move the `activate_jail` 3-step restart workflow out of `config_misc.py` router

View File

@@ -5,8 +5,17 @@ from typing import Annotated
import structlog import structlog
from fastapi import APIRouter, HTTPException, Query, Request, status from fastapi import APIRouter, HTTPException, Query, Request, status
from app.dependencies import AuthDep, DbDep, Fail2BanSocketDep, Fail2BanStartCommandDep from app.dependencies import (
from app.exceptions import ConfigOperationError, JailOperationError AuthDep,
DbDep,
Fail2BanSocketDep,
Fail2BanStartCommandDep,
)
from app.exceptions import (
ConfigOperationError,
Fail2BanConnectionError,
JailOperationError,
)
from app.models.config import ( from app.models.config import (
Fail2BanLogResponse, Fail2BanLogResponse,
GlobalConfigResponse, GlobalConfigResponse,
@@ -19,9 +28,12 @@ from app.models.config import (
RegexTestResponse, RegexTestResponse,
ServiceStatusResponse, ServiceStatusResponse,
) )
from app.services import config_service, jail_service, log_service, setup_service from app.services import (
config_service,
jail_service,
log_service,
)
from app.utils.config_file_utils import start_daemon, wait_for_fail2ban from app.utils.config_file_utils import start_daemon, wait_for_fail2ban
from app.exceptions import Fail2BanConnectionError
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -41,17 +53,20 @@ def _bad_request(message: str) -> HTTPException:
detail=message, detail=message,
) )
@router.get( @router.get(
"/global", "/global",
response_model=GlobalConfigResponse, response_model=GlobalConfigResponse,
summary="Return global fail2ban settings", summary="Return global fail2ban settings",
) )
async def get_global_config( async def get_global_config(
request: Request, _request: Request,
_auth: AuthDep, _auth: AuthDep,
socket_path: Fail2BanSocketDep, socket_path: Fail2BanSocketDep,
) -> GlobalConfigResponse: ) -> GlobalConfigResponse:
"""Return global fail2ban settings (log level, log target, database config). """Return global fail2ban settings.
Includes log level, log target, and database configuration.
Args: Args:
request: Incoming request. request: Incoming request.
@@ -69,15 +84,13 @@ async def get_global_config(
raise _bad_gateway(exc) from exc raise _bad_gateway(exc) from exc
@router.put( @router.put(
"/global", "/global",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
summary="Update global fail2ban settings", summary="Update global fail2ban settings",
) )
async def update_global_config( async def update_global_config(
request: Request, _request: Request,
_auth: AuthDep, _auth: AuthDep,
socket_path: Fail2BanSocketDep, socket_path: Fail2BanSocketDep,
body: GlobalConfigUpdate, body: GlobalConfigUpdate,
@@ -105,16 +118,13 @@ async def update_global_config(
# Reload endpoint # Reload endpoint
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post( @router.post(
"/reload", "/reload",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
summary="Reload fail2ban to apply configuration changes", summary="Reload fail2ban to apply configuration changes",
) )
async def reload_fail2ban( async def reload_fail2ban(
request: Request, _request: Request,
_auth: AuthDep, _auth: AuthDep,
socket_path: Fail2BanSocketDep, socket_path: Fail2BanSocketDep,
) -> None: ) -> None:
@@ -144,16 +154,13 @@ async def reload_fail2ban(
# Restart endpoint # Restart endpoint
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post( @router.post(
"/restart", "/restart",
status_code=status.HTTP_204_NO_CONTENT, status_code=status.HTTP_204_NO_CONTENT,
summary="Restart the fail2ban service", summary="Restart the fail2ban service",
) )
async def restart_fail2ban( async def restart_fail2ban(
request: Request, _request: Request,
_auth: AuthDep, _auth: AuthDep,
socket_path: Fail2BanSocketDep, socket_path: Fail2BanSocketDep,
start_cmd: Fail2BanStartCommandDep, start_cmd: Fail2BanStartCommandDep,
@@ -175,8 +182,8 @@ async def restart_fail2ban(
HTTPException: 503 when fail2ban does not come back online within HTTPException: 503 when fail2ban does not come back online within
10 seconds after being started. Check the fail2ban log for 10 seconds after being started. Check the fail2ban log for
initialisation errors. Use initialisation errors. Use
``POST /api/config/jails/{name}/rollback`` if a specific jail ``POST /api/config/jails/{name}/rollback``
is suspect. if a specific jail is suspect.
""" """
start_cmd_parts: list[str] = start_cmd.split() start_cmd_parts: list[str] = start_cmd.split()
@@ -194,15 +201,21 @@ async def restart_fail2ban(
# Step 2: start the daemon via subprocess. # Step 2: start the daemon via subprocess.
await start_daemon(start_cmd_parts) await start_daemon(start_cmd_parts)
# Step 3: probe the socket until fail2ban is responsive or the budget expires. # Step 3: probe the socket until fail2ban is responsive or the budget
fail2ban_running: bool = await wait_for_fail2ban(socket_path, max_wait_seconds=10.0) # expires.
fail2ban_running: bool = await wait_for_fail2ban(
socket_path,
max_wait_seconds=10.0,
)
if not fail2ban_running: if not fail2ban_running:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail=( detail=(
"fail2ban was stopped but did not come back online within 10 seconds. " "fail2ban was stopped but did not come back "
"online within 10 seconds. "
"Check the fail2ban log for initialisation errors. " "Check the fail2ban log for initialisation errors. "
"Use POST /api/config/jails/{name}/rollback if a specific jail is suspect." "Use POST /api/config/jails/{name}/rollback if a "
"specific jail is suspect."
), ),
) )
log.info("fail2ban_restarted") log.info("fail2ban_restarted")
@@ -212,9 +225,6 @@ async def restart_fail2ban(
# Regex tester (stateless) # Regex tester (stateless)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post( @router.post(
"/regex-test", "/regex-test",
response_model=RegexTestResponse, response_model=RegexTestResponse,
@@ -234,7 +244,8 @@ async def regex_test(
body: Sample log line and regex pattern. body: Sample log line and regex pattern.
Returns: Returns:
:class:`~app.models.config.RegexTestResponse` with match result and groups. :class:`~app.models.config.RegexTestResponse` with match result and
groups.
""" """
return log_service.test_regex(body) return log_service.test_regex(body)
@@ -243,9 +254,6 @@ async def regex_test(
# Log path management # Log path management
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.post( @router.post(
"/preview-log", "/preview-log",
response_model=LogPreviewResponse, response_model=LogPreviewResponse,
@@ -275,16 +283,13 @@ async def preview_log(
# Map color thresholds # Map color thresholds
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@router.get( @router.get(
"/map-color-thresholds", "/map-color-thresholds",
response_model=MapColorThresholdsResponse, response_model=MapColorThresholdsResponse,
summary="Get map color threshold configuration", summary="Get map color threshold configuration",
) )
async def get_map_color_thresholds( async def get_map_color_thresholds(
request: Request, _request: Request,
_auth: AuthDep, _auth: AuthDep,
db: DbDep, db: DbDep,
) -> MapColorThresholdsResponse: ) -> MapColorThresholdsResponse:
@@ -298,16 +303,7 @@ async def get_map_color_thresholds(
:class:`~app.models.config.MapColorThresholdsResponse` with :class:`~app.models.config.MapColorThresholdsResponse` with
current thresholds. current thresholds.
""" """
high, medium, low = await setup_service.get_map_color_thresholds(db) return await config_service.get_map_color_thresholds(db)
return MapColorThresholdsResponse(
threshold_high=high,
threshold_medium=medium,
threshold_low=low,
)
@router.put( @router.put(
"/map-color-thresholds", "/map-color-thresholds",
@@ -315,7 +311,7 @@ async def get_map_color_thresholds(
summary="Update map color threshold configuration", summary="Update map color threshold configuration",
) )
async def update_map_color_thresholds( async def update_map_color_thresholds(
request: Request, _request: Request,
_auth: AuthDep, _auth: AuthDep,
db: DbDep, db: DbDep,
body: MapColorThresholdsUpdate, body: MapColorThresholdsUpdate,
@@ -336,22 +332,11 @@ async def update_map_color_thresholds(
properly ordered). properly ordered).
""" """
try: try:
await setup_service.set_map_color_thresholds( await config_service.update_map_color_thresholds(db, body)
db,
threshold_high=body.threshold_high,
threshold_medium=body.threshold_medium,
threshold_low=body.threshold_low,
)
except ValueError as exc: except ValueError as exc:
raise _bad_request(str(exc)) from exc raise _bad_request(str(exc)) from exc
return MapColorThresholdsResponse( return await config_service.get_map_color_thresholds(db)
threshold_high=body.threshold_high,
threshold_medium=body.threshold_medium,
threshold_low=body.threshold_low,
)
@router.get( @router.get(
@@ -360,13 +345,26 @@ async def update_map_color_thresholds(
summary="Read the tail of the fail2ban daemon log file", summary="Read the tail of the fail2ban daemon log file",
) )
async def get_fail2ban_log( async def get_fail2ban_log(
request: Request, _request: Request,
_auth: AuthDep, _auth: AuthDep,
socket_path: Fail2BanSocketDep, socket_path: Fail2BanSocketDep,
lines: Annotated[int, Query(ge=1, le=2000, description="Number of lines to return from the tail.")] = 200, lines: Annotated[
filter: Annotated[ # noqa: A002 int,
Query(
ge=1,
le=2000,
description="Number of lines to return from the tail.",
),
] = 200,
filter_: Annotated[ # noqa: A002
str | None, str | None,
Query(description="Plain-text substring filter; only matching lines are returned."), Query(
alias="filter",
description=(
"Plain-text substring filter; "
"only matching lines are returned."
),
),
] = None, ] = None,
) -> Fail2BanLogResponse: ) -> Fail2BanLogResponse:
"""Return the tail of the fail2ban daemon log file. """Return the tail of the fail2ban daemon log file.
@@ -390,22 +388,20 @@ async def get_fail2ban_log(
HTTPException: 502 when fail2ban is unreachable. HTTPException: 502 when fail2ban is unreachable.
""" """
try: try:
return await config_service.read_fail2ban_log(socket_path, lines, filter) return await log_service.read_fail2ban_log(socket_path, lines, filter_)
except ConfigOperationError as exc: except ConfigOperationError as exc:
raise _bad_request(str(exc)) from exc raise _bad_request(str(exc)) from exc
except Fail2BanConnectionError as exc: except Fail2BanConnectionError as exc:
raise _bad_gateway(exc) from exc raise _bad_gateway(exc) from exc
@router.get( @router.get(
"/service-status", "/service-status",
response_model=ServiceStatusResponse, response_model=ServiceStatusResponse,
summary="Return fail2ban service health status with log configuration", summary="Return fail2ban service health status with log configuration",
) )
async def get_service_status( async def get_service_status(
request: Request, _request: Request,
_auth: AuthDep, _auth: AuthDep,
socket_path: Fail2BanSocketDep, socket_path: Fail2BanSocketDep,
) -> ServiceStatusResponse: ) -> ServiceStatusResponse:
@@ -428,11 +424,9 @@ async def get_service_status(
from app.services import health_service from app.services import health_service
try: try:
return await config_service.get_service_status( return await health_service.get_service_status(
socket_path, socket_path,
probe_fn=health_service.probe, probe_fn=health_service.probe,
) )
except Fail2BanConnectionError as exc: except Fail2BanConnectionError as exc:
raise _bad_gateway(exc) from exc raise _bad_gateway(exc) from exc

View File

@@ -13,10 +13,8 @@ routers can serialise them directly.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from app.utils.async_utils import run_blocking
import contextlib import contextlib
import re import re
from pathlib import Path
from typing import TYPE_CHECKING, TypeVar, cast from typing import TYPE_CHECKING, TypeVar, cast
import structlog import structlog
@@ -28,14 +26,10 @@ if TYPE_CHECKING:
import aiosqlite import aiosqlite
from app import __version__
from app.exceptions import ConfigOperationError, ConfigValidationError, JailNotFoundError from app.exceptions import ConfigOperationError, ConfigValidationError, JailNotFoundError
from app.services.log_service import preview_log as util_preview_log
from app.services.log_service import test_regex as util_test_regex
from app.models.config import ( from app.models.config import (
AddLogPathRequest, AddLogPathRequest,
BantimeEscalation, BantimeEscalation,
Fail2BanLogResponse,
GlobalConfigResponse, GlobalConfigResponse,
GlobalConfigUpdate, GlobalConfigUpdate,
JailConfig, JailConfig,
@@ -48,15 +42,16 @@ from app.models.config import (
MapColorThresholdsUpdate, MapColorThresholdsUpdate,
RegexTestRequest, RegexTestRequest,
RegexTestResponse, RegexTestResponse,
ServiceStatusResponse,
) )
from app.utils.fail2ban_client import Fail2BanClient from app.services.log_service import preview_log as util_preview_log
from app.utils.setup_utils import ( from app.services.log_service import test_regex as util_test_regex
from app.services.settings_service import (
get_map_color_thresholds as util_get_map_color_thresholds, get_map_color_thresholds as util_get_map_color_thresholds,
) )
from app.utils.setup_utils import ( from app.services.settings_service import (
set_map_color_thresholds as util_set_map_color_thresholds, set_map_color_thresholds as util_set_map_color_thresholds,
) )
from app.utils.fail2ban_client import Fail2BanClient
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -649,181 +644,3 @@ _NON_FILE_LOG_TARGETS: frozenset[str] = frozenset(
_SAFE_LOG_PREFIXES: tuple[str, ...] = ("/var/log", "/config/log") _SAFE_LOG_PREFIXES: tuple[str, ...] = ("/var/log", "/config/log")
def _count_file_lines(file_path: str) -> int:
"""Count the total number of lines in *file_path* synchronously."""
count = 0
with open(file_path, "rb") as fh:
for chunk in iter(lambda: fh.read(65536), b""):
count += chunk.count(b"\n")
return count
def _read_tail_lines(file_path: str, num_lines: int) -> list[str]:
"""Read the last *num_lines* from *file_path* in a memory-efficient way."""
chunk_size = 8192
raw_lines: list[bytes] = []
with open(file_path, "rb") as fh:
fh.seek(0, 2)
end_pos = fh.tell()
if end_pos == 0:
return []
buf = b""
pos = end_pos
while len(raw_lines) <= num_lines and pos > 0:
read_size = min(chunk_size, pos)
pos -= read_size
fh.seek(pos)
chunk = fh.read(read_size)
buf = chunk + buf
raw_lines = buf.split(b"\n")
if pos > 0 and len(raw_lines) > 1:
raw_lines = raw_lines[1:]
return [ln.decode("utf-8", errors="replace").rstrip() for ln in raw_lines[-num_lines:] if ln.strip()]
async def read_fail2ban_log(
socket_path: str,
lines: int,
filter_text: str | None = None,
) -> Fail2BanLogResponse:
"""Read the tail of the fail2ban daemon log file.
Queries the fail2ban socket for the current log target and log level,
validates that the target is a readable file, then returns the last
*lines* entries optionally filtered by *filter_text*.
Security: the resolved log path is rejected unless it starts with one of
the paths in :data:`_SAFE_LOG_PREFIXES`, preventing path traversal.
Args:
socket_path: Path to the fail2ban Unix domain socket.
lines: Number of lines to return from the tail of the file (12000).
filter_text: Optional plain-text substring — only matching lines are
returned. Applied server-side; does not affect *total_lines*.
Returns:
:class:`~app.models.config.Fail2BanLogResponse`.
Raises:
ConfigOperationError: When the log target is not a file, when the
resolved path is outside the allowed directories, or when the
file cannot be read.
~app.utils.fail2ban_client.Fail2BanConnectionError: Socket unreachable.
"""
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
log_level_raw, log_target_raw = await asyncio.gather(
_safe_get_typed(client, ["get", "loglevel"], "INFO"),
_safe_get_typed(client, ["get", "logtarget"], "STDOUT"),
)
log_level = str(log_level_raw or "INFO").upper()
log_target = str(log_target_raw or "STDOUT")
# Reject non-file targets up front.
if log_target.upper() in _NON_FILE_LOG_TARGETS:
raise ConfigOperationError(
f"fail2ban is logging to {log_target!r}. "
"File-based log viewing is only available when fail2ban logs to a file path."
)
# Resolve and validate (security: no path traversal outside safe dirs).
try:
resolved = Path(log_target).resolve()
except (ValueError, OSError) as exc:
raise ConfigOperationError(
f"Cannot resolve log target path {log_target!r}: {exc}"
) from exc
resolved_str = str(resolved)
if not any(resolved_str.startswith(safe) for safe in _SAFE_LOG_PREFIXES):
raise ConfigOperationError(
f"Log path {resolved_str!r} is outside the allowed directory. "
"Only paths under /var/log or /config/log are permitted."
)
if not resolved.is_file():
raise ConfigOperationError(f"Log file not found: {resolved_str!r}")
loop = asyncio.get_event_loop()
total_lines, raw_lines = await asyncio.gather(
run_blocking( _count_file_lines, resolved_str),
run_blocking( _read_tail_lines, resolved_str, lines),
)
filtered = (
[ln for ln in raw_lines if filter_text in ln]
if filter_text
else raw_lines
)
log.info(
"fail2ban_log_read",
log_path=resolved_str,
lines_requested=lines,
lines_returned=len(filtered),
filter_active=filter_text is not None,
)
return Fail2BanLogResponse(
log_path=resolved_str,
lines=filtered,
total_lines=total_lines,
log_level=log_level,
log_target=log_target,
)
async def get_service_status(
socket_path: str,
probe_fn: Callable[[str], Awaitable[ServiceStatusResponse]] | None = None,
) -> ServiceStatusResponse:
"""Return fail2ban service health status with log configuration.
Delegates to an injectable *probe_fn* (defaults to
:func:`~app.services.health_service.probe`). This avoids direct service-to-
service imports inside this module.
Args:
socket_path: Path to the fail2ban Unix domain socket.
probe_fn: Optional probe function.
Returns:
:class:`~app.models.config.ServiceStatusResponse`.
"""
if probe_fn is None:
raise ValueError("probe_fn is required to avoid service-to-service coupling")
server_status = await probe_fn(socket_path)
if server_status.online:
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
log_level_raw, log_target_raw = await asyncio.gather(
_safe_get_typed(client, ["get", "loglevel"], "INFO"),
_safe_get_typed(client, ["get", "logtarget"], "STDOUT"),
)
log_level = str(log_level_raw or "INFO").upper()
log_target = str(log_target_raw or "STDOUT")
else:
log_level = "UNKNOWN"
log_target = "UNKNOWN"
log.info(
"service_status_fetched",
online=server_status.online,
jail_count=server_status.active_jails,
)
return ServiceStatusResponse(
online=server_status.online,
version=__version__,
jail_count=server_status.active_jails,
total_bans=server_status.total_bans,
total_failures=server_status.total_failures,
log_level=log_level,
log_target=log_target,
)

View File

@@ -9,13 +9,18 @@ seconds by the background health-check task, not on every HTTP request.
from __future__ import annotations from __future__ import annotations
from typing import cast import asyncio
from collections.abc import Awaitable, Callable
from typing import TypeVar, cast
import structlog import structlog
from app import __version__
from app.models.config import ServiceStatusResponse
from app.models.server import ServerStatus from app.models.server import ServerStatus
from app.utils.fail2ban_client import ( from app.utils.fail2ban_client import (
Fail2BanClient, Fail2BanClient,
Fail2BanCommand,
Fail2BanConnectionError, Fail2BanConnectionError,
Fail2BanProtocolError, Fail2BanProtocolError,
Fail2BanResponse, Fail2BanResponse,
@@ -49,7 +54,9 @@ def _ok(response: object) -> object:
try: try:
code, data = cast("Fail2BanResponse", response) code, data = cast("Fail2BanResponse", response)
except (TypeError, ValueError) as exc: except (TypeError, ValueError) as exc:
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc raise ValueError(
f"Unexpected fail2ban response shape: {response!r}"
) from exc
if code != 0: if code != 0:
raise ValueError(f"fail2ban returned error code {code}: {data!r}") raise ValueError(f"fail2ban returned error code {code}: {data!r}")
@@ -81,13 +88,101 @@ def _to_dict(pairs: object) -> dict[str, object]:
return result return result
T = TypeVar("T")
async def _safe_get(
client: Fail2BanClient,
command: Fail2BanCommand,
default: object | None = None,
) -> object | None:
"""Send a command and return *default* if it fails."""
try:
return _ok(await client.send(command))
except (
Fail2BanConnectionError,
Fail2BanProtocolError,
ValueError,
OSError,
):
return default
async def _safe_get_typed(
client: Fail2BanClient,
command: Fail2BanCommand,
default: T,
) -> T:
"""Send a command and return the result typed as ``default``'s type."""
return cast("T", await _safe_get(client, command, default))
async def get_service_status(
socket_path: str,
probe_fn: Callable[[str], Awaitable[ServerStatus]] | None = None,
) -> ServiceStatusResponse:
"""Return fail2ban service health status with log configuration.
Delegates to an injectable *probe_fn* (defaults to
:func:`~app.services.health_service.probe`).
Args:
socket_path: Path to the fail2ban Unix domain socket.
probe_fn: Optional probe function.
Returns:
:class:`~app.models.config.ServiceStatusResponse`.
"""
if probe_fn is None:
raise ValueError(
"probe_fn is required to avoid service-to-service coupling"
)
server_status = await probe_fn(socket_path)
if server_status.online:
client = Fail2BanClient(
socket_path=socket_path,
timeout=_SOCKET_TIMEOUT,
)
log_level_raw, log_target_raw = await asyncio.gather(
_safe_get_typed(client, ["get", "loglevel"], "INFO"),
_safe_get_typed(client, ["get", "logtarget"], "STDOUT"),
)
log_level = str(log_level_raw or "INFO").upper()
log_target = str(log_target_raw or "STDOUT")
else:
log_level = "UNKNOWN"
log_target = "UNKNOWN"
log.info(
"service_status_fetched",
online=server_status.online,
jail_count=server_status.active_jails,
)
return ServiceStatusResponse(
online=server_status.online,
version=__version__,
jail_count=server_status.active_jails,
total_bans=server_status.total_bans,
total_failures=server_status.total_failures,
log_level=log_level,
log_target=log_target,
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Public interface # Public interface
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerStatus: async def probe(
"""Probe the fail2ban daemon and return a :class:`~app.models.server.ServerStatus`. socket_path: str,
timeout: float = _SOCKET_TIMEOUT,
) -> ServerStatus:
"""Probe the fail2ban daemon and return a
:class:`~app.models.server.ServerStatus`.
Sends ``ping``, ``version``, ``status``, and per-jail ``status <jail>`` Sends ``ping``, ``version``, ``status``, and per-jail ``status <jail>``
commands. Any socket or protocol error is caught and results in an commands. Any socket or protocol error is caught and results in an
@@ -109,11 +204,14 @@ async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerSta
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
ping_data = _ok(await client.send(["ping"])) ping_data = _ok(await client.send(["ping"]))
if ping_data != "pong": if ping_data != "pong":
log.warning("fail2ban_unexpected_ping_response", response=ping_data) log.warning(
"fail2ban_unexpected_ping_response",
response=ping_data,
)
return ServerStatus(online=False) return ServerStatus(online=False)
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
# 2. Version # # 2. Version
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
try: try:
version: str | None = str(_ok(await client.send(["version"]))) version: str | None = str(_ok(await client.send(["version"])))
@@ -125,7 +223,9 @@ async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerSta
# ------------------------------------------------------------------ # # ------------------------------------------------------------------ #
status_data = _to_dict(_ok(await client.send(["status"]))) status_data = _to_dict(_ok(await client.send(["status"])))
active_jails: int = int(str(status_data.get("Number of jail", 0) or 0)) active_jails: int = int(str(status_data.get("Number of jail", 0) or 0))
jail_list_raw: str = str(status_data.get("Jail list", "") or "").strip() jail_list_raw: str = str(
status_data.get("Jail list", "") or ""
).strip()
jail_names: list[str] = ( jail_names: list[str] = (
[j.strip() for j in jail_list_raw.split(",") if j.strip()] [j.strip() for j in jail_list_raw.split(",") if j.strip()]
if jail_list_raw if jail_list_raw
@@ -140,11 +240,17 @@ async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerSta
for jail_name in jail_names: for jail_name in jail_names:
try: try:
jail_resp = _to_dict(_ok(await client.send(["status", jail_name]))) jail_resp = _to_dict(
_ok(await client.send(["status", jail_name]))
)
filter_stats = _to_dict(jail_resp.get("Filter") or []) filter_stats = _to_dict(jail_resp.get("Filter") or [])
action_stats = _to_dict(jail_resp.get("Actions") or []) action_stats = _to_dict(jail_resp.get("Actions") or [])
total_failures += int(str(filter_stats.get("Currently failed", 0) or 0)) total_failures += int(
total_bans += int(str(action_stats.get("Currently banned", 0) or 0)) str(filter_stats.get("Currently failed", 0) or 0)
)
total_bans += int(
str(action_stats.get("Currently banned", 0) or 0)
)
except (ValueError, TypeError, KeyError) as exc: except (ValueError, TypeError, KeyError) as exc:
log.warning( log.warning(
"fail2ban_jail_status_parse_error", "fail2ban_jail_status_parse_error",
@@ -174,5 +280,3 @@ async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerSta
except ValueError as exc: except ValueError as exc:
log.error("fail2ban_probe_parse_error", error=str(exc)) log.error("fail2ban_probe_parse_error", error=str(exc))
return ServerStatus(online=False) return ServerStatus(online=False)

View File

@@ -1,22 +1,168 @@
"""Log helper service. """Log helper service.
Contains regex test and log preview helpers that are independent of Contains regex test, log preview, and fail2ban log reading helpers.
fail2ban socket operations.
""" """
from __future__ import annotations from __future__ import annotations
from app.utils.async_utils import run_blocking import asyncio
import re import re
import structlog
from pathlib import Path from pathlib import Path
from typing import TypeVar, cast
from app.exceptions import ConfigOperationError
from app.models.config import ( from app.models.config import (
Fail2BanLogResponse,
LogPreviewLine, LogPreviewLine,
LogPreviewRequest, LogPreviewRequest,
LogPreviewResponse, LogPreviewResponse,
RegexTestRequest, RegexTestRequest,
RegexTestResponse, RegexTestResponse,
) )
from app.utils.async_utils import run_blocking
from app.utils.fail2ban_client import (
Fail2BanClient,
Fail2BanConnectionError,
Fail2BanProtocolError,
Fail2BanResponse,
)
log: structlog.stdlib.BoundLogger = structlog.get_logger()
_SOCKET_TIMEOUT: float = 10.0
_NON_FILE_LOG_TARGETS: frozenset[str] = frozenset(
{"STDOUT", "STDERR", "SYSLOG", "SYSTEMD-JOURNAL"}
)
_SAFE_LOG_PREFIXES: tuple[str, ...] = ("/var/log", "/config/log")
def _ok(response: object) -> object:
"""Extract the payload from a fail2ban ``(return_code, data)`` response."""
try:
code, data = cast(Fail2BanResponse, response)
except (TypeError, ValueError) as exc:
raise ValueError(
f"Unexpected fail2ban response shape: {response!r}"
) from exc
if code != 0:
raise ValueError(f"fail2ban returned error code {code}: {data!r}")
return data
def _count_file_lines(file_path: str) -> int:
"""Count the total number of lines in *file_path* synchronously."""
count = 0
with open(file_path, "rb") as fh:
for chunk in iter(lambda: fh.read(65536), b""):
count += chunk.count(b"\n")
return count
async def _safe_get(
client: Fail2BanClient,
command: list[str],
default: object | None = None,
) -> object | None:
"""Send a command and return *default* if it fails."""
try:
return _ok(await client.send(command))
except (
Fail2BanConnectionError,
Fail2BanProtocolError,
OSError,
ValueError,
):
return default
T = TypeVar("T")
async def _safe_get_typed(
client: Fail2BanClient,
command: list[str],
default: T,
) -> T:
"""Send a command and return the result typed as ``default``'s type."""
return cast("T", await _safe_get(client, command, default))
async def read_fail2ban_log(
socket_path: str,
lines: int,
filter_text: str | None = None,
) -> Fail2BanLogResponse:
"""Read the tail of the fail2ban daemon log file.
Queries the fail2ban socket for the current log target and log level,
validates that the target is a readable file, then returns the last
*lines* entries optionally filtered by *filter_text*.
"""
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
log_level_raw, log_target_raw = await asyncio.gather(
_safe_get_typed(client, ["get", "loglevel"], "INFO"),
_safe_get_typed(client, ["get", "logtarget"], "STDOUT"),
)
log_level = str(log_level_raw or "INFO").upper()
log_target = str(log_target_raw or "STDOUT")
if log_target.upper() in _NON_FILE_LOG_TARGETS:
raise ConfigOperationError(
f"fail2ban is logging to {log_target!r}. "
"File-based log viewing is only available when fail2ban logs "
"to a file path."
)
try:
resolved = Path(log_target).resolve()
except (ValueError, OSError) as exc:
raise ConfigOperationError(
f"Cannot resolve log target path {log_target!r}: {exc}"
) from exc
resolved_str = str(resolved)
if not any(resolved_str.startswith(safe) for safe in _SAFE_LOG_PREFIXES):
raise ConfigOperationError(
f"Log path {resolved_str!r} is outside the allowed directory. "
"Only paths under /var/log or /config/log are permitted."
)
if not resolved.is_file():
raise ConfigOperationError(f"Log file not found: {resolved_str!r}")
total_lines, raw_lines = await asyncio.gather(
run_blocking(_count_file_lines, resolved_str),
run_blocking(_read_tail_lines, resolved_str, lines),
)
filtered = (
[ln for ln in raw_lines if filter_text in ln]
if filter_text
else raw_lines
)
log.info(
"fail2ban_log_read",
log_path=resolved_str,
lines_requested=lines,
lines_returned=len(filtered),
filter_active=filter_text is not None,
)
return Fail2BanLogResponse(
log_path=resolved_str,
lines=filtered,
total_lines=total_lines,
log_level=log_level,
log_target=log_target,
)
def test_regex(request: RegexTestRequest) -> RegexTestResponse: def test_regex(request: RegexTestRequest) -> RegexTestResponse:
@@ -38,7 +184,10 @@ def test_regex(request: RegexTestRequest) -> RegexTestResponse:
return RegexTestResponse(matched=False) return RegexTestResponse(matched=False)
groups: list[str] = list(match.groups() or []) groups: list[str] = list(match.groups() or [])
return RegexTestResponse(matched=True, groups=[str(g) for g in groups if g is not None]) return RegexTestResponse(
matched=True,
groups=[str(g) for g in groups if g is not None],
)
async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse: async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse:
@@ -87,7 +236,11 @@ async def preview_log(req: LogPreviewRequest) -> LogPreviewResponse:
matched_count = 0 matched_count = 0
for line in raw_lines: for line in raw_lines:
m = compiled.search(line) m = compiled.search(line)
groups = [str(g) for g in (m.groups() or []) if g is not None] if m else [] groups = [
str(g)
for g in (m.groups() or [])
if g is not None
] if m else []
result_lines.append( result_lines.append(
LogPreviewLine(line=line, matched=(m is not None), groups=groups), LogPreviewLine(line=line, matched=(m is not None), groups=groups),
) )
@@ -124,4 +277,8 @@ def _read_tail_lines(file_path: str, num_lines: int) -> list[str]:
if pos > 0 and len(raw_lines) > 1: if pos > 0 and len(raw_lines) > 1:
raw_lines = raw_lines[1:] raw_lines = raw_lines[1:]
return [ln.decode("utf-8", errors="replace").rstrip() for ln in raw_lines[-num_lines:] if ln.strip()] return [
ln.decode("utf-8", errors="replace").rstrip()
for ln in raw_lines[-num_lines:]
if ln.strip()
]

View File

@@ -284,21 +284,6 @@ class ConfigService(Protocol):
) -> None: ) -> None:
... ...
async def read_fail2ban_log(
self,
socket_path: str,
lines: int,
filter_text: str | None = None,
) -> Fail2BanLogResponse:
...
async def get_service_status(
self,
socket_path: str,
probe_fn: Callable[[str], Awaitable[ServiceStatusResponse]] | None = None,
) -> ServiceStatusResponse:
...
@runtime_checkable @runtime_checkable
class HistoryService(Protocol): class HistoryService(Protocol):

View File

@@ -0,0 +1,89 @@
"""Shared settings persistence helpers.
This service centralises storage and validation for application settings that are
shared between setup and runtime configuration workflows.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import structlog
from app.repositories import settings_repo
if TYPE_CHECKING: # pragma: no cover
import aiosqlite
log: structlog.stdlib.BoundLogger = structlog.get_logger()
_KEY_MAP_COLOR_THRESHOLD_HIGH = "map_color_threshold_high"
_KEY_MAP_COLOR_THRESHOLD_MEDIUM = "map_color_threshold_medium"
_KEY_MAP_COLOR_THRESHOLD_LOW = "map_color_threshold_low"
_DEFAULT_MAP_COLOR_THRESHOLD_HIGH = 100
_DEFAULT_MAP_COLOR_THRESHOLD_MEDIUM = 50
_DEFAULT_MAP_COLOR_THRESHOLD_LOW = 20
async def get_map_color_thresholds(
db: aiosqlite.Connection,
) -> tuple[int, int, int]:
"""Return map color thresholds from stored settings.
Args:
db: Active aiosqlite connection.
Returns:
A tuple of ``(high, medium, low)`` thresholds.
"""
high = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_HIGH)
medium = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM)
low = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW)
return (
int(high) if high else _DEFAULT_MAP_COLOR_THRESHOLD_HIGH,
int(medium) if medium else _DEFAULT_MAP_COLOR_THRESHOLD_MEDIUM,
int(low) if low else _DEFAULT_MAP_COLOR_THRESHOLD_LOW,
)
async def set_map_color_thresholds(
db: aiosqlite.Connection,
*,
threshold_high: int,
threshold_medium: int,
threshold_low: int,
) -> None:
"""Persist validated map color thresholds.
Args:
db: Active aiosqlite connection.
threshold_high: High threshold value.
threshold_medium: Medium threshold value.
threshold_low: Low threshold value.
Raises:
ValueError: If thresholds are non-positive or misordered.
"""
if threshold_high <= 0 or threshold_medium <= 0 or threshold_low <= 0:
raise ValueError("All thresholds must be positive integers.")
if not (threshold_high > threshold_medium > threshold_low):
raise ValueError("Thresholds must satisfy: high > medium > low.")
await settings_repo.set_setting(
db, _KEY_MAP_COLOR_THRESHOLD_HIGH, str(threshold_high)
)
await settings_repo.set_setting(
db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, str(threshold_medium)
)
await settings_repo.set_setting(
db, _KEY_MAP_COLOR_THRESHOLD_LOW, str(threshold_low)
)
log.info(
"map_color_thresholds_persisted",
high=threshold_high,
medium=threshold_medium,
low=threshold_low,
)

View File

@@ -15,16 +15,16 @@ import structlog
from app.db import init_db, open_db from app.db import init_db, open_db
from app.repositories import settings_repo from app.repositories import settings_repo
from app.utils.async_utils import run_blocking from app.services.settings_service import (
from app.utils.setup_utils import (
get_map_color_thresholds as util_get_map_color_thresholds, get_map_color_thresholds as util_get_map_color_thresholds,
) )
from app.services.settings_service import (
set_map_color_thresholds as util_set_map_color_thresholds,
)
from app.utils.async_utils import run_blocking
from app.utils.setup_utils import ( from app.utils.setup_utils import (
get_password_hash as util_get_password_hash, get_password_hash as util_get_password_hash,
) )
from app.utils.setup_utils import (
set_map_color_thresholds as util_set_map_color_thresholds,
)
if TYPE_CHECKING: if TYPE_CHECKING:
import aiosqlite import aiosqlite

View File

@@ -6,42 +6,8 @@ from app.repositories import settings_repo
_KEY_PASSWORD_HASH = "master_password_hash" _KEY_PASSWORD_HASH = "master_password_hash"
_KEY_SETUP_DONE = "setup_completed" _KEY_SETUP_DONE = "setup_completed"
_KEY_MAP_COLOR_THRESHOLD_HIGH = "map_color_threshold_high"
_KEY_MAP_COLOR_THRESHOLD_MEDIUM = "map_color_threshold_medium"
_KEY_MAP_COLOR_THRESHOLD_LOW = "map_color_threshold_low"
async def get_password_hash(db): async def get_password_hash(db):
"""Return the stored master password hash or None.""" """Return the stored master password hash or None."""
return await settings_repo.get_setting(db, _KEY_PASSWORD_HASH) return await settings_repo.get_setting(db, _KEY_PASSWORD_HASH)
async def get_map_color_thresholds(db):
"""Return map color thresholds as tuple (high, medium, low)."""
high = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_HIGH)
medium = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM)
low = await settings_repo.get_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW)
return (
int(high) if high else 100,
int(medium) if medium else 50,
int(low) if low else 20,
)
async def set_map_color_thresholds(
db,
*,
threshold_high: int,
threshold_medium: int,
threshold_low: int,
) -> None:
"""Persist map color thresholds after validating values."""
if threshold_high <= 0 or threshold_medium <= 0 or threshold_low <= 0:
raise ValueError("All thresholds must be positive integers.")
if not (threshold_high > threshold_medium > threshold_low):
raise ValueError("Thresholds must satisfy: high > medium > low.")
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_HIGH, str(threshold_high))
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_MEDIUM, str(threshold_medium))
await settings_repo.set_setting(db, _KEY_MAP_COLOR_THRESHOLD_LOW, str(threshold_low))

View File

@@ -227,7 +227,7 @@ class TestServiceStatusBanguiVersion:
async def test_online_response_contains_bangui_version(self) -> None: async def test_online_response_contains_bangui_version(self) -> None:
"""The returned model must contain the ``bangui_version`` field.""" """The returned model must contain the ``bangui_version`` field."""
from app.models.server import ServerStatus from app.models.server import ServerStatus
from app.services import config_service from app.services import health_service
import app import app
online_status = ServerStatus( online_status = ServerStatus(
@@ -250,8 +250,8 @@ class TestServiceStatusBanguiVersion:
def __init__(self, **_kw: Any) -> None: def __init__(self, **_kw: Any) -> None:
self.send = AsyncMock(side_effect=_send) self.send = AsyncMock(side_effect=_send)
with patch("app.services.config_service.Fail2BanClient", _FakeClient): with patch("app.services.health_service.Fail2BanClient", _FakeClient):
result = await config_service.get_service_status( result = await health_service.get_service_status(
"/fake/socket", "/fake/socket",
probe_fn=AsyncMock(return_value=online_status), probe_fn=AsyncMock(return_value=online_status),
) )
@@ -263,12 +263,12 @@ class TestServiceStatusBanguiVersion:
async def test_offline_response_contains_bangui_version(self) -> None: async def test_offline_response_contains_bangui_version(self) -> None:
"""Even when fail2ban is offline, ``bangui_version`` must be present.""" """Even when fail2ban is offline, ``bangui_version`` must be present."""
from app.models.server import ServerStatus from app.models.server import ServerStatus
from app.services import config_service from app.services import health_service
import app import app
offline_status = ServerStatus(online=False) offline_status = ServerStatus(online=False)
result = await config_service.get_service_status( result = await health_service.get_service_status(
"/fake/socket", "/fake/socket",
probe_fn=AsyncMock(return_value=offline_status), probe_fn=AsyncMock(return_value=offline_status),
) )

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from contextlib import ExitStack
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
@@ -15,7 +16,7 @@ from app.models.config import (
LogPreviewRequest, LogPreviewRequest,
RegexTestRequest, RegexTestRequest,
) )
from app.services import config_service from app.services import config_service, health_service, log_service
from app.services.config_service import ( from app.services.config_service import (
ConfigValidationError, ConfigValidationError,
JailNotFoundError, JailNotFoundError,
@@ -650,7 +651,10 @@ class TestReadFail2BanLog:
def __init__(self, **_kw: Any) -> None: def __init__(self, **_kw: Any) -> None:
self.send = AsyncMock(side_effect=_send) self.send = AsyncMock(side_effect=_send)
return patch("app.services.config_service.Fail2BanClient", _FakeClient) stack = ExitStack()
stack.enter_context(patch("app.services.config_service.Fail2BanClient", _FakeClient))
stack.enter_context(patch("app.services.log_service.Fail2BanClient", _FakeClient))
return stack
async def test_returns_log_lines_from_file(self, tmp_path: Any) -> None: async def test_returns_log_lines_from_file(self, tmp_path: Any) -> None:
"""read_fail2ban_log returns lines from the file and counts totals.""" """read_fail2ban_log returns lines from the file and counts totals."""
@@ -660,8 +664,8 @@ class TestReadFail2BanLog:
# Patch _SAFE_LOG_PREFIXES to allow tmp_path # Patch _SAFE_LOG_PREFIXES to allow tmp_path
with self._patch_client(log_target=str(log_file)), \ with self._patch_client(log_target=str(log_file)), \
patch("app.services.config_service._SAFE_LOG_PREFIXES", (log_dir,)): patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)):
result = await config_service.read_fail2ban_log(_SOCKET, 200) result = await log_service.read_fail2ban_log(_SOCKET, 200)
assert result.log_path == str(log_file.resolve()) assert result.log_path == str(log_file.resolve())
assert result.total_lines >= 3 assert result.total_lines >= 3
@@ -675,8 +679,8 @@ class TestReadFail2BanLog:
log_dir = str(tmp_path) log_dir = str(tmp_path)
with self._patch_client(log_target=str(log_file)), \ with self._patch_client(log_target=str(log_file)), \
patch("app.services.config_service._SAFE_LOG_PREFIXES", (log_dir,)): patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)):
result = await config_service.read_fail2ban_log(_SOCKET, 200, "Found") result = await log_service.read_fail2ban_log(_SOCKET, 200, "Found")
assert all("Found" in ln for ln in result.lines) assert all("Found" in ln for ln in result.lines)
assert result.total_lines >= 3 # total is unfiltered assert result.total_lines >= 3 # total is unfiltered
@@ -685,13 +689,13 @@ class TestReadFail2BanLog:
"""read_fail2ban_log raises ConfigOperationError for STDOUT target.""" """read_fail2ban_log raises ConfigOperationError for STDOUT target."""
with self._patch_client(log_target="STDOUT"), \ with self._patch_client(log_target="STDOUT"), \
pytest.raises(config_service.ConfigOperationError, match="STDOUT"): pytest.raises(config_service.ConfigOperationError, match="STDOUT"):
await config_service.read_fail2ban_log(_SOCKET, 200) await log_service.read_fail2ban_log(_SOCKET, 200)
async def test_syslog_target_raises_operation_error(self) -> None: async def test_syslog_target_raises_operation_error(self) -> None:
"""read_fail2ban_log raises ConfigOperationError for SYSLOG target.""" """read_fail2ban_log raises ConfigOperationError for SYSLOG target."""
with self._patch_client(log_target="SYSLOG"), \ with self._patch_client(log_target="SYSLOG"), \
pytest.raises(config_service.ConfigOperationError, match="SYSLOG"): pytest.raises(config_service.ConfigOperationError, match="SYSLOG"):
await config_service.read_fail2ban_log(_SOCKET, 200) await log_service.read_fail2ban_log(_SOCKET, 200)
async def test_path_outside_safe_dir_raises_operation_error(self, tmp_path: Any) -> None: async def test_path_outside_safe_dir_raises_operation_error(self, tmp_path: Any) -> None:
"""read_fail2ban_log rejects a log_target outside allowed directories.""" """read_fail2ban_log rejects a log_target outside allowed directories."""
@@ -700,9 +704,9 @@ class TestReadFail2BanLog:
# Allow only /var/log — tmp_path is deliberately not in the safe list. # Allow only /var/log — tmp_path is deliberately not in the safe list.
with self._patch_client(log_target=str(log_file)), \ with self._patch_client(log_target=str(log_file)), \
patch("app.services.config_service._SAFE_LOG_PREFIXES", ("/var/log",)), \ patch("app.services.log_service._SAFE_LOG_PREFIXES", ("/var/log",)), \
pytest.raises(config_service.ConfigOperationError, match="outside the allowed"): pytest.raises(config_service.ConfigOperationError, match="outside the allowed"):
await config_service.read_fail2ban_log(_SOCKET, 200) await log_service.read_fail2ban_log(_SOCKET, 200)
async def test_missing_log_file_raises_operation_error(self, tmp_path: Any) -> None: async def test_missing_log_file_raises_operation_error(self, tmp_path: Any) -> None:
"""read_fail2ban_log raises ConfigOperationError when the file does not exist.""" """read_fail2ban_log raises ConfigOperationError when the file does not exist."""
@@ -710,9 +714,9 @@ class TestReadFail2BanLog:
log_dir = str(tmp_path) log_dir = str(tmp_path)
with self._patch_client(log_target=missing), \ with self._patch_client(log_target=missing), \
patch("app.services.config_service._SAFE_LOG_PREFIXES", (log_dir,)), \ patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)), \
pytest.raises(config_service.ConfigOperationError, match="not found"): pytest.raises(config_service.ConfigOperationError, match="not found"):
await config_service.read_fail2ban_log(_SOCKET, 200) await log_service.read_fail2ban_log(_SOCKET, 200)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -743,8 +747,8 @@ class TestGetServiceStatus:
def __init__(self, **_kw: Any) -> None: def __init__(self, **_kw: Any) -> None:
self.send = AsyncMock(side_effect=_send) self.send = AsyncMock(side_effect=_send)
with patch("app.services.config_service.Fail2BanClient", _FakeClient): with patch("app.services.health_service.Fail2BanClient", _FakeClient):
result = await config_service.get_service_status( result = await health_service.get_service_status(
_SOCKET, _SOCKET,
probe_fn=AsyncMock(return_value=online_status), probe_fn=AsyncMock(return_value=online_status),
) )
@@ -765,7 +769,7 @@ class TestGetServiceStatus:
offline_status = ServerStatus(online=False) offline_status = ServerStatus(online=False)
result = await config_service.get_service_status( result = await health_service.get_service_status(
_SOCKET, _SOCKET,
probe_fn=AsyncMock(return_value=offline_status), probe_fn=AsyncMock(return_value=offline_status),
) )