feat: Task 3 — invalid jail config recovery (pre-validation, crash detection, rollback)

- Backend: extend activate_jail() with pre-validation and 4-attempt post-reload
  health probe; add validate_jail_config() and rollback_jail() service functions
- Backend: new endpoints POST /api/config/jails/{name}/validate,
  GET /api/config/pending-recovery, POST /api/config/jails/{name}/rollback
- Backend: extend JailActivationResponse with fail2ban_running + validation_warnings;
  add JailValidationIssue, JailValidationResult, PendingRecovery, RollbackResponse models
- Backend: health_check task tracks last_activation and creates PendingRecovery
  record when fail2ban goes offline within 60 s of an activation
- Backend: add fail2ban_start_command setting (configurable start cmd for rollback)
- Frontend: ActivateJailDialog — pre-validation on open, crash-detected callback,
  extended spinner text during activation+verify
- Frontend: JailsTab — Validate Config button for inactive jails, validation
  result panels (blocking errors + advisory warnings)
- Frontend: RecoveryBanner component — polls pending-recovery, shows full-width
  alert with Disable & Restart / View Logs buttons
- Frontend: MainLayout — mount RecoveryBanner at layout level
- Tests: 19 new backend service tests (validate, rollback, filter/action parsing)
  + 6 health_check crash-detection tests + 11 router tests; 5 RecoveryBanner
  frontend tests; fix mock setup in existing activate_jail tests
This commit is contained in:
2026-03-14 14:13:07 +01:00
parent ab11ece001
commit 0966f347c4
17 changed files with 1862 additions and 26 deletions

View File

