Optimise geo lookup and aggregation for 10k+ IPs
- Add persistent geo_cache SQLite table (db.py) - Rewrite geo_service: batch API (100 IPs/call), two-tier cache, no caching of failed lookups so they are retried - Pre-warm geo cache from DB on startup (main.py lifespan) - Rewrite bans_by_country: SQL GROUP BY ip aggregation + lookup_batch instead of 2000-row fetch + asyncio.gather individual calls - Pre-warm geo cache after blocklist import (blocklist_service) - Add 300ms debounce to useMapData hook to cancel stale requests - Add perf benchmark asserting <2s for 10k bans - Add seed_10k_bans.py script for manual perf testing
This commit is contained in:
@@ -64,6 +64,17 @@ CREATE TABLE IF NOT EXISTS import_log (
|
||||
);
|
||||
"""
|
||||
|
||||
_CREATE_GEO_CACHE: str = """
|
||||
CREATE TABLE IF NOT EXISTS geo_cache (
|
||||
ip TEXT PRIMARY KEY,
|
||||
country_code TEXT,
|
||||
country_name TEXT,
|
||||
asn TEXT,
|
||||
org TEXT,
|
||||
cached_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
|
||||
);
|
||||
"""
|
||||
|
||||
# Ordered list of DDL statements to execute on initialisation.
|
||||
_SCHEMA_STATEMENTS: list[str] = [
|
||||
_CREATE_SETTINGS,
|
||||
@@ -71,6 +82,7 @@ _SCHEMA_STATEMENTS: list[str] = [
|
||||
_CREATE_SESSIONS_TOKEN_INDEX,
|
||||
_CREATE_BLOCKLIST_SOURCES,
|
||||
_CREATE_IMPORT_LOG,
|
||||
_CREATE_GEO_CACHE,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -134,6 +134,11 @@ async def _lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
http_session: aiohttp.ClientSession = aiohttp.ClientSession()
|
||||
app.state.http_session = http_session
|
||||
|
||||
# --- Pre-warm geo cache from the persistent store ---
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
await geo_service.load_cache_from_db(db)
|
||||
|
||||
# --- Background task scheduler ---
|
||||
scheduler: AsyncIOScheduler = AsyncIOScheduler(timezone="UTC")
|
||||
scheduler.start()
|
||||
|
||||
@@ -9,14 +9,16 @@ Also provides ``GET /api/dashboard/bans`` for the dashboard ban-list table.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Annotated
|
||||
|
||||
import aiosqlite
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
|
||||
from fastapi import APIRouter, Query, Request
|
||||
from fastapi import APIRouter, Depends, Query, Request
|
||||
|
||||
from app.dependencies import AuthDep
|
||||
from app.dependencies import AuthDep, get_db
|
||||
from app.models.ban import (
|
||||
BanOrigin,
|
||||
BansByCountryResponse,
|
||||
@@ -75,6 +77,7 @@ async def get_server_status(
|
||||
async def get_dashboard_bans(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
db: Annotated[aiosqlite.Connection, Depends(get_db)],
|
||||
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
|
||||
page: int = Query(default=1, ge=1, description="1-based page number."),
|
||||
page_size: int = Query(default=_DEFAULT_PAGE_SIZE, ge=1, le=500, description="Items per page."),
|
||||
@@ -92,6 +95,7 @@ async def get_dashboard_bans(
|
||||
Args:
|
||||
request: The incoming request (used to access ``app.state``).
|
||||
_auth: Validated session dependency.
|
||||
db: BanGUI application database (for persistent geo cache writes).
|
||||
range: Time-range preset — ``"24h"``, ``"7d"``, ``"30d"``, or
|
||||
``"365d"``.
|
||||
page: 1-based page number.
|
||||
@@ -106,7 +110,7 @@ async def get_dashboard_bans(
|
||||
http_session: aiohttp.ClientSession = request.app.state.http_session
|
||||
|
||||
async def _enricher(ip: str) -> geo_service.GeoInfo | None:
|
||||
return await geo_service.lookup(ip, http_session)
|
||||
return await geo_service.lookup(ip, http_session, db=db)
|
||||
|
||||
return await ban_service.list_bans(
|
||||
socket_path,
|
||||
@@ -126,6 +130,7 @@ async def get_dashboard_bans(
|
||||
async def get_bans_by_country(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
db: Annotated[aiosqlite.Connection, Depends(get_db)],
|
||||
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
|
||||
origin: BanOrigin | None = Query(
|
||||
default=None,
|
||||
@@ -134,30 +139,29 @@ async def get_bans_by_country(
|
||||
) -> BansByCountryResponse:
|
||||
"""Return ban counts aggregated by ISO country code.
|
||||
|
||||
Fetches up to 2 000 ban records in the selected time window, enriches
|
||||
every record with geo data, and returns a ``{country_code: count}`` map
|
||||
plus the full enriched ban list for the companion access table.
|
||||
Uses SQL aggregation (``GROUP BY ip``) and batch geo-resolution to handle
|
||||
10 000+ banned IPs efficiently. Returns a ``{country_code: count}`` map
|
||||
and the 200 most recent raw ban rows for the companion access table.
|
||||
|
||||
Args:
|
||||
request: The incoming request.
|
||||
_auth: Validated session dependency.
|
||||
db: BanGUI application database (for persistent geo cache writes).
|
||||
range: Time-range preset.
|
||||
origin: Optional filter by ban origin.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.ban.BansByCountryResponse` with per-country
|
||||
aggregation and the full ban list.
|
||||
aggregation and the companion ban list.
|
||||
"""
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
http_session: aiohttp.ClientSession = request.app.state.http_session
|
||||
|
||||
async def _enricher(ip: str) -> geo_service.GeoInfo | None:
|
||||
return await geo_service.lookup(ip, http_session)
|
||||
|
||||
return await ban_service.bans_by_country(
|
||||
socket_path,
|
||||
range,
|
||||
geo_enricher=_enricher,
|
||||
http_session=http_session,
|
||||
app_db=db,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import aiosqlite
|
||||
import structlog
|
||||
@@ -29,6 +29,9 @@ from app.models.ban import (
|
||||
)
|
||||
from app.utils.fail2ban_client import Fail2BanClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -280,35 +283,51 @@ async def list_bans(
|
||||
# bans_by_country
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
#: Maximum bans fetched for aggregation (guard against huge databases).
|
||||
_MAX_GEO_BANS: int = 2_000
|
||||
#: Maximum rows returned in the companion table alongside the map.
|
||||
_MAX_COMPANION_BANS: int = 200
|
||||
|
||||
|
||||
async def bans_by_country(
|
||||
socket_path: str,
|
||||
range_: TimeRange,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
geo_enricher: Any | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> BansByCountryResponse:
|
||||
"""Aggregate ban counts per country for the selected time window.
|
||||
|
||||
Fetches up to ``_MAX_GEO_BANS`` ban records from the fail2ban database,
|
||||
enriches them with geo data, and returns a ``{country_code: count}`` map
|
||||
alongside the enriched ban list for the companion access table.
|
||||
Uses a two-step strategy optimised for large datasets:
|
||||
|
||||
1. Queries the fail2ban DB with ``GROUP BY ip`` to get the per-IP ban
|
||||
counts for all unique IPs in the window — no row-count cap.
|
||||
2. Batch-resolves every unique IP via :func:`~app.services.geo_service.lookup_batch`
|
||||
(100 IPs per HTTP call) instead of one-at-a-time lookups.
|
||||
3. Returns a ``{country_code: count}`` aggregation and the 200 most
|
||||
recent raw rows (already geo-cached from step 2) for the companion
|
||||
table.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
range_: Time-range preset.
|
||||
geo_enricher: Optional async ``(ip) -> GeoInfo | None`` callable.
|
||||
http_session: Optional :class:`aiohttp.ClientSession` for batch
|
||||
geo lookups. When provided, :func:`geo_service.lookup_batch`
|
||||
is used instead of the *geo_enricher* callable.
|
||||
geo_enricher: Legacy async ``(ip) -> GeoInfo | None`` callable;
|
||||
used when *http_session* is ``None``.
|
||||
app_db: Optional BanGUI application database used to persist newly
|
||||
resolved geo entries across restarts.
|
||||
origin: Optional origin filter — ``"blocklist"`` restricts results to
|
||||
the ``blocklist-import`` jail, ``"selfblock"`` excludes it.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.ban.BansByCountryResponse` with per-country
|
||||
aggregation and the full ban list.
|
||||
aggregation and the companion ban list.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
since: int = _since_unix(range_)
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
@@ -323,6 +342,7 @@ async def bans_by_country(
|
||||
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||
f2b_db.row_factory = aiosqlite.Row
|
||||
|
||||
# Total count for the window.
|
||||
async with f2b_db.execute(
|
||||
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
||||
(since, *origin_params),
|
||||
@@ -330,6 +350,19 @@ async def bans_by_country(
|
||||
count_row = await cur.fetchone()
|
||||
total: int = int(count_row[0]) if count_row else 0
|
||||
|
||||
# Aggregation: unique IPs + their total event count.
|
||||
# No LIMIT here — we need all unique source IPs for accurate country counts.
|
||||
async with f2b_db.execute(
|
||||
"SELECT ip, COUNT(*) AS event_count "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ?"
|
||||
+ origin_clause
|
||||
+ " GROUP BY ip",
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
agg_rows = await cur.fetchall()
|
||||
|
||||
# Companion table: most recent raw rows for display alongside the map.
|
||||
async with f2b_db.execute(
|
||||
"SELECT jail, ip, timeofban, bancount, data "
|
||||
"FROM bans "
|
||||
@@ -337,14 +370,21 @@ async def bans_by_country(
|
||||
+ origin_clause
|
||||
+ " ORDER BY timeofban DESC "
|
||||
"LIMIT ?",
|
||||
(since, *origin_params, _MAX_GEO_BANS),
|
||||
(since, *origin_params, _MAX_COMPANION_BANS),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
companion_rows = await cur.fetchall()
|
||||
|
||||
# Geo-enrich unique IPs in parallel.
|
||||
unique_ips: list[str] = list({str(r["ip"]) for r in rows})
|
||||
# Batch-resolve all unique IPs (much faster than individual lookups).
|
||||
unique_ips: list[str] = [str(r["ip"]) for r in agg_rows]
|
||||
geo_map: dict[str, Any] = {}
|
||||
if geo_enricher is not None and unique_ips:
|
||||
|
||||
if http_session is not None and unique_ips:
|
||||
try:
|
||||
geo_map = await geo_service.lookup_batch(unique_ips, http_session, db=app_db)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("ban_service_batch_geo_failed", error=str(exc))
|
||||
elif geo_enricher is not None and unique_ips:
|
||||
# Fallback: legacy per-IP enricher (used in tests / older callers).
|
||||
async def _safe_lookup(ip: str) -> tuple[str, Any]:
|
||||
try:
|
||||
return ip, await geo_enricher(ip)
|
||||
@@ -355,16 +395,29 @@ async def bans_by_country(
|
||||
results = await asyncio.gather(*(_safe_lookup(ip) for ip in unique_ips))
|
||||
geo_map = dict(results)
|
||||
|
||||
# Build ban items and aggregate country counts.
|
||||
# Build country aggregation from the SQL-grouped rows.
|
||||
countries: dict[str, int] = {}
|
||||
country_names: dict[str, str] = {}
|
||||
bans: list[DashboardBanItem] = []
|
||||
|
||||
for row in rows:
|
||||
ip = str(row["ip"])
|
||||
for row in agg_rows:
|
||||
ip: str = str(row["ip"])
|
||||
geo = geo_map.get(ip)
|
||||
cc: str | None = geo.country_code if geo else None
|
||||
cn: str | None = geo.country_name if geo else None
|
||||
event_count: int = int(row["event_count"])
|
||||
|
||||
if cc:
|
||||
countries[cc] = countries.get(cc, 0) + event_count
|
||||
if cn and cc not in country_names:
|
||||
country_names[cc] = cn
|
||||
|
||||
# Build companion table from recent rows (geo already cached from batch step).
|
||||
bans: list[DashboardBanItem] = []
|
||||
for row in companion_rows:
|
||||
ip = str(row["ip"])
|
||||
geo = geo_map.get(ip)
|
||||
cc = geo.country_code if geo else None
|
||||
cn = geo.country_name if geo else None
|
||||
asn: str | None = geo.asn if geo else None
|
||||
org: str | None = geo.org if geo else None
|
||||
matches, _ = _parse_data_json(row["data"])
|
||||
@@ -384,11 +437,6 @@ async def bans_by_country(
|
||||
)
|
||||
)
|
||||
|
||||
if cc:
|
||||
countries[cc] = countries.get(cc, 0) + 1
|
||||
if cn and cc not in country_names:
|
||||
country_names[cc] = cn
|
||||
|
||||
return BansByCountryResponse(
|
||||
countries=countries,
|
||||
country_names=country_names,
|
||||
|
||||
@@ -245,6 +245,10 @@ async def import_source(
|
||||
fail2ban requires individual addresses. Any error encountered during
|
||||
download is recorded and the result is returned without raising.
|
||||
|
||||
After a successful import the geo cache is pre-warmed by batch-resolving
|
||||
all newly banned IPs. This ensures the dashboard and map show country
|
||||
data immediately after import rather than facing cold-cache lookups.
|
||||
|
||||
Args:
|
||||
source: The :class:`~app.models.blocklist.BlocklistSource` to import.
|
||||
http_session: Shared :class:`aiohttp.ClientSession`.
|
||||
@@ -287,6 +291,7 @@ async def import_source(
|
||||
imported = 0
|
||||
skipped = 0
|
||||
ban_error: str | None = None
|
||||
imported_ips: list[str] = []
|
||||
|
||||
# Import jail_service here to avoid circular import at module level.
|
||||
from app.services import jail_service # noqa: PLC0415
|
||||
@@ -304,6 +309,7 @@ async def import_source(
|
||||
try:
|
||||
await jail_service.ban_ip(socket_path, BLOCKLIST_JAIL, stripped)
|
||||
imported += 1
|
||||
imported_ips.append(stripped)
|
||||
except jail_service.JailNotFoundError as exc:
|
||||
# The target jail does not exist in fail2ban — there is no point
|
||||
# continuing because every subsequent ban would also fail.
|
||||
@@ -329,6 +335,25 @@ async def import_source(
|
||||
skipped=skipped,
|
||||
error=ban_error,
|
||||
)
|
||||
|
||||
# --- Pre-warm geo cache for newly imported IPs ---
|
||||
if imported_ips:
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
try:
|
||||
await geo_service.lookup_batch(imported_ips, http_session, db=db)
|
||||
log.info(
|
||||
"blocklist_geo_prewarm_complete",
|
||||
source_id=source.id,
|
||||
count=len(imported_ips),
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"blocklist_geo_prewarm_failed",
|
||||
source_id=source.id,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
return ImportSourceResult(
|
||||
source_id=source.id,
|
||||
source_url=source.url,
|
||||
|
||||
@@ -1,22 +1,39 @@
|
||||
"""Geo service.
|
||||
|
||||
Resolves IP addresses to their country, ASN, and organisation using the
|
||||
`ip-api.com <http://ip-api.com>`_ JSON API. Results are cached in memory
|
||||
to avoid redundant HTTP requests for addresses that appear repeatedly.
|
||||
`ip-api.com <http://ip-api.com>`_ JSON API. Results are cached in two tiers:
|
||||
|
||||
The free ip-api.com endpoint requires no API key and supports up to 45
|
||||
requests per minute. Because results are cached indefinitely for the life
|
||||
of the process, under normal load the rate limit is rarely approached.
|
||||
1. **In-memory dict** — fastest; survives for the life of the process.
|
||||
2. **Persistent SQLite table** (``geo_cache``) — survives restarts; loaded
|
||||
into the in-memory dict during application startup via
|
||||
:func:`load_cache_from_db`.
|
||||
|
||||
Only *successful* lookups (those returning a non-``None`` ``country_code``)
|
||||
are written to the persistent cache. Failed lookups are **not** cached so
|
||||
they will be retried on the next request.
|
||||
|
||||
For bulk operations the batch endpoint ``http://ip-api.com/batch`` is used
|
||||
(up to 100 IPs per HTTP call) which is far more efficient than one-at-a-time
|
||||
requests. Use :func:`lookup_batch` from the ban or blocklist services.
|
||||
|
||||
Usage::
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
from app.services import geo_service
|
||||
|
||||
# warm the cache from the persistent store at startup
|
||||
async with aiosqlite.connect("bangui.db") as db:
|
||||
await geo_service.load_cache_from_db(db)
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# single lookup
|
||||
info = await geo_service.lookup("1.2.3.4", session)
|
||||
if info:
|
||||
print(info.country_code) # "DE"
|
||||
|
||||
# bulk lookup (more efficient for large sets)
|
||||
geo_map = await geo_service.lookup_batch(["1.2.3.4", "5.6.7.8"], session)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -28,6 +45,7 @@ import structlog
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -36,12 +54,22 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
#: ip-api.com single-IP lookup endpoint (HTTP only on the free tier).
|
||||
_API_URL: str = "http://ip-api.com/json/{ip}?fields=status,message,country,countryCode,org,as"
|
||||
_API_URL: str = (
|
||||
"http://ip-api.com/json/{ip}?fields=status,message,country,countryCode,org,as"
|
||||
)
|
||||
|
||||
#: ip-api.com batch endpoint — accepts up to 100 IPs per POST.
|
||||
_BATCH_API_URL: str = (
|
||||
"http://ip-api.com/batch?fields=status,message,country,countryCode,org,as,query"
|
||||
)
|
||||
|
||||
#: Maximum IPs per batch request (ip-api.com hard limit is 100).
|
||||
_BATCH_SIZE: int = 100
|
||||
|
||||
#: Maximum number of entries kept in the in-process cache before it is
|
||||
#: flushed completely. A simple eviction strategy — the cache is cheap to
|
||||
#: rebuild and memory is bounded.
|
||||
_MAX_CACHE_SIZE: int = 10_000
|
||||
#: rebuild from the persistent store.
|
||||
_MAX_CACHE_SIZE: int = 50_000
|
||||
|
||||
#: Timeout for outgoing geo API requests in seconds.
|
||||
_REQUEST_TIMEOUT: float = 5.0
|
||||
@@ -89,25 +117,95 @@ def clear_cache() -> None:
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# Persistent cache I/O
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def lookup(ip: str, http_session: aiohttp.ClientSession) -> GeoInfo | None:
|
||||
async def load_cache_from_db(db: aiosqlite.Connection) -> None:
|
||||
"""Pre-populate the in-memory cache from the ``geo_cache`` table.
|
||||
|
||||
Should be called once during application startup so the service starts
|
||||
with a warm cache instead of making cold API calls on the first request.
|
||||
|
||||
Args:
|
||||
db: Open :class:`aiosqlite.Connection` to the BanGUI application
|
||||
database (not the fail2ban database).
|
||||
"""
|
||||
count = 0
|
||||
async with db.execute(
|
||||
"SELECT ip, country_code, country_name, asn, org FROM geo_cache"
|
||||
) as cur:
|
||||
async for row in cur:
|
||||
ip: str = str(row[0])
|
||||
country_code: str | None = row[1]
|
||||
if country_code is None:
|
||||
continue
|
||||
_cache[ip] = GeoInfo(
|
||||
country_code=country_code,
|
||||
country_name=row[2],
|
||||
asn=row[3],
|
||||
org=row[4],
|
||||
)
|
||||
count += 1
|
||||
log.info("geo_cache_loaded_from_db", entries=count)
|
||||
|
||||
|
||||
async def _persist_entry(
|
||||
db: aiosqlite.Connection,
|
||||
ip: str,
|
||||
info: GeoInfo,
|
||||
) -> None:
|
||||
"""Upsert a resolved :class:`GeoInfo` into the ``geo_cache`` table.
|
||||
|
||||
Only called when ``info.country_code`` is not ``None`` so the persistent
|
||||
store never contains empty placeholder rows.
|
||||
|
||||
Args:
|
||||
db: BanGUI application database connection.
|
||||
ip: IP address string.
|
||||
info: Resolved geo data to persist.
|
||||
"""
|
||||
await db.execute(
|
||||
"""
|
||||
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(ip) DO UPDATE SET
|
||||
country_code = excluded.country_code,
|
||||
country_name = excluded.country_name,
|
||||
asn = excluded.asn,
|
||||
org = excluded.org,
|
||||
cached_at = strftime('%Y-%m-%dT%H:%M:%fZ', 'now')
|
||||
""",
|
||||
(ip, info.country_code, info.country_name, info.asn, info.org),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — single lookup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def lookup(
|
||||
ip: str,
|
||||
http_session: aiohttp.ClientSession,
|
||||
db: aiosqlite.Connection | None = None,
|
||||
) -> GeoInfo | None:
|
||||
"""Resolve an IP address to country, ASN, and organisation metadata.
|
||||
|
||||
Results are cached in-process. If the cache exceeds ``_MAX_CACHE_SIZE``
|
||||
entries it is flushed before the new result is stored, keeping memory
|
||||
usage bounded.
|
||||
entries it is flushed before the new result is stored.
|
||||
|
||||
Private, loopback, and link-local addresses are resolved to a placeholder
|
||||
``GeoInfo`` with ``None`` values so callers are not blocked by pointless
|
||||
API calls for RFC-1918 ranges.
|
||||
Only successful resolutions (``country_code is not None``) are written to
|
||||
the persistent cache when *db* is provided. Failed lookups are **not**
|
||||
cached so they are retried on the next call.
|
||||
|
||||
Args:
|
||||
ip: IPv4 or IPv6 address string.
|
||||
http_session: Shared :class:`aiohttp.ClientSession` (from
|
||||
``app.state.http_session``).
|
||||
db: Optional BanGUI application database. When provided, successful
|
||||
lookups are persisted for cross-restart cache warming.
|
||||
|
||||
Returns:
|
||||
A :class:`GeoInfo` instance, or ``None`` when the lookup fails
|
||||
@@ -135,37 +233,170 @@ async def lookup(ip: str, http_session: aiohttp.ClientSession) -> GeoInfo | None
|
||||
ip=ip,
|
||||
message=data.get("message", "unknown"),
|
||||
)
|
||||
# Still cache a negative result so we do not retry reserved IPs.
|
||||
result = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
||||
_store(ip, result)
|
||||
return result
|
||||
# Do NOT cache failed lookups — they will be retried on the next call.
|
||||
return GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
||||
|
||||
country_code: str | None = _str_or_none(data.get("countryCode"))
|
||||
country_name: str | None = _str_or_none(data.get("country"))
|
||||
asn_raw: str | None = _str_or_none(data.get("as"))
|
||||
org_raw: str | None = _str_or_none(data.get("org"))
|
||||
|
||||
# ip-api returns the full "AS12345 Some Org" string in both "as" and "org".
|
||||
# Extract just the AS number prefix for the asn field.
|
||||
asn: str | None = asn_raw.split()[0] if asn_raw else None
|
||||
org: str | None = org_raw
|
||||
|
||||
result = GeoInfo(
|
||||
country_code=country_code,
|
||||
country_name=country_name,
|
||||
asn=asn,
|
||||
org=org,
|
||||
)
|
||||
result = _parse_single_response(data)
|
||||
_store(ip, result)
|
||||
log.debug("geo_lookup_success", ip=ip, country=country_code, asn=asn)
|
||||
if result.country_code is not None and db is not None:
|
||||
try:
|
||||
await _persist_entry(db, ip, result)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("geo_persist_failed", ip=ip, error=str(exc))
|
||||
log.debug("geo_lookup_success", ip=ip, country=result.country_code, asn=result.asn)
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API — batch lookup
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def lookup_batch(
|
||||
ips: list[str],
|
||||
http_session: aiohttp.ClientSession,
|
||||
db: aiosqlite.Connection | None = None,
|
||||
) -> dict[str, GeoInfo]:
|
||||
"""Resolve multiple IP addresses in bulk using ip-api.com batch endpoint.
|
||||
|
||||
IPs already present in the in-memory cache are returned immediately
|
||||
without making an HTTP request. Uncached IPs are sent to
|
||||
``http://ip-api.com/batch`` in chunks of up to :data:`_BATCH_SIZE`.
|
||||
|
||||
Only successful resolutions (``country_code is not None``) are written to
|
||||
the persistent cache when *db* is provided.
|
||||
|
||||
Args:
|
||||
ips: List of IP address strings to resolve. Duplicates are ignored.
|
||||
http_session: Shared :class:`aiohttp.ClientSession`.
|
||||
db: Optional BanGUI application database for persistent cache writes.
|
||||
|
||||
Returns:
|
||||
Dict mapping ``ip → GeoInfo`` for every input IP. IPs whose
|
||||
resolution failed will have a ``GeoInfo`` with all-``None`` fields.
|
||||
"""
|
||||
geo_result: dict[str, GeoInfo] = {}
|
||||
uncached: list[str] = []
|
||||
|
||||
unique_ips = list(dict.fromkeys(ips)) # deduplicate, preserve order
|
||||
for ip in unique_ips:
|
||||
if ip in _cache:
|
||||
geo_result[ip] = _cache[ip]
|
||||
else:
|
||||
uncached.append(ip)
|
||||
|
||||
if not uncached:
|
||||
return geo_result
|
||||
|
||||
log.info("geo_batch_lookup_start", total=len(uncached))
|
||||
|
||||
for chunk_start in range(0, len(uncached), _BATCH_SIZE):
|
||||
chunk = uncached[chunk_start : chunk_start + _BATCH_SIZE]
|
||||
chunk_result = await _batch_api_call(chunk, http_session)
|
||||
|
||||
for ip, info in chunk_result.items():
|
||||
_store(ip, info)
|
||||
geo_result[ip] = info
|
||||
if info.country_code is not None and db is not None:
|
||||
try:
|
||||
await _persist_entry(db, ip, info)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("geo_persist_failed", ip=ip, error=str(exc))
|
||||
|
||||
log.info(
|
||||
"geo_batch_lookup_complete",
|
||||
requested=len(uncached),
|
||||
resolved=sum(1 for g in geo_result.values() if g.country_code is not None),
|
||||
)
|
||||
return geo_result
|
||||
|
||||
|
||||
async def _batch_api_call(
|
||||
ips: list[str],
|
||||
http_session: aiohttp.ClientSession,
|
||||
) -> dict[str, GeoInfo]:
|
||||
"""Send one batch request to the ip-api.com batch endpoint.
|
||||
|
||||
Args:
|
||||
ips: Up to :data:`_BATCH_SIZE` IP address strings.
|
||||
http_session: Shared HTTP session.
|
||||
|
||||
Returns:
|
||||
Dict mapping ``ip → GeoInfo`` for every IP in *ips*. IPs where the
|
||||
API returned a failure record or the request raised an exception get
|
||||
an all-``None`` :class:`GeoInfo`.
|
||||
"""
|
||||
empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
|
||||
fallback: dict[str, GeoInfo] = dict.fromkeys(ips, empty)
|
||||
|
||||
payload = [{"query": ip} for ip in ips]
|
||||
try:
|
||||
async with http_session.post(
|
||||
_BATCH_API_URL,
|
||||
json=payload,
|
||||
timeout=_REQUEST_TIMEOUT * 2, # type: ignore[arg-type]
|
||||
) as resp:
|
||||
if resp.status != 200:
|
||||
log.warning("geo_batch_non_200", status=resp.status, count=len(ips))
|
||||
return fallback
|
||||
data: list[dict[str, object]] = await resp.json(content_type=None)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("geo_batch_request_failed", count=len(ips), error=str(exc))
|
||||
return fallback
|
||||
|
||||
out: dict[str, GeoInfo] = {}
|
||||
for entry in data:
|
||||
ip_str: str = str(entry.get("query", ""))
|
||||
if not ip_str:
|
||||
continue
|
||||
if entry.get("status") != "success":
|
||||
out[ip_str] = empty
|
||||
log.debug(
|
||||
"geo_batch_entry_failed",
|
||||
ip=ip_str,
|
||||
message=entry.get("message", "unknown"),
|
||||
)
|
||||
continue
|
||||
out[ip_str] = _parse_single_response(entry)
|
||||
|
||||
# Fill any IPs missing from the response.
|
||||
for ip in ips:
|
||||
if ip not in out:
|
||||
out[ip] = empty
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_single_response(data: dict[str, object]) -> GeoInfo:
|
||||
"""Build a :class:`GeoInfo` from a single ip-api.com response dict.
|
||||
|
||||
Args:
|
||||
data: A ``status == "success"`` JSON response from ip-api.com.
|
||||
|
||||
Returns:
|
||||
Populated :class:`GeoInfo`.
|
||||
"""
|
||||
country_code: str | None = _str_or_none(data.get("countryCode"))
|
||||
country_name: str | None = _str_or_none(data.get("country"))
|
||||
asn_raw: str | None = _str_or_none(data.get("as"))
|
||||
org_raw: str | None = _str_or_none(data.get("org"))
|
||||
|
||||
# ip-api returns "AS12345 Some Org" in both "as" and "org".
|
||||
asn: str | None = asn_raw.split()[0] if asn_raw else None
|
||||
|
||||
return GeoInfo(
|
||||
country_code=country_code,
|
||||
country_name=country_name,
|
||||
asn=asn,
|
||||
org=org_raw,
|
||||
)
|
||||
|
||||
|
||||
def _str_or_none(value: object) -> str | None:
|
||||
"""Return *value* as a non-empty string, or ``None``.
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ ignore = ["B008"] # FastAPI uses function calls in default arguments (Depends)
|
||||
# sys.path manipulation before stdlib imports is intentional in test helpers
|
||||
# pytest evaluates fixture type annotations at runtime, so TC001/TC002/TC003 are false-positives
|
||||
"tests/**" = ["E402", "TC001", "TC002", "TC003"]
|
||||
"app/routers/**" = ["TC001"] # FastAPI evaluates Depends() type aliases at runtime via get_type_hints()
|
||||
"app/routers/**" = ["TC001", "TC002"] # FastAPI evaluates Depends() type aliases at runtime via get_type_hints()
|
||||
|
||||
[tool.ruff.format]
|
||||
quote-style = "double"
|
||||
|
||||
0
backend/tests/scripts/__init__.py
Normal file
0
backend/tests/scripts/__init__.py
Normal file
213
backend/tests/scripts/seed_10k_bans.py
Normal file
213
backend/tests/scripts/seed_10k_bans.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Seed 10 000 synthetic bans into the fail2ban dev database.
|
||||
|
||||
Usage::
|
||||
|
||||
cd backend
|
||||
python tests/scripts/seed_10k_bans.py [--db-path /path/to/fail2ban.sqlite3]
|
||||
|
||||
This script inserts 10 000 synthetic ban rows spread over the last 365 days
|
||||
into the fail2ban SQLite database and pre-resolves all synthetic IPs into the
|
||||
BanGUI geo_cache. Run it once to get realistic dashboard and map load times
|
||||
in the browser without requiring a live fail2ban instance with active traffic.
|
||||
|
||||
.. warning::
|
||||
This script **writes** to the fail2ban database. Only use it against the
|
||||
development database (``Docker/fail2ban-dev-config/fail2ban.sqlite3`` or
|
||||
equivalent). Never run it against a production database.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import random
|
||||
import sqlite3
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default paths
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_DEFAULT_F2B_DB: str = str(
|
||||
Path(__file__).resolve().parents[3] / "Docker" / "fail2ban-dev-config" / "fail2ban.sqlite3"
|
||||
)
|
||||
_DEFAULT_APP_DB: str = str(
|
||||
Path(__file__).resolve().parents[2] / "bangui.db"
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_BAN_COUNT: int = 10_000
|
||||
_YEAR_SECONDS: int = 365 * 24 * 3600
|
||||
_JAIL_POOL: list[str] = ["sshd", "nginx", "blocklist-import", "postfix", "dovecot"]
|
||||
_COUNTRY_POOL: list[tuple[str, str]] = [
|
||||
("DE", "Germany"),
|
||||
("US", "United States"),
|
||||
("CN", "China"),
|
||||
("RU", "Russia"),
|
||||
("FR", "France"),
|
||||
("BR", "Brazil"),
|
||||
("IN", "India"),
|
||||
("GB", "United Kingdom"),
|
||||
("NL", "Netherlands"),
|
||||
("CA", "Canada"),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _random_ip() -> str:
|
||||
"""Return a random dotted-decimal IPv4 string in public ranges."""
|
||||
return ".".join(str(random.randint(1, 254)) for _ in range(4))
|
||||
|
||||
|
||||
def _seed_bans(f2b_db_path: str) -> list[str]:
|
||||
"""Insert 10 000 synthetic ban rows into the fail2ban SQLite database.
|
||||
|
||||
Uses the synchronous ``sqlite3`` module because fail2ban itself uses
|
||||
synchronous writes and the schema is straightforward.
|
||||
|
||||
Args:
|
||||
f2b_db_path: Filesystem path to the fail2ban SQLite database.
|
||||
|
||||
Returns:
|
||||
List of all IP addresses inserted.
|
||||
"""
|
||||
now = int(time.time())
|
||||
ips: list[str] = [_random_ip() for _ in range(_BAN_COUNT)]
|
||||
rows = [
|
||||
(
|
||||
random.choice(_JAIL_POOL),
|
||||
ip,
|
||||
now - random.randint(0, _YEAR_SECONDS),
|
||||
3600,
|
||||
random.randint(1, 10),
|
||||
None,
|
||||
)
|
||||
for ip in ips
|
||||
]
|
||||
|
||||
with sqlite3.connect(f2b_db_path) as con:
|
||||
# Ensure the bans table exists (for dev environments where fail2ban
|
||||
# may not have created it yet).
|
||||
con.execute(
|
||||
"CREATE TABLE IF NOT EXISTS bans ("
|
||||
"jail TEXT NOT NULL, "
|
||||
"ip TEXT, "
|
||||
"timeofban INTEGER NOT NULL, "
|
||||
"bantime INTEGER NOT NULL DEFAULT 3600, "
|
||||
"bancount INTEGER NOT NULL DEFAULT 1, "
|
||||
"data JSON"
|
||||
")"
|
||||
)
|
||||
con.executemany(
|
||||
"INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
rows,
|
||||
)
|
||||
con.commit()
|
||||
|
||||
log.info("Inserted %d ban rows into %s", _BAN_COUNT, f2b_db_path)
|
||||
return ips
|
||||
|
||||
|
||||
def _seed_geo_cache(app_db_path: str, ips: list[str]) -> None:
|
||||
"""Pre-populate the BanGUI geo_cache table for all inserted IPs.
|
||||
|
||||
Assigns synthetic country data cycling through :data:`_COUNTRY_POOL` so
|
||||
the world map shows a realistic distribution of countries without making
|
||||
any real HTTP requests.
|
||||
|
||||
Args:
|
||||
app_db_path: Filesystem path to the BanGUI application database.
|
||||
ips: List of IP addresses to pre-cache.
|
||||
"""
|
||||
country_cycle = _COUNTRY_POOL * (len(ips) // len(_COUNTRY_POOL) + 1)
|
||||
rows = [
|
||||
(ip, cc, cn, f"AS{1000 + i % 500}", f"Synthetic ISP {i % 50}")
|
||||
for i, (ip, (cc, cn)) in enumerate(zip(ips, country_cycle, strict=False))
|
||||
]
|
||||
|
||||
with sqlite3.connect(app_db_path) as con:
|
||||
con.execute(
|
||||
"CREATE TABLE IF NOT EXISTS geo_cache ("
|
||||
"ip TEXT PRIMARY KEY, "
|
||||
"country_code TEXT, "
|
||||
"country_name TEXT, "
|
||||
"asn TEXT, "
|
||||
"org TEXT, "
|
||||
"cached_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))"
|
||||
")"
|
||||
)
|
||||
con.executemany(
|
||||
"""
|
||||
INSERT INTO geo_cache (ip, country_code, country_name, asn, org)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(ip) DO UPDATE SET
|
||||
country_code = excluded.country_code,
|
||||
country_name = excluded.country_name,
|
||||
asn = excluded.asn,
|
||||
org = excluded.org
|
||||
""",
|
||||
rows,
|
||||
)
|
||||
con.commit()
|
||||
|
||||
log.info("Pre-cached geo data for %d IPs in %s", len(ips), app_db_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Parse CLI arguments and run the seed operation."""
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(message)s")
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Seed 10 000 synthetic bans for performance testing."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--f2b-db",
|
||||
default=_DEFAULT_F2B_DB,
|
||||
help=f"Path to the fail2ban SQLite database (default: {_DEFAULT_F2B_DB})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--app-db",
|
||||
default=_DEFAULT_APP_DB,
|
||||
help=f"Path to the BanGUI application database (default: {_DEFAULT_APP_DB})",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
f2b_path = Path(args.f2b_db)
|
||||
app_path = Path(args.app_db)
|
||||
|
||||
if not f2b_path.parent.exists():
|
||||
log.error("fail2ban DB directory does not exist: %s", f2b_path.parent)
|
||||
sys.exit(1)
|
||||
|
||||
if not app_path.parent.exists():
|
||||
log.error("App DB directory does not exist: %s", app_path.parent)
|
||||
sys.exit(1)
|
||||
|
||||
log.info("Seeding %d bans into: %s", _BAN_COUNT, f2b_path)
|
||||
ips = _seed_bans(str(f2b_path))
|
||||
|
||||
log.info("Pre-caching geo data into: %s", app_path)
|
||||
_seed_geo_cache(str(app_path), ips)
|
||||
|
||||
log.info("Done. Restart the BanGUI backend to load the new geo cache entries.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
257
backend/tests/test_services/test_ban_service_perf.py
Normal file
257
backend/tests/test_services/test_ban_service_perf.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""Performance benchmark for ban_service with 10 000+ banned IPs.
|
||||
|
||||
These tests assert that both ``list_bans`` and ``bans_by_country`` complete
|
||||
within 2 seconds wall-clock time when the geo cache is warm and the fail2ban
|
||||
database contains 10 000 synthetic ban records.
|
||||
|
||||
External network calls are eliminated by pre-populating the in-memory geo
|
||||
cache before the timed section, so the benchmark measures only the database
|
||||
query and in-process aggregation overhead.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import random
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from app.services import ban_service, geo_service
|
||||
from app.services.geo_service import GeoInfo
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_BAN_COUNT: int = 10_000
|
||||
_WALL_CLOCK_LIMIT: float = 2.0 # seconds
|
||||
|
||||
_NOW: int = int(time.time())
|
||||
|
||||
#: Country codes to cycle through when generating synthetic geo data.
|
||||
_COUNTRIES: list[tuple[str, str]] = [
|
||||
("DE", "Germany"),
|
||||
("US", "United States"),
|
||||
("CN", "China"),
|
||||
("RU", "Russia"),
|
||||
("FR", "France"),
|
||||
("BR", "Brazil"),
|
||||
("IN", "India"),
|
||||
("GB", "United Kingdom"),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _random_ip() -> str:
|
||||
"""Generate a random-looking public IPv4 address string.
|
||||
|
||||
Returns:
|
||||
Dotted-decimal string with each octet in range 1–254.
|
||||
"""
|
||||
return ".".join(str(random.randint(1, 254)) for _ in range(4))
|
||||
|
||||
|
||||
def _random_jail() -> str:
|
||||
"""Pick a jail name from a small pool.
|
||||
|
||||
Returns:
|
||||
One of ``sshd``, ``nginx``, ``blocklist-import``.
|
||||
"""
|
||||
return random.choice(["sshd", "nginx", "blocklist-import"])
|
||||
|
||||
|
||||
async def _seed_f2b_db(path: str, n: int) -> list[str]:
|
||||
"""Create a fail2ban SQLite database with *n* synthetic ban rows.
|
||||
|
||||
Bans are spread uniformly over the last 365 days.
|
||||
|
||||
Args:
|
||||
path: Filesystem path for the new database.
|
||||
n: Number of rows to insert.
|
||||
|
||||
Returns:
|
||||
List of all unique IP address strings inserted.
|
||||
"""
|
||||
year_seconds = 365 * 24 * 3600
|
||||
ips: list[str] = [_random_ip() for _ in range(n)]
|
||||
|
||||
async with aiosqlite.connect(path) as db:
|
||||
await db.execute(
|
||||
"CREATE TABLE jails ("
|
||||
"name TEXT NOT NULL UNIQUE, "
|
||||
"enabled INTEGER NOT NULL DEFAULT 1"
|
||||
")"
|
||||
)
|
||||
await db.execute(
|
||||
"CREATE TABLE bans ("
|
||||
"jail TEXT NOT NULL, "
|
||||
"ip TEXT, "
|
||||
"timeofban INTEGER NOT NULL, "
|
||||
"bantime INTEGER NOT NULL DEFAULT 3600, "
|
||||
"bancount INTEGER NOT NULL DEFAULT 1, "
|
||||
"data JSON"
|
||||
")"
|
||||
)
|
||||
rows = [
|
||||
(_random_jail(), ip, _NOW - random.randint(0, year_seconds), 3600, 1, None)
|
||||
for ip in ips
|
||||
]
|
||||
await db.executemany(
|
||||
"INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
rows,
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
return ips
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop_policy() -> None: # type: ignore[misc]
|
||||
"""Use the default event loop policy for module-scoped fixtures."""
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
async def perf_db_path(tmp_path_factory: Any) -> str: # type: ignore[misc]
|
||||
"""Return the path to a fail2ban DB seeded with 10 000 synthetic bans.
|
||||
|
||||
Module-scoped so the database is created only once for all perf tests.
|
||||
"""
|
||||
tmp_path = tmp_path_factory.mktemp("perf")
|
||||
path = str(tmp_path / "fail2ban_perf.sqlite3")
|
||||
ips = await _seed_f2b_db(path, _BAN_COUNT)
|
||||
|
||||
# Pre-populate the in-memory geo cache so no network calls are made.
|
||||
geo_service.clear_cache()
|
||||
country_cycle = _COUNTRIES * (_BAN_COUNT // len(_COUNTRIES) + 1)
|
||||
for i, ip in enumerate(ips):
|
||||
cc, cn = country_cycle[i]
|
||||
geo_service._cache[ip] = GeoInfo( # noqa: SLF001 (test-only direct access)
|
||||
country_code=cc,
|
||||
country_name=cn,
|
||||
asn=f"AS{1000 + i % 500}",
|
||||
org="Synthetic ISP",
|
||||
)
|
||||
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBanServicePerformance:
|
||||
"""Wall-clock performance assertions for the ban service."""
|
||||
|
||||
async def test_list_bans_returns_within_time_limit(
|
||||
self, perf_db_path: str
|
||||
) -> None:
|
||||
"""``list_bans`` with 10 000 bans completes in under 2 seconds."""
|
||||
|
||||
async def noop_enricher(ip: str) -> GeoInfo | None:
|
||||
return geo_service._cache.get(ip) # noqa: SLF001
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=perf_db_path),
|
||||
):
|
||||
start = time.perf_counter()
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock",
|
||||
"365d",
|
||||
page=1,
|
||||
page_size=100,
|
||||
geo_enricher=noop_enricher,
|
||||
)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
assert result.total == _BAN_COUNT, (
|
||||
f"Expected {_BAN_COUNT} total bans, got {result.total}"
|
||||
)
|
||||
assert len(result.items) == 100
|
||||
assert elapsed < _WALL_CLOCK_LIMIT, (
|
||||
f"list_bans took {elapsed:.2f}s — must be < {_WALL_CLOCK_LIMIT}s"
|
||||
)
|
||||
|
||||
async def test_bans_by_country_returns_within_time_limit(
|
||||
self, perf_db_path: str
|
||||
) -> None:
|
||||
"""``bans_by_country`` with 10 000 bans completes in under 2 seconds."""
|
||||
|
||||
async def noop_enricher(ip: str) -> GeoInfo | None:
|
||||
return geo_service._cache.get(ip) # noqa: SLF001
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=perf_db_path),
|
||||
):
|
||||
start = time.perf_counter()
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock",
|
||||
"365d",
|
||||
geo_enricher=noop_enricher,
|
||||
)
|
||||
elapsed = time.perf_counter() - start
|
||||
|
||||
assert result.total == _BAN_COUNT
|
||||
assert len(result.countries) > 0 # At least one country resolved
|
||||
assert elapsed < _WALL_CLOCK_LIMIT, (
|
||||
f"bans_by_country took {elapsed:.2f}s — must be < {_WALL_CLOCK_LIMIT}s"
|
||||
)
|
||||
|
||||
async def test_list_bans_country_data_populated(
|
||||
self, perf_db_path: str
|
||||
) -> None:
|
||||
"""All returned items have geo data from the warm cache."""
|
||||
|
||||
async def noop_enricher(ip: str) -> GeoInfo | None:
|
||||
return geo_service._cache.get(ip) # noqa: SLF001
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=perf_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock",
|
||||
"365d",
|
||||
page=1,
|
||||
page_size=100,
|
||||
geo_enricher=noop_enricher,
|
||||
)
|
||||
|
||||
# Every item should have a country because the cache is warm.
|
||||
missing = [i for i in result.items if i.country_code is None]
|
||||
assert missing == [], f"{len(missing)} items missing country_code"
|
||||
|
||||
async def test_bans_by_country_aggregation_correct(
|
||||
self, perf_db_path: str
|
||||
) -> None:
|
||||
"""Country aggregation sums across all 10 000 bans."""
|
||||
|
||||
async def noop_enricher(ip: str) -> GeoInfo | None:
|
||||
return geo_service._cache.get(ip) # noqa: SLF001
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=perf_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock",
|
||||
"365d",
|
||||
geo_enricher=noop_enricher,
|
||||
)
|
||||
|
||||
total_in_countries = sum(result.countries.values())
|
||||
# Total bans in country map should equal total bans (all IPs are cached).
|
||||
assert total_in_countries == _BAN_COUNT, (
|
||||
f"Country sum {total_in_countries} != total {_BAN_COUNT}"
|
||||
)
|
||||
@@ -166,8 +166,8 @@ class TestLookupCaching:
|
||||
|
||||
assert session.get.call_count == 2
|
||||
|
||||
async def test_negative_result_cached(self) -> None:
|
||||
"""A failed lookup result (status != success) is also cached."""
|
||||
async def test_negative_result_not_cached(self) -> None:
|
||||
"""A failed lookup (status != success) is NOT cached so it is retried."""
|
||||
session = _make_session(
|
||||
{"status": "fail", "message": "reserved range"}
|
||||
)
|
||||
@@ -175,7 +175,8 @@ class TestLookupCaching:
|
||||
await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type]
|
||||
|
||||
assert session.get.call_count == 1
|
||||
# Failed lookups must not be cached — both calls must reach the API.
|
||||
assert session.get.call_count == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -201,7 +202,7 @@ class TestLookupFailures:
|
||||
assert result is None
|
||||
|
||||
async def test_failed_status_returns_geo_info_with_nulls(self) -> None:
|
||||
"""When ip-api returns ``status=fail`` a GeoInfo with null fields is cached."""
|
||||
"""When ip-api returns ``status=fail`` a GeoInfo with null fields is returned (but not cached)."""
|
||||
session = _make_session({"status": "fail", "message": "private range"})
|
||||
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user