refactor: improve backend type safety and import organization

- Add TYPE_CHECKING guards for runtime-expensive imports (aiohttp, aiosqlite)
- Reorganize imports to follow PEP 8 conventions
- Convert TypeAlias to modern PEP 695 type syntax (where appropriate)
- Use Sequence/Mapping from collections.abc for type hints (covariant)
- Replace string literals with cast() for improved type inference
- Fix casting of Fail2BanResponse and TypedDict patterns
- Add IpLookupResult TypedDict for precise return type annotation
- Reformat overlong lines for readability (120 char limit)
- Add asyncio_mode and filterwarnings to pytest config
- Update test fixtures with improved type hints

This improves mypy type checking and makes type relationships explicit.
This commit is contained in:
2026-03-20 13:44:14 +01:00
parent bdcdd5d672
commit 1c0bac1353
30 changed files with 431 additions and 644 deletions

View File

@@ -14,7 +14,8 @@ from __future__ import annotations
import asyncio
import contextlib
import ipaddress
from typing import TYPE_CHECKING, Awaitable, Callable, cast, TypeAlias
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, TypedDict, cast
import structlog
@@ -27,6 +28,7 @@ from app.models.jail import (
JailStatus,
JailSummary,
)
from app.services.geo_service import GeoInfo
from app.utils.fail2ban_client import (
Fail2BanClient,
Fail2BanCommand,
@@ -39,11 +41,21 @@ if TYPE_CHECKING:
import aiohttp
import aiosqlite
from app.services.geo_service import GeoInfo
log: structlog.stdlib.BoundLogger = structlog.get_logger()
GeoEnricher: TypeAlias = Callable[[str], Awaitable["GeoInfo | None"]]
class IpLookupResult(TypedDict):
"""Result returned by :func:`lookup_ip`.
This is intentionally a :class:`TypedDict` to provide precise typing for
callers (e.g. routers) while keeping the implementation flexible.
"""
ip: str
currently_banned_in: list[str]
geo: GeoInfo | None
GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
# ---------------------------------------------------------------------------
# Constants
@@ -104,7 +116,7 @@ def _ok(response: object) -> object:
ValueError: If the response indicates an error (return code ≠ 0).
"""
try:
code, data = cast(Fail2BanResponse, response)
code, data = cast("Fail2BanResponse", response)
except (TypeError, ValueError) as exc:
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
@@ -202,7 +214,7 @@ async def _safe_get(
"""
try:
response = await client.send(command)
return _ok(cast(Fail2BanResponse, response))
return _ok(cast("Fail2BanResponse", response))
except (ValueError, TypeError, Exception):
return default
@@ -337,7 +349,6 @@ async def _fetch_jail_summary(
client.send(["get", name, "backend"]),
client.send(["get", name, "idle"]),
])
uses_backend_backend_commands = True
else:
# Commands not supported; return default values without sending.
async def _return_default(value: object | None) -> Fail2BanResponse:
@@ -347,7 +358,6 @@ async def _fetch_jail_summary(
_return_default("polling"), # backend default
_return_default(False), # idle default
])
uses_backend_backend_commands = False
_r = await asyncio.gather(*gather_list, return_exceptions=True)
status_raw: object | Exception = _r[0]
@@ -377,7 +387,7 @@ async def _fetch_jail_summary(
if isinstance(raw, Exception):
return fallback
try:
return int(str(_ok(cast(Fail2BanResponse, raw))))
return int(str(_ok(cast("Fail2BanResponse", raw))))
except (ValueError, TypeError):
return fallback
@@ -385,7 +395,7 @@ async def _fetch_jail_summary(
if isinstance(raw, Exception):
return fallback
try:
return str(_ok(cast(Fail2BanResponse, raw)))
return str(_ok(cast("Fail2BanResponse", raw)))
except (ValueError, TypeError):
return fallback
@@ -393,7 +403,7 @@ async def _fetch_jail_summary(
if isinstance(raw, Exception):
return fallback
try:
return bool(_ok(cast(Fail2BanResponse, raw)))
return bool(_ok(cast("Fail2BanResponse", raw)))
except (ValueError, TypeError):
return fallback
@@ -687,7 +697,7 @@ async def reload_all(
names_set -= set(exclude_jails)
stream: list[list[object]] = [["start", n] for n in sorted(names_set)]
_ok(await client.send(["reload", "--all", [], cast(Fail2BanToken, stream)]))
_ok(await client.send(["reload", "--all", [], cast("Fail2BanToken", stream)]))
log.info("all_jails_reloaded")
except ValueError as exc:
# Detect UnknownJailException (missing or invalid jail configuration)
@@ -811,8 +821,8 @@ async def unban_ip(
async def get_active_bans(
socket_path: str,
geo_enricher: GeoEnricher | None = None,
http_session: "aiohttp.ClientSession" | None = None,
app_db: "aiosqlite.Connection" | None = None,
http_session: aiohttp.ClientSession | None = None,
app_db: aiosqlite.Connection | None = None,
) -> ActiveBanListResponse:
"""Return all currently banned IPs across every jail.
@@ -880,7 +890,7 @@ async def get_active_bans(
continue
try:
ban_list: list[str] = cast(list[str], _ok(raw_result)) or []
ban_list: list[str] = cast("list[str]", _ok(raw_result)) or []
except (TypeError, ValueError) as exc:
log.warning(
"active_bans_parse_error",
@@ -1007,8 +1017,8 @@ async def get_jail_banned_ips(
page: int = 1,
page_size: int = 25,
search: str | None = None,
http_session: "aiohttp.ClientSession" | None = None,
app_db: "aiosqlite.Connection" | None = None,
http_session: aiohttp.ClientSession | None = None,
app_db: aiosqlite.Connection | None = None,
) -> JailBannedIpsResponse:
"""Return a paginated list of currently banned IPs for a single jail.
@@ -1055,7 +1065,7 @@ async def get_jail_banned_ips(
except (ValueError, TypeError):
raw_result = []
ban_list: list[str] = cast(list[str], raw_result) or []
ban_list: list[str] = cast("list[str]", raw_result) or []
# Parse all entries.
all_bans: list[ActiveBan] = []
@@ -1121,7 +1131,7 @@ async def _enrich_bans(
The same list with ``country`` fields populated where lookup succeeded.
"""
geo_results: list[object | Exception] = await asyncio.gather(
*[cast(Awaitable[object], geo_enricher(ban.ip)) for ban in bans],
*[cast("Awaitable[object]", geo_enricher(ban.ip)) for ban in bans],
return_exceptions=True,
)
enriched: list[ActiveBan] = []
@@ -1277,7 +1287,7 @@ async def lookup_ip(
socket_path: str,
ip: str,
geo_enricher: GeoEnricher | None = None,
) -> dict[str, object | list[str] | None]:
) -> IpLookupResult:
"""Return ban status and history for a single IP address.
Checks every running jail for whether the IP is currently banned.
@@ -1330,7 +1340,7 @@ async def lookup_ip(
if isinstance(result, Exception):
continue
try:
ban_list: list[str] = cast(list[str], _ok(result)) or []
ban_list: list[str] = cast("list[str]", _ok(result)) or []
if ip in ban_list:
currently_banned_in.append(jail_name)
except (ValueError, TypeError):