Add origin field and filter for ban sources (Tasks 1 & 2)

- Task 1: Mark imported blocklist IP addresses
  - Add BanOrigin type and _derive_origin() to ban.py model
  - Populate origin field in ban_service list_bans() and bans_by_country()
  - BanTable and MapPage companion table show origin badge column
  - Tests: origin derivation in test_ban_service.py and test_dashboard.py

- Task 2: Add origin filter to dashboard and world map
  - ban_service: _origin_sql_filter() helper; origin param on list_bans()
    and bans_by_country()
  - dashboard router: optional origin query param forwarded to service
  - Frontend: BanOriginFilter type + BAN_ORIGIN_FILTER_LABELS in ban.ts
  - fetchBans / fetchBansByCountry forward origin to API
  - useBans / useMapData accept and pass origin; page resets on change
  - BanTable accepts origin prop; DashboardPage adds segmented filter
  - MapPage adds origin Select next to time-range picker
  - Tests: origin filter assertions in test_ban_service and test_dashboard
This commit is contained in:
2026-03-07 20:03:43 +01:00
parent 706d2e1df8
commit 53d664de4f
28 changed files with 1637 additions and 103 deletions

View File

@@ -447,3 +447,90 @@ class TestPreviewLog:
data = resp.json()
assert data["total_lines"] == 1
assert data["matched_count"] == 1
# ---------------------------------------------------------------------------
# GET /api/config/map-color-thresholds
# ---------------------------------------------------------------------------
class TestGetMapColorThresholds:
"""Tests for ``GET /api/config/map-color-thresholds``."""
async def test_200_returns_thresholds(self, config_client: AsyncClient) -> None:
"""GET /api/config/map-color-thresholds returns 200 with current values."""
resp = await config_client.get("/api/config/map-color-thresholds")
assert resp.status_code == 200
data = resp.json()
assert "threshold_high" in data
assert "threshold_medium" in data
assert "threshold_low" in data
# Should return defaults after setup
assert data["threshold_high"] == 100
assert data["threshold_medium"] == 50
assert data["threshold_low"] == 20
# ---------------------------------------------------------------------------
# PUT /api/config/map-color-thresholds
# ---------------------------------------------------------------------------
class TestUpdateMapColorThresholds:
"""Tests for ``PUT /api/config/map-color-thresholds``."""
async def test_200_updates_thresholds(self, config_client: AsyncClient) -> None:
"""PUT /api/config/map-color-thresholds returns 200 and updates settings."""
update_payload = {
"threshold_high": 200,
"threshold_medium": 80,
"threshold_low": 30,
}
resp = await config_client.put(
"/api/config/map-color-thresholds", json=update_payload
)
assert resp.status_code == 200
data = resp.json()
assert data["threshold_high"] == 200
assert data["threshold_medium"] == 80
assert data["threshold_low"] == 30
# Verify the values persist
get_resp = await config_client.get("/api/config/map-color-thresholds")
assert get_resp.status_code == 200
get_data = get_resp.json()
assert get_data["threshold_high"] == 200
assert get_data["threshold_medium"] == 80
assert get_data["threshold_low"] == 30
async def test_400_for_invalid_order(self, config_client: AsyncClient) -> None:
"""PUT /api/config/map-color-thresholds returns 400 if thresholds are misordered."""
invalid_payload = {
"threshold_high": 50,
"threshold_medium": 50,
"threshold_low": 20,
}
resp = await config_client.put(
"/api/config/map-color-thresholds", json=invalid_payload
)
assert resp.status_code == 400
assert "high > medium > low" in resp.json()["detail"]
async def test_400_for_non_positive_values(
self, config_client: AsyncClient
) -> None:
"""PUT /api/config/map-color-thresholds returns 422 for non-positive values (Pydantic validation)."""
invalid_payload = {
"threshold_high": 100,
"threshold_medium": 50,
"threshold_low": 0,
}
resp = await config_client.put(
"/api/config/map-color-thresholds", json=invalid_payload
)
# Pydantic validates ge=1 constraint before our service code runs
assert resp.status_code == 422

View File

