"""Tests for the setup router (POST /api/setup, GET /api/setup, GET /api/setup/timezone).""" from __future__ import annotations from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import aiosqlite import pytest from httpx import ASGITransport, AsyncClient from app.config import Settings from app.db import init_db from app.main import _lifespan, create_app # --------------------------------------------------------------------------- # Shared setup payload # --------------------------------------------------------------------------- _SETUP_PAYLOAD: dict[str, object] = { "master_password": "supersecret123", "database_path": "bangui.db", "fail2ban_socket": "/var/run/fail2ban/fail2ban.sock", "timezone": "UTC", "session_duration_minutes": 60, } # --------------------------------------------------------------------------- # Fixture for tests that need direct access to app.state # --------------------------------------------------------------------------- @pytest.fixture async def app_and_client(tmp_path: Path) -> tuple[object, AsyncClient]: # type: ignore[misc] """Yield ``(app, client)`` for tests that inspect ``app.state`` directly. Args: tmp_path: Pytest-provided isolated temporary directory. Yields: A tuple of ``(FastAPI app instance, AsyncClient)``. """ settings = Settings( database_path=str(tmp_path / "setup_cache_test.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", session_secret="test-setup-cache-secret", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path) db.row_factory = aiosqlite.Row await init_db(db) app.state.db = db transport: ASGITransport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: yield app, ac await db.close() class TestGetSetupStatus: """GET /api/setup — check setup completion state.""" async def test_returns_not_completed_on_fresh_db(self, client: AsyncClient) -> None: """Status endpoint reports setup not done on a fresh database.""" response = await client.get("/api/setup") assert response.status_code == 200 assert response.json() == {"completed": False} async def test_returns_completed_after_setup(self, client: AsyncClient) -> None: """Status endpoint reports setup done after POST /api/setup.""" await client.post( "/api/setup", json={ "master_password": "supersecret123", "database_path": "bangui.db", "fail2ban_socket": "/var/run/fail2ban/fail2ban.sock", "timezone": "UTC", "session_duration_minutes": 60, }, ) response = await client.get("/api/setup") assert response.status_code == 200 assert response.json() == {"completed": True} class TestPostSetup: """POST /api/setup — run the first-run configuration wizard.""" async def test_accepts_valid_payload(self, client: AsyncClient) -> None: """Setup endpoint returns 201 for a valid first-run payload.""" response = await client.post( "/api/setup", json={ "master_password": "supersecret123", "database_path": "bangui.db", "fail2ban_socket": "/var/run/fail2ban/fail2ban.sock", "timezone": "UTC", "session_duration_minutes": 60, }, ) assert response.status_code == 201 body = response.json() assert "message" in body async def test_rejects_short_password(self, client: AsyncClient) -> None: """Setup endpoint rejects passwords shorter than 8 characters.""" response = await client.post( "/api/setup", json={"master_password": "short"}, ) assert response.status_code == 422 async def test_rejects_second_call(self, client: AsyncClient) -> None: """Setup endpoint returns 409 if setup has already been completed.""" payload = { "master_password": "supersecret123", "database_path": "bangui.db", "fail2ban_socket": "/var/run/fail2ban/fail2ban.sock", "timezone": "UTC", "session_duration_minutes": 60, } first = await client.post("/api/setup", json=payload) assert first.status_code == 201 second = await client.post("/api/setup", json=payload) assert second.status_code == 409 async def test_accepts_defaults_for_optional_fields( self, client: AsyncClient ) -> None: """Setup endpoint uses defaults when optional fields are omitted.""" response = await client.post( "/api/setup", json={"master_password": "supersecret123"}, ) assert response.status_code == 201 class TestSetupRedirectMiddleware: """Verify that the setup-redirect middleware enforces setup-first.""" async def test_protected_endpoint_redirects_before_setup( self, client: AsyncClient ) -> None: """Non-setup API requests redirect to /api/setup on a fresh instance.""" response = await client.get( "/api/auth/login", follow_redirects=False, ) # Middleware issues 307 redirect to /api/setup assert response.status_code == 307 assert response.headers["location"] == "/api/setup" async def test_health_always_reachable_before_setup( self, client: AsyncClient ) -> None: """Health endpoint is always reachable even before setup.""" response = await client.get("/api/health") assert response.status_code == 200 async def test_no_redirect_after_setup(self, client: AsyncClient) -> None: """Protected endpoints are reachable (no redirect) after setup.""" await client.post( "/api/setup", json={"master_password": "supersecret123"}, ) # /api/auth/login should now be reachable (returns 405 GET not allowed, # not a setup redirect) response = await client.post( "/api/auth/login", json={"password": "wrong"}, follow_redirects=False, ) # 401 wrong password — not a 307 redirect assert response.status_code == 401 class TestGetTimezone: """GET /api/setup/timezone — return the configured IANA timezone.""" async def test_returns_utc_before_setup(self, client: AsyncClient) -> None: """Timezone endpoint returns 'UTC' on a fresh database (no setup yet).""" response = await client.get("/api/setup/timezone") assert response.status_code == 200 assert response.json() == {"timezone": "UTC"} async def test_returns_configured_timezone(self, client: AsyncClient) -> None: """Timezone endpoint returns the value set during setup.""" await client.post( "/api/setup", json={ "master_password": "supersecret123", "timezone": "Europe/Berlin", }, ) response = await client.get("/api/setup/timezone") assert response.status_code == 200 assert response.json() == {"timezone": "Europe/Berlin"} async def test_endpoint_always_reachable_before_setup( self, client: AsyncClient ) -> None: """Timezone endpoint is reachable before setup (no redirect).""" response = await client.get( "/api/setup/timezone", follow_redirects=False, ) # Should return 200, not a 307 redirect, because /api/setup paths # are always allowed by the SetupRedirectMiddleware. assert response.status_code == 200 # --------------------------------------------------------------------------- # Setup-complete flag caching in SetupRedirectMiddleware (Task 4) # --------------------------------------------------------------------------- class TestSetupCompleteCaching: """SetupRedirectMiddleware caches the setup_complete flag in ``app.state``.""" async def test_cache_flag_set_after_first_post_setup_request( self, app_and_client: tuple[object, AsyncClient], ) -> None: """``_setup_complete_cached`` is set to True on the first request after setup. The ``/api/setup`` path is in ``_ALWAYS_ALLOWED`` so it bypasses the middleware check. The first request to a non-exempt endpoint triggers the DB query and, when setup is complete, populates the cache flag. """ from fastapi import FastAPI app, client = app_and_client assert isinstance(app, FastAPI) # Complete setup (exempt from middleware, no flag set yet). resp = await client.post("/api/setup", json=_SETUP_PAYLOAD) assert resp.status_code == 201 # Flag not yet cached — setup was via an exempt path. assert not getattr(app.state, "_setup_complete_cached", False) # First non-exempt request — middleware queries DB and sets the flag. await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload] assert app.state._setup_complete_cached is True # type: ignore[attr-defined] async def test_cached_path_skips_is_setup_complete( self, app_and_client: tuple[object, AsyncClient], ) -> None: """Subsequent requests do not call ``is_setup_complete`` once flag is cached. After the flag is set, the middleware must not touch the database for any further requests — even if ``is_setup_complete`` would raise. """ from fastapi import FastAPI app, client = app_and_client assert isinstance(app, FastAPI) # Do setup and warm the cache. await client.post("/api/setup", json=_SETUP_PAYLOAD) await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload] assert app.state._setup_complete_cached is True # type: ignore[attr-defined] call_count = 0 async def _counting(db): # type: ignore[no-untyped-def] nonlocal call_count call_count += 1 return True with patch("app.services.setup_service.is_setup_complete", side_effect=_counting): await client.post( "/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}, ) # Cache was warm — is_setup_complete must not have been called. assert call_count == 0 # --------------------------------------------------------------------------- # Task 0.1 — Lifespan creates the database parent directory (Task 0.1) # --------------------------------------------------------------------------- class TestLifespanDatabaseDirectoryCreation: """App lifespan creates the database parent directory when it does not exist.""" async def test_creates_nested_database_directory(self, tmp_path: Path) -> None: """Lifespan creates intermediate directories for the database path. Verifies that a deeply-nested database path is handled correctly — the parent directories are created before ``aiosqlite.connect`` is called so the app does not crash on a fresh volume. """ nested_db = tmp_path / "deep" / "nested" / "bangui.db" assert not nested_db.parent.exists() settings = Settings( database_path=str(nested_db), fail2ban_socket="/tmp/fake.sock", session_secret="test-lifespan-mkdir-secret", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) mock_scheduler = MagicMock() mock_scheduler.start = MagicMock() mock_scheduler.shutdown = MagicMock() with ( patch("app.services.geo_service.init_geoip"), patch( "app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None), ), patch("app.tasks.health_check.register"), patch("app.tasks.blocklist_import.register"), patch("app.tasks.geo_cache_flush.register"), patch("app.tasks.geo_re_resolve.register"), patch("app.main.AsyncIOScheduler", return_value=mock_scheduler), patch("app.main.ensure_jail_configs"), ): async with _lifespan(app): assert nested_db.parent.exists(), ( "Expected lifespan to create database parent directory" ) async def test_existing_database_directory_is_not_an_error( self, tmp_path: Path ) -> None: """Lifespan does not raise when the database directory already exists. ``mkdir(exist_ok=True)`` must be used so that re-starts on an existing volume do not fail. """ db_path = tmp_path / "bangui.db" # tmp_path already exists — this simulates a pre-existing volume. settings = Settings( database_path=str(db_path), fail2ban_socket="/tmp/fake.sock", session_secret="test-lifespan-exist-ok-secret", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) mock_scheduler = MagicMock() mock_scheduler.start = MagicMock() mock_scheduler.shutdown = MagicMock() with ( patch("app.services.geo_service.init_geoip"), patch( "app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None), ), patch("app.tasks.health_check.register"), patch("app.tasks.blocklist_import.register"), patch("app.tasks.geo_cache_flush.register"), patch("app.tasks.geo_re_resolve.register"), patch("app.main.AsyncIOScheduler", return_value=mock_scheduler), patch("app.main.ensure_jail_configs"), ): # Should not raise FileExistsError or similar. async with _lifespan(app): assert tmp_path.exists() # --------------------------------------------------------------------------- # Task 0.2 — Middleware redirects when app.state.db is None # --------------------------------------------------------------------------- class TestSetupRedirectMiddlewareDbNone: """SetupRedirectMiddleware redirects when the database is not yet available.""" async def test_redirects_to_setup_when_db_not_set(self, tmp_path: Path) -> None: """A ``None`` db on app.state causes a 307 redirect to ``/api/setup``. Simulates the race window where a request arrives before the lifespan has finished initialising the database connection. """ settings = Settings( database_path=str(tmp_path / "bangui.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", session_secret="test-db-none-secret", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) # Deliberately do NOT set app.state.db to simulate startup not complete. transport = ASGITransport(app=app) async with AsyncClient( transport=transport, base_url="http://test" ) as ac: response = await ac.get("/api/auth/login", follow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "/api/setup" async def test_health_reachable_when_db_not_set(self, tmp_path: Path) -> None: """Health endpoint is always reachable even when db is not initialised.""" settings = Settings( database_path=str(tmp_path / "bangui.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", session_secret="test-db-none-health-secret", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) transport = ASGITransport(app=app) async with AsyncClient( transport=transport, base_url="http://test" ) as ac: response = await ac.get("/api/health") assert response.status_code == 200