Task 0.1: Create database parent directory before connecting - main.py _lifespan now calls Path(database_path).parent.mkdir(parents=True, exist_ok=True) before aiosqlite.connect() so the app starts cleanly on a fresh Docker volume with a nested database path. Task 0.2: SetupRedirectMiddleware redirects when db is None - Guard now reads: if db is None or not is_setup_complete(db) A missing database (startup still in progress) is treated as setup not complete instead of silently allowing all API routes through. Task 0.3: SetupGuard redirects to /setup on API failure - .catch() handler now sets status to 'pending' instead of 'done'. A crashed backend cannot serve protected routes; conservative fallback is to redirect to /setup. Task 0.4: SetupPage shows spinner while checking setup status - Added 'checking' boolean state; full-screen Spinner is rendered until getSetupStatus() resolves, preventing form flash before redirect. - Added console.warn in catch block; cleanup return added to useEffect. Also: remove unused type: ignore[call-arg] from config.py. Tests: 18 backend tests pass; 117 frontend tests pass.
435 lines
16 KiB
Python
435 lines
16 KiB
Python
"""Tests for the setup router (POST /api/setup, GET /api/setup, GET /api/setup/timezone)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import aiosqlite
|
|
import pytest
|
|
from httpx import ASGITransport, AsyncClient
|
|
|
|
from app.config import Settings
|
|
from app.db import init_db
|
|
from app.main import _lifespan, create_app
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Shared setup payload
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_SETUP_PAYLOAD: dict[str, object] = {
|
|
"master_password": "supersecret123",
|
|
"database_path": "bangui.db",
|
|
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
|
"timezone": "UTC",
|
|
"session_duration_minutes": 60,
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixture for tests that need direct access to app.state
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
async def app_and_client(tmp_path: Path) -> tuple[object, AsyncClient]: # type: ignore[misc]
|
|
"""Yield ``(app, client)`` for tests that inspect ``app.state`` directly.
|
|
|
|
Args:
|
|
tmp_path: Pytest-provided isolated temporary directory.
|
|
|
|
Yields:
|
|
A tuple of ``(FastAPI app instance, AsyncClient)``.
|
|
"""
|
|
settings = Settings(
|
|
database_path=str(tmp_path / "setup_cache_test.db"),
|
|
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
|
session_secret="test-setup-cache-secret",
|
|
session_duration_minutes=60,
|
|
timezone="UTC",
|
|
log_level="debug",
|
|
)
|
|
app = create_app(settings=settings)
|
|
|
|
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
|
|
db.row_factory = aiosqlite.Row
|
|
await init_db(db)
|
|
app.state.db = db
|
|
|
|
transport: ASGITransport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
yield app, ac
|
|
|
|
await db.close()
|
|
|
|
|
|
class TestGetSetupStatus:
|
|
"""GET /api/setup — check setup completion state."""
|
|
|
|
async def test_returns_not_completed_on_fresh_db(self, client: AsyncClient) -> None:
|
|
"""Status endpoint reports setup not done on a fresh database."""
|
|
response = await client.get("/api/setup")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"completed": False}
|
|
|
|
async def test_returns_completed_after_setup(self, client: AsyncClient) -> None:
|
|
"""Status endpoint reports setup done after POST /api/setup."""
|
|
await client.post(
|
|
"/api/setup",
|
|
json={
|
|
"master_password": "supersecret123",
|
|
"database_path": "bangui.db",
|
|
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
|
"timezone": "UTC",
|
|
"session_duration_minutes": 60,
|
|
},
|
|
)
|
|
response = await client.get("/api/setup")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"completed": True}
|
|
|
|
|
|
class TestPostSetup:
|
|
"""POST /api/setup — run the first-run configuration wizard."""
|
|
|
|
async def test_accepts_valid_payload(self, client: AsyncClient) -> None:
|
|
"""Setup endpoint returns 201 for a valid first-run payload."""
|
|
response = await client.post(
|
|
"/api/setup",
|
|
json={
|
|
"master_password": "supersecret123",
|
|
"database_path": "bangui.db",
|
|
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
|
"timezone": "UTC",
|
|
"session_duration_minutes": 60,
|
|
},
|
|
)
|
|
assert response.status_code == 201
|
|
body = response.json()
|
|
assert "message" in body
|
|
|
|
async def test_rejects_short_password(self, client: AsyncClient) -> None:
|
|
"""Setup endpoint rejects passwords shorter than 8 characters."""
|
|
response = await client.post(
|
|
"/api/setup",
|
|
json={"master_password": "short"},
|
|
)
|
|
assert response.status_code == 422
|
|
|
|
async def test_rejects_second_call(self, client: AsyncClient) -> None:
|
|
"""Setup endpoint returns 409 if setup has already been completed."""
|
|
payload = {
|
|
"master_password": "supersecret123",
|
|
"database_path": "bangui.db",
|
|
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
|
"timezone": "UTC",
|
|
"session_duration_minutes": 60,
|
|
}
|
|
first = await client.post("/api/setup", json=payload)
|
|
assert first.status_code == 201
|
|
|
|
second = await client.post("/api/setup", json=payload)
|
|
assert second.status_code == 409
|
|
|
|
async def test_accepts_defaults_for_optional_fields(
|
|
self, client: AsyncClient
|
|
) -> None:
|
|
"""Setup endpoint uses defaults when optional fields are omitted."""
|
|
response = await client.post(
|
|
"/api/setup",
|
|
json={"master_password": "supersecret123"},
|
|
)
|
|
assert response.status_code == 201
|
|
|
|
|
|
class TestSetupRedirectMiddleware:
|
|
"""Verify that the setup-redirect middleware enforces setup-first."""
|
|
|
|
async def test_protected_endpoint_redirects_before_setup(
|
|
self, client: AsyncClient
|
|
) -> None:
|
|
"""Non-setup API requests redirect to /api/setup on a fresh instance."""
|
|
response = await client.get(
|
|
"/api/auth/login",
|
|
follow_redirects=False,
|
|
)
|
|
# Middleware issues 307 redirect to /api/setup
|
|
assert response.status_code == 307
|
|
assert response.headers["location"] == "/api/setup"
|
|
|
|
async def test_health_always_reachable_before_setup(
|
|
self, client: AsyncClient
|
|
) -> None:
|
|
"""Health endpoint is always reachable even before setup."""
|
|
response = await client.get("/api/health")
|
|
assert response.status_code == 200
|
|
|
|
async def test_no_redirect_after_setup(self, client: AsyncClient) -> None:
|
|
"""Protected endpoints are reachable (no redirect) after setup."""
|
|
await client.post(
|
|
"/api/setup",
|
|
json={"master_password": "supersecret123"},
|
|
)
|
|
# /api/auth/login should now be reachable (returns 405 GET not allowed,
|
|
# not a setup redirect)
|
|
response = await client.post(
|
|
"/api/auth/login",
|
|
json={"password": "wrong"},
|
|
follow_redirects=False,
|
|
)
|
|
# 401 wrong password — not a 307 redirect
|
|
assert response.status_code == 401
|
|
|
|
|
|
class TestGetTimezone:
|
|
"""GET /api/setup/timezone — return the configured IANA timezone."""
|
|
|
|
async def test_returns_utc_before_setup(self, client: AsyncClient) -> None:
|
|
"""Timezone endpoint returns 'UTC' on a fresh database (no setup yet)."""
|
|
response = await client.get("/api/setup/timezone")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"timezone": "UTC"}
|
|
|
|
async def test_returns_configured_timezone(self, client: AsyncClient) -> None:
|
|
"""Timezone endpoint returns the value set during setup."""
|
|
await client.post(
|
|
"/api/setup",
|
|
json={
|
|
"master_password": "supersecret123",
|
|
"timezone": "Europe/Berlin",
|
|
},
|
|
)
|
|
response = await client.get("/api/setup/timezone")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"timezone": "Europe/Berlin"}
|
|
|
|
async def test_endpoint_always_reachable_before_setup(
|
|
self, client: AsyncClient
|
|
) -> None:
|
|
"""Timezone endpoint is reachable before setup (no redirect)."""
|
|
response = await client.get(
|
|
"/api/setup/timezone",
|
|
follow_redirects=False,
|
|
)
|
|
# Should return 200, not a 307 redirect, because /api/setup paths
|
|
# are always allowed by the SetupRedirectMiddleware.
|
|
assert response.status_code == 200
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Setup-complete flag caching in SetupRedirectMiddleware (Task 4)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSetupCompleteCaching:
|
|
"""SetupRedirectMiddleware caches the setup_complete flag in ``app.state``."""
|
|
|
|
async def test_cache_flag_set_after_first_post_setup_request(
|
|
self,
|
|
app_and_client: tuple[object, AsyncClient],
|
|
) -> None:
|
|
"""``_setup_complete_cached`` is set to True on the first request after setup.
|
|
|
|
The ``/api/setup`` path is in ``_ALWAYS_ALLOWED`` so it bypasses the
|
|
middleware check. The first request to a non-exempt endpoint triggers
|
|
the DB query and, when setup is complete, populates the cache flag.
|
|
"""
|
|
from fastapi import FastAPI
|
|
|
|
app, client = app_and_client
|
|
assert isinstance(app, FastAPI)
|
|
|
|
# Complete setup (exempt from middleware, no flag set yet).
|
|
resp = await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
|
assert resp.status_code == 201
|
|
|
|
# Flag not yet cached — setup was via an exempt path.
|
|
assert not getattr(app.state, "_setup_complete_cached", False)
|
|
|
|
# First non-exempt request — middleware queries DB and sets the flag.
|
|
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload]
|
|
|
|
assert app.state._setup_complete_cached is True # type: ignore[attr-defined]
|
|
|
|
async def test_cached_path_skips_is_setup_complete(
|
|
self,
|
|
app_and_client: tuple[object, AsyncClient],
|
|
) -> None:
|
|
"""Subsequent requests do not call ``is_setup_complete`` once flag is cached.
|
|
|
|
After the flag is set, the middleware must not touch the database for
|
|
any further requests — even if ``is_setup_complete`` would raise.
|
|
"""
|
|
from fastapi import FastAPI
|
|
|
|
app, client = app_and_client
|
|
assert isinstance(app, FastAPI)
|
|
|
|
# Do setup and warm the cache.
|
|
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
|
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload]
|
|
assert app.state._setup_complete_cached is True # type: ignore[attr-defined]
|
|
|
|
call_count = 0
|
|
|
|
async def _counting(db): # type: ignore[no-untyped-def]
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return True
|
|
|
|
with patch("app.services.setup_service.is_setup_complete", side_effect=_counting):
|
|
await client.post(
|
|
"/api/auth/login",
|
|
json={"password": _SETUP_PAYLOAD["master_password"]},
|
|
)
|
|
|
|
# Cache was warm — is_setup_complete must not have been called.
|
|
assert call_count == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Task 0.1 — Lifespan creates the database parent directory (Task 0.1)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestLifespanDatabaseDirectoryCreation:
|
|
"""App lifespan creates the database parent directory when it does not exist."""
|
|
|
|
async def test_creates_nested_database_directory(self, tmp_path: Path) -> None:
|
|
"""Lifespan creates intermediate directories for the database path.
|
|
|
|
Verifies that a deeply-nested database path is handled correctly —
|
|
the parent directories are created before ``aiosqlite.connect`` is
|
|
called so the app does not crash on a fresh volume.
|
|
"""
|
|
nested_db = tmp_path / "deep" / "nested" / "bangui.db"
|
|
assert not nested_db.parent.exists()
|
|
|
|
settings = Settings(
|
|
database_path=str(nested_db),
|
|
fail2ban_socket="/tmp/fake.sock",
|
|
session_secret="test-lifespan-mkdir-secret",
|
|
session_duration_minutes=60,
|
|
timezone="UTC",
|
|
log_level="debug",
|
|
)
|
|
app = create_app(settings=settings)
|
|
|
|
mock_scheduler = MagicMock()
|
|
mock_scheduler.start = MagicMock()
|
|
mock_scheduler.shutdown = MagicMock()
|
|
|
|
with (
|
|
patch("app.services.geo_service.init_geoip"),
|
|
patch(
|
|
"app.services.geo_service.load_cache_from_db",
|
|
new=AsyncMock(return_value=None),
|
|
),
|
|
patch("app.tasks.health_check.register"),
|
|
patch("app.tasks.blocklist_import.register"),
|
|
patch("app.tasks.geo_cache_flush.register"),
|
|
patch("app.tasks.geo_re_resolve.register"),
|
|
patch("app.main.AsyncIOScheduler", return_value=mock_scheduler),
|
|
):
|
|
async with _lifespan(app):
|
|
assert nested_db.parent.exists(), (
|
|
"Expected lifespan to create database parent directory"
|
|
)
|
|
|
|
async def test_existing_database_directory_is_not_an_error(
|
|
self, tmp_path: Path
|
|
) -> None:
|
|
"""Lifespan does not raise when the database directory already exists.
|
|
|
|
``mkdir(exist_ok=True)`` must be used so that re-starts on an existing
|
|
volume do not fail.
|
|
"""
|
|
db_path = tmp_path / "bangui.db"
|
|
# tmp_path already exists — this simulates a pre-existing volume.
|
|
|
|
settings = Settings(
|
|
database_path=str(db_path),
|
|
fail2ban_socket="/tmp/fake.sock",
|
|
session_secret="test-lifespan-exist-ok-secret",
|
|
session_duration_minutes=60,
|
|
timezone="UTC",
|
|
log_level="debug",
|
|
)
|
|
app = create_app(settings=settings)
|
|
|
|
mock_scheduler = MagicMock()
|
|
mock_scheduler.start = MagicMock()
|
|
mock_scheduler.shutdown = MagicMock()
|
|
|
|
with (
|
|
patch("app.services.geo_service.init_geoip"),
|
|
patch(
|
|
"app.services.geo_service.load_cache_from_db",
|
|
new=AsyncMock(return_value=None),
|
|
),
|
|
patch("app.tasks.health_check.register"),
|
|
patch("app.tasks.blocklist_import.register"),
|
|
patch("app.tasks.geo_cache_flush.register"),
|
|
patch("app.tasks.geo_re_resolve.register"),
|
|
patch("app.main.AsyncIOScheduler", return_value=mock_scheduler),
|
|
):
|
|
# Should not raise FileExistsError or similar.
|
|
async with _lifespan(app):
|
|
assert tmp_path.exists()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Task 0.2 — Middleware redirects when app.state.db is None
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSetupRedirectMiddlewareDbNone:
|
|
"""SetupRedirectMiddleware redirects when the database is not yet available."""
|
|
|
|
async def test_redirects_to_setup_when_db_not_set(self, tmp_path: Path) -> None:
|
|
"""A ``None`` db on app.state causes a 307 redirect to ``/api/setup``.
|
|
|
|
Simulates the race window where a request arrives before the lifespan
|
|
has finished initialising the database connection.
|
|
"""
|
|
settings = Settings(
|
|
database_path=str(tmp_path / "bangui.db"),
|
|
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
|
session_secret="test-db-none-secret",
|
|
session_duration_minutes=60,
|
|
timezone="UTC",
|
|
log_level="debug",
|
|
)
|
|
app = create_app(settings=settings)
|
|
# Deliberately do NOT set app.state.db to simulate startup not complete.
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(
|
|
transport=transport, base_url="http://test"
|
|
) as ac:
|
|
response = await ac.get("/api/auth/login", follow_redirects=False)
|
|
|
|
assert response.status_code == 307
|
|
assert response.headers["location"] == "/api/setup"
|
|
|
|
async def test_health_reachable_when_db_not_set(self, tmp_path: Path) -> None:
|
|
"""Health endpoint is always reachable even when db is not initialised."""
|
|
settings = Settings(
|
|
database_path=str(tmp_path / "bangui.db"),
|
|
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
|
session_secret="test-db-none-health-secret",
|
|
session_duration_minutes=60,
|
|
timezone="UTC",
|
|
log_level="debug",
|
|
)
|
|
app = create_app(settings=settings)
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(
|
|
transport=transport, base_url="http://test"
|
|
) as ac:
|
|
response = await ac.get("/api/health")
|
|
|
|
assert response.status_code == 200
|
|
|