Fix ban_service typing by replacing Any with GeoEnricher and GeoInfo
This commit is contained in:
@@ -131,7 +131,9 @@ After completing TASK B-5, a `geo_service` method (or via `geo_cache_repo` throu
|
|||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
#### TASK B-7 — Replace `Any` type annotations in `ban_service.py`
|
#### TASK B-7 — Replace `Any` type annotations in `ban_service.py` ✅
|
||||||
|
|
||||||
|
**Status:** Completed ✅
|
||||||
|
|
||||||
**Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations.
|
**Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations.
|
||||||
|
|
||||||
|
|||||||
@@ -13,15 +13,18 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, TypeAlias
|
||||||
|
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import aiosqlite
|
import aiosqlite
|
||||||
|
|
||||||
|
from app.services.geo_service import GeoInfo
|
||||||
|
|
||||||
from app.models.ban import (
|
from app.models.ban import (
|
||||||
BLOCKLIST_JAIL,
|
BLOCKLIST_JAIL,
|
||||||
BUCKET_SECONDS,
|
BUCKET_SECONDS,
|
||||||
@@ -34,6 +37,7 @@ from app.models.ban import (
|
|||||||
BanTrendResponse,
|
BanTrendResponse,
|
||||||
DashboardBanItem,
|
DashboardBanItem,
|
||||||
DashboardBanListResponse,
|
DashboardBanListResponse,
|
||||||
|
JailBanCount as JailBanCountModel,
|
||||||
TimeRange,
|
TimeRange,
|
||||||
_derive_origin,
|
_derive_origin,
|
||||||
bucket_count,
|
bucket_count,
|
||||||
@@ -46,6 +50,8 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||||
|
|
||||||
|
GeoEnricher: TypeAlias = Callable[[str], Awaitable["GeoInfo"] | None]
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Constants
|
# Constants
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -144,7 +150,7 @@ async def _get_fail2ban_db_path(socket_path: str) -> str:
|
|||||||
return str(data)
|
return str(data)
|
||||||
|
|
||||||
|
|
||||||
def _parse_data_json(raw: Any) -> tuple[list[str], int]:
|
def _parse_data_json(raw: object) -> tuple[list[str], int]:
|
||||||
"""Extract matches and failure count from the ``bans.data`` column.
|
"""Extract matches and failure count from the ``bans.data`` column.
|
||||||
|
|
||||||
The ``data`` column stores a JSON blob with optional keys:
|
The ``data`` column stores a JSON blob with optional keys:
|
||||||
@@ -162,10 +168,10 @@ def _parse_data_json(raw: Any) -> tuple[list[str], int]:
|
|||||||
if raw is None:
|
if raw is None:
|
||||||
return [], 0
|
return [], 0
|
||||||
|
|
||||||
obj: dict[str, Any] = {}
|
obj: dict[str, object] = {}
|
||||||
if isinstance(raw, str):
|
if isinstance(raw, str):
|
||||||
try:
|
try:
|
||||||
parsed: Any = json.loads(raw)
|
parsed: object = json.loads(raw)
|
||||||
if isinstance(parsed, dict):
|
if isinstance(parsed, dict):
|
||||||
obj = parsed
|
obj = parsed
|
||||||
# json.loads("null") → None, or other non-dict — treat as empty
|
# json.loads("null") → None, or other non-dict — treat as empty
|
||||||
@@ -174,8 +180,20 @@ def _parse_data_json(raw: Any) -> tuple[list[str], int]:
|
|||||||
elif isinstance(raw, dict):
|
elif isinstance(raw, dict):
|
||||||
obj = raw
|
obj = raw
|
||||||
|
|
||||||
matches: list[str] = [str(m) for m in (obj.get("matches") or [])]
|
raw_matches = obj.get("matches")
|
||||||
failures: int = int(obj.get("failures", 0))
|
if isinstance(raw_matches, list):
|
||||||
|
matches: list[str] = [str(m) for m in raw_matches]
|
||||||
|
else:
|
||||||
|
matches = []
|
||||||
|
|
||||||
|
raw_failures = obj.get("failures")
|
||||||
|
failures: int = 0
|
||||||
|
if isinstance(raw_failures, (int, float, str)):
|
||||||
|
try:
|
||||||
|
failures = int(raw_failures)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
failures = 0
|
||||||
|
|
||||||
return matches, failures
|
return matches, failures
|
||||||
|
|
||||||
|
|
||||||
@@ -192,7 +210,7 @@ async def list_bans(
|
|||||||
page_size: int = _DEFAULT_PAGE_SIZE,
|
page_size: int = _DEFAULT_PAGE_SIZE,
|
||||||
http_session: aiohttp.ClientSession | None = None,
|
http_session: aiohttp.ClientSession | None = None,
|
||||||
app_db: aiosqlite.Connection | None = None,
|
app_db: aiosqlite.Connection | None = None,
|
||||||
geo_enricher: Any | None = None,
|
geo_enricher: GeoEnricher | None = None,
|
||||||
origin: BanOrigin | None = None,
|
origin: BanOrigin | None = None,
|
||||||
) -> DashboardBanListResponse:
|
) -> DashboardBanListResponse:
|
||||||
"""Return a paginated list of bans within the selected time window.
|
"""Return a paginated list of bans within the selected time window.
|
||||||
@@ -258,7 +276,7 @@ async def list_bans(
|
|||||||
# Batch-resolve geo data for all IPs on this page in a single API call.
|
# Batch-resolve geo data for all IPs on this page in a single API call.
|
||||||
# This avoids hitting the 45 req/min single-IP rate limit when the
|
# This avoids hitting the 45 req/min single-IP rate limit when the
|
||||||
# page contains many bans (e.g. after a large blocklist import).
|
# page contains many bans (e.g. after a large blocklist import).
|
||||||
geo_map: dict[str, Any] = {}
|
geo_map: dict[str, "GeoInfo"] = {}
|
||||||
if http_session is not None and rows:
|
if http_session is not None and rows:
|
||||||
page_ips: list[str] = [r.ip for r in rows]
|
page_ips: list[str] = [r.ip for r in rows]
|
||||||
try:
|
try:
|
||||||
@@ -333,7 +351,7 @@ async def bans_by_country(
|
|||||||
socket_path: str,
|
socket_path: str,
|
||||||
range_: TimeRange,
|
range_: TimeRange,
|
||||||
http_session: aiohttp.ClientSession | None = None,
|
http_session: aiohttp.ClientSession | None = None,
|
||||||
geo_enricher: Any | None = None,
|
geo_enricher: GeoEnricher | None = None,
|
||||||
app_db: aiosqlite.Connection | None = None,
|
app_db: aiosqlite.Connection | None = None,
|
||||||
origin: BanOrigin | None = None,
|
origin: BanOrigin | None = None,
|
||||||
) -> BansByCountryResponse:
|
) -> BansByCountryResponse:
|
||||||
@@ -410,7 +428,7 @@ async def bans_by_country(
|
|||||||
)
|
)
|
||||||
|
|
||||||
unique_ips: list[str] = [r.ip for r in agg_rows]
|
unique_ips: list[str] = [r.ip for r in agg_rows]
|
||||||
geo_map: dict[str, Any] = {}
|
geo_map: dict[str, "GeoInfo"] = {}
|
||||||
|
|
||||||
if http_session is not None and unique_ips:
|
if http_session is not None and unique_ips:
|
||||||
# Serve only what is already in the in-memory cache — no API calls on
|
# Serve only what is already in the in-memory cache — no API calls on
|
||||||
@@ -431,7 +449,7 @@ async def bans_by_country(
|
|||||||
)
|
)
|
||||||
elif geo_enricher is not None and unique_ips:
|
elif geo_enricher is not None and unique_ips:
|
||||||
# Fallback: legacy per-IP enricher (used in tests / older callers).
|
# Fallback: legacy per-IP enricher (used in tests / older callers).
|
||||||
async def _safe_lookup(ip: str) -> tuple[str, Any]:
|
async def _safe_lookup(ip: str) -> tuple[str, "GeoInfo" | None]:
|
||||||
try:
|
try:
|
||||||
return ip, await geo_enricher(ip)
|
return ip, await geo_enricher(ip)
|
||||||
except Exception: # noqa: BLE001
|
except Exception: # noqa: BLE001
|
||||||
@@ -439,18 +457,18 @@ async def bans_by_country(
|
|||||||
return ip, None
|
return ip, None
|
||||||
|
|
||||||
results = await asyncio.gather(*(_safe_lookup(ip) for ip in unique_ips))
|
results = await asyncio.gather(*(_safe_lookup(ip) for ip in unique_ips))
|
||||||
geo_map = dict(results)
|
geo_map = {ip: geo for ip, geo in results if geo is not None}
|
||||||
|
|
||||||
# Build country aggregation from the SQL-grouped rows.
|
# Build country aggregation from the SQL-grouped rows.
|
||||||
countries: dict[str, int] = {}
|
countries: dict[str, int] = {}
|
||||||
country_names: dict[str, str] = {}
|
country_names: dict[str, str] = {}
|
||||||
|
|
||||||
for row in agg_rows:
|
for agg_row in agg_rows:
|
||||||
ip: str = row.ip
|
ip: str = agg_row.ip
|
||||||
geo = geo_map.get(ip)
|
geo = geo_map.get(ip)
|
||||||
cc: str | None = geo.country_code if geo else None
|
cc: str | None = geo.country_code if geo else None
|
||||||
cn: str | None = geo.country_name if geo else None
|
cn: str | None = geo.country_name if geo else None
|
||||||
event_count: int = row.event_count
|
event_count: int = agg_row.event_count
|
||||||
|
|
||||||
if cc:
|
if cc:
|
||||||
countries[cc] = countries.get(cc, 0) + event_count
|
countries[cc] = countries.get(cc, 0) + event_count
|
||||||
@@ -459,27 +477,27 @@ async def bans_by_country(
|
|||||||
|
|
||||||
# Build companion table from recent rows (geo already cached from batch step).
|
# Build companion table from recent rows (geo already cached from batch step).
|
||||||
bans: list[DashboardBanItem] = []
|
bans: list[DashboardBanItem] = []
|
||||||
for row in companion_rows:
|
for companion_row in companion_rows:
|
||||||
ip = row.ip
|
ip = companion_row.ip
|
||||||
geo = geo_map.get(ip)
|
geo = geo_map.get(ip)
|
||||||
cc = geo.country_code if geo else None
|
cc = geo.country_code if geo else None
|
||||||
cn = geo.country_name if geo else None
|
cn = geo.country_name if geo else None
|
||||||
asn: str | None = geo.asn if geo else None
|
asn: str | None = geo.asn if geo else None
|
||||||
org: str | None = geo.org if geo else None
|
org: str | None = geo.org if geo else None
|
||||||
matches, _ = _parse_data_json(row.data)
|
matches, _ = _parse_data_json(companion_row.data)
|
||||||
|
|
||||||
bans.append(
|
bans.append(
|
||||||
DashboardBanItem(
|
DashboardBanItem(
|
||||||
ip=ip,
|
ip=ip,
|
||||||
jail=row.jail,
|
jail=companion_row.jail,
|
||||||
banned_at=_ts_to_iso(row.timeofban),
|
banned_at=_ts_to_iso(companion_row.timeofban),
|
||||||
service=matches[0] if matches else None,
|
service=matches[0] if matches else None,
|
||||||
country_code=cc,
|
country_code=cc,
|
||||||
country_name=cn,
|
country_name=cn,
|
||||||
asn=asn,
|
asn=asn,
|
||||||
org=org,
|
org=org,
|
||||||
ban_count=row.bancount,
|
ban_count=companion_row.bancount,
|
||||||
origin=_derive_origin(row.jail),
|
origin=_derive_origin(companion_row.jail),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -608,7 +626,7 @@ async def bans_by_jail(
|
|||||||
origin=origin,
|
origin=origin,
|
||||||
)
|
)
|
||||||
|
|
||||||
total, jails = await fail2ban_db_repo.get_bans_by_jail(
|
total, jail_counts = await fail2ban_db_repo.get_bans_by_jail(
|
||||||
db_path=db_path,
|
db_path=db_path,
|
||||||
since=since,
|
since=since,
|
||||||
origin=origin,
|
origin=origin,
|
||||||
@@ -634,11 +652,10 @@ async def bans_by_jail(
|
|||||||
log.debug(
|
log.debug(
|
||||||
"ban_service_bans_by_jail_result",
|
"ban_service_bans_by_jail_result",
|
||||||
total=total,
|
total=total,
|
||||||
jail_count=len(jails),
|
jail_count=len(jail_counts),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Pydantic strict validation requires either dicts or model instances.
|
return BansByJailResponse(
|
||||||
# Our repository returns dataclasses for simplicity, so convert them here.
|
jails=[JailBanCountModel(jail=j.jail, count=j.count) for j in jail_counts],
|
||||||
jail_dicts: list[dict[str, object]] = [asdict(j) for j in jails]
|
total=total,
|
||||||
|
)
|
||||||
return BansByJailResponse(jails=jail_dicts, total=total)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user