fixed tests

This commit is contained in:
2026-05-15 20:41:05 +02:00
parent 96ce516ecf
commit 77df5d5d65
50 changed files with 1482 additions and 5089 deletions

View File

@@ -32,12 +32,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
``bantime``, ``bancount``, and optionally ``data``.
"""
async with aiosqlite.connect(path) as db:
await db.execute(
"CREATE TABLE jails ("
"name TEXT NOT NULL UNIQUE, "
"enabled INTEGER NOT NULL DEFAULT 1"
")"
)
await db.execute("CREATE TABLE jails (name TEXT NOT NULL UNIQUE, enabled INTEGER NOT NULL DEFAULT 1)")
await db.execute(
"CREATE TABLE bans ("
"jail TEXT NOT NULL, "
@@ -50,8 +45,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
)
for row in rows:
await db.execute(
"INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) "
"VALUES (?, ?, ?, ?, ?, ?)",
"INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) VALUES (?, ?, ?, ?, ?, ?)",
(
row["jail"],
row["ip"],
@@ -257,9 +251,7 @@ class TestListBansHappyPath:
assert result.total == 3
async def test_source_archive_reads_from_archive(
self, app_db_with_archive: aiosqlite.Connection
) -> None:
async def test_source_archive_reads_from_archive(self, app_db_with_archive: aiosqlite.Connection) -> None:
"""Using source='archive' reads from the BanGUI archive table."""
result = await ban_service.list_bans(
"/fake/sock",
@@ -280,9 +272,7 @@ class TestListBansHappyPath:
class TestListBansGeoEnrichment:
"""Verify geo enrichment integration in ban_service.list_bans()."""
async def test_geo_data_applied_when_enricher_provided(
self, f2b_db_path: str
) -> None:
async def test_geo_data_applied_when_enricher_provided(self, f2b_db_path: str) -> None:
"""Geo fields are populated when an enricher returns data."""
from app.models.geo import GeoInfo
@@ -298,30 +288,24 @@ class TestListBansGeoEnrichment:
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans(
"/fake/sock", "24h", geo_enricher=fake_enricher
)
result = await ban_service.list_bans("/fake/sock", "24h", geo_enricher=fake_enricher)
for item in result.items:
assert item.country_code == "DE"
assert item.country_name == "Germany"
assert item.asn == "AS3320"
async def test_geo_failure_does_not_break_results(
self, f2b_db_path: str
) -> None:
async def test_geo_failure_does_not_break_results(self, f2b_db_path: str) -> None:
"""A geo enricher that raises still returns ban items (geo fields null)."""
async def failing_enricher(ip: str) -> None:
raise RuntimeError("geo service down")
raise OSError("geo service down")
with patch(
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans(
"/fake/sock", "24h", geo_enricher=failing_enricher
)
result = await ban_service.list_bans("/fake/sock", "24h", geo_enricher=failing_enricher)
assert result.total == 2
for item in result.items:
@@ -336,9 +320,7 @@ class TestListBansGeoEnrichment:
class TestListBansBatchGeoEnrichment:
"""Verify that list_bans uses lookup_batch when http_session is provided."""
async def test_batch_geo_applied_via_http_session(
self, f2b_db_path: str
) -> None:
async def test_batch_geo_applied_via_http_session(self, f2b_db_path: str) -> None:
"""Geo fields are populated via lookup_batch when http_session is given."""
from unittest.mock import MagicMock
@@ -350,6 +332,8 @@ class TestListBansBatchGeoEnrichment:
"5.6.7.8": GeoInfo(country_code="US", country_name="United States", asn="AS15169", org="Google"),
}
fake_geo_batch = AsyncMock(return_value=fake_geo_map)
mock_geo_cache = MagicMock()
mock_geo_cache.lookup_batch = fake_geo_batch
with patch(
"app.services.ban_service.get_fail2ban_db_path",
@@ -359,7 +343,7 @@ class TestListBansBatchGeoEnrichment:
"/fake/sock",
"24h",
http_session=fake_session,
geo_batch_lookup=fake_geo_batch,
geo_cache=mock_geo_cache,
)
fake_geo_batch.assert_awaited_once_with(["1.2.3.4", "5.6.7.8"], fake_session, db=None)
@@ -371,15 +355,15 @@ class TestListBansBatchGeoEnrichment:
assert us_item.country_code == "US"
assert us_item.country_name == "United States"
async def test_batch_failure_does_not_break_results(
self, f2b_db_path: str
) -> None:
async def test_batch_failure_does_not_break_results(self, f2b_db_path: str) -> None:
"""A lookup_batch failure still returns items with null geo fields."""
from unittest.mock import MagicMock
fake_session = MagicMock()
failing_geo_batch = AsyncMock(side_effect=RuntimeError("batch geo down"))
failing_geo_batch = AsyncMock(side_effect=OSError("batch geo down"))
mock_geo_cache = MagicMock()
mock_geo_cache.lookup_batch = failing_geo_batch
with patch(
"app.services.ban_service.get_fail2ban_db_path",
@@ -389,16 +373,14 @@ class TestListBansBatchGeoEnrichment:
"/fake/sock",
"24h",
http_session=fake_session,
geo_batch_lookup=failing_geo_batch,
geo_cache=mock_geo_cache,
)
assert result.total == 2
for item in result.items:
assert item.country_code is None
async def test_http_session_takes_priority_over_geo_enricher(
self, f2b_db_path: str
) -> None:
async def test_http_session_takes_priority_over_geo_enricher(self, f2b_db_path: str) -> None:
"""When both http_session and geo_enricher are provided, batch wins."""
from unittest.mock import MagicMock
@@ -410,6 +392,8 @@ class TestListBansBatchGeoEnrichment:
"5.6.7.8": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
}
fake_geo_batch = AsyncMock(return_value=fake_geo_map)
mock_geo_cache = MagicMock()
mock_geo_cache.lookup_batch = fake_geo_batch
async def enricher_should_not_be_called(ip: str) -> GeoInfo:
raise AssertionError(f"geo_enricher was called for {ip!r} — should not happen")
@@ -422,7 +406,7 @@ class TestListBansBatchGeoEnrichment:
"/fake/sock",
"24h",
http_session=fake_session,
geo_batch_lookup=fake_geo_batch,
geo_cache=mock_geo_cache,
geo_enricher=enricher_should_not_be_called,
)
@@ -462,9 +446,7 @@ class TestListBansPagination:
# Different IPs should appear on different pages.
assert page1.items[0].ip != page2.items[0].ip
async def test_total_reflects_full_count_not_page_count(
self, f2b_db_path: str
) -> None:
async def test_total_reflects_full_count_not_page_count(self, f2b_db_path: str) -> None:
"""``total`` reports all matching records regardless of pagination."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
@@ -483,9 +465,7 @@ class TestListBansPagination:
class TestBanOriginDerivation:
"""Verify that ban_service correctly derives ``origin`` from jail names."""
async def test_blocklist_import_jail_yields_blocklist_origin(
self, mixed_origin_db_path: str
) -> None:
async def test_blocklist_import_jail_yields_blocklist_origin(self, mixed_origin_db_path: str) -> None:
"""Bans from ``blocklist-import`` jail carry ``origin == "blocklist"``."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
@@ -497,9 +477,7 @@ class TestBanOriginDerivation:
assert len(blocklist_items) == 1
assert blocklist_items[0].origin == "blocklist"
async def test_organic_jail_yields_selfblock_origin(
self, mixed_origin_db_path: str
) -> None:
async def test_organic_jail_yields_selfblock_origin(self, mixed_origin_db_path: str) -> None:
"""Bans from organic jails (sshd, nginx, …) carry ``origin == "selfblock"``."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
@@ -512,9 +490,7 @@ class TestBanOriginDerivation:
for item in organic_items:
assert item.origin == "selfblock"
async def test_all_items_carry_origin_field(
self, mixed_origin_db_path: str
) -> None:
async def test_all_items_carry_origin_field(self, mixed_origin_db_path: str) -> None:
"""Every returned item has an ``origin`` field with a valid value."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
@@ -525,9 +501,7 @@ class TestBanOriginDerivation:
for item in result.items:
assert item.origin in ("blocklist", "selfblock")
async def test_bans_by_country_blocklist_origin(
self, mixed_origin_db_path: str
) -> None:
async def test_bans_by_country_blocklist_origin(self, mixed_origin_db_path: str) -> None:
"""``bans_by_country`` also derives origin correctly for blocklist bans."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
@@ -535,13 +509,11 @@ class TestBanOriginDerivation:
):
result = await ban_service.bans_by_country("/fake/sock", "24h")
blocklist_bans = [b for b in result.bans if b.jail == "blocklist-import"]
blocklist_bans = [b for b in result.items if b.jail == "blocklist-import"]
assert len(blocklist_bans) == 1
assert blocklist_bans[0].origin == "blocklist"
async def test_bans_by_country_selfblock_origin(
self, mixed_origin_db_path: str
) -> None:
async def test_bans_by_country_selfblock_origin(self, mixed_origin_db_path: str) -> None:
"""``bans_by_country`` derives origin correctly for organic jails."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
@@ -549,7 +521,7 @@ class TestBanOriginDerivation:
):
result = await ban_service.bans_by_country("/fake/sock", "24h")
organic_bans = [b for b in result.bans if b.jail != "blocklist-import"]
organic_bans = [b for b in result.items if b.jail != "blocklist-import"]
assert len(organic_bans) == 2
for ban in organic_bans:
assert ban.origin == "selfblock"
@@ -563,34 +535,26 @@ class TestBanOriginDerivation:
class TestOriginFilter:
"""Verify that the origin filter correctly restricts results."""
async def test_list_bans_blocklist_filter_returns_only_blocklist(
self, mixed_origin_db_path: str
) -> None:
async def test_list_bans_blocklist_filter_returns_only_blocklist(self, mixed_origin_db_path: str) -> None:
"""``origin='blocklist'`` returns only blocklist-import jail bans."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.list_bans(
"/fake/sock", "24h", origin="blocklist"
)
result = await ban_service.list_bans("/fake/sock", "24h", origin="blocklist")
assert result.total == 1
assert len(result.items) == 1
assert result.items[0].jail == "blocklist-import"
assert result.items[0].origin == "blocklist"
async def test_list_bans_selfblock_filter_excludes_blocklist(
self, mixed_origin_db_path: str
) -> None:
async def test_list_bans_selfblock_filter_excludes_blocklist(self, mixed_origin_db_path: str) -> None:
"""``origin='selfblock'`` excludes the blocklist-import jail."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.list_bans(
"/fake/sock", "24h", origin="selfblock"
)
result = await ban_service.list_bans("/fake/sock", "24h", origin="selfblock")
assert result.total == 2
assert len(result.items) == 2
@@ -598,9 +562,7 @@ class TestOriginFilter:
assert item.jail != "blocklist-import"
assert item.origin == "selfblock"
async def test_list_bans_no_filter_returns_all(
self, mixed_origin_db_path: str
) -> None:
async def test_list_bans_no_filter_returns_all(self, mixed_origin_db_path: str) -> None:
"""``origin=None`` applies no jail restriction — all bans returned."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
@@ -610,53 +572,39 @@ class TestOriginFilter:
assert result.total == 3
async def test_bans_by_country_blocklist_filter(
self, mixed_origin_db_path: str
) -> None:
async def test_bans_by_country_blocklist_filter(self, mixed_origin_db_path: str) -> None:
"""``bans_by_country`` with ``origin='blocklist'`` counts only blocklist bans."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_country(
"/fake/sock", "24h", origin="blocklist"
)
result = await ban_service.bans_by_country("/fake/sock", "24h", origin="blocklist")
assert result.total == 1
assert all(b.jail == "blocklist-import" for b in result.bans)
assert all(b.jail == "blocklist-import" for b in result.items)
async def test_bans_by_country_selfblock_filter(
self, mixed_origin_db_path: str
) -> None:
async def test_bans_by_country_selfblock_filter(self, mixed_origin_db_path: str) -> None:
"""``bans_by_country`` with ``origin='selfblock'`` excludes blocklist jails."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_country(
"/fake/sock", "24h", origin="selfblock"
)
result = await ban_service.bans_by_country("/fake/sock", "24h", origin="selfblock")
assert result.total == 2
assert all(b.jail != "blocklist-import" for b in result.bans)
assert all(b.jail != "blocklist-import" for b in result.items)
async def test_bans_by_country_no_filter_returns_all(
self, mixed_origin_db_path: str
) -> None:
async def test_bans_by_country_no_filter_returns_all(self, mixed_origin_db_path: str) -> None:
"""``bans_by_country`` with ``origin=None`` returns all bans."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_country(
"/fake/sock", "24h", origin=None
)
result = await ban_service.bans_by_country("/fake/sock", "24h", origin=None)
assert result.total == 3
async def test_bans_by_country_country_code_returns_all_matched_rows(
self, tmp_path: Path
) -> None:
async def test_bans_by_country_country_code_returns_all_matched_rows(self, tmp_path: Path) -> None:
"""``bans_by_country`` returns all companion rows for the selected country."""
path = str(tmp_path / "fail2ban_country_filter.sqlite3")
rows = [
@@ -672,8 +620,8 @@ class TestOriginFilter:
]
await _create_f2b_db(path, rows)
from app.services import geo_service
from app.models.geo import GeoInfo
from app.services import geo_service
geo_service._default_geo_cache._cache["10.0.0.1"] = GeoInfo(
country_code="DE",
@@ -682,12 +630,13 @@ class TestOriginFilter:
org=None,
)
with patch(
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=path),
), patch(
"app.services.ban_service.asyncio.create_task"
) as mock_create_task:
with (
patch(
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=path),
),
patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
):
result = await ban_service.bans_by_country(
"/fake/sock",
"24h",
@@ -698,8 +647,8 @@ class TestOriginFilter:
mock_create_task.assert_not_called()
assert result.total == 205
assert len(result.bans) == 205
assert all(b.country_code == "DE" for b in result.bans)
assert len(result.items) == 205
assert all(b.country_code == "DE" for b in result.items)
await geo_service.clear_cache()
@@ -715,7 +664,7 @@ class TestOriginFilter:
)
assert result.total == 2
assert len(result.bans) == 2
assert len(result.items) == 2
# ---------------------------------------------------------------------------
@@ -728,13 +677,11 @@ class TestBansbyCountryBackground:
"""bans_by_country() with http_session uses cache-only geo and fires a
background task for uncached IPs instead of blocking on API calls."""
async def test_cached_geo_returned_without_api_call(
self, mixed_origin_db_path: str
) -> None:
async def test_cached_geo_returned_without_api_call(self, mixed_origin_db_path: str) -> None:
"""When all IPs are in the cache, lookup_cached_only returns them and
no background task is created."""
from app.services import geo_service
from app.models.geo import GeoInfo
from app.services import geo_service
# Pre-populate the cache for all three IPs in the fixture.
geo_service._default_geo_cache._cache["10.0.0.1"] = GeoInfo(
@@ -752,9 +699,7 @@ class TestBansbyCountryBackground:
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
),
patch(
"app.services.ban_service.asyncio.create_task"
) as mock_create_task,
patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
):
mock_session = AsyncMock()
mock_batch = AsyncMock(return_value={})
@@ -763,7 +708,6 @@ class TestBansbyCountryBackground:
"24h",
http_session=mock_session,
geo_cache_lookup=geo_service.lookup_cached_only,
geo_batch_lookup=mock_batch,
)
# All countries resolved from cache — no background task needed.
@@ -773,9 +717,7 @@ class TestBansbyCountryBackground:
assert "DE" in result.countries or "US" in result.countries or "JP" in result.countries
await geo_service.clear_cache()
async def test_uncached_ips_trigger_background_task(
self, mixed_origin_db_path: str
) -> None:
async def test_uncached_ips_trigger_background_task(self, mixed_origin_db_path: str) -> None:
"""When IPs are NOT in the cache, create_task is called for background
resolution and the response returns without blocking."""
from app.services import geo_service
@@ -787,9 +729,7 @@ class TestBansbyCountryBackground:
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
),
patch(
"app.services.ban_service.asyncio.create_task"
) as mock_create_task,
patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
):
mock_session = AsyncMock()
mock_batch = AsyncMock(return_value={})
@@ -798,7 +738,7 @@ class TestBansbyCountryBackground:
"24h",
http_session=mock_session,
geo_cache_lookup=geo_service.lookup_cached_only,
geo_batch_lookup=mock_batch,
geo_cache=geo_service.GeoCache(),
)
# Background task must have been scheduled for uncached IPs.
@@ -806,9 +746,7 @@ class TestBansbyCountryBackground:
# Response is still valid with empty country map (IPs not cached yet).
assert result.total == 3
async def test_no_background_task_without_http_session(
self, mixed_origin_db_path: str
) -> None:
async def test_no_background_task_without_http_session(self, mixed_origin_db_path: str) -> None:
"""When http_session is None, no background task is created."""
from app.services import geo_service
@@ -819,13 +757,9 @@ class TestBansbyCountryBackground:
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
),
patch(
"app.services.ban_service.asyncio.create_task"
) as mock_create_task,
patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
):
result = await ban_service.bans_by_country(
"/fake/sock", "24h", http_session=None
)
result = await ban_service.bans_by_country("/fake/sock", "24h", http_session=None)
mock_create_task.assert_not_called()
assert result.total == 3
@@ -904,9 +838,7 @@ class TestBanTrend:
timestamps = [b.timestamp for b in result.buckets]
assert timestamps == sorted(timestamps)
async def test_ban_trend_source_archive_reads_archive(
self, app_db_with_archive: aiosqlite.Connection
) -> None:
async def test_ban_trend_source_archive_reads_archive(self, app_db_with_archive: aiosqlite.Connection) -> None:
"""``ban_trend`` accepts source='archive' and uses archived rows."""
result = await ban_service.ban_trend(
"/fake/sock",
@@ -959,9 +891,7 @@ class TestBanTrend:
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=path),
):
result = await ban_service.ban_trend(
"/fake/sock", "24h", origin="blocklist"
)
result = await ban_service.ban_trend("/fake/sock", "24h", origin="blocklist")
assert sum(b.count for b in result.buckets) == 1
@@ -985,9 +915,7 @@ class TestBanTrend:
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=path),
):
result = await ban_service.ban_trend(
"/fake/sock", "24h", origin="selfblock"
)
result = await ban_service.ban_trend("/fake/sock", "24h", origin="selfblock")
assert sum(b.count for b in result.buckets) == 2
@@ -1096,9 +1024,7 @@ class TestBansByJail:
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_jail(
"/fake/sock", "24h", origin="blocklist"
)
result = await ban_service.bans_by_jail("/fake/sock", "24h", origin="blocklist")
assert len(result.jails) == 1
assert result.jails[0].jail == "blocklist-import"
@@ -1110,32 +1036,24 @@ class TestBansByJail:
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_jail(
"/fake/sock", "24h", origin="selfblock"
)
result = await ban_service.bans_by_jail("/fake/sock", "24h", origin="selfblock")
jail_names = {j.jail for j in result.jails}
assert "blocklist-import" not in jail_names
assert result.total == 2
async def test_no_origin_filter_returns_all_jails(
self, mixed_origin_db_path: str
) -> None:
async def test_no_origin_filter_returns_all_jails(self, mixed_origin_db_path: str) -> None:
"""``origin=None`` returns bans from all jails."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_jail(
"/fake/sock", "24h", origin=None
)
result = await ban_service.bans_by_jail("/fake/sock", "24h", origin=None)
assert result.total == 3
assert len(result.jails) == 3
async def test_bans_by_jail_source_archive_reads_archive(
self, app_db_with_archive: aiosqlite.Connection
) -> None:
async def test_bans_by_jail_source_archive_reads_archive(self, app_db_with_archive: aiosqlite.Connection) -> None:
"""``bans_by_jail`` accepts source='archive' and aggregates archived rows."""
result = await ban_service.bans_by_jail(
"/fake/sock",
@@ -1147,9 +1065,7 @@ class TestBansByJail:
assert result.total == 2
assert any(j.jail == "sshd" for j in result.jails)
async def test_diagnostic_warning_when_zero_results_despite_data(
self, tmp_path: Path
) -> None:
async def test_diagnostic_warning_when_zero_results_despite_data(self, tmp_path: Path) -> None:
"""A warning is logged when the time-range filter excludes all existing rows."""
import time as _time
@@ -1176,9 +1092,6 @@ class TestBansByJail:
assert result.jails == []
# The diagnostic warning must have been emitted.
warning_calls = [
c
for c in mock_log.warning.call_args_list
if c[0][0] == "ban_service_bans_by_jail_empty_despite_data"
c for c in mock_log.warning.call_args_list if c[0][0] == "ban_service_bans_by_jail_empty_despite_data"
]
assert len(warning_calls) == 1