- Add persistent geo_cache SQLite table (db.py) - Rewrite geo_service: batch API (100 IPs/call), two-tier cache, no caching of failed lookups so they are retried - Pre-warm geo cache from DB on startup (main.py lifespan) - Rewrite bans_by_country: SQL GROUP BY ip aggregation + lookup_batch instead of 2000-row fetch + asyncio.gather individual calls - Pre-warm geo cache after blocklist import (blocklist_service) - Add 300ms debounce to useMapData hook to cancel stale requests - Add perf benchmark asserting <2s for 10k bans - Add seed_10k_bans.py script for manual perf testing
258 lines
8.2 KiB
Python
258 lines
8.2 KiB
Python
"""Performance benchmark for ban_service with 10 000+ banned IPs.
|
||
|
||
These tests assert that both ``list_bans`` and ``bans_by_country`` complete
|
||
within 2 seconds wall-clock time when the geo cache is warm and the fail2ban
|
||
database contains 10 000 synthetic ban records.
|
||
|
||
External network calls are eliminated by pre-populating the in-memory geo
|
||
cache before the timed section, so the benchmark measures only the database
|
||
query and in-process aggregation overhead.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import random
|
||
import time
|
||
from typing import Any
|
||
from unittest.mock import AsyncMock, patch
|
||
|
||
import aiosqlite
|
||
import pytest
|
||
|
||
from app.services import ban_service, geo_service
|
||
from app.services.geo_service import GeoInfo
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Constants
|
||
# ---------------------------------------------------------------------------
|
||
|
||
_BAN_COUNT: int = 10_000
|
||
_WALL_CLOCK_LIMIT: float = 2.0 # seconds
|
||
|
||
_NOW: int = int(time.time())
|
||
|
||
#: Country codes to cycle through when generating synthetic geo data.
|
||
_COUNTRIES: list[tuple[str, str]] = [
|
||
("DE", "Germany"),
|
||
("US", "United States"),
|
||
("CN", "China"),
|
||
("RU", "Russia"),
|
||
("FR", "France"),
|
||
("BR", "Brazil"),
|
||
("IN", "India"),
|
||
("GB", "United Kingdom"),
|
||
]
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Fixtures
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _random_ip() -> str:
|
||
"""Generate a random-looking public IPv4 address string.
|
||
|
||
Returns:
|
||
Dotted-decimal string with each octet in range 1–254.
|
||
"""
|
||
return ".".join(str(random.randint(1, 254)) for _ in range(4))
|
||
|
||
|
||
def _random_jail() -> str:
|
||
"""Pick a jail name from a small pool.
|
||
|
||
Returns:
|
||
One of ``sshd``, ``nginx``, ``blocklist-import``.
|
||
"""
|
||
return random.choice(["sshd", "nginx", "blocklist-import"])
|
||
|
||
|
||
async def _seed_f2b_db(path: str, n: int) -> list[str]:
|
||
"""Create a fail2ban SQLite database with *n* synthetic ban rows.
|
||
|
||
Bans are spread uniformly over the last 365 days.
|
||
|
||
Args:
|
||
path: Filesystem path for the new database.
|
||
n: Number of rows to insert.
|
||
|
||
Returns:
|
||
List of all unique IP address strings inserted.
|
||
"""
|
||
year_seconds = 365 * 24 * 3600
|
||
ips: list[str] = [_random_ip() for _ in range(n)]
|
||
|
||
async with aiosqlite.connect(path) as db:
|
||
await db.execute(
|
||
"CREATE TABLE jails ("
|
||
"name TEXT NOT NULL UNIQUE, "
|
||
"enabled INTEGER NOT NULL DEFAULT 1"
|
||
")"
|
||
)
|
||
await db.execute(
|
||
"CREATE TABLE bans ("
|
||
"jail TEXT NOT NULL, "
|
||
"ip TEXT, "
|
||
"timeofban INTEGER NOT NULL, "
|
||
"bantime INTEGER NOT NULL DEFAULT 3600, "
|
||
"bancount INTEGER NOT NULL DEFAULT 1, "
|
||
"data JSON"
|
||
")"
|
||
)
|
||
rows = [
|
||
(_random_jail(), ip, _NOW - random.randint(0, year_seconds), 3600, 1, None)
|
||
for ip in ips
|
||
]
|
||
await db.executemany(
|
||
"INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) "
|
||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||
rows,
|
||
)
|
||
await db.commit()
|
||
|
||
return ips
|
||
|
||
|
||
@pytest.fixture(scope="module")
|
||
def event_loop_policy() -> None: # type: ignore[misc]
|
||
"""Use the default event loop policy for module-scoped fixtures."""
|
||
return None
|
||
|
||
|
||
@pytest.fixture(scope="module")
|
||
async def perf_db_path(tmp_path_factory: Any) -> str: # type: ignore[misc]
|
||
"""Return the path to a fail2ban DB seeded with 10 000 synthetic bans.
|
||
|
||
Module-scoped so the database is created only once for all perf tests.
|
||
"""
|
||
tmp_path = tmp_path_factory.mktemp("perf")
|
||
path = str(tmp_path / "fail2ban_perf.sqlite3")
|
||
ips = await _seed_f2b_db(path, _BAN_COUNT)
|
||
|
||
# Pre-populate the in-memory geo cache so no network calls are made.
|
||
geo_service.clear_cache()
|
||
country_cycle = _COUNTRIES * (_BAN_COUNT // len(_COUNTRIES) + 1)
|
||
for i, ip in enumerate(ips):
|
||
cc, cn = country_cycle[i]
|
||
geo_service._cache[ip] = GeoInfo( # noqa: SLF001 (test-only direct access)
|
||
country_code=cc,
|
||
country_name=cn,
|
||
asn=f"AS{1000 + i % 500}",
|
||
org="Synthetic ISP",
|
||
)
|
||
|
||
return path
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Benchmark tests
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestBanServicePerformance:
|
||
"""Wall-clock performance assertions for the ban service."""
|
||
|
||
async def test_list_bans_returns_within_time_limit(
|
||
self, perf_db_path: str
|
||
) -> None:
|
||
"""``list_bans`` with 10 000 bans completes in under 2 seconds."""
|
||
|
||
async def noop_enricher(ip: str) -> GeoInfo | None:
|
||
return geo_service._cache.get(ip) # noqa: SLF001
|
||
|
||
with patch(
|
||
"app.services.ban_service._get_fail2ban_db_path",
|
||
new=AsyncMock(return_value=perf_db_path),
|
||
):
|
||
start = time.perf_counter()
|
||
result = await ban_service.list_bans(
|
||
"/fake/sock",
|
||
"365d",
|
||
page=1,
|
||
page_size=100,
|
||
geo_enricher=noop_enricher,
|
||
)
|
||
elapsed = time.perf_counter() - start
|
||
|
||
assert result.total == _BAN_COUNT, (
|
||
f"Expected {_BAN_COUNT} total bans, got {result.total}"
|
||
)
|
||
assert len(result.items) == 100
|
||
assert elapsed < _WALL_CLOCK_LIMIT, (
|
||
f"list_bans took {elapsed:.2f}s — must be < {_WALL_CLOCK_LIMIT}s"
|
||
)
|
||
|
||
async def test_bans_by_country_returns_within_time_limit(
|
||
self, perf_db_path: str
|
||
) -> None:
|
||
"""``bans_by_country`` with 10 000 bans completes in under 2 seconds."""
|
||
|
||
async def noop_enricher(ip: str) -> GeoInfo | None:
|
||
return geo_service._cache.get(ip) # noqa: SLF001
|
||
|
||
with patch(
|
||
"app.services.ban_service._get_fail2ban_db_path",
|
||
new=AsyncMock(return_value=perf_db_path),
|
||
):
|
||
start = time.perf_counter()
|
||
result = await ban_service.bans_by_country(
|
||
"/fake/sock",
|
||
"365d",
|
||
geo_enricher=noop_enricher,
|
||
)
|
||
elapsed = time.perf_counter() - start
|
||
|
||
assert result.total == _BAN_COUNT
|
||
assert len(result.countries) > 0 # At least one country resolved
|
||
assert elapsed < _WALL_CLOCK_LIMIT, (
|
||
f"bans_by_country took {elapsed:.2f}s — must be < {_WALL_CLOCK_LIMIT}s"
|
||
)
|
||
|
||
async def test_list_bans_country_data_populated(
|
||
self, perf_db_path: str
|
||
) -> None:
|
||
"""All returned items have geo data from the warm cache."""
|
||
|
||
async def noop_enricher(ip: str) -> GeoInfo | None:
|
||
return geo_service._cache.get(ip) # noqa: SLF001
|
||
|
||
with patch(
|
||
"app.services.ban_service._get_fail2ban_db_path",
|
||
new=AsyncMock(return_value=perf_db_path),
|
||
):
|
||
result = await ban_service.list_bans(
|
||
"/fake/sock",
|
||
"365d",
|
||
page=1,
|
||
page_size=100,
|
||
geo_enricher=noop_enricher,
|
||
)
|
||
|
||
# Every item should have a country because the cache is warm.
|
||
missing = [i for i in result.items if i.country_code is None]
|
||
assert missing == [], f"{len(missing)} items missing country_code"
|
||
|
||
async def test_bans_by_country_aggregation_correct(
|
||
self, perf_db_path: str
|
||
) -> None:
|
||
"""Country aggregation sums across all 10 000 bans."""
|
||
|
||
async def noop_enricher(ip: str) -> GeoInfo | None:
|
||
return geo_service._cache.get(ip) # noqa: SLF001
|
||
|
||
with patch(
|
||
"app.services.ban_service._get_fail2ban_db_path",
|
||
new=AsyncMock(return_value=perf_db_path),
|
||
):
|
||
result = await ban_service.bans_by_country(
|
||
"/fake/sock",
|
||
"365d",
|
||
geo_enricher=noop_enricher,
|
||
)
|
||
|
||
total_in_countries = sum(result.countries.values())
|
||
# Total bans in country map should equal total bans (all IPs are cached).
|
||
assert total_in_countries == _BAN_COUNT, (
|
||
f"Country sum {total_in_countries} != total {_BAN_COUNT}"
|
||
)
|