feat: Task 3 — invalid jail config recovery (pre-validation, crash detection, rollback)
- Backend: extend activate_jail() with pre-validation and 4-attempt post-reload
health probe; add validate_jail_config() and rollback_jail() service functions
- Backend: new endpoints POST /api/config/jails/{name}/validate,
GET /api/config/pending-recovery, POST /api/config/jails/{name}/rollback
- Backend: extend JailActivationResponse with fail2ban_running + validation_warnings;
add JailValidationIssue, JailValidationResult, PendingRecovery, RollbackResponse models
- Backend: health_check task tracks last_activation and creates PendingRecovery
record when fail2ban goes offline within 60 s of an activation
- Backend: add fail2ban_start_command setting (configurable start cmd for rollback)
- Frontend: ActivateJailDialog — pre-validation on open, crash-detected callback,
extended spinner text during activation+verify
- Frontend: JailsTab — Validate Config button for inactive jails, validation
result panels (blocking errors + advisory warnings)
- Frontend: RecoveryBanner component — polls pending-recovery, shows full-width
alert with Disable & Restart / View Logs buttons
- Frontend: MainLayout — mount RecoveryBanner at layout level
- Tests: 19 new backend service tests (validate, rollback, filter/action parsing)
+ 6 health_check crash-detection tests + 11 router tests; 5 RecoveryBanner
frontend tests; fix mock setup in existing activate_jail tests
This commit is contained in:
@@ -1874,3 +1874,217 @@ class TestGetServiceStatus:
|
||||
).get("/api/config/service-status")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Task 3 endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestValidateJailEndpoint:
|
||||
"""Tests for ``POST /api/config/jails/{name}/validate``."""
|
||||
|
||||
async def test_200_valid_config(self, config_client: AsyncClient) -> None:
|
||||
"""Returns 200 with valid=True when the jail config has no issues."""
|
||||
from app.models.config import JailValidationResult
|
||||
|
||||
mock_result = JailValidationResult(
|
||||
jail_name="sshd", valid=True, issues=[]
|
||||
)
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.validate_jail_config",
|
||||
AsyncMock(return_value=mock_result),
|
||||
):
|
||||
resp = await config_client.post("/api/config/jails/sshd/validate")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["valid"] is True
|
||||
assert data["jail_name"] == "sshd"
|
||||
assert data["issues"] == []
|
||||
|
||||
async def test_200_invalid_config(self, config_client: AsyncClient) -> None:
|
||||
"""Returns 200 with valid=False and issues when there are errors."""
|
||||
from app.models.config import JailValidationIssue, JailValidationResult
|
||||
|
||||
issue = JailValidationIssue(field="filter", message="Filter file not found: filter.d/bad.conf (or .local)")
|
||||
mock_result = JailValidationResult(
|
||||
jail_name="sshd", valid=False, issues=[issue]
|
||||
)
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.validate_jail_config",
|
||||
AsyncMock(return_value=mock_result),
|
||||
):
|
||||
resp = await config_client.post("/api/config/jails/sshd/validate")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["valid"] is False
|
||||
assert len(data["issues"]) == 1
|
||||
assert data["issues"][0]["field"] == "filter"
|
||||
|
||||
async def test_400_for_invalid_jail_name(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/jails/bad-name/validate returns 400 on JailNameError."""
|
||||
from app.services.config_file_service import JailNameError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.validate_jail_config",
|
||||
AsyncMock(side_effect=JailNameError("bad name")),
|
||||
):
|
||||
resp = await config_client.post("/api/config/jails/bad-name/validate")
|
||||
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_401_when_unauthenticated(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/jails/sshd/validate returns 401 without session."""
|
||||
resp = await AsyncClient(
|
||||
transport=ASGITransport(app=config_client._transport.app), # type: ignore[attr-defined]
|
||||
base_url="http://test",
|
||||
).post("/api/config/jails/sshd/validate")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestPendingRecovery:
|
||||
"""Tests for ``GET /api/config/pending-recovery``."""
|
||||
|
||||
async def test_returns_null_when_no_pending_recovery(
|
||||
self, config_client: AsyncClient
|
||||
) -> None:
|
||||
"""Returns null body (204-like 200) when pending_recovery is not set."""
|
||||
app = config_client._transport.app # type: ignore[attr-defined]
|
||||
app.state.pending_recovery = None
|
||||
|
||||
resp = await config_client.get("/api/config/pending-recovery")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() is None
|
||||
|
||||
async def test_returns_record_when_set(self, config_client: AsyncClient) -> None:
|
||||
"""Returns the PendingRecovery model when one is stored on app.state."""
|
||||
import datetime
|
||||
|
||||
from app.models.config import PendingRecovery
|
||||
|
||||
now = datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
record = PendingRecovery(
|
||||
jail_name="sshd",
|
||||
activated_at=now - datetime.timedelta(seconds=20),
|
||||
detected_at=now,
|
||||
)
|
||||
app = config_client._transport.app # type: ignore[attr-defined]
|
||||
app.state.pending_recovery = record
|
||||
|
||||
resp = await config_client.get("/api/config/pending-recovery")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["jail_name"] == "sshd"
|
||||
assert data["recovered"] is False
|
||||
|
||||
async def test_401_when_unauthenticated(self, config_client: AsyncClient) -> None:
|
||||
"""GET /api/config/pending-recovery returns 401 without session."""
|
||||
resp = await AsyncClient(
|
||||
transport=ASGITransport(app=config_client._transport.app), # type: ignore[attr-defined]
|
||||
base_url="http://test",
|
||||
).get("/api/config/pending-recovery")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestRollbackEndpoint:
|
||||
"""Tests for ``POST /api/config/jails/{name}/rollback``."""
|
||||
|
||||
async def test_200_success_clears_pending_recovery(
|
||||
self, config_client: AsyncClient
|
||||
) -> None:
|
||||
"""A successful rollback returns 200 and clears app.state.pending_recovery."""
|
||||
import datetime
|
||||
|
||||
from app.models.config import PendingRecovery, RollbackResponse
|
||||
|
||||
# Set up a pending recovery record on the app.
|
||||
app = config_client._transport.app # type: ignore[attr-defined]
|
||||
now = datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
app.state.pending_recovery = PendingRecovery(
|
||||
jail_name="sshd",
|
||||
activated_at=now - datetime.timedelta(seconds=10),
|
||||
detected_at=now,
|
||||
)
|
||||
|
||||
mock_result = RollbackResponse(
|
||||
jail_name="sshd",
|
||||
disabled=True,
|
||||
fail2ban_running=True,
|
||||
active_jails=0,
|
||||
message="Jail 'sshd' disabled and fail2ban restarted.",
|
||||
)
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.rollback_jail",
|
||||
AsyncMock(return_value=mock_result),
|
||||
):
|
||||
resp = await config_client.post("/api/config/jails/sshd/rollback")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["disabled"] is True
|
||||
assert data["fail2ban_running"] is True
|
||||
# Successful rollback must clear the pending record.
|
||||
assert app.state.pending_recovery is None
|
||||
|
||||
async def test_200_fail_preserves_pending_recovery(
|
||||
self, config_client: AsyncClient
|
||||
) -> None:
|
||||
"""When fail2ban is still down after rollback, pending_recovery is retained."""
|
||||
import datetime
|
||||
|
||||
from app.models.config import PendingRecovery, RollbackResponse
|
||||
|
||||
app = config_client._transport.app # type: ignore[attr-defined]
|
||||
now = datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
record = PendingRecovery(
|
||||
jail_name="sshd",
|
||||
activated_at=now - datetime.timedelta(seconds=10),
|
||||
detected_at=now,
|
||||
)
|
||||
app.state.pending_recovery = record
|
||||
|
||||
mock_result = RollbackResponse(
|
||||
jail_name="sshd",
|
||||
disabled=True,
|
||||
fail2ban_running=False,
|
||||
active_jails=0,
|
||||
message="fail2ban did not come back online.",
|
||||
)
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.rollback_jail",
|
||||
AsyncMock(return_value=mock_result),
|
||||
):
|
||||
resp = await config_client.post("/api/config/jails/sshd/rollback")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["fail2ban_running"] is False
|
||||
# Pending record should NOT be cleared when rollback didn't fully succeed.
|
||||
assert app.state.pending_recovery is not None
|
||||
|
||||
async def test_400_for_invalid_jail_name(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/jails/bad/rollback returns 400 on JailNameError."""
|
||||
from app.services.config_file_service import JailNameError
|
||||
|
||||
with patch(
|
||||
"app.routers.config.config_file_service.rollback_jail",
|
||||
AsyncMock(side_effect=JailNameError("bad")),
|
||||
):
|
||||
resp = await config_client.post("/api/config/jails/bad/rollback")
|
||||
|
||||
assert resp.status_code == 400
|
||||
|
||||
async def test_401_when_unauthenticated(self, config_client: AsyncClient) -> None:
|
||||
"""POST /api/config/jails/sshd/rollback returns 401 without session."""
|
||||
resp = await AsyncClient(
|
||||
transport=ASGITransport(app=config_client._transport.app), # type: ignore[attr-defined]
|
||||
base_url="http://test",
|
||||
).post("/api/config/jails/sshd/rollback")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
@@ -443,6 +443,10 @@ class TestActivateJail:
|
||||
new=AsyncMock(side_effect=[set(), {"apache-auth"}]),
|
||||
),
|
||||
patch("app.services.config_file_service.jail_service") as mock_js,
|
||||
patch(
|
||||
"app.services.config_file_service._probe_fail2ban_running",
|
||||
new=AsyncMock(return_value=True),
|
||||
),
|
||||
):
|
||||
mock_js.reload_all = AsyncMock()
|
||||
result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||
@@ -494,9 +498,13 @@ class TestActivateJail:
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
new=AsyncMock(side_effect=[set(), set()]),
|
||||
),
|
||||
patch("app.services.config_file_service.jail_service") as mock_js,
|
||||
patch(
|
||||
"app.services.config_file_service._probe_fail2ban_running",
|
||||
new=AsyncMock(return_value=True),
|
||||
),
|
||||
):
|
||||
mock_js.reload_all = AsyncMock()
|
||||
await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||
@@ -2513,6 +2521,10 @@ class TestActivateJailReloadArgs:
|
||||
new=AsyncMock(side_effect=[set(), {"apache-auth"}]),
|
||||
),
|
||||
patch("app.services.config_file_service.jail_service") as mock_js,
|
||||
patch(
|
||||
"app.services.config_file_service._probe_fail2ban_running",
|
||||
new=AsyncMock(return_value=True),
|
||||
),
|
||||
):
|
||||
mock_js.reload_all = AsyncMock()
|
||||
await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
|
||||
@@ -2535,6 +2547,10 @@ class TestActivateJailReloadArgs:
|
||||
new=AsyncMock(side_effect=[set(), {"apache-auth"}]),
|
||||
),
|
||||
patch("app.services.config_file_service.jail_service") as mock_js,
|
||||
patch(
|
||||
"app.services.config_file_service._probe_fail2ban_running",
|
||||
new=AsyncMock(return_value=True),
|
||||
),
|
||||
):
|
||||
mock_js.reload_all = AsyncMock()
|
||||
result = await activate_jail(
|
||||
@@ -2558,12 +2574,17 @@ class TestActivateJailReloadArgs:
|
||||
|
||||
req = ActivateJailRequest()
|
||||
# Pre-reload: jail not running. Post-reload: still not running (boot failed).
|
||||
# fail2ban is up (probe succeeds) but the jail didn't appear.
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(side_effect=[set(), set()]),
|
||||
),
|
||||
patch("app.services.config_file_service.jail_service") as mock_js,
|
||||
patch(
|
||||
"app.services.config_file_service._probe_fail2ban_running",
|
||||
new=AsyncMock(return_value=True),
|
||||
),
|
||||
):
|
||||
mock_js.reload_all = AsyncMock()
|
||||
result = await activate_jail(
|
||||
@@ -2600,3 +2621,212 @@ class TestDeactivateJailReloadArgs:
|
||||
"/fake.sock", exclude_jails=["sshd"]
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _validate_jail_config_sync (Task 3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from app.services.config_file_service import ( # noqa: E402 (added after block)
|
||||
_validate_jail_config_sync,
|
||||
_extract_filter_base_name,
|
||||
_extract_action_base_name,
|
||||
validate_jail_config,
|
||||
rollback_jail,
|
||||
)
|
||||
|
||||
|
||||
class TestExtractFilterBaseName:
|
||||
def test_plain_name(self) -> None:
|
||||
assert _extract_filter_base_name("sshd") == "sshd"
|
||||
|
||||
def test_strips_mode_suffix(self) -> None:
|
||||
assert _extract_filter_base_name("sshd[mode=aggressive]") == "sshd"
|
||||
|
||||
def test_strips_whitespace(self) -> None:
|
||||
assert _extract_filter_base_name(" nginx ") == "nginx"
|
||||
|
||||
|
||||
class TestExtractActionBaseName:
|
||||
def test_plain_name(self) -> None:
|
||||
assert _extract_action_base_name("iptables-multiport") == "iptables-multiport"
|
||||
|
||||
def test_strips_option_suffix(self) -> None:
|
||||
assert _extract_action_base_name("iptables[name=SSH]") == "iptables"
|
||||
|
||||
def test_returns_none_for_variable_interpolation(self) -> None:
|
||||
assert _extract_action_base_name("%(action_)s") is None
|
||||
|
||||
def test_returns_none_for_dollar_variable(self) -> None:
|
||||
assert _extract_action_base_name("${action}") is None
|
||||
|
||||
|
||||
class TestValidateJailConfigSync:
|
||||
"""Tests for _validate_jail_config_sync — the sync validation core."""
|
||||
|
||||
def _setup_config(self, config_dir: Path, jail_cfg: str) -> None:
|
||||
"""Write a minimal fail2ban directory layout with *jail_cfg* content."""
|
||||
_write(config_dir / "jail.d" / "test.conf", jail_cfg)
|
||||
|
||||
def test_valid_config_no_issues(self, tmp_path: Path) -> None:
|
||||
"""A jail whose filter exists and has a valid regex should pass."""
|
||||
# Create a real filter file so the existence check passes.
|
||||
filter_d = tmp_path / "filter.d"
|
||||
filter_d.mkdir(parents=True, exist_ok=True)
|
||||
(filter_d / "sshd.conf").write_text("[Definition]\nfailregex = Host .* <HOST>\n")
|
||||
|
||||
self._setup_config(
|
||||
tmp_path,
|
||||
"[sshd]\nenabled = false\nfilter = sshd\nlogpath = /no/such/log\n",
|
||||
)
|
||||
|
||||
result = _validate_jail_config_sync(tmp_path, "sshd")
|
||||
# logpath advisory warning is OK; no blocking errors expected.
|
||||
blocking = [i for i in result.issues if i.field != "logpath"]
|
||||
assert blocking == [], blocking
|
||||
|
||||
def test_missing_filter_reported(self, tmp_path: Path) -> None:
|
||||
"""A jail whose filter file does not exist must report a filter issue."""
|
||||
self._setup_config(
|
||||
tmp_path,
|
||||
"[bad-jail]\nenabled = false\nfilter = nonexistent-filter\n",
|
||||
)
|
||||
|
||||
result = _validate_jail_config_sync(tmp_path, "bad-jail")
|
||||
assert not result.valid
|
||||
fields = [i.field for i in result.issues]
|
||||
assert "filter" in fields
|
||||
|
||||
def test_bad_failregex_reported(self, tmp_path: Path) -> None:
|
||||
"""A jail with an un-compilable failregex must report a failregex issue."""
|
||||
self._setup_config(
|
||||
tmp_path,
|
||||
"[broken]\nenabled = false\nfailregex = [invalid regex(\n",
|
||||
)
|
||||
|
||||
result = _validate_jail_config_sync(tmp_path, "broken")
|
||||
assert not result.valid
|
||||
fields = [i.field for i in result.issues]
|
||||
assert "failregex" in fields
|
||||
|
||||
def test_missing_log_path_is_advisory(self, tmp_path: Path) -> None:
|
||||
"""A missing log path should be reported in the logpath field."""
|
||||
self._setup_config(
|
||||
tmp_path,
|
||||
"[myjail]\nenabled = false\nlogpath = /no/such/path.log\n",
|
||||
)
|
||||
|
||||
result = _validate_jail_config_sync(tmp_path, "myjail")
|
||||
fields = [i.field for i in result.issues]
|
||||
assert "logpath" in fields
|
||||
|
||||
def test_missing_action_reported(self, tmp_path: Path) -> None:
|
||||
"""A jail referencing a non-existent action file must report an action issue."""
|
||||
self._setup_config(
|
||||
tmp_path,
|
||||
"[myjail]\nenabled = false\naction = nonexistent-action\n",
|
||||
)
|
||||
|
||||
result = _validate_jail_config_sync(tmp_path, "myjail")
|
||||
fields = [i.field for i in result.issues]
|
||||
assert "action" in fields
|
||||
|
||||
def test_unknown_jail_name(self, tmp_path: Path) -> None:
|
||||
"""Requesting validation for a jail not in any config returns an invalid result."""
|
||||
(tmp_path / "jail.d").mkdir(parents=True, exist_ok=True)
|
||||
|
||||
result = _validate_jail_config_sync(tmp_path, "ghost")
|
||||
assert not result.valid
|
||||
assert any(i.field == "name" for i in result.issues)
|
||||
|
||||
def test_variable_action_not_flagged(self, tmp_path: Path) -> None:
|
||||
"""An action like ``%(action_)s`` should not be checked for file existence."""
|
||||
self._setup_config(
|
||||
tmp_path,
|
||||
"[myjail]\nenabled = false\naction = %(action_)s\n",
|
||||
)
|
||||
result = _validate_jail_config_sync(tmp_path, "myjail")
|
||||
# Ensure no action file-missing error (the variable expression is skipped).
|
||||
action_errors = [i for i in result.issues if i.field == "action"]
|
||||
assert action_errors == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestValidateJailConfigAsync:
|
||||
"""Tests for the public async wrapper validate_jail_config."""
|
||||
|
||||
async def test_returns_jail_validation_result(self, tmp_path: Path) -> None:
|
||||
(tmp_path / "jail.d").mkdir(parents=True, exist_ok=True)
|
||||
_write(
|
||||
tmp_path / "jail.d" / "test.conf",
|
||||
"[testjail]\nenabled = false\n",
|
||||
)
|
||||
|
||||
result = await validate_jail_config(str(tmp_path), "testjail")
|
||||
assert result.jail_name == "testjail"
|
||||
|
||||
async def test_rejects_unsafe_name(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(JailNameError):
|
||||
await validate_jail_config(str(tmp_path), "../evil")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestRollbackJail:
|
||||
"""Tests for rollback_jail (Task 3)."""
|
||||
|
||||
async def test_rollback_success(self, tmp_path: Path) -> None:
|
||||
"""When fail2ban comes back online, rollback returns fail2ban_running=True."""
|
||||
_write(tmp_path / "jail.d" / "sshd.conf", "[sshd]\nenabled = true\n")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._start_daemon",
|
||||
new=AsyncMock(return_value=True),
|
||||
),
|
||||
patch(
|
||||
"app.services.config_file_service._wait_for_fail2ban",
|
||||
new=AsyncMock(return_value=True),
|
||||
),
|
||||
patch(
|
||||
"app.services.config_file_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value=set()),
|
||||
),
|
||||
):
|
||||
result = await rollback_jail(
|
||||
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
|
||||
)
|
||||
|
||||
assert result.disabled is True
|
||||
assert result.fail2ban_running is True
|
||||
assert result.jail_name == "sshd"
|
||||
# .local file must have enabled=false
|
||||
local = tmp_path / "jail.d" / "sshd.local"
|
||||
assert local.is_file()
|
||||
assert "enabled = false" in local.read_text()
|
||||
|
||||
async def test_rollback_fail2ban_still_down(self, tmp_path: Path) -> None:
|
||||
"""When fail2ban does not come back, rollback returns fail2ban_running=False."""
|
||||
_write(tmp_path / "jail.d" / "sshd.conf", "[sshd]\nenabled = true\n")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_file_service._start_daemon",
|
||||
new=AsyncMock(return_value=False),
|
||||
),
|
||||
patch(
|
||||
"app.services.config_file_service._wait_for_fail2ban",
|
||||
new=AsyncMock(return_value=False),
|
||||
),
|
||||
):
|
||||
result = await rollback_jail(
|
||||
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
|
||||
)
|
||||
|
||||
assert result.fail2ban_running is False
|
||||
assert result.disabled is True
|
||||
|
||||
async def test_rollback_rejects_unsafe_name(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(JailNameError):
|
||||
await rollback_jail(
|
||||
str(tmp_path), "/fake.sock", "../evil", ["fail2ban-client", "start"]
|
||||
)
|
||||
|
||||
|
||||
@@ -8,10 +8,12 @@ the scheduler and primes the initial status.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.config import PendingRecovery
|
||||
from app.models.server import ServerStatus
|
||||
from app.tasks.health_check import HEALTH_CHECK_INTERVAL, _run_probe, register
|
||||
|
||||
@@ -33,6 +35,8 @@ def _make_app(prev_online: bool = False) -> MagicMock:
|
||||
app.state.settings.fail2ban_socket = "/var/run/fail2ban/fail2ban.sock"
|
||||
app.state.server_status = ServerStatus(online=prev_online)
|
||||
app.state.scheduler = MagicMock()
|
||||
app.state.last_activation = None
|
||||
app.state.pending_recovery = None
|
||||
return app
|
||||
|
||||
|
||||
@@ -236,3 +240,111 @@ class TestRegister:
|
||||
|
||||
_, kwargs = app.state.scheduler.add_job.call_args
|
||||
assert kwargs["kwargs"] == {"app": app}
|
||||
|
||||
def test_register_initialises_last_activation_none(self) -> None:
|
||||
"""``register`` must set ``app.state.last_activation = None``."""
|
||||
app = _make_app()
|
||||
|
||||
register(app)
|
||||
|
||||
assert app.state.last_activation is None
|
||||
|
||||
def test_register_initialises_pending_recovery_none(self) -> None:
|
||||
"""``register`` must set ``app.state.pending_recovery = None``."""
|
||||
app = _make_app()
|
||||
|
||||
register(app)
|
||||
|
||||
assert app.state.pending_recovery is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Crash detection (Task 3)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCrashDetection:
|
||||
"""Tests for activation-crash detection in _run_probe."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crash_within_window_creates_pending_recovery(self) -> None:
|
||||
"""An online→offline transition within 60 s of activation must set pending_recovery."""
|
||||
app = _make_app(prev_online=True)
|
||||
now = datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
app.state.last_activation = {
|
||||
"jail_name": "sshd",
|
||||
"at": now - datetime.timedelta(seconds=10),
|
||||
}
|
||||
app.state.pending_recovery = None
|
||||
|
||||
offline_status = ServerStatus(online=False)
|
||||
|
||||
with patch(
|
||||
"app.tasks.health_check.health_service.probe",
|
||||
new_callable=AsyncMock,
|
||||
return_value=offline_status,
|
||||
):
|
||||
await _run_probe(app)
|
||||
|
||||
assert app.state.pending_recovery is not None
|
||||
assert isinstance(app.state.pending_recovery, PendingRecovery)
|
||||
assert app.state.pending_recovery.jail_name == "sshd"
|
||||
assert app.state.pending_recovery.recovered is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crash_outside_window_does_not_create_pending_recovery(self) -> None:
|
||||
"""A crash more than 60 s after activation must NOT set pending_recovery."""
|
||||
app = _make_app(prev_online=True)
|
||||
app.state.last_activation = {
|
||||
"jail_name": "sshd",
|
||||
"at": datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
- datetime.timedelta(seconds=120),
|
||||
}
|
||||
app.state.pending_recovery = None
|
||||
|
||||
with patch(
|
||||
"app.tasks.health_check.health_service.probe",
|
||||
new_callable=AsyncMock,
|
||||
return_value=ServerStatus(online=False),
|
||||
):
|
||||
await _run_probe(app)
|
||||
|
||||
assert app.state.pending_recovery is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_came_online_marks_pending_recovery_resolved(self) -> None:
|
||||
"""An offline→online transition must mark an existing pending_recovery as recovered."""
|
||||
app = _make_app(prev_online=False)
|
||||
activated_at = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(seconds=30)
|
||||
detected_at = datetime.datetime.now(tz=datetime.timezone.utc)
|
||||
app.state.pending_recovery = PendingRecovery(
|
||||
jail_name="sshd",
|
||||
activated_at=activated_at,
|
||||
detected_at=detected_at,
|
||||
recovered=False,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.tasks.health_check.health_service.probe",
|
||||
new_callable=AsyncMock,
|
||||
return_value=ServerStatus(online=True),
|
||||
):
|
||||
await _run_probe(app)
|
||||
|
||||
assert app.state.pending_recovery.recovered is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_crash_without_recent_activation_does_nothing(self) -> None:
|
||||
"""A crash when last_activation is None must not create a pending_recovery."""
|
||||
app = _make_app(prev_online=True)
|
||||
app.state.last_activation = None
|
||||
app.state.pending_recovery = None
|
||||
|
||||
with patch(
|
||||
"app.tasks.health_check.health_service.probe",
|
||||
new_callable=AsyncMock,
|
||||
return_value=ServerStatus(online=False),
|
||||
):
|
||||
await _run_probe(app)
|
||||
|
||||
assert app.state.pending_recovery is None
|
||||
|
||||
Reference in New Issue
Block a user