Fix ban_service typing by replacing Any with GeoEnricher and GeoInfo

This commit is contained in:
2026-03-17 10:33:39 +01:00
parent c9e688cc52
commit dfbe126368
2 changed files with 49 additions and 30 deletions

View File

@@ -131,7 +131,9 @@ After completing TASK B-5, a `geo_service` method (or via `geo_cache_repo` throu
--- ---
#### TASK B-7 — Replace `Any` type annotations in `ban_service.py` #### TASK B-7 — Replace `Any` type annotations in `ban_service.py`
**Status:** Completed ✅
**Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations. **Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations.

View File

@@ -13,15 +13,18 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import time import time
from collections.abc import Awaitable, Callable
from dataclasses import asdict from dataclasses import asdict
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, TypeAlias
import structlog import structlog
if TYPE_CHECKING: if TYPE_CHECKING:
import aiosqlite import aiosqlite
from app.services.geo_service import GeoInfo
from app.models.ban import ( from app.models.ban import (
BLOCKLIST_JAIL, BLOCKLIST_JAIL,
BUCKET_SECONDS, BUCKET_SECONDS,
@@ -34,6 +37,7 @@ from app.models.ban import (
BanTrendResponse, BanTrendResponse,
DashboardBanItem, DashboardBanItem,
DashboardBanListResponse, DashboardBanListResponse,
JailBanCount as JailBanCountModel,
TimeRange, TimeRange,
_derive_origin, _derive_origin,
bucket_count, bucket_count,
@@ -46,6 +50,8 @@ if TYPE_CHECKING:
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
GeoEnricher: TypeAlias = Callable[[str], Awaitable["GeoInfo"] | None]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Constants # Constants
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -144,7 +150,7 @@ async def _get_fail2ban_db_path(socket_path: str) -> str:
return str(data) return str(data)
def _parse_data_json(raw: Any) -> tuple[list[str], int]: def _parse_data_json(raw: object) -> tuple[list[str], int]:
"""Extract matches and failure count from the ``bans.data`` column. """Extract matches and failure count from the ``bans.data`` column.
The ``data`` column stores a JSON blob with optional keys: The ``data`` column stores a JSON blob with optional keys:
@@ -162,10 +168,10 @@ def _parse_data_json(raw: Any) -> tuple[list[str], int]:
if raw is None: if raw is None:
return [], 0 return [], 0
obj: dict[str, Any] = {} obj: dict[str, object] = {}
if isinstance(raw, str): if isinstance(raw, str):
try: try:
parsed: Any = json.loads(raw) parsed: object = json.loads(raw)
if isinstance(parsed, dict): if isinstance(parsed, dict):
obj = parsed obj = parsed
# json.loads("null") → None, or other non-dict — treat as empty # json.loads("null") → None, or other non-dict — treat as empty
@@ -174,8 +180,20 @@ def _parse_data_json(raw: Any) -> tuple[list[str], int]:
elif isinstance(raw, dict): elif isinstance(raw, dict):
obj = raw obj = raw
matches: list[str] = [str(m) for m in (obj.get("matches") or [])] raw_matches = obj.get("matches")
failures: int = int(obj.get("failures", 0)) if isinstance(raw_matches, list):
matches: list[str] = [str(m) for m in raw_matches]
else:
matches = []
raw_failures = obj.get("failures")
failures: int = 0
if isinstance(raw_failures, (int, float, str)):
try:
failures = int(raw_failures)
except (ValueError, TypeError):
failures = 0
return matches, failures return matches, failures
@@ -192,7 +210,7 @@ async def list_bans(
page_size: int = _DEFAULT_PAGE_SIZE, page_size: int = _DEFAULT_PAGE_SIZE,
http_session: aiohttp.ClientSession | None = None, http_session: aiohttp.ClientSession | None = None,
app_db: aiosqlite.Connection | None = None, app_db: aiosqlite.Connection | None = None,
geo_enricher: Any | None = None, geo_enricher: GeoEnricher | None = None,
origin: BanOrigin | None = None, origin: BanOrigin | None = None,
) -> DashboardBanListResponse: ) -> DashboardBanListResponse:
"""Return a paginated list of bans within the selected time window. """Return a paginated list of bans within the selected time window.
@@ -258,7 +276,7 @@ async def list_bans(
# Batch-resolve geo data for all IPs on this page in a single API call. # Batch-resolve geo data for all IPs on this page in a single API call.
# This avoids hitting the 45 req/min single-IP rate limit when the # This avoids hitting the 45 req/min single-IP rate limit when the
# page contains many bans (e.g. after a large blocklist import). # page contains many bans (e.g. after a large blocklist import).
geo_map: dict[str, Any] = {} geo_map: dict[str, "GeoInfo"] = {}
if http_session is not None and rows: if http_session is not None and rows:
page_ips: list[str] = [r.ip for r in rows] page_ips: list[str] = [r.ip for r in rows]
try: try:
@@ -333,7 +351,7 @@ async def bans_by_country(
socket_path: str, socket_path: str,
range_: TimeRange, range_: TimeRange,
http_session: aiohttp.ClientSession | None = None, http_session: aiohttp.ClientSession | None = None,
geo_enricher: Any | None = None, geo_enricher: GeoEnricher | None = None,
app_db: aiosqlite.Connection | None = None, app_db: aiosqlite.Connection | None = None,
origin: BanOrigin | None = None, origin: BanOrigin | None = None,
) -> BansByCountryResponse: ) -> BansByCountryResponse:
@@ -410,7 +428,7 @@ async def bans_by_country(
) )
unique_ips: list[str] = [r.ip for r in agg_rows] unique_ips: list[str] = [r.ip for r in agg_rows]
geo_map: dict[str, Any] = {} geo_map: dict[str, "GeoInfo"] = {}
if http_session is not None and unique_ips: if http_session is not None and unique_ips:
# Serve only what is already in the in-memory cache — no API calls on # Serve only what is already in the in-memory cache — no API calls on
@@ -431,7 +449,7 @@ async def bans_by_country(
) )
elif geo_enricher is not None and unique_ips: elif geo_enricher is not None and unique_ips:
# Fallback: legacy per-IP enricher (used in tests / older callers). # Fallback: legacy per-IP enricher (used in tests / older callers).
async def _safe_lookup(ip: str) -> tuple[str, Any]: async def _safe_lookup(ip: str) -> tuple[str, "GeoInfo" | None]:
try: try:
return ip, await geo_enricher(ip) return ip, await geo_enricher(ip)
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
@@ -439,18 +457,18 @@ async def bans_by_country(
return ip, None return ip, None
results = await asyncio.gather(*(_safe_lookup(ip) for ip in unique_ips)) results = await asyncio.gather(*(_safe_lookup(ip) for ip in unique_ips))
geo_map = dict(results) geo_map = {ip: geo for ip, geo in results if geo is not None}
# Build country aggregation from the SQL-grouped rows. # Build country aggregation from the SQL-grouped rows.
countries: dict[str, int] = {} countries: dict[str, int] = {}
country_names: dict[str, str] = {} country_names: dict[str, str] = {}
for row in agg_rows: for agg_row in agg_rows:
ip: str = row.ip ip: str = agg_row.ip
geo = geo_map.get(ip) geo = geo_map.get(ip)
cc: str | None = geo.country_code if geo else None cc: str | None = geo.country_code if geo else None
cn: str | None = geo.country_name if geo else None cn: str | None = geo.country_name if geo else None
event_count: int = row.event_count event_count: int = agg_row.event_count
if cc: if cc:
countries[cc] = countries.get(cc, 0) + event_count countries[cc] = countries.get(cc, 0) + event_count
@@ -459,27 +477,27 @@ async def bans_by_country(
# Build companion table from recent rows (geo already cached from batch step). # Build companion table from recent rows (geo already cached from batch step).
bans: list[DashboardBanItem] = [] bans: list[DashboardBanItem] = []
for row in companion_rows: for companion_row in companion_rows:
ip = row.ip ip = companion_row.ip
geo = geo_map.get(ip) geo = geo_map.get(ip)
cc = geo.country_code if geo else None cc = geo.country_code if geo else None
cn = geo.country_name if geo else None cn = geo.country_name if geo else None
asn: str | None = geo.asn if geo else None asn: str | None = geo.asn if geo else None
org: str | None = geo.org if geo else None org: str | None = geo.org if geo else None
matches, _ = _parse_data_json(row.data) matches, _ = _parse_data_json(companion_row.data)
bans.append( bans.append(
DashboardBanItem( DashboardBanItem(
ip=ip, ip=ip,
jail=row.jail, jail=companion_row.jail,
banned_at=_ts_to_iso(row.timeofban), banned_at=_ts_to_iso(companion_row.timeofban),
service=matches[0] if matches else None, service=matches[0] if matches else None,
country_code=cc, country_code=cc,
country_name=cn, country_name=cn,
asn=asn, asn=asn,
org=org, org=org,
ban_count=row.bancount, ban_count=companion_row.bancount,
origin=_derive_origin(row.jail), origin=_derive_origin(companion_row.jail),
) )
) )
@@ -608,7 +626,7 @@ async def bans_by_jail(
origin=origin, origin=origin,
) )
total, jails = await fail2ban_db_repo.get_bans_by_jail( total, jail_counts = await fail2ban_db_repo.get_bans_by_jail(
db_path=db_path, db_path=db_path,
since=since, since=since,
origin=origin, origin=origin,
@@ -634,11 +652,10 @@ async def bans_by_jail(
log.debug( log.debug(
"ban_service_bans_by_jail_result", "ban_service_bans_by_jail_result",
total=total, total=total,
jail_count=len(jails), jail_count=len(jail_counts),
) )
# Pydantic strict validation requires either dicts or model instances. return BansByJailResponse(
# Our repository returns dataclasses for simplicity, so convert them here. jails=[JailBanCountModel(jail=j.jail, count=j.count) for j in jail_counts],
jail_dicts: list[dict[str, object]] = [asdict(j) for j in jails] total=total,
)
return BansByJailResponse(jails=jail_dicts, total=total)