Remove Any type annotations from config_service.py

Replace Any with typed aliases (Fail2BanToken/Fail2BanCommand/Fail2BanResponse), add typed helper, and update task list.
This commit is contained in:
2026-03-17 11:42:46 +01:00
parent ce59a66973
commit 482399c9e2
2 changed files with 55 additions and 61 deletions

View File

@@ -16,10 +16,12 @@ import asyncio
import contextlib
import re
from pathlib import Path
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, TypeVar, cast
import structlog
from app.utils.fail2ban_client import Fail2BanCommand, Fail2BanResponse, Fail2BanToken
if TYPE_CHECKING:
import aiosqlite
@@ -80,7 +82,7 @@ class ConfigOperationError(Exception):
# ---------------------------------------------------------------------------
def _ok(response: Any) -> Any:
def _ok(response: object) -> object:
"""Extract payload from a fail2ban ``(return_code, data)`` response.
Args:
@@ -93,7 +95,7 @@ def _ok(response: Any) -> Any:
ValueError: If the return code indicates an error.
"""
try:
code, data = response
code, data = cast(Fail2BanResponse, response)
except (TypeError, ValueError) as exc:
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
if code != 0:
@@ -101,11 +103,11 @@ def _ok(response: Any) -> Any:
return data
def _to_dict(pairs: Any) -> dict[str, Any]:
def _to_dict(pairs: object) -> dict[str, object]:
"""Convert a list of ``(key, value)`` pairs to a plain dict."""
if not isinstance(pairs, (list, tuple)):
return {}
result: dict[str, Any] = {}
result: dict[str, object] = {}
for item in pairs:
try:
k, v = item
@@ -115,7 +117,7 @@ def _to_dict(pairs: Any) -> dict[str, Any]:
return result
def _ensure_list(value: Any) -> list[str]:
def _ensure_list(value: object | None) -> list[str]:
"""Coerce a fail2ban ``get`` result to a list of strings."""
if value is None:
return []
@@ -126,11 +128,14 @@ def _ensure_list(value: Any) -> list[str]:
return [str(value)]
_T = TypeVar("_T")
async def _safe_get(
client: Fail2BanClient,
command: list[Any],
default: Any = None,
) -> Any:
command: Fail2BanCommand,
default: object | None = None,
) -> object | None:
"""Send a command and return *default* if it fails."""
try:
return _ok(await client.send(command))
@@ -138,6 +143,15 @@ async def _safe_get(
return default
async def _safe_get_typed(
client: Fail2BanClient,
command: Fail2BanCommand,
default: _T,
) -> _T:
"""Send a command and return the result typed as ``default``'s type."""
return cast(_T, await _safe_get(client, command, default))
def _is_not_found_error(exc: Exception) -> bool:
"""Return ``True`` if *exc* signals an unknown jail."""
msg = str(exc).lower()
@@ -192,47 +206,25 @@ async def get_jail_config(socket_path: str, name: str) -> JailConfigResponse:
raise JailNotFoundError(name) from exc
raise
(
bantime_raw,
findtime_raw,
maxretry_raw,
failregex_raw,
ignoreregex_raw,
logpath_raw,
datepattern_raw,
logencoding_raw,
backend_raw,
usedns_raw,
prefregex_raw,
actions_raw,
bt_increment_raw,
bt_factor_raw,
bt_formula_raw,
bt_multipliers_raw,
bt_maxtime_raw,
bt_rndtime_raw,
bt_overalljails_raw,
) = await asyncio.gather(
_safe_get(client, ["get", name, "bantime"], 600),
_safe_get(client, ["get", name, "findtime"], 600),
_safe_get(client, ["get", name, "maxretry"], 5),
_safe_get(client, ["get", name, "failregex"], []),
_safe_get(client, ["get", name, "ignoreregex"], []),
_safe_get(client, ["get", name, "logpath"], []),
_safe_get(client, ["get", name, "datepattern"], None),
_safe_get(client, ["get", name, "logencoding"], "UTF-8"),
_safe_get(client, ["get", name, "backend"], "polling"),
_safe_get(client, ["get", name, "usedns"], "warn"),
_safe_get(client, ["get", name, "prefregex"], ""),
_safe_get(client, ["get", name, "actions"], []),
_safe_get(client, ["get", name, "bantime.increment"], False),
_safe_get(client, ["get", name, "bantime.factor"], None),
_safe_get(client, ["get", name, "bantime.formula"], None),
_safe_get(client, ["get", name, "bantime.multipliers"], None),
_safe_get(client, ["get", name, "bantime.maxtime"], None),
_safe_get(client, ["get", name, "bantime.rndtime"], None),
_safe_get(client, ["get", name, "bantime.overalljails"], False),
)
bantime_raw: int = await _safe_get_typed(client, ["get", name, "bantime"], 600)
findtime_raw: int = await _safe_get_typed(client, ["get", name, "findtime"], 600)
maxretry_raw: int = await _safe_get_typed(client, ["get", name, "maxretry"], 5)
failregex_raw: list[object] = await _safe_get_typed(client, ["get", name, "failregex"], [])
ignoreregex_raw: list[object] = await _safe_get_typed(client, ["get", name, "ignoreregex"], [])
logpath_raw: list[object] = await _safe_get_typed(client, ["get", name, "logpath"], [])
datepattern_raw: str | None = await _safe_get_typed(client, ["get", name, "datepattern"], None)
logencoding_raw: str = await _safe_get_typed(client, ["get", name, "logencoding"], "UTF-8")
backend_raw: str = await _safe_get_typed(client, ["get", name, "backend"], "polling")
usedns_raw: str = await _safe_get_typed(client, ["get", name, "usedns"], "warn")
prefregex_raw: str = await _safe_get_typed(client, ["get", name, "prefregex"], "")
actions_raw: list[object] = await _safe_get_typed(client, ["get", name, "actions"], [])
bt_increment_raw: bool = await _safe_get_typed(client, ["get", name, "bantime.increment"], False)
bt_factor_raw: str | float | None = await _safe_get_typed(client, ["get", name, "bantime.factor"], None)
bt_formula_raw: str | None = await _safe_get_typed(client, ["get", name, "bantime.formula"], None)
bt_multipliers_raw: str | None = await _safe_get_typed(client, ["get", name, "bantime.multipliers"], None)
bt_maxtime_raw: str | int | None = await _safe_get_typed(client, ["get", name, "bantime.maxtime"], None)
bt_rndtime_raw: str | int | None = await _safe_get_typed(client, ["get", name, "bantime.rndtime"], None)
bt_overalljails_raw: bool = await _safe_get_typed(client, ["get", name, "bantime.overalljails"], False)
bantime_escalation = BantimeEscalation(
increment=bool(bt_increment_raw),
@@ -352,7 +344,7 @@ async def update_jail_config(
raise JailNotFoundError(name) from exc
raise
async def _set(key: str, value: Any) -> None:
async def _set(key: str, value: Fail2BanToken) -> None:
try:
_ok(await client.send(["set", name, key, value]))
except ValueError as exc:
@@ -423,7 +415,7 @@ async def _replace_regex_list(
new_patterns: Replacement list (may be empty to clear).
"""
# Determine current count.
current_raw = await _safe_get(client, ["get", jail, field], [])
current_raw: list[object] = await _safe_get_typed(client, ["get", jail, field], [])
current: list[str] = _ensure_list(current_raw)
del_cmd = f"del{field}"
@@ -470,10 +462,10 @@ async def get_global_config(socket_path: str) -> GlobalConfigResponse:
db_purge_age_raw,
db_max_matches_raw,
) = await asyncio.gather(
_safe_get(client, ["get", "loglevel"], "INFO"),
_safe_get(client, ["get", "logtarget"], "STDOUT"),
_safe_get(client, ["get", "dbpurgeage"], 86400),
_safe_get(client, ["get", "dbmaxmatches"], 10),
_safe_get_typed(client, ["get", "loglevel"], "INFO"),
_safe_get_typed(client, ["get", "logtarget"], "STDOUT"),
_safe_get_typed(client, ["get", "dbpurgeage"], 86400),
_safe_get_typed(client, ["get", "dbmaxmatches"], 10),
)
return GlobalConfigResponse(
@@ -497,7 +489,7 @@ async def update_global_config(socket_path: str, update: GlobalConfigUpdate) ->
"""
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
async def _set_global(key: str, value: Any) -> None:
async def _set_global(key: str, value: Fail2BanToken) -> None:
try:
_ok(await client.send(["set", key, value]))
except ValueError as exc:
@@ -822,8 +814,8 @@ async def read_fail2ban_log(
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
log_level_raw, log_target_raw = await asyncio.gather(
_safe_get(client, ["get", "loglevel"], "INFO"),
_safe_get(client, ["get", "logtarget"], "STDOUT"),
_safe_get_typed(client, ["get", "loglevel"], "INFO"),
_safe_get_typed(client, ["get", "logtarget"], "STDOUT"),
)
log_level = str(log_level_raw or "INFO").upper()
@@ -905,8 +897,8 @@ async def get_service_status(socket_path: str) -> ServiceStatusResponse:
if server_status.online:
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
log_level_raw, log_target_raw = await asyncio.gather(
_safe_get(client, ["get", "loglevel"], "INFO"),
_safe_get(client, ["get", "logtarget"], "STDOUT"),
_safe_get_typed(client, ["get", "loglevel"], "INFO"),
_safe_get_typed(client, ["get", "logtarget"], "STDOUT"),
)
log_level = str(log_level_raw or "INFO").upper()
log_target = str(log_target_raw or "STDOUT")