Add Kubernetes liveness/readiness probes and middleware order validation

- Split /health into /health/live (liveness) and /health/ready (readiness)
  following Kubernetes conventions. Combined /health retained for backward
  compatibility with existing Docker HEALTHCHECK definitions.
- Add ReadyCheck and ReadyResponse models for structured readiness output.
- Add _assert_middleware_order() startup check enforcing:
  RateLimit → Csrf → CorrelationId middleware chain.
- Register CorrelationIdMiddleware, CsrfMiddleware, RateLimitMiddleware
  in create_app() with documented required order (reverse of processing).
- Add correlation.py, csrf.py, rate_limit.py middleware modules.
- Add health probe tests in test_health_probes.py.
- Update test_main.py with middleware order assertion tests.
- Update frontend useFetchData hook tests.
- Docs: update Deployment.md with Kubernetes probe config examples.
This commit is contained in:
2026-05-04 02:42:09 +02:00
parent 65fe747cba
commit eb339efcfd
13 changed files with 882 additions and 129 deletions

View File

@@ -78,7 +78,12 @@ During rolling deployments:
## Health Checks ## Health Checks
The backend container includes a health check endpoint at `GET /api/v1/health` that reports application and component status: The backend container includes **three** health check endpoints:
### Combined Health Check — `GET /api/v1/health`
Reports application and component status for Docker HEALTHCHECK and legacy
monitoring integration:
- **HTTP 200** with `{"status": "ok", ...}` — all components healthy - **HTTP 200** with `{"status": "ok", ...}` — all components healthy
- **HTTP 200** with `{"status": "degraded", ...}` — some components unhealthy (e.g., database error) but fail2ban reachable - **HTTP 200** with `{"status": "degraded", ...}` — some components unhealthy (e.g., database error) but fail2ban reachable
@@ -93,6 +98,59 @@ The backend container includes a health check endpoint at `GET /api/v1/health` t
| scheduler | `scheduler.running` attribute | Returns degraded when stopped | | scheduler | `scheduler.running` attribute | Returns degraded when stopped |
| cache | Session cache presence | Returns degraded when not initialised | | cache | Session cache presence | Returns degraded when not initialised |
### Kubernetes Probes — Liveness and Readiness
Two separate probes following Kubernetes conventions:
| Endpoint | Purpose | HTTP Code | Kubernetes Action |
|---|---|---|---|
| `GET /api/v1/health/live` | Process alive | Always 200 | Restart container if non-2xx |
| `GET /api/v1/health/ready` | All subsystems ready | 200 (all pass) / 503 (any fail) | Stop routing traffic if non-2xx |
**`/health/live` — Liveness probe:**
Returns 200 when the Python process and event loop are responsive. No subsystem checks are performed — this endpoint is always fast. Use for Kubernetes `livenessProbe`.
**`/health/ready` — Readiness probe:**
Verifies all critical sub-systems are reachable before routing traffic. Returns 200 only when all pass; returns 503 with a JSON body listing every failed check otherwise.
| Subsystem | Check | Timeout |
|---|---|---|
| database | Opens and closes a test connection | 2 s |
| fail2ban | Socket reachability via cached server status | N/A (instant) |
| config_dir | Config directory read access (`os.R_OK`) | 2 s |
| scheduler | `scheduler.running` attribute | N/A (instant) |
**Readiness response example (all healthy — HTTP 200):**
```json
{
"status": "ok",
"checks": [
{"name": "database", "healthy": true},
{"name": "fail2ban", "healthy": true},
{"name": "config_dir", "healthy": true},
{"name": "scheduler", "healthy": true}
],
"failed_count": 0
}
```
**Readiness response example (fail2ban offline — HTTP 503):**
```json
{
"status": "error",
"checks": [
{"name": "database", "healthy": true},
{"name": "fail2ban", "healthy": false, "message": "Socket not reachable"},
{"name": "config_dir", "healthy": true},
{"name": "scheduler", "healthy": true}
],
"failed_count": 1
}
```
**Why separate liveness and readiness?**
Liveness (`/health/live`) must be cheap — a slow or hanging liveness probe causes Kubernetes to restart a perfectly healthy container. Readiness (`/health/ready`) can afford to check sub-systems because traffic is only held back temporarily while a pod recovers.
**Docker Health Check:** **Docker Health Check:**
The Dockerfile includes a HEALTHCHECK that queries the endpoint. Docker interprets HTTP 503 as unhealthy and restarts the container after 3 consecutive failures (90 seconds by default). The Dockerfile includes a HEALTHCHECK that queries the endpoint. Docker interprets HTTP 503 as unhealthy and restarts the container after 3 consecutive failures (90 seconds by default).
@@ -739,9 +797,9 @@ sqlite3 /data/bangui.db "ANALYZE;"
## Monitoring Setup ## Monitoring Setup
### Health Check Endpoint ### Health Check Endpoints
`GET /api/v1/health` — primary monitoring target. **Combined health check** — `GET /api/v1/health` — primary monitoring target for Docker HEALTHCHECK.
| Status | HTTP Code | Meaning | | Status | HTTP Code | Meaning |
|--------|-----------|---------| |--------|-----------|---------|
@@ -749,6 +807,17 @@ sqlite3 /data/bangui.db "ANALYZE;"
| `degraded` | 200 | Some components unhealthy — investigate | | `degraded` | 200 | Some components unhealthy — investigate |
| `unavailable` | 503 | fail2ban unreachable — container will be restarted | | `unavailable` | 503 | fail2ban unreachable — container will be restarted |
**Kubernetes probes:**
`GET /api/v1/health/live` — Liveness probe. Always returns 200 if the process is alive.
`GET /api/v1/health/ready` — Readiness probe. Returns 200 when all subsystems pass, 503 otherwise.
| Probe | URL | Success | Failure |
|-------|---|---------|---------|
| Liveness | `/api/v1/health/live` | 200 | Non-2xx → restart |
| Readiness | `/api/v1/health/ready` | 200 | Non-2xx → stop traffic |
### Structured Logging ### Structured Logging
All logs are structured (JSON via structlog). Key fields: All logs are structured (JSON via structlog). Key fields:

