diff --git a/Docs/Tasks.md b/Docs/Tasks.md index e7df82e..a8d80e3 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -392,6 +392,8 @@ For each component listed: #### TASK B-12 โ€” Remove `Any` type annotations in `config_service.py` +**Status:** Completed โœ… + **Violated rule:** Backend-Development.md ยง1 โ€” Never use `Any`; all functions must have explicit type annotations. **Files affected:** diff --git a/backend/app/services/config_service.py b/backend/app/services/config_service.py index d791061..362a51a 100644 --- a/backend/app/services/config_service.py +++ b/backend/app/services/config_service.py @@ -16,10 +16,12 @@ import asyncio import contextlib import re from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, TypeVar, cast import structlog +from app.utils.fail2ban_client import Fail2BanCommand, Fail2BanResponse, Fail2BanToken + if TYPE_CHECKING: import aiosqlite @@ -80,7 +82,7 @@ class ConfigOperationError(Exception): # --------------------------------------------------------------------------- -def _ok(response: Any) -> Any: +def _ok(response: object) -> object: """Extract payload from a fail2ban ``(return_code, data)`` response. Args: @@ -93,7 +95,7 @@ def _ok(response: Any) -> Any: ValueError: If the return code indicates an error. """ try: - code, data = response + code, data = cast(Fail2BanResponse, response) except (TypeError, ValueError) as exc: raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc if code != 0: @@ -101,11 +103,11 @@ def _ok(response: Any) -> Any: return data -def _to_dict(pairs: Any) -> dict[str, Any]: +def _to_dict(pairs: object) -> dict[str, object]: """Convert a list of ``(key, value)`` pairs to a plain dict.""" if not isinstance(pairs, (list, tuple)): return {} - result: dict[str, Any] = {} + result: dict[str, object] = {} for item in pairs: try: k, v = item @@ -115,7 +117,7 @@ def _to_dict(pairs: Any) -> dict[str, Any]: return result -def _ensure_list(value: Any) -> list[str]: +def _ensure_list(value: object | None) -> list[str]: """Coerce a fail2ban ``get`` result to a list of strings.""" if value is None: return [] @@ -126,11 +128,14 @@ def _ensure_list(value: Any) -> list[str]: return [str(value)] +_T = TypeVar("_T") + + async def _safe_get( client: Fail2BanClient, - command: list[Any], - default: Any = None, -) -> Any: + command: Fail2BanCommand, + default: object | None = None, +) -> object | None: """Send a command and return *default* if it fails.""" try: return _ok(await client.send(command)) @@ -138,6 +143,15 @@ async def _safe_get( return default +async def _safe_get_typed( + client: Fail2BanClient, + command: Fail2BanCommand, + default: _T, +) -> _T: + """Send a command and return the result typed as ``default``'s type.""" + return cast(_T, await _safe_get(client, command, default)) + + def _is_not_found_error(exc: Exception) -> bool: """Return ``True`` if *exc* signals an unknown jail.""" msg = str(exc).lower() @@ -192,47 +206,25 @@ async def get_jail_config(socket_path: str, name: str) -> JailConfigResponse: raise JailNotFoundError(name) from exc raise - ( - bantime_raw, - findtime_raw, - maxretry_raw, - failregex_raw, - ignoreregex_raw, - logpath_raw, - datepattern_raw, - logencoding_raw, - backend_raw, - usedns_raw, - prefregex_raw, - actions_raw, - bt_increment_raw, - bt_factor_raw, - bt_formula_raw, - bt_multipliers_raw, - bt_maxtime_raw, - bt_rndtime_raw, - bt_overalljails_raw, - ) = await asyncio.gather( - _safe_get(client, ["get", name, "bantime"], 600), - _safe_get(client, ["get", name, "findtime"], 600), - _safe_get(client, ["get", name, "maxretry"], 5), - _safe_get(client, ["get", name, "failregex"], []), - _safe_get(client, ["get", name, "ignoreregex"], []), - _safe_get(client, ["get", name, "logpath"], []), - _safe_get(client, ["get", name, "datepattern"], None), - _safe_get(client, ["get", name, "logencoding"], "UTF-8"), - _safe_get(client, ["get", name, "backend"], "polling"), - _safe_get(client, ["get", name, "usedns"], "warn"), - _safe_get(client, ["get", name, "prefregex"], ""), - _safe_get(client, ["get", name, "actions"], []), - _safe_get(client, ["get", name, "bantime.increment"], False), - _safe_get(client, ["get", name, "bantime.factor"], None), - _safe_get(client, ["get", name, "bantime.formula"], None), - _safe_get(client, ["get", name, "bantime.multipliers"], None), - _safe_get(client, ["get", name, "bantime.maxtime"], None), - _safe_get(client, ["get", name, "bantime.rndtime"], None), - _safe_get(client, ["get", name, "bantime.overalljails"], False), - ) + bantime_raw: int = await _safe_get_typed(client, ["get", name, "bantime"], 600) + findtime_raw: int = await _safe_get_typed(client, ["get", name, "findtime"], 600) + maxretry_raw: int = await _safe_get_typed(client, ["get", name, "maxretry"], 5) + failregex_raw: list[object] = await _safe_get_typed(client, ["get", name, "failregex"], []) + ignoreregex_raw: list[object] = await _safe_get_typed(client, ["get", name, "ignoreregex"], []) + logpath_raw: list[object] = await _safe_get_typed(client, ["get", name, "logpath"], []) + datepattern_raw: str | None = await _safe_get_typed(client, ["get", name, "datepattern"], None) + logencoding_raw: str = await _safe_get_typed(client, ["get", name, "logencoding"], "UTF-8") + backend_raw: str = await _safe_get_typed(client, ["get", name, "backend"], "polling") + usedns_raw: str = await _safe_get_typed(client, ["get", name, "usedns"], "warn") + prefregex_raw: str = await _safe_get_typed(client, ["get", name, "prefregex"], "") + actions_raw: list[object] = await _safe_get_typed(client, ["get", name, "actions"], []) + bt_increment_raw: bool = await _safe_get_typed(client, ["get", name, "bantime.increment"], False) + bt_factor_raw: str | float | None = await _safe_get_typed(client, ["get", name, "bantime.factor"], None) + bt_formula_raw: str | None = await _safe_get_typed(client, ["get", name, "bantime.formula"], None) + bt_multipliers_raw: str | None = await _safe_get_typed(client, ["get", name, "bantime.multipliers"], None) + bt_maxtime_raw: str | int | None = await _safe_get_typed(client, ["get", name, "bantime.maxtime"], None) + bt_rndtime_raw: str | int | None = await _safe_get_typed(client, ["get", name, "bantime.rndtime"], None) + bt_overalljails_raw: bool = await _safe_get_typed(client, ["get", name, "bantime.overalljails"], False) bantime_escalation = BantimeEscalation( increment=bool(bt_increment_raw), @@ -352,7 +344,7 @@ async def update_jail_config( raise JailNotFoundError(name) from exc raise - async def _set(key: str, value: Any) -> None: + async def _set(key: str, value: Fail2BanToken) -> None: try: _ok(await client.send(["set", name, key, value])) except ValueError as exc: @@ -423,7 +415,7 @@ async def _replace_regex_list( new_patterns: Replacement list (may be empty to clear). """ # Determine current count. - current_raw = await _safe_get(client, ["get", jail, field], []) + current_raw: list[object] = await _safe_get_typed(client, ["get", jail, field], []) current: list[str] = _ensure_list(current_raw) del_cmd = f"del{field}" @@ -470,10 +462,10 @@ async def get_global_config(socket_path: str) -> GlobalConfigResponse: db_purge_age_raw, db_max_matches_raw, ) = await asyncio.gather( - _safe_get(client, ["get", "loglevel"], "INFO"), - _safe_get(client, ["get", "logtarget"], "STDOUT"), - _safe_get(client, ["get", "dbpurgeage"], 86400), - _safe_get(client, ["get", "dbmaxmatches"], 10), + _safe_get_typed(client, ["get", "loglevel"], "INFO"), + _safe_get_typed(client, ["get", "logtarget"], "STDOUT"), + _safe_get_typed(client, ["get", "dbpurgeage"], 86400), + _safe_get_typed(client, ["get", "dbmaxmatches"], 10), ) return GlobalConfigResponse( @@ -497,7 +489,7 @@ async def update_global_config(socket_path: str, update: GlobalConfigUpdate) -> """ client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) - async def _set_global(key: str, value: Any) -> None: + async def _set_global(key: str, value: Fail2BanToken) -> None: try: _ok(await client.send(["set", key, value])) except ValueError as exc: @@ -822,8 +814,8 @@ async def read_fail2ban_log( client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) log_level_raw, log_target_raw = await asyncio.gather( - _safe_get(client, ["get", "loglevel"], "INFO"), - _safe_get(client, ["get", "logtarget"], "STDOUT"), + _safe_get_typed(client, ["get", "loglevel"], "INFO"), + _safe_get_typed(client, ["get", "logtarget"], "STDOUT"), ) log_level = str(log_level_raw or "INFO").upper() @@ -905,8 +897,8 @@ async def get_service_status(socket_path: str) -> ServiceStatusResponse: if server_status.online: client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) log_level_raw, log_target_raw = await asyncio.gather( - _safe_get(client, ["get", "loglevel"], "INFO"), - _safe_get(client, ["get", "logtarget"], "STDOUT"), + _safe_get_typed(client, ["get", "loglevel"], "INFO"), + _safe_get_typed(client, ["get", "logtarget"], "STDOUT"), ) log_level = str(log_level_raw or "INFO").upper() log_target = str(log_target_raw or "STDOUT")