Add country-specific companion table filtering for map page
This commit is contained in:
@@ -80,6 +80,32 @@ async def test_get_currently_banned_filters_and_pagination(tmp_path: Path) -> No
|
||||
assert records[0].ip == "3.3.3.3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_currently_banned_filters_by_ip_list(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "fail2ban.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_bans_table(db)
|
||||
await db.executemany(
|
||||
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
||||
[
|
||||
("jail1", "1.1.1.1", 10, 1, "{}"),
|
||||
("jail1", "2.2.2.2", 20, 1, "{}"),
|
||||
("jail1", "3.3.3.3", 30, 1, "{}"),
|
||||
],
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
records, total = await fail2ban_db_repo.get_currently_banned(
|
||||
db_path=db_path,
|
||||
since=0,
|
||||
ip_filter=["2.2.2.2", "3.3.3.3"],
|
||||
)
|
||||
|
||||
assert total == 2
|
||||
assert len(records) == 2
|
||||
assert {record.ip for record in records} == {"2.2.2.2", "3.3.3.3"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_ban_counts_by_bucket_ignores_out_of_range_buckets(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "fail2ban.db")
|
||||
|
||||
@@ -47,6 +47,10 @@ async def test_get_archived_history_filtering_and_pagination(app_db: str) -> Non
|
||||
assert total == 2
|
||||
assert len(rows) == 1
|
||||
|
||||
rows, total = await get_archived_history(db, ip_filter=["2.2.2.2"])
|
||||
assert total == 1
|
||||
assert rows[0]["ip"] == "2.2.2.2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_purge_archived_history(app_db: str) -> None:
|
||||
|
||||
@@ -522,6 +522,19 @@ class TestDashboardBansOriginField:
|
||||
|
||||
assert mock_fn.call_args[1]["source"] == "archive"
|
||||
|
||||
async def test_bans_by_country_country_code_forwarded(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""The ``country_code`` query parameter is forwarded to bans_by_country."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
|
||||
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
|
||||
await dashboard_client.get(
|
||||
"/api/dashboard/bans/by-country?country_code=DE"
|
||||
)
|
||||
|
||||
_, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("country_code") == "DE"
|
||||
|
||||
async def test_blocklist_origin_serialised_correctly(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
|
||||
@@ -654,6 +654,54 @@ class TestOriginFilter:
|
||||
|
||||
assert result.total == 3
|
||||
|
||||
async def test_bans_by_country_country_code_returns_all_matched_rows(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""``bans_by_country`` returns all companion rows for the selected country."""
|
||||
path = str(tmp_path / "fail2ban_country_filter.sqlite3")
|
||||
rows = [
|
||||
{
|
||||
"jail": "sshd",
|
||||
"ip": "10.0.0.1",
|
||||
"timeofban": _ONE_HOUR_AGO - i,
|
||||
"bantime": 3600,
|
||||
"bancount": 1,
|
||||
"data": {"matches": ["failed login"]},
|
||||
}
|
||||
for i in range(205)
|
||||
]
|
||||
await _create_f2b_db(path, rows)
|
||||
|
||||
from app.services import geo_service
|
||||
|
||||
geo_service._cache["10.0.0.1"] = geo_service.GeoInfo(
|
||||
country_code="DE",
|
||||
country_name="Germany",
|
||||
asn=None,
|
||||
org=None,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
), patch(
|
||||
"app.services.ban_service.asyncio.create_task"
|
||||
) as mock_create_task:
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
country_code="DE",
|
||||
http_session=AsyncMock(),
|
||||
geo_cache_lookup=geo_service.lookup_cached_only,
|
||||
)
|
||||
|
||||
mock_create_task.assert_not_called()
|
||||
assert result.total == 205
|
||||
assert len(result.bans) == 205
|
||||
assert all(b.country_code == "DE" for b in result.bans)
|
||||
|
||||
geo_service.clear_cache()
|
||||
|
||||
async def test_bans_by_country_source_archive_reads_archive(
|
||||
self, app_db_with_archive: aiosqlite.Connection
|
||||
) -> None:
|
||||
|
||||
Reference in New Issue
Block a user