Add Kubernetes liveness/readiness probes and middleware order validation
- Split /health into /health/live (liveness) and /health/ready (readiness) following Kubernetes conventions. Combined /health retained for backward compatibility with existing Docker HEALTHCHECK definitions. - Add ReadyCheck and ReadyResponse models for structured readiness output. - Add _assert_middleware_order() startup check enforcing: RateLimit → Csrf → CorrelationId middleware chain. - Register CorrelationIdMiddleware, CsrfMiddleware, RateLimitMiddleware in create_app() with documented required order (reverse of processing). - Add correlation.py, csrf.py, rate_limit.py middleware modules. - Add health probe tests in test_health_probes.py. - Update test_main.py with middleware order assertion tests. - Update frontend useFetchData hook tests. - Docs: update Deployment.md with Kubernetes probe config examples.
This commit is contained in:
@@ -12,7 +12,15 @@ from httpx import ASGITransport, AsyncClient
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.exceptions import ConfigValidationError, ConfigWriteError, JailNotFoundError
|
||||
from app.main import CORSMiddleware, _enforce_single_worker, _lifespan, create_app
|
||||
from app.main import (
|
||||
CORSMiddleware,
|
||||
_assert_middleware_order,
|
||||
_enforce_single_worker,
|
||||
_lifespan,
|
||||
create_app,
|
||||
)
|
||||
from app.middleware.correlation import CorrelationIdMiddleware
|
||||
from app.middleware.rate_limit import RateLimitMiddleware
|
||||
from app.services import setup_service
|
||||
|
||||
|
||||
@@ -450,14 +458,23 @@ async def test_startup_loads_geo_cache_from_persisted_runtime_database(tmp_path:
|
||||
exit_stack.enter_context(patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=load_cache))
|
||||
exit_stack.enter_context(patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0)))
|
||||
exit_stack.enter_context(patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=True)))
|
||||
exit_stack.enter_context(patch("app.services.setup_service.get_runtime_database_path", new=AsyncMock(return_value=runtime_db_path)))
|
||||
exit_stack.enter_context(patch("app.services.setup_service.get_persisted_runtime_settings", new=AsyncMock(return_value={
|
||||
"database_path": runtime_db_path,
|
||||
"fail2ban_socket": "/tmp/persisted.sock",
|
||||
"timezone": "Europe/Berlin",
|
||||
"session_duration_minutes": 123,
|
||||
})))
|
||||
exit_stack.enter_context(patch("app.services.setup_service.get_fail2ban_db_path", new=AsyncMock(return_value="/tmp/fail2ban/banned.tar.bz2")))
|
||||
exit_stack.enter_context(patch(
|
||||
"app.services.setup_service.get_runtime_database_path",
|
||||
new=AsyncMock(return_value=runtime_db_path),
|
||||
))
|
||||
exit_stack.enter_context(patch(
|
||||
"app.services.setup_service.get_persisted_runtime_settings",
|
||||
new=AsyncMock(return_value={
|
||||
"database_path": runtime_db_path,
|
||||
"fail2ban_socket": "/tmp/persisted.sock",
|
||||
"timezone": "Europe/Berlin",
|
||||
"session_duration_minutes": 123,
|
||||
}),
|
||||
))
|
||||
exit_stack.enter_context(patch(
|
||||
"app.services.setup_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value="/tmp/fail2ban/banned.tar.bz2"),
|
||||
))
|
||||
exit_stack.enter_context(patch("app.tasks.health_check.register"))
|
||||
exit_stack.enter_context(patch("app.tasks.blocklist_import.register"))
|
||||
exit_stack.enter_context(patch("app.tasks.geo_cache_flush.register"))
|
||||
@@ -466,8 +483,9 @@ async def test_startup_loads_geo_cache_from_persisted_runtime_database(tmp_path:
|
||||
|
||||
with exit_stack:
|
||||
async with _lifespan(app):
|
||||
loaded_db_path = load_cache.call_args.args[0]
|
||||
runtime_connections = [conn for path, conn in opened_connections if path == runtime_db_path]
|
||||
runtime_connections = [
|
||||
conn for path, conn in opened_connections if path == runtime_db_path
|
||||
]
|
||||
assert runtime_connections, "Expected runtime database to be opened"
|
||||
|
||||
assert app.state.runtime_settings is not None
|
||||
@@ -538,6 +556,91 @@ async def test_concurrent_requests_use_request_scoped_db_connections(tmp_path: P
|
||||
assert all(connection.close.await_count == 1 for connection in connections)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Middleware order validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_settings(tmp_path: Path) -> Settings:
|
||||
"""Return a minimal Settings object with a temporary fail2ban config dir."""
|
||||
fail2ban_config_dir = tmp_path / "fail2ban"
|
||||
fail2ban_config_dir.mkdir()
|
||||
return Settings(
|
||||
database_path=str(tmp_path / "bangui.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
fail2ban_config_dir=str(fail2ban_config_dir),
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
|
||||
|
||||
def test_create_app_raises_on_incorrect_middleware_order(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""_assert_middleware_order() raises AssertionError when middleware order is wrong.
|
||||
|
||||
The security-critical chain requires:
|
||||
RateLimitMiddleware → CsrfMiddleware → CorrelationIdMiddleware
|
||||
in user_middleware (processing order: outermost → innermost).
|
||||
"""
|
||||
monkeypatch.setenv("TESTING", "1")
|
||||
settings = _make_settings(tmp_path)
|
||||
app = create_app(settings=settings)
|
||||
# Swap CorrelationIdMiddleware and RateLimitMiddleware to break the order.
|
||||
user_mw = app.user_middleware
|
||||
corr_idx = next(i for i, m in enumerate(user_mw) if m.cls.__name__ == "CorrelationIdMiddleware")
|
||||
rate_idx = next(i for i, m in enumerate(user_mw) if m.cls.__name__ == "RateLimitMiddleware")
|
||||
user_mw[corr_idx], user_mw[rate_idx] = user_mw[rate_idx], user_mw[corr_idx]
|
||||
with pytest.raises(AssertionError, match="must be registered before"):
|
||||
_assert_middleware_order(app)
|
||||
|
||||
|
||||
def test_middleware_order_validation_passes_for_correct_order(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""_assert_middleware_order() does not raise when middleware order is correct."""
|
||||
monkeypatch.setenv("TESTING", "1")
|
||||
settings = _make_settings(tmp_path)
|
||||
app = create_app(settings=settings)
|
||||
_assert_middleware_order(app) # Should not raise
|
||||
|
||||
|
||||
def test_create_app_validates_middleware_order_at_startup(
|
||||
tmp_path: Path,
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""create_app() raises immediately if middleware registration order is incorrect.
|
||||
|
||||
This test verifies the integration: _assert_middleware_order is called at the
|
||||
end of create_app, so a fresh app with deliberately wrong middleware order
|
||||
(simulated by patching add_middleware during creation) raises AssertionError.
|
||||
"""
|
||||
monkeypatch.setenv("TESTING", "1")
|
||||
settings = _make_settings(tmp_path)
|
||||
|
||||
from starlette.applications import Starlette
|
||||
|
||||
original_add = Starlette.add_middleware
|
||||
|
||||
def swapping_add(self, middleware_cls: type, **kwargs: object) -> None:
|
||||
"""Patched add_middleware that swaps CorrelationId and RateLimit."""
|
||||
if middleware_cls is CorrelationIdMiddleware:
|
||||
pass # Skip CorrelationId
|
||||
elif middleware_cls is RateLimitMiddleware:
|
||||
original_add(self, RateLimitMiddleware, **kwargs)
|
||||
original_add(self, CorrelationIdMiddleware)
|
||||
else:
|
||||
original_add(self, middleware_cls, **kwargs)
|
||||
|
||||
with patch.object(Starlette, "add_middleware", swapping_add), \
|
||||
pytest.raises(AssertionError, match="must be registered before"):
|
||||
create_app(settings=settings)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Single-worker enforcement
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
130
backend/tests/test_routers/test_health_probes.py
Normal file
130
backend/tests/test_routers/test_health_probes.py
Normal file
@@ -0,0 +1,130 @@
|
||||
"""Tests for the health-check router — liveness and readiness probes."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
|
||||
from app.models.server import ServerStatus
|
||||
from app.models.response import ReadyCheck
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /health/live — liveness probe
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_liveness_returns_200(client: AsyncClient) -> None:
|
||||
"""``GET /api/v1/health/live`` must always return HTTP 200."""
|
||||
response = await client.get("/api/v1/health/live")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_liveness_body_is_ready_response(client: AsyncClient) -> None:
|
||||
"""Response body must be a ReadyResponse."""
|
||||
response = await client.get("/api/v1/health/live")
|
||||
data: dict[str, object] = response.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["failed_count"] == 0
|
||||
assert "checks" in data
|
||||
assert isinstance(data["checks"], list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_liveness_includes_process_check(client: AsyncClient) -> None:
|
||||
"""Liveness response must include a 'process' check."""
|
||||
response = await client.get("/api/v1/health/live")
|
||||
data: dict[str, object] = response.json()
|
||||
checks: list[dict[str, object]] = data["checks"] # type: ignore[assignment]
|
||||
assert any(c.get("name") == "process" and c.get("healthy") is True for c in checks)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /health/ready — readiness probe
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readiness_returns_200_when_all_pass(client: AsyncClient) -> None:
|
||||
"""``GET /api/v1/health/ready`` must return 200 when all subsystems pass."""
|
||||
with patch("app.routers.health._run_check", side_effect=lambda n, c, e: ReadyCheck(name=n, healthy=True)):
|
||||
response = await client.get("/api/v1/health/ready")
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readiness_returns_503_when_subsystem_fails(client: AsyncClient) -> None:
|
||||
"""``GET /api/v1/health/ready`` must return 503 when at least one check fails."""
|
||||
# Force fail2ban offline
|
||||
client._transport.app.state.server_status = ServerStatus(online=False)
|
||||
response = await client.get("/api/v1/health/ready")
|
||||
assert response.status_code == 503
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readiness_body_is_ready_response(client: AsyncClient) -> None:
|
||||
"""Response body must be a ReadyResponse."""
|
||||
response = await client.get("/api/v1/health/ready")
|
||||
data: dict[str, object] = response.json()
|
||||
assert data["status"] in ("ok", "error")
|
||||
assert "failed_count" in data
|
||||
assert "checks" in data
|
||||
assert isinstance(data["checks"], list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readiness_includes_all_subsystems(client: AsyncClient) -> None:
|
||||
"""Readiness response must include checks for all four subsystems."""
|
||||
response = await client.get("/api/v1/health/ready")
|
||||
data: dict[str, object] = response.json()
|
||||
checks: list[dict[str, object]] = data["checks"] # type: ignore[assignment]
|
||||
names = {c["name"] for c in checks}
|
||||
assert names == {"database", "fail2ban", "config_dir", "scheduler"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readiness_status_ok_when_all_healthy(client: AsyncClient) -> None:
|
||||
"""``status`` must be 'ok' when all checks pass."""
|
||||
with patch("app.routers.health._run_check", side_effect=lambda n, c, e: ReadyCheck(name=n, healthy=True)):
|
||||
response = await client.get("/api/v1/health/ready")
|
||||
data: dict[str, object] = response.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["failed_count"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readiness_status_error_when_fail2ban_offline(client: AsyncClient) -> None:
|
||||
"""``status`` must be 'error' when fail2ban is offline."""
|
||||
client._transport.app.state.server_status = ServerStatus(online=False)
|
||||
response = await client.get("/api/v1/health/ready")
|
||||
data: dict[str, object] = response.json()
|
||||
assert data["status"] == "error"
|
||||
assert data["failed_count"] > 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readiness_includes_failed_subsystem_detail(client: AsyncClient) -> None:
|
||||
"""When fail2ban is offline the fail2ban check must include an error message."""
|
||||
client._transport.app.state.server_status = ServerStatus(online=False)
|
||||
response = await client.get("/api/v1/health/ready")
|
||||
data: dict[str, object] = response.json()
|
||||
checks: list[dict[str, object]] = data["checks"] # type: ignore[assignment]
|
||||
f2b = next(c for c in checks if c["name"] == "fail2ban")
|
||||
assert f2b["healthy"] is False
|
||||
assert f2b["message"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readiness_content_type_is_json(client: AsyncClient) -> None:
|
||||
"""``/api/v1/health/ready`` must set the ``Content-Type`` header to JSON."""
|
||||
response = await client.get("/api/v1/health/ready")
|
||||
assert "application/json" in response.headers.get("content-type", "")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readiness_live_content_type_is_json(client: AsyncClient) -> None:
|
||||
"""``/api/v1/health/live`` must set the ``Content-Type`` header to JSON."""
|
||||
response = await client.get("/api/v1/health/live")
|
||||
assert "application/json" in response.headers.get("content-type", "")
|
||||
Reference in New Issue
Block a user