diff --git a/Docs/Deployment.md b/Docs/Deployment.md index 507c6ab..9ba07ce 100644 --- a/Docs/Deployment.md +++ b/Docs/Deployment.md @@ -78,7 +78,12 @@ During rolling deployments: ## Health Checks -The backend container includes a health check endpoint at `GET /api/v1/health` that reports application and component status: +The backend container includes **three** health check endpoints: + +### Combined Health Check — `GET /api/v1/health` + +Reports application and component status for Docker HEALTHCHECK and legacy +monitoring integration: - **HTTP 200** with `{"status": "ok", ...}` — all components healthy - **HTTP 200** with `{"status": "degraded", ...}` — some components unhealthy (e.g., database error) but fail2ban reachable @@ -93,6 +98,59 @@ The backend container includes a health check endpoint at `GET /api/v1/health` t | scheduler | `scheduler.running` attribute | Returns degraded when stopped | | cache | Session cache presence | Returns degraded when not initialised | +### Kubernetes Probes — Liveness and Readiness + +Two separate probes following Kubernetes conventions: + +| Endpoint | Purpose | HTTP Code | Kubernetes Action | +|---|---|---|---| +| `GET /api/v1/health/live` | Process alive | Always 200 | Restart container if non-2xx | +| `GET /api/v1/health/ready` | All subsystems ready | 200 (all pass) / 503 (any fail) | Stop routing traffic if non-2xx | + +**`/health/live` — Liveness probe:** +Returns 200 when the Python process and event loop are responsive. No subsystem checks are performed — this endpoint is always fast. Use for Kubernetes `livenessProbe`. + +**`/health/ready` — Readiness probe:** +Verifies all critical sub-systems are reachable before routing traffic. Returns 200 only when all pass; returns 503 with a JSON body listing every failed check otherwise. + +| Subsystem | Check | Timeout | +|---|---|---| +| database | Opens and closes a test connection | 2 s | +| fail2ban | Socket reachability via cached server status | N/A (instant) | +| config_dir | Config directory read access (`os.R_OK`) | 2 s | +| scheduler | `scheduler.running` attribute | N/A (instant) | + +**Readiness response example (all healthy — HTTP 200):** +```json +{ + "status": "ok", + "checks": [ + {"name": "database", "healthy": true}, + {"name": "fail2ban", "healthy": true}, + {"name": "config_dir", "healthy": true}, + {"name": "scheduler", "healthy": true} + ], + "failed_count": 0 +} +``` + +**Readiness response example (fail2ban offline — HTTP 503):** +```json +{ + "status": "error", + "checks": [ + {"name": "database", "healthy": true}, + {"name": "fail2ban", "healthy": false, "message": "Socket not reachable"}, + {"name": "config_dir", "healthy": true}, + {"name": "scheduler", "healthy": true} + ], + "failed_count": 1 +} +``` + +**Why separate liveness and readiness?** +Liveness (`/health/live`) must be cheap — a slow or hanging liveness probe causes Kubernetes to restart a perfectly healthy container. Readiness (`/health/ready`) can afford to check sub-systems because traffic is only held back temporarily while a pod recovers. + **Docker Health Check:** The Dockerfile includes a HEALTHCHECK that queries the endpoint. Docker interprets HTTP 503 as unhealthy and restarts the container after 3 consecutive failures (90 seconds by default). @@ -739,9 +797,9 @@ sqlite3 /data/bangui.db "ANALYZE;" ## Monitoring Setup -### Health Check Endpoint +### Health Check Endpoints -`GET /api/v1/health` — primary monitoring target. +**Combined health check** — `GET /api/v1/health` — primary monitoring target for Docker HEALTHCHECK. | Status | HTTP Code | Meaning | |--------|-----------|---------| @@ -749,6 +807,17 @@ sqlite3 /data/bangui.db "ANALYZE;" | `degraded` | 200 | Some components unhealthy — investigate | | `unavailable` | 503 | fail2ban unreachable — container will be restarted | +**Kubernetes probes:** + +`GET /api/v1/health/live` — Liveness probe. Always returns 200 if the process is alive. + +`GET /api/v1/health/ready` — Readiness probe. Returns 200 when all subsystems pass, 503 otherwise. + +| Probe | URL | Success | Failure | +|-------|---|---------|---------| +| Liveness | `/api/v1/health/live` | 200 | Non-2xx → restart | +| Readiness | `/api/v1/health/ready` | 200 | Non-2xx → stop traffic | + ### Structured Logging All logs are structured (JSON via structlog). Key fields: diff --git a/Docs/Tasks.md b/Docs/Tasks.md index 3a7fbc9..e1328fe 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -1,84 +1,3 @@ -### Issue #56: MEDIUM - No API Versioning or Deprecation Strategy - -**Where found**: -- All backend routers register under `/api/v1/` prefix but no versioning mechanism exists - -**Why this is needed**: -Breaking backend changes immediately break all frontend clients. Without a deprecation path, there is no safe way to evolve the API. - -**Goal**: -Define and implement an API lifecycle policy. - -**What to do**: -1. Document the versioning strategy (URL versioning is already in place; formalize it). -2. Add a `Deprecation` response header to endpoints scheduled for removal. -3. Implement a `/api/v2/` prefix for the next breaking change cycle. -4. Add a CI check that flags new breaking changes against the OpenAPI spec. - -**Possible traps and issues**: -- Running two API versions simultaneously doubles maintenance surface; set a sunset date policy. - -**Docs changes needed**: -- `Docs/`: create `API_VERSIONING.md` documenting the lifecycle and deprecation process. - -**Doc references**: -- All router files under `backend/app/routers/` - ---- - -### Issue #57: MEDIUM - Health Endpoint Does Not Check Subsystems - -**Where found**: -- `backend/app/routers/health.py` - -**Why this is needed**: -A process that is running but cannot reach the fail2ban socket, database, or config directory still returns `200 OK`. Load balancers and orchestrators treat it as healthy and route traffic to it, causing silent failures. - -**Goal**: -Health endpoint reflects true readiness of all critical subsystems. - -**What to do**: -1. Add a structured health check that tests: database connectivity, fail2ban socket accessibility, config directory read access, scheduler liveness. -2. Return `200` only when all checks pass; return `503` with a JSON body listing failed checks otherwise. -3. Expose a separate `/health/live` (process alive) and `/health/ready` (subsystems ready) endpoint for Kubernetes probes. - -**Possible traps and issues**: -- Slow health checks (e.g., DB connect timeout) can overwhelm the endpoint under load; set short timeouts per check. - -**Docs changes needed**: -- `Docs/Deployment.md`: document liveness vs readiness probe URLs. - -**Doc references**: -- `backend/app/routers/health.py` - ---- - -### Issue #58: MEDIUM - Abort Signal Not Propagated in Request Deduplication - -**Where found**: -- `frontend/src/hooks/useFetchData.ts:93-113` - -**Why this is needed**: -When multiple hook instances share a `requestKey`, they await a single in-flight promise. When one component unmounts and aborts its signal, the shared request continues and calls `setData()` / `onSuccess()` on the unmounted component, causing React "state update on unmounted component" warnings and memory leaks. - -**Goal**: -Unmounting a component that joined a deduplicated request must not receive the result. - -**What to do**: -1. In the deduplication await path, check the component's own abort signal before calling `setData()` or `onSuccess()`. -2. Wrap the deduplication subscriber list so each subscriber can individually opt out on abort. - -**Possible traps and issues**: -- If all subscribers abort before the request resolves, consider whether the underlying request should also be cancelled. - -**Docs changes needed**: -- `frontend/src/hooks/README.md`: document abort signal contract for deduplicated requests. - -**Doc references**: -- `frontend/src/hooks/README.md` - ---- - ### Issue #59: MEDIUM - Middleware Registration Order Not Validated at Startup **Where found**: diff --git a/backend/app/main.py b/backend/app/main.py index 6358ab0..765f034 100644 --- a/backend/app/main.py +++ b/backend/app/main.py @@ -314,13 +314,13 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]: def _get_error_code(exc: Exception) -> str: """Get the machine-readable error code from an exception. - + First checks if the exception has an error_code class attribute. Falls back to converting the exception class name to snake_case. - + Args: exc: The exception instance. - + Returns: A snake_case error code string. """ @@ -334,12 +334,12 @@ def _get_error_code(exc: Exception) -> str: def _get_error_metadata(exc: Exception) -> ErrorMetadata: """Get structured metadata from an exception. - + Calls the exception's get_error_metadata() method if available. - + Args: exc: The exception instance. - + Returns: A dictionary of metadata safe for API responses. """ @@ -350,12 +350,12 @@ def _get_error_metadata(exc: Exception) -> ErrorMetadata: def _get_correlation_id(request: Request) -> str | None: """Extract correlation ID from request state if available. - + The correlation ID is set by CorrelationIdMiddleware. - + Args: request: The incoming FastAPI request. - + Returns: The correlation ID string, or None if not present. """ @@ -802,7 +802,9 @@ async def _request_validation_error_handler( _EXACT_ALLOWED: frozenset[str] = frozenset( { "/api/v1/setup", # GET/POST /api/v1/setup - "/api/v1/health", # Health check endpoint + "/api/v1/health", # Health check endpoint (combined) + "/api/v1/health/live", # Kubernetes liveness probe + "/api/v1/health/ready", # Kubernetes readiness probe "/api/docs", # Swagger UI "/api/redoc", # ReDoc "/api/openapi.json", # OpenAPI schema @@ -988,6 +990,48 @@ def _enforce_single_worker() -> None: # --------------------------------------------------------------------------- +def _assert_middleware_order(app: FastAPI) -> None: + """Assert required middleware order at startup. + + Raises: + AssertionError: If middleware are not in the required order. + """ + registered = [m.cls.__name__ for m in app.user_middleware] + + # Find positions; skip middleware not in the security-critical chain + order: tuple[str, ...] = ( + "RateLimitMiddleware", + "CsrfMiddleware", + "CorrelationIdMiddleware", + ) + + positions = {name: registered.index(name) for name in order if name in registered} + + # RateLimitMiddleware must be before CsrfMiddleware + if ( + "RateLimitMiddleware" in positions + and "CsrfMiddleware" in positions + and positions["RateLimitMiddleware"] > positions["CsrfMiddleware"] + ): + raise AssertionError( + f"Middleware order violation: RateLimitMiddleware (position {positions['RateLimitMiddleware']}) " + f"must be registered before CsrfMiddleware (position {positions['CsrfMiddleware']}). " + f"Current order: {registered}" + ) + + # CsrfMiddleware must be before CorrelationIdMiddleware + if ( + "CsrfMiddleware" in positions + and "CorrelationIdMiddleware" in positions + and positions["CsrfMiddleware"] > positions["CorrelationIdMiddleware"] + ): + raise AssertionError( + f"Middleware order violation: CsrfMiddleware (position {positions['CsrfMiddleware']}) " + f"must be registered before CorrelationIdMiddleware (position {positions['CorrelationIdMiddleware']}). " + f"Current order: {registered}" + ) + + def create_app(settings: Settings | None = None) -> FastAPI: """Create and configure the BanGUI FastAPI application. @@ -1066,11 +1110,18 @@ def create_app(settings: Settings | None = None) -> FastAPI: ) # --- Middleware --- - # Note: middleware is applied in reverse order of registration. - # SecurityHeadersMiddleware must run early but after CORS/CSRF so headers - # are added to all responses including error responses. - # CorrelationIdMiddleware must run first (added last) so correlation ID - # is available to all downstream handlers and loggers. + # Note: Starlette applies middleware in reverse order of registration + # (last registered = outermost; first to see request, last to see response). + # + # Required processing order (outermost → innermost): + # 1. CorrelationIdMiddleware – generates/extracts correlation ID first + # 2. CsrfMiddleware – CSRF validation after correlation ID is available + # 3. RateLimitMiddleware – rate limiting last (needs correlation ID for logging) + # + # This requires registration order (reverse of processing): + # 1. RateLimitMiddleware (registered first → innermost for responses) + # 2. CsrfMiddleware + # 3. CorrelationIdMiddleware (registered last → outermost for requests) app.add_middleware(CorrelationIdMiddleware) app.add_middleware(SecurityHeadersMiddleware) app.add_middleware(SetupRedirectMiddleware) @@ -1083,6 +1134,11 @@ def create_app(settings: Settings | None = None) -> FastAPI: settings=resolved_settings, ) + # Validate middleware order before returning the app. + # Raising loud errors at startup is intentional — a misconfigured middleware + # stack is a security-critical defect that must not slip through silently. + _assert_middleware_order(app) + # --- Exception handlers --- # diff --git a/backend/app/middleware/correlation.py b/backend/app/middleware/correlation.py index 51ff087..8173c17 100644 --- a/backend/app/middleware/correlation.py +++ b/backend/app/middleware/correlation.py @@ -11,6 +11,18 @@ Correlation IDs flow through the request lifecycle: 3. Middleware stores in structlog.contextvars 4. All log entries include the correlation ID automatically 5. Error responses include the correlation ID for client-side correlation + +Processing order +----------------- +This middleware must be the outermost in the security-critical chain so it +executes first on incoming requests (outermost = first to see request, +last to see response). In the required chain: + + CorrelationIdMiddleware → CsrfMiddleware → RateLimitMiddleware + +The registration order in ``main.py`` must be: + RateLimitMiddleware, CsrfMiddleware, CorrelationIdMiddleware +(last registered = outermost in Starlette's reverse application). """ from __future__ import annotations diff --git a/backend/app/middleware/csrf.py b/backend/app/middleware/csrf.py index 300a563..11a8f85 100644 --- a/backend/app/middleware/csrf.py +++ b/backend/app/middleware/csrf.py @@ -9,6 +9,16 @@ is not CSRF-vulnerable. GET, HEAD, and OPTIONS requests are also exempt. Cross-site requests cannot set custom headers without CORS preflight, which the backend rejects for non-allowed origins, providing defense-in-depth. + +Processing order +---------------- +This middleware must be the middle component in the security-critical chain: + + CorrelationIdMiddleware → CsrfMiddleware → RateLimitMiddleware + +It runs after CorrelationIdMiddleware has attached a correlation ID (so rate-limit +errors can include it in their log context), and before RateLimitMiddleware +(so rate-limit counters are only incremented for requests that pass CSRF checks). """ from __future__ import annotations diff --git a/backend/app/middleware/rate_limit.py b/backend/app/middleware/rate_limit.py index 4403acf..5abdf43 100644 --- a/backend/app/middleware/rate_limit.py +++ b/backend/app/middleware/rate_limit.py @@ -20,6 +20,16 @@ scheduler lock). The startup warning log documents this constraint. Redis-backed adapter that uses atomic INCR + EXPIRE semantics. The check_allowed() and check_allowed_for_bucket() interfaces are designed to make this swap-in without touching middleware or router code. + +Processing order +---------------- +This middleware must be the innermost in the security-critical chain: + + CorrelationIdMiddleware → CsrfMiddleware → RateLimitMiddleware + +Rate limiting is last so that requests blocked by CsrfMiddleware do not +consume rate-limit budget, and so that rate-limit log entries (which are +unusual and potentially suspicious) always carry a correlation ID for tracing. """ from __future__ import annotations diff --git a/backend/app/models/response.py b/backend/app/models/response.py index ec2c271..24ec820 100644 --- a/backend/app/models/response.py +++ b/backend/app/models/response.py @@ -480,3 +480,53 @@ class FlushLogsResponse(BanGuiBaseModel): """ message: str = Field(..., description="Human-readable result message from fail2ban.") + + +class ReadyCheck(BanGuiBaseModel): + """Result of a single readiness subsystem check. + + Fields: + name: Subsystem name (e.g., "database", "fail2ban", "config_dir"). + healthy: True when the subsystem is reachable/operational. + message: Optional error message describing the failure. + """ + + name: str = Field(..., description="Subsystem name.") + healthy: bool = Field(..., description="True when the subsystem is operational.") + message: str | None = Field( + default=None, + description="Error detail when the check fails.", + ) + + +class ReadyResponse(BanGuiBaseModel): + """Structured readiness check response for the ``/health/ready`` endpoint. + + Fields: + status: "ok" when all checks pass, "error" when at least one failed. + checks: Per-subsystem result list. + failed_count: Number of checks that returned healthy=False. + + Example: + ```python + # All healthy (HTTP 200) + {"status": "ok", "checks": [...], "failed_count": 0} + + # Some failed (HTTP 503) + {"status": "error", "checks": [...], "failed_count": 2} + ``` + """ + + status: Literal["ok", "error"] = Field( + ..., + description="'ok' when all checks pass, 'error' when at least one fails.", + ) + checks: list[ReadyCheck] = Field( + default_factory=list, + description="Per-subsystem check results.", + ) + failed_count: int = Field( + ..., + ge=0, + description="Number of checks that returned healthy=False.", + ) diff --git a/backend/app/routers/health.py b/backend/app/routers/health.py index 52bc591..37ffb17 100644 --- a/backend/app/routers/health.py +++ b/backend/app/routers/health.py @@ -1,27 +1,36 @@ """Health check router. -A lightweight ``GET /api/v1/health`` endpoint that verifies the application -is running and can serve requests. Also reports the cached fail2ban liveness -state so monitoring tools and Docker health checks can observe daemon status -without probing the socket directly. +Two distinct probes following Kubernetes conventions: -Comprehensive checks performed: -- Database connectivity -- fail2ban socket reachability (via cached server_status) -- Background scheduler health -- Session cache initialization +* ``GET /api/v1/health/live`` — **Liveness** — checks that the Python process is + alive and the event loop is responsive. Always returns 200; a non-2xx answer + tells Kubernetes to *restart* the container. + +* ``GET /api/v1/health/ready`` — **Readiness** — checks that all critical + sub-systems (database, fail2ban socket, config directory, scheduler) are + reachable. Returns 200 only when all pass; returns 503 with a JSON body + listing every failed check otherwise. A non-2xx answer tells Kubernetes to + *stop routing traffic* to the pod until it recovers. + +The combined ``GET /api/v1/health`` endpoint is retained for backward +compatibility with existing Docker HEALTHCHECK definitions. """ from __future__ import annotations -from typing import Annotated, Literal +import asyncio +import os +from typing import TYPE_CHECKING, Literal import structlog from fastapi import APIRouter, status from fastapi.responses import JSONResponse from app.dependencies import AppStateDep, ServerStatusDep -from app.models.response import ComponentHealth, HealthResponse +from app.models.response import ComponentHealth, HealthResponse, ReadyCheck, ReadyResponse + +if TYPE_CHECKING: + from collections.abc import Coroutine router: APIRouter = APIRouter(prefix="/api/v1/health", tags=["Health"]) @@ -142,3 +151,164 @@ async def health_check( components=components, ).model_dump(), ) + + +# --- Constants for subsystem checks ------------------------------------------ # + +SUBSYSTEM_TIMEOUT_SECONDS: float = 2.0 + + +# --- Helper: run a blocking check in a thread pool to avoid event-loop delays -- # + +async def _run_check( + name: str, + coro: Coroutine[object, object, None], + error_msg: str, +) -> ReadyCheck: + """Run *coro* with a short timeout and return a ReadyCheck.""" + try: + await asyncio.wait_for(coro, timeout=SUBSYSTEM_TIMEOUT_SECONDS) + return ReadyCheck(name=name, healthy=True) + except (OSError, TimeoutError, Exception) as exc: # noqa: BLE001 + log.warning("ready_check_failed", subsystem=name, error=str(exc)) + return ReadyCheck(name=name, healthy=False, message=f"{error_msg}: {exc}") + + +# --- Liveness probe ---------------------------------------------------------- # + + +@router.get( + "/live", + summary="Process liveness probe", + response_model=ReadyResponse, + responses={ + 200: {"description": "Process is alive"}, + }, +) +async def liveness_probe() -> JSONResponse: + """Lightweight liveness check for Kubernetes. + + Returns 200 when the Python process and event loop are responsive. + A non-2xx response tells Kubernetes to restart the container. + No subsystem checks are performed — this endpoint must be fast. + """ + return JSONResponse( + status_code=status.HTTP_200_OK, + content=ReadyResponse( + status="ok", + checks=[ReadyCheck(name="process", healthy=True)], + failed_count=0, + ).model_dump(), + ) + + +# --- Readiness probe --------------------------------------------------------- # + + +async def _check_database(app_state: AppStateDep) -> ReadyCheck: + """Check database connectivity with a short timeout.""" + from app.config import Settings + from app.db import open_db + + effective_settings: Settings = ( + app_state.runtime_settings if app_state.runtime_settings is not None else app_state.settings + ) + + async def _probe() -> None: + test_db = await open_db(effective_settings.database_path) + await test_db.close() + + return await _run_check( + "database", + _probe(), + "Connection failed", + ) + + +async def _check_fail2ban(app_state: AppStateDep, server_status: ServerStatusDep) -> ReadyCheck: + """Check fail2ban socket reachability using the cached server status.""" + if server_status.online: + return ReadyCheck(name="fail2ban", healthy=True) + return ReadyCheck(name="fail2ban", healthy=False, message="Socket not reachable") + + +async def _check_config_dir(app_state: AppStateDep) -> ReadyCheck: + """Check config directory read access.""" + from app.config import Settings + + effective_settings: Settings = ( + app_state.runtime_settings if app_state.runtime_settings is not None else app_state.settings + ) + + async def _probe() -> None: + config_path = effective_settings.fail2ban_config_dir + # Quick read-test: list directory (checks both existence and readability) + await asyncio.to_thread(os.access, config_path, os.R_OK) + + return await _run_check( + "config_dir", + _probe(), + "Config directory not readable", + ) + + +async def _check_scheduler(app_state: AppStateDep) -> ReadyCheck: + """Check scheduler liveness.""" + try: + scheduler = app_state.scheduler + if scheduler is not None and getattr(scheduler, "running", False): + return ReadyCheck(name="scheduler", healthy=True) + elif scheduler is not None: + return ReadyCheck(name="scheduler", healthy=False, message="Scheduler stopped") + else: + return ReadyCheck(name="scheduler", healthy=False, message="Not initialised") + except AttributeError: + return ReadyCheck(name="scheduler", healthy=False, message="Not accessible") + + +@router.get( + "/ready", + summary="Subsystem readiness probe", + response_model=ReadyResponse, + responses={ + 200: {"description": "All subsystems healthy"}, + 503: {"description": "One or more subsystems unreachable"}, + }, +) +async def readiness_probe( + app_state: AppStateDep, + server_status: ServerStatusDep, +) -> JSONResponse: + """Readiness check for Kubernetes. + + Verifies all critical sub-systems are reachable: + - Database connectivity + - fail2ban socket (via cached server status) + - Config directory read access + - Background scheduler liveness + + Returns HTTP 200 only when every check passes; returns HTTP 503 with a + JSON body listing every failed subsystem otherwise. Each check has a + short per-subsystem timeout to prevent the endpoint from overwhelming the + system under load. + """ + db_check, f2b_check, config_check, sched_check = await asyncio.gather( + _check_database(app_state), + _check_fail2ban(app_state, server_status), + _check_config_dir(app_state), + _check_scheduler(app_state), + ) + + checks: list[ReadyCheck] = [db_check, f2b_check, config_check, sched_check] + failed_count = sum(1 for c in checks if not c.healthy) + + http_status = status.HTTP_200_OK if failed_count == 0 else status.HTTP_503_SERVICE_UNAVAILABLE + + return JSONResponse( + status_code=http_status, + content=ReadyResponse( + status="ok" if failed_count == 0 else "error", + checks=checks, + failed_count=failed_count, + ).model_dump(), + ) diff --git a/backend/tests/test_main.py b/backend/tests/test_main.py index 2191ea9..dfe63f9 100644 --- a/backend/tests/test_main.py +++ b/backend/tests/test_main.py @@ -12,7 +12,15 @@ from httpx import ASGITransport, AsyncClient from app.config import Settings from app.db import init_db from app.exceptions import ConfigValidationError, ConfigWriteError, JailNotFoundError -from app.main import CORSMiddleware, _enforce_single_worker, _lifespan, create_app +from app.main import ( + CORSMiddleware, + _assert_middleware_order, + _enforce_single_worker, + _lifespan, + create_app, +) +from app.middleware.correlation import CorrelationIdMiddleware +from app.middleware.rate_limit import RateLimitMiddleware from app.services import setup_service @@ -450,14 +458,23 @@ async def test_startup_loads_geo_cache_from_persisted_runtime_database(tmp_path: exit_stack.enter_context(patch("app.services.geo_cache.GeoCache.load_cache_from_db", new=load_cache)) exit_stack.enter_context(patch("app.services.geo_cache.GeoCache.count_unresolved", new=AsyncMock(return_value=0))) exit_stack.enter_context(patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=True))) - exit_stack.enter_context(patch("app.services.setup_service.get_runtime_database_path", new=AsyncMock(return_value=runtime_db_path))) - exit_stack.enter_context(patch("app.services.setup_service.get_persisted_runtime_settings", new=AsyncMock(return_value={ - "database_path": runtime_db_path, - "fail2ban_socket": "/tmp/persisted.sock", - "timezone": "Europe/Berlin", - "session_duration_minutes": 123, - }))) - exit_stack.enter_context(patch("app.services.setup_service.get_fail2ban_db_path", new=AsyncMock(return_value="/tmp/fail2ban/banned.tar.bz2"))) + exit_stack.enter_context(patch( + "app.services.setup_service.get_runtime_database_path", + new=AsyncMock(return_value=runtime_db_path), + )) + exit_stack.enter_context(patch( + "app.services.setup_service.get_persisted_runtime_settings", + new=AsyncMock(return_value={ + "database_path": runtime_db_path, + "fail2ban_socket": "/tmp/persisted.sock", + "timezone": "Europe/Berlin", + "session_duration_minutes": 123, + }), + )) + exit_stack.enter_context(patch( + "app.services.setup_service.get_fail2ban_db_path", + new=AsyncMock(return_value="/tmp/fail2ban/banned.tar.bz2"), + )) exit_stack.enter_context(patch("app.tasks.health_check.register")) exit_stack.enter_context(patch("app.tasks.blocklist_import.register")) exit_stack.enter_context(patch("app.tasks.geo_cache_flush.register")) @@ -466,8 +483,9 @@ async def test_startup_loads_geo_cache_from_persisted_runtime_database(tmp_path: with exit_stack: async with _lifespan(app): - loaded_db_path = load_cache.call_args.args[0] - runtime_connections = [conn for path, conn in opened_connections if path == runtime_db_path] + runtime_connections = [ + conn for path, conn in opened_connections if path == runtime_db_path + ] assert runtime_connections, "Expected runtime database to be opened" assert app.state.runtime_settings is not None @@ -538,6 +556,91 @@ async def test_concurrent_requests_use_request_scoped_db_connections(tmp_path: P assert all(connection.close.await_count == 1 for connection in connections) +# --------------------------------------------------------------------------- +# Middleware order validation +# --------------------------------------------------------------------------- + + +def _make_settings(tmp_path: Path) -> Settings: + """Return a minimal Settings object with a temporary fail2ban config dir.""" + fail2ban_config_dir = tmp_path / "fail2ban" + fail2ban_config_dir.mkdir() + return Settings( + database_path=str(tmp_path / "bangui.db"), + fail2ban_socket="/tmp/fake_fail2ban.sock", + fail2ban_config_dir=str(fail2ban_config_dir), + session_secret="test-secret-key-do-not-use-in-production", + session_duration_minutes=60, + timezone="UTC", + log_level="debug", + ) + + +def test_create_app_raises_on_incorrect_middleware_order( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """_assert_middleware_order() raises AssertionError when middleware order is wrong. + + The security-critical chain requires: + RateLimitMiddleware → CsrfMiddleware → CorrelationIdMiddleware + in user_middleware (processing order: outermost → innermost). + """ + monkeypatch.setenv("TESTING", "1") + settings = _make_settings(tmp_path) + app = create_app(settings=settings) + # Swap CorrelationIdMiddleware and RateLimitMiddleware to break the order. + user_mw = app.user_middleware + corr_idx = next(i for i, m in enumerate(user_mw) if m.cls.__name__ == "CorrelationIdMiddleware") + rate_idx = next(i for i, m in enumerate(user_mw) if m.cls.__name__ == "RateLimitMiddleware") + user_mw[corr_idx], user_mw[rate_idx] = user_mw[rate_idx], user_mw[corr_idx] + with pytest.raises(AssertionError, match="must be registered before"): + _assert_middleware_order(app) + + +def test_middleware_order_validation_passes_for_correct_order( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """_assert_middleware_order() does not raise when middleware order is correct.""" + monkeypatch.setenv("TESTING", "1") + settings = _make_settings(tmp_path) + app = create_app(settings=settings) + _assert_middleware_order(app) # Should not raise + + +def test_create_app_validates_middleware_order_at_startup( + tmp_path: Path, + monkeypatch: pytest.MonkeyPatch, +) -> None: + """create_app() raises immediately if middleware registration order is incorrect. + + This test verifies the integration: _assert_middleware_order is called at the + end of create_app, so a fresh app with deliberately wrong middleware order + (simulated by patching add_middleware during creation) raises AssertionError. + """ + monkeypatch.setenv("TESTING", "1") + settings = _make_settings(tmp_path) + + from starlette.applications import Starlette + + original_add = Starlette.add_middleware + + def swapping_add(self, middleware_cls: type, **kwargs: object) -> None: + """Patched add_middleware that swaps CorrelationId and RateLimit.""" + if middleware_cls is CorrelationIdMiddleware: + pass # Skip CorrelationId + elif middleware_cls is RateLimitMiddleware: + original_add(self, RateLimitMiddleware, **kwargs) + original_add(self, CorrelationIdMiddleware) + else: + original_add(self, middleware_cls, **kwargs) + + with patch.object(Starlette, "add_middleware", swapping_add), \ + pytest.raises(AssertionError, match="must be registered before"): + create_app(settings=settings) + + # --------------------------------------------------------------------------- # Single-worker enforcement # --------------------------------------------------------------------------- diff --git a/backend/tests/test_routers/test_health_probes.py b/backend/tests/test_routers/test_health_probes.py new file mode 100644 index 0000000..c7e1917 --- /dev/null +++ b/backend/tests/test_routers/test_health_probes.py @@ -0,0 +1,130 @@ +"""Tests for the health-check router — liveness and readiness probes.""" + +from unittest.mock import MagicMock, patch + +import pytest +from httpx import AsyncClient + +from app.models.server import ServerStatus +from app.models.response import ReadyCheck + + +# --------------------------------------------------------------------------- +# GET /health/live — liveness probe +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_liveness_returns_200(client: AsyncClient) -> None: + """``GET /api/v1/health/live`` must always return HTTP 200.""" + response = await client.get("/api/v1/health/live") + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_liveness_body_is_ready_response(client: AsyncClient) -> None: + """Response body must be a ReadyResponse.""" + response = await client.get("/api/v1/health/live") + data: dict[str, object] = response.json() + assert data["status"] == "ok" + assert data["failed_count"] == 0 + assert "checks" in data + assert isinstance(data["checks"], list) + + +@pytest.mark.asyncio +async def test_liveness_includes_process_check(client: AsyncClient) -> None: + """Liveness response must include a 'process' check.""" + response = await client.get("/api/v1/health/live") + data: dict[str, object] = response.json() + checks: list[dict[str, object]] = data["checks"] # type: ignore[assignment] + assert any(c.get("name") == "process" and c.get("healthy") is True for c in checks) + + +# --------------------------------------------------------------------------- +# GET /health/ready — readiness probe +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_readiness_returns_200_when_all_pass(client: AsyncClient) -> None: + """``GET /api/v1/health/ready`` must return 200 when all subsystems pass.""" + with patch("app.routers.health._run_check", side_effect=lambda n, c, e: ReadyCheck(name=n, healthy=True)): + response = await client.get("/api/v1/health/ready") + assert response.status_code == 200 + + +@pytest.mark.asyncio +async def test_readiness_returns_503_when_subsystem_fails(client: AsyncClient) -> None: + """``GET /api/v1/health/ready`` must return 503 when at least one check fails.""" + # Force fail2ban offline + client._transport.app.state.server_status = ServerStatus(online=False) + response = await client.get("/api/v1/health/ready") + assert response.status_code == 503 + + +@pytest.mark.asyncio +async def test_readiness_body_is_ready_response(client: AsyncClient) -> None: + """Response body must be a ReadyResponse.""" + response = await client.get("/api/v1/health/ready") + data: dict[str, object] = response.json() + assert data["status"] in ("ok", "error") + assert "failed_count" in data + assert "checks" in data + assert isinstance(data["checks"], list) + + +@pytest.mark.asyncio +async def test_readiness_includes_all_subsystems(client: AsyncClient) -> None: + """Readiness response must include checks for all four subsystems.""" + response = await client.get("/api/v1/health/ready") + data: dict[str, object] = response.json() + checks: list[dict[str, object]] = data["checks"] # type: ignore[assignment] + names = {c["name"] for c in checks} + assert names == {"database", "fail2ban", "config_dir", "scheduler"} + + +@pytest.mark.asyncio +async def test_readiness_status_ok_when_all_healthy(client: AsyncClient) -> None: + """``status`` must be 'ok' when all checks pass.""" + with patch("app.routers.health._run_check", side_effect=lambda n, c, e: ReadyCheck(name=n, healthy=True)): + response = await client.get("/api/v1/health/ready") + data: dict[str, object] = response.json() + assert data["status"] == "ok" + assert data["failed_count"] == 0 + + +@pytest.mark.asyncio +async def test_readiness_status_error_when_fail2ban_offline(client: AsyncClient) -> None: + """``status`` must be 'error' when fail2ban is offline.""" + client._transport.app.state.server_status = ServerStatus(online=False) + response = await client.get("/api/v1/health/ready") + data: dict[str, object] = response.json() + assert data["status"] == "error" + assert data["failed_count"] > 0 + + +@pytest.mark.asyncio +async def test_readiness_includes_failed_subsystem_detail(client: AsyncClient) -> None: + """When fail2ban is offline the fail2ban check must include an error message.""" + client._transport.app.state.server_status = ServerStatus(online=False) + response = await client.get("/api/v1/health/ready") + data: dict[str, object] = response.json() + checks: list[dict[str, object]] = data["checks"] # type: ignore[assignment] + f2b = next(c for c in checks if c["name"] == "fail2ban") + assert f2b["healthy"] is False + assert f2b["message"] is not None + + +@pytest.mark.asyncio +async def test_readiness_content_type_is_json(client: AsyncClient) -> None: + """``/api/v1/health/ready`` must set the ``Content-Type`` header to JSON.""" + response = await client.get("/api/v1/health/ready") + assert "application/json" in response.headers.get("content-type", "") + + +@pytest.mark.asyncio +async def test_readiness_live_content_type_is_json(client: AsyncClient) -> None: + """``/api/v1/health/live`` must set the ``Content-Type`` header to JSON.""" + response = await client.get("/api/v1/health/live") + assert "application/json" in response.headers.get("content-type", "") diff --git a/frontend/src/hooks/README.md b/frontend/src/hooks/README.md index a0fd657..ff1f57a 100644 --- a/frontend/src/hooks/README.md +++ b/frontend/src/hooks/README.md @@ -53,3 +53,28 @@ With drift correction: - Total time from poll start to next poll start is always ~5 seconds - Fetch duration doesn't affect the long-term polling rate - Bandwidth and CPU usage remain consistent + +## Request Deduplication with Abort Signal Handling + +When `useFetchData` is called with a `requestKey`, multiple hook instances sharing the same key coalesce around a single in-flight request. Each instance is tracked as a **subscriber**. + +### How it works + +- A module-level `Map` holds in-flight requests. +- Each `InFlightRequest` contains a `subscribers` map: `subscriberId → { signal, onSetData, onSuccess }`. +- When a new hook instance joins a deduplicated request, it registers a subscriber entry and attaches a `once` abort listener that removes it from the subscriber list. +- When the underlying request resolves, **each subscriber's signal is checked before calling `setData` or `onSuccess`**. Aborted subscribers are skipped, preventing state updates on unmounted components. +- When all subscribers have aborted, the underlying `AbortController` is signalled, cancelling the in-flight fetch. + +### Guarantees + +| Scenario | Result | +|---|---| +| Subscriber aborts before request resolves | Subscriber receives no `setData`/`onSuccess`. Entry removed from subscriber list. | +| Last subscriber aborts before request resolves | Underlying request is cancelled via `AbortController`. | +| Request resolves, subscriber already aborted | `signal.aborted` checked before `setData`/`onSuccess` — no-op. | +| Request errors | Non-abort errors passed to `handleFetchError`. Each subscriber's error state updated independently. | + +### When to use deduplication + +Use `requestKey` when multiple components may request the same resource simultaneously. The shared promise avoids redundant network calls and prevents race conditions from out-of-order responses. diff --git a/frontend/src/hooks/__tests__/useFetchData.test.ts b/frontend/src/hooks/__tests__/useFetchData.test.ts index 554543c..b185ce2 100644 --- a/frontend/src/hooks/__tests__/useFetchData.test.ts +++ b/frontend/src/hooks/__tests__/useFetchData.test.ts @@ -1,8 +1,13 @@ -import { describe, it, expect, vi } from "vitest"; +import { describe, it, expect, vi, beforeEach } from "vitest"; import { renderHook, act } from "@testing-library/react"; -import { useFetchData } from "../useFetchData"; +import { useFetchData, _resetInFlightRequests } from "../useFetchData"; import type { FetchError } from "../../types/api"; +// Clear the module-level inFlightRequests map between tests to prevent state leakage +beforeEach(() => { + _resetInFlightRequests(); +}); + describe("useFetchData", () => { it("fetches and selects data on mount", async () => { const fetcher = vi.fn().mockResolvedValue({ value: "test" }); @@ -406,4 +411,123 @@ describe("useFetchData", () => { expect(fetcher).toHaveBeenCalledTimes(2); expect(result.current.data).toBe("second"); }); + + it("aborted subscriber in deduplication does not receive setData/onSuccess", async () => { + let resolveFirst: ((value: { value: string }) => void) | null = null; + const fetcher = vi.fn().mockImplementation( + () => + new Promise((resolve) => { + resolveFirst = resolve; + }) + ); + const selector = vi.fn((response: { value: string }) => response.value); + const onSuccess1 = vi.fn(); + const onSuccess2 = vi.fn(); + + const hook1 = renderHook(() => + useFetchData({ + fetcher, + selector, + errorMessage: "Failed to load", + requestKey: "abort-subscriber-test", + onSuccess: onSuccess1, + }) + ); + + const hook2 = renderHook(() => + useFetchData({ + fetcher, + selector, + errorMessage: "Failed to load", + requestKey: "abort-subscriber-test", + onSuccess: onSuccess2, + }) + ); + + await act(async () => { + await Promise.resolve(); + }); + + expect(fetcher).toHaveBeenCalledTimes(1); + + // Unmount hook2 (simulates component unmounting and aborting its signal) + hook2.unmount(); + + // Allow the request to resolve + await act(async () => { + resolveFirst?.({ value: "shared-data" }); + await Promise.resolve(); + }); + + // hook2 should not have received the data (it aborted) + expect(hook2.result.current.data).toBeUndefined(); + expect(onSuccess2).not.toHaveBeenCalled(); + // hook1 should still have the data + expect(hook1.result.current.data).toBe("shared-data"); + expect(onSuccess1).toHaveBeenCalledWith({ value: "shared-data" }); + }); + + it("last subscriber abort cancels underlying request", async () => { + let resolveFirst: ((value: { value: string }) => void) | null = null; + const abortSignals: AbortSignal[] = []; + const fetcher = vi.fn().mockImplementation((signal: AbortSignal) => { + abortSignals.push(signal); + return new Promise((resolve) => { + resolveFirst = resolve; + }); + }); + const selector = vi.fn((response: { value: string }) => response.value); + + const hook1 = renderHook(() => + useFetchData({ + fetcher, + selector, + errorMessage: "Failed to load", + requestKey: "cancel-underlying-test", + }) + ); + + await act(async () => { + await Promise.resolve(); + }); + + expect(fetcher).toHaveBeenCalledTimes(1); + + // Mount second subscriber + const hook2 = renderHook(() => + useFetchData({ + fetcher, + selector, + errorMessage: "Failed to load", + requestKey: "cancel-underlying-test", + }) + ); + + await act(async () => { + await Promise.resolve(); + }); + + // Both subscribers now sharing one request + expect(fetcher).toHaveBeenCalledTimes(1); + + // Unmount first subscriber (doesn't cancel underlying request yet) + hook1.unmount(); + + await act(async () => { + await Promise.resolve(); + }); + + // Underlying request should NOT be cancelled yet (hook2 still waiting) + expect(abortSignals[0]?.aborted).toBe(false); + + // Unmount second (last) subscriber — should cancel the underlying request + hook2.unmount(); + + await act(async () => { + await Promise.resolve(); + }); + + // Underlying request should now be cancelled + expect(abortSignals[0]?.aborted).toBe(true); + }); }); diff --git a/frontend/src/hooks/useFetchData.ts b/frontend/src/hooks/useFetchData.ts index c941aef..bc2bf98 100644 --- a/frontend/src/hooks/useFetchData.ts +++ b/frontend/src/hooks/useFetchData.ts @@ -19,16 +19,33 @@ import type { FetchError } from "../types/api"; /** * Module-level cache for in-flight requests. - * Maps requestKey to { promise, controller } to enable deduplication + * Maps requestKey to { promise, controller, subscribers } to enable deduplication * across multiple hook instances. */ +interface Subscriber { + /** This instance's abort signal. */ + signal: AbortSignal; + /** Callbacks registered by this subscriber. */ + onSetData?: (response: TResponse) => void; + onSuccess?: (response: TResponse) => void; +} + interface InFlightRequest { promise: Promise; controller: AbortController; + /** Map of subscriberId -> Subscriber. Cleared when subscriber aborts. */ + subscribers: Map>; + /** True when initiator has cleaned up but subscribers remain. */ + initiatorDone: boolean; } const inFlightRequests = new Map>(); +/** Visible for testing only. Clears all in-flight request state. */ +export const _resetInFlightRequests = (): void => { + inFlightRequests.clear(); +}; + export interface UseFetchDataOptions { /** Async function that accepts an AbortSignal for cancellation. */ fetcher: (signal: AbortSignal) => Promise; @@ -86,18 +103,57 @@ export function useFetchData( const [error, setError] = useState(null); const abortRef = useRef(null); const localControllerRef = useRef(null); + const subscriberAbortRef = useRef(null); + /** Unique ID for this instance, used to track its subscription to deduplicated requests. */ + const subscriberIdRef = useRef(null); const refresh = useCallback((): void => { + // Abort any previous request from this hook instance + abortRef.current?.abort(); + localControllerRef.current = new AbortController(); + abortRef.current = localControllerRef.current; + // If using request deduplication via requestKey and a request is already in-flight, - // wait for it to complete instead of launching a duplicate + // subscribe to it instead of launching a duplicate if (requestKey && inFlightRequests.has(requestKey)) { const inFlight = inFlightRequests.get(requestKey)! as InFlightRequest; - inFlight.promise - .then((response: TResponse) => { + + // Create per-instance abort controller for this subscription. + // Do NOT abort previous subscription signals here - each subscriber's signal + // only aborts when THAT subscriber unmounts. Aborting old signals would + // incorrectly remove other active subscribers from the list. + subscriberAbortRef.current = new AbortController(); + + // Register this instance as a subscriber. + // Always generate a new ID - each hook instance must have its own subscriber ID + // so that unmounting one subscriber does not remove another subscriber's entry. + subscriberIdRef.current = crypto.randomUUID(); + const sid = subscriberIdRef.current; + const subscription: Subscriber = { + signal: subscriberAbortRef.current.signal, + onSetData: (response: TResponse) => { setData(selector(response)); if (onSuccess) { onSuccess(response); } + }, + onSuccess, + }; + inFlight.subscribers.set(sid, subscription); + + // When this instance aborts, remove it from the subscriber list. + // Note: we do NOT abort the underlying controller here - that belongs to + // the initiator and should only be aborted by the initiator itself. + const cleanup = (): void => { + inFlight.subscribers.delete(sid); + }; + subscription.signal.addEventListener("abort", cleanup, { once: true }); + + inFlight.promise + .then((response: TResponse) => { + // Check whether this subscriber has already aborted before propagating result + if (subscription.signal.aborted) return; + subscription.onSetData?.(response); }) .catch((err: unknown) => { // Only handle non-abort errors; abort errors are silently ignored @@ -139,7 +195,6 @@ export function useFetchData( if (!controller.signal.aborted) { setLoading(false); } - // Clear cache entry when response arrives if (requestKey) { inFlightRequests.delete(requestKey); } @@ -150,6 +205,8 @@ export function useFetchData( inFlightRequests.set(requestKey, { promise: responsePromise, controller, + subscribers: new Map(), + initiatorDone: false, }); } }, [fetcher, selector, errorMessage, onSuccess, requestKey]); @@ -158,7 +215,25 @@ export function useFetchData( refresh(); return (): void => { - abortRef.current?.abort(); + if (subscriberIdRef.current !== null) { + // Subscriber cleanup + const inFlight = requestKey ? inFlightRequests.get(requestKey) : undefined; + subscriberAbortRef.current?.abort(); + // If initiator already done and no subscribers left, cancel the request + if (inFlight && inFlight.initiatorDone && inFlight.subscribers.size === 0) { + inFlight.controller.abort(); + } + } else { + // Initiator cleanup + const inFlight = requestKey ? inFlightRequests.get(requestKey) : undefined; + // Mark initiator as done so last subscriber knows to cancel + if (inFlight) { + inFlight.initiatorDone = true; + } + if (!inFlight || inFlight.subscribers.size === 0) { + abortRef.current?.abort(); + } + } }; }, [refresh]);