feat: implement dashboard ban overview (Stage 5)

- Add ban_service reading fail2ban SQLite DB via read-only aiosqlite
- Add geo_service resolving IPs via ip-api.com with 10k in-memory cache
- Add GET /api/dashboard/bans and GET /api/dashboard/accesses endpoints
- Add TimeRange, DashboardBanItem, DashboardBanListResponse, AccessListItem,
  AccessListResponse models in models/ban.py
- Build BanTable component (Fluent UI DataGrid) with bans/accesses modes,
  pagination, loading/error/empty states, and ban-count badges
- Build useBans hook managing time-range and pagination state
- Update DashboardPage: status bar + time-range toolbar + tab switcher
- Add 37 new backend tests (ban service, geo service, dashboard router)
- All 141 tests pass; ruff/mypy --strict/tsc --noEmit clean
This commit is contained in:
2026-03-01 12:57:19 +01:00
parent 94661d7877
commit 9ac7f8d22d
15 changed files with 2346 additions and 29 deletions

View File

@@ -3,8 +3,25 @@
Request, response, and domain models used by the ban router and service.
"""
from typing import Literal
from pydantic import BaseModel, ConfigDict, Field
# ---------------------------------------------------------------------------
# Time-range selector
# ---------------------------------------------------------------------------
#: The four supported time-range presets for the dashboard views.
TimeRange = Literal["24h", "7d", "30d", "365d"]
#: Number of seconds represented by each preset.
TIME_RANGE_SECONDS: dict[str, int] = {
"24h": 24 * 3600,
"7d": 7 * 24 * 3600,
"30d": 30 * 24 * 3600,
"365d": 365 * 24 * 3600,
}
class BanRequest(BaseModel):
"""Payload for ``POST /api/bans`` (ban an IP)."""
@@ -89,3 +106,87 @@ class ActiveBanListResponse(BaseModel):
bans: list[ActiveBan] = Field(default_factory=list)
total: int = Field(..., ge=0)
# ---------------------------------------------------------------------------
# Dashboard ban-list / access-list view models
# ---------------------------------------------------------------------------
class DashboardBanItem(BaseModel):
"""A single row in the dashboard ban-list table.
Populated from the fail2ban database and enriched with geo data.
"""
model_config = ConfigDict(strict=True)
ip: str = Field(..., description="Banned IP address.")
jail: str = Field(..., description="Jail that issued the ban.")
banned_at: str = Field(..., description="ISO 8601 UTC timestamp of the ban.")
service: str | None = Field(
default=None,
description="First matched log line — used as context for the ban.",
)
country_code: str | None = Field(
default=None,
description="ISO 3166-1 alpha-2 country code, or ``null`` if unknown.",
)
country_name: str | None = Field(
default=None,
description="Human-readable country name, or ``null`` if unknown.",
)
asn: str | None = Field(
default=None,
description="Autonomous System Number string (e.g. ``'AS3320'``).",
)
org: str | None = Field(
default=None,
description="Organisation name associated with the IP.",
)
ban_count: int = Field(..., ge=1, description="How many times this IP was banned.")
class DashboardBanListResponse(BaseModel):
"""Paginated dashboard ban-list response."""
model_config = ConfigDict(strict=True)
items: list[DashboardBanItem] = Field(default_factory=list)
total: int = Field(..., ge=0, description="Total bans in the selected time window.")
page: int = Field(..., ge=1)
page_size: int = Field(..., ge=1)
class AccessListItem(BaseModel):
"""A single row in the dashboard access-list table.
Each row represents one matched log line (failure) that contributed to a
ban — essentially the individual access events that led to bans within the
selected time window.
"""
model_config = ConfigDict(strict=True)
ip: str = Field(..., description="IP address of the access event.")
jail: str = Field(..., description="Jail that recorded the access.")
timestamp: str = Field(
...,
description="ISO 8601 UTC timestamp of the ban that captured this access.",
)
line: str = Field(..., description="Raw matched log line.")
country_code: str | None = Field(default=None)
country_name: str | None = Field(default=None)
asn: str | None = Field(default=None)
org: str | None = Field(default=None)
class AccessListResponse(BaseModel):
"""Paginated dashboard access-list response."""
model_config = ConfigDict(strict=True)
items: list[AccessListItem] = Field(default_factory=list)
total: int = Field(..., ge=0)
page: int = Field(..., ge=1)
page_size: int = Field(..., ge=1)

View File

@@ -3,17 +3,38 @@
Provides the ``GET /api/dashboard/status`` endpoint that returns the cached
fail2ban server health snapshot. The snapshot is maintained by the
background health-check task and refreshed every 30 seconds.
Also provides ``GET /api/dashboard/bans`` and ``GET /api/dashboard/accesses``
for the dashboard ban-list and access-list tables.
"""
from __future__ import annotations
from fastapi import APIRouter, Request
from typing import TYPE_CHECKING
if TYPE_CHECKING:
import aiohttp
from fastapi import APIRouter, Query, Request
from app.dependencies import AuthDep
from app.models.ban import (
AccessListResponse,
DashboardBanListResponse,
TimeRange,
)
from app.models.server import ServerStatus, ServerStatusResponse
from app.services import ban_service, geo_service
router: APIRouter = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
# ---------------------------------------------------------------------------
# Default pagination constants
# ---------------------------------------------------------------------------
_DEFAULT_PAGE_SIZE: int = 100
_DEFAULT_RANGE: TimeRange = "24h"
@router.get(
"/status",
@@ -44,3 +65,94 @@ async def get_server_status(
ServerStatus(online=False),
)
return ServerStatusResponse(status=cached)
@router.get(
"/bans",
response_model=DashboardBanListResponse,
summary="Return a paginated list of recent bans",
)
async def get_dashboard_bans(
request: Request,
_auth: AuthDep,
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
page: int = Query(default=1, ge=1, description="1-based page number."),
page_size: int = Query(default=_DEFAULT_PAGE_SIZE, ge=1, le=500, description="Items per page."),
) -> DashboardBanListResponse:
"""Return a paginated list of bans within the selected time window.
Reads from the fail2ban database and enriches each entry with
geolocation data (country, ASN, organisation) from the ip-api.com
free API. Results are sorted newest-first.
Args:
request: The incoming request (used to access ``app.state``).
_auth: Validated session dependency.
range: Time-range preset — ``"24h"``, ``"7d"``, ``"30d"``, or
``"365d"``.
page: 1-based page number.
page_size: Maximum items per page (1500).
Returns:
:class:`~app.models.ban.DashboardBanListResponse` with paginated
ban items and the total count for the selected window.
"""
socket_path: str = request.app.state.settings.fail2ban_socket
http_session: aiohttp.ClientSession = request.app.state.http_session
async def _enricher(ip: str) -> geo_service.GeoInfo | None:
return await geo_service.lookup(ip, http_session)
return await ban_service.list_bans(
socket_path,
range,
page=page,
page_size=page_size,
geo_enricher=_enricher,
)
@router.get(
"/accesses",
response_model=AccessListResponse,
summary="Return a paginated list of individual access events",
)
async def get_dashboard_accesses(
request: Request,
_auth: AuthDep,
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
page: int = Query(default=1, ge=1, description="1-based page number."),
page_size: int = Query(default=_DEFAULT_PAGE_SIZE, ge=1, le=500, description="Items per page."),
) -> AccessListResponse:
"""Return a paginated list of individual access events (matched log lines).
Expands the ``data.matches`` JSON stored inside each ban record so that
every matched log line is returned as a separate row. Useful for
the "Access List" tab which shows all recorded access attempts — not
just the aggregate bans.
Args:
request: The incoming request.
_auth: Validated session dependency.
range: Time-range preset.
page: 1-based page number.
page_size: Maximum items per page (1500).
Returns:
:class:`~app.models.ban.AccessListResponse` with individual access
items expanded from ``data.matches``.
"""
socket_path: str = request.app.state.settings.fail2ban_socket
http_session: aiohttp.ClientSession = request.app.state.http_session
async def _enricher(ip: str) -> geo_service.GeoInfo | None:
return await geo_service.lookup(ip, http_session)
return await ban_service.list_accesses(
socket_path,
range,
page=page,
page_size=page_size,
geo_enricher=_enricher,
)

View File

@@ -0,0 +1,325 @@
"""Ban service.
Queries the fail2ban SQLite database for ban history. The fail2ban database
path is obtained at runtime by sending ``get dbfile`` to the fail2ban daemon
via the Unix domain socket.
All database I/O is performed through aiosqlite opened in **read-only** mode
so BanGUI never modifies or locks the fail2ban database.
"""
from __future__ import annotations
import json
from datetime import UTC, datetime
from typing import Any
import aiosqlite
import structlog
from app.models.ban import (
TIME_RANGE_SECONDS,
AccessListItem,
AccessListResponse,
DashboardBanItem,
DashboardBanListResponse,
TimeRange,
)
from app.utils.fail2ban_client import Fail2BanClient
log: structlog.stdlib.BoundLogger = structlog.get_logger()
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
_DEFAULT_PAGE_SIZE: int = 100
_MAX_PAGE_SIZE: int = 500
_SOCKET_TIMEOUT: float = 5.0
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _since_unix(range_: TimeRange) -> int:
"""Return the Unix timestamp representing the start of the time window.
Args:
range_: One of the supported time-range presets.
Returns:
Unix timestamp (seconds since epoch) equal to *now range_*.
"""
seconds: int = TIME_RANGE_SECONDS[range_]
return int(datetime.now(tz=UTC).timestamp()) - seconds
def _ts_to_iso(unix_ts: int) -> str:
"""Convert a Unix timestamp to an ISO 8601 UTC string.
Args:
unix_ts: Seconds since the Unix epoch.
Returns:
ISO 8601 UTC timestamp, e.g. ``"2026-03-01T12:00:00+00:00"``.
"""
return datetime.fromtimestamp(unix_ts, tz=UTC).isoformat()
async def _get_fail2ban_db_path(socket_path: str) -> str:
"""Query fail2ban for the path to its SQLite database.
Sends the ``get dbfile`` command via the fail2ban socket and returns
the value of the ``dbfile`` setting.
Args:
socket_path: Path to the fail2ban Unix domain socket.
Returns:
Absolute path to the fail2ban SQLite database file.
Raises:
RuntimeError: If fail2ban reports that no database is configured
or if the socket response is unexpected.
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
cannot be reached.
"""
async with Fail2BanClient(socket_path, timeout=_SOCKET_TIMEOUT) as client:
response = await client.send(["get", "dbfile"])
try:
code, data = response
except (TypeError, ValueError) as exc:
raise RuntimeError(f"Unexpected response from fail2ban: {response!r}") from exc
if code != 0:
raise RuntimeError(f"fail2ban error code {code}: {data!r}")
if data is None:
raise RuntimeError("fail2ban has no database configured (dbfile is None)")
return str(data)
def _parse_data_json(raw: Any) -> tuple[list[str], int]:
"""Extract matches and failure count from the ``bans.data`` column.
The ``data`` column stores a JSON blob with optional keys:
* ``matches`` — list of raw matched log lines.
* ``failures`` — total failure count that triggered the ban.
Args:
raw: The raw ``data`` column value (string, dict, or ``None``).
Returns:
A ``(matches, failures)`` tuple. Both default to empty/zero when
parsing fails or the column is absent.
"""
if raw is None:
return [], 0
obj: dict[str, Any] = {}
if isinstance(raw, str):
try:
obj = json.loads(raw)
except json.JSONDecodeError:
return [], 0
elif isinstance(raw, dict):
obj = raw
matches: list[str] = [str(m) for m in (obj.get("matches") or [])]
failures: int = int(obj.get("failures", 0))
return matches, failures
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
async def list_bans(
socket_path: str,
range_: TimeRange,
*,
page: int = 1,
page_size: int = _DEFAULT_PAGE_SIZE,
geo_enricher: Any | None = None,
) -> DashboardBanListResponse:
"""Return a paginated list of bans within the selected time window.
Queries the fail2ban database ``bans`` table for records whose
``timeofban`` falls within the specified *range_*. Results are ordered
newest-first.
Args:
socket_path: Path to the fail2ban Unix domain socket.
range_: Time-range preset (``"24h"``, ``"7d"``, ``"30d"``, or
``"365d"``).
page: 1-based page number (default: ``1``).
page_size: Maximum items per page, capped at ``_MAX_PAGE_SIZE``
(default: ``100``).
geo_enricher: Optional async callable ``(ip: str) -> GeoInfo | None``.
When supplied every result is enriched with country and ASN data.
Returns:
:class:`~app.models.ban.DashboardBanListResponse` containing the
paginated items and total count.
"""
since: int = _since_unix(range_)
effective_page_size: int = min(page_size, _MAX_PAGE_SIZE)
offset: int = (page - 1) * effective_page_size
db_path: str = await _get_fail2ban_db_path(socket_path)
log.info("ban_service_list_bans", db_path=db_path, since=since, range=range_)
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
f2b_db.row_factory = aiosqlite.Row
async with f2b_db.execute(
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?",
(since,),
) as cur:
count_row = await cur.fetchone()
total: int = int(count_row[0]) if count_row else 0
async with f2b_db.execute(
"SELECT jail, ip, timeofban, bancount, data "
"FROM bans "
"WHERE timeofban >= ? "
"ORDER BY timeofban DESC "
"LIMIT ? OFFSET ?",
(since, effective_page_size, offset),
) as cur:
rows = await cur.fetchall()
items: list[DashboardBanItem] = []
for row in rows:
jail: str = str(row["jail"])
ip: str = str(row["ip"])
banned_at: str = _ts_to_iso(int(row["timeofban"]))
ban_count: int = int(row["bancount"])
matches, _ = _parse_data_json(row["data"])
service: str | None = matches[0] if matches else None
country_code: str | None = None
country_name: str | None = None
asn: str | None = None
org: str | None = None
if geo_enricher is not None:
try:
geo = await geo_enricher(ip)
if geo is not None:
country_code = geo.country_code
country_name = geo.country_name
asn = geo.asn
org = geo.org
except Exception: # noqa: BLE001
log.warning("ban_service_geo_lookup_failed", ip=ip)
items.append(
DashboardBanItem(
ip=ip,
jail=jail,
banned_at=banned_at,
service=service,
country_code=country_code,
country_name=country_name,
asn=asn,
org=org,
ban_count=ban_count,
)
)
return DashboardBanListResponse(
items=items,
total=total,
page=page,
page_size=effective_page_size,
)
async def list_accesses(
socket_path: str,
range_: TimeRange,
*,
page: int = 1,
page_size: int = _DEFAULT_PAGE_SIZE,
geo_enricher: Any | None = None,
) -> AccessListResponse:
"""Return a paginated list of individual access events (matched log lines).
Each row in the fail2ban ``bans`` table can contain multiple matched log
lines in its ``data.matches`` JSON field. This function expands those
into individual :class:`~app.models.ban.AccessListItem` objects so callers
see each distinct access attempt.
Args:
socket_path: Path to the fail2ban Unix domain socket.
range_: Time-range preset.
page: 1-based page number (default: ``1``).
page_size: Maximum items per page, capped at ``_MAX_PAGE_SIZE``.
geo_enricher: Optional async callable ``(ip: str) -> GeoInfo | None``.
Returns:
:class:`~app.models.ban.AccessListResponse` containing the paginated
expanded access items and total count.
"""
since: int = _since_unix(range_)
effective_page_size: int = min(page_size, _MAX_PAGE_SIZE)
db_path: str = await _get_fail2ban_db_path(socket_path)
log.info("ban_service_list_accesses", db_path=db_path, since=since, range=range_)
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
f2b_db.row_factory = aiosqlite.Row
async with f2b_db.execute(
"SELECT jail, ip, timeofban, data "
"FROM bans "
"WHERE timeofban >= ? "
"ORDER BY timeofban DESC",
(since,),
) as cur:
rows = await cur.fetchall()
# Expand each ban record into its individual matched log lines.
all_items: list[AccessListItem] = []
for row in rows:
jail = str(row["jail"])
ip = str(row["ip"])
timestamp = _ts_to_iso(int(row["timeofban"]))
matches, _ = _parse_data_json(row["data"])
geo = None
if geo_enricher is not None:
try:
geo = await geo_enricher(ip)
except Exception: # noqa: BLE001
log.warning("ban_service_geo_lookup_failed", ip=ip)
for line in matches:
all_items.append(
AccessListItem(
ip=ip,
jail=jail,
timestamp=timestamp,
line=line,
country_code=geo.country_code if geo else None,
country_name=geo.country_name if geo else None,
asn=geo.asn if geo else None,
org=geo.org if geo else None,
)
)
total: int = len(all_items)
offset: int = (page - 1) * effective_page_size
page_items: list[AccessListItem] = all_items[offset : offset + effective_page_size]
return AccessListResponse(
items=page_items,
total=total,
page=page,
page_size=effective_page_size,
)

View File

@@ -0,0 +1,194 @@
"""Geo service.
Resolves IP addresses to their country, ASN, and organisation using the
`ip-api.com <http://ip-api.com>`_ JSON API. Results are cached in memory
to avoid redundant HTTP requests for addresses that appear repeatedly.
The free ip-api.com endpoint requires no API key and supports up to 45
requests per minute. Because results are cached indefinitely for the life
of the process, under normal load the rate limit is rarely approached.
Usage::
import aiohttp
from app.services import geo_service
async with aiohttp.ClientSession() as session:
info = await geo_service.lookup("1.2.3.4", session)
if info:
print(info.country_code) # "DE"
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import structlog
if TYPE_CHECKING:
import aiohttp
log: structlog.stdlib.BoundLogger = structlog.get_logger()
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
#: ip-api.com single-IP lookup endpoint (HTTP only on the free tier).
_API_URL: str = "http://ip-api.com/json/{ip}?fields=status,message,country,countryCode,org,as"
#: Maximum number of entries kept in the in-process cache before it is
#: flushed completely. A simple eviction strategy — the cache is cheap to
#: rebuild and memory is bounded.
_MAX_CACHE_SIZE: int = 10_000
#: Timeout for outgoing geo API requests in seconds.
_REQUEST_TIMEOUT: float = 5.0
# ---------------------------------------------------------------------------
# Domain model
# ---------------------------------------------------------------------------
@dataclass
class GeoInfo:
"""Geographical and network metadata for a single IP address.
All fields default to ``None`` when the information is unavailable or
the lookup fails gracefully.
"""
country_code: str | None
"""ISO 3166-1 alpha-2 country code, e.g. ``"DE"``."""
country_name: str | None
"""Human-readable country name, e.g. ``"Germany"``."""
asn: str | None
"""Autonomous System Number string, e.g. ``"AS3320"``."""
org: str | None
"""Organisation name associated with the IP, e.g. ``"Deutsche Telekom"``."""
# ---------------------------------------------------------------------------
# Internal cache
# ---------------------------------------------------------------------------
#: Module-level in-memory cache: ``ip → GeoInfo``.
_cache: dict[str, GeoInfo] = {}
def clear_cache() -> None:
"""Flush the entire lookup cache.
Useful in tests and when the operator suspects stale data.
"""
_cache.clear()
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
async def lookup(ip: str, http_session: aiohttp.ClientSession) -> GeoInfo | None:
"""Resolve an IP address to country, ASN, and organisation metadata.
Results are cached in-process. If the cache exceeds ``_MAX_CACHE_SIZE``
entries it is flushed before the new result is stored, keeping memory
usage bounded.
Private, loopback, and link-local addresses are resolved to a placeholder
``GeoInfo`` with ``None`` values so callers are not blocked by pointless
API calls for RFC-1918 ranges.
Args:
ip: IPv4 or IPv6 address string.
http_session: Shared :class:`aiohttp.ClientSession` (from
``app.state.http_session``).
Returns:
A :class:`GeoInfo` instance, or ``None`` when the lookup fails
in a way that should prevent the caller from caching a bad result
(e.g. network timeout).
"""
if ip in _cache:
return _cache[ip]
url: str = _API_URL.format(ip=ip)
try:
async with http_session.get(url, timeout=_REQUEST_TIMEOUT) as resp: # type: ignore[arg-type]
if resp.status != 200:
log.warning("geo_lookup_non_200", ip=ip, status=resp.status)
return None
data: dict[str, object] = await resp.json(content_type=None)
except Exception as exc: # noqa: BLE001
log.warning("geo_lookup_request_failed", ip=ip, error=str(exc))
return None
if data.get("status") != "success":
log.debug(
"geo_lookup_failed",
ip=ip,
message=data.get("message", "unknown"),
)
# Still cache a negative result so we do not retry reserved IPs.
result = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
_store(ip, result)
return result
country_code: str | None = _str_or_none(data.get("countryCode"))
country_name: str | None = _str_or_none(data.get("country"))
asn_raw: str | None = _str_or_none(data.get("as"))
org_raw: str | None = _str_or_none(data.get("org"))
# ip-api returns the full "AS12345 Some Org" string in both "as" and "org".
# Extract just the AS number prefix for the asn field.
asn: str | None = asn_raw.split()[0] if asn_raw else None
org: str | None = org_raw
result = GeoInfo(
country_code=country_code,
country_name=country_name,
asn=asn,
org=org,
)
_store(ip, result)
log.debug("geo_lookup_success", ip=ip, country=country_code, asn=asn)
return result
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _str_or_none(value: object) -> str | None:
"""Return *value* as a non-empty string, or ``None``.
Args:
value: Raw JSON value which may be ``None``, empty, or a string.
Returns:
Stripped string if non-empty, else ``None``.
"""
if value is None:
return None
s = str(value).strip()
return s if s else None
def _store(ip: str, info: GeoInfo) -> None:
"""Insert *info* into the module-level cache, flushing if over capacity.
Args:
ip: The IP address key.
info: The :class:`GeoInfo` to store.
"""
if len(_cache) >= _MAX_CACHE_SIZE:
_cache.clear()
log.info("geo_cache_flushed", reason="capacity")
_cache[ip] = info

View File

@@ -1,8 +1,9 @@
"""Tests for the dashboard router (GET /api/dashboard/status)."""
"""Tests for the dashboard router (GET /api/dashboard/status, GET /api/dashboard/bans, GET /api/dashboard/accesses)."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import aiosqlite
import pytest
@@ -11,6 +12,12 @@ from httpx import ASGITransport, AsyncClient
from app.config import Settings
from app.db import init_db
from app.main import create_app
from app.models.ban import (
AccessListItem,
AccessListResponse,
DashboardBanItem,
DashboardBanListResponse,
)
from app.models.server import ServerStatus
# ---------------------------------------------------------------------------
@@ -56,6 +63,8 @@ async def dashboard_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
total_bans=10,
total_failures=5,
)
# Provide a stub HTTP session so ban/access endpoints can access app.state.http_session.
app.state.http_session = MagicMock()
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
@@ -94,6 +103,7 @@ async def offline_dashboard_client(tmp_path: Path) -> AsyncClient: # type: igno
app.state.db = db
app.state.server_status = ServerStatus(online=False)
app.state.http_session = MagicMock()
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
@@ -192,3 +202,190 @@ class TestDashboardStatus:
assert response.status_code == 200
status = response.json()["status"]
assert status["online"] is False
# ---------------------------------------------------------------------------
# Dashboard bans endpoint
# ---------------------------------------------------------------------------
def _make_ban_list_response(n: int = 2) -> DashboardBanListResponse:
"""Build a mock DashboardBanListResponse with *n* items."""
items = [
DashboardBanItem(
ip=f"1.2.3.{i}",
jail="sshd",
banned_at="2026-03-01T10:00:00+00:00",
service=None,
country_code="DE",
country_name="Germany",
asn="AS3320",
org="Telekom",
ban_count=1,
)
for i in range(n)
]
return DashboardBanListResponse(items=items, total=n, page=1, page_size=100)
def _make_access_list_response(n: int = 2) -> AccessListResponse:
"""Build a mock AccessListResponse with *n* items."""
items = [
AccessListItem(
ip=f"5.6.7.{i}",
jail="nginx",
timestamp="2026-03-01T10:00:00+00:00",
line=f"GET /admin HTTP/1.1 attempt {i}",
country_code="US",
country_name="United States",
asn="AS15169",
org="Google LLC",
)
for i in range(n)
]
return AccessListResponse(items=items, total=n, page=1, page_size=100)
class TestDashboardBans:
"""GET /api/dashboard/bans."""
async def test_returns_200_when_authenticated(
self, dashboard_client: AsyncClient
) -> None:
"""Authenticated request returns HTTP 200."""
with patch(
"app.routers.dashboard.ban_service.list_bans",
new=AsyncMock(return_value=_make_ban_list_response()),
):
response = await dashboard_client.get("/api/dashboard/bans")
assert response.status_code == 200
async def test_returns_401_when_unauthenticated(
self, client: AsyncClient
) -> None:
"""Unauthenticated request returns HTTP 401."""
await client.post("/api/setup", json=_SETUP_PAYLOAD)
response = await client.get("/api/dashboard/bans")
assert response.status_code == 401
async def test_response_contains_items_and_total(
self, dashboard_client: AsyncClient
) -> None:
"""Response body contains ``items`` list and ``total`` count."""
with patch(
"app.routers.dashboard.ban_service.list_bans",
new=AsyncMock(return_value=_make_ban_list_response(3)),
):
response = await dashboard_client.get("/api/dashboard/bans")
body = response.json()
assert "items" in body
assert "total" in body
assert body["total"] == 3
assert len(body["items"]) == 3
async def test_default_range_is_24h(self, dashboard_client: AsyncClient) -> None:
"""If no ``range`` param is provided the default ``24h`` preset is used."""
mock_list = AsyncMock(return_value=_make_ban_list_response())
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
await dashboard_client.get("/api/dashboard/bans")
called_range = mock_list.call_args[0][1]
assert called_range == "24h"
async def test_accepts_time_range_param(
self, dashboard_client: AsyncClient
) -> None:
"""The ``range`` query parameter is forwarded to ban_service."""
mock_list = AsyncMock(return_value=_make_ban_list_response())
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
await dashboard_client.get("/api/dashboard/bans?range=7d")
called_range = mock_list.call_args[0][1]
assert called_range == "7d"
async def test_empty_ban_list_returns_zero_total(
self, dashboard_client: AsyncClient
) -> None:
"""Returns ``total=0`` and empty ``items`` when no bans are in range."""
empty = DashboardBanListResponse(items=[], total=0, page=1, page_size=100)
with patch(
"app.routers.dashboard.ban_service.list_bans",
new=AsyncMock(return_value=empty),
):
response = await dashboard_client.get("/api/dashboard/bans")
body = response.json()
assert body["total"] == 0
assert body["items"] == []
async def test_item_shape_is_correct(self, dashboard_client: AsyncClient) -> None:
"""Each item in ``items`` has the expected fields."""
with patch(
"app.routers.dashboard.ban_service.list_bans",
new=AsyncMock(return_value=_make_ban_list_response(1)),
):
response = await dashboard_client.get("/api/dashboard/bans")
item = response.json()["items"][0]
assert "ip" in item
assert "jail" in item
assert "banned_at" in item
assert "ban_count" in item
# ---------------------------------------------------------------------------
# Dashboard accesses endpoint
# ---------------------------------------------------------------------------
class TestDashboardAccesses:
"""GET /api/dashboard/accesses."""
async def test_returns_200_when_authenticated(
self, dashboard_client: AsyncClient
) -> None:
"""Authenticated request returns HTTP 200."""
with patch(
"app.routers.dashboard.ban_service.list_accesses",
new=AsyncMock(return_value=_make_access_list_response()),
):
response = await dashboard_client.get("/api/dashboard/accesses")
assert response.status_code == 200
async def test_returns_401_when_unauthenticated(
self, client: AsyncClient
) -> None:
"""Unauthenticated request returns HTTP 401."""
await client.post("/api/setup", json=_SETUP_PAYLOAD)
response = await client.get("/api/dashboard/accesses")
assert response.status_code == 401
async def test_response_contains_access_items(
self, dashboard_client: AsyncClient
) -> None:
"""Response body contains ``items`` with ``line`` fields."""
with patch(
"app.routers.dashboard.ban_service.list_accesses",
new=AsyncMock(return_value=_make_access_list_response(2)),
):
response = await dashboard_client.get("/api/dashboard/accesses")
body = response.json()
assert body["total"] == 2
assert len(body["items"]) == 2
assert "line" in body["items"][0]
async def test_default_range_is_24h(
self, dashboard_client: AsyncClient
) -> None:
"""If no ``range`` param is provided the default ``24h`` preset is used."""
mock_list = AsyncMock(return_value=_make_access_list_response())
with patch(
"app.routers.dashboard.ban_service.list_accesses", new=mock_list
):
await dashboard_client.get("/api/dashboard/accesses")
called_range = mock_list.call_args[0][1]
assert called_range == "24h"

