fixed tests

This commit is contained in:
2026-05-15 20:41:05 +02:00
parent 96ce516ecf
commit 77df5d5d65
50 changed files with 1482 additions and 5089 deletions

View File

@@ -10,13 +10,17 @@ import pytest
from httpx import ASGITransport, AsyncClient
import app
from app.config import Settings
from app.db import init_db
from app.main import create_app
from app.models.ban import (
DashboardBanItem,
DashboardBanListResponse,
from app.models.ban_domain import (
DomainBansByCountry,
DomainBansByJail,
DomainBanTrend,
DomainBanTrendBucket,
DomainDashboardBanItem,
DomainDashboardBanList,
DomainJailBanCount,
)
from app.models.server import ServerStatus
@@ -25,7 +29,7 @@ from app.models.server import ServerStatus
# ---------------------------------------------------------------------------
_SETUP_PAYLOAD = {
"master_password": "testpassword1",
"master_password": "Testpass1!",
"database_path": "bangui.db",
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
"timezone": "UTC",
@@ -40,13 +44,17 @@ async def dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
Unlike the shared ``client`` fixture this one also exposes access to
``app.state`` via the app instance so we can seed the status cache.
"""
config_dir = tmp_path / "fail2ban"
config_dir.mkdir()
settings = Settings(
database_path=str(tmp_path / "dashboard_test.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
session_secret="test-dashboard-secret",
fail2ban_config_dir=str(config_dir),
session_secret="test-dashboard-secret-that-is-long-enough",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
session_cookie_secure=False,
)
app = create_app(settings=settings)
@@ -66,8 +74,13 @@ async def dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
# Provide a stub HTTP session so ban/access endpoints can access app.state.http_session.
app.state.http_session = MagicMock()
# Initialize GeoCache (normally done in lifespan handler)
from app.services.geo_cache import GeoCache
app.state.geo_cache = GeoCache()
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
async with AsyncClient(transport=transport, base_url="http://test", headers={"X-BanGUI-Request": "1"}) as ac:
# Complete setup so the middleware doesn't redirect.
resp = await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
assert resp.status_code == 201
@@ -87,13 +100,17 @@ async def dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
@pytest.fixture
async def offline_dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
"""Like ``dashboard_client`` but with an offline server status."""
config_dir = tmp_path / "fail2ban"
config_dir.mkdir()
settings = Settings(
database_path=str(tmp_path / "dashboard_offline_test.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
session_secret="test-dashboard-offline-secret",
fail2ban_config_dir=str(config_dir),
session_secret="test-dashboard-offline-secret-long-enough",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
session_cookie_secure=False,
)
app = create_app(settings=settings)
@@ -105,8 +122,13 @@ async def offline_dashboard_client(tmp_path: Path) -> AsyncClient: # type: igno
app.state.server_status = ServerStatus(online=False)
app.state.http_session = MagicMock()
# Initialize GeoCache (normally done in lifespan handler)
from app.services.geo_cache import GeoCache
app.state.geo_cache = GeoCache()
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
async with AsyncClient(transport=transport, base_url="http://test", headers={"X-BanGUI-Request": "1"}) as ac:
resp = await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
assert resp.status_code == 201
@@ -129,25 +151,19 @@ async def offline_dashboard_client(tmp_path: Path) -> AsyncClient: # type: igno
class TestDashboardStatus:
"""GET /api/dashboard/status."""
async def test_returns_200_when_authenticated(
self, dashboard_client: AsyncClient
) -> None:
async def test_returns_200_when_authenticated(self, dashboard_client: AsyncClient) -> None:
"""Authenticated request returns HTTP 200."""
response = await dashboard_client.get("/api/v1/dashboard/status")
assert response.status_code == 200
async def test_returns_401_when_unauthenticated(
self, client: AsyncClient
) -> None:
async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
"""Unauthenticated request returns HTTP 401."""
# Complete setup so the middleware allows the request through.
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
response = await client.get("/api/v1/dashboard/status")
assert response.status_code == 401
async def test_response_shape_when_online(
self, dashboard_client: AsyncClient
) -> None:
async def test_response_shape_when_online(self, dashboard_client: AsyncClient) -> None:
"""Response contains the expected ``status`` object shape."""
response = await dashboard_client.get("/api/v1/dashboard/status")
body = response.json()
@@ -161,9 +177,7 @@ class TestDashboardStatus:
assert "total_bans" in status
assert "total_failures" in status
async def test_cached_values_returned_when_online(
self, dashboard_client: AsyncClient
) -> None:
async def test_cached_values_returned_when_online(self, dashboard_client: AsyncClient) -> None:
"""Endpoint returns the exact values from ``app.state.server_status``."""
response = await dashboard_client.get("/api/v1/dashboard/status")
body = response.json()
@@ -175,9 +189,7 @@ class TestDashboardStatus:
assert status["total_bans"] == 10
assert status["total_failures"] == 5
async def test_offline_status_returned_correctly(
self, offline_dashboard_client: AsyncClient
) -> None:
async def test_offline_status_returned_correctly(self, offline_dashboard_client: AsyncClient) -> None:
"""Endpoint returns online=False when the cache holds an offline snapshot."""
response = await offline_dashboard_client.get("/api/v1/dashboard/status")
assert response.status_code == 200
@@ -190,9 +202,7 @@ class TestDashboardStatus:
assert status["total_bans"] == 0
assert status["total_failures"] == 0
async def test_returns_offline_when_state_not_initialised(
self, client: AsyncClient
) -> None:
async def test_returns_offline_when_state_not_initialised(self, client: AsyncClient) -> None:
"""Endpoint returns online=False as a safe default if the cache is absent."""
# Setup + login so the endpoint is reachable.
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
@@ -200,7 +210,9 @@ class TestDashboardStatus:
"/api/v1/auth/login",
json={"password": _SETUP_PAYLOAD["master_password"]},
)
# server_status is not set on app.state in the shared `client` fixture.
# Clear server_status to simulate uninitialized state.
client._transport.app.state.server_status = None # type: ignore[attr-defined]
client._transport.app.state.server_status = None # type: ignore[attr-defined]
response = await client.get("/api/v1/dashboard/status")
assert response.status_code == 200
status = response.json()["status"]
@@ -212,10 +224,10 @@ class TestDashboardStatus:
# ---------------------------------------------------------------------------
def _make_ban_list_response(n: int = 2) -> DashboardBanListResponse:
"""Build a mock DashboardBanListResponse with *n* items."""
def _make_ban_list_response(n: int = 2) -> DomainDashboardBanList:
"""Build a mock DomainDashboardBanList with *n* items."""
items = [
DashboardBanItem(
DomainDashboardBanItem(
ip=f"1.2.3.{i}",
jail="sshd",
banned_at="2026-03-01T10:00:00+00:00",
@@ -229,15 +241,18 @@ def _make_ban_list_response(n: int = 2) -> DashboardBanListResponse:
)
for i in range(n)
]
return DashboardBanListResponse(items=items, total=n, page=1, page_size=100)
return DomainDashboardBanList(
items=items,
total=n,
page=1,
page_size=100,
)
class TestDashboardBans:
"""GET /api/dashboard/bans."""
async def test_returns_200_when_authenticated(
self, dashboard_client: AsyncClient
) -> None:
async def test_returns_200_when_authenticated(self, dashboard_client: AsyncClient) -> None:
"""Authenticated request returns HTTP 200."""
with patch(
"app.routers.dashboard.ban_service.list_bans",
@@ -246,17 +261,13 @@ class TestDashboardBans:
response = await dashboard_client.get("/api/v1/dashboard/bans")
assert response.status_code == 200
async def test_returns_401_when_unauthenticated(
self, client: AsyncClient
) -> None:
async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
"""Unauthenticated request returns HTTP 401."""
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
response = await client.get("/api/v1/dashboard/bans")
assert response.status_code == 401
async def test_response_contains_items_and_total(
self, dashboard_client: AsyncClient
) -> None:
async def test_response_contains_items_and_total(self, dashboard_client: AsyncClient) -> None:
"""Response body contains ``items`` list and ``total`` count."""
with patch(
"app.routers.dashboard.ban_service.list_bans",
@@ -266,8 +277,8 @@ class TestDashboardBans:
body = response.json()
assert "items" in body
assert "total" in body
assert body["total"] == 3
assert "pagination" in body
assert body["pagination"]["total"] == 3
assert len(body["items"]) == 3
async def test_default_range_is_24h(self, dashboard_client: AsyncClient) -> None:
@@ -279,9 +290,7 @@ class TestDashboardBans:
called_range = mock_list.call_args[0][1]
assert called_range == "24h"
async def test_accepts_time_range_param(
self, dashboard_client: AsyncClient
) -> None:
async def test_accepts_time_range_param(self, dashboard_client: AsyncClient) -> None:
"""The ``range`` query parameter is forwarded to ban_service."""
mock_list = AsyncMock(return_value=_make_ban_list_response())
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
@@ -290,9 +299,7 @@ class TestDashboardBans:
called_range = mock_list.call_args[0][1]
assert called_range == "7d"
async def test_accepts_source_param(
self, dashboard_client: AsyncClient
) -> None:
async def test_accepts_source_param(self, dashboard_client: AsyncClient) -> None:
"""The ``source`` query parameter is forwarded to ban_service."""
mock_list = AsyncMock(return_value=_make_ban_list_response())
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
@@ -301,11 +308,14 @@ class TestDashboardBans:
called_source = mock_list.call_args[1]["source"]
assert called_source == "archive"
async def test_empty_ban_list_returns_zero_total(
self, dashboard_client: AsyncClient
) -> None:
async def test_empty_ban_list_returns_zero_total(self, dashboard_client: AsyncClient) -> None:
"""Returns ``total=0`` and empty ``items`` when no bans are in range."""
empty = DashboardBanListResponse(items=[], total=0, page=1, page_size=100)
empty = DomainDashboardBanList(
items=[],
total=0,
page=1,
page_size=100,
)
with patch(
"app.routers.dashboard.ban_service.list_bans",
new=AsyncMock(return_value=empty),
@@ -313,7 +323,7 @@ class TestDashboardBans:
response = await dashboard_client.get("/api/v1/dashboard/bans")
body = response.json()
assert body["total"] == 0
assert body["pagination"]["total"] == 0
assert body["items"] == []
async def test_item_shape_is_correct(self, dashboard_client: AsyncClient) -> None:
@@ -336,12 +346,10 @@ class TestDashboardBans:
# ---------------------------------------------------------------------------
def _make_bans_by_country_response() -> object:
"""Build a stub BansByCountryResponse."""
from app.models.ban import BansByCountryResponse
def _make_bans_by_country_response() -> DomainBansByCountry:
"""Build a stub DomainBansByCountry."""
items = [
DashboardBanItem(
DomainDashboardBanItem(
ip="1.2.3.4",
jail="sshd",
banned_at="2026-03-01T10:00:00+00:00",
@@ -353,7 +361,7 @@ def _make_bans_by_country_response() -> object:
ban_count=1,
origin="selfblock",
),
DashboardBanItem(
DomainDashboardBanItem(
ip="5.6.7.8",
jail="blocklist-import",
banned_at="2026-03-01T10:05:00+00:00",
@@ -366,10 +374,10 @@ def _make_bans_by_country_response() -> object:
origin="blocklist",
),
]
return BansByCountryResponse(
return DomainBansByCountry(
countries={"DE": 1, "US": 1},
country_names={"DE": "Germany", "US": "United States"},
bans=items,
items=items,
total=2,
)
@@ -378,9 +386,7 @@ def _make_bans_by_country_response() -> object:
class TestBansByCountry:
"""GET /api/dashboard/bans/by-country."""
async def test_returns_200_when_authenticated(
self, dashboard_client: AsyncClient
) -> None:
async def test_returns_200_when_authenticated(self, dashboard_client: AsyncClient) -> None:
"""Authenticated request returns HTTP 200."""
with patch(
"app.routers.dashboard.ban_service.bans_by_country",
@@ -389,9 +395,7 @@ class TestBansByCountry:
response = await dashboard_client.get("/api/v1/dashboard/bans/by-country")
assert response.status_code == 200
async def test_returns_401_when_unauthenticated(
self, client: AsyncClient
) -> None:
async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
"""Unauthenticated request returns HTTP 401."""
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
response = await client.get("/api/v1/dashboard/bans/by-country")
@@ -415,38 +419,26 @@ class TestBansByCountry:
assert body["countries"]["US"] == 1
assert body["country_names"]["DE"] == "Germany"
async def test_accepts_time_range_param(
self, dashboard_client: AsyncClient
) -> None:
async def test_accepts_time_range_param(self, dashboard_client: AsyncClient) -> None:
"""The range query parameter is forwarded to ban_service."""
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
with patch(
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn
):
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
await dashboard_client.get("/api/v1/dashboard/bans/by-country?range=7d")
called_range = mock_fn.call_args[0][1]
assert called_range == "7d"
async def test_invalid_source_returns_422(
self, dashboard_client: AsyncClient
) -> None:
"""An invalid source value returns HTTP 422."""
response = await dashboard_client.get(
"/api/v1/dashboard/bans/by-country?source=invalid"
)
assert response.status_code == 422
async def test_invalid_source_returns_400(self, dashboard_client: AsyncClient) -> None:
"""An invalid source value returns HTTP 400."""
response = await dashboard_client.get("/api/v1/dashboard/bans/by-country?source=invalid")
assert response.status_code == 400
async def test_empty_window_returns_empty_response(
self, dashboard_client: AsyncClient
) -> None:
async def test_empty_window_returns_empty_response(self, dashboard_client: AsyncClient) -> None:
"""Empty time range returns empty countries dict and bans list."""
from app.models.ban import BansByCountryResponse
empty = BansByCountryResponse(
empty = DomainBansByCountry(
countries={},
country_names={},
bans=[],
items=[],
total=0,
)
with patch(
@@ -469,9 +461,7 @@ class TestBansByCountry:
class TestDashboardBansOriginField:
"""Verify that the ``origin`` field is present in API responses."""
async def test_origin_present_in_ban_list_items(
self, dashboard_client: AsyncClient
) -> None:
async def test_origin_present_in_ban_list_items(self, dashboard_client: AsyncClient) -> None:
"""Each item in ``/api/dashboard/bans`` carries an ``origin`` field."""
with patch(
"app.routers.dashboard.ban_service.list_bans",
@@ -483,9 +473,7 @@ class TestDashboardBansOriginField:
assert "origin" in item
assert item["origin"] in ("blocklist", "selfblock")
async def test_selfblock_origin_serialised_correctly(
self, dashboard_client: AsyncClient
) -> None:
async def test_selfblock_origin_serialised_correctly(self, dashboard_client: AsyncClient) -> None:
"""A ban from a non-blocklist jail serialises as ``"selfblock"``."""
with patch(
"app.routers.dashboard.ban_service.list_bans",
@@ -497,9 +485,7 @@ class TestDashboardBansOriginField:
assert item["jail"] == "sshd"
assert item["origin"] == "selfblock"
async def test_origin_present_in_bans_by_country(
self, dashboard_client: AsyncClient
) -> None:
async def test_origin_present_in_bans_by_country(self, dashboard_client: AsyncClient) -> None:
"""Each ban in ``/api/dashboard/bans/by-country`` carries an ``origin``."""
with patch(
"app.routers.dashboard.ban_service.bans_by_country",
@@ -512,9 +498,7 @@ class TestDashboardBansOriginField:
origins = {ban["origin"] for ban in bans}
assert origins == {"blocklist", "selfblock"}
async def test_bans_by_country_source_param_forwarded(
self, dashboard_client: AsyncClient
) -> None:
async def test_bans_by_country_source_param_forwarded(self, dashboard_client: AsyncClient) -> None:
"""The ``source`` query parameter is forwarded to bans_by_country."""
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
@@ -522,22 +506,16 @@ class TestDashboardBansOriginField:
assert mock_fn.call_args[1]["source"] == "archive"
async def test_bans_by_country_country_code_forwarded(
self, dashboard_client: AsyncClient
) -> None:
async def test_bans_by_country_country_code_forwarded(self, dashboard_client: AsyncClient) -> None:
"""The ``country_code`` query parameter is forwarded to bans_by_country."""
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
await dashboard_client.get(
"/api/v1/dashboard/bans/by-country?country_code=DE"
)
await dashboard_client.get("/api/v1/dashboard/bans/by-country?country_code=DE")
_, kwargs = mock_fn.call_args
assert kwargs.get("country_code") == "DE"
async def test_blocklist_origin_serialised_correctly(
self, dashboard_client: AsyncClient
) -> None:
async def test_blocklist_origin_serialised_correctly(self, dashboard_client: AsyncClient) -> None:
"""A ban from the ``blocklist-import`` jail serialises as ``"blocklist"``."""
with patch(
"app.routers.dashboard.ban_service.bans_by_country",
@@ -558,9 +536,7 @@ class TestDashboardBansOriginField:
class TestOriginFilterParam:
"""Verify that the ``origin`` query parameter is forwarded to the service."""
async def test_bans_origin_blocklist_forwarded_to_service(
self, dashboard_client: AsyncClient
) -> None:
async def test_bans_origin_blocklist_forwarded_to_service(self, dashboard_client: AsyncClient) -> None:
"""``?origin=blocklist`` is passed to ``ban_service.list_bans``."""
mock_list = AsyncMock(return_value=_make_ban_list_response())
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
@@ -569,9 +545,7 @@ class TestOriginFilterParam:
_, kwargs = mock_list.call_args
assert kwargs.get("origin") == "blocklist"
async def test_bans_origin_selfblock_forwarded_to_service(
self, dashboard_client: AsyncClient
) -> None:
async def test_bans_origin_selfblock_forwarded_to_service(self, dashboard_client: AsyncClient) -> None:
"""``?origin=selfblock`` is passed to ``ban_service.list_bans``."""
mock_list = AsyncMock(return_value=_make_ban_list_response())
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
@@ -580,9 +554,7 @@ class TestOriginFilterParam:
_, kwargs = mock_list.call_args
assert kwargs.get("origin") == "selfblock"
async def test_bans_no_origin_param_defaults_to_none(
self, dashboard_client: AsyncClient
) -> None:
async def test_bans_no_origin_param_defaults_to_none(self, dashboard_client: AsyncClient) -> None:
"""Omitting ``origin`` passes ``None`` to the service (no filtering)."""
mock_list = AsyncMock(return_value=_make_ban_list_response())
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
@@ -591,36 +563,24 @@ class TestOriginFilterParam:
_, kwargs = mock_list.call_args
assert kwargs.get("origin") is None
async def test_bans_invalid_origin_returns_422(
self, dashboard_client: AsyncClient
) -> None:
"""An invalid ``origin`` value returns HTTP 422 Unprocessable Entity."""
async def test_bans_invalid_origin_returns_400(self, dashboard_client: AsyncClient) -> None:
"""An invalid ``origin`` value returns HTTP 400."""
response = await dashboard_client.get("/api/v1/dashboard/bans?origin=invalid")
assert response.status_code == 422
assert response.status_code == 400
async def test_by_country_origin_blocklist_forwarded(
self, dashboard_client: AsyncClient
) -> None:
async def test_by_country_origin_blocklist_forwarded(self, dashboard_client: AsyncClient) -> None:
"""``?origin=blocklist`` is passed to ``ban_service.bans_by_country``."""
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
with patch(
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn
):
await dashboard_client.get(
"/api/v1/dashboard/bans/by-country?origin=blocklist"
)
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
await dashboard_client.get("/api/v1/dashboard/bans/by-country?origin=blocklist")
_, kwargs = mock_fn.call_args
assert kwargs.get("origin") == "blocklist"
async def test_by_country_no_origin_defaults_to_none(
self, dashboard_client: AsyncClient
) -> None:
async def test_by_country_no_origin_defaults_to_none(self, dashboard_client: AsyncClient) -> None:
"""Omitting ``origin`` passes ``None`` to ``bans_by_country``."""
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
with patch(
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn
):
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
await dashboard_client.get("/api/v1/dashboard/bans/by-country")
_, kwargs = mock_fn.call_args
@@ -632,24 +592,17 @@ class TestOriginFilterParam:
# ---------------------------------------------------------------------------
def _make_ban_trend_response(n_buckets: int = 24) -> object:
"""Build a stub :class:`~app.models.ban.BanTrendResponse`."""
from app.models.ban import BanTrendBucket, BanTrendResponse
buckets = [
BanTrendBucket(timestamp=f"2026-03-01T{i:02d}:00:00+00:00", count=i)
for i in range(n_buckets)
]
return BanTrendResponse(buckets=buckets, bucket_size="1h")
def _make_ban_trend_response(n_buckets: int = 24) -> DomainBanTrend:
"""Build a stub :class:`~app.models.ban_domain.DomainBanTrend`."""
buckets = [DomainBanTrendBucket(timestamp=f"2026-03-01T{i:02d}:00:00+00:00", count=i) for i in range(n_buckets)]
return DomainBanTrend(buckets=buckets, bucket_size="1h")
@pytest.mark.anyio
class TestBanTrend:
"""GET /api/dashboard/bans/trend."""
async def test_returns_200_when_authenticated(
self, dashboard_client: AsyncClient
) -> None:
async def test_returns_200_when_authenticated(self, dashboard_client: AsyncClient) -> None:
"""Authenticated request returns HTTP 200."""
with patch(
"app.routers.dashboard.ban_service.ban_trend",
@@ -658,9 +611,7 @@ class TestBanTrend:
response = await dashboard_client.get("/api/v1/dashboard/bans/trend")
assert response.status_code == 200
async def test_returns_401_when_unauthenticated(
self, client: AsyncClient
) -> None:
async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
"""Unauthenticated request returns HTTP 401."""
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
response = await client.get("/api/v1/dashboard/bans/trend")
@@ -680,9 +631,7 @@ class TestBanTrend:
assert len(body["buckets"]) == 24
assert body["bucket_size"] == "1h"
async def test_each_bucket_has_timestamp_and_count(
self, dashboard_client: AsyncClient
) -> None:
async def test_each_bucket_has_timestamp_and_count(self, dashboard_client: AsyncClient) -> None:
"""Every element of ``buckets`` has ``timestamp`` and ``count``."""
with patch(
"app.routers.dashboard.ban_service.ban_trend",
@@ -717,16 +666,12 @@ class TestBanTrend:
"""``?origin=blocklist`` is passed as a keyword arg to the service."""
mock_fn = AsyncMock(return_value=_make_ban_trend_response())
with patch("app.routers.dashboard.ban_service.ban_trend", new=mock_fn):
await dashboard_client.get(
"/api/v1/dashboard/bans/trend?origin=blocklist"
)
await dashboard_client.get("/api/v1/dashboard/bans/trend?origin=blocklist")
_, kwargs = mock_fn.call_args
assert kwargs.get("origin") == "blocklist"
async def test_no_origin_defaults_to_none(
self, dashboard_client: AsyncClient
) -> None:
async def test_no_origin_defaults_to_none(self, dashboard_client: AsyncClient) -> None:
"""Omitting ``origin`` passes ``None`` to the service."""
mock_fn = AsyncMock(return_value=_make_ban_trend_response())
with patch("app.routers.dashboard.ban_service.ban_trend", new=mock_fn):
@@ -735,29 +680,19 @@ class TestBanTrend:
_, kwargs = mock_fn.call_args
assert kwargs.get("origin") is None
async def test_invalid_range_returns_422(
self, dashboard_client: AsyncClient
) -> None:
"""An invalid ``range`` value returns HTTP 422."""
response = await dashboard_client.get(
"/api/v1/dashboard/bans/trend?range=invalid"
)
assert response.status_code == 422
async def test_invalid_range_returns_400(self, dashboard_client: AsyncClient) -> None:
"""An invalid ``range`` value returns HTTP 400."""
response = await dashboard_client.get("/api/v1/dashboard/bans/trend?range=invalid")
assert response.status_code == 400
async def test_invalid_source_returns_422(
self, dashboard_client: AsyncClient
) -> None:
"""An invalid source value returns HTTP 422."""
response = await dashboard_client.get(
"/api/v1/dashboard/bans/trend?source=invalid"
)
assert response.status_code == 422
async def test_invalid_source_returns_400(self, dashboard_client: AsyncClient) -> None:
"""An invalid source value returns HTTP 400."""
response = await dashboard_client.get("/api/v1/dashboard/bans/trend?source=invalid")
assert response.status_code == 400
async def test_empty_buckets_response(self, dashboard_client: AsyncClient) -> None:
"""Empty bucket list is serialised correctly."""
from app.models.ban import BanTrendResponse
empty = BanTrendResponse(buckets=[], bucket_size="1h")
empty = DomainBanTrend(buckets=[], bucket_size="1h")
with patch(
"app.routers.dashboard.ban_service.ban_trend",
new=AsyncMock(return_value=empty),
@@ -774,14 +709,12 @@ class TestBanTrend:
# ---------------------------------------------------------------------------
def _make_bans_by_jail_response() -> object:
"""Build a stub :class:`~app.models.ban.BansByJailResponse`."""
from app.models.ban import BansByJailResponse, JailBanCount
return BansByJailResponse(
def _make_bans_by_jail_response() -> DomainBansByJail:
"""Build a stub :class:`~app.models.ban_domain.DomainBansByJail`."""
return DomainBansByJail(
jails=[
JailBanCount(jail="sshd", count=10),
JailBanCount(jail="nginx", count=5),
DomainJailBanCount(jail="sshd", count=10),
DomainJailBanCount(jail="nginx", count=5),
],
total=15,
)
@@ -791,9 +724,7 @@ def _make_bans_by_jail_response() -> object:
class TestBansByJail:
"""GET /api/dashboard/bans/by-jail."""
async def test_returns_200_when_authenticated(
self, dashboard_client: AsyncClient
) -> None:
async def test_returns_200_when_authenticated(self, dashboard_client: AsyncClient) -> None:
"""Authenticated request returns HTTP 200."""
with patch(
"app.routers.dashboard.ban_service.bans_by_jail",
@@ -802,9 +733,7 @@ class TestBansByJail:
response = await dashboard_client.get("/api/v1/dashboard/bans/by-jail")
assert response.status_code == 200
async def test_returns_401_when_unauthenticated(
self, client: AsyncClient
) -> None:
async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
"""Unauthenticated request returns HTTP 401."""
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
response = await client.get("/api/v1/dashboard/bans/by-jail")
@@ -823,9 +752,7 @@ class TestBansByJail:
assert "total" in body
assert isinstance(body["total"], int)
async def test_each_jail_has_name_and_count(
self, dashboard_client: AsyncClient
) -> None:
async def test_each_jail_has_name_and_count(self, dashboard_client: AsyncClient) -> None:
"""Every element of ``jails`` has ``jail`` (string) and ``count`` (int)."""
with patch(
"app.routers.dashboard.ban_service.bans_by_jail",
@@ -861,16 +788,12 @@ class TestBansByJail:
"""``?origin=blocklist`` is passed as a keyword arg to the service."""
mock_fn = AsyncMock(return_value=_make_bans_by_jail_response())
with patch("app.routers.dashboard.ban_service.bans_by_jail", new=mock_fn):
await dashboard_client.get(
"/api/v1/dashboard/bans/by-jail?origin=blocklist"
)
await dashboard_client.get("/api/v1/dashboard/bans/by-jail?origin=blocklist")
_, kwargs = mock_fn.call_args
assert kwargs.get("origin") == "blocklist"
async def test_no_origin_defaults_to_none(
self, dashboard_client: AsyncClient
) -> None:
async def test_no_origin_defaults_to_none(self, dashboard_client: AsyncClient) -> None:
"""Omitting ``origin`` passes ``None`` to the service."""
mock_fn = AsyncMock(return_value=_make_bans_by_jail_response())
with patch("app.routers.dashboard.ban_service.bans_by_jail", new=mock_fn):
@@ -879,23 +802,15 @@ class TestBansByJail:
_, kwargs = mock_fn.call_args
assert kwargs.get("origin") is None
async def test_invalid_range_returns_422(
self, dashboard_client: AsyncClient
) -> None:
"""An invalid ``range`` value returns HTTP 422."""
response = await dashboard_client.get(
"/api/v1/dashboard/bans/by-jail?range=invalid"
)
assert response.status_code == 422
async def test_invalid_range_returns_400(self, dashboard_client: AsyncClient) -> None:
"""An invalid ``range`` value returns HTTP 400."""
response = await dashboard_client.get("/api/v1/dashboard/bans/by-jail?range=invalid")
assert response.status_code == 400
async def test_invalid_source_returns_422(
self, dashboard_client: AsyncClient
) -> None:
"""An invalid source value returns HTTP 422."""
response = await dashboard_client.get(
"/api/v1/dashboard/bans/by-jail?source=invalid"
)
assert response.status_code == 422
async def test_invalid_source_returns_400(self, dashboard_client: AsyncClient) -> None:
"""An invalid source value returns HTTP 400."""
response = await dashboard_client.get("/api/v1/dashboard/bans/by-jail?source=invalid")
assert response.status_code == 400
async def test_empty_jails_response(self, dashboard_client: AsyncClient) -> None:
"""Empty jails list is serialised correctly."""
@@ -911,4 +826,3 @@ class TestBansByJail:
body = response.json()
assert body["jails"] == []
assert body["total"] == 0