On startup BanGUI now verifies that the four fail2ban jail config files required by its two custom jails (manual-Jail and blocklist-import) are present in `$fail2ban_config_dir/jail.d`. Any missing file is created with the correct default content; existing files are never overwritten. Files managed: - manual-Jail.conf (enabled=false template) - manual-Jail.local (enabled=true override) - blocklist-import.conf (enabled=false template) - blocklist-import.local (enabled=true override) The check runs in the lifespan hook immediately after logging is configured, before the database is opened.
437 lines
16 KiB
Python
437 lines
16 KiB
Python
"""Tests for the setup router (POST /api/setup, GET /api/setup, GET /api/setup/timezone)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import aiosqlite
|
|
import pytest
|
|
from httpx import ASGITransport, AsyncClient
|
|
|
|
from app.config import Settings
|
|
from app.db import init_db
|
|
from app.main import _lifespan, create_app
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Shared setup payload
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_SETUP_PAYLOAD: dict[str, object] = {
|
|
"master_password": "supersecret123",
|
|
"database_path": "bangui.db",
|
|
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
|
"timezone": "UTC",
|
|
"session_duration_minutes": 60,
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixture for tests that need direct access to app.state
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
async def app_and_client(tmp_path: Path) -> tuple[object, AsyncClient]: # type: ignore[misc]
|
|
"""Yield ``(app, client)`` for tests that inspect ``app.state`` directly.
|
|
|
|
Args:
|
|
tmp_path: Pytest-provided isolated temporary directory.
|
|
|
|
Yields:
|
|
A tuple of ``(FastAPI app instance, AsyncClient)``.
|
|
"""
|
|
settings = Settings(
|
|
database_path=str(tmp_path / "setup_cache_test.db"),
|
|
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
|
session_secret="test-setup-cache-secret",
|
|
session_duration_minutes=60,
|
|
timezone="UTC",
|
|
log_level="debug",
|
|
)
|
|
app = create_app(settings=settings)
|
|
|
|
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
|
|
db.row_factory = aiosqlite.Row
|
|
await init_db(db)
|
|
app.state.db = db
|
|
|
|
transport: ASGITransport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
yield app, ac
|
|
|
|
await db.close()
|
|
|
|
|
|
class TestGetSetupStatus:
|
|
"""GET /api/setup — check setup completion state."""
|
|
|
|
async def test_returns_not_completed_on_fresh_db(self, client: AsyncClient) -> None:
|
|
"""Status endpoint reports setup not done on a fresh database."""
|
|
response = await client.get("/api/setup")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"completed": False}
|
|
|
|
async def test_returns_completed_after_setup(self, client: AsyncClient) -> None:
|
|
"""Status endpoint reports setup done after POST /api/setup."""
|
|
await client.post(
|
|
"/api/setup",
|
|
json={
|
|
"master_password": "supersecret123",
|
|
"database_path": "bangui.db",
|
|
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
|
"timezone": "UTC",
|
|
"session_duration_minutes": 60,
|
|
},
|
|
)
|
|
response = await client.get("/api/setup")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"completed": True}
|
|
|
|
|
|
class TestPostSetup:
|
|
"""POST /api/setup — run the first-run configuration wizard."""
|
|
|
|
async def test_accepts_valid_payload(self, client: AsyncClient) -> None:
|
|
"""Setup endpoint returns 201 for a valid first-run payload."""
|
|
response = await client.post(
|
|
"/api/setup",
|
|
json={
|
|
"master_password": "supersecret123",
|
|
"database_path": "bangui.db",
|
|
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
|
"timezone": "UTC",
|
|
"session_duration_minutes": 60,
|
|
},
|
|
)
|
|
assert response.status_code == 201
|
|
body = response.json()
|
|
assert "message" in body
|
|
|
|
async def test_rejects_short_password(self, client: AsyncClient) -> None:
|
|
"""Setup endpoint rejects passwords shorter than 8 characters."""
|
|
response = await client.post(
|
|
"/api/setup",
|
|
json={"master_password": "short"},
|
|
)
|
|
assert response.status_code == 422
|
|
|
|
async def test_rejects_second_call(self, client: AsyncClient) -> None:
|
|
"""Setup endpoint returns 409 if setup has already been completed."""
|
|
payload = {
|
|
"master_password": "supersecret123",
|
|
"database_path": "bangui.db",
|
|
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
|
"timezone": "UTC",
|
|
"session_duration_minutes": 60,
|
|
}
|
|
first = await client.post("/api/setup", json=payload)
|
|
assert first.status_code == 201
|
|
|
|
second = await client.post("/api/setup", json=payload)
|
|
assert second.status_code == 409
|
|
|
|
async def test_accepts_defaults_for_optional_fields(
|
|
self, client: AsyncClient
|
|
) -> None:
|
|
"""Setup endpoint uses defaults when optional fields are omitted."""
|
|
response = await client.post(
|
|
"/api/setup",
|
|
json={"master_password": "supersecret123"},
|
|
)
|
|
assert response.status_code == 201
|
|
|
|
|
|
class TestSetupRedirectMiddleware:
|
|
"""Verify that the setup-redirect middleware enforces setup-first."""
|
|
|
|
async def test_protected_endpoint_redirects_before_setup(
|
|
self, client: AsyncClient
|
|
) -> None:
|
|
"""Non-setup API requests redirect to /api/setup on a fresh instance."""
|
|
response = await client.get(
|
|
"/api/auth/login",
|
|
follow_redirects=False,
|
|
)
|
|
# Middleware issues 307 redirect to /api/setup
|
|
assert response.status_code == 307
|
|
assert response.headers["location"] == "/api/setup"
|
|
|
|
async def test_health_always_reachable_before_setup(
|
|
self, client: AsyncClient
|
|
) -> None:
|
|
"""Health endpoint is always reachable even before setup."""
|
|
response = await client.get("/api/health")
|
|
assert response.status_code == 200
|
|
|
|
async def test_no_redirect_after_setup(self, client: AsyncClient) -> None:
|
|
"""Protected endpoints are reachable (no redirect) after setup."""
|
|
await client.post(
|
|
"/api/setup",
|
|
json={"master_password": "supersecret123"},
|
|
)
|
|
# /api/auth/login should now be reachable (returns 405 GET not allowed,
|
|
# not a setup redirect)
|
|
response = await client.post(
|
|
"/api/auth/login",
|
|
json={"password": "wrong"},
|
|
follow_redirects=False,
|
|
)
|
|
# 401 wrong password — not a 307 redirect
|
|
assert response.status_code == 401
|
|
|
|
|
|
class TestGetTimezone:
|
|
"""GET /api/setup/timezone — return the configured IANA timezone."""
|
|
|
|
async def test_returns_utc_before_setup(self, client: AsyncClient) -> None:
|
|
"""Timezone endpoint returns 'UTC' on a fresh database (no setup yet)."""
|
|
response = await client.get("/api/setup/timezone")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"timezone": "UTC"}
|
|
|
|
async def test_returns_configured_timezone(self, client: AsyncClient) -> None:
|
|
"""Timezone endpoint returns the value set during setup."""
|
|
await client.post(
|
|
"/api/setup",
|
|
json={
|
|
"master_password": "supersecret123",
|
|
"timezone": "Europe/Berlin",
|
|
},
|
|
)
|
|
response = await client.get("/api/setup/timezone")
|
|
assert response.status_code == 200
|
|
assert response.json() == {"timezone": "Europe/Berlin"}
|
|
|
|
async def test_endpoint_always_reachable_before_setup(
|
|
self, client: AsyncClient
|
|
) -> None:
|
|
"""Timezone endpoint is reachable before setup (no redirect)."""
|
|
response = await client.get(
|
|
"/api/setup/timezone",
|
|
follow_redirects=False,
|
|
)
|
|
# Should return 200, not a 307 redirect, because /api/setup paths
|
|
# are always allowed by the SetupRedirectMiddleware.
|
|
assert response.status_code == 200
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Setup-complete flag caching in SetupRedirectMiddleware (Task 4)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSetupCompleteCaching:
|
|
"""SetupRedirectMiddleware caches the setup_complete flag in ``app.state``."""
|
|
|
|
async def test_cache_flag_set_after_first_post_setup_request(
|
|
self,
|
|
app_and_client: tuple[object, AsyncClient],
|
|
) -> None:
|
|
"""``_setup_complete_cached`` is set to True on the first request after setup.
|
|
|
|
The ``/api/setup`` path is in ``_ALWAYS_ALLOWED`` so it bypasses the
|
|
middleware check. The first request to a non-exempt endpoint triggers
|
|
the DB query and, when setup is complete, populates the cache flag.
|
|
"""
|
|
from fastapi import FastAPI
|
|
|
|
app, client = app_and_client
|
|
assert isinstance(app, FastAPI)
|
|
|
|
# Complete setup (exempt from middleware, no flag set yet).
|
|
resp = await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
|
assert resp.status_code == 201
|
|
|
|
# Flag not yet cached — setup was via an exempt path.
|
|
assert not getattr(app.state, "_setup_complete_cached", False)
|
|
|
|
# First non-exempt request — middleware queries DB and sets the flag.
|
|
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload]
|
|
|
|
assert app.state._setup_complete_cached is True # type: ignore[attr-defined]
|
|
|
|
async def test_cached_path_skips_is_setup_complete(
|
|
self,
|
|
app_and_client: tuple[object, AsyncClient],
|
|
) -> None:
|
|
"""Subsequent requests do not call ``is_setup_complete`` once flag is cached.
|
|
|
|
After the flag is set, the middleware must not touch the database for
|
|
any further requests — even if ``is_setup_complete`` would raise.
|
|
"""
|
|
from fastapi import FastAPI
|
|
|
|
app, client = app_and_client
|
|
assert isinstance(app, FastAPI)
|
|
|
|
# Do setup and warm the cache.
|
|
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
|
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload]
|
|
assert app.state._setup_complete_cached is True # type: ignore[attr-defined]
|
|
|
|
call_count = 0
|
|
|
|
async def _counting(db): # type: ignore[no-untyped-def]
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return True
|
|
|
|
with patch("app.services.setup_service.is_setup_complete", side_effect=_counting):
|
|
await client.post(
|
|
"/api/auth/login",
|
|
json={"password": _SETUP_PAYLOAD["master_password"]},
|
|
)
|
|
|
|
# Cache was warm — is_setup_complete must not have been called.
|
|
assert call_count == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Task 0.1 — Lifespan creates the database parent directory (Task 0.1)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestLifespanDatabaseDirectoryCreation:
|
|
"""App lifespan creates the database parent directory when it does not exist."""
|
|
|
|
async def test_creates_nested_database_directory(self, tmp_path: Path) -> None:
|
|
"""Lifespan creates intermediate directories for the database path.
|
|
|
|
Verifies that a deeply-nested database path is handled correctly —
|
|
the parent directories are created before ``aiosqlite.connect`` is
|
|
called so the app does not crash on a fresh volume.
|
|
"""
|
|
nested_db = tmp_path / "deep" / "nested" / "bangui.db"
|
|
assert not nested_db.parent.exists()
|
|
|
|
settings = Settings(
|
|
database_path=str(nested_db),
|
|
fail2ban_socket="/tmp/fake.sock",
|
|
session_secret="test-lifespan-mkdir-secret",
|
|
session_duration_minutes=60,
|
|
timezone="UTC",
|
|
log_level="debug",
|
|
)
|
|
app = create_app(settings=settings)
|
|
|
|
mock_scheduler = MagicMock()
|
|
mock_scheduler.start = MagicMock()
|
|
mock_scheduler.shutdown = MagicMock()
|
|
|
|
with (
|
|
patch("app.services.geo_service.init_geoip"),
|
|
patch(
|
|
"app.services.geo_service.load_cache_from_db",
|
|
new=AsyncMock(return_value=None),
|
|
),
|
|
patch("app.tasks.health_check.register"),
|
|
patch("app.tasks.blocklist_import.register"),
|
|
patch("app.tasks.geo_cache_flush.register"),
|
|
patch("app.tasks.geo_re_resolve.register"),
|
|
patch("app.main.AsyncIOScheduler", return_value=mock_scheduler),
|
|
patch("app.main.ensure_jail_configs"),
|
|
):
|
|
async with _lifespan(app):
|
|
assert nested_db.parent.exists(), (
|
|
"Expected lifespan to create database parent directory"
|
|
)
|
|
|
|
async def test_existing_database_directory_is_not_an_error(
|
|
self, tmp_path: Path
|
|
) -> None:
|
|
"""Lifespan does not raise when the database directory already exists.
|
|
|
|
``mkdir(exist_ok=True)`` must be used so that re-starts on an existing
|
|
volume do not fail.
|
|
"""
|
|
db_path = tmp_path / "bangui.db"
|
|
# tmp_path already exists — this simulates a pre-existing volume.
|
|
|
|
settings = Settings(
|
|
database_path=str(db_path),
|
|
fail2ban_socket="/tmp/fake.sock",
|
|
session_secret="test-lifespan-exist-ok-secret",
|
|
session_duration_minutes=60,
|
|
timezone="UTC",
|
|
log_level="debug",
|
|
)
|
|
app = create_app(settings=settings)
|
|
|
|
mock_scheduler = MagicMock()
|
|
mock_scheduler.start = MagicMock()
|
|
mock_scheduler.shutdown = MagicMock()
|
|
|
|
with (
|
|
patch("app.services.geo_service.init_geoip"),
|
|
patch(
|
|
"app.services.geo_service.load_cache_from_db",
|
|
new=AsyncMock(return_value=None),
|
|
),
|
|
patch("app.tasks.health_check.register"),
|
|
patch("app.tasks.blocklist_import.register"),
|
|
patch("app.tasks.geo_cache_flush.register"),
|
|
patch("app.tasks.geo_re_resolve.register"),
|
|
patch("app.main.AsyncIOScheduler", return_value=mock_scheduler),
|
|
patch("app.main.ensure_jail_configs"),
|
|
):
|
|
# Should not raise FileExistsError or similar.
|
|
async with _lifespan(app):
|
|
assert tmp_path.exists()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Task 0.2 — Middleware redirects when app.state.db is None
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSetupRedirectMiddlewareDbNone:
|
|
"""SetupRedirectMiddleware redirects when the database is not yet available."""
|
|
|
|
async def test_redirects_to_setup_when_db_not_set(self, tmp_path: Path) -> None:
|
|
"""A ``None`` db on app.state causes a 307 redirect to ``/api/setup``.
|
|
|
|
Simulates the race window where a request arrives before the lifespan
|
|
has finished initialising the database connection.
|
|
"""
|
|
settings = Settings(
|
|
database_path=str(tmp_path / "bangui.db"),
|
|
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
|
session_secret="test-db-none-secret",
|
|
session_duration_minutes=60,
|
|
timezone="UTC",
|
|
log_level="debug",
|
|
)
|
|
app = create_app(settings=settings)
|
|
# Deliberately do NOT set app.state.db to simulate startup not complete.
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(
|
|
transport=transport, base_url="http://test"
|
|
) as ac:
|
|
response = await ac.get("/api/auth/login", follow_redirects=False)
|
|
|
|
assert response.status_code == 307
|
|
assert response.headers["location"] == "/api/setup"
|
|
|
|
async def test_health_reachable_when_db_not_set(self, tmp_path: Path) -> None:
|
|
"""Health endpoint is always reachable even when db is not initialised."""
|
|
settings = Settings(
|
|
database_path=str(tmp_path / "bangui.db"),
|
|
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
|
session_secret="test-db-none-health-secret",
|
|
session_duration_minutes=60,
|
|
timezone="UTC",
|
|
log_level="debug",
|
|
)
|
|
app = create_app(settings=settings)
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(
|
|
transport=transport, base_url="http://test"
|
|
) as ac:
|
|
response = await ac.get("/api/health")
|
|
|
|
assert response.status_code == 200
|
|
|