Stage 1.1-1.3: reload_all include/exclude_jails params already implemented; added keyword-arg assertions in router and service tests. Stage 2.1/6.1: _send_command_sync retry loop (3 attempts, 150ms exp backoff) retrying on EAGAIN/ECONNREFUSED/ENOBUFS; immediate raise on all other errors. Stage 2.2: asyncio.Lock at module level in jail_service.reload_all to serialize concurrent reload--all commands. Stage 3.1: activate_jail re-queries _get_active_jail_names after reload; returns active=False with descriptive message if jail did not start. Stage 4.1/6.2: asyncio.Semaphore (max 10) in Fail2BanClient.send, lazy- initialized; logs fail2ban_command_waiting_semaphore at debug when waiting. Stage 5.1/5.2: unit tests asserting reload_all is called with include_jails and exclude_jails; activation verification happy/sad path tests. Stage 6.3: TestSendCommandSyncRetry (5 cases) + TestFail2BanClientSemaphore concurrency test. Stage 7.1-7.3: _since_unix uses time.time(); bans_by_jail debug logging with since_iso; diagnostic warning when total==0 despite table rows; unit test verifying the warning fires for stale data.
1043 lines
38 KiB
Python
1043 lines
38 KiB
Python
"""Tests for ban_service.list_bans()."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import aiosqlite
|
|
import pytest
|
|
|
|
from app.services import ban_service
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_NOW: int = int(time.time())
|
|
_ONE_HOUR_AGO: int = _NOW - 3600
|
|
_TWO_DAYS_AGO: int = _NOW - 2 * 24 * 3600
|
|
|
|
|
|
async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
|
|
"""Create a minimal fail2ban SQLite database with the given ban rows.
|
|
|
|
Args:
|
|
path: Filesystem path for the new SQLite file.
|
|
rows: Sequence of dicts with keys ``jail``, ``ip``, ``timeofban``,
|
|
``bantime``, ``bancount``, and optionally ``data``.
|
|
"""
|
|
async with aiosqlite.connect(path) as db:
|
|
await db.execute(
|
|
"CREATE TABLE jails ("
|
|
"name TEXT NOT NULL UNIQUE, "
|
|
"enabled INTEGER NOT NULL DEFAULT 1"
|
|
")"
|
|
)
|
|
await db.execute(
|
|
"CREATE TABLE bans ("
|
|
"jail TEXT NOT NULL, "
|
|
"ip TEXT, "
|
|
"timeofban INTEGER NOT NULL, "
|
|
"bantime INTEGER NOT NULL, "
|
|
"bancount INTEGER NOT NULL DEFAULT 1, "
|
|
"data JSON"
|
|
")"
|
|
)
|
|
for row in rows:
|
|
await db.execute(
|
|
"INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) "
|
|
"VALUES (?, ?, ?, ?, ?, ?)",
|
|
(
|
|
row["jail"],
|
|
row["ip"],
|
|
row["timeofban"],
|
|
row.get("bantime", 3600),
|
|
row.get("bancount", 1),
|
|
json.dumps(row["data"]) if "data" in row else None,
|
|
),
|
|
)
|
|
await db.commit()
|
|
|
|
|
|
@pytest.fixture
|
|
async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
|
"""Return the path to a test fail2ban SQLite database with several bans."""
|
|
path = str(tmp_path / "fail2ban_test.sqlite3")
|
|
await _create_f2b_db(
|
|
path,
|
|
[
|
|
{
|
|
"jail": "sshd",
|
|
"ip": "1.2.3.4",
|
|
"timeofban": _ONE_HOUR_AGO,
|
|
"bantime": 3600,
|
|
"bancount": 2,
|
|
"data": {
|
|
"matches": ["Nov 10 10:00 sshd[123]: Failed password for root"],
|
|
"failures": 5,
|
|
},
|
|
},
|
|
{
|
|
"jail": "nginx",
|
|
"ip": "5.6.7.8",
|
|
"timeofban": _ONE_HOUR_AGO,
|
|
"bantime": 7200,
|
|
"bancount": 1,
|
|
"data": {"matches": ["GET /admin HTTP/1.1"], "failures": 3},
|
|
},
|
|
{
|
|
"jail": "sshd",
|
|
"ip": "9.10.11.12",
|
|
"timeofban": _TWO_DAYS_AGO,
|
|
"bantime": 3600,
|
|
"bancount": 1,
|
|
"data": {"failures": 6}, # no matches
|
|
},
|
|
],
|
|
)
|
|
return path
|
|
|
|
|
|
@pytest.fixture
|
|
async def mixed_origin_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
|
"""Return a database with bans from both blocklist-import and organic jails."""
|
|
path = str(tmp_path / "fail2ban_mixed_origin.sqlite3")
|
|
await _create_f2b_db(
|
|
path,
|
|
[
|
|
{
|
|
"jail": "blocklist-import",
|
|
"ip": "10.0.0.1",
|
|
"timeofban": _ONE_HOUR_AGO,
|
|
"bantime": -1,
|
|
"bancount": 1,
|
|
},
|
|
{
|
|
"jail": "sshd",
|
|
"ip": "10.0.0.2",
|
|
"timeofban": _ONE_HOUR_AGO,
|
|
"bantime": 3600,
|
|
"bancount": 3,
|
|
},
|
|
{
|
|
"jail": "nginx",
|
|
"ip": "10.0.0.3",
|
|
"timeofban": _ONE_HOUR_AGO,
|
|
"bantime": 7200,
|
|
"bancount": 1,
|
|
},
|
|
],
|
|
)
|
|
return path
|
|
|
|
|
|
@pytest.fixture
|
|
async def empty_f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
|
"""Return the path to a fail2ban SQLite database with no ban records."""
|
|
path = str(tmp_path / "fail2ban_empty.sqlite3")
|
|
await _create_f2b_db(path, [])
|
|
return path
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# list_bans — happy path
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestListBansHappyPath:
|
|
"""Verify ban_service.list_bans() under normal conditions."""
|
|
|
|
async def test_returns_bans_in_range(self, f2b_db_path: str) -> None:
|
|
"""Only bans within the selected range are returned."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=f2b_db_path),
|
|
):
|
|
result = await ban_service.list_bans("/fake/sock", "24h")
|
|
|
|
# Two bans within last 24 h; one is 2 days old and excluded.
|
|
assert result.total == 2
|
|
assert len(result.items) == 2
|
|
|
|
async def test_results_sorted_newest_first(self, f2b_db_path: str) -> None:
|
|
"""Items are ordered by ``banned_at`` descending (newest first)."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=f2b_db_path),
|
|
):
|
|
result = await ban_service.list_bans("/fake/sock", "24h")
|
|
|
|
timestamps = [item.banned_at for item in result.items]
|
|
assert timestamps == sorted(timestamps, reverse=True)
|
|
|
|
async def test_ban_fields_present(self, f2b_db_path: str) -> None:
|
|
"""Each item contains ip, jail, banned_at, ban_count."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=f2b_db_path),
|
|
):
|
|
result = await ban_service.list_bans("/fake/sock", "24h")
|
|
|
|
for item in result.items:
|
|
assert item.ip
|
|
assert item.jail
|
|
assert item.banned_at
|
|
assert item.ban_count >= 1
|
|
|
|
async def test_service_extracted_from_first_match(self, f2b_db_path: str) -> None:
|
|
"""``service`` field is the first element of ``data.matches``."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=f2b_db_path),
|
|
):
|
|
result = await ban_service.list_bans("/fake/sock", "24h")
|
|
|
|
sshd_item = next(i for i in result.items if i.jail == "sshd")
|
|
assert sshd_item.service is not None
|
|
assert "Failed password" in sshd_item.service
|
|
|
|
async def test_service_is_none_when_no_matches(self, f2b_db_path: str) -> None:
|
|
"""``service`` is ``None`` when the ban has no stored matches."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=f2b_db_path),
|
|
):
|
|
# Use 7d to include the older ban with no matches.
|
|
result = await ban_service.list_bans("/fake/sock", "7d")
|
|
|
|
no_match = next(i for i in result.items if i.ip == "9.10.11.12")
|
|
assert no_match.service is None
|
|
|
|
async def test_empty_db_returns_zero(self, empty_f2b_db_path: str) -> None:
|
|
"""When no bans exist the result has total=0 and no items."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=empty_f2b_db_path),
|
|
):
|
|
result = await ban_service.list_bans("/fake/sock", "24h")
|
|
|
|
assert result.total == 0
|
|
assert result.items == []
|
|
|
|
async def test_365d_range_includes_old_bans(self, f2b_db_path: str) -> None:
|
|
"""The ``365d`` range includes bans that are 2 days old."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=f2b_db_path),
|
|
):
|
|
result = await ban_service.list_bans("/fake/sock", "365d")
|
|
|
|
assert result.total == 3
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# list_bans — geo enrichment
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestListBansGeoEnrichment:
|
|
"""Verify geo enrichment integration in ban_service.list_bans()."""
|
|
|
|
async def test_geo_data_applied_when_enricher_provided(
|
|
self, f2b_db_path: str
|
|
) -> None:
|
|
"""Geo fields are populated when an enricher returns data."""
|
|
from app.services.geo_service import GeoInfo
|
|
|
|
async def fake_enricher(ip: str) -> GeoInfo:
|
|
return GeoInfo(
|
|
country_code="DE",
|
|
country_name="Germany",
|
|
asn="AS3320",
|
|
org="Deutsche Telekom",
|
|
)
|
|
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=f2b_db_path),
|
|
):
|
|
result = await ban_service.list_bans(
|
|
"/fake/sock", "24h", geo_enricher=fake_enricher
|
|
)
|
|
|
|
for item in result.items:
|
|
assert item.country_code == "DE"
|
|
assert item.country_name == "Germany"
|
|
assert item.asn == "AS3320"
|
|
|
|
async def test_geo_failure_does_not_break_results(
|
|
self, f2b_db_path: str
|
|
) -> None:
|
|
"""A geo enricher that raises still returns ban items (geo fields null)."""
|
|
|
|
async def failing_enricher(ip: str) -> None:
|
|
raise RuntimeError("geo service down")
|
|
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=f2b_db_path),
|
|
):
|
|
result = await ban_service.list_bans(
|
|
"/fake/sock", "24h", geo_enricher=failing_enricher
|
|
)
|
|
|
|
assert result.total == 2
|
|
for item in result.items:
|
|
assert item.country_code is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# list_bans — batch geo enrichment via http_session
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestListBansBatchGeoEnrichment:
|
|
"""Verify that list_bans uses lookup_batch when http_session is provided."""
|
|
|
|
async def test_batch_geo_applied_via_http_session(
|
|
self, f2b_db_path: str
|
|
) -> None:
|
|
"""Geo fields are populated via lookup_batch when http_session is given."""
|
|
from unittest.mock import MagicMock
|
|
|
|
from app.services.geo_service import GeoInfo
|
|
|
|
fake_session = MagicMock()
|
|
fake_geo_map = {
|
|
"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS3320", org="Deutsche Telekom"),
|
|
"5.6.7.8": GeoInfo(country_code="US", country_name="United States", asn="AS15169", org="Google"),
|
|
}
|
|
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=f2b_db_path),
|
|
), patch(
|
|
"app.services.geo_service.lookup_batch",
|
|
new=AsyncMock(return_value=fake_geo_map),
|
|
):
|
|
result = await ban_service.list_bans(
|
|
"/fake/sock", "24h", http_session=fake_session
|
|
)
|
|
|
|
assert result.total == 2
|
|
de_item = next(i for i in result.items if i.ip == "1.2.3.4")
|
|
us_item = next(i for i in result.items if i.ip == "5.6.7.8")
|
|
assert de_item.country_code == "DE"
|
|
assert de_item.country_name == "Germany"
|
|
assert us_item.country_code == "US"
|
|
assert us_item.country_name == "United States"
|
|
|
|
async def test_batch_failure_does_not_break_results(
|
|
self, f2b_db_path: str
|
|
) -> None:
|
|
"""A lookup_batch failure still returns items with null geo fields."""
|
|
from unittest.mock import MagicMock
|
|
|
|
fake_session = MagicMock()
|
|
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=f2b_db_path),
|
|
), patch(
|
|
"app.services.geo_service.lookup_batch",
|
|
new=AsyncMock(side_effect=RuntimeError("batch geo down")),
|
|
):
|
|
result = await ban_service.list_bans(
|
|
"/fake/sock", "24h", http_session=fake_session
|
|
)
|
|
|
|
assert result.total == 2
|
|
for item in result.items:
|
|
assert item.country_code is None
|
|
|
|
async def test_http_session_takes_priority_over_geo_enricher(
|
|
self, f2b_db_path: str
|
|
) -> None:
|
|
"""When both http_session and geo_enricher are provided, batch wins."""
|
|
from unittest.mock import MagicMock
|
|
|
|
from app.services.geo_service import GeoInfo
|
|
|
|
fake_session = MagicMock()
|
|
fake_geo_map = {
|
|
"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
|
|
"5.6.7.8": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
|
|
}
|
|
|
|
async def enricher_should_not_be_called(ip: str) -> GeoInfo:
|
|
raise AssertionError(f"geo_enricher was called for {ip!r} — should not happen")
|
|
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=f2b_db_path),
|
|
), patch(
|
|
"app.services.geo_service.lookup_batch",
|
|
new=AsyncMock(return_value=fake_geo_map),
|
|
):
|
|
result = await ban_service.list_bans(
|
|
"/fake/sock",
|
|
"24h",
|
|
http_session=fake_session,
|
|
geo_enricher=enricher_should_not_be_called,
|
|
)
|
|
|
|
assert result.total == 2
|
|
for item in result.items:
|
|
assert item.country_code == "DE"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# list_bans — pagination
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestListBansPagination:
|
|
"""Verify pagination parameters in list_bans()."""
|
|
|
|
async def test_page_size_respected(self, f2b_db_path: str) -> None:
|
|
"""``page_size=1`` returns at most one item."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=f2b_db_path),
|
|
):
|
|
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
|
|
|
|
assert len(result.items) == 1
|
|
assert result.page_size == 1
|
|
|
|
async def test_page_2_returns_remaining_items(self, f2b_db_path: str) -> None:
|
|
"""The second page returns items not on the first page."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=f2b_db_path),
|
|
):
|
|
page1 = await ban_service.list_bans("/fake/sock", "7d", page=1, page_size=1)
|
|
page2 = await ban_service.list_bans("/fake/sock", "7d", page=2, page_size=1)
|
|
|
|
# Different IPs should appear on different pages.
|
|
assert page1.items[0].ip != page2.items[0].ip
|
|
|
|
async def test_total_reflects_full_count_not_page_count(
|
|
self, f2b_db_path: str
|
|
) -> None:
|
|
"""``total`` reports all matching records regardless of pagination."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=f2b_db_path),
|
|
):
|
|
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
|
|
|
|
assert result.total == 3 # All three bans are within 7d.
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# list_bans / bans_by_country — origin derivation
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBanOriginDerivation:
|
|
"""Verify that ban_service correctly derives ``origin`` from jail names."""
|
|
|
|
async def test_blocklist_import_jail_yields_blocklist_origin(
|
|
self, mixed_origin_db_path: str
|
|
) -> None:
|
|
"""Bans from ``blocklist-import`` jail carry ``origin == "blocklist"``."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
|
):
|
|
result = await ban_service.list_bans("/fake/sock", "24h")
|
|
|
|
blocklist_items = [i for i in result.items if i.jail == "blocklist-import"]
|
|
assert len(blocklist_items) == 1
|
|
assert blocklist_items[0].origin == "blocklist"
|
|
|
|
async def test_organic_jail_yields_selfblock_origin(
|
|
self, mixed_origin_db_path: str
|
|
) -> None:
|
|
"""Bans from organic jails (sshd, nginx, …) carry ``origin == "selfblock"``."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
|
):
|
|
result = await ban_service.list_bans("/fake/sock", "24h")
|
|
|
|
organic_items = [i for i in result.items if i.jail != "blocklist-import"]
|
|
assert len(organic_items) == 2
|
|
for item in organic_items:
|
|
assert item.origin == "selfblock"
|
|
|
|
async def test_all_items_carry_origin_field(
|
|
self, mixed_origin_db_path: str
|
|
) -> None:
|
|
"""Every returned item has an ``origin`` field with a valid value."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
|
):
|
|
result = await ban_service.list_bans("/fake/sock", "24h")
|
|
|
|
for item in result.items:
|
|
assert item.origin in ("blocklist", "selfblock")
|
|
|
|
async def test_bans_by_country_blocklist_origin(
|
|
self, mixed_origin_db_path: str
|
|
) -> None:
|
|
"""``bans_by_country`` also derives origin correctly for blocklist bans."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
|
):
|
|
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
|
|
|
blocklist_bans = [b for b in result.bans if b.jail == "blocklist-import"]
|
|
assert len(blocklist_bans) == 1
|
|
assert blocklist_bans[0].origin == "blocklist"
|
|
|
|
async def test_bans_by_country_selfblock_origin(
|
|
self, mixed_origin_db_path: str
|
|
) -> None:
|
|
"""``bans_by_country`` derives origin correctly for organic jails."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
|
):
|
|
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
|
|
|
organic_bans = [b for b in result.bans if b.jail != "blocklist-import"]
|
|
assert len(organic_bans) == 2
|
|
for ban in organic_bans:
|
|
assert ban.origin == "selfblock"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# list_bans / bans_by_country — origin filter parameter
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestOriginFilter:
|
|
"""Verify that the origin filter correctly restricts results."""
|
|
|
|
async def test_list_bans_blocklist_filter_returns_only_blocklist(
|
|
self, mixed_origin_db_path: str
|
|
) -> None:
|
|
"""``origin='blocklist'`` returns only blocklist-import jail bans."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
|
):
|
|
result = await ban_service.list_bans(
|
|
"/fake/sock", "24h", origin="blocklist"
|
|
)
|
|
|
|
assert result.total == 1
|
|
assert len(result.items) == 1
|
|
assert result.items[0].jail == "blocklist-import"
|
|
assert result.items[0].origin == "blocklist"
|
|
|
|
async def test_list_bans_selfblock_filter_excludes_blocklist(
|
|
self, mixed_origin_db_path: str
|
|
) -> None:
|
|
"""``origin='selfblock'`` excludes the blocklist-import jail."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
|
):
|
|
result = await ban_service.list_bans(
|
|
"/fake/sock", "24h", origin="selfblock"
|
|
)
|
|
|
|
assert result.total == 2
|
|
assert len(result.items) == 2
|
|
for item in result.items:
|
|
assert item.jail != "blocklist-import"
|
|
assert item.origin == "selfblock"
|
|
|
|
async def test_list_bans_no_filter_returns_all(
|
|
self, mixed_origin_db_path: str
|
|
) -> None:
|
|
"""``origin=None`` applies no jail restriction — all bans returned."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
|
):
|
|
result = await ban_service.list_bans("/fake/sock", "24h", origin=None)
|
|
|
|
assert result.total == 3
|
|
|
|
async def test_bans_by_country_blocklist_filter(
|
|
self, mixed_origin_db_path: str
|
|
) -> None:
|
|
"""``bans_by_country`` with ``origin='blocklist'`` counts only blocklist bans."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
|
):
|
|
result = await ban_service.bans_by_country(
|
|
"/fake/sock", "24h", origin="blocklist"
|
|
)
|
|
|
|
assert result.total == 1
|
|
assert all(b.jail == "blocklist-import" for b in result.bans)
|
|
|
|
async def test_bans_by_country_selfblock_filter(
|
|
self, mixed_origin_db_path: str
|
|
) -> None:
|
|
"""``bans_by_country`` with ``origin='selfblock'`` excludes blocklist jails."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
|
):
|
|
result = await ban_service.bans_by_country(
|
|
"/fake/sock", "24h", origin="selfblock"
|
|
)
|
|
|
|
assert result.total == 2
|
|
assert all(b.jail != "blocklist-import" for b in result.bans)
|
|
|
|
async def test_bans_by_country_no_filter_returns_all(
|
|
self, mixed_origin_db_path: str
|
|
) -> None:
|
|
"""``bans_by_country`` with ``origin=None`` returns all bans."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
|
):
|
|
result = await ban_service.bans_by_country(
|
|
"/fake/sock", "24h", origin=None
|
|
)
|
|
|
|
assert result.total == 3
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# bans_by_country — background geo resolution (Task 3)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
class TestBansbyCountryBackground:
|
|
"""bans_by_country() with http_session uses cache-only geo and fires a
|
|
background task for uncached IPs instead of blocking on API calls."""
|
|
|
|
async def test_cached_geo_returned_without_api_call(
|
|
self, mixed_origin_db_path: str
|
|
) -> None:
|
|
"""When all IPs are in the cache, lookup_cached_only returns them and
|
|
no background task is created."""
|
|
from app.services import geo_service
|
|
|
|
# Pre-populate the cache for all three IPs in the fixture.
|
|
geo_service._cache["10.0.0.1"] = geo_service.GeoInfo( # type: ignore[attr-defined]
|
|
country_code="DE", country_name="Germany", asn=None, org=None
|
|
)
|
|
geo_service._cache["10.0.0.2"] = geo_service.GeoInfo( # type: ignore[attr-defined]
|
|
country_code="US", country_name="United States", asn=None, org=None
|
|
)
|
|
geo_service._cache["10.0.0.3"] = geo_service.GeoInfo( # type: ignore[attr-defined]
|
|
country_code="JP", country_name="Japan", asn=None, org=None
|
|
)
|
|
|
|
with (
|
|
patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
|
),
|
|
patch(
|
|
"app.services.ban_service.asyncio.create_task"
|
|
) as mock_create_task,
|
|
):
|
|
mock_session = AsyncMock()
|
|
result = await ban_service.bans_by_country(
|
|
"/fake/sock", "24h", http_session=mock_session
|
|
)
|
|
|
|
# All countries resolved from cache — no background task needed.
|
|
mock_create_task.assert_not_called()
|
|
assert result.total == 3
|
|
# Country counts should reflect the cached data.
|
|
assert "DE" in result.countries or "US" in result.countries or "JP" in result.countries
|
|
geo_service.clear_cache()
|
|
|
|
async def test_uncached_ips_trigger_background_task(
|
|
self, mixed_origin_db_path: str
|
|
) -> None:
|
|
"""When IPs are NOT in the cache, create_task is called for background
|
|
resolution and the response returns without blocking."""
|
|
from app.services import geo_service
|
|
|
|
geo_service.clear_cache() # ensure cache is empty
|
|
|
|
with (
|
|
patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
|
),
|
|
patch(
|
|
"app.services.ban_service.asyncio.create_task"
|
|
) as mock_create_task,
|
|
):
|
|
mock_session = AsyncMock()
|
|
result = await ban_service.bans_by_country(
|
|
"/fake/sock", "24h", http_session=mock_session
|
|
)
|
|
|
|
# Background task must have been scheduled for uncached IPs.
|
|
mock_create_task.assert_called_once()
|
|
# Response is still valid with empty country map (IPs not cached yet).
|
|
assert result.total == 3
|
|
|
|
async def test_no_background_task_without_http_session(
|
|
self, mixed_origin_db_path: str
|
|
) -> None:
|
|
"""When http_session is None, no background task is created."""
|
|
from app.services import geo_service
|
|
|
|
geo_service.clear_cache()
|
|
|
|
with (
|
|
patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
|
),
|
|
patch(
|
|
"app.services.ban_service.asyncio.create_task"
|
|
) as mock_create_task,
|
|
):
|
|
result = await ban_service.bans_by_country(
|
|
"/fake/sock", "24h", http_session=None
|
|
)
|
|
|
|
mock_create_task.assert_not_called()
|
|
assert result.total == 3
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ban_trend
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBanTrend:
|
|
"""Verify ban_service.ban_trend() behaviour."""
|
|
|
|
async def test_24h_returns_24_buckets(self, empty_f2b_db_path: str) -> None:
|
|
"""``range_='24h'`` always yields exactly 24 buckets."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=empty_f2b_db_path),
|
|
):
|
|
result = await ban_service.ban_trend("/fake/sock", "24h")
|
|
|
|
assert len(result.buckets) == 24
|
|
assert result.bucket_size == "1h"
|
|
|
|
async def test_7d_returns_28_buckets(self, empty_f2b_db_path: str) -> None:
|
|
"""``range_='7d'`` yields 28 six-hour buckets."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=empty_f2b_db_path),
|
|
):
|
|
result = await ban_service.ban_trend("/fake/sock", "7d")
|
|
|
|
assert len(result.buckets) == 28
|
|
assert result.bucket_size == "6h"
|
|
|
|
async def test_30d_returns_30_buckets(self, empty_f2b_db_path: str) -> None:
|
|
"""``range_='30d'`` yields 30 daily buckets."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=empty_f2b_db_path),
|
|
):
|
|
result = await ban_service.ban_trend("/fake/sock", "30d")
|
|
|
|
assert len(result.buckets) == 30
|
|
assert result.bucket_size == "1d"
|
|
|
|
async def test_365d_bucket_size_label(self, empty_f2b_db_path: str) -> None:
|
|
"""``range_='365d'`` uses '7d' as the bucket size label."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=empty_f2b_db_path),
|
|
):
|
|
result = await ban_service.ban_trend("/fake/sock", "365d")
|
|
|
|
assert result.bucket_size == "7d"
|
|
assert len(result.buckets) > 0
|
|
|
|
async def test_empty_db_all_buckets_zero(self, empty_f2b_db_path: str) -> None:
|
|
"""All bucket counts are zero when the database has no bans."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=empty_f2b_db_path),
|
|
):
|
|
result = await ban_service.ban_trend("/fake/sock", "24h")
|
|
|
|
assert all(b.count == 0 for b in result.buckets)
|
|
|
|
async def test_buckets_are_time_ordered(self, empty_f2b_db_path: str) -> None:
|
|
"""Buckets are ordered chronologically (ascending timestamps)."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=empty_f2b_db_path),
|
|
):
|
|
result = await ban_service.ban_trend("/fake/sock", "7d")
|
|
|
|
timestamps = [b.timestamp for b in result.buckets]
|
|
assert timestamps == sorted(timestamps)
|
|
|
|
async def test_bans_counted_in_correct_bucket(self, tmp_path: Path) -> None:
|
|
"""A ban at a known time appears in the expected bucket."""
|
|
import time as _time
|
|
|
|
now = int(_time.time())
|
|
# Place a ban exactly 30 minutes ago — should land in bucket 0 of a 24h range
|
|
# (the most recent hour bucket when 'since' is ~24 h ago).
|
|
thirty_min_ago = now - 1800
|
|
path = str(tmp_path / "test_bucket.sqlite3")
|
|
await _create_f2b_db(
|
|
path,
|
|
[{"jail": "sshd", "ip": "1.2.3.4", "timeofban": thirty_min_ago}],
|
|
)
|
|
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=path),
|
|
):
|
|
result = await ban_service.ban_trend("/fake/sock", "24h")
|
|
|
|
# Total ban count across all buckets must be exactly 1.
|
|
assert sum(b.count for b in result.buckets) == 1
|
|
|
|
async def test_origin_filter_blocklist(self, tmp_path: Path) -> None:
|
|
"""``origin='blocklist'`` counts only blocklist-import bans."""
|
|
import time as _time
|
|
|
|
now = int(_time.time())
|
|
one_hour_ago = now - 3600
|
|
path = str(tmp_path / "test_trend_origin.sqlite3")
|
|
await _create_f2b_db(
|
|
path,
|
|
[
|
|
{"jail": "blocklist-import", "ip": "10.0.0.1", "timeofban": one_hour_ago},
|
|
{"jail": "sshd", "ip": "10.0.0.2", "timeofban": one_hour_ago},
|
|
],
|
|
)
|
|
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=path),
|
|
):
|
|
result = await ban_service.ban_trend(
|
|
"/fake/sock", "24h", origin="blocklist"
|
|
)
|
|
|
|
assert sum(b.count for b in result.buckets) == 1
|
|
|
|
async def test_origin_filter_selfblock(self, tmp_path: Path) -> None:
|
|
"""``origin='selfblock'`` excludes blocklist-import bans."""
|
|
import time as _time
|
|
|
|
now = int(_time.time())
|
|
one_hour_ago = now - 3600
|
|
path = str(tmp_path / "test_trend_selfblock.sqlite3")
|
|
await _create_f2b_db(
|
|
path,
|
|
[
|
|
{"jail": "blocklist-import", "ip": "10.0.0.1", "timeofban": one_hour_ago},
|
|
{"jail": "sshd", "ip": "10.0.0.2", "timeofban": one_hour_ago},
|
|
{"jail": "nginx", "ip": "10.0.0.3", "timeofban": one_hour_ago},
|
|
],
|
|
)
|
|
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=path),
|
|
):
|
|
result = await ban_service.ban_trend(
|
|
"/fake/sock", "24h", origin="selfblock"
|
|
)
|
|
|
|
assert sum(b.count for b in result.buckets) == 2
|
|
|
|
async def test_each_bucket_has_iso_timestamp(self, empty_f2b_db_path: str) -> None:
|
|
"""Every bucket timestamp is a valid ISO 8601 string."""
|
|
from datetime import datetime
|
|
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=empty_f2b_db_path),
|
|
):
|
|
result = await ban_service.ban_trend("/fake/sock", "24h")
|
|
|
|
for bucket in result.buckets:
|
|
# datetime.fromisoformat raises ValueError on invalid input.
|
|
parsed = datetime.fromisoformat(bucket.timestamp)
|
|
assert parsed.tzinfo is not None # Must be timezone-aware (UTC)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# bans_by_jail
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBansByJail:
|
|
"""Verify ban_service.bans_by_jail() behaviour."""
|
|
|
|
async def test_returns_jails_sorted_descending(self, tmp_path: Path) -> None:
|
|
"""Jails are returned ordered by count descending."""
|
|
import time as _time
|
|
|
|
now = int(_time.time())
|
|
one_hour_ago = now - 3600
|
|
path = str(tmp_path / "test_by_jail.sqlite3")
|
|
await _create_f2b_db(
|
|
path,
|
|
[
|
|
{"jail": "sshd", "ip": "1.1.1.1", "timeofban": one_hour_ago},
|
|
{"jail": "sshd", "ip": "1.1.1.2", "timeofban": one_hour_ago},
|
|
{"jail": "nginx", "ip": "2.2.2.2", "timeofban": one_hour_ago},
|
|
],
|
|
)
|
|
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=path),
|
|
):
|
|
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
|
|
|
assert result.jails[0].jail == "sshd"
|
|
assert result.jails[0].count == 2
|
|
assert result.jails[1].jail == "nginx"
|
|
assert result.jails[1].count == 1
|
|
|
|
async def test_total_equals_sum_of_counts(self, tmp_path: Path) -> None:
|
|
"""``total`` equals the sum of all per-jail counts."""
|
|
import time as _time
|
|
|
|
now = int(_time.time())
|
|
one_hour_ago = now - 3600
|
|
path = str(tmp_path / "test_by_jail_total.sqlite3")
|
|
await _create_f2b_db(
|
|
path,
|
|
[
|
|
{"jail": "sshd", "ip": "1.1.1.1", "timeofban": one_hour_ago},
|
|
{"jail": "nginx", "ip": "2.2.2.2", "timeofban": one_hour_ago},
|
|
{"jail": "nginx", "ip": "3.3.3.3", "timeofban": one_hour_ago},
|
|
],
|
|
)
|
|
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=path),
|
|
):
|
|
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
|
|
|
assert result.total == sum(j.count for j in result.jails)
|
|
assert result.total == 3
|
|
|
|
async def test_empty_db_returns_empty_list(self, empty_f2b_db_path: str) -> None:
|
|
"""An empty database returns an empty jails list with total zero."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=empty_f2b_db_path),
|
|
):
|
|
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
|
|
|
assert result.jails == []
|
|
assert result.total == 0
|
|
|
|
async def test_excludes_bans_outside_time_window(self, f2b_db_path: str) -> None:
|
|
"""Bans older than the time window are not counted."""
|
|
# f2b_db_path has one ban from _TWO_DAYS_AGO, which is outside "24h".
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=f2b_db_path),
|
|
):
|
|
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
|
|
|
# Only 2 bans within 24h (both from _ONE_HOUR_AGO).
|
|
assert result.total == 2
|
|
|
|
async def test_origin_filter_blocklist(self, mixed_origin_db_path: str) -> None:
|
|
"""``origin='blocklist'`` returns only the blocklist-import jail."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
|
):
|
|
result = await ban_service.bans_by_jail(
|
|
"/fake/sock", "24h", origin="blocklist"
|
|
)
|
|
|
|
assert len(result.jails) == 1
|
|
assert result.jails[0].jail == "blocklist-import"
|
|
assert result.total == 1
|
|
|
|
async def test_origin_filter_selfblock(self, mixed_origin_db_path: str) -> None:
|
|
"""``origin='selfblock'`` excludes the blocklist-import jail."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
|
):
|
|
result = await ban_service.bans_by_jail(
|
|
"/fake/sock", "24h", origin="selfblock"
|
|
)
|
|
|
|
jail_names = {j.jail for j in result.jails}
|
|
assert "blocklist-import" not in jail_names
|
|
assert result.total == 2
|
|
|
|
async def test_no_origin_filter_returns_all_jails(
|
|
self, mixed_origin_db_path: str
|
|
) -> None:
|
|
"""``origin=None`` returns bans from all jails."""
|
|
with patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=mixed_origin_db_path),
|
|
):
|
|
result = await ban_service.bans_by_jail(
|
|
"/fake/sock", "24h", origin=None
|
|
)
|
|
|
|
assert result.total == 3
|
|
assert len(result.jails) == 3
|
|
|
|
async def test_diagnostic_warning_when_zero_results_despite_data(
|
|
self, tmp_path: Path
|
|
) -> None:
|
|
"""A warning is logged when the time-range filter excludes all existing rows."""
|
|
import time as _time
|
|
|
|
# Insert rows with timeofban far in the past (outside any range window).
|
|
far_past = int(_time.time()) - 400 * 24 * 3600 # ~400 days ago
|
|
path = str(tmp_path / "test_diag.sqlite3")
|
|
await _create_f2b_db(
|
|
path,
|
|
[
|
|
{"jail": "sshd", "ip": "1.1.1.1", "timeofban": far_past},
|
|
],
|
|
)
|
|
|
|
with (
|
|
patch(
|
|
"app.services.ban_service._get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=path),
|
|
),
|
|
patch("app.services.ban_service.log") as mock_log,
|
|
):
|
|
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
|
|
|
assert result.total == 0
|
|
assert result.jails == []
|
|
# The diagnostic warning must have been emitted.
|
|
warning_calls = [
|
|
c
|
|
for c in mock_log.warning.call_args_list
|
|
if c[0][0] == "ban_service_bans_by_jail_empty_despite_data"
|
|
]
|
|
assert len(warning_calls) == 1
|
|
|