- Add TYPE_CHECKING guards for runtime-expensive imports (aiohttp, aiosqlite) - Reorganize imports to follow PEP 8 conventions - Convert TypeAlias to modern PEP 695 type syntax (where appropriate) - Use Sequence/Mapping from collections.abc for type hints (covariant) - Replace string literals with cast() for improved type inference - Fix casting of Fail2BanResponse and TypedDict patterns - Add IpLookupResult TypedDict for precise return type annotation - Reformat overlong lines for readability (120 char limit) - Add asyncio_mode and filterwarnings to pytest config - Update test fixtures with improved type hints This improves mypy type checking and makes type relationships explicit.
258 lines
8.2 KiB
Python
258 lines
8.2 KiB
Python
"""Performance benchmark for ban_service with 10 000+ banned IPs.
|
||
|
||
These tests assert that both ``list_bans`` and ``bans_by_country`` complete
|
||
within 2 seconds wall-clock time when the geo cache is warm and the fail2ban
|
||
database contains 10 000 synthetic ban records.
|
||
|
||
External network calls are eliminated by pre-populating the in-memory geo
|
||
cache before the timed section, so the benchmark measures only the database
|
||
query and in-process aggregation overhead.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import random
|
||
import time
|
||
from typing import Any
|
||
from unittest.mock import AsyncMock, patch
|
||
|
||
import aiosqlite
|
||
import pytest
|
||
|
||
from app.services import ban_service, geo_service
|
||
from app.services.geo_service import GeoInfo
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Constants
|
||
# ---------------------------------------------------------------------------
|
||
|
||
_BAN_COUNT: int = 10_000
|
||
_WALL_CLOCK_LIMIT: float = 2.0 # seconds
|
||
|
||
_NOW: int = int(time.time())
|
||
|
||
#: Country codes to cycle through when generating synthetic geo data.
|
||
_COUNTRIES: list[tuple[str, str]] = [
|
||
("DE", "Germany"),
|
||
("US", "United States"),
|
||
("CN", "China"),
|
||
("RU", "Russia"),
|
||
("FR", "France"),
|
||
("BR", "Brazil"),
|
||
("IN", "India"),
|
||
("GB", "United Kingdom"),
|
||
]
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Fixtures
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
def _random_ip() -> str:
|
||
"""Generate a random-looking public IPv4 address string.
|
||
|
||
Returns:
|
||
Dotted-decimal string with each octet in range 1–254.
|
||
"""
|
||
return ".".join(str(random.randint(1, 254)) for _ in range(4))
|
||
|
||
|
||
def _random_jail() -> str:
|
||
"""Pick a jail name from a small pool.
|
||
|
||
Returns:
|
||
One of ``sshd``, ``nginx``, ``blocklist-import``.
|
||
"""
|
||
return random.choice(["sshd", "nginx", "blocklist-import"])
|
||
|
||
|
||
async def _seed_f2b_db(path: str, n: int) -> list[str]:
|
||
"""Create a fail2ban SQLite database with *n* synthetic ban rows.
|
||
|
||
Bans are spread uniformly over the last 365 days.
|
||
|
||
Args:
|
||
path: Filesystem path for the new database.
|
||
n: Number of rows to insert.
|
||
|
||
Returns:
|
||
List of all unique IP address strings inserted.
|
||
"""
|
||
year_seconds = 365 * 24 * 3600
|
||
ips: list[str] = [_random_ip() for _ in range(n)]
|
||
|
||
async with aiosqlite.connect(path) as db:
|
||
await db.execute(
|
||
"CREATE TABLE jails ("
|
||
"name TEXT NOT NULL UNIQUE, "
|
||
"enabled INTEGER NOT NULL DEFAULT 1"
|
||
")"
|
||
)
|
||
await db.execute(
|
||
"CREATE TABLE bans ("
|
||
"jail TEXT NOT NULL, "
|
||
"ip TEXT, "
|
||
"timeofban INTEGER NOT NULL, "
|
||
"bantime INTEGER NOT NULL DEFAULT 3600, "
|
||
"bancount INTEGER NOT NULL DEFAULT 1, "
|
||
"data JSON"
|
||
")"
|
||
)
|
||
rows = [
|
||
(_random_jail(), ip, _NOW - random.randint(0, year_seconds), 3600, 1, None)
|
||
for ip in ips
|
||
]
|
||
await db.executemany(
|
||
"INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) "
|
||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||
rows,
|
||
)
|
||
await db.commit()
|
||
|
||
return ips
|
||
|
||
|
||
@pytest.fixture(scope="module")
|
||
def event_loop_policy() -> None:
|
||
"""Use the default event loop policy for module-scoped fixtures."""
|
||
return None
|
||
|
||
|
||
@pytest.fixture(scope="module")
|
||
async def perf_db_path(tmp_path_factory: Any) -> str:
|
||
"""Return the path to a fail2ban DB seeded with 10 000 synthetic bans.
|
||
|
||
Module-scoped so the database is created only once for all perf tests.
|
||
"""
|
||
tmp_path = tmp_path_factory.mktemp("perf")
|
||
path = str(tmp_path / "fail2ban_perf.sqlite3")
|
||
ips = await _seed_f2b_db(path, _BAN_COUNT)
|
||
|
||
# Pre-populate the in-memory geo cache so no network calls are made.
|
||
geo_service.clear_cache()
|
||
country_cycle = _COUNTRIES * (_BAN_COUNT // len(_COUNTRIES) + 1)
|
||
for i, ip in enumerate(ips):
|
||
cc, cn = country_cycle[i]
|
||
geo_service._cache[ip] = GeoInfo( # noqa: SLF001 (test-only direct access)
|
||
country_code=cc,
|
||
country_name=cn,
|
||
asn=f"AS{1000 + i % 500}",
|
||
org="Synthetic ISP",
|
||
)
|
||
|
||
return path
|
||
|
||
|
||
# ---------------------------------------------------------------------------
|
||
# Benchmark tests
|
||
# ---------------------------------------------------------------------------
|
||
|
||
|
||
class TestBanServicePerformance:
|
||
"""Wall-clock performance assertions for the ban service."""
|
||
|
||
async def test_list_bans_returns_within_time_limit(
|
||
self, perf_db_path: str
|
||
) -> None:
|
||
"""``list_bans`` with 10 000 bans completes in under 2 seconds."""
|
||
|
||
async def noop_enricher(ip: str) -> GeoInfo | None:
|
||
return geo_service._cache.get(ip) # noqa: SLF001
|
||
|
||
with patch(
|
||
"app.services.ban_service._get_fail2ban_db_path",
|
||
new=AsyncMock(return_value=perf_db_path),
|
||
):
|
||
start = time.perf_counter()
|
||
result = await ban_service.list_bans(
|
||
"/fake/sock",
|
||
"365d",
|
||
page=1,
|
||
page_size=100,
|
||
geo_enricher=noop_enricher,
|
||
)
|
||
elapsed = time.perf_counter() - start
|
||
|
||
assert result.total == _BAN_COUNT, (
|
||
f"Expected {_BAN_COUNT} total bans, got {result.total}"
|
||
)
|
||
assert len(result.items) == 100
|
||
assert elapsed < _WALL_CLOCK_LIMIT, (
|
||
f"list_bans took {elapsed:.2f}s — must be < {_WALL_CLOCK_LIMIT}s"
|
||
)
|
||
|
||
async def test_bans_by_country_returns_within_time_limit(
|
||
self, perf_db_path: str
|
||
) -> None:
|
||
"""``bans_by_country`` with 10 000 bans completes in under 2 seconds."""
|
||
|
||
async def noop_enricher(ip: str) -> GeoInfo | None:
|
||
return geo_service._cache.get(ip) # noqa: SLF001
|
||
|
||
with patch(
|
||
"app.services.ban_service._get_fail2ban_db_path",
|
||
new=AsyncMock(return_value=perf_db_path),
|
||
):
|
||
start = time.perf_counter()
|
||
result = await ban_service.bans_by_country(
|
||
"/fake/sock",
|
||
"365d",
|
||
geo_enricher=noop_enricher,
|
||
)
|
||
elapsed = time.perf_counter() - start
|
||
|
||
assert result.total == _BAN_COUNT
|
||
assert len(result.countries) > 0 # At least one country resolved
|
||
assert elapsed < _WALL_CLOCK_LIMIT, (
|
||
f"bans_by_country took {elapsed:.2f}s — must be < {_WALL_CLOCK_LIMIT}s"
|
||
)
|
||
|
||
async def test_list_bans_country_data_populated(
|
||
self, perf_db_path: str
|
||
) -> None:
|
||
"""All returned items have geo data from the warm cache."""
|
||
|
||
async def noop_enricher(ip: str) -> GeoInfo | None:
|
||
return geo_service._cache.get(ip) # noqa: SLF001
|
||
|
||
with patch(
|
||
"app.services.ban_service._get_fail2ban_db_path",
|
||
new=AsyncMock(return_value=perf_db_path),
|
||
):
|
||
result = await ban_service.list_bans(
|
||
"/fake/sock",
|
||
"365d",
|
||
page=1,
|
||
page_size=100,
|
||
geo_enricher=noop_enricher,
|
||
)
|
||
|
||
# Every item should have a country because the cache is warm.
|
||
missing = [i for i in result.items if i.country_code is None]
|
||
assert missing == [], f"{len(missing)} items missing country_code"
|
||
|
||
async def test_bans_by_country_aggregation_correct(
|
||
self, perf_db_path: str
|
||
) -> None:
|
||
"""Country aggregation sums across all 10 000 bans."""
|
||
|
||
async def noop_enricher(ip: str) -> GeoInfo | None:
|
||
return geo_service._cache.get(ip) # noqa: SLF001
|
||
|
||
with patch(
|
||
"app.services.ban_service._get_fail2ban_db_path",
|
||
new=AsyncMock(return_value=perf_db_path),
|
||
):
|
||
result = await ban_service.bans_by_country(
|
||
"/fake/sock",
|
||
"365d",
|
||
geo_enricher=noop_enricher,
|
||
)
|
||
|
||
total_in_countries = sum(result.countries.values())
|
||
# Total bans in country map should equal total bans (all IPs are cached).
|
||
assert total_in_countries == _BAN_COUNT, (
|
||
f"Country sum {total_in_countries} != total {_BAN_COUNT}"
|
||
)
|