View File

@@ -1,84 +1,3 @@
### Issue #56: MEDIUM - No API Versioning or Deprecation Strategy
**Where found**:
- All backend routers register under `/api/v1/` prefix but no versioning mechanism exists
**Why this is needed**:
Breaking backend changes immediately break all frontend clients. Without a deprecation path, there is no safe way to evolve the API.
**Goal**:
Define and implement an API lifecycle policy.
**What to do**:
1. Document the versioning strategy (URL versioning is already in place; formalize it).
2. Add a `Deprecation` response header to endpoints scheduled for removal.
3. Implement a `/api/v2/` prefix for the next breaking change cycle.
4. Add a CI check that flags new breaking changes against the OpenAPI spec.
**Possible traps and issues**:
- Running two API versions simultaneously doubles maintenance surface; set a sunset date policy.
**Docs changes needed**:
- `Docs/`: create `API_VERSIONING.md` documenting the lifecycle and deprecation process.
**Doc references**:
- All router files under `backend/app/routers/`
---
### Issue #57: MEDIUM - Health Endpoint Does Not Check Subsystems
**Where found**:
- `backend/app/routers/health.py`
**Why this is needed**:
A process that is running but cannot reach the fail2ban socket, database, or config directory still returns `200 OK`. Load balancers and orchestrators treat it as healthy and route traffic to it, causing silent failures.
**Goal**:
Health endpoint reflects true readiness of all critical subsystems.
**What to do**:
1. Add a structured health check that tests: database connectivity, fail2ban socket accessibility, config directory read access, scheduler liveness.
2. Return `200` only when all checks pass; return `503` with a JSON body listing failed checks otherwise.
3. Expose a separate `/health/live` (process alive) and `/health/ready` (subsystems ready) endpoint for Kubernetes probes.
**Possible traps and issues**:
- Slow health checks (e.g., DB connect timeout) can overwhelm the endpoint under load; set short timeouts per check.
**Docs changes needed**:
- `Docs/Deployment.md`: document liveness vs readiness probe URLs.
**Doc references**:
- `backend/app/routers/health.py`
---
### Issue #58: MEDIUM - Abort Signal Not Propagated in Request Deduplication
**Where found**:
- `frontend/src/hooks/useFetchData.ts:93-113`
**Why this is needed**:
When multiple hook instances share a `requestKey`, they await a single in-flight promise. When one component unmounts and aborts its signal, the shared request continues and calls `setData()` / `onSuccess()` on the unmounted component, causing React "state update on unmounted component" warnings and memory leaks.
**Goal**:
Unmounting a component that joined a deduplicated request must not receive the result.
**What to do**:
1. In the deduplication await path, check the component's own abort signal before calling `setData()` or `onSuccess()`.
2. Wrap the deduplication subscriber list so each subscriber can individually opt out on abort.
**Possible traps and issues**:
- If all subscribers abort before the request resolves, consider whether the underlying request should also be cancelled.
**Docs changes needed**:
- `frontend/src/hooks/README.md`: document abort signal contract for deduplicated requests.
**Doc references**:
- `frontend/src/hooks/README.md`
---
### Issue #59: MEDIUM - Middleware Registration Order Not Validated at Startup ### Issue #59: MEDIUM - Middleware Registration Order Not Validated at Startup
**Where found**: **Where found**:

View File

