refactor: complete Task 2/3 geo decouple + exceptions centralization; mark as done
This commit is contained in:
23
backend/app/exceptions.py
Normal file
23
backend/app/exceptions.py
Normal file
@@ -0,0 +1,23 @@
|
||||
"""Shared domain exception classes used across routers and services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class JailNotFoundError(Exception):
|
||||
"""Raised when a requested jail name does not exist."""
|
||||
|
||||
|
||||
class JailOperationError(Exception):
|
||||
"""Raised when a fail2ban jail operation fails."""
|
||||
|
||||
|
||||
class ConfigValidationError(Exception):
|
||||
"""Raised when config values fail validation before applying."""
|
||||
|
||||
|
||||
class ConfigOperationError(Exception):
|
||||
"""Raised when a config payload update or command fails."""
|
||||
|
||||
|
||||
class ServerOperationError(Exception):
|
||||
"""Raised when a server control command (e.g. refresh) fails."""
|
||||
@@ -3,8 +3,18 @@
|
||||
Response models for the ``GET /api/geo/lookup/{ip}`` endpoint.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
|
||||
|
||||
class GeoDetail(BaseModel):
|
||||
"""Enriched geolocation data for an IP address.
|
||||
@@ -64,3 +74,26 @@ class IpLookupResponse(BaseModel):
|
||||
default=None,
|
||||
description="Enriched geographical and network information.",
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# shared service types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeoInfo:
|
||||
"""Geo resolution result used throughout backend services."""
|
||||
|
||||
country_code: str | None
|
||||
country_name: str | None
|
||||
asn: str | None
|
||||
org: str | None
|
||||
|
||||
|
||||
GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
|
||||
GeoBatchLookup = Callable[
|
||||
[list[str], "aiohttp.ClientSession", "aiosqlite.Connection | None"],
|
||||
Awaitable[dict[str, GeoInfo]],
|
||||
]
|
||||
GeoCacheLookup = Callable[[list[str]], tuple[dict[str, GeoInfo], list[str]]]
|
||||
|
||||
@@ -20,8 +20,8 @@ from fastapi import APIRouter, HTTPException, Request, status
|
||||
from app.dependencies import AuthDep
|
||||
from app.models.ban import ActiveBanListResponse, BanRequest, UnbanAllResponse, UnbanRequest
|
||||
from app.models.jail import JailCommandResponse
|
||||
from app.services import jail_service
|
||||
from app.services.jail_service import JailNotFoundError, JailOperationError
|
||||
from app.services import geo_service, jail_service
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/bans", tags=["Bans"])
|
||||
@@ -73,6 +73,7 @@ async def get_active_bans(
|
||||
try:
|
||||
return await jail_service.get_active_bans(
|
||||
socket_path,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
http_session=http_session,
|
||||
app_db=app_db,
|
||||
)
|
||||
|
||||
@@ -42,7 +42,7 @@ from app.models.blocklist import (
|
||||
ScheduleConfig,
|
||||
ScheduleInfo,
|
||||
)
|
||||
from app.services import blocklist_service
|
||||
from app.services import blocklist_service, geo_service
|
||||
from app.tasks import blocklist_import as blocklist_import_task
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/blocklists", tags=["Blocklists"])
|
||||
@@ -131,7 +131,13 @@ async def run_import_now(
|
||||
"""
|
||||
http_session: aiohttp.ClientSession = request.app.state.http_session
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
return await blocklist_service.import_all(db, http_session, socket_path)
|
||||
return await blocklist_service.import_all(
|
||||
db,
|
||||
http_session,
|
||||
socket_path,
|
||||
geo_is_cached=geo_service.is_cached,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
|
||||
@@ -93,12 +93,7 @@ from app.services.config_file_service import (
|
||||
JailNameError,
|
||||
JailNotFoundInConfigError,
|
||||
)
|
||||
from app.services.config_service import (
|
||||
ConfigOperationError,
|
||||
ConfigValidationError,
|
||||
JailNotFoundError,
|
||||
)
|
||||
from app.services.jail_service import JailOperationError
|
||||
from app.exceptions import ConfigOperationError, ConfigValidationError, JailNotFoundError, JailOperationError
|
||||
from app.tasks.health_check import _run_probe
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ from app.models.ban import (
|
||||
TimeRange,
|
||||
)
|
||||
from app.models.server import ServerStatus, ServerStatusResponse
|
||||
from app.services import ban_service
|
||||
from app.services import ban_service, geo_service
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/dashboard", tags=["Dashboard"])
|
||||
|
||||
@@ -120,6 +120,7 @@ async def get_dashboard_bans(
|
||||
page_size=page_size,
|
||||
http_session=http_session,
|
||||
app_db=None,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
@@ -163,6 +164,8 @@ async def get_bans_by_country(
|
||||
socket_path,
|
||||
range,
|
||||
http_session=http_session,
|
||||
geo_cache_lookup=geo_service.lookup_cached_only,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
app_db=None,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
@@ -19,9 +19,8 @@ import aiosqlite
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Request, status
|
||||
|
||||
from app.dependencies import AuthDep, get_db
|
||||
from app.models.geo import GeoCacheStatsResponse, GeoDetail, IpLookupResponse
|
||||
from app.models.geo import GeoCacheStatsResponse, GeoDetail, GeoInfo, IpLookupResponse
|
||||
from app.services import geo_service, jail_service
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/geo", tags=["Geo"])
|
||||
|
||||
@@ -31,8 +31,8 @@ from app.models.jail import (
|
||||
JailDetailResponse,
|
||||
JailListResponse,
|
||||
)
|
||||
from app.services import jail_service
|
||||
from app.services.jail_service import JailNotFoundError, JailOperationError
|
||||
from app.services import geo_service, jail_service
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/jails", tags=["Jails"])
|
||||
@@ -606,6 +606,7 @@ async def get_jail_banned_ips(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
search=search,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
http_session=http_session,
|
||||
app_db=app_db,
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ from fastapi import APIRouter, HTTPException, Request, status
|
||||
from app.dependencies import AuthDep
|
||||
from app.models.server import ServerSettingsResponse, ServerSettingsUpdate
|
||||
from app.services import server_service
|
||||
from app.services.server_service import ServerOperationError
|
||||
from app.exceptions import ServerOperationError
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/server", tags=["Server"])
|
||||
|
||||
@@ -11,11 +11,8 @@ so BanGUI never modifies or locks the fail2ban database.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
|
||||
@@ -39,18 +36,16 @@ from app.models.ban import (
|
||||
JailBanCount as JailBanCountModel,
|
||||
)
|
||||
from app.repositories import fail2ban_db_repo
|
||||
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanResponse
|
||||
from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoBatchLookup, GeoCacheLookup, GeoEnricher, GeoInfo
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
type GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -102,98 +97,6 @@ def _since_unix(range_: TimeRange) -> int:
|
||||
return int(time.time()) - seconds
|
||||
|
||||
|
||||
def _ts_to_iso(unix_ts: int) -> str:
|
||||
"""Convert a Unix timestamp to an ISO 8601 UTC string.
|
||||
|
||||
Args:
|
||||
unix_ts: Seconds since the Unix epoch.
|
||||
|
||||
Returns:
|
||||
ISO 8601 UTC timestamp, e.g. ``"2026-03-01T12:00:00+00:00"``.
|
||||
"""
|
||||
return datetime.fromtimestamp(unix_ts, tz=UTC).isoformat()
|
||||
|
||||
|
||||
async def _get_fail2ban_db_path(socket_path: str) -> str:
|
||||
"""Query fail2ban for the path to its SQLite database.
|
||||
|
||||
Sends the ``get dbfile`` command via the fail2ban socket and returns
|
||||
the value of the ``dbfile`` setting.
|
||||
|
||||
Args:
|
||||
socket_path: Path to the fail2ban Unix domain socket.
|
||||
|
||||
Returns:
|
||||
Absolute path to the fail2ban SQLite database file.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If fail2ban reports that no database is configured
|
||||
or if the socket response is unexpected.
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
||||
cannot be reached.
|
||||
"""
|
||||
async with Fail2BanClient(socket_path, timeout=_SOCKET_TIMEOUT) as client:
|
||||
response = await client.send(["get", "dbfile"])
|
||||
|
||||
try:
|
||||
code, data = cast("Fail2BanResponse", response)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise RuntimeError(f"Unexpected response from fail2ban: {response!r}") from exc
|
||||
|
||||
if code != 0:
|
||||
raise RuntimeError(f"fail2ban error code {code}: {data!r}")
|
||||
|
||||
if data is None:
|
||||
raise RuntimeError("fail2ban has no database configured (dbfile is None)")
|
||||
|
||||
return str(data)
|
||||
|
||||
|
||||
def _parse_data_json(raw: object) -> tuple[list[str], int]:
|
||||
"""Extract matches and failure count from the ``bans.data`` column.
|
||||
|
||||
The ``data`` column stores a JSON blob with optional keys:
|
||||
|
||||
* ``matches`` — list of raw matched log lines.
|
||||
* ``failures`` — total failure count that triggered the ban.
|
||||
|
||||
Args:
|
||||
raw: The raw ``data`` column value (string, dict, or ``None``).
|
||||
|
||||
Returns:
|
||||
A ``(matches, failures)`` tuple. Both default to empty/zero when
|
||||
parsing fails or the column is absent.
|
||||
"""
|
||||
if raw is None:
|
||||
return [], 0
|
||||
|
||||
obj: dict[str, object] = {}
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
parsed: object = json.loads(raw)
|
||||
if isinstance(parsed, dict):
|
||||
obj = parsed
|
||||
# json.loads("null") → None, or other non-dict — treat as empty
|
||||
except json.JSONDecodeError:
|
||||
return [], 0
|
||||
elif isinstance(raw, dict):
|
||||
obj = raw
|
||||
|
||||
raw_matches = obj.get("matches")
|
||||
if isinstance(raw_matches, list):
|
||||
matches: list[str] = [str(m) for m in raw_matches]
|
||||
else:
|
||||
matches = []
|
||||
|
||||
raw_failures = obj.get("failures")
|
||||
failures: int = 0
|
||||
if isinstance(raw_failures, (int, float, str)):
|
||||
try:
|
||||
failures = int(raw_failures)
|
||||
except (ValueError, TypeError):
|
||||
failures = 0
|
||||
|
||||
return matches, failures
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -209,6 +112,7 @@ async def list_bans(
|
||||
page_size: int = _DEFAULT_PAGE_SIZE,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> DashboardBanListResponse:
|
||||
@@ -248,14 +152,13 @@ async def list_bans(
|
||||
:class:`~app.models.ban.DashboardBanListResponse` containing the
|
||||
paginated items and total count.
|
||||
"""
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
since: int = _since_unix(range_)
|
||||
effective_page_size: int = min(page_size, _MAX_PAGE_SIZE)
|
||||
offset: int = (page - 1) * effective_page_size
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
"ban_service_list_bans",
|
||||
db_path=db_path,
|
||||
@@ -276,10 +179,10 @@ async def list_bans(
|
||||
# This avoids hitting the 45 req/min single-IP rate limit when the
|
||||
# page contains many bans (e.g. after a large blocklist import).
|
||||
geo_map: dict[str, GeoInfo] = {}
|
||||
if http_session is not None and rows:
|
||||
if http_session is not None and rows and geo_batch_lookup is not None:
|
||||
page_ips: list[str] = [r.ip for r in rows]
|
||||
try:
|
||||
geo_map = await geo_service.lookup_batch(page_ips, http_session, db=app_db)
|
||||
geo_map = await geo_batch_lookup(page_ips, http_session, db=app_db)
|
||||
except Exception: # noqa: BLE001
|
||||
log.warning("ban_service_batch_geo_failed_list_bans")
|
||||
|
||||
@@ -287,9 +190,9 @@ async def list_bans(
|
||||
for row in rows:
|
||||
jail: str = row.jail
|
||||
ip: str = row.ip
|
||||
banned_at: str = _ts_to_iso(row.timeofban)
|
||||
banned_at: str = ts_to_iso(row.timeofban)
|
||||
ban_count: int = row.bancount
|
||||
matches, _ = _parse_data_json(row.data)
|
||||
matches, _ = parse_data_json(row.data)
|
||||
service: str | None = matches[0] if matches else None
|
||||
|
||||
country_code: str | None = None
|
||||
@@ -350,6 +253,8 @@ async def bans_by_country(
|
||||
socket_path: str,
|
||||
range_: TimeRange,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
geo_cache_lookup: GeoCacheLookup | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
@@ -389,11 +294,10 @@ async def bans_by_country(
|
||||
:class:`~app.models.ban.BansByCountryResponse` with per-country
|
||||
aggregation and the companion ban list.
|
||||
"""
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
since: int = _since_unix(range_)
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
"ban_service_bans_by_country",
|
||||
db_path=db_path,
|
||||
@@ -429,23 +333,24 @@ async def bans_by_country(
|
||||
unique_ips: list[str] = [r.ip for r in agg_rows]
|
||||
geo_map: dict[str, GeoInfo] = {}
|
||||
|
||||
if http_session is not None and unique_ips:
|
||||
if http_session is not None and unique_ips and geo_cache_lookup is not None:
|
||||
# Serve only what is already in the in-memory cache — no API calls on
|
||||
# the hot path. Uncached IPs are resolved asynchronously in the
|
||||
# background so subsequent requests benefit from a warmer cache.
|
||||
geo_map, uncached = geo_service.lookup_cached_only(unique_ips)
|
||||
geo_map, uncached = geo_cache_lookup(unique_ips)
|
||||
if uncached:
|
||||
log.info(
|
||||
"ban_service_geo_background_scheduled",
|
||||
uncached=len(uncached),
|
||||
cached=len(geo_map),
|
||||
)
|
||||
# Fire-and-forget: lookup_batch handles rate-limiting / retries.
|
||||
# The dirty-set flush task persists results to the DB.
|
||||
asyncio.create_task( # noqa: RUF006
|
||||
geo_service.lookup_batch(uncached, http_session, db=app_db),
|
||||
name="geo_bans_by_country",
|
||||
)
|
||||
if geo_batch_lookup is not None:
|
||||
# Fire-and-forget: lookup_batch handles rate-limiting / retries.
|
||||
# The dirty-set flush task persists results to the DB.
|
||||
asyncio.create_task( # noqa: RUF006
|
||||
geo_batch_lookup(uncached, http_session, db=app_db),
|
||||
name="geo_bans_by_country",
|
||||
)
|
||||
elif geo_enricher is not None and unique_ips:
|
||||
# Fallback: legacy per-IP enricher (used in tests / older callers).
|
||||
async def _safe_lookup(ip: str) -> tuple[str, GeoInfo | None]:
|
||||
@@ -483,13 +388,13 @@ async def bans_by_country(
|
||||
cn = geo.country_name if geo else None
|
||||
asn: str | None = geo.asn if geo else None
|
||||
org: str | None = geo.org if geo else None
|
||||
matches, _ = _parse_data_json(companion_row.data)
|
||||
matches, _ = parse_data_json(companion_row.data)
|
||||
|
||||
bans.append(
|
||||
DashboardBanItem(
|
||||
ip=ip,
|
||||
jail=companion_row.jail,
|
||||
banned_at=_ts_to_iso(companion_row.timeofban),
|
||||
banned_at=ts_to_iso(companion_row.timeofban),
|
||||
service=matches[0] if matches else None,
|
||||
country_code=cc,
|
||||
country_name=cn,
|
||||
@@ -550,7 +455,7 @@ async def ban_trend(
|
||||
num_buckets: int = bucket_count(range_)
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
"ban_service_ban_trend",
|
||||
db_path=db_path,
|
||||
@@ -571,7 +476,7 @@ async def ban_trend(
|
||||
|
||||
buckets: list[BanTrendBucket] = [
|
||||
BanTrendBucket(
|
||||
timestamp=_ts_to_iso(since + i * bucket_secs),
|
||||
timestamp=ts_to_iso(since + i * bucket_secs),
|
||||
count=counts[i],
|
||||
)
|
||||
for i in range(num_buckets)
|
||||
@@ -615,12 +520,12 @@ async def bans_by_jail(
|
||||
since: int = _since_unix(range_)
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.debug(
|
||||
"ban_service_bans_by_jail",
|
||||
db_path=db_path,
|
||||
since=since,
|
||||
since_iso=_ts_to_iso(since),
|
||||
since_iso=ts_to_iso(since),
|
||||
range=range_,
|
||||
origin=origin,
|
||||
)
|
||||
|
||||
@@ -33,9 +33,13 @@ from app.repositories import blocklist_repo, import_log_repo, settings_repo
|
||||
from app.utils.ip_utils import is_valid_ip, is_valid_network
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
|
||||
from app.models.geo import GeoBatchLookup
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
#: Settings key used to persist the schedule config.
|
||||
@@ -238,6 +242,8 @@ async def import_source(
|
||||
http_session: aiohttp.ClientSession,
|
||||
socket_path: str,
|
||||
db: aiosqlite.Connection,
|
||||
geo_is_cached: Callable[[str], bool] | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
) -> ImportSourceResult:
|
||||
"""Download and apply bans from a single blocklist source.
|
||||
|
||||
@@ -339,12 +345,8 @@ async def import_source(
|
||||
)
|
||||
|
||||
# --- Pre-warm geo cache for newly imported IPs ---
|
||||
if imported_ips:
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
uncached_ips: list[str] = [
|
||||
ip for ip in imported_ips if not geo_service.is_cached(ip)
|
||||
]
|
||||
if imported_ips and geo_is_cached is not None:
|
||||
uncached_ips: list[str] = [ip for ip in imported_ips if not geo_is_cached(ip)]
|
||||
skipped_geo: int = len(imported_ips) - len(uncached_ips)
|
||||
|
||||
if skipped_geo > 0:
|
||||
@@ -355,9 +357,9 @@ async def import_source(
|
||||
to_lookup=len(uncached_ips),
|
||||
)
|
||||
|
||||
if uncached_ips:
|
||||
if uncached_ips and geo_batch_lookup is not None:
|
||||
try:
|
||||
await geo_service.lookup_batch(uncached_ips, http_session, db=db)
|
||||
await geo_batch_lookup(uncached_ips, http_session, db=db)
|
||||
log.info(
|
||||
"blocklist_geo_prewarm_complete",
|
||||
source_id=source.id,
|
||||
@@ -383,6 +385,8 @@ async def import_all(
|
||||
db: aiosqlite.Connection,
|
||||
http_session: aiohttp.ClientSession,
|
||||
socket_path: str,
|
||||
geo_is_cached: Callable[[str], bool] | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
) -> ImportRunResult:
|
||||
"""Import all enabled blocklist sources.
|
||||
|
||||
@@ -406,7 +410,14 @@ async def import_all(
|
||||
|
||||
for row in sources:
|
||||
source = _row_to_source(row)
|
||||
result = await import_source(source, http_session, socket_path, db)
|
||||
result = await import_source(
|
||||
source,
|
||||
http_session,
|
||||
socket_path,
|
||||
db,
|
||||
geo_is_cached=geo_is_cached,
|
||||
geo_batch_lookup=geo_batch_lookup,
|
||||
)
|
||||
results.append(result)
|
||||
total_imported += result.ips_imported
|
||||
total_skipped += result.ips_skipped
|
||||
|
||||
@@ -54,8 +54,8 @@ from app.models.config import (
|
||||
JailValidationResult,
|
||||
RollbackResponse,
|
||||
)
|
||||
from app.exceptions import JailNotFoundError
|
||||
from app.services import jail_service
|
||||
from app.services.jail_service import JailNotFoundError as JailNotFoundError
|
||||
from app.utils import conffile_parser
|
||||
from app.utils.fail2ban_client import (
|
||||
Fail2BanClient,
|
||||
|
||||
@@ -44,6 +44,7 @@ from app.models.config import (
|
||||
RegexTestResponse,
|
||||
ServiceStatusResponse,
|
||||
)
|
||||
from app.exceptions import ConfigOperationError, ConfigValidationError, JailNotFoundError
|
||||
from app.services import setup_service
|
||||
from app.utils.fail2ban_client import Fail2BanClient
|
||||
|
||||
@@ -55,26 +56,7 @@ _SOCKET_TIMEOUT: float = 10.0
|
||||
# Custom exceptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JailNotFoundError(Exception):
|
||||
"""Raised when a requested jail name does not exist in fail2ban."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialise with the jail name that was not found.
|
||||
|
||||
Args:
|
||||
name: The jail name that could not be located.
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(f"Jail not found: {name!r}")
|
||||
|
||||
|
||||
class ConfigValidationError(Exception):
|
||||
"""Raised when a configuration value fails validation before writing."""
|
||||
|
||||
|
||||
class ConfigOperationError(Exception):
|
||||
"""Raised when a configuration write command fails."""
|
||||
# (exceptions are now defined in app.exceptions and imported above)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -41,13 +41,12 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
|
||||
from app.models.geo import GeoInfo
|
||||
from app.repositories import geo_cache_repo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -94,40 +93,6 @@ _BATCH_DELAY: float = 1.5
|
||||
#: transient error (e.g. connection reset due to rate limiting).
|
||||
_BATCH_MAX_RETRIES: int = 2
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Domain model
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeoInfo:
|
||||
"""Geographical and network metadata for a single IP address.
|
||||
|
||||
All fields default to ``None`` when the information is unavailable or
|
||||
the lookup fails gracefully.
|
||||
"""
|
||||
|
||||
country_code: str | None
|
||||
"""ISO 3166-1 alpha-2 country code, e.g. ``"DE"``."""
|
||||
|
||||
country_name: str | None
|
||||
"""Human-readable country name, e.g. ``"Germany"``."""
|
||||
|
||||
asn: str | None
|
||||
"""Autonomous System Number string, e.g. ``"AS3320"``."""
|
||||
|
||||
org: str | None
|
||||
"""Organisation name associated with the IP, e.g. ``"Deutsche Telekom"``."""
|
||||
|
||||
|
||||
type GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
|
||||
"""Async callable used to enrich IPs with :class:`~app.services.geo_service.GeoInfo`.
|
||||
|
||||
This is a shared type alias used by services that optionally accept a geo
|
||||
lookup callable (for example, :mod:`app.services.history_service`).
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal cache
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING
|
||||
import structlog
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.services.geo_service import GeoEnricher
|
||||
from app.models.geo import GeoEnricher
|
||||
|
||||
from app.models.ban import TIME_RANGE_SECONDS, TimeRange
|
||||
from app.models.history import (
|
||||
@@ -26,7 +26,7 @@ from app.models.history import (
|
||||
IpTimelineEvent,
|
||||
)
|
||||
from app.repositories import fail2ban_db_repo
|
||||
from app.services.ban_service import _get_fail2ban_db_path, _parse_data_json, _ts_to_iso
|
||||
from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -93,7 +93,7 @@ async def list_history(
|
||||
if range_ is not None:
|
||||
since = _since_unix(range_)
|
||||
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.info(
|
||||
"history_service_list",
|
||||
db_path=db_path,
|
||||
@@ -116,9 +116,9 @@ async def list_history(
|
||||
for row in rows:
|
||||
jail_name: str = row.jail
|
||||
ip: str = row.ip
|
||||
banned_at: str = _ts_to_iso(row.timeofban)
|
||||
banned_at: str = ts_to_iso(row.timeofban)
|
||||
ban_count: int = row.bancount
|
||||
matches, failures = _parse_data_json(row.data)
|
||||
matches, failures = parse_data_json(row.data)
|
||||
|
||||
country_code: str | None = None
|
||||
country_name: str | None = None
|
||||
@@ -180,7 +180,7 @@ async def get_ip_detail(
|
||||
:class:`~app.models.history.IpDetailResponse` if any records exist
|
||||
for *ip*, or ``None`` if the IP has no history in the database.
|
||||
"""
|
||||
db_path: str = await _get_fail2ban_db_path(socket_path)
|
||||
db_path: str = await get_fail2ban_db_path(socket_path)
|
||||
log.info("history_service_ip_detail", db_path=db_path, ip=ip)
|
||||
|
||||
rows = await fail2ban_db_repo.get_history_for_ip(db_path=db_path, ip=ip)
|
||||
@@ -193,9 +193,9 @@ async def get_ip_detail(
|
||||
|
||||
for row in rows:
|
||||
jail_name: str = row.jail
|
||||
banned_at: str = _ts_to_iso(row.timeofban)
|
||||
banned_at: str = ts_to_iso(row.timeofban)
|
||||
ban_count: int = row.bancount
|
||||
matches, failures = _parse_data_json(row.data)
|
||||
matches, failures = parse_data_json(row.data)
|
||||
total_failures += failures
|
||||
timeline.append(
|
||||
IpTimelineEvent(
|
||||
|
||||
@@ -14,11 +14,11 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import contextlib
|
||||
import ipaddress
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, TypedDict, cast
|
||||
|
||||
import structlog
|
||||
|
||||
from app.exceptions import JailNotFoundError, JailOperationError
|
||||
from app.models.ban import ActiveBan, ActiveBanListResponse, JailBannedIpsResponse
|
||||
from app.models.config import BantimeEscalation
|
||||
from app.models.jail import (
|
||||
@@ -28,7 +28,6 @@ from app.models.jail import (
|
||||
JailStatus,
|
||||
JailSummary,
|
||||
)
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.utils.fail2ban_client import (
|
||||
Fail2BanClient,
|
||||
Fail2BanCommand,
|
||||
@@ -38,9 +37,13 @@ from app.utils.fail2ban_client import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Awaitable
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
|
||||
from app.models.geo import GeoBatchLookup, GeoEnricher, GeoInfo
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
class IpLookupResult(TypedDict):
|
||||
@@ -55,8 +58,6 @@ class IpLookupResult(TypedDict):
|
||||
geo: GeoInfo | None
|
||||
|
||||
|
||||
GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -81,23 +82,6 @@ _backend_cmd_lock: asyncio.Lock = asyncio.Lock()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JailNotFoundError(Exception):
|
||||
"""Raised when a requested jail name does not exist in fail2ban."""
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
"""Initialise with the jail name that was not found.
|
||||
|
||||
Args:
|
||||
name: The jail name that could not be located.
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(f"Jail not found: {name!r}")
|
||||
|
||||
|
||||
class JailOperationError(Exception):
|
||||
"""Raised when a jail control command fails for a non-auth reason."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -820,6 +804,7 @@ async def unban_ip(
|
||||
|
||||
async def get_active_bans(
|
||||
socket_path: str,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
@@ -857,7 +842,6 @@ async def get_active_bans(
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket
|
||||
cannot be reached.
|
||||
"""
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
|
||||
@@ -905,10 +889,10 @@ async def get_active_bans(
|
||||
bans.append(ban)
|
||||
|
||||
# Enrich with geo data — prefer batch lookup over per-IP enricher.
|
||||
if http_session is not None and bans:
|
||||
if http_session is not None and bans and geo_batch_lookup is not None:
|
||||
all_ips: list[str] = [ban.ip for ban in bans]
|
||||
try:
|
||||
geo_map = await geo_service.lookup_batch(all_ips, http_session, db=app_db)
|
||||
geo_map = await geo_batch_lookup(all_ips, http_session, db=app_db)
|
||||
except Exception: # noqa: BLE001
|
||||
log.warning("active_bans_batch_geo_failed")
|
||||
geo_map = {}
|
||||
@@ -1017,6 +1001,7 @@ async def get_jail_banned_ips(
|
||||
page: int = 1,
|
||||
page_size: int = 25,
|
||||
search: str | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
) -> JailBannedIpsResponse:
|
||||
@@ -1044,8 +1029,6 @@ async def get_jail_banned_ips(
|
||||
~app.utils.fail2ban_client.Fail2BanConnectionError: If the socket is
|
||||
unreachable.
|
||||
"""
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
# Clamp page_size to the allowed maximum.
|
||||
page_size = min(page_size, _MAX_PAGE_SIZE)
|
||||
|
||||
@@ -1086,10 +1069,10 @@ async def get_jail_banned_ips(
|
||||
page_bans = all_bans[start : start + page_size]
|
||||
|
||||
# Geo-enrich only the page slice.
|
||||
if http_session is not None and page_bans:
|
||||
if http_session is not None and page_bans and geo_batch_lookup is not None:
|
||||
page_ips = [b.ip for b in page_bans]
|
||||
try:
|
||||
geo_map = await geo_service.lookup_batch(page_ips, http_session, db=app_db)
|
||||
geo_map = await geo_batch_lookup(page_ips, http_session, db=app_db)
|
||||
except Exception: # noqa: BLE001
|
||||
log.warning("jail_banned_ips_geo_failed", jail=jail_name)
|
||||
geo_map = {}
|
||||
|
||||
@@ -14,6 +14,8 @@ from typing import cast
|
||||
|
||||
import structlog
|
||||
|
||||
from app.exceptions import ServerOperationError
|
||||
from app.exceptions import ServerOperationError
|
||||
from app.models.server import ServerSettings, ServerSettingsResponse, ServerSettingsUpdate
|
||||
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanCommand, Fail2BanResponse
|
||||
|
||||
@@ -54,15 +56,6 @@ def _to_str(value: object | None, default: str) -> str:
|
||||
return str(value)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Custom exceptions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ServerOperationError(Exception):
|
||||
"""Raised when a server-level set command fails."""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
63
backend/app/utils/fail2ban_db_utils.py
Normal file
63
backend/app/utils/fail2ban_db_utils.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Utilities shared by fail2ban-related services."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
||||
def ts_to_iso(unix_ts: int) -> str:
|
||||
"""Convert a Unix timestamp to an ISO 8601 UTC string."""
|
||||
return datetime.fromtimestamp(unix_ts, tz=UTC).isoformat()
|
||||
|
||||
|
||||
async def get_fail2ban_db_path(socket_path: str) -> str:
|
||||
"""Query fail2ban for the path to its SQLite database file."""
|
||||
from app.utils.fail2ban_client import Fail2BanClient # pragma: no cover
|
||||
|
||||
socket_timeout: float = 5.0
|
||||
|
||||
async with Fail2BanClient(socket_path, timeout=socket_timeout) as client:
|
||||
response = await client.send(["get", "dbfile"])
|
||||
|
||||
if not isinstance(response, tuple) or len(response) != 2:
|
||||
raise RuntimeError(f"Unexpected response from fail2ban: {response!r}")
|
||||
|
||||
code, data = response
|
||||
if code != 0:
|
||||
raise RuntimeError(f"fail2ban error code {code}: {data!r}")
|
||||
|
||||
if data is None:
|
||||
raise RuntimeError("fail2ban has no database configured (dbfile is None)")
|
||||
|
||||
return str(data)
|
||||
|
||||
|
||||
def parse_data_json(raw: object) -> tuple[list[str], int]:
|
||||
"""Extract matches and failure count from the fail2ban bans.data value."""
|
||||
if raw is None:
|
||||
return [], 0
|
||||
|
||||
obj: dict[str, object] = {}
|
||||
if isinstance(raw, str):
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
if isinstance(parsed, dict):
|
||||
obj = parsed
|
||||
except json.JSONDecodeError:
|
||||
return [], 0
|
||||
elif isinstance(raw, dict):
|
||||
obj = raw
|
||||
|
||||
raw_matches = obj.get("matches")
|
||||
matches = [str(m) for m in raw_matches] if isinstance(raw_matches, list) else []
|
||||
|
||||
raw_failures = obj.get("failures")
|
||||
failures = 0
|
||||
if isinstance(raw_failures, (int, float, str)):
|
||||
try:
|
||||
failures = int(raw_failures)
|
||||
except (ValueError, TypeError):
|
||||
failures = 0
|
||||
|
||||
return matches, failures
|
||||
@@ -12,7 +12,7 @@ from httpx import ASGITransport, AsyncClient
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
|
||||
@@ -154,7 +154,7 @@ class TestListBansHappyPath:
|
||||
async def test_returns_bans_in_range(self, f2b_db_path: str) -> None:
|
||||
"""Only bans within the selected range are returned."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -166,7 +166,7 @@ class TestListBansHappyPath:
|
||||
async def test_results_sorted_newest_first(self, f2b_db_path: str) -> None:
|
||||
"""Items are ordered by ``banned_at`` descending (newest first)."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -177,7 +177,7 @@ class TestListBansHappyPath:
|
||||
async def test_ban_fields_present(self, f2b_db_path: str) -> None:
|
||||
"""Each item contains ip, jail, banned_at, ban_count."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -191,7 +191,7 @@ class TestListBansHappyPath:
|
||||
async def test_service_extracted_from_first_match(self, f2b_db_path: str) -> None:
|
||||
"""``service`` field is the first element of ``data.matches``."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -203,7 +203,7 @@ class TestListBansHappyPath:
|
||||
async def test_service_is_none_when_no_matches(self, f2b_db_path: str) -> None:
|
||||
"""``service`` is ``None`` when the ban has no stored matches."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
# Use 7d to include the older ban with no matches.
|
||||
@@ -215,7 +215,7 @@ class TestListBansHappyPath:
|
||||
async def test_empty_db_returns_zero(self, empty_f2b_db_path: str) -> None:
|
||||
"""When no bans exist the result has total=0 and no items."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -226,7 +226,7 @@ class TestListBansHappyPath:
|
||||
async def test_365d_range_includes_old_bans(self, f2b_db_path: str) -> None:
|
||||
"""The ``365d`` range includes bans that are 2 days old."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "365d")
|
||||
@@ -246,7 +246,7 @@ class TestListBansGeoEnrichment:
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
"""Geo fields are populated when an enricher returns data."""
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
async def fake_enricher(ip: str) -> GeoInfo:
|
||||
return GeoInfo(
|
||||
@@ -257,7 +257,7 @@ class TestListBansGeoEnrichment:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
@@ -278,7 +278,7 @@ class TestListBansGeoEnrichment:
|
||||
raise RuntimeError("geo service down")
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
@@ -304,25 +304,27 @@ class TestListBansBatchGeoEnrichment:
|
||||
"""Geo fields are populated via lookup_batch when http_session is given."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_geo_map = {
|
||||
"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS3320", org="Deutsche Telekom"),
|
||||
"5.6.7.8": GeoInfo(country_code="US", country_name="United States", asn="AS15169", org="Google"),
|
||||
}
|
||||
fake_geo_batch = AsyncMock(return_value=fake_geo_map)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
), patch(
|
||||
"app.services.geo_service.lookup_batch",
|
||||
new=AsyncMock(return_value=fake_geo_map),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", http_session=fake_session
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=fake_session,
|
||||
geo_batch_lookup=fake_geo_batch,
|
||||
)
|
||||
|
||||
fake_geo_batch.assert_awaited_once_with(["1.2.3.4", "5.6.7.8"], fake_session, db=None)
|
||||
assert result.total == 2
|
||||
de_item = next(i for i in result.items if i.ip == "1.2.3.4")
|
||||
us_item = next(i for i in result.items if i.ip == "5.6.7.8")
|
||||
@@ -339,15 +341,17 @@ class TestListBansBatchGeoEnrichment:
|
||||
|
||||
fake_session = MagicMock()
|
||||
|
||||
failing_geo_batch = AsyncMock(side_effect=RuntimeError("batch geo down"))
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
), patch(
|
||||
"app.services.geo_service.lookup_batch",
|
||||
new=AsyncMock(side_effect=RuntimeError("batch geo down")),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock", "24h", http_session=fake_session
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=fake_session,
|
||||
geo_batch_lookup=failing_geo_batch,
|
||||
)
|
||||
|
||||
assert result.total == 2
|
||||
@@ -360,28 +364,27 @@ class TestListBansBatchGeoEnrichment:
|
||||
"""When both http_session and geo_enricher are provided, batch wins."""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
fake_session = MagicMock()
|
||||
fake_geo_map = {
|
||||
"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
|
||||
"5.6.7.8": GeoInfo(country_code="DE", country_name="Germany", asn=None, org=None),
|
||||
}
|
||||
fake_geo_batch = AsyncMock(return_value=fake_geo_map)
|
||||
|
||||
async def enricher_should_not_be_called(ip: str) -> GeoInfo:
|
||||
raise AssertionError(f"geo_enricher was called for {ip!r} — should not happen")
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
), patch(
|
||||
"app.services.geo_service.lookup_batch",
|
||||
new=AsyncMock(return_value=fake_geo_map),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=fake_session,
|
||||
geo_batch_lookup=fake_geo_batch,
|
||||
geo_enricher=enricher_should_not_be_called,
|
||||
)
|
||||
|
||||
@@ -401,7 +404,7 @@ class TestListBansPagination:
|
||||
async def test_page_size_respected(self, f2b_db_path: str) -> None:
|
||||
"""``page_size=1`` returns at most one item."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
|
||||
@@ -412,7 +415,7 @@ class TestListBansPagination:
|
||||
async def test_page_2_returns_remaining_items(self, f2b_db_path: str) -> None:
|
||||
"""The second page returns items not on the first page."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
page1 = await ban_service.list_bans("/fake/sock", "7d", page=1, page_size=1)
|
||||
@@ -426,7 +429,7 @@ class TestListBansPagination:
|
||||
) -> None:
|
||||
"""``total`` reports all matching records regardless of pagination."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "7d", page_size=1)
|
||||
@@ -447,7 +450,7 @@ class TestBanOriginDerivation:
|
||||
) -> None:
|
||||
"""Bans from ``blocklist-import`` jail carry ``origin == "blocklist"``."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -461,7 +464,7 @@ class TestBanOriginDerivation:
|
||||
) -> None:
|
||||
"""Bans from organic jails (sshd, nginx, …) carry ``origin == "selfblock"``."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -476,7 +479,7 @@ class TestBanOriginDerivation:
|
||||
) -> None:
|
||||
"""Every returned item has an ``origin`` field with a valid value."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h")
|
||||
@@ -489,7 +492,7 @@ class TestBanOriginDerivation:
|
||||
) -> None:
|
||||
"""``bans_by_country`` also derives origin correctly for blocklist bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
||||
@@ -503,7 +506,7 @@ class TestBanOriginDerivation:
|
||||
) -> None:
|
||||
"""``bans_by_country`` derives origin correctly for organic jails."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country("/fake/sock", "24h")
|
||||
@@ -527,7 +530,7 @@ class TestOriginFilter:
|
||||
) -> None:
|
||||
"""``origin='blocklist'`` returns only blocklist-import jail bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
@@ -544,7 +547,7 @@ class TestOriginFilter:
|
||||
) -> None:
|
||||
"""``origin='selfblock'`` excludes the blocklist-import jail."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
@@ -562,7 +565,7 @@ class TestOriginFilter:
|
||||
) -> None:
|
||||
"""``origin=None`` applies no jail restriction — all bans returned."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans("/fake/sock", "24h", origin=None)
|
||||
@@ -574,7 +577,7 @@ class TestOriginFilter:
|
||||
) -> None:
|
||||
"""``bans_by_country`` with ``origin='blocklist'`` counts only blocklist bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
@@ -589,7 +592,7 @@ class TestOriginFilter:
|
||||
) -> None:
|
||||
"""``bans_by_country`` with ``origin='selfblock'`` excludes blocklist jails."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
@@ -604,7 +607,7 @@ class TestOriginFilter:
|
||||
) -> None:
|
||||
"""``bans_by_country`` with ``origin=None`` returns all bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
@@ -644,7 +647,7 @@ class TestBansbyCountryBackground:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
),
|
||||
patch(
|
||||
@@ -652,8 +655,13 @@ class TestBansbyCountryBackground:
|
||||
) as mock_create_task,
|
||||
):
|
||||
mock_session = AsyncMock()
|
||||
mock_batch = AsyncMock(return_value={})
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", http_session=mock_session
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=mock_session,
|
||||
geo_cache_lookup=geo_service.lookup_cached_only,
|
||||
geo_batch_lookup=mock_batch,
|
||||
)
|
||||
|
||||
# All countries resolved from cache — no background task needed.
|
||||
@@ -674,7 +682,7 @@ class TestBansbyCountryBackground:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
),
|
||||
patch(
|
||||
@@ -682,8 +690,13 @@ class TestBansbyCountryBackground:
|
||||
) as mock_create_task,
|
||||
):
|
||||
mock_session = AsyncMock()
|
||||
mock_batch = AsyncMock(return_value={})
|
||||
result = await ban_service.bans_by_country(
|
||||
"/fake/sock", "24h", http_session=mock_session
|
||||
"/fake/sock",
|
||||
"24h",
|
||||
http_session=mock_session,
|
||||
geo_cache_lookup=geo_service.lookup_cached_only,
|
||||
geo_batch_lookup=mock_batch,
|
||||
)
|
||||
|
||||
# Background task must have been scheduled for uncached IPs.
|
||||
@@ -701,7 +714,7 @@ class TestBansbyCountryBackground:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
),
|
||||
patch(
|
||||
@@ -727,7 +740,7 @@ class TestBanTrend:
|
||||
async def test_24h_returns_24_buckets(self, empty_f2b_db_path: str) -> None:
|
||||
"""``range_='24h'`` always yields exactly 24 buckets."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "24h")
|
||||
@@ -738,7 +751,7 @@ class TestBanTrend:
|
||||
async def test_7d_returns_28_buckets(self, empty_f2b_db_path: str) -> None:
|
||||
"""``range_='7d'`` yields 28 six-hour buckets."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "7d")
|
||||
@@ -749,7 +762,7 @@ class TestBanTrend:
|
||||
async def test_30d_returns_30_buckets(self, empty_f2b_db_path: str) -> None:
|
||||
"""``range_='30d'`` yields 30 daily buckets."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "30d")
|
||||
@@ -760,7 +773,7 @@ class TestBanTrend:
|
||||
async def test_365d_bucket_size_label(self, empty_f2b_db_path: str) -> None:
|
||||
"""``range_='365d'`` uses '7d' as the bucket size label."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "365d")
|
||||
@@ -771,7 +784,7 @@ class TestBanTrend:
|
||||
async def test_empty_db_all_buckets_zero(self, empty_f2b_db_path: str) -> None:
|
||||
"""All bucket counts are zero when the database has no bans."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "24h")
|
||||
@@ -781,7 +794,7 @@ class TestBanTrend:
|
||||
async def test_buckets_are_time_ordered(self, empty_f2b_db_path: str) -> None:
|
||||
"""Buckets are ordered chronologically (ascending timestamps)."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "7d")
|
||||
@@ -804,7 +817,7 @@ class TestBanTrend:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "24h")
|
||||
@@ -828,7 +841,7 @@ class TestBanTrend:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.ban_trend(
|
||||
@@ -854,7 +867,7 @@ class TestBanTrend:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.ban_trend(
|
||||
@@ -868,7 +881,7 @@ class TestBanTrend:
|
||||
from datetime import datetime
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.ban_trend("/fake/sock", "24h")
|
||||
@@ -904,7 +917,7 @@ class TestBansByJail:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
||||
@@ -931,7 +944,7 @@ class TestBansByJail:
|
||||
)
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
||||
@@ -942,7 +955,7 @@ class TestBansByJail:
|
||||
async def test_empty_db_returns_empty_list(self, empty_f2b_db_path: str) -> None:
|
||||
"""An empty database returns an empty jails list with total zero."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=empty_f2b_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
||||
@@ -954,7 +967,7 @@ class TestBansByJail:
|
||||
"""Bans older than the time window are not counted."""
|
||||
# f2b_db_path has one ban from _TWO_DAYS_AGO, which is outside "24h".
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail("/fake/sock", "24h")
|
||||
@@ -965,7 +978,7 @@ class TestBansByJail:
|
||||
async def test_origin_filter_blocklist(self, mixed_origin_db_path: str) -> None:
|
||||
"""``origin='blocklist'`` returns only the blocklist-import jail."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail(
|
||||
@@ -979,7 +992,7 @@ class TestBansByJail:
|
||||
async def test_origin_filter_selfblock(self, mixed_origin_db_path: str) -> None:
|
||||
"""``origin='selfblock'`` excludes the blocklist-import jail."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail(
|
||||
@@ -995,7 +1008,7 @@ class TestBansByJail:
|
||||
) -> None:
|
||||
"""``origin=None`` returns bans from all jails."""
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=mixed_origin_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_jail(
|
||||
@@ -1023,7 +1036,7 @@ class TestBansByJail:
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=path),
|
||||
),
|
||||
patch("app.services.ban_service.log") as mock_log,
|
||||
|
||||
@@ -19,8 +19,8 @@ from unittest.mock import AsyncMock, patch
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from app.models.geo import GeoInfo
|
||||
from app.services import ban_service, geo_service
|
||||
from app.services.geo_service import GeoInfo
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
@@ -161,7 +161,7 @@ class TestBanServicePerformance:
|
||||
return geo_service._cache.get(ip) # noqa: SLF001
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=perf_db_path),
|
||||
):
|
||||
start = time.perf_counter()
|
||||
@@ -191,7 +191,7 @@ class TestBanServicePerformance:
|
||||
return geo_service._cache.get(ip) # noqa: SLF001
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=perf_db_path),
|
||||
):
|
||||
start = time.perf_counter()
|
||||
@@ -217,7 +217,7 @@ class TestBanServicePerformance:
|
||||
return geo_service._cache.get(ip) # noqa: SLF001
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=perf_db_path),
|
||||
):
|
||||
result = await ban_service.list_bans(
|
||||
@@ -241,7 +241,7 @@ class TestBanServicePerformance:
|
||||
return geo_service._cache.get(ip) # noqa: SLF001
|
||||
|
||||
with patch(
|
||||
"app.services.ban_service._get_fail2ban_db_path",
|
||||
"app.services.ban_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=perf_db_path),
|
||||
):
|
||||
result = await ban_service.bans_by_country(
|
||||
|
||||
@@ -315,20 +315,15 @@ class TestGeoPrewarmCacheFilter:
|
||||
def _mock_is_cached(ip: str) -> bool:
|
||||
return ip == "1.2.3.4"
|
||||
|
||||
with (
|
||||
patch("app.services.jail_service.ban_ip", new_callable=AsyncMock),
|
||||
patch(
|
||||
"app.services.geo_service.is_cached",
|
||||
side_effect=_mock_is_cached,
|
||||
),
|
||||
patch(
|
||||
"app.services.geo_service.lookup_batch",
|
||||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
) as mock_batch,
|
||||
):
|
||||
mock_batch = AsyncMock(return_value={})
|
||||
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
|
||||
result = await blocklist_service.import_source(
|
||||
source, session, "/tmp/fake.sock", db
|
||||
source,
|
||||
session,
|
||||
"/tmp/fake.sock",
|
||||
db,
|
||||
geo_is_cached=_mock_is_cached,
|
||||
geo_batch_lookup=mock_batch,
|
||||
)
|
||||
|
||||
assert result.ips_imported == 3
|
||||
|
||||
@@ -7,8 +7,8 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.geo import GeoInfo
|
||||
from app.services import geo_service
|
||||
from app.services.geo_service import GeoInfo
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
|
||||
@@ -123,7 +123,7 @@ class TestListHistory:
|
||||
) -> None:
|
||||
"""No filter returns every record in the database."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history("fake_socket")
|
||||
@@ -135,7 +135,7 @@ class TestListHistory:
|
||||
) -> None:
|
||||
"""The ``range_`` filter excludes bans older than the window."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
# "24h" window should include only the two recent bans
|
||||
@@ -147,7 +147,7 @@ class TestListHistory:
|
||||
async def test_jail_filter(self, f2b_db_path: str) -> None:
|
||||
"""Jail filter restricts results to bans from that jail."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history("fake_socket", jail="nginx")
|
||||
@@ -157,7 +157,7 @@ class TestListHistory:
|
||||
async def test_ip_prefix_filter(self, f2b_db_path: str) -> None:
|
||||
"""IP prefix filter restricts results to matching IPs."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -170,7 +170,7 @@ class TestListHistory:
|
||||
async def test_combined_filters(self, f2b_db_path: str) -> None:
|
||||
"""Jail + IP prefix filters applied together narrow the result set."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -182,7 +182,7 @@ class TestListHistory:
|
||||
async def test_unknown_ip_returns_empty(self, f2b_db_path: str) -> None:
|
||||
"""Filtering by a non-existent IP returns an empty result set."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -196,7 +196,7 @@ class TestListHistory:
|
||||
) -> None:
|
||||
"""``failures`` field is parsed from the JSON ``data`` column."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -210,7 +210,7 @@ class TestListHistory:
|
||||
) -> None:
|
||||
"""``matches`` list is parsed from the JSON ``data`` column."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -226,7 +226,7 @@ class TestListHistory:
|
||||
) -> None:
|
||||
"""Records with ``data=NULL`` produce failures=0 and matches=[]."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -240,7 +240,7 @@ class TestListHistory:
|
||||
async def test_pagination(self, f2b_db_path: str) -> None:
|
||||
"""Pagination returns the correct slice."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.list_history(
|
||||
@@ -265,7 +265,7 @@ class TestGetIpDetail:
|
||||
) -> None:
|
||||
"""Returns ``None`` when the IP has no records in the database."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.get_ip_detail("fake_socket", "99.99.99.99")
|
||||
@@ -276,7 +276,7 @@ class TestGetIpDetail:
|
||||
) -> None:
|
||||
"""Returns an IpDetailResponse with correct totals for a known IP."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.get_ip_detail("fake_socket", "1.2.3.4")
|
||||
@@ -291,7 +291,7 @@ class TestGetIpDetail:
|
||||
) -> None:
|
||||
"""Timeline events are ordered newest-first."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.get_ip_detail("fake_socket", "1.2.3.4")
|
||||
@@ -304,7 +304,7 @@ class TestGetIpDetail:
|
||||
async def test_last_ban_at_is_most_recent(self, f2b_db_path: str) -> None:
|
||||
"""``last_ban_at`` matches the banned_at of the first timeline event."""
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.get_ip_detail("fake_socket", "1.2.3.4")
|
||||
@@ -316,7 +316,7 @@ class TestGetIpDetail:
|
||||
self, f2b_db_path: str
|
||||
) -> None:
|
||||
"""Geolocation is applied when a geo_enricher is provided."""
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
mock_geo = GeoInfo(
|
||||
country_code="US",
|
||||
@@ -327,7 +327,7 @@ class TestGetIpDetail:
|
||||
fake_enricher = AsyncMock(return_value=mock_geo)
|
||||
|
||||
with patch(
|
||||
"app.services.history_service._get_fail2ban_db_path",
|
||||
"app.services.history_service.get_fail2ban_db_path",
|
||||
new=AsyncMock(return_value=f2b_db_path),
|
||||
):
|
||||
result = await history_service.get_ip_detail(
|
||||
|
||||
@@ -635,7 +635,7 @@ class TestGetActiveBans:
|
||||
|
||||
async def test_http_session_triggers_lookup_batch(self) -> None:
|
||||
"""When http_session is provided, geo_service.lookup_batch is used."""
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
responses = {
|
||||
"status": _make_global_status("sshd"),
|
||||
@@ -645,17 +645,14 @@ class TestGetActiveBans:
|
||||
),
|
||||
}
|
||||
mock_geo = {"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")}
|
||||
mock_batch = AsyncMock(return_value=mock_geo)
|
||||
|
||||
with (
|
||||
_patch_client(responses),
|
||||
patch(
|
||||
"app.services.geo_service.lookup_batch",
|
||||
new=AsyncMock(return_value=mock_geo),
|
||||
) as mock_batch,
|
||||
):
|
||||
with _patch_client(responses):
|
||||
mock_session = AsyncMock()
|
||||
result = await jail_service.get_active_bans(
|
||||
_SOCKET, http_session=mock_session
|
||||
_SOCKET,
|
||||
http_session=mock_session,
|
||||
geo_batch_lookup=mock_batch,
|
||||
)
|
||||
|
||||
mock_batch.assert_awaited_once()
|
||||
@@ -672,16 +669,14 @@ class TestGetActiveBans:
|
||||
),
|
||||
}
|
||||
|
||||
with (
|
||||
_patch_client(responses),
|
||||
patch(
|
||||
"app.services.geo_service.lookup_batch",
|
||||
new=AsyncMock(side_effect=RuntimeError("geo down")),
|
||||
),
|
||||
):
|
||||
failing_batch = AsyncMock(side_effect=RuntimeError("geo down"))
|
||||
|
||||
with _patch_client(responses):
|
||||
mock_session = AsyncMock()
|
||||
result = await jail_service.get_active_bans(
|
||||
_SOCKET, http_session=mock_session
|
||||
_SOCKET,
|
||||
http_session=mock_session,
|
||||
geo_batch_lookup=failing_batch,
|
||||
)
|
||||
|
||||
assert result.total == 1
|
||||
@@ -689,7 +684,7 @@ class TestGetActiveBans:
|
||||
|
||||
async def test_geo_enricher_still_used_without_http_session(self) -> None:
|
||||
"""Legacy geo_enricher is still called when http_session is not provided."""
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
|
||||
responses = {
|
||||
"status": _make_global_status("sshd"),
|
||||
@@ -987,6 +982,7 @@ class TestGetJailBannedIps:
|
||||
page=1,
|
||||
page_size=2,
|
||||
http_session=http_session,
|
||||
geo_batch_lookup=geo_service.lookup_batch,
|
||||
)
|
||||
|
||||
# Only the 2-IP page slice should be passed to geo enrichment.
|
||||
|
||||
@@ -13,7 +13,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.models.geo import GeoInfo
|
||||
from app.tasks.geo_re_resolve import _run_re_resolve
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user