Remove Any type annotations from config_service.py
Replace Any with typed aliases (Fail2BanToken/Fail2BanCommand/Fail2BanResponse), add typed helper, and update task list.
This commit is contained in:
@@ -16,10 +16,12 @@ import asyncio
|
||||
import contextlib
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, TypeVar, cast
|
||||
|
||||
import structlog
|
||||
|
||||
from app.utils.fail2ban_client import Fail2BanCommand, Fail2BanResponse, Fail2BanToken
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiosqlite
|
||||
|
||||
@@ -80,7 +82,7 @@ class ConfigOperationError(Exception):
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _ok(response: Any) -> Any:
|
||||
def _ok(response: object) -> object:
|
||||
"""Extract payload from a fail2ban ``(return_code, data)`` response.
|
||||
|
||||
Args:
|
||||
@@ -93,7 +95,7 @@ def _ok(response: Any) -> Any:
|
||||
ValueError: If the return code indicates an error.
|
||||
"""
|
||||
try:
|
||||
code, data = response
|
||||
code, data = cast(Fail2BanResponse, response)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||
if code != 0:
|
||||
@@ -101,11 +103,11 @@ def _ok(response: Any) -> Any:
|
||||
return data
|
||||
|
||||
|
||||
def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
def _to_dict(pairs: object) -> dict[str, object]:
|
||||
"""Convert a list of ``(key, value)`` pairs to a plain dict."""
|
||||
if not isinstance(pairs, (list, tuple)):
|
||||
return {}
|
||||
result: dict[str, Any] = {}
|
||||
result: dict[str, object] = {}
|
||||
for item in pairs:
|
||||
try:
|
||||
k, v = item
|
||||
@@ -115,7 +117,7 @@ def _to_dict(pairs: Any) -> dict[str, Any]:
|
||||
return result
|
||||
|
||||
|
||||
def _ensure_list(value: Any) -> list[str]:
|
||||
def _ensure_list(value: object | None) -> list[str]:
|
||||
"""Coerce a fail2ban ``get`` result to a list of strings."""
|
||||
if value is None:
|
||||
return []
|
||||
@@ -126,11 +128,14 @@ def _ensure_list(value: Any) -> list[str]:
|
||||
return [str(value)]
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
|
||||
async def _safe_get(
|
||||
client: Fail2BanClient,
|
||||
command: list[Any],
|
||||
default: Any = None,
|
||||
) -> Any:
|
||||
command: Fail2BanCommand,
|
||||
default: object | None = None,
|
||||
) -> object | None:
|
||||
"""Send a command and return *default* if it fails."""
|
||||
try:
|
||||
return _ok(await client.send(command))
|
||||
@@ -138,6 +143,15 @@ async def _safe_get(
|
||||
return default
|
||||
|
||||
|
||||
async def _safe_get_typed(
|
||||
client: Fail2BanClient,
|
||||
command: Fail2BanCommand,
|
||||
default: _T,
|
||||
) -> _T:
|
||||
"""Send a command and return the result typed as ``default``'s type."""
|
||||
return cast(_T, await _safe_get(client, command, default))
|
||||
|
||||
|
||||
def _is_not_found_error(exc: Exception) -> bool:
|
||||
"""Return ``True`` if *exc* signals an unknown jail."""
|
||||
msg = str(exc).lower()
|
||||
@@ -192,47 +206,25 @@ async def get_jail_config(socket_path: str, name: str) -> JailConfigResponse:
|
||||
raise JailNotFoundError(name) from exc
|
||||
raise
|
||||
|
||||
(
|
||||
bantime_raw,
|
||||
findtime_raw,
|
||||
maxretry_raw,
|
||||
failregex_raw,
|
||||
ignoreregex_raw,
|
||||
logpath_raw,
|
||||
datepattern_raw,
|
||||
logencoding_raw,
|
||||
backend_raw,
|
||||
usedns_raw,
|
||||
prefregex_raw,
|
||||
actions_raw,
|
||||
bt_increment_raw,
|
||||
bt_factor_raw,
|
||||
bt_formula_raw,
|
||||
bt_multipliers_raw,
|
||||
bt_maxtime_raw,
|
||||
bt_rndtime_raw,
|
||||
bt_overalljails_raw,
|
||||
) = await asyncio.gather(
|
||||
_safe_get(client, ["get", name, "bantime"], 600),
|
||||
_safe_get(client, ["get", name, "findtime"], 600),
|
||||
_safe_get(client, ["get", name, "maxretry"], 5),
|
||||
_safe_get(client, ["get", name, "failregex"], []),
|
||||
_safe_get(client, ["get", name, "ignoreregex"], []),
|
||||
_safe_get(client, ["get", name, "logpath"], []),
|
||||
_safe_get(client, ["get", name, "datepattern"], None),
|
||||
_safe_get(client, ["get", name, "logencoding"], "UTF-8"),
|
||||
_safe_get(client, ["get", name, "backend"], "polling"),
|
||||
_safe_get(client, ["get", name, "usedns"], "warn"),
|
||||
_safe_get(client, ["get", name, "prefregex"], ""),
|
||||
_safe_get(client, ["get", name, "actions"], []),
|
||||
_safe_get(client, ["get", name, "bantime.increment"], False),
|
||||
_safe_get(client, ["get", name, "bantime.factor"], None),
|
||||
_safe_get(client, ["get", name, "bantime.formula"], None),
|
||||
_safe_get(client, ["get", name, "bantime.multipliers"], None),
|
||||
_safe_get(client, ["get", name, "bantime.maxtime"], None),
|
||||
_safe_get(client, ["get", name, "bantime.rndtime"], None),
|
||||
_safe_get(client, ["get", name, "bantime.overalljails"], False),
|
||||
)
|
||||
bantime_raw: int = await _safe_get_typed(client, ["get", name, "bantime"], 600)
|
||||
findtime_raw: int = await _safe_get_typed(client, ["get", name, "findtime"], 600)
|
||||
maxretry_raw: int = await _safe_get_typed(client, ["get", name, "maxretry"], 5)
|
||||
failregex_raw: list[object] = await _safe_get_typed(client, ["get", name, "failregex"], [])
|
||||
ignoreregex_raw: list[object] = await _safe_get_typed(client, ["get", name, "ignoreregex"], [])
|
||||
logpath_raw: list[object] = await _safe_get_typed(client, ["get", name, "logpath"], [])
|
||||
datepattern_raw: str | None = await _safe_get_typed(client, ["get", name, "datepattern"], None)
|
||||
logencoding_raw: str = await _safe_get_typed(client, ["get", name, "logencoding"], "UTF-8")
|
||||
backend_raw: str = await _safe_get_typed(client, ["get", name, "backend"], "polling")
|
||||
usedns_raw: str = await _safe_get_typed(client, ["get", name, "usedns"], "warn")
|
||||
prefregex_raw: str = await _safe_get_typed(client, ["get", name, "prefregex"], "")
|
||||
actions_raw: list[object] = await _safe_get_typed(client, ["get", name, "actions"], [])
|
||||
bt_increment_raw: bool = await _safe_get_typed(client, ["get", name, "bantime.increment"], False)
|
||||
bt_factor_raw: str | float | None = await _safe_get_typed(client, ["get", name, "bantime.factor"], None)
|
||||
bt_formula_raw: str | None = await _safe_get_typed(client, ["get", name, "bantime.formula"], None)
|
||||
bt_multipliers_raw: str | None = await _safe_get_typed(client, ["get", name, "bantime.multipliers"], None)
|
||||
bt_maxtime_raw: str | int | None = await _safe_get_typed(client, ["get", name, "bantime.maxtime"], None)
|
||||
bt_rndtime_raw: str | int | None = await _safe_get_typed(client, ["get", name, "bantime.rndtime"], None)
|
||||
bt_overalljails_raw: bool = await _safe_get_typed(client, ["get", name, "bantime.overalljails"], False)
|
||||
|
||||
bantime_escalation = BantimeEscalation(
|
||||
increment=bool(bt_increment_raw),
|
||||
@@ -352,7 +344,7 @@ async def update_jail_config(
|
||||
raise JailNotFoundError(name) from exc
|
||||
raise
|
||||
|
||||
async def _set(key: str, value: Any) -> None:
|
||||
async def _set(key: str, value: Fail2BanToken) -> None:
|
||||
try:
|
||||
_ok(await client.send(["set", name, key, value]))
|
||||
except ValueError as exc:
|
||||
@@ -423,7 +415,7 @@ async def _replace_regex_list(
|
||||
new_patterns: Replacement list (may be empty to clear).
|
||||
"""
|
||||
# Determine current count.
|
||||
current_raw = await _safe_get(client, ["get", jail, field], [])
|
||||
current_raw: list[object] = await _safe_get_typed(client, ["get", jail, field], [])
|
||||
current: list[str] = _ensure_list(current_raw)
|
||||
|
||||
del_cmd = f"del{field}"
|
||||
@@ -470,10 +462,10 @@ async def get_global_config(socket_path: str) -> GlobalConfigResponse:
|
||||
db_purge_age_raw,
|
||||
db_max_matches_raw,
|
||||
) = await asyncio.gather(
|
||||
_safe_get(client, ["get", "loglevel"], "INFO"),
|
||||
_safe_get(client, ["get", "logtarget"], "STDOUT"),
|
||||
_safe_get(client, ["get", "dbpurgeage"], 86400),
|
||||
_safe_get(client, ["get", "dbmaxmatches"], 10),
|
||||
_safe_get_typed(client, ["get", "loglevel"], "INFO"),
|
||||
_safe_get_typed(client, ["get", "logtarget"], "STDOUT"),
|
||||
_safe_get_typed(client, ["get", "dbpurgeage"], 86400),
|
||||
_safe_get_typed(client, ["get", "dbmaxmatches"], 10),
|
||||
)
|
||||
|
||||
return GlobalConfigResponse(
|
||||
@@ -497,7 +489,7 @@ async def update_global_config(socket_path: str, update: GlobalConfigUpdate) ->
|
||||
"""
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
|
||||
async def _set_global(key: str, value: Any) -> None:
|
||||
async def _set_global(key: str, value: Fail2BanToken) -> None:
|
||||
try:
|
||||
_ok(await client.send(["set", key, value]))
|
||||
except ValueError as exc:
|
||||
@@ -822,8 +814,8 @@ async def read_fail2ban_log(
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
|
||||
log_level_raw, log_target_raw = await asyncio.gather(
|
||||
_safe_get(client, ["get", "loglevel"], "INFO"),
|
||||
_safe_get(client, ["get", "logtarget"], "STDOUT"),
|
||||
_safe_get_typed(client, ["get", "loglevel"], "INFO"),
|
||||
_safe_get_typed(client, ["get", "logtarget"], "STDOUT"),
|
||||
)
|
||||
|
||||
log_level = str(log_level_raw or "INFO").upper()
|
||||
@@ -905,8 +897,8 @@ async def get_service_status(socket_path: str) -> ServiceStatusResponse:
|
||||
if server_status.online:
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
log_level_raw, log_target_raw = await asyncio.gather(
|
||||
_safe_get(client, ["get", "loglevel"], "INFO"),
|
||||
_safe_get(client, ["get", "logtarget"], "STDOUT"),
|
||||
_safe_get_typed(client, ["get", "loglevel"], "INFO"),
|
||||
_safe_get_typed(client, ["get", "logtarget"], "STDOUT"),
|
||||
)
|
||||
log_level = str(log_level_raw or "INFO").upper()
|
||||
log_target = str(log_target_raw or "STDOUT")
|
||||
|
||||
Reference in New Issue
Block a user