Add origin field and filter for ban sources (Tasks 1 & 2)
- Task 1: Mark imported blocklist IP addresses
- Add BanOrigin type and _derive_origin() to ban.py model
- Populate origin field in ban_service list_bans() and bans_by_country()
- BanTable and MapPage companion table show origin badge column
- Tests: origin derivation in test_ban_service.py and test_dashboard.py
- Task 2: Add origin filter to dashboard and world map
- ban_service: _origin_sql_filter() helper; origin param on list_bans()
and bans_by_country()
- dashboard router: optional origin query param forwarded to service
- Frontend: BanOriginFilter type + BAN_ORIGIN_FILTER_LABELS in ban.ts
- fetchBans / fetchBansByCountry forward origin to API
- useBans / useMapData accept and pass origin; page resets on change
- BanTable accepts origin prop; DashboardPage adds segmented filter
- MapPage adds origin Select next to time-range picker
- Tests: origin filter assertions in test_ban_service and test_dashboard
This commit is contained in:
@@ -102,6 +102,39 @@ async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mixed_origin_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||
"""Return a database with bans from both blocklist-import and organic jails."""
|
||||
path = str(tmp_path / "fail2ban_mixed_origin.sqlite3")
|
||||
await _create_f2b_db(
|
||||
path,
|
||||
[
|
||||
{
|
||||
"jail": "blocklist-import",
|
||||
"ip": "10.0.0.1",
|
||||
"timeofban": _ONE_HOUR_AGO,
|
||||
"bantime": -1,
|
||||
"bancount": 1,
|
||||
},
|
||||
{
|
||||
"jail": "sshd",
|
||||
"ip": "10.0.0.2",
|
||||
"timeofban": _ONE_HOUR_AGO,
|
||||
"bantime": 3600,
|
||||
"bancount": 3,
|
||||
},
|
||||
{
|
||||
"jail": "nginx",
|
||||
"ip": "10.0.0.3",
|
||||
"timeofban": _ONE_HOUR_AGO,
|
||||
"bantime": 7200,
|
||||
"bancount": 1,
|
||||
},
|
||||
],
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def empty_f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||
"""Return the path to a fail2ban SQLite database with no ban records."""
|
||||
@@ -299,3 +332,183 @@ class TestListBansPagination:
|
||||
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
|
||||
|
||||
assert result.total == 3 # All three bans are within 7d.
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_bans / bans_by_country — origin derivation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBanOriginDerivation:
|
||||
"""Verify that ban_service correctly derives ``origin`` from jail names."""
|
||||
|
||||
async def test_blocklist_import_jail_yields_blocklist_origin(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""Bans from ``blocklist-import`` jail carry ``origin == "blocklist"``."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
|
||||
blocklist_items = [i for i in result.items if i.jail == "blocklist-import"]
|
||||
assert len(blocklist_items) == 1
|
||||
assert blocklist_items[0].origin == "blocklist"
|
||||
|
||||
async def test_organic_jail_yields_selfblock_origin(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""Bans from organic jails (sshd, nginx, …) carry ``origin == "selfblock"``."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
|
||||
organic_items = [i for i in result.items if i.jail != "blocklist-import"]
|
||||
assert len(organic_items) == 2
|
||||
for item in organic_items:
|
||||
assert item.origin == "selfblock"
|
||||
|
||||
async def test_all_items_carry_origin_field(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""Every returned item has an ``origin`` field with a valid value."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
|
||||
for item in result.items:
|
||||
assert item.origin in ("blocklist", "selfblock")
|
||||
|
||||
async def test_bans_by_country_blocklist_origin(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""``bans_by_country`` also derives origin correctly for blocklist bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
||||
|
||||
blocklist_bans = [b for b in result.bans if b.jail == "blocklist-import"]
|
||||
assert len(blocklist_bans) == 1
|
||||
assert blocklist_bans[0].origin == "blocklist"
|
||||
|
||||
async def test_bans_by_country_selfblock_origin(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""``bans_by_country`` derives origin correctly for organic jails."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
||||
|
||||
organic_bans = [b for b in result.bans if b.jail != "blocklist-import"]
|
||||
assert len(organic_bans) == 2
|
||||
for ban in organic_bans:
|
||||
assert ban.origin == "selfblock"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_bans / bans_by_country — origin filter parameter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOriginFilter:
|
||||
"""Verify that the origin filter correctly restricts results."""
|
||||
|
||||
async def test_list_bans_blocklist_filter_returns_only_blocklist(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""``origin='blocklist'`` returns only blocklist-import jail bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", origin="blocklist"
|
||||
)
|
||||
|
||||
assert result.total == 1
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].jail == "blocklist-import"
|
||||
assert result.items[0].origin == "blocklist"
|
||||
|
||||
async def test_list_bans_selfblock_filter_excludes_blocklist(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""``origin='selfblock'`` excludes the blocklist-import jail."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", origin="selfblock"
|
||||
)
|
||||
|
||||
assert result.total == 2
|
||||
assert len(result.items) == 2
|
||||
for item in result.items:
|
||||
assert item.jail != "blocklist-import"
|
||||
assert item.origin == "selfblock"
|
||||
|
||||
async def test_list_bans_no_filter_returns_all(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""``origin=None`` applies no jail restriction — all bans returned."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h", origin=None)
|
||||
|
||||
assert result.total == 3
|
||||
|
||||
async def test_bans_by_country_blocklist_filter(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""``bans_by_country`` with ``origin='blocklist'`` counts only blocklist bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", origin="blocklist"
|
||||
)
|
||||
|
||||
assert result.total == 1
|
||||
assert all(b.jail == "blocklist-import" for b in result.bans)
|
||||
|
||||
async def test_bans_by_country_selfblock_filter(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""``bans_by_country`` with ``origin='selfblock'`` excludes blocklist jails."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", origin="selfblock"
|
||||
)
|
||||
|
||||
assert result.total == 2
|
||||
assert all(b.jail != "blocklist-import" for b in result.bans)
|
||||
|
||||
async def test_bans_by_country_no_filter_returns_all(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""``bans_by_country`` with ``origin=None`` returns all bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", origin=None
|
||||
)
|
||||
|
||||
assert result.total == 3
|
||||
|
||||
@@ -98,6 +98,23 @@ class TestRunSetup:
|
||||
with pytest.raises(RuntimeError, match="already been completed"):
|
||||
await setup_service.run_setup(db, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
async def test_initializes_map_color_thresholds_with_defaults(
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""run_setup() initializes map color thresholds with default values."""
|
||||
await setup_service.run_setup(
|
||||
db,
|
||||
master_password="mypassword1",
|
||||
database_path="bangui.db",
|
||||
fail2ban_socket="/var/run/fail2ban/fail2ban.sock",
|
||||
timezone="UTC",
|
||||
session_duration_minutes=60,
|
||||
)
|
||||
high, medium, low = await setup_service.get_map_color_thresholds(db)
|
||||
assert high == 100
|
||||
assert medium == 50
|
||||
assert low == 20
|
||||
|
||||
|
||||
class TestGetTimezone:
|
||||
async def test_returns_utc_on_fresh_db(self, db: aiosqlite.Connection) -> None:
|
||||
@@ -119,6 +136,74 @@ class TestGetTimezone:
|
||||
assert await setup_service.get_timezone(db) == "America/New_York"
|
||||
|
||||
|
||||
class TestMapColorThresholds:
|
||||
async def test_get_map_color_thresholds_returns_defaults_on_fresh_db(
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""get_map_color_thresholds() returns default values on a fresh database."""
|
||||
high, medium, low = await setup_service.get_map_color_thresholds(db)
|
||||
assert high == 100
|
||||
assert medium == 50
|
||||
assert low == 20
|
||||
|
||||
async def test_set_map_color_thresholds_persists_values(
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""set_map_color_thresholds() stores and retrieves custom values."""
|
||||
await setup_service.set_map_color_thresholds(
|
||||
db, threshold_high=200, threshold_medium=80, threshold_low=30
|
||||
)
|
||||
high, medium, low = await setup_service.get_map_color_thresholds(db)
|
||||
assert high == 200
|
||||
assert medium == 80
|
||||
assert low == 30
|
||||
|
||||
async def test_set_map_color_thresholds_rejects_non_positive(
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""set_map_color_thresholds() raises ValueError for non-positive thresholds."""
|
||||
with pytest.raises(ValueError, match="positive integers"):
|
||||
await setup_service.set_map_color_thresholds(
|
||||
db, threshold_high=100, threshold_medium=50, threshold_low=0
|
||||
)
|
||||
with pytest.raises(ValueError, match="positive integers"):
|
||||
await setup_service.set_map_color_thresholds(
|
||||
db, threshold_high=-10, threshold_medium=50, threshold_low=20
|
||||
)
|
||||
|
||||
async def test_set_map_color_thresholds_rejects_invalid_order(
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""
|
||||
set_map_color_thresholds() rejects invalid ordering.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="high > medium > low"):
|
||||
await setup_service.set_map_color_thresholds(
|
||||
db, threshold_high=50, threshold_medium=50, threshold_low=20
|
||||
)
|
||||
with pytest.raises(ValueError, match="high > medium > low"):
|
||||
await setup_service.set_map_color_thresholds(
|
||||
db, threshold_high=100, threshold_medium=30, threshold_low=50
|
||||
)
|
||||
|
||||
async def test_run_setup_initializes_default_thresholds(
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""run_setup() initializes map color thresholds with defaults."""
|
||||
await setup_service.run_setup(
|
||||
db,
|
||||
master_password="mypassword1",
|
||||
database_path="bangui.db",
|
||||
fail2ban_socket="/var/run/fail2ban/fail2ban.sock",
|
||||
timezone="UTC",
|
||||
session_duration_minutes=60,
|
||||
)
|
||||
high, medium, low = await setup_service.get_map_color_thresholds(db)
|
||||
assert high == 100
|
||||
assert medium == 50
|
||||
assert low == 20
|
||||
|
||||
|
||||
class TestRunSetupAsync:
|
||||
"""Verify the async/non-blocking bcrypt behavior of run_setup."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user