@@ -220,6 +220,7 @@ def _make_ban_list_response(n: int = 2) -> DashboardBanListResponse:
asn="AS3320",
org="Telekom",
ban_count=1,
origin="selfblock",
)
for i in range(n)
]
@@ -334,10 +335,11 @@ def _make_bans_by_country_response() -> object:
asn="AS3320",
org="Telekom",
ban_count=1,
origin="selfblock",
),
DashboardBanItem(
ip="5.6.7.8",
jail="sshd",
jail="blocklist-import",
banned_at="2026-03-01T10:05:00+00:00",
service=None,
country_code="US",
@@ -345,6 +347,7 @@ def _make_bans_by_country_response() -> object:
asn="AS15169",
org="Google LLC",
ban_count=2,
origin="blocklist",
),
]
return BansByCountryResponse(
@@ -431,3 +434,146 @@ class TestBansByCountry:
assert body["total"] == 0
assert body["countries"] == {}
assert body["bans"] == []
# ---------------------------------------------------------------------------
# Origin field tests
# ---------------------------------------------------------------------------
class TestDashboardBansOriginField:
"""Verify that the ``origin`` field is present in API responses."""
async def test_origin_present_in_ban_list_items(
self, dashboard_client: AsyncClient
) -> None:
"""Each item in ``/api/dashboard/bans`` carries an ``origin`` field."""
with patch(
"app.routers.dashboard.ban_service.list_bans",
new=AsyncMock(return_value=_make_ban_list_response(1)),
):
response = await dashboard_client.get("/api/dashboard/bans")
item = response.json()["items"][0]
assert "origin" in item
assert item["origin"] in ("blocklist", "selfblock")
async def test_selfblock_origin_serialised_correctly(
self, dashboard_client: AsyncClient
) -> None:
"""A ban from a non-blocklist jail serialises as ``"selfblock"``."""
with patch(
"app.routers.dashboard.ban_service.list_bans",
new=AsyncMock(return_value=_make_ban_list_response(1)),
):
response = await dashboard_client.get("/api/dashboard/bans")
item = response.json()["items"][0]
assert item["jail"] == "sshd"
assert item["origin"] == "selfblock"
async def test_origin_present_in_bans_by_country(
self, dashboard_client: AsyncClient
) -> None:
"""Each ban in ``/api/dashboard/bans/by-country`` carries an ``origin``."""
with patch(
"app.routers.dashboard.ban_service.bans_by_country",
new=AsyncMock(return_value=_make_bans_by_country_response()),
):
response = await dashboard_client.get("/api/dashboard/bans/by-country")
bans = response.json()["bans"]
assert all("origin" in ban for ban in bans)
origins = {ban["origin"] for ban in bans}
assert origins == {"blocklist", "selfblock"}
async def test_blocklist_origin_serialised_correctly(
self, dashboard_client: AsyncClient
) -> None:
"""A ban from the ``blocklist-import`` jail serialises as ``"blocklist"``."""
with patch(
"app.routers.dashboard.ban_service.bans_by_country",
new=AsyncMock(return_value=_make_bans_by_country_response()),
):
response = await dashboard_client.get("/api/dashboard/bans/by-country")
bans = response.json()["bans"]
blocklist_ban = next(b for b in bans if b["jail"] == "blocklist-import")
assert blocklist_ban["origin"] == "blocklist"
# ---------------------------------------------------------------------------
# Origin filter query parameter tests
# ---------------------------------------------------------------------------
class TestOriginFilterParam:
"""Verify that the ``origin`` query parameter is forwarded to the service."""
async def test_bans_origin_blocklist_forwarded_to_service(
self, dashboard_client: AsyncClient
) -> None:
"""``?origin=blocklist`` is passed to ``ban_service.list_bans``."""
mock_list = AsyncMock(return_value=_make_ban_list_response())
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
await dashboard_client.get("/api/dashboard/bans?origin=blocklist")
_, kwargs = mock_list.call_args
assert kwargs.get("origin") == "blocklist"
async def test_bans_origin_selfblock_forwarded_to_service(
self, dashboard_client: AsyncClient
) -> None:
"""``?origin=selfblock`` is passed to ``ban_service.list_bans``."""
mock_list = AsyncMock(return_value=_make_ban_list_response())
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
await dashboard_client.get("/api/dashboard/bans?origin=selfblock")
_, kwargs = mock_list.call_args
assert kwargs.get("origin") == "selfblock"
async def test_bans_no_origin_param_defaults_to_none(
self, dashboard_client: AsyncClient
) -> None:
"""Omitting ``origin`` passes ``None`` to the service (no filtering)."""
mock_list = AsyncMock(return_value=_make_ban_list_response())
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
await dashboard_client.get("/api/dashboard/bans")
_, kwargs = mock_list.call_args
assert kwargs.get("origin") is None
async def test_bans_invalid_origin_returns_422(
self, dashboard_client: AsyncClient
) -> None:
"""An invalid ``origin`` value returns HTTP 422 Unprocessable Entity."""
response = await dashboard_client.get("/api/dashboard/bans?origin=invalid")
assert response.status_code == 422
async def test_by_country_origin_blocklist_forwarded(
self, dashboard_client: AsyncClient
) -> None:
"""``?origin=blocklist`` is passed to ``ban_service.bans_by_country``."""
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
with patch(
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn
):
await dashboard_client.get(
"/api/dashboard/bans/by-country?origin=blocklist"
)
_, kwargs = mock_fn.call_args
assert kwargs.get("origin") == "blocklist"
async def test_by_country_no_origin_defaults_to_none(
self, dashboard_client: AsyncClient
) -> None:
"""Omitting ``origin`` passes ``None`` to ``bans_by_country``."""
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
with patch(
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn
):
await dashboard_client.get("/api/dashboard/bans/by-country")
_, kwargs = mock_fn.call_args
assert kwargs.get("origin") is None

View File

@@ -102,6 +102,39 @@ async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
return path
@pytest.fixture
async def mixed_origin_db_path(tmp_path: Path) -> str: # type: ignore[misc]
"""Return a database with bans from both blocklist-import and organic jails."""
path = str(tmp_path / "fail2ban_mixed_origin.sqlite3")
await _create_f2b_db(
path,
[
{
"jail": "blocklist-import",
"ip": "10.0.0.1",
"timeofban": _ONE_HOUR_AGO,
"bantime": -1,
"bancount": 1,
},
{
"jail": "sshd",
"ip": "10.0.0.2",
"timeofban": _ONE_HOUR_AGO,
"bantime": 3600,
"bancount": 3,
},
{
"jail": "nginx",
"ip": "10.0.0.3",
"timeofban": _ONE_HOUR_AGO,
"bantime": 7200,
"bancount": 1,
},
],
)
return path
@pytest.fixture
async def empty_f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
"""Return the path to a fail2ban SQLite database with no ban records."""
@@ -299,3 +332,183 @@ class TestListBansPagination:
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
assert result.total == 3 # All three bans are within 7d.
# ---------------------------------------------------------------------------
# list_bans / bans_by_country — origin derivation
# ---------------------------------------------------------------------------
class TestBanOriginDerivation:
"""Verify that ban_service correctly derives ``origin`` from jail names."""
async def test_blocklist_import_jail_yields_blocklist_origin(
self, mixed_origin_db_path: str
) -> None:
"""Bans from ``blocklist-import`` jail carry ``origin == "blocklist"``."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.list_bans("/fake/sock", "24h")
blocklist_items = [i for i in result.items if i.jail == "blocklist-import"]
assert len(blocklist_items) == 1
assert blocklist_items[0].origin == "blocklist"
async def test_organic_jail_yields_selfblock_origin(
self, mixed_origin_db_path: str
) -> None:
"""Bans from organic jails (sshd, nginx, …) carry ``origin == "selfblock"``."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.list_bans("/fake/sock", "24h")
organic_items = [i for i in result.items if i.jail != "blocklist-import"]
assert len(organic_items) == 2
for item in organic_items:
assert item.origin == "selfblock"
async def test_all_items_carry_origin_field(
self, mixed_origin_db_path: str
) -> None:
"""Every returned item has an ``origin`` field with a valid value."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.list_bans("/fake/sock", "24h")
for item in result.items:
assert item.origin in ("blocklist", "selfblock")
async def test_bans_by_country_blocklist_origin(
self, mixed_origin_db_path: str
) -> None:
"""``bans_by_country`` also derives origin correctly for blocklist bans."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_country("/fake/sock", "24h")
blocklist_bans = [b for b in result.bans if b.jail == "blocklist-import"]
assert len(blocklist_bans) == 1
assert blocklist_bans[0].origin == "blocklist"
async def test_bans_by_country_selfblock_origin(
self, mixed_origin_db_path: str
) -> None:
"""``bans_by_country`` derives origin correctly for organic jails."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_country("/fake/sock", "24h")
organic_bans = [b for b in result.bans if b.jail != "blocklist-import"]
assert len(organic_bans) == 2
for ban in organic_bans:
assert ban.origin == "selfblock"
# ---------------------------------------------------------------------------
# list_bans / bans_by_country — origin filter parameter
# ---------------------------------------------------------------------------
class TestOriginFilter:
"""Verify that the origin filter correctly restricts results."""
async def test_list_bans_blocklist_filter_returns_only_blocklist(
self, mixed_origin_db_path: str
) -> None:
"""``origin='blocklist'`` returns only blocklist-import jail bans."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.list_bans(
"/fake/sock", "24h", origin="blocklist"
)
assert result.total == 1
assert len(result.items) == 1
assert result.items[0].jail == "blocklist-import"
assert result.items[0].origin == "blocklist"
async def test_list_bans_selfblock_filter_excludes_blocklist(
self, mixed_origin_db_path: str
) -> None:
"""``origin='selfblock'`` excludes the blocklist-import jail."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.list_bans(
"/fake/sock", "24h", origin="selfblock"
)
assert result.total == 2
assert len(result.items) == 2
for item in result.items:
assert item.jail != "blocklist-import"
assert item.origin == "selfblock"
async def test_list_bans_no_filter_returns_all(
self, mixed_origin_db_path: str
) -> None:
"""``origin=None`` applies no jail restriction — all bans returned."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.list_bans("/fake/sock", "24h", origin=None)
assert result.total == 3
async def test_bans_by_country_blocklist_filter(
self, mixed_origin_db_path: str
) -> None:
"""``bans_by_country`` with ``origin='blocklist'`` counts only blocklist bans."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_country(
"/fake/sock", "24h", origin="blocklist"
)
assert result.total == 1
assert all(b.jail == "blocklist-import" for b in result.bans)
async def test_bans_by_country_selfblock_filter(
self, mixed_origin_db_path: str
) -> None:
"""``bans_by_country`` with ``origin='selfblock'`` excludes blocklist jails."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_country(
"/fake/sock", "24h", origin="selfblock"
)
assert result.total == 2
assert all(b.jail != "blocklist-import" for b in result.bans)
async def test_bans_by_country_no_filter_returns_all(
self, mixed_origin_db_path: str
) -> None:
"""``bans_by_country`` with ``origin=None`` returns all bans."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_country(
"/fake/sock", "24h", origin=None
)
assert result.total == 3

View File

@@ -98,6 +98,23 @@ class TestRunSetup:
with pytest.raises(RuntimeError, match="already been completed"):
await setup_service.run_setup(db, **kwargs) # type: ignore[arg-type]
async def test_initializes_map_color_thresholds_with_defaults(
self, db: aiosqlite.Connection
) -> None:
"""run_setup() initializes map color thresholds with default values."""
await setup_service.run_setup(
db,
master_password="mypassword1",
database_path="bangui.db",
fail2ban_socket="/var/run/fail2ban/fail2ban.sock",
timezone="UTC",
session_duration_minutes=60,
)
high, medium, low = await setup_service.get_map_color_thresholds(db)
assert high == 100
assert medium == 50
assert low == 20
class TestGetTimezone:
async def test_returns_utc_on_fresh_db(self, db: aiosqlite.Connection) -> None:
@@ -119,6 +136,74 @@ class TestGetTimezone:
assert await setup_service.get_timezone(db) == "America/New_York"
class TestMapColorThresholds:
async def test_get_map_color_thresholds_returns_defaults_on_fresh_db(
self, db: aiosqlite.Connection
) -> None:
"""get_map_color_thresholds() returns default values on a fresh database."""
high, medium, low = await setup_service.get_map_color_thresholds(db)
assert high == 100
assert medium == 50
assert low == 20
async def test_set_map_color_thresholds_persists_values(
self, db: aiosqlite.Connection
) -> None:
"""set_map_color_thresholds() stores and retrieves custom values."""
await setup_service.set_map_color_thresholds(
db, threshold_high=200, threshold_medium=80, threshold_low=30
)
high, medium, low = await setup_service.get_map_color_thresholds(db)
assert high == 200
assert medium == 80
assert low == 30
async def test_set_map_color_thresholds_rejects_non_positive(
self, db: aiosqlite.Connection
) -> None:
"""set_map_color_thresholds() raises ValueError for non-positive thresholds."""
with pytest.raises(ValueError, match="positive integers"):
await setup_service.set_map_color_thresholds(
db, threshold_high=100, threshold_medium=50, threshold_low=0
)
with pytest.raises(ValueError, match="positive integers"):
await setup_service.set_map_color_thresholds(
db, threshold_high=-10, threshold_medium=50, threshold_low=20
)
async def test_set_map_color_thresholds_rejects_invalid_order(
self, db: aiosqlite.Connection
) -> None:
"""
set_map_color_thresholds() rejects invalid ordering.
"""
with pytest.raises(ValueError, match="high > medium > low"):
await setup_service.set_map_color_thresholds(
db, threshold_high=50, threshold_medium=50, threshold_low=20
)
with pytest.raises(ValueError, match="high > medium > low"):
await setup_service.set_map_color_thresholds(
db, threshold_high=100, threshold_medium=30, threshold_low=50
)
async def test_run_setup_initializes_default_thresholds(
self, db: aiosqlite.Connection
) -> None:
"""run_setup() initializes map color thresholds with defaults."""
await setup_service.run_setup(
db,
master_password="mypassword1",
database_path="bangui.db",
fail2ban_socket="/var/run/fail2ban/fail2ban.sock",
timezone="UTC",
session_duration_minutes=60,
)
high, medium, low = await setup_service.get_map_color_thresholds(db)
assert high == 100
assert medium == 50
assert low == 20
class TestRunSetupAsync:
"""Verify the async/non-blocking bcrypt behavior of run_setup."""