feat: implement dashboard ban overview (Stage 5)
- Add ban_service reading fail2ban SQLite DB via read-only aiosqlite - Add geo_service resolving IPs via ip-api.com with 10k in-memory cache - Add GET /api/dashboard/bans and GET /api/dashboard/accesses endpoints - Add TimeRange, DashboardBanItem, DashboardBanListResponse, AccessListItem, AccessListResponse models in models/ban.py - Build BanTable component (Fluent UI DataGrid) with bans/accesses modes, pagination, loading/error/empty states, and ban-count badges - Build useBans hook managing time-range and pagination state - Update DashboardPage: status bar + time-range toolbar + tab switcher - Add 37 new backend tests (ban service, geo service, dashboard router) - All 141 tests pass; ruff/mypy --strict/tsc --noEmit clean
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
"""Tests for the dashboard router (GET /api/dashboard/status)."""
|
||||
"""Tests for the dashboard router (GET /api/dashboard/status, GET /api/dashboard/bans, GET /api/dashboard/accesses)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
@@ -11,6 +12,12 @@ from httpx import ASGITransport, AsyncClient
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.models.ban import (
|
||||
AccessListItem,
|
||||
AccessListResponse,
|
||||
DashboardBanItem,
|
||||
DashboardBanListResponse,
|
||||
)
|
||||
from app.models.server import ServerStatus
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -56,6 +63,8 @@ async def dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
total_bans=10,
|
||||
total_failures=5,
|
||||
)
|
||||
# Provide a stub HTTP session so ban/access endpoints can access app.state.http_session.
|
||||
app.state.http_session = MagicMock()
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
@@ -94,6 +103,7 @@ async def offline_dashboard_client(tmp_path: Path) -> AsyncClient: # type: igno
|
||||
app.state.db = db
|
||||
|
||||
app.state.server_status = ServerStatus(online=False)
|
||||
app.state.http_session = MagicMock()
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
@@ -192,3 +202,190 @@ class TestDashboardStatus:
|
||||
assert response.status_code == 200
|
||||
status = response.json()["status"]
|
||||
assert status["online"] is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dashboard bans endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_ban_list_response(n: int = 2) -> DashboardBanListResponse:
|
||||
"""Build a mock DashboardBanListResponse with *n* items."""
|
||||
items = [
|
||||
DashboardBanItem(
|
||||
ip=f"1.2.3.{i}",
|
||||
jail="sshd",
|
||||
banned_at="2026-03-01T10:00:00+00:00",
|
||||
service=None,
|
||||
country_code="DE",
|
||||
country_name="Germany",
|
||||
asn="AS3320",
|
||||
org="Telekom",
|
||||
ban_count=1,
|
||||
)
|
||||
for i in range(n)
|
||||
]
|
||||
return DashboardBanListResponse(items=items, total=n, page=1, page_size=100)
|
||||
|
||||
|
||||
def _make_access_list_response(n: int = 2) -> AccessListResponse:
|
||||
"""Build a mock AccessListResponse with *n* items."""
|
||||
items = [
|
||||
AccessListItem(
|
||||
ip=f"5.6.7.{i}",
|
||||
jail="nginx",
|
||||
timestamp="2026-03-01T10:00:00+00:00",
|
||||
line=f"GET /admin HTTP/1.1 attempt {i}",
|
||||
country_code="US",
|
||||
country_name="United States",
|
||||
asn="AS15169",
|
||||
org="Google LLC",
|
||||
)
|
||||
for i in range(n)
|
||||
]
|
||||
return AccessListResponse(items=items, total=n, page=1, page_size=100)
|
||||
|
||||
|
||||
class TestDashboardBans:
|
||||
"""GET /api/dashboard/bans."""
|
||||
|
||||
async def test_returns_200_when_authenticated(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""Authenticated request returns HTTP 200."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_bans",
|
||||
new=AsyncMock(return_value=_make_ban_list_response()),
|
||||
):
|
||||
response = await dashboard_client.get("/api/dashboard/bans")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_returns_401_when_unauthenticated(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
"""Unauthenticated request returns HTTP 401."""
|
||||
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
||||
response = await client.get("/api/dashboard/bans")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_response_contains_items_and_total(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""Response body contains ``items`` list and ``total`` count."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_bans",
|
||||
new=AsyncMock(return_value=_make_ban_list_response(3)),
|
||||
):
|
||||
response = await dashboard_client.get("/api/dashboard/bans")
|
||||
|
||||
body = response.json()
|
||||
assert "items" in body
|
||||
assert "total" in body
|
||||
assert body["total"] == 3
|
||||
assert len(body["items"]) == 3
|
||||
|
||||
async def test_default_range_is_24h(self, dashboard_client: AsyncClient) -> None:
|
||||
"""If no ``range`` param is provided the default ``24h`` preset is used."""
|
||||
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
||||
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
||||
await dashboard_client.get("/api/dashboard/bans")
|
||||
|
||||
called_range = mock_list.call_args[0][1]
|
||||
assert called_range == "24h"
|
||||
|
||||
async def test_accepts_time_range_param(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""The ``range`` query parameter is forwarded to ban_service."""
|
||||
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
||||
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
||||
await dashboard_client.get("/api/dashboard/bans?range=7d")
|
||||
|
||||
called_range = mock_list.call_args[0][1]
|
||||
assert called_range == "7d"
|
||||
|
||||
async def test_empty_ban_list_returns_zero_total(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""Returns ``total=0`` and empty ``items`` when no bans are in range."""
|
||||
empty = DashboardBanListResponse(items=[], total=0, page=1, page_size=100)
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_bans",
|
||||
new=AsyncMock(return_value=empty),
|
||||
):
|
||||
response = await dashboard_client.get("/api/dashboard/bans")
|
||||
|
||||
body = response.json()
|
||||
assert body["total"] == 0
|
||||
assert body["items"] == []
|
||||
|
||||
async def test_item_shape_is_correct(self, dashboard_client: AsyncClient) -> None:
|
||||
"""Each item in ``items`` has the expected fields."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_bans",
|
||||
new=AsyncMock(return_value=_make_ban_list_response(1)),
|
||||
):
|
||||
response = await dashboard_client.get("/api/dashboard/bans")
|
||||
|
||||
item = response.json()["items"][0]
|
||||
assert "ip" in item
|
||||
assert "jail" in item
|
||||
assert "banned_at" in item
|
||||
assert "ban_count" in item
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dashboard accesses endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDashboardAccesses:
|
||||
"""GET /api/dashboard/accesses."""
|
||||
|
||||
async def test_returns_200_when_authenticated(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""Authenticated request returns HTTP 200."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_accesses",
|
||||
new=AsyncMock(return_value=_make_access_list_response()),
|
||||
):
|
||||
response = await dashboard_client.get("/api/dashboard/accesses")
|
||||
assert response.status_code == 200
|
||||
|
||||
async def test_returns_401_when_unauthenticated(
|
||||
self, client: AsyncClient
|
||||
) -> None:
|
||||
"""Unauthenticated request returns HTTP 401."""
|
||||
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
||||
response = await client.get("/api/dashboard/accesses")
|
||||
assert response.status_code == 401
|
||||
|
||||
async def test_response_contains_access_items(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""Response body contains ``items`` with ``line`` fields."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_accesses",
|
||||
new=AsyncMock(return_value=_make_access_list_response(2)),
|
||||
):
|
||||
response = await dashboard_client.get("/api/dashboard/accesses")
|
||||
|
||||
body = response.json()
|
||||
assert body["total"] == 2
|
||||
assert len(body["items"]) == 2
|
||||
assert "line" in body["items"][0]
|
||||
|
||||
async def test_default_range_is_24h(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""If no ``range`` param is provided the default ``24h`` preset is used."""
|
||||
mock_list = AsyncMock(return_value=_make_access_list_response())
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_accesses", new=mock_list
|
||||
):
|
||||
await dashboard_client.get("/api/dashboard/accesses")
|
||||
|
||||
called_range = mock_list.call_args[0][1]
|
||||
assert called_range == "24h"
|
||||
|
||||
|
||||
359
backend/tests/test_services/test_ban_service.py
Normal file
359
backend/tests/test_services/test_ban_service.py
Normal file
@@ -0,0 +1,359 @@
|
||||
"""Tests for ban_service.list_bans() and ban_service.list_accesses()."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from app.services import ban_service
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_NOW: int = int(time.time())
|
||||
_ONE_HOUR_AGO: int = _NOW - 3600
|
||||
_TWO_DAYS_AGO: int = _NOW - 2 * 24 * 3600
|
||||
|
||||
|
||||
async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
|
||||
"""Create a minimal fail2ban SQLite database with the given ban rows.
|
||||
|
||||
Args:
|
||||
path: Filesystem path for the new SQLite file.
|
||||
rows: Sequence of dicts with keys ``jail``, ``ip``, ``timeofban``,
|
||||
``bantime``, ``bancount``, and optionally ``data``.
|
||||
"""
|
||||
async with aiosqlite.connect(path) as db:
|
||||
await db.execute(
|
||||
"CREATE TABLE jails ("
|
||||
"name TEXT NOT NULL UNIQUE, "
|
||||
"enabled INTEGER NOT NULL DEFAULT 1"
|
||||
")"
|
||||
)
|
||||
await db.execute(
|
||||
"CREATE TABLE bans ("
|
||||
"jail TEXT NOT NULL, "
|
||||
"ip TEXT, "
|
||||
"timeofban INTEGER NOT NULL, "
|
||||
"bantime INTEGER NOT NULL, "
|
||||
"bancount INTEGER NOT NULL DEFAULT 1, "
|
||||
"data JSON"
|
||||
")"
|
||||
)
|
||||
for row in rows:
|
||||
await db.execute(
|
||||
"INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) "
|
||||
"VALUES (?, ?, ?, ?, ?, ?)",
|
||||
(
|
||||
row["jail"],
|
||||
row["ip"],
|
||||
row["timeofban"],
|
||||
row.get("bantime", 3600),
|
||||
row.get("bancount", 1),
|
||||
json.dumps(row["data"]) if "data" in row else None,
|
||||
),
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||
"""Return the path to a test fail2ban SQLite database with several bans."""
|
||||
path = str(tmp_path / "fail2ban_test.sqlite3")
|
||||
await _create_f2b_db(
|
||||
path,
|
||||
[
|
||||
{
|
||||
"jail": "sshd",
|
||||
"ip": "1.2.3.4",
|
||||
"timeofban": _ONE_HOUR_AGO,
|
||||
"bantime": 3600,
|
||||
"bancount": 2,
|
||||
"data": {
|
||||
"matches": ["Nov 10 10:00 sshd[123]: Failed password for root"],
|
||||
"failures": 5,
|
||||
},
|
||||
},
|
||||
{
|
||||
"jail": "nginx",
|
||||
"ip": "5.6.7.8",
|
||||
"timeofban": _ONE_HOUR_AGO,
|
||||
"bantime": 7200,
|
||||
"bancount": 1,
|
||||
"data": {"matches": ["GET /admin HTTP/1.1"], "failures": 3},
|
||||
},
|
||||
{
|
||||
"jail": "sshd",
|
||||
"ip": "9.10.11.12",
|
||||
"timeofban": _TWO_DAYS_AGO,
|
||||
"bantime": 3600,
|
||||
"bancount": 1,
|
||||
"data": {"failures": 6}, # no matches
|
||||
},
|
||||
],
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def empty_f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||
"""Return the path to a fail2ban SQLite database with no ban records."""
|
||||
path = str(tmp_path / "fail2ban_empty.sqlite3")
|
||||
await _create_f2b_db(path, [])
|
||||
return path
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_bans — happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListBansHappyPath:
|
||||
"""Verify ban_service.list_bans() under normal conditions."""
|
||||
|
||||
async def test_returns_bans_in_range(self, f2b_db_path: str) -> None:
|
||||
"""Only bans within the selected range are returned."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
|
||||
# Two bans within last 24 h; one is 2 days old and excluded.
|
||||
assert result.total == 2
|
||||
assert len(result.items) == 2
|
||||
|
||||
async def test_results_sorted_newest_first(self, f2b_db_path: str) -> None:
|
||||
"""Items are ordered by ``banned_at`` descending (newest first)."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
|
||||
timestamps = [item.banned_at for item in result.items]
|
||||
assert timestamps == sorted(timestamps, reverse=True)
|
||||
|
||||
async def test_ban_fields_present(self, f2b_db_path: str) -> None:
|
||||
"""Each item contains ip, jail, banned_at, ban_count."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
|
||||
for item in result.items:
|
||||
assert item.ip
|
||||
assert item.jail
|
||||
assert item.banned_at
|
||||
assert item.ban_count >= 1
|
||||
|
||||
async def test_service_extracted_from_first_match(self, f2b_db_path: str) -> None:
|
||||
"""``service`` field is the first element of ``data.matches``."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
|
||||
sshd_item = next(i for i in result.items if i.jail == "sshd")
|
||||
assert sshd_item.service is not None
|
||||
assert "Failed password" in sshd_item.service
|
||||
|
||||
async def test_service_is_none_when_no_matches(self, f2b_db_path: str) -> None:
|
||||
"""``service`` is ``None`` when the ban has no stored matches."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
# Use 7d to include the older ban with no matches.
|
||||
result = await ban_service.list_bans("/fake/sock", "7d")
|
||||
|
||||
no_match = next(i for i in result.items if i.ip == "9.10.11.12")
|
||||
assert no_match.service is None
|
||||
|
||||
async def test_empty_db_returns_zero(self, empty_f2b_db_path: str) -> None:
|
||||
"""When no bans exist the result has total=0 and no items."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
|
||||
assert result.total == 0
|
||||
assert result.items == []
|
||||
|
||||
async def test_365d_range_includes_old_bans(self, f2b_db_path: str) -> None:
|
||||
"""The ``365d`` range includes bans that are 2 days old."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "365d")
|
||||
|
||||
assert result.total == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_bans — geo enrichment
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListBansGeoEnrichment:
|
||||
"""Verify geo enrichment integration in ban_service.list_bans()."""
|
||||
|
||||
async def test_geo_data_applied_when_enricher_provided(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
"""Geo fields are populated when an enricher returns data."""
|
||||
from app.services.geo_service import GeoInfo
|
||||
|
||||
async def fake_enricher(ip: str) -> GeoInfo:
|
||||
return GeoInfo(
|
||||
country_code="DE",
|
||||
country_name="Germany",
|
||||
asn="AS3320",
|
||||
org="Deutsche Telekom",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", geo_enricher=fake_enricher
|
||||
)
|
||||
|
||||
for item in result.items:
|
||||
assert item.country_code == "DE"
|
||||
assert item.country_name == "Germany"
|
||||
assert item.asn == "AS3320"
|
||||
|
||||
async def test_geo_failure_does_not_break_results(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
"""A geo enricher that raises still returns ban items (geo fields null)."""
|
||||
|
||||
async def failing_enricher(ip: str) -> None:
|
||||
raise RuntimeError("geo service down")
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", geo_enricher=failing_enricher
|
||||
)
|
||||
|
||||
assert result.total == 2
|
||||
for item in result.items:
|
||||
assert item.country_code is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_bans — pagination
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListBansPagination:
|
||||
"""Verify pagination parameters in list_bans()."""
|
||||
|
||||
async def test_page_size_respected(self, f2b_db_path: str) -> None:
|
||||
"""``page_size=1`` returns at most one item."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
|
||||
|
||||
assert len(result.items) == 1
|
||||
assert result.page_size == 1
|
||||
|
||||
async def test_page_2_returns_remaining_items(self, f2b_db_path: str) -> None:
|
||||
"""The second page returns items not on the first page."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
page1 = await ban_service.list_bans("/fake/sock", "7d", page=1, page_size=1)
|
||||
page2 = await ban_service.list_bans("/fake/sock", "7d", page=2, page_size=1)
|
||||
|
||||
# Different IPs should appear on different pages.
|
||||
assert page1.items[0].ip != page2.items[0].ip
|
||||
|
||||
async def test_total_reflects_full_count_not_page_count(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
"""``total`` reports all matching records regardless of pagination."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
|
||||
|
||||
assert result.total == 3 # All three bans are within 7d.
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_accesses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListAccesses:
|
||||
"""Verify ban_service.list_accesses()."""
|
||||
|
||||
async def test_expands_matches_into_rows(self, f2b_db_path: str) -> None:
|
||||
"""Each element in ``data.matches`` becomes a separate row."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_accesses("/fake/sock", "24h")
|
||||
|
||||
# Two bans in last 24h: sshd (1 match) + nginx (1 match) = 2 rows.
|
||||
assert result.total == 2
|
||||
assert len(result.items) == 2
|
||||
|
||||
async def test_access_item_has_line_field(self, f2b_db_path: str) -> None:
|
||||
"""Each access item contains the raw matched log line."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_accesses("/fake/sock", "24h")
|
||||
|
||||
for item in result.items:
|
||||
assert item.line
|
||||
|
||||
async def test_ban_with_no_matches_produces_no_access_rows(
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
"""Bans with empty matches list do not contribute rows."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_accesses("/fake/sock", "7d")
|
||||
|
||||
# Third ban (9.10.11.12) has no matches, so only 2 rows total.
|
||||
assert result.total == 2
|
||||
|
||||
async def test_empty_db_returns_zero_accesses(
|
||||
self, empty_f2b_db_path: str
|
||||
) -> None:
|
||||
"""Returns empty result when no bans exist."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_accesses("/fake/sock", "24h")
|
||||
|
||||
assert result.total == 0
|
||||
assert result.items == []
|
||||
212
backend/tests/test_services/test_geo_service.py
Normal file
212
backend/tests/test_services/test_geo_service.py
Normal file
@@ -0,0 +1,212 @@
|
||||
"""Tests for geo_service.lookup()."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services import geo_service
|
||||
from app.services.geo_service import GeoInfo
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_session(response_json: dict[str, object], status: int = 200) -> MagicMock:
|
||||
"""Build a mock aiohttp.ClientSession that returns *response_json*.
|
||||
|
||||
Args:
|
||||
response_json: The dict that the mock response's ``json()`` returns.
|
||||
status: HTTP status code for the mock response.
|
||||
|
||||
Returns:
|
||||
A :class:`MagicMock` that behaves like an
|
||||
``aiohttp.ClientSession`` in an ``async with`` context.
|
||||
"""
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = status
|
||||
mock_resp.json = AsyncMock(return_value=response_json)
|
||||
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
session = MagicMock()
|
||||
session.get = MagicMock(return_value=mock_ctx)
|
||||
return session
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_geo_cache() -> None: # type: ignore[misc]
|
||||
"""Flush the module-level geo cache before every test."""
|
||||
geo_service.clear_cache()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLookupSuccess:
|
||||
"""geo_service.lookup() under normal conditions."""
|
||||
|
||||
async def test_returns_country_code(self) -> None:
|
||||
"""country_code is populated from the ``countryCode`` field."""
|
||||
session = _make_session(
|
||||
{
|
||||
"status": "success",
|
||||
"countryCode": "DE",
|
||||
"country": "Germany",
|
||||
"as": "AS3320 Deutsche Telekom AG",
|
||||
"org": "AS3320 Deutsche Telekom AG",
|
||||
}
|
||||
)
|
||||
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
|
||||
assert result is not None
|
||||
assert result.country_code == "DE"
|
||||
|
||||
async def test_returns_country_name(self) -> None:
|
||||
"""country_name is populated from the ``country`` field."""
|
||||
session = _make_session(
|
||||
{
|
||||
"status": "success",
|
||||
"countryCode": "US",
|
||||
"country": "United States",
|
||||
"as": "AS15169 Google LLC",
|
||||
"org": "Google LLC",
|
||||
}
|
||||
)
|
||||
result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
|
||||
|
||||
assert result is not None
|
||||
assert result.country_name == "United States"
|
||||
|
||||
async def test_asn_extracted_without_org_suffix(self) -> None:
|
||||
"""The ASN field contains only the ``AS<N>`` prefix, not the full string."""
|
||||
session = _make_session(
|
||||
{
|
||||
"status": "success",
|
||||
"countryCode": "DE",
|
||||
"country": "Germany",
|
||||
"as": "AS3320 Deutsche Telekom AG",
|
||||
"org": "Deutsche Telekom",
|
||||
}
|
||||
)
|
||||
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
|
||||
assert result is not None
|
||||
assert result.asn == "AS3320"
|
||||
|
||||
async def test_org_populated(self) -> None:
|
||||
"""org field is populated from the ``org`` key."""
|
||||
session = _make_session(
|
||||
{
|
||||
"status": "success",
|
||||
"countryCode": "US",
|
||||
"country": "United States",
|
||||
"as": "AS15169 Google LLC",
|
||||
"org": "Google LLC",
|
||||
}
|
||||
)
|
||||
result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
|
||||
|
||||
assert result is not None
|
||||
assert result.org == "Google LLC"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cache behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLookupCaching:
|
||||
"""Verify that results are cached and the cache can be cleared."""
|
||||
|
||||
async def test_second_call_uses_cache(self) -> None:
|
||||
"""Subsequent lookups for the same IP do not make additional HTTP requests."""
|
||||
session = _make_session(
|
||||
{
|
||||
"status": "success",
|
||||
"countryCode": "DE",
|
||||
"country": "Germany",
|
||||
"as": "AS3320 Deutsche Telekom AG",
|
||||
"org": "Deutsche Telekom",
|
||||
}
|
||||
)
|
||||
|
||||
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
|
||||
# The session.get() should only have been called once.
|
||||
assert session.get.call_count == 1
|
||||
|
||||
async def test_clear_cache_forces_refetch(self) -> None:
|
||||
"""After clearing the cache a new HTTP request is made."""
|
||||
session = _make_session(
|
||||
{
|
||||
"status": "success",
|
||||
"countryCode": "DE",
|
||||
"country": "Germany",
|
||||
"as": "AS3320",
|
||||
"org": "Telekom",
|
||||
}
|
||||
)
|
||||
|
||||
await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type]
|
||||
geo_service.clear_cache()
|
||||
await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type]
|
||||
|
||||
assert session.get.call_count == 2
|
||||
|
||||
async def test_negative_result_cached(self) -> None:
|
||||
"""A failed lookup result (status != success) is also cached."""
|
||||
session = _make_session(
|
||||
{"status": "fail", "message": "reserved range"}
|
||||
)
|
||||
|
||||
await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type]
|
||||
await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type]
|
||||
|
||||
assert session.get.call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Failure modes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLookupFailures:
|
||||
"""geo_service.lookup() when things go wrong."""
|
||||
|
||||
async def test_non_200_response_returns_none(self) -> None:
|
||||
"""A 429 or 500 status returns ``None`` without caching."""
|
||||
session = _make_session({}, status=429)
|
||||
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
|
||||
assert result is None
|
||||
|
||||
async def test_network_error_returns_none(self) -> None:
|
||||
"""A network exception returns ``None``."""
|
||||
session = MagicMock()
|
||||
session.get = MagicMock(side_effect=OSError("connection refused"))
|
||||
|
||||
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
|
||||
assert result is None
|
||||
|
||||
async def test_failed_status_returns_geo_info_with_nulls(self) -> None:
|
||||
"""When ip-api returns ``status=fail`` a GeoInfo with null fields is cached."""
|
||||
session = _make_session({"status": "fail", "message": "private range"})
|
||||
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, GeoInfo)
|
||||
assert result.country_code is None
|
||||
assert result.country_name is None
|
||||
Reference in New Issue
Block a user