From 1c0bac13530086b44ed99ab71378ca735a8790c4 Mon Sep 17 00:00:00 2001 From: Lukas Date: Fri, 20 Mar 2026 13:44:14 +0100 Subject: [PATCH] refactor: improve backend type safety and import organization - Add TYPE_CHECKING guards for runtime-expensive imports (aiohttp, aiosqlite) - Reorganize imports to follow PEP 8 conventions - Convert TypeAlias to modern PEP 695 type syntax (where appropriate) - Use Sequence/Mapping from collections.abc for type hints (covariant) - Replace string literals with cast() for improved type inference - Fix casting of Fail2BanResponse and TypedDict patterns - Add IpLookupResult TypedDict for precise return type annotation - Reformat overlong lines for readability (120 char limit) - Add asyncio_mode and filterwarnings to pytest config - Update test fixtures with improved type hints This improves mypy type checking and makes type relationships explicit. --- backend/app/config.py | 2 +- backend/app/dependencies.py | 2 +- backend/app/repositories/geo_cache_repo.py | 4 +- backend/app/repositories/import_log_repo.py | 7 +- backend/app/routers/config.py | 53 +-- backend/app/routers/geo.py | 9 +- backend/app/services/ban_service.py | 31 +- backend/app/services/blocklist_service.py | 2 +- backend/app/services/config_file_service.py | 236 +++------- backend/app/services/config_service.py | 12 +- backend/app/services/health_service.py | 2 +- backend/app/services/history_service.py | 4 +- backend/app/services/jail_service.py | 52 ++- backend/app/services/server_service.py | 10 +- backend/app/tasks/geo_re_resolve.py | 2 +- backend/app/tasks/health_check.py | 2 +- backend/app/utils/fail2ban_client.py | 36 +- backend/pyproject.toml | 3 +- backend/tests/test_routers/test_auth.py | 5 +- backend/tests/test_routers/test_geo.py | 8 +- backend/tests/test_routers/test_jails.py | 7 +- backend/tests/test_routers/test_setup.py | 10 +- .../tests/test_services/test_auth_service.py | 2 +- .../tests/test_services/test_ban_service.py | 12 +- .../test_services/test_ban_service_perf.py | 4 +- .../test_services/test_config_file_service.py | 412 +++++++----------- .../tests/test_services/test_geo_service.py | 133 +++--- .../test_services/test_history_service.py | 2 +- .../tests/test_services/test_jail_service.py | 3 - backend/tests/test_tasks/test_health_check.py | 8 +- 30 files changed, 431 insertions(+), 644 deletions(-) diff --git a/backend/app/config.py b/backend/app/config.py index 4e89da2..0f73ce5 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -85,4 +85,4 @@ def get_settings() -> Settings: A validated :class:`Settings` object. Raises :class:`pydantic.ValidationError` if required keys are absent or values fail validation. """ - return Settings() # pydantic-settings populates required fields from env vars + return Settings() # type: ignore[call-arg] # pydantic-settings populates required fields from env vars diff --git a/backend/app/dependencies.py b/backend/app/dependencies.py index 7505073..b4d701c 100644 --- a/backend/app/dependencies.py +++ b/backend/app/dependencies.py @@ -92,7 +92,7 @@ async def get_settings(request: Request) -> Settings: Returns: The application settings loaded at startup. """ - state = cast(AppState, request.app.state) + state = cast("AppState", request.app.state) return state.settings diff --git a/backend/app/repositories/geo_cache_repo.py b/backend/app/repositories/geo_cache_repo.py index 51de260..6fb4e5b 100644 --- a/backend/app/repositories/geo_cache_repo.py +++ b/backend/app/repositories/geo_cache_repo.py @@ -12,6 +12,8 @@ from __future__ import annotations from typing import TYPE_CHECKING, TypedDict if TYPE_CHECKING: + from collections.abc import Sequence + import aiosqlite @@ -112,7 +114,7 @@ async def upsert_neg_entry(db: aiosqlite.Connection, ip: str) -> None: async def bulk_upsert_entries( db: aiosqlite.Connection, - rows: list[tuple[str, str | None, str | None, str | None, str | None]], + rows: Sequence[tuple[str, str | None, str | None, str | None, str | None]], ) -> int: """Bulk insert or update multiple geo cache entries.""" if not rows: diff --git a/backend/app/repositories/import_log_repo.py b/backend/app/repositories/import_log_repo.py index 860fc51..b62ccce 100644 --- a/backend/app/repositories/import_log_repo.py +++ b/backend/app/repositories/import_log_repo.py @@ -8,10 +8,11 @@ table. All methods are plain async functions that accept a from __future__ import annotations import math -from collections.abc import Mapping from typing import TYPE_CHECKING, TypedDict, cast if TYPE_CHECKING: + from collections.abc import Mapping + import aiosqlite @@ -165,5 +166,5 @@ def _row_to_dict(row: object) -> ImportLogRow: Returns: Dict mapping column names to Python values. """ - mapping = cast(Mapping[str, object], row) - return cast(ImportLogRow, dict(mapping)) + mapping = cast("Mapping[str, object]", row) + return cast("ImportLogRow", dict(mapping)) diff --git a/backend/app/routers/config.py b/backend/app/routers/config.py index 8bee91d..04d1463 100644 --- a/backend/app/routers/config.py +++ b/backend/app/routers/config.py @@ -44,8 +44,6 @@ import structlog from fastapi import APIRouter, HTTPException, Path, Query, Request, status from app.dependencies import AuthDep - -log: structlog.stdlib.BoundLogger = structlog.get_logger() from app.models.config import ( ActionConfig, ActionCreateRequest, @@ -104,6 +102,8 @@ from app.services.jail_service import JailOperationError from app.tasks.health_check import _run_probe from app.utils.fail2ban_client import Fail2BanConnectionError +log: structlog.stdlib.BoundLogger = structlog.get_logger() + router: APIRouter = APIRouter(prefix="/api/config", tags=["Config"]) # --------------------------------------------------------------------------- @@ -428,9 +428,7 @@ async def restart_fail2ban( await config_file_service.start_daemon(start_cmd_parts) # Step 3: probe the socket until fail2ban is responsive or the budget expires. - fail2ban_running: bool = await config_file_service.wait_for_fail2ban( - socket_path, max_wait_seconds=10.0 - ) + fail2ban_running: bool = await config_file_service.wait_for_fail2ban(socket_path, max_wait_seconds=10.0) if not fail2ban_running: raise HTTPException( status_code=status.HTTP_503_SERVICE_UNAVAILABLE, @@ -604,9 +602,7 @@ async def get_map_color_thresholds( """ from app.services import setup_service - high, medium, low = await setup_service.get_map_color_thresholds( - request.app.state.db - ) + high, medium, low = await setup_service.get_map_color_thresholds(request.app.state.db) return MapColorThresholdsResponse( threshold_high=high, threshold_medium=medium, @@ -696,9 +692,7 @@ async def activate_jail( req = body if body is not None else ActivateJailRequest() try: - result = await config_file_service.activate_jail( - config_dir, socket_path, name, req - ) + result = await config_file_service.activate_jail(config_dir, socket_path, name, req) except JailNameError as exc: raise _bad_request(str(exc)) from exc except JailNotFoundInConfigError: @@ -831,9 +825,7 @@ async def delete_jail_local_override( socket_path: str = request.app.state.settings.fail2ban_socket try: - await config_file_service.delete_jail_local_override( - config_dir, socket_path, name - ) + await config_file_service.delete_jail_local_override(config_dir, socket_path, name) except JailNameError as exc: raise _bad_request(str(exc)) from exc except JailNotFoundInConfigError: @@ -952,9 +944,7 @@ async def rollback_jail( start_cmd_parts: list[str] = start_cmd.split() try: - result = await config_file_service.rollback_jail( - config_dir, socket_path, name, start_cmd_parts - ) + result = await config_file_service.rollback_jail(config_dir, socket_path, name, start_cmd_parts) except JailNameError as exc: raise _bad_request(str(exc)) from exc except ConfigWriteError as exc: @@ -1107,9 +1097,7 @@ async def update_filter( config_dir: str = request.app.state.settings.fail2ban_config_dir socket_path: str = request.app.state.settings.fail2ban_socket try: - return await config_file_service.update_filter( - config_dir, socket_path, name, body, do_reload=reload - ) + return await config_file_service.update_filter(config_dir, socket_path, name, body, do_reload=reload) except FilterNameError as exc: raise _bad_request(str(exc)) from exc except FilterNotFoundError: @@ -1159,9 +1147,7 @@ async def create_filter( config_dir: str = request.app.state.settings.fail2ban_config_dir socket_path: str = request.app.state.settings.fail2ban_socket try: - return await config_file_service.create_filter( - config_dir, socket_path, body, do_reload=reload - ) + return await config_file_service.create_filter(config_dir, socket_path, body, do_reload=reload) except FilterNameError as exc: raise _bad_request(str(exc)) from exc except FilterAlreadyExistsError as exc: @@ -1257,9 +1243,7 @@ async def assign_filter_to_jail( config_dir: str = request.app.state.settings.fail2ban_config_dir socket_path: str = request.app.state.settings.fail2ban_socket try: - await config_file_service.assign_filter_to_jail( - config_dir, socket_path, name, body, do_reload=reload - ) + await config_file_service.assign_filter_to_jail(config_dir, socket_path, name, body, do_reload=reload) except (JailNameError, FilterNameError) as exc: raise _bad_request(str(exc)) from exc except JailNotFoundInConfigError: @@ -1403,9 +1387,7 @@ async def update_action( config_dir: str = request.app.state.settings.fail2ban_config_dir socket_path: str = request.app.state.settings.fail2ban_socket try: - return await config_file_service.update_action( - config_dir, socket_path, name, body, do_reload=reload - ) + return await config_file_service.update_action(config_dir, socket_path, name, body, do_reload=reload) except ActionNameError as exc: raise _bad_request(str(exc)) from exc except ActionNotFoundError: @@ -1451,9 +1433,7 @@ async def create_action( config_dir: str = request.app.state.settings.fail2ban_config_dir socket_path: str = request.app.state.settings.fail2ban_socket try: - return await config_file_service.create_action( - config_dir, socket_path, body, do_reload=reload - ) + return await config_file_service.create_action(config_dir, socket_path, body, do_reload=reload) except ActionNameError as exc: raise _bad_request(str(exc)) from exc except ActionAlreadyExistsError as exc: @@ -1546,9 +1526,7 @@ async def assign_action_to_jail( config_dir: str = request.app.state.settings.fail2ban_config_dir socket_path: str = request.app.state.settings.fail2ban_socket try: - await config_file_service.assign_action_to_jail( - config_dir, socket_path, name, body, do_reload=reload - ) + await config_file_service.assign_action_to_jail(config_dir, socket_path, name, body, do_reload=reload) except (JailNameError, ActionNameError) as exc: raise _bad_request(str(exc)) from exc except JailNotFoundInConfigError: @@ -1597,9 +1575,7 @@ async def remove_action_from_jail( config_dir: str = request.app.state.settings.fail2ban_config_dir socket_path: str = request.app.state.settings.fail2ban_socket try: - await config_file_service.remove_action_from_jail( - config_dir, socket_path, name, action_name, do_reload=reload - ) + await config_file_service.remove_action_from_jail(config_dir, socket_path, name, action_name, do_reload=reload) except (JailNameError, ActionNameError) as exc: raise _bad_request(str(exc)) from exc except JailNotFoundInConfigError: @@ -1689,4 +1665,3 @@ async def get_service_status( return await config_service.get_service_status(socket_path) except Fail2BanConnectionError as exc: raise _bad_gateway(exc) from exc - diff --git a/backend/app/routers/geo.py b/backend/app/routers/geo.py index 2b0abfc..8e4e874 100644 --- a/backend/app/routers/geo.py +++ b/backend/app/routers/geo.py @@ -13,12 +13,15 @@ from typing import TYPE_CHECKING, Annotated if TYPE_CHECKING: import aiohttp + from app.services.jail_service import IpLookupResult + import aiosqlite from fastapi import APIRouter, Depends, HTTPException, Path, Request, status from app.dependencies import AuthDep, get_db from app.models.geo import GeoCacheStatsResponse, GeoDetail, IpLookupResponse from app.services import geo_service, jail_service +from app.services.geo_service import GeoInfo from app.utils.fail2ban_client import Fail2BanConnectionError router: APIRouter = APIRouter(prefix="/api/geo", tags=["Geo"]) @@ -61,7 +64,7 @@ async def lookup_ip( return await geo_service.lookup(addr, http_session) try: - result = await jail_service.lookup_ip( + result: IpLookupResult = await jail_service.lookup_ip( socket_path, ip, geo_enricher=_enricher, @@ -77,9 +80,9 @@ async def lookup_ip( detail=f"Cannot reach fail2ban: {exc}", ) from exc - raw_geo = result.get("geo") + raw_geo = result["geo"] geo_detail: GeoDetail | None = None - if raw_geo is not None: + if isinstance(raw_geo, GeoInfo): geo_detail = GeoDetail( country_code=raw_geo.country_code, country_name=raw_geo.country_name, diff --git a/backend/app/services/ban_service.py b/backend/app/services/ban_service.py index 26d1687..b457e7e 100644 --- a/backend/app/services/ban_service.py +++ b/backend/app/services/ban_service.py @@ -14,17 +14,11 @@ import asyncio import json import time from collections.abc import Awaitable, Callable -from dataclasses import asdict from datetime import UTC, datetime -from typing import TYPE_CHECKING, TypeAlias +from typing import TYPE_CHECKING, cast import structlog -if TYPE_CHECKING: - import aiosqlite - - from app.services.geo_service import GeoInfo - from app.models.ban import ( BLOCKLIST_JAIL, BUCKET_SECONDS, @@ -37,20 +31,25 @@ from app.models.ban import ( BanTrendResponse, DashboardBanItem, DashboardBanListResponse, - JailBanCount as JailBanCountModel, TimeRange, _derive_origin, bucket_count, ) +from app.models.ban import ( + JailBanCount as JailBanCountModel, +) from app.repositories import fail2ban_db_repo -from app.utils.fail2ban_client import Fail2BanClient +from app.utils.fail2ban_client import Fail2BanClient, Fail2BanResponse if TYPE_CHECKING: import aiohttp + import aiosqlite + + from app.services.geo_service import GeoInfo log: structlog.stdlib.BoundLogger = structlog.get_logger() -GeoEnricher: TypeAlias = Callable[[str], Awaitable["GeoInfo"] | None] +type GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]] # --------------------------------------------------------------------------- # Constants @@ -137,7 +136,7 @@ async def _get_fail2ban_db_path(socket_path: str) -> str: response = await client.send(["get", "dbfile"]) try: - code, data = response + code, data = cast("Fail2BanResponse", response) except (TypeError, ValueError) as exc: raise RuntimeError(f"Unexpected response from fail2ban: {response!r}") from exc @@ -276,7 +275,7 @@ async def list_bans( # Batch-resolve geo data for all IPs on this page in a single API call. # This avoids hitting the 45 req/min single-IP rate limit when the # page contains many bans (e.g. after a large blocklist import). - geo_map: dict[str, "GeoInfo"] = {} + geo_map: dict[str, GeoInfo] = {} if http_session is not None and rows: page_ips: list[str] = [r.ip for r in rows] try: @@ -428,7 +427,7 @@ async def bans_by_country( ) unique_ips: list[str] = [r.ip for r in agg_rows] - geo_map: dict[str, "GeoInfo"] = {} + geo_map: dict[str, GeoInfo] = {} if http_session is not None and unique_ips: # Serve only what is already in the in-memory cache — no API calls on @@ -449,7 +448,7 @@ async def bans_by_country( ) elif geo_enricher is not None and unique_ips: # Fallback: legacy per-IP enricher (used in tests / older callers). - async def _safe_lookup(ip: str) -> tuple[str, "GeoInfo" | None]: + async def _safe_lookup(ip: str) -> tuple[str, GeoInfo | None]: try: return ip, await geo_enricher(ip) except Exception: # noqa: BLE001 @@ -636,9 +635,7 @@ async def bans_by_jail( # has *any* rows and log a warning with min/max timeofban so operators can # diagnose timezone or filter mismatches from logs. if total == 0: - table_row_count, min_timeofban, max_timeofban = ( - await fail2ban_db_repo.get_bans_table_summary(db_path) - ) + table_row_count, min_timeofban, max_timeofban = await fail2ban_db_repo.get_bans_table_summary(db_path) if table_row_count > 0: log.warning( "ban_service_bans_by_jail_empty_despite_data", diff --git a/backend/app/services/blocklist_service.py b/backend/app/services/blocklist_service.py index 91c7671..9350f37 100644 --- a/backend/app/services/blocklist_service.py +++ b/backend/app/services/blocklist_service.py @@ -542,7 +542,7 @@ async def list_import_logs( # --------------------------------------------------------------------------- -def _aiohttp_timeout(seconds: float) -> "aiohttp.ClientTimeout": +def _aiohttp_timeout(seconds: float) -> aiohttp.ClientTimeout: """Return an :class:`aiohttp.ClientTimeout` with the given total timeout. Args: diff --git a/backend/app/services/config_file_service.py b/backend/app/services/config_file_service.py index c31a24a..a9e3222 100644 --- a/backend/app/services/config_file_service.py +++ b/backend/app/services/config_file_service.py @@ -28,7 +28,7 @@ import os import re import tempfile from pathlib import Path -from typing import TYPE_CHECKING, cast, TypeAlias +from typing import cast import structlog @@ -59,7 +59,6 @@ from app.services.jail_service import JailNotFoundError as JailNotFoundError from app.utils import conffile_parser from app.utils.fail2ban_client import ( Fail2BanClient, - Fail2BanCommand, Fail2BanConnectionError, Fail2BanResponse, ) @@ -73,9 +72,7 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger() _SOCKET_TIMEOUT: float = 10.0 # Allowlist pattern for jail names used in path construction. -_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile( - r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$" -) +_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$") # Sections that are not jail definitions. _META_SECTIONS: frozenset[str] = frozenset({"INCLUDES", "DEFAULT"}) @@ -167,8 +164,7 @@ class FilterReadonlyError(Exception): """ self.name: str = name super().__init__( - f"Filter {name!r} is a shipped default (.conf only); " - "only user-created .local files can be deleted." + f"Filter {name!r} is a shipped default (.conf only); only user-created .local files can be deleted." ) @@ -423,9 +419,7 @@ def _parse_jails_sync( # items() merges DEFAULT values automatically. jails[section] = dict(parser.items(section)) except configparser.Error as exc: - log.warning( - "jail_section_parse_error", section=section, error=str(exc) - ) + log.warning("jail_section_parse_error", section=section, error=str(exc)) log.debug("jails_parsed", count=len(jails), config_dir=str(config_dir)) return jails, source_files @@ -522,11 +516,7 @@ def _build_inactive_jail( bantime_escalation=bantime_escalation, source_file=source_file, enabled=enabled, - has_local_override=( - (config_dir / "jail.d" / f"{name}.local").is_file() - if config_dir is not None - else False - ), + has_local_override=((config_dir / "jail.d" / f"{name}.local").is_file() if config_dir is not None else False), ) @@ -557,7 +547,7 @@ async def _get_active_jail_names(socket_path: str) -> set[str]: return result def _ok(response: object) -> object: - code, data = cast(Fail2BanResponse, response) + code, data = cast("Fail2BanResponse", response) if code != 0: raise ValueError(f"fail2ban error {code}: {data!r}") return data @@ -572,9 +562,7 @@ async def _get_active_jail_names(socket_path: str) -> set[str]: log.warning("fail2ban_unreachable_during_inactive_list") return set() except Exception as exc: # noqa: BLE001 - log.warning( - "fail2ban_status_error_during_inactive_list", error=str(exc) - ) + log.warning("fail2ban_status_error_during_inactive_list", error=str(exc)) return set() @@ -662,10 +650,7 @@ def _validate_jail_config_sync( issues.append( JailValidationIssue( field="filter", - message=( - f"Filter file not found: filter.d/{base_filter}.conf" - " (or .local)" - ), + message=(f"Filter file not found: filter.d/{base_filter}.conf (or .local)"), ) ) @@ -681,10 +666,7 @@ def _validate_jail_config_sync( issues.append( JailValidationIssue( field="action", - message=( - f"Action file not found: action.d/{action_name}.conf" - " (or .local)" - ), + message=(f"Action file not found: action.d/{action_name}.conf (or .local)"), ) ) @@ -840,9 +822,7 @@ def _write_local_override_sync( try: jail_d.mkdir(parents=True, exist_ok=True) except OSError as exc: - raise ConfigWriteError( - f"Cannot create jail.d directory: {exc}" - ) from exc + raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc local_path = jail_d / f"{jail_name}.local" @@ -867,7 +847,7 @@ def _write_local_override_sync( if overrides.get("port") is not None: lines.append(f"port = {overrides['port']}") if overrides.get("logpath"): - paths: list[str] = cast(list[str], overrides["logpath"]) + paths: list[str] = cast("list[str]", overrides["logpath"]) if paths: lines.append(f"logpath = {paths[0]}") for p in paths[1:]: @@ -890,9 +870,7 @@ def _write_local_override_sync( # Clean up temp file if rename failed. with contextlib.suppress(OSError): os.unlink(tmp_name) # noqa: F821 — only reachable when tmp_name is set - raise ConfigWriteError( - f"Failed to write {local_path}: {exc}" - ) from exc + raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc log.info( "jail_local_written", @@ -921,9 +899,7 @@ def _restore_local_file_sync(local_path: Path, original_content: bytes | None) - try: local_path.unlink(missing_ok=True) except OSError as exc: - raise ConfigWriteError( - f"Failed to delete {local_path} during rollback: {exc}" - ) from exc + raise ConfigWriteError(f"Failed to delete {local_path} during rollback: {exc}") from exc return tmp_name: str | None = None @@ -941,9 +917,7 @@ def _restore_local_file_sync(local_path: Path, original_content: bytes | None) - with contextlib.suppress(OSError): if tmp_name is not None: os.unlink(tmp_name) - raise ConfigWriteError( - f"Failed to restore {local_path} during rollback: {exc}" - ) from exc + raise ConfigWriteError(f"Failed to restore {local_path} during rollback: {exc}") from exc def _validate_regex_patterns(patterns: list[str]) -> None: @@ -979,9 +953,7 @@ def _write_filter_local_sync(filter_d: Path, name: str, content: str) -> None: try: filter_d.mkdir(parents=True, exist_ok=True) except OSError as exc: - raise ConfigWriteError( - f"Cannot create filter.d directory: {exc}" - ) from exc + raise ConfigWriteError(f"Cannot create filter.d directory: {exc}") from exc local_path = filter_d / f"{name}.local" try: @@ -998,9 +970,7 @@ def _write_filter_local_sync(filter_d: Path, name: str, content: str) -> None: except OSError as exc: with contextlib.suppress(OSError): os.unlink(tmp_name) # noqa: F821 - raise ConfigWriteError( - f"Failed to write {local_path}: {exc}" - ) from exc + raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc log.info("filter_local_written", filter=name, path=str(local_path)) @@ -1031,9 +1001,7 @@ def _set_jail_local_key_sync( try: jail_d.mkdir(parents=True, exist_ok=True) except OSError as exc: - raise ConfigWriteError( - f"Cannot create jail.d directory: {exc}" - ) from exc + raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc local_path = jail_d / f"{jail_name}.local" @@ -1072,9 +1040,7 @@ def _set_jail_local_key_sync( except OSError as exc: with contextlib.suppress(OSError): os.unlink(tmp_name) # noqa: F821 - raise ConfigWriteError( - f"Failed to write {local_path}: {exc}" - ) from exc + raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc log.info( "jail_local_key_set", @@ -1112,8 +1078,8 @@ async def list_inactive_jails( inactive jails. """ loop = asyncio.get_event_loop() - parsed_result: tuple[dict[str, dict[str, str]], dict[str, str]] = ( - await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)) + parsed_result: tuple[dict[str, dict[str, str]], dict[str, str]] = await loop.run_in_executor( + None, _parse_jails_sync, Path(config_dir) ) all_jails, source_files = parsed_result active_names: set[str] = await _get_active_jail_names(socket_path) @@ -1170,9 +1136,7 @@ async def activate_jail( _safe_jail_name(name) loop = asyncio.get_event_loop() - all_jails, _source_files = await loop.run_in_executor( - None, _parse_jails_sync, Path(config_dir) - ) + all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)) if name not in all_jails: raise JailNotFoundInConfigError(name) @@ -1208,10 +1172,7 @@ async def activate_jail( active=False, fail2ban_running=True, validation_warnings=warnings, - message=( - f"Jail {name!r} cannot be activated: " - + "; ".join(i.message for i in blocking) - ), + message=(f"Jail {name!r} cannot be activated: " + "; ".join(i.message for i in blocking)), ) overrides: dict[str, object] = { @@ -1254,9 +1215,7 @@ async def activate_jail( jail=name, error=str(exc), ) - recovered = await _rollback_activation_async( - config_dir, name, socket_path, original_content - ) + recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content) return JailActivationResponse( name=name, active=False, @@ -1272,9 +1231,7 @@ async def activate_jail( ) except Exception as exc: # noqa: BLE001 log.warning("reload_after_activate_failed", jail=name, error=str(exc)) - recovered = await _rollback_activation_async( - config_dir, name, socket_path, original_content - ) + recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content) return JailActivationResponse( name=name, active=False, @@ -1305,9 +1262,7 @@ async def activate_jail( jail=name, message="fail2ban socket unreachable after reload — initiating rollback.", ) - recovered = await _rollback_activation_async( - config_dir, name, socket_path, original_content - ) + recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content) return JailActivationResponse( name=name, active=False, @@ -1330,9 +1285,7 @@ async def activate_jail( jail=name, message="Jail did not appear in running jails — initiating rollback.", ) - recovered = await _rollback_activation_async( - config_dir, name, socket_path, original_content - ) + recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content) return JailActivationResponse( name=name, active=False, @@ -1388,14 +1341,10 @@ async def _rollback_activation_async( # Step 1 — restore original file (or delete it). try: - await loop.run_in_executor( - None, _restore_local_file_sync, local_path, original_content - ) + await loop.run_in_executor(None, _restore_local_file_sync, local_path, original_content) log.info("jail_activation_rollback_file_restored", jail=name) except ConfigWriteError as exc: - log.error( - "jail_activation_rollback_restore_failed", jail=name, error=str(exc) - ) + log.error("jail_activation_rollback_restore_failed", jail=name, error=str(exc)) return False # Step 2 — reload fail2ban with the restored config. @@ -1403,9 +1352,7 @@ async def _rollback_activation_async( await jail_service.reload_all(socket_path) log.info("jail_activation_rollback_reload_ok", jail=name) except Exception as exc: # noqa: BLE001 - log.warning( - "jail_activation_rollback_reload_failed", jail=name, error=str(exc) - ) + log.warning("jail_activation_rollback_reload_failed", jail=name, error=str(exc)) return False # Step 3 — wait for fail2ban to come back. @@ -1450,9 +1397,7 @@ async def deactivate_jail( _safe_jail_name(name) loop = asyncio.get_event_loop() - all_jails, _source_files = await loop.run_in_executor( - None, _parse_jails_sync, Path(config_dir) - ) + all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)) if name not in all_jails: raise JailNotFoundInConfigError(name) @@ -1510,9 +1455,7 @@ async def delete_jail_local_override( _safe_jail_name(name) loop = asyncio.get_event_loop() - all_jails, _source_files = await loop.run_in_executor( - None, _parse_jails_sync, Path(config_dir) - ) + all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)) if name not in all_jails: raise JailNotFoundInConfigError(name) @@ -1523,13 +1466,9 @@ async def delete_jail_local_override( local_path = Path(config_dir) / "jail.d" / f"{name}.local" try: - await loop.run_in_executor( - None, lambda: local_path.unlink(missing_ok=True) - ) + await loop.run_in_executor(None, lambda: local_path.unlink(missing_ok=True)) except OSError as exc: - raise ConfigWriteError( - f"Failed to delete {local_path}: {exc}" - ) from exc + raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc log.info("jail_local_override_deleted", jail=name, path=str(local_path)) @@ -1610,9 +1549,7 @@ async def rollback_jail( log.info("jail_rollback_start_attempted", jail=name, start_ok=started) # Wait for the socket to come back. - fail2ban_running = await wait_for_fail2ban( - socket_path, max_wait_seconds=10.0, poll_interval=2.0 - ) + fail2ban_running = await wait_for_fail2ban(socket_path, max_wait_seconds=10.0, poll_interval=2.0) active_jails = 0 if fail2ban_running: @@ -1626,10 +1563,7 @@ async def rollback_jail( disabled=True, fail2ban_running=True, active_jails=active_jails, - message=( - f"Jail {name!r} disabled and fail2ban restarted successfully " - f"with {active_jails} active jail(s)." - ), + message=(f"Jail {name!r} disabled and fail2ban restarted successfully with {active_jails} active jail(s)."), ) log.warning("jail_rollback_fail2ban_still_down", jail=name) @@ -1650,9 +1584,7 @@ async def rollback_jail( # --------------------------------------------------------------------------- # Allowlist pattern for filter names used in path construction. -_SAFE_FILTER_NAME_RE: re.Pattern[str] = re.compile( - r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$" -) +_SAFE_FILTER_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$") class FilterNotFoundError(Exception): @@ -1764,9 +1696,7 @@ def _parse_filters_sync( try: content = conf_path.read_text(encoding="utf-8") except OSError as exc: - log.warning( - "filter_read_error", name=name, path=str(conf_path), error=str(exc) - ) + log.warning("filter_read_error", name=name, path=str(conf_path), error=str(exc)) continue if has_local: @@ -1842,9 +1772,7 @@ async def list_filters( loop = asyncio.get_event_loop() # Run the synchronous scan in a thread-pool executor. - raw_filters: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor( - None, _parse_filters_sync, filter_d - ) + raw_filters: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(None, _parse_filters_sync, filter_d) # Fetch active jail names and their configs concurrently. all_jails_result, active_names = await asyncio.gather( @@ -1857,9 +1785,7 @@ async def list_filters( filters: list[FilterConfig] = [] for name, filename, content, has_local, source_path in raw_filters: - cfg = conffile_parser.parse_filter_file( - content, name=name, filename=filename - ) + cfg = conffile_parser.parse_filter_file(content, name=name, filename=filename) used_by = sorted(filter_to_jails.get(name, [])) filters.append( FilterConfig( @@ -1947,9 +1873,7 @@ async def get_filter( content, has_local, source_path = await loop.run_in_executor(None, _read) - cfg = conffile_parser.parse_filter_file( - content, name=base_name, filename=f"{base_name}.conf" - ) + cfg = conffile_parser.parse_filter_file(content, name=base_name, filename=f"{base_name}.conf") all_jails_result, active_names = await asyncio.gather( loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)), @@ -2182,9 +2106,7 @@ async def delete_filter( try: local_path.unlink() except OSError as exc: - raise ConfigWriteError( - f"Failed to delete {local_path}: {exc}" - ) from exc + raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc log.info("filter_local_deleted", filter=base_name, path=str(local_path)) @@ -2226,9 +2148,7 @@ async def assign_filter_to_jail( loop = asyncio.get_event_loop() # Verify the jail exists in config. - all_jails, _src = await loop.run_in_executor( - None, _parse_jails_sync, Path(config_dir) - ) + all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)) if jail_name not in all_jails: raise JailNotFoundInConfigError(jail_name) @@ -2276,9 +2196,7 @@ async def assign_filter_to_jail( # --------------------------------------------------------------------------- # Allowlist pattern for action names used in path construction. -_SAFE_ACTION_NAME_RE: re.Pattern[str] = re.compile( - r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$" -) +_SAFE_ACTION_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$") class ActionNotFoundError(Exception): @@ -2318,8 +2236,7 @@ class ActionReadonlyError(Exception): """ self.name: str = name super().__init__( - f"Action {name!r} is a shipped default (.conf only); " - "only user-created .local files can be deleted." + f"Action {name!r} is a shipped default (.conf only); only user-created .local files can be deleted." ) @@ -2428,9 +2345,7 @@ def _parse_actions_sync( try: content = conf_path.read_text(encoding="utf-8") except OSError as exc: - log.warning( - "action_read_error", name=name, path=str(conf_path), error=str(exc) - ) + log.warning("action_read_error", name=name, path=str(conf_path), error=str(exc)) continue if has_local: @@ -2495,9 +2410,7 @@ def _append_jail_action_sync( try: jail_d.mkdir(parents=True, exist_ok=True) except OSError as exc: - raise ConfigWriteError( - f"Cannot create jail.d directory: {exc}" - ) from exc + raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc local_path = jail_d / f"{jail_name}.local" @@ -2517,9 +2430,7 @@ def _append_jail_action_sync( existing_raw = parser.get(jail_name, "action") if parser.has_option(jail_name, "action") else "" existing_lines = [ - line.strip() - for line in existing_raw.splitlines() - if line.strip() and not line.strip().startswith("#") + line.strip() for line in existing_raw.splitlines() if line.strip() and not line.strip().startswith("#") ] # Extract base names from existing entries for duplicate checking. @@ -2533,9 +2444,7 @@ def _append_jail_action_sync( if existing_lines: # configparser multi-line: continuation lines start with whitespace. - new_value = existing_lines[0] + "".join( - f"\n {line}" for line in existing_lines[1:] - ) + new_value = existing_lines[0] + "".join(f"\n {line}" for line in existing_lines[1:]) parser.set(jail_name, "action", new_value) else: parser.set(jail_name, "action", action_entry) @@ -2559,9 +2468,7 @@ def _append_jail_action_sync( except OSError as exc: with contextlib.suppress(OSError): os.unlink(tmp_name) # noqa: F821 - raise ConfigWriteError( - f"Failed to write {local_path}: {exc}" - ) from exc + raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc log.info( "jail_action_appended", @@ -2612,9 +2519,7 @@ def _remove_jail_action_sync( existing_raw = parser.get(jail_name, "action") existing_lines = [ - line.strip() - for line in existing_raw.splitlines() - if line.strip() and not line.strip().startswith("#") + line.strip() for line in existing_raw.splitlines() if line.strip() and not line.strip().startswith("#") ] def _base(entry: str) -> str: @@ -2628,9 +2533,7 @@ def _remove_jail_action_sync( return if filtered: - new_value = filtered[0] + "".join( - f"\n {line}" for line in filtered[1:] - ) + new_value = filtered[0] + "".join(f"\n {line}" for line in filtered[1:]) parser.set(jail_name, "action", new_value) else: parser.remove_option(jail_name, "action") @@ -2654,9 +2557,7 @@ def _remove_jail_action_sync( except OSError as exc: with contextlib.suppress(OSError): os.unlink(tmp_name) # noqa: F821 - raise ConfigWriteError( - f"Failed to write {local_path}: {exc}" - ) from exc + raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc log.info( "jail_action_removed", @@ -2683,9 +2584,7 @@ def _write_action_local_sync(action_d: Path, name: str, content: str) -> None: try: action_d.mkdir(parents=True, exist_ok=True) except OSError as exc: - raise ConfigWriteError( - f"Cannot create action.d directory: {exc}" - ) from exc + raise ConfigWriteError(f"Cannot create action.d directory: {exc}") from exc local_path = action_d / f"{name}.local" try: @@ -2702,9 +2601,7 @@ def _write_action_local_sync(action_d: Path, name: str, content: str) -> None: except OSError as exc: with contextlib.suppress(OSError): os.unlink(tmp_name) # noqa: F821 - raise ConfigWriteError( - f"Failed to write {local_path}: {exc}" - ) from exc + raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc log.info("action_local_written", action=name, path=str(local_path)) @@ -2740,9 +2637,7 @@ async def list_actions( action_d = Path(config_dir) / "action.d" loop = asyncio.get_event_loop() - raw_actions: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor( - None, _parse_actions_sync, action_d - ) + raw_actions: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(None, _parse_actions_sync, action_d) all_jails_result, active_names = await asyncio.gather( loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)), @@ -2754,9 +2649,7 @@ async def list_actions( actions: list[ActionConfig] = [] for name, filename, content, has_local, source_path in raw_actions: - cfg = conffile_parser.parse_action_file( - content, name=name, filename=filename - ) + cfg = conffile_parser.parse_action_file(content, name=name, filename=filename) used_by = sorted(action_to_jails.get(name, [])) actions.append( ActionConfig( @@ -2843,9 +2736,7 @@ async def get_action( content, has_local, source_path = await loop.run_in_executor(None, _read) - cfg = conffile_parser.parse_action_file( - content, name=base_name, filename=f"{base_name}.conf" - ) + cfg = conffile_parser.parse_action_file(content, name=base_name, filename=f"{base_name}.conf") all_jails_result, active_names = await asyncio.gather( loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)), @@ -3061,9 +2952,7 @@ async def delete_action( try: local_path.unlink() except OSError as exc: - raise ConfigWriteError( - f"Failed to delete {local_path}: {exc}" - ) from exc + raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc log.info("action_local_deleted", action=base_name, path=str(local_path)) @@ -3105,9 +2994,7 @@ async def assign_action_to_jail( loop = asyncio.get_event_loop() - all_jails, _src = await loop.run_in_executor( - None, _parse_jails_sync, Path(config_dir) - ) + all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)) if jail_name not in all_jails: raise JailNotFoundInConfigError(jail_name) @@ -3187,9 +3074,7 @@ async def remove_action_from_jail( loop = asyncio.get_event_loop() - all_jails, _src = await loop.run_in_executor( - None, _parse_jails_sync, Path(config_dir) - ) + all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)) if jail_name not in all_jails: raise JailNotFoundInConfigError(jail_name) @@ -3218,4 +3103,3 @@ async def remove_action_from_jail( action=action_name, reload=do_reload, ) - diff --git a/backend/app/services/config_service.py b/backend/app/services/config_service.py index 362a51a..88f9c2f 100644 --- a/backend/app/services/config_service.py +++ b/backend/app/services/config_service.py @@ -95,7 +95,7 @@ def _ok(response: object) -> object: ValueError: If the return code indicates an error. """ try: - code, data = cast(Fail2BanResponse, response) + code, data = cast("Fail2BanResponse", response) except (TypeError, ValueError) as exc: raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc if code != 0: @@ -128,7 +128,7 @@ def _ensure_list(value: object | None) -> list[str]: return [str(value)] -_T = TypeVar("_T") +T = TypeVar("T") async def _safe_get( @@ -143,13 +143,13 @@ async def _safe_get( return default -async def _safe_get_typed( +async def _safe_get_typed[T]( client: Fail2BanClient, command: Fail2BanCommand, - default: _T, -) -> _T: + default: T, +) -> T: """Send a command and return the result typed as ``default``'s type.""" - return cast(_T, await _safe_get(client, command, default)) + return cast("T", await _safe_get(client, command, default)) def _is_not_found_error(exc: Exception) -> bool: diff --git a/backend/app/services/health_service.py b/backend/app/services/health_service.py index 87322c1..685391f 100644 --- a/backend/app/services/health_service.py +++ b/backend/app/services/health_service.py @@ -47,7 +47,7 @@ def _ok(response: object) -> object: ValueError: If the response indicates an error (return code ≠ 0). """ try: - code, data = cast(Fail2BanResponse, response) + code, data = cast("Fail2BanResponse", response) except (TypeError, ValueError) as exc: raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc diff --git a/backend/app/services/history_service.py b/backend/app/services/history_service.py index fee58c0..dc31298 100644 --- a/backend/app/services/history_service.py +++ b/backend/app/services/history_service.py @@ -11,10 +11,12 @@ modifies or locks the fail2ban database. from __future__ import annotations from datetime import UTC, datetime +from typing import TYPE_CHECKING import structlog -from app.services.geo_service import GeoEnricher +if TYPE_CHECKING: + from app.services.geo_service import GeoEnricher from app.models.ban import TIME_RANGE_SECONDS, TimeRange from app.models.history import ( diff --git a/backend/app/services/jail_service.py b/backend/app/services/jail_service.py index 958a6ec..5df436f 100644 --- a/backend/app/services/jail_service.py +++ b/backend/app/services/jail_service.py @@ -14,7 +14,8 @@ from __future__ import annotations import asyncio import contextlib import ipaddress -from typing import TYPE_CHECKING, Awaitable, Callable, cast, TypeAlias +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, TypedDict, cast import structlog @@ -27,6 +28,7 @@ from app.models.jail import ( JailStatus, JailSummary, ) +from app.services.geo_service import GeoInfo from app.utils.fail2ban_client import ( Fail2BanClient, Fail2BanCommand, @@ -39,11 +41,21 @@ if TYPE_CHECKING: import aiohttp import aiosqlite - from app.services.geo_service import GeoInfo - log: structlog.stdlib.BoundLogger = structlog.get_logger() -GeoEnricher: TypeAlias = Callable[[str], Awaitable["GeoInfo | None"]] +class IpLookupResult(TypedDict): + """Result returned by :func:`lookup_ip`. + + This is intentionally a :class:`TypedDict` to provide precise typing for + callers (e.g. routers) while keeping the implementation flexible. + """ + + ip: str + currently_banned_in: list[str] + geo: GeoInfo | None + + +GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]] # --------------------------------------------------------------------------- # Constants @@ -104,7 +116,7 @@ def _ok(response: object) -> object: ValueError: If the response indicates an error (return code ≠ 0). """ try: - code, data = cast(Fail2BanResponse, response) + code, data = cast("Fail2BanResponse", response) except (TypeError, ValueError) as exc: raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc @@ -202,7 +214,7 @@ async def _safe_get( """ try: response = await client.send(command) - return _ok(cast(Fail2BanResponse, response)) + return _ok(cast("Fail2BanResponse", response)) except (ValueError, TypeError, Exception): return default @@ -337,7 +349,6 @@ async def _fetch_jail_summary( client.send(["get", name, "backend"]), client.send(["get", name, "idle"]), ]) - uses_backend_backend_commands = True else: # Commands not supported; return default values without sending. async def _return_default(value: object | None) -> Fail2BanResponse: @@ -347,7 +358,6 @@ async def _fetch_jail_summary( _return_default("polling"), # backend default _return_default(False), # idle default ]) - uses_backend_backend_commands = False _r = await asyncio.gather(*gather_list, return_exceptions=True) status_raw: object | Exception = _r[0] @@ -377,7 +387,7 @@ async def _fetch_jail_summary( if isinstance(raw, Exception): return fallback try: - return int(str(_ok(cast(Fail2BanResponse, raw)))) + return int(str(_ok(cast("Fail2BanResponse", raw)))) except (ValueError, TypeError): return fallback @@ -385,7 +395,7 @@ async def _fetch_jail_summary( if isinstance(raw, Exception): return fallback try: - return str(_ok(cast(Fail2BanResponse, raw))) + return str(_ok(cast("Fail2BanResponse", raw))) except (ValueError, TypeError): return fallback @@ -393,7 +403,7 @@ async def _fetch_jail_summary( if isinstance(raw, Exception): return fallback try: - return bool(_ok(cast(Fail2BanResponse, raw))) + return bool(_ok(cast("Fail2BanResponse", raw))) except (ValueError, TypeError): return fallback @@ -687,7 +697,7 @@ async def reload_all( names_set -= set(exclude_jails) stream: list[list[object]] = [["start", n] for n in sorted(names_set)] - _ok(await client.send(["reload", "--all", [], cast(Fail2BanToken, stream)])) + _ok(await client.send(["reload", "--all", [], cast("Fail2BanToken", stream)])) log.info("all_jails_reloaded") except ValueError as exc: # Detect UnknownJailException (missing or invalid jail configuration) @@ -811,8 +821,8 @@ async def unban_ip( async def get_active_bans( socket_path: str, geo_enricher: GeoEnricher | None = None, - http_session: "aiohttp.ClientSession" | None = None, - app_db: "aiosqlite.Connection" | None = None, + http_session: aiohttp.ClientSession | None = None, + app_db: aiosqlite.Connection | None = None, ) -> ActiveBanListResponse: """Return all currently banned IPs across every jail. @@ -880,7 +890,7 @@ async def get_active_bans( continue try: - ban_list: list[str] = cast(list[str], _ok(raw_result)) or [] + ban_list: list[str] = cast("list[str]", _ok(raw_result)) or [] except (TypeError, ValueError) as exc: log.warning( "active_bans_parse_error", @@ -1007,8 +1017,8 @@ async def get_jail_banned_ips( page: int = 1, page_size: int = 25, search: str | None = None, - http_session: "aiohttp.ClientSession" | None = None, - app_db: "aiosqlite.Connection" | None = None, + http_session: aiohttp.ClientSession | None = None, + app_db: aiosqlite.Connection | None = None, ) -> JailBannedIpsResponse: """Return a paginated list of currently banned IPs for a single jail. @@ -1055,7 +1065,7 @@ async def get_jail_banned_ips( except (ValueError, TypeError): raw_result = [] - ban_list: list[str] = cast(list[str], raw_result) or [] + ban_list: list[str] = cast("list[str]", raw_result) or [] # Parse all entries. all_bans: list[ActiveBan] = [] @@ -1121,7 +1131,7 @@ async def _enrich_bans( The same list with ``country`` fields populated where lookup succeeded. """ geo_results: list[object | Exception] = await asyncio.gather( - *[cast(Awaitable[object], geo_enricher(ban.ip)) for ban in bans], + *[cast("Awaitable[object]", geo_enricher(ban.ip)) for ban in bans], return_exceptions=True, ) enriched: list[ActiveBan] = [] @@ -1277,7 +1287,7 @@ async def lookup_ip( socket_path: str, ip: str, geo_enricher: GeoEnricher | None = None, -) -> dict[str, object | list[str] | None]: +) -> IpLookupResult: """Return ban status and history for a single IP address. Checks every running jail for whether the IP is currently banned. @@ -1330,7 +1340,7 @@ async def lookup_ip( if isinstance(result, Exception): continue try: - ban_list: list[str] = cast(list[str], _ok(result)) or [] + ban_list: list[str] = cast("list[str]", _ok(result)) or [] if ip in ban_list: currently_banned_in.append(jail_name) except (ValueError, TypeError): diff --git a/backend/app/services/server_service.py b/backend/app/services/server_service.py index 85d5914..1943ef6 100644 --- a/backend/app/services/server_service.py +++ b/backend/app/services/server_service.py @@ -10,7 +10,7 @@ HTTP/FastAPI concerns. from __future__ import annotations -from typing import cast, TypeAlias +from typing import cast import structlog @@ -21,7 +21,7 @@ from app.utils.fail2ban_client import Fail2BanClient, Fail2BanCommand, Fail2BanR # Types # --------------------------------------------------------------------------- -Fail2BanSettingValue: TypeAlias = str | int | bool +type Fail2BanSettingValue = str | int | bool """Allowed values for server settings commands.""" log: structlog.stdlib.BoundLogger = structlog.get_logger() @@ -106,7 +106,7 @@ async def _safe_get( """ try: response = await client.send(command) - return _ok(cast(Fail2BanResponse, response)) + return _ok(cast("Fail2BanResponse", response)) except Exception: return default @@ -189,7 +189,7 @@ async def update_settings(socket_path: str, update: ServerSettingsUpdate) -> Non async def _set(key: str, value: Fail2BanSettingValue) -> None: try: response = await client.send(["set", key, value]) - _ok(cast(Fail2BanResponse, response)) + _ok(cast("Fail2BanResponse", response)) except ValueError as exc: raise ServerOperationError(f"Failed to set {key!r} = {value!r}: {exc}") from exc @@ -224,7 +224,7 @@ async def flush_logs(socket_path: str) -> str: client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) try: response = await client.send(["flushlogs"]) - result = _ok(cast(Fail2BanResponse, response)) + result = _ok(cast("Fail2BanResponse", response)) log.info("logs_flushed", result=result) return str(result) except ValueError as exc: diff --git a/backend/app/tasks/geo_re_resolve.py b/backend/app/tasks/geo_re_resolve.py index e3f85fe..81e93d7 100644 --- a/backend/app/tasks/geo_re_resolve.py +++ b/backend/app/tasks/geo_re_resolve.py @@ -35,7 +35,7 @@ GEO_RE_RESOLVE_INTERVAL: int = 600 JOB_ID: str = "geo_re_resolve" -async def _run_re_resolve(app: "FastAPI") -> None: +async def _run_re_resolve(app: FastAPI) -> None: """Query NULL-country IPs from the database and re-resolve them. Reads shared resources from ``app.state`` and delegates to diff --git a/backend/app/tasks/health_check.py b/backend/app/tasks/health_check.py index 597b92d..996bdd4 100644 --- a/backend/app/tasks/health_check.py +++ b/backend/app/tasks/health_check.py @@ -47,7 +47,7 @@ HEALTH_CHECK_INTERVAL: int = 30 _ACTIVATION_CRASH_WINDOW: int = 60 -async def _run_probe(app: "FastAPI") -> None: +async def _run_probe(app: FastAPI) -> None: """Probe fail2ban and cache the result on *app.state*. Detects online/offline state transitions. When fail2ban goes offline diff --git a/backend/app/utils/fail2ban_client.py b/backend/app/utils/fail2ban_client.py index 6e84cf6..d02a6a5 100644 --- a/backend/app/utils/fail2ban_client.py +++ b/backend/app/utils/fail2ban_client.py @@ -21,34 +21,52 @@ import contextlib import errno import socket import time +from collections.abc import Mapping, Sequence, Set from pickle import HIGHEST_PROTOCOL, dumps, loads -from typing import TYPE_CHECKING, TypeAlias +from typing import TYPE_CHECKING + +import structlog # --------------------------------------------------------------------------- # Types # --------------------------------------------------------------------------- -Fail2BanToken: TypeAlias = str | int | float | bool | None | dict[str, object] | list[object] +# Use covariant container types so callers can pass ``list[int]`` / ``dict[str, str]`` +# without needing to cast. At runtime we only accept the basic built-in +# containers supported by fail2ban's protocol (list/dict/set) and stringify +# anything else. +# +# NOTE: ``Sequence`` will also accept tuples, but tuples are stringified at +# runtime because fail2ban only understands lists. + +type Fail2BanToken = ( + str + | int + | float + | bool + | None + | Mapping[str, object] + | Sequence[object] + | Set[object] +) """A single token in a fail2ban command. Fail2ban accepts simple types (str/int/float/bool) plus compound types -(list/dict). Complex objects are stringified before being sent. +(list/dict/set). Complex objects are stringified before being sent. """ -Fail2BanCommand: TypeAlias = list[Fail2BanToken] +type Fail2BanCommand = Sequence[Fail2BanToken] """A command sent to fail2ban over the socket. -Commands are pickle serialised lists of tokens. +Commands are pickle serialised sequences of tokens. """ -Fail2BanResponse: TypeAlias = tuple[int, object] +type Fail2BanResponse = tuple[int, object] """A typical fail2ban response containing a status code and payload.""" if TYPE_CHECKING: from types import TracebackType -import structlog - log: structlog.stdlib.BoundLogger = structlog.get_logger() # fail2ban protocol constants — inline to avoid a hard import dependency @@ -200,7 +218,7 @@ def _send_command_sync( ) from last_oserror -def _coerce_command_token(token: Fail2BanToken) -> Fail2BanToken: +def _coerce_command_token(token: object) -> Fail2BanToken: """Coerce a command token to a type that fail2ban understands. fail2ban's ``CSocket.convert`` accepts ``str``, ``bool``, ``int``, diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 5938a4c..4649476 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -60,4 +60,5 @@ plugins = ["pydantic.mypy"] asyncio_mode = "auto" pythonpath = [".", "../fail2ban-master"] testpaths = ["tests"] -addopts = "--cov=app --cov-report=term-missing" +addopts = "--asyncio-mode=auto --cov=app --cov-report=term-missing" +filterwarnings = ["ignore::pytest.PytestRemovedIn9Warning"] diff --git a/backend/tests/test_routers/test_auth.py b/backend/tests/test_routers/test_auth.py index afd59d7..8d5ebe9 100644 --- a/backend/tests/test_routers/test_auth.py +++ b/backend/tests/test_routers/test_auth.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Generator from unittest.mock import patch import pytest @@ -157,12 +158,12 @@ class TestRequireAuthSessionCache: """In-memory session token cache inside ``require_auth``.""" @pytest.fixture(autouse=True) - def reset_cache(self) -> None: # type: ignore[misc] + def reset_cache(self) -> Generator[None, None, None]: """Flush the session cache before and after every test in this class.""" from app import dependencies dependencies.clear_session_cache() - yield # type: ignore[misc] + yield dependencies.clear_session_cache() async def test_second_request_skips_db(self, client: AsyncClient) -> None: diff --git a/backend/tests/test_routers/test_geo.py b/backend/tests/test_routers/test_geo.py index c57363e..2f8ef43 100644 --- a/backend/tests/test_routers/test_geo.py +++ b/backend/tests/test_routers/test_geo.py @@ -70,7 +70,7 @@ class TestGeoLookup: async def test_200_with_geo_info(self, geo_client: AsyncClient) -> None: """GET /api/geo/lookup/{ip} returns 200 with enriched result.""" geo = GeoInfo(country_code="DE", country_name="Germany", asn="12345", org="Acme") - result = { + result: dict[str, object] = { "ip": "1.2.3.4", "currently_banned_in": ["sshd"], "geo": geo, @@ -92,7 +92,7 @@ class TestGeoLookup: async def test_200_when_not_banned(self, geo_client: AsyncClient) -> None: """GET /api/geo/lookup/{ip} returns empty list when IP is not banned anywhere.""" - result = { + result: dict[str, object] = { "ip": "8.8.8.8", "currently_banned_in": [], "geo": GeoInfo(country_code="US", country_name="United States", asn=None, org=None), @@ -108,7 +108,7 @@ class TestGeoLookup: async def test_200_with_no_geo(self, geo_client: AsyncClient) -> None: """GET /api/geo/lookup/{ip} returns null geo when enricher fails.""" - result = { + result: dict[str, object] = { "ip": "1.2.3.4", "currently_banned_in": [], "geo": None, @@ -144,7 +144,7 @@ class TestGeoLookup: async def test_ipv6_address(self, geo_client: AsyncClient) -> None: """GET /api/geo/lookup/{ip} handles IPv6 addresses.""" - result = { + result: dict[str, object] = { "ip": "2001:db8::1", "currently_banned_in": [], "geo": None, diff --git a/backend/tests/test_routers/test_jails.py b/backend/tests/test_routers/test_jails.py index 4954e23..eee7c46 100644 --- a/backend/tests/test_routers/test_jails.py +++ b/backend/tests/test_routers/test_jails.py @@ -12,6 +12,7 @@ from httpx import ASGITransport, AsyncClient from app.config import Settings from app.db import init_db from app.main import create_app +from app.models.ban import JailBannedIpsResponse from app.models.jail import Jail, JailDetailResponse, JailListResponse, JailStatus, JailSummary # --------------------------------------------------------------------------- @@ -801,17 +802,17 @@ class TestGetJailBannedIps: def _mock_response( self, *, - items: list[dict] | None = None, + items: list[dict[str, str | None]] | None = None, total: int = 2, page: int = 1, page_size: int = 25, - ) -> "JailBannedIpsResponse": # type: ignore[name-defined] + ) -> JailBannedIpsResponse: from app.models.ban import ActiveBan, JailBannedIpsResponse ban_items = ( [ ActiveBan( - ip=item.get("ip", "1.2.3.4"), + ip=item.get("ip") or "1.2.3.4", jail="sshd", banned_at=item.get("banned_at", "2025-01-01T10:00:00+00:00"), expires_at=item.get("expires_at", "2025-01-01T10:10:00+00:00"), diff --git a/backend/tests/test_routers/test_setup.py b/backend/tests/test_routers/test_setup.py index 0fc7040..da9e623 100644 --- a/backend/tests/test_routers/test_setup.py +++ b/backend/tests/test_routers/test_setup.py @@ -247,9 +247,9 @@ class TestSetupCompleteCaching: assert not getattr(app.state, "_setup_complete_cached", False) # First non-exempt request — middleware queries DB and sets the flag. - await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload] + await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) - assert app.state._setup_complete_cached is True # type: ignore[attr-defined] + assert app.state._setup_complete_cached is True async def test_cached_path_skips_is_setup_complete( self, @@ -267,12 +267,12 @@ class TestSetupCompleteCaching: # Do setup and warm the cache. await client.post("/api/setup", json=_SETUP_PAYLOAD) - await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload] - assert app.state._setup_complete_cached is True # type: ignore[attr-defined] + await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) + assert app.state._setup_complete_cached is True call_count = 0 - async def _counting(db): # type: ignore[no-untyped-def] + async def _counting(db: aiosqlite.Connection) -> bool: nonlocal call_count call_count += 1 return True diff --git a/backend/tests/test_services/test_auth_service.py b/backend/tests/test_services/test_auth_service.py index d30a8b5..1df04c0 100644 --- a/backend/tests/test_services/test_auth_service.py +++ b/backend/tests/test_services/test_auth_service.py @@ -73,7 +73,7 @@ class TestCheckPasswordAsync: auth_service._check_password("secret", hashed), # noqa: SLF001 auth_service._check_password("wrong", hashed), # noqa: SLF001 ) - assert results == [True, False] + assert tuple(results) == (True, False) class TestLogin: diff --git a/backend/tests/test_services/test_ban_service.py b/backend/tests/test_services/test_ban_service.py index d0d93b7..de2faaa 100644 --- a/backend/tests/test_services/test_ban_service.py +++ b/backend/tests/test_services/test_ban_service.py @@ -64,7 +64,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None: @pytest.fixture -async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc] +async def f2b_db_path(tmp_path: Path) -> str: """Return the path to a test fail2ban SQLite database with several bans.""" path = str(tmp_path / "fail2ban_test.sqlite3") await _create_f2b_db( @@ -103,7 +103,7 @@ async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc] @pytest.fixture -async def mixed_origin_db_path(tmp_path: Path) -> str: # type: ignore[misc] +async def mixed_origin_db_path(tmp_path: Path) -> str: """Return a database with bans from both blocklist-import and organic jails.""" path = str(tmp_path / "fail2ban_mixed_origin.sqlite3") await _create_f2b_db( @@ -136,7 +136,7 @@ async def mixed_origin_db_path(tmp_path: Path) -> str: # type: ignore[misc] @pytest.fixture -async def empty_f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc] +async def empty_f2b_db_path(tmp_path: Path) -> str: """Return the path to a fail2ban SQLite database with no ban records.""" path = str(tmp_path / "fail2ban_empty.sqlite3") await _create_f2b_db(path, []) @@ -632,13 +632,13 @@ class TestBansbyCountryBackground: from app.services import geo_service # Pre-populate the cache for all three IPs in the fixture. - geo_service._cache["10.0.0.1"] = geo_service.GeoInfo( # type: ignore[attr-defined] + geo_service._cache["10.0.0.1"] = geo_service.GeoInfo( country_code="DE", country_name="Germany", asn=None, org=None ) - geo_service._cache["10.0.0.2"] = geo_service.GeoInfo( # type: ignore[attr-defined] + geo_service._cache["10.0.0.2"] = geo_service.GeoInfo( country_code="US", country_name="United States", asn=None, org=None ) - geo_service._cache["10.0.0.3"] = geo_service.GeoInfo( # type: ignore[attr-defined] + geo_service._cache["10.0.0.3"] = geo_service.GeoInfo( country_code="JP", country_name="Japan", asn=None, org=None ) diff --git a/backend/tests/test_services/test_ban_service_perf.py b/backend/tests/test_services/test_ban_service_perf.py index bbf007b..7f898be 100644 --- a/backend/tests/test_services/test_ban_service_perf.py +++ b/backend/tests/test_services/test_ban_service_perf.py @@ -114,13 +114,13 @@ async def _seed_f2b_db(path: str, n: int) -> list[str]: @pytest.fixture(scope="module") -def event_loop_policy() -> None: # type: ignore[misc] +def event_loop_policy() -> None: """Use the default event loop policy for module-scoped fixtures.""" return None @pytest.fixture(scope="module") -async def perf_db_path(tmp_path_factory: Any) -> str: # type: ignore[misc] +async def perf_db_path(tmp_path_factory: Any) -> str: """Return the path to a fail2ban DB seeded with 10 000 synthetic bans. Module-scoped so the database is created only once for all perf tests. diff --git a/backend/tests/test_services/test_config_file_service.py b/backend/tests/test_services/test_config_file_service.py index e648fe8..26b7918 100644 --- a/backend/tests/test_services/test_config_file_service.py +++ b/backend/tests/test_services/test_config_file_service.py @@ -13,15 +13,19 @@ from app.services.config_file_service import ( JailNameError, JailNotFoundInConfigError, _build_inactive_jail, + _extract_action_base_name, + _extract_filter_base_name, _ordered_config_files, _parse_jails_sync, _resolve_filter, _safe_jail_name, + _validate_jail_config_sync, _write_local_override_sync, activate_jail, deactivate_jail, list_inactive_jails, rollback_jail, + validate_jail_config, ) # --------------------------------------------------------------------------- @@ -292,9 +296,7 @@ class TestBuildInactiveJail: def test_has_local_override_absent(self, tmp_path: Path) -> None: """has_local_override is False when no .local file exists.""" - jail = _build_inactive_jail( - "sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path - ) + jail = _build_inactive_jail("sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path) assert jail.has_local_override is False def test_has_local_override_present(self, tmp_path: Path) -> None: @@ -302,9 +304,7 @@ class TestBuildInactiveJail: local = tmp_path / "jail.d" / "sshd.local" local.parent.mkdir(parents=True, exist_ok=True) local.write_text("[sshd]\nenabled = false\n") - jail = _build_inactive_jail( - "sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path - ) + jail = _build_inactive_jail("sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path) assert jail.has_local_override is True def test_has_local_override_no_config_dir(self) -> None: @@ -363,9 +363,7 @@ class TestWriteLocalOverrideSync: assert "2222" in content def test_override_logpath_list(self, tmp_path: Path) -> None: - _write_local_override_sync( - tmp_path, "sshd", True, {"logpath": ["/var/log/auth.log", "/var/log/secure"]} - ) + _write_local_override_sync(tmp_path, "sshd", True, {"logpath": ["/var/log/auth.log", "/var/log/secure"]}) content = (tmp_path / "jail.d" / "sshd.local").read_text() assert "/var/log/auth.log" in content assert "/var/log/secure" in content @@ -447,9 +445,7 @@ class TestListInactiveJails: assert "sshd" in names assert "apache-auth" in names - async def test_has_local_override_true_when_local_file_exists( - self, tmp_path: Path - ) -> None: + async def test_has_local_override_true_when_local_file_exists(self, tmp_path: Path) -> None: """has_local_override is True for a jail whose jail.d .local file exists.""" _write(tmp_path / "jail.conf", JAIL_CONF) local = tmp_path / "jail.d" / "apache-auth.local" @@ -463,9 +459,7 @@ class TestListInactiveJails: jail = next(j for j in result.jails if j.name == "apache-auth") assert jail.has_local_override is True - async def test_has_local_override_false_when_no_local_file( - self, tmp_path: Path - ) -> None: + async def test_has_local_override_false_when_no_local_file(self, tmp_path: Path) -> None: """has_local_override is False when no jail.d .local file exists.""" _write(tmp_path / "jail.conf", JAIL_CONF) with patch( @@ -608,7 +602,8 @@ class TestActivateJail: patch( "app.services.config_file_service._get_active_jail_names", new=AsyncMock(return_value=set()), - ),pytest.raises(JailNotFoundInConfigError) + ), + pytest.raises(JailNotFoundInConfigError), ): await activate_jail(str(tmp_path), "/fake.sock", "nonexistent", req) @@ -621,7 +616,8 @@ class TestActivateJail: patch( "app.services.config_file_service._get_active_jail_names", new=AsyncMock(return_value={"sshd"}), - ),pytest.raises(JailAlreadyActiveError) + ), + pytest.raises(JailAlreadyActiveError), ): await activate_jail(str(tmp_path), "/fake.sock", "sshd", req) @@ -691,7 +687,8 @@ class TestDeactivateJail: patch( "app.services.config_file_service._get_active_jail_names", new=AsyncMock(return_value={"sshd"}), - ),pytest.raises(JailNotFoundInConfigError) + ), + pytest.raises(JailNotFoundInConfigError), ): await deactivate_jail(str(tmp_path), "/fake.sock", "nonexistent") @@ -701,7 +698,8 @@ class TestDeactivateJail: patch( "app.services.config_file_service._get_active_jail_names", new=AsyncMock(return_value=set()), - ),pytest.raises(JailAlreadyInactiveError) + ), + pytest.raises(JailAlreadyInactiveError), ): await deactivate_jail(str(tmp_path), "/fake.sock", "apache-auth") @@ -710,38 +708,6 @@ class TestDeactivateJail: await deactivate_jail(str(tmp_path), "/fake.sock", "a/b") -# --------------------------------------------------------------------------- -# _extract_filter_base_name -# --------------------------------------------------------------------------- - - -class TestExtractFilterBaseName: - def test_simple_name(self) -> None: - from app.services.config_file_service import _extract_filter_base_name - - assert _extract_filter_base_name("sshd") == "sshd" - - def test_name_with_mode(self) -> None: - from app.services.config_file_service import _extract_filter_base_name - - assert _extract_filter_base_name("sshd[mode=aggressive]") == "sshd" - - def test_name_with_variable_mode(self) -> None: - from app.services.config_file_service import _extract_filter_base_name - - assert _extract_filter_base_name("sshd[mode=%(mode)s]") == "sshd" - - def test_whitespace_stripped(self) -> None: - from app.services.config_file_service import _extract_filter_base_name - - assert _extract_filter_base_name(" nginx ") == "nginx" - - def test_empty_string(self) -> None: - from app.services.config_file_service import _extract_filter_base_name - - assert _extract_filter_base_name("") == "" - - # --------------------------------------------------------------------------- # _build_filter_to_jails_map # --------------------------------------------------------------------------- @@ -757,9 +723,7 @@ class TestBuildFilterToJailsMap: def test_inactive_jail_not_included(self) -> None: from app.services.config_file_service import _build_filter_to_jails_map - result = _build_filter_to_jails_map( - {"apache-auth": {"filter": "apache-auth"}}, set() - ) + result = _build_filter_to_jails_map({"apache-auth": {"filter": "apache-auth"}}, set()) assert result == {} def test_multiple_jails_sharing_filter(self) -> None: @@ -775,9 +739,7 @@ class TestBuildFilterToJailsMap: def test_mode_suffix_stripped(self) -> None: from app.services.config_file_service import _build_filter_to_jails_map - result = _build_filter_to_jails_map( - {"sshd": {"filter": "sshd[mode=aggressive]"}}, {"sshd"} - ) + result = _build_filter_to_jails_map({"sshd": {"filter": "sshd[mode=aggressive]"}}, {"sshd"}) assert "sshd" in result def test_missing_filter_key_falls_back_to_jail_name(self) -> None: @@ -988,10 +950,13 @@ class TestGetFilter: async def test_raises_filter_not_found(self, tmp_path: Path) -> None: from app.services.config_file_service import FilterNotFoundError, get_filter - with patch( - "app.services.config_file_service._get_active_jail_names", - new=AsyncMock(return_value=set()), - ), pytest.raises(FilterNotFoundError): + with ( + patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ), + pytest.raises(FilterNotFoundError), + ): await get_filter(str(tmp_path), "/fake.sock", "nonexistent") async def test_has_local_override_detected(self, tmp_path: Path) -> None: @@ -1093,10 +1058,13 @@ class TestGetFilterLocalOnly: async def test_raises_when_neither_conf_nor_local(self, tmp_path: Path) -> None: from app.services.config_file_service import FilterNotFoundError, get_filter - with patch( - "app.services.config_file_service._get_active_jail_names", - new=AsyncMock(return_value=set()), - ), pytest.raises(FilterNotFoundError): + with ( + patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ), + pytest.raises(FilterNotFoundError), + ): await get_filter(str(tmp_path), "/fake.sock", "nonexistent") async def test_accepts_local_extension(self, tmp_path: Path) -> None: @@ -1212,9 +1180,7 @@ class TestSetJailLocalKeySync: jail_d = tmp_path / "jail.d" jail_d.mkdir() - (jail_d / "sshd.local").write_text( - "[sshd]\nenabled = true\n" - ) + (jail_d / "sshd.local").write_text("[sshd]\nenabled = true\n") _set_jail_local_key_sync(tmp_path, "sshd", "filter", "newfilter") @@ -1300,10 +1266,13 @@ class TestUpdateFilter: from app.models.config import FilterUpdateRequest from app.services.config_file_service import FilterNotFoundError, update_filter - with patch( - "app.services.config_file_service._get_active_jail_names", - new=AsyncMock(return_value=set()), - ), pytest.raises(FilterNotFoundError): + with ( + patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ), + pytest.raises(FilterNotFoundError), + ): await update_filter( str(tmp_path), "/fake.sock", @@ -1321,10 +1290,13 @@ class TestUpdateFilter: filter_d = tmp_path / "filter.d" _write(filter_d / "sshd.conf", _FILTER_CONF_WITH_REGEX) - with patch( - "app.services.config_file_service._get_active_jail_names", - new=AsyncMock(return_value=set()), - ), pytest.raises(FilterInvalidRegexError): + with ( + patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ), + pytest.raises(FilterInvalidRegexError), + ): await update_filter( str(tmp_path), "/fake.sock", @@ -1351,13 +1323,16 @@ class TestUpdateFilter: filter_d = tmp_path / "filter.d" _write(filter_d / "sshd.conf", _FILTER_CONF) - with patch( - "app.services.config_file_service._get_active_jail_names", - new=AsyncMock(return_value=set()), - ), patch( - "app.services.config_file_service.jail_service.reload_all", - new=AsyncMock(), - ) as mock_reload: + with ( + patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ), + patch( + "app.services.config_file_service.jail_service.reload_all", + new=AsyncMock(), + ) as mock_reload, + ): await update_filter( str(tmp_path), "/fake.sock", @@ -1405,10 +1380,13 @@ class TestCreateFilter: filter_d = tmp_path / "filter.d" _write(filter_d / "sshd.conf", _FILTER_CONF) - with patch( - "app.services.config_file_service._get_active_jail_names", - new=AsyncMock(return_value=set()), - ), pytest.raises(FilterAlreadyExistsError): + with ( + patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ), + pytest.raises(FilterAlreadyExistsError), + ): await create_filter( str(tmp_path), "/fake.sock", @@ -1422,10 +1400,13 @@ class TestCreateFilter: filter_d = tmp_path / "filter.d" _write(filter_d / "custom.local", "[Definition]\n") - with patch( - "app.services.config_file_service._get_active_jail_names", - new=AsyncMock(return_value=set()), - ), pytest.raises(FilterAlreadyExistsError): + with ( + patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ), + pytest.raises(FilterAlreadyExistsError), + ): await create_filter( str(tmp_path), "/fake.sock", @@ -1436,10 +1417,13 @@ class TestCreateFilter: from app.models.config import FilterCreateRequest from app.services.config_file_service import FilterInvalidRegexError, create_filter - with patch( - "app.services.config_file_service._get_active_jail_names", - new=AsyncMock(return_value=set()), - ), pytest.raises(FilterInvalidRegexError): + with ( + patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ), + pytest.raises(FilterInvalidRegexError), + ): await create_filter( str(tmp_path), "/fake.sock", @@ -1461,13 +1445,16 @@ class TestCreateFilter: from app.models.config import FilterCreateRequest from app.services.config_file_service import create_filter - with patch( - "app.services.config_file_service._get_active_jail_names", - new=AsyncMock(return_value=set()), - ), patch( - "app.services.config_file_service.jail_service.reload_all", - new=AsyncMock(), - ) as mock_reload: + with ( + patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ), + patch( + "app.services.config_file_service.jail_service.reload_all", + new=AsyncMock(), + ) as mock_reload, + ): await create_filter( str(tmp_path), "/fake.sock", @@ -1485,9 +1472,7 @@ class TestCreateFilter: @pytest.mark.asyncio class TestDeleteFilter: - async def test_deletes_local_file_when_conf_and_local_exist( - self, tmp_path: Path - ) -> None: + async def test_deletes_local_file_when_conf_and_local_exist(self, tmp_path: Path) -> None: from app.services.config_file_service import delete_filter filter_d = tmp_path / "filter.d" @@ -1524,9 +1509,7 @@ class TestDeleteFilter: with pytest.raises(FilterNotFoundError): await delete_filter(str(tmp_path), "nonexistent") - async def test_accepts_filter_name_error_for_invalid_name( - self, tmp_path: Path - ) -> None: + async def test_accepts_filter_name_error_for_invalid_name(self, tmp_path: Path) -> None: from app.services.config_file_service import FilterNameError, delete_filter with pytest.raises(FilterNameError): @@ -1607,9 +1590,7 @@ class TestAssignFilterToJail: AssignFilterRequest(filter_name="sshd"), ) - async def test_raises_filter_name_error_for_invalid_filter( - self, tmp_path: Path - ) -> None: + async def test_raises_filter_name_error_for_invalid_filter(self, tmp_path: Path) -> None: from app.models.config import AssignFilterRequest from app.services.config_file_service import FilterNameError, assign_filter_to_jail @@ -1719,34 +1700,26 @@ class TestBuildActionToJailsMap: def test_active_jail_maps_to_action(self) -> None: from app.services.config_file_service import _build_action_to_jails_map - result = _build_action_to_jails_map( - {"sshd": {"action": "iptables-multiport"}}, {"sshd"} - ) + result = _build_action_to_jails_map({"sshd": {"action": "iptables-multiport"}}, {"sshd"}) assert result == {"iptables-multiport": ["sshd"]} def test_inactive_jail_not_included(self) -> None: from app.services.config_file_service import _build_action_to_jails_map - result = _build_action_to_jails_map( - {"sshd": {"action": "iptables-multiport"}}, set() - ) + result = _build_action_to_jails_map({"sshd": {"action": "iptables-multiport"}}, set()) assert result == {} def test_multiple_actions_per_jail(self) -> None: from app.services.config_file_service import _build_action_to_jails_map - result = _build_action_to_jails_map( - {"sshd": {"action": "iptables-multiport\niptables-ipset"}}, {"sshd"} - ) + result = _build_action_to_jails_map({"sshd": {"action": "iptables-multiport\niptables-ipset"}}, {"sshd"}) assert "iptables-multiport" in result assert "iptables-ipset" in result def test_parameter_block_stripped(self) -> None: from app.services.config_file_service import _build_action_to_jails_map - result = _build_action_to_jails_map( - {"sshd": {"action": "iptables[port=ssh, protocol=tcp]"}}, {"sshd"} - ) + result = _build_action_to_jails_map({"sshd": {"action": "iptables[port=ssh, protocol=tcp]"}}, {"sshd"}) assert "iptables" in result def test_multiple_jails_sharing_action(self) -> None: @@ -2001,10 +1974,13 @@ class TestGetAction: async def test_raises_for_unknown_action(self, tmp_path: Path) -> None: from app.services.config_file_service import ActionNotFoundError, get_action - with patch( - "app.services.config_file_service._get_active_jail_names", - new=AsyncMock(return_value=set()), - ), pytest.raises(ActionNotFoundError): + with ( + patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ), + pytest.raises(ActionNotFoundError), + ): await get_action(str(tmp_path), "/fake.sock", "nonexistent") async def test_local_only_action_returned(self, tmp_path: Path) -> None: @@ -2118,10 +2094,13 @@ class TestUpdateAction: from app.models.config import ActionUpdateRequest from app.services.config_file_service import ActionNotFoundError, update_action - with patch( - "app.services.config_file_service._get_active_jail_names", - new=AsyncMock(return_value=set()), - ), pytest.raises(ActionNotFoundError): + with ( + patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ), + pytest.raises(ActionNotFoundError), + ): await update_action( str(tmp_path), "/fake.sock", @@ -2587,9 +2566,7 @@ class TestRemoveActionFromJail: "app.services.config_file_service._get_active_jail_names", new=AsyncMock(return_value=set()), ): - await remove_action_from_jail( - str(tmp_path), "/fake.sock", "sshd", "iptables-multiport" - ) + await remove_action_from_jail(str(tmp_path), "/fake.sock", "sshd", "iptables-multiport") content = (jail_d / "sshd.local").read_text() assert "iptables-multiport" not in content @@ -2601,17 +2578,13 @@ class TestRemoveActionFromJail: ) with pytest.raises(JailNotFoundInConfigError): - await remove_action_from_jail( - str(tmp_path), "/fake.sock", "nonexistent", "iptables" - ) + await remove_action_from_jail(str(tmp_path), "/fake.sock", "nonexistent", "iptables") async def test_raises_jail_name_error(self, tmp_path: Path) -> None: from app.services.config_file_service import JailNameError, remove_action_from_jail with pytest.raises(JailNameError): - await remove_action_from_jail( - str(tmp_path), "/fake.sock", "../evil", "iptables" - ) + await remove_action_from_jail(str(tmp_path), "/fake.sock", "../evil", "iptables") async def test_raises_action_name_error(self, tmp_path: Path) -> None: from app.services.config_file_service import ActionNameError, remove_action_from_jail @@ -2619,9 +2592,7 @@ class TestRemoveActionFromJail: _write(tmp_path / "jail.conf", JAIL_CONF) with pytest.raises(ActionNameError): - await remove_action_from_jail( - str(tmp_path), "/fake.sock", "sshd", "../evil" - ) + await remove_action_from_jail(str(tmp_path), "/fake.sock", "sshd", "../evil") async def test_triggers_reload_when_requested(self, tmp_path: Path) -> None: from app.services.config_file_service import remove_action_from_jail @@ -2640,9 +2611,7 @@ class TestRemoveActionFromJail: new=AsyncMock(), ) as mock_reload, ): - await remove_action_from_jail( - str(tmp_path), "/fake.sock", "sshd", "iptables", do_reload=True - ) + await remove_action_from_jail(str(tmp_path), "/fake.sock", "sshd", "iptables", do_reload=True) mock_reload.assert_awaited_once() @@ -2680,13 +2649,9 @@ class TestActivateJailReloadArgs: mock_js.reload_all = AsyncMock() await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req) - mock_js.reload_all.assert_awaited_once_with( - "/fake.sock", include_jails=["apache-auth"] - ) + mock_js.reload_all.assert_awaited_once_with("/fake.sock", include_jails=["apache-auth"]) - async def test_activate_returns_active_true_when_jail_starts( - self, tmp_path: Path - ) -> None: + async def test_activate_returns_active_true_when_jail_starts(self, tmp_path: Path) -> None: """activate_jail returns active=True when the jail appears in post-reload names.""" _write(tmp_path / "jail.conf", JAIL_CONF) from app.models.config import ActivateJailRequest, JailValidationResult @@ -2708,16 +2673,12 @@ class TestActivateJailReloadArgs: ), ): mock_js.reload_all = AsyncMock() - result = await activate_jail( - str(tmp_path), "/fake.sock", "apache-auth", req - ) + result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req) assert result.active is True assert "activated" in result.message.lower() - async def test_activate_returns_active_false_when_jail_does_not_start( - self, tmp_path: Path - ) -> None: + async def test_activate_returns_active_false_when_jail_does_not_start(self, tmp_path: Path) -> None: """activate_jail returns active=False when the jail is absent after reload. This covers the Stage 3.1 requirement: if the jail config is invalid @@ -2746,9 +2707,7 @@ class TestActivateJailReloadArgs: ), ): mock_js.reload_all = AsyncMock() - result = await activate_jail( - str(tmp_path), "/fake.sock", "apache-auth", req - ) + result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req) assert result.active is False assert "apache-auth" in result.name @@ -2776,23 +2735,13 @@ class TestDeactivateJailReloadArgs: mock_js.reload_all = AsyncMock() await deactivate_jail(str(tmp_path), "/fake.sock", "sshd") - mock_js.reload_all.assert_awaited_once_with( - "/fake.sock", exclude_jails=["sshd"] - ) + mock_js.reload_all.assert_awaited_once_with("/fake.sock", exclude_jails=["sshd"]) # --------------------------------------------------------------------------- # _validate_jail_config_sync (Task 3) # --------------------------------------------------------------------------- -from app.services.config_file_service import ( # noqa: E402 (added after block) - _validate_jail_config_sync, - _extract_filter_base_name, - _extract_action_base_name, - validate_jail_config, - rollback_jail, -) - class TestExtractFilterBaseName: def test_plain_name(self) -> None: @@ -2938,11 +2887,11 @@ class TestRollbackJail: with ( patch( - "app.services.config_file_service._start_daemon", + "app.services.config_file_service.start_daemon", new=AsyncMock(return_value=True), ), patch( - "app.services.config_file_service._wait_for_fail2ban", + "app.services.config_file_service.wait_for_fail2ban", new=AsyncMock(return_value=True), ), patch( @@ -2950,9 +2899,7 @@ class TestRollbackJail: new=AsyncMock(return_value=set()), ), ): - result = await rollback_jail( - str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"] - ) + result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]) assert result.disabled is True assert result.fail2ban_running is True @@ -2968,26 +2915,22 @@ class TestRollbackJail: with ( patch( - "app.services.config_file_service._start_daemon", + "app.services.config_file_service.start_daemon", new=AsyncMock(return_value=False), ), patch( - "app.services.config_file_service._wait_for_fail2ban", + "app.services.config_file_service.wait_for_fail2ban", new=AsyncMock(return_value=False), ), ): - result = await rollback_jail( - str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"] - ) + result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]) assert result.fail2ban_running is False assert result.disabled is True async def test_rollback_rejects_unsafe_name(self, tmp_path: Path) -> None: with pytest.raises(JailNameError): - await rollback_jail( - str(tmp_path), "/fake.sock", "../evil", ["fail2ban-client", "start"] - ) + await rollback_jail(str(tmp_path), "/fake.sock", "../evil", ["fail2ban-client", "start"]) # --------------------------------------------------------------------------- @@ -3096,9 +3039,7 @@ class TestActivateJailBlocking: class TestActivateJailRollback: """Rollback logic in activate_jail restores the .local file and recovers.""" - async def test_activate_jail_rollback_on_reload_failure( - self, tmp_path: Path - ) -> None: + async def test_activate_jail_rollback_on_reload_failure(self, tmp_path: Path) -> None: """Rollback when reload_all raises on the activation reload. Expects: @@ -3135,23 +3076,17 @@ class TestActivateJailRollback: ), patch( "app.services.config_file_service._validate_jail_config_sync", - return_value=JailValidationResult( - jail_name="apache-auth", valid=True - ), + return_value=JailValidationResult(jail_name="apache-auth", valid=True), ), ): mock_js.reload_all = AsyncMock(side_effect=reload_side_effect) - result = await activate_jail( - str(tmp_path), "/fake.sock", "apache-auth", req - ) + result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req) assert result.active is False assert result.recovered is True assert local_path.read_text() == original_local - async def test_activate_jail_rollback_on_health_check_failure( - self, tmp_path: Path - ) -> None: + async def test_activate_jail_rollback_on_health_check_failure(self, tmp_path: Path) -> None: """Rollback when fail2ban is unreachable after the activation reload. Expects: @@ -3190,15 +3125,11 @@ class TestActivateJailRollback: ), patch( "app.services.config_file_service._validate_jail_config_sync", - return_value=JailValidationResult( - jail_name="apache-auth", valid=True - ), + return_value=JailValidationResult(jail_name="apache-auth", valid=True), ), ): mock_js.reload_all = AsyncMock() - result = await activate_jail( - str(tmp_path), "/fake.sock", "apache-auth", req - ) + result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req) assert result.active is False assert result.recovered is True @@ -3232,25 +3163,17 @@ class TestActivateJailRollback: ), patch( "app.services.config_file_service._validate_jail_config_sync", - return_value=JailValidationResult( - jail_name="apache-auth", valid=True - ), + return_value=JailValidationResult(jail_name="apache-auth", valid=True), ), ): # Both the activation reload and the recovery reload fail. - mock_js.reload_all = AsyncMock( - side_effect=RuntimeError("fail2ban unavailable") - ) - result = await activate_jail( - str(tmp_path), "/fake.sock", "apache-auth", req - ) + mock_js.reload_all = AsyncMock(side_effect=RuntimeError("fail2ban unavailable")) + result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req) assert result.active is False assert result.recovered is False - async def test_activate_jail_rollback_on_jail_not_found_error( - self, tmp_path: Path - ) -> None: + async def test_activate_jail_rollback_on_jail_not_found_error(self, tmp_path: Path) -> None: """Rollback when reload_all raises JailNotFoundError (invalid config). When fail2ban cannot create a jail due to invalid configuration @@ -3294,16 +3217,12 @@ class TestActivateJailRollback: ), patch( "app.services.config_file_service._validate_jail_config_sync", - return_value=JailValidationResult( - jail_name="apache-auth", valid=True - ), + return_value=JailValidationResult(jail_name="apache-auth", valid=True), ), ): mock_js.reload_all = AsyncMock(side_effect=reload_side_effect) mock_js.JailNotFoundError = JailNotFoundError - result = await activate_jail( - str(tmp_path), "/fake.sock", "apache-auth", req - ) + result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req) assert result.active is False assert result.recovered is True @@ -3311,9 +3230,7 @@ class TestActivateJailRollback: # Verify the error message mentions logpath issues. assert "logpath" in result.message.lower() or "check that all logpath" in result.message.lower() - async def test_activate_jail_rollback_deletes_file_when_no_prior_local( - self, tmp_path: Path - ) -> None: + async def test_activate_jail_rollback_deletes_file_when_no_prior_local(self, tmp_path: Path) -> None: """Rollback deletes the .local file when none existed before activation. When a jail had no .local override before activation, activate_jail @@ -3355,15 +3272,11 @@ class TestActivateJailRollback: ), patch( "app.services.config_file_service._validate_jail_config_sync", - return_value=JailValidationResult( - jail_name="apache-auth", valid=True - ), + return_value=JailValidationResult(jail_name="apache-auth", valid=True), ), ): mock_js.reload_all = AsyncMock(side_effect=reload_side_effect) - result = await activate_jail( - str(tmp_path), "/fake.sock", "apache-auth", req - ) + result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req) assert result.active is False assert result.recovered is True @@ -3376,7 +3289,7 @@ class TestActivateJailRollback: @pytest.mark.asyncio -class TestRollbackJail: +class TestRollbackJailIntegration: """Integration tests for :func:`~app.services.config_file_service.rollback_jail`.""" async def test_local_file_written_enabled_false(self, tmp_path: Path) -> None: @@ -3419,15 +3332,11 @@ class TestRollbackJail: AsyncMock(return_value={"other"}), ), ): - await rollback_jail( - str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"] - ) + await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]) mock_start.assert_awaited_once_with(["fail2ban-client", "start"]) - async def test_fail2ban_running_reflects_socket_probe_not_subprocess_exit( - self, tmp_path: Path - ) -> None: + async def test_fail2ban_running_reflects_socket_probe_not_subprocess_exit(self, tmp_path: Path) -> None: """fail2ban_running in the response reflects the socket probe result. Even when start_daemon returns True (subprocess exit 0), if the socket @@ -3443,15 +3352,11 @@ class TestRollbackJail: AsyncMock(return_value=False), # socket still unresponsive ), ): - result = await rollback_jail( - str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"] - ) + result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]) assert result.fail2ban_running is False - async def test_active_jails_zero_when_fail2ban_not_running( - self, tmp_path: Path - ) -> None: + async def test_active_jails_zero_when_fail2ban_not_running(self, tmp_path: Path) -> None: """active_jails is 0 in the response when fail2ban_running is False.""" with ( patch( @@ -3463,15 +3368,11 @@ class TestRollbackJail: AsyncMock(return_value=False), ), ): - result = await rollback_jail( - str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"] - ) + result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]) assert result.active_jails == 0 - async def test_active_jails_count_from_socket_when_running( - self, tmp_path: Path - ) -> None: + async def test_active_jails_count_from_socket_when_running(self, tmp_path: Path) -> None: """active_jails reflects the actual jail count from the socket when fail2ban is up.""" with ( patch( @@ -3487,15 +3388,11 @@ class TestRollbackJail: AsyncMock(return_value={"sshd", "nginx", "apache-auth"}), ), ): - result = await rollback_jail( - str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"] - ) + result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]) assert result.active_jails == 3 - async def test_fail2ban_down_at_start_still_succeeds_file_write( - self, tmp_path: Path - ) -> None: + async def test_fail2ban_down_at_start_still_succeeds_file_write(self, tmp_path: Path) -> None: """rollback_jail writes the local file even when fail2ban is down at call time.""" # fail2ban is down: start_daemon fails and wait_for_fail2ban returns False. with ( @@ -3508,12 +3405,9 @@ class TestRollbackJail: AsyncMock(return_value=False), ), ): - result = await rollback_jail( - str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"] - ) + result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]) local = tmp_path / "jail.d" / "sshd.local" assert local.is_file(), "local file must be written even when fail2ban is down" assert result.disabled is True assert result.fail2ban_running is False - diff --git a/backend/tests/test_services/test_geo_service.py b/backend/tests/test_services/test_geo_service.py index f400059..ceb7469 100644 --- a/backend/tests/test_services/test_geo_service.py +++ b/backend/tests/test_services/test_geo_service.py @@ -2,6 +2,7 @@ from __future__ import annotations +from collections.abc import Mapping, Sequence from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -44,7 +45,7 @@ def _make_session(response_json: dict[str, object], status: int = 200) -> MagicM @pytest.fixture(autouse=True) -def clear_geo_cache() -> None: # type: ignore[misc] +def clear_geo_cache() -> None: """Flush the module-level geo cache before every test.""" geo_service.clear_cache() @@ -68,7 +69,7 @@ class TestLookupSuccess: "org": "AS3320 Deutsche Telekom AG", } ) - result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] + result = await geo_service.lookup("1.2.3.4", session) assert result is not None assert result.country_code == "DE" @@ -84,7 +85,7 @@ class TestLookupSuccess: "org": "Google LLC", } ) - result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type] + result = await geo_service.lookup("8.8.8.8", session) assert result is not None assert result.country_name == "United States" @@ -100,7 +101,7 @@ class TestLookupSuccess: "org": "Deutsche Telekom", } ) - result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] + result = await geo_service.lookup("1.2.3.4", session) assert result is not None assert result.asn == "AS3320" @@ -116,7 +117,7 @@ class TestLookupSuccess: "org": "Google LLC", } ) - result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type] + result = await geo_service.lookup("8.8.8.8", session) assert result is not None assert result.org == "Google LLC" @@ -142,8 +143,8 @@ class TestLookupCaching: } ) - await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] - await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] + await geo_service.lookup("1.2.3.4", session) + await geo_service.lookup("1.2.3.4", session) # The session.get() should only have been called once. assert session.get.call_count == 1 @@ -160,9 +161,9 @@ class TestLookupCaching: } ) - await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type] + await geo_service.lookup("2.3.4.5", session) geo_service.clear_cache() - await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type] + await geo_service.lookup("2.3.4.5", session) assert session.get.call_count == 2 @@ -172,8 +173,8 @@ class TestLookupCaching: {"status": "fail", "message": "reserved range"} ) - await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type] - await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type] + await geo_service.lookup("192.168.1.1", session) + await geo_service.lookup("192.168.1.1", session) # Second call is blocked by the negative cache — only one API hit. assert session.get.call_count == 1 @@ -190,7 +191,7 @@ class TestLookupFailures: async def test_non_200_response_returns_null_geo_info(self) -> None: """A 429 or 500 status returns GeoInfo with null fields (not None).""" session = _make_session({}, status=429) - result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] + result = await geo_service.lookup("1.2.3.4", session) assert result is not None assert isinstance(result, GeoInfo) assert result.country_code is None @@ -203,7 +204,7 @@ class TestLookupFailures: mock_ctx.__aexit__ = AsyncMock(return_value=False) session.get = MagicMock(return_value=mock_ctx) - result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type] + result = await geo_service.lookup("10.0.0.1", session) assert result is not None assert isinstance(result, GeoInfo) assert result.country_code is None @@ -211,7 +212,7 @@ class TestLookupFailures: async def test_failed_status_returns_geo_info_with_nulls(self) -> None: """When ip-api returns ``status=fail`` a GeoInfo with null fields is returned (but not cached).""" session = _make_session({"status": "fail", "message": "private range"}) - result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type] + result = await geo_service.lookup("10.0.0.1", session) assert result is not None assert isinstance(result, GeoInfo) @@ -231,8 +232,8 @@ class TestNegativeCache: """After a failed lookup the second call is served from the neg cache.""" session = _make_session({"status": "fail", "message": "private range"}) - r1 = await geo_service.lookup("192.0.2.1", session) # type: ignore[arg-type] - r2 = await geo_service.lookup("192.0.2.1", session) # type: ignore[arg-type] + r1 = await geo_service.lookup("192.0.2.1", session) + r2 = await geo_service.lookup("192.0.2.1", session) # Only one HTTP call should have been made; second served from neg cache. assert session.get.call_count == 1 @@ -243,12 +244,12 @@ class TestNegativeCache: """When the neg-cache entry is older than the TTL a new API call is made.""" session = _make_session({"status": "fail", "message": "private range"}) - await geo_service.lookup("192.0.2.2", session) # type: ignore[arg-type] + await geo_service.lookup("192.0.2.2", session) # Manually expire the neg-cache entry. - geo_service._neg_cache["192.0.2.2"] -= geo_service._NEG_CACHE_TTL + 1 # type: ignore[attr-defined] + geo_service._neg_cache["192.0.2.2"] -= geo_service._NEG_CACHE_TTL + 1 - await geo_service.lookup("192.0.2.2", session) # type: ignore[arg-type] + await geo_service.lookup("192.0.2.2", session) # Both calls should have hit the API. assert session.get.call_count == 2 @@ -257,9 +258,9 @@ class TestNegativeCache: """After clearing the neg cache the IP is eligible for a new API call.""" session = _make_session({"status": "fail", "message": "private range"}) - await geo_service.lookup("192.0.2.3", session) # type: ignore[arg-type] + await geo_service.lookup("192.0.2.3", session) geo_service.clear_neg_cache() - await geo_service.lookup("192.0.2.3", session) # type: ignore[arg-type] + await geo_service.lookup("192.0.2.3", session) assert session.get.call_count == 2 @@ -275,9 +276,9 @@ class TestNegativeCache: } ) - await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] + await geo_service.lookup("1.2.3.4", session) - assert "1.2.3.4" not in geo_service._neg_cache # type: ignore[attr-defined] + assert "1.2.3.4" not in geo_service._neg_cache # --------------------------------------------------------------------------- @@ -307,7 +308,7 @@ class TestGeoipFallback: mock_reader = self._make_geoip_reader("DE", "Germany") with patch.object(geo_service, "_geoip_reader", mock_reader): - result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] + result = await geo_service.lookup("1.2.3.4", session) mock_reader.country.assert_called_once_with("1.2.3.4") assert result is not None @@ -320,12 +321,12 @@ class TestGeoipFallback: mock_reader = self._make_geoip_reader("US", "United States") with patch.object(geo_service, "_geoip_reader", mock_reader): - await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type] + await geo_service.lookup("8.8.8.8", session) # Second call must be served from positive cache without hitting API. - await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type] + await geo_service.lookup("8.8.8.8", session) assert session.get.call_count == 1 - assert "8.8.8.8" in geo_service._cache # type: ignore[attr-defined] + assert "8.8.8.8" in geo_service._cache async def test_geoip_fallback_not_called_on_api_success(self) -> None: """When ip-api succeeds, the geoip2 reader must not be consulted.""" @@ -341,7 +342,7 @@ class TestGeoipFallback: mock_reader = self._make_geoip_reader("XX", "Nowhere") with patch.object(geo_service, "_geoip_reader", mock_reader): - result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] + result = await geo_service.lookup("1.2.3.4", session) mock_reader.country.assert_not_called() assert result is not None @@ -352,7 +353,7 @@ class TestGeoipFallback: session = _make_session({"status": "fail", "message": "private range"}) with patch.object(geo_service, "_geoip_reader", None): - result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type] + result = await geo_service.lookup("10.0.0.1", session) assert result is not None assert result.country_code is None @@ -363,7 +364,7 @@ class TestGeoipFallback: # --------------------------------------------------------------------------- -def _make_batch_session(batch_response: list[dict[str, object]]) -> MagicMock: +def _make_batch_session(batch_response: Sequence[Mapping[str, object]]) -> MagicMock: """Build a mock aiohttp.ClientSession for batch POST calls. Args: @@ -412,7 +413,7 @@ class TestLookupBatchSingleCommit: session = _make_batch_session(batch_response) db = _make_async_db() - await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type] + await geo_service.lookup_batch(ips, session, db=db) db.commit.assert_awaited_once() @@ -426,7 +427,7 @@ class TestLookupBatchSingleCommit: session = _make_batch_session(batch_response) db = _make_async_db() - await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type] + await geo_service.lookup_batch(ips, session, db=db) db.commit.assert_awaited_once() @@ -452,13 +453,13 @@ class TestLookupBatchSingleCommit: async def test_no_commit_for_all_cached_ips(self) -> None: """When all IPs are already cached, no HTTP call and no commit occur.""" - geo_service._cache["5.5.5.5"] = GeoInfo( # type: ignore[attr-defined] + geo_service._cache["5.5.5.5"] = GeoInfo( country_code="FR", country_name="France", asn="AS1", org="ISP" ) db = _make_async_db() session = _make_batch_session([]) - result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db) # type: ignore[arg-type] + result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db) assert result["5.5.5.5"].country_code == "FR" db.commit.assert_not_awaited() @@ -476,26 +477,26 @@ class TestDirtySetTracking: def test_successful_resolution_adds_to_dirty(self) -> None: """Storing a GeoInfo with a country_code adds the IP to _dirty.""" info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP") - geo_service._store("1.2.3.4", info) # type: ignore[attr-defined] + geo_service._store("1.2.3.4", info) - assert "1.2.3.4" in geo_service._dirty # type: ignore[attr-defined] + assert "1.2.3.4" in geo_service._dirty def test_null_country_does_not_add_to_dirty(self) -> None: """Storing a GeoInfo with country_code=None must not pollute _dirty.""" info = GeoInfo(country_code=None, country_name=None, asn=None, org=None) - geo_service._store("10.0.0.1", info) # type: ignore[attr-defined] + geo_service._store("10.0.0.1", info) - assert "10.0.0.1" not in geo_service._dirty # type: ignore[attr-defined] + assert "10.0.0.1" not in geo_service._dirty def test_clear_cache_also_clears_dirty(self) -> None: """clear_cache() must discard any pending dirty entries.""" info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP") - geo_service._store("8.8.8.8", info) # type: ignore[attr-defined] - assert geo_service._dirty # type: ignore[attr-defined] + geo_service._store("8.8.8.8", info) + assert geo_service._dirty geo_service.clear_cache() - assert not geo_service._dirty # type: ignore[attr-defined] + assert not geo_service._dirty async def test_lookup_batch_populates_dirty(self) -> None: """After lookup_batch() with db=None, resolved IPs appear in _dirty.""" @@ -509,7 +510,7 @@ class TestDirtySetTracking: await geo_service.lookup_batch(ips, session, db=None) for ip in ips: - assert ip in geo_service._dirty # type: ignore[attr-defined] + assert ip in geo_service._dirty class TestFlushDirty: @@ -518,8 +519,8 @@ class TestFlushDirty: async def test_flush_writes_and_clears_dirty(self) -> None: """flush_dirty() inserts all dirty IPs and clears _dirty afterwards.""" info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT") - geo_service._store("100.0.0.1", info) # type: ignore[attr-defined] - assert "100.0.0.1" in geo_service._dirty # type: ignore[attr-defined] + geo_service._store("100.0.0.1", info) + assert "100.0.0.1" in geo_service._dirty db = _make_async_db() count = await geo_service.flush_dirty(db) @@ -527,7 +528,7 @@ class TestFlushDirty: assert count == 1 db.executemany.assert_awaited_once() db.commit.assert_awaited_once() - assert "100.0.0.1" not in geo_service._dirty # type: ignore[attr-defined] + assert "100.0.0.1" not in geo_service._dirty async def test_flush_returns_zero_when_nothing_dirty(self) -> None: """flush_dirty() returns 0 and makes no DB calls when _dirty is empty.""" @@ -541,7 +542,7 @@ class TestFlushDirty: async def test_flush_re_adds_to_dirty_on_db_error(self) -> None: """When the DB write fails, entries are re-added to _dirty for retry.""" info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP") - geo_service._store("200.0.0.1", info) # type: ignore[attr-defined] + geo_service._store("200.0.0.1", info) db = _make_async_db() db.executemany = AsyncMock(side_effect=OSError("disk full")) @@ -549,7 +550,7 @@ class TestFlushDirty: count = await geo_service.flush_dirty(db) assert count == 0 - assert "200.0.0.1" in geo_service._dirty # type: ignore[attr-defined] + assert "200.0.0.1" in geo_service._dirty async def test_flush_batch_and_lookup_batch_integration(self) -> None: """lookup_batch() populates _dirty; flush_dirty() then persists them.""" @@ -562,14 +563,14 @@ class TestFlushDirty: # Resolve without DB to populate only in-memory cache and _dirty. await geo_service.lookup_batch(ips, session, db=None) - assert geo_service._dirty == set(ips) # type: ignore[attr-defined] + assert geo_service._dirty == set(ips) # Now flush to the DB. db = _make_async_db() count = await geo_service.flush_dirty(db) assert count == 2 - assert not geo_service._dirty # type: ignore[attr-defined] + assert not geo_service._dirty db.commit.assert_awaited_once() @@ -585,7 +586,7 @@ class TestLookupBatchThrottling: """When more than _BATCH_SIZE IPs are sent, asyncio.sleep is called between consecutive batch HTTP calls with at least _BATCH_DELAY.""" # Generate _BATCH_SIZE + 1 IPs so we get exactly 2 batch calls. - batch_size: int = geo_service._BATCH_SIZE # type: ignore[attr-defined] + batch_size: int = geo_service._BATCH_SIZE ips = [f"10.0.{i // 256}.{i % 256}" for i in range(batch_size + 1)] def _make_result(chunk: list[str], _session: object) -> dict[str, GeoInfo]: @@ -608,7 +609,7 @@ class TestLookupBatchThrottling: assert mock_batch.call_count == 2 mock_sleep.assert_awaited_once() delay_arg: float = mock_sleep.call_args[0][0] - assert delay_arg >= geo_service._BATCH_DELAY # type: ignore[attr-defined] + assert delay_arg >= geo_service._BATCH_DELAY async def test_lookup_batch_retries_on_full_chunk_failure(self) -> None: """When a chunk returns all-None on first try, it retries and succeeds.""" @@ -650,7 +651,7 @@ class TestLookupBatchThrottling: _empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None) _failure: dict[str, GeoInfo] = dict.fromkeys(ips, _empty) - max_retries: int = geo_service._BATCH_MAX_RETRIES # type: ignore[attr-defined] + max_retries: int = geo_service._BATCH_MAX_RETRIES with ( patch( @@ -667,11 +668,11 @@ class TestLookupBatchThrottling: # IP should have no country. assert result["9.9.9.9"].country_code is None # Negative cache should contain the IP. - assert "9.9.9.9" in geo_service._neg_cache # type: ignore[attr-defined] + assert "9.9.9.9" in geo_service._neg_cache # Sleep called for each retry with exponential backoff. assert mock_sleep.call_count == max_retries backoff_values = [call.args[0] for call in mock_sleep.call_args_list] - batch_delay: float = geo_service._BATCH_DELAY # type: ignore[attr-defined] + batch_delay: float = geo_service._BATCH_DELAY for i, val in enumerate(backoff_values): expected = batch_delay * (2 ** (i + 1)) assert val == pytest.approx(expected) @@ -709,7 +710,7 @@ class TestErrorLogging: import structlog.testing with structlog.testing.capture_logs() as captured: - result = await geo_service.lookup("197.221.98.153", session) # type: ignore[arg-type] + result = await geo_service.lookup("197.221.98.153", session) assert result is not None assert result.country_code is None @@ -733,7 +734,7 @@ class TestErrorLogging: import structlog.testing with structlog.testing.capture_logs() as captured: - await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type] + await geo_service.lookup("10.0.0.1", session) request_failed = [e for e in captured if e.get("event") == "geo_lookup_request_failed"] assert len(request_failed) == 1 @@ -757,7 +758,7 @@ class TestErrorLogging: import structlog.testing with structlog.testing.capture_logs() as captured: - result = await geo_service._batch_api_call(["1.2.3.4"], session) # type: ignore[attr-defined] + result = await geo_service._batch_api_call(["1.2.3.4"], session) assert result["1.2.3.4"].country_code is None @@ -778,7 +779,7 @@ class TestLookupCachedOnly: def test_returns_cached_ips(self) -> None: """IPs already in the cache are returned in the geo_map.""" - geo_service._cache["1.1.1.1"] = GeoInfo( # type: ignore[attr-defined] + geo_service._cache["1.1.1.1"] = GeoInfo( country_code="AU", country_name="Australia", asn="AS13335", org="Cloudflare" ) geo_map, uncached = geo_service.lookup_cached_only(["1.1.1.1"]) @@ -798,7 +799,7 @@ class TestLookupCachedOnly: """IPs in the negative cache within TTL are not re-queued as uncached.""" import time - geo_service._neg_cache["10.0.0.1"] = time.monotonic() # type: ignore[attr-defined] + geo_service._neg_cache["10.0.0.1"] = time.monotonic() geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.1"]) @@ -807,7 +808,7 @@ class TestLookupCachedOnly: def test_expired_neg_cache_requeued(self) -> None: """IPs whose neg-cache entry has expired are listed as uncached.""" - geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired # type: ignore[attr-defined] + geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired _geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.2"]) @@ -815,12 +816,12 @@ class TestLookupCachedOnly: def test_mixed_ips(self) -> None: """A mix of cached, neg-cached, and unknown IPs is split correctly.""" - geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined] + geo_service._cache["1.2.3.4"] = GeoInfo( country_code="DE", country_name="Germany", asn=None, org=None ) import time - geo_service._neg_cache["5.5.5.5"] = time.monotonic() # type: ignore[attr-defined] + geo_service._neg_cache["5.5.5.5"] = time.monotonic() geo_map, uncached = geo_service.lookup_cached_only(["1.2.3.4", "5.5.5.5", "9.9.9.9"]) @@ -829,7 +830,7 @@ class TestLookupCachedOnly: def test_deduplication(self) -> None: """Duplicate IPs in the input appear at most once in the output.""" - geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined] + geo_service._cache["1.2.3.4"] = GeoInfo( country_code="US", country_name="United States", asn=None, org=None ) @@ -866,7 +867,7 @@ class TestLookupBatchBulkWrites: session = _make_batch_session(batch_response) db = _make_async_db() - await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type] + await geo_service.lookup_batch(ips, session, db=db) # One executemany for the positive rows. assert db.executemany.await_count >= 1 @@ -883,7 +884,7 @@ class TestLookupBatchBulkWrites: session = _make_batch_session(batch_response) db = _make_async_db() - await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type] + await geo_service.lookup_batch(ips, session, db=db) assert db.executemany.await_count >= 1 db.execute.assert_not_awaited() @@ -905,7 +906,7 @@ class TestLookupBatchBulkWrites: session = _make_batch_session(batch_response) db = _make_async_db() - await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type] + await geo_service.lookup_batch(ips, session, db=db) # One executemany for positives, one for negatives. assert db.executemany.await_count == 2 diff --git a/backend/tests/test_services/test_history_service.py b/backend/tests/test_services/test_history_service.py index 425fbc0..508a69f 100644 --- a/backend/tests/test_services/test_history_service.py +++ b/backend/tests/test_services/test_history_service.py @@ -64,7 +64,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None: @pytest.fixture -async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc] +async def f2b_db_path(tmp_path: Path) -> str: """Return the path to a test fail2ban SQLite database.""" path = str(tmp_path / "fail2ban_test.sqlite3") await _create_f2b_db( diff --git a/backend/tests/test_services/test_jail_service.py b/backend/tests/test_services/test_jail_service.py index 4afb718..0824332 100644 --- a/backend/tests/test_services/test_jail_service.py +++ b/backend/tests/test_services/test_jail_service.py @@ -996,9 +996,6 @@ class TestGetJailBannedIps: async def test_unknown_jail_raises_jail_not_found_error(self) -> None: """get_jail_banned_ips raises JailNotFoundError for unknown jail.""" - responses = { - "status|ghost|short": (0, pytest.raises), # will be overridden - } # Simulate fail2ban returning an "unknown jail" error. class _FakeClient: def __init__(self, **_kw: Any) -> None: diff --git a/backend/tests/test_tasks/test_health_check.py b/backend/tests/test_tasks/test_health_check.py index 4a8512b..0af33f1 100644 --- a/backend/tests/test_tasks/test_health_check.py +++ b/backend/tests/test_tasks/test_health_check.py @@ -270,7 +270,7 @@ class TestCrashDetection: async def test_crash_within_window_creates_pending_recovery(self) -> None: """An online→offline transition within 60 s of activation must set pending_recovery.""" app = _make_app(prev_online=True) - now = datetime.datetime.now(tz=datetime.timezone.utc) + now = datetime.datetime.now(tz=datetime.UTC) app.state.last_activation = { "jail_name": "sshd", "at": now - datetime.timedelta(seconds=10), @@ -297,7 +297,7 @@ class TestCrashDetection: app = _make_app(prev_online=True) app.state.last_activation = { "jail_name": "sshd", - "at": datetime.datetime.now(tz=datetime.timezone.utc) + "at": datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=120), } app.state.pending_recovery = None @@ -315,8 +315,8 @@ class TestCrashDetection: async def test_came_online_marks_pending_recovery_resolved(self) -> None: """An offline→online transition must mark an existing pending_recovery as recovered.""" app = _make_app(prev_online=False) - activated_at = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(seconds=30) - detected_at = datetime.datetime.now(tz=datetime.timezone.utc) + activated_at = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=30) + detected_at = datetime.datetime.now(tz=datetime.UTC) app.state.pending_recovery = PendingRecovery( jail_name="sshd", activated_at=activated_at,