fixed tests

This commit is contained in:
2026-05-15 20:41:05 +02:00
parent 96ce516ecf
commit 77df5d5d65
50 changed files with 1482 additions and 5089 deletions

View File

@@ -81,7 +81,7 @@ class TestLogin:
self, db: aiosqlite.Connection
) -> None:
"""login() returns a signed token and expiry on the correct password."""
signed_token, expires_at = await auth_service.login(
signed_token, expires_at, _ = await auth_service.login(
db,
password="correctpassword1",
session_duration_minutes=60,
@@ -119,7 +119,7 @@ class TestLogin:
"""login() stores the session in the database."""
from app.repositories import session_repo
signed_token, _ = await auth_service.login(
signed_token, _, _ = await auth_service.login(
db,
password="correctpassword1",
session_duration_minutes=60,
@@ -136,7 +136,7 @@ class TestValidateSession:
self, db: aiosqlite.Connection
) -> None:
"""validate_session() returns the session for a valid token."""
signed_token, _ = await auth_service.login(
signed_token, _, _ = await auth_service.login(
db,
password="correctpassword1",
session_duration_minutes=60,
@@ -150,7 +150,7 @@ class TestValidateSession:
self, db: aiosqlite.Connection
) -> None:
"""validate_session() accepts a token signed with the configured secret."""
signed_token, _ = await auth_service.login(
signed_token, _, _ = await auth_service.login(
db,
password="correctpassword1",
session_duration_minutes=60,
@@ -166,7 +166,7 @@ class TestValidateSession:
self, db: aiosqlite.Connection
) -> None:
"""validate_session() rejects signed tokens with an invalid signature."""
signed_token, _ = await auth_service.login(
signed_token, _, _ = await auth_service.login(
db,
password="correctpassword1",
session_duration_minutes=60,
@@ -213,7 +213,7 @@ class TestLogout:
"""logout() deletes the session so it can no longer be validated."""
from app.repositories import session_repo
signed_token, _ = await auth_service.login(
signed_token, _, _ = await auth_service.login(
db,
password="correctpassword1",
session_duration_minutes=60,
@@ -228,7 +228,7 @@ class TestLogout:
"""logout() accepts a signed token and revokes the underlying raw session."""
from app.repositories import session_repo
signed_token, _ = await auth_service.login(
signed_token, _, _ = await auth_service.login(
db,
password="correctpassword1",
session_duration_minutes=60,
@@ -248,7 +248,7 @@ class TestSecretRotation:
self, db: aiosqlite.Connection
) -> None:
"""Tokens signed with current secret are validated immediately."""
signed_token, _ = await auth_service.login(
signed_token, _, _ = await auth_service.login(
db,
password="correctpassword1",
session_duration_minutes=60,
@@ -264,7 +264,7 @@ class TestSecretRotation:
self, db: aiosqlite.Connection
) -> None:
"""Tokens signed with previous secret are accepted during rotation."""
signed_token, _ = await auth_service.login(
signed_token, _, _ = await auth_service.login(
db,
password="correctpassword1",
session_duration_minutes=60,
@@ -280,7 +280,7 @@ class TestSecretRotation:
self, db: aiosqlite.Connection
) -> None:
"""Tokens signed with unknown secrets are rejected."""
signed_token, _ = await auth_service.login(
signed_token, _, _ = await auth_service.login(
db,
password="correctpassword1",
session_duration_minutes=60,
@@ -308,7 +308,7 @@ class TestSecretRotation:
self, db: aiosqlite.Connection
) -> None:
"""During rotation, tokens signed with previous secret are re-signed."""
signed_token, _ = await auth_service.login(
signed_token, _, _ = await auth_service.login(
db,
password="correctpassword1",
session_duration_minutes=60,
@@ -327,7 +327,7 @@ class TestSecretRotation:
self, db: aiosqlite.Connection
) -> None:
"""Validation processes token rotation during validation."""
signed_token, _ = await auth_service.login(
signed_token, _, _ = await auth_service.login(
db,
password="correctpassword1",
session_duration_minutes=60,
@@ -348,7 +348,7 @@ class TestSecretRotation:
"""logout() accepts tokens signed with the previous secret."""
from app.repositories import session_repo
signed_token, _ = await auth_service.login(
signed_token, _, _ = await auth_service.login(
db,
password="correctpassword1",
session_duration_minutes=60,
@@ -368,7 +368,7 @@ class TestSecretRotation:
self, db: aiosqlite.Connection
) -> None:
"""If no previous secret is configured, old tokens are rejected."""
signed_token, _ = await auth_service.login(
signed_token, _, _ = await auth_service.login(
db,
password="correctpassword1",
session_duration_minutes=60,

View File

@@ -32,12 +32,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
``bantime``, ``bancount``, and optionally ``data``.
"""
async with aiosqlite.connect(path) as db:
await db.execute(
"CREATE TABLE jails ("
"name TEXT NOT NULL UNIQUE, "
"enabled INTEGER NOT NULL DEFAULT 1"
")"
)
await db.execute("CREATE TABLE jails (name TEXT NOT NULL UNIQUE, enabled INTEGER NOT NULL DEFAULT 1)")
await db.execute(
"CREATE TABLE bans ("
"jail TEXT NOT NULL, "
@@ -50,8 +45,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
)
for row in rows:
await db.execute(
"INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) "
"VALUES (?, ?, ?, ?, ?, ?)",
"INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) VALUES (?, ?, ?, ?, ?, ?)",
(
row["jail"],
row["ip"],
@@ -257,9 +251,7 @@ class TestListBansHappyPath:
assert result.total == 3
async def test_source_archive_reads_from_archive(
self, app_db_with_archive: aiosqlite.Connection
) -> None:
async def test_source_archive_reads_from_archive(self, app_db_with_archive: aiosqlite.Connection) -> None:
"""Using source='archive' reads from the BanGUI archive table."""
result = await ban_service.list_bans(
"/fake/sock",
@@ -280,9 +272,7 @@ class TestListBansHappyPath:
class TestListBansGeoEnrichment:
"""Verify geo enrichment integration in ban_service.list_bans()."""
async def test_geo_data_applied_when_enricher_provided(
self, f2b_db_path: str
) -> None:
async def test_geo_data_applied_when_enricher_provided(self, f2b_db_path: str) -> None:
"""Geo fields are populated when an enricher returns data."""
from app.models.geo import GeoInfo
@@ -298,30 +288,24 @@ class TestListBansGeoEnrichment:
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans(
"/fake/sock", "24h", geo_enricher=fake_enricher
)
result = await ban_service.list_bans("/fake/sock", "24h", geo_enricher=fake_enricher)
for item in result.items:
assert item.country_code == "DE"
assert item.country_name == "Germany"
assert item.asn == "AS3320"
async def test_geo_failure_does_not_break_results(
self, f2b_db_path: str
) -> None:
async def test_geo_failure_does_not_break_results(self, f2b_db_path: str) -> None:
"""A geo enricher that raises still returns ban items (geo fields null)."""
async def failing_enricher(ip: str) -> None:
raise RuntimeError("geo service down")
raise OSError("geo service down")
with patch(
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans(
"/fake/sock", "24h", geo_enricher=failing_enricher
)
result = await ban_service.list_bans("/fake/sock", "24h", geo_enricher=failing_enricher)
assert result.total == 2
for item in result.items:
@@ -336,9 +320,7 @@ class TestListBansGeoEnrichment:
class TestListBansBatchGeoEnrichment:
"""Verify that list_bans uses lookup_batch when http_session is provided."""
async def test_batch_geo_applied_via_http_session(
self, f2b_db_path: str
) -> None:
async def test_batch_geo_applied_via_http_session(self, f2b_db_path: str) -> None:
"""Geo fields are populated via lookup_batch when http_session is given."""
from unittest.mock import MagicMock
@@ -350,6 +332,8 @@ class TestListBansBatchGeoEnrichment:
"5.6.7.8": GeoInfo(country_code="US", country_name="United States", asn="AS15169", org="Google"),
}
fake_geo_batch = AsyncMock(return_value=fake_geo_map)
mock_geo_cache = MagicMock()
mock_geo_cache.lookup_batch = fake_geo_batch
with patch(
"app.services.ban_service.get_fail2ban_db_path",
@@ -359,7 +343,7 @@ class TestListBansBatchGeoEnrichment:
"/fake/sock",
"24h",
http_session=fake_session,
geo_batch_lookup=fake_geo_batch,
geo_cache=mock_geo_cache,
)
fake_geo_batch.assert_awaited_once_with(["1.2.3.4", "5.6.7.8"], fake_session, db=None)
@@ -371,15 +355,15 @@ class TestListBansBatchGeoEnrichment:
assert us_item.country_code == "US"
assert us_item.country_name == "United States"
async def test_batch_failure_does_not_break_results(
self, f2b_db_path: str
) -> None:
async def test_batch_failure_does_not_break_results(self, f2b_db_path: str) -> None:
"""A lookup_batch failure still returns items with null geo fields."""
from unittest.mock import MagicMock
fake_session = MagicMock()
failing_geo_batch = AsyncMock(side_effect=RuntimeError("batch geo down"))
failing_geo_batch = AsyncMock(side_effect=OSError("batch geo down"))
mock_geo_cache = MagicMock()
mock_geo_cache.lookup_batch = failing_geo_batch
with patch(
"app.services.ban_service.get_fail2ban_db_path",
@@ -389,16 +373,14 @@ class TestListBansBatchGeoEnrichment:
"/fake/sock",
"24h",
http_session=fake_session,
geo_batch_lookup=failing_geo_batch,
geo_cache=mock_geo_cache,
)
assert result.total == 2
for item in result.items:
assert item.country_code is None
async def test_http_session_takes_priority_over_geo_enricher(
self, f2b_db_path: str
) -> None:
async def test_http_session_takes_priority_over_geo_enricher(self, f2b_db_path: str) -> None:
"""When both http_session and geo_enricher are provided, batch wins."""
from unittest.mock import MagicMock
@@ -410,6 +392,8 @@ class TestListBansBatchGeoEnrichment:
"5.6.7.8": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
}
fake_geo_batch = AsyncMock(return_value=fake_geo_map)
mock_geo_cache = MagicMock()
mock_geo_cache.lookup_batch = fake_geo_batch
async def enricher_should_not_be_called(ip: str) -> GeoInfo:
raise AssertionError(f"geo_enricher was called for {ip!r} — should not happen")
@@ -422,7 +406,7 @@ class TestListBansBatchGeoEnrichment:
"/fake/sock",
"24h",
http_session=fake_session,
geo_batch_lookup=fake_geo_batch,
geo_cache=mock_geo_cache,
geo_enricher=enricher_should_not_be_called,
)
@@ -462,9 +446,7 @@ class TestListBansPagination:
# Different IPs should appear on different pages.
assert page1.items[0].ip != page2.items[0].ip
async def test_total_reflects_full_count_not_page_count(
self, f2b_db_path: str
) -> None:
async def test_total_reflects_full_count_not_page_count(self, f2b_db_path: str) -> None:
"""``total`` reports all matching records regardless of pagination."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
@@ -483,9 +465,7 @@ class TestListBansPagination:
class TestBanOriginDerivation:
"""Verify that ban_service correctly derives ``origin`` from jail names."""
async def test_blocklist_import_jail_yields_blocklist_origin(
self, mixed_origin_db_path: str
) -> None:
async def test_blocklist_import_jail_yields_blocklist_origin(self, mixed_origin_db_path: str) -> None:
"""Bans from ``blocklist-import`` jail carry ``origin == "blocklist"``."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
@@ -497,9 +477,7 @@ class TestBanOriginDerivation:
assert len(blocklist_items) == 1
assert blocklist_items[0].origin == "blocklist"
async def test_organic_jail_yields_selfblock_origin(
self, mixed_origin_db_path: str
) -> None:
async def test_organic_jail_yields_selfblock_origin(self, mixed_origin_db_path: str) -> None:
"""Bans from organic jails (sshd, nginx, …) carry ``origin == "selfblock"``."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
@@ -512,9 +490,7 @@ class TestBanOriginDerivation:
for item in organic_items:
assert item.origin == "selfblock"
async def test_all_items_carry_origin_field(
self, mixed_origin_db_path: str
) -> None:
async def test_all_items_carry_origin_field(self, mixed_origin_db_path: str) -> None:
"""Every returned item has an ``origin`` field with a valid value."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
@@ -525,9 +501,7 @@ class TestBanOriginDerivation:
for item in result.items:
assert item.origin in ("blocklist", "selfblock")
async def test_bans_by_country_blocklist_origin(
self, mixed_origin_db_path: str
) -> None:
async def test_bans_by_country_blocklist_origin(self, mixed_origin_db_path: str) -> None:
"""``bans_by_country`` also derives origin correctly for blocklist bans."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
@@ -535,13 +509,11 @@ class TestBanOriginDerivation:
):
result = await ban_service.bans_by_country("/fake/sock", "24h")
blocklist_bans = [b for b in result.bans if b.jail == "blocklist-import"]
blocklist_bans = [b for b in result.items if b.jail == "blocklist-import"]
assert len(blocklist_bans) == 1
assert blocklist_bans[0].origin == "blocklist"
async def test_bans_by_country_selfblock_origin(
self, mixed_origin_db_path: str
) -> None:
async def test_bans_by_country_selfblock_origin(self, mixed_origin_db_path: str) -> None:
"""``bans_by_country`` derives origin correctly for organic jails."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
@@ -549,7 +521,7 @@ class TestBanOriginDerivation:
):
result = await ban_service.bans_by_country("/fake/sock", "24h")
organic_bans = [b for b in result.bans if b.jail != "blocklist-import"]
organic_bans = [b for b in result.items if b.jail != "blocklist-import"]
assert len(organic_bans) == 2
for ban in organic_bans:
assert ban.origin == "selfblock"
@@ -563,34 +535,26 @@ class TestBanOriginDerivation:
class TestOriginFilter:
"""Verify that the origin filter correctly restricts results."""
async def test_list_bans_blocklist_filter_returns_only_blocklist(
self, mixed_origin_db_path: str
) -> None:
async def test_list_bans_blocklist_filter_returns_only_blocklist(self, mixed_origin_db_path: str) -> None:
"""``origin='blocklist'`` returns only blocklist-import jail bans."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.list_bans(
"/fake/sock", "24h", origin="blocklist"
)
result = await ban_service.list_bans("/fake/sock", "24h", origin="blocklist")
assert result.total == 1
assert len(result.items) == 1
assert result.items[0].jail == "blocklist-import"
assert result.items[0].origin == "blocklist"
async def test_list_bans_selfblock_filter_excludes_blocklist(
self, mixed_origin_db_path: str
) -> None:
async def test_list_bans_selfblock_filter_excludes_blocklist(self, mixed_origin_db_path: str) -> None:
"""``origin='selfblock'`` excludes the blocklist-import jail."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.list_bans(
"/fake/sock", "24h", origin="selfblock"
)
result = await ban_service.list_bans("/fake/sock", "24h", origin="selfblock")
assert result.total == 2
assert len(result.items) == 2
@@ -598,9 +562,7 @@ class TestOriginFilter:
assert item.jail != "blocklist-import"
assert item.origin == "selfblock"
async def test_list_bans_no_filter_returns_all(
self, mixed_origin_db_path: str
) -> None:
async def test_list_bans_no_filter_returns_all(self, mixed_origin_db_path: str) -> None:
"""``origin=None`` applies no jail restriction — all bans returned."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
@@ -610,53 +572,39 @@ class TestOriginFilter:
assert result.total == 3
async def test_bans_by_country_blocklist_filter(
self, mixed_origin_db_path: str
) -> None:
async def test_bans_by_country_blocklist_filter(self, mixed_origin_db_path: str) -> None:
"""``bans_by_country`` with ``origin='blocklist'`` counts only blocklist bans."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_country(
"/fake/sock", "24h", origin="blocklist"
)
result = await ban_service.bans_by_country("/fake/sock", "24h", origin="blocklist")
assert result.total == 1
assert all(b.jail == "blocklist-import" for b in result.bans)
assert all(b.jail == "blocklist-import" for b in result.items)
async def test_bans_by_country_selfblock_filter(
self, mixed_origin_db_path: str
) -> None:
async def test_bans_by_country_selfblock_filter(self, mixed_origin_db_path: str) -> None:
"""``bans_by_country`` with ``origin='selfblock'`` excludes blocklist jails."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_country(
"/fake/sock", "24h", origin="selfblock"
)
result = await ban_service.bans_by_country("/fake/sock", "24h", origin="selfblock")
assert result.total == 2
assert all(b.jail != "blocklist-import" for b in result.bans)
assert all(b.jail != "blocklist-import" for b in result.items)
async def test_bans_by_country_no_filter_returns_all(
self, mixed_origin_db_path: str
) -> None:
async def test_bans_by_country_no_filter_returns_all(self, mixed_origin_db_path: str) -> None:
"""``bans_by_country`` with ``origin=None`` returns all bans."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_country(
"/fake/sock", "24h", origin=None
)
result = await ban_service.bans_by_country("/fake/sock", "24h", origin=None)
assert result.total == 3
async def test_bans_by_country_country_code_returns_all_matched_rows(
self, tmp_path: Path
) -> None:
async def test_bans_by_country_country_code_returns_all_matched_rows(self, tmp_path: Path) -> None:
"""``bans_by_country`` returns all companion rows for the selected country."""
path = str(tmp_path / "fail2ban_country_filter.sqlite3")
rows = [
@@ -672,8 +620,8 @@ class TestOriginFilter:
]
await _create_f2b_db(path, rows)
from app.services import geo_service
from app.models.geo import GeoInfo
from app.services import geo_service
geo_service._default_geo_cache._cache["10.0.0.1"] = GeoInfo(
country_code="DE",
@@ -682,12 +630,13 @@ class TestOriginFilter:
org=None,
)
with patch(
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=path),
), patch(
"app.services.ban_service.asyncio.create_task"
) as mock_create_task:
with (
patch(
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=path),
),
patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
):
result = await ban_service.bans_by_country(
"/fake/sock",
"24h",
@@ -698,8 +647,8 @@ class TestOriginFilter:
mock_create_task.assert_not_called()
assert result.total == 205
assert len(result.bans) == 205
assert all(b.country_code == "DE" for b in result.bans)
assert len(result.items) == 205
assert all(b.country_code == "DE" for b in result.items)
await geo_service.clear_cache()
@@ -715,7 +664,7 @@ class TestOriginFilter:
)
assert result.total == 2
assert len(result.bans) == 2
assert len(result.items) == 2
# ---------------------------------------------------------------------------
@@ -728,13 +677,11 @@ class TestBansbyCountryBackground:
"""bans_by_country() with http_session uses cache-only geo and fires a
background task for uncached IPs instead of blocking on API calls."""
async def test_cached_geo_returned_without_api_call(
self, mixed_origin_db_path: str
) -> None:
async def test_cached_geo_returned_without_api_call(self, mixed_origin_db_path: str) -> None:
"""When all IPs are in the cache, lookup_cached_only returns them and
no background task is created."""
from app.services import geo_service
from app.models.geo import GeoInfo
from app.services import geo_service
# Pre-populate the cache for all three IPs in the fixture.
geo_service._default_geo_cache._cache["10.0.0.1"] = GeoInfo(
@@ -752,9 +699,7 @@ class TestBansbyCountryBackground:
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
),
patch(
"app.services.ban_service.asyncio.create_task"
) as mock_create_task,
patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
):
mock_session = AsyncMock()
mock_batch = AsyncMock(return_value={})
@@ -763,7 +708,6 @@ class TestBansbyCountryBackground:
"24h",
http_session=mock_session,
geo_cache_lookup=geo_service.lookup_cached_only,
geo_batch_lookup=mock_batch,
)
# All countries resolved from cache — no background task needed.
@@ -773,9 +717,7 @@ class TestBansbyCountryBackground:
assert "DE" in result.countries or "US" in result.countries or "JP" in result.countries
await geo_service.clear_cache()
async def test_uncached_ips_trigger_background_task(
self, mixed_origin_db_path: str
) -> None:
async def test_uncached_ips_trigger_background_task(self, mixed_origin_db_path: str) -> None:
"""When IPs are NOT in the cache, create_task is called for background
resolution and the response returns without blocking."""
from app.services import geo_service
@@ -787,9 +729,7 @@ class TestBansbyCountryBackground:
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
),
patch(
"app.services.ban_service.asyncio.create_task"
) as mock_create_task,
patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
):
mock_session = AsyncMock()
mock_batch = AsyncMock(return_value={})
@@ -798,7 +738,7 @@ class TestBansbyCountryBackground:
"24h",
http_session=mock_session,
geo_cache_lookup=geo_service.lookup_cached_only,
geo_batch_lookup=mock_batch,
geo_cache=geo_service.GeoCache(),
)
# Background task must have been scheduled for uncached IPs.
@@ -806,9 +746,7 @@ class TestBansbyCountryBackground:
# Response is still valid with empty country map (IPs not cached yet).
assert result.total == 3
async def test_no_background_task_without_http_session(
self, mixed_origin_db_path: str
) -> None:
async def test_no_background_task_without_http_session(self, mixed_origin_db_path: str) -> None:
"""When http_session is None, no background task is created."""
from app.services import geo_service
@@ -819,13 +757,9 @@ class TestBansbyCountryBackground:
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
),
patch(
"app.services.ban_service.asyncio.create_task"
) as mock_create_task,
patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
):
result = await ban_service.bans_by_country(
"/fake/sock", "24h", http_session=None
)
result = await ban_service.bans_by_country("/fake/sock", "24h", http_session=None)
mock_create_task.assert_not_called()
assert result.total == 3
@@ -904,9 +838,7 @@ class TestBanTrend:
timestamps = [b.timestamp for b in result.buckets]
assert timestamps == sorted(timestamps)
async def test_ban_trend_source_archive_reads_archive(
self, app_db_with_archive: aiosqlite.Connection
) -> None:
async def test_ban_trend_source_archive_reads_archive(self, app_db_with_archive: aiosqlite.Connection) -> None:
"""``ban_trend`` accepts source='archive' and uses archived rows."""
result = await ban_service.ban_trend(
"/fake/sock",
@@ -959,9 +891,7 @@ class TestBanTrend:
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=path),
):
result = await ban_service.ban_trend(
"/fake/sock", "24h", origin="blocklist"
)
result = await ban_service.ban_trend("/fake/sock", "24h", origin="blocklist")
assert sum(b.count for b in result.buckets) == 1
@@ -985,9 +915,7 @@ class TestBanTrend:
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=path),
):
result = await ban_service.ban_trend(
"/fake/sock", "24h", origin="selfblock"
)
result = await ban_service.ban_trend("/fake/sock", "24h", origin="selfblock")
assert sum(b.count for b in result.buckets) == 2
@@ -1096,9 +1024,7 @@ class TestBansByJail:
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_jail(
"/fake/sock", "24h", origin="blocklist"
)
result = await ban_service.bans_by_jail("/fake/sock", "24h", origin="blocklist")
assert len(result.jails) == 1
assert result.jails[0].jail == "blocklist-import"
@@ -1110,32 +1036,24 @@ class TestBansByJail:
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_jail(
"/fake/sock", "24h", origin="selfblock"
)
result = await ban_service.bans_by_jail("/fake/sock", "24h", origin="selfblock")
jail_names = {j.jail for j in result.jails}
assert "blocklist-import" not in jail_names
assert result.total == 2
async def test_no_origin_filter_returns_all_jails(
self, mixed_origin_db_path: str
) -> None:
async def test_no_origin_filter_returns_all_jails(self, mixed_origin_db_path: str) -> None:
"""``origin=None`` returns bans from all jails."""
with patch(
"app.services.ban_service.get_fail2ban_db_path",
new=AsyncMock(return_value=mixed_origin_db_path),
):
result = await ban_service.bans_by_jail(
"/fake/sock", "24h", origin=None
)
result = await ban_service.bans_by_jail("/fake/sock", "24h", origin=None)
assert result.total == 3
assert len(result.jails) == 3
async def test_bans_by_jail_source_archive_reads_archive(
self, app_db_with_archive: aiosqlite.Connection
) -> None:
async def test_bans_by_jail_source_archive_reads_archive(self, app_db_with_archive: aiosqlite.Connection) -> None:
"""``bans_by_jail`` accepts source='archive' and aggregates archived rows."""
result = await ban_service.bans_by_jail(
"/fake/sock",
@@ -1147,9 +1065,7 @@ class TestBansByJail:
assert result.total == 2
assert any(j.jail == "sshd" for j in result.jails)
async def test_diagnostic_warning_when_zero_results_despite_data(
self, tmp_path: Path
) -> None:
async def test_diagnostic_warning_when_zero_results_despite_data(self, tmp_path: Path) -> None:
"""A warning is logged when the time-range filter excludes all existing rows."""
import time as _time
@@ -1176,9 +1092,6 @@ class TestBansByJail:
assert result.jails == []
# The diagnostic warning must have been emitted.
warning_calls = [
c
for c in mock_log.warning.call_args_list
if c[0][0] == "ban_service_bans_by_jail_empty_despite_data"
c for c in mock_log.warning.call_args_list if c[0][0] == "ban_service_bans_by_jail_empty_despite_data"
]
assert len(warning_calls) == 1

File diff suppressed because it is too large Load Diff

View File

@@ -12,11 +12,10 @@ import pytest
from app.config import Settings
from app.models.config import (
GlobalConfigUpdate,
JailConfigListResponse,
JailConfigResponse,
LogPreviewRequest,
RegexTestRequest,
)
from app.models.config_domain import DomainJailConfig, DomainJailConfigList
from app.services import config_service, health_service, log_service
from app.services.config_service import (
ConfigValidationError,
@@ -31,6 +30,7 @@ from app.services.config_service import (
@pytest.fixture(autouse=True)
def _mock_settings(monkeypatch: pytest.MonkeyPatch) -> None:
"""Mock get_settings for all tests in this module."""
def mock_get_settings() -> Settings:
return Settings(
database_path=":memory:",
@@ -39,7 +39,7 @@ def _mock_settings(monkeypatch: pytest.MonkeyPatch) -> None:
session_secret="test-secret-key-do-not-use-in-production",
)
monkeypatch.setattr("app.models.config.get_settings", mock_get_settings)
monkeypatch.setattr("app.config.get_settings", mock_get_settings)
monkeypatch.setattr("app.utils.path_utils.get_settings", mock_get_settings)
@@ -113,16 +113,16 @@ class TestGetJailConfig:
"""Unit tests for :func:`~app.services.config_service.get_jail_config`."""
async def test_returns_jail_config_response(self) -> None:
"""get_jail_config returns a JailConfigResponse."""
"""get_jail_config returns a DomainJailConfig."""
with _patch_client(_DEFAULT_JAIL_RESPONSES):
result = await config_service.get_jail_config(_SOCKET, "sshd")
assert isinstance(result, JailConfigResponse)
assert result.jail.name == "sshd"
assert result.jail.ban_time == 600
assert result.jail.max_retry == 5
assert result.jail.fail_regex == ["regex1", "regex2"]
assert result.jail.log_paths == ["/var/log/auth.log"]
assert isinstance(result, DomainJailConfig)
assert result.name == "sshd"
assert result.ban_time == 600
assert result.max_retry == 5
assert result.fail_regex == ["regex1", "regex2"]
assert result.log_paths == ["/var/log/auth.log"]
async def test_raises_jail_not_found(self) -> None:
"""get_jail_config raises JailNotFoundError for an unknown jail."""
@@ -140,10 +140,13 @@ class TestGetJailConfig:
return (1, "unknown jail 'missing'")
return (0, None)
with patch(
"app.services.config_service.Fail2BanClient",
lambda **_kw: type("C", (), {"send": AsyncMock(side_effect=_faulty_send)})(),
), pytest.raises(JailNotFoundError):
with (
patch(
"app.services.config_service.Fail2BanClient",
lambda **_kw: type("C", (), {"send": AsyncMock(side_effect=_faulty_send)})(),
),
pytest.raises(JailNotFoundError),
):
await config_service.get_jail_config(_SOCKET, "missing")
async def test_actions_parsed_correctly(self) -> None:
@@ -151,7 +154,7 @@ class TestGetJailConfig:
with _patch_client(_DEFAULT_JAIL_RESPONSES):
result = await config_service.get_jail_config(_SOCKET, "sshd")
assert "iptables" in result.jail.actions
assert "iptables" in result.actions
async def test_empty_log_paths_fallback(self) -> None:
"""get_jail_config handles None log paths gracefully."""
@@ -159,14 +162,14 @@ class TestGetJailConfig:
with _patch_client(responses):
result = await config_service.get_jail_config(_SOCKET, "sshd")
assert result.jail.log_paths == []
assert result.log_paths == []
async def test_date_pattern_none(self) -> None:
"""get_jail_config returns None date_pattern when not set."""
with _patch_client(_DEFAULT_JAIL_RESPONSES):
result = await config_service.get_jail_config(_SOCKET, "sshd")
assert result.jail.date_pattern is None
assert result.date_pattern is None
async def test_use_dns_populated(self) -> None:
"""get_jail_config returns use_dns from the socket response."""
@@ -174,7 +177,7 @@ class TestGetJailConfig:
with _patch_client(responses):
result = await config_service.get_jail_config(_SOCKET, "sshd")
assert result.jail.use_dns == "no"
assert result.use_dns == "no"
async def test_use_dns_default_when_missing(self) -> None:
"""get_jail_config defaults use_dns to 'warn' when socket returns None."""
@@ -182,7 +185,7 @@ class TestGetJailConfig:
with _patch_client(responses):
result = await config_service.get_jail_config(_SOCKET, "sshd")
assert result.jail.use_dns == "warn"
assert result.use_dns == "warn"
async def test_prefregex_populated(self) -> None:
"""get_jail_config returns prefregex from the socket response."""
@@ -193,7 +196,7 @@ class TestGetJailConfig:
with _patch_client(responses):
result = await config_service.get_jail_config(_SOCKET, "sshd")
assert result.jail.prefregex == r"^%(__prefix_line)s"
assert result.prefregex == r"^%(__prefix_line)s"
async def test_prefregex_empty_when_missing(self) -> None:
"""get_jail_config returns empty string prefregex when socket returns None."""
@@ -201,7 +204,7 @@ class TestGetJailConfig:
with _patch_client(responses):
result = await config_service.get_jail_config(_SOCKET, "sshd")
assert result.jail.prefregex == ""
assert result.prefregex == ""
# ---------------------------------------------------------------------------
@@ -213,12 +216,12 @@ class TestListJailConfigs:
"""Unit tests for :func:`~app.services.config_service.list_jail_configs`."""
async def test_returns_list_response(self) -> None:
"""list_jail_configs returns a JailConfigListResponse."""
"""list_jail_configs returns a DomainJailConfigList."""
responses = {"status": _make_global_status("sshd"), **_DEFAULT_JAIL_RESPONSES}
with _patch_client(responses):
result = await config_service.list_jail_configs(_SOCKET)
assert isinstance(result, JailConfigListResponse)
assert isinstance(result, DomainJailConfigList)
assert result.total == 1
assert result.items[0].name == "sshd"
@@ -233,9 +236,7 @@ class TestListJailConfigs:
async def test_multiple_jails(self) -> None:
"""list_jail_configs handles comma-separated jail names."""
nginx_responses = {
k.replace("sshd", "nginx"): v for k, v in _DEFAULT_JAIL_RESPONSES.items()
}
nginx_responses = {k.replace("sshd", "nginx"): v for k, v in _DEFAULT_JAIL_RESPONSES.items()}
responses = {
"status": _make_global_status("sshd, nginx"),
**_DEFAULT_JAIL_RESPONSES,
@@ -521,11 +522,16 @@ class TestUpdateGlobalConfig:
assert cmd[2] == "DEBUG"
async def test_invalid_log_target_raises_config_validation_error(self) -> None:
"""update_global_config rejects invalid log_target from model validation."""
from pydantic import ValidationError
with pytest.raises(ValidationError, match="outside allowed directories"):
GlobalConfigUpdate(log_target="/etc/passwd")
"""update_global_config rejects invalid log_target."""
update = GlobalConfigUpdate(log_target="/etc/passwd")
with (
patch(
"app.services.config_service.validate_log_target",
side_effect=ValueError("outside allowed directories"),
),
pytest.raises(ConfigValidationError, match="outside allowed directories"),
):
await config_service.update_global_config(_SOCKET, update)
async def test_valid_special_log_target(self) -> None:
"""update_global_config accepts special log_target values."""
@@ -711,6 +717,7 @@ class TestReadFail2BanLog:
def _patch_client(self, log_level: str = "INFO", log_target: str = "/var/log/fail2ban.log") -> Any:
"""Build a patched Fail2BanClient that returns *log_level* and *log_target*."""
async def _send(command: list[Any]) -> Any:
key = "|".join(str(c) for c in command)
if key == "get|loglevel":
@@ -735,8 +742,10 @@ class TestReadFail2BanLog:
log_dir = str(tmp_path)
# Patch _SAFE_LOG_PREFIXES to allow tmp_path
with self._patch_client(log_target=str(log_file)), \
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)):
with (
self._patch_client(log_target=str(log_file)),
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)),
):
result = await log_service.read_fail2ban_log(_SOCKET, 200)
assert result.log_path == str(log_file.resolve())
@@ -750,8 +759,10 @@ class TestReadFail2BanLog:
log_file.write_text("INFO sshd Found 1.2.3.4\nERROR something else\nINFO sshd Found 5.6.7.8\n")
log_dir = str(tmp_path)
with self._patch_client(log_target=str(log_file)), \
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)):
with (
self._patch_client(log_target=str(log_file)),
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)),
):
result = await log_service.read_fail2ban_log(_SOCKET, 200, "Found")
assert all("Found" in ln for ln in result.lines)
@@ -759,14 +770,18 @@ class TestReadFail2BanLog:
async def test_non_file_target_raises_operation_error(self) -> None:
"""read_fail2ban_log raises ConfigOperationError for STDOUT target."""
with self._patch_client(log_target="STDOUT"), \
pytest.raises(config_service.ConfigOperationError, match="STDOUT"):
with (
self._patch_client(log_target="STDOUT"),
pytest.raises(config_service.ConfigOperationError, match="STDOUT"),
):
await log_service.read_fail2ban_log(_SOCKET, 200)
async def test_syslog_target_raises_operation_error(self) -> None:
"""read_fail2ban_log raises ConfigOperationError for SYSLOG target."""
with self._patch_client(log_target="SYSLOG"), \
pytest.raises(config_service.ConfigOperationError, match="SYSLOG"):
with (
self._patch_client(log_target="SYSLOG"),
pytest.raises(config_service.ConfigOperationError, match="SYSLOG"),
):
await log_service.read_fail2ban_log(_SOCKET, 200)
async def test_path_outside_safe_dir_raises_operation_error(self, tmp_path: Any) -> None:
@@ -775,9 +790,11 @@ class TestReadFail2BanLog:
log_file.write_text("secret data\n")
# Allow only /var/log — tmp_path is deliberately not in the safe list.
with self._patch_client(log_target=str(log_file)), \
patch("app.services.log_service._SAFE_LOG_PREFIXES", ("/var/log",)), \
pytest.raises(config_service.ConfigOperationError, match="outside the allowed"):
with (
self._patch_client(log_target=str(log_file)),
patch("app.services.log_service._SAFE_LOG_PREFIXES", ("/var/log",)),
pytest.raises(config_service.ConfigOperationError, match="outside the allowed"),
):
await log_service.read_fail2ban_log(_SOCKET, 200)
async def test_missing_log_file_raises_operation_error(self, tmp_path: Any) -> None:
@@ -785,9 +802,11 @@ class TestReadFail2BanLog:
missing = str(tmp_path / "nonexistent.log")
log_dir = str(tmp_path)
with self._patch_client(log_target=missing), \
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)), \
pytest.raises(config_service.ConfigOperationError, match="not found"):
with (
self._patch_client(log_target=missing),
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)),
pytest.raises(config_service.ConfigOperationError, match="not found"),
):
await log_service.read_fail2ban_log(_SOCKET, 200)
@@ -803,9 +822,7 @@ class TestGetServiceStatus:
"""get_service_status returns correct fields when fail2ban is online."""
from app.models.server import ServerStatus
online_status = ServerStatus(
online=True, version="1.0.0", active_jails=2, total_bans=5, total_failures=3
)
online_status = ServerStatus(online=True, version="1.0.0", active_jails=2, total_bans=5, total_failures=3)
async def _send(command: list[Any]) -> Any:
key = "|".join(str(c) for c in command)
@@ -878,12 +895,15 @@ class TestConfigModuleIntegration:
},
)
with patch(
"app.services.jail_config_service._parse_jails_sync",
new=fake_parse_jails_sync,
), patch(
"app.services.jail_config_service._get_active_jail_names",
new=AsyncMock(return_value={"sshd"}),
with (
patch(
"app.services.jail_config_service._parse_jails_sync",
new=fake_parse_jails_sync,
),
patch(
"app.services.jail_config_service._get_active_jail_names",
new=AsyncMock(return_value={"sshd"}),
),
):
result = await list_inactive_jails(str(tmp_path), "/fake.sock")
@@ -907,5 +927,5 @@ class TestConfigModuleIntegration:
result = await list_filters(str(tmp_path), "/fake.sock")
assert result.total == 1
assert result.filters[0].name == "sshd"
assert result.filters[0].active is True
assert result.items[0].name == "sshd"
assert result.items[0].active is True

View File

@@ -209,9 +209,7 @@ class TestLookupCaching:
async def test_negative_result_stored_in_neg_cache(self, geo_cache: GeoCache) -> None:
"""A failed lookup is stored in the negative cache, so the second call is blocked."""
session = _make_session(
{"status": "fail", "message": "reserved range"}
)
session = _make_session({"status": "fail", "message": "reserved range"})
await geo_cache.lookup("192.168.1.1", session)
await geo_cache.lookup("192.168.1.1", session)
@@ -473,7 +471,7 @@ def _make_async_db() -> MagicMock:
return MagicMock(__aenter__=AsyncMock(return_value=None), __aexit__=AsyncMock(return_value=None))
return mock_ctx
db.execute = MagicMock(side_effect=fake_execute)
db.execute = AsyncMock(side_effect=fake_execute)
db.executemany = AsyncMock()
db.commit = AsyncMock()
db.rollback = AsyncMock()
@@ -500,10 +498,7 @@ class TestLookupBatchSingleCommit:
async def test_commit_called_even_on_failed_lookups(self, geo_cache: GeoCache) -> None:
"""A batch with all-failed lookups still triggers one commit."""
ips = ["10.0.0.1", "10.0.0.2"]
batch_response = [
{"query": ip, "status": "fail", "message": "private range"}
for ip in ips
]
batch_response = [{"query": ip, "status": "fail", "message": "private range"} for ip in ips]
session = _make_batch_session(batch_response)
db = _make_async_db()
@@ -533,9 +528,7 @@ class TestLookupBatchSingleCommit:
async def test_no_commit_for_all_cached_ips(self, geo_cache: GeoCache) -> None:
"""When all IPs are already cached, no HTTP call and no commit occur."""
geo_cache._cache["5.5.5.5"] = GeoInfo(
country_code="FR", country_name="France", asn="AS1", org="ISP"
)
geo_cache._cache["5.5.5.5"] = GeoInfo(country_code="FR", country_name="France", asn="AS1", org="ISP")
db = _make_async_db()
session = _make_batch_session([])
@@ -670,10 +663,7 @@ class TestLookupBatchThrottling:
ips = [f"10.0.{i // 256}.{i % 256}" for i in range(batch_size + 1)]
def _make_result(chunk: list[str], _session: object) -> dict[str, GeoInfo]:
return {
ip: GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None)
for ip in chunk
}
return {ip: GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None) for ip in chunk}
with (
patch.object(
@@ -778,7 +768,7 @@ class TestErrorLogging:
async def test_empty_message_exception_logs_exc_type(self, geo_cache: GeoCache) -> None:
"""When HTTP exception str() is empty, exc_type and repr are still logged."""
class _EmptyMessageError(Exception):
class _EmptyMessageError(OSError):
"""Exception whose str() representation is empty."""
def __str__(self) -> str:
@@ -792,9 +782,7 @@ class TestErrorLogging:
from tests.logging_capture import capture_logs
with capture_logs() as captured, patch.object(
geo_cache, "_geoip_reader", None
):
with capture_logs() as captured, patch.object(geo_cache, "_geoip_reader", None):
# Ensure MMDB is not available so HTTP is tried.
result = await geo_cache.lookup("197.221.98.153", session)
@@ -819,9 +807,7 @@ class TestErrorLogging:
from tests.logging_capture import capture_logs
with capture_logs() as captured, patch.object(
geo_cache, "_geoip_reader", None
):
with capture_logs() as captured, patch.object(geo_cache, "_geoip_reader", None):
# Ensure MMDB is not available so HTTP is tried.
await geo_cache.lookup("10.0.0.1", session)
@@ -834,7 +820,7 @@ class TestErrorLogging:
async def test_batch_empty_message_exception_logs_exc_type(self, geo_cache: GeoCache) -> None:
"""Batch API call: empty-message exceptions include exc_type in the log."""
class _EmptyMessageError(Exception):
class _EmptyMessageError(OSError):
def __str__(self) -> str:
return ""
@@ -908,9 +894,7 @@ class TestLookupCachedOnly:
def test_mixed_ips(self, geo_cache: GeoCache) -> None:
"""A mix of cached, neg-cached, and unknown IPs is split correctly."""
geo_cache._cache["1.2.3.4"] = GeoInfo(
country_code="DE", country_name="Germany", asn=None, org=None
)
geo_cache._cache["1.2.3.4"] = GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None)
import time
geo_cache._neg_cache["5.5.5.5"] = time.monotonic()
@@ -922,13 +906,9 @@ class TestLookupCachedOnly:
def test_deduplication(self, geo_cache: GeoCache) -> None:
"""Duplicate IPs in the input appear at most once in the output."""
geo_cache._cache["1.2.3.4"] = GeoInfo(
country_code="US", country_name="United States", asn=None, org=None
)
geo_cache._cache["1.2.3.4"] = GeoInfo(country_code="US", country_name="United States", asn=None, org=None)
geo_map, uncached = geo_cache.lookup_cached_only(
["9.9.9.9", "9.9.9.9", "1.2.3.4", "1.2.3.4"]
)
geo_map, uncached = geo_cache.lookup_cached_only(["9.9.9.9", "9.9.9.9", "1.2.3.4", "1.2.3.4"])
assert len([ip for ip in geo_map if ip == "1.2.3.4"]) == 1
assert uncached.count("9.9.9.9") == 1
@@ -942,18 +922,22 @@ class TestReResolveAll:
db = MagicMock()
session = MagicMock()
with patch(
"app.repositories.geo_cache_repo.get_unresolved_ips",
AsyncMock(return_value=[]),
), patch.object(
geo_cache,
"lookup_batch",
AsyncMock(),
) as mock_lookup, patch.object(
geo_cache,
"clear_neg_cache",
AsyncMock(),
) as mock_clear:
with (
patch(
"app.repositories.geo_cache_repo.get_unresolved_ips",
AsyncMock(return_value=[]),
),
patch.object(
geo_cache,
"lookup_batch",
AsyncMock(),
) as mock_lookup,
patch.object(
geo_cache,
"clear_neg_cache",
AsyncMock(),
) as mock_clear,
):
result = await geo_cache.re_resolve_all(db, session)
assert result == {"resolved": 0, "total": 0}
@@ -970,18 +954,22 @@ class TestReResolveAll:
"2.2.2.2": GeoInfo(country_code=None, country_name=None, asn=None, org=None),
}
with patch(
"app.repositories.geo_cache_repo.get_unresolved_ips",
AsyncMock(return_value=ips),
), patch.object(
geo_cache,
"lookup_batch",
AsyncMock(return_value=geo_map),
) as mock_lookup, patch.object(
geo_cache,
"clear_neg_cache",
AsyncMock(),
) as mock_clear:
with (
patch(
"app.repositories.geo_cache_repo.get_unresolved_ips",
AsyncMock(return_value=ips),
),
patch.object(
geo_cache,
"lookup_batch",
AsyncMock(return_value=geo_map),
) as mock_lookup,
patch.object(
geo_cache,
"clear_neg_cache",
AsyncMock(),
) as mock_clear,
):
result = await geo_cache.re_resolve_all(db, session)
assert result == {"resolved": 1, "total": 2}
@@ -1018,23 +1006,21 @@ class TestLookupBatchBulkWrites:
# One executemany for the positive rows.
assert db.executemany.await_count >= 1
# High-level: execute() must NOT be called for the batch writes.
db.execute.assert_not_awaited()
# BEGIN IMMEDIATE is called for transaction wrapper.
assert db.execute.await_count == 1
async def test_executemany_called_for_failed_ips(self, geo_cache: GeoCache) -> None:
"""When IPs fail resolution, a single executemany write covers neg entries."""
ips = ["10.0.0.1", "10.0.0.2"]
batch_response = [
{"query": ip, "status": "fail", "message": "private range"}
for ip in ips
]
batch_response = [{"query": ip, "status": "fail", "message": "private range"} for ip in ips]
session = _make_batch_session(batch_response)
db = _make_async_db()
await geo_cache.lookup_batch(ips, session, db=db)
assert db.executemany.await_count >= 1
db.execute.assert_not_awaited()
# BEGIN IMMEDIATE is called for transaction wrapper.
assert db.execute.await_count == 1
async def test_mixed_results_two_executemany_calls(self, geo_cache: GeoCache) -> None:
"""A mix of successful and failed IPs produces two executemany calls."""
@@ -1057,7 +1043,8 @@ class TestLookupBatchBulkWrites:
# One executemany for positives, one for negatives.
assert db.executemany.await_count == 2
db.execute.assert_not_awaited()
# BEGIN IMMEDIATE is called for transaction wrapper.
assert db.execute.await_count == 1
# ---------------------------------------------------------------------------
@@ -1071,9 +1058,7 @@ class TestCacheMetrics:
async def test_cache_hit_increments_hits(self) -> None:
"""lookup() with a cached IP increments _hits."""
geo_cache = GeoCache(allow_http_fallback=True)
geo_cache._cache["1.1.1.1"] = GeoInfo(
country_code="AU", country_name="Australia", asn=None, org=None
)
geo_cache._cache["1.1.1.1"] = GeoInfo(country_code="AU", country_name="Australia", asn=None, org=None)
await geo_cache.lookup("1.1.1.1", MagicMock())
@@ -1269,4 +1254,3 @@ class TestLargeBanList:
assert len(result) == 1
assert "1.1.1.1" in result

View File

@@ -138,7 +138,7 @@ class TestListHistory:
new=AsyncMock(return_value=f2b_db_path),
):
result = await history_service.list_history("fake_socket")
assert result.pagination.total == 4
assert result.total == 4
assert len(result.items) == 4
async def test_time_range_filter_excludes_old_bans(
@@ -153,7 +153,7 @@ class TestListHistory:
result = await history_service.list_history(
"fake_socket", range_="24h"
)
assert result.pagination.total == 2
assert result.total == 2
async def test_jail_filter(self, f2b_db_path: str) -> None:
"""Jail filter restricts results to bans from that jail."""
@@ -162,7 +162,7 @@ class TestListHistory:
new=AsyncMock(return_value=f2b_db_path),
):
result = await history_service.list_history("fake_socket", jail="nginx")
assert result.pagination.total == 1
assert result.total == 1
assert result.items[0].jail == "nginx"
async def test_ip_prefix_filter(self, f2b_db_path: str) -> None:
@@ -174,7 +174,7 @@ class TestListHistory:
result = await history_service.list_history(
"fake_socket", ip_filter="1.2.3"
)
assert result.pagination.total == 2
assert result.total == 2
for item in result.items:
assert item.ip.startswith("1.2.3")
@@ -188,7 +188,7 @@ class TestListHistory:
"fake_socket", jail="sshd", ip_filter="1.2.3.4"
)
# 2 sshd bans for 1.2.3.4
assert result.pagination.total == 2
assert result.total == 2
async def test_origin_filter_selfblock(self, f2b_db_path: str) -> None:
"""Origin filter should include only selfblock entries."""
@@ -200,7 +200,7 @@ class TestListHistory:
"fake_socket", origin="selfblock"
)
assert result.pagination.total == 4
assert result.total == 4
assert all(item.jail != "blocklist-import" for item in result.items)
async def test_unknown_ip_returns_empty(self, f2b_db_path: str) -> None:
@@ -212,7 +212,7 @@ class TestListHistory:
result = await history_service.list_history(
"fake_socket", ip_filter="99.99.99.99"
)
assert result.pagination.total == 0
assert result.total == 0
assert result.items == []
async def test_failures_extracted_from_data(
@@ -226,7 +226,7 @@ class TestListHistory:
result = await history_service.list_history(
"fake_socket", ip_filter="5.6.7.8"
)
assert result.pagination.total == 1
assert result.total == 1
assert result.items[0].failures == 3
async def test_matches_extracted_from_data(
@@ -287,7 +287,7 @@ class TestListHistory:
result = await history_service.list_history(
"fake_socket", ip_filter="9.0.0.1"
)
assert result.pagination.total == 1
assert result.total == 1
item = result.items[0]
assert item.failures == 0
assert item.matches == []
@@ -301,10 +301,10 @@ class TestListHistory:
result = await history_service.list_history(
"fake_socket", page=1, page_size=2
)
assert result.pagination.total == 4
assert result.total == 4
assert len(result.items) == 2
assert result.pagination.page == 1
assert result.pagination.page_size == 2
assert result.page == 1
assert result.page_size == 2
async def test_source_archive_reads_from_archive(self, f2b_db_path: str, tmp_path: Path) -> None:
"""Using source='archive' reads from the BanGUI archive table."""
@@ -328,7 +328,7 @@ class TestListHistory:
db=db,
)
assert result.pagination.total == 1
assert result.total == 1
assert result.items[0].ip == "10.0.0.1"
@@ -363,8 +363,8 @@ class TestGetIpDetail:
assert result is not None
assert result.ip == "1.2.3.4"
assert result.pagination.total_bans == 2
assert result.pagination.total_failures == 10 # 5 + 5
assert result.total_bans == 2
assert result.total_failures == 10 # 5 + 5
async def test_timeline_ordered_newest_first(
self, f2b_db_path: str

View File

@@ -80,9 +80,8 @@ class TestNormaliseIp:
def test_normalise_ip_ipv4_mapped_ipv6_to_ipv4(self) -> None:
assert normalise_ip("::ffff:192.168.1.1") == "192.168.1.1"
def test_normalise_ip_invalid_raises_value_error(self) -> None:
with pytest.raises(ValueError):
normalise_ip("not-an-ip")
def test_normalise_ip_invalid_returns_unchanged(self) -> None:
assert normalise_ip("not-an-ip") == "not-an-ip"
class TestNormaliseNetwork:

View File

@@ -10,9 +10,13 @@ from unittest.mock import AsyncMock, patch
import pytest
from app.exceptions import Fail2BanConnectionError
from app.models.ban import ActiveBanListResponse, JailBannedIpsResponse
from app.models.ban_domain import DomainActiveBanList
from app.models.geo import GeoDetail, GeoInfo
from app.models.jail import JailDetailResponse, JailListResponse
from app.models.jail_domain import (
DomainJailBannedIps,
DomainJailDetail,
DomainJailList,
)
from app.services import ban_service, jail_service
from app.services.jail_service import JailNotFoundError, JailOperationError
from app.utils import jail_socket
@@ -109,9 +113,9 @@ class TestListJails:
with _patch_client(responses):
result = await jail_service.list_jails(_SOCKET, jail_service_state)
assert isinstance(result, JailListResponse)
assert isinstance(result, DomainJailList)
assert result.total == 1
assert result.jails[0].name == "sshd"
assert result.items[0].name == "sshd"
async def test_empty_jail_list(self, jail_service_state: JailServiceState) -> None:
"""list_jails returns empty response when no jails are active."""
@@ -120,7 +124,7 @@ class TestListJails:
result = await jail_service.list_jails(_SOCKET, jail_service_state)
assert result.total == 0
assert result.jails == []
assert result.items == []
async def test_jail_status_populated(self, jail_service_state: JailServiceState) -> None:
"""list_jails populates JailStatus with failed/banned counters."""
@@ -136,7 +140,7 @@ class TestListJails:
with _patch_client(responses):
result = await jail_service.list_jails(_SOCKET, jail_service_state)
jail = result.jails[0]
jail = result.items[0]
assert jail.status is not None
assert jail.status.currently_banned == 5
assert jail.status.total_banned == 50
@@ -155,7 +159,7 @@ class TestListJails:
with _patch_client(responses):
result = await jail_service.list_jails(_SOCKET, jail_service_state)
jail = result.jails[0]
jail = result.items[0]
assert jail.ban_time == 3600
assert jail.find_time == 300
assert jail.max_retry == 3
@@ -183,7 +187,7 @@ class TestListJails:
result = await jail_service.list_jails(_SOCKET, jail_service_state)
assert result.total == 2
names = {j.name for j in result.jails}
names = {j.name for j in result.items}
assert names == {"sshd", "nginx"}
async def test_connection_error_propagates(self, jail_service_state: JailServiceState) -> None:
@@ -223,7 +227,7 @@ class TestListJails:
result = await jail_service.list_jails(_SOCKET, jail_service_state)
# Verify the result uses the default values for backend and idle.
jail = result.jails[0]
jail = result.items[0]
assert jail.backend == "polling" # default
assert jail.idle is False # default
# Capability should now be cached as False.
@@ -249,7 +253,7 @@ class TestListJails:
result = await jail_service.list_jails(_SOCKET, jail_service_state)
# Verify real values are returned.
jail = result.jails[0]
jail = result.items[0]
assert jail.backend == "systemd" # real value
assert jail.idle is True # real value
# Capability should now be cached as True.
@@ -280,7 +284,7 @@ class TestListJails:
result = await jail_service.list_jails(_SOCKET, jail_service_state)
# Both jails should return default values (cached result is False).
for jail in result.jails:
for jail in result.items:
assert jail.backend == "polling"
assert jail.idle is False
@@ -329,11 +333,11 @@ class TestGetJail:
}
async def test_returns_jail_detail_response(self, jail_service_state: JailServiceState) -> None:
"""get_jail returns a JailDetailResponse."""
"""get_jail returns a DomainJailDetail."""
with _patch_client(self._full_responses()):
result = await jail_service.get_jail(_SOCKET, "sshd")
assert isinstance(result, JailDetailResponse)
assert isinstance(result, DomainJailDetail)
assert result.jail.name == "sshd"
async def test_log_paths_parsed(self, jail_service_state: JailServiceState) -> None:
@@ -453,9 +457,7 @@ class TestJailControls:
"reload|--all|[]|[['start', 'new'], ['start', 'nginx']]": (0, "OK"),
}
):
await jail_service.reload_all(
_SOCKET, include_jails=["new"], exclude_jails=["old"]
)
await jail_service.reload_all(_SOCKET, include_jails=["new"], exclude_jails=["old"])
async def test_reload_all_unknown_jail_raises_jail_not_found(self) -> None:
"""reload_all detects UnknownJailException and raises JailNotFoundError.
@@ -465,18 +467,19 @@ class TestJailControls:
test verifies that reload_all detects this and re-raises as
JailNotFoundError instead of the generic JailOperationError.
"""
with _patch_client(
{
"status": _make_global_status("sshd"),
"reload|--all|[]|[['start', 'airsonic-auth'], ['start', 'sshd']]": (
1,
Exception("UnknownJailException('airsonic-auth')"),
),
}
), pytest.raises(jail_service.JailNotFoundError) as exc_info:
await jail_service.reload_all(
_SOCKET, include_jails=["airsonic-auth"]
)
with (
_patch_client(
{
"status": _make_global_status("sshd"),
"reload|--all|[]|[['start', 'airsonic-auth'], ['start', 'sshd']]": (
1,
Exception("UnknownJailException('airsonic-auth')"),
),
}
),
pytest.raises(jail_service.JailNotFoundError) as exc_info,
):
await jail_service.reload_all(_SOCKET, include_jails=["airsonic-auth"])
assert exc_info.value.name == "airsonic-auth"
async def test_restart_sends_stop_command(self) -> None:
@@ -486,9 +489,7 @@ class TestJailControls:
async def test_restart_operation_error_raises(self) -> None:
"""restart() raises JailOperationError when fail2ban rejects the stop."""
with _patch_client({"stop": (1, Exception("cannot stop"))}), pytest.raises(
JailOperationError
):
with _patch_client({"stop": (1, Exception("cannot stop"))}), pytest.raises(JailOperationError):
await jail_service.restart(_SOCKET)
async def test_restart_connection_error_propagates(self) -> None:
@@ -496,9 +497,7 @@ class TestJailControls:
class _FailClient:
def __init__(self, **_kw: Any) -> None:
self.send = AsyncMock(
side_effect=Fail2BanConnectionError("no socket", _SOCKET)
)
self.send = AsyncMock(side_effect=Fail2BanConnectionError("no socket", _SOCKET))
with (
patch("app.services.jail_service.Fail2BanClient", _FailClient),
@@ -638,7 +637,7 @@ class TestGetActiveBans:
with _patch_client(responses):
result = await ban_service.get_active_bans(_SOCKET)
assert isinstance(result, ActiveBanListResponse)
assert isinstance(result, DomainActiveBanList)
assert result.total == 1
assert result.bans[0].ip == "1.2.3.4"
assert result.bans[0].jail == "sshd"
@@ -724,17 +723,18 @@ class TestGetActiveBans:
),
}
mock_geo = {"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")}
mock_batch = AsyncMock(return_value=mock_geo)
mock_cache = AsyncMock()
mock_cache.lookup_batch = AsyncMock(return_value=mock_geo)
with _patch_client(responses):
mock_session = AsyncMock()
result = await ban_service.get_active_bans(
_SOCKET,
http_session=mock_session,
geo_batch_lookup=mock_batch,
geo_cache=mock_cache,
)
mock_batch.assert_awaited_once()
mock_cache.lookup_batch.assert_awaited_once()
assert result.total == 1
assert result.bans[0].country == "DE"
@@ -748,14 +748,17 @@ class TestGetActiveBans:
),
}
failing_batch = AsyncMock(side_effect=RuntimeError("geo down"))
import aiohttp
mock_cache = AsyncMock()
mock_cache.lookup_batch = AsyncMock(side_effect=aiohttp.ClientError("geo down"))
with _patch_client(responses):
mock_session = AsyncMock()
result = await ban_service.get_active_bans(
_SOCKET,
http_session=mock_session,
geo_batch_lookup=failing_batch,
geo_cache=mock_cache,
)
assert result.total == 1
@@ -777,9 +780,7 @@ class TestGetActiveBans:
return GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None)
with _patch_client(responses):
result = await ban_service.get_active_bans(
_SOCKET, geo_enricher=_enricher
)
result = await ban_service.get_active_bans(_SOCKET, geo_enricher=_enricher)
assert result.total == 1
assert result.bans[0].country == "JP"
@@ -875,7 +876,7 @@ class TestLookupIp:
assert result.geo.org == "Acme"
async def test_http_session_uses_geo_service_lookup(self) -> None:
"""lookup_ip uses geo_service.lookup when http_session is provided."""
"""lookup_ip uses geo_enricher when provided."""
responses = {
"get|--all|banned|1.2.3.4": (0, []),
"status": _make_global_status("sshd"),
@@ -883,19 +884,16 @@ class TestLookupIp:
}
mock_geo = GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None)
mock_session = AsyncMock()
mock_enricher = AsyncMock(return_value=mock_geo)
with _patch_client(responses), patch(
"app.services.jail_service.geo_service.lookup",
AsyncMock(return_value=mock_geo),
) as mock_lookup:
with _patch_client(responses):
result = await jail_service.lookup_ip(
_SOCKET,
"1.2.3.4",
http_session=mock_session,
geo_enricher=mock_enricher,
)
mock_lookup.assert_awaited_once_with("1.2.3.4", mock_session)
mock_enricher.assert_awaited_once_with("1.2.3.4")
assert isinstance(result.geo, GeoDetail)
assert result.geo.country_code == "JP"
assert result.geo.country_name == "Japan"
@@ -985,7 +983,7 @@ class TestGetJailBannedIps:
with _patch_client(_banned_ips_responses()):
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd")
assert isinstance(result, JailBannedIpsResponse)
assert isinstance(result, DomainJailBannedIps)
async def test_total_reflects_all_entries(self) -> None:
"""total equals the number of parsed ban entries."""
@@ -996,12 +994,8 @@ class TestGetJailBannedIps:
async def test_page_1_returns_first_n_items(self) -> None:
"""page=1 with page_size=2 returns the first two entries."""
with _patch_client(
_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])
):
result = await jail_service.get_jail_banned_ips(
_SOCKET, "sshd", page=1, page_size=2
)
with _patch_client(_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])):
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", page=1, page_size=2)
assert len(result.items) == 2
assert result.items[0].ip == "1.2.3.4"
@@ -1010,12 +1004,8 @@ class TestGetJailBannedIps:
async def test_page_2_returns_remaining_items(self) -> None:
"""page=2 with page_size=2 returns the third entry."""
with _patch_client(
_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])
):
result = await jail_service.get_jail_banned_ips(
_SOCKET, "sshd", page=2, page_size=2
)
with _patch_client(_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])):
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", page=2, page_size=2)
assert len(result.items) == 1
assert result.items[0].ip == "9.10.11.12"
@@ -1023,9 +1013,7 @@ class TestGetJailBannedIps:
async def test_page_beyond_last_returns_empty_items(self) -> None:
"""Requesting a page past the end returns an empty items list."""
with _patch_client(_banned_ips_responses()):
result = await jail_service.get_jail_banned_ips(
_SOCKET, "sshd", page=99, page_size=25
)
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", page=99, page_size=25)
assert result.items == []
assert result.total == 2
@@ -1033,9 +1021,7 @@ class TestGetJailBannedIps:
async def test_search_filter_narrows_results(self) -> None:
"""search parameter filters entries by IP substring."""
with _patch_client(_banned_ips_responses()):
result = await jail_service.get_jail_banned_ips(
_SOCKET, "sshd", search="1.2.3"
)
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", search="1.2.3")
assert result.total == 1
assert result.items[0].ip == "1.2.3.4"
@@ -1044,18 +1030,14 @@ class TestGetJailBannedIps:
"""search filter is case-insensitive."""
entries = ["192.168.0.1\t2025-01-01 10:00:00 + 600 = 2025-01-01 10:10:00"]
with _patch_client(_banned_ips_responses(entries=entries)):
result = await jail_service.get_jail_banned_ips(
_SOCKET, "sshd", search="192.168"
)
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", search="192.168")
assert result.total == 1
async def test_search_no_match_returns_empty(self) -> None:
"""search that matches nothing returns empty items and total=0."""
with _patch_client(_banned_ips_responses()):
result = await jail_service.get_jail_banned_ips(
_SOCKET, "sshd", search="999.999"
)
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", search="999.999")
assert result.total == 0
assert result.items == []
@@ -1080,9 +1062,7 @@ class TestGetJailBannedIps:
"get|sshd|banip|--with-time": (0, entries),
}
with _patch_client(responses):
result = await jail_service.get_jail_banned_ips(
_SOCKET, "sshd", page=1, page_size=200
)
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", page=1, page_size=200)
assert len(result.items) <= 100
@@ -1090,30 +1070,22 @@ class TestGetJailBannedIps:
"""Geo enrichment is requested only for IPs in the current page."""
from unittest.mock import MagicMock
from app.services import geo_service
http_session = MagicMock()
geo_enrichment_ips: list[list[str]] = []
async def _mock_lookup_batch(
ips: list[str], _session: Any, **_kw: Any
) -> dict[str, Any]:
geo_enrichment_ips.append(list(ips))
return {}
mock_cache = MagicMock()
mock_cache.lookup_batch = AsyncMock(
side_effect=lambda ips, _session, **_kw: (geo_enrichment_ips.append(list(ips)), {})[-1]
)
with (
_patch_client(
_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])
),
patch.object(geo_service, "lookup_batch", side_effect=_mock_lookup_batch),
):
with _patch_client(_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])):
result = await jail_service.get_jail_banned_ips(
_SOCKET,
"sshd",
page=1,
page_size=2,
http_session=http_session,
geo_batch_lookup=geo_service.lookup_batch,
geo_cache=mock_cache,
)
# Only the 2-IP page slice should be passed to geo enrichment.
@@ -1123,6 +1095,7 @@ class TestGetJailBannedIps:
async def test_unknown_jail_raises_jail_not_found_error(self) -> None:
"""get_jail_banned_ips raises JailNotFoundError for unknown jail."""
# Simulate fail2ban returning an "unknown jail" error.
class _FakeClient:
def __init__(self, **_kw: Any) -> None:
@@ -1142,9 +1115,7 @@ class TestGetJailBannedIps:
class _FailClient:
def __init__(self, **_kw: Any) -> None:
self.send = AsyncMock(
side_effect=Fail2BanConnectionError("no socket", _SOCKET)
)
self.send = AsyncMock(side_effect=Fail2BanConnectionError("no socket", _SOCKET))
with (
patch("app.services.jail_service.Fail2BanClient", _FailClient),

View File

@@ -7,7 +7,8 @@ from unittest.mock import AsyncMock, patch
import pytest
from app.models.server import ServerSettingsResponse, ServerSettingsUpdate
from app.models.server import ServerSettingsUpdate
from app.models.server_domain import DomainServerSettingsResult
from app.services import server_service
from app.services.server_service import ServerOperationError
@@ -58,7 +59,7 @@ class TestGetSettings:
with _patch_client(_DEFAULT_RESPONSES):
result = await server_service.get_settings(_SOCKET)
assert isinstance(result, ServerSettingsResponse)
assert isinstance(result, DomainServerSettingsResult)
assert result.settings.log_level == "INFO"
assert result.settings.log_target == "/var/log/fail2ban.log"
assert result.settings.db_purge_age == 86400