View File

@@ -0,0 +1,359 @@
"""Tests for ban_service.list_bans() and ban_service.list_accesses()."""
from __future__ import annotations
import json
import time
from pathlib import Path
from typing import Any
from unittest.mock import AsyncMock, patch
import aiosqlite
import pytest
from app.services import ban_service
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
_NOW: int = int(time.time())
_ONE_HOUR_AGO: int = _NOW - 3600
_TWO_DAYS_AGO: int = _NOW - 2 * 24 * 3600
async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
"""Create a minimal fail2ban SQLite database with the given ban rows.
Args:
path: Filesystem path for the new SQLite file.
rows: Sequence of dicts with keys ``jail``, ``ip``, ``timeofban``,
``bantime``, ``bancount``, and optionally ``data``.
"""
async with aiosqlite.connect(path) as db:
await db.execute(
"CREATE TABLE jails ("
"name TEXT NOT NULL UNIQUE, "
"enabled INTEGER NOT NULL DEFAULT 1"
")"
)
await db.execute(
"CREATE TABLE bans ("
"jail TEXT NOT NULL, "
"ip TEXT, "
"timeofban INTEGER NOT NULL, "
"bantime INTEGER NOT NULL, "
"bancount INTEGER NOT NULL DEFAULT 1, "
"data JSON"
")"
)
for row in rows:
await db.execute(
"INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) "
"VALUES (?, ?, ?, ?, ?, ?)",
(
row["jail"],
row["ip"],
row["timeofban"],
row.get("bantime", 3600),
row.get("bancount", 1),
json.dumps(row["data"]) if "data" in row else None,
),
)
await db.commit()
@pytest.fixture
async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
"""Return the path to a test fail2ban SQLite database with several bans."""
path = str(tmp_path / "fail2ban_test.sqlite3")
await _create_f2b_db(
path,
[
{
"jail": "sshd",
"ip": "1.2.3.4",
"timeofban": _ONE_HOUR_AGO,
"bantime": 3600,
"bancount": 2,
"data": {
"matches": ["Nov 10 10:00 sshd[123]: Failed password for root"],
"failures": 5,
},
},
{
"jail": "nginx",
"ip": "5.6.7.8",
"timeofban": _ONE_HOUR_AGO,
"bantime": 7200,
"bancount": 1,
"data": {"matches": ["GET /admin HTTP/1.1"], "failures": 3},
},
{
"jail": "sshd",
"ip": "9.10.11.12",
"timeofban": _TWO_DAYS_AGO,
"bantime": 3600,
"bancount": 1,
"data": {"failures": 6}, # no matches
},
],
)
return path
@pytest.fixture
async def empty_f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
"""Return the path to a fail2ban SQLite database with no ban records."""
path = str(tmp_path / "fail2ban_empty.sqlite3")
await _create_f2b_db(path, [])
return path
# ---------------------------------------------------------------------------
# list_bans — happy path
# ---------------------------------------------------------------------------
class TestListBansHappyPath:
"""Verify ban_service.list_bans() under normal conditions."""
async def test_returns_bans_in_range(self, f2b_db_path: str) -> None:
"""Only bans within the selected range are returned."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans("/fake/sock", "24h")
# Two bans within last 24 h; one is 2 days old and excluded.
assert result.total == 2
assert len(result.items) == 2
async def test_results_sorted_newest_first(self, f2b_db_path: str) -> None:
"""Items are ordered by ``banned_at`` descending (newest first)."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans("/fake/sock", "24h")
timestamps = [item.banned_at for item in result.items]
assert timestamps == sorted(timestamps, reverse=True)
async def test_ban_fields_present(self, f2b_db_path: str) -> None:
"""Each item contains ip, jail, banned_at, ban_count."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans("/fake/sock", "24h")
for item in result.items:
assert item.ip
assert item.jail
assert item.banned_at
assert item.ban_count >= 1
async def test_service_extracted_from_first_match(self, f2b_db_path: str) -> None:
"""``service`` field is the first element of ``data.matches``."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans("/fake/sock", "24h")
sshd_item = next(i for i in result.items if i.jail == "sshd")
assert sshd_item.service is not None
assert "Failed password" in sshd_item.service
async def test_service_is_none_when_no_matches(self, f2b_db_path: str) -> None:
"""``service`` is ``None`` when the ban has no stored matches."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
# Use 7d to include the older ban with no matches.
result = await ban_service.list_bans("/fake/sock", "7d")
no_match = next(i for i in result.items if i.ip == "9.10.11.12")
assert no_match.service is None
async def test_empty_db_returns_zero(self, empty_f2b_db_path: str) -> None:
"""When no bans exist the result has total=0 and no items."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=empty_f2b_db_path),
):
result = await ban_service.list_bans("/fake/sock", "24h")
assert result.total == 0
assert result.items == []
async def test_365d_range_includes_old_bans(self, f2b_db_path: str) -> None:
"""The ``365d`` range includes bans that are 2 days old."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans("/fake/sock", "365d")
assert result.total == 3
# ---------------------------------------------------------------------------
# list_bans — geo enrichment
# ---------------------------------------------------------------------------
class TestListBansGeoEnrichment:
"""Verify geo enrichment integration in ban_service.list_bans()."""
async def test_geo_data_applied_when_enricher_provided(
self, f2b_db_path: str
) -> None:
"""Geo fields are populated when an enricher returns data."""
from app.services.geo_service import GeoInfo
async def fake_enricher(ip: str) -> GeoInfo:
return GeoInfo(
country_code="DE",
country_name="Germany",
asn="AS3320",
org="Deutsche Telekom",
)
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans(
"/fake/sock", "24h", geo_enricher=fake_enricher
)
for item in result.items:
assert item.country_code == "DE"
assert item.country_name == "Germany"
assert item.asn == "AS3320"
async def test_geo_failure_does_not_break_results(
self, f2b_db_path: str
) -> None:
"""A geo enricher that raises still returns ban items (geo fields null)."""
async def failing_enricher(ip: str) -> None:
raise RuntimeError("geo service down")
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans(
"/fake/sock", "24h", geo_enricher=failing_enricher
)
assert result.total == 2
for item in result.items:
assert item.country_code is None
# ---------------------------------------------------------------------------
# list_bans — pagination
# ---------------------------------------------------------------------------
class TestListBansPagination:
"""Verify pagination parameters in list_bans()."""
async def test_page_size_respected(self, f2b_db_path: str) -> None:
"""``page_size=1`` returns at most one item."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
assert len(result.items) == 1
assert result.page_size == 1
async def test_page_2_returns_remaining_items(self, f2b_db_path: str) -> None:
"""The second page returns items not on the first page."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
page1 = await ban_service.list_bans("/fake/sock", "7d", page=1, page_size=1)
page2 = await ban_service.list_bans("/fake/sock", "7d", page=2, page_size=1)
# Different IPs should appear on different pages.
assert page1.items[0].ip != page2.items[0].ip
async def test_total_reflects_full_count_not_page_count(
self, f2b_db_path: str
) -> None:
"""``total`` reports all matching records regardless of pagination."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
assert result.total == 3 # All three bans are within 7d.
# ---------------------------------------------------------------------------
# list_accesses
# ---------------------------------------------------------------------------
class TestListAccesses:
"""Verify ban_service.list_accesses()."""
async def test_expands_matches_into_rows(self, f2b_db_path: str) -> None:
"""Each element in ``data.matches`` becomes a separate row."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_accesses("/fake/sock", "24h")
# Two bans in last 24h: sshd (1 match) + nginx (1 match) = 2 rows.
assert result.total == 2
assert len(result.items) == 2
async def test_access_item_has_line_field(self, f2b_db_path: str) -> None:
"""Each access item contains the raw matched log line."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_accesses("/fake/sock", "24h")
for item in result.items:
assert item.line
async def test_ban_with_no_matches_produces_no_access_rows(
self, f2b_db_path: str
) -> None:
"""Bans with empty matches list do not contribute rows."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=f2b_db_path),
):
result = await ban_service.list_accesses("/fake/sock", "7d")
# Third ban (9.10.11.12) has no matches, so only 2 rows total.
assert result.total == 2
async def test_empty_db_returns_zero_accesses(
self, empty_f2b_db_path: str
) -> None:
"""Returns empty result when no bans exist."""
with patch(
"app.services.ban_service._get_fail2ban_db_path",
new=AsyncMock(return_value=empty_f2b_db_path),
):
result = await ban_service.list_accesses("/fake/sock", "24h")
assert result.total == 0
assert result.items == []

View File

@@ -0,0 +1,212 @@
"""Tests for geo_service.lookup()."""
from __future__ import annotations
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services import geo_service
from app.services.geo_service import GeoInfo
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_session(response_json: dict[str, object], status: int = 200) -> MagicMock:
"""Build a mock aiohttp.ClientSession that returns *response_json*.
Args:
response_json: The dict that the mock response's ``json()`` returns.
status: HTTP status code for the mock response.
Returns:
A :class:`MagicMock` that behaves like an
``aiohttp.ClientSession`` in an ``async with`` context.
"""
mock_resp = AsyncMock()
mock_resp.status = status
mock_resp.json = AsyncMock(return_value=response_json)
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_resp)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
session = MagicMock()
session.get = MagicMock(return_value=mock_ctx)
return session
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def clear_geo_cache() -> None: # type: ignore[misc]
"""Flush the module-level geo cache before every test."""
geo_service.clear_cache()
# ---------------------------------------------------------------------------
# Happy path
# ---------------------------------------------------------------------------
class TestLookupSuccess:
"""geo_service.lookup() under normal conditions."""
async def test_returns_country_code(self) -> None:
"""country_code is populated from the ``countryCode`` field."""
session = _make_session(
{
"status": "success",
"countryCode": "DE",
"country": "Germany",
"as": "AS3320 Deutsche Telekom AG",
"org": "AS3320 Deutsche Telekom AG",
}
)
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
assert result is not None
assert result.country_code == "DE"
async def test_returns_country_name(self) -> None:
"""country_name is populated from the ``country`` field."""
session = _make_session(
{
"status": "success",
"countryCode": "US",
"country": "United States",
"as": "AS15169 Google LLC",
"org": "Google LLC",
}
)
result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
assert result is not None
assert result.country_name == "United States"
async def test_asn_extracted_without_org_suffix(self) -> None:
"""The ASN field contains only the ``AS<N>`` prefix, not the full string."""
session = _make_session(
{
"status": "success",
"countryCode": "DE",
"country": "Germany",
"as": "AS3320 Deutsche Telekom AG",
"org": "Deutsche Telekom",
}
)
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
assert result is not None
assert result.asn == "AS3320"
async def test_org_populated(self) -> None:
"""org field is populated from the ``org`` key."""
session = _make_session(
{
"status": "success",
"countryCode": "US",
"country": "United States",
"as": "AS15169 Google LLC",
"org": "Google LLC",
}
)
result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type]
assert result is not None
assert result.org == "Google LLC"
# ---------------------------------------------------------------------------
# Cache behaviour
# ---------------------------------------------------------------------------
class TestLookupCaching:
"""Verify that results are cached and the cache can be cleared."""
async def test_second_call_uses_cache(self) -> None:
"""Subsequent lookups for the same IP do not make additional HTTP requests."""
session = _make_session(
{
"status": "success",
"countryCode": "DE",
"country": "Germany",
"as": "AS3320 Deutsche Telekom AG",
"org": "Deutsche Telekom",
}
)
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
# The session.get() should only have been called once.
assert session.get.call_count == 1
async def test_clear_cache_forces_refetch(self) -> None:
"""After clearing the cache a new HTTP request is made."""
session = _make_session(
{
"status": "success",
"countryCode": "DE",
"country": "Germany",
"as": "AS3320",
"org": "Telekom",
}
)
await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type]
geo_service.clear_cache()
await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type]
assert session.get.call_count == 2
async def test_negative_result_cached(self) -> None:
"""A failed lookup result (status != success) is also cached."""
session = _make_session(
{"status": "fail", "message": "reserved range"}
)
await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type]
await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type]
assert session.get.call_count == 1
# ---------------------------------------------------------------------------
# Failure modes
# ---------------------------------------------------------------------------
class TestLookupFailures:
"""geo_service.lookup() when things go wrong."""
async def test_non_200_response_returns_none(self) -> None:
"""A 429 or 500 status returns ``None`` without caching."""
session = _make_session({}, status=429)
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type]
assert result is None
async def test_network_error_returns_none(self) -> None:
"""A network exception returns ``None``."""
session = MagicMock()
session.get = MagicMock(side_effect=OSError("connection refused"))
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
assert result is None
async def test_failed_status_returns_geo_info_with_nulls(self) -> None:
"""When ip-api returns ``status=fail`` a GeoInfo with null fields is cached."""
session = _make_session({"status": "fail", "message": "private range"})
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type]
assert result is not None
assert isinstance(result, GeoInfo)
assert result.country_code is None
assert result.country_name is None