Add country-specific companion table filtering for map page

This commit is contained in:
2026-04-05 22:12:06 +02:00
parent c03a5c1cbc
commit c51858ec71
13 changed files with 332 additions and 85 deletions

View File

@@ -80,6 +80,32 @@ async def test_get_currently_banned_filters_and_pagination(tmp_path: Path) -> No
assert records[0].ip == "3.3.3.3"
@pytest.mark.asyncio
async def test_get_currently_banned_filters_by_ip_list(tmp_path: Path) -> None:
db_path = str(tmp_path / "fail2ban.db")
async with aiosqlite.connect(db_path) as db:
await _create_bans_table(db)
await db.executemany(
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
[
("jail1", "1.1.1.1", 10, 1, "{}"),
("jail1", "2.2.2.2", 20, 1, "{}"),
("jail1", "3.3.3.3", 30, 1, "{}"),
],
)
await db.commit()
records, total = await fail2ban_db_repo.get_currently_banned(
db_path=db_path,
since=0,
ip_filter=["2.2.2.2", "3.3.3.3"],
)
assert total == 2
assert len(records) == 2
assert {record.ip for record in records} == {"2.2.2.2", "3.3.3.3"}
@pytest.mark.asyncio
async def test_get_ban_counts_by_bucket_ignores_out_of_range_buckets(tmp_path: Path) -> None:
db_path = str(tmp_path / "fail2ban.db")

View File

@@ -47,6 +47,10 @@ async def test_get_archived_history_filtering_and_pagination(app_db: str) -> Non
assert total == 2
assert len(rows) == 1
rows, total = await get_archived_history(db, ip_filter=["2.2.2.2"])
assert total == 1
assert rows[0]["ip"] == "2.2.2.2"
@pytest.mark.asyncio
async def test_purge_archived_history(app_db: str) -> None:

View File

@@ -522,6 +522,19 @@ class TestDashboardBansOriginField:
assert mock_fn.call_args[1]["source"] == "archive"
async def test_bans_by_country_country_code_forwarded(
self, dashboard_client: AsyncClient
) -> None:
"""The ``country_code`` query parameter is forwarded to bans_by_country."""
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
await dashboard_client.get(
"/api/dashboard/bans/by-country?country_code=DE"
)
_, kwargs = mock_fn.call_args
assert kwargs.get("country_code") == "DE"
async def test_blocklist_origin_serialised_correctly(
self, dashboard_client: AsyncClient
) -> None:

View File

@@ -654,6 +654,54 @@ class TestOriginFilter:
assert result.total == 3
async def test_bans_by_country_country_code_returns_all_matched_rows(
self, tmp_path: Path
) -> None:
"""``bans_by_country`` returns all companion rows for the selected country."""
path = str(tmp_path / "fail2ban_country_filter.sqlite3")
rows = [
{
"jail": "sshd",
"ip": "10.0.0.1",
"timeofban": _ONE_HOUR_AGO - i,
"bantime": 3600,
"bancount": 1,
"data": {"matches": ["failed login"]},
}
for i in range(205)
]
await _create_f2b_db(path, rows)
from app.services import geo_service
geo_service._cache["10.0.0.1"] = geo_service.GeoInfo(
country_code="DE",
country_name="Germany",
asn=None,
org=None,
)
with patch(
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=path),
), patch(
"app.services.ban_service.asyncio.create_task"
) as mock_create_task:
result = await ban_service.bans_by_country(
"/fake/sock",
"24h",
country_code="DE",
http_session=AsyncMock(),
geo_cache_lookup=geo_service.lookup_cached_only,
)
mock_create_task.assert_not_called()
assert result.total == 205
assert len(result.bans) == 205
assert all(b.country_code == "DE" for b in result.bans)
geo_service.clear_cache()
async def test_bans_by_country_source_archive_reads_archive(
self, app_db_with_archive: aiosqlite.Connection
) -> None: