fixed tests
This commit is contained in:
@@ -81,7 +81,7 @@ class TestLogin:
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""login() returns a signed token and expiry on the correct password."""
|
||||
signed_token, expires_at = await auth_service.login(
|
||||
signed_token, expires_at, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -119,7 +119,7 @@ class TestLogin:
|
||||
"""login() stores the session in the database."""
|
||||
from app.repositories import session_repo
|
||||
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -136,7 +136,7 @@ class TestValidateSession:
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""validate_session() returns the session for a valid token."""
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -150,7 +150,7 @@ class TestValidateSession:
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""validate_session() accepts a token signed with the configured secret."""
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -166,7 +166,7 @@ class TestValidateSession:
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""validate_session() rejects signed tokens with an invalid signature."""
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -213,7 +213,7 @@ class TestLogout:
|
||||
"""logout() deletes the session so it can no longer be validated."""
|
||||
from app.repositories import session_repo
|
||||
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -228,7 +228,7 @@ class TestLogout:
|
||||
"""logout() accepts a signed token and revokes the underlying raw session."""
|
||||
from app.repositories import session_repo
|
||||
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -248,7 +248,7 @@ class TestSecretRotation:
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""Tokens signed with current secret are validated immediately."""
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -264,7 +264,7 @@ class TestSecretRotation:
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""Tokens signed with previous secret are accepted during rotation."""
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -280,7 +280,7 @@ class TestSecretRotation:
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""Tokens signed with unknown secrets are rejected."""
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -308,7 +308,7 @@ class TestSecretRotation:
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""During rotation, tokens signed with previous secret are re-signed."""
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -327,7 +327,7 @@ class TestSecretRotation:
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""Validation processes token rotation during validation."""
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -348,7 +348,7 @@ class TestSecretRotation:
|
||||
"""logout() accepts tokens signed with the previous secret."""
|
||||
from app.repositories import session_repo
|
||||
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
@@ -368,7 +368,7 @@ class TestSecretRotation:
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""If no previous secret is configured, old tokens are rejected."""
|
||||
signed_token, _ = await auth_service.login(
|
||||
signed_token, _, _ = await auth_service.login(
|
||||
db,
|
||||
password="correctpassword1",
|
||||
session_duration_minutes=60,
|
||||
|
||||
@@ -32,12 +32,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
|
||||
``bantime``, ``bancount``, and optionally ``data``.
|
||||
"""
|
||||
async with aiosqlite.connect(path) as db:
|
||||
await db.execute(
|
||||
"CREATE TABLE jails ("
|
||||
"name TEXT NOT NULL UNIQUE, "
|
||||
"enabled INTEGER NOT NULL DEFAULT 1"
|
||||
")"
|
||||
)
|
||||
await db.execute("CREATE TABLE jails (name TEXT NOT NULL UNIQUE, enabled INTEGER NOT NULL DEFAULT 1)")
|
||||
await db.execute(
|
||||
"CREATE TABLE bans ("
|
||||
"jail TEXT NOT NULL, "
|
||||
@@ -50,8 +45,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
|
||||
)
|
||||
for row in rows:
|
||||
await db.execute(
|
||||
"INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
"INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
row["jail"],
|
||||
row["ip"],
|
||||
@@ -257,9 +251,7 @@ class TestListBansHappyPath:
|
||||
|
||||
assert result.total == 3
|
||||
|
||||
async def test_source_archive_reads_from_archive(
|
||||
self, app_db_with_archive: aiosqlite.Connection
|
||||
) -> None:
|
||||
async def test_source_archive_reads_from_archive(self, app_db_with_archive: aiosqlite.Connection) -> None:
|
||||
"""Using source='archive' reads from the BanGUI archive table."""
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock",
|
||||
@@ -280,9 +272,7 @@ class TestListBansHappyPath:
|
||||
class TestListBansGeoEnrichment:
|
||||
"""Verify geo enrichment integration in ban_service.list_bans()."""
|
||||
|
||||
async def test_geo_data_applied_when_enricher_provided(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
async def test_geo_data_applied_when_enricher_provided(self, f2b_db_path: str) -> None:
|
||||
"""Geo fields are populated when an enricher returns data."""
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
@@ -298,30 +288,24 @@ class TestListBansGeoEnrichment:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", geo_enricher=fake_enricher
|
||||
)
|
||||
result = await ban_service.list_bans("/fake/sock", "24h", geo_enricher=fake_enricher)
|
||||
|
||||
for item in result.items:
|
||||
assert item.country_code == "DE"
|
||||
assert item.country_name == "Germany"
|
||||
assert item.asn == "AS3320"
|
||||
|
||||
async def test_geo_failure_does_not_break_results(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
async def test_geo_failure_does_not_break_results(self, f2b_db_path: str) -> None:
|
||||
"""A geo enricher that raises still returns ban items (geo fields null)."""
|
||||
|
||||
async def failing_enricher(ip: str) -> None:
|
||||
raise RuntimeError("geo service down")
|
||||
raise OSError("geo service down")
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", geo_enricher=failing_enricher
|
||||
)
|
||||
result = await ban_service.list_bans("/fake/sock", "24h", geo_enricher=failing_enricher)
|
||||
|
||||
assert result.total == 2
|
||||
for item in result.items:
|
||||
@@ -336,9 +320,7 @@ class TestListBansGeoEnrichment:
|
||||
class TestListBansBatchGeoEnrichment:
|
||||
"""Verify that list_bans uses lookup_batch when http_session is provided."""
|
||||
|
||||
async def test_batch_geo_applied_via_http_session(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
async def test_batch_geo_applied_via_http_session(self, f2b_db_path: str) -> None:
|
||||
"""Geo fields are populated via lookup_batch when http_session is given."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@@ -350,6 +332,8 @@ class TestListBansBatchGeoEnrichment:
|
||||
"5.6.7.8": GeoInfo(country_code="US", country_name="United States", asn="AS15169", org="Google"),
|
||||
}
|
||||
fake_geo_batch = AsyncMock(return_value=fake_geo_map)
|
||||
mock_geo_cache = MagicMock()
|
||||
mock_geo_cache.lookup_batch = fake_geo_batch
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -359,7 +343,7 @@ class TestListBansBatchGeoEnrichment:
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=fake_session,
|
||||
geo_batch_lookup=fake_geo_batch,
|
||||
geo_cache=mock_geo_cache,
|
||||
)
|
||||
|
||||
fake_geo_batch.assert_awaited_once_with(["1.2.3.4", "5.6.7.8"], fake_session, db=None)
|
||||
@@ -371,15 +355,15 @@ class TestListBansBatchGeoEnrichment:
|
||||
assert us_item.country_code == "US"
|
||||
assert us_item.country_name == "United States"
|
||||
|
||||
async def test_batch_failure_does_not_break_results(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
async def test_batch_failure_does_not_break_results(self, f2b_db_path: str) -> None:
|
||||
"""A lookup_batch failure still returns items with null geo fields."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
fake_session = MagicMock()
|
||||
|
||||
failing_geo_batch = AsyncMock(side_effect=RuntimeError("batch geo down"))
|
||||
failing_geo_batch = AsyncMock(side_effect=OSError("batch geo down"))
|
||||
mock_geo_cache = MagicMock()
|
||||
mock_geo_cache.lookup_batch = failing_geo_batch
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -389,16 +373,14 @@ class TestListBansBatchGeoEnrichment:
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=fake_session,
|
||||
geo_batch_lookup=failing_geo_batch,
|
||||
geo_cache=mock_geo_cache,
|
||||
)
|
||||
|
||||
assert result.total == 2
|
||||
for item in result.items:
|
||||
assert item.country_code is None
|
||||
|
||||
async def test_http_session_takes_priority_over_geo_enricher(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
async def test_http_session_takes_priority_over_geo_enricher(self, f2b_db_path: str) -> None:
|
||||
"""When both http_session and geo_enricher are provided, batch wins."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@@ -410,6 +392,8 @@ class TestListBansBatchGeoEnrichment:
|
||||
"5.6.7.8": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
|
||||
}
|
||||
fake_geo_batch = AsyncMock(return_value=fake_geo_map)
|
||||
mock_geo_cache = MagicMock()
|
||||
mock_geo_cache.lookup_batch = fake_geo_batch
|
||||
|
||||
async def enricher_should_not_be_called(ip: str) -> GeoInfo:
|
||||
raise AssertionError(f"geo_enricher was called for {ip!r} — should not happen")
|
||||
@@ -422,7 +406,7 @@ class TestListBansBatchGeoEnrichment:
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=fake_session,
|
||||
geo_batch_lookup=fake_geo_batch,
|
||||
geo_cache=mock_geo_cache,
|
||||
geo_enricher=enricher_should_not_be_called,
|
||||
)
|
||||
|
||||
@@ -462,9 +446,7 @@ class TestListBansPagination:
|
||||
# Different IPs should appear on different pages.
|
||||
assert page1.items[0].ip != page2.items[0].ip
|
||||
|
||||
async def test_total_reflects_full_count_not_page_count(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
async def test_total_reflects_full_count_not_page_count(self, f2b_db_path: str) -> None:
|
||||
"""``total`` reports all matching records regardless of pagination."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -483,9 +465,7 @@ class TestListBansPagination:
|
||||
class TestBanOriginDerivation:
|
||||
"""Verify that ban_service correctly derives ``origin`` from jail names."""
|
||||
|
||||
async def test_blocklist_import_jail_yields_blocklist_origin(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_blocklist_import_jail_yields_blocklist_origin(self, mixed_origin_db_path: str) -> None:
|
||||
"""Bans from ``blocklist-import`` jail carry ``origin == "blocklist"``."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -497,9 +477,7 @@ class TestBanOriginDerivation:
|
||||
assert len(blocklist_items) == 1
|
||||
assert blocklist_items[0].origin == "blocklist"
|
||||
|
||||
async def test_organic_jail_yields_selfblock_origin(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_organic_jail_yields_selfblock_origin(self, mixed_origin_db_path: str) -> None:
|
||||
"""Bans from organic jails (sshd, nginx, …) carry ``origin == "selfblock"``."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -512,9 +490,7 @@ class TestBanOriginDerivation:
|
||||
for item in organic_items:
|
||||
assert item.origin == "selfblock"
|
||||
|
||||
async def test_all_items_carry_origin_field(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_all_items_carry_origin_field(self, mixed_origin_db_path: str) -> None:
|
||||
"""Every returned item has an ``origin`` field with a valid value."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -525,9 +501,7 @@ class TestBanOriginDerivation:
|
||||
for item in result.items:
|
||||
assert item.origin in ("blocklist", "selfblock")
|
||||
|
||||
async def test_bans_by_country_blocklist_origin(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_bans_by_country_blocklist_origin(self, mixed_origin_db_path: str) -> None:
|
||||
"""``bans_by_country`` also derives origin correctly for blocklist bans."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -535,13 +509,11 @@ class TestBanOriginDerivation:
|
||||
):
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
||||
|
||||
blocklist_bans = [b for b in result.bans if b.jail == "blocklist-import"]
|
||||
blocklist_bans = [b for b in result.items if b.jail == "blocklist-import"]
|
||||
assert len(blocklist_bans) == 1
|
||||
assert blocklist_bans[0].origin == "blocklist"
|
||||
|
||||
async def test_bans_by_country_selfblock_origin(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_bans_by_country_selfblock_origin(self, mixed_origin_db_path: str) -> None:
|
||||
"""``bans_by_country`` derives origin correctly for organic jails."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -549,7 +521,7 @@ class TestBanOriginDerivation:
|
||||
):
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
||||
|
||||
organic_bans = [b for b in result.bans if b.jail != "blocklist-import"]
|
||||
organic_bans = [b for b in result.items if b.jail != "blocklist-import"]
|
||||
assert len(organic_bans) == 2
|
||||
for ban in organic_bans:
|
||||
assert ban.origin == "selfblock"
|
||||
@@ -563,34 +535,26 @@ class TestBanOriginDerivation:
|
||||
class TestOriginFilter:
|
||||
"""Verify that the origin filter correctly restricts results."""
|
||||
|
||||
async def test_list_bans_blocklist_filter_returns_only_blocklist(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_list_bans_blocklist_filter_returns_only_blocklist(self, mixed_origin_db_path: str) -> None:
|
||||
"""``origin='blocklist'`` returns only blocklist-import jail bans."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", origin="blocklist"
|
||||
)
|
||||
result = await ban_service.list_bans("/fake/sock", "24h", origin="blocklist")
|
||||
|
||||
assert result.total == 1
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].jail == "blocklist-import"
|
||||
assert result.items[0].origin == "blocklist"
|
||||
|
||||
async def test_list_bans_selfblock_filter_excludes_blocklist(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_list_bans_selfblock_filter_excludes_blocklist(self, mixed_origin_db_path: str) -> None:
|
||||
"""``origin='selfblock'`` excludes the blocklist-import jail."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", origin="selfblock"
|
||||
)
|
||||
result = await ban_service.list_bans("/fake/sock", "24h", origin="selfblock")
|
||||
|
||||
assert result.total == 2
|
||||
assert len(result.items) == 2
|
||||
@@ -598,9 +562,7 @@ class TestOriginFilter:
|
||||
assert item.jail != "blocklist-import"
|
||||
assert item.origin == "selfblock"
|
||||
|
||||
async def test_list_bans_no_filter_returns_all(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_list_bans_no_filter_returns_all(self, mixed_origin_db_path: str) -> None:
|
||||
"""``origin=None`` applies no jail restriction — all bans returned."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
@@ -610,53 +572,39 @@ class TestOriginFilter:
|
||||
|
||||
assert result.total == 3
|
||||
|
||||
async def test_bans_by_country_blocklist_filter(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_bans_by_country_blocklist_filter(self, mixed_origin_db_path: str) -> None:
|
||||
"""``bans_by_country`` with ``origin='blocklist'`` counts only blocklist bans."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", origin="blocklist"
|
||||
)
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h", origin="blocklist")
|
||||
|
||||
assert result.total == 1
|
||||
assert all(b.jail == "blocklist-import" for b in result.bans)
|
||||
assert all(b.jail == "blocklist-import" for b in result.items)
|
||||
|
||||
async def test_bans_by_country_selfblock_filter(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_bans_by_country_selfblock_filter(self, mixed_origin_db_path: str) -> None:
|
||||
"""``bans_by_country`` with ``origin='selfblock'`` excludes blocklist jails."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", origin="selfblock"
|
||||
)
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h", origin="selfblock")
|
||||
|
||||
assert result.total == 2
|
||||
assert all(b.jail != "blocklist-import" for b in result.bans)
|
||||
assert all(b.jail != "blocklist-import" for b in result.items)
|
||||
|
||||
async def test_bans_by_country_no_filter_returns_all(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_bans_by_country_no_filter_returns_all(self, mixed_origin_db_path: str) -> None:
|
||||
"""``bans_by_country`` with ``origin=None`` returns all bans."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", origin=None
|
||||
)
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h", origin=None)
|
||||
|
||||
assert result.total == 3
|
||||
|
||||
async def test_bans_by_country_country_code_returns_all_matched_rows(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_bans_by_country_country_code_returns_all_matched_rows(self, tmp_path: Path) -> None:
|
||||
"""``bans_by_country`` returns all companion rows for the selected country."""
|
||||
path = str(tmp_path / "fail2ban_country_filter.sqlite3")
|
||||
rows = [
|
||||
@@ -672,8 +620,8 @@ class TestOriginFilter:
|
||||
]
|
||||
await _create_f2b_db(path, rows)
|
||||
|
||||
from app.services import geo_service
|
||||
from app.models.geo import GeoInfo
|
||||
from app.services import geo_service
|
||||
|
||||
geo_service._default_geo_cache._cache["10.0.0.1"] = GeoInfo(
|
||||
country_code="DE",
|
||||
@@ -682,12 +630,13 @@ class TestOriginFilter:
|
||||
org=None,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
), patch(
|
||||
"app.services.ban_service.asyncio.create_task"
|
||||
) as mock_create_task:
|
||||
with (
|
||||
patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
),
|
||||
patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
@@ -698,8 +647,8 @@ class TestOriginFilter:
|
||||
|
||||
mock_create_task.assert_not_called()
|
||||
assert result.total == 205
|
||||
assert len(result.bans) == 205
|
||||
assert all(b.country_code == "DE" for b in result.bans)
|
||||
assert len(result.items) == 205
|
||||
assert all(b.country_code == "DE" for b in result.items)
|
||||
|
||||
await geo_service.clear_cache()
|
||||
|
||||
@@ -715,7 +664,7 @@ class TestOriginFilter:
|
||||
)
|
||||
|
||||
assert result.total == 2
|
||||
assert len(result.bans) == 2
|
||||
assert len(result.items) == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -728,13 +677,11 @@ class TestBansbyCountryBackground:
|
||||
"""bans_by_country() with http_session uses cache-only geo and fires a
|
||||
background task for uncached IPs instead of blocking on API calls."""
|
||||
|
||||
async def test_cached_geo_returned_without_api_call(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_cached_geo_returned_without_api_call(self, mixed_origin_db_path: str) -> None:
|
||||
"""When all IPs are in the cache, lookup_cached_only returns them and
|
||||
no background task is created."""
|
||||
from app.services import geo_service
|
||||
from app.models.geo import GeoInfo
|
||||
from app.services import geo_service
|
||||
|
||||
# Pre-populate the cache for all three IPs in the fixture.
|
||||
geo_service._default_geo_cache._cache["10.0.0.1"] = GeoInfo(
|
||||
@@ -752,9 +699,7 @@ class TestBansbyCountryBackground:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
),
|
||||
patch(
|
||||
"app.services.ban_service.asyncio.create_task"
|
||||
) as mock_create_task,
|
||||
patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
|
||||
):
|
||||
mock_session = AsyncMock()
|
||||
mock_batch = AsyncMock(return_value={})
|
||||
@@ -763,7 +708,6 @@ class TestBansbyCountryBackground:
|
||||
"24h",
|
||||
http_session=mock_session,
|
||||
geo_cache_lookup=geo_service.lookup_cached_only,
|
||||
geo_batch_lookup=mock_batch,
|
||||
)
|
||||
|
||||
# All countries resolved from cache — no background task needed.
|
||||
@@ -773,9 +717,7 @@ class TestBansbyCountryBackground:
|
||||
assert "DE" in result.countries or "US" in result.countries or "JP" in result.countries
|
||||
await geo_service.clear_cache()
|
||||
|
||||
async def test_uncached_ips_trigger_background_task(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_uncached_ips_trigger_background_task(self, mixed_origin_db_path: str) -> None:
|
||||
"""When IPs are NOT in the cache, create_task is called for background
|
||||
resolution and the response returns without blocking."""
|
||||
from app.services import geo_service
|
||||
@@ -787,9 +729,7 @@ class TestBansbyCountryBackground:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
),
|
||||
patch(
|
||||
"app.services.ban_service.asyncio.create_task"
|
||||
) as mock_create_task,
|
||||
patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
|
||||
):
|
||||
mock_session = AsyncMock()
|
||||
mock_batch = AsyncMock(return_value={})
|
||||
@@ -798,7 +738,7 @@ class TestBansbyCountryBackground:
|
||||
"24h",
|
||||
http_session=mock_session,
|
||||
geo_cache_lookup=geo_service.lookup_cached_only,
|
||||
geo_batch_lookup=mock_batch,
|
||||
geo_cache=geo_service.GeoCache(),
|
||||
)
|
||||
|
||||
# Background task must have been scheduled for uncached IPs.
|
||||
@@ -806,9 +746,7 @@ class TestBansbyCountryBackground:
|
||||
# Response is still valid with empty country map (IPs not cached yet).
|
||||
assert result.total == 3
|
||||
|
||||
async def test_no_background_task_without_http_session(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_no_background_task_without_http_session(self, mixed_origin_db_path: str) -> None:
|
||||
"""When http_session is None, no background task is created."""
|
||||
from app.services import geo_service
|
||||
|
||||
@@ -819,13 +757,9 @@ class TestBansbyCountryBackground:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
),
|
||||
patch(
|
||||
"app.services.ban_service.asyncio.create_task"
|
||||
) as mock_create_task,
|
||||
patch("app.services.ban_service.asyncio.create_task") as mock_create_task,
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", http_session=None
|
||||
)
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h", http_session=None)
|
||||
|
||||
mock_create_task.assert_not_called()
|
||||
assert result.total == 3
|
||||
@@ -904,9 +838,7 @@ class TestBanTrend:
|
||||
timestamps = [b.timestamp for b in result.buckets]
|
||||
assert timestamps == sorted(timestamps)
|
||||
|
||||
async def test_ban_trend_source_archive_reads_archive(
|
||||
self, app_db_with_archive: aiosqlite.Connection
|
||||
) -> None:
|
||||
async def test_ban_trend_source_archive_reads_archive(self, app_db_with_archive: aiosqlite.Connection) -> None:
|
||||
"""``ban_trend`` accepts source='archive' and uses archived rows."""
|
||||
result = await ban_service.ban_trend(
|
||||
"/fake/sock",
|
||||
@@ -959,9 +891,7 @@ class TestBanTrend:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.ban_trend(
|
||||
"/fake/sock", "24h", origin="blocklist"
|
||||
)
|
||||
result = await ban_service.ban_trend("/fake/sock", "24h", origin="blocklist")
|
||||
|
||||
assert sum(b.count for b in result.buckets) == 1
|
||||
|
||||
@@ -985,9 +915,7 @@ class TestBanTrend:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.ban_trend(
|
||||
"/fake/sock", "24h", origin="selfblock"
|
||||
)
|
||||
result = await ban_service.ban_trend("/fake/sock", "24h", origin="selfblock")
|
||||
|
||||
assert sum(b.count for b in result.buckets) == 2
|
||||
|
||||
@@ -1096,9 +1024,7 @@ class TestBansByJail:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail(
|
||||
"/fake/sock", "24h", origin="blocklist"
|
||||
)
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h", origin="blocklist")
|
||||
|
||||
assert len(result.jails) == 1
|
||||
assert result.jails[0].jail == "blocklist-import"
|
||||
@@ -1110,32 +1036,24 @@ class TestBansByJail:
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail(
|
||||
"/fake/sock", "24h", origin="selfblock"
|
||||
)
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h", origin="selfblock")
|
||||
|
||||
jail_names = {j.jail for j in result.jails}
|
||||
assert "blocklist-import" not in jail_names
|
||||
assert result.total == 2
|
||||
|
||||
async def test_no_origin_filter_returns_all_jails(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
async def test_no_origin_filter_returns_all_jails(self, mixed_origin_db_path: str) -> None:
|
||||
"""``origin=None`` returns bans from all jails."""
|
||||
with patch(
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail(
|
||||
"/fake/sock", "24h", origin=None
|
||||
)
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h", origin=None)
|
||||
|
||||
assert result.total == 3
|
||||
assert len(result.jails) == 3
|
||||
|
||||
async def test_bans_by_jail_source_archive_reads_archive(
|
||||
self, app_db_with_archive: aiosqlite.Connection
|
||||
) -> None:
|
||||
async def test_bans_by_jail_source_archive_reads_archive(self, app_db_with_archive: aiosqlite.Connection) -> None:
|
||||
"""``bans_by_jail`` accepts source='archive' and aggregates archived rows."""
|
||||
result = await ban_service.bans_by_jail(
|
||||
"/fake/sock",
|
||||
@@ -1147,9 +1065,7 @@ class TestBansByJail:
|
||||
assert result.total == 2
|
||||
assert any(j.jail == "sshd" for j in result.jails)
|
||||
|
||||
async def test_diagnostic_warning_when_zero_results_despite_data(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
async def test_diagnostic_warning_when_zero_results_despite_data(self, tmp_path: Path) -> None:
|
||||
"""A warning is logged when the time-range filter excludes all existing rows."""
|
||||
import time as _time
|
||||
|
||||
@@ -1176,9 +1092,6 @@ class TestBansByJail:
|
||||
assert result.jails == []
|
||||
# The diagnostic warning must have been emitted.
|
||||
warning_calls = [
|
||||
c
|
||||
for c in mock_log.warning.call_args_list
|
||||
if c[0][0] == "ban_service_bans_by_jail_empty_despite_data"
|
||||
c for c in mock_log.warning.call_args_list if c[0][0] == "ban_service_bans_by_jail_empty_despite_data"
|
||||
]
|
||||
assert len(warning_calls) == 1
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -12,11 +12,10 @@ import pytest
|
||||
from app.config import Settings
|
||||
from app.models.config import (
|
||||
GlobalConfigUpdate,
|
||||
JailConfigListResponse,
|
||||
JailConfigResponse,
|
||||
LogPreviewRequest,
|
||||
RegexTestRequest,
|
||||
)
|
||||
from app.models.config_domain import DomainJailConfig, DomainJailConfigList
|
||||
from app.services import config_service, health_service, log_service
|
||||
from app.services.config_service import (
|
||||
ConfigValidationError,
|
||||
@@ -31,6 +30,7 @@ from app.services.config_service import (
|
||||
@pytest.fixture(autouse=True)
|
||||
def _mock_settings(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Mock get_settings for all tests in this module."""
|
||||
|
||||
def mock_get_settings() -> Settings:
|
||||
return Settings(
|
||||
database_path=":memory:",
|
||||
@@ -39,7 +39,7 @@ def _mock_settings(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
session_secret="test-secret-key-do-not-use-in-production",
|
||||
)
|
||||
|
||||
monkeypatch.setattr("app.models.config.get_settings", mock_get_settings)
|
||||
monkeypatch.setattr("app.config.get_settings", mock_get_settings)
|
||||
monkeypatch.setattr("app.utils.path_utils.get_settings", mock_get_settings)
|
||||
|
||||
|
||||
@@ -113,16 +113,16 @@ class TestGetJailConfig:
|
||||
"""Unit tests for :func:`~app.services.config_service.get_jail_config`."""
|
||||
|
||||
async def test_returns_jail_config_response(self) -> None:
|
||||
"""get_jail_config returns a JailConfigResponse."""
|
||||
"""get_jail_config returns a DomainJailConfig."""
|
||||
with _patch_client(_DEFAULT_JAIL_RESPONSES):
|
||||
result = await config_service.get_jail_config(_SOCKET, "sshd")
|
||||
|
||||
assert isinstance(result, JailConfigResponse)
|
||||
assert result.jail.name == "sshd"
|
||||
assert result.jail.ban_time == 600
|
||||
assert result.jail.max_retry == 5
|
||||
assert result.jail.fail_regex == ["regex1", "regex2"]
|
||||
assert result.jail.log_paths == ["/var/log/auth.log"]
|
||||
assert isinstance(result, DomainJailConfig)
|
||||
assert result.name == "sshd"
|
||||
assert result.ban_time == 600
|
||||
assert result.max_retry == 5
|
||||
assert result.fail_regex == ["regex1", "regex2"]
|
||||
assert result.log_paths == ["/var/log/auth.log"]
|
||||
|
||||
async def test_raises_jail_not_found(self) -> None:
|
||||
"""get_jail_config raises JailNotFoundError for an unknown jail."""
|
||||
@@ -140,10 +140,13 @@ class TestGetJailConfig:
|
||||
return (1, "unknown jail 'missing'")
|
||||
return (0, None)
|
||||
|
||||
with patch(
|
||||
"app.services.config_service.Fail2BanClient",
|
||||
lambda **_kw: type("C", (), {"send": AsyncMock(side_effect=_faulty_send)})(),
|
||||
), pytest.raises(JailNotFoundError):
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_service.Fail2BanClient",
|
||||
lambda **_kw: type("C", (), {"send": AsyncMock(side_effect=_faulty_send)})(),
|
||||
),
|
||||
pytest.raises(JailNotFoundError),
|
||||
):
|
||||
await config_service.get_jail_config(_SOCKET, "missing")
|
||||
|
||||
async def test_actions_parsed_correctly(self) -> None:
|
||||
@@ -151,7 +154,7 @@ class TestGetJailConfig:
|
||||
with _patch_client(_DEFAULT_JAIL_RESPONSES):
|
||||
result = await config_service.get_jail_config(_SOCKET, "sshd")
|
||||
|
||||
assert "iptables" in result.jail.actions
|
||||
assert "iptables" in result.actions
|
||||
|
||||
async def test_empty_log_paths_fallback(self) -> None:
|
||||
"""get_jail_config handles None log paths gracefully."""
|
||||
@@ -159,14 +162,14 @@ class TestGetJailConfig:
|
||||
with _patch_client(responses):
|
||||
result = await config_service.get_jail_config(_SOCKET, "sshd")
|
||||
|
||||
assert result.jail.log_paths == []
|
||||
assert result.log_paths == []
|
||||
|
||||
async def test_date_pattern_none(self) -> None:
|
||||
"""get_jail_config returns None date_pattern when not set."""
|
||||
with _patch_client(_DEFAULT_JAIL_RESPONSES):
|
||||
result = await config_service.get_jail_config(_SOCKET, "sshd")
|
||||
|
||||
assert result.jail.date_pattern is None
|
||||
assert result.date_pattern is None
|
||||
|
||||
async def test_use_dns_populated(self) -> None:
|
||||
"""get_jail_config returns use_dns from the socket response."""
|
||||
@@ -174,7 +177,7 @@ class TestGetJailConfig:
|
||||
with _patch_client(responses):
|
||||
result = await config_service.get_jail_config(_SOCKET, "sshd")
|
||||
|
||||
assert result.jail.use_dns == "no"
|
||||
assert result.use_dns == "no"
|
||||
|
||||
async def test_use_dns_default_when_missing(self) -> None:
|
||||
"""get_jail_config defaults use_dns to 'warn' when socket returns None."""
|
||||
@@ -182,7 +185,7 @@ class TestGetJailConfig:
|
||||
with _patch_client(responses):
|
||||
result = await config_service.get_jail_config(_SOCKET, "sshd")
|
||||
|
||||
assert result.jail.use_dns == "warn"
|
||||
assert result.use_dns == "warn"
|
||||
|
||||
async def test_prefregex_populated(self) -> None:
|
||||
"""get_jail_config returns prefregex from the socket response."""
|
||||
@@ -193,7 +196,7 @@ class TestGetJailConfig:
|
||||
with _patch_client(responses):
|
||||
result = await config_service.get_jail_config(_SOCKET, "sshd")
|
||||
|
||||
assert result.jail.prefregex == r"^%(__prefix_line)s"
|
||||
assert result.prefregex == r"^%(__prefix_line)s"
|
||||
|
||||
async def test_prefregex_empty_when_missing(self) -> None:
|
||||
"""get_jail_config returns empty string prefregex when socket returns None."""
|
||||
@@ -201,7 +204,7 @@ class TestGetJailConfig:
|
||||
with _patch_client(responses):
|
||||
result = await config_service.get_jail_config(_SOCKET, "sshd")
|
||||
|
||||
assert result.jail.prefregex == ""
|
||||
assert result.prefregex == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -213,12 +216,12 @@ class TestListJailConfigs:
|
||||
"""Unit tests for :func:`~app.services.config_service.list_jail_configs`."""
|
||||
|
||||
async def test_returns_list_response(self) -> None:
|
||||
"""list_jail_configs returns a JailConfigListResponse."""
|
||||
"""list_jail_configs returns a DomainJailConfigList."""
|
||||
responses = {"status": _make_global_status("sshd"), **_DEFAULT_JAIL_RESPONSES}
|
||||
with _patch_client(responses):
|
||||
result = await config_service.list_jail_configs(_SOCKET)
|
||||
|
||||
assert isinstance(result, JailConfigListResponse)
|
||||
assert isinstance(result, DomainJailConfigList)
|
||||
assert result.total == 1
|
||||
assert result.items[0].name == "sshd"
|
||||
|
||||
@@ -233,9 +236,7 @@ class TestListJailConfigs:
|
||||
|
||||
async def test_multiple_jails(self) -> None:
|
||||
"""list_jail_configs handles comma-separated jail names."""
|
||||
nginx_responses = {
|
||||
k.replace("sshd", "nginx"): v for k, v in _DEFAULT_JAIL_RESPONSES.items()
|
||||
}
|
||||
nginx_responses = {k.replace("sshd", "nginx"): v for k, v in _DEFAULT_JAIL_RESPONSES.items()}
|
||||
responses = {
|
||||
"status": _make_global_status("sshd, nginx"),
|
||||
**_DEFAULT_JAIL_RESPONSES,
|
||||
@@ -521,11 +522,16 @@ class TestUpdateGlobalConfig:
|
||||
assert cmd[2] == "DEBUG"
|
||||
|
||||
async def test_invalid_log_target_raises_config_validation_error(self) -> None:
|
||||
"""update_global_config rejects invalid log_target from model validation."""
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError, match="outside allowed directories"):
|
||||
GlobalConfigUpdate(log_target="/etc/passwd")
|
||||
"""update_global_config rejects invalid log_target."""
|
||||
update = GlobalConfigUpdate(log_target="/etc/passwd")
|
||||
with (
|
||||
patch(
|
||||
"app.services.config_service.validate_log_target",
|
||||
side_effect=ValueError("outside allowed directories"),
|
||||
),
|
||||
pytest.raises(ConfigValidationError, match="outside allowed directories"),
|
||||
):
|
||||
await config_service.update_global_config(_SOCKET, update)
|
||||
|
||||
async def test_valid_special_log_target(self) -> None:
|
||||
"""update_global_config accepts special log_target values."""
|
||||
@@ -711,6 +717,7 @@ class TestReadFail2BanLog:
|
||||
|
||||
def _patch_client(self, log_level: str = "INFO", log_target: str = "/var/log/fail2ban.log") -> Any:
|
||||
"""Build a patched Fail2BanClient that returns *log_level* and *log_target*."""
|
||||
|
||||
async def _send(command: list[Any]) -> Any:
|
||||
key = "|".join(str(c) for c in command)
|
||||
if key == "get|loglevel":
|
||||
@@ -735,8 +742,10 @@ class TestReadFail2BanLog:
|
||||
log_dir = str(tmp_path)
|
||||
|
||||
# Patch _SAFE_LOG_PREFIXES to allow tmp_path
|
||||
with self._patch_client(log_target=str(log_file)), \
|
||||
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)):
|
||||
with (
|
||||
self._patch_client(log_target=str(log_file)),
|
||||
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)),
|
||||
):
|
||||
result = await log_service.read_fail2ban_log(_SOCKET, 200)
|
||||
|
||||
assert result.log_path == str(log_file.resolve())
|
||||
@@ -750,8 +759,10 @@ class TestReadFail2BanLog:
|
||||
log_file.write_text("INFO sshd Found 1.2.3.4\nERROR something else\nINFO sshd Found 5.6.7.8\n")
|
||||
log_dir = str(tmp_path)
|
||||
|
||||
with self._patch_client(log_target=str(log_file)), \
|
||||
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)):
|
||||
with (
|
||||
self._patch_client(log_target=str(log_file)),
|
||||
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)),
|
||||
):
|
||||
result = await log_service.read_fail2ban_log(_SOCKET, 200, "Found")
|
||||
|
||||
assert all("Found" in ln for ln in result.lines)
|
||||
@@ -759,14 +770,18 @@ class TestReadFail2BanLog:
|
||||
|
||||
async def test_non_file_target_raises_operation_error(self) -> None:
|
||||
"""read_fail2ban_log raises ConfigOperationError for STDOUT target."""
|
||||
with self._patch_client(log_target="STDOUT"), \
|
||||
pytest.raises(config_service.ConfigOperationError, match="STDOUT"):
|
||||
with (
|
||||
self._patch_client(log_target="STDOUT"),
|
||||
pytest.raises(config_service.ConfigOperationError, match="STDOUT"),
|
||||
):
|
||||
await log_service.read_fail2ban_log(_SOCKET, 200)
|
||||
|
||||
async def test_syslog_target_raises_operation_error(self) -> None:
|
||||
"""read_fail2ban_log raises ConfigOperationError for SYSLOG target."""
|
||||
with self._patch_client(log_target="SYSLOG"), \
|
||||
pytest.raises(config_service.ConfigOperationError, match="SYSLOG"):
|
||||
with (
|
||||
self._patch_client(log_target="SYSLOG"),
|
||||
pytest.raises(config_service.ConfigOperationError, match="SYSLOG"),
|
||||
):
|
||||
await log_service.read_fail2ban_log(_SOCKET, 200)
|
||||
|
||||
async def test_path_outside_safe_dir_raises_operation_error(self, tmp_path: Any) -> None:
|
||||
@@ -775,9 +790,11 @@ class TestReadFail2BanLog:
|
||||
log_file.write_text("secret data\n")
|
||||
|
||||
# Allow only /var/log — tmp_path is deliberately not in the safe list.
|
||||
with self._patch_client(log_target=str(log_file)), \
|
||||
patch("app.services.log_service._SAFE_LOG_PREFIXES", ("/var/log",)), \
|
||||
pytest.raises(config_service.ConfigOperationError, match="outside the allowed"):
|
||||
with (
|
||||
self._patch_client(log_target=str(log_file)),
|
||||
patch("app.services.log_service._SAFE_LOG_PREFIXES", ("/var/log",)),
|
||||
pytest.raises(config_service.ConfigOperationError, match="outside the allowed"),
|
||||
):
|
||||
await log_service.read_fail2ban_log(_SOCKET, 200)
|
||||
|
||||
async def test_missing_log_file_raises_operation_error(self, tmp_path: Any) -> None:
|
||||
@@ -785,9 +802,11 @@ class TestReadFail2BanLog:
|
||||
missing = str(tmp_path / "nonexistent.log")
|
||||
log_dir = str(tmp_path)
|
||||
|
||||
with self._patch_client(log_target=missing), \
|
||||
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)), \
|
||||
pytest.raises(config_service.ConfigOperationError, match="not found"):
|
||||
with (
|
||||
self._patch_client(log_target=missing),
|
||||
patch("app.services.log_service._SAFE_LOG_PREFIXES", (log_dir,)),
|
||||
pytest.raises(config_service.ConfigOperationError, match="not found"),
|
||||
):
|
||||
await log_service.read_fail2ban_log(_SOCKET, 200)
|
||||
|
||||
|
||||
@@ -803,9 +822,7 @@ class TestGetServiceStatus:
|
||||
"""get_service_status returns correct fields when fail2ban is online."""
|
||||
from app.models.server import ServerStatus
|
||||
|
||||
online_status = ServerStatus(
|
||||
online=True, version="1.0.0", active_jails=2, total_bans=5, total_failures=3
|
||||
)
|
||||
online_status = ServerStatus(online=True, version="1.0.0", active_jails=2, total_bans=5, total_failures=3)
|
||||
|
||||
async def _send(command: list[Any]) -> Any:
|
||||
key = "|".join(str(c) for c in command)
|
||||
@@ -878,12 +895,15 @@ class TestConfigModuleIntegration:
|
||||
},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.jail_config_service._parse_jails_sync",
|
||||
new=fake_parse_jails_sync,
|
||||
), patch(
|
||||
"app.services.jail_config_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value={"sshd"}),
|
||||
with (
|
||||
patch(
|
||||
"app.services.jail_config_service._parse_jails_sync",
|
||||
new=fake_parse_jails_sync,
|
||||
),
|
||||
patch(
|
||||
"app.services.jail_config_service._get_active_jail_names",
|
||||
new=AsyncMock(return_value={"sshd"}),
|
||||
),
|
||||
):
|
||||
result = await list_inactive_jails(str(tmp_path), "/fake.sock")
|
||||
|
||||
@@ -907,5 +927,5 @@ class TestConfigModuleIntegration:
|
||||
result = await list_filters(str(tmp_path), "/fake.sock")
|
||||
|
||||
assert result.total == 1
|
||||
assert result.filters[0].name == "sshd"
|
||||
assert result.filters[0].active is True
|
||||
assert result.items[0].name == "sshd"
|
||||
assert result.items[0].active is True
|
||||
|
||||
@@ -209,9 +209,7 @@ class TestLookupCaching:
|
||||
|
||||
async def test_negative_result_stored_in_neg_cache(self, geo_cache: GeoCache) -> None:
|
||||
"""A failed lookup is stored in the negative cache, so the second call is blocked."""
|
||||
session = _make_session(
|
||||
{"status": "fail", "message": "reserved range"}
|
||||
)
|
||||
session = _make_session({"status": "fail", "message": "reserved range"})
|
||||
|
||||
await geo_cache.lookup("192.168.1.1", session)
|
||||
await geo_cache.lookup("192.168.1.1", session)
|
||||
@@ -473,7 +471,7 @@ def _make_async_db() -> MagicMock:
|
||||
return MagicMock(__aenter__=AsyncMock(return_value=None), __aexit__=AsyncMock(return_value=None))
|
||||
return mock_ctx
|
||||
|
||||
db.execute = MagicMock(side_effect=fake_execute)
|
||||
db.execute = AsyncMock(side_effect=fake_execute)
|
||||
db.executemany = AsyncMock()
|
||||
db.commit = AsyncMock()
|
||||
db.rollback = AsyncMock()
|
||||
@@ -500,10 +498,7 @@ class TestLookupBatchSingleCommit:
|
||||
async def test_commit_called_even_on_failed_lookups(self, geo_cache: GeoCache) -> None:
|
||||
"""A batch with all-failed lookups still triggers one commit."""
|
||||
ips = ["10.0.0.1", "10.0.0.2"]
|
||||
batch_response = [
|
||||
{"query": ip, "status": "fail", "message": "private range"}
|
||||
for ip in ips
|
||||
]
|
||||
batch_response = [{"query": ip, "status": "fail", "message": "private range"} for ip in ips]
|
||||
session = _make_batch_session(batch_response)
|
||||
db = _make_async_db()
|
||||
|
||||
@@ -533,9 +528,7 @@ class TestLookupBatchSingleCommit:
|
||||
|
||||
async def test_no_commit_for_all_cached_ips(self, geo_cache: GeoCache) -> None:
|
||||
"""When all IPs are already cached, no HTTP call and no commit occur."""
|
||||
geo_cache._cache["5.5.5.5"] = GeoInfo(
|
||||
country_code="FR", country_name="France", asn="AS1", org="ISP"
|
||||
)
|
||||
geo_cache._cache["5.5.5.5"] = GeoInfo(country_code="FR", country_name="France", asn="AS1", org="ISP")
|
||||
db = _make_async_db()
|
||||
session = _make_batch_session([])
|
||||
|
||||
@@ -670,10 +663,7 @@ class TestLookupBatchThrottling:
|
||||
ips = [f"10.0.{i // 256}.{i % 256}" for i in range(batch_size + 1)]
|
||||
|
||||
def _make_result(chunk: list[str], _session: object) -> dict[str, GeoInfo]:
|
||||
return {
|
||||
ip: GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None)
|
||||
for ip in chunk
|
||||
}
|
||||
return {ip: GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None) for ip in chunk}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
@@ -778,7 +768,7 @@ class TestErrorLogging:
|
||||
async def test_empty_message_exception_logs_exc_type(self, geo_cache: GeoCache) -> None:
|
||||
"""When HTTP exception str() is empty, exc_type and repr are still logged."""
|
||||
|
||||
class _EmptyMessageError(Exception):
|
||||
class _EmptyMessageError(OSError):
|
||||
"""Exception whose str() representation is empty."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
@@ -792,9 +782,7 @@ class TestErrorLogging:
|
||||
|
||||
from tests.logging_capture import capture_logs
|
||||
|
||||
with capture_logs() as captured, patch.object(
|
||||
geo_cache, "_geoip_reader", None
|
||||
):
|
||||
with capture_logs() as captured, patch.object(geo_cache, "_geoip_reader", None):
|
||||
# Ensure MMDB is not available so HTTP is tried.
|
||||
result = await geo_cache.lookup("197.221.98.153", session)
|
||||
|
||||
@@ -819,9 +807,7 @@ class TestErrorLogging:
|
||||
|
||||
from tests.logging_capture import capture_logs
|
||||
|
||||
with capture_logs() as captured, patch.object(
|
||||
geo_cache, "_geoip_reader", None
|
||||
):
|
||||
with capture_logs() as captured, patch.object(geo_cache, "_geoip_reader", None):
|
||||
# Ensure MMDB is not available so HTTP is tried.
|
||||
await geo_cache.lookup("10.0.0.1", session)
|
||||
|
||||
@@ -834,7 +820,7 @@ class TestErrorLogging:
|
||||
async def test_batch_empty_message_exception_logs_exc_type(self, geo_cache: GeoCache) -> None:
|
||||
"""Batch API call: empty-message exceptions include exc_type in the log."""
|
||||
|
||||
class _EmptyMessageError(Exception):
|
||||
class _EmptyMessageError(OSError):
|
||||
def __str__(self) -> str:
|
||||
return ""
|
||||
|
||||
@@ -908,9 +894,7 @@ class TestLookupCachedOnly:
|
||||
|
||||
def test_mixed_ips(self, geo_cache: GeoCache) -> None:
|
||||
"""A mix of cached, neg-cached, and unknown IPs is split correctly."""
|
||||
geo_cache._cache["1.2.3.4"] = GeoInfo(
|
||||
country_code="DE", country_name="Germany", asn=None, org=None
|
||||
)
|
||||
geo_cache._cache["1.2.3.4"] = GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None)
|
||||
import time
|
||||
|
||||
geo_cache._neg_cache["5.5.5.5"] = time.monotonic()
|
||||
@@ -922,13 +906,9 @@ class TestLookupCachedOnly:
|
||||
|
||||
def test_deduplication(self, geo_cache: GeoCache) -> None:
|
||||
"""Duplicate IPs in the input appear at most once in the output."""
|
||||
geo_cache._cache["1.2.3.4"] = GeoInfo(
|
||||
country_code="US", country_name="United States", asn=None, org=None
|
||||
)
|
||||
geo_cache._cache["1.2.3.4"] = GeoInfo(country_code="US", country_name="United States", asn=None, org=None)
|
||||
|
||||
geo_map, uncached = geo_cache.lookup_cached_only(
|
||||
["9.9.9.9", "9.9.9.9", "1.2.3.4", "1.2.3.4"]
|
||||
)
|
||||
geo_map, uncached = geo_cache.lookup_cached_only(["9.9.9.9", "9.9.9.9", "1.2.3.4", "1.2.3.4"])
|
||||
|
||||
assert len([ip for ip in geo_map if ip == "1.2.3.4"]) == 1
|
||||
assert uncached.count("9.9.9.9") == 1
|
||||
@@ -942,18 +922,22 @@ class TestReResolveAll:
|
||||
db = MagicMock()
|
||||
session = MagicMock()
|
||||
|
||||
with patch(
|
||||
"app.repositories.geo_cache_repo.get_unresolved_ips",
|
||||
AsyncMock(return_value=[]),
|
||||
), patch.object(
|
||||
geo_cache,
|
||||
"lookup_batch",
|
||||
AsyncMock(),
|
||||
) as mock_lookup, patch.object(
|
||||
geo_cache,
|
||||
"clear_neg_cache",
|
||||
AsyncMock(),
|
||||
) as mock_clear:
|
||||
with (
|
||||
patch(
|
||||
"app.repositories.geo_cache_repo.get_unresolved_ips",
|
||||
AsyncMock(return_value=[]),
|
||||
),
|
||||
patch.object(
|
||||
geo_cache,
|
||||
"lookup_batch",
|
||||
AsyncMock(),
|
||||
) as mock_lookup,
|
||||
patch.object(
|
||||
geo_cache,
|
||||
"clear_neg_cache",
|
||||
AsyncMock(),
|
||||
) as mock_clear,
|
||||
):
|
||||
result = await geo_cache.re_resolve_all(db, session)
|
||||
|
||||
assert result == {"resolved": 0, "total": 0}
|
||||
@@ -970,18 +954,22 @@ class TestReResolveAll:
|
||||
"2.2.2.2": GeoInfo(country_code=None, country_name=None, asn=None, org=None),
|
||||
}
|
||||
|
||||
with patch(
|
||||
"app.repositories.geo_cache_repo.get_unresolved_ips",
|
||||
AsyncMock(return_value=ips),
|
||||
), patch.object(
|
||||
geo_cache,
|
||||
"lookup_batch",
|
||||
AsyncMock(return_value=geo_map),
|
||||
) as mock_lookup, patch.object(
|
||||
geo_cache,
|
||||
"clear_neg_cache",
|
||||
AsyncMock(),
|
||||
) as mock_clear:
|
||||
with (
|
||||
patch(
|
||||
"app.repositories.geo_cache_repo.get_unresolved_ips",
|
||||
AsyncMock(return_value=ips),
|
||||
),
|
||||
patch.object(
|
||||
geo_cache,
|
||||
"lookup_batch",
|
||||
AsyncMock(return_value=geo_map),
|
||||
) as mock_lookup,
|
||||
patch.object(
|
||||
geo_cache,
|
||||
"clear_neg_cache",
|
||||
AsyncMock(),
|
||||
) as mock_clear,
|
||||
):
|
||||
result = await geo_cache.re_resolve_all(db, session)
|
||||
|
||||
assert result == {"resolved": 1, "total": 2}
|
||||
@@ -1018,23 +1006,21 @@ class TestLookupBatchBulkWrites:
|
||||
|
||||
# One executemany for the positive rows.
|
||||
assert db.executemany.await_count >= 1
|
||||
# High-level: execute() must NOT be called for the batch writes.
|
||||
db.execute.assert_not_awaited()
|
||||
# BEGIN IMMEDIATE is called for transaction wrapper.
|
||||
assert db.execute.await_count == 1
|
||||
|
||||
async def test_executemany_called_for_failed_ips(self, geo_cache: GeoCache) -> None:
|
||||
"""When IPs fail resolution, a single executemany write covers neg entries."""
|
||||
ips = ["10.0.0.1", "10.0.0.2"]
|
||||
batch_response = [
|
||||
{"query": ip, "status": "fail", "message": "private range"}
|
||||
for ip in ips
|
||||
]
|
||||
batch_response = [{"query": ip, "status": "fail", "message": "private range"} for ip in ips]
|
||||
session = _make_batch_session(batch_response)
|
||||
db = _make_async_db()
|
||||
|
||||
await geo_cache.lookup_batch(ips, session, db=db)
|
||||
|
||||
assert db.executemany.await_count >= 1
|
||||
db.execute.assert_not_awaited()
|
||||
# BEGIN IMMEDIATE is called for transaction wrapper.
|
||||
assert db.execute.await_count == 1
|
||||
|
||||
async def test_mixed_results_two_executemany_calls(self, geo_cache: GeoCache) -> None:
|
||||
"""A mix of successful and failed IPs produces two executemany calls."""
|
||||
@@ -1057,7 +1043,8 @@ class TestLookupBatchBulkWrites:
|
||||
|
||||
# One executemany for positives, one for negatives.
|
||||
assert db.executemany.await_count == 2
|
||||
db.execute.assert_not_awaited()
|
||||
# BEGIN IMMEDIATE is called for transaction wrapper.
|
||||
assert db.execute.await_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -1071,9 +1058,7 @@ class TestCacheMetrics:
|
||||
async def test_cache_hit_increments_hits(self) -> None:
|
||||
"""lookup() with a cached IP increments _hits."""
|
||||
geo_cache = GeoCache(allow_http_fallback=True)
|
||||
geo_cache._cache["1.1.1.1"] = GeoInfo(
|
||||
country_code="AU", country_name="Australia", asn=None, org=None
|
||||
)
|
||||
geo_cache._cache["1.1.1.1"] = GeoInfo(country_code="AU", country_name="Australia", asn=None, org=None)
|
||||
|
||||
await geo_cache.lookup("1.1.1.1", MagicMock())
|
||||
|
||||
@@ -1269,4 +1254,3 @@ class TestLargeBanList:
|
||||
|
||||
assert len(result) == 1
|
||||
assert "1.1.1.1" in result
|
||||
|
||||
|
||||
@@ -138,7 +138,7 @@ class TestListHistory:
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history("fake_socket")
|
||||
assert result.pagination.total == 4
|
||||
assert result.total == 4
|
||||
assert len(result.items) == 4
|
||||
|
||||
async def test_time_range_filter_excludes_old_bans(
|
||||
@@ -153,7 +153,7 @@ class TestListHistory:
|
||||
result = await history_service.list_history(
|
||||
"fake_socket", range_="24h"
|
||||
)
|
||||
assert result.pagination.total == 2
|
||||
assert result.total == 2
|
||||
|
||||
async def test_jail_filter(self, f2b_db_path: str) -> None:
|
||||
"""Jail filter restricts results to bans from that jail."""
|
||||
@@ -162,7 +162,7 @@ class TestListHistory:
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history("fake_socket", jail="nginx")
|
||||
assert result.pagination.total == 1
|
||||
assert result.total == 1
|
||||
assert result.items[0].jail == "nginx"
|
||||
|
||||
async def test_ip_prefix_filter(self, f2b_db_path: str) -> None:
|
||||
@@ -174,7 +174,7 @@ class TestListHistory:
|
||||
result = await history_service.list_history(
|
||||
"fake_socket", ip_filter="1.2.3"
|
||||
)
|
||||
assert result.pagination.total == 2
|
||||
assert result.total == 2
|
||||
for item in result.items:
|
||||
assert item.ip.startswith("1.2.3")
|
||||
|
||||
@@ -188,7 +188,7 @@ class TestListHistory:
|
||||
"fake_socket", jail="sshd", ip_filter="1.2.3.4"
|
||||
)
|
||||
# 2 sshd bans for 1.2.3.4
|
||||
assert result.pagination.total == 2
|
||||
assert result.total == 2
|
||||
|
||||
async def test_origin_filter_selfblock(self, f2b_db_path: str) -> None:
|
||||
"""Origin filter should include only selfblock entries."""
|
||||
@@ -200,7 +200,7 @@ class TestListHistory:
|
||||
"fake_socket", origin="selfblock"
|
||||
)
|
||||
|
||||
assert result.pagination.total == 4
|
||||
assert result.total == 4
|
||||
assert all(item.jail != "blocklist-import" for item in result.items)
|
||||
|
||||
async def test_unknown_ip_returns_empty(self, f2b_db_path: str) -> None:
|
||||
@@ -212,7 +212,7 @@ class TestListHistory:
|
||||
result = await history_service.list_history(
|
||||
"fake_socket", ip_filter="99.99.99.99"
|
||||
)
|
||||
assert result.pagination.total == 0
|
||||
assert result.total == 0
|
||||
assert result.items == []
|
||||
|
||||
async def test_failures_extracted_from_data(
|
||||
@@ -226,7 +226,7 @@ class TestListHistory:
|
||||
result = await history_service.list_history(
|
||||
"fake_socket", ip_filter="5.6.7.8"
|
||||
)
|
||||
assert result.pagination.total == 1
|
||||
assert result.total == 1
|
||||
assert result.items[0].failures == 3
|
||||
|
||||
async def test_matches_extracted_from_data(
|
||||
@@ -287,7 +287,7 @@ class TestListHistory:
|
||||
result = await history_service.list_history(
|
||||
"fake_socket", ip_filter="9.0.0.1"
|
||||
)
|
||||
assert result.pagination.total == 1
|
||||
assert result.total == 1
|
||||
item = result.items[0]
|
||||
assert item.failures == 0
|
||||
assert item.matches == []
|
||||
@@ -301,10 +301,10 @@ class TestListHistory:
|
||||
result = await history_service.list_history(
|
||||
"fake_socket", page=1, page_size=2
|
||||
)
|
||||
assert result.pagination.total == 4
|
||||
assert result.total == 4
|
||||
assert len(result.items) == 2
|
||||
assert result.pagination.page == 1
|
||||
assert result.pagination.page_size == 2
|
||||
assert result.page == 1
|
||||
assert result.page_size == 2
|
||||
|
||||
async def test_source_archive_reads_from_archive(self, f2b_db_path: str, tmp_path: Path) -> None:
|
||||
"""Using source='archive' reads from the BanGUI archive table."""
|
||||
@@ -328,7 +328,7 @@ class TestListHistory:
|
||||
db=db,
|
||||
)
|
||||
|
||||
assert result.pagination.total == 1
|
||||
assert result.total == 1
|
||||
assert result.items[0].ip == "10.0.0.1"
|
||||
|
||||
|
||||
@@ -363,8 +363,8 @@ class TestGetIpDetail:
|
||||
|
||||
assert result is not None
|
||||
assert result.ip == "1.2.3.4"
|
||||
assert result.pagination.total_bans == 2
|
||||
assert result.pagination.total_failures == 10 # 5 + 5
|
||||
assert result.total_bans == 2
|
||||
assert result.total_failures == 10 # 5 + 5
|
||||
|
||||
async def test_timeline_ordered_newest_first(
|
||||
self, f2b_db_path: str
|
||||
|
||||
@@ -80,9 +80,8 @@ class TestNormaliseIp:
|
||||
def test_normalise_ip_ipv4_mapped_ipv6_to_ipv4(self) -> None:
|
||||
assert normalise_ip("::ffff:192.168.1.1") == "192.168.1.1"
|
||||
|
||||
def test_normalise_ip_invalid_raises_value_error(self) -> None:
|
||||
with pytest.raises(ValueError):
|
||||
normalise_ip("not-an-ip")
|
||||
def test_normalise_ip_invalid_returns_unchanged(self) -> None:
|
||||
assert normalise_ip("not-an-ip") == "not-an-ip"
|
||||
|
||||
|
||||
class TestNormaliseNetwork:
|
||||
|
||||
@@ -10,9 +10,13 @@ from unittest.mock import AsyncMock, patch
|
||||
import pytest
|
||||
|
||||
from app.exceptions import Fail2BanConnectionError
|
||||
from app.models.ban import ActiveBanListResponse, JailBannedIpsResponse
|
||||
from app.models.ban_domain import DomainActiveBanList
|
||||
from app.models.geo import GeoDetail, GeoInfo
|
||||
from app.models.jail import JailDetailResponse, JailListResponse
|
||||
from app.models.jail_domain import (
|
||||
DomainJailBannedIps,
|
||||
DomainJailDetail,
|
||||
DomainJailList,
|
||||
)
|
||||
from app.services import ban_service, jail_service
|
||||
from app.services.jail_service import JailNotFoundError, JailOperationError
|
||||
from app.utils import jail_socket
|
||||
@@ -109,9 +113,9 @@ class TestListJails:
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.list_jails(_SOCKET, jail_service_state)
|
||||
|
||||
assert isinstance(result, JailListResponse)
|
||||
assert isinstance(result, DomainJailList)
|
||||
assert result.total == 1
|
||||
assert result.jails[0].name == "sshd"
|
||||
assert result.items[0].name == "sshd"
|
||||
|
||||
async def test_empty_jail_list(self, jail_service_state: JailServiceState) -> None:
|
||||
"""list_jails returns empty response when no jails are active."""
|
||||
@@ -120,7 +124,7 @@ class TestListJails:
|
||||
result = await jail_service.list_jails(_SOCKET, jail_service_state)
|
||||
|
||||
assert result.total == 0
|
||||
assert result.jails == []
|
||||
assert result.items == []
|
||||
|
||||
async def test_jail_status_populated(self, jail_service_state: JailServiceState) -> None:
|
||||
"""list_jails populates JailStatus with failed/banned counters."""
|
||||
@@ -136,7 +140,7 @@ class TestListJails:
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.list_jails(_SOCKET, jail_service_state)
|
||||
|
||||
jail = result.jails[0]
|
||||
jail = result.items[0]
|
||||
assert jail.status is not None
|
||||
assert jail.status.currently_banned == 5
|
||||
assert jail.status.total_banned == 50
|
||||
@@ -155,7 +159,7 @@ class TestListJails:
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.list_jails(_SOCKET, jail_service_state)
|
||||
|
||||
jail = result.jails[0]
|
||||
jail = result.items[0]
|
||||
assert jail.ban_time == 3600
|
||||
assert jail.find_time == 300
|
||||
assert jail.max_retry == 3
|
||||
@@ -183,7 +187,7 @@ class TestListJails:
|
||||
result = await jail_service.list_jails(_SOCKET, jail_service_state)
|
||||
|
||||
assert result.total == 2
|
||||
names = {j.name for j in result.jails}
|
||||
names = {j.name for j in result.items}
|
||||
assert names == {"sshd", "nginx"}
|
||||
|
||||
async def test_connection_error_propagates(self, jail_service_state: JailServiceState) -> None:
|
||||
@@ -223,7 +227,7 @@ class TestListJails:
|
||||
result = await jail_service.list_jails(_SOCKET, jail_service_state)
|
||||
|
||||
# Verify the result uses the default values for backend and idle.
|
||||
jail = result.jails[0]
|
||||
jail = result.items[0]
|
||||
assert jail.backend == "polling" # default
|
||||
assert jail.idle is False # default
|
||||
# Capability should now be cached as False.
|
||||
@@ -249,7 +253,7 @@ class TestListJails:
|
||||
result = await jail_service.list_jails(_SOCKET, jail_service_state)
|
||||
|
||||
# Verify real values are returned.
|
||||
jail = result.jails[0]
|
||||
jail = result.items[0]
|
||||
assert jail.backend == "systemd" # real value
|
||||
assert jail.idle is True # real value
|
||||
# Capability should now be cached as True.
|
||||
@@ -280,7 +284,7 @@ class TestListJails:
|
||||
result = await jail_service.list_jails(_SOCKET, jail_service_state)
|
||||
|
||||
# Both jails should return default values (cached result is False).
|
||||
for jail in result.jails:
|
||||
for jail in result.items:
|
||||
assert jail.backend == "polling"
|
||||
assert jail.idle is False
|
||||
|
||||
@@ -329,11 +333,11 @@ class TestGetJail:
|
||||
}
|
||||
|
||||
async def test_returns_jail_detail_response(self, jail_service_state: JailServiceState) -> None:
|
||||
"""get_jail returns a JailDetailResponse."""
|
||||
"""get_jail returns a DomainJailDetail."""
|
||||
with _patch_client(self._full_responses()):
|
||||
result = await jail_service.get_jail(_SOCKET, "sshd")
|
||||
|
||||
assert isinstance(result, JailDetailResponse)
|
||||
assert isinstance(result, DomainJailDetail)
|
||||
assert result.jail.name == "sshd"
|
||||
|
||||
async def test_log_paths_parsed(self, jail_service_state: JailServiceState) -> None:
|
||||
@@ -453,9 +457,7 @@ class TestJailControls:
|
||||
"reload|--all|[]|[['start', 'new'], ['start', 'nginx']]": (0, "OK"),
|
||||
}
|
||||
):
|
||||
await jail_service.reload_all(
|
||||
_SOCKET, include_jails=["new"], exclude_jails=["old"]
|
||||
)
|
||||
await jail_service.reload_all(_SOCKET, include_jails=["new"], exclude_jails=["old"])
|
||||
|
||||
async def test_reload_all_unknown_jail_raises_jail_not_found(self) -> None:
|
||||
"""reload_all detects UnknownJailException and raises JailNotFoundError.
|
||||
@@ -465,18 +467,19 @@ class TestJailControls:
|
||||
test verifies that reload_all detects this and re-raises as
|
||||
JailNotFoundError instead of the generic JailOperationError.
|
||||
"""
|
||||
with _patch_client(
|
||||
{
|
||||
"status": _make_global_status("sshd"),
|
||||
"reload|--all|[]|[['start', 'airsonic-auth'], ['start', 'sshd']]": (
|
||||
1,
|
||||
Exception("UnknownJailException('airsonic-auth')"),
|
||||
),
|
||||
}
|
||||
), pytest.raises(jail_service.JailNotFoundError) as exc_info:
|
||||
await jail_service.reload_all(
|
||||
_SOCKET, include_jails=["airsonic-auth"]
|
||||
)
|
||||
with (
|
||||
_patch_client(
|
||||
{
|
||||
"status": _make_global_status("sshd"),
|
||||
"reload|--all|[]|[['start', 'airsonic-auth'], ['start', 'sshd']]": (
|
||||
1,
|
||||
Exception("UnknownJailException('airsonic-auth')"),
|
||||
),
|
||||
}
|
||||
),
|
||||
pytest.raises(jail_service.JailNotFoundError) as exc_info,
|
||||
):
|
||||
await jail_service.reload_all(_SOCKET, include_jails=["airsonic-auth"])
|
||||
assert exc_info.value.name == "airsonic-auth"
|
||||
|
||||
async def test_restart_sends_stop_command(self) -> None:
|
||||
@@ -486,9 +489,7 @@ class TestJailControls:
|
||||
|
||||
async def test_restart_operation_error_raises(self) -> None:
|
||||
"""restart() raises JailOperationError when fail2ban rejects the stop."""
|
||||
with _patch_client({"stop": (1, Exception("cannot stop"))}), pytest.raises(
|
||||
JailOperationError
|
||||
):
|
||||
with _patch_client({"stop": (1, Exception("cannot stop"))}), pytest.raises(JailOperationError):
|
||||
await jail_service.restart(_SOCKET)
|
||||
|
||||
async def test_restart_connection_error_propagates(self) -> None:
|
||||
@@ -496,9 +497,7 @@ class TestJailControls:
|
||||
|
||||
class _FailClient:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
self.send = AsyncMock(
|
||||
side_effect=Fail2BanConnectionError("no socket", _SOCKET)
|
||||
)
|
||||
self.send = AsyncMock(side_effect=Fail2BanConnectionError("no socket", _SOCKET))
|
||||
|
||||
with (
|
||||
patch("app.services.jail_service.Fail2BanClient", _FailClient),
|
||||
@@ -638,7 +637,7 @@ class TestGetActiveBans:
|
||||
with _patch_client(responses):
|
||||
result = await ban_service.get_active_bans(_SOCKET)
|
||||
|
||||
assert isinstance(result, ActiveBanListResponse)
|
||||
assert isinstance(result, DomainActiveBanList)
|
||||
assert result.total == 1
|
||||
assert result.bans[0].ip == "1.2.3.4"
|
||||
assert result.bans[0].jail == "sshd"
|
||||
@@ -724,17 +723,18 @@ class TestGetActiveBans:
|
||||
),
|
||||
}
|
||||
mock_geo = {"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")}
|
||||
mock_batch = AsyncMock(return_value=mock_geo)
|
||||
mock_cache = AsyncMock()
|
||||
mock_cache.lookup_batch = AsyncMock(return_value=mock_geo)
|
||||
|
||||
with _patch_client(responses):
|
||||
mock_session = AsyncMock()
|
||||
result = await ban_service.get_active_bans(
|
||||
_SOCKET,
|
||||
http_session=mock_session,
|
||||
geo_batch_lookup=mock_batch,
|
||||
geo_cache=mock_cache,
|
||||
)
|
||||
|
||||
mock_batch.assert_awaited_once()
|
||||
mock_cache.lookup_batch.assert_awaited_once()
|
||||
assert result.total == 1
|
||||
assert result.bans[0].country == "DE"
|
||||
|
||||
@@ -748,14 +748,17 @@ class TestGetActiveBans:
|
||||
),
|
||||
}
|
||||
|
||||
failing_batch = AsyncMock(side_effect=RuntimeError("geo down"))
|
||||
import aiohttp
|
||||
|
||||
mock_cache = AsyncMock()
|
||||
mock_cache.lookup_batch = AsyncMock(side_effect=aiohttp.ClientError("geo down"))
|
||||
|
||||
with _patch_client(responses):
|
||||
mock_session = AsyncMock()
|
||||
result = await ban_service.get_active_bans(
|
||||
_SOCKET,
|
||||
http_session=mock_session,
|
||||
geo_batch_lookup=failing_batch,
|
||||
geo_cache=mock_cache,
|
||||
)
|
||||
|
||||
assert result.total == 1
|
||||
@@ -777,9 +780,7 @@ class TestGetActiveBans:
|
||||
return GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None)
|
||||
|
||||
with _patch_client(responses):
|
||||
result = await ban_service.get_active_bans(
|
||||
_SOCKET, geo_enricher=_enricher
|
||||
)
|
||||
result = await ban_service.get_active_bans(_SOCKET, geo_enricher=_enricher)
|
||||
|
||||
assert result.total == 1
|
||||
assert result.bans[0].country == "JP"
|
||||
@@ -875,7 +876,7 @@ class TestLookupIp:
|
||||
assert result.geo.org == "Acme"
|
||||
|
||||
async def test_http_session_uses_geo_service_lookup(self) -> None:
|
||||
"""lookup_ip uses geo_service.lookup when http_session is provided."""
|
||||
"""lookup_ip uses geo_enricher when provided."""
|
||||
responses = {
|
||||
"get|--all|banned|1.2.3.4": (0, []),
|
||||
"status": _make_global_status("sshd"),
|
||||
@@ -883,19 +884,16 @@ class TestLookupIp:
|
||||
}
|
||||
|
||||
mock_geo = GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None)
|
||||
mock_session = AsyncMock()
|
||||
mock_enricher = AsyncMock(return_value=mock_geo)
|
||||
|
||||
with _patch_client(responses), patch(
|
||||
"app.services.jail_service.geo_service.lookup",
|
||||
AsyncMock(return_value=mock_geo),
|
||||
) as mock_lookup:
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.lookup_ip(
|
||||
_SOCKET,
|
||||
"1.2.3.4",
|
||||
http_session=mock_session,
|
||||
geo_enricher=mock_enricher,
|
||||
)
|
||||
|
||||
mock_lookup.assert_awaited_once_with("1.2.3.4", mock_session)
|
||||
mock_enricher.assert_awaited_once_with("1.2.3.4")
|
||||
assert isinstance(result.geo, GeoDetail)
|
||||
assert result.geo.country_code == "JP"
|
||||
assert result.geo.country_name == "Japan"
|
||||
@@ -985,7 +983,7 @@ class TestGetJailBannedIps:
|
||||
with _patch_client(_banned_ips_responses()):
|
||||
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd")
|
||||
|
||||
assert isinstance(result, JailBannedIpsResponse)
|
||||
assert isinstance(result, DomainJailBannedIps)
|
||||
|
||||
async def test_total_reflects_all_entries(self) -> None:
|
||||
"""total equals the number of parsed ban entries."""
|
||||
@@ -996,12 +994,8 @@ class TestGetJailBannedIps:
|
||||
|
||||
async def test_page_1_returns_first_n_items(self) -> None:
|
||||
"""page=1 with page_size=2 returns the first two entries."""
|
||||
with _patch_client(
|
||||
_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])
|
||||
):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", page=1, page_size=2
|
||||
)
|
||||
with _patch_client(_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])):
|
||||
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", page=1, page_size=2)
|
||||
|
||||
assert len(result.items) == 2
|
||||
assert result.items[0].ip == "1.2.3.4"
|
||||
@@ -1010,12 +1004,8 @@ class TestGetJailBannedIps:
|
||||
|
||||
async def test_page_2_returns_remaining_items(self) -> None:
|
||||
"""page=2 with page_size=2 returns the third entry."""
|
||||
with _patch_client(
|
||||
_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])
|
||||
):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", page=2, page_size=2
|
||||
)
|
||||
with _patch_client(_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])):
|
||||
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", page=2, page_size=2)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].ip == "9.10.11.12"
|
||||
@@ -1023,9 +1013,7 @@ class TestGetJailBannedIps:
|
||||
async def test_page_beyond_last_returns_empty_items(self) -> None:
|
||||
"""Requesting a page past the end returns an empty items list."""
|
||||
with _patch_client(_banned_ips_responses()):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", page=99, page_size=25
|
||||
)
|
||||
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", page=99, page_size=25)
|
||||
|
||||
assert result.items == []
|
||||
assert result.total == 2
|
||||
@@ -1033,9 +1021,7 @@ class TestGetJailBannedIps:
|
||||
async def test_search_filter_narrows_results(self) -> None:
|
||||
"""search parameter filters entries by IP substring."""
|
||||
with _patch_client(_banned_ips_responses()):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", search="1.2.3"
|
||||
)
|
||||
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", search="1.2.3")
|
||||
|
||||
assert result.total == 1
|
||||
assert result.items[0].ip == "1.2.3.4"
|
||||
@@ -1044,18 +1030,14 @@ class TestGetJailBannedIps:
|
||||
"""search filter is case-insensitive."""
|
||||
entries = ["192.168.0.1\t2025-01-01 10:00:00 + 600 = 2025-01-01 10:10:00"]
|
||||
with _patch_client(_banned_ips_responses(entries=entries)):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", search="192.168"
|
||||
)
|
||||
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", search="192.168")
|
||||
|
||||
assert result.total == 1
|
||||
|
||||
async def test_search_no_match_returns_empty(self) -> None:
|
||||
"""search that matches nothing returns empty items and total=0."""
|
||||
with _patch_client(_banned_ips_responses()):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", search="999.999"
|
||||
)
|
||||
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", search="999.999")
|
||||
|
||||
assert result.total == 0
|
||||
assert result.items == []
|
||||
@@ -1080,9 +1062,7 @@ class TestGetJailBannedIps:
|
||||
"get|sshd|banip|--with-time": (0, entries),
|
||||
}
|
||||
with _patch_client(responses):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET, "sshd", page=1, page_size=200
|
||||
)
|
||||
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd", page=1, page_size=200)
|
||||
|
||||
assert len(result.items) <= 100
|
||||
|
||||
@@ -1090,30 +1070,22 @@ class TestGetJailBannedIps:
|
||||
"""Geo enrichment is requested only for IPs in the current page."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.services import geo_service
|
||||
|
||||
http_session = MagicMock()
|
||||
geo_enrichment_ips: list[list[str]] = []
|
||||
|
||||
async def _mock_lookup_batch(
|
||||
ips: list[str], _session: Any, **_kw: Any
|
||||
) -> dict[str, Any]:
|
||||
geo_enrichment_ips.append(list(ips))
|
||||
return {}
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.lookup_batch = AsyncMock(
|
||||
side_effect=lambda ips, _session, **_kw: (geo_enrichment_ips.append(list(ips)), {})[-1]
|
||||
)
|
||||
|
||||
with (
|
||||
_patch_client(
|
||||
_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])
|
||||
),
|
||||
patch.object(geo_service, "lookup_batch", side_effect=_mock_lookup_batch),
|
||||
):
|
||||
with _patch_client(_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])):
|
||||
result = await jail_service.get_jail_banned_ips(
|
||||
_SOCKET,
|
||||
"sshd",
|
||||
page=1,
|
||||
page_size=2,
|
||||
http_session=http_session,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
geo_cache=mock_cache,
|
||||
)
|
||||
|
||||
# Only the 2-IP page slice should be passed to geo enrichment.
|
||||
@@ -1123,6 +1095,7 @@ class TestGetJailBannedIps:
|
||||
|
||||
async def test_unknown_jail_raises_jail_not_found_error(self) -> None:
|
||||
"""get_jail_banned_ips raises JailNotFoundError for unknown jail."""
|
||||
|
||||
# Simulate fail2ban returning an "unknown jail" error.
|
||||
class _FakeClient:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
@@ -1142,9 +1115,7 @@ class TestGetJailBannedIps:
|
||||
|
||||
class _FailClient:
|
||||
def __init__(self, **_kw: Any) -> None:
|
||||
self.send = AsyncMock(
|
||||
side_effect=Fail2BanConnectionError("no socket", _SOCKET)
|
||||
)
|
||||
self.send = AsyncMock(side_effect=Fail2BanConnectionError("no socket", _SOCKET))
|
||||
|
||||
with (
|
||||
patch("app.services.jail_service.Fail2BanClient", _FailClient),
|
||||
|
||||
@@ -7,7 +7,8 @@ from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.server import ServerSettingsResponse, ServerSettingsUpdate
|
||||
from app.models.server import ServerSettingsUpdate
|
||||
from app.models.server_domain import DomainServerSettingsResult
|
||||
from app.services import server_service
|
||||
from app.services.server_service import ServerOperationError
|
||||
|
||||
@@ -58,7 +59,7 @@ class TestGetSettings:
|
||||
with _patch_client(_DEFAULT_RESPONSES):
|
||||
result = await server_service.get_settings(_SOCKET)
|
||||
|
||||
assert isinstance(result, ServerSettingsResponse)
|
||||
assert isinstance(result, DomainServerSettingsResult)
|
||||
assert result.settings.log_level == "INFO"
|
||||
assert result.settings.log_target == "/var/log/fail2ban.log"
|
||||
assert result.settings.db_purge_age == 86400
|
||||
|
||||
Reference in New Issue
Block a user