- Task 1: Mark imported blocklist IP addresses
- Add BanOrigin type and _derive_origin() to ban.py model
- Populate origin field in ban_service list_bans() and bans_by_country()
- BanTable and MapPage companion table show origin badge column
- Tests: origin derivation in test_ban_service.py and test_dashboard.py
- Task 2: Add origin filter to dashboard and world map
- ban_service: _origin_sql_filter() helper; origin param on list_bans()
and bans_by_country()
- dashboard router: optional origin query param forwarded to service
- Frontend: BanOriginFilter type + BAN_ORIGIN_FILTER_LABELS in ban.ts
- fetchBans / fetchBansByCountry forward origin to API
- useBans / useMapData accept and pass origin; page resets on change
- BanTable accepts origin prop; DashboardPage adds segmented filter
- MapPage adds origin Select next to time-range picker
- Tests: origin filter assertions in test_ban_service and test_dashboard
580 lines
21 KiB
Python
580 lines
21 KiB
Python
"""Tests for the dashboard router (GET /api/dashboard/status, GET /api/dashboard/bans)."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import aiosqlite
|
|
import pytest
|
|
from httpx import ASGITransport, AsyncClient
|
|
|
|
from app.config import Settings
|
|
from app.db import init_db
|
|
from app.main import create_app
|
|
from app.models.ban import (
|
|
DashboardBanItem,
|
|
DashboardBanListResponse,
|
|
)
|
|
from app.models.server import ServerStatus
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_SETUP_PAYLOAD = {
|
|
"master_password": "testpassword1",
|
|
"database_path": "bangui.db",
|
|
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
|
"timezone": "UTC",
|
|
"session_duration_minutes": 60,
|
|
}
|
|
|
|
|
|
@pytest.fixture
|
|
async def dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
|
"""Provide an authenticated ``AsyncClient`` with a pre-seeded server status.
|
|
|
|
Unlike the shared ``client`` fixture this one also exposes access to
|
|
``app.state`` via the app instance so we can seed the status cache.
|
|
"""
|
|
settings = Settings(
|
|
database_path=str(tmp_path / "dashboard_test.db"),
|
|
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
|
session_secret="test-dashboard-secret",
|
|
session_duration_minutes=60,
|
|
timezone="UTC",
|
|
log_level="debug",
|
|
)
|
|
app = create_app(settings=settings)
|
|
|
|
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
|
|
db.row_factory = aiosqlite.Row
|
|
await init_db(db)
|
|
app.state.db = db
|
|
|
|
# Pre-seed a server status so the endpoint has something to return.
|
|
app.state.server_status = ServerStatus(
|
|
online=True,
|
|
version="1.0.2",
|
|
active_jails=2,
|
|
total_bans=10,
|
|
total_failures=5,
|
|
)
|
|
# Provide a stub HTTP session so ban/access endpoints can access app.state.http_session.
|
|
app.state.http_session = MagicMock()
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
# Complete setup so the middleware doesn't redirect.
|
|
resp = await ac.post("/api/setup", json=_SETUP_PAYLOAD)
|
|
assert resp.status_code == 201
|
|
|
|
# Login to get a session cookie.
|
|
login_resp = await ac.post(
|
|
"/api/auth/login",
|
|
json={"password": _SETUP_PAYLOAD["master_password"]},
|
|
)
|
|
assert login_resp.status_code == 200
|
|
|
|
yield ac
|
|
|
|
await db.close()
|
|
|
|
|
|
@pytest.fixture
|
|
async def offline_dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
|
"""Like ``dashboard_client`` but with an offline server status."""
|
|
settings = Settings(
|
|
database_path=str(tmp_path / "dashboard_offline_test.db"),
|
|
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
|
session_secret="test-dashboard-offline-secret",
|
|
session_duration_minutes=60,
|
|
timezone="UTC",
|
|
log_level="debug",
|
|
)
|
|
app = create_app(settings=settings)
|
|
|
|
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
|
|
db.row_factory = aiosqlite.Row
|
|
await init_db(db)
|
|
app.state.db = db
|
|
|
|
app.state.server_status = ServerStatus(online=False)
|
|
app.state.http_session = MagicMock()
|
|
|
|
transport = ASGITransport(app=app)
|
|
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
|
resp = await ac.post("/api/setup", json=_SETUP_PAYLOAD)
|
|
assert resp.status_code == 201
|
|
|
|
login_resp = await ac.post(
|
|
"/api/auth/login",
|
|
json={"password": _SETUP_PAYLOAD["master_password"]},
|
|
)
|
|
assert login_resp.status_code == 200
|
|
|
|
yield ac
|
|
|
|
await db.close()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDashboardStatus:
|
|
"""GET /api/dashboard/status."""
|
|
|
|
async def test_returns_200_when_authenticated(
|
|
self, dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""Authenticated request returns HTTP 200."""
|
|
response = await dashboard_client.get("/api/dashboard/status")
|
|
assert response.status_code == 200
|
|
|
|
async def test_returns_401_when_unauthenticated(
|
|
self, client: AsyncClient
|
|
) -> None:
|
|
"""Unauthenticated request returns HTTP 401."""
|
|
# Complete setup so the middleware allows the request through.
|
|
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
|
response = await client.get("/api/dashboard/status")
|
|
assert response.status_code == 401
|
|
|
|
async def test_response_shape_when_online(
|
|
self, dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""Response contains the expected ``status`` object shape."""
|
|
response = await dashboard_client.get("/api/dashboard/status")
|
|
body = response.json()
|
|
|
|
assert "status" in body
|
|
status = body["status"]
|
|
assert "online" in status
|
|
assert "version" in status
|
|
assert "active_jails" in status
|
|
assert "total_bans" in status
|
|
assert "total_failures" in status
|
|
|
|
async def test_cached_values_returned_when_online(
|
|
self, dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""Endpoint returns the exact values from ``app.state.server_status``."""
|
|
response = await dashboard_client.get("/api/dashboard/status")
|
|
status = response.json()["status"]
|
|
|
|
assert status["online"] is True
|
|
assert status["version"] == "1.0.2"
|
|
assert status["active_jails"] == 2
|
|
assert status["total_bans"] == 10
|
|
assert status["total_failures"] == 5
|
|
|
|
async def test_offline_status_returned_correctly(
|
|
self, offline_dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""Endpoint returns online=False when the cache holds an offline snapshot."""
|
|
response = await offline_dashboard_client.get("/api/dashboard/status")
|
|
assert response.status_code == 200
|
|
status = response.json()["status"]
|
|
|
|
assert status["online"] is False
|
|
assert status["version"] is None
|
|
assert status["active_jails"] == 0
|
|
assert status["total_bans"] == 0
|
|
assert status["total_failures"] == 0
|
|
|
|
async def test_returns_offline_when_state_not_initialised(
|
|
self, client: AsyncClient
|
|
) -> None:
|
|
"""Endpoint returns online=False as a safe default if the cache is absent."""
|
|
# Setup + login so the endpoint is reachable.
|
|
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
|
await client.post(
|
|
"/api/auth/login",
|
|
json={"password": _SETUP_PAYLOAD["master_password"]},
|
|
)
|
|
# server_status is not set on app.state in the shared `client` fixture.
|
|
response = await client.get("/api/dashboard/status")
|
|
assert response.status_code == 200
|
|
status = response.json()["status"]
|
|
assert status["online"] is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Dashboard bans endpoint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_ban_list_response(n: int = 2) -> DashboardBanListResponse:
|
|
"""Build a mock DashboardBanListResponse with *n* items."""
|
|
items = [
|
|
DashboardBanItem(
|
|
ip=f"1.2.3.{i}",
|
|
jail="sshd",
|
|
banned_at="2026-03-01T10:00:00+00:00",
|
|
service=None,
|
|
country_code="DE",
|
|
country_name="Germany",
|
|
asn="AS3320",
|
|
org="Telekom",
|
|
ban_count=1,
|
|
origin="selfblock",
|
|
)
|
|
for i in range(n)
|
|
]
|
|
return DashboardBanListResponse(items=items, total=n, page=1, page_size=100)
|
|
|
|
|
|
class TestDashboardBans:
|
|
"""GET /api/dashboard/bans."""
|
|
|
|
async def test_returns_200_when_authenticated(
|
|
self, dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""Authenticated request returns HTTP 200."""
|
|
with patch(
|
|
"app.routers.dashboard.ban_service.list_bans",
|
|
new=AsyncMock(return_value=_make_ban_list_response()),
|
|
):
|
|
response = await dashboard_client.get("/api/dashboard/bans")
|
|
assert response.status_code == 200
|
|
|
|
async def test_returns_401_when_unauthenticated(
|
|
self, client: AsyncClient
|
|
) -> None:
|
|
"""Unauthenticated request returns HTTP 401."""
|
|
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
|
response = await client.get("/api/dashboard/bans")
|
|
assert response.status_code == 401
|
|
|
|
async def test_response_contains_items_and_total(
|
|
self, dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""Response body contains ``items`` list and ``total`` count."""
|
|
with patch(
|
|
"app.routers.dashboard.ban_service.list_bans",
|
|
new=AsyncMock(return_value=_make_ban_list_response(3)),
|
|
):
|
|
response = await dashboard_client.get("/api/dashboard/bans")
|
|
|
|
body = response.json()
|
|
assert "items" in body
|
|
assert "total" in body
|
|
assert body["total"] == 3
|
|
assert len(body["items"]) == 3
|
|
|
|
async def test_default_range_is_24h(self, dashboard_client: AsyncClient) -> None:
|
|
"""If no ``range`` param is provided the default ``24h`` preset is used."""
|
|
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
|
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
|
await dashboard_client.get("/api/dashboard/bans")
|
|
|
|
called_range = mock_list.call_args[0][1]
|
|
assert called_range == "24h"
|
|
|
|
async def test_accepts_time_range_param(
|
|
self, dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""The ``range`` query parameter is forwarded to ban_service."""
|
|
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
|
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
|
await dashboard_client.get("/api/dashboard/bans?range=7d")
|
|
|
|
called_range = mock_list.call_args[0][1]
|
|
assert called_range == "7d"
|
|
|
|
async def test_empty_ban_list_returns_zero_total(
|
|
self, dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""Returns ``total=0`` and empty ``items`` when no bans are in range."""
|
|
empty = DashboardBanListResponse(items=[], total=0, page=1, page_size=100)
|
|
with patch(
|
|
"app.routers.dashboard.ban_service.list_bans",
|
|
new=AsyncMock(return_value=empty),
|
|
):
|
|
response = await dashboard_client.get("/api/dashboard/bans")
|
|
|
|
body = response.json()
|
|
assert body["total"] == 0
|
|
assert body["items"] == []
|
|
|
|
async def test_item_shape_is_correct(self, dashboard_client: AsyncClient) -> None:
|
|
"""Each item in ``items`` has the expected fields."""
|
|
with patch(
|
|
"app.routers.dashboard.ban_service.list_bans",
|
|
new=AsyncMock(return_value=_make_ban_list_response(1)),
|
|
):
|
|
response = await dashboard_client.get("/api/dashboard/bans")
|
|
|
|
item = response.json()["items"][0]
|
|
assert "ip" in item
|
|
assert "jail" in item
|
|
assert "banned_at" in item
|
|
assert "ban_count" in item
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Bans by country endpoint
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_bans_by_country_response() -> object:
|
|
"""Build a stub BansByCountryResponse."""
|
|
from app.models.ban import BansByCountryResponse
|
|
|
|
items = [
|
|
DashboardBanItem(
|
|
ip="1.2.3.4",
|
|
jail="sshd",
|
|
banned_at="2026-03-01T10:00:00+00:00",
|
|
service=None,
|
|
country_code="DE",
|
|
country_name="Germany",
|
|
asn="AS3320",
|
|
org="Telekom",
|
|
ban_count=1,
|
|
origin="selfblock",
|
|
),
|
|
DashboardBanItem(
|
|
ip="5.6.7.8",
|
|
jail="blocklist-import",
|
|
banned_at="2026-03-01T10:05:00+00:00",
|
|
service=None,
|
|
country_code="US",
|
|
country_name="United States",
|
|
asn="AS15169",
|
|
org="Google LLC",
|
|
ban_count=2,
|
|
origin="blocklist",
|
|
),
|
|
]
|
|
return BansByCountryResponse(
|
|
countries={"DE": 1, "US": 1},
|
|
country_names={"DE": "Germany", "US": "United States"},
|
|
bans=items,
|
|
total=2,
|
|
)
|
|
|
|
|
|
@pytest.mark.anyio
|
|
class TestBansByCountry:
|
|
"""GET /api/dashboard/bans/by-country."""
|
|
|
|
async def test_returns_200_when_authenticated(
|
|
self, dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""Authenticated request returns HTTP 200."""
|
|
with patch(
|
|
"app.routers.dashboard.ban_service.bans_by_country",
|
|
new=AsyncMock(return_value=_make_bans_by_country_response()),
|
|
):
|
|
response = await dashboard_client.get("/api/dashboard/bans/by-country")
|
|
assert response.status_code == 200
|
|
|
|
async def test_returns_401_when_unauthenticated(
|
|
self, client: AsyncClient
|
|
) -> None:
|
|
"""Unauthenticated request returns HTTP 401."""
|
|
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
|
response = await client.get("/api/dashboard/bans/by-country")
|
|
assert response.status_code == 401
|
|
|
|
async def test_response_shape(self, dashboard_client: AsyncClient) -> None:
|
|
"""Response body contains countries, country_names, bans, total."""
|
|
with patch(
|
|
"app.routers.dashboard.ban_service.bans_by_country",
|
|
new=AsyncMock(return_value=_make_bans_by_country_response()),
|
|
):
|
|
response = await dashboard_client.get("/api/dashboard/bans/by-country")
|
|
|
|
body = response.json()
|
|
assert "countries" in body
|
|
assert "country_names" in body
|
|
assert "bans" in body
|
|
assert "total" in body
|
|
assert body["total"] == 2
|
|
assert body["countries"]["DE"] == 1
|
|
assert body["countries"]["US"] == 1
|
|
assert body["country_names"]["DE"] == "Germany"
|
|
|
|
async def test_accepts_time_range_param(
|
|
self, dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""The range query parameter is forwarded to ban_service."""
|
|
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
|
|
with patch(
|
|
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn
|
|
):
|
|
await dashboard_client.get("/api/dashboard/bans/by-country?range=7d")
|
|
|
|
called_range = mock_fn.call_args[0][1]
|
|
assert called_range == "7d"
|
|
|
|
async def test_empty_window_returns_empty_response(
|
|
self, dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""Empty time range returns empty countries dict and bans list."""
|
|
from app.models.ban import BansByCountryResponse
|
|
|
|
empty = BansByCountryResponse(
|
|
countries={},
|
|
country_names={},
|
|
bans=[],
|
|
total=0,
|
|
)
|
|
with patch(
|
|
"app.routers.dashboard.ban_service.bans_by_country",
|
|
new=AsyncMock(return_value=empty),
|
|
):
|
|
response = await dashboard_client.get("/api/dashboard/bans/by-country")
|
|
|
|
body = response.json()
|
|
assert body["total"] == 0
|
|
assert body["countries"] == {}
|
|
assert body["bans"] == []
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Origin field tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestDashboardBansOriginField:
|
|
"""Verify that the ``origin`` field is present in API responses."""
|
|
|
|
async def test_origin_present_in_ban_list_items(
|
|
self, dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""Each item in ``/api/dashboard/bans`` carries an ``origin`` field."""
|
|
with patch(
|
|
"app.routers.dashboard.ban_service.list_bans",
|
|
new=AsyncMock(return_value=_make_ban_list_response(1)),
|
|
):
|
|
response = await dashboard_client.get("/api/dashboard/bans")
|
|
|
|
item = response.json()["items"][0]
|
|
assert "origin" in item
|
|
assert item["origin"] in ("blocklist", "selfblock")
|
|
|
|
async def test_selfblock_origin_serialised_correctly(
|
|
self, dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""A ban from a non-blocklist jail serialises as ``"selfblock"``."""
|
|
with patch(
|
|
"app.routers.dashboard.ban_service.list_bans",
|
|
new=AsyncMock(return_value=_make_ban_list_response(1)),
|
|
):
|
|
response = await dashboard_client.get("/api/dashboard/bans")
|
|
|
|
item = response.json()["items"][0]
|
|
assert item["jail"] == "sshd"
|
|
assert item["origin"] == "selfblock"
|
|
|
|
async def test_origin_present_in_bans_by_country(
|
|
self, dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""Each ban in ``/api/dashboard/bans/by-country`` carries an ``origin``."""
|
|
with patch(
|
|
"app.routers.dashboard.ban_service.bans_by_country",
|
|
new=AsyncMock(return_value=_make_bans_by_country_response()),
|
|
):
|
|
response = await dashboard_client.get("/api/dashboard/bans/by-country")
|
|
|
|
bans = response.json()["bans"]
|
|
assert all("origin" in ban for ban in bans)
|
|
origins = {ban["origin"] for ban in bans}
|
|
assert origins == {"blocklist", "selfblock"}
|
|
|
|
async def test_blocklist_origin_serialised_correctly(
|
|
self, dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""A ban from the ``blocklist-import`` jail serialises as ``"blocklist"``."""
|
|
with patch(
|
|
"app.routers.dashboard.ban_service.bans_by_country",
|
|
new=AsyncMock(return_value=_make_bans_by_country_response()),
|
|
):
|
|
response = await dashboard_client.get("/api/dashboard/bans/by-country")
|
|
|
|
bans = response.json()["bans"]
|
|
blocklist_ban = next(b for b in bans if b["jail"] == "blocklist-import")
|
|
assert blocklist_ban["origin"] == "blocklist"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Origin filter query parameter tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestOriginFilterParam:
|
|
"""Verify that the ``origin`` query parameter is forwarded to the service."""
|
|
|
|
async def test_bans_origin_blocklist_forwarded_to_service(
|
|
self, dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""``?origin=blocklist`` is passed to ``ban_service.list_bans``."""
|
|
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
|
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
|
await dashboard_client.get("/api/dashboard/bans?origin=blocklist")
|
|
|
|
_, kwargs = mock_list.call_args
|
|
assert kwargs.get("origin") == "blocklist"
|
|
|
|
async def test_bans_origin_selfblock_forwarded_to_service(
|
|
self, dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""``?origin=selfblock`` is passed to ``ban_service.list_bans``."""
|
|
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
|
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
|
await dashboard_client.get("/api/dashboard/bans?origin=selfblock")
|
|
|
|
_, kwargs = mock_list.call_args
|
|
assert kwargs.get("origin") == "selfblock"
|
|
|
|
async def test_bans_no_origin_param_defaults_to_none(
|
|
self, dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""Omitting ``origin`` passes ``None`` to the service (no filtering)."""
|
|
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
|
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
|
await dashboard_client.get("/api/dashboard/bans")
|
|
|
|
_, kwargs = mock_list.call_args
|
|
assert kwargs.get("origin") is None
|
|
|
|
async def test_bans_invalid_origin_returns_422(
|
|
self, dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""An invalid ``origin`` value returns HTTP 422 Unprocessable Entity."""
|
|
response = await dashboard_client.get("/api/dashboard/bans?origin=invalid")
|
|
assert response.status_code == 422
|
|
|
|
async def test_by_country_origin_blocklist_forwarded(
|
|
self, dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""``?origin=blocklist`` is passed to ``ban_service.bans_by_country``."""
|
|
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
|
|
with patch(
|
|
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn
|
|
):
|
|
await dashboard_client.get(
|
|
"/api/dashboard/bans/by-country?origin=blocklist"
|
|
)
|
|
|
|
_, kwargs = mock_fn.call_args
|
|
assert kwargs.get("origin") == "blocklist"
|
|
|
|
async def test_by_country_no_origin_defaults_to_none(
|
|
self, dashboard_client: AsyncClient
|
|
) -> None:
|
|
"""Omitting ``origin`` passes ``None`` to ``bans_by_country``."""
|
|
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
|
|
with patch(
|
|
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn
|
|
):
|
|
await dashboard_client.get("/api/dashboard/bans/by-country")
|
|
|
|
_, kwargs = mock_fn.call_args
|
|
assert kwargs.get("origin") is None
|