Files
BanGUI/backend/tests/test_services/test_ban_service.py
Lukas 0225f32901 Fix country not shown in ban list due to geo rate limiting
list_bans() was calling geo_service.lookup() once per IP on the
page (e.g. 100 sequential HTTP requests), hitting the ip-api.com
free-tier single-IP limit of 45 req/min.  IPs beyond the ~45th
were added to the in-process negative cache (5 min TTL) and showed
as no country until the TTL expired.  The map endpoint never had
this problem because it used lookup_batch (100 IPs per POST).

Add http_session and app_db params to list_bans().  When
http_session is provided (production path), the entire page is
resolved in one lookup_batch() call instead of N individual ones.
The legacy geo_enricher callback is kept for test compatibility.
Update the dashboard router to use the batch path directly.

Adds 3 tests covering the batch geo path, failure resilience, and
http_session priority over geo_enricher.
2026-03-10 17:20:13 +01:00

613 lines
22 KiB
Python

"""Tests for ban_service.list_bans()."""
from __future__ import annotations
import json
import time
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, patch
import aiosqlite
import pytest
from app.services import ban_service
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
_NOW: int = int(time.time())
_ONE_HOUR_AGO: int = _NOW - 3600
_TWO_DAYS_AGO: int = _NOW - 2 * 24 * 3600
async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
"""Create a minimal fail2ban SQLite database with the given ban rows.
Args:
path: Filesystem path for the new SQLite file.
rows: Sequence of dicts with keys ``jail``, ``ip``, ``timeofban``,
``bantime``, ``bancount``, and optionally ``data``.
"""
async with aiosqlite.connect(path) as db:
await db.execute(
"CREATE TABLE jails ("
"name TEXT NOT NULL UNIQUE, "
"enabled INTEGER NOT NULL DEFAULT 1"
")"
)
await db.execute(
"CREATE TABLE bans ("
"jail TEXT NOT NULL, "
"ip TEXT, "
"timeofban INTEGER NOT NULL, "
"bantime INTEGER NOT NULL, "
"bancount INTEGER NOT NULL DEFAULT 1, "
"data JSON"
")"
)
for row in rows:
await db.execute(
"INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) "
"VALUES (?, ?, ?, ?, ?, ?)",
(
row["jail"],
row["ip"],
row["timeofban"],
row.get("bantime", 3600),
row.get("bancount", 1),
json.dumps(row["data"]) if "data" in row else None,
),
)
await db.commit()
@pytest.fixture
async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
"""Return the path to a test fail2ban SQLite database with several bans."""
path = str(tmp_path / "fail2ban_test.sqlite3")
await _create_f2b_db(
path,
[
{
"jail": "sshd",
"ip": "1.2.3.4",
"timeofban": _ONE_HOUR_AGO,
"bantime": 3600,
"bancount": 2,
"data": {
"matches": ["Nov 10 10:00 sshd[123]: Failed password for root"],
"failures": 5,
},
},
{
"jail": "nginx",
"ip": "5.6.7.8",
"timeofban": _ONE_HOUR_AGO,
"bantime": 7200,
"bancount": 1,
"data": {"matches": ["GET /admin HTTP/1.1"], "failures": 3},
},
{
"jail": "sshd",
"ip": "9.10.11.12",
"timeofban": _TWO_DAYS_AGO,
"bantime": 3600,
"bancount": 1,
"data": {"failures": 6}, # no matches
},
],
)
return path
@pytest.fixture
async def mixed_origin_db_path(tmp_path: Path) -> str: # type: ignore[misc]
"""Return a database with bans from both blocklist-import and organic jails."""
path = str(tmp_path / "fail2ban_mixed_origin.sqlite3")
await _create_f2b_db(
path,
[
{
"jail": "blocklist-import",
"ip": "10.0.0.1",
"timeofban": _ONE_HOUR_AGO,
"bantime": -1,
"bancount": 1,
},
{
"jail": "sshd",
"ip": "10.0.0.2",
"timeofban": _ONE_HOUR_AGO,
"bantime": 3600,
"bancount": 3,
},
{
"jail": "nginx",
"ip": "10.0.0.3",
"timeofban": _ONE_HOUR_AGO,
"bantime": 7200,
"bancount": 1,
},
],
)
return path
@pytest.fixture
async def empty_f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
"""Return the path to a fail2ban SQLite database with no ban records."""
path = str(tmp_path / "fail2ban_empty.sqlite3")
await _create_f2b_db(path, [])
return path
# ---------------------------------------------------------------------------
# list_bans — happy path
# ---------------------------------------------------------------------------
class TestListBansHappyPath:
"""Verify ban_service.list_bans() under normal conditions."""
async def test_returns_bans_in_range(self, f2b_db_path: str) -> None:
"""Only bans within the selected range are returned."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans("/fake/sock", "24h")
# Two bans within last 24 h; one is 2 days old and excluded.
assert result.total == 2
assert len(result.items) == 2
async def test_results_sorted_newest_first(self, f2b_db_path: str) -> None:
"""Items are ordered by ``banned_at`` descending (newest first)."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans("/fake/sock", "24h")
timestamps = [item.banned_at for item in result.items]
assert timestamps == sorted(timestamps, reverse=True)
async def test_ban_fields_present(self, f2b_db_path: str) -> None:
"""Each item contains ip, jail, banned_at, ban_count."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans("/fake/sock", "24h")
for item in result.items:
assert item.ip
assert item.jail
assert item.banned_at
assert item.ban_count >= 1
async def test_service_extracted_from_first_match(self, f2b_db_path: str) -> None:
"""``service`` field is the first element of ``data.matches``."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans("/fake/sock", "24h")
sshd_item = next(i for i in result.items if i.jail == "sshd")
assert sshd_item.service is not None
assert "Failed password" in sshd_item.service
async def test_service_is_none_when_no_matches(self, f2b_db_path: str) -> None:
"""``service`` is ``None`` when the ban has no stored matches."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
# Use 7d to include the older ban with no matches.
result = await ban_service.list_bans("/fake/sock", "7d")
no_match = next(i for i in result.items if i.ip == "9.10.11.12")
assert no_match.service is None
async def test_empty_db_returns_zero(self, empty_f2b_db_path: str) -> None:
"""When no bans exist the result has total=0 and no items."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=empty_f2b_db_path),
):
result = await ban_service.list_bans("/fake/sock", "24h")
assert result.total == 0
assert result.items == []
async def test_365d_range_includes_old_bans(self, f2b_db_path: str) -> None:
"""The ``365d`` range includes bans that are 2 days old."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans("/fake/sock", "365d")
assert result.total == 3
# ---------------------------------------------------------------------------
# list_bans — geo enrichment
# ---------------------------------------------------------------------------
class TestListBansGeoEnrichment:
"""Verify geo enrichment integration in ban_service.list_bans()."""
async def test_geo_data_applied_when_enricher_provided(
self, f2b_db_path: str
) -> None:
"""Geo fields are populated when an enricher returns data."""
from app.services.geo_service import GeoInfo
async def fake_enricher(ip: str) -> GeoInfo:
return GeoInfo(
country_code="DE",
country_name="Germany",
asn="AS3320",
org="Deutsche Telekom",
)
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans(
"/fake/sock", "24h", geo_enricher=fake_enricher
)
for item in result.items:
assert item.country_code == "DE"
assert item.country_name == "Germany"
assert item.asn == "AS3320"
async def test_geo_failure_does_not_break_results(
self, f2b_db_path: str
) -> None:
"""A geo enricher that raises still returns ban items (geo fields null)."""
async def failing_enricher(ip: str) -> None:
raise RuntimeError("geo service down")
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans(
"/fake/sock", "24h", geo_enricher=failing_enricher
)
assert result.total == 2
for item in result.items:
assert item.country_code is None
# ---------------------------------------------------------------------------
# list_bans — batch geo enrichment via http_session
# ---------------------------------------------------------------------------
class TestListBansBatchGeoEnrichment:
"""Verify that list_bans uses lookup_batch when http_session is provided."""
async def test_batch_geo_applied_via_http_session(
self, f2b_db_path: str
) -> None:
"""Geo fields are populated via lookup_batch when http_session is given."""
from app.services.geo_service import GeoInfo
from unittest.mock import MagicMock
fake_session = MagicMock()
fake_geo_map = {
"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS3320", org="Deutsche Telekom"),
"5.6.7.8": GeoInfo(country_code="US", country_name="United States", asn="AS15169", org="Google"),
}
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
), patch(
"app.services.geo_service.lookup_batch",
new=AsyncMock(return_value=fake_geo_map),
):
result = await ban_service.list_bans(
"/fake/sock", "24h", http_session=fake_session
)
assert result.total == 2
de_item = next(i for i in result.items if i.ip == "1.2.3.4")
us_item = next(i for i in result.items if i.ip == "5.6.7.8")
assert de_item.country_code == "DE"
assert de_item.country_name == "Germany"
assert us_item.country_code == "US"
assert us_item.country_name == "United States"
async def test_batch_failure_does_not_break_results(
self, f2b_db_path: str
) -> None:
"""A lookup_batch failure still returns items with null geo fields."""
from unittest.mock import MagicMock
fake_session = MagicMock()
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
), patch(
"app.services.geo_service.lookup_batch",
new=AsyncMock(side_effect=RuntimeError("batch geo down")),
):
result = await ban_service.list_bans(
"/fake/sock", "24h", http_session=fake_session
)
assert result.total == 2
for item in result.items:
assert item.country_code is None
async def test_http_session_takes_priority_over_geo_enricher(
self, f2b_db_path: str
) -> None:
"""When both http_session and geo_enricher are provided, batch wins."""
from app.services.geo_service import GeoInfo
from unittest.mock import MagicMock
fake_session = MagicMock()
fake_geo_map = {
"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
"5.6.7.8": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
}
async def enricher_should_not_be_called(ip: str) -> GeoInfo:
raise AssertionError(f"geo_enricher was called for {ip!r} — should not happen")
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
), patch(
"app.services.geo_service.lookup_batch",
new=AsyncMock(return_value=fake_geo_map),
):
result = await ban_service.list_bans(
"/fake/sock",
"24h",
http_session=fake_session,
geo_enricher=enricher_should_not_be_called,
)
assert result.total == 2
for item in result.items:
assert item.country_code == "DE"
# ---------------------------------------------------------------------------
# list_bans — pagination
# ---------------------------------------------------------------------------
class TestListBansPagination:
"""Verify pagination parameters in list_bans()."""
async def test_page_size_respected(self, f2b_db_path: str) -> None:
"""``page_size=1`` returns at most one item."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
assert len(result.items) == 1
assert result.page_size == 1
async def test_page_2_returns_remaining_items(self, f2b_db_path: str) -> None:
"""The second page returns items not on the first page."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
page1 = await ban_service.list_bans("/fake/sock", "7d", page=1, page_size=1)
page2 = await ban_service.list_bans("/fake/sock", "7d", page=2, page_size=1)
# Different IPs should appear on different pages.
assert page1.items[0].ip != page2.items[0].ip
async def test_total_reflects_full_count_not_page_count(
self, f2b_db_path: str
) -> None:
"""``total`` reports all matching records regardless of pagination."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
assert result.total == 3 # All three bans are within 7d.
# ---------------------------------------------------------------------------
# list_bans / bans_by_country — origin derivation
# ---------------------------------------------------------------------------
class TestBanOriginDerivation:
"""Verify that ban_service correctly derives ``origin`` from jail names."""
async def test_blocklist_import_jail_yields_blocklist_origin(
self, mixed_origin_db_path: str
) -> None:
"""Bans from ``blocklist-import`` jail carry ``origin == "blocklist"``."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.list_bans("/fake/sock", "24h")
blocklist_items = [i for i in result.items if i.jail == "blocklist-import"]
assert len(blocklist_items) == 1
assert blocklist_items[0].origin == "blocklist"
async def test_organic_jail_yields_selfblock_origin(
self, mixed_origin_db_path: str
) -> None:
"""Bans from organic jails (sshd, nginx, …) carry ``origin == "selfblock"``."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.list_bans("/fake/sock", "24h")
organic_items = [i for i in result.items if i.jail != "blocklist-import"]
assert len(organic_items) == 2
for item in organic_items:
assert item.origin == "selfblock"
async def test_all_items_carry_origin_field(
self, mixed_origin_db_path: str
) -> None:
"""Every returned item has an ``origin`` field with a valid value."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.list_bans("/fake/sock", "24h")
for item in result.items:
assert item.origin in ("blocklist", "selfblock")
async def test_bans_by_country_blocklist_origin(
self, mixed_origin_db_path: str
) -> None:
"""``bans_by_country`` also derives origin correctly for blocklist bans."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_country("/fake/sock", "24h")
blocklist_bans = [b for b in result.bans if b.jail == "blocklist-import"]
assert len(blocklist_bans) == 1
assert blocklist_bans[0].origin == "blocklist"
async def test_bans_by_country_selfblock_origin(
self, mixed_origin_db_path: str
) -> None:
"""``bans_by_country`` derives origin correctly for organic jails."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_country("/fake/sock", "24h")
organic_bans = [b for b in result.bans if b.jail != "blocklist-import"]
assert len(organic_bans) == 2
for ban in organic_bans:
assert ban.origin == "selfblock"
# ---------------------------------------------------------------------------
# list_bans / bans_by_country — origin filter parameter
# ---------------------------------------------------------------------------
class TestOriginFilter:
"""Verify that the origin filter correctly restricts results."""
async def test_list_bans_blocklist_filter_returns_only_blocklist(
self, mixed_origin_db_path: str
) -> None:
"""``origin='blocklist'`` returns only blocklist-import jail bans."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.list_bans(
"/fake/sock", "24h", origin="blocklist"
)
assert result.total == 1
assert len(result.items) == 1
assert result.items[0].jail == "blocklist-import"
assert result.items[0].origin == "blocklist"
async def test_list_bans_selfblock_filter_excludes_blocklist(
self, mixed_origin_db_path: str
) -> None:
"""``origin='selfblock'`` excludes the blocklist-import jail."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.list_bans(
"/fake/sock", "24h", origin="selfblock"
)
assert result.total == 2
assert len(result.items) == 2
for item in result.items:
assert item.jail != "blocklist-import"
assert item.origin == "selfblock"
async def test_list_bans_no_filter_returns_all(
self, mixed_origin_db_path: str
) -> None:
"""``origin=None`` applies no jail restriction — all bans returned."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.list_bans("/fake/sock", "24h", origin=None)
assert result.total == 3
async def test_bans_by_country_blocklist_filter(
self, mixed_origin_db_path: str
) -> None:
"""``bans_by_country`` with ``origin='blocklist'`` counts only blocklist bans."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_country(
"/fake/sock", "24h", origin="blocklist"
)
assert result.total == 1
assert all(b.jail == "blocklist-import" for b in result.bans)
async def test_bans_by_country_selfblock_filter(
self, mixed_origin_db_path: str
) -> None:
"""``bans_by_country`` with ``origin='selfblock'`` excludes blocklist jails."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_country(
"/fake/sock", "24h", origin="selfblock"
)
assert result.total == 2
assert all(b.jail != "blocklist-import" for b in result.bans)
async def test_bans_by_country_no_filter_returns_all(
self, mixed_origin_db_path: str
) -> None:
"""``bans_by_country`` with ``origin=None`` returns all bans."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_country(
"/fake/sock", "24h", origin=None
)
assert result.total == 3