Add country-specific companion table filtering for map page

This commit is contained in:
2026-04-05 22:12:06 +02:00
parent c03a5c1cbc
commit c51858ec71
13 changed files with 332 additions and 85 deletions

View File

@@ -126,6 +126,7 @@ async def get_currently_banned(
since: int,
origin: BanOrigin | None = None,
*,
ip_filter: list[str] | None = None,
limit: int | None = None,
offset: int | None = None,
) -> tuple[list[BanRecord], int]:
@@ -135,6 +136,7 @@ async def get_currently_banned(
db_path: File path to the fail2ban SQLite database.
since: Unix timestamp to filter bans newer than or equal to.
origin: Optional origin filter.
ip_filter: Optional list of IP addresses to restrict the result to.
limit: Optional maximum number of rows to return.
offset: Optional offset for pagination.
@@ -142,14 +144,21 @@ async def get_currently_banned(
A ``(records, total)`` tuple.
"""
if ip_filter is not None and len(ip_filter) == 0:
return [], 0
origin_clause, origin_params = _origin_sql_filter(origin)
ip_filter_clause = ""
if ip_filter is not None:
placeholder = ", ".join("?" for _ in ip_filter)
ip_filter_clause = f" AND ip IN ({placeholder})"
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
(since, *origin_params),
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause + ip_filter_clause,
(since, *origin_params, *(ip_filter or [])),
) as cur:
count_row = await cur.fetchone()
total: int = int(count_row[0]) if count_row else 0
@@ -157,9 +166,9 @@ async def get_currently_banned(
query = (
"SELECT jail, ip, timeofban, bancount, data "
"FROM bans "
"WHERE timeofban >= ?" + origin_clause + " ORDER BY timeofban DESC"
"WHERE timeofban >= ?" + origin_clause + ip_filter_clause + " ORDER BY timeofban DESC"
)
params: list[object] = [since, *origin_params]
params: list[object] = [since, *origin_params, *(ip_filter or [])]
if limit is not None:
query += " LIMIT ?"
params.append(limit)

View File

@@ -40,13 +40,16 @@ async def get_archived_history(
db: aiosqlite.Connection,
since: int | None = None,
jail: str | None = None,
ip_filter: str | None = None,
ip_filter: str | list[str] | None = None,
origin: BanOrigin | None = None,
action: str | None = None,
page: int = 1,
page_size: int = 100,
) -> tuple[list[dict], int]:
"""Return a paginated archived history result set."""
if isinstance(ip_filter, list) and len(ip_filter) == 0:
return [], 0
wheres: list[str] = []
params: list[object] = []
@@ -59,8 +62,13 @@ async def get_archived_history(
params.append(jail)
if ip_filter is not None:
wheres.append("ip LIKE ?")
params.append(f"{ip_filter}%")
if isinstance(ip_filter, list):
placeholder = ", ".join("?" for _ in ip_filter)
wheres.append(f"ip IN ({placeholder})")
params.extend(ip_filter)
else:
wheres.append("ip LIKE ?")
params.append(f"{ip_filter}%")
if origin == "blocklist":
wheres.append("jail = ?")
@@ -108,7 +116,7 @@ async def get_all_archived_history(
db: aiosqlite.Connection,
since: int | None = None,
jail: str | None = None,
ip_filter: str | None = None,
ip_filter: str | list[str] | None = None,
origin: BanOrigin | None = None,
action: str | None = None,
) -> list[dict]:

View File

@@ -83,7 +83,10 @@ async def get_dashboard_bans(
request: Request,
_auth: AuthDep,
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
source: Literal["fail2ban", "archive"] = Query(default="fail2ban", description="Data source: 'fail2ban' or 'archive'."),
source: Literal["fail2ban", "archive"] = Query(
default="fail2ban",
description="Data source: 'fail2ban' or 'archive'.",
),
page: int = Query(default=1, ge=1, description="1-based page number."),
page_size: int = Query(default=_DEFAULT_PAGE_SIZE, ge=1, le=500, description="Items per page."),
origin: BanOrigin | None = Query(
@@ -137,11 +140,18 @@ async def get_bans_by_country(
request: Request,
_auth: AuthDep,
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
source: Literal["fail2ban", "archive"] = Query(default="fail2ban", description="Data source: 'fail2ban' or 'archive'."),
source: Literal["fail2ban", "archive"] = Query(
default="fail2ban",
description="Data source: 'fail2ban' or 'archive'.",
),
origin: BanOrigin | None = Query(
default=None,
description="Filter by ban origin: 'blocklist' or 'selfblock'. Omit for all.",
),
country_code: str | None = Query(
default=None,
description="ISO alpha-2 country code to filter companion rows.",
),
) -> BansByCountryResponse:
"""Return ban counts aggregated by ISO country code.
@@ -173,6 +183,7 @@ async def get_bans_by_country(
geo_batch_lookup=geo_service.lookup_batch,
app_db=request.app.state.db,
origin=origin,
country_code=country_code,
)
@@ -185,7 +196,10 @@ async def get_ban_trend(
request: Request,
_auth: AuthDep,
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
source: Literal["fail2ban", "archive"] = Query(default="fail2ban", description="Data source: 'fail2ban' or 'archive'."),
source: Literal["fail2ban", "archive"] = Query(
default="fail2ban",
description="Data source: 'fail2ban' or 'archive'.",
),
origin: BanOrigin | None = Query(
default=None,
description="Filter by ban origin: 'blocklist' or 'selfblock'. Omit for all.",
@@ -235,7 +249,10 @@ async def get_bans_by_jail(
request: Request,
_auth: AuthDep,
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
source: Literal["fail2ban", "archive"] = Query(default="fail2ban", description="Data source: 'fail2ban' or 'archive'."),
source: Literal["fail2ban", "archive"] = Query(
default="fail2ban",
description="Data source: 'fail2ban' or 'archive'.",
),
origin: BanOrigin | None = Query(
default=None,
description="Filter by ban origin: 'blocklist' or 'selfblock'. Omit for all.",

View File

@@ -290,6 +290,7 @@ async def bans_by_country(
geo_enricher: GeoEnricher | None = None,
app_db: aiosqlite.Connection | None = None,
origin: BanOrigin | None = None,
country_code: str | None = None,
) -> BansByCountryResponse:
"""Aggregate ban counts per country for the selected time window.
@@ -350,16 +351,6 @@ async def bans_by_country(
total = len(all_rows)
# companion rows for the table should be most recent
companion_rows, _ = await get_archived_history(
db=app_db,
since=since,
origin=origin,
action="ban",
page=1,
page_size=_MAX_COMPANION_BANS,
)
agg_rows = {}
for row in all_rows:
ip = str(row["ip"])
@@ -393,14 +384,6 @@ async def bans_by_country(
origin=origin,
)
companion_rows, _ = await fail2ban_db_repo.get_currently_banned(
db_path=db_path,
since=since,
origin=origin,
limit=_MAX_COMPANION_BANS,
offset=0,
)
unique_ips = [r.ip for r in agg_rows]
geo_map: dict[str, GeoInfo] = {}
@@ -434,6 +417,54 @@ async def bans_by_country(
results = await asyncio.gather(*(_safe_lookup(ip) for ip in unique_ips))
geo_map = {ip: geo for ip, geo in results if geo is not None}
companion_rows: list[dict[str, object] | fail2ban_db_repo.BanRecord]
if country_code is None:
if source == "archive":
companion_rows, _ = await get_archived_history(
db=app_db,
since=since,
origin=origin,
action="ban",
page=1,
page_size=_MAX_COMPANION_BANS,
)
else:
companion_rows, _ = await fail2ban_db_repo.get_currently_banned(
db_path=db_path,
since=since,
origin=origin,
limit=_MAX_COMPANION_BANS,
offset=0,
)
else:
matched_ips = [
ip
for ip, geo in geo_map.items()
if geo is not None and geo.country_code == country_code
]
if source == "archive":
if matched_ips:
companion_rows = await get_all_archived_history(
db=app_db,
since=since,
origin=origin,
action="ban",
ip_filter=matched_ips,
)
else:
companion_rows = []
else:
if matched_ips:
companion_rows, _ = await fail2ban_db_repo.get_currently_banned(
db_path=db_path,
since=since,
origin=origin,
ip_filter=matched_ips,
)
else:
companion_rows = []
# Build country aggregation from the SQL-grouped rows.
countries: dict[str, int] = {}
country_names: dict[str, str] = {}