"""Tests for the dashboard router (GET /api/dashboard/status, GET /api/dashboard/bans).""" from __future__ import annotations from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import aiosqlite import pytest from httpx import ASGITransport, AsyncClient from app.config import Settings from app.db import init_db from app.main import create_app from app.models.ban import ( DashboardBanItem, DashboardBanListResponse, ) from app.models.server import ServerStatus # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- _SETUP_PAYLOAD = { "master_password": "testpassword1", "database_path": "bangui.db", "fail2ban_socket": "/var/run/fail2ban/fail2ban.sock", "timezone": "UTC", "session_duration_minutes": 60, } @pytest.fixture async def dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc] """Provide an authenticated ``AsyncClient`` with a pre-seeded server status. Unlike the shared ``client`` fixture this one also exposes access to ``app.state`` via the app instance so we can seed the status cache. """ settings = Settings( database_path=str(tmp_path / "dashboard_test.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", session_secret="test-dashboard-secret", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path) db.row_factory = aiosqlite.Row await init_db(db) app.state.db = db # Pre-seed a server status so the endpoint has something to return. app.state.server_status = ServerStatus( online=True, version="1.0.2", active_jails=2, total_bans=10, total_failures=5, ) # Provide a stub HTTP session so ban/access endpoints can access app.state.http_session. app.state.http_session = MagicMock() transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: # Complete setup so the middleware doesn't redirect. resp = await ac.post("/api/setup", json=_SETUP_PAYLOAD) assert resp.status_code == 201 # Login to get a session cookie. login_resp = await ac.post( "/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}, ) assert login_resp.status_code == 200 yield ac await db.close() @pytest.fixture async def offline_dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc] """Like ``dashboard_client`` but with an offline server status.""" settings = Settings( database_path=str(tmp_path / "dashboard_offline_test.db"), fail2ban_socket="/tmp/fake_fail2ban.sock", session_secret="test-dashboard-offline-secret", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path) db.row_factory = aiosqlite.Row await init_db(db) app.state.db = db app.state.server_status = ServerStatus(online=False) app.state.http_session = MagicMock() transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: resp = await ac.post("/api/setup", json=_SETUP_PAYLOAD) assert resp.status_code == 201 login_resp = await ac.post( "/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}, ) assert login_resp.status_code == 200 yield ac await db.close() # --------------------------------------------------------------------------- # Tests # --------------------------------------------------------------------------- class TestDashboardStatus: """GET /api/dashboard/status.""" async def test_returns_200_when_authenticated( self, dashboard_client: AsyncClient ) -> None: """Authenticated request returns HTTP 200.""" response = await dashboard_client.get("/api/dashboard/status") assert response.status_code == 200 async def test_returns_401_when_unauthenticated( self, client: AsyncClient ) -> None: """Unauthenticated request returns HTTP 401.""" # Complete setup so the middleware allows the request through. await client.post("/api/setup", json=_SETUP_PAYLOAD) response = await client.get("/api/dashboard/status") assert response.status_code == 401 async def test_response_shape_when_online( self, dashboard_client: AsyncClient ) -> None: """Response contains the expected ``status`` object shape.""" response = await dashboard_client.get("/api/dashboard/status") body = response.json() assert "status" in body status = body["status"] assert "online" in status assert "version" in status assert "active_jails" in status assert "total_bans" in status assert "total_failures" in status async def test_cached_values_returned_when_online( self, dashboard_client: AsyncClient ) -> None: """Endpoint returns the exact values from ``app.state.server_status``.""" response = await dashboard_client.get("/api/dashboard/status") status = response.json()["status"] assert status["online"] is True assert status["version"] == "1.0.2" assert status["active_jails"] == 2 assert status["total_bans"] == 10 assert status["total_failures"] == 5 async def test_offline_status_returned_correctly( self, offline_dashboard_client: AsyncClient ) -> None: """Endpoint returns online=False when the cache holds an offline snapshot.""" response = await offline_dashboard_client.get("/api/dashboard/status") assert response.status_code == 200 status = response.json()["status"] assert status["online"] is False assert status["version"] is None assert status["active_jails"] == 0 assert status["total_bans"] == 0 assert status["total_failures"] == 0 async def test_returns_offline_when_state_not_initialised( self, client: AsyncClient ) -> None: """Endpoint returns online=False as a safe default if the cache is absent.""" # Setup + login so the endpoint is reachable. await client.post("/api/setup", json=_SETUP_PAYLOAD) await client.post( "/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}, ) # server_status is not set on app.state in the shared `client` fixture. response = await client.get("/api/dashboard/status") assert response.status_code == 200 status = response.json()["status"] assert status["online"] is False # --------------------------------------------------------------------------- # Dashboard bans endpoint # --------------------------------------------------------------------------- def _make_ban_list_response(n: int = 2) -> DashboardBanListResponse: """Build a mock DashboardBanListResponse with *n* items.""" items = [ DashboardBanItem( ip=f"1.2.3.{i}", jail="sshd", banned_at="2026-03-01T10:00:00+00:00", service=None, country_code="DE", country_name="Germany", asn="AS3320", org="Telekom", ban_count=1, origin="selfblock", ) for i in range(n) ] return DashboardBanListResponse(items=items, total=n, page=1, page_size=100) class TestDashboardBans: """GET /api/dashboard/bans.""" async def test_returns_200_when_authenticated( self, dashboard_client: AsyncClient ) -> None: """Authenticated request returns HTTP 200.""" with patch( "app.routers.dashboard.ban_service.list_bans", new=AsyncMock(return_value=_make_ban_list_response()), ): response = await dashboard_client.get("/api/dashboard/bans") assert response.status_code == 200 async def test_returns_401_when_unauthenticated( self, client: AsyncClient ) -> None: """Unauthenticated request returns HTTP 401.""" await client.post("/api/setup", json=_SETUP_PAYLOAD) response = await client.get("/api/dashboard/bans") assert response.status_code == 401 async def test_response_contains_items_and_total( self, dashboard_client: AsyncClient ) -> None: """Response body contains ``items`` list and ``total`` count.""" with patch( "app.routers.dashboard.ban_service.list_bans", new=AsyncMock(return_value=_make_ban_list_response(3)), ): response = await dashboard_client.get("/api/dashboard/bans") body = response.json() assert "items" in body assert "total" in body assert body["total"] == 3 assert len(body["items"]) == 3 async def test_default_range_is_24h(self, dashboard_client: AsyncClient) -> None: """If no ``range`` param is provided the default ``24h`` preset is used.""" mock_list = AsyncMock(return_value=_make_ban_list_response()) with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list): await dashboard_client.get("/api/dashboard/bans") called_range = mock_list.call_args[0][1] assert called_range == "24h" async def test_accepts_time_range_param( self, dashboard_client: AsyncClient ) -> None: """The ``range`` query parameter is forwarded to ban_service.""" mock_list = AsyncMock(return_value=_make_ban_list_response()) with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list): await dashboard_client.get("/api/dashboard/bans?range=7d") called_range = mock_list.call_args[0][1] assert called_range == "7d" async def test_empty_ban_list_returns_zero_total( self, dashboard_client: AsyncClient ) -> None: """Returns ``total=0`` and empty ``items`` when no bans are in range.""" empty = DashboardBanListResponse(items=[], total=0, page=1, page_size=100) with patch( "app.routers.dashboard.ban_service.list_bans", new=AsyncMock(return_value=empty), ): response = await dashboard_client.get("/api/dashboard/bans") body = response.json() assert body["total"] == 0 assert body["items"] == [] async def test_item_shape_is_correct(self, dashboard_client: AsyncClient) -> None: """Each item in ``items`` has the expected fields.""" with patch( "app.routers.dashboard.ban_service.list_bans", new=AsyncMock(return_value=_make_ban_list_response(1)), ): response = await dashboard_client.get("/api/dashboard/bans") item = response.json()["items"][0] assert "ip" in item assert "jail" in item assert "banned_at" in item assert "ban_count" in item # --------------------------------------------------------------------------- # Bans by country endpoint # --------------------------------------------------------------------------- def _make_bans_by_country_response() -> object: """Build a stub BansByCountryResponse.""" from app.models.ban import BansByCountryResponse items = [ DashboardBanItem( ip="1.2.3.4", jail="sshd", banned_at="2026-03-01T10:00:00+00:00", service=None, country_code="DE", country_name="Germany", asn="AS3320", org="Telekom", ban_count=1, origin="selfblock", ), DashboardBanItem( ip="5.6.7.8", jail="blocklist-import", banned_at="2026-03-01T10:05:00+00:00", service=None, country_code="US", country_name="United States", asn="AS15169", org="Google LLC", ban_count=2, origin="blocklist", ), ] return BansByCountryResponse( countries={"DE": 1, "US": 1}, country_names={"DE": "Germany", "US": "United States"}, bans=items, total=2, ) @pytest.mark.anyio class TestBansByCountry: """GET /api/dashboard/bans/by-country.""" async def test_returns_200_when_authenticated( self, dashboard_client: AsyncClient ) -> None: """Authenticated request returns HTTP 200.""" with patch( "app.routers.dashboard.ban_service.bans_by_country", new=AsyncMock(return_value=_make_bans_by_country_response()), ): response = await dashboard_client.get("/api/dashboard/bans/by-country") assert response.status_code == 200 async def test_returns_401_when_unauthenticated( self, client: AsyncClient ) -> None: """Unauthenticated request returns HTTP 401.""" await client.post("/api/setup", json=_SETUP_PAYLOAD) response = await client.get("/api/dashboard/bans/by-country") assert response.status_code == 401 async def test_response_shape(self, dashboard_client: AsyncClient) -> None: """Response body contains countries, country_names, bans, total.""" with patch( "app.routers.dashboard.ban_service.bans_by_country", new=AsyncMock(return_value=_make_bans_by_country_response()), ): response = await dashboard_client.get("/api/dashboard/bans/by-country") body = response.json() assert "countries" in body assert "country_names" in body assert "bans" in body assert "total" in body assert body["total"] == 2 assert body["countries"]["DE"] == 1 assert body["countries"]["US"] == 1 assert body["country_names"]["DE"] == "Germany" async def test_accepts_time_range_param( self, dashboard_client: AsyncClient ) -> None: """The range query parameter is forwarded to ban_service.""" mock_fn = AsyncMock(return_value=_make_bans_by_country_response()) with patch( "app.routers.dashboard.ban_service.bans_by_country", new=mock_fn ): await dashboard_client.get("/api/dashboard/bans/by-country?range=7d") called_range = mock_fn.call_args[0][1] assert called_range == "7d" async def test_empty_window_returns_empty_response( self, dashboard_client: AsyncClient ) -> None: """Empty time range returns empty countries dict and bans list.""" from app.models.ban import BansByCountryResponse empty = BansByCountryResponse( countries={}, country_names={}, bans=[], total=0, ) with patch( "app.routers.dashboard.ban_service.bans_by_country", new=AsyncMock(return_value=empty), ): response = await dashboard_client.get("/api/dashboard/bans/by-country") body = response.json() assert body["total"] == 0 assert body["countries"] == {} assert body["bans"] == [] # --------------------------------------------------------------------------- # Origin field tests # --------------------------------------------------------------------------- class TestDashboardBansOriginField: """Verify that the ``origin`` field is present in API responses.""" async def test_origin_present_in_ban_list_items( self, dashboard_client: AsyncClient ) -> None: """Each item in ``/api/dashboard/bans`` carries an ``origin`` field.""" with patch( "app.routers.dashboard.ban_service.list_bans", new=AsyncMock(return_value=_make_ban_list_response(1)), ): response = await dashboard_client.get("/api/dashboard/bans") item = response.json()["items"][0] assert "origin" in item assert item["origin"] in ("blocklist", "selfblock") async def test_selfblock_origin_serialised_correctly( self, dashboard_client: AsyncClient ) -> None: """A ban from a non-blocklist jail serialises as ``"selfblock"``.""" with patch( "app.routers.dashboard.ban_service.list_bans", new=AsyncMock(return_value=_make_ban_list_response(1)), ): response = await dashboard_client.get("/api/dashboard/bans") item = response.json()["items"][0] assert item["jail"] == "sshd" assert item["origin"] == "selfblock" async def test_origin_present_in_bans_by_country( self, dashboard_client: AsyncClient ) -> None: """Each ban in ``/api/dashboard/bans/by-country`` carries an ``origin``.""" with patch( "app.routers.dashboard.ban_service.bans_by_country", new=AsyncMock(return_value=_make_bans_by_country_response()), ): response = await dashboard_client.get("/api/dashboard/bans/by-country") bans = response.json()["bans"] assert all("origin" in ban for ban in bans) origins = {ban["origin"] for ban in bans} assert origins == {"blocklist", "selfblock"} async def test_blocklist_origin_serialised_correctly( self, dashboard_client: AsyncClient ) -> None: """A ban from the ``blocklist-import`` jail serialises as ``"blocklist"``.""" with patch( "app.routers.dashboard.ban_service.bans_by_country", new=AsyncMock(return_value=_make_bans_by_country_response()), ): response = await dashboard_client.get("/api/dashboard/bans/by-country") bans = response.json()["bans"] blocklist_ban = next(b for b in bans if b["jail"] == "blocklist-import") assert blocklist_ban["origin"] == "blocklist" # --------------------------------------------------------------------------- # Origin filter query parameter tests # --------------------------------------------------------------------------- class TestOriginFilterParam: """Verify that the ``origin`` query parameter is forwarded to the service.""" async def test_bans_origin_blocklist_forwarded_to_service( self, dashboard_client: AsyncClient ) -> None: """``?origin=blocklist`` is passed to ``ban_service.list_bans``.""" mock_list = AsyncMock(return_value=_make_ban_list_response()) with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list): await dashboard_client.get("/api/dashboard/bans?origin=blocklist") _, kwargs = mock_list.call_args assert kwargs.get("origin") == "blocklist" async def test_bans_origin_selfblock_forwarded_to_service( self, dashboard_client: AsyncClient ) -> None: """``?origin=selfblock`` is passed to ``ban_service.list_bans``.""" mock_list = AsyncMock(return_value=_make_ban_list_response()) with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list): await dashboard_client.get("/api/dashboard/bans?origin=selfblock") _, kwargs = mock_list.call_args assert kwargs.get("origin") == "selfblock" async def test_bans_no_origin_param_defaults_to_none( self, dashboard_client: AsyncClient ) -> None: """Omitting ``origin`` passes ``None`` to the service (no filtering).""" mock_list = AsyncMock(return_value=_make_ban_list_response()) with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list): await dashboard_client.get("/api/dashboard/bans") _, kwargs = mock_list.call_args assert kwargs.get("origin") is None async def test_bans_invalid_origin_returns_422( self, dashboard_client: AsyncClient ) -> None: """An invalid ``origin`` value returns HTTP 422 Unprocessable Entity.""" response = await dashboard_client.get("/api/dashboard/bans?origin=invalid") assert response.status_code == 422 async def test_by_country_origin_blocklist_forwarded( self, dashboard_client: AsyncClient ) -> None: """``?origin=blocklist`` is passed to ``ban_service.bans_by_country``.""" mock_fn = AsyncMock(return_value=_make_bans_by_country_response()) with patch( "app.routers.dashboard.ban_service.bans_by_country", new=mock_fn ): await dashboard_client.get( "/api/dashboard/bans/by-country?origin=blocklist" ) _, kwargs = mock_fn.call_args assert kwargs.get("origin") == "blocklist" async def test_by_country_no_origin_defaults_to_none( self, dashboard_client: AsyncClient ) -> None: """Omitting ``origin`` passes ``None`` to ``bans_by_country``.""" mock_fn = AsyncMock(return_value=_make_bans_by_country_response()) with patch( "app.routers.dashboard.ban_service.bans_by_country", new=mock_fn ): await dashboard_client.get("/api/dashboard/bans/by-country") _, kwargs = mock_fn.call_args assert kwargs.get("origin") is None # --------------------------------------------------------------------------- # Ban trend endpoint # --------------------------------------------------------------------------- def _make_ban_trend_response(n_buckets: int = 24) -> object: """Build a stub :class:`~app.models.ban.BanTrendResponse`.""" from app.models.ban import BanTrendBucket, BanTrendResponse buckets = [ BanTrendBucket(timestamp=f"2026-03-01T{i:02d}:00:00+00:00", count=i) for i in range(n_buckets) ] return BanTrendResponse(buckets=buckets, bucket_size="1h") @pytest.mark.anyio class TestBanTrend: """GET /api/dashboard/bans/trend.""" async def test_returns_200_when_authenticated( self, dashboard_client: AsyncClient ) -> None: """Authenticated request returns HTTP 200.""" with patch( "app.routers.dashboard.ban_service.ban_trend", new=AsyncMock(return_value=_make_ban_trend_response()), ): response = await dashboard_client.get("/api/dashboard/bans/trend") assert response.status_code == 200 async def test_returns_401_when_unauthenticated( self, client: AsyncClient ) -> None: """Unauthenticated request returns HTTP 401.""" await client.post("/api/setup", json=_SETUP_PAYLOAD) response = await client.get("/api/dashboard/bans/trend") assert response.status_code == 401 async def test_response_shape(self, dashboard_client: AsyncClient) -> None: """Response body contains ``buckets`` list and ``bucket_size`` string.""" with patch( "app.routers.dashboard.ban_service.ban_trend", new=AsyncMock(return_value=_make_ban_trend_response(24)), ): response = await dashboard_client.get("/api/dashboard/bans/trend") body = response.json() assert "buckets" in body assert "bucket_size" in body assert len(body["buckets"]) == 24 assert body["bucket_size"] == "1h" async def test_each_bucket_has_timestamp_and_count( self, dashboard_client: AsyncClient ) -> None: """Every element of ``buckets`` has ``timestamp`` and ``count``.""" with patch( "app.routers.dashboard.ban_service.ban_trend", new=AsyncMock(return_value=_make_ban_trend_response(3)), ): response = await dashboard_client.get("/api/dashboard/bans/trend") for bucket in response.json()["buckets"]: assert "timestamp" in bucket assert "count" in bucket assert isinstance(bucket["count"], int) async def test_default_range_is_24h(self, dashboard_client: AsyncClient) -> None: """Omitting ``range`` defaults to ``24h``.""" mock_fn = AsyncMock(return_value=_make_ban_trend_response()) with patch("app.routers.dashboard.ban_service.ban_trend", new=mock_fn): await dashboard_client.get("/api/dashboard/bans/trend") called_range = mock_fn.call_args[0][1] assert called_range == "24h" async def test_accepts_range_param(self, dashboard_client: AsyncClient) -> None: """The ``range`` query parameter is forwarded to the service.""" mock_fn = AsyncMock(return_value=_make_ban_trend_response(28)) with patch("app.routers.dashboard.ban_service.ban_trend", new=mock_fn): await dashboard_client.get("/api/dashboard/bans/trend?range=7d") called_range = mock_fn.call_args[0][1] assert called_range == "7d" async def test_origin_param_forwarded(self, dashboard_client: AsyncClient) -> None: """``?origin=blocklist`` is passed as a keyword arg to the service.""" mock_fn = AsyncMock(return_value=_make_ban_trend_response()) with patch("app.routers.dashboard.ban_service.ban_trend", new=mock_fn): await dashboard_client.get( "/api/dashboard/bans/trend?origin=blocklist" ) _, kwargs = mock_fn.call_args assert kwargs.get("origin") == "blocklist" async def test_no_origin_defaults_to_none( self, dashboard_client: AsyncClient ) -> None: """Omitting ``origin`` passes ``None`` to the service.""" mock_fn = AsyncMock(return_value=_make_ban_trend_response()) with patch("app.routers.dashboard.ban_service.ban_trend", new=mock_fn): await dashboard_client.get("/api/dashboard/bans/trend") _, kwargs = mock_fn.call_args assert kwargs.get("origin") is None async def test_invalid_range_returns_422( self, dashboard_client: AsyncClient ) -> None: """An invalid ``range`` value returns HTTP 422.""" response = await dashboard_client.get( "/api/dashboard/bans/trend?range=invalid" ) assert response.status_code == 422 async def test_empty_buckets_response(self, dashboard_client: AsyncClient) -> None: """Empty bucket list is serialised correctly.""" from app.models.ban import BanTrendResponse empty = BanTrendResponse(buckets=[], bucket_size="1h") with patch( "app.routers.dashboard.ban_service.ban_trend", new=AsyncMock(return_value=empty), ): response = await dashboard_client.get("/api/dashboard/bans/trend") body = response.json() assert body["buckets"] == [] assert body["bucket_size"] == "1h"