diff --git a/Docs/Tasks.md b/Docs/Tasks.md index a8d80e3..265c094 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -158,6 +158,8 @@ After completing TASK B-5, a `geo_service` method (or via `geo_cache_repo` throu #### TASK B-8 — Remove `print()` from `geo_service.py` docstring example +**Status:** Completed ✅ + **Violated rule:** Refactoring.md §4 / Backend-Development.md §2 — Never use `print()` in production code; use `structlog`. **Files affected:** @@ -229,6 +231,8 @@ Remove or rewrite the docstring snippet so it does not contain a bare `print()` #### TASK F-1 — Wrap `SetupPage` API calls in a dedicated hook +**Status:** Completed ✅ + **Violated rule:** Refactoring.md §3.1 — Pages must not call API functions from `src/api/` directly; all data fetching goes through hooks. **Files affected:** @@ -409,6 +413,8 @@ For each component listed: #### TASK B-13 — Remove `Any` type annotations in `jail_service.py` +**Status:** Completed ✅ + **Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations. **Files affected:** @@ -424,6 +430,8 @@ For each component listed: #### TASK B-14 — Remove `Any` type annotations in `health_service.py` +**Status:** Completed ✅ + **Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations. **Files affected:** @@ -439,6 +447,8 @@ For each component listed: #### TASK B-15 — Remove `Any` type annotations in `blocklist_service.py` +**Status:** Completed ✅ + **Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations. **Files affected:** @@ -454,6 +464,8 @@ For each component listed: #### TASK B-16 — Remove `Any` type annotations in `import_log_repo.py` +**Status:** Completed ✅ + **Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations. **Files affected:** @@ -470,6 +482,8 @@ For each component listed: #### TASK B-17 — Remove `Any` type annotations in `config_file_service.py` +**Status:** Completed ✅ + **Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations. **Files affected:** @@ -485,6 +499,8 @@ For each component listed: #### TASK B-18 — Remove `Any` type annotations in `fail2ban_client.py` +**Status:** Completed ✅ + **Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations. **Files affected:** @@ -500,6 +516,8 @@ For each component listed: #### TASK B-19 — Remove `Any` annotations from background tasks +**Status:** Completed ✅ + **Violated rule:** Backend-Development.md §1 — Never use `Any`; all functions must have explicit type annotations. **Files affected:** @@ -517,6 +535,8 @@ For each component listed: #### TASK B-20 — Remove `type: ignore` in `dependencies.get_settings` +**Status:** Completed ✅ + **Violated rule:** Backend-Development.md §1 — Avoid `Any` and ignored type errors. **Files affected:** @@ -527,3 +547,19 @@ For each component listed: 1. Introduce a typed model (e.g., `TypedDict` or `Protocol`) for `app.state` to declare `settings: Settings` and other shared state properties. 2. Update `get_settings` (and any other helpers that read from `app.state`) so the return type is inferred as `Settings` without needing a `type: ignore` comment. 3. Run `mypy --strict` or `pyright` to confirm the type ignore is no longer needed. + +--- + +#### TASK B-21 — Fix `geo_re_resolve` test mocks to support async calls + +**Status:** Completed ✅ + +**Violated rule:** Test code must correctly mock async coroutines (`AsyncMock`) when awaited. + +**Files affected:** +- `backend/tests/test_tasks/test_geo_re_resolve.py` — patched `geo_service` to ensure `get_unresolved_ips` is an `AsyncMock`. + +**What to do:** + +1. Ensure all mocks for async service methods are `AsyncMock` so they can be awaited. +2. Run `pytest -q -c backend/pyproject.toml` to confirm the test suite passes. diff --git a/backend/app/dependencies.py b/backend/app/dependencies.py index 0afb7d4..7505073 100644 --- a/backend/app/dependencies.py +++ b/backend/app/dependencies.py @@ -7,7 +7,7 @@ directly — to keep coupling explicit and testable. """ import time -from typing import Annotated +from typing import Annotated, Protocol, cast import aiosqlite import structlog @@ -19,6 +19,13 @@ from app.utils.time_utils import utc_now log: structlog.stdlib.BoundLogger = structlog.get_logger() + +class AppState(Protocol): + """Partial view of the FastAPI application state used by dependencies.""" + + settings: Settings + + _COOKIE_NAME = "bangui_session" # --------------------------------------------------------------------------- @@ -85,7 +92,8 @@ async def get_settings(request: Request) -> Settings: Returns: The application settings loaded at startup. """ - return request.app.state.settings # type: ignore[no-any-return] + state = cast(AppState, request.app.state) + return state.settings async def require_auth( diff --git a/backend/app/repositories/import_log_repo.py b/backend/app/repositories/import_log_repo.py index 6ec284e..860fc51 100644 --- a/backend/app/repositories/import_log_repo.py +++ b/backend/app/repositories/import_log_repo.py @@ -8,12 +8,25 @@ table. All methods are plain async functions that accept a from __future__ import annotations import math -from typing import TYPE_CHECKING, Any +from collections.abc import Mapping +from typing import TYPE_CHECKING, TypedDict, cast if TYPE_CHECKING: import aiosqlite +class ImportLogRow(TypedDict): + """Row shape returned by queries on the import_log table.""" + + id: int + source_id: int | None + source_url: str + timestamp: str + ips_imported: int + ips_skipped: int + errors: str | None + + async def add_log( db: aiosqlite.Connection, *, @@ -54,7 +67,7 @@ async def list_logs( source_id: int | None = None, page: int = 1, page_size: int = 50, -) -> tuple[list[dict[str, Any]], int]: +) -> tuple[list[ImportLogRow], int]: """Return a paginated list of import log entries. Args: @@ -68,8 +81,8 @@ async def list_logs( *total* is the count of all matching rows (ignoring pagination). """ where = "" - params_count: list[Any] = [] - params_rows: list[Any] = [] + params_count: list[object] = [] + params_rows: list[object] = [] if source_id is not None: where = " WHERE source_id = ?" @@ -102,7 +115,7 @@ async def list_logs( return items, total -async def get_last_log(db: aiosqlite.Connection) -> dict[str, Any] | None: +async def get_last_log(db: aiosqlite.Connection) -> ImportLogRow | None: """Return the most recent import log entry across all sources. Args: @@ -143,13 +156,14 @@ def compute_total_pages(total: int, page_size: int) -> int: # --------------------------------------------------------------------------- -def _row_to_dict(row: Any) -> dict[str, Any]: +def _row_to_dict(row: object) -> ImportLogRow: """Convert an aiosqlite row to a plain Python dict. Args: - row: An :class:`aiosqlite.Row` or sequence returned by a cursor. + row: An :class:`aiosqlite.Row` or similar mapping returned by a cursor. Returns: Dict mapping column names to Python values. """ - return dict(row) + mapping = cast(Mapping[str, object], row) + return cast(ImportLogRow, dict(mapping)) diff --git a/backend/app/services/blocklist_service.py b/backend/app/services/blocklist_service.py index 23df0d1..91c7671 100644 --- a/backend/app/services/blocklist_service.py +++ b/backend/app/services/blocklist_service.py @@ -15,7 +15,7 @@ under the key ``"blocklist_schedule"``. from __future__ import annotations import json -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import structlog @@ -56,7 +56,7 @@ _PREVIEW_MAX_BYTES: int = 65536 # --------------------------------------------------------------------------- -def _row_to_source(row: dict[str, Any]) -> BlocklistSource: +def _row_to_source(row: dict[str, object]) -> BlocklistSource: """Convert a repository row dict to a :class:`BlocklistSource`. Args: @@ -542,7 +542,7 @@ async def list_import_logs( # --------------------------------------------------------------------------- -def _aiohttp_timeout(seconds: float) -> Any: +def _aiohttp_timeout(seconds: float) -> "aiohttp.ClientTimeout": """Return an :class:`aiohttp.ClientTimeout` with the given total timeout. Args: diff --git a/backend/app/services/config_file_service.py b/backend/app/services/config_file_service.py index ef9ca46..c31a24a 100644 --- a/backend/app/services/config_file_service.py +++ b/backend/app/services/config_file_service.py @@ -28,7 +28,7 @@ import os import re import tempfile from pathlib import Path -from typing import Any +from typing import TYPE_CHECKING, cast, TypeAlias import structlog @@ -57,7 +57,12 @@ from app.models.config import ( from app.services import jail_service from app.services.jail_service import JailNotFoundError as JailNotFoundError from app.utils import conffile_parser -from app.utils.fail2ban_client import Fail2BanClient, Fail2BanConnectionError +from app.utils.fail2ban_client import ( + Fail2BanClient, + Fail2BanCommand, + Fail2BanConnectionError, + Fail2BanResponse, +) log: structlog.stdlib.BoundLogger = structlog.get_logger() @@ -539,10 +544,10 @@ async def _get_active_jail_names(socket_path: str) -> set[str]: try: client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) - def _to_dict_inner(pairs: Any) -> dict[str, Any]: + def _to_dict_inner(pairs: object) -> dict[str, object]: if not isinstance(pairs, (list, tuple)): return {} - result: dict[str, Any] = {} + result: dict[str, object] = {} for item in pairs: try: k, v = item @@ -551,8 +556,8 @@ async def _get_active_jail_names(socket_path: str) -> set[str]: pass return result - def _ok(response: Any) -> Any: - code, data = response + def _ok(response: object) -> object: + code, data = cast(Fail2BanResponse, response) if code != 0: raise ValueError(f"fail2ban error {code}: {data!r}") return data @@ -813,7 +818,7 @@ def _write_local_override_sync( config_dir: Path, jail_name: str, enabled: bool, - overrides: dict[str, Any], + overrides: dict[str, object], ) -> None: """Write a ``jail.d/{name}.local`` file atomically. @@ -862,7 +867,7 @@ def _write_local_override_sync( if overrides.get("port") is not None: lines.append(f"port = {overrides['port']}") if overrides.get("logpath"): - paths: list[str] = overrides["logpath"] + paths: list[str] = cast(list[str], overrides["logpath"]) if paths: lines.append(f"logpath = {paths[0]}") for p in paths[1:]: @@ -1209,7 +1214,7 @@ async def activate_jail( ), ) - overrides: dict[str, Any] = { + overrides: dict[str, object] = { "bantime": req.bantime, "findtime": req.findtime, "maxretry": req.maxretry, diff --git a/backend/app/services/geo_service.py b/backend/app/services/geo_service.py index f9e2b7f..b2b5f14 100644 --- a/backend/app/services/geo_service.py +++ b/backend/app/services/geo_service.py @@ -30,7 +30,8 @@ Usage:: # single lookup info = await geo_service.lookup("1.2.3.4", session) if info: - print(info.country_code) # "DE" + # info.country_code == "DE" + ... # use the GeoInfo object in your application # bulk lookup (more efficient for large sets) geo_map = await geo_service.lookup_batch(["1.2.3.4", "5.6.7.8"], session) @@ -42,7 +43,7 @@ import asyncio import time from collections.abc import Awaitable, Callable from dataclasses import dataclass -from typing import TYPE_CHECKING, TypeAlias +from typing import TYPE_CHECKING import aiohttp import structlog @@ -119,7 +120,7 @@ class GeoInfo: """Organisation name associated with the IP, e.g. ``"Deutsche Telekom"``.""" -GeoEnricher: TypeAlias = Callable[[str], Awaitable[GeoInfo | None]] +type GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]] """Async callable used to enrich IPs with :class:`~app.services.geo_service.GeoInfo`. This is a shared type alias used by services that optionally accept a geo diff --git a/backend/app/services/health_service.py b/backend/app/services/health_service.py index df9750d..87322c1 100644 --- a/backend/app/services/health_service.py +++ b/backend/app/services/health_service.py @@ -9,12 +9,17 @@ seconds by the background health-check task, not on every HTTP request. from __future__ import annotations -from typing import Any +from typing import cast import structlog from app.models.server import ServerStatus -from app.utils.fail2ban_client import Fail2BanClient, Fail2BanConnectionError, Fail2BanProtocolError +from app.utils.fail2ban_client import ( + Fail2BanClient, + Fail2BanConnectionError, + Fail2BanProtocolError, + Fail2BanResponse, +) log: structlog.stdlib.BoundLogger = structlog.get_logger() @@ -25,7 +30,7 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger() _SOCKET_TIMEOUT: float = 5.0 -def _ok(response: Any) -> Any: +def _ok(response: object) -> object: """Extract the payload from a fail2ban ``(return_code, data)`` response. fail2ban wraps every response in a ``(0, data)`` success tuple or @@ -42,7 +47,7 @@ def _ok(response: Any) -> Any: ValueError: If the response indicates an error (return code ≠ 0). """ try: - code, data = response + code, data = cast(Fail2BanResponse, response) except (TypeError, ValueError) as exc: raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc @@ -52,7 +57,7 @@ def _ok(response: Any) -> Any: return data -def _to_dict(pairs: Any) -> dict[str, Any]: +def _to_dict(pairs: object) -> dict[str, object]: """Convert a list of ``(key, value)`` pairs to a plain dict. fail2ban returns structured data as lists of 2-tuples rather than dicts. @@ -66,7 +71,7 @@ def _to_dict(pairs: Any) -> dict[str, Any]: """ if not isinstance(pairs, (list, tuple)): return {} - result: dict[str, Any] = {} + result: dict[str, object] = {} for item in pairs: try: k, v = item @@ -119,7 +124,7 @@ async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerSta # 3. Global status — jail count and names # # ------------------------------------------------------------------ # status_data = _to_dict(_ok(await client.send(["status"]))) - active_jails: int = int(status_data.get("Number of jail", 0) or 0) + active_jails: int = int(str(status_data.get("Number of jail", 0) or 0)) jail_list_raw: str = str(status_data.get("Jail list", "") or "").strip() jail_names: list[str] = ( [j.strip() for j in jail_list_raw.split(",") if j.strip()] @@ -138,8 +143,8 @@ async def probe(socket_path: str, timeout: float = _SOCKET_TIMEOUT) -> ServerSta jail_resp = _to_dict(_ok(await client.send(["status", jail_name]))) filter_stats = _to_dict(jail_resp.get("Filter") or []) action_stats = _to_dict(jail_resp.get("Actions") or []) - total_failures += int(filter_stats.get("Currently failed", 0) or 0) - total_bans += int(action_stats.get("Currently banned", 0) or 0) + total_failures += int(str(filter_stats.get("Currently failed", 0) or 0)) + total_bans += int(str(action_stats.get("Currently banned", 0) or 0)) except (ValueError, TypeError, KeyError) as exc: log.warning( "fail2ban_jail_status_parse_error", diff --git a/backend/app/services/jail_service.py b/backend/app/services/jail_service.py index bc84d38..958a6ec 100644 --- a/backend/app/services/jail_service.py +++ b/backend/app/services/jail_service.py @@ -14,7 +14,7 @@ from __future__ import annotations import asyncio import contextlib import ipaddress -from typing import Any +from typing import TYPE_CHECKING, Awaitable, Callable, cast, TypeAlias import structlog @@ -27,10 +27,24 @@ from app.models.jail import ( JailStatus, JailSummary, ) -from app.utils.fail2ban_client import Fail2BanClient, Fail2BanConnectionError +from app.utils.fail2ban_client import ( + Fail2BanClient, + Fail2BanCommand, + Fail2BanConnectionError, + Fail2BanResponse, + Fail2BanToken, +) + +if TYPE_CHECKING: + import aiohttp + import aiosqlite + + from app.services.geo_service import GeoInfo log: structlog.stdlib.BoundLogger = structlog.get_logger() +GeoEnricher: TypeAlias = Callable[[str], Awaitable["GeoInfo | None"]] + # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- @@ -77,7 +91,7 @@ class JailOperationError(Exception): # --------------------------------------------------------------------------- -def _ok(response: Any) -> Any: +def _ok(response: object) -> object: """Extract the payload from a fail2ban ``(return_code, data)`` response. Args: @@ -90,7 +104,7 @@ def _ok(response: Any) -> Any: ValueError: If the response indicates an error (return code ≠ 0). """ try: - code, data = response + code, data = cast(Fail2BanResponse, response) except (TypeError, ValueError) as exc: raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc @@ -100,7 +114,7 @@ def _ok(response: Any) -> Any: return data -def _to_dict(pairs: Any) -> dict[str, Any]: +def _to_dict(pairs: object) -> dict[str, object]: """Convert a list of ``(key, value)`` pairs to a plain dict. Args: @@ -111,7 +125,7 @@ def _to_dict(pairs: Any) -> dict[str, Any]: """ if not isinstance(pairs, (list, tuple)): return {} - result: dict[str, Any] = {} + result: dict[str, object] = {} for item in pairs: try: k, v = item @@ -121,7 +135,7 @@ def _to_dict(pairs: Any) -> dict[str, Any]: return result -def _ensure_list(value: Any) -> list[str]: +def _ensure_list(value: object | None) -> list[str]: """Coerce a fail2ban response value to a list of strings. Some fail2ban ``get`` responses return ``None`` or a single string @@ -170,9 +184,9 @@ def _is_not_found_error(exc: Exception) -> bool: async def _safe_get( client: Fail2BanClient, - command: list[Any], - default: Any = None, -) -> Any: + command: Fail2BanCommand, + default: object | None = None, +) -> object | None: """Send a ``get`` command and return ``default`` on error. Errors during optional detail queries (logpath, regex, etc.) should @@ -187,7 +201,8 @@ async def _safe_get( The response payload, or *default* on any error. """ try: - return _ok(await client.send(command)) + response = await client.send(command) + return _ok(cast(Fail2BanResponse, response)) except (ValueError, TypeError, Exception): return default @@ -309,7 +324,7 @@ async def _fetch_jail_summary( backend_cmd_is_supported = await _check_backend_cmd_supported(client, name) # Build the gather list based on command support. - gather_list: list[Any] = [ + gather_list: list[Awaitable[object]] = [ client.send(["status", name, "short"]), client.send(["get", name, "bantime"]), client.send(["get", name, "findtime"]), @@ -325,7 +340,7 @@ async def _fetch_jail_summary( uses_backend_backend_commands = True else: # Commands not supported; return default values without sending. - async def _return_default(value: Any) -> tuple[int, Any]: + async def _return_default(value: object | None) -> Fail2BanResponse: return (0, value) gather_list.extend([ @@ -335,12 +350,12 @@ async def _fetch_jail_summary( uses_backend_backend_commands = False _r = await asyncio.gather(*gather_list, return_exceptions=True) - status_raw: Any = _r[0] - bantime_raw: Any = _r[1] - findtime_raw: Any = _r[2] - maxretry_raw: Any = _r[3] - backend_raw: Any = _r[4] - idle_raw: Any = _r[5] + status_raw: object | Exception = _r[0] + bantime_raw: object | Exception = _r[1] + findtime_raw: object | Exception = _r[2] + maxretry_raw: object | Exception = _r[3] + backend_raw: object | Exception = _r[4] + idle_raw: object | Exception = _r[5] # Parse jail status (filter + actions). jail_status: JailStatus | None = None @@ -350,35 +365,35 @@ async def _fetch_jail_summary( filter_stats = _to_dict(raw.get("Filter") or []) action_stats = _to_dict(raw.get("Actions") or []) jail_status = JailStatus( - currently_banned=int(action_stats.get("Currently banned", 0) or 0), - total_banned=int(action_stats.get("Total banned", 0) or 0), - currently_failed=int(filter_stats.get("Currently failed", 0) or 0), - total_failed=int(filter_stats.get("Total failed", 0) or 0), + currently_banned=int(str(action_stats.get("Currently banned", 0) or 0)), + total_banned=int(str(action_stats.get("Total banned", 0) or 0)), + currently_failed=int(str(filter_stats.get("Currently failed", 0) or 0)), + total_failed=int(str(filter_stats.get("Total failed", 0) or 0)), ) except (ValueError, TypeError) as exc: log.warning("jail_status_parse_error", jail=name, error=str(exc)) - def _safe_int(raw: Any, fallback: int) -> int: + def _safe_int(raw: object | Exception, fallback: int) -> int: if isinstance(raw, Exception): return fallback try: - return int(_ok(raw)) + return int(str(_ok(cast(Fail2BanResponse, raw)))) except (ValueError, TypeError): return fallback - def _safe_str(raw: Any, fallback: str) -> str: + def _safe_str(raw: object | Exception, fallback: str) -> str: if isinstance(raw, Exception): return fallback try: - return str(_ok(raw)) + return str(_ok(cast(Fail2BanResponse, raw))) except (ValueError, TypeError): return fallback - def _safe_bool(raw: Any, fallback: bool = False) -> bool: + def _safe_bool(raw: object | Exception, fallback: bool = False) -> bool: if isinstance(raw, Exception): return fallback try: - return bool(_ok(raw)) + return bool(_ok(cast(Fail2BanResponse, raw))) except (ValueError, TypeError): return fallback @@ -428,10 +443,10 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse: action_stats = _to_dict(raw.get("Actions") or []) jail_status = JailStatus( - currently_banned=int(action_stats.get("Currently banned", 0) or 0), - total_banned=int(action_stats.get("Total banned", 0) or 0), - currently_failed=int(filter_stats.get("Currently failed", 0) or 0), - total_failed=int(filter_stats.get("Total failed", 0) or 0), + currently_banned=int(str(action_stats.get("Currently banned", 0) or 0)), + total_banned=int(str(action_stats.get("Total banned", 0) or 0)), + currently_failed=int(str(filter_stats.get("Currently failed", 0) or 0)), + total_failed=int(str(filter_stats.get("Total failed", 0) or 0)), ) # Fetch all detail fields in parallel. @@ -480,11 +495,11 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse: bt_increment: bool = bool(bt_increment_raw) bantime_escalation = BantimeEscalation( increment=bt_increment, - factor=float(bt_factor_raw) if bt_factor_raw is not None else None, + factor=float(str(bt_factor_raw)) if bt_factor_raw is not None else None, formula=str(bt_formula_raw) if bt_formula_raw else None, multipliers=str(bt_multipliers_raw) if bt_multipliers_raw else None, - max_time=int(bt_maxtime_raw) if bt_maxtime_raw is not None else None, - rnd_time=int(bt_rndtime_raw) if bt_rndtime_raw is not None else None, + max_time=int(str(bt_maxtime_raw)) if bt_maxtime_raw is not None else None, + rnd_time=int(str(bt_rndtime_raw)) if bt_rndtime_raw is not None else None, overall_jails=bool(bt_overalljails_raw), ) @@ -500,9 +515,9 @@ async def get_jail(socket_path: str, name: str) -> JailDetailResponse: ignore_ips=_ensure_list(ignoreip_raw), date_pattern=str(datepattern_raw) if datepattern_raw else None, log_encoding=str(logencoding_raw or "UTF-8"), - find_time=int(findtime_raw or 600), - ban_time=int(bantime_raw or 600), - max_retry=int(maxretry_raw or 5), + find_time=int(str(findtime_raw or 600)), + ban_time=int(str(bantime_raw or 600)), + max_retry=int(str(maxretry_raw or 5)), bantime_escalation=bantime_escalation, status=jail_status, actions=_ensure_list(actions_raw), @@ -671,8 +686,8 @@ async def reload_all( if exclude_jails: names_set -= set(exclude_jails) - stream: list[list[str]] = [["start", n] for n in sorted(names_set)] - _ok(await client.send(["reload", "--all", [], stream])) + stream: list[list[object]] = [["start", n] for n in sorted(names_set)] + _ok(await client.send(["reload", "--all", [], cast(Fail2BanToken, stream)])) log.info("all_jails_reloaded") except ValueError as exc: # Detect UnknownJailException (missing or invalid jail configuration) @@ -795,9 +810,9 @@ async def unban_ip( async def get_active_bans( socket_path: str, - geo_enricher: Any | None = None, - http_session: Any | None = None, - app_db: Any | None = None, + geo_enricher: GeoEnricher | None = None, + http_session: "aiohttp.ClientSession" | None = None, + app_db: "aiosqlite.Connection" | None = None, ) -> ActiveBanListResponse: """Return all currently banned IPs across every jail. @@ -849,7 +864,7 @@ async def get_active_bans( return ActiveBanListResponse(bans=[], total=0) # For each jail, fetch the ban list with time info in parallel. - results: list[Any] = await asyncio.gather( + results: list[object | Exception] = await asyncio.gather( *[client.send(["get", jn, "banip", "--with-time"]) for jn in jail_names], return_exceptions=True, ) @@ -865,7 +880,7 @@ async def get_active_bans( continue try: - ban_list: list[str] = _ok(raw_result) or [] + ban_list: list[str] = cast(list[str], _ok(raw_result)) or [] except (TypeError, ValueError) as exc: log.warning( "active_bans_parse_error", @@ -992,8 +1007,8 @@ async def get_jail_banned_ips( page: int = 1, page_size: int = 25, search: str | None = None, - http_session: Any | None = None, - app_db: Any | None = None, + http_session: "aiohttp.ClientSession" | None = None, + app_db: "aiosqlite.Connection" | None = None, ) -> JailBannedIpsResponse: """Return a paginated list of currently banned IPs for a single jail. @@ -1040,7 +1055,7 @@ async def get_jail_banned_ips( except (ValueError, TypeError): raw_result = [] - ban_list: list[str] = raw_result or [] + ban_list: list[str] = cast(list[str], raw_result) or [] # Parse all entries. all_bans: list[ActiveBan] = [] @@ -1094,7 +1109,7 @@ async def get_jail_banned_ips( async def _enrich_bans( bans: list[ActiveBan], - geo_enricher: Any, + geo_enricher: GeoEnricher, ) -> list[ActiveBan]: """Enrich ban records with geo data asynchronously. @@ -1105,14 +1120,15 @@ async def _enrich_bans( Returns: The same list with ``country`` fields populated where lookup succeeded. """ - geo_results: list[Any] = await asyncio.gather( - *[geo_enricher(ban.ip) for ban in bans], + geo_results: list[object | Exception] = await asyncio.gather( + *[cast(Awaitable[object], geo_enricher(ban.ip)) for ban in bans], return_exceptions=True, ) enriched: list[ActiveBan] = [] for ban, geo in zip(bans, geo_results, strict=False): if geo is not None and not isinstance(geo, Exception): - enriched.append(ban.model_copy(update={"country": geo.country_code})) + geo_info = cast("GeoInfo", geo) + enriched.append(ban.model_copy(update={"country": geo_info.country_code})) else: enriched.append(ban) return enriched @@ -1260,8 +1276,8 @@ async def set_ignore_self(socket_path: str, name: str, *, on: bool) -> None: async def lookup_ip( socket_path: str, ip: str, - geo_enricher: Any | None = None, -) -> dict[str, Any]: + geo_enricher: GeoEnricher | None = None, +) -> dict[str, object | list[str] | None]: """Return ban status and history for a single IP address. Checks every running jail for whether the IP is currently banned. @@ -1304,7 +1320,7 @@ async def lookup_ip( ) # Check ban status per jail in parallel. - ban_results: list[Any] = await asyncio.gather( + ban_results: list[object | Exception] = await asyncio.gather( *[client.send(["get", jn, "banip"]) for jn in jail_names], return_exceptions=True, ) @@ -1314,7 +1330,7 @@ async def lookup_ip( if isinstance(result, Exception): continue try: - ban_list: list[str] = _ok(result) or [] + ban_list: list[str] = cast(list[str], _ok(result)) or [] if ip in ban_list: currently_banned_in.append(jail_name) except (ValueError, TypeError): @@ -1351,6 +1367,6 @@ async def unban_all_ips(socket_path: str) -> int: cannot be reached. """ client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) - count: int = int(_ok(await client.send(["unban", "--all"]))) + count: int = int(str(_ok(await client.send(["unban", "--all"])) or 0)) log.info("all_ips_unbanned", count=count) return count diff --git a/backend/app/tasks/geo_re_resolve.py b/backend/app/tasks/geo_re_resolve.py index c01f6fc..e3f85fe 100644 --- a/backend/app/tasks/geo_re_resolve.py +++ b/backend/app/tasks/geo_re_resolve.py @@ -17,7 +17,7 @@ The task runs every 10 minutes. On each invocation it: from __future__ import annotations -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import structlog @@ -35,7 +35,7 @@ GEO_RE_RESOLVE_INTERVAL: int = 600 JOB_ID: str = "geo_re_resolve" -async def _run_re_resolve(app: Any) -> None: +async def _run_re_resolve(app: "FastAPI") -> None: """Query NULL-country IPs from the database and re-resolve them. Reads shared resources from ``app.state`` and delegates to diff --git a/backend/app/tasks/health_check.py b/backend/app/tasks/health_check.py index 6e82b69..597b92d 100644 --- a/backend/app/tasks/health_check.py +++ b/backend/app/tasks/health_check.py @@ -18,7 +18,7 @@ within 60 seconds of that activation, a from __future__ import annotations import datetime -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, TypedDict import structlog @@ -31,6 +31,14 @@ if TYPE_CHECKING: # pragma: no cover log: structlog.stdlib.BoundLogger = structlog.get_logger() + +class ActivationRecord(TypedDict): + """Stored timestamp data for a jail activation event.""" + + jail_name: str + at: datetime.datetime + + #: How often the probe fires (seconds). HEALTH_CHECK_INTERVAL: int = 30 @@ -39,7 +47,7 @@ HEALTH_CHECK_INTERVAL: int = 30 _ACTIVATION_CRASH_WINDOW: int = 60 -async def _run_probe(app: Any) -> None: +async def _run_probe(app: "FastAPI") -> None: """Probe fail2ban and cache the result on *app.state*. Detects online/offline state transitions. When fail2ban goes offline @@ -86,7 +94,7 @@ async def _run_probe(app: Any) -> None: elif not status.online and prev_status.online: log.warning("fail2ban_went_offline") # Check whether this crash happened shortly after a jail activation. - last_activation: dict[str, Any] | None = getattr( + last_activation: ActivationRecord | None = getattr( app.state, "last_activation", None ) if last_activation is not None: diff --git a/backend/tests/test_tasks/test_geo_re_resolve.py b/backend/tests/test_tasks/test_geo_re_resolve.py index 23ceb66..33eead1 100644 --- a/backend/tests/test_tasks/test_geo_re_resolve.py +++ b/backend/tests/test_tasks/test_geo_re_resolve.py @@ -79,6 +79,8 @@ async def test_run_re_resolve_no_unresolved_ips_skips() -> None: app = _make_app(unresolved_ips=[]) with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo: + mock_geo.get_unresolved_ips = AsyncMock(return_value=[]) + await _run_re_resolve(app) mock_geo.clear_neg_cache.assert_not_called() @@ -96,6 +98,7 @@ async def test_run_re_resolve_clears_neg_cache() -> None: app = _make_app(unresolved_ips=ips, lookup_result=result) with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo: + mock_geo.get_unresolved_ips = AsyncMock(return_value=ips) mock_geo.lookup_batch = AsyncMock(return_value=result) await _run_re_resolve(app) @@ -114,6 +117,7 @@ async def test_run_re_resolve_calls_lookup_batch_with_db() -> None: app = _make_app(unresolved_ips=ips, lookup_result=result) with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo: + mock_geo.get_unresolved_ips = AsyncMock(return_value=ips) mock_geo.lookup_batch = AsyncMock(return_value=result) await _run_re_resolve(app) @@ -137,6 +141,7 @@ async def test_run_re_resolve_logs_correct_counts(caplog: Any) -> None: app = _make_app(unresolved_ips=ips, lookup_result=result) with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo: + mock_geo.get_unresolved_ips = AsyncMock(return_value=ips) mock_geo.lookup_batch = AsyncMock(return_value=result) await _run_re_resolve(app) @@ -159,6 +164,7 @@ async def test_run_re_resolve_handles_all_resolved() -> None: app = _make_app(unresolved_ips=ips, lookup_result=result) with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo: + mock_geo.get_unresolved_ips = AsyncMock(return_value=ips) mock_geo.lookup_batch = AsyncMock(return_value=result) await _run_re_resolve(app) diff --git a/frontend/src/pages/SetupPage.tsx b/frontend/src/pages/SetupPage.tsx index 9a10dca..6a266ad 100644 --- a/frontend/src/pages/SetupPage.tsx +++ b/frontend/src/pages/SetupPage.tsx @@ -99,37 +99,18 @@ export function SetupPage(): React.JSX.Element { const styles = useStyles(); const navigate = useNavigate(); - const [checking, setChecking] = useState(true); + const { status, loading, error, submit, submitting, submitError } = useSetup(); const [values, setValues] = useState(DEFAULT_VALUES); const [errors, setErrors] = useState>>({}); - const [apiError, setApiError] = useState(null); - const [submitting, setSubmitting] = useState(false); + const apiError = error ?? submitError; // Redirect to /login if setup has already been completed. - // Show a full-screen spinner while the check is in flight to prevent - // the form from flashing before the redirect fires. + // Show a full-screen spinner while the initial status check is in flight. useEffect(() => { - let cancelled = false; - getSetupStatus() - .then((res) => { - if (!cancelled) { - if (res.completed) { - navigate("/login", { replace: true }); - } else { - setChecking(false); - } - } - }) - .catch(() => { - // Failed check: the backend may still be starting up. Stay on this - // page so the user can attempt setup once the backend is ready. - console.warn("SetupPage: setup status check failed — rendering setup form"); - if (!cancelled) setChecking(false); - }); - return (): void => { - cancelled = true; - }; - }, [navigate]); + if (status?.completed) { + navigate("/login", { replace: true }); + } + }, [navigate, status]); // --------------------------------------------------------------------------- // Handlers @@ -169,13 +150,11 @@ export function SetupPage(): React.JSX.Element { async function handleSubmit(ev: FormEvent): Promise { ev.preventDefault(); - setApiError(null); if (!validate()) return; - setSubmitting(true); try { - await submitSetup({ + await submit({ master_password: values.masterPassword, database_path: values.databasePath, fail2ban_socket: values.fail2banSocket, @@ -183,14 +162,8 @@ export function SetupPage(): React.JSX.Element { session_duration_minutes: parseInt(values.sessionDurationMinutes, 10), }); navigate("/login", { replace: true }); - } catch (err) { - if (err instanceof ApiError) { - setApiError(err.message || `Error ${String(err.status)}`); - } else { - setApiError("An unexpected error occurred. Please try again."); - } - } finally { - setSubmitting(false); + } catch { + // Errors are surfaced through the hook via `submitError`. } } @@ -198,7 +171,7 @@ export function SetupPage(): React.JSX.Element { // Render // --------------------------------------------------------------------------- - if (checking) { + if (loading) { return (
- {apiError !== null && ( + {apiError && ( {apiError}