fixed tests
This commit is contained in:
@@ -10,13 +10,17 @@ import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
import app
|
||||
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.models.ban import (
|
||||
DashboardBanItem,
|
||||
DashboardBanListResponse,
|
||||
from app.models.ban_domain import (
|
||||
DomainBansByCountry,
|
||||
DomainBansByJail,
|
||||
DomainBanTrend,
|
||||
DomainBanTrendBucket,
|
||||
DomainDashboardBanItem,
|
||||
DomainDashboardBanList,
|
||||
DomainJailBanCount,
|
||||
)
|
||||
from app.models.server import ServerStatus
|
||||
|
||||
@@ -25,7 +29,7 @@ from app.models.server import ServerStatus
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SETUP_PAYLOAD = {
|
||||
"master_password": "testpassword1",
|
||||
"master_password": "Testpass1!",
|
||||
"database_path": "bangui.db",
|
||||
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
||||
"timezone": "UTC",
|
||||
@@ -40,13 +44,17 @@ async def dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
Unlike the shared ``client`` fixture this one also exposes access to
|
||||
``app.state`` via the app instance so we can seed the status cache.
|
||||
"""
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir()
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "dashboard_test.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
session_secret="test-dashboard-secret",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
session_secret="test-dashboard-secret-that-is-long-enough",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
session_cookie_secure=False,
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
@@ -66,8 +74,13 @@ async def dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
# Provide a stub HTTP session so ban/access endpoints can access app.state.http_session.
|
||||
app.state.http_session = MagicMock()
|
||||
|
||||
# Initialize GeoCache (normally done in lifespan handler)
|
||||
from app.services.geo_cache import GeoCache
|
||||
|
||||
app.state.geo_cache = GeoCache()
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
async with AsyncClient(transport=transport, base_url="http://test", headers={"X-BanGUI-Request": "1"}) as ac:
|
||||
# Complete setup so the middleware doesn't redirect.
|
||||
resp = await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
assert resp.status_code == 201
|
||||
@@ -87,13 +100,17 @@ async def dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
@pytest.fixture
|
||||
async def offline_dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
"""Like ``dashboard_client`` but with an offline server status."""
|
||||
config_dir = tmp_path / "fail2ban"
|
||||
config_dir.mkdir()
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "dashboard_offline_test.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
session_secret="test-dashboard-offline-secret",
|
||||
fail2ban_config_dir=str(config_dir),
|
||||
session_secret="test-dashboard-offline-secret-long-enough",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
session_cookie_secure=False,
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
@@ -105,8 +122,13 @@ async def offline_dashboard_client(tmp_path: Path) -> AsyncClient: # type: igno
|
||||
app.state.server_status = ServerStatus(online=False)
|
||||
app.state.http_session = MagicMock()
|
||||
|
||||
# Initialize GeoCache (normally done in lifespan handler)
|
||||
from app.services.geo_cache import GeoCache
|
||||
|
||||
app.state.geo_cache = GeoCache()
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
async with AsyncClient(transport=transport, base_url="http://test", headers={"X-BanGUI-Request": "1"}) as ac:
|
||||
resp = await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
assert resp.status_code == 201
|
||||
|
||||
@@ -129,25 +151,19 @@ async def offline_dashboard_client(tmp_path: Path) -> AsyncClient: # type: igno
|
||||
class TestDashboardStatus:
|
||||
"""GET /api/dashboard/status."""
|
||||
|
||||
async def test_returns_200_when_authenticated(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_200_when_authenticated(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Authenticated request returns HTTP 200."""
|
||||
response = await dashboard_client.get("/api/v1/dashboard/status")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_returns_401_when_unauthenticated(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
|
||||
"""Unauthenticated request returns HTTP 401."""
|
||||
# Complete setup so the middleware allows the request through.
|
||||
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
response = await client.get("/api/v1/dashboard/status")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_response_shape_when_online(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_response_shape_when_online(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Response contains the expected ``status`` object shape."""
|
||||
response = await dashboard_client.get("/api/v1/dashboard/status")
|
||||
body = response.json()
|
||||
@@ -161,9 +177,7 @@ class TestDashboardStatus:
|
||||
assert "total_bans" in status
|
||||
assert "total_failures" in status
|
||||
|
||||
async def test_cached_values_returned_when_online(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_cached_values_returned_when_online(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Endpoint returns the exact values from ``app.state.server_status``."""
|
||||
response = await dashboard_client.get("/api/v1/dashboard/status")
|
||||
body = response.json()
|
||||
@@ -175,9 +189,7 @@ class TestDashboardStatus:
|
||||
assert status["total_bans"] == 10
|
||||
assert status["total_failures"] == 5
|
||||
|
||||
async def test_offline_status_returned_correctly(
|
||||
self, offline_dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_offline_status_returned_correctly(self, offline_dashboard_client: AsyncClient) -> None:
|
||||
"""Endpoint returns online=False when the cache holds an offline snapshot."""
|
||||
response = await offline_dashboard_client.get("/api/v1/dashboard/status")
|
||||
assert response.status_code == 200
|
||||
@@ -190,9 +202,7 @@ class TestDashboardStatus:
|
||||
assert status["total_bans"] == 0
|
||||
assert status["total_failures"] == 0
|
||||
|
||||
async def test_returns_offline_when_state_not_initialised(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_offline_when_state_not_initialised(self, client: AsyncClient) -> None:
|
||||
"""Endpoint returns online=False as a safe default if the cache is absent."""
|
||||
# Setup + login so the endpoint is reachable.
|
||||
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
@@ -200,7 +210,9 @@ class TestDashboardStatus:
|
||||
"/api/v1/auth/login",
|
||||
json={"password": _SETUP_PAYLOAD["master_password"]},
|
||||
)
|
||||
# server_status is not set on app.state in the shared `client` fixture.
|
||||
# Clear server_status to simulate uninitialized state.
|
||||
client._transport.app.state.server_status = None # type: ignore[attr-defined]
|
||||
client._transport.app.state.server_status = None # type: ignore[attr-defined]
|
||||
response = await client.get("/api/v1/dashboard/status")
|
||||
assert response.status_code == 200
|
||||
status = response.json()["status"]
|
||||
@@ -212,10 +224,10 @@ class TestDashboardStatus:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_ban_list_response(n: int = 2) -> DashboardBanListResponse:
|
||||
"""Build a mock DashboardBanListResponse with *n* items."""
|
||||
def _make_ban_list_response(n: int = 2) -> DomainDashboardBanList:
|
||||
"""Build a mock DomainDashboardBanList with *n* items."""
|
||||
items = [
|
||||
DashboardBanItem(
|
||||
DomainDashboardBanItem(
|
||||
ip=f"1.2.3.{i}",
|
||||
jail="sshd",
|
||||
banned_at="2026-03-01T10:00:00+00:00",
|
||||
@@ -229,15 +241,18 @@ def _make_ban_list_response(n: int = 2) -> DashboardBanListResponse:
|
||||
)
|
||||
for i in range(n)
|
||||
]
|
||||
return DashboardBanListResponse(items=items, total=n, page=1, page_size=100)
|
||||
return DomainDashboardBanList(
|
||||
items=items,
|
||||
total=n,
|
||||
page=1,
|
||||
page_size=100,
|
||||
)
|
||||
|
||||
|
||||
class TestDashboardBans:
|
||||
"""GET /api/dashboard/bans."""
|
||||
|
||||
async def test_returns_200_when_authenticated(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_200_when_authenticated(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Authenticated request returns HTTP 200."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_bans",
|
||||
@@ -246,17 +261,13 @@ class TestDashboardBans:
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_returns_401_when_unauthenticated(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
|
||||
"""Unauthenticated request returns HTTP 401."""
|
||||
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
response = await client.get("/api/v1/dashboard/bans")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_response_contains_items_and_total(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_response_contains_items_and_total(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Response body contains ``items`` list and ``total`` count."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_bans",
|
||||
@@ -266,8 +277,8 @@ class TestDashboardBans:
|
||||
|
||||
body = response.json()
|
||||
assert "items" in body
|
||||
assert "total" in body
|
||||
assert body["total"] == 3
|
||||
assert "pagination" in body
|
||||
assert body["pagination"]["total"] == 3
|
||||
assert len(body["items"]) == 3
|
||||
|
||||
async def test_default_range_is_24h(self, dashboard_client: AsyncClient) -> None:
|
||||
@@ -279,9 +290,7 @@ class TestDashboardBans:
|
||||
called_range = mock_list.call_args[0][1]
|
||||
assert called_range == "24h"
|
||||
|
||||
async def test_accepts_time_range_param(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_accepts_time_range_param(self, dashboard_client: AsyncClient) -> None:
|
||||
"""The ``range`` query parameter is forwarded to ban_service."""
|
||||
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
||||
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
||||
@@ -290,9 +299,7 @@ class TestDashboardBans:
|
||||
called_range = mock_list.call_args[0][1]
|
||||
assert called_range == "7d"
|
||||
|
||||
async def test_accepts_source_param(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_accepts_source_param(self, dashboard_client: AsyncClient) -> None:
|
||||
"""The ``source`` query parameter is forwarded to ban_service."""
|
||||
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
||||
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
||||
@@ -301,11 +308,14 @@ class TestDashboardBans:
|
||||
called_source = mock_list.call_args[1]["source"]
|
||||
assert called_source == "archive"
|
||||
|
||||
async def test_empty_ban_list_returns_zero_total(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_empty_ban_list_returns_zero_total(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Returns ``total=0`` and empty ``items`` when no bans are in range."""
|
||||
empty = DashboardBanListResponse(items=[], total=0, page=1, page_size=100)
|
||||
empty = DomainDashboardBanList(
|
||||
items=[],
|
||||
total=0,
|
||||
page=1,
|
||||
page_size=100,
|
||||
)
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_bans",
|
||||
new=AsyncMock(return_value=empty),
|
||||
@@ -313,7 +323,7 @@ class TestDashboardBans:
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans")
|
||||
|
||||
body = response.json()
|
||||
assert body["total"] == 0
|
||||
assert body["pagination"]["total"] == 0
|
||||
assert body["items"] == []
|
||||
|
||||
async def test_item_shape_is_correct(self, dashboard_client: AsyncClient) -> None:
|
||||
@@ -336,12 +346,10 @@ class TestDashboardBans:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_bans_by_country_response() -> object:
|
||||
"""Build a stub BansByCountryResponse."""
|
||||
from app.models.ban import BansByCountryResponse
|
||||
|
||||
def _make_bans_by_country_response() -> DomainBansByCountry:
|
||||
"""Build a stub DomainBansByCountry."""
|
||||
items = [
|
||||
DashboardBanItem(
|
||||
DomainDashboardBanItem(
|
||||
ip="1.2.3.4",
|
||||
jail="sshd",
|
||||
banned_at="2026-03-01T10:00:00+00:00",
|
||||
@@ -353,7 +361,7 @@ def _make_bans_by_country_response() -> object:
|
||||
ban_count=1,
|
||||
origin="selfblock",
|
||||
),
|
||||
DashboardBanItem(
|
||||
DomainDashboardBanItem(
|
||||
ip="5.6.7.8",
|
||||
jail="blocklist-import",
|
||||
banned_at="2026-03-01T10:05:00+00:00",
|
||||
@@ -366,10 +374,10 @@ def _make_bans_by_country_response() -> object:
|
||||
origin="blocklist",
|
||||
),
|
||||
]
|
||||
return BansByCountryResponse(
|
||||
return DomainBansByCountry(
|
||||
countries={"DE": 1, "US": 1},
|
||||
country_names={"DE": "Germany", "US": "United States"},
|
||||
bans=items,
|
||||
items=items,
|
||||
total=2,
|
||||
)
|
||||
|
||||
@@ -378,9 +386,7 @@ def _make_bans_by_country_response() -> object:
|
||||
class TestBansByCountry:
|
||||
"""GET /api/dashboard/bans/by-country."""
|
||||
|
||||
async def test_returns_200_when_authenticated(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_200_when_authenticated(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Authenticated request returns HTTP 200."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_country",
|
||||
@@ -389,9 +395,7 @@ class TestBansByCountry:
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans/by-country")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_returns_401_when_unauthenticated(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
|
||||
"""Unauthenticated request returns HTTP 401."""
|
||||
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
response = await client.get("/api/v1/dashboard/bans/by-country")
|
||||
@@ -415,38 +419,26 @@ class TestBansByCountry:
|
||||
assert body["countries"]["US"] == 1
|
||||
assert body["country_names"]["DE"] == "Germany"
|
||||
|
||||
async def test_accepts_time_range_param(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_accepts_time_range_param(self, dashboard_client: AsyncClient) -> None:
|
||||
"""The range query parameter is forwarded to ban_service."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn
|
||||
):
|
||||
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
|
||||
await dashboard_client.get("/api/v1/dashboard/bans/by-country?range=7d")
|
||||
|
||||
called_range = mock_fn.call_args[0][1]
|
||||
assert called_range == "7d"
|
||||
|
||||
async def test_invalid_source_returns_422(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""An invalid source value returns HTTP 422."""
|
||||
response = await dashboard_client.get(
|
||||
"/api/v1/dashboard/bans/by-country?source=invalid"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
async def test_invalid_source_returns_400(self, dashboard_client: AsyncClient) -> None:
|
||||
"""An invalid source value returns HTTP 400."""
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans/by-country?source=invalid")
|
||||
assert response.status_code == 400
|
||||
|
||||
async def test_empty_window_returns_empty_response(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_empty_window_returns_empty_response(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Empty time range returns empty countries dict and bans list."""
|
||||
from app.models.ban import BansByCountryResponse
|
||||
|
||||
empty = BansByCountryResponse(
|
||||
empty = DomainBansByCountry(
|
||||
countries={},
|
||||
country_names={},
|
||||
bans=[],
|
||||
items=[],
|
||||
total=0,
|
||||
)
|
||||
with patch(
|
||||
@@ -469,9 +461,7 @@ class TestBansByCountry:
|
||||
class TestDashboardBansOriginField:
|
||||
"""Verify that the ``origin`` field is present in API responses."""
|
||||
|
||||
async def test_origin_present_in_ban_list_items(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_origin_present_in_ban_list_items(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Each item in ``/api/dashboard/bans`` carries an ``origin`` field."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_bans",
|
||||
@@ -483,9 +473,7 @@ class TestDashboardBansOriginField:
|
||||
assert "origin" in item
|
||||
assert item["origin"] in ("blocklist", "selfblock")
|
||||
|
||||
async def test_selfblock_origin_serialised_correctly(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_selfblock_origin_serialised_correctly(self, dashboard_client: AsyncClient) -> None:
|
||||
"""A ban from a non-blocklist jail serialises as ``"selfblock"``."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_bans",
|
||||
@@ -497,9 +485,7 @@ class TestDashboardBansOriginField:
|
||||
assert item["jail"] == "sshd"
|
||||
assert item["origin"] == "selfblock"
|
||||
|
||||
async def test_origin_present_in_bans_by_country(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_origin_present_in_bans_by_country(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Each ban in ``/api/dashboard/bans/by-country`` carries an ``origin``."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_country",
|
||||
@@ -512,9 +498,7 @@ class TestDashboardBansOriginField:
|
||||
origins = {ban["origin"] for ban in bans}
|
||||
assert origins == {"blocklist", "selfblock"}
|
||||
|
||||
async def test_bans_by_country_source_param_forwarded(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_bans_by_country_source_param_forwarded(self, dashboard_client: AsyncClient) -> None:
|
||||
"""The ``source`` query parameter is forwarded to bans_by_country."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
|
||||
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
|
||||
@@ -522,22 +506,16 @@ class TestDashboardBansOriginField:
|
||||
|
||||
assert mock_fn.call_args[1]["source"] == "archive"
|
||||
|
||||
async def test_bans_by_country_country_code_forwarded(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_bans_by_country_country_code_forwarded(self, dashboard_client: AsyncClient) -> None:
|
||||
"""The ``country_code`` query parameter is forwarded to bans_by_country."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
|
||||
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
|
||||
await dashboard_client.get(
|
||||
"/api/v1/dashboard/bans/by-country?country_code=DE"
|
||||
)
|
||||
await dashboard_client.get("/api/v1/dashboard/bans/by-country?country_code=DE")
|
||||
|
||||
_, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("country_code") == "DE"
|
||||
|
||||
async def test_blocklist_origin_serialised_correctly(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_blocklist_origin_serialised_correctly(self, dashboard_client: AsyncClient) -> None:
|
||||
"""A ban from the ``blocklist-import`` jail serialises as ``"blocklist"``."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_country",
|
||||
@@ -558,9 +536,7 @@ class TestDashboardBansOriginField:
|
||||
class TestOriginFilterParam:
|
||||
"""Verify that the ``origin`` query parameter is forwarded to the service."""
|
||||
|
||||
async def test_bans_origin_blocklist_forwarded_to_service(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_bans_origin_blocklist_forwarded_to_service(self, dashboard_client: AsyncClient) -> None:
|
||||
"""``?origin=blocklist`` is passed to ``ban_service.list_bans``."""
|
||||
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
||||
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
||||
@@ -569,9 +545,7 @@ class TestOriginFilterParam:
|
||||
_, kwargs = mock_list.call_args
|
||||
assert kwargs.get("origin") == "blocklist"
|
||||
|
||||
async def test_bans_origin_selfblock_forwarded_to_service(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_bans_origin_selfblock_forwarded_to_service(self, dashboard_client: AsyncClient) -> None:
|
||||
"""``?origin=selfblock`` is passed to ``ban_service.list_bans``."""
|
||||
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
||||
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
||||
@@ -580,9 +554,7 @@ class TestOriginFilterParam:
|
||||
_, kwargs = mock_list.call_args
|
||||
assert kwargs.get("origin") == "selfblock"
|
||||
|
||||
async def test_bans_no_origin_param_defaults_to_none(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_bans_no_origin_param_defaults_to_none(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Omitting ``origin`` passes ``None`` to the service (no filtering)."""
|
||||
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
||||
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
||||
@@ -591,36 +563,24 @@ class TestOriginFilterParam:
|
||||
_, kwargs = mock_list.call_args
|
||||
assert kwargs.get("origin") is None
|
||||
|
||||
async def test_bans_invalid_origin_returns_422(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""An invalid ``origin`` value returns HTTP 422 Unprocessable Entity."""
|
||||
async def test_bans_invalid_origin_returns_400(self, dashboard_client: AsyncClient) -> None:
|
||||
"""An invalid ``origin`` value returns HTTP 400."""
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans?origin=invalid")
|
||||
assert response.status_code == 422
|
||||
assert response.status_code == 400
|
||||
|
||||
async def test_by_country_origin_blocklist_forwarded(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_by_country_origin_blocklist_forwarded(self, dashboard_client: AsyncClient) -> None:
|
||||
"""``?origin=blocklist`` is passed to ``ban_service.bans_by_country``."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn
|
||||
):
|
||||
await dashboard_client.get(
|
||||
"/api/v1/dashboard/bans/by-country?origin=blocklist"
|
||||
)
|
||||
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
|
||||
await dashboard_client.get("/api/v1/dashboard/bans/by-country?origin=blocklist")
|
||||
|
||||
_, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("origin") == "blocklist"
|
||||
|
||||
async def test_by_country_no_origin_defaults_to_none(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_by_country_no_origin_defaults_to_none(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Omitting ``origin`` passes ``None`` to ``bans_by_country``."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn
|
||||
):
|
||||
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
|
||||
await dashboard_client.get("/api/v1/dashboard/bans/by-country")
|
||||
|
||||
_, kwargs = mock_fn.call_args
|
||||
@@ -632,24 +592,17 @@ class TestOriginFilterParam:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_ban_trend_response(n_buckets: int = 24) -> object:
|
||||
"""Build a stub :class:`~app.models.ban.BanTrendResponse`."""
|
||||
from app.models.ban import BanTrendBucket, BanTrendResponse
|
||||
|
||||
buckets = [
|
||||
BanTrendBucket(timestamp=f"2026-03-01T{i:02d}:00:00+00:00", count=i)
|
||||
for i in range(n_buckets)
|
||||
]
|
||||
return BanTrendResponse(buckets=buckets, bucket_size="1h")
|
||||
def _make_ban_trend_response(n_buckets: int = 24) -> DomainBanTrend:
|
||||
"""Build a stub :class:`~app.models.ban_domain.DomainBanTrend`."""
|
||||
buckets = [DomainBanTrendBucket(timestamp=f"2026-03-01T{i:02d}:00:00+00:00", count=i) for i in range(n_buckets)]
|
||||
return DomainBanTrend(buckets=buckets, bucket_size="1h")
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
class TestBanTrend:
|
||||
"""GET /api/dashboard/bans/trend."""
|
||||
|
||||
async def test_returns_200_when_authenticated(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_200_when_authenticated(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Authenticated request returns HTTP 200."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.ban_trend",
|
||||
@@ -658,9 +611,7 @@ class TestBanTrend:
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans/trend")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_returns_401_when_unauthenticated(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
|
||||
"""Unauthenticated request returns HTTP 401."""
|
||||
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
response = await client.get("/api/v1/dashboard/bans/trend")
|
||||
@@ -680,9 +631,7 @@ class TestBanTrend:
|
||||
assert len(body["buckets"]) == 24
|
||||
assert body["bucket_size"] == "1h"
|
||||
|
||||
async def test_each_bucket_has_timestamp_and_count(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_each_bucket_has_timestamp_and_count(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Every element of ``buckets`` has ``timestamp`` and ``count``."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.ban_trend",
|
||||
@@ -717,16 +666,12 @@ class TestBanTrend:
|
||||
"""``?origin=blocklist`` is passed as a keyword arg to the service."""
|
||||
mock_fn = AsyncMock(return_value=_make_ban_trend_response())
|
||||
with patch("app.routers.dashboard.ban_service.ban_trend", new=mock_fn):
|
||||
await dashboard_client.get(
|
||||
"/api/v1/dashboard/bans/trend?origin=blocklist"
|
||||
)
|
||||
await dashboard_client.get("/api/v1/dashboard/bans/trend?origin=blocklist")
|
||||
|
||||
_, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("origin") == "blocklist"
|
||||
|
||||
async def test_no_origin_defaults_to_none(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_no_origin_defaults_to_none(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Omitting ``origin`` passes ``None`` to the service."""
|
||||
mock_fn = AsyncMock(return_value=_make_ban_trend_response())
|
||||
with patch("app.routers.dashboard.ban_service.ban_trend", new=mock_fn):
|
||||
@@ -735,29 +680,19 @@ class TestBanTrend:
|
||||
_, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("origin") is None
|
||||
|
||||
async def test_invalid_range_returns_422(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""An invalid ``range`` value returns HTTP 422."""
|
||||
response = await dashboard_client.get(
|
||||
"/api/v1/dashboard/bans/trend?range=invalid"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
async def test_invalid_range_returns_400(self, dashboard_client: AsyncClient) -> None:
|
||||
"""An invalid ``range`` value returns HTTP 400."""
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans/trend?range=invalid")
|
||||
assert response.status_code == 400
|
||||
|
||||
async def test_invalid_source_returns_422(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""An invalid source value returns HTTP 422."""
|
||||
response = await dashboard_client.get(
|
||||
"/api/v1/dashboard/bans/trend?source=invalid"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
async def test_invalid_source_returns_400(self, dashboard_client: AsyncClient) -> None:
|
||||
"""An invalid source value returns HTTP 400."""
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans/trend?source=invalid")
|
||||
assert response.status_code == 400
|
||||
|
||||
async def test_empty_buckets_response(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Empty bucket list is serialised correctly."""
|
||||
from app.models.ban import BanTrendResponse
|
||||
|
||||
empty = BanTrendResponse(buckets=[], bucket_size="1h")
|
||||
empty = DomainBanTrend(buckets=[], bucket_size="1h")
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.ban_trend",
|
||||
new=AsyncMock(return_value=empty),
|
||||
@@ -774,14 +709,12 @@ class TestBanTrend:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_bans_by_jail_response() -> object:
|
||||
"""Build a stub :class:`~app.models.ban.BansByJailResponse`."""
|
||||
from app.models.ban import BansByJailResponse, JailBanCount
|
||||
|
||||
return BansByJailResponse(
|
||||
def _make_bans_by_jail_response() -> DomainBansByJail:
|
||||
"""Build a stub :class:`~app.models.ban_domain.DomainBansByJail`."""
|
||||
return DomainBansByJail(
|
||||
jails=[
|
||||
JailBanCount(jail="sshd", count=10),
|
||||
JailBanCount(jail="nginx", count=5),
|
||||
DomainJailBanCount(jail="sshd", count=10),
|
||||
DomainJailBanCount(jail="nginx", count=5),
|
||||
],
|
||||
total=15,
|
||||
)
|
||||
@@ -791,9 +724,7 @@ def _make_bans_by_jail_response() -> object:
|
||||
class TestBansByJail:
|
||||
"""GET /api/dashboard/bans/by-jail."""
|
||||
|
||||
async def test_returns_200_when_authenticated(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_200_when_authenticated(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Authenticated request returns HTTP 200."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_jail",
|
||||
@@ -802,9 +733,7 @@ class TestBansByJail:
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans/by-jail")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_returns_401_when_unauthenticated(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
async def test_returns_401_when_unauthenticated(self, client: AsyncClient) -> None:
|
||||
"""Unauthenticated request returns HTTP 401."""
|
||||
await client.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
response = await client.get("/api/v1/dashboard/bans/by-jail")
|
||||
@@ -823,9 +752,7 @@ class TestBansByJail:
|
||||
assert "total" in body
|
||||
assert isinstance(body["total"], int)
|
||||
|
||||
async def test_each_jail_has_name_and_count(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_each_jail_has_name_and_count(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Every element of ``jails`` has ``jail`` (string) and ``count`` (int)."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_jail",
|
||||
@@ -861,16 +788,12 @@ class TestBansByJail:
|
||||
"""``?origin=blocklist`` is passed as a keyword arg to the service."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_jail_response())
|
||||
with patch("app.routers.dashboard.ban_service.bans_by_jail", new=mock_fn):
|
||||
await dashboard_client.get(
|
||||
"/api/v1/dashboard/bans/by-jail?origin=blocklist"
|
||||
)
|
||||
await dashboard_client.get("/api/v1/dashboard/bans/by-jail?origin=blocklist")
|
||||
|
||||
_, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("origin") == "blocklist"
|
||||
|
||||
async def test_no_origin_defaults_to_none(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_no_origin_defaults_to_none(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Omitting ``origin`` passes ``None`` to the service."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_jail_response())
|
||||
with patch("app.routers.dashboard.ban_service.bans_by_jail", new=mock_fn):
|
||||
@@ -879,23 +802,15 @@ class TestBansByJail:
|
||||
_, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("origin") is None
|
||||
|
||||
async def test_invalid_range_returns_422(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""An invalid ``range`` value returns HTTP 422."""
|
||||
response = await dashboard_client.get(
|
||||
"/api/v1/dashboard/bans/by-jail?range=invalid"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
async def test_invalid_range_returns_400(self, dashboard_client: AsyncClient) -> None:
|
||||
"""An invalid ``range`` value returns HTTP 400."""
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans/by-jail?range=invalid")
|
||||
assert response.status_code == 400
|
||||
|
||||
async def test_invalid_source_returns_422(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""An invalid source value returns HTTP 422."""
|
||||
response = await dashboard_client.get(
|
||||
"/api/v1/dashboard/bans/by-jail?source=invalid"
|
||||
)
|
||||
assert response.status_code == 422
|
||||
async def test_invalid_source_returns_400(self, dashboard_client: AsyncClient) -> None:
|
||||
"""An invalid source value returns HTTP 400."""
|
||||
response = await dashboard_client.get("/api/v1/dashboard/bans/by-jail?source=invalid")
|
||||
assert response.status_code == 400
|
||||
|
||||
async def test_empty_jails_response(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Empty jails list is serialised correctly."""
|
||||
@@ -911,4 +826,3 @@ class TestBansByJail:
|
||||
body = response.json()
|
||||
assert body["jails"] == []
|
||||
assert body["total"] == 0
|
||||
|
||||
|
||||
Reference in New Issue
Block a user