Fix geo_re_resolve async mocks and mark tasks complete
This commit is contained in:
@@ -15,7 +15,7 @@ under the key ``"blocklist_schedule"``.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
|
||||
@@ -56,7 +56,7 @@ _PREVIEW_MAX_BYTES: int = 65536
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _row_to_source(row: dict[str, Any]) -> BlocklistSource:
|
||||
def _row_to_source(row: dict[str, object]) -> BlocklistSource:
|
||||
"""Convert a repository row dict to a :class:`BlocklistSource`.
|
||||
|
||||
Args:
|
||||
@@ -542,7 +542,7 @@ async def list_import_logs(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _aiohttp_timeout(seconds: float) -> Any:
|
||||
def _aiohttp_timeout(seconds: float) -> "aiohttp.ClientTimeout":
|
||||
"""Return an :class:`aiohttp.ClientTimeout` with the given total timeout.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -28,7 +28,7 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, cast, TypeAlias
|
||||
|
||||
import structlog
|
||||
|
||||
@@ -57,7 +57,12 @@ from app.models.config import (
|
||||
from app.services import jail_service
|
||||
from app.services.jail_service import JailNotFoundError as JailNotFoundError
|
||||
from app.utils import conffile_parser
|
||||
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanConnectionError
|
||||
from app.utils.fail2ban_client import (
|
||||
Fail2BanClient,
|
||||
Fail2BanCommand,
|
||||
Fail2BanConnectionError,
|
||||
Fail2BanResponse,
|
||||
)
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -539,10 +544,10 @@ async def _get_active_jail_names(socket_path: str) -> set[str]:
|
||||
try:
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
|
||||
def _to_dict_inner(pairs: Any) -> dict[str, Any]:
|
||||
def _to_dict_inner(pairs: object) -> dict[str, object]:
|
||||
if not isinstance(pairs, (list, tuple)):
|
||||
return {}
|
||||
result: dict[str, Any] = {}
|
||||
result: dict[str, object] = {}
|
||||
for item in pairs:
|
||||
try:
|
||||
k, v = item
|
||||
@@ -551,8 +556,8 @@ async def _get_active_jail_names(socket_path: str) -> set[str]:
|
||||
pass
|
||||
return result
|
||||
|
||||
def _ok(response: Any) -> Any:
|
||||
code, data = response
|
||||
def _ok(response: object) -> object:
|
||||
code, data = cast(Fail2BanResponse, response)
|
||||
if code != 0:
|
||||
raise ValueError(f"fail2ban error {code}: {data!r}")
|
||||
return data
|
||||
@@ -813,7 +818,7 @@ def _write_local_override_sync(
|
||||
config_dir: Path,
|
||||
jail_name: str,
|
||||
enabled: bool,
|
||||
overrides: dict[str, Any],
|
||||
overrides: dict[str, object],
|
||||
) -> None:
|
||||
"""Write a ``jail.d/{name}.local`` file atomically.
|
||||
|
||||
@@ -862,7 +867,7 @@ def _write_local_override_sync(
|
||||
if overrides.get("port") is not None:
|
||||
lines.append(f"port = {overrides['port']}")
|
||||
if overrides.get("logpath"):
|
||||
paths: list[str] = overrides["logpath"]
|
||||
paths: list[str] = cast(list[str], overrides["logpath"])
|
||||
if paths:
|
||||
lines.append(f"logpath = {paths[0]}")
|
||||
for p in paths[1:]:
|
||||
@@ -1209,7 +1214,7 @@ async def activate_jail(
|
||||
),
|
||||
)
|
||||
|
||||
overrides: dict[str, Any] = {
|
||||
overrides: dict[str, object] = {
|
||||
"bantime": req.bantime,
|
||||
"findtime": req.findtime,
|
||||
"maxretry": req.maxretry,
|
||||
|
||||
@@ -30,7 +30,8 @@ Usage::
|
||||
# single lookup
|
||||
info = await geo_service.lookup("1.2.3.4", session)
|
||||
if info:
|
||||
print(info.country_code) # "DE"
|
||||
# info.country_code == "DE"
|
||||
... # use the GeoInfo object in your application
|
||||
|
||||
# bulk lookup (more efficient for large sets)
|
||||
geo_map = await geo_service.lookup_batch(["1.2.3.4", "5.6.7.8"], session)
|
||||
@@ -42,7 +43,7 @@ import asyncio
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, TypeAlias
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
@@ -119,7 +120,7 @@ class GeoInfo:
|
||||
"""Organisation name associated with the IP, e.g. ``"Deutsche Telekom"``."""
|
||||
|
||||
|
||||
GeoEnricher: TypeAlias = Callable[[str], Awaitable[GeoInfo | None]]
|
||||
type GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
|
||||
"""Async callable used to enrich IPs with :class:`~app.services.geo_service.GeoInfo`.
|
||||
|
||||
This is a shared type alias used by services that optionally accept a geo
|
||||
|
||||
@@ -9,12 +9,17 @@ seconds by the background health-check task, not on every HTTP request.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import structlog
|
||||
|
||||
from app.models.server import ServerStatus
|
||||
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanConnectionError, Fail2BanProtocolError
|
||||
from app.utils.fail2ban_client import (
|
||||
Fail2BanClient,
|
||||
Fail2BanConnectionError,
|
||||
Fail2BanProtocolError,
|
||||
Fail2BanResponse,
|
||||
)
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -25,7 +30,7 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
_SOCKET_TIMEOUT: float = 5.0
|
||||
|
||||
|
||||
def _ok(response: Any) -> Any:
|
||||
def _ok(response: object) -> object:
|
||||
"""Extract the payload from a fail2ban ``(return_code, data)`` response.
|
||||
|
||||
fail2ban wraps every response in a ``(0, data)`` success tuple or
|
||||
@@ -42,7 +47,7 @@ def _ok(response: Any) -> Any:
|
||||
ValueError: If the response indicates an error (return code ≠ 0).
|
||||
"""
|
||||
try:
|
||||
code, data = response
|
||||
code, data = cast(Fail2BanResponse, response)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||
|
||||
@@ -52,7 +57,7 @@ def _ok(response: Any) -> Any:
|
||||
return data
|
||||
|
||||
|
||||
def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
def _to_dict(pairs: object) -> dict[str, object]:
|
||||
"""Convert a list of ``(key, value)`` pairs to a plain dict.
|
||||
|
||||
fail2ban returns structured data as lists of 2-tuples rather than dicts.
|
||||
@@ -66,7 +71,7 @@ def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
"""
|
||||
if not isinstance(pairs, (list, tuple)):
|
||||
return {}
|
||||
result: dict[str, Any] = {}
|
||||
result: dict[str, object] = {}
|
||||
for item in pairs:
|
||||
try:
|
||||
k, v = item
|
||||
@@ -119,7 +124,7 @@ async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerSta
|
||||
# 3. Global status — jail count and names #
|
||||
# ------------------------------------------------------------------ #
|
||||
status_data = _to_dict(_ok(await client.send(["status"])))
|
||||
active_jails: int = int(status_data.get("Number of jail", 0) or 0)
|
||||
active_jails: int = int(str(status_data.get("Number of jail", 0) or 0))
|
||||
jail_list_raw: str = str(status_data.get("Jail list", "") or "").strip()
|
||||
jail_names: list[str] = (
|
||||
[j.strip() for j in jail_list_raw.split(",") if j.strip()]
|
||||
@@ -138,8 +143,8 @@ async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerSta
|
||||
jail_resp = _to_dict(_ok(await client.send(["status", jail_name])))
|
||||
filter_stats = _to_dict(jail_resp.get("Filter") or [])
|
||||
action_stats = _to_dict(jail_resp.get("Actions") or [])
|
||||
total_failures += int(filter_stats.get("Currently failed", 0) or 0)
|
||||
total_bans += int(action_stats.get("Currently banned", 0) or 0)
|
||||
total_failures += int(str(filter_stats.get("Currently failed", 0) or 0))
|
||||
total_bans += int(str(action_stats.get("Currently banned", 0) or 0))
|
||||
except (ValueError, TypeError, KeyError) as exc:
|
||||
log.warning(
|
||||
"fail2ban_jail_status_parse_error",
|
||||
|
||||
@@ -14,7 +14,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import contextlib
|
||||
import ipaddress
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, cast, TypeAlias
|
||||
|
||||
import structlog
|
||||
|
||||
@@ -27,10 +27,24 @@ from app.models.jail import (
|
||||
JailStatus,
|
||||
JailSummary,
|
||||
)
|
||||
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanConnectionError
|
||||
from app.utils.fail2ban_client import (
|
||||
Fail2BanClient,
|
||||
Fail2BanCommand,
|
||||
Fail2BanConnectionError,
|
||||
Fail2BanResponse,
|
||||
Fail2BanToken,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
|
||||
from app.services.geo_service import GeoInfo
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
GeoEnricher: TypeAlias = Callable[[str], Awaitable["GeoInfo | None"]]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -77,7 +91,7 @@ class JailOperationError(Exception):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ok(response: Any) -> Any:
|
||||
def _ok(response: object) -> object:
|
||||
"""Extract the payload from a fail2ban ``(return_code, data)`` response.
|
||||
|
||||
Args:
|
||||
@@ -90,7 +104,7 @@ def _ok(response: Any) -> Any:
|
||||
ValueError: If the response indicates an error (return code ≠ 0).
|
||||
"""
|
||||
try:
|
||||
code, data = response
|
||||
code, data = cast(Fail2BanResponse, response)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||
|
||||
@@ -100,7 +114,7 @@ def _ok(response: Any) -> Any:
|
||||
return data
|
||||
|
||||
|
||||
def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
def _to_dict(pairs: object) -> dict[str, object]:
|
||||
"""Convert a list of ``(key, value)`` pairs to a plain dict.
|
||||
|
||||
Args:
|
||||
@@ -111,7 +125,7 @@ def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
"""
|
||||
if not isinstance(pairs, (list, tuple)):
|
||||
return {}
|
||||
result: dict[str, Any] = {}
|
||||
result: dict[str, object] = {}
|
||||
for item in pairs:
|
||||
try:
|
||||
k, v = item
|
||||
@@ -121,7 +135,7 @@ def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
return result
|
||||
|
||||
|
||||
def _ensure_list(value: Any) -> list[str]:
|
||||
def _ensure_list(value: object | None) -> list[str]:
|
||||
"""Coerce a fail2ban response value to a list of strings.
|
||||
|
||||
Some fail2ban ``get`` responses return ``None`` or a single string
|
||||
@@ -170,9 +184,9 @@ def _is_not_found_error(exc: Exception) -> bool:
|
||||
|
||||
async def _safe_get(
|
||||
client: Fail2BanClient,
|
||||
command: list[Any],
|
||||
default: Any = None,
|
||||
) -> Any:
|
||||
command: Fail2BanCommand,
|
||||
default: object | None = None,
|
||||
) -> object | None:
|
||||
"""Send a ``get`` command and return ``default`` on error.
|
||||
|
||||
Errors during optional detail queries (logpath, regex, etc.) should
|
||||
@@ -187,7 +201,8 @@ async def _safe_get(
|
||||
The response payload, or *default* on any error.
|
||||
"""
|
||||
try:
|
||||
return _ok(await client.send(command))
|
||||
response = await client.send(command)
|
||||
return _ok(cast(Fail2BanResponse, response))
|
||||
except (ValueError, TypeError, Exception):
|
||||
return default
|
||||
|
||||
@@ -309,7 +324,7 @@ async def _fetch_jail_summary(
|
||||
backend_cmd_is_supported = await _check_backend_cmd_supported(client, name)
|
||||
|
||||
# Build the gather list based on command support.
|
||||
gather_list: list[Any] = [
|
||||
gather_list: list[Awaitable[object]] = [
|
||||
client.send(["status", name, "short"]),
|
||||
client.send(["get", name, "bantime"]),
|
||||
client.send(["get", name, "findtime"]),
|
||||
@@ -325,7 +340,7 @@ async def _fetch_jail_summary(
|
||||
uses_backend_backend_commands = True
|
||||
else:
|
||||
# Commands not supported; return default values without sending.
|
||||
async def _return_default(value: Any) -> tuple[int, Any]:
|
||||
async def _return_default(value: object | None) -> Fail2BanResponse:
|
||||
return (0, value)
|
||||
|
||||
gather_list.extend([
|
||||
@@ -335,12 +350,12 @@ async def _fetch_jail_summary(
|
||||
uses_backend_backend_commands = False
|
||||
|
||||
_r = await asyncio.gather(*gather_list, return_exceptions=True)
|
||||
status_raw: Any = _r[0]
|
||||
bantime_raw: Any = _r[1]
|
||||
findtime_raw: Any = _r[2]
|
||||
maxretry_raw: Any = _r[3]
|
||||
backend_raw: Any = _r[4]
|
||||
idle_raw: Any = _r[5]
|
||||
status_raw: object | Exception = _r[0]
|
||||
bantime_raw: object | Exception = _r[1]
|
||||
findtime_raw: object | Exception = _r[2]
|
||||
maxretry_raw: object | Exception = _r[3]
|
||||
backend_raw: object | Exception = _r[4]
|
||||
idle_raw: object | Exception = _r[5]
|
||||
|
||||
# Parse jail status (filter + actions).
|
||||
jail_status: JailStatus | None = None
|
||||
@@ -350,35 +365,35 @@ async def _fetch_jail_summary(
|
||||
filter_stats = _to_dict(raw.get("Filter") or [])
|
||||
action_stats = _to_dict(raw.get("Actions") or [])
|
||||
jail_status = JailStatus(
|
||||
currently_banned=int(action_stats.get("Currently banned", 0) or 0),
|
||||
total_banned=int(action_stats.get("Total banned", 0) or 0),
|
||||
currently_failed=int(filter_stats.get("Currently failed", 0) or 0),
|
||||
total_failed=int(filter_stats.get("Total failed", 0) or 0),
|
||||
currently_banned=int(str(action_stats.get("Currently banned", 0) or 0)),
|
||||
total_banned=int(str(action_stats.get("Total banned", 0) or 0)),
|
||||
currently_failed=int(str(filter_stats.get("Currently failed", 0) or 0)),
|
||||
total_failed=int(str(filter_stats.get("Total failed", 0) or 0)),
|
||||
)
|
||||
except (ValueError, TypeError) as exc:
|
||||
log.warning("jail_status_parse_error", jail=name, error=str(exc))
|
||||
|
||||
def _safe_int(raw: Any, fallback: int) -> int:
|
||||
def _safe_int(raw: object | Exception, fallback: int) -> int:
|
||||
if isinstance(raw, Exception):
|
||||
return fallback
|
||||
try:
|
||||
return int(_ok(raw))
|
||||
return int(str(_ok(cast(Fail2BanResponse, raw))))
|
||||
except (ValueError, TypeError):
|
||||
return fallback
|
||||
|
||||
def _safe_str(raw: Any, fallback: str) -> str:
|
||||
def _safe_str(raw: object | Exception, fallback: str) -> str:
|
||||
if isinstance(raw, Exception):
|
||||
return fallback
|
||||
try:
|
||||
return str(_ok(raw))
|
||||
return str(_ok(cast(Fail2BanResponse, raw)))
|
||||
except (ValueError, TypeError):
|
||||
return fallback
|
||||
|
||||
def _safe_bool(raw: Any, fallback: bool = False) -> bool:
|
||||
def _safe_bool(raw: object | Exception, fallback: bool = False) -> bool:
|
||||
if isinstance(raw, Exception):
|
||||
return fallback
|
||||
try:
|
||||
return bool(_ok(raw))
|
||||
return bool(_ok(cast(Fail2BanResponse, raw)))
|
||||
except (ValueError, TypeError):
|
||||
return fallback
|
||||
|
||||
@@ -428,10 +443,10 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
||||
action_stats = _to_dict(raw.get("Actions") or [])
|
||||
|
||||
jail_status = JailStatus(
|
||||
currently_banned=int(action_stats.get("Currently banned", 0) or 0),
|
||||
total_banned=int(action_stats.get("Total banned", 0) or 0),
|
||||
currently_failed=int(filter_stats.get("Currently failed", 0) or 0),
|
||||
total_failed=int(filter_stats.get("Total failed", 0) or 0),
|
||||
currently_banned=int(str(action_stats.get("Currently banned", 0) or 0)),
|
||||
total_banned=int(str(action_stats.get("Total banned", 0) or 0)),
|
||||
currently_failed=int(str(filter_stats.get("Currently failed", 0) or 0)),
|
||||
total_failed=int(str(filter_stats.get("Total failed", 0) or 0)),
|
||||
)
|
||||
|
||||
# Fetch all detail fields in parallel.
|
||||
@@ -480,11 +495,11 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
||||
bt_increment: bool = bool(bt_increment_raw)
|
||||
bantime_escalation = BantimeEscalation(
|
||||
increment=bt_increment,
|
||||
factor=float(bt_factor_raw) if bt_factor_raw is not None else None,
|
||||
factor=float(str(bt_factor_raw)) if bt_factor_raw is not None else None,
|
||||
formula=str(bt_formula_raw) if bt_formula_raw else None,
|
||||
multipliers=str(bt_multipliers_raw) if bt_multipliers_raw else None,
|
||||
max_time=int(bt_maxtime_raw) if bt_maxtime_raw is not None else None,
|
||||
rnd_time=int(bt_rndtime_raw) if bt_rndtime_raw is not None else None,
|
||||
max_time=int(str(bt_maxtime_raw)) if bt_maxtime_raw is not None else None,
|
||||
rnd_time=int(str(bt_rndtime_raw)) if bt_rndtime_raw is not None else None,
|
||||
overall_jails=bool(bt_overalljails_raw),
|
||||
)
|
||||
|
||||
@@ -500,9 +515,9 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse:
|
||||
ignore_ips=_ensure_list(ignoreip_raw),
|
||||
date_pattern=str(datepattern_raw) if datepattern_raw else None,
|
||||
log_encoding=str(logencoding_raw or "UTF-8"),
|
||||
find_time=int(findtime_raw or 600),
|
||||
ban_time=int(bantime_raw or 600),
|
||||
max_retry=int(maxretry_raw or 5),
|
||||
find_time=int(str(findtime_raw or 600)),
|
||||
ban_time=int(str(bantime_raw or 600)),
|
||||
max_retry=int(str(maxretry_raw or 5)),
|
||||
bantime_escalation=bantime_escalation,
|
||||
status=jail_status,
|
||||
actions=_ensure_list(actions_raw),
|
||||
@@ -671,8 +686,8 @@ async def reload_all(
|
||||
if exclude_jails:
|
||||
names_set -= set(exclude_jails)
|
||||
|
||||
stream: list[list[str]] = [["start", n] for n in sorted(names_set)]
|
||||
_ok(await client.send(["reload", "--all", [], stream]))
|
||||
stream: list[list[object]] = [["start", n] for n in sorted(names_set)]
|
||||
_ok(await client.send(["reload", "--all", [], cast(Fail2BanToken, stream)]))
|
||||
log.info("all_jails_reloaded")
|
||||
except ValueError as exc:
|
||||
# Detect UnknownJailException (missing or invalid jail configuration)
|
||||
@@ -795,9 +810,9 @@ async def unban_ip(
|
||||
|
||||
async def get_active_bans(
|
||||
socket_path: str,
|
||||
geo_enricher: Any | None = None,
|
||||
http_session: Any | None = None,
|
||||
app_db: Any | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
http_session: "aiohttp.ClientSession" | None = None,
|
||||
app_db: "aiosqlite.Connection" | None = None,
|
||||
) -> ActiveBanListResponse:
|
||||
"""Return all currently banned IPs across every jail.
|
||||
|
||||
@@ -849,7 +864,7 @@ async def get_active_bans(
|
||||
return ActiveBanListResponse(bans=[], total=0)
|
||||
|
||||
# For each jail, fetch the ban list with time info in parallel.
|
||||
results: list[Any] = await asyncio.gather(
|
||||
results: list[object | Exception] = await asyncio.gather(
|
||||
*[client.send(["get", jn, "banip", "--with-time"]) for jn in jail_names],
|
||||
return_exceptions=True,
|
||||
)
|
||||
@@ -865,7 +880,7 @@ async def get_active_bans(
|
||||
continue
|
||||
|
||||
try:
|
||||
ban_list: list[str] = _ok(raw_result) or []
|
||||
ban_list: list[str] = cast(list[str], _ok(raw_result)) or []
|
||||
except (TypeError, ValueError) as exc:
|
||||
log.warning(
|
||||
"active_bans_parse_error",
|
||||
@@ -992,8 +1007,8 @@ async def get_jail_banned_ips(
|
||||
page: int = 1,
|
||||
page_size: int = 25,
|
||||
search: str | None = None,
|
||||
http_session: Any | None = None,
|
||||
app_db: Any | None = None,
|
||||
http_session: "aiohttp.ClientSession" | None = None,
|
||||
app_db: "aiosqlite.Connection" | None = None,
|
||||
) -> JailBannedIpsResponse:
|
||||
"""Return a paginated list of currently banned IPs for a single jail.
|
||||
|
||||
@@ -1040,7 +1055,7 @@ async def get_jail_banned_ips(
|
||||
except (ValueError, TypeError):
|
||||
raw_result = []
|
||||
|
||||
ban_list: list[str] = raw_result or []
|
||||
ban_list: list[str] = cast(list[str], raw_result) or []
|
||||
|
||||
# Parse all entries.
|
||||
all_bans: list[ActiveBan] = []
|
||||
@@ -1094,7 +1109,7 @@ async def get_jail_banned_ips(
|
||||
|
||||
async def _enrich_bans(
|
||||
bans: list[ActiveBan],
|
||||
geo_enricher: Any,
|
||||
geo_enricher: GeoEnricher,
|
||||
) -> list[ActiveBan]:
|
||||
"""Enrich ban records with geo data asynchronously.
|
||||
|
||||
@@ -1105,14 +1120,15 @@ async def _enrich_bans(
|
||||
Returns:
|
||||
The same list with ``country`` fields populated where lookup succeeded.
|
||||
"""
|
||||
geo_results: list[Any] = await asyncio.gather(
|
||||
*[geo_enricher(ban.ip) for ban in bans],
|
||||
geo_results: list[object | Exception] = await asyncio.gather(
|
||||
*[cast(Awaitable[object], geo_enricher(ban.ip)) for ban in bans],
|
||||
return_exceptions=True,
|
||||
)
|
||||
enriched: list[ActiveBan] = []
|
||||
for ban, geo in zip(bans, geo_results, strict=False):
|
||||
if geo is not None and not isinstance(geo, Exception):
|
||||
enriched.append(ban.model_copy(update={"country": geo.country_code}))
|
||||
geo_info = cast("GeoInfo", geo)
|
||||
enriched.append(ban.model_copy(update={"country": geo_info.country_code}))
|
||||
else:
|
||||
enriched.append(ban)
|
||||
return enriched
|
||||
@@ -1260,8 +1276,8 @@ async def set_ignore_self(socket_path: str, name: str, *, on: bool) -> None:
|
||||
async def lookup_ip(
|
||||
socket_path: str,
|
||||
ip: str,
|
||||
geo_enricher: Any | None = None,
|
||||
) -> dict[str, Any]:
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
) -> dict[str, object | list[str] | None]:
|
||||
"""Return ban status and history for a single IP address.
|
||||
|
||||
Checks every running jail for whether the IP is currently banned.
|
||||
@@ -1304,7 +1320,7 @@ async def lookup_ip(
|
||||
)
|
||||
|
||||
# Check ban status per jail in parallel.
|
||||
ban_results: list[Any] = await asyncio.gather(
|
||||
ban_results: list[object | Exception] = await asyncio.gather(
|
||||
*[client.send(["get", jn, "banip"]) for jn in jail_names],
|
||||
return_exceptions=True,
|
||||
)
|
||||
@@ -1314,7 +1330,7 @@ async def lookup_ip(
|
||||
if isinstance(result, Exception):
|
||||
continue
|
||||
try:
|
||||
ban_list: list[str] = _ok(result) or []
|
||||
ban_list: list[str] = cast(list[str], _ok(result)) or []
|
||||
if ip in ban_list:
|
||||
currently_banned_in.append(jail_name)
|
||||
except (ValueError, TypeError):
|
||||
@@ -1351,6 +1367,6 @@ async def unban_all_ips(socket_path: str) -> int:
|
||||
cannot be reached.
|
||||
"""
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
count: int = int(_ok(await client.send(["unban", "--all"])))
|
||||
count: int = int(str(_ok(await client.send(["unban", "--all"])) or 0))
|
||||
log.info("all_ips_unbanned", count=count)
|
||||
return count
|
||||
|
||||
Reference in New Issue
Block a user