Add origin field and filter for ban sources (Tasks 1 & 2)
- Task 1: Mark imported blocklist IP addresses
- Add BanOrigin type and _derive_origin() to ban.py model
- Populate origin field in ban_service list_bans() and bans_by_country()
- BanTable and MapPage companion table show origin badge column
- Tests: origin derivation in test_ban_service.py and test_dashboard.py
- Task 2: Add origin filter to dashboard and world map
- ban_service: _origin_sql_filter() helper; origin param on list_bans()
and bans_by_country()
- dashboard router: optional origin query param forwarded to service
- Frontend: BanOriginFilter type + BAN_ORIGIN_FILTER_LABELS in ban.ts
- fetchBans / fetchBansByCountry forward origin to API
- useBans / useMapData accept and pass origin; page resets on change
- BanTable accepts origin prop; DashboardPage adds segmented filter
- MapPage adds origin Select next to time-range picker
- Tests: origin filter assertions in test_ban_service and test_dashboard
This commit is contained in:
@@ -447,3 +447,90 @@ class TestPreviewLog:
|
||||
data = resp.json()
|
||||
assert data["total_lines"] == 1
|
||||
assert data["matched_count"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/config/map-color-thresholds
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetMapColorThresholds:
|
||||
"""Tests for ``GET /api/config/map-color-thresholds``."""
|
||||
|
||||
async def test_200_returns_thresholds(self, config_client: AsyncClient) -> None:
|
||||
"""GET /api/config/map-color-thresholds returns 200 with current values."""
|
||||
resp = await config_client.get("/api/config/map-color-thresholds")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "threshold_high" in data
|
||||
assert "threshold_medium" in data
|
||||
assert "threshold_low" in data
|
||||
# Should return defaults after setup
|
||||
assert data["threshold_high"] == 100
|
||||
assert data["threshold_medium"] == 50
|
||||
assert data["threshold_low"] == 20
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /api/config/map-color-thresholds
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUpdateMapColorThresholds:
|
||||
"""Tests for ``PUT /api/config/map-color-thresholds``."""
|
||||
|
||||
async def test_200_updates_thresholds(self, config_client: AsyncClient) -> None:
|
||||
"""PUT /api/config/map-color-thresholds returns 200 and updates settings."""
|
||||
update_payload = {
|
||||
"threshold_high": 200,
|
||||
"threshold_medium": 80,
|
||||
"threshold_low": 30,
|
||||
}
|
||||
resp = await config_client.put(
|
||||
"/api/config/map-color-thresholds", json=update_payload
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["threshold_high"] == 200
|
||||
assert data["threshold_medium"] == 80
|
||||
assert data["threshold_low"] == 30
|
||||
|
||||
# Verify the values persist
|
||||
get_resp = await config_client.get("/api/config/map-color-thresholds")
|
||||
assert get_resp.status_code == 200
|
||||
get_data = get_resp.json()
|
||||
assert get_data["threshold_high"] == 200
|
||||
assert get_data["threshold_medium"] == 80
|
||||
assert get_data["threshold_low"] == 30
|
||||
|
||||
async def test_400_for_invalid_order(self, config_client: AsyncClient) -> None:
|
||||
"""PUT /api/config/map-color-thresholds returns 400 if thresholds are misordered."""
|
||||
invalid_payload = {
|
||||
"threshold_high": 50,
|
||||
"threshold_medium": 50,
|
||||
"threshold_low": 20,
|
||||
}
|
||||
resp = await config_client.put(
|
||||
"/api/config/map-color-thresholds", json=invalid_payload
|
||||
)
|
||||
|
||||
assert resp.status_code == 400
|
||||
assert "high > medium > low" in resp.json()["detail"]
|
||||
|
||||
async def test_400_for_non_positive_values(
|
||||
self, config_client: AsyncClient
|
||||
) -> None:
|
||||
"""PUT /api/config/map-color-thresholds returns 422 for non-positive values (Pydantic validation)."""
|
||||
invalid_payload = {
|
||||
"threshold_high": 100,
|
||||
"threshold_medium": 50,
|
||||
"threshold_low": 0,
|
||||
}
|
||||
resp = await config_client.put(
|
||||
"/api/config/map-color-thresholds", json=invalid_payload
|
||||
)
|
||||
|
||||
# Pydantic validates ge=1 constraint before our service code runs
|
||||
assert resp.status_code == 422
|
||||
|
||||
@@ -220,6 +220,7 @@ def _make_ban_list_response(n: int = 2) -> DashboardBanListResponse:
|
||||
asn="AS3320",
|
||||
org="Telekom",
|
||||
ban_count=1,
|
||||
origin="selfblock",
|
||||
)
|
||||
for i in range(n)
|
||||
]
|
||||
@@ -334,10 +335,11 @@ def _make_bans_by_country_response() -> object:
|
||||
asn="AS3320",
|
||||
org="Telekom",
|
||||
ban_count=1,
|
||||
origin="selfblock",
|
||||
),
|
||||
DashboardBanItem(
|
||||
ip="5.6.7.8",
|
||||
jail="sshd",
|
||||
jail="blocklist-import",
|
||||
banned_at="2026-03-01T10:05:00+00:00",
|
||||
service=None,
|
||||
country_code="US",
|
||||
@@ -345,6 +347,7 @@ def _make_bans_by_country_response() -> object:
|
||||
asn="AS15169",
|
||||
org="Google LLC",
|
||||
ban_count=2,
|
||||
origin="blocklist",
|
||||
),
|
||||
]
|
||||
return BansByCountryResponse(
|
||||
@@ -431,3 +434,146 @@ class TestBansByCountry:
|
||||
assert body["total"] == 0
|
||||
assert body["countries"] == {}
|
||||
assert body["bans"] == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Origin field tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDashboardBansOriginField:
|
||||
"""Verify that the ``origin`` field is present in API responses."""
|
||||
|
||||
async def test_origin_present_in_ban_list_items(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""Each item in ``/api/dashboard/bans`` carries an ``origin`` field."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_bans",
|
||||
new=AsyncMock(return_value=_make_ban_list_response(1)),
|
||||
):
|
||||
response = await dashboard_client.get("/api/dashboard/bans")
|
||||
|
||||
item = response.json()["items"][0]
|
||||
assert "origin" in item
|
||||
assert item["origin"] in ("blocklist", "selfblock")
|
||||
|
||||
async def test_selfblock_origin_serialised_correctly(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""A ban from a non-blocklist jail serialises as ``"selfblock"``."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_bans",
|
||||
new=AsyncMock(return_value=_make_ban_list_response(1)),
|
||||
):
|
||||
response = await dashboard_client.get("/api/dashboard/bans")
|
||||
|
||||
item = response.json()["items"][0]
|
||||
assert item["jail"] == "sshd"
|
||||
assert item["origin"] == "selfblock"
|
||||
|
||||
async def test_origin_present_in_bans_by_country(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""Each ban in ``/api/dashboard/bans/by-country`` carries an ``origin``."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_country",
|
||||
new=AsyncMock(return_value=_make_bans_by_country_response()),
|
||||
):
|
||||
response = await dashboard_client.get("/api/dashboard/bans/by-country")
|
||||
|
||||
bans = response.json()["bans"]
|
||||
assert all("origin" in ban for ban in bans)
|
||||
origins = {ban["origin"] for ban in bans}
|
||||
assert origins == {"blocklist", "selfblock"}
|
||||
|
||||
async def test_blocklist_origin_serialised_correctly(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""A ban from the ``blocklist-import`` jail serialises as ``"blocklist"``."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_country",
|
||||
new=AsyncMock(return_value=_make_bans_by_country_response()),
|
||||
):
|
||||
response = await dashboard_client.get("/api/dashboard/bans/by-country")
|
||||
|
||||
bans = response.json()["bans"]
|
||||
blocklist_ban = next(b for b in bans if b["jail"] == "blocklist-import")
|
||||
assert blocklist_ban["origin"] == "blocklist"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Origin filter query parameter tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOriginFilterParam:
|
||||
"""Verify that the ``origin`` query parameter is forwarded to the service."""
|
||||
|
||||
async def test_bans_origin_blocklist_forwarded_to_service(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""``?origin=blocklist`` is passed to ``ban_service.list_bans``."""
|
||||
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
||||
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
||||
await dashboard_client.get("/api/dashboard/bans?origin=blocklist")
|
||||
|
||||
_, kwargs = mock_list.call_args
|
||||
assert kwargs.get("origin") == "blocklist"
|
||||
|
||||
async def test_bans_origin_selfblock_forwarded_to_service(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""``?origin=selfblock`` is passed to ``ban_service.list_bans``."""
|
||||
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
||||
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
||||
await dashboard_client.get("/api/dashboard/bans?origin=selfblock")
|
||||
|
||||
_, kwargs = mock_list.call_args
|
||||
assert kwargs.get("origin") == "selfblock"
|
||||
|
||||
async def test_bans_no_origin_param_defaults_to_none(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""Omitting ``origin`` passes ``None`` to the service (no filtering)."""
|
||||
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
||||
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
||||
await dashboard_client.get("/api/dashboard/bans")
|
||||
|
||||
_, kwargs = mock_list.call_args
|
||||
assert kwargs.get("origin") is None
|
||||
|
||||
async def test_bans_invalid_origin_returns_422(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""An invalid ``origin`` value returns HTTP 422 Unprocessable Entity."""
|
||||
response = await dashboard_client.get("/api/dashboard/bans?origin=invalid")
|
||||
assert response.status_code == 422
|
||||
|
||||
async def test_by_country_origin_blocklist_forwarded(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""``?origin=blocklist`` is passed to ``ban_service.bans_by_country``."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn
|
||||
):
|
||||
await dashboard_client.get(
|
||||
"/api/dashboard/bans/by-country?origin=blocklist"
|
||||
)
|
||||
|
||||
_, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("origin") == "blocklist"
|
||||
|
||||
async def test_by_country_no_origin_defaults_to_none(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""Omitting ``origin`` passes ``None`` to ``bans_by_country``."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn
|
||||
):
|
||||
await dashboard_client.get("/api/dashboard/bans/by-country")
|
||||
|
||||
_, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("origin") is None
|
||||
|
||||
Reference in New Issue
Block a user