Add origin field and filter for ban sources (Tasks 1 & 2)
- Task 1: Mark imported blocklist IP addresses
- Add BanOrigin type and _derive_origin() to ban.py model
- Populate origin field in ban_service list_bans() and bans_by_country()
- BanTable and MapPage companion table show origin badge column
- Tests: origin derivation in test_ban_service.py and test_dashboard.py
- Task 2: Add origin filter to dashboard and world map
- ban_service: _origin_sql_filter() helper; origin param on list_bans()
and bans_by_country()
- dashboard router: optional origin query param forwarded to service
- Frontend: BanOriginFilter type + BAN_ORIGIN_FILTER_LABELS in ban.ts
- fetchBans / fetchBansByCountry forward origin to API
- useBans / useMapData accept and pass origin; page resets on change
- BanTable accepts origin prop; DashboardPage adds segmented filter
- MapPage adds origin Select next to time-range picker
- Tests: origin filter assertions in test_ban_service and test_dashboard
This commit is contained in:
@@ -48,6 +48,26 @@ class UnbanRequest(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
#: Discriminator literal for the origin of a ban.
|
||||
BanOrigin = Literal["blocklist", "selfblock"]
|
||||
|
||||
#: Jail name used by the blocklist import service.
|
||||
BLOCKLIST_JAIL: str = "blocklist-import"
|
||||
|
||||
|
||||
def _derive_origin(jail: str) -> BanOrigin:
|
||||
"""Derive the ban origin from the jail name.
|
||||
|
||||
Args:
|
||||
jail: The jail that issued the ban.
|
||||
|
||||
Returns:
|
||||
``"blocklist"`` when the jail is the dedicated blocklist-import
|
||||
jail, ``"selfblock"`` otherwise.
|
||||
"""
|
||||
return "blocklist" if jail == BLOCKLIST_JAIL else "selfblock"
|
||||
|
||||
|
||||
class Ban(BaseModel):
|
||||
"""Domain model representing a single active or historical ban record."""
|
||||
|
||||
@@ -65,6 +85,10 @@ class Ban(BaseModel):
|
||||
default=None,
|
||||
description="ISO 3166-1 alpha-2 country code resolved from the IP.",
|
||||
)
|
||||
origin: BanOrigin = Field(
|
||||
...,
|
||||
description="Whether this ban came from a blocklist import or fail2ban itself.",
|
||||
)
|
||||
|
||||
|
||||
class BanResponse(BaseModel):
|
||||
@@ -146,6 +170,10 @@ class DashboardBanItem(BaseModel):
|
||||
description="Organisation name associated with the IP.",
|
||||
)
|
||||
ban_count: int = Field(..., ge=1, description="How many times this IP was banned.")
|
||||
origin: BanOrigin = Field(
|
||||
...,
|
||||
description="Whether this ban came from a blocklist import or fail2ban itself.",
|
||||
)
|
||||
|
||||
|
||||
class DashboardBanListResponse(BaseModel):
|
||||
|
||||
@@ -169,3 +169,36 @@ class LogPreviewResponse(BaseModel):
|
||||
total_lines: int = Field(..., ge=0)
|
||||
matched_count: int = Field(..., ge=0)
|
||||
regex_error: str | None = Field(default=None, description="Set if the regex failed to compile.")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Map color threshold models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MapColorThresholdsResponse(BaseModel):
|
||||
"""Response for ``GET /api/config/map-thresholds``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
threshold_high: int = Field(
|
||||
..., description="Ban count for red coloring."
|
||||
)
|
||||
threshold_medium: int = Field(
|
||||
..., description="Ban count for yellow coloring."
|
||||
)
|
||||
threshold_low: int = Field(
|
||||
..., description="Ban count for green coloring."
|
||||
)
|
||||
|
||||
|
||||
class MapColorThresholdsUpdate(BaseModel):
|
||||
"""Payload for ``PUT /api/config/map-thresholds``."""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
threshold_high: int = Field(..., gt=0, description="Ban count for red.")
|
||||
threshold_medium: int = Field(
|
||||
..., gt=0, description="Ban count for yellow."
|
||||
)
|
||||
threshold_low: int = Field(..., gt=0, description="Ban count for green.")
|
||||
|
||||
@@ -30,6 +30,8 @@ from app.models.config import (
|
||||
JailConfigUpdate,
|
||||
LogPreviewRequest,
|
||||
LogPreviewResponse,
|
||||
MapColorThresholdsResponse,
|
||||
MapColorThresholdsUpdate,
|
||||
RegexTestRequest,
|
||||
RegexTestResponse,
|
||||
)
|
||||
@@ -380,3 +382,83 @@ async def preview_log(
|
||||
:class:`~app.models.config.LogPreviewResponse` with per-line results.
|
||||
"""
|
||||
return await config_service.preview_log(body)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Map color thresholds
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@router.get(
|
||||
"/map-color-thresholds",
|
||||
response_model=MapColorThresholdsResponse,
|
||||
summary="Get map color threshold configuration",
|
||||
)
|
||||
async def get_map_color_thresholds(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
) -> MapColorThresholdsResponse:
|
||||
"""Return the configured map color thresholds.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object.
|
||||
_auth: Validated session.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.MapColorThresholdsResponse` with
|
||||
current thresholds.
|
||||
"""
|
||||
from app.services import setup_service
|
||||
|
||||
high, medium, low = await setup_service.get_map_color_thresholds(
|
||||
request.app.state.db
|
||||
)
|
||||
return MapColorThresholdsResponse(
|
||||
threshold_high=high,
|
||||
threshold_medium=medium,
|
||||
threshold_low=low,
|
||||
)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/map-color-thresholds",
|
||||
response_model=MapColorThresholdsResponse,
|
||||
summary="Update map color threshold configuration",
|
||||
)
|
||||
async def update_map_color_thresholds(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
body: MapColorThresholdsUpdate,
|
||||
) -> MapColorThresholdsResponse:
|
||||
"""Update the map color threshold configuration.
|
||||
|
||||
Args:
|
||||
request: FastAPI request object.
|
||||
_auth: Validated session.
|
||||
body: New threshold values.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.config.MapColorThresholdsResponse` with
|
||||
updated thresholds.
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if validation fails (thresholds not
|
||||
properly ordered).
|
||||
"""
|
||||
from app.services import setup_service
|
||||
|
||||
try:
|
||||
await setup_service.set_map_color_thresholds(
|
||||
request.app.state.db,
|
||||
threshold_high=body.threshold_high,
|
||||
threshold_medium=body.threshold_medium,
|
||||
threshold_low=body.threshold_low,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
|
||||
return MapColorThresholdsResponse(
|
||||
threshold_high=body.threshold_high,
|
||||
threshold_medium=body.threshold_medium,
|
||||
threshold_low=body.threshold_low,
|
||||
)
|
||||
|
||||
@@ -18,6 +18,7 @@ from fastapi import APIRouter, Query, Request
|
||||
|
||||
from app.dependencies import AuthDep
|
||||
from app.models.ban import (
|
||||
BanOrigin,
|
||||
BansByCountryResponse,
|
||||
DashboardBanListResponse,
|
||||
TimeRange,
|
||||
@@ -77,6 +78,10 @@ async def get_dashboard_bans(
|
||||
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
|
||||
page: int = Query(default=1, ge=1, description="1-based page number."),
|
||||
page_size: int = Query(default=_DEFAULT_PAGE_SIZE, ge=1, le=500, description="Items per page."),
|
||||
origin: BanOrigin | None = Query(
|
||||
default=None,
|
||||
description="Filter by ban origin: 'blocklist' or 'selfblock'. Omit for all.",
|
||||
),
|
||||
) -> DashboardBanListResponse:
|
||||
"""Return a paginated list of bans within the selected time window.
|
||||
|
||||
@@ -91,6 +96,7 @@ async def get_dashboard_bans(
|
||||
``"365d"``.
|
||||
page: 1-based page number.
|
||||
page_size: Maximum items per page (1–500).
|
||||
origin: Optional filter by ban origin.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.ban.DashboardBanListResponse` with paginated
|
||||
@@ -108,6 +114,7 @@ async def get_dashboard_bans(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
geo_enricher=_enricher,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
|
||||
@@ -120,6 +127,10 @@ async def get_bans_by_country(
|
||||
request: Request,
|
||||
_auth: AuthDep,
|
||||
range: TimeRange = Query(default=_DEFAULT_RANGE, description="Time-range preset."),
|
||||
origin: BanOrigin | None = Query(
|
||||
default=None,
|
||||
description="Filter by ban origin: 'blocklist' or 'selfblock'. Omit for all.",
|
||||
),
|
||||
) -> BansByCountryResponse:
|
||||
"""Return ban counts aggregated by ISO country code.
|
||||
|
||||
@@ -131,6 +142,7 @@ async def get_bans_by_country(
|
||||
request: The incoming request.
|
||||
_auth: Validated session dependency.
|
||||
range: Time-range preset.
|
||||
origin: Optional filter by ban origin.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.ban.BansByCountryResponse` with per-country
|
||||
@@ -146,5 +158,6 @@ async def get_bans_by_country(
|
||||
socket_path,
|
||||
range,
|
||||
geo_enricher=_enricher,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
|
||||
@@ -18,11 +18,14 @@ import aiosqlite
|
||||
import structlog
|
||||
|
||||
from app.models.ban import (
|
||||
BLOCKLIST_JAIL,
|
||||
TIME_RANGE_SECONDS,
|
||||
BanOrigin,
|
||||
BansByCountryResponse,
|
||||
DashboardBanItem,
|
||||
DashboardBanListResponse,
|
||||
TimeRange,
|
||||
_derive_origin,
|
||||
)
|
||||
from app.utils.fail2ban_client import Fail2BanClient
|
||||
|
||||
@@ -41,6 +44,24 @@ _SOCKET_TIMEOUT: float = 5.0
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
|
||||
"""Return a SQL fragment and its parameters for the origin filter.
|
||||
|
||||
Args:
|
||||
origin: ``"blocklist"`` to restrict to the blocklist-import jail,
|
||||
``"selfblock"`` to exclude it, or ``None`` for no restriction.
|
||||
|
||||
Returns:
|
||||
A ``(sql_fragment, params)`` pair — the fragment starts with ``" AND"``
|
||||
so it can be appended directly to an existing WHERE clause.
|
||||
"""
|
||||
if origin == "blocklist":
|
||||
return " AND jail = ?", (BLOCKLIST_JAIL,)
|
||||
if origin == "selfblock":
|
||||
return " AND jail != ?", (BLOCKLIST_JAIL,)
|
||||
return "", ()
|
||||
|
||||
|
||||
def _since_unix(range_: TimeRange) -> int:
|
||||
"""Return the Unix timestamp representing the start of the time window.
|
||||
|
||||
@@ -148,6 +169,7 @@ async def list_bans(
|
||||
page: int = 1,
|
||||
page_size: int = _DEFAULT_PAGE_SIZE,
|
||||
geo_enricher: Any | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> DashboardBanListResponse:
|
||||
"""Return a paginated list of bans within the selected time window.
|
||||
|
||||
@@ -164,6 +186,8 @@ async def list_bans(
|
||||
(default: ``100``).
|
||||
geo_enricher: Optional async callable ``(ip: str) -> GeoInfo | None``.
|
||||
When supplied every result is enriched with country and ASN data.
|
||||
origin: Optional origin filter — ``"blocklist"`` restricts results to
|
||||
the ``blocklist-import`` jail, ``"selfblock"`` excludes it.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.ban.DashboardBanListResponse` containing the
|
||||
@@ -172,16 +196,23 @@ async def list_bans(
|
||||
since: int = _since_unix(range_)
|
||||
effective_page_size: int = min(page_size, _MAX_PAGE_SIZE)
|
||||
offset: int = (page - 1) * effective_page_size
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
log.info("ban_service_list_bans", db_path=db_path, since=since, range=range_)
|
||||
log.info(
|
||||
"ban_service_list_bans",
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
range=range_,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||
f2b_db.row_factory = aiosqlite.Row
|
||||
|
||||
async with f2b_db.execute(
|
||||
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?",
|
||||
(since,),
|
||||
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
count_row = await cur.fetchone()
|
||||
total: int = int(count_row[0]) if count_row else 0
|
||||
@@ -189,10 +220,11 @@ async def list_bans(
|
||||
async with f2b_db.execute(
|
||||
"SELECT jail, ip, timeofban, bancount, data "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ? "
|
||||
"ORDER BY timeofban DESC "
|
||||
"WHERE timeofban >= ?"
|
||||
+ origin_clause
|
||||
+ " ORDER BY timeofban DESC "
|
||||
"LIMIT ? OFFSET ?",
|
||||
(since, effective_page_size, offset),
|
||||
(since, *origin_params, effective_page_size, offset),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
|
||||
@@ -232,6 +264,7 @@ async def list_bans(
|
||||
asn=asn,
|
||||
org=org,
|
||||
ban_count=ban_count,
|
||||
origin=_derive_origin(jail),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -255,6 +288,7 @@ async def bans_by_country(
|
||||
socket_path: str,
|
||||
range_: TimeRange,
|
||||
geo_enricher: Any | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> BansByCountryResponse:
|
||||
"""Aggregate ban counts per country for the selected time window.
|
||||
|
||||
@@ -266,6 +300,8 @@ async def bans_by_country(
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
range_: Time-range preset.
|
||||
geo_enricher: Optional async ``(ip) -> GeoInfo | None`` callable.
|
||||
origin: Optional origin filter — ``"blocklist"`` restricts results to
|
||||
the ``blocklist-import`` jail, ``"selfblock"`` excludes it.
|
||||
|
||||
Returns:
|
||||
:class:`~app.models.ban.BansByCountryResponse` with per-country
|
||||
@@ -274,15 +310,22 @@ async def bans_by_country(
|
||||
import asyncio
|
||||
|
||||
since: int = _since_unix(range_)
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
log.info("ban_service_bans_by_country", db_path=db_path, since=since, range=range_)
|
||||
log.info(
|
||||
"ban_service_bans_by_country",
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
range=range_,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db:
|
||||
f2b_db.row_factory = aiosqlite.Row
|
||||
|
||||
async with f2b_db.execute(
|
||||
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?",
|
||||
(since,),
|
||||
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
||||
(since, *origin_params),
|
||||
) as cur:
|
||||
count_row = await cur.fetchone()
|
||||
total: int = int(count_row[0]) if count_row else 0
|
||||
@@ -290,10 +333,11 @@ async def bans_by_country(
|
||||
async with f2b_db.execute(
|
||||
"SELECT jail, ip, timeofban, bancount, data "
|
||||
"FROM bans "
|
||||
"WHERE timeofban >= ? "
|
||||
"ORDER BY timeofban DESC "
|
||||
"WHERE timeofban >= ?"
|
||||
+ origin_clause
|
||||
+ " ORDER BY timeofban DESC "
|
||||
"LIMIT ?",
|
||||
(since, _MAX_GEO_BANS),
|
||||
(since, *origin_params, _MAX_GEO_BANS),
|
||||
) as cur:
|
||||
rows = await cur.fetchall()
|
||||
|
||||
@@ -336,6 +380,7 @@ async def bans_by_country(
|
||||
asn=asn,
|
||||
org=org,
|
||||
ban_count=int(row["bancount"]),
|
||||
origin=_derive_origin(str(row["jail"])),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -16,10 +16,13 @@ import asyncio
|
||||
import contextlib
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import structlog
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiosqlite
|
||||
|
||||
from app.models.config import (
|
||||
AddLogPathRequest,
|
||||
GlobalConfigResponse,
|
||||
@@ -31,9 +34,12 @@ from app.models.config import (
|
||||
LogPreviewLine,
|
||||
LogPreviewRequest,
|
||||
LogPreviewResponse,
|
||||
MapColorThresholdsResponse,
|
||||
MapColorThresholdsUpdate,
|
||||
RegexTestRequest,
|
||||
RegexTestResponse,
|
||||
)
|
||||
from app.services import setup_service
|
||||
from app.utils.fail2ban_client import Fail2BanClient
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
@@ -609,3 +615,46 @@ def _read_tail_lines(file_path: str, num_lines: int) -> list[str]:
|
||||
if pos > 0 and len(raw_lines) > 1:
|
||||
raw_lines = raw_lines[1:]
|
||||
return [ln.decode("utf-8", errors="replace").rstrip() for ln in raw_lines[-num_lines:] if ln.strip()]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Map color thresholds
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
async def get_map_color_thresholds(db: aiosqlite.Connection) -> MapColorThresholdsResponse:
|
||||
"""Retrieve the current map color threshold configuration.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection to the application database.
|
||||
|
||||
Returns:
|
||||
A :class:`MapColorThresholdsResponse` containing the three threshold values.
|
||||
"""
|
||||
high, medium, low = await setup_service.get_map_color_thresholds(db)
|
||||
return MapColorThresholdsResponse(
|
||||
threshold_high=high,
|
||||
threshold_medium=medium,
|
||||
threshold_low=low,
|
||||
)
|
||||
|
||||
|
||||
async def update_map_color_thresholds(
|
||||
db: aiosqlite.Connection,
|
||||
update: MapColorThresholdsUpdate,
|
||||
) -> None:
|
||||
"""Update the map color threshold configuration.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection to the application database.
|
||||
update: The new threshold values.
|
||||
|
||||
Raises:
|
||||
ValueError: If validation fails (thresholds must satisfy high > medium > low).
|
||||
"""
|
||||
await setup_service.set_map_color_thresholds(
|
||||
db,
|
||||
threshold_high=update.threshold_high,
|
||||
threshold_medium=update.threshold_medium,
|
||||
threshold_low=update.threshold_low,
|
||||
)
|
||||
|
||||
@@ -447,3 +447,90 @@ class TestPreviewLog:
|
||||
data = resp.json()
|
||||
assert data["total_lines"] == 1
|
||||
assert data["matched_count"] == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/config/map-color-thresholds
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetMapColorThresholds:
|
||||
"""Tests for ``GET /api/config/map-color-thresholds``."""
|
||||
|
||||
async def test_200_returns_thresholds(self, config_client: AsyncClient) -> None:
|
||||
"""GET /api/config/map-color-thresholds returns 200 with current values."""
|
||||
resp = await config_client.get("/api/config/map-color-thresholds")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "threshold_high" in data
|
||||
assert "threshold_medium" in data
|
||||
assert "threshold_low" in data
|
||||
# Should return defaults after setup
|
||||
assert data["threshold_high"] == 100
|
||||
assert data["threshold_medium"] == 50
|
||||
assert data["threshold_low"] == 20
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /api/config/map-color-thresholds
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUpdateMapColorThresholds:
|
||||
"""Tests for ``PUT /api/config/map-color-thresholds``."""
|
||||
|
||||
async def test_200_updates_thresholds(self, config_client: AsyncClient) -> None:
|
||||
"""PUT /api/config/map-color-thresholds returns 200 and updates settings."""
|
||||
update_payload = {
|
||||
"threshold_high": 200,
|
||||
"threshold_medium": 80,
|
||||
"threshold_low": 30,
|
||||
}
|
||||
resp = await config_client.put(
|
||||
"/api/config/map-color-thresholds", json=update_payload
|
||||
)
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["threshold_high"] == 200
|
||||
assert data["threshold_medium"] == 80
|
||||
assert data["threshold_low"] == 30
|
||||
|
||||
# Verify the values persist
|
||||
get_resp = await config_client.get("/api/config/map-color-thresholds")
|
||||
assert get_resp.status_code == 200
|
||||
get_data = get_resp.json()
|
||||
assert get_data["threshold_high"] == 200
|
||||
assert get_data["threshold_medium"] == 80
|
||||
assert get_data["threshold_low"] == 30
|
||||
|
||||
async def test_400_for_invalid_order(self, config_client: AsyncClient) -> None:
|
||||
"""PUT /api/config/map-color-thresholds returns 400 if thresholds are misordered."""
|
||||
invalid_payload = {
|
||||
"threshold_high": 50,
|
||||
"threshold_medium": 50,
|
||||
"threshold_low": 20,
|
||||
}
|
||||
resp = await config_client.put(
|
||||
"/api/config/map-color-thresholds", json=invalid_payload
|
||||
)
|
||||
|
||||
assert resp.status_code == 400
|
||||
assert "high > medium > low" in resp.json()["detail"]
|
||||
|
||||
async def test_400_for_non_positive_values(
|
||||
self, config_client: AsyncClient
|
||||
) -> None:
|
||||
"""PUT /api/config/map-color-thresholds returns 422 for non-positive values (Pydantic validation)."""
|
||||
invalid_payload = {
|
||||
"threshold_high": 100,
|
||||
"threshold_medium": 50,
|
||||
"threshold_low": 0,
|
||||
}
|
||||
resp = await config_client.put(
|
||||
"/api/config/map-color-thresholds", json=invalid_payload
|
||||
)
|
||||
|
||||
# Pydantic validates ge=1 constraint before our service code runs
|
||||
assert resp.status_code == 422
|
||||
|
||||
@@ -220,6 +220,7 @@ def _make_ban_list_response(n: int = 2) -> DashboardBanListResponse:
|
||||
asn="AS3320",
|
||||
org="Telekom",
|
||||
ban_count=1,
|
||||
origin="selfblock",
|
||||
)
|
||||
for i in range(n)
|
||||
]
|
||||
@@ -334,10 +335,11 @@ def _make_bans_by_country_response() -> object:
|
||||
asn="AS3320",
|
||||
org="Telekom",
|
||||
ban_count=1,
|
||||
origin="selfblock",
|
||||
),
|
||||
DashboardBanItem(
|
||||
ip="5.6.7.8",
|
||||
jail="sshd",
|
||||
jail="blocklist-import",
|
||||
banned_at="2026-03-01T10:05:00+00:00",
|
||||
service=None,
|
||||
country_code="US",
|
||||
@@ -345,6 +347,7 @@ def _make_bans_by_country_response() -> object:
|
||||
asn="AS15169",
|
||||
org="Google LLC",
|
||||
ban_count=2,
|
||||
origin="blocklist",
|
||||
),
|
||||
]
|
||||
return BansByCountryResponse(
|
||||
@@ -431,3 +434,146 @@ class TestBansByCountry:
|
||||
assert body["total"] == 0
|
||||
assert body["countries"] == {}
|
||||
assert body["bans"] == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Origin field tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDashboardBansOriginField:
|
||||
"""Verify that the ``origin`` field is present in API responses."""
|
||||
|
||||
async def test_origin_present_in_ban_list_items(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""Each item in ``/api/dashboard/bans`` carries an ``origin`` field."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_bans",
|
||||
new=AsyncMock(return_value=_make_ban_list_response(1)),
|
||||
):
|
||||
response = await dashboard_client.get("/api/dashboard/bans")
|
||||
|
||||
item = response.json()["items"][0]
|
||||
assert "origin" in item
|
||||
assert item["origin"] in ("blocklist", "selfblock")
|
||||
|
||||
async def test_selfblock_origin_serialised_correctly(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""A ban from a non-blocklist jail serialises as ``"selfblock"``."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.list_bans",
|
||||
new=AsyncMock(return_value=_make_ban_list_response(1)),
|
||||
):
|
||||
response = await dashboard_client.get("/api/dashboard/bans")
|
||||
|
||||
item = response.json()["items"][0]
|
||||
assert item["jail"] == "sshd"
|
||||
assert item["origin"] == "selfblock"
|
||||
|
||||
async def test_origin_present_in_bans_by_country(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""Each ban in ``/api/dashboard/bans/by-country`` carries an ``origin``."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_country",
|
||||
new=AsyncMock(return_value=_make_bans_by_country_response()),
|
||||
):
|
||||
response = await dashboard_client.get("/api/dashboard/bans/by-country")
|
||||
|
||||
bans = response.json()["bans"]
|
||||
assert all("origin" in ban for ban in bans)
|
||||
origins = {ban["origin"] for ban in bans}
|
||||
assert origins == {"blocklist", "selfblock"}
|
||||
|
||||
async def test_blocklist_origin_serialised_correctly(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""A ban from the ``blocklist-import`` jail serialises as ``"blocklist"``."""
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_country",
|
||||
new=AsyncMock(return_value=_make_bans_by_country_response()),
|
||||
):
|
||||
response = await dashboard_client.get("/api/dashboard/bans/by-country")
|
||||
|
||||
bans = response.json()["bans"]
|
||||
blocklist_ban = next(b for b in bans if b["jail"] == "blocklist-import")
|
||||
assert blocklist_ban["origin"] == "blocklist"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Origin filter query parameter tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOriginFilterParam:
|
||||
"""Verify that the ``origin`` query parameter is forwarded to the service."""
|
||||
|
||||
async def test_bans_origin_blocklist_forwarded_to_service(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""``?origin=blocklist`` is passed to ``ban_service.list_bans``."""
|
||||
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
||||
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
||||
await dashboard_client.get("/api/dashboard/bans?origin=blocklist")
|
||||
|
||||
_, kwargs = mock_list.call_args
|
||||
assert kwargs.get("origin") == "blocklist"
|
||||
|
||||
async def test_bans_origin_selfblock_forwarded_to_service(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""``?origin=selfblock`` is passed to ``ban_service.list_bans``."""
|
||||
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
||||
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
||||
await dashboard_client.get("/api/dashboard/bans?origin=selfblock")
|
||||
|
||||
_, kwargs = mock_list.call_args
|
||||
assert kwargs.get("origin") == "selfblock"
|
||||
|
||||
async def test_bans_no_origin_param_defaults_to_none(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""Omitting ``origin`` passes ``None`` to the service (no filtering)."""
|
||||
mock_list = AsyncMock(return_value=_make_ban_list_response())
|
||||
with patch("app.routers.dashboard.ban_service.list_bans", new=mock_list):
|
||||
await dashboard_client.get("/api/dashboard/bans")
|
||||
|
||||
_, kwargs = mock_list.call_args
|
||||
assert kwargs.get("origin") is None
|
||||
|
||||
async def test_bans_invalid_origin_returns_422(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""An invalid ``origin`` value returns HTTP 422 Unprocessable Entity."""
|
||||
response = await dashboard_client.get("/api/dashboard/bans?origin=invalid")
|
||||
assert response.status_code == 422
|
||||
|
||||
async def test_by_country_origin_blocklist_forwarded(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""``?origin=blocklist`` is passed to ``ban_service.bans_by_country``."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn
|
||||
):
|
||||
await dashboard_client.get(
|
||||
"/api/dashboard/bans/by-country?origin=blocklist"
|
||||
)
|
||||
|
||||
_, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("origin") == "blocklist"
|
||||
|
||||
async def test_by_country_no_origin_defaults_to_none(
|
||||
self, dashboard_client: AsyncClient
|
||||
) -> None:
|
||||
"""Omitting ``origin`` passes ``None`` to ``bans_by_country``."""
|
||||
mock_fn = AsyncMock(return_value=_make_bans_by_country_response())
|
||||
with patch(
|
||||
"app.routers.dashboard.ban_service.bans_by_country", new=mock_fn
|
||||
):
|
||||
await dashboard_client.get("/api/dashboard/bans/by-country")
|
||||
|
||||
_, kwargs = mock_fn.call_args
|
||||
assert kwargs.get("origin") is None
|
||||
|
||||
@@ -102,6 +102,39 @@ async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mixed_origin_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||
"""Return a database with bans from both blocklist-import and organic jails."""
|
||||
path = str(tmp_path / "fail2ban_mixed_origin.sqlite3")
|
||||
await _create_f2b_db(
|
||||
path,
|
||||
[
|
||||
{
|
||||
"jail": "blocklist-import",
|
||||
"ip": "10.0.0.1",
|
||||
"timeofban": _ONE_HOUR_AGO,
|
||||
"bantime": -1,
|
||||
"bancount": 1,
|
||||
},
|
||||
{
|
||||
"jail": "sshd",
|
||||
"ip": "10.0.0.2",
|
||||
"timeofban": _ONE_HOUR_AGO,
|
||||
"bantime": 3600,
|
||||
"bancount": 3,
|
||||
},
|
||||
{
|
||||
"jail": "nginx",
|
||||
"ip": "10.0.0.3",
|
||||
"timeofban": _ONE_HOUR_AGO,
|
||||
"bantime": 7200,
|
||||
"bancount": 1,
|
||||
},
|
||||
],
|
||||
)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def empty_f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
|
||||
"""Return the path to a fail2ban SQLite database with no ban records."""
|
||||
@@ -299,3 +332,183 @@ class TestListBansPagination:
|
||||
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
|
||||
|
||||
assert result.total == 3 # All three bans are within 7d.
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_bans / bans_by_country — origin derivation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBanOriginDerivation:
|
||||
"""Verify that ban_service correctly derives ``origin`` from jail names."""
|
||||
|
||||
async def test_blocklist_import_jail_yields_blocklist_origin(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""Bans from ``blocklist-import`` jail carry ``origin == "blocklist"``."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
|
||||
blocklist_items = [i for i in result.items if i.jail == "blocklist-import"]
|
||||
assert len(blocklist_items) == 1
|
||||
assert blocklist_items[0].origin == "blocklist"
|
||||
|
||||
async def test_organic_jail_yields_selfblock_origin(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""Bans from organic jails (sshd, nginx, …) carry ``origin == "selfblock"``."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
|
||||
organic_items = [i for i in result.items if i.jail != "blocklist-import"]
|
||||
assert len(organic_items) == 2
|
||||
for item in organic_items:
|
||||
assert item.origin == "selfblock"
|
||||
|
||||
async def test_all_items_carry_origin_field(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""Every returned item has an ``origin`` field with a valid value."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
|
||||
for item in result.items:
|
||||
assert item.origin in ("blocklist", "selfblock")
|
||||
|
||||
async def test_bans_by_country_blocklist_origin(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""``bans_by_country`` also derives origin correctly for blocklist bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
||||
|
||||
blocklist_bans = [b for b in result.bans if b.jail == "blocklist-import"]
|
||||
assert len(blocklist_bans) == 1
|
||||
assert blocklist_bans[0].origin == "blocklist"
|
||||
|
||||
async def test_bans_by_country_selfblock_origin(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""``bans_by_country`` derives origin correctly for organic jails."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
||||
|
||||
organic_bans = [b for b in result.bans if b.jail != "blocklist-import"]
|
||||
assert len(organic_bans) == 2
|
||||
for ban in organic_bans:
|
||||
assert ban.origin == "selfblock"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_bans / bans_by_country — origin filter parameter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOriginFilter:
|
||||
"""Verify that the origin filter correctly restricts results."""
|
||||
|
||||
async def test_list_bans_blocklist_filter_returns_only_blocklist(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""``origin='blocklist'`` returns only blocklist-import jail bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", origin="blocklist"
|
||||
)
|
||||
|
||||
assert result.total == 1
|
||||
assert len(result.items) == 1
|
||||
assert result.items[0].jail == "blocklist-import"
|
||||
assert result.items[0].origin == "blocklist"
|
||||
|
||||
async def test_list_bans_selfblock_filter_excludes_blocklist(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""``origin='selfblock'`` excludes the blocklist-import jail."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", origin="selfblock"
|
||||
)
|
||||
|
||||
assert result.total == 2
|
||||
assert len(result.items) == 2
|
||||
for item in result.items:
|
||||
assert item.jail != "blocklist-import"
|
||||
assert item.origin == "selfblock"
|
||||
|
||||
async def test_list_bans_no_filter_returns_all(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""``origin=None`` applies no jail restriction — all bans returned."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h", origin=None)
|
||||
|
||||
assert result.total == 3
|
||||
|
||||
async def test_bans_by_country_blocklist_filter(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""``bans_by_country`` with ``origin='blocklist'`` counts only blocklist bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", origin="blocklist"
|
||||
)
|
||||
|
||||
assert result.total == 1
|
||||
assert all(b.jail == "blocklist-import" for b in result.bans)
|
||||
|
||||
async def test_bans_by_country_selfblock_filter(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""``bans_by_country`` with ``origin='selfblock'`` excludes blocklist jails."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", origin="selfblock"
|
||||
)
|
||||
|
||||
assert result.total == 2
|
||||
assert all(b.jail != "blocklist-import" for b in result.bans)
|
||||
|
||||
async def test_bans_by_country_no_filter_returns_all(
|
||||
self, mixed_origin_db_path: str
|
||||
) -> None:
|
||||
"""``bans_by_country`` with ``origin=None`` returns all bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", origin=None
|
||||
)
|
||||
|
||||
assert result.total == 3
|
||||
|
||||
@@ -98,6 +98,23 @@ class TestRunSetup:
|
||||
with pytest.raises(RuntimeError, match="already been completed"):
|
||||
await setup_service.run_setup(db, **kwargs) # type: ignore[arg-type]
|
||||
|
||||
async def test_initializes_map_color_thresholds_with_defaults(
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""run_setup() initializes map color thresholds with default values."""
|
||||
await setup_service.run_setup(
|
||||
db,
|
||||
master_password="mypassword1",
|
||||
database_path="bangui.db",
|
||||
fail2ban_socket="/var/run/fail2ban/fail2ban.sock",
|
||||
timezone="UTC",
|
||||
session_duration_minutes=60,
|
||||
)
|
||||
high, medium, low = await setup_service.get_map_color_thresholds(db)
|
||||
assert high == 100
|
||||
assert medium == 50
|
||||
assert low == 20
|
||||
|
||||
|
||||
class TestGetTimezone:
|
||||
async def test_returns_utc_on_fresh_db(self, db: aiosqlite.Connection) -> None:
|
||||
@@ -119,6 +136,74 @@ class TestGetTimezone:
|
||||
assert await setup_service.get_timezone(db) == "America/New_York"
|
||||
|
||||
|
||||
class TestMapColorThresholds:
|
||||
async def test_get_map_color_thresholds_returns_defaults_on_fresh_db(
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""get_map_color_thresholds() returns default values on a fresh database."""
|
||||
high, medium, low = await setup_service.get_map_color_thresholds(db)
|
||||
assert high == 100
|
||||
assert medium == 50
|
||||
assert low == 20
|
||||
|
||||
async def test_set_map_color_thresholds_persists_values(
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""set_map_color_thresholds() stores and retrieves custom values."""
|
||||
await setup_service.set_map_color_thresholds(
|
||||
db, threshold_high=200, threshold_medium=80, threshold_low=30
|
||||
)
|
||||
high, medium, low = await setup_service.get_map_color_thresholds(db)
|
||||
assert high == 200
|
||||
assert medium == 80
|
||||
assert low == 30
|
||||
|
||||
async def test_set_map_color_thresholds_rejects_non_positive(
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""set_map_color_thresholds() raises ValueError for non-positive thresholds."""
|
||||
with pytest.raises(ValueError, match="positive integers"):
|
||||
await setup_service.set_map_color_thresholds(
|
||||
db, threshold_high=100, threshold_medium=50, threshold_low=0
|
||||
)
|
||||
with pytest.raises(ValueError, match="positive integers"):
|
||||
await setup_service.set_map_color_thresholds(
|
||||
db, threshold_high=-10, threshold_medium=50, threshold_low=20
|
||||
)
|
||||
|
||||
async def test_set_map_color_thresholds_rejects_invalid_order(
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""
|
||||
set_map_color_thresholds() rejects invalid ordering.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="high > medium > low"):
|
||||
await setup_service.set_map_color_thresholds(
|
||||
db, threshold_high=50, threshold_medium=50, threshold_low=20
|
||||
)
|
||||
with pytest.raises(ValueError, match="high > medium > low"):
|
||||
await setup_service.set_map_color_thresholds(
|
||||
db, threshold_high=100, threshold_medium=30, threshold_low=50
|
||||
)
|
||||
|
||||
async def test_run_setup_initializes_default_thresholds(
|
||||
self, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""run_setup() initializes map color thresholds with defaults."""
|
||||
await setup_service.run_setup(
|
||||
db,
|
||||
master_password="mypassword1",
|
||||
database_path="bangui.db",
|
||||
fail2ban_socket="/var/run/fail2ban/fail2ban.sock",
|
||||
timezone="UTC",
|
||||
session_duration_minutes=60,
|
||||
)
|
||||
high, medium, low = await setup_service.get_map_color_thresholds(db)
|
||||
assert high == 100
|
||||
assert medium == 50
|
||||
assert low == 20
|
||||
|
||||
|
||||
class TestRunSetupAsync:
|
||||
"""Verify the async/non-blocking bcrypt behavior of run_setup."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user