diff --git a/Docs/Tasks.md b/Docs/Tasks.md index f773def..9b5c492 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -131,7 +131,9 @@ After completing TASK B-5, a `geo_service` method (or via `geo_cache_repo` throu --- -#### TASK B-7 — Replace `Any` type annotations in `ban_service.py` +#### TASK B-7 — Replace `Any` type annotations in `ban_service.py` ✅ + +**Status:** Completed ✅ **Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations. diff --git a/backend/app/services/ban_service.py b/backend/app/services/ban_service.py index 14c9bc7..26d1687 100644 --- a/backend/app/services/ban_service.py +++ b/backend/app/services/ban_service.py @@ -13,15 +13,18 @@ from __future__ import annotations import asyncio import json import time +from collections.abc import Awaitable, Callable from dataclasses import asdict from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, TypeAlias import structlog if TYPE_CHECKING: import aiosqlite + from app.services.geo_service import GeoInfo + from app.models.ban import ( BLOCKLIST_JAIL, BUCKET_SECONDS, @@ -34,6 +37,7 @@ from app.models.ban import ( BanTrendResponse, DashboardBanItem, DashboardBanListResponse, + JailBanCount as JailBanCountModel, TimeRange, _derive_origin, bucket_count, @@ -46,6 +50,8 @@ if TYPE_CHECKING: log: structlog.stdlib.BoundLogger = structlog.get_logger() +GeoEnricher: TypeAlias = Callable[[str], Awaitable["GeoInfo"] | None] + # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- @@ -144,7 +150,7 @@ async def _get_fail2ban_db_path(socket_path: str) -> str: return str(data) -def _parse_data_json(raw: Any) -> tuple[list[str], int]: +def _parse_data_json(raw: object) -> tuple[list[str], int]: """Extract matches and failure count from the ``bans.data`` column. The ``data`` column stores a JSON blob with optional keys: @@ -162,10 +168,10 @@ def _parse_data_json(raw: Any) -> tuple[list[str], int]: if raw is None: return [], 0 - obj: dict[str, Any] = {} + obj: dict[str, object] = {} if isinstance(raw, str): try: - parsed: Any = json.loads(raw) + parsed: object = json.loads(raw) if isinstance(parsed, dict): obj = parsed # json.loads("null") → None, or other non-dict — treat as empty @@ -174,8 +180,20 @@ def _parse_data_json(raw: Any) -> tuple[list[str], int]: elif isinstance(raw, dict): obj = raw - matches: list[str] = [str(m) for m in (obj.get("matches") or [])] - failures: int = int(obj.get("failures", 0)) + raw_matches = obj.get("matches") + if isinstance(raw_matches, list): + matches: list[str] = [str(m) for m in raw_matches] + else: + matches = [] + + raw_failures = obj.get("failures") + failures: int = 0 + if isinstance(raw_failures, (int, float, str)): + try: + failures = int(raw_failures) + except (ValueError, TypeError): + failures = 0 + return matches, failures @@ -192,7 +210,7 @@ async def list_bans( page_size: int = _DEFAULT_PAGE_SIZE, http_session: aiohttp.ClientSession | None = None, app_db: aiosqlite.Connection | None = None, - geo_enricher: Any | None = None, + geo_enricher: GeoEnricher | None = None, origin: BanOrigin | None = None, ) -> DashboardBanListResponse: """Return a paginated list of bans within the selected time window. @@ -258,7 +276,7 @@ async def list_bans( # Batch-resolve geo data for all IPs on this page in a single API call. # This avoids hitting the 45 req/min single-IP rate limit when the # page contains many bans (e.g. after a large blocklist import). - geo_map: dict[str, Any] = {} + geo_map: dict[str, "GeoInfo"] = {} if http_session is not None and rows: page_ips: list[str] = [r.ip for r in rows] try: @@ -333,7 +351,7 @@ async def bans_by_country( socket_path: str, range_: TimeRange, http_session: aiohttp.ClientSession | None = None, - geo_enricher: Any | None = None, + geo_enricher: GeoEnricher | None = None, app_db: aiosqlite.Connection | None = None, origin: BanOrigin | None = None, ) -> BansByCountryResponse: @@ -410,7 +428,7 @@ async def bans_by_country( ) unique_ips: list[str] = [r.ip for r in agg_rows] - geo_map: dict[str, Any] = {} + geo_map: dict[str, "GeoInfo"] = {} if http_session is not None and unique_ips: # Serve only what is already in the in-memory cache — no API calls on @@ -431,7 +449,7 @@ async def bans_by_country( ) elif geo_enricher is not None and unique_ips: # Fallback: legacy per-IP enricher (used in tests / older callers). - async def _safe_lookup(ip: str) -> tuple[str, Any]: + async def _safe_lookup(ip: str) -> tuple[str, "GeoInfo" | None]: try: return ip, await geo_enricher(ip) except Exception: # noqa: BLE001 @@ -439,18 +457,18 @@ async def bans_by_country( return ip, None results = await asyncio.gather(*(_safe_lookup(ip) for ip in unique_ips)) - geo_map = dict(results) + geo_map = {ip: geo for ip, geo in results if geo is not None} # Build country aggregation from the SQL-grouped rows. countries: dict[str, int] = {} country_names: dict[str, str] = {} - for row in agg_rows: - ip: str = row.ip + for agg_row in agg_rows: + ip: str = agg_row.ip geo = geo_map.get(ip) cc: str | None = geo.country_code if geo else None cn: str | None = geo.country_name if geo else None - event_count: int = row.event_count + event_count: int = agg_row.event_count if cc: countries[cc] = countries.get(cc, 0) + event_count @@ -459,27 +477,27 @@ async def bans_by_country( # Build companion table from recent rows (geo already cached from batch step). bans: list[DashboardBanItem] = [] - for row in companion_rows: - ip = row.ip + for companion_row in companion_rows: + ip = companion_row.ip geo = geo_map.get(ip) cc = geo.country_code if geo else None cn = geo.country_name if geo else None asn: str | None = geo.asn if geo else None org: str | None = geo.org if geo else None - matches, _ = _parse_data_json(row.data) + matches, _ = _parse_data_json(companion_row.data) bans.append( DashboardBanItem( ip=ip, - jail=row.jail, - banned_at=_ts_to_iso(row.timeofban), + jail=companion_row.jail, + banned_at=_ts_to_iso(companion_row.timeofban), service=matches[0] if matches else None, country_code=cc, country_name=cn, asn=asn, org=org, - ban_count=row.bancount, - origin=_derive_origin(row.jail), + ban_count=companion_row.bancount, + origin=_derive_origin(companion_row.jail), ) ) @@ -608,7 +626,7 @@ async def bans_by_jail( origin=origin, ) - total, jails = await fail2ban_db_repo.get_bans_by_jail( + total, jail_counts = await fail2ban_db_repo.get_bans_by_jail( db_path=db_path, since=since, origin=origin, @@ -634,11 +652,10 @@ async def bans_by_jail( log.debug( "ban_service_bans_by_jail_result", total=total, - jail_count=len(jails), + jail_count=len(jail_counts), ) - # Pydantic strict validation requires either dicts or model instances. - # Our repository returns dataclasses for simplicity, so convert them here. - jail_dicts: list[dict[str, object]] = [asdict(j) for j in jails] - - return BansByJailResponse(jails=jail_dicts, total=total) + return BansByJailResponse( + jails=[JailBanCountModel(jail=j.jail, count=j.count) for j in jail_counts], + total=total, + )