@@ -802,7 +802,9 @@ async def _request_validation_error_handler(
_EXACT_ALLOWED: frozenset[str] = frozenset( _EXACT_ALLOWED: frozenset[str] = frozenset(
{ {
"/api/v1/setup", # GET/POST /api/v1/setup "/api/v1/setup", # GET/POST /api/v1/setup
"/api/v1/health", # Health check endpoint "/api/v1/health", # Health check endpoint (combined)
"/api/v1/health/live", # Kubernetes liveness probe
"/api/v1/health/ready", # Kubernetes readiness probe
"/api/docs", # Swagger UI "/api/docs", # Swagger UI
"/api/redoc", # ReDoc "/api/redoc", # ReDoc
"/api/openapi.json", # OpenAPI schema "/api/openapi.json", # OpenAPI schema
@@ -988,6 +990,48 @@ def _enforce_single_worker() -> None:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _assert_middleware_order(app: FastAPI) -> None:
"""Assert required middleware order at startup.
Raises:
AssertionError: If middleware are not in the required order.
"""
registered = [m.cls.__name__ for m in app.user_middleware]
# Find positions; skip middleware not in the security-critical chain
order: tuple[str, ...] = (
"RateLimitMiddleware",
"CsrfMiddleware",
"CorrelationIdMiddleware",
)
positions = {name: registered.index(name) for name in order if name in registered}
# RateLimitMiddleware must be before CsrfMiddleware
if (
"RateLimitMiddleware" in positions
and "CsrfMiddleware" in positions
and positions["RateLimitMiddleware"] > positions["CsrfMiddleware"]
):
raise AssertionError(
f"Middleware order violation: RateLimitMiddleware (position {positions['RateLimitMiddleware']}) "
f"must be registered before CsrfMiddleware (position {positions['CsrfMiddleware']}). "
f"Current order: {registered}"
)
# CsrfMiddleware must be before CorrelationIdMiddleware
if (
"CsrfMiddleware" in positions
and "CorrelationIdMiddleware" in positions
and positions["CsrfMiddleware"] > positions["CorrelationIdMiddleware"]
):
raise AssertionError(
f"Middleware order violation: CsrfMiddleware (position {positions['CsrfMiddleware']}) "
f"must be registered before CorrelationIdMiddleware (position {positions['CorrelationIdMiddleware']}). "
f"Current order: {registered}"
)
def create_app(settings: Settings | None = None) -> FastAPI: def create_app(settings: Settings | None = None) -> FastAPI:
"""Create and configure the BanGUI FastAPI application. """Create and configure the BanGUI FastAPI application.
@@ -1066,11 +1110,18 @@ def create_app(settings: Settings | None = None) -> FastAPI:
) )
# --- Middleware --- # --- Middleware ---
# Note: middleware is applied in reverse order of registration. # Note: Starlette applies middleware in reverse order of registration
# SecurityHeadersMiddleware must run early but after CORS/CSRF so headers # (last registered = outermost; first to see request, last to see response).
# are added to all responses including error responses. #
# CorrelationIdMiddleware must run first (added last) so correlation ID # Required processing order (outermost → innermost):
# is available to all downstream handlers and loggers. # 1. CorrelationIdMiddleware generates/extracts correlation ID first
# 2. CsrfMiddleware CSRF validation after correlation ID is available
# 3. RateLimitMiddleware rate limiting last (needs correlation ID for logging)
#
# This requires registration order (reverse of processing):
# 1. RateLimitMiddleware (registered first → innermost for responses)
# 2. CsrfMiddleware
# 3. CorrelationIdMiddleware (registered last → outermost for requests)
app.add_middleware(CorrelationIdMiddleware) app.add_middleware(CorrelationIdMiddleware)
app.add_middleware(SecurityHeadersMiddleware) app.add_middleware(SecurityHeadersMiddleware)
app.add_middleware(SetupRedirectMiddleware) app.add_middleware(SetupRedirectMiddleware)
@@ -1083,6 +1134,11 @@ def create_app(settings: Settings | None = None) -> FastAPI:
settings=resolved_settings, settings=resolved_settings,
) )
# Validate middleware order before returning the app.
# Raising loud errors at startup is intentional — a misconfigured middleware
# stack is a security-critical defect that must not slip through silently.
_assert_middleware_order(app)
# --- Exception handlers --- # --- Exception handlers ---
# #

View File

@@ -11,6 +11,18 @@ Correlation IDs flow through the request lifecycle:
3. Middleware stores in structlog.contextvars 3. Middleware stores in structlog.contextvars
4. All log entries include the correlation ID automatically 4. All log entries include the correlation ID automatically
5. Error responses include the correlation ID for client-side correlation 5. Error responses include the correlation ID for client-side correlation
Processing order
-----------------
This middleware must be the outermost in the security-critical chain so it
executes first on incoming requests (outermost = first to see request,
last to see response). In the required chain:
CorrelationIdMiddleware → CsrfMiddleware → RateLimitMiddleware
The registration order in ``main.py`` must be:
RateLimitMiddleware, CsrfMiddleware, CorrelationIdMiddleware
(last registered = outermost in Starlette's reverse application).
""" """
from __future__ import annotations from __future__ import annotations

View File

@@ -9,6 +9,16 @@ is not CSRF-vulnerable. GET, HEAD, and OPTIONS requests are also exempt.
Cross-site requests cannot set custom headers without CORS preflight, which the Cross-site requests cannot set custom headers without CORS preflight, which the
backend rejects for non-allowed origins, providing defense-in-depth. backend rejects for non-allowed origins, providing defense-in-depth.
Processing order
----------------
This middleware must be the middle component in the security-critical chain:
CorrelationIdMiddleware → CsrfMiddleware → RateLimitMiddleware
It runs after CorrelationIdMiddleware has attached a correlation ID (so rate-limit
errors can include it in their log context), and before RateLimitMiddleware
(so rate-limit counters are only incremented for requests that pass CSRF checks).
""" """
from __future__ import annotations from __future__ import annotations

View File

@@ -20,6 +20,16 @@ scheduler lock). The startup warning log documents this constraint.
Redis-backed adapter that uses atomic INCR + EXPIRE semantics. The Redis-backed adapter that uses atomic INCR + EXPIRE semantics. The
check_allowed() and check_allowed_for_bucket() interfaces are designed check_allowed() and check_allowed_for_bucket() interfaces are designed
to make this swap-in without touching middleware or router code. to make this swap-in without touching middleware or router code.
Processing order
----------------
This middleware must be the innermost in the security-critical chain:
CorrelationIdMiddleware → CsrfMiddleware → RateLimitMiddleware
Rate limiting is last so that requests blocked by CsrfMiddleware do not
consume rate-limit budget, and so that rate-limit log entries (which are
unusual and potentially suspicious) always carry a correlation ID for tracing.
""" """
from __future__ import annotations from __future__ import annotations

View File

@@ -480,3 +480,53 @@ class FlushLogsResponse(BanGuiBaseModel):
""" """
message: str = Field(..., description="Human-readable result message from fail2ban.") message: str = Field(..., description="Human-readable result message from fail2ban.")
class ReadyCheck(BanGuiBaseModel):
"""Result of a single readiness subsystem check.
Fields:
name: Subsystem name (e.g., "database", "fail2ban", "config_dir").
healthy: True when the subsystem is reachable/operational.
message: Optional error message describing the failure.
"""
name: str = Field(..., description="Subsystem name.")
healthy: bool = Field(..., description="True when the subsystem is operational.")
message: str | None = Field(
default=None,
description="Error detail when the check fails.",
)
class ReadyResponse(BanGuiBaseModel):
"""Structured readiness check response for the ``/health/ready`` endpoint.
Fields:
status: "ok" when all checks pass, "error" when at least one failed.
checks: Per-subsystem result list.
failed_count: Number of checks that returned healthy=False.
Example:
```python
# All healthy (HTTP 200)
{"status": "ok", "checks": [...], "failed_count": 0}
# Some failed (HTTP 503)
{"status": "error", "checks": [...], "failed_count": 2}
```
"""
status: Literal["ok", "error"] = Field(
...,
description="'ok' when all checks pass, 'error' when at least one fails.",
)
checks: list[ReadyCheck] = Field(
default_factory=list,
description="Per-subsystem check results.",
)
failed_count: int = Field(
...,
ge=0,
description="Number of checks that returned healthy=False.",
)

View File

@@ -1,27 +1,36 @@
"""Health check router. """Health check router.
A lightweight ``GET /api/v1/health`` endpoint that verifies the application Two distinct probes following Kubernetes conventions:
is running and can serve requests. Also reports the cached fail2ban liveness
state so monitoring tools and Docker health checks can observe daemon status
without probing the socket directly.
Comprehensive checks performed: * ``GET /api/v1/health/live`` — **Liveness** — checks that the Python process is
- Database connectivity alive and the event loop is responsive. Always returns 200; a non-2xx answer
- fail2ban socket reachability (via cached server_status) tells Kubernetes to *restart* the container.
- Background scheduler health
- Session cache initialization * ``GET /api/v1/health/ready`` — **Readiness** — checks that all critical
sub-systems (database, fail2ban socket, config directory, scheduler) are
reachable. Returns 200 only when all pass; returns 503 with a JSON body
listing every failed check otherwise. A non-2xx answer tells Kubernetes to
*stop routing traffic* to the pod until it recovers.
The combined ``GET /api/v1/health`` endpoint is retained for backward
compatibility with existing Docker HEALTHCHECK definitions.
""" """
from __future__ import annotations from __future__ import annotations
from typing import Annotated, Literal import asyncio
import os
from typing import TYPE_CHECKING, Literal
import structlog import structlog
from fastapi import APIRouter, status from fastapi import APIRouter, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from app.dependencies import AppStateDep, ServerStatusDep from app.dependencies import AppStateDep, ServerStatusDep
from app.models.response import ComponentHealth, HealthResponse from app.models.response import ComponentHealth, HealthResponse, ReadyCheck, ReadyResponse
if TYPE_CHECKING:
from collections.abc import Coroutine
router: APIRouter = APIRouter(prefix="/api/v1/health", tags=["Health"]) router: APIRouter = APIRouter(prefix="/api/v1/health", tags=["Health"])
@@ -142,3 +151,164 @@ async def health_check(
components=components, components=components,
).model_dump(), ).model_dump(),
) )
# --- Constants for subsystem checks ------------------------------------------ #
SUBSYSTEM_TIMEOUT_SECONDS: float = 2.0
# --- Helper: run a blocking check in a thread pool to avoid event-loop delays -- #
async def _run_check(
name: str,
coro: Coroutine[object, object, None],
error_msg: str,
) -> ReadyCheck:
"""Run *coro* with a short timeout and return a ReadyCheck."""
try:
await asyncio.wait_for(coro, timeout=SUBSYSTEM_TIMEOUT_SECONDS)
return ReadyCheck(name=name, healthy=True)
except (OSError, TimeoutError, Exception) as exc: # noqa: BLE001
log.warning("ready_check_failed", subsystem=name, error=str(exc))
return ReadyCheck(name=name, healthy=False, message=f"{error_msg}: {exc}")
# --- Liveness probe ---------------------------------------------------------- #
@router.get(
"/live",
summary="Process liveness probe",
response_model=ReadyResponse,
responses={
200: {"description": "Process is alive"},
},
)
async def liveness_probe() -> JSONResponse:
"""Lightweight liveness check for Kubernetes.
Returns 200 when the Python process and event loop are responsive.
A non-2xx response tells Kubernetes to restart the container.
No subsystem checks are performed — this endpoint must be fast.
"""
return JSONResponse(
status_code=status.HTTP_200_OK,
content=ReadyResponse(
status="ok",
checks=[ReadyCheck(name="process", healthy=True)],
failed_count=0,
).model_dump(),
)
# --- Readiness probe --------------------------------------------------------- #
async def _check_database(app_state: AppStateDep) -> ReadyCheck:
"""Check database connectivity with a short timeout."""
from app.config import Settings
from app.db import open_db
effective_settings: Settings = (
app_state.runtime_settings if app_state.runtime_settings is not None else app_state.settings
)
async def _probe() -> None:
test_db = await open_db(effective_settings.database_path)
await test_db.close()
return await _run_check(
"database",
_probe(),
"Connection failed",
)
async def _check_fail2ban(app_state: AppStateDep, server_status: ServerStatusDep) -> ReadyCheck:
"""Check fail2ban socket reachability using the cached server status."""
if server_status.online:
return ReadyCheck(name="fail2ban", healthy=True)
return ReadyCheck(name="fail2ban", healthy=False, message="Socket not reachable")
async def _check_config_dir(app_state: AppStateDep) -> ReadyCheck:
"""Check config directory read access."""
from app.config import Settings
effective_settings: Settings = (
app_state.runtime_settings if app_state.runtime_settings is not None else app_state.settings
)
async def _probe() -> None:
config_path = effective_settings.fail2ban_config_dir
# Quick read-test: list directory (checks both existence and readability)
await asyncio.to_thread(os.access, config_path, os.R_OK)
return await _run_check(
"config_dir",
_probe(),
"Config directory not readable",
)
async def _check_scheduler(app_state: AppStateDep) -> ReadyCheck:
"""Check scheduler liveness."""
try:
scheduler = app_state.scheduler
if scheduler is not None and getattr(scheduler, "running", False):
return ReadyCheck(name="scheduler", healthy=True)
elif scheduler is not None:
return ReadyCheck(name="scheduler", healthy=False, message="Scheduler stopped")
else:
return ReadyCheck(name="scheduler", healthy=False, message="Not initialised")
except AttributeError:
return ReadyCheck(name="scheduler", healthy=False, message="Not accessible")
@router.get(
"/ready",
summary="Subsystem readiness probe",
response_model=ReadyResponse,
responses={
200: {"description": "All subsystems healthy"},
503: {"description": "One or more subsystems unreachable"},
},
)
async def readiness_probe(
app_state: AppStateDep,
server_status: ServerStatusDep,
) -> JSONResponse:
"""Readiness check for Kubernetes.
Verifies all critical sub-systems are reachable:
- Database connectivity
- fail2ban socket (via cached server status)
- Config directory read access
- Background scheduler liveness
Returns HTTP 200 only when every check passes; returns HTTP 503 with a
JSON body listing every failed subsystem otherwise. Each check has a
short per-subsystem timeout to prevent the endpoint from overwhelming the
system under load.
"""
db_check, f2b_check, config_check, sched_check = await asyncio.gather(
_check_database(app_state),
_check_fail2ban(app_state, server_status),
_check_config_dir(app_state),
_check_scheduler(app_state),
)
checks: list[ReadyCheck] = [db_check, f2b_check, config_check, sched_check]
failed_count = sum(1 for c in checks if not c.healthy)
http_status = status.HTTP_200_OK if failed_count == 0 else status.HTTP_503_SERVICE_UNAVAILABLE
return JSONResponse(
status_code=http_status,
content=ReadyResponse(
status="ok" if failed_count == 0 else "error",
checks=checks,
failed_count=failed_count,
).model_dump(),
)

