Add country-specific companion table filtering for map page
This commit is contained in:
@@ -126,6 +126,7 @@ async def get_currently_banned(
|
||||
since: int,
|
||||
origin: BanOrigin | None = None,
|
||||
*,
|
||||
ip_filter: list[str] | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> tuple[list[BanRecord], int]:
|
||||
@@ -135,6 +136,7 @@ async def get_currently_banned(
|
||||
db_path: File path to the fail2ban SQLite database.
|
||||
since: Unix timestamp to filter bans newer than or equal to.
|
||||
origin: Optional origin filter.
|
||||
ip_filter: Optional list of IP addresses to restrict the result to.
|
||||
limit: Optional maximum number of rows to return.
|
||||
offset: Optional offset for pagination.
|
||||
|
||||
@@ -142,14 +144,21 @@ async def get_currently_banned(
|
||||
A ``(records, total)`` tuple.
|
||||
"""
|
||||
|
||||
if ip_filter is not None and len(ip_filter) == 0:
|
||||
return [], 0
|
||||
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
ip_filter_clause = ""
|
||||
if ip_filter is not None:
|
||||
placeholder = ", ".join("?" for _ in ip_filter)
|
||||
ip_filter_clause = f" AND ip IN ({placeholder})"
|
||||
|
||||
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||
db.row_factory = aiosqlite.Row
|
||||
|
||||
async with db.execute(
|
||||
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
||||
(since, *origin_params),
|
||||
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause + ip_filter_clause,
|
||||
(since, *origin_params, *(ip_filter or [])),
|
||||
) as cur:
|
||||
count_row = await cur.fetchone()
|
||||
total: int = int(count_row[0]) if count_row else 0
|
||||
@@ -157,9 +166,9 @@ async def get_currently_banned(
|
||||
query = (
|
||||
"SELECT jail, ip, timeofban, bancount, data "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ?" + origin_clause + " ORDER BY timeofban DESC"
|
||||
"WHERE timeofban >= ?" + origin_clause + ip_filter_clause + " ORDER BY timeofban DESC"
|
||||
)
|
||||
params: list[object] = [since, *origin_params]
|
||||
params: list[object] = [since, *origin_params, *(ip_filter or [])]
|
||||
if limit is not None:
|
||||
query += " LIMIT ?"
|
||||
params.append(limit)
|
||||
|
||||
@@ -40,13 +40,16 @@ async def get_archived_history(
|
||||
db: aiosqlite.Connection,
|
||||
since: int | None = None,
|
||||
jail: str | None = None,
|
||||
ip_filter: str | None = None,
|
||||
ip_filter: str | list[str] | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
action: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 100,
|
||||
) -> tuple[list[dict], int]:
|
||||
"""Return a paginated archived history result set."""
|
||||
if isinstance(ip_filter, list) and len(ip_filter) == 0:
|
||||
return [], 0
|
||||
|
||||
wheres: list[str] = []
|
||||
params: list[object] = []
|
||||
|
||||
@@ -59,8 +62,13 @@ async def get_archived_history(
|
||||
params.append(jail)
|
||||
|
||||
if ip_filter is not None:
|
||||
wheres.append("ip LIKE ?")
|
||||
params.append(f"{ip_filter}%")
|
||||
if isinstance(ip_filter, list):
|
||||
placeholder = ", ".join("?" for _ in ip_filter)
|
||||
wheres.append(f"ip IN ({placeholder})")
|
||||
params.extend(ip_filter)
|
||||
else:
|
||||
wheres.append("ip LIKE ?")
|
||||
params.append(f"{ip_filter}%")
|
||||
|
||||
if origin == "blocklist":
|
||||
wheres.append("jail = ?")
|
||||
@@ -108,7 +116,7 @@ async def get_all_archived_history(
|
||||
db: aiosqlite.Connection,
|
||||
since: int | None = None,
|
||||
jail: str | None = None,
|
||||
ip_filter: str | None = None,
|
||||
ip_filter: str | list[str] | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
action: str | None = None,
|
||||
) -> list[dict]:
|
||||
|
||||
@@ -83,7 +83,10 @@ async def get_dashboard_bans(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
|
||||
source: Literal["fail2ban", "archive"] = Query(default="fail2ban", description="Data source: 'fail2ban' or 'archive'."),
|
||||
source: Literal["fail2ban", "archive"] = Query(
|
||||
default="fail2ban",
|
||||
description="Data source: 'fail2ban' or 'archive'.",
|
||||
),
|
||||
page: int = Query(default=1, ge=1, description="1-based page number."),
|
||||
page_size: int = Query(default=_DEFAULT_PAGE_SIZE, ge=1, le=500, description="Items per page."),
|
||||
origin: BanOrigin | None = Query(
|
||||
@@ -137,11 +140,18 @@ async def get_bans_by_country(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
|
||||
source: Literal["fail2ban", "archive"] = Query(default="fail2ban", description="Data source: 'fail2ban' or 'archive'."),
|
||||
source: Literal["fail2ban", "archive"] = Query(
|
||||
default="fail2ban",
|
||||
description="Data source: 'fail2ban' or 'archive'.",
|
||||
),
|
||||
origin: BanOrigin | None = Query(
|
||||
default=None,
|
||||
description="Filter by ban origin: 'blocklist' or 'selfblock'. Omit for all.",
|
||||
),
|
||||
country_code: str | None = Query(
|
||||
default=None,
|
||||
description="ISO alpha-2 country code to filter companion rows.",
|
||||
),
|
||||
) -> BansByCountryResponse:
|
||||
"""Return ban counts aggregated by ISO country code.
|
||||
|
||||
@@ -173,6 +183,7 @@ async def get_bans_by_country(
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
app_db=request.app.state.db,
|
||||
origin=origin,
|
||||
country_code=country_code,
|
||||
)
|
||||
|
||||
|
||||
@@ -185,7 +196,10 @@ async def get_ban_trend(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
|
||||
source: Literal["fail2ban", "archive"] = Query(default="fail2ban", description="Data source: 'fail2ban' or 'archive'."),
|
||||
source: Literal["fail2ban", "archive"] = Query(
|
||||
default="fail2ban",
|
||||
description="Data source: 'fail2ban' or 'archive'.",
|
||||
),
|
||||
origin: BanOrigin | None = Query(
|
||||
default=None,
|
||||
description="Filter by ban origin: 'blocklist' or 'selfblock'. Omit for all.",
|
||||
@@ -235,7 +249,10 @@ async def get_bans_by_jail(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
|
||||
source: Literal["fail2ban", "archive"] = Query(default="fail2ban", description="Data source: 'fail2ban' or 'archive'."),
|
||||
source: Literal["fail2ban", "archive"] = Query(
|
||||
default="fail2ban",
|
||||
description="Data source: 'fail2ban' or 'archive'.",
|
||||
),
|
||||
origin: BanOrigin | None = Query(
|
||||
default=None,
|
||||
description="Filter by ban origin: 'blocklist' or 'selfblock'. Omit for all.",
|
||||
|
||||
@@ -290,6 +290,7 @@ async def bans_by_country(
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
country_code: str | None = None,
|
||||
) -> BansByCountryResponse:
|
||||
"""Aggregate ban counts per country for the selected time window.
|
||||
|
||||
@@ -350,16 +351,6 @@ async def bans_by_country(
|
||||
|
||||
total = len(all_rows)
|
||||
|
||||
# companion rows for the table should be most recent
|
||||
companion_rows, _ = await get_archived_history(
|
||||
db=app_db,
|
||||
since=since,
|
||||
origin=origin,
|
||||
action="ban",
|
||||
page=1,
|
||||
page_size=_MAX_COMPANION_BANS,
|
||||
)
|
||||
|
||||
agg_rows = {}
|
||||
for row in all_rows:
|
||||
ip = str(row["ip"])
|
||||
@@ -393,14 +384,6 @@ async def bans_by_country(
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
companion_rows, _ = await fail2ban_db_repo.get_currently_banned(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
limit=_MAX_COMPANION_BANS,
|
||||
offset=0,
|
||||
)
|
||||
|
||||
unique_ips = [r.ip for r in agg_rows]
|
||||
geo_map: dict[str, GeoInfo] = {}
|
||||
|
||||
@@ -434,6 +417,54 @@ async def bans_by_country(
|
||||
results = await asyncio.gather(*(_safe_lookup(ip) for ip in unique_ips))
|
||||
geo_map = {ip: geo for ip, geo in results if geo is not None}
|
||||
|
||||
companion_rows: list[dict[str, object] | fail2ban_db_repo.BanRecord]
|
||||
if country_code is None:
|
||||
if source == "archive":
|
||||
companion_rows, _ = await get_archived_history(
|
||||
db=app_db,
|
||||
since=since,
|
||||
origin=origin,
|
||||
action="ban",
|
||||
page=1,
|
||||
page_size=_MAX_COMPANION_BANS,
|
||||
)
|
||||
else:
|
||||
companion_rows, _ = await fail2ban_db_repo.get_currently_banned(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
limit=_MAX_COMPANION_BANS,
|
||||
offset=0,
|
||||
)
|
||||
else:
|
||||
matched_ips = [
|
||||
ip
|
||||
for ip, geo in geo_map.items()
|
||||
if geo is not None and geo.country_code == country_code
|
||||
]
|
||||
|
||||
if source == "archive":
|
||||
if matched_ips:
|
||||
companion_rows = await get_all_archived_history(
|
||||
db=app_db,
|
||||
since=since,
|
||||
origin=origin,
|
||||
action="ban",
|
||||
ip_filter=matched_ips,
|
||||
)
|
||||
else:
|
||||
companion_rows = []
|
||||
else:
|
||||
if matched_ips:
|
||||
companion_rows, _ = await fail2ban_db_repo.get_currently_banned(
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
origin=origin,
|
||||
ip_filter=matched_ips,
|
||||
)
|
||||
else:
|
||||
companion_rows = []
|
||||
|
||||
# Build country aggregation from the SQL-grouped rows.
|
||||
countries: dict[str, int] = {}
|
||||
country_names: dict[str, str] = {}
|
||||
|
||||
@@ -80,6 +80,32 @@ async def test_get_currently_banned_filters_and_pagination(tmp_path: Path) -> No
|
||||
assert records[0].ip == "3.3.3.3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_currently_banned_filters_by_ip_list(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "fail2ban.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_bans_table(db)
|
||||
await db.executemany(
|
||||
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
||||
[
|
||||
("jail1", "1.1.1.1", 10, 1, "{}"),
|
||||
("jail1", "2.2.2.2", 20, 1, "{}"),
|
||||
("jail1", "3.3.3.3", 30, 1, "{}"),
|
||||
],
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
records, total = await fail2ban_db_repo.get_currently_banned(
|
||||
db_path=db_path,
|
||||
since=0,
|
||||
ip_filter=["2.2.2.2", "3.3.3.3"],
|
||||
)
|
||||
|
||||
assert total == 2
|
||||
assert len(records) == 2
|
||||
assert {record.ip for record in records} == {"2.2.2.2", "3.3.3.3"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_ban_counts_by_bucket_ignores_out_of_range_buckets(tmp_path: Path) -> None:
|
||||
db_path = str(tmp_path / "fail2ban.db")
|
||||
|
||||
@@ -47,6 +47,10 @@ async def test_get_archived_history_filtering_and_pagination(app_db: str) -> Non
|
||||
assert total == 2
|
||||
assert len(rows) == 1
|
||||
|
||||
rows, total = await get_archived_history(db, ip_filter=["2.2.2.2"])
|
||||
assert total == 1
|
||||
assert rows[0]["ip"] == "2.2.2.2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_purge_archived_history(app_db: str) -> None:
|
||||
|
||||
@@ -522,6 +522,19 @@ class TestDashboardBansOriginField:
|
||||
|
||||
assert mock_fn.call_args[1]["source"] == "archive"
|
||||
|
||||
async def test_bans_by_country_country_code_forwarded(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""The ``country_code`` query parameter is forwarded to bans_by_country."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
|
||||
with patch("app.routers.dashboard.ban_service.bans_by_country", new=mock_fn):
|
||||
await dashboard_client.get(
|
||||
"/api/dashboard/bans/by-country?country_code=DE"
|
||||
)
|
||||
|
||||
_, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("country_code") == "DE"
|
||||
|
||||
async def test_blocklist_origin_serialised_correctly(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
|
||||
@@ -654,6 +654,54 @@ class TestOriginFilter:
|
||||
|
||||
assert result.total == 3
|
||||
|
||||
async def test_bans_by_country_country_code_returns_all_matched_rows(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
"""``bans_by_country`` returns all companion rows for the selected country."""
|
||||
path = str(tmp_path / "fail2ban_country_filter.sqlite3")
|
||||
rows = [
|
||||
{
|
||||
"jail": "sshd",
|
||||
"ip": "10.0.0.1",
|
||||
"timeofban": _ONE_HOUR_AGO - i,
|
||||
"bantime": 3600,
|
||||
"bancount": 1,
|
||||
"data": {"matches": ["failed login"]},
|
||||
}
|
||||
for i in range(205)
|
||||
]
|
||||
await _create_f2b_db(path, rows)
|
||||
|
||||
from app.services import geo_service
|
||||
|
||||
geo_service._cache["10.0.0.1"] = geo_service.GeoInfo(
|
||||
country_code="DE",
|
||||
country_name="Germany",
|
||||
asn=None,
|
||||
org=None,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
), patch(
|
||||
"app.services.ban_service.asyncio.create_task"
|
||||
) as mock_create_task:
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
country_code="DE",
|
||||
http_session=AsyncMock(),
|
||||
geo_cache_lookup=geo_service.lookup_cached_only,
|
||||
)
|
||||
|
||||
mock_create_task.assert_not_called()
|
||||
assert result.total == 205
|
||||
assert len(result.bans) == 205
|
||||
assert all(b.country_code == "DE" for b in result.bans)
|
||||
|
||||
geo_service.clear_cache()
|
||||
|
||||
async def test_bans_by_country_source_archive_reads_archive(
|
||||
self, app_db_with_archive: aiosqlite.Connection
|
||||
) -> None:
|
||||
|
||||
Reference in New Issue
Block a user