fixed tests
This commit is contained in:
@@ -32,12 +32,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
|
||||
``bantime``, ``bancount``, and optionally ``data``.
|
||||
"""
|
||||
async with aiosqlite.connect(path) as db:
|
||||
await db.execute(
|
||||
"CREATE TABLE jails ("
|
||||
"name TEXT NOT NULL UNIQUE, "
|
||||
"enabled INTEGER NOT NULL DEFAULT 1"
|
||||
")"
|
||||
)
|
||||
await db.execute("CREATE TABLE jails (name TEXT NOT NULL UNIQUE, enabled INTEGER NOT NULL DEFAULT 1)")
|
||||
await db.execute(
|
||||
"CREATE TABLE bans ("
|
||||
"jail TEXT NOT NULL, "
|
||||
@@ -50,8 +45,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
|
||||
)
|
||||
for row in rows:
|
||||
await db.execute(
|
||||
"INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
"INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
row["jail"],
|
||||
row["ip"],
|
||||
@@ -257,9 +251,7 @@ class TestListBansHappyPath:
|
||||
|
||||
assert result.total == 3
|
||||
|
||||
async def test_source_archive_reads_from_archive(
|
||||
self, app_db_with_archive: aiosqlite.Connection
|
||||
) -> None:
|
||||
async def test_source_archive_reads_from_archive(self, app_db_with_archive: aiosqlite.Connection) -> None:
|
||||
"""Using source='archive' reads from the BanGUI archive table."""
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock",
|
||||
@@ -280,9 +272,7 @@ class TestListBansHappyPath:
|
||||
class TestListBansGeoEnrichment:
|
||||
"""Verify geo enrichment integration in ban_service.list_bans()."""
|
||||
|
||||
async def test_geo_data_applied_when_enricher_provided(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
async def test_geo_data_applied_when_enricher_provided(self, f2b_db_path: str) -> None:
|
||||
"""Geo fields are populated when an enricher returns data."""
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
@@ -298,30 +288,24 @@ class TestListBansGeoEnrichment:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", geo_enricher=fake_enricher
|
||||
)
|
||||
result = await ban_service.list_bans("/fake/sock", "24h", geo_enricher=fake_enricher)
|
||||
|
||||
for item in result.items:
|
||||
assert item.country_code == "DE"
|
||||
assert item.country_name == "Germany"
|
||||
assert item.asn == "AS3320"
|
||||
|
||||
async def test_geo_failure_does_not_break_results(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
async def test_geo_failure_does_not_break_results(self, f2b_db_path: str) -> None:
|
||||
"""A geo enricher that raises still returns ban items (geo fields null)."""
|
||||
|
||||
async def failing_enricher(ip: str) -> None:
|
||||
raise RuntimeError("geo service down")
|
||||
raise OSError("geo service down")
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", geo_enricher=failing_enricher
|
||||
)
|
||||
result = await ban_service.list_bans("/fake/sock", "24h", geo_enricher=failing_enricher)
|
||||
|
||||
assert result.total == 2
|
||||
for item in result.items:
|
||||
@@ -336,9 +320,7 @@ class TestListBansGeoEnrichment:
|
||||
class TestListBansBatchGeoEnrichment:
|
||||
"""Verify that list_bans uses lookup_batch when http_session is provided."""
|
||||
|
||||
async def test_batch_geo_applied_via_http_session(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
async def test_batch_geo_applied_via_http_session(self, f2b_db_path: str) -> None:
|
||||
"""Geo fields are populated via lookup_batch when http_session is given."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@@ -350,6 +332,8 @@ class TestListBansBatchGeoEnrichment:
|
||||
"5.6.7.8": GeoInfo(country_code="US", country_name="United States", asn="AS15169", org="Google"),
|
||||
}
|
||||
fake_geo_batch = AsyncMock(return_value=fake_geo_map)
|
||||
mock_geo_cache = MagicMock()
|
||||
mock_geo_cache.lookup_batch = fake_geo_batch
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -359,7 +343,7 @@ class TestListBansBatchGeoEnrichment:
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=fake_session,
|
||||
geo_batch_lookup=fake_geo_batch,
|
||||
geo_cache=mock_geo_cache,
|
||||
)
|
||||
|
||||
fake_geo_batch.assert_awaited_once_with(["1.2.3.4", "5.6.7.8"], fake_session, db=None)
|
||||
@@ -371,15 +355,15 @@ class TestListBansBatchGeoEnrichment:
|
||||
assert us_item.country_code == "US"
|
||||
assert us_item.country_name == "United States"
|
||||
|
||||
async def test_batch_failure_does_not_break_results(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
async def test_batch_failure_does_not_break_results(self, f2b_db_path: str) -> None:
|
||||
"""A lookup_batch failure still returns items with null geo fields."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
fake_session = MagicMock()
|
||||
|
||||
failing_geo_batch = AsyncMock(side_effect=RuntimeError("batch geo down"))
|
||||
failing_geo_batch = AsyncMock(side_effect=OSError("batch geo down"))
|
||||
mock_geo_cache = MagicMock()
|
||||
mock_geo_cache.lookup_batch = failing_geo_batch
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -389,16 +373,14 @@ class TestListBansBatchGeoEnrichment:
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=fake_session,
|
||||
geo_batch_lookup=failing_geo_batch,
|
||||
geo_cache=mock_geo_cache,
|
||||
)
|
||||
|
||||
assert result.total == 2
|
||||
for item in result.items:
|
||||
assert item.country_code is None
|
||||
|
||||
async def test_http_session_takes_priority_over_geo_enricher(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
async def test_http_session_takes_priority_over_geo_enricher(self, f2b_db_path: str) -> None:
|
||||
"""When both http_session and geo_enricher are provided, batch wins."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@@ -410,6 +392,8 @@ class TestListBansBatchGeoEnrichment:
|
||||
"5.6.7.8": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
|
||||
}
|
||||
fake_geo_batch = AsyncMock(return_value=fake_geo_map)
|
||||
mock_geo_cache = MagicMock()
|
||||
mock_geo_cache.lookup_batch = fake_geo_batch
|
||||
|
||||
async def enricher_should_not_be_called(ip: str) -> GeoInfo:
|
||||
raise AssertionError(f"geo_enricher was called for {ip!r} — should not happen")
|
||||
@@ -422,7 +406,7 @@ class TestListBansBatchGeoEnrichment:
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=fake_session,
|
||||
geo_batch_lookup=fake_geo_batch,
|
||||
geo_cache=mock_geo_cache,
|
||||
geo_enricher=enricher_should_not_be_called,
|
||||
)
|
||||
|
||||
@@ -462,9 +446,7 @@ class TestListBansPagination:
|
||||
# Different IPs should appear on different pages.
|
||||
assert page1.items[0].ip != page2.items[0].ip
|
||||
|
||||
async def test_total_reflects_full_count_not_page_count(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
async def test_total_reflects_full_count_not_page_count(self, f2b_db_path: str) -> None:
|
||||
"""``total`` reports all matching records regardless of pagination."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -483,9 +465,7 @@ class TestListBansPagination:
|
||||
class TestBanOriginDerivation:
|
||||
"""Verify that ban_service correctly derives ``origin`` from jail names."""
|
||||
|
||||
async def test_blocklist_import_jail_yields_blocklist_origin(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_blocklist_import_jail_yields_blocklist_origin(self, mixed_origin_db_path: str) -> None:
|
||||
"""Bans from ``blocklist-import`` jail carry ``origin == "blocklist"``."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -497,9 +477,7 @@ class TestBanOriginDerivation:
|
||||
assert len(blocklist_items) == 1
|
||||
assert blocklist_items[0].origin == "blocklist"
|
||||
|
||||
async def test_organic_jail_yields_selfblock_origin(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_organic_jail_yields_selfblock_origin(self, mixed_origin_db_path: str) -> None:
|
||||
"""Bans from organic jails (sshd, nginx, …) carry ``origin == "selfblock"``."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -512,9 +490,7 @@ class TestBanOriginDerivation:
|
||||
for item in organic_items:
|
||||
assert item.origin == "selfblock"
|
||||
|
||||
async def test_all_items_carry_origin_field(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_all_items_carry_origin_field(self, mixed_origin_db_path: str) -> None:
|
||||
"""Every returned item has an ``origin`` field with a valid value."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -525,9 +501,7 @@ class TestBanOriginDerivation:
|
||||
for item in result.items:
|
||||
assert item.origin in ("blocklist", "selfblock")
|
||||
|
||||
async def test_bans_by_country_blocklist_origin(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_bans_by_country_blocklist_origin(self, mixed_origin_db_path: str) -> None:
|
||||
"""``bans_by_country`` also derives origin correctly for blocklist bans."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -535,13 +509,11 @@ class TestBanOriginDerivation:
|
||||
):
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
||||
|
||||
blocklist_bans = [b for b in result.bans if b.jail == "blocklist-import"]
|
||||
blocklist_bans = [b for b in result.items if b.jail == "blocklist-import"]
|
||||
assert len(blocklist_bans) == 1
|
||||
assert blocklist_bans[0].origin == "blocklist"
|
||||
|
||||
async def test_bans_by_country_selfblock_origin(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_bans_by_country_selfblock_origin(self, mixed_origin_db_path: str) -> None:
|
||||
"""``bans_by_country`` derives origin correctly for organic jails."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -549,7 +521,7 @@ class TestBanOriginDerivation:
|
||||
):
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
||||
|
||||
organic_bans = [b for b in result.bans if b.jail != "blocklist-import"]
|
||||
organic_bans = [b for b in result.items if b.jail != "blocklist-import"]
|
||||
assert len(organic_bans) == 2
|
||||
for ban in organic_bans:
|
||||
assert ban.origin == "selfblock"
|
||||
@@ -563,34 +535,26 @@ class TestBanOriginDerivation:
|
||||
class TestOriginFilter:
|
||||
"""Verify that the origin filter correctly restricts results."""
|
||||
|
||||
async def test_list_bans_blocklist_filter_returns_only_blocklist(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_list_bans_blocklist_filter_returns_only_blocklist(self, mixed_origin_db_path: str) -> None:
|
||||
"""``origin='blocklist'`` returns only blocklist-import jail bans."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", origin="blocklist"
|
||||
)
|
||||
result = await ban_service.list_bans("/fake/sock", "24h", origin="blocklist")
|
||||
|
||||
assert result.total == 1
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].jail == "blocklist-import"
|
||||
assert result.items[0].origin == "blocklist"
|
||||
|
||||
async def test_list_bans_selfblock_filter_excludes_blocklist(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_list_bans_selfblock_filter_excludes_blocklist(self, mixed_origin_db_path: str) -> None:
|
||||
"""``origin='selfblock'`` excludes the blocklist-import jail."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", origin="selfblock"
|
||||
)
|
||||
result = await ban_service.list_bans("/fake/sock", "24h", origin="selfblock")
|
||||
|
||||
assert result.total == 2
|
||||
assert len(result.items) == 2
|
||||
@@ -598,9 +562,7 @@ class TestOriginFilter:
|
||||
assert item.jail != "blocklist-import"
|
||||
assert item.origin == "selfblock"
|
||||
|
||||
async def test_list_bans_no_filter_returns_all(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_list_bans_no_filter_returns_all(self, mixed_origin_db_path: str) -> None:
|
||||
"""``origin=None`` applies no jail restriction — all bans returned."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -610,53 +572,39 @@ class TestOriginFilter:
|
||||
|
||||
assert result.total == 3
|
||||
|
||||
async def test_bans_by_country_blocklist_filter(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_bans_by_country_blocklist_filter(self, mixed_origin_db_path: str) -> None:
|
||||
"""``bans_by_country`` with ``origin='blocklist'`` counts only blocklist bans."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", origin="blocklist"
|
||||
)
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h", origin="blocklist")
|
||||
|
||||
assert result.total == 1
|
||||
assert all(b.jail == "blocklist-import" for b in result.bans)
|
||||
assert all(b.jail == "blocklist-import" for b in result.items)
|
||||
|
||||
async def test_bans_by_country_selfblock_filter(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_bans_by_country_selfblock_filter(self, mixed_origin_db_path: str) -> None:
|
||||
"""``bans_by_country`` with ``origin='selfblock'`` excludes blocklist jails."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", origin="selfblock"
|
||||
)
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h", origin="selfblock")
|
||||
|
||||
assert result.total == 2
|
||||
assert all(b.jail != "blocklist-import" for b in result.bans)
|
||||
assert all(b.jail != "blocklist-import" for b in result.items)
|
||||
|
||||
async def test_bans_by_country_no_filter_returns_all(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_bans_by_country_no_filter_returns_all(self, mixed_origin_db_path: str) -> None:
|
||||
"""``bans_by_country`` with ``origin=None`` returns all bans."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", origin=None
|
||||
)
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h", origin=None)
|
||||
|
||||
assert result.total == 3
|
||||
|
||||
async def test_bans_by_country_country_code_returns_all_matched_rows(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_bans_by_country_country_code_returns_all_matched_rows(self, tmp_path: Path) -> None:
|
||||
"""``bans_by_country`` returns all companion rows for the selected country."""
|
||||
path = str(tmp_path / "fail2ban_country_filter.sqlite3")
|
||||
rows = [
|
||||
@@ -672,8 +620,8 @@ class TestOriginFilter:
|
||||
]
|
||||
await _create_f2b_db(path, rows)
|
||||
|
||||
from app.services import geo_service
|
||||
from app.models.geo import GeoInfo
|
||||
from app.services import geo_service
|
||||
|
||||
geo_service._default_geo_cache._cache["10.0.0.1"] = GeoInfo(
|
||||
country_code="DE",
|
||||
@@ -682,12 +630,13 @@ class TestOriginFilter:
|
||||
org=None,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
), patch(
|
||||
"app.services.ban_service.asyncio.create_task"
|
||||
) as mock_create_task:
|
||||
with (
|
||||
patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
),
|
||||
patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
@@ -698,8 +647,8 @@ class TestOriginFilter:
|
||||
|
||||
mock_create_task.assert_not_called()
|
||||
assert result.total == 205
|
||||
assert len(result.bans) == 205
|
||||
assert all(b.country_code == "DE" for b in result.bans)
|
||||
assert len(result.items) == 205
|
||||
assert all(b.country_code == "DE" for b in result.items)
|
||||
|
||||
await geo_service.clear_cache()
|
||||
|
||||
@@ -715,7 +664,7 @@ class TestOriginFilter:
|
||||
)
|
||||
|
||||
assert result.total == 2
|
||||
assert len(result.bans) == 2
|
||||
assert len(result.items) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -728,13 +677,11 @@ class TestBansbyCountryBackground:
|
||||
"""bans_by_country() with http_session uses cache-only geo and fires a
|
||||
background task for uncached IPs instead of blocking on API calls."""
|
||||
|
||||
async def test_cached_geo_returned_without_api_call(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_cached_geo_returned_without_api_call(self, mixed_origin_db_path: str) -> None:
|
||||
"""When all IPs are in the cache, lookup_cached_only returns them and
|
||||
no background task is created."""
|
||||
from app.services import geo_service
|
||||
from app.models.geo import GeoInfo
|
||||
from app.services import geo_service
|
||||
|
||||
# Pre-populate the cache for all three IPs in the fixture.
|
||||
geo_service._default_geo_cache._cache["10.0.0.1"] = GeoInfo(
|
||||
@@ -752,9 +699,7 @@ class TestBansbyCountryBackground:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
),
|
||||
patch(
|
||||
"app.services.ban_service.asyncio.create_task"
|
||||
) as mock_create_task,
|
||||
patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
|
||||
):
|
||||
mock_session = AsyncMock()
|
||||
mock_batch = AsyncMock(return_value={})
|
||||
@@ -763,7 +708,6 @@ class TestBansbyCountryBackground:
|
||||
"24h",
|
||||
http_session=mock_session,
|
||||
geo_cache_lookup=geo_service.lookup_cached_only,
|
||||
geo_batch_lookup=mock_batch,
|
||||
)
|
||||
|
||||
# All countries resolved from cache — no background task needed.
|
||||
@@ -773,9 +717,7 @@ class TestBansbyCountryBackground:
|
||||
assert "DE" in result.countries or "US" in result.countries or "JP" in result.countries
|
||||
await geo_service.clear_cache()
|
||||
|
||||
async def test_uncached_ips_trigger_background_task(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_uncached_ips_trigger_background_task(self, mixed_origin_db_path: str) -> None:
|
||||
"""When IPs are NOT in the cache, create_task is called for background
|
||||
resolution and the response returns without blocking."""
|
||||
from app.services import geo_service
|
||||
@@ -787,9 +729,7 @@ class TestBansbyCountryBackground:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
),
|
||||
patch(
|
||||
"app.services.ban_service.asyncio.create_task"
|
||||
) as mock_create_task,
|
||||
patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
|
||||
):
|
||||
mock_session = AsyncMock()
|
||||
mock_batch = AsyncMock(return_value={})
|
||||
@@ -798,7 +738,7 @@ class TestBansbyCountryBackground:
|
||||
"24h",
|
||||
http_session=mock_session,
|
||||
geo_cache_lookup=geo_service.lookup_cached_only,
|
||||
geo_batch_lookup=mock_batch,
|
||||
geo_cache=geo_service.GeoCache(),
|
||||
)
|
||||
|
||||
# Background task must have been scheduled for uncached IPs.
|
||||
@@ -806,9 +746,7 @@ class TestBansbyCountryBackground:
|
||||
# Response is still valid with empty country map (IPs not cached yet).
|
||||
assert result.total == 3
|
||||
|
||||
async def test_no_background_task_without_http_session(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_no_background_task_without_http_session(self, mixed_origin_db_path: str) -> None:
|
||||
"""When http_session is None, no background task is created."""
|
||||
from app.services import geo_service
|
||||
|
||||
@@ -819,13 +757,9 @@ class TestBansbyCountryBackground:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
),
|
||||
patch(
|
||||
"app.services.ban_service.asyncio.create_task"
|
||||
) as mock_create_task,
|
||||
patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", http_session=None
|
||||
)
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h", http_session=None)
|
||||
|
||||
mock_create_task.assert_not_called()
|
||||
assert result.total == 3
|
||||
@@ -904,9 +838,7 @@ class TestBanTrend:
|
||||
timestamps = [b.timestamp for b in result.buckets]
|
||||
assert timestamps == sorted(timestamps)
|
||||
|
||||
async def test_ban_trend_source_archive_reads_archive(
|
||||
self, app_db_with_archive: aiosqlite.Connection
|
||||
) -> None:
|
||||
async def test_ban_trend_source_archive_reads_archive(self, app_db_with_archive: aiosqlite.Connection) -> None:
|
||||
"""``ban_trend`` accepts source='archive' and uses archived rows."""
|
||||
result = await ban_service.ban_trend(
|
||||
"/fake/sock",
|
||||
@@ -959,9 +891,7 @@ class TestBanTrend:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.ban_trend(
|
||||
"/fake/sock", "24h", origin="blocklist"
|
||||
)
|
||||
result = await ban_service.ban_trend("/fake/sock", "24h", origin="blocklist")
|
||||
|
||||
assert sum(b.count for b in result.buckets) == 1
|
||||
|
||||
@@ -985,9 +915,7 @@ class TestBanTrend:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.ban_trend(
|
||||
"/fake/sock", "24h", origin="selfblock"
|
||||
)
|
||||
result = await ban_service.ban_trend("/fake/sock", "24h", origin="selfblock")
|
||||
|
||||
assert sum(b.count for b in result.buckets) == 2
|
||||
|
||||
@@ -1096,9 +1024,7 @@ class TestBansByJail:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail(
|
||||
"/fake/sock", "24h", origin="blocklist"
|
||||
)
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h", origin="blocklist")
|
||||
|
||||
assert len(result.jails) == 1
|
||||
assert result.jails[0].jail == "blocklist-import"
|
||||
@@ -1110,32 +1036,24 @@ class TestBansByJail:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail(
|
||||
"/fake/sock", "24h", origin="selfblock"
|
||||
)
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h", origin="selfblock")
|
||||
|
||||
jail_names = {j.jail for j in result.jails}
|
||||
assert "blocklist-import" not in jail_names
|
||||
assert result.total == 2
|
||||
|
||||
async def test_no_origin_filter_returns_all_jails(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_no_origin_filter_returns_all_jails(self, mixed_origin_db_path: str) -> None:
|
||||
"""``origin=None`` returns bans from all jails."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail(
|
||||
"/fake/sock", "24h", origin=None
|
||||
)
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h", origin=None)
|
||||
|
||||
assert result.total == 3
|
||||
assert len(result.jails) == 3
|
||||
|
||||
async def test_bans_by_jail_source_archive_reads_archive(
|
||||
self, app_db_with_archive: aiosqlite.Connection
|
||||
) -> None:
|
||||
async def test_bans_by_jail_source_archive_reads_archive(self, app_db_with_archive: aiosqlite.Connection) -> None:
|
||||
"""``bans_by_jail`` accepts source='archive' and aggregates archived rows."""
|
||||
result = await ban_service.bans_by_jail(
|
||||
"/fake/sock",
|
||||
@@ -1147,9 +1065,7 @@ class TestBansByJail:
|
||||
assert result.total == 2
|
||||
assert any(j.jail == "sshd" for j in result.jails)
|
||||
|
||||
async def test_diagnostic_warning_when_zero_results_despite_data(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_diagnostic_warning_when_zero_results_despite_data(self, tmp_path: Path) -> None:
|
||||
"""A warning is logged when the time-range filter excludes all existing rows."""
|
||||
import time as _time
|
||||
|
||||
@@ -1176,9 +1092,6 @@ class TestBansByJail:
|
||||
assert result.jails == []
|
||||
# The diagnostic warning must have been emitted.
|
||||
warning_calls = [
|
||||
c
|
||||
for c in mock_log.warning.call_args_list
|
||||
if c[0][0] == "ban_service_bans_by_jail_empty_despite_data"
|
||||
c for c in mock_log.warning.call_args_list if c[0][0] == "ban_service_bans_by_jail_empty_despite_data"
|
||||
]
|
||||
assert len(warning_calls) == 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user