@@ -60,6 +60,15 @@ class Settings(BaseSettings):
"Used for listing, viewing, and editing configuration files through the web UI."
),
)
fail2ban_start_command: str = Field(
default="fail2ban-client start",
description=(
"Shell command used to start (not reload) the fail2ban daemon during "
"recovery rollback. Split by whitespace to build the argument list — "
"no shell interpretation is performed. "
"Example: 'systemctl start fail2ban' or 'fail2ban-client start'."
),
)
model_config = SettingsConfigDict(
env_prefix="BANGUI_",

View File

@@ -3,6 +3,8 @@
Request, response, and domain models for the config router and service.
"""
import datetime
from pydantic import BaseModel, ConfigDict, Field
# ---------------------------------------------------------------------------
@@ -860,6 +862,102 @@ class JailActivationResponse(BaseModel):
description="New activation state: ``True`` after activate, ``False`` after deactivate.",
)
message: str = Field(..., description="Human-readable result message.")
fail2ban_running: bool = Field(
default=True,
description=(
"Whether the fail2ban daemon is still running after the activation "
"and reload. ``False`` signals that the daemon may have crashed."
),
)
validation_warnings: list[str] = Field(
default_factory=list,
description="Non-fatal warnings from the pre-activation validation step.",
)
# ---------------------------------------------------------------------------
# Jail validation models (Task 3)
# ---------------------------------------------------------------------------
class JailValidationIssue(BaseModel):
"""A single issue found during pre-activation validation of a jail config."""
model_config = ConfigDict(strict=True)
field: str = Field(
...,
description="Config field associated with this issue, e.g. 'filter', 'failregex', 'logpath'.",
)
message: str = Field(..., description="Human-readable description of the issue.")
class JailValidationResult(BaseModel):
"""Result of pre-activation validation of a single jail configuration."""
model_config = ConfigDict(strict=True)
jail_name: str = Field(..., description="Name of the validated jail.")
valid: bool = Field(..., description="True when no issues were found.")
issues: list[JailValidationIssue] = Field(
default_factory=list,
description="Validation issues found. Empty when valid=True.",
)
# ---------------------------------------------------------------------------
# Rollback response model (Task 3)
# ---------------------------------------------------------------------------
class RollbackResponse(BaseModel):
"""Response for ``POST /api/config/jails/{name}/rollback``."""
model_config = ConfigDict(strict=True)
jail_name: str = Field(..., description="Name of the jail that was disabled.")
disabled: bool = Field(
...,
description="Whether the jail's .local override was successfully written with enabled=false.",
)
fail2ban_running: bool = Field(
...,
description="Whether fail2ban is online after the rollback attempt.",
)
active_jails: int = Field(
default=0,
ge=0,
description="Number of currently active jails after a successful restart.",
)
message: str = Field(..., description="Human-readable result message.")
# ---------------------------------------------------------------------------
# Pending recovery model (Task 3)
# ---------------------------------------------------------------------------
class PendingRecovery(BaseModel):
"""Records a probable activation-caused fail2ban crash pending user action."""
model_config = ConfigDict(strict=True)
jail_name: str = Field(
...,
description="Name of the jail whose activation likely caused the crash.",
)
activated_at: datetime.datetime = Field(
...,
description="ISO-8601 UTC timestamp of when the jail was activated.",
)
detected_at: datetime.datetime = Field(
...,
description="ISO-8601 UTC timestamp of when the crash was detected.",
)
recovered: bool = Field(
default=False,
description="Whether fail2ban has been successfully restarted.",
)
# ---------------------------------------------------------------------------

View File

@@ -9,6 +9,9 @@ global settings, test regex patterns, add log paths, and preview log files.
* ``GET /api/config/jails/inactive`` — list all inactive jails
* ``POST /api/config/jails/{name}/activate`` — activate an inactive jail
* ``POST /api/config/jails/{name}/deactivate`` — deactivate an active jail
* ``POST /api/config/jails/{name}/validate`` — validate jail config pre-activation (Task 3)
* ``POST /api/config/jails/{name}/rollback`` — disable bad jail and restart fail2ban (Task 3)
* ``GET /api/config/pending-recovery`` — active crash-recovery record (Task 3)
* ``POST /api/config/jails/{name}/filter`` — assign a filter to a jail
* ``POST /api/config/jails/{name}/action`` — add an action to a jail
* ``DELETE /api/config/jails/{name}/action/{action_name}`` — remove an action from a jail
@@ -34,6 +37,7 @@ global settings, test regex patterns, add log paths, and preview log files.
from __future__ import annotations
import datetime
from typing import Annotated
from fastapi import APIRouter, HTTPException, Path, Query, Request, status
@@ -60,12 +64,15 @@ from app.models.config import (
JailConfigListResponse,
JailConfigResponse,
JailConfigUpdate,
JailValidationResult,
LogPreviewRequest,
LogPreviewResponse,
MapColorThresholdsResponse,
MapColorThresholdsUpdate,
PendingRecovery,
RegexTestRequest,
RegexTestResponse,
RollbackResponse,
ServiceStatusResponse,
)
from app.services import config_file_service, config_service, jail_service
@@ -611,7 +618,7 @@ async def activate_jail(
req = body if body is not None else ActivateJailRequest()
try:
return await config_file_service.activate_jail(
result = await config_file_service.activate_jail(
config_dir, socket_path, name, req
)
except JailNameError as exc:
@@ -631,6 +638,24 @@ async def activate_jail(
except Fail2BanConnectionError as exc:
raise _bad_gateway(exc) from exc
# Record this activation so the health-check task can attribute a
# subsequent fail2ban crash to it.
request.app.state.last_activation = {
"jail_name": name,
"at": datetime.datetime.now(tz=datetime.UTC),
}
# If fail2ban stopped responding after the reload, create a pending-recovery
# record immediately (before the background health task notices).
if not result.fail2ban_running:
request.app.state.pending_recovery = PendingRecovery(
jail_name=name,
activated_at=request.app.state.last_activation["at"],
detected_at=datetime.datetime.now(tz=datetime.UTC),
)
return result
@router.post(
"/jails/{name}/deactivate",
@@ -684,6 +709,125 @@ async def deactivate_jail(
raise _bad_gateway(exc) from exc
# ---------------------------------------------------------------------------
# Jail validation & rollback endpoints (Task 3)
# ---------------------------------------------------------------------------
@router.post(
"/jails/{name}/validate",
response_model=JailValidationResult,
summary="Validate jail configuration before activation",
)
async def validate_jail(
request: Request,
_auth: AuthDep,
name: _NamePath,
) -> JailValidationResult:
"""Run pre-activation validation checks on a jail configuration.
Validates filter and action file existence, regex pattern compilation, and
log path existence without modifying any files or reloading fail2ban.
Args:
request: FastAPI request object.
_auth: Validated session.
name: Jail name to validate.
Returns:
:class:`~app.models.config.JailValidationResult` with any issues found.
Raises:
HTTPException: 400 if *name* contains invalid characters.
HTTPException: 404 if *name* is not found in any config file.
"""
config_dir: str = request.app.state.settings.fail2ban_config_dir
try:
return await config_file_service.validate_jail_config(config_dir, name)
except JailNameError as exc:
raise _bad_request(str(exc)) from exc
@router.get(
"/pending-recovery",
response_model=PendingRecovery | None,
summary="Return active crash-recovery record if one exists",
)
async def get_pending_recovery(
request: Request,
_auth: AuthDep,
) -> PendingRecovery | None:
"""Return the current :class:`~app.models.config.PendingRecovery` record.
A non-null response means fail2ban crashed shortly after a jail activation
and the user should be offered a rollback option. Returns ``null`` (HTTP
200 with ``null`` body) when no recovery is pending.
Args:
request: FastAPI request object.
_auth: Validated session.
Returns:
:class:`~app.models.config.PendingRecovery` or ``None``.
"""
return getattr(request.app.state, "pending_recovery", None)
@router.post(
"/jails/{name}/rollback",
response_model=RollbackResponse,
summary="Disable a bad jail config and restart fail2ban",
)
async def rollback_jail(
request: Request,
_auth: AuthDep,
name: _NamePath,
) -> RollbackResponse:
"""Disable the specified jail and attempt to restart fail2ban.
Writes ``enabled = false`` to ``jail.d/{name}.local`` (works even when
fail2ban is down — no socket is needed), then runs the configured start
command and waits up to ten seconds for the daemon to come back online.
On success, clears the :class:`~app.models.config.PendingRecovery` record.
Args:
request: FastAPI request object.
_auth: Validated session.
name: Jail name to disable and roll back.
Returns:
:class:`~app.models.config.RollbackResponse`.
Raises:
HTTPException: 400 if *name* contains invalid characters.
HTTPException: 500 if writing the .local override file fails.
"""
config_dir: str = request.app.state.settings.fail2ban_config_dir
socket_path: str = request.app.state.settings.fail2ban_socket
start_cmd: str = request.app.state.settings.fail2ban_start_command
start_cmd_parts: list[str] = start_cmd.split()
try:
result = await config_file_service.rollback_jail(
config_dir, socket_path, name, start_cmd_parts
)
except JailNameError as exc:
raise _bad_request(str(exc)) from exc
except ConfigWriteError as exc:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to write config override: {exc}",
) from exc
# Clear pending recovery if fail2ban came back online.
if result.fail2ban_running:
request.app.state.pending_recovery = None
request.app.state.last_activation = None
return result
# ---------------------------------------------------------------------------
# Filter discovery endpoints (Task 2.1)
# ---------------------------------------------------------------------------

View File

@@ -50,6 +50,9 @@ from app.models.config import (
InactiveJail,
InactiveJailListResponse,
JailActivationResponse,
JailValidationIssue,
JailValidationResult,
RollbackResponse,
)
from app.services import conffile_parser, jail_service
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanConnectionError
@@ -560,6 +563,242 @@ async def _get_active_jail_names(socket_path: str) -> set[str]:
return set()
# ---------------------------------------------------------------------------
# Validation helpers (Task 3)
# ---------------------------------------------------------------------------
# Seconds to wait between fail2ban liveness probes after a reload.
_POST_RELOAD_PROBE_INTERVAL: float = 2.0
# Maximum number of post-reload probe attempts (initial attempt + retries).
_POST_RELOAD_MAX_ATTEMPTS: int = 4
def _extract_action_base_name(action_str: str) -> str | None:
"""Return the base action name from an action assignment string.
Returns ``None`` for complex fail2ban expressions that cannot be resolved
to a single filename (e.g. ``%(action_)s`` interpolations or multi-token
composite actions).
Args:
action_str: A single line from the jail's ``action`` setting.
Returns:
Simple base name suitable for a filesystem lookup, or ``None``.
"""
if "%" in action_str or "$" in action_str:
return None
base = action_str.split("[")[0].strip()
if _SAFE_ACTION_NAME_RE.match(base):
return base
return None
def _validate_jail_config_sync(
config_dir: Path,
name: str,
) -> JailValidationResult:
"""Run synchronous pre-activation checks on a jail configuration.
Validates:
1. Filter file existence in ``filter.d/``.
2. Action file existence in ``action.d/`` (for resolvable action names).
3. Regex compilation for every ``failregex`` and ``ignoreregex`` pattern.
4. Log path existence on disk (generates warnings, not errors).
Args:
config_dir: The fail2ban configuration root directory.
name: Validated jail name.
Returns:
:class:`~app.models.config.JailValidationResult` with any issues found.
"""
issues: list[JailValidationIssue] = []
all_jails, _ = _parse_jails_sync(config_dir)
settings = all_jails.get(name)
if settings is None:
return JailValidationResult(
jail_name=name,
valid=False,
issues=[
JailValidationIssue(
field="name",
message=f"Jail {name!r} not found in config files.",
)
],
)
filter_d = config_dir / "filter.d"
action_d = config_dir / "action.d"
# 1. Filter existence check.
raw_filter = settings.get("filter", "")
if raw_filter:
mode = settings.get("mode", "normal")
resolved = _resolve_filter(raw_filter, name, mode)
base_filter = _extract_filter_base_name(resolved)
if base_filter:
conf_ok = (filter_d / f"{base_filter}.conf").is_file()
local_ok = (filter_d / f"{base_filter}.local").is_file()
if not conf_ok and not local_ok:
issues.append(
JailValidationIssue(
field="filter",
message=(
f"Filter file not found: filter.d/{base_filter}.conf"
" (or .local)"
),
)
)
# 2. Action existence check.
raw_action = settings.get("action", "")
if raw_action:
for action_line in _parse_multiline(raw_action):
action_name = _extract_action_base_name(action_line)
if action_name:
conf_ok = (action_d / f"{action_name}.conf").is_file()
local_ok = (action_d / f"{action_name}.local").is_file()
if not conf_ok and not local_ok:
issues.append(
JailValidationIssue(
field="action",
message=(
f"Action file not found: action.d/{action_name}.conf"
" (or .local)"
),
)
)
# 3. failregex compilation.
for pattern in _parse_multiline(settings.get("failregex", "")):
try:
re.compile(pattern)
except re.error as exc:
issues.append(
JailValidationIssue(
field="failregex",
message=f"Invalid regex pattern: {exc}",
)
)
# 4. ignoreregex compilation.
for pattern in _parse_multiline(settings.get("ignoreregex", "")):
try:
re.compile(pattern)
except re.error as exc:
issues.append(
JailValidationIssue(
field="ignoreregex",
message=f"Invalid regex pattern: {exc}",
)
)
# 5. Log path existence (warning only — paths may be created at runtime).
raw_logpath = settings.get("logpath", "")
if raw_logpath:
for log_path in _parse_multiline(raw_logpath):
# Skip glob patterns and fail2ban variable references.
if "*" in log_path or "?" in log_path or "%(" in log_path:
continue
if not Path(log_path).exists():
issues.append(
JailValidationIssue(
field="logpath",
message=f"Log file not found on disk: {log_path}",
)
)
valid = len(issues) == 0
log.debug(
"jail_validation_complete",
jail=name,
valid=valid,
issue_count=len(issues),
)
return JailValidationResult(jail_name=name, valid=valid, issues=issues)
async def _probe_fail2ban_running(socket_path: str) -> bool:
"""Return ``True`` if the fail2ban socket responds to a ping.
Args:
socket_path: Path to the fail2ban Unix domain socket.
Returns:
``True`` when fail2ban is reachable, ``False`` otherwise.
"""
try:
client = Fail2BanClient(socket_path=socket_path, timeout=5.0)
resp = await client.send(["ping"])
return isinstance(resp, (list, tuple)) and resp[0] == 0
except Exception: # noqa: BLE001
return False
async def _wait_for_fail2ban(
socket_path: str,
max_wait_seconds: float = 10.0,
poll_interval: float = 2.0,
) -> bool:
"""Poll the fail2ban socket until it responds or the timeout expires.
Args:
socket_path: Path to the fail2ban Unix domain socket.
max_wait_seconds: Total time budget in seconds.
poll_interval: Delay between probe attempts in seconds.
Returns:
``True`` if fail2ban came online within the budget.
"""
elapsed = 0.0
while elapsed < max_wait_seconds:
if await _probe_fail2ban_running(socket_path):
return True
await asyncio.sleep(poll_interval)
elapsed += poll_interval
return False
async def _start_daemon(start_cmd_parts: list[str]) -> bool:
"""Start the fail2ban daemon using *start_cmd_parts*.
Uses :func:`asyncio.create_subprocess_exec` (no shell interpretation)
to avoid command injection.
Args:
start_cmd_parts: Command and arguments, e.g.
``["fail2ban-client", "start"]``.
Returns:
``True`` when the process exited with code 0.
"""
if not start_cmd_parts:
log.warning("fail2ban_start_cmd_empty")
return False
try:
proc = await asyncio.create_subprocess_exec(
*start_cmd_parts,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
await asyncio.wait_for(proc.wait(), timeout=30.0)
success = proc.returncode == 0
if not success:
log.warning(
"fail2ban_start_cmd_nonzero",
cmd=start_cmd_parts,
returncode=proc.returncode,
)
return success
except (TimeoutError, OSError) as exc:
log.warning("fail2ban_start_cmd_error", cmd=start_cmd_parts, error=str(exc))
return False
def _write_local_override_sync(
config_dir: Path,
jail_name: str,
@@ -846,9 +1085,10 @@ async def activate_jail(
) -> JailActivationResponse:
"""Enable an inactive jail and reload fail2ban.
Writes ``enabled = true`` (plus any override values from *req*) to
``jail.d/{name}.local`` and then triggers a full fail2ban reload so the
jail starts immediately.
Performs pre-activation validation, writes ``enabled = true`` (plus any
override values from *req*) to ``jail.d/{name}.local``, and triggers a
full fail2ban reload. After the reload a multi-attempt health probe
determines whether fail2ban (and the specific jail) are still running.
Args:
config_dir: Absolute path to the fail2ban configuration directory.
@@ -857,7 +1097,8 @@ async def activate_jail(
req: Optional override values to write alongside ``enabled = true``.
Returns:
:class:`~app.models.config.JailActivationResponse`.
:class:`~app.models.config.JailActivationResponse` including
``fail2ban_running`` and ``validation_warnings`` fields.
Raises:
JailNameError: If *name* contains invalid characters.
@@ -881,6 +1122,20 @@ async def activate_jail(
if name in active_names:
raise JailAlreadyActiveError(name)
# ---------------------------------------------------------------------- #
# Pre-activation validation — collect warnings but do not block #
# ---------------------------------------------------------------------- #
validation_result: JailValidationResult = await loop.run_in_executor(
None, _validate_jail_config_sync, Path(config_dir), name
)
warnings: list[str] = [f"{i.field}: {i.message}" for i in validation_result.issues]
if warnings:
log.warning(
"jail_activation_validation_warnings",
jail=name,
warnings=warnings,
)
overrides: dict[str, Any] = {
"bantime": req.bantime,
"findtime": req.findtime,
@@ -903,9 +1158,35 @@ async def activate_jail(
except Exception as exc: # noqa: BLE001
log.warning("reload_after_activate_failed", jail=name, error=str(exc))
# Verify the jail actually started after the reload. A config error
# (bad regex, missing log file, etc.) may silently prevent fail2ban from
# starting the jail even though the reload command succeeded.
# ---------------------------------------------------------------------- #
# Post-reload health probe with retries #
# ---------------------------------------------------------------------- #
fail2ban_running = False
for attempt in range(_POST_RELOAD_MAX_ATTEMPTS):
if attempt > 0:
await asyncio.sleep(_POST_RELOAD_PROBE_INTERVAL)
if await _probe_fail2ban_running(socket_path):
fail2ban_running = True
break
if not fail2ban_running:
log.warning(
"fail2ban_down_after_activate",
jail=name,
message="fail2ban socket unreachable after reload — daemon may have crashed.",
)
return JailActivationResponse(
name=name,
active=False,
fail2ban_running=False,
validation_warnings=warnings,
message=(
f"Jail {name!r} was written to config but fail2ban stopped "
"responding after reload. The jail configuration may be invalid."
),
)
# Verify the jail actually started (config error may prevent it silently).
post_reload_names = await _get_active_jail_names(socket_path)
actually_running = name in post_reload_names
if not actually_running:
@@ -917,6 +1198,8 @@ async def activate_jail(
return JailActivationResponse(
name=name,
active=False,
fail2ban_running=True,
validation_warnings=warnings,
message=(
f"Jail {name!r} was written to config but did not start after "
"reload — check the jail configuration (filters, log paths, regex)."
@@ -927,6 +1210,8 @@ async def activate_jail(
return JailActivationResponse(
name=name,
active=True,
fail2ban_running=True,
validation_warnings=warnings,
message=f"Jail {name!r} activated successfully.",
)
@@ -994,6 +1279,117 @@ async def deactivate_jail(
)
async def validate_jail_config(
config_dir: str,
name: str,
) -> JailValidationResult:
"""Run pre-activation validation checks on a jail configuration.
Validates that referenced filter and action files exist in ``filter.d/``
and ``action.d/``, that all regex patterns compile, and that declared log
paths exist on disk.
Args:
config_dir: Absolute path to the fail2ban configuration directory.
name: Name of the jail to validate.
Returns:
:class:`~app.models.config.JailValidationResult` with any issues found.
Raises:
JailNameError: If *name* contains invalid characters.
"""
_safe_jail_name(name)
loop = asyncio.get_event_loop()
return await loop.run_in_executor(
None,
_validate_jail_config_sync,
Path(config_dir),
name,
)
async def rollback_jail(
config_dir: str,
socket_path: str,
name: str,
start_cmd_parts: list[str],
) -> RollbackResponse:
"""Disable a bad jail config and restart the fail2ban daemon.
Writes ``enabled = false`` to ``jail.d/{name}.local`` (works even when
fail2ban is down — only a file write), then attempts to start the daemon
with *start_cmd_parts*. Waits up to 10 seconds for the socket to respond.
Args:
config_dir: Absolute path to the fail2ban configuration directory.
socket_path: Path to the fail2ban Unix domain socket.
name: Name of the jail to disable.
start_cmd_parts: Argument list for the daemon start command, e.g.
``["fail2ban-client", "start"]``.
Returns:
:class:`~app.models.config.RollbackResponse`.
Raises:
JailNameError: If *name* contains invalid characters.
ConfigWriteError: If writing the ``.local`` file fails.
"""
_safe_jail_name(name)
loop = asyncio.get_event_loop()
# Write enabled=false — this must succeed even when fail2ban is down.
await loop.run_in_executor(
None,
_write_local_override_sync,
Path(config_dir),
name,
False,
{},
)
log.info("jail_rolled_back_disabled", jail=name)
# Attempt to start the daemon.
started = await _start_daemon(start_cmd_parts)
log.info("jail_rollback_start_attempted", jail=name, start_ok=started)
# Wait for the socket to come back.
fail2ban_running = await _wait_for_fail2ban(
socket_path, max_wait_seconds=10.0, poll_interval=2.0
)
active_jails = 0
if fail2ban_running:
names = await _get_active_jail_names(socket_path)
active_jails = len(names)
if fail2ban_running:
log.info("jail_rollback_success", jail=name, active_jails=active_jails)
return RollbackResponse(
jail_name=name,
disabled=True,
fail2ban_running=True,
active_jails=active_jails,
message=(
f"Jail {name!r} disabled and fail2ban restarted successfully "
f"with {active_jails} active jail(s)."
),
)
log.warning("jail_rollback_fail2ban_still_down", jail=name)
return RollbackResponse(
jail_name=name,
disabled=True,
fail2ban_running=False,
active_jails=0,
message=(
f"Jail {name!r} was disabled but fail2ban did not come back online. "
"Check the fail2ban log for additional errors."
),
)
# ---------------------------------------------------------------------------
# Filter discovery helpers (Task 2.1)
# ---------------------------------------------------------------------------

View File

@@ -4,14 +4,25 @@ Registers an APScheduler job that probes the fail2ban socket every 30 seconds
and stores the result on ``app.state.server_status``. The dashboard endpoint
reads from this cache, keeping HTTP responses fast and the daemon connection
decoupled from user-facing requests.
Crash detection (Task 3)
------------------------
When a jail activation is performed, the router stores a timestamp on
``app.state.last_activation`` (a ``dict`` with ``jail_name`` and ``at``
keys). If the health probe subsequently detects an online→offline transition
within 60 seconds of that activation, a
:class:`~app.models.config.PendingRecovery` record is written to
``app.state.pending_recovery`` so the UI can offer a one-click rollback.
"""
from __future__ import annotations
import datetime
from typing import TYPE_CHECKING, Any
import structlog
from app.models.config import PendingRecovery
from app.models.server import ServerStatus
from app.services import health_service
@@ -23,10 +34,19 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger()
#: How often the probe fires (seconds).
HEALTH_CHECK_INTERVAL: int = 30
#: Maximum seconds since an activation for a subsequent crash to be attributed
#: to that activation.
_ACTIVATION_CRASH_WINDOW: int = 60
async def _run_probe(app: Any) -> None:
"""Probe fail2ban and cache the result on *app.state*.
Detects online/offline state transitions. When fail2ban goes offline
within :data:`_ACTIVATION_CRASH_WINDOW` seconds of the last jail
activation, writes a :class:`~app.models.config.PendingRecovery` record to
``app.state.pending_recovery``.
This is the APScheduler job callback. It reads ``fail2ban_socket`` from
``app.state.settings``, runs the health probe, and writes the result to
``app.state.server_status``.
@@ -42,11 +62,54 @@ async def _run_probe(app: Any) -> None:
status: ServerStatus = await health_service.probe(socket_path)
app.state.server_status = status
now = datetime.datetime.now(tz=datetime.UTC)
# Log transitions between online and offline states.
if status.online and not prev_status.online:
log.info("fail2ban_came_online", version=status.version)
# Clear any pending recovery once fail2ban is back online.
existing: PendingRecovery | None = getattr(
app.state, "pending_recovery", None
)
if existing is not None and not existing.recovered:
app.state.pending_recovery = PendingRecovery(
jail_name=existing.jail_name,
activated_at=existing.activated_at,
detected_at=existing.detected_at,
recovered=True,
)
log.info(
"pending_recovery_resolved",
jail=existing.jail_name,
)
elif not status.online and prev_status.online:
log.warning("fail2ban_went_offline")
# Check whether this crash happened shortly after a jail activation.
last_activation: dict[str, Any] | None = getattr(
app.state, "last_activation", None
)
if last_activation is not None:
activated_at: datetime.datetime = last_activation["at"]
seconds_since = (now - activated_at).total_seconds()
if seconds_since <= _ACTIVATION_CRASH_WINDOW:
jail_name: str = last_activation["jail_name"]
# Only create a new record when there is not already an
# unresolved one for the same jail.
current: PendingRecovery | None = getattr(
app.state, "pending_recovery", None
)
if current is None or current.recovered:
app.state.pending_recovery = PendingRecovery(
jail_name=jail_name,
activated_at=activated_at,
detected_at=now,
)
log.warning(
"activation_crash_detected",
jail=jail_name,
seconds_since_activation=seconds_since,
)
log.debug(
"health_check_complete",
@@ -71,6 +134,10 @@ def register(app: FastAPI) -> None:
# first probe fires.
app.state.server_status = ServerStatus(online=False)
# Initialise activation tracking state.
app.state.last_activation = None
app.state.pending_recovery = None
app.state.scheduler.add_job(
_run_probe,
trigger="interval",

View File

@@ -1874,3 +1874,217 @@ class TestGetServiceStatus:
).get("/api/config/service-status")
assert resp.status_code == 401
# ---------------------------------------------------------------------------
# Task 3 endpoints
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestValidateJailEndpoint:
"""Tests for ``POST /api/config/jails/{name}/validate``."""
async def test_200_valid_config(self, config_client: AsyncClient) -> None:
"""Returns 200 with valid=True when the jail config has no issues."""
from app.models.config import JailValidationResult
mock_result = JailValidationResult(
jail_name="sshd", valid=True, issues=[]
)
with patch(
"app.routers.config.config_file_service.validate_jail_config",
AsyncMock(return_value=mock_result),
):
resp = await config_client.post("/api/config/jails/sshd/validate")
assert resp.status_code == 200
data = resp.json()
assert data["valid"] is True
assert data["jail_name"] == "sshd"
assert data["issues"] == []
async def test_200_invalid_config(self, config_client: AsyncClient) -> None:
"""Returns 200 with valid=False and issues when there are errors."""
from app.models.config import JailValidationIssue, JailValidationResult
issue = JailValidationIssue(field="filter", message="Filter file not found: filter.d/bad.conf (or .local)")
mock_result = JailValidationResult(
jail_name="sshd", valid=False, issues=[issue]
)
with patch(
"app.routers.config.config_file_service.validate_jail_config",
AsyncMock(return_value=mock_result),
):
resp = await config_client.post("/api/config/jails/sshd/validate")
assert resp.status_code == 200
data = resp.json()
assert data["valid"] is False
assert len(data["issues"]) == 1
assert data["issues"][0]["field"] == "filter"
async def test_400_for_invalid_jail_name(self, config_client: AsyncClient) -> None:
"""POST /api/config/jails/bad-name/validate returns 400 on JailNameError."""
from app.services.config_file_service import JailNameError
with patch(
"app.routers.config.config_file_service.validate_jail_config",
AsyncMock(side_effect=JailNameError("bad name")),
):
resp = await config_client.post("/api/config/jails/bad-name/validate")
assert resp.status_code == 400
async def test_401_when_unauthenticated(self, config_client: AsyncClient) -> None:
"""POST /api/config/jails/sshd/validate returns 401 without session."""
resp = await AsyncClient(
transport=ASGITransport(app=config_client._transport.app), # type: ignore[attr-defined]
base_url="http://test",
).post("/api/config/jails/sshd/validate")
assert resp.status_code == 401
@pytest.mark.asyncio
class TestPendingRecovery:
"""Tests for ``GET /api/config/pending-recovery``."""
async def test_returns_null_when_no_pending_recovery(
self, config_client: AsyncClient
) -> None:
"""Returns null body (204-like 200) when pending_recovery is not set."""
app = config_client._transport.app # type: ignore[attr-defined]
app.state.pending_recovery = None
resp = await config_client.get("/api/config/pending-recovery")
assert resp.status_code == 200
assert resp.json() is None
async def test_returns_record_when_set(self, config_client: AsyncClient) -> None:
"""Returns the PendingRecovery model when one is stored on app.state."""
import datetime
from app.models.config import PendingRecovery
now = datetime.datetime.now(tz=datetime.timezone.utc)
record = PendingRecovery(
jail_name="sshd",
activated_at=now - datetime.timedelta(seconds=20),
detected_at=now,
)
app = config_client._transport.app # type: ignore[attr-defined]
app.state.pending_recovery = record
resp = await config_client.get("/api/config/pending-recovery")
assert resp.status_code == 200
data = resp.json()
assert data["jail_name"] == "sshd"
assert data["recovered"] is False
async def test_401_when_unauthenticated(self, config_client: AsyncClient) -> None:
"""GET /api/config/pending-recovery returns 401 without session."""
resp = await AsyncClient(
transport=ASGITransport(app=config_client._transport.app), # type: ignore[attr-defined]
base_url="http://test",
).get("/api/config/pending-recovery")
assert resp.status_code == 401
@pytest.mark.asyncio
class TestRollbackEndpoint:
"""Tests for ``POST /api/config/jails/{name}/rollback``."""
async def test_200_success_clears_pending_recovery(
self, config_client: AsyncClient
) -> None:
"""A successful rollback returns 200 and clears app.state.pending_recovery."""
import datetime
from app.models.config import PendingRecovery, RollbackResponse
# Set up a pending recovery record on the app.
app = config_client._transport.app # type: ignore[attr-defined]
now = datetime.datetime.now(tz=datetime.timezone.utc)
app.state.pending_recovery = PendingRecovery(
jail_name="sshd",
activated_at=now - datetime.timedelta(seconds=10),
detected_at=now,
)
mock_result = RollbackResponse(
jail_name="sshd",
disabled=True,
fail2ban_running=True,
active_jails=0,
message="Jail 'sshd' disabled and fail2ban restarted.",
)
with patch(
"app.routers.config.config_file_service.rollback_jail",
AsyncMock(return_value=mock_result),
):
resp = await config_client.post("/api/config/jails/sshd/rollback")
assert resp.status_code == 200
data = resp.json()
assert data["disabled"] is True
assert data["fail2ban_running"] is True
# Successful rollback must clear the pending record.
assert app.state.pending_recovery is None
async def test_200_fail_preserves_pending_recovery(
self, config_client: AsyncClient
) -> None:
"""When fail2ban is still down after rollback, pending_recovery is retained."""
import datetime
from app.models.config import PendingRecovery, RollbackResponse
app = config_client._transport.app # type: ignore[attr-defined]
now = datetime.datetime.now(tz=datetime.timezone.utc)
record = PendingRecovery(
jail_name="sshd",
activated_at=now - datetime.timedelta(seconds=10),
detected_at=now,
)
app.state.pending_recovery = record
mock_result = RollbackResponse(
jail_name="sshd",
disabled=True,
fail2ban_running=False,
active_jails=0,
message="fail2ban did not come back online.",
)
with patch(
"app.routers.config.config_file_service.rollback_jail",
AsyncMock(return_value=mock_result),
):
resp = await config_client.post("/api/config/jails/sshd/rollback")
assert resp.status_code == 200
data = resp.json()
assert data["fail2ban_running"] is False
# Pending record should NOT be cleared when rollback didn't fully succeed.
assert app.state.pending_recovery is not None
async def test_400_for_invalid_jail_name(self, config_client: AsyncClient) -> None:
"""POST /api/config/jails/bad/rollback returns 400 on JailNameError."""
from app.services.config_file_service import JailNameError
with patch(
"app.routers.config.config_file_service.rollback_jail",
AsyncMock(side_effect=JailNameError("bad")),
):
resp = await config_client.post("/api/config/jails/bad/rollback")
assert resp.status_code == 400
async def test_401_when_unauthenticated(self, config_client: AsyncClient) -> None:
"""POST /api/config/jails/sshd/rollback returns 401 without session."""
resp = await AsyncClient(
transport=ASGITransport(app=config_client._transport.app), # type: ignore[attr-defined]
base_url="http://test",
).post("/api/config/jails/sshd/rollback")
assert resp.status_code == 401

View File

@@ -443,6 +443,10 @@ class TestActivateJail:
new=AsyncMock(side_effect=[set(), {"apache-auth"}]),
),
patch("app.services.config_file_service.jail_service") as mock_js,
patch(
"app.services.config_file_service._probe_fail2ban_running",
new=AsyncMock(return_value=True),
),
):
mock_js.reload_all = AsyncMock()
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
@@ -494,9 +498,13 @@ class TestActivateJail:
with (
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
new=AsyncMock(side_effect=[set(), set()]),
),
patch("app.services.config_file_service.jail_service") as mock_js,
patch(
"app.services.config_file_service._probe_fail2ban_running",
new=AsyncMock(return_value=True),
),
):
mock_js.reload_all = AsyncMock()
await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
@@ -2513,6 +2521,10 @@ class TestActivateJailReloadArgs:
new=AsyncMock(side_effect=[set(), {"apache-auth"}]),
),
patch("app.services.config_file_service.jail_service") as mock_js,
patch(
"app.services.config_file_service._probe_fail2ban_running",
new=AsyncMock(return_value=True),
),
):
mock_js.reload_all = AsyncMock()
await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
@@ -2535,6 +2547,10 @@ class TestActivateJailReloadArgs:
new=AsyncMock(side_effect=[set(), {"apache-auth"}]),
),
patch("app.services.config_file_service.jail_service") as mock_js,
patch(
"app.services.config_file_service._probe_fail2ban_running",
new=AsyncMock(return_value=True),
),
):
mock_js.reload_all = AsyncMock()
result = await activate_jail(
@@ -2558,12 +2574,17 @@ class TestActivateJailReloadArgs:
req = ActivateJailRequest()
# Pre-reload: jail not running. Post-reload: still not running (boot failed).
# fail2ban is up (probe succeeds) but the jail didn't appear.
with (
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(side_effect=[set(), set()]),
),
patch("app.services.config_file_service.jail_service") as mock_js,
patch(
"app.services.config_file_service._probe_fail2ban_running",
new=AsyncMock(return_value=True),
),
):
mock_js.reload_all = AsyncMock()
result = await activate_jail(
@@ -2600,3 +2621,212 @@ class TestDeactivateJailReloadArgs:
"/fake.sock", exclude_jails=["sshd"]
)
# ---------------------------------------------------------------------------
# _validate_jail_config_sync (Task 3)
# ---------------------------------------------------------------------------
from app.services.config_file_service import ( # noqa: E402 (added after block)
_validate_jail_config_sync,
_extract_filter_base_name,
_extract_action_base_name,
validate_jail_config,
rollback_jail,
)
class TestExtractFilterBaseName:
def test_plain_name(self) -> None:
assert _extract_filter_base_name("sshd") == "sshd"
def test_strips_mode_suffix(self) -> None:
assert _extract_filter_base_name("sshd[mode=aggressive]") == "sshd"
def test_strips_whitespace(self) -> None:
assert _extract_filter_base_name(" nginx ") == "nginx"
class TestExtractActionBaseName:
def test_plain_name(self) -> None:
assert _extract_action_base_name("iptables-multiport") == "iptables-multiport"
def test_strips_option_suffix(self) -> None:
assert _extract_action_base_name("iptables[name=SSH]") == "iptables"
def test_returns_none_for_variable_interpolation(self) -> None:
assert _extract_action_base_name("%(action_)s") is None
def test_returns_none_for_dollar_variable(self) -> None:
assert _extract_action_base_name("${action}") is None
class TestValidateJailConfigSync:
"""Tests for _validate_jail_config_sync — the sync validation core."""
def _setup_config(self, config_dir: Path, jail_cfg: str) -> None:
"""Write a minimal fail2ban directory layout with *jail_cfg* content."""
_write(config_dir / "jail.d" / "test.conf", jail_cfg)
def test_valid_config_no_issues(self, tmp_path: Path) -> None:
"""A jail whose filter exists and has a valid regex should pass."""
# Create a real filter file so the existence check passes.
filter_d = tmp_path / "filter.d"
filter_d.mkdir(parents=True, exist_ok=True)
(filter_d / "sshd.conf").write_text("[Definition]\nfailregex = Host .* <HOST>\n")
self._setup_config(
tmp_path,
"[sshd]\nenabled = false\nfilter = sshd\nlogpath = /no/such/log\n",
)
result = _validate_jail_config_sync(tmp_path, "sshd")
# logpath advisory warning is OK; no blocking errors expected.
blocking = [i for i in result.issues if i.field != "logpath"]
assert blocking == [], blocking
def test_missing_filter_reported(self, tmp_path: Path) -> None:
"""A jail whose filter file does not exist must report a filter issue."""
self._setup_config(
tmp_path,
"[bad-jail]\nenabled = false\nfilter = nonexistent-filter\n",
)
result = _validate_jail_config_sync(tmp_path, "bad-jail")
assert not result.valid
fields = [i.field for i in result.issues]
assert "filter" in fields
def test_bad_failregex_reported(self, tmp_path: Path) -> None:
"""A jail with an un-compilable failregex must report a failregex issue."""
self._setup_config(
tmp_path,
"[broken]\nenabled = false\nfailregex = [invalid regex(\n",
)
result = _validate_jail_config_sync(tmp_path, "broken")
assert not result.valid
fields = [i.field for i in result.issues]
assert "failregex" in fields
def test_missing_log_path_is_advisory(self, tmp_path: Path) -> None:
"""A missing log path should be reported in the logpath field."""
self._setup_config(
tmp_path,
"[myjail]\nenabled = false\nlogpath = /no/such/path.log\n",
)
result = _validate_jail_config_sync(tmp_path, "myjail")
fields = [i.field for i in result.issues]
assert "logpath" in fields
def test_missing_action_reported(self, tmp_path: Path) -> None:
"""A jail referencing a non-existent action file must report an action issue."""
self._setup_config(
tmp_path,
"[myjail]\nenabled = false\naction = nonexistent-action\n",
)
result = _validate_jail_config_sync(tmp_path, "myjail")
fields = [i.field for i in result.issues]
assert "action" in fields
def test_unknown_jail_name(self, tmp_path: Path) -> None:
"""Requesting validation for a jail not in any config returns an invalid result."""
(tmp_path / "jail.d").mkdir(parents=True, exist_ok=True)
result = _validate_jail_config_sync(tmp_path, "ghost")
assert not result.valid
assert any(i.field == "name" for i in result.issues)
def test_variable_action_not_flagged(self, tmp_path: Path) -> None:
"""An action like ``%(action_)s`` should not be checked for file existence."""
self._setup_config(
tmp_path,
"[myjail]\nenabled = false\naction = %(action_)s\n",
)
result = _validate_jail_config_sync(tmp_path, "myjail")
# Ensure no action file-missing error (the variable expression is skipped).
action_errors = [i for i in result.issues if i.field == "action"]
assert action_errors == []
@pytest.mark.asyncio
class TestValidateJailConfigAsync:
"""Tests for the public async wrapper validate_jail_config."""
async def test_returns_jail_validation_result(self, tmp_path: Path) -> None:
(tmp_path / "jail.d").mkdir(parents=True, exist_ok=True)
_write(
tmp_path / "jail.d" / "test.conf",
"[testjail]\nenabled = false\n",
)
result = await validate_jail_config(str(tmp_path), "testjail")
assert result.jail_name == "testjail"
async def test_rejects_unsafe_name(self, tmp_path: Path) -> None:
with pytest.raises(JailNameError):
await validate_jail_config(str(tmp_path), "../evil")
@pytest.mark.asyncio
class TestRollbackJail:
"""Tests for rollback_jail (Task 3)."""
async def test_rollback_success(self, tmp_path: Path) -> None:
"""When fail2ban comes back online, rollback returns fail2ban_running=True."""
_write(tmp_path / "jail.d" / "sshd.conf", "[sshd]\nenabled = true\n")
with (
patch(
"app.services.config_file_service._start_daemon",
new=AsyncMock(return_value=True),
),
patch(
"app.services.config_file_service._wait_for_fail2ban",
new=AsyncMock(return_value=True),
),
patch(
"app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()),
),
):
result = await rollback_jail(
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
)
assert result.disabled is True
assert result.fail2ban_running is True
assert result.jail_name == "sshd"
# .local file must have enabled=false
local = tmp_path / "jail.d" / "sshd.local"
assert local.is_file()
assert "enabled = false" in local.read_text()
async def test_rollback_fail2ban_still_down(self, tmp_path: Path) -> None:
"""When fail2ban does not come back, rollback returns fail2ban_running=False."""
_write(tmp_path / "jail.d" / "sshd.conf", "[sshd]\nenabled = true\n")
with (
patch(
"app.services.config_file_service._start_daemon",
new=AsyncMock(return_value=False),
),
patch(
"app.services.config_file_service._wait_for_fail2ban",
new=AsyncMock(return_value=False),
),
):
result = await rollback_jail(
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
)
assert result.fail2ban_running is False
assert result.disabled is True
async def test_rollback_rejects_unsafe_name(self, tmp_path: Path) -> None:
with pytest.raises(JailNameError):
await rollback_jail(
str(tmp_path), "/fake.sock", "../evil", ["fail2ban-client", "start"]
)

View File

@@ -8,10 +8,12 @@ the scheduler and primes the initial status.
from __future__ import annotations
import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.models.config import PendingRecovery
from app.models.server import ServerStatus
from app.tasks.health_check import HEALTH_CHECK_INTERVAL, _run_probe, register
@@ -33,6 +35,8 @@ def _make_app(prev_online: bool = False) -> MagicMock:
app.state.settings.fail2ban_socket = "/var/run/fail2ban/fail2ban.sock"
app.state.server_status = ServerStatus(online=prev_online)
app.state.scheduler = MagicMock()
app.state.last_activation = None
app.state.pending_recovery = None
return app
@@ -236,3 +240,111 @@ class TestRegister:
_, kwargs = app.state.scheduler.add_job.call_args
assert kwargs["kwargs"] == {"app": app}
def test_register_initialises_last_activation_none(self) -> None:
"""``register`` must set ``app.state.last_activation = None``."""
app = _make_app()
register(app)
assert app.state.last_activation is None
def test_register_initialises_pending_recovery_none(self) -> None:
"""``register`` must set ``app.state.pending_recovery = None``."""
app = _make_app()
register(app)
assert app.state.pending_recovery is None
# ---------------------------------------------------------------------------
# Crash detection (Task 3)
# ---------------------------------------------------------------------------
class TestCrashDetection:
"""Tests for activation-crash detection in _run_probe."""
@pytest.mark.asyncio
async def test_crash_within_window_creates_pending_recovery(self) -> None:
"""An online→offline transition within 60 s of activation must set pending_recovery."""
app = _make_app(prev_online=True)
now = datetime.datetime.now(tz=datetime.timezone.utc)
app.state.last_activation = {
"jail_name": "sshd",
"at": now - datetime.timedelta(seconds=10),
}
app.state.pending_recovery = None
offline_status = ServerStatus(online=False)
with patch(
"app.tasks.health_check.health_service.probe",
new_callable=AsyncMock,
return_value=offline_status,
):
await _run_probe(app)
assert app.state.pending_recovery is not None
assert isinstance(app.state.pending_recovery, PendingRecovery)
assert app.state.pending_recovery.jail_name == "sshd"
assert app.state.pending_recovery.recovered is False
@pytest.mark.asyncio
async def test_crash_outside_window_does_not_create_pending_recovery(self) -> None:
"""A crash more than 60 s after activation must NOT set pending_recovery."""
app = _make_app(prev_online=True)
app.state.last_activation = {
"jail_name": "sshd",
"at": datetime.datetime.now(tz=datetime.timezone.utc)
- datetime.timedelta(seconds=120),
}
app.state.pending_recovery = None
with patch(
"app.tasks.health_check.health_service.probe",
new_callable=AsyncMock,
return_value=ServerStatus(online=False),
):
await _run_probe(app)
assert app.state.pending_recovery is None
@pytest.mark.asyncio
async def test_came_online_marks_pending_recovery_resolved(self) -> None:
"""An offline→online transition must mark an existing pending_recovery as recovered."""
app = _make_app(prev_online=False)
activated_at = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(seconds=30)
detected_at = datetime.datetime.now(tz=datetime.timezone.utc)
app.state.pending_recovery = PendingRecovery(
jail_name="sshd",
activated_at=activated_at,
detected_at=detected_at,
recovered=False,
)
with patch(
"app.tasks.health_check.health_service.probe",
new_callable=AsyncMock,
return_value=ServerStatus(online=True),
):
await _run_probe(app)
assert app.state.pending_recovery.recovered is True
@pytest.mark.asyncio
async def test_crash_without_recent_activation_does_nothing(self) -> None:
"""A crash when last_activation is None must not create a pending_recovery."""
app = _make_app(prev_online=True)
app.state.last_activation = None
app.state.pending_recovery = None
with patch(
"app.tasks.health_check.health_service.probe",
new_callable=AsyncMock,
return_value=ServerStatus(online=False),
):
await _run_probe(app)
assert app.state.pending_recovery is None