View File

@@ -12,7 +12,15 @@ from httpx import ASGITransport, AsyncClient
from app.config import Settings from app.config import Settings
from app.db import init_db from app.db import init_db
from app.exceptions import ConfigValidationError, ConfigWriteError, JailNotFoundError from app.exceptions import ConfigValidationError, ConfigWriteError, JailNotFoundError
from app.main import CORSMiddleware, _enforce_single_worker, _lifespan, create_app from app.main import (
CORSMiddleware,
_assert_middleware_order,
_enforce_single_worker,
_lifespan,
create_app,
)
from app.middleware.correlation import CorrelationIdMiddleware
from app.middleware.rate_limit import RateLimitMiddleware
from app.services import setup_service from app.services import setup_service
@@ -450,14 +458,23 @@ async def test_startup_loads_geo_cache_from_persisted_runtime_database(tmp_path:
exit_stack.enter_context(patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=load_cache)) exit_stack.enter_context(patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=load_cache))
exit_stack.enter_context(patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0))) exit_stack.enter_context(patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0)))
exit_stack.enter_context(patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=True))) exit_stack.enter_context(patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=True)))
exit_stack.enter_context(patch("app.services.setup_service.get_runtime_database_path", new=AsyncMock(return_value=runtime_db_path))) exit_stack.enter_context(patch(
exit_stack.enter_context(patch("app.services.setup_service.get_persisted_runtime_settings", new=AsyncMock(return_value={ "app.services.setup_service.get_runtime_database_path",
"database_path": runtime_db_path, new=AsyncMock(return_value=runtime_db_path),
"fail2ban_socket": "/tmp/persisted.sock", ))
"timezone": "Europe/Berlin", exit_stack.enter_context(patch(
"session_duration_minutes": 123, "app.services.setup_service.get_persisted_runtime_settings",
}))) new=AsyncMock(return_value={
exit_stack.enter_context(patch("app.services.setup_service.get_fail2ban_db_path", new=AsyncMock(return_value="/tmp/fail2ban/banned.tar.bz2"))) "database_path": runtime_db_path,
"fail2ban_socket": "/tmp/persisted.sock",
"timezone": "Europe/Berlin",
"session_duration_minutes": 123,
}),
))
exit_stack.enter_context(patch(
"app.services.setup_service.get_fail2ban_db_path",
new=AsyncMock(return_value="/tmp/fail2ban/banned.tar.bz2"),
))
exit_stack.enter_context(patch("app.tasks.health_check.register")) exit_stack.enter_context(patch("app.tasks.health_check.register"))
exit_stack.enter_context(patch("app.tasks.blocklist_import.register")) exit_stack.enter_context(patch("app.tasks.blocklist_import.register"))
exit_stack.enter_context(patch("app.tasks.geo_cache_flush.register")) exit_stack.enter_context(patch("app.tasks.geo_cache_flush.register"))
@@ -466,8 +483,9 @@ async def test_startup_loads_geo_cache_from_persisted_runtime_database(tmp_path:
with exit_stack: with exit_stack:
async with _lifespan(app): async with _lifespan(app):
loaded_db_path = load_cache.call_args.args[0] runtime_connections = [
runtime_connections = [conn for path, conn in opened_connections if path == runtime_db_path] conn for path, conn in opened_connections if path == runtime_db_path
]
assert runtime_connections, "Expected runtime database to be opened" assert runtime_connections, "Expected runtime database to be opened"
assert app.state.runtime_settings is not None assert app.state.runtime_settings is not None
@@ -538,6 +556,91 @@ async def test_concurrent_requests_use_request_scoped_db_connections(tmp_path: P
assert all(connection.close.await_count == 1 for connection in connections) assert all(connection.close.await_count == 1 for connection in connections)
# ---------------------------------------------------------------------------
# Middleware order validation
# ---------------------------------------------------------------------------
def _make_settings(tmp_path: Path) -> Settings:
"""Return a minimal Settings object with a temporary fail2ban config dir."""
fail2ban_config_dir = tmp_path / "fail2ban"
fail2ban_config_dir.mkdir()
return Settings(
database_path=str(tmp_path / "bangui.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
fail2ban_config_dir=str(fail2ban_config_dir),
session_secret="test-secret-key-do-not-use-in-production",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
def test_create_app_raises_on_incorrect_middleware_order(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""_assert_middleware_order() raises AssertionError when middleware order is wrong.
The security-critical chain requires:
RateLimitMiddleware → CsrfMiddleware → CorrelationIdMiddleware
in user_middleware (processing order: outermost → innermost).
"""
monkeypatch.setenv("TESTING", "1")
settings = _make_settings(tmp_path)
app = create_app(settings=settings)
# Swap CorrelationIdMiddleware and RateLimitMiddleware to break the order.
user_mw = app.user_middleware
corr_idx = next(i for i, m in enumerate(user_mw) if m.cls.__name__ == "CorrelationIdMiddleware")
rate_idx = next(i for i, m in enumerate(user_mw) if m.cls.__name__ == "RateLimitMiddleware")
user_mw[corr_idx], user_mw[rate_idx] = user_mw[rate_idx], user_mw[corr_idx]
with pytest.raises(AssertionError, match="must be registered before"):
_assert_middleware_order(app)
def test_middleware_order_validation_passes_for_correct_order(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""_assert_middleware_order() does not raise when middleware order is correct."""
monkeypatch.setenv("TESTING", "1")
settings = _make_settings(tmp_path)
app = create_app(settings=settings)
_assert_middleware_order(app) # Should not raise
def test_create_app_validates_middleware_order_at_startup(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""create_app() raises immediately if middleware registration order is incorrect.
This test verifies the integration: _assert_middleware_order is called at the
end of create_app, so a fresh app with deliberately wrong middleware order
(simulated by patching add_middleware during creation) raises AssertionError.
"""
monkeypatch.setenv("TESTING", "1")
settings = _make_settings(tmp_path)
from starlette.applications import Starlette
original_add = Starlette.add_middleware
def swapping_add(self, middleware_cls: type, **kwargs: object) -> None:
"""Patched add_middleware that swaps CorrelationId and RateLimit."""
if middleware_cls is CorrelationIdMiddleware:
pass # Skip CorrelationId
elif middleware_cls is RateLimitMiddleware:
original_add(self, RateLimitMiddleware, **kwargs)
original_add(self, CorrelationIdMiddleware)
else:
original_add(self, middleware_cls, **kwargs)
with patch.object(Starlette, "add_middleware", swapping_add), \
pytest.raises(AssertionError, match="must be registered before"):
create_app(settings=settings)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Single-worker enforcement # Single-worker enforcement
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------

View File

@@ -0,0 +1,130 @@
"""Tests for the health-check router — liveness and readiness probes."""
from unittest.mock import MagicMock, patch
import pytest
from httpx import AsyncClient
from app.models.server import ServerStatus
from app.models.response import ReadyCheck
# ---------------------------------------------------------------------------
# GET /health/live — liveness probe
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_liveness_returns_200(client: AsyncClient) -> None:
"""``GET /api/v1/health/live`` must always return HTTP 200."""
response = await client.get("/api/v1/health/live")
assert response.status_code == 200
@pytest.mark.asyncio
async def test_liveness_body_is_ready_response(client: AsyncClient) -> None:
"""Response body must be a ReadyResponse."""
response = await client.get("/api/v1/health/live")
data: dict[str, object] = response.json()
assert data["status"] == "ok"
assert data["failed_count"] == 0
assert "checks" in data
assert isinstance(data["checks"], list)
@pytest.mark.asyncio
async def test_liveness_includes_process_check(client: AsyncClient) -> None:
"""Liveness response must include a 'process' check."""
response = await client.get("/api/v1/health/live")
data: dict[str, object] = response.json()
checks: list[dict[str, object]] = data["checks"] # type: ignore[assignment]
assert any(c.get("name") == "process" and c.get("healthy") is True for c in checks)
# ---------------------------------------------------------------------------
# GET /health/ready — readiness probe
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_readiness_returns_200_when_all_pass(client: AsyncClient) -> None:
"""``GET /api/v1/health/ready`` must return 200 when all subsystems pass."""
with patch("app.routers.health._run_check", side_effect=lambda n, c, e: ReadyCheck(name=n, healthy=True)):
response = await client.get("/api/v1/health/ready")
assert response.status_code == 200
@pytest.mark.asyncio
async def test_readiness_returns_503_when_subsystem_fails(client: AsyncClient) -> None:
"""``GET /api/v1/health/ready`` must return 503 when at least one check fails."""
# Force fail2ban offline
client._transport.app.state.server_status = ServerStatus(online=False)
response = await client.get("/api/v1/health/ready")
assert response.status_code == 503
@pytest.mark.asyncio
async def test_readiness_body_is_ready_response(client: AsyncClient) -> None:
"""Response body must be a ReadyResponse."""
response = await client.get("/api/v1/health/ready")
data: dict[str, object] = response.json()
assert data["status"] in ("ok", "error")
assert "failed_count" in data
assert "checks" in data
assert isinstance(data["checks"], list)
@pytest.mark.asyncio
async def test_readiness_includes_all_subsystems(client: AsyncClient) -> None:
"""Readiness response must include checks for all four subsystems."""
response = await client.get("/api/v1/health/ready")
data: dict[str, object] = response.json()
checks: list[dict[str, object]] = data["checks"] # type: ignore[assignment]
names = {c["name"] for c in checks}
assert names == {"database", "fail2ban", "config_dir", "scheduler"}
@pytest.mark.asyncio
async def test_readiness_status_ok_when_all_healthy(client: AsyncClient) -> None:
"""``status`` must be 'ok' when all checks pass."""
with patch("app.routers.health._run_check", side_effect=lambda n, c, e: ReadyCheck(name=n, healthy=True)):
response = await client.get("/api/v1/health/ready")
data: dict[str, object] = response.json()
assert data["status"] == "ok"
assert data["failed_count"] == 0
@pytest.mark.asyncio
async def test_readiness_status_error_when_fail2ban_offline(client: AsyncClient) -> None:
"""``status`` must be 'error' when fail2ban is offline."""
client._transport.app.state.server_status = ServerStatus(online=False)
response = await client.get("/api/v1/health/ready")
data: dict[str, object] = response.json()
assert data["status"] == "error"
assert data["failed_count"] > 0
@pytest.mark.asyncio
async def test_readiness_includes_failed_subsystem_detail(client: AsyncClient) -> None:
"""When fail2ban is offline the fail2ban check must include an error message."""
client._transport.app.state.server_status = ServerStatus(online=False)
response = await client.get("/api/v1/health/ready")
data: dict[str, object] = response.json()
checks: list[dict[str, object]] = data["checks"] # type: ignore[assignment]
f2b = next(c for c in checks if c["name"] == "fail2ban")
assert f2b["healthy"] is False
assert f2b["message"] is not None
@pytest.mark.asyncio
async def test_readiness_content_type_is_json(client: AsyncClient) -> None:
"""``/api/v1/health/ready`` must set the ``Content-Type`` header to JSON."""
response = await client.get("/api/v1/health/ready")
assert "application/json" in response.headers.get("content-type", "")
@pytest.mark.asyncio
async def test_readiness_live_content_type_is_json(client: AsyncClient) -> None:
"""``/api/v1/health/live`` must set the ``Content-Type`` header to JSON."""
response = await client.get("/api/v1/health/live")
assert "application/json" in response.headers.get("content-type", "")

View File

@@ -53,3 +53,28 @@ With drift correction:
- Total time from poll start to next poll start is always ~5 seconds - Total time from poll start to next poll start is always ~5 seconds
- Fetch duration doesn't affect the long-term polling rate - Fetch duration doesn't affect the long-term polling rate
- Bandwidth and CPU usage remain consistent - Bandwidth and CPU usage remain consistent
## Request Deduplication with Abort Signal Handling
When `useFetchData` is called with a `requestKey`, multiple hook instances sharing the same key coalesce around a single in-flight request. Each instance is tracked as a **subscriber**.
### How it works
- A module-level `Map<requestKey, InFlightRequest>` holds in-flight requests.
- Each `InFlightRequest` contains a `subscribers` map: `subscriberId → { signal, onSetData, onSuccess }`.
- When a new hook instance joins a deduplicated request, it registers a subscriber entry and attaches a `once` abort listener that removes it from the subscriber list.
- When the underlying request resolves, **each subscriber's signal is checked before calling `setData` or `onSuccess`**. Aborted subscribers are skipped, preventing state updates on unmounted components.
- When all subscribers have aborted, the underlying `AbortController` is signalled, cancelling the in-flight fetch.
### Guarantees
| Scenario | Result |
|---|---|
| Subscriber aborts before request resolves | Subscriber receives no `setData`/`onSuccess`. Entry removed from subscriber list. |
| Last subscriber aborts before request resolves | Underlying request is cancelled via `AbortController`. |
| Request resolves, subscriber already aborted | `signal.aborted` checked before `setData`/`onSuccess` — no-op. |
| Request errors | Non-abort errors passed to `handleFetchError`. Each subscriber's error state updated independently. |
### When to use deduplication
Use `requestKey` when multiple components may request the same resource simultaneously. The shared promise avoids redundant network calls and prevents race conditions from out-of-order responses.

View File

@@ -1,8 +1,13 @@
import { describe, it, expect, vi } from "vitest"; import { describe, it, expect, vi, beforeEach } from "vitest";
import { renderHook, act } from "@testing-library/react"; import { renderHook, act } from "@testing-library/react";
import { useFetchData } from "../useFetchData"; import { useFetchData, _resetInFlightRequests } from "../useFetchData";
import type { FetchError } from "../../types/api"; import type { FetchError } from "../../types/api";
// Clear the module-level inFlightRequests map between tests to prevent state leakage
beforeEach(() => {
_resetInFlightRequests();
});
describe("useFetchData", () => { describe("useFetchData", () => {
it("fetches and selects data on mount", async () => { it("fetches and selects data on mount", async () => {
const fetcher = vi.fn().mockResolvedValue({ value: "test" }); const fetcher = vi.fn().mockResolvedValue({ value: "test" });
@@ -406,4 +411,123 @@ describe("useFetchData", () => {
expect(fetcher).toHaveBeenCalledTimes(2); expect(fetcher).toHaveBeenCalledTimes(2);
expect(result.current.data).toBe("second"); expect(result.current.data).toBe("second");
}); });
it("aborted subscriber in deduplication does not receive setData/onSuccess", async () => {
let resolveFirst: ((value: { value: string }) => void) | null = null;
const fetcher = vi.fn().mockImplementation(
() =>
new Promise((resolve) => {
resolveFirst = resolve;
})
);
const selector = vi.fn((response: { value: string }) => response.value);
const onSuccess1 = vi.fn();
const onSuccess2 = vi.fn();
const hook1 = renderHook(() =>
useFetchData({
fetcher,
selector,
errorMessage: "Failed to load",
requestKey: "abort-subscriber-test",
onSuccess: onSuccess1,
})
);
const hook2 = renderHook(() =>
useFetchData({
fetcher,
selector,
errorMessage: "Failed to load",
requestKey: "abort-subscriber-test",
onSuccess: onSuccess2,
})
);
await act(async () => {
await Promise.resolve();
});
expect(fetcher).toHaveBeenCalledTimes(1);
// Unmount hook2 (simulates component unmounting and aborting its signal)
hook2.unmount();
// Allow the request to resolve
await act(async () => {
resolveFirst?.({ value: "shared-data" });
await Promise.resolve();
});
// hook2 should not have received the data (it aborted)
expect(hook2.result.current.data).toBeUndefined();
expect(onSuccess2).not.toHaveBeenCalled();
// hook1 should still have the data
expect(hook1.result.current.data).toBe("shared-data");
expect(onSuccess1).toHaveBeenCalledWith({ value: "shared-data" });
});
it("last subscriber abort cancels underlying request", async () => {
let resolveFirst: ((value: { value: string }) => void) | null = null;
const abortSignals: AbortSignal[] = [];
const fetcher = vi.fn().mockImplementation((signal: AbortSignal) => {
abortSignals.push(signal);
return new Promise((resolve) => {
resolveFirst = resolve;
});
});
const selector = vi.fn((response: { value: string }) => response.value);
const hook1 = renderHook(() =>
useFetchData({
fetcher,
selector,
errorMessage: "Failed to load",
requestKey: "cancel-underlying-test",
})
);
await act(async () => {
await Promise.resolve();
});
expect(fetcher).toHaveBeenCalledTimes(1);
// Mount second subscriber
const hook2 = renderHook(() =>
useFetchData({
fetcher,
selector,
errorMessage: "Failed to load",
requestKey: "cancel-underlying-test",
})
);
await act(async () => {
await Promise.resolve();
});
// Both subscribers now sharing one request
expect(fetcher).toHaveBeenCalledTimes(1);
// Unmount first subscriber (doesn't cancel underlying request yet)
hook1.unmount();
await act(async () => {
await Promise.resolve();
});
// Underlying request should NOT be cancelled yet (hook2 still waiting)
expect(abortSignals[0]?.aborted).toBe(false);
// Unmount second (last) subscriber — should cancel the underlying request
hook2.unmount();
await act(async () => {
await Promise.resolve();
});
// Underlying request should now be cancelled
expect(abortSignals[0]?.aborted).toBe(true);
});
}); });

View File

@@ -19,16 +19,33 @@ import type { FetchError } from "../types/api";
/** /**
* Module-level cache for in-flight requests. * Module-level cache for in-flight requests.
* Maps requestKey to { promise, controller } to enable deduplication * Maps requestKey to { promise, controller, subscribers } to enable deduplication
* across multiple hook instances. * across multiple hook instances.
*/ */
interface Subscriber<TResponse> {
/** This instance's abort signal. */
signal: AbortSignal;
/** Callbacks registered by this subscriber. */
onSetData?: (response: TResponse) => void;
onSuccess?: (response: TResponse) => void;
}
interface InFlightRequest<TResponse> { interface InFlightRequest<TResponse> {
promise: Promise<TResponse>; promise: Promise<TResponse>;
controller: AbortController; controller: AbortController;
/** Map of subscriberId -> Subscriber. Cleared when subscriber aborts. */
subscribers: Map<string, Subscriber<TResponse>>;
/** True when initiator has cleaned up but subscribers remain. */
initiatorDone: boolean;
} }
const inFlightRequests = new Map<string, InFlightRequest<unknown>>(); const inFlightRequests = new Map<string, InFlightRequest<unknown>>();
/** Visible for testing only. Clears all in-flight request state. */
export const _resetInFlightRequests = (): void => {
inFlightRequests.clear();
};
export interface UseFetchDataOptions<TResponse, TData> { export interface UseFetchDataOptions<TResponse, TData> {
/** Async function that accepts an AbortSignal for cancellation. */ /** Async function that accepts an AbortSignal for cancellation. */
fetcher: (signal: AbortSignal) => Promise<TResponse>; fetcher: (signal: AbortSignal) => Promise<TResponse>;
@@ -86,18 +103,57 @@ export function useFetchData<TResponse, TData>(
const [error, setError] = useState<FetchError | null>(null); const [error, setError] = useState<FetchError | null>(null);
const abortRef = useRef<AbortController | null>(null); const abortRef = useRef<AbortController | null>(null);
const localControllerRef = useRef<AbortController | null>(null); const localControllerRef = useRef<AbortController | null>(null);
const subscriberAbortRef = useRef<AbortController | null>(null);
/** Unique ID for this instance, used to track its subscription to deduplicated requests. */
const subscriberIdRef = useRef<string | null>(null);
const refresh = useCallback((): void => { const refresh = useCallback((): void => {
// Abort any previous request from this hook instance
abortRef.current?.abort();
localControllerRef.current = new AbortController();
abortRef.current = localControllerRef.current;
// If using request deduplication via requestKey and a request is already in-flight, // If using request deduplication via requestKey and a request is already in-flight,
// wait for it to complete instead of launching a duplicate // subscribe to it instead of launching a duplicate
if (requestKey && inFlightRequests.has(requestKey)) { if (requestKey && inFlightRequests.has(requestKey)) {
const inFlight = inFlightRequests.get(requestKey)! as InFlightRequest<TResponse>; const inFlight = inFlightRequests.get(requestKey)! as InFlightRequest<TResponse>;
inFlight.promise
.then((response: TResponse) => { // Create per-instance abort controller for this subscription.
// Do NOT abort previous subscription signals here - each subscriber's signal
// only aborts when THAT subscriber unmounts. Aborting old signals would
// incorrectly remove other active subscribers from the list.
subscriberAbortRef.current = new AbortController();
// Register this instance as a subscriber.
// Always generate a new ID - each hook instance must have its own subscriber ID
// so that unmounting one subscriber does not remove another subscriber's entry.
subscriberIdRef.current = crypto.randomUUID();
const sid = subscriberIdRef.current;
const subscription: Subscriber<TResponse> = {
signal: subscriberAbortRef.current.signal,
onSetData: (response: TResponse) => {
setData(selector(response)); setData(selector(response));
if (onSuccess) { if (onSuccess) {
onSuccess(response); onSuccess(response);
} }
},
onSuccess,
};
inFlight.subscribers.set(sid, subscription);
// When this instance aborts, remove it from the subscriber list.
// Note: we do NOT abort the underlying controller here - that belongs to
// the initiator and should only be aborted by the initiator itself.
const cleanup = (): void => {
inFlight.subscribers.delete(sid);
};
subscription.signal.addEventListener("abort", cleanup, { once: true });
inFlight.promise
.then((response: TResponse) => {
// Check whether this subscriber has already aborted before propagating result
if (subscription.signal.aborted) return;
subscription.onSetData?.(response);
}) })
.catch((err: unknown) => { .catch((err: unknown) => {
// Only handle non-abort errors; abort errors are silently ignored // Only handle non-abort errors; abort errors are silently ignored
@@ -139,7 +195,6 @@ export function useFetchData<TResponse, TData>(
if (!controller.signal.aborted) { if (!controller.signal.aborted) {
setLoading(false); setLoading(false);
} }
// Clear cache entry when response arrives
if (requestKey) { if (requestKey) {
inFlightRequests.delete(requestKey); inFlightRequests.delete(requestKey);
} }
@@ -150,6 +205,8 @@ export function useFetchData<TResponse, TData>(
inFlightRequests.set(requestKey, { inFlightRequests.set(requestKey, {
promise: responsePromise, promise: responsePromise,
controller, controller,
subscribers: new Map(),
initiatorDone: false,
}); });
} }
}, [fetcher, selector, errorMessage, onSuccess, requestKey]); }, [fetcher, selector, errorMessage, onSuccess, requestKey]);
@@ -158,7 +215,25 @@ export function useFetchData<TResponse, TData>(
refresh(); refresh();
return (): void => { return (): void => {
abortRef.current?.abort(); if (subscriberIdRef.current !== null) {
// Subscriber cleanup
const inFlight = requestKey ? inFlightRequests.get(requestKey) : undefined;
subscriberAbortRef.current?.abort();
// If initiator already done and no subscribers left, cancel the request
if (inFlight && inFlight.initiatorDone && inFlight.subscribers.size === 0) {
inFlight.controller.abort();
}
} else {
// Initiator cleanup
const inFlight = requestKey ? inFlightRequests.get(requestKey) : undefined;
// Mark initiator as done so last subscriber knows to cancel
if (inFlight) {
inFlight.initiatorDone = true;
}
if (!inFlight || inFlight.subscribers.size === 0) {
abortRef.current?.abort();
}
}
}; };
}, [refresh]); }, [refresh]);