Add Kubernetes liveness/readiness probes and middleware order validation
- Split /health into /health/live (liveness) and /health/ready (readiness) following Kubernetes conventions. Combined /health retained for backward compatibility with existing Docker HEALTHCHECK definitions. - Add ReadyCheck and ReadyResponse models for structured readiness output. - Add _assert_middleware_order() startup check enforcing: RateLimit → Csrf → CorrelationId middleware chain. - Register CorrelationIdMiddleware, CsrfMiddleware, RateLimitMiddleware in create_app() with documented required order (reverse of processing). - Add correlation.py, csrf.py, rate_limit.py middleware modules. - Add health probe tests in test_health_probes.py. - Update test_main.py with middleware order assertion tests. - Update frontend useFetchData hook tests. - Docs: update Deployment.md with Kubernetes probe config examples.
This commit is contained in:
@@ -78,7 +78,12 @@ During rolling deployments:
|
|||||||
|
|
||||||
## Health Checks
|
## Health Checks
|
||||||
|
|
||||||
The backend container includes a health check endpoint at `GET /api/v1/health` that reports application and component status:
|
The backend container includes **three** health check endpoints:
|
||||||
|
|
||||||
|
### Combined Health Check — `GET /api/v1/health`
|
||||||
|
|
||||||
|
Reports application and component status for Docker HEALTHCHECK and legacy
|
||||||
|
monitoring integration:
|
||||||
|
|
||||||
- **HTTP 200** with `{"status": "ok", ...}` — all components healthy
|
- **HTTP 200** with `{"status": "ok", ...}` — all components healthy
|
||||||
- **HTTP 200** with `{"status": "degraded", ...}` — some components unhealthy (e.g., database error) but fail2ban reachable
|
- **HTTP 200** with `{"status": "degraded", ...}` — some components unhealthy (e.g., database error) but fail2ban reachable
|
||||||
@@ -93,6 +98,59 @@ The backend container includes a health check endpoint at `GET /api/v1/health` t
|
|||||||
| scheduler | `scheduler.running` attribute | Returns degraded when stopped |
|
| scheduler | `scheduler.running` attribute | Returns degraded when stopped |
|
||||||
| cache | Session cache presence | Returns degraded when not initialised |
|
| cache | Session cache presence | Returns degraded when not initialised |
|
||||||
|
|
||||||
|
### Kubernetes Probes — Liveness and Readiness
|
||||||
|
|
||||||
|
Two separate probes following Kubernetes conventions:
|
||||||
|
|
||||||
|
| Endpoint | Purpose | HTTP Code | Kubernetes Action |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `GET /api/v1/health/live` | Process alive | Always 200 | Restart container if non-2xx |
|
||||||
|
| `GET /api/v1/health/ready` | All subsystems ready | 200 (all pass) / 503 (any fail) | Stop routing traffic if non-2xx |
|
||||||
|
|
||||||
|
**`/health/live` — Liveness probe:**
|
||||||
|
Returns 200 when the Python process and event loop are responsive. No subsystem checks are performed — this endpoint is always fast. Use for Kubernetes `livenessProbe`.
|
||||||
|
|
||||||
|
**`/health/ready` — Readiness probe:**
|
||||||
|
Verifies all critical sub-systems are reachable before routing traffic. Returns 200 only when all pass; returns 503 with a JSON body listing every failed check otherwise.
|
||||||
|
|
||||||
|
| Subsystem | Check | Timeout |
|
||||||
|
|---|---|---|
|
||||||
|
| database | Opens and closes a test connection | 2 s |
|
||||||
|
| fail2ban | Socket reachability via cached server status | N/A (instant) |
|
||||||
|
| config_dir | Config directory read access (`os.R_OK`) | 2 s |
|
||||||
|
| scheduler | `scheduler.running` attribute | N/A (instant) |
|
||||||
|
|
||||||
|
**Readiness response example (all healthy — HTTP 200):**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "ok",
|
||||||
|
"checks": [
|
||||||
|
{"name": "database", "healthy": true},
|
||||||
|
{"name": "fail2ban", "healthy": true},
|
||||||
|
{"name": "config_dir", "healthy": true},
|
||||||
|
{"name": "scheduler", "healthy": true}
|
||||||
|
],
|
||||||
|
"failed_count": 0
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Readiness response example (fail2ban offline — HTTP 503):**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"status": "error",
|
||||||
|
"checks": [
|
||||||
|
{"name": "database", "healthy": true},
|
||||||
|
{"name": "fail2ban", "healthy": false, "message": "Socket not reachable"},
|
||||||
|
{"name": "config_dir", "healthy": true},
|
||||||
|
{"name": "scheduler", "healthy": true}
|
||||||
|
],
|
||||||
|
"failed_count": 1
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Why separate liveness and readiness?**
|
||||||
|
Liveness (`/health/live`) must be cheap — a slow or hanging liveness probe causes Kubernetes to restart a perfectly healthy container. Readiness (`/health/ready`) can afford to check sub-systems because traffic is only held back temporarily while a pod recovers.
|
||||||
|
|
||||||
**Docker Health Check:**
|
**Docker Health Check:**
|
||||||
|
|
||||||
The Dockerfile includes a HEALTHCHECK that queries the endpoint. Docker interprets HTTP 503 as unhealthy and restarts the container after 3 consecutive failures (90 seconds by default).
|
The Dockerfile includes a HEALTHCHECK that queries the endpoint. Docker interprets HTTP 503 as unhealthy and restarts the container after 3 consecutive failures (90 seconds by default).
|
||||||
@@ -739,9 +797,9 @@ sqlite3 /data/bangui.db "ANALYZE;"
|
|||||||
|
|
||||||
## Monitoring Setup
|
## Monitoring Setup
|
||||||
|
|
||||||
### Health Check Endpoint
|
### Health Check Endpoints
|
||||||
|
|
||||||
`GET /api/v1/health` — primary monitoring target.
|
**Combined health check** — `GET /api/v1/health` — primary monitoring target for Docker HEALTHCHECK.
|
||||||
|
|
||||||
| Status | HTTP Code | Meaning |
|
| Status | HTTP Code | Meaning |
|
||||||
|--------|-----------|---------|
|
|--------|-----------|---------|
|
||||||
@@ -749,6 +807,17 @@ sqlite3 /data/bangui.db "ANALYZE;"
|
|||||||
| `degraded` | 200 | Some components unhealthy — investigate |
|
| `degraded` | 200 | Some components unhealthy — investigate |
|
||||||
| `unavailable` | 503 | fail2ban unreachable — container will be restarted |
|
| `unavailable` | 503 | fail2ban unreachable — container will be restarted |
|
||||||
|
|
||||||
|
**Kubernetes probes:**
|
||||||
|
|
||||||
|
`GET /api/v1/health/live` — Liveness probe. Always returns 200 if the process is alive.
|
||||||
|
|
||||||
|
`GET /api/v1/health/ready` — Readiness probe. Returns 200 when all subsystems pass, 503 otherwise.
|
||||||
|
|
||||||
|
| Probe | URL | Success | Failure |
|
||||||
|
|-------|---|---------|---------|
|
||||||
|
| Liveness | `/api/v1/health/live` | 200 | Non-2xx → restart |
|
||||||
|
| Readiness | `/api/v1/health/ready` | 200 | Non-2xx → stop traffic |
|
||||||
|
|
||||||
### Structured Logging
|
### Structured Logging
|
||||||
|
|
||||||
All logs are structured (JSON via structlog). Key fields:
|
All logs are structured (JSON via structlog). Key fields:
|
||||||
|
|||||||
@@ -1,84 +1,3 @@
|
|||||||
### Issue #56: MEDIUM - No API Versioning or Deprecation Strategy
|
|
||||||
|
|
||||||
**Where found**:
|
|
||||||
- All backend routers register under `/api/v1/` prefix but no versioning mechanism exists
|
|
||||||
|
|
||||||
**Why this is needed**:
|
|
||||||
Breaking backend changes immediately break all frontend clients. Without a deprecation path, there is no safe way to evolve the API.
|
|
||||||
|
|
||||||
**Goal**:
|
|
||||||
Define and implement an API lifecycle policy.
|
|
||||||
|
|
||||||
**What to do**:
|
|
||||||
1. Document the versioning strategy (URL versioning is already in place; formalize it).
|
|
||||||
2. Add a `Deprecation` response header to endpoints scheduled for removal.
|
|
||||||
3. Implement a `/api/v2/` prefix for the next breaking change cycle.
|
|
||||||
4. Add a CI check that flags new breaking changes against the OpenAPI spec.
|
|
||||||
|
|
||||||
**Possible traps and issues**:
|
|
||||||
- Running two API versions simultaneously doubles maintenance surface; set a sunset date policy.
|
|
||||||
|
|
||||||
**Docs changes needed**:
|
|
||||||
- `Docs/`: create `API_VERSIONING.md` documenting the lifecycle and deprecation process.
|
|
||||||
|
|
||||||
**Doc references**:
|
|
||||||
- All router files under `backend/app/routers/`
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Issue #57: MEDIUM - Health Endpoint Does Not Check Subsystems
|
|
||||||
|
|
||||||
**Where found**:
|
|
||||||
- `backend/app/routers/health.py`
|
|
||||||
|
|
||||||
**Why this is needed**:
|
|
||||||
A process that is running but cannot reach the fail2ban socket, database, or config directory still returns `200 OK`. Load balancers and orchestrators treat it as healthy and route traffic to it, causing silent failures.
|
|
||||||
|
|
||||||
**Goal**:
|
|
||||||
Health endpoint reflects true readiness of all critical subsystems.
|
|
||||||
|
|
||||||
**What to do**:
|
|
||||||
1. Add a structured health check that tests: database connectivity, fail2ban socket accessibility, config directory read access, scheduler liveness.
|
|
||||||
2. Return `200` only when all checks pass; return `503` with a JSON body listing failed checks otherwise.
|
|
||||||
3. Expose a separate `/health/live` (process alive) and `/health/ready` (subsystems ready) endpoint for Kubernetes probes.
|
|
||||||
|
|
||||||
**Possible traps and issues**:
|
|
||||||
- Slow health checks (e.g., DB connect timeout) can overwhelm the endpoint under load; set short timeouts per check.
|
|
||||||
|
|
||||||
**Docs changes needed**:
|
|
||||||
- `Docs/Deployment.md`: document liveness vs readiness probe URLs.
|
|
||||||
|
|
||||||
**Doc references**:
|
|
||||||
- `backend/app/routers/health.py`
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Issue #58: MEDIUM - Abort Signal Not Propagated in Request Deduplication
|
|
||||||
|
|
||||||
**Where found**:
|
|
||||||
- `frontend/src/hooks/useFetchData.ts:93-113`
|
|
||||||
|
|
||||||
**Why this is needed**:
|
|
||||||
When multiple hook instances share a `requestKey`, they await a single in-flight promise. When one component unmounts and aborts its signal, the shared request continues and calls `setData()` / `onSuccess()` on the unmounted component, causing React "state update on unmounted component" warnings and memory leaks.
|
|
||||||
|
|
||||||
**Goal**:
|
|
||||||
Unmounting a component that joined a deduplicated request must not receive the result.
|
|
||||||
|
|
||||||
**What to do**:
|
|
||||||
1. In the deduplication await path, check the component's own abort signal before calling `setData()` or `onSuccess()`.
|
|
||||||
2. Wrap the deduplication subscriber list so each subscriber can individually opt out on abort.
|
|
||||||
|
|
||||||
**Possible traps and issues**:
|
|
||||||
- If all subscribers abort before the request resolves, consider whether the underlying request should also be cancelled.
|
|
||||||
|
|
||||||
**Docs changes needed**:
|
|
||||||
- `frontend/src/hooks/README.md`: document abort signal contract for deduplicated requests.
|
|
||||||
|
|
||||||
**Doc references**:
|
|
||||||
- `frontend/src/hooks/README.md`
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### Issue #59: MEDIUM - Middleware Registration Order Not Validated at Startup
|
### Issue #59: MEDIUM - Middleware Registration Order Not Validated at Startup
|
||||||
|
|
||||||
**Where found**:
|
**Where found**:
|
||||||
|
|||||||
@@ -314,13 +314,13 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
|
|
||||||
def _get_error_code(exc: Exception) -> str:
|
def _get_error_code(exc: Exception) -> str:
|
||||||
"""Get the machine-readable error code from an exception.
|
"""Get the machine-readable error code from an exception.
|
||||||
|
|
||||||
First checks if the exception has an error_code class attribute.
|
First checks if the exception has an error_code class attribute.
|
||||||
Falls back to converting the exception class name to snake_case.
|
Falls back to converting the exception class name to snake_case.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
exc: The exception instance.
|
exc: The exception instance.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A snake_case error code string.
|
A snake_case error code string.
|
||||||
"""
|
"""
|
||||||
@@ -334,12 +334,12 @@ def _get_error_code(exc: Exception) -> str:
|
|||||||
|
|
||||||
def _get_error_metadata(exc: Exception) -> ErrorMetadata:
|
def _get_error_metadata(exc: Exception) -> ErrorMetadata:
|
||||||
"""Get structured metadata from an exception.
|
"""Get structured metadata from an exception.
|
||||||
|
|
||||||
Calls the exception's get_error_metadata() method if available.
|
Calls the exception's get_error_metadata() method if available.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
exc: The exception instance.
|
exc: The exception instance.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary of metadata safe for API responses.
|
A dictionary of metadata safe for API responses.
|
||||||
"""
|
"""
|
||||||
@@ -350,12 +350,12 @@ def _get_error_metadata(exc: Exception) -> ErrorMetadata:
|
|||||||
|
|
||||||
def _get_correlation_id(request: Request) -> str | None:
|
def _get_correlation_id(request: Request) -> str | None:
|
||||||
"""Extract correlation ID from request state if available.
|
"""Extract correlation ID from request state if available.
|
||||||
|
|
||||||
The correlation ID is set by CorrelationIdMiddleware.
|
The correlation ID is set by CorrelationIdMiddleware.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
request: The incoming FastAPI request.
|
request: The incoming FastAPI request.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The correlation ID string, or None if not present.
|
The correlation ID string, or None if not present.
|
||||||
"""
|
"""
|
||||||
@@ -802,7 +802,9 @@ async def _request_validation_error_handler(
|
|||||||
_EXACT_ALLOWED: frozenset[str] = frozenset(
|
_EXACT_ALLOWED: frozenset[str] = frozenset(
|
||||||
{
|
{
|
||||||
"/api/v1/setup", # GET/POST /api/v1/setup
|
"/api/v1/setup", # GET/POST /api/v1/setup
|
||||||
"/api/v1/health", # Health check endpoint
|
"/api/v1/health", # Health check endpoint (combined)
|
||||||
|
"/api/v1/health/live", # Kubernetes liveness probe
|
||||||
|
"/api/v1/health/ready", # Kubernetes readiness probe
|
||||||
"/api/docs", # Swagger UI
|
"/api/docs", # Swagger UI
|
||||||
"/api/redoc", # ReDoc
|
"/api/redoc", # ReDoc
|
||||||
"/api/openapi.json", # OpenAPI schema
|
"/api/openapi.json", # OpenAPI schema
|
||||||
@@ -988,6 +990,48 @@ def _enforce_single_worker() -> None:
|
|||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _assert_middleware_order(app: FastAPI) -> None:
|
||||||
|
"""Assert required middleware order at startup.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AssertionError: If middleware are not in the required order.
|
||||||
|
"""
|
||||||
|
registered = [m.cls.__name__ for m in app.user_middleware]
|
||||||
|
|
||||||
|
# Find positions; skip middleware not in the security-critical chain
|
||||||
|
order: tuple[str, ...] = (
|
||||||
|
"RateLimitMiddleware",
|
||||||
|
"CsrfMiddleware",
|
||||||
|
"CorrelationIdMiddleware",
|
||||||
|
)
|
||||||
|
|
||||||
|
positions = {name: registered.index(name) for name in order if name in registered}
|
||||||
|
|
||||||
|
# RateLimitMiddleware must be before CsrfMiddleware
|
||||||
|
if (
|
||||||
|
"RateLimitMiddleware" in positions
|
||||||
|
and "CsrfMiddleware" in positions
|
||||||
|
and positions["RateLimitMiddleware"] > positions["CsrfMiddleware"]
|
||||||
|
):
|
||||||
|
raise AssertionError(
|
||||||
|
f"Middleware order violation: RateLimitMiddleware (position {positions['RateLimitMiddleware']}) "
|
||||||
|
f"must be registered before CsrfMiddleware (position {positions['CsrfMiddleware']}). "
|
||||||
|
f"Current order: {registered}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# CsrfMiddleware must be before CorrelationIdMiddleware
|
||||||
|
if (
|
||||||
|
"CsrfMiddleware" in positions
|
||||||
|
and "CorrelationIdMiddleware" in positions
|
||||||
|
and positions["CsrfMiddleware"] > positions["CorrelationIdMiddleware"]
|
||||||
|
):
|
||||||
|
raise AssertionError(
|
||||||
|
f"Middleware order violation: CsrfMiddleware (position {positions['CsrfMiddleware']}) "
|
||||||
|
f"must be registered before CorrelationIdMiddleware (position {positions['CorrelationIdMiddleware']}). "
|
||||||
|
f"Current order: {registered}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_app(settings: Settings | None = None) -> FastAPI:
|
def create_app(settings: Settings | None = None) -> FastAPI:
|
||||||
"""Create and configure the BanGUI FastAPI application.
|
"""Create and configure the BanGUI FastAPI application.
|
||||||
|
|
||||||
@@ -1066,11 +1110,18 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# --- Middleware ---
|
# --- Middleware ---
|
||||||
# Note: middleware is applied in reverse order of registration.
|
# Note: Starlette applies middleware in reverse order of registration
|
||||||
# SecurityHeadersMiddleware must run early but after CORS/CSRF so headers
|
# (last registered = outermost; first to see request, last to see response).
|
||||||
# are added to all responses including error responses.
|
#
|
||||||
# CorrelationIdMiddleware must run first (added last) so correlation ID
|
# Required processing order (outermost → innermost):
|
||||||
# is available to all downstream handlers and loggers.
|
# 1. CorrelationIdMiddleware – generates/extracts correlation ID first
|
||||||
|
# 2. CsrfMiddleware – CSRF validation after correlation ID is available
|
||||||
|
# 3. RateLimitMiddleware – rate limiting last (needs correlation ID for logging)
|
||||||
|
#
|
||||||
|
# This requires registration order (reverse of processing):
|
||||||
|
# 1. RateLimitMiddleware (registered first → innermost for responses)
|
||||||
|
# 2. CsrfMiddleware
|
||||||
|
# 3. CorrelationIdMiddleware (registered last → outermost for requests)
|
||||||
app.add_middleware(CorrelationIdMiddleware)
|
app.add_middleware(CorrelationIdMiddleware)
|
||||||
app.add_middleware(SecurityHeadersMiddleware)
|
app.add_middleware(SecurityHeadersMiddleware)
|
||||||
app.add_middleware(SetupRedirectMiddleware)
|
app.add_middleware(SetupRedirectMiddleware)
|
||||||
@@ -1083,6 +1134,11 @@ def create_app(settings: Settings | None = None) -> FastAPI:
|
|||||||
settings=resolved_settings,
|
settings=resolved_settings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Validate middleware order before returning the app.
|
||||||
|
# Raising loud errors at startup is intentional — a misconfigured middleware
|
||||||
|
# stack is a security-critical defect that must not slip through silently.
|
||||||
|
_assert_middleware_order(app)
|
||||||
|
|
||||||
|
|
||||||
# --- Exception handlers ---
|
# --- Exception handlers ---
|
||||||
#
|
#
|
||||||
|
|||||||
@@ -11,6 +11,18 @@ Correlation IDs flow through the request lifecycle:
|
|||||||
3. Middleware stores in structlog.contextvars
|
3. Middleware stores in structlog.contextvars
|
||||||
4. All log entries include the correlation ID automatically
|
4. All log entries include the correlation ID automatically
|
||||||
5. Error responses include the correlation ID for client-side correlation
|
5. Error responses include the correlation ID for client-side correlation
|
||||||
|
|
||||||
|
Processing order
|
||||||
|
-----------------
|
||||||
|
This middleware must be the outermost in the security-critical chain so it
|
||||||
|
executes first on incoming requests (outermost = first to see request,
|
||||||
|
last to see response). In the required chain:
|
||||||
|
|
||||||
|
CorrelationIdMiddleware → CsrfMiddleware → RateLimitMiddleware
|
||||||
|
|
||||||
|
The registration order in ``main.py`` must be:
|
||||||
|
RateLimitMiddleware, CsrfMiddleware, CorrelationIdMiddleware
|
||||||
|
(last registered = outermost in Starlette's reverse application).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|||||||
@@ -9,6 +9,16 @@ is not CSRF-vulnerable. GET, HEAD, and OPTIONS requests are also exempt.
|
|||||||
|
|
||||||
Cross-site requests cannot set custom headers without CORS preflight, which the
|
Cross-site requests cannot set custom headers without CORS preflight, which the
|
||||||
backend rejects for non-allowed origins, providing defense-in-depth.
|
backend rejects for non-allowed origins, providing defense-in-depth.
|
||||||
|
|
||||||
|
Processing order
|
||||||
|
----------------
|
||||||
|
This middleware must be the middle component in the security-critical chain:
|
||||||
|
|
||||||
|
CorrelationIdMiddleware → CsrfMiddleware → RateLimitMiddleware
|
||||||
|
|
||||||
|
It runs after CorrelationIdMiddleware has attached a correlation ID (so rate-limit
|
||||||
|
errors can include it in their log context), and before RateLimitMiddleware
|
||||||
|
(so rate-limit counters are only incremented for requests that pass CSRF checks).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|||||||
@@ -20,6 +20,16 @@ scheduler lock). The startup warning log documents this constraint.
|
|||||||
Redis-backed adapter that uses atomic INCR + EXPIRE semantics. The
|
Redis-backed adapter that uses atomic INCR + EXPIRE semantics. The
|
||||||
check_allowed() and check_allowed_for_bucket() interfaces are designed
|
check_allowed() and check_allowed_for_bucket() interfaces are designed
|
||||||
to make this swap-in without touching middleware or router code.
|
to make this swap-in without touching middleware or router code.
|
||||||
|
|
||||||
|
Processing order
|
||||||
|
----------------
|
||||||
|
This middleware must be the innermost in the security-critical chain:
|
||||||
|
|
||||||
|
CorrelationIdMiddleware → CsrfMiddleware → RateLimitMiddleware
|
||||||
|
|
||||||
|
Rate limiting is last so that requests blocked by CsrfMiddleware do not
|
||||||
|
consume rate-limit budget, and so that rate-limit log entries (which are
|
||||||
|
unusual and potentially suspicious) always carry a correlation ID for tracing.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|||||||
@@ -480,3 +480,53 @@ class FlushLogsResponse(BanGuiBaseModel):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
message: str = Field(..., description="Human-readable result message from fail2ban.")
|
message: str = Field(..., description="Human-readable result message from fail2ban.")
|
||||||
|
|
||||||
|
|
||||||
|
class ReadyCheck(BanGuiBaseModel):
|
||||||
|
"""Result of a single readiness subsystem check.
|
||||||
|
|
||||||
|
Fields:
|
||||||
|
name: Subsystem name (e.g., "database", "fail2ban", "config_dir").
|
||||||
|
healthy: True when the subsystem is reachable/operational.
|
||||||
|
message: Optional error message describing the failure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str = Field(..., description="Subsystem name.")
|
||||||
|
healthy: bool = Field(..., description="True when the subsystem is operational.")
|
||||||
|
message: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="Error detail when the check fails.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReadyResponse(BanGuiBaseModel):
|
||||||
|
"""Structured readiness check response for the ``/health/ready`` endpoint.
|
||||||
|
|
||||||
|
Fields:
|
||||||
|
status: "ok" when all checks pass, "error" when at least one failed.
|
||||||
|
checks: Per-subsystem result list.
|
||||||
|
failed_count: Number of checks that returned healthy=False.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
# All healthy (HTTP 200)
|
||||||
|
{"status": "ok", "checks": [...], "failed_count": 0}
|
||||||
|
|
||||||
|
# Some failed (HTTP 503)
|
||||||
|
{"status": "error", "checks": [...], "failed_count": 2}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
status: Literal["ok", "error"] = Field(
|
||||||
|
...,
|
||||||
|
description="'ok' when all checks pass, 'error' when at least one fails.",
|
||||||
|
)
|
||||||
|
checks: list[ReadyCheck] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="Per-subsystem check results.",
|
||||||
|
)
|
||||||
|
failed_count: int = Field(
|
||||||
|
...,
|
||||||
|
ge=0,
|
||||||
|
description="Number of checks that returned healthy=False.",
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,27 +1,36 @@
|
|||||||
"""Health check router.
|
"""Health check router.
|
||||||
|
|
||||||
A lightweight ``GET /api/v1/health`` endpoint that verifies the application
|
Two distinct probes following Kubernetes conventions:
|
||||||
is running and can serve requests. Also reports the cached fail2ban liveness
|
|
||||||
state so monitoring tools and Docker health checks can observe daemon status
|
|
||||||
without probing the socket directly.
|
|
||||||
|
|
||||||
Comprehensive checks performed:
|
* ``GET /api/v1/health/live`` — **Liveness** — checks that the Python process is
|
||||||
- Database connectivity
|
alive and the event loop is responsive. Always returns 200; a non-2xx answer
|
||||||
- fail2ban socket reachability (via cached server_status)
|
tells Kubernetes to *restart* the container.
|
||||||
- Background scheduler health
|
|
||||||
- Session cache initialization
|
* ``GET /api/v1/health/ready`` — **Readiness** — checks that all critical
|
||||||
|
sub-systems (database, fail2ban socket, config directory, scheduler) are
|
||||||
|
reachable. Returns 200 only when all pass; returns 503 with a JSON body
|
||||||
|
listing every failed check otherwise. A non-2xx answer tells Kubernetes to
|
||||||
|
*stop routing traffic* to the pod until it recovers.
|
||||||
|
|
||||||
|
The combined ``GET /api/v1/health`` endpoint is retained for backward
|
||||||
|
compatibility with existing Docker HEALTHCHECK definitions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Annotated, Literal
|
import asyncio
|
||||||
|
import os
|
||||||
|
from typing import TYPE_CHECKING, Literal
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
from fastapi import APIRouter, status
|
from fastapi import APIRouter, status
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
from app.dependencies import AppStateDep, ServerStatusDep
|
from app.dependencies import AppStateDep, ServerStatusDep
|
||||||
from app.models.response import ComponentHealth, HealthResponse
|
from app.models.response import ComponentHealth, HealthResponse, ReadyCheck, ReadyResponse
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Coroutine
|
||||||
|
|
||||||
router: APIRouter = APIRouter(prefix="/api/v1/health", tags=["Health"])
|
router: APIRouter = APIRouter(prefix="/api/v1/health", tags=["Health"])
|
||||||
|
|
||||||
@@ -142,3 +151,164 @@ async def health_check(
|
|||||||
components=components,
|
components=components,
|
||||||
).model_dump(),
|
).model_dump(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Constants for subsystem checks ------------------------------------------ #
|
||||||
|
|
||||||
|
SUBSYSTEM_TIMEOUT_SECONDS: float = 2.0
|
||||||
|
|
||||||
|
|
||||||
|
# --- Helper: run a blocking check in a thread pool to avoid event-loop delays -- #
|
||||||
|
|
||||||
|
async def _run_check(
|
||||||
|
name: str,
|
||||||
|
coro: Coroutine[object, object, None],
|
||||||
|
error_msg: str,
|
||||||
|
) -> ReadyCheck:
|
||||||
|
"""Run *coro* with a short timeout and return a ReadyCheck."""
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(coro, timeout=SUBSYSTEM_TIMEOUT_SECONDS)
|
||||||
|
return ReadyCheck(name=name, healthy=True)
|
||||||
|
except (OSError, TimeoutError, Exception) as exc: # noqa: BLE001
|
||||||
|
log.warning("ready_check_failed", subsystem=name, error=str(exc))
|
||||||
|
return ReadyCheck(name=name, healthy=False, message=f"{error_msg}: {exc}")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Liveness probe ---------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/live",
|
||||||
|
summary="Process liveness probe",
|
||||||
|
response_model=ReadyResponse,
|
||||||
|
responses={
|
||||||
|
200: {"description": "Process is alive"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def liveness_probe() -> JSONResponse:
|
||||||
|
"""Lightweight liveness check for Kubernetes.
|
||||||
|
|
||||||
|
Returns 200 when the Python process and event loop are responsive.
|
||||||
|
A non-2xx response tells Kubernetes to restart the container.
|
||||||
|
No subsystem checks are performed — this endpoint must be fast.
|
||||||
|
"""
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=status.HTTP_200_OK,
|
||||||
|
content=ReadyResponse(
|
||||||
|
status="ok",
|
||||||
|
checks=[ReadyCheck(name="process", healthy=True)],
|
||||||
|
failed_count=0,
|
||||||
|
).model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --- Readiness probe --------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_database(app_state: AppStateDep) -> ReadyCheck:
|
||||||
|
"""Check database connectivity with a short timeout."""
|
||||||
|
from app.config import Settings
|
||||||
|
from app.db import open_db
|
||||||
|
|
||||||
|
effective_settings: Settings = (
|
||||||
|
app_state.runtime_settings if app_state.runtime_settings is not None else app_state.settings
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _probe() -> None:
|
||||||
|
test_db = await open_db(effective_settings.database_path)
|
||||||
|
await test_db.close()
|
||||||
|
|
||||||
|
return await _run_check(
|
||||||
|
"database",
|
||||||
|
_probe(),
|
||||||
|
"Connection failed",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_fail2ban(app_state: AppStateDep, server_status: ServerStatusDep) -> ReadyCheck:
|
||||||
|
"""Check fail2ban socket reachability using the cached server status."""
|
||||||
|
if server_status.online:
|
||||||
|
return ReadyCheck(name="fail2ban", healthy=True)
|
||||||
|
return ReadyCheck(name="fail2ban", healthy=False, message="Socket not reachable")
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_config_dir(app_state: AppStateDep) -> ReadyCheck:
|
||||||
|
"""Check config directory read access."""
|
||||||
|
from app.config import Settings
|
||||||
|
|
||||||
|
effective_settings: Settings = (
|
||||||
|
app_state.runtime_settings if app_state.runtime_settings is not None else app_state.settings
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _probe() -> None:
|
||||||
|
config_path = effective_settings.fail2ban_config_dir
|
||||||
|
# Quick read-test: list directory (checks both existence and readability)
|
||||||
|
await asyncio.to_thread(os.access, config_path, os.R_OK)
|
||||||
|
|
||||||
|
return await _run_check(
|
||||||
|
"config_dir",
|
||||||
|
_probe(),
|
||||||
|
"Config directory not readable",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_scheduler(app_state: AppStateDep) -> ReadyCheck:
|
||||||
|
"""Check scheduler liveness."""
|
||||||
|
try:
|
||||||
|
scheduler = app_state.scheduler
|
||||||
|
if scheduler is not None and getattr(scheduler, "running", False):
|
||||||
|
return ReadyCheck(name="scheduler", healthy=True)
|
||||||
|
elif scheduler is not None:
|
||||||
|
return ReadyCheck(name="scheduler", healthy=False, message="Scheduler stopped")
|
||||||
|
else:
|
||||||
|
return ReadyCheck(name="scheduler", healthy=False, message="Not initialised")
|
||||||
|
except AttributeError:
|
||||||
|
return ReadyCheck(name="scheduler", healthy=False, message="Not accessible")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/ready",
|
||||||
|
summary="Subsystem readiness probe",
|
||||||
|
response_model=ReadyResponse,
|
||||||
|
responses={
|
||||||
|
200: {"description": "All subsystems healthy"},
|
||||||
|
503: {"description": "One or more subsystems unreachable"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
async def readiness_probe(
|
||||||
|
app_state: AppStateDep,
|
||||||
|
server_status: ServerStatusDep,
|
||||||
|
) -> JSONResponse:
|
||||||
|
"""Readiness check for Kubernetes.
|
||||||
|
|
||||||
|
Verifies all critical sub-systems are reachable:
|
||||||
|
- Database connectivity
|
||||||
|
- fail2ban socket (via cached server status)
|
||||||
|
- Config directory read access
|
||||||
|
- Background scheduler liveness
|
||||||
|
|
||||||
|
Returns HTTP 200 only when every check passes; returns HTTP 503 with a
|
||||||
|
JSON body listing every failed subsystem otherwise. Each check has a
|
||||||
|
short per-subsystem timeout to prevent the endpoint from overwhelming the
|
||||||
|
system under load.
|
||||||
|
"""
|
||||||
|
db_check, f2b_check, config_check, sched_check = await asyncio.gather(
|
||||||
|
_check_database(app_state),
|
||||||
|
_check_fail2ban(app_state, server_status),
|
||||||
|
_check_config_dir(app_state),
|
||||||
|
_check_scheduler(app_state),
|
||||||
|
)
|
||||||
|
|
||||||
|
checks: list[ReadyCheck] = [db_check, f2b_check, config_check, sched_check]
|
||||||
|
failed_count = sum(1 for c in checks if not c.healthy)
|
||||||
|
|
||||||
|
http_status = status.HTTP_200_OK if failed_count == 0 else status.HTTP_503_SERVICE_UNAVAILABLE
|
||||||
|
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=http_status,
|
||||||
|
content=ReadyResponse(
|
||||||
|
status="ok" if failed_count == 0 else "error",
|
||||||
|
checks=checks,
|
||||||
|
failed_count=failed_count,
|
||||||
|
).model_dump(),
|
||||||
|
)
|
||||||
|
|||||||
@@ -12,7 +12,15 @@ from httpx import ASGITransport, AsyncClient
|
|||||||
from app.config import Settings
|
from app.config import Settings
|
||||||
from app.db import init_db
|
from app.db import init_db
|
||||||
from app.exceptions import ConfigValidationError, ConfigWriteError, JailNotFoundError
|
from app.exceptions import ConfigValidationError, ConfigWriteError, JailNotFoundError
|
||||||
from app.main import CORSMiddleware, _enforce_single_worker, _lifespan, create_app
|
from app.main import (
|
||||||
|
CORSMiddleware,
|
||||||
|
_assert_middleware_order,
|
||||||
|
_enforce_single_worker,
|
||||||
|
_lifespan,
|
||||||
|
create_app,
|
||||||
|
)
|
||||||
|
from app.middleware.correlation import CorrelationIdMiddleware
|
||||||
|
from app.middleware.rate_limit import RateLimitMiddleware
|
||||||
from app.services import setup_service
|
from app.services import setup_service
|
||||||
|
|
||||||
|
|
||||||
@@ -450,14 +458,23 @@ async def test_startup_loads_geo_cache_from_persisted_runtime_database(tmp_path:
|
|||||||
exit_stack.enter_context(patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=load_cache))
|
exit_stack.enter_context(patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=load_cache))
|
||||||
exit_stack.enter_context(patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0)))
|
exit_stack.enter_context(patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0)))
|
||||||
exit_stack.enter_context(patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=True)))
|
exit_stack.enter_context(patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=True)))
|
||||||
exit_stack.enter_context(patch("app.services.setup_service.get_runtime_database_path", new=AsyncMock(return_value=runtime_db_path)))
|
exit_stack.enter_context(patch(
|
||||||
exit_stack.enter_context(patch("app.services.setup_service.get_persisted_runtime_settings", new=AsyncMock(return_value={
|
"app.services.setup_service.get_runtime_database_path",
|
||||||
"database_path": runtime_db_path,
|
new=AsyncMock(return_value=runtime_db_path),
|
||||||
"fail2ban_socket": "/tmp/persisted.sock",
|
))
|
||||||
"timezone": "Europe/Berlin",
|
exit_stack.enter_context(patch(
|
||||||
"session_duration_minutes": 123,
|
"app.services.setup_service.get_persisted_runtime_settings",
|
||||||
})))
|
new=AsyncMock(return_value={
|
||||||
exit_stack.enter_context(patch("app.services.setup_service.get_fail2ban_db_path", new=AsyncMock(return_value="/tmp/fail2ban/banned.tar.bz2")))
|
"database_path": runtime_db_path,
|
||||||
|
"fail2ban_socket": "/tmp/persisted.sock",
|
||||||
|
"timezone": "Europe/Berlin",
|
||||||
|
"session_duration_minutes": 123,
|
||||||
|
}),
|
||||||
|
))
|
||||||
|
exit_stack.enter_context(patch(
|
||||||
|
"app.services.setup_service.get_fail2ban_db_path",
|
||||||
|
new=AsyncMock(return_value="/tmp/fail2ban/banned.tar.bz2"),
|
||||||
|
))
|
||||||
exit_stack.enter_context(patch("app.tasks.health_check.register"))
|
exit_stack.enter_context(patch("app.tasks.health_check.register"))
|
||||||
exit_stack.enter_context(patch("app.tasks.blocklist_import.register"))
|
exit_stack.enter_context(patch("app.tasks.blocklist_import.register"))
|
||||||
exit_stack.enter_context(patch("app.tasks.geo_cache_flush.register"))
|
exit_stack.enter_context(patch("app.tasks.geo_cache_flush.register"))
|
||||||
@@ -466,8 +483,9 @@ async def test_startup_loads_geo_cache_from_persisted_runtime_database(tmp_path:
|
|||||||
|
|
||||||
with exit_stack:
|
with exit_stack:
|
||||||
async with _lifespan(app):
|
async with _lifespan(app):
|
||||||
loaded_db_path = load_cache.call_args.args[0]
|
runtime_connections = [
|
||||||
runtime_connections = [conn for path, conn in opened_connections if path == runtime_db_path]
|
conn for path, conn in opened_connections if path == runtime_db_path
|
||||||
|
]
|
||||||
assert runtime_connections, "Expected runtime database to be opened"
|
assert runtime_connections, "Expected runtime database to be opened"
|
||||||
|
|
||||||
assert app.state.runtime_settings is not None
|
assert app.state.runtime_settings is not None
|
||||||
@@ -538,6 +556,91 @@ async def test_concurrent_requests_use_request_scoped_db_connections(tmp_path: P
|
|||||||
assert all(connection.close.await_count == 1 for connection in connections)
|
assert all(connection.close.await_count == 1 for connection in connections)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Middleware order validation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_settings(tmp_path: Path) -> Settings:
|
||||||
|
"""Return a minimal Settings object with a temporary fail2ban config dir."""
|
||||||
|
fail2ban_config_dir = tmp_path / "fail2ban"
|
||||||
|
fail2ban_config_dir.mkdir()
|
||||||
|
return Settings(
|
||||||
|
database_path=str(tmp_path / "bangui.db"),
|
||||||
|
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||||
|
fail2ban_config_dir=str(fail2ban_config_dir),
|
||||||
|
session_secret="test-secret-key-do-not-use-in-production",
|
||||||
|
session_duration_minutes=60,
|
||||||
|
timezone="UTC",
|
||||||
|
log_level="debug",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_app_raises_on_incorrect_middleware_order(
|
||||||
|
tmp_path: Path,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""_assert_middleware_order() raises AssertionError when middleware order is wrong.
|
||||||
|
|
||||||
|
The security-critical chain requires:
|
||||||
|
RateLimitMiddleware → CsrfMiddleware → CorrelationIdMiddleware
|
||||||
|
in user_middleware (processing order: outermost → innermost).
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("TESTING", "1")
|
||||||
|
settings = _make_settings(tmp_path)
|
||||||
|
app = create_app(settings=settings)
|
||||||
|
# Swap CorrelationIdMiddleware and RateLimitMiddleware to break the order.
|
||||||
|
user_mw = app.user_middleware
|
||||||
|
corr_idx = next(i for i, m in enumerate(user_mw) if m.cls.__name__ == "CorrelationIdMiddleware")
|
||||||
|
rate_idx = next(i for i, m in enumerate(user_mw) if m.cls.__name__ == "RateLimitMiddleware")
|
||||||
|
user_mw[corr_idx], user_mw[rate_idx] = user_mw[rate_idx], user_mw[corr_idx]
|
||||||
|
with pytest.raises(AssertionError, match="must be registered before"):
|
||||||
|
_assert_middleware_order(app)
|
||||||
|
|
||||||
|
|
||||||
|
def test_middleware_order_validation_passes_for_correct_order(
|
||||||
|
tmp_path: Path,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""_assert_middleware_order() does not raise when middleware order is correct."""
|
||||||
|
monkeypatch.setenv("TESTING", "1")
|
||||||
|
settings = _make_settings(tmp_path)
|
||||||
|
app = create_app(settings=settings)
|
||||||
|
_assert_middleware_order(app) # Should not raise
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_app_validates_middleware_order_at_startup(
|
||||||
|
tmp_path: Path,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
) -> None:
|
||||||
|
"""create_app() raises immediately if middleware registration order is incorrect.
|
||||||
|
|
||||||
|
This test verifies the integration: _assert_middleware_order is called at the
|
||||||
|
end of create_app, so a fresh app with deliberately wrong middleware order
|
||||||
|
(simulated by patching add_middleware during creation) raises AssertionError.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("TESTING", "1")
|
||||||
|
settings = _make_settings(tmp_path)
|
||||||
|
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
|
||||||
|
original_add = Starlette.add_middleware
|
||||||
|
|
||||||
|
def swapping_add(self, middleware_cls: type, **kwargs: object) -> None:
|
||||||
|
"""Patched add_middleware that swaps CorrelationId and RateLimit."""
|
||||||
|
if middleware_cls is CorrelationIdMiddleware:
|
||||||
|
pass # Skip CorrelationId
|
||||||
|
elif middleware_cls is RateLimitMiddleware:
|
||||||
|
original_add(self, RateLimitMiddleware, **kwargs)
|
||||||
|
original_add(self, CorrelationIdMiddleware)
|
||||||
|
else:
|
||||||
|
original_add(self, middleware_cls, **kwargs)
|
||||||
|
|
||||||
|
with patch.object(Starlette, "add_middleware", swapping_add), \
|
||||||
|
pytest.raises(AssertionError, match="must be registered before"):
|
||||||
|
create_app(settings=settings)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Single-worker enforcement
|
# Single-worker enforcement
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
130
backend/tests/test_routers/test_health_probes.py
Normal file
130
backend/tests/test_routers/test_health_probes.py
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
"""Tests for the health-check router — liveness and readiness probes."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from httpx import AsyncClient
|
||||||
|
|
||||||
|
from app.models.server import ServerStatus
|
||||||
|
from app.models.response import ReadyCheck
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET /health/live — liveness probe
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_liveness_returns_200(client: AsyncClient) -> None:
|
||||||
|
"""``GET /api/v1/health/live`` must always return HTTP 200."""
|
||||||
|
response = await client.get("/api/v1/health/live")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_liveness_body_is_ready_response(client: AsyncClient) -> None:
|
||||||
|
"""Response body must be a ReadyResponse."""
|
||||||
|
response = await client.get("/api/v1/health/live")
|
||||||
|
data: dict[str, object] = response.json()
|
||||||
|
assert data["status"] == "ok"
|
||||||
|
assert data["failed_count"] == 0
|
||||||
|
assert "checks" in data
|
||||||
|
assert isinstance(data["checks"], list)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_liveness_includes_process_check(client: AsyncClient) -> None:
|
||||||
|
"""Liveness response must include a 'process' check."""
|
||||||
|
response = await client.get("/api/v1/health/live")
|
||||||
|
data: dict[str, object] = response.json()
|
||||||
|
checks: list[dict[str, object]] = data["checks"] # type: ignore[assignment]
|
||||||
|
assert any(c.get("name") == "process" and c.get("healthy") is True for c in checks)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# GET /health/ready — readiness probe
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_readiness_returns_200_when_all_pass(client: AsyncClient) -> None:
|
||||||
|
"""``GET /api/v1/health/ready`` must return 200 when all subsystems pass."""
|
||||||
|
with patch("app.routers.health._run_check", side_effect=lambda n, c, e: ReadyCheck(name=n, healthy=True)):
|
||||||
|
response = await client.get("/api/v1/health/ready")
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_readiness_returns_503_when_subsystem_fails(client: AsyncClient) -> None:
|
||||||
|
"""``GET /api/v1/health/ready`` must return 503 when at least one check fails."""
|
||||||
|
# Force fail2ban offline
|
||||||
|
client._transport.app.state.server_status = ServerStatus(online=False)
|
||||||
|
response = await client.get("/api/v1/health/ready")
|
||||||
|
assert response.status_code == 503
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_readiness_body_is_ready_response(client: AsyncClient) -> None:
|
||||||
|
"""Response body must be a ReadyResponse."""
|
||||||
|
response = await client.get("/api/v1/health/ready")
|
||||||
|
data: dict[str, object] = response.json()
|
||||||
|
assert data["status"] in ("ok", "error")
|
||||||
|
assert "failed_count" in data
|
||||||
|
assert "checks" in data
|
||||||
|
assert isinstance(data["checks"], list)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_readiness_includes_all_subsystems(client: AsyncClient) -> None:
|
||||||
|
"""Readiness response must include checks for all four subsystems."""
|
||||||
|
response = await client.get("/api/v1/health/ready")
|
||||||
|
data: dict[str, object] = response.json()
|
||||||
|
checks: list[dict[str, object]] = data["checks"] # type: ignore[assignment]
|
||||||
|
names = {c["name"] for c in checks}
|
||||||
|
assert names == {"database", "fail2ban", "config_dir", "scheduler"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_readiness_status_ok_when_all_healthy(client: AsyncClient) -> None:
|
||||||
|
"""``status`` must be 'ok' when all checks pass."""
|
||||||
|
with patch("app.routers.health._run_check", side_effect=lambda n, c, e: ReadyCheck(name=n, healthy=True)):
|
||||||
|
response = await client.get("/api/v1/health/ready")
|
||||||
|
data: dict[str, object] = response.json()
|
||||||
|
assert data["status"] == "ok"
|
||||||
|
assert data["failed_count"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_readiness_status_error_when_fail2ban_offline(client: AsyncClient) -> None:
|
||||||
|
"""``status`` must be 'error' when fail2ban is offline."""
|
||||||
|
client._transport.app.state.server_status = ServerStatus(online=False)
|
||||||
|
response = await client.get("/api/v1/health/ready")
|
||||||
|
data: dict[str, object] = response.json()
|
||||||
|
assert data["status"] == "error"
|
||||||
|
assert data["failed_count"] > 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_readiness_includes_failed_subsystem_detail(client: AsyncClient) -> None:
|
||||||
|
"""When fail2ban is offline the fail2ban check must include an error message."""
|
||||||
|
client._transport.app.state.server_status = ServerStatus(online=False)
|
||||||
|
response = await client.get("/api/v1/health/ready")
|
||||||
|
data: dict[str, object] = response.json()
|
||||||
|
checks: list[dict[str, object]] = data["checks"] # type: ignore[assignment]
|
||||||
|
f2b = next(c for c in checks if c["name"] == "fail2ban")
|
||||||
|
assert f2b["healthy"] is False
|
||||||
|
assert f2b["message"] is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_readiness_content_type_is_json(client: AsyncClient) -> None:
|
||||||
|
"""``/api/v1/health/ready`` must set the ``Content-Type`` header to JSON."""
|
||||||
|
response = await client.get("/api/v1/health/ready")
|
||||||
|
assert "application/json" in response.headers.get("content-type", "")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_readiness_live_content_type_is_json(client: AsyncClient) -> None:
|
||||||
|
"""``/api/v1/health/live`` must set the ``Content-Type`` header to JSON."""
|
||||||
|
response = await client.get("/api/v1/health/live")
|
||||||
|
assert "application/json" in response.headers.get("content-type", "")
|
||||||
@@ -53,3 +53,28 @@ With drift correction:
|
|||||||
- Total time from poll start to next poll start is always ~5 seconds
|
- Total time from poll start to next poll start is always ~5 seconds
|
||||||
- Fetch duration doesn't affect the long-term polling rate
|
- Fetch duration doesn't affect the long-term polling rate
|
||||||
- Bandwidth and CPU usage remain consistent
|
- Bandwidth and CPU usage remain consistent
|
||||||
|
|
||||||
|
## Request Deduplication with Abort Signal Handling
|
||||||
|
|
||||||
|
When `useFetchData` is called with a `requestKey`, multiple hook instances sharing the same key coalesce around a single in-flight request. Each instance is tracked as a **subscriber**.
|
||||||
|
|
||||||
|
### How it works
|
||||||
|
|
||||||
|
- A module-level `Map<requestKey, InFlightRequest>` holds in-flight requests.
|
||||||
|
- Each `InFlightRequest` contains a `subscribers` map: `subscriberId → { signal, onSetData, onSuccess }`.
|
||||||
|
- When a new hook instance joins a deduplicated request, it registers a subscriber entry and attaches a `once` abort listener that removes it from the subscriber list.
|
||||||
|
- When the underlying request resolves, **each subscriber's signal is checked before calling `setData` or `onSuccess`**. Aborted subscribers are skipped, preventing state updates on unmounted components.
|
||||||
|
- When all subscribers have aborted, the underlying `AbortController` is signalled, cancelling the in-flight fetch.
|
||||||
|
|
||||||
|
### Guarantees
|
||||||
|
|
||||||
|
| Scenario | Result |
|
||||||
|
|---|---|
|
||||||
|
| Subscriber aborts before request resolves | Subscriber receives no `setData`/`onSuccess`. Entry removed from subscriber list. |
|
||||||
|
| Last subscriber aborts before request resolves | Underlying request is cancelled via `AbortController`. |
|
||||||
|
| Request resolves, subscriber already aborted | `signal.aborted` checked before `setData`/`onSuccess` — no-op. |
|
||||||
|
| Request errors | Non-abort errors passed to `handleFetchError`. Each subscriber's error state updated independently. |
|
||||||
|
|
||||||
|
### When to use deduplication
|
||||||
|
|
||||||
|
Use `requestKey` when multiple components may request the same resource simultaneously. The shared promise avoids redundant network calls and prevents race conditions from out-of-order responses.
|
||||||
|
|||||||
@@ -1,8 +1,13 @@
|
|||||||
import { describe, it, expect, vi } from "vitest";
|
import { describe, it, expect, vi, beforeEach } from "vitest";
|
||||||
import { renderHook, act } from "@testing-library/react";
|
import { renderHook, act } from "@testing-library/react";
|
||||||
import { useFetchData } from "../useFetchData";
|
import { useFetchData, _resetInFlightRequests } from "../useFetchData";
|
||||||
import type { FetchError } from "../../types/api";
|
import type { FetchError } from "../../types/api";
|
||||||
|
|
||||||
|
// Clear the module-level inFlightRequests map between tests to prevent state leakage
|
||||||
|
beforeEach(() => {
|
||||||
|
_resetInFlightRequests();
|
||||||
|
});
|
||||||
|
|
||||||
describe("useFetchData", () => {
|
describe("useFetchData", () => {
|
||||||
it("fetches and selects data on mount", async () => {
|
it("fetches and selects data on mount", async () => {
|
||||||
const fetcher = vi.fn().mockResolvedValue({ value: "test" });
|
const fetcher = vi.fn().mockResolvedValue({ value: "test" });
|
||||||
@@ -406,4 +411,123 @@ describe("useFetchData", () => {
|
|||||||
expect(fetcher).toHaveBeenCalledTimes(2);
|
expect(fetcher).toHaveBeenCalledTimes(2);
|
||||||
expect(result.current.data).toBe("second");
|
expect(result.current.data).toBe("second");
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("aborted subscriber in deduplication does not receive setData/onSuccess", async () => {
|
||||||
|
let resolveFirst: ((value: { value: string }) => void) | null = null;
|
||||||
|
const fetcher = vi.fn().mockImplementation(
|
||||||
|
() =>
|
||||||
|
new Promise((resolve) => {
|
||||||
|
resolveFirst = resolve;
|
||||||
|
})
|
||||||
|
);
|
||||||
|
const selector = vi.fn((response: { value: string }) => response.value);
|
||||||
|
const onSuccess1 = vi.fn();
|
||||||
|
const onSuccess2 = vi.fn();
|
||||||
|
|
||||||
|
const hook1 = renderHook(() =>
|
||||||
|
useFetchData({
|
||||||
|
fetcher,
|
||||||
|
selector,
|
||||||
|
errorMessage: "Failed to load",
|
||||||
|
requestKey: "abort-subscriber-test",
|
||||||
|
onSuccess: onSuccess1,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
const hook2 = renderHook(() =>
|
||||||
|
useFetchData({
|
||||||
|
fetcher,
|
||||||
|
selector,
|
||||||
|
errorMessage: "Failed to load",
|
||||||
|
requestKey: "abort-subscriber-test",
|
||||||
|
onSuccess: onSuccess2,
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await Promise.resolve();
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(fetcher).toHaveBeenCalledTimes(1);
|
||||||
|
|
||||||
|
// Unmount hook2 (simulates component unmounting and aborting its signal)
|
||||||
|
hook2.unmount();
|
||||||
|
|
||||||
|
// Allow the request to resolve
|
||||||
|
await act(async () => {
|
||||||
|
resolveFirst?.({ value: "shared-data" });
|
||||||
|
await Promise.resolve();
|
||||||
|
});
|
||||||
|
|
||||||
|
// hook2 should not have received the data (it aborted)
|
||||||
|
expect(hook2.result.current.data).toBeUndefined();
|
||||||
|
expect(onSuccess2).not.toHaveBeenCalled();
|
||||||
|
// hook1 should still have the data
|
||||||
|
expect(hook1.result.current.data).toBe("shared-data");
|
||||||
|
expect(onSuccess1).toHaveBeenCalledWith({ value: "shared-data" });
|
||||||
|
});
|
||||||
|
|
||||||
|
it("last subscriber abort cancels underlying request", async () => {
|
||||||
|
let resolveFirst: ((value: { value: string }) => void) | null = null;
|
||||||
|
const abortSignals: AbortSignal[] = [];
|
||||||
|
const fetcher = vi.fn().mockImplementation((signal: AbortSignal) => {
|
||||||
|
abortSignals.push(signal);
|
||||||
|
return new Promise((resolve) => {
|
||||||
|
resolveFirst = resolve;
|
||||||
|
});
|
||||||
|
});
|
||||||
|
const selector = vi.fn((response: { value: string }) => response.value);
|
||||||
|
|
||||||
|
const hook1 = renderHook(() =>
|
||||||
|
useFetchData({
|
||||||
|
fetcher,
|
||||||
|
selector,
|
||||||
|
errorMessage: "Failed to load",
|
||||||
|
requestKey: "cancel-underlying-test",
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await Promise.resolve();
|
||||||
|
});
|
||||||
|
|
||||||
|
expect(fetcher).toHaveBeenCalledTimes(1);
|
||||||
|
|
||||||
|
// Mount second subscriber
|
||||||
|
const hook2 = renderHook(() =>
|
||||||
|
useFetchData({
|
||||||
|
fetcher,
|
||||||
|
selector,
|
||||||
|
errorMessage: "Failed to load",
|
||||||
|
requestKey: "cancel-underlying-test",
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await Promise.resolve();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Both subscribers now sharing one request
|
||||||
|
expect(fetcher).toHaveBeenCalledTimes(1);
|
||||||
|
|
||||||
|
// Unmount first subscriber (doesn't cancel underlying request yet)
|
||||||
|
hook1.unmount();
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await Promise.resolve();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Underlying request should NOT be cancelled yet (hook2 still waiting)
|
||||||
|
expect(abortSignals[0]?.aborted).toBe(false);
|
||||||
|
|
||||||
|
// Unmount second (last) subscriber — should cancel the underlying request
|
||||||
|
hook2.unmount();
|
||||||
|
|
||||||
|
await act(async () => {
|
||||||
|
await Promise.resolve();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Underlying request should now be cancelled
|
||||||
|
expect(abortSignals[0]?.aborted).toBe(true);
|
||||||
|
});
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -19,16 +19,33 @@ import type { FetchError } from "../types/api";
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Module-level cache for in-flight requests.
|
* Module-level cache for in-flight requests.
|
||||||
* Maps requestKey to { promise, controller } to enable deduplication
|
* Maps requestKey to { promise, controller, subscribers } to enable deduplication
|
||||||
* across multiple hook instances.
|
* across multiple hook instances.
|
||||||
*/
|
*/
|
||||||
|
interface Subscriber<TResponse> {
|
||||||
|
/** This instance's abort signal. */
|
||||||
|
signal: AbortSignal;
|
||||||
|
/** Callbacks registered by this subscriber. */
|
||||||
|
onSetData?: (response: TResponse) => void;
|
||||||
|
onSuccess?: (response: TResponse) => void;
|
||||||
|
}
|
||||||
|
|
||||||
interface InFlightRequest<TResponse> {
|
interface InFlightRequest<TResponse> {
|
||||||
promise: Promise<TResponse>;
|
promise: Promise<TResponse>;
|
||||||
controller: AbortController;
|
controller: AbortController;
|
||||||
|
/** Map of subscriberId -> Subscriber. Cleared when subscriber aborts. */
|
||||||
|
subscribers: Map<string, Subscriber<TResponse>>;
|
||||||
|
/** True when initiator has cleaned up but subscribers remain. */
|
||||||
|
initiatorDone: boolean;
|
||||||
}
|
}
|
||||||
|
|
||||||
const inFlightRequests = new Map<string, InFlightRequest<unknown>>();
|
const inFlightRequests = new Map<string, InFlightRequest<unknown>>();
|
||||||
|
|
||||||
|
/** Visible for testing only. Clears all in-flight request state. */
|
||||||
|
export const _resetInFlightRequests = (): void => {
|
||||||
|
inFlightRequests.clear();
|
||||||
|
};
|
||||||
|
|
||||||
export interface UseFetchDataOptions<TResponse, TData> {
|
export interface UseFetchDataOptions<TResponse, TData> {
|
||||||
/** Async function that accepts an AbortSignal for cancellation. */
|
/** Async function that accepts an AbortSignal for cancellation. */
|
||||||
fetcher: (signal: AbortSignal) => Promise<TResponse>;
|
fetcher: (signal: AbortSignal) => Promise<TResponse>;
|
||||||
@@ -86,18 +103,57 @@ export function useFetchData<TResponse, TData>(
|
|||||||
const [error, setError] = useState<FetchError | null>(null);
|
const [error, setError] = useState<FetchError | null>(null);
|
||||||
const abortRef = useRef<AbortController | null>(null);
|
const abortRef = useRef<AbortController | null>(null);
|
||||||
const localControllerRef = useRef<AbortController | null>(null);
|
const localControllerRef = useRef<AbortController | null>(null);
|
||||||
|
const subscriberAbortRef = useRef<AbortController | null>(null);
|
||||||
|
/** Unique ID for this instance, used to track its subscription to deduplicated requests. */
|
||||||
|
const subscriberIdRef = useRef<string | null>(null);
|
||||||
|
|
||||||
const refresh = useCallback((): void => {
|
const refresh = useCallback((): void => {
|
||||||
|
// Abort any previous request from this hook instance
|
||||||
|
abortRef.current?.abort();
|
||||||
|
localControllerRef.current = new AbortController();
|
||||||
|
abortRef.current = localControllerRef.current;
|
||||||
|
|
||||||
// If using request deduplication via requestKey and a request is already in-flight,
|
// If using request deduplication via requestKey and a request is already in-flight,
|
||||||
// wait for it to complete instead of launching a duplicate
|
// subscribe to it instead of launching a duplicate
|
||||||
if (requestKey && inFlightRequests.has(requestKey)) {
|
if (requestKey && inFlightRequests.has(requestKey)) {
|
||||||
const inFlight = inFlightRequests.get(requestKey)! as InFlightRequest<TResponse>;
|
const inFlight = inFlightRequests.get(requestKey)! as InFlightRequest<TResponse>;
|
||||||
inFlight.promise
|
|
||||||
.then((response: TResponse) => {
|
// Create per-instance abort controller for this subscription.
|
||||||
|
// Do NOT abort previous subscription signals here - each subscriber's signal
|
||||||
|
// only aborts when THAT subscriber unmounts. Aborting old signals would
|
||||||
|
// incorrectly remove other active subscribers from the list.
|
||||||
|
subscriberAbortRef.current = new AbortController();
|
||||||
|
|
||||||
|
// Register this instance as a subscriber.
|
||||||
|
// Always generate a new ID - each hook instance must have its own subscriber ID
|
||||||
|
// so that unmounting one subscriber does not remove another subscriber's entry.
|
||||||
|
subscriberIdRef.current = crypto.randomUUID();
|
||||||
|
const sid = subscriberIdRef.current;
|
||||||
|
const subscription: Subscriber<TResponse> = {
|
||||||
|
signal: subscriberAbortRef.current.signal,
|
||||||
|
onSetData: (response: TResponse) => {
|
||||||
setData(selector(response));
|
setData(selector(response));
|
||||||
if (onSuccess) {
|
if (onSuccess) {
|
||||||
onSuccess(response);
|
onSuccess(response);
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
onSuccess,
|
||||||
|
};
|
||||||
|
inFlight.subscribers.set(sid, subscription);
|
||||||
|
|
||||||
|
// When this instance aborts, remove it from the subscriber list.
|
||||||
|
// Note: we do NOT abort the underlying controller here - that belongs to
|
||||||
|
// the initiator and should only be aborted by the initiator itself.
|
||||||
|
const cleanup = (): void => {
|
||||||
|
inFlight.subscribers.delete(sid);
|
||||||
|
};
|
||||||
|
subscription.signal.addEventListener("abort", cleanup, { once: true });
|
||||||
|
|
||||||
|
inFlight.promise
|
||||||
|
.then((response: TResponse) => {
|
||||||
|
// Check whether this subscriber has already aborted before propagating result
|
||||||
|
if (subscription.signal.aborted) return;
|
||||||
|
subscription.onSetData?.(response);
|
||||||
})
|
})
|
||||||
.catch((err: unknown) => {
|
.catch((err: unknown) => {
|
||||||
// Only handle non-abort errors; abort errors are silently ignored
|
// Only handle non-abort errors; abort errors are silently ignored
|
||||||
@@ -139,7 +195,6 @@ export function useFetchData<TResponse, TData>(
|
|||||||
if (!controller.signal.aborted) {
|
if (!controller.signal.aborted) {
|
||||||
setLoading(false);
|
setLoading(false);
|
||||||
}
|
}
|
||||||
// Clear cache entry when response arrives
|
|
||||||
if (requestKey) {
|
if (requestKey) {
|
||||||
inFlightRequests.delete(requestKey);
|
inFlightRequests.delete(requestKey);
|
||||||
}
|
}
|
||||||
@@ -150,6 +205,8 @@ export function useFetchData<TResponse, TData>(
|
|||||||
inFlightRequests.set(requestKey, {
|
inFlightRequests.set(requestKey, {
|
||||||
promise: responsePromise,
|
promise: responsePromise,
|
||||||
controller,
|
controller,
|
||||||
|
subscribers: new Map(),
|
||||||
|
initiatorDone: false,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
}, [fetcher, selector, errorMessage, onSuccess, requestKey]);
|
}, [fetcher, selector, errorMessage, onSuccess, requestKey]);
|
||||||
@@ -158,7 +215,25 @@ export function useFetchData<TResponse, TData>(
|
|||||||
refresh();
|
refresh();
|
||||||
|
|
||||||
return (): void => {
|
return (): void => {
|
||||||
abortRef.current?.abort();
|
if (subscriberIdRef.current !== null) {
|
||||||
|
// Subscriber cleanup
|
||||||
|
const inFlight = requestKey ? inFlightRequests.get(requestKey) : undefined;
|
||||||
|
subscriberAbortRef.current?.abort();
|
||||||
|
// If initiator already done and no subscribers left, cancel the request
|
||||||
|
if (inFlight && inFlight.initiatorDone && inFlight.subscribers.size === 0) {
|
||||||
|
inFlight.controller.abort();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Initiator cleanup
|
||||||
|
const inFlight = requestKey ? inFlightRequests.get(requestKey) : undefined;
|
||||||
|
// Mark initiator as done so last subscriber knows to cancel
|
||||||
|
if (inFlight) {
|
||||||
|
inFlight.initiatorDone = true;
|
||||||
|
}
|
||||||
|
if (!inFlight || inFlight.subscribers.size === 0) {
|
||||||
|
abortRef.current?.abort();
|
||||||
|
}
|
||||||
|
}
|
||||||
};
|
};
|
||||||
}, [refresh]);
|
}, [refresh]);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user