Extract health-check crash-detection logic into runtime state helper

This commit is contained in:
2026-04-17 16:58:24 +02:00
parent 1e2850a34e
commit 7a1cb0c46c
5 changed files with 122 additions and 69 deletions

View File

@@ -330,6 +330,8 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
**Docs changes needed:** Update `Docs/Refactoring.md`. **Docs changes needed:** Update `Docs/Refactoring.md`.
**Status:** Completed ✅
**Why this is needed:** Business rules about crash attribution timing should not live inside a scheduling artifact. Embedding decision logic in a background job makes it invisible, hard to test without the scheduler, and impossible to reuse from another trigger (e.g. a manual probe endpoint). **Why this is needed:** Business rules about crash attribution timing should not live inside a scheduling artifact. Embedding decision logic in a background job makes it invisible, hard to test without the scheduler, and impossible to reuse from another trigger (e.g. a manual probe endpoint).
--- ---

View File

@@ -18,17 +18,17 @@ within 60 seconds of that activation, a
from __future__ import annotations from __future__ import annotations
import datetime import datetime
from typing import TYPE_CHECKING, TypedDict from typing import TYPE_CHECKING
import structlog import structlog
from app.models.config import PendingRecovery
from app.models.server import ServerStatus from app.models.server import ServerStatus
from app.services import health_service from app.services import health_service
from app.utils.runtime_state import ( from app.utils.runtime_state import (
RuntimeState, RuntimeState,
get_effective_settings, get_effective_settings,
get_runtime_state, get_runtime_state,
process_health_probe_result,
) )
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
@@ -39,20 +39,9 @@ if TYPE_CHECKING: # pragma: no cover
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
class ActivationRecord(TypedDict):
"""Stored timestamp data for a jail activation event."""
jail_name: str
at: datetime.datetime
#: How often the probe fires (seconds). #: How often the probe fires (seconds).
HEALTH_CHECK_INTERVAL: int = 30 HEALTH_CHECK_INTERVAL: int = 30
#: Maximum seconds since an activation for a subsequent crash to be attributed
#: to that activation.
_ACTIVATION_CRASH_WINDOW: int = 60
async def _run_probe_with_resources(settings: Settings, runtime_state: RuntimeState) -> None: async def _run_probe_with_resources(settings: Settings, runtime_state: RuntimeState) -> None:
"""Probe fail2ban and cache the result on the runtime state. """Probe fail2ban and cache the result on the runtime state.
@@ -68,57 +57,7 @@ async def _run_probe_with_resources(settings: Settings, runtime_state: RuntimeSt
ServerStatus(online=False), ServerStatus(online=False),
) )
status: ServerStatus = await health_service.probe(socket_path) status: ServerStatus = await health_service.probe(socket_path)
runtime_state.server_status = status process_health_probe_result(runtime_state, status)
now = datetime.datetime.now(tz=datetime.UTC)
# Log transitions between online and offline states.
if status.online and not prev_status.online:
log.info("fail2ban_came_online", version=status.version)
# Clear any pending recovery once fail2ban is back online.
existing: PendingRecovery | None = getattr(runtime_state, "pending_recovery", None)
if existing is not None and not existing.recovered:
runtime_state.pending_recovery = PendingRecovery(
jail_name=existing.jail_name,
activated_at=existing.activated_at,
detected_at=existing.detected_at,
recovered=True,
)
log.info(
"pending_recovery_resolved",
jail=existing.jail_name,
)
elif not status.online and prev_status.online:
log.warning("fail2ban_went_offline")
# Check whether this crash happened shortly after a jail activation.
last_activation: ActivationRecord | None = getattr(runtime_state, "last_activation", None)
if last_activation is not None:
activated_at: datetime.datetime = last_activation["at"]
seconds_since = (now - activated_at).total_seconds()
if seconds_since <= _ACTIVATION_CRASH_WINDOW:
jail_name: str = last_activation["jail_name"]
# Only create a new record when there is not already an
# unresolved one for the same jail.
current: PendingRecovery | None = getattr(runtime_state, "pending_recovery", None)
if current is None or current.recovered:
runtime_state.pending_recovery = PendingRecovery(
jail_name=jail_name,
activated_at=activated_at,
detected_at=now,
)
log.warning(
"activation_crash_detected",
jail=jail_name,
seconds_since_activation=seconds_since,
)
log.debug(
"health_check_complete",
online=status.online,
version=status.version,
active_jails=status.active_jails,
)
async def _run_probe(app: FastAPI) -> None: async def _run_probe(app: FastAPI) -> None:

View File

@@ -15,14 +15,22 @@ from typing import TYPE_CHECKING, Any
from starlette.datastructures import State from starlette.datastructures import State
import structlog
from app.models.config import PendingRecovery from app.models.config import PendingRecovery
from app.models.server import ServerStatus from app.models.server import ServerStatus
if TYPE_CHECKING: # pragma: no cover if TYPE_CHECKING: # pragma: no cover
from app.config import Settings from app.config import Settings
log: structlog.stdlib.BoundLogger = structlog.get_logger()
ActivationRecord = dict[str, datetime.datetime] ActivationRecord = dict[str, datetime.datetime]
# Maximum seconds since an activation for a subsequent crash to be
# attributed to that activation.
_ACTIVATION_CRASH_WINDOW: int = 60
_RUNTIME_ATTRIBUTES: frozenset[str] = frozenset( _RUNTIME_ATTRIBUTES: frozenset[str] = frozenset(
{ {
"setup_complete_cached", "setup_complete_cached",
@@ -151,3 +159,67 @@ def clear_pending_recovery(app: Any) -> None:
def clear_activation_record(app: Any) -> None: def clear_activation_record(app: Any) -> None:
"""Clear the current activation tracking record.""" """Clear the current activation tracking record."""
get_runtime_state(app).last_activation = None get_runtime_state(app).last_activation = None
def process_health_probe_result(
runtime_state: RuntimeState,
status: ServerStatus,
now: datetime.datetime | None = None,
) -> None:
"""Process a new health probe result and update runtime state.
This function tracks fail2ban transitions and creates or resolves
pending recovery records when the daemon goes offline shortly after a
jail activation.
Args:
runtime_state: The mutable runtime state manager.
status: The latest fail2ban server status.
now: The current timestamp used for time-based decisions.
"""
prev_status = getattr(runtime_state, "server_status", ServerStatus(online=False))
runtime_state.server_status = status
now = now if now is not None else datetime.datetime.now(tz=datetime.UTC)
if status.online and not prev_status.online:
log.info("fail2ban_came_online", version=status.version)
existing = runtime_state.pending_recovery
if existing is not None and not existing.recovered:
runtime_state.pending_recovery = PendingRecovery(
jail_name=existing.jail_name,
activated_at=existing.activated_at,
detected_at=existing.detected_at,
recovered=True,
)
log.info(
"pending_recovery_resolved",
jail=existing.jail_name,
)
elif not status.online and prev_status.online:
log.warning("fail2ban_went_offline")
last_activation = runtime_state.last_activation
if last_activation is not None:
activated_at = last_activation["at"]
seconds_since = (now - activated_at).total_seconds()
if seconds_since <= _ACTIVATION_CRASH_WINDOW:
jail_name = last_activation["jail_name"]
current = runtime_state.pending_recovery
if current is None or current.recovered:
runtime_state.pending_recovery = PendingRecovery(
jail_name=jail_name,
activated_at=activated_at,
detected_at=now,
)
log.warning(
"activation_crash_detected",
jail=jail_name,
seconds_since_activation=seconds_since,
)
log.debug(
"health_check_complete",
online=status.online,
version=status.version,
active_jails=status.active_jails,
)

View File

@@ -77,7 +77,7 @@ class TestRunProbe:
"app.tasks.health_check.health_service.probe", "app.tasks.health_check.health_service.probe",
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=new_status, return_value=new_status,
), patch("app.tasks.health_check.log") as mock_log: ), patch("app.utils.runtime_state.log") as mock_log:
await _run_probe(app) await _run_probe(app)
online_calls = [c for c in mock_log.info.call_args_list if c[0][0] == "fail2ban_came_online"] online_calls = [c for c in mock_log.info.call_args_list if c[0][0] == "fail2ban_came_online"]
@@ -93,7 +93,7 @@ class TestRunProbe:
"app.tasks.health_check.health_service.probe", "app.tasks.health_check.health_service.probe",
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=new_status, return_value=new_status,
), patch("app.tasks.health_check.log") as mock_log: ), patch("app.utils.runtime_state.log") as mock_log:
await _run_probe(app) await _run_probe(app)
offline_calls = [c for c in mock_log.warning.call_args_list if c[0][0] == "fail2ban_went_offline"] offline_calls = [c for c in mock_log.warning.call_args_list if c[0][0] == "fail2ban_went_offline"]
@@ -109,7 +109,7 @@ class TestRunProbe:
"app.tasks.health_check.health_service.probe", "app.tasks.health_check.health_service.probe",
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=new_status, return_value=new_status,
), patch("app.tasks.health_check.log") as mock_log: ), patch("app.utils.runtime_state.log") as mock_log:
await _run_probe(app) await _run_probe(app)
transition_calls = [ transition_calls = [
@@ -134,7 +134,7 @@ class TestRunProbe:
"app.tasks.health_check.health_service.probe", "app.tasks.health_check.health_service.probe",
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=new_status, return_value=new_status,
), patch("app.tasks.health_check.log") as mock_log: ), patch("app.utils.runtime_state.log") as mock_log:
await _run_probe(app) await _run_probe(app)
transition_calls = [ transition_calls = [
@@ -180,7 +180,7 @@ class TestRunProbe:
new_callable=AsyncMock, new_callable=AsyncMock,
return_value=new_status, return_value=new_status,
), ),
patch("app.tasks.health_check.log"), patch("app.utils.runtime_state.log"),
): ):
# Must not raise even with no prior status. # Must not raise even with no prior status.
await _run_probe(app) await _run_probe(app)

View File

@@ -1,8 +1,11 @@
from __future__ import annotations from __future__ import annotations
import datetime
from unittest.mock import MagicMock from unittest.mock import MagicMock
from app.config import Settings from app.config import Settings
from app.models.config import PendingRecovery
from app.models.server import ServerStatus
from app.utils.runtime_state import get_app_settings, get_effective_settings from app.utils.runtime_state import get_app_settings, get_effective_settings
@@ -45,3 +48,40 @@ def test_get_app_settings_reads_bootstrap_settings() -> None:
app = _FakeApp(_FakeState(settings=settings)) app = _FakeApp(_FakeState(settings=settings))
assert get_app_settings(app) is settings assert get_app_settings(app) is settings
def test_process_health_probe_result_creates_pending_recovery_within_window() -> None:
from app.utils.runtime_state import RuntimeState, process_health_probe_result
now = datetime.datetime.now(tz=datetime.UTC)
runtime_state = RuntimeState(
server_status=ServerStatus(online=True),
last_activation={"jail_name": "sshd", "at": now - datetime.timedelta(seconds=30)},
pending_recovery=None,
)
process_health_probe_result(runtime_state, ServerStatus(online=False), now=now)
assert runtime_state.pending_recovery is not None
assert runtime_state.pending_recovery.jail_name == "sshd"
assert runtime_state.pending_recovery.recovered is False
def test_process_health_probe_result_resolves_existing_pending_recovery() -> None:
from app.utils.runtime_state import RuntimeState, process_health_probe_result
activated_at = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=30)
runtime_state = RuntimeState(
server_status=ServerStatus(online=False),
pending_recovery=PendingRecovery(
jail_name="sshd",
activated_at=activated_at,
detected_at=activated_at + datetime.timedelta(seconds=10),
recovered=False,
),
)
process_health_probe_result(runtime_state, ServerStatus(online=True), now=activated_at + datetime.timedelta(seconds=20))
assert runtime_state.pending_recovery is not None
assert runtime_state.pending_recovery.recovered is True