"""Performance benchmark for ban_service with 10 000+ banned IPs. These tests assert that both ``list_bans`` and ``bans_by_country`` complete within 2 seconds wall-clock time when the geo cache is warm and the fail2ban database contains 10 000 synthetic ban records. External network calls are eliminated by pre-populating the in-memory geo cache before the timed section, so the benchmark measures only the database query and in-process aggregation overhead. """ from __future__ import annotations import random import time from typing import Any from unittest.mock import AsyncMock, patch import aiosqlite import pytest from app.services import ban_service, geo_service from app.services.geo_service import GeoInfo # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- _BAN_COUNT: int = 10_000 _WALL_CLOCK_LIMIT: float = 2.0 # seconds _NOW: int = int(time.time()) #: Country codes to cycle through when generating synthetic geo data. _COUNTRIES: list[tuple[str, str]] = [ ("DE", "Germany"), ("US", "United States"), ("CN", "China"), ("RU", "Russia"), ("FR", "France"), ("BR", "Brazil"), ("IN", "India"), ("GB", "United Kingdom"), ] # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- def _random_ip() -> str: """Generate a random-looking public IPv4 address string. Returns: Dotted-decimal string with each octet in range 1–254. """ return ".".join(str(random.randint(1, 254)) for _ in range(4)) def _random_jail() -> str: """Pick a jail name from a small pool. Returns: One of ``sshd``, ``nginx``, ``blocklist-import``. """ return random.choice(["sshd", "nginx", "blocklist-import"]) async def _seed_f2b_db(path: str, n: int) -> list[str]: """Create a fail2ban SQLite database with *n* synthetic ban rows. Bans are spread uniformly over the last 365 days. Args: path: Filesystem path for the new database. n: Number of rows to insert. Returns: List of all unique IP address strings inserted. """ year_seconds = 365 * 24 * 3600 ips: list[str] = [_random_ip() for _ in range(n)] async with aiosqlite.connect(path) as db: await db.execute( "CREATE TABLE jails (" "name TEXT NOT NULL UNIQUE, " "enabled INTEGER NOT NULL DEFAULT 1" ")" ) await db.execute( "CREATE TABLE bans (" "jail TEXT NOT NULL, " "ip TEXT, " "timeofban INTEGER NOT NULL, " "bantime INTEGER NOT NULL DEFAULT 3600, " "bancount INTEGER NOT NULL DEFAULT 1, " "data JSON" ")" ) rows = [ (_random_jail(), ip, _NOW - random.randint(0, year_seconds), 3600, 1, None) for ip in ips ] await db.executemany( "INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) " "VALUES (?, ?, ?, ?, ?, ?)", rows, ) await db.commit() return ips @pytest.fixture(scope="module") def event_loop_policy() -> None: # type: ignore[misc] """Use the default event loop policy for module-scoped fixtures.""" return None @pytest.fixture(scope="module") async def perf_db_path(tmp_path_factory: Any) -> str: # type: ignore[misc] """Return the path to a fail2ban DB seeded with 10 000 synthetic bans. Module-scoped so the database is created only once for all perf tests. """ tmp_path = tmp_path_factory.mktemp("perf") path = str(tmp_path / "fail2ban_perf.sqlite3") ips = await _seed_f2b_db(path, _BAN_COUNT) # Pre-populate the in-memory geo cache so no network calls are made. geo_service.clear_cache() country_cycle = _COUNTRIES * (_BAN_COUNT // len(_COUNTRIES) + 1) for i, ip in enumerate(ips): cc, cn = country_cycle[i] geo_service._cache[ip] = GeoInfo( # noqa: SLF001 (test-only direct access) country_code=cc, country_name=cn, asn=f"AS{1000 + i % 500}", org="Synthetic ISP", ) return path # --------------------------------------------------------------------------- # Benchmark tests # --------------------------------------------------------------------------- class TestBanServicePerformance: """Wall-clock performance assertions for the ban service.""" async def test_list_bans_returns_within_time_limit( self, perf_db_path: str ) -> None: """``list_bans`` with 10 000 bans completes in under 2 seconds.""" async def noop_enricher(ip: str) -> GeoInfo | None: return geo_service._cache.get(ip) # noqa: SLF001 with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=perf_db_path), ): start = time.perf_counter() result = await ban_service.list_bans( "/fake/sock", "365d", page=1, page_size=100, geo_enricher=noop_enricher, ) elapsed = time.perf_counter() - start assert result.total == _BAN_COUNT, ( f"Expected {_BAN_COUNT} total bans, got {result.total}" ) assert len(result.items) == 100 assert elapsed < _WALL_CLOCK_LIMIT, ( f"list_bans took {elapsed:.2f}s — must be < {_WALL_CLOCK_LIMIT}s" ) async def test_bans_by_country_returns_within_time_limit( self, perf_db_path: str ) -> None: """``bans_by_country`` with 10 000 bans completes in under 2 seconds.""" async def noop_enricher(ip: str) -> GeoInfo | None: return geo_service._cache.get(ip) # noqa: SLF001 with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=perf_db_path), ): start = time.perf_counter() result = await ban_service.bans_by_country( "/fake/sock", "365d", geo_enricher=noop_enricher, ) elapsed = time.perf_counter() - start assert result.total == _BAN_COUNT assert len(result.countries) > 0 # At least one country resolved assert elapsed < _WALL_CLOCK_LIMIT, ( f"bans_by_country took {elapsed:.2f}s — must be < {_WALL_CLOCK_LIMIT}s" ) async def test_list_bans_country_data_populated( self, perf_db_path: str ) -> None: """All returned items have geo data from the warm cache.""" async def noop_enricher(ip: str) -> GeoInfo | None: return geo_service._cache.get(ip) # noqa: SLF001 with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=perf_db_path), ): result = await ban_service.list_bans( "/fake/sock", "365d", page=1, page_size=100, geo_enricher=noop_enricher, ) # Every item should have a country because the cache is warm. missing = [i for i in result.items if i.country_code is None] assert missing == [], f"{len(missing)} items missing country_code" async def test_bans_by_country_aggregation_correct( self, perf_db_path: str ) -> None: """Country aggregation sums across all 10 000 bans.""" async def noop_enricher(ip: str) -> GeoInfo | None: return geo_service._cache.get(ip) # noqa: SLF001 with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=perf_db_path), ): result = await ban_service.bans_by_country( "/fake/sock", "365d", geo_enricher=noop_enricher, ) total_in_countries = sum(result.countries.values()) # Total bans in country map should equal total bans (all IPs are cached). assert total_in_countries == _BAN_COUNT, ( f"Country sum {total_in_countries} != total {_BAN_COUNT}" )