refactor: improve backend type safety and import organization

- Add TYPE_CHECKING guards for runtime-expensive imports (aiohttp, aiosqlite)
- Reorganize imports to follow PEP 8 conventions
- Convert TypeAlias to modern PEP 695 type syntax (where appropriate)
- Use Sequence/Mapping from collections.abc for type hints (covariant)
- Replace string literals with cast() for improved type inference
- Fix casting of Fail2BanResponse and TypedDict patterns
- Add IpLookupResult TypedDict for precise return type annotation
- Reformat overlong lines for readability (120 char limit)
- Add asyncio_mode and filterwarnings to pytest config
- Update test fixtures with improved type hints

This improves mypy type checking and makes type relationships explicit.
This commit is contained in:
2026-03-20 13:44:14 +01:00
parent bdcdd5d672
commit 1c0bac1353
30 changed files with 431 additions and 644 deletions

View File

@@ -85,4 +85,4 @@ def get_settings() -> Settings:
A validated :class:`Settings` object. Raises :class:`pydantic.ValidationError` A validated :class:`Settings` object. Raises :class:`pydantic.ValidationError`
if required keys are absent or values fail validation. if required keys are absent or values fail validation.
""" """
return Settings() # pydantic-settings populates required fields from env vars return Settings() # type: ignore[call-arg] # pydantic-settings populates required fields from env vars

View File

@@ -92,7 +92,7 @@ async def get_settings(request: Request) -> Settings:
Returns: Returns:
The application settings loaded at startup. The application settings loaded at startup.
""" """
state = cast(AppState, request.app.state) state = cast("AppState", request.app.state)
return state.settings return state.settings

View File

@@ -12,6 +12,8 @@ from __future__ import annotations
from typing import TYPE_CHECKING, TypedDict from typing import TYPE_CHECKING, TypedDict
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Sequence
import aiosqlite import aiosqlite
@@ -112,7 +114,7 @@ async def upsert_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
async def bulk_upsert_entries( async def bulk_upsert_entries(
db: aiosqlite.Connection, db: aiosqlite.Connection,
rows: list[tuple[str, str | None, str | None, str | None, str | None]], rows: Sequence[tuple[str, str | None, str | None, str | None, str | None]],
) -> int: ) -> int:
"""Bulk insert or update multiple geo cache entries.""" """Bulk insert or update multiple geo cache entries."""
if not rows: if not rows:

View File

@@ -8,10 +8,11 @@ table. All methods are plain async functions that accept a
from __future__ import annotations from __future__ import annotations
import math import math
from collections.abc import Mapping
from typing import TYPE_CHECKING, TypedDict, cast from typing import TYPE_CHECKING, TypedDict, cast
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Mapping
import aiosqlite import aiosqlite
@@ -165,5 +166,5 @@ def _row_to_dict(row: object) -> ImportLogRow:
Returns: Returns:
Dict mapping column names to Python values. Dict mapping column names to Python values.
""" """
mapping = cast(Mapping[str, object], row) mapping = cast("Mapping[str, object]", row)
return cast(ImportLogRow, dict(mapping)) return cast("ImportLogRow", dict(mapping))

View File

@@ -44,8 +44,6 @@ import structlog
from fastapi import APIRouter, HTTPException, Path, Query, Request, status from fastapi import APIRouter, HTTPException, Path, Query, Request, status
from app.dependencies import AuthDep from app.dependencies import AuthDep
log: structlog.stdlib.BoundLogger = structlog.get_logger()
from app.models.config import ( from app.models.config import (
ActionConfig, ActionConfig,
ActionCreateRequest, ActionCreateRequest,
@@ -104,6 +102,8 @@ from app.services.jail_service import JailOperationError
from app.tasks.health_check import _run_probe from app.tasks.health_check import _run_probe
from app.utils.fail2ban_client import Fail2BanConnectionError from app.utils.fail2ban_client import Fail2BanConnectionError
log: structlog.stdlib.BoundLogger = structlog.get_logger()
router: APIRouter = APIRouter(prefix="/api/config", tags=["Config"]) router: APIRouter = APIRouter(prefix="/api/config", tags=["Config"])
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -428,9 +428,7 @@ async def restart_fail2ban(
await config_file_service.start_daemon(start_cmd_parts) await config_file_service.start_daemon(start_cmd_parts)
# Step 3: probe the socket until fail2ban is responsive or the budget expires. # Step 3: probe the socket until fail2ban is responsive or the budget expires.
fail2ban_running: bool = await config_file_service.wait_for_fail2ban( fail2ban_running: bool = await config_file_service.wait_for_fail2ban(socket_path, max_wait_seconds=10.0)
socket_path, max_wait_seconds=10.0
)
if not fail2ban_running: if not fail2ban_running:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
@@ -604,9 +602,7 @@ async def get_map_color_thresholds(
""" """
from app.services import setup_service from app.services import setup_service
high, medium, low = await setup_service.get_map_color_thresholds( high, medium, low = await setup_service.get_map_color_thresholds(request.app.state.db)
request.app.state.db
)
return MapColorThresholdsResponse( return MapColorThresholdsResponse(
threshold_high=high, threshold_high=high,
threshold_medium=medium, threshold_medium=medium,
@@ -696,9 +692,7 @@ async def activate_jail(
req = body if body is not None else ActivateJailRequest() req = body if body is not None else ActivateJailRequest()
try: try:
result = await config_file_service.activate_jail( result = await config_file_service.activate_jail(config_dir, socket_path, name, req)
config_dir, socket_path, name, req
)
except JailNameError as exc: except JailNameError as exc:
raise _bad_request(str(exc)) from exc raise _bad_request(str(exc)) from exc
except JailNotFoundInConfigError: except JailNotFoundInConfigError:
@@ -831,9 +825,7 @@ async def delete_jail_local_override(
socket_path: str = request.app.state.settings.fail2ban_socket socket_path: str = request.app.state.settings.fail2ban_socket
try: try:
await config_file_service.delete_jail_local_override( await config_file_service.delete_jail_local_override(config_dir, socket_path, name)
config_dir, socket_path, name
)
except JailNameError as exc: except JailNameError as exc:
raise _bad_request(str(exc)) from exc raise _bad_request(str(exc)) from exc
except JailNotFoundInConfigError: except JailNotFoundInConfigError:
@@ -952,9 +944,7 @@ async def rollback_jail(
start_cmd_parts: list[str] = start_cmd.split() start_cmd_parts: list[str] = start_cmd.split()
try: try:
result = await config_file_service.rollback_jail( result = await config_file_service.rollback_jail(config_dir, socket_path, name, start_cmd_parts)
config_dir, socket_path, name, start_cmd_parts
)
except JailNameError as exc: except JailNameError as exc:
raise _bad_request(str(exc)) from exc raise _bad_request(str(exc)) from exc
except ConfigWriteError as exc: except ConfigWriteError as exc:
@@ -1107,9 +1097,7 @@ async def update_filter(
config_dir: str = request.app.state.settings.fail2ban_config_dir config_dir: str = request.app.state.settings.fail2ban_config_dir
socket_path: str = request.app.state.settings.fail2ban_socket socket_path: str = request.app.state.settings.fail2ban_socket
try: try:
return await config_file_service.update_filter( return await config_file_service.update_filter(config_dir, socket_path, name, body, do_reload=reload)
config_dir, socket_path, name, body, do_reload=reload
)
except FilterNameError as exc: except FilterNameError as exc:
raise _bad_request(str(exc)) from exc raise _bad_request(str(exc)) from exc
except FilterNotFoundError: except FilterNotFoundError:
@@ -1159,9 +1147,7 @@ async def create_filter(
config_dir: str = request.app.state.settings.fail2ban_config_dir config_dir: str = request.app.state.settings.fail2ban_config_dir
socket_path: str = request.app.state.settings.fail2ban_socket socket_path: str = request.app.state.settings.fail2ban_socket
try: try:
return await config_file_service.create_filter( return await config_file_service.create_filter(config_dir, socket_path, body, do_reload=reload)
config_dir, socket_path, body, do_reload=reload
)
except FilterNameError as exc: except FilterNameError as exc:
raise _bad_request(str(exc)) from exc raise _bad_request(str(exc)) from exc
except FilterAlreadyExistsError as exc: except FilterAlreadyExistsError as exc:
@@ -1257,9 +1243,7 @@ async def assign_filter_to_jail(
config_dir: str = request.app.state.settings.fail2ban_config_dir config_dir: str = request.app.state.settings.fail2ban_config_dir
socket_path: str = request.app.state.settings.fail2ban_socket socket_path: str = request.app.state.settings.fail2ban_socket
try: try:
await config_file_service.assign_filter_to_jail( await config_file_service.assign_filter_to_jail(config_dir, socket_path, name, body, do_reload=reload)
config_dir, socket_path, name, body, do_reload=reload
)
except (JailNameError, FilterNameError) as exc: except (JailNameError, FilterNameError) as exc:
raise _bad_request(str(exc)) from exc raise _bad_request(str(exc)) from exc
except JailNotFoundInConfigError: except JailNotFoundInConfigError:
@@ -1403,9 +1387,7 @@ async def update_action(
config_dir: str = request.app.state.settings.fail2ban_config_dir config_dir: str = request.app.state.settings.fail2ban_config_dir
socket_path: str = request.app.state.settings.fail2ban_socket socket_path: str = request.app.state.settings.fail2ban_socket
try: try:
return await config_file_service.update_action( return await config_file_service.update_action(config_dir, socket_path, name, body, do_reload=reload)
config_dir, socket_path, name, body, do_reload=reload
)
except ActionNameError as exc: except ActionNameError as exc:
raise _bad_request(str(exc)) from exc raise _bad_request(str(exc)) from exc
except ActionNotFoundError: except ActionNotFoundError:
@@ -1451,9 +1433,7 @@ async def create_action(
config_dir: str = request.app.state.settings.fail2ban_config_dir config_dir: str = request.app.state.settings.fail2ban_config_dir
socket_path: str = request.app.state.settings.fail2ban_socket socket_path: str = request.app.state.settings.fail2ban_socket
try: try:
return await config_file_service.create_action( return await config_file_service.create_action(config_dir, socket_path, body, do_reload=reload)
config_dir, socket_path, body, do_reload=reload
)
except ActionNameError as exc: except ActionNameError as exc:
raise _bad_request(str(exc)) from exc raise _bad_request(str(exc)) from exc
except ActionAlreadyExistsError as exc: except ActionAlreadyExistsError as exc:
@@ -1546,9 +1526,7 @@ async def assign_action_to_jail(
config_dir: str = request.app.state.settings.fail2ban_config_dir config_dir: str = request.app.state.settings.fail2ban_config_dir
socket_path: str = request.app.state.settings.fail2ban_socket socket_path: str = request.app.state.settings.fail2ban_socket
try: try:
await config_file_service.assign_action_to_jail( await config_file_service.assign_action_to_jail(config_dir, socket_path, name, body, do_reload=reload)
config_dir, socket_path, name, body, do_reload=reload
)
except (JailNameError, ActionNameError) as exc: except (JailNameError, ActionNameError) as exc:
raise _bad_request(str(exc)) from exc raise _bad_request(str(exc)) from exc
except JailNotFoundInConfigError: except JailNotFoundInConfigError:
@@ -1597,9 +1575,7 @@ async def remove_action_from_jail(
config_dir: str = request.app.state.settings.fail2ban_config_dir config_dir: str = request.app.state.settings.fail2ban_config_dir
socket_path: str = request.app.state.settings.fail2ban_socket socket_path: str = request.app.state.settings.fail2ban_socket
try: try:
await config_file_service.remove_action_from_jail( await config_file_service.remove_action_from_jail(config_dir, socket_path, name, action_name, do_reload=reload)
config_dir, socket_path, name, action_name, do_reload=reload
)
except (JailNameError, ActionNameError) as exc: except (JailNameError, ActionNameError) as exc:
raise _bad_request(str(exc)) from exc raise _bad_request(str(exc)) from exc
except JailNotFoundInConfigError: except JailNotFoundInConfigError:
@@ -1689,4 +1665,3 @@ async def get_service_status(
return await config_service.get_service_status(socket_path) return await config_service.get_service_status(socket_path)
except Fail2BanConnectionError as exc: except Fail2BanConnectionError as exc:
raise _bad_gateway(exc) from exc raise _bad_gateway(exc) from exc

View File

@@ -13,12 +13,15 @@ from typing import TYPE_CHECKING, Annotated
if TYPE_CHECKING: if TYPE_CHECKING:
import aiohttp import aiohttp
from app.services.jail_service import IpLookupResult
import aiosqlite import aiosqlite
from fastapi import APIRouter, Depends, HTTPException, Path, Request, status from fastapi import APIRouter, Depends, HTTPException, Path, Request, status
from app.dependencies import AuthDep, get_db from app.dependencies import AuthDep, get_db
from app.models.geo import GeoCacheStatsResponse, GeoDetail, IpLookupResponse from app.models.geo import GeoCacheStatsResponse, GeoDetail, IpLookupResponse
from app.services import geo_service, jail_service from app.services import geo_service, jail_service
from app.services.geo_service import GeoInfo
from app.utils.fail2ban_client import Fail2BanConnectionError from app.utils.fail2ban_client import Fail2BanConnectionError
router: APIRouter = APIRouter(prefix="/api/geo", tags=["Geo"]) router: APIRouter = APIRouter(prefix="/api/geo", tags=["Geo"])
@@ -61,7 +64,7 @@ async def lookup_ip(
return await geo_service.lookup(addr, http_session) return await geo_service.lookup(addr, http_session)
try: try:
result = await jail_service.lookup_ip( result: IpLookupResult = await jail_service.lookup_ip(
socket_path, socket_path,
ip, ip,
geo_enricher=_enricher, geo_enricher=_enricher,
@@ -77,9 +80,9 @@ async def lookup_ip(
detail=f"Cannot reach fail2ban: {exc}", detail=f"Cannot reach fail2ban: {exc}",
) from exc ) from exc
raw_geo = result.get("geo") raw_geo = result["geo"]
geo_detail: GeoDetail | None = None geo_detail: GeoDetail | None = None
if raw_geo is not None: if isinstance(raw_geo, GeoInfo):
geo_detail = GeoDetail( geo_detail = GeoDetail(
country_code=raw_geo.country_code, country_code=raw_geo.country_code,
country_name=raw_geo.country_name, country_name=raw_geo.country_name,

View File

@@ -14,17 +14,11 @@ import asyncio
import json import json
import time import time
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from dataclasses import asdict
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, TypeAlias from typing import TYPE_CHECKING, cast
import structlog import structlog
if TYPE_CHECKING:
import aiosqlite
from app.services.geo_service import GeoInfo
from app.models.ban import ( from app.models.ban import (
BLOCKLIST_JAIL, BLOCKLIST_JAIL,
BUCKET_SECONDS, BUCKET_SECONDS,
@@ -37,20 +31,25 @@ from app.models.ban import (
BanTrendResponse, BanTrendResponse,
DashboardBanItem, DashboardBanItem,
DashboardBanListResponse, DashboardBanListResponse,
JailBanCount as JailBanCountModel,
TimeRange, TimeRange,
_derive_origin, _derive_origin,
bucket_count, bucket_count,
) )
from app.models.ban import (
JailBanCount as JailBanCountModel,
)
from app.repositories import fail2ban_db_repo from app.repositories import fail2ban_db_repo
from app.utils.fail2ban_client import Fail2BanClient from app.utils.fail2ban_client import Fail2BanClient, Fail2BanResponse
if TYPE_CHECKING: if TYPE_CHECKING:
import aiohttp import aiohttp
import aiosqlite
from app.services.geo_service import GeoInfo
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
GeoEnricher: TypeAlias = Callable[[str], Awaitable["GeoInfo"] | None] type GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Constants # Constants
@@ -137,7 +136,7 @@ async def _get_fail2ban_db_path(socket_path: str) -> str:
response = await client.send(["get", "dbfile"]) response = await client.send(["get", "dbfile"])
try: try:
code, data = response code, data = cast("Fail2BanResponse", response)
except (TypeError, ValueError) as exc: except (TypeError, ValueError) as exc:
raise RuntimeError(f"Unexpected response from fail2ban: {response!r}") from exc raise RuntimeError(f"Unexpected response from fail2ban: {response!r}") from exc
@@ -276,7 +275,7 @@ async def list_bans(
# Batch-resolve geo data for all IPs on this page in a single API call. # Batch-resolve geo data for all IPs on this page in a single API call.
# This avoids hitting the 45 req/min single-IP rate limit when the # This avoids hitting the 45 req/min single-IP rate limit when the
# page contains many bans (e.g. after a large blocklist import). # page contains many bans (e.g. after a large blocklist import).
geo_map: dict[str, "GeoInfo"] = {} geo_map: dict[str, GeoInfo] = {}
if http_session is not None and rows: if http_session is not None and rows:
page_ips: list[str] = [r.ip for r in rows] page_ips: list[str] = [r.ip for r in rows]
try: try:
@@ -428,7 +427,7 @@ async def bans_by_country(
) )
unique_ips: list[str] = [r.ip for r in agg_rows] unique_ips: list[str] = [r.ip for r in agg_rows]
geo_map: dict[str, "GeoInfo"] = {} geo_map: dict[str, GeoInfo] = {}
if http_session is not None and unique_ips: if http_session is not None and unique_ips:
# Serve only what is already in the in-memory cache — no API calls on # Serve only what is already in the in-memory cache — no API calls on
@@ -449,7 +448,7 @@ async def bans_by_country(
) )
elif geo_enricher is not None and unique_ips: elif geo_enricher is not None and unique_ips:
# Fallback: legacy per-IP enricher (used in tests / older callers). # Fallback: legacy per-IP enricher (used in tests / older callers).
async def _safe_lookup(ip: str) -> tuple[str, "GeoInfo" | None]: async def _safe_lookup(ip: str) -> tuple[str, GeoInfo | None]:
try: try:
return ip, await geo_enricher(ip) return ip, await geo_enricher(ip)
except Exception: # noqa: BLE001 except Exception: # noqa: BLE001
@@ -636,9 +635,7 @@ async def bans_by_jail(
# has *any* rows and log a warning with min/max timeofban so operators can # has *any* rows and log a warning with min/max timeofban so operators can
# diagnose timezone or filter mismatches from logs. # diagnose timezone or filter mismatches from logs.
if total == 0: if total == 0:
table_row_count, min_timeofban, max_timeofban = ( table_row_count, min_timeofban, max_timeofban = await fail2ban_db_repo.get_bans_table_summary(db_path)
await fail2ban_db_repo.get_bans_table_summary(db_path)
)
if table_row_count > 0: if table_row_count > 0:
log.warning( log.warning(
"ban_service_bans_by_jail_empty_despite_data", "ban_service_bans_by_jail_empty_despite_data",

View File

@@ -542,7 +542,7 @@ async def list_import_logs(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _aiohttp_timeout(seconds: float) -> "aiohttp.ClientTimeout": def _aiohttp_timeout(seconds: float) -> aiohttp.ClientTimeout:
"""Return an :class:`aiohttp.ClientTimeout` with the given total timeout. """Return an :class:`aiohttp.ClientTimeout` with the given total timeout.
Args: Args:

View File

@@ -28,7 +28,7 @@ import os
import re import re
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, cast, TypeAlias from typing import cast
import structlog import structlog
@@ -59,7 +59,6 @@ from app.services.jail_service import JailNotFoundError as JailNotFoundError
from app.utils import conffile_parser from app.utils import conffile_parser
from app.utils.fail2ban_client import ( from app.utils.fail2ban_client import (
Fail2BanClient, Fail2BanClient,
Fail2BanCommand,
Fail2BanConnectionError, Fail2BanConnectionError,
Fail2BanResponse, Fail2BanResponse,
) )
@@ -73,9 +72,7 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger()
_SOCKET_TIMEOUT: float = 10.0 _SOCKET_TIMEOUT: float = 10.0
# Allowlist pattern for jail names used in path construction. # Allowlist pattern for jail names used in path construction.
_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile( _SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$"
)
# Sections that are not jail definitions. # Sections that are not jail definitions.
_META_SECTIONS: frozenset[str] = frozenset({"INCLUDES", "DEFAULT"}) _META_SECTIONS: frozenset[str] = frozenset({"INCLUDES", "DEFAULT"})
@@ -167,8 +164,7 @@ class FilterReadonlyError(Exception):
""" """
self.name: str = name self.name: str = name
super().__init__( super().__init__(
f"Filter {name!r} is a shipped default (.conf only); " f"Filter {name!r} is a shipped default (.conf only); only user-created .local files can be deleted."
"only user-created .local files can be deleted."
) )
@@ -423,9 +419,7 @@ def _parse_jails_sync(
# items() merges DEFAULT values automatically. # items() merges DEFAULT values automatically.
jails[section] = dict(parser.items(section)) jails[section] = dict(parser.items(section))
except configparser.Error as exc: except configparser.Error as exc:
log.warning( log.warning("jail_section_parse_error", section=section, error=str(exc))
"jail_section_parse_error", section=section, error=str(exc)
)
log.debug("jails_parsed", count=len(jails), config_dir=str(config_dir)) log.debug("jails_parsed", count=len(jails), config_dir=str(config_dir))
return jails, source_files return jails, source_files
@@ -522,11 +516,7 @@ def _build_inactive_jail(
bantime_escalation=bantime_escalation, bantime_escalation=bantime_escalation,
source_file=source_file, source_file=source_file,
enabled=enabled, enabled=enabled,
has_local_override=( has_local_override=((config_dir / "jail.d" / f"{name}.local").is_file() if config_dir is not None else False),
(config_dir / "jail.d" / f"{name}.local").is_file()
if config_dir is not None
else False
),
) )
@@ -557,7 +547,7 @@ async def _get_active_jail_names(socket_path: str) -> set[str]:
return result return result
def _ok(response: object) -> object: def _ok(response: object) -> object:
code, data = cast(Fail2BanResponse, response) code, data = cast("Fail2BanResponse", response)
if code != 0: if code != 0:
raise ValueError(f"fail2ban error {code}: {data!r}") raise ValueError(f"fail2ban error {code}: {data!r}")
return data return data
@@ -572,9 +562,7 @@ async def _get_active_jail_names(socket_path: str) -> set[str]:
log.warning("fail2ban_unreachable_during_inactive_list") log.warning("fail2ban_unreachable_during_inactive_list")
return set() return set()
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning( log.warning("fail2ban_status_error_during_inactive_list", error=str(exc))
"fail2ban_status_error_during_inactive_list", error=str(exc)
)
return set() return set()
@@ -662,10 +650,7 @@ def _validate_jail_config_sync(
issues.append( issues.append(
JailValidationIssue( JailValidationIssue(
field="filter", field="filter",
message=( message=(f"Filter file not found: filter.d/{base_filter}.conf (or .local)"),
f"Filter file not found: filter.d/{base_filter}.conf"
" (or .local)"
),
) )
) )
@@ -681,10 +666,7 @@ def _validate_jail_config_sync(
issues.append( issues.append(
JailValidationIssue( JailValidationIssue(
field="action", field="action",
message=( message=(f"Action file not found: action.d/{action_name}.conf (or .local)"),
f"Action file not found: action.d/{action_name}.conf"
" (or .local)"
),
) )
) )
@@ -840,9 +822,7 @@ def _write_local_override_sync(
try: try:
jail_d.mkdir(parents=True, exist_ok=True) jail_d.mkdir(parents=True, exist_ok=True)
except OSError as exc: except OSError as exc:
raise ConfigWriteError( raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
f"Cannot create jail.d directory: {exc}"
) from exc
local_path = jail_d / f"{jail_name}.local" local_path = jail_d / f"{jail_name}.local"
@@ -867,7 +847,7 @@ def _write_local_override_sync(
if overrides.get("port") is not None: if overrides.get("port") is not None:
lines.append(f"port = {overrides['port']}") lines.append(f"port = {overrides['port']}")
if overrides.get("logpath"): if overrides.get("logpath"):
paths: list[str] = cast(list[str], overrides["logpath"]) paths: list[str] = cast("list[str]", overrides["logpath"])
if paths: if paths:
lines.append(f"logpath = {paths[0]}") lines.append(f"logpath = {paths[0]}")
for p in paths[1:]: for p in paths[1:]:
@@ -890,9 +870,7 @@ def _write_local_override_sync(
# Clean up temp file if rename failed. # Clean up temp file if rename failed.
with contextlib.suppress(OSError): with contextlib.suppress(OSError):
os.unlink(tmp_name) # noqa: F821 — only reachable when tmp_name is set os.unlink(tmp_name) # noqa: F821 — only reachable when tmp_name is set
raise ConfigWriteError( raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
f"Failed to write {local_path}: {exc}"
) from exc
log.info( log.info(
"jail_local_written", "jail_local_written",
@@ -921,9 +899,7 @@ def _restore_local_file_sync(local_path: Path, original_content: bytes | None) -
try: try:
local_path.unlink(missing_ok=True) local_path.unlink(missing_ok=True)
except OSError as exc: except OSError as exc:
raise ConfigWriteError( raise ConfigWriteError(f"Failed to delete {local_path} during rollback: {exc}") from exc
f"Failed to delete {local_path} during rollback: {exc}"
) from exc
return return
tmp_name: str | None = None tmp_name: str | None = None
@@ -941,9 +917,7 @@ def _restore_local_file_sync(local_path: Path, original_content: bytes | None) -
with contextlib.suppress(OSError): with contextlib.suppress(OSError):
if tmp_name is not None: if tmp_name is not None:
os.unlink(tmp_name) os.unlink(tmp_name)
raise ConfigWriteError( raise ConfigWriteError(f"Failed to restore {local_path} during rollback: {exc}") from exc
f"Failed to restore {local_path} during rollback: {exc}"
) from exc
def _validate_regex_patterns(patterns: list[str]) -> None: def _validate_regex_patterns(patterns: list[str]) -> None:
@@ -979,9 +953,7 @@ def _write_filter_local_sync(filter_d: Path, name: str, content: str) -> None:
try: try:
filter_d.mkdir(parents=True, exist_ok=True) filter_d.mkdir(parents=True, exist_ok=True)
except OSError as exc: except OSError as exc:
raise ConfigWriteError( raise ConfigWriteError(f"Cannot create filter.d directory: {exc}") from exc
f"Cannot create filter.d directory: {exc}"
) from exc
local_path = filter_d / f"{name}.local" local_path = filter_d / f"{name}.local"
try: try:
@@ -998,9 +970,7 @@ def _write_filter_local_sync(filter_d: Path, name: str, content: str) -> None:
except OSError as exc: except OSError as exc:
with contextlib.suppress(OSError): with contextlib.suppress(OSError):
os.unlink(tmp_name) # noqa: F821 os.unlink(tmp_name) # noqa: F821
raise ConfigWriteError( raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
f"Failed to write {local_path}: {exc}"
) from exc
log.info("filter_local_written", filter=name, path=str(local_path)) log.info("filter_local_written", filter=name, path=str(local_path))
@@ -1031,9 +1001,7 @@ def _set_jail_local_key_sync(
try: try:
jail_d.mkdir(parents=True, exist_ok=True) jail_d.mkdir(parents=True, exist_ok=True)
except OSError as exc: except OSError as exc:
raise ConfigWriteError( raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
f"Cannot create jail.d directory: {exc}"
) from exc
local_path = jail_d / f"{jail_name}.local" local_path = jail_d / f"{jail_name}.local"
@@ -1072,9 +1040,7 @@ def _set_jail_local_key_sync(
except OSError as exc: except OSError as exc:
with contextlib.suppress(OSError): with contextlib.suppress(OSError):
os.unlink(tmp_name) # noqa: F821 os.unlink(tmp_name) # noqa: F821
raise ConfigWriteError( raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
f"Failed to write {local_path}: {exc}"
) from exc
log.info( log.info(
"jail_local_key_set", "jail_local_key_set",
@@ -1112,8 +1078,8 @@ async def list_inactive_jails(
inactive jails. inactive jails.
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
parsed_result: tuple[dict[str, dict[str, str]], dict[str, str]] = ( parsed_result: tuple[dict[str, dict[str, str]], dict[str, str]] = await loop.run_in_executor(
await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)) None, _parse_jails_sync, Path(config_dir)
) )
all_jails, source_files = parsed_result all_jails, source_files = parsed_result
active_names: set[str] = await _get_active_jail_names(socket_path) active_names: set[str] = await _get_active_jail_names(socket_path)
@@ -1170,9 +1136,7 @@ async def activate_jail(
_safe_jail_name(name) _safe_jail_name(name)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
all_jails, _source_files = await loop.run_in_executor( all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
None, _parse_jails_sync, Path(config_dir)
)
if name not in all_jails: if name not in all_jails:
raise JailNotFoundInConfigError(name) raise JailNotFoundInConfigError(name)
@@ -1208,10 +1172,7 @@ async def activate_jail(
active=False, active=False,
fail2ban_running=True, fail2ban_running=True,
validation_warnings=warnings, validation_warnings=warnings,
message=( message=(f"Jail {name!r} cannot be activated: " + "; ".join(i.message for i in blocking)),
f"Jail {name!r} cannot be activated: "
+ "; ".join(i.message for i in blocking)
),
) )
overrides: dict[str, object] = { overrides: dict[str, object] = {
@@ -1254,9 +1215,7 @@ async def activate_jail(
jail=name, jail=name,
error=str(exc), error=str(exc),
) )
recovered = await _rollback_activation_async( recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
config_dir, name, socket_path, original_content
)
return JailActivationResponse( return JailActivationResponse(
name=name, name=name,
active=False, active=False,
@@ -1272,9 +1231,7 @@ async def activate_jail(
) )
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning("reload_after_activate_failed", jail=name, error=str(exc)) log.warning("reload_after_activate_failed", jail=name, error=str(exc))
recovered = await _rollback_activation_async( recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
config_dir, name, socket_path, original_content
)
return JailActivationResponse( return JailActivationResponse(
name=name, name=name,
active=False, active=False,
@@ -1305,9 +1262,7 @@ async def activate_jail(
jail=name, jail=name,
message="fail2ban socket unreachable after reload — initiating rollback.", message="fail2ban socket unreachable after reload — initiating rollback.",
) )
recovered = await _rollback_activation_async( recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
config_dir, name, socket_path, original_content
)
return JailActivationResponse( return JailActivationResponse(
name=name, name=name,
active=False, active=False,
@@ -1330,9 +1285,7 @@ async def activate_jail(
jail=name, jail=name,
message="Jail did not appear in running jails — initiating rollback.", message="Jail did not appear in running jails — initiating rollback.",
) )
recovered = await _rollback_activation_async( recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
config_dir, name, socket_path, original_content
)
return JailActivationResponse( return JailActivationResponse(
name=name, name=name,
active=False, active=False,
@@ -1388,14 +1341,10 @@ async def _rollback_activation_async(
# Step 1 — restore original file (or delete it). # Step 1 — restore original file (or delete it).
try: try:
await loop.run_in_executor( await loop.run_in_executor(None, _restore_local_file_sync, local_path, original_content)
None, _restore_local_file_sync, local_path, original_content
)
log.info("jail_activation_rollback_file_restored", jail=name) log.info("jail_activation_rollback_file_restored", jail=name)
except ConfigWriteError as exc: except ConfigWriteError as exc:
log.error( log.error("jail_activation_rollback_restore_failed", jail=name, error=str(exc))
"jail_activation_rollback_restore_failed", jail=name, error=str(exc)
)
return False return False
# Step 2 — reload fail2ban with the restored config. # Step 2 — reload fail2ban with the restored config.
@@ -1403,9 +1352,7 @@ async def _rollback_activation_async(
await jail_service.reload_all(socket_path) await jail_service.reload_all(socket_path)
log.info("jail_activation_rollback_reload_ok", jail=name) log.info("jail_activation_rollback_reload_ok", jail=name)
except Exception as exc: # noqa: BLE001 except Exception as exc: # noqa: BLE001
log.warning( log.warning("jail_activation_rollback_reload_failed", jail=name, error=str(exc))
"jail_activation_rollback_reload_failed", jail=name, error=str(exc)
)
return False return False
# Step 3 — wait for fail2ban to come back. # Step 3 — wait for fail2ban to come back.
@@ -1450,9 +1397,7 @@ async def deactivate_jail(
_safe_jail_name(name) _safe_jail_name(name)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
all_jails, _source_files = await loop.run_in_executor( all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
None, _parse_jails_sync, Path(config_dir)
)
if name not in all_jails: if name not in all_jails:
raise JailNotFoundInConfigError(name) raise JailNotFoundInConfigError(name)
@@ -1510,9 +1455,7 @@ async def delete_jail_local_override(
_safe_jail_name(name) _safe_jail_name(name)
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
all_jails, _source_files = await loop.run_in_executor( all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
None, _parse_jails_sync, Path(config_dir)
)
if name not in all_jails: if name not in all_jails:
raise JailNotFoundInConfigError(name) raise JailNotFoundInConfigError(name)
@@ -1523,13 +1466,9 @@ async def delete_jail_local_override(
local_path = Path(config_dir) / "jail.d" / f"{name}.local" local_path = Path(config_dir) / "jail.d" / f"{name}.local"
try: try:
await loop.run_in_executor( await loop.run_in_executor(None, lambda: local_path.unlink(missing_ok=True))
None, lambda: local_path.unlink(missing_ok=True)
)
except OSError as exc: except OSError as exc:
raise ConfigWriteError( raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
f"Failed to delete {local_path}: {exc}"
) from exc
log.info("jail_local_override_deleted", jail=name, path=str(local_path)) log.info("jail_local_override_deleted", jail=name, path=str(local_path))
@@ -1610,9 +1549,7 @@ async def rollback_jail(
log.info("jail_rollback_start_attempted", jail=name, start_ok=started) log.info("jail_rollback_start_attempted", jail=name, start_ok=started)
# Wait for the socket to come back. # Wait for the socket to come back.
fail2ban_running = await wait_for_fail2ban( fail2ban_running = await wait_for_fail2ban(socket_path, max_wait_seconds=10.0, poll_interval=2.0)
socket_path, max_wait_seconds=10.0, poll_interval=2.0
)
active_jails = 0 active_jails = 0
if fail2ban_running: if fail2ban_running:
@@ -1626,10 +1563,7 @@ async def rollback_jail(
disabled=True, disabled=True,
fail2ban_running=True, fail2ban_running=True,
active_jails=active_jails, active_jails=active_jails,
message=( message=(f"Jail {name!r} disabled and fail2ban restarted successfully with {active_jails} active jail(s)."),
f"Jail {name!r} disabled and fail2ban restarted successfully "
f"with {active_jails} active jail(s)."
),
) )
log.warning("jail_rollback_fail2ban_still_down", jail=name) log.warning("jail_rollback_fail2ban_still_down", jail=name)
@@ -1650,9 +1584,7 @@ async def rollback_jail(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Allowlist pattern for filter names used in path construction. # Allowlist pattern for filter names used in path construction.
_SAFE_FILTER_NAME_RE: re.Pattern[str] = re.compile( _SAFE_FILTER_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$"
)
class FilterNotFoundError(Exception): class FilterNotFoundError(Exception):
@@ -1764,9 +1696,7 @@ def _parse_filters_sync(
try: try:
content = conf_path.read_text(encoding="utf-8") content = conf_path.read_text(encoding="utf-8")
except OSError as exc: except OSError as exc:
log.warning( log.warning("filter_read_error", name=name, path=str(conf_path), error=str(exc))
"filter_read_error", name=name, path=str(conf_path), error=str(exc)
)
continue continue
if has_local: if has_local:
@@ -1842,9 +1772,7 @@ async def list_filters(
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
# Run the synchronous scan in a thread-pool executor. # Run the synchronous scan in a thread-pool executor.
raw_filters: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor( raw_filters: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(None, _parse_filters_sync, filter_d)
None, _parse_filters_sync, filter_d
)
# Fetch active jail names and their configs concurrently. # Fetch active jail names and their configs concurrently.
all_jails_result, active_names = await asyncio.gather( all_jails_result, active_names = await asyncio.gather(
@@ -1857,9 +1785,7 @@ async def list_filters(
filters: list[FilterConfig] = [] filters: list[FilterConfig] = []
for name, filename, content, has_local, source_path in raw_filters: for name, filename, content, has_local, source_path in raw_filters:
cfg = conffile_parser.parse_filter_file( cfg = conffile_parser.parse_filter_file(content, name=name, filename=filename)
content, name=name, filename=filename
)
used_by = sorted(filter_to_jails.get(name, [])) used_by = sorted(filter_to_jails.get(name, []))
filters.append( filters.append(
FilterConfig( FilterConfig(
@@ -1947,9 +1873,7 @@ async def get_filter(
content, has_local, source_path = await loop.run_in_executor(None, _read) content, has_local, source_path = await loop.run_in_executor(None, _read)
cfg = conffile_parser.parse_filter_file( cfg = conffile_parser.parse_filter_file(content, name=base_name, filename=f"{base_name}.conf")
content, name=base_name, filename=f"{base_name}.conf"
)
all_jails_result, active_names = await asyncio.gather( all_jails_result, active_names = await asyncio.gather(
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)), loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
@@ -2182,9 +2106,7 @@ async def delete_filter(
try: try:
local_path.unlink() local_path.unlink()
except OSError as exc: except OSError as exc:
raise ConfigWriteError( raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
f"Failed to delete {local_path}: {exc}"
) from exc
log.info("filter_local_deleted", filter=base_name, path=str(local_path)) log.info("filter_local_deleted", filter=base_name, path=str(local_path))
@@ -2226,9 +2148,7 @@ async def assign_filter_to_jail(
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
# Verify the jail exists in config. # Verify the jail exists in config.
all_jails, _src = await loop.run_in_executor( all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
None, _parse_jails_sync, Path(config_dir)
)
if jail_name not in all_jails: if jail_name not in all_jails:
raise JailNotFoundInConfigError(jail_name) raise JailNotFoundInConfigError(jail_name)
@@ -2276,9 +2196,7 @@ async def assign_filter_to_jail(
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Allowlist pattern for action names used in path construction. # Allowlist pattern for action names used in path construction.
_SAFE_ACTION_NAME_RE: re.Pattern[str] = re.compile( _SAFE_ACTION_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$"
)
class ActionNotFoundError(Exception): class ActionNotFoundError(Exception):
@@ -2318,8 +2236,7 @@ class ActionReadonlyError(Exception):
""" """
self.name: str = name self.name: str = name
super().__init__( super().__init__(
f"Action {name!r} is a shipped default (.conf only); " f"Action {name!r} is a shipped default (.conf only); only user-created .local files can be deleted."
"only user-created .local files can be deleted."
) )
@@ -2428,9 +2345,7 @@ def _parse_actions_sync(
try: try:
content = conf_path.read_text(encoding="utf-8") content = conf_path.read_text(encoding="utf-8")
except OSError as exc: except OSError as exc:
log.warning( log.warning("action_read_error", name=name, path=str(conf_path), error=str(exc))
"action_read_error", name=name, path=str(conf_path), error=str(exc)
)
continue continue
if has_local: if has_local:
@@ -2495,9 +2410,7 @@ def _append_jail_action_sync(
try: try:
jail_d.mkdir(parents=True, exist_ok=True) jail_d.mkdir(parents=True, exist_ok=True)
except OSError as exc: except OSError as exc:
raise ConfigWriteError( raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
f"Cannot create jail.d directory: {exc}"
) from exc
local_path = jail_d / f"{jail_name}.local" local_path = jail_d / f"{jail_name}.local"
@@ -2517,9 +2430,7 @@ def _append_jail_action_sync(
existing_raw = parser.get(jail_name, "action") if parser.has_option(jail_name, "action") else "" existing_raw = parser.get(jail_name, "action") if parser.has_option(jail_name, "action") else ""
existing_lines = [ existing_lines = [
line.strip() line.strip() for line in existing_raw.splitlines() if line.strip() and not line.strip().startswith("#")
for line in existing_raw.splitlines()
if line.strip() and not line.strip().startswith("#")
] ]
# Extract base names from existing entries for duplicate checking. # Extract base names from existing entries for duplicate checking.
@@ -2533,9 +2444,7 @@ def _append_jail_action_sync(
if existing_lines: if existing_lines:
# configparser multi-line: continuation lines start with whitespace. # configparser multi-line: continuation lines start with whitespace.
new_value = existing_lines[0] + "".join( new_value = existing_lines[0] + "".join(f"\n {line}" for line in existing_lines[1:])
f"\n {line}" for line in existing_lines[1:]
)
parser.set(jail_name, "action", new_value) parser.set(jail_name, "action", new_value)
else: else:
parser.set(jail_name, "action", action_entry) parser.set(jail_name, "action", action_entry)
@@ -2559,9 +2468,7 @@ def _append_jail_action_sync(
except OSError as exc: except OSError as exc:
with contextlib.suppress(OSError): with contextlib.suppress(OSError):
os.unlink(tmp_name) # noqa: F821 os.unlink(tmp_name) # noqa: F821
raise ConfigWriteError( raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
f"Failed to write {local_path}: {exc}"
) from exc
log.info( log.info(
"jail_action_appended", "jail_action_appended",
@@ -2612,9 +2519,7 @@ def _remove_jail_action_sync(
existing_raw = parser.get(jail_name, "action") existing_raw = parser.get(jail_name, "action")
existing_lines = [ existing_lines = [
line.strip() line.strip() for line in existing_raw.splitlines() if line.strip() and not line.strip().startswith("#")
for line in existing_raw.splitlines()
if line.strip() and not line.strip().startswith("#")
] ]
def _base(entry: str) -> str: def _base(entry: str) -> str:
@@ -2628,9 +2533,7 @@ def _remove_jail_action_sync(
return return
if filtered: if filtered:
new_value = filtered[0] + "".join( new_value = filtered[0] + "".join(f"\n {line}" for line in filtered[1:])
f"\n {line}" for line in filtered[1:]
)
parser.set(jail_name, "action", new_value) parser.set(jail_name, "action", new_value)
else: else:
parser.remove_option(jail_name, "action") parser.remove_option(jail_name, "action")
@@ -2654,9 +2557,7 @@ def _remove_jail_action_sync(
except OSError as exc: except OSError as exc:
with contextlib.suppress(OSError): with contextlib.suppress(OSError):
os.unlink(tmp_name) # noqa: F821 os.unlink(tmp_name) # noqa: F821
raise ConfigWriteError( raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
f"Failed to write {local_path}: {exc}"
) from exc
log.info( log.info(
"jail_action_removed", "jail_action_removed",
@@ -2683,9 +2584,7 @@ def _write_action_local_sync(action_d: Path, name: str, content: str) -> None:
try: try:
action_d.mkdir(parents=True, exist_ok=True) action_d.mkdir(parents=True, exist_ok=True)
except OSError as exc: except OSError as exc:
raise ConfigWriteError( raise ConfigWriteError(f"Cannot create action.d directory: {exc}") from exc
f"Cannot create action.d directory: {exc}"
) from exc
local_path = action_d / f"{name}.local" local_path = action_d / f"{name}.local"
try: try:
@@ -2702,9 +2601,7 @@ def _write_action_local_sync(action_d: Path, name: str, content: str) -> None:
except OSError as exc: except OSError as exc:
with contextlib.suppress(OSError): with contextlib.suppress(OSError):
os.unlink(tmp_name) # noqa: F821 os.unlink(tmp_name) # noqa: F821
raise ConfigWriteError( raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
f"Failed to write {local_path}: {exc}"
) from exc
log.info("action_local_written", action=name, path=str(local_path)) log.info("action_local_written", action=name, path=str(local_path))
@@ -2740,9 +2637,7 @@ async def list_actions(
action_d = Path(config_dir) / "action.d" action_d = Path(config_dir) / "action.d"
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
raw_actions: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor( raw_actions: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(None, _parse_actions_sync, action_d)
None, _parse_actions_sync, action_d
)
all_jails_result, active_names = await asyncio.gather( all_jails_result, active_names = await asyncio.gather(
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)), loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
@@ -2754,9 +2649,7 @@ async def list_actions(
actions: list[ActionConfig] = [] actions: list[ActionConfig] = []
for name, filename, content, has_local, source_path in raw_actions: for name, filename, content, has_local, source_path in raw_actions:
cfg = conffile_parser.parse_action_file( cfg = conffile_parser.parse_action_file(content, name=name, filename=filename)
content, name=name, filename=filename
)
used_by = sorted(action_to_jails.get(name, [])) used_by = sorted(action_to_jails.get(name, []))
actions.append( actions.append(
ActionConfig( ActionConfig(
@@ -2843,9 +2736,7 @@ async def get_action(
content, has_local, source_path = await loop.run_in_executor(None, _read) content, has_local, source_path = await loop.run_in_executor(None, _read)
cfg = conffile_parser.parse_action_file( cfg = conffile_parser.parse_action_file(content, name=base_name, filename=f"{base_name}.conf")
content, name=base_name, filename=f"{base_name}.conf"
)
all_jails_result, active_names = await asyncio.gather( all_jails_result, active_names = await asyncio.gather(
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)), loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
@@ -3061,9 +2952,7 @@ async def delete_action(
try: try:
local_path.unlink() local_path.unlink()
except OSError as exc: except OSError as exc:
raise ConfigWriteError( raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
f"Failed to delete {local_path}: {exc}"
) from exc
log.info("action_local_deleted", action=base_name, path=str(local_path)) log.info("action_local_deleted", action=base_name, path=str(local_path))
@@ -3105,9 +2994,7 @@ async def assign_action_to_jail(
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
all_jails, _src = await loop.run_in_executor( all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
None, _parse_jails_sync, Path(config_dir)
)
if jail_name not in all_jails: if jail_name not in all_jails:
raise JailNotFoundInConfigError(jail_name) raise JailNotFoundInConfigError(jail_name)
@@ -3187,9 +3074,7 @@ async def remove_action_from_jail(
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
all_jails, _src = await loop.run_in_executor( all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
None, _parse_jails_sync, Path(config_dir)
)
if jail_name not in all_jails: if jail_name not in all_jails:
raise JailNotFoundInConfigError(jail_name) raise JailNotFoundInConfigError(jail_name)
@@ -3218,4 +3103,3 @@ async def remove_action_from_jail(
action=action_name, action=action_name,
reload=do_reload, reload=do_reload,
) )

View File

@@ -95,7 +95,7 @@ def _ok(response: object) -> object:
ValueError: If the return code indicates an error. ValueError: If the return code indicates an error.
""" """
try: try:
code, data = cast(Fail2BanResponse, response) code, data = cast("Fail2BanResponse", response)
except (TypeError, ValueError) as exc: except (TypeError, ValueError) as exc:
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
if code != 0: if code != 0:
@@ -128,7 +128,7 @@ def _ensure_list(value: object | None) -> list[str]:
return [str(value)] return [str(value)]
_T = TypeVar("_T") T = TypeVar("T")
async def _safe_get( async def _safe_get(
@@ -143,13 +143,13 @@ async def _safe_get(
return default return default
async def _safe_get_typed( async def _safe_get_typed[T](
client: Fail2BanClient, client: Fail2BanClient,
command: Fail2BanCommand, command: Fail2BanCommand,
default: _T, default: T,
) -> _T: ) -> T:
"""Send a command and return the result typed as ``default``'s type.""" """Send a command and return the result typed as ``default``'s type."""
return cast(_T, await _safe_get(client, command, default)) return cast("T", await _safe_get(client, command, default))
def _is_not_found_error(exc: Exception) -> bool: def _is_not_found_error(exc: Exception) -> bool:

View File

@@ -47,7 +47,7 @@ def _ok(response: object) -> object:
ValueError: If the response indicates an error (return code ≠ 0). ValueError: If the response indicates an error (return code ≠ 0).
""" """
try: try:
code, data = cast(Fail2BanResponse, response) code, data = cast("Fail2BanResponse", response)
except (TypeError, ValueError) as exc: except (TypeError, ValueError) as exc:
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc

View File

@@ -11,9 +11,11 @@ modifies or locks the fail2ban database.
from __future__ import annotations from __future__ import annotations
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING
import structlog import structlog
if TYPE_CHECKING:
from app.services.geo_service import GeoEnricher from app.services.geo_service import GeoEnricher
from app.models.ban import TIME_RANGE_SECONDS, TimeRange from app.models.ban import TIME_RANGE_SECONDS, TimeRange

View File

@@ -14,7 +14,8 @@ from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
import ipaddress import ipaddress
from typing import TYPE_CHECKING, Awaitable, Callable, cast, TypeAlias from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, TypedDict, cast
import structlog import structlog
@@ -27,6 +28,7 @@ from app.models.jail import (
JailStatus, JailStatus,
JailSummary, JailSummary,
) )
from app.services.geo_service import GeoInfo
from app.utils.fail2ban_client import ( from app.utils.fail2ban_client import (
Fail2BanClient, Fail2BanClient,
Fail2BanCommand, Fail2BanCommand,
@@ -39,11 +41,21 @@ if TYPE_CHECKING:
import aiohttp import aiohttp
import aiosqlite import aiosqlite
from app.services.geo_service import GeoInfo
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
GeoEnricher: TypeAlias = Callable[[str], Awaitable["GeoInfo | None"]] class IpLookupResult(TypedDict):
"""Result returned by :func:`lookup_ip`.
This is intentionally a :class:`TypedDict` to provide precise typing for
callers (e.g. routers) while keeping the implementation flexible.
"""
ip: str
currently_banned_in: list[str]
geo: GeoInfo | None
GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Constants # Constants
@@ -104,7 +116,7 @@ def _ok(response: object) -> object:
ValueError: If the response indicates an error (return code ≠ 0). ValueError: If the response indicates an error (return code ≠ 0).
""" """
try: try:
code, data = cast(Fail2BanResponse, response) code, data = cast("Fail2BanResponse", response)
except (TypeError, ValueError) as exc: except (TypeError, ValueError) as exc:
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
@@ -202,7 +214,7 @@ async def _safe_get(
""" """
try: try:
response = await client.send(command) response = await client.send(command)
return _ok(cast(Fail2BanResponse, response)) return _ok(cast("Fail2BanResponse", response))
except (ValueError, TypeError, Exception): except (ValueError, TypeError, Exception):
return default return default
@@ -337,7 +349,6 @@ async def _fetch_jail_summary(
client.send(["get", name, "backend"]), client.send(["get", name, "backend"]),
client.send(["get", name, "idle"]), client.send(["get", name, "idle"]),
]) ])
uses_backend_backend_commands = True
else: else:
# Commands not supported; return default values without sending. # Commands not supported; return default values without sending.
async def _return_default(value: object | None) -> Fail2BanResponse: async def _return_default(value: object | None) -> Fail2BanResponse:
@@ -347,7 +358,6 @@ async def _fetch_jail_summary(
_return_default("polling"), # backend default _return_default("polling"), # backend default
_return_default(False), # idle default _return_default(False), # idle default
]) ])
uses_backend_backend_commands = False
_r = await asyncio.gather(*gather_list, return_exceptions=True) _r = await asyncio.gather(*gather_list, return_exceptions=True)
status_raw: object | Exception = _r[0] status_raw: object | Exception = _r[0]
@@ -377,7 +387,7 @@ async def _fetch_jail_summary(
if isinstance(raw, Exception): if isinstance(raw, Exception):
return fallback return fallback
try: try:
return int(str(_ok(cast(Fail2BanResponse, raw)))) return int(str(_ok(cast("Fail2BanResponse", raw))))
except (ValueError, TypeError): except (ValueError, TypeError):
return fallback return fallback
@@ -385,7 +395,7 @@ async def _fetch_jail_summary(
if isinstance(raw, Exception): if isinstance(raw, Exception):
return fallback return fallback
try: try:
return str(_ok(cast(Fail2BanResponse, raw))) return str(_ok(cast("Fail2BanResponse", raw)))
except (ValueError, TypeError): except (ValueError, TypeError):
return fallback return fallback
@@ -393,7 +403,7 @@ async def _fetch_jail_summary(
if isinstance(raw, Exception): if isinstance(raw, Exception):
return fallback return fallback
try: try:
return bool(_ok(cast(Fail2BanResponse, raw))) return bool(_ok(cast("Fail2BanResponse", raw)))
except (ValueError, TypeError): except (ValueError, TypeError):
return fallback return fallback
@@ -687,7 +697,7 @@ async def reload_all(
names_set -= set(exclude_jails) names_set -= set(exclude_jails)
stream: list[list[object]] = [["start", n] for n in sorted(names_set)] stream: list[list[object]] = [["start", n] for n in sorted(names_set)]
_ok(await client.send(["reload", "--all", [], cast(Fail2BanToken, stream)])) _ok(await client.send(["reload", "--all", [], cast("Fail2BanToken", stream)]))
log.info("all_jails_reloaded") log.info("all_jails_reloaded")
except ValueError as exc: except ValueError as exc:
# Detect UnknownJailException (missing or invalid jail configuration) # Detect UnknownJailException (missing or invalid jail configuration)
@@ -811,8 +821,8 @@ async def unban_ip(
async def get_active_bans( async def get_active_bans(
socket_path: str, socket_path: str,
geo_enricher: GeoEnricher | None = None, geo_enricher: GeoEnricher | None = None,
http_session: "aiohttp.ClientSession" | None = None, http_session: aiohttp.ClientSession | None = None,
app_db: "aiosqlite.Connection" | None = None, app_db: aiosqlite.Connection | None = None,
) -> ActiveBanListResponse: ) -> ActiveBanListResponse:
"""Return all currently banned IPs across every jail. """Return all currently banned IPs across every jail.
@@ -880,7 +890,7 @@ async def get_active_bans(
continue continue
try: try:
ban_list: list[str] = cast(list[str], _ok(raw_result)) or [] ban_list: list[str] = cast("list[str]", _ok(raw_result)) or []
except (TypeError, ValueError) as exc: except (TypeError, ValueError) as exc:
log.warning( log.warning(
"active_bans_parse_error", "active_bans_parse_error",
@@ -1007,8 +1017,8 @@ async def get_jail_banned_ips(
page: int = 1, page: int = 1,
page_size: int = 25, page_size: int = 25,
search: str | None = None, search: str | None = None,
http_session: "aiohttp.ClientSession" | None = None, http_session: aiohttp.ClientSession | None = None,
app_db: "aiosqlite.Connection" | None = None, app_db: aiosqlite.Connection | None = None,
) -> JailBannedIpsResponse: ) -> JailBannedIpsResponse:
"""Return a paginated list of currently banned IPs for a single jail. """Return a paginated list of currently banned IPs for a single jail.
@@ -1055,7 +1065,7 @@ async def get_jail_banned_ips(
except (ValueError, TypeError): except (ValueError, TypeError):
raw_result = [] raw_result = []
ban_list: list[str] = cast(list[str], raw_result) or [] ban_list: list[str] = cast("list[str]", raw_result) or []
# Parse all entries. # Parse all entries.
all_bans: list[ActiveBan] = [] all_bans: list[ActiveBan] = []
@@ -1121,7 +1131,7 @@ async def _enrich_bans(
The same list with ``country`` fields populated where lookup succeeded. The same list with ``country`` fields populated where lookup succeeded.
""" """
geo_results: list[object | Exception] = await asyncio.gather( geo_results: list[object | Exception] = await asyncio.gather(
*[cast(Awaitable[object], geo_enricher(ban.ip)) for ban in bans], *[cast("Awaitable[object]", geo_enricher(ban.ip)) for ban in bans],
return_exceptions=True, return_exceptions=True,
) )
enriched: list[ActiveBan] = [] enriched: list[ActiveBan] = []
@@ -1277,7 +1287,7 @@ async def lookup_ip(
socket_path: str, socket_path: str,
ip: str, ip: str,
geo_enricher: GeoEnricher | None = None, geo_enricher: GeoEnricher | None = None,
) -> dict[str, object | list[str] | None]: ) -> IpLookupResult:
"""Return ban status and history for a single IP address. """Return ban status and history for a single IP address.
Checks every running jail for whether the IP is currently banned. Checks every running jail for whether the IP is currently banned.
@@ -1330,7 +1340,7 @@ async def lookup_ip(
if isinstance(result, Exception): if isinstance(result, Exception):
continue continue
try: try:
ban_list: list[str] = cast(list[str], _ok(result)) or [] ban_list: list[str] = cast("list[str]", _ok(result)) or []
if ip in ban_list: if ip in ban_list:
currently_banned_in.append(jail_name) currently_banned_in.append(jail_name)
except (ValueError, TypeError): except (ValueError, TypeError):

View File

@@ -10,7 +10,7 @@ HTTP/FastAPI concerns.
from __future__ import annotations from __future__ import annotations
from typing import cast, TypeAlias from typing import cast
import structlog import structlog
@@ -21,7 +21,7 @@ from app.utils.fail2ban_client import Fail2BanClient, Fail2BanCommand, Fail2BanR
# Types # Types
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
Fail2BanSettingValue: TypeAlias = str | int | bool type Fail2BanSettingValue = str | int | bool
"""Allowed values for server settings commands.""" """Allowed values for server settings commands."""
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -106,7 +106,7 @@ async def _safe_get(
""" """
try: try:
response = await client.send(command) response = await client.send(command)
return _ok(cast(Fail2BanResponse, response)) return _ok(cast("Fail2BanResponse", response))
except Exception: except Exception:
return default return default
@@ -189,7 +189,7 @@ async def update_settings(socket_path: str, update: ServerSettingsUpdate) -> Non
async def _set(key: str, value: Fail2BanSettingValue) -> None: async def _set(key: str, value: Fail2BanSettingValue) -> None:
try: try:
response = await client.send(["set", key, value]) response = await client.send(["set", key, value])
_ok(cast(Fail2BanResponse, response)) _ok(cast("Fail2BanResponse", response))
except ValueError as exc: except ValueError as exc:
raise ServerOperationError(f"Failed to set {key!r} = {value!r}: {exc}") from exc raise ServerOperationError(f"Failed to set {key!r} = {value!r}: {exc}") from exc
@@ -224,7 +224,7 @@ async def flush_logs(socket_path: str) -> str:
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
try: try:
response = await client.send(["flushlogs"]) response = await client.send(["flushlogs"])
result = _ok(cast(Fail2BanResponse, response)) result = _ok(cast("Fail2BanResponse", response))
log.info("logs_flushed", result=result) log.info("logs_flushed", result=result)
return str(result) return str(result)
except ValueError as exc: except ValueError as exc:

View File

@@ -35,7 +35,7 @@ GEO_RE_RESOLVE_INTERVAL: int = 600
JOB_ID: str = "geo_re_resolve" JOB_ID: str = "geo_re_resolve"
async def _run_re_resolve(app: "FastAPI") -> None: async def _run_re_resolve(app: FastAPI) -> None:
"""Query NULL-country IPs from the database and re-resolve them. """Query NULL-country IPs from the database and re-resolve them.
Reads shared resources from ``app.state`` and delegates to Reads shared resources from ``app.state`` and delegates to

View File

@@ -47,7 +47,7 @@ HEALTH_CHECK_INTERVAL: int = 30
_ACTIVATION_CRASH_WINDOW: int = 60 _ACTIVATION_CRASH_WINDOW: int = 60
async def _run_probe(app: "FastAPI") -> None: async def _run_probe(app: FastAPI) -> None:
"""Probe fail2ban and cache the result on *app.state*. """Probe fail2ban and cache the result on *app.state*.
Detects online/offline state transitions. When fail2ban goes offline Detects online/offline state transitions. When fail2ban goes offline

View File

@@ -21,34 +21,52 @@ import contextlib
import errno import errno
import socket import socket
import time import time
from collections.abc import Mapping, Sequence, Set
from pickle import HIGHEST_PROTOCOL, dumps, loads from pickle import HIGHEST_PROTOCOL, dumps, loads
from typing import TYPE_CHECKING, TypeAlias from typing import TYPE_CHECKING
import structlog
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Types # Types
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
Fail2BanToken: TypeAlias = str | int | float | bool | None | dict[str, object] | list[object] # Use covariant container types so callers can pass ``list[int]`` / ``dict[str, str]``
# without needing to cast. At runtime we only accept the basic built-in
# containers supported by fail2ban's protocol (list/dict/set) and stringify
# anything else.
#
# NOTE: ``Sequence`` will also accept tuples, but tuples are stringified at
# runtime because fail2ban only understands lists.
type Fail2BanToken = (
str
| int
| float
| bool
| None
| Mapping[str, object]
| Sequence[object]
| Set[object]
)
"""A single token in a fail2ban command. """A single token in a fail2ban command.
Fail2ban accepts simple types (str/int/float/bool) plus compound types Fail2ban accepts simple types (str/int/float/bool) plus compound types
(list/dict). Complex objects are stringified before being sent. (list/dict/set). Complex objects are stringified before being sent.
""" """
Fail2BanCommand: TypeAlias = list[Fail2BanToken] type Fail2BanCommand = Sequence[Fail2BanToken]
"""A command sent to fail2ban over the socket. """A command sent to fail2ban over the socket.
Commands are pickle serialised lists of tokens. Commands are pickle serialised sequences of tokens.
""" """
Fail2BanResponse: TypeAlias = tuple[int, object] type Fail2BanResponse = tuple[int, object]
"""A typical fail2ban response containing a status code and payload.""" """A typical fail2ban response containing a status code and payload."""
if TYPE_CHECKING: if TYPE_CHECKING:
from types import TracebackType from types import TracebackType
import structlog
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
# fail2ban protocol constants — inline to avoid a hard import dependency # fail2ban protocol constants — inline to avoid a hard import dependency
@@ -200,7 +218,7 @@ def _send_command_sync(
) from last_oserror ) from last_oserror
def _coerce_command_token(token: Fail2BanToken) -> Fail2BanToken: def _coerce_command_token(token: object) -> Fail2BanToken:
"""Coerce a command token to a type that fail2ban understands. """Coerce a command token to a type that fail2ban understands.
fail2ban's ``CSocket.convert`` accepts ``str``, ``bool``, ``int``, fail2ban's ``CSocket.convert`` accepts ``str``, ``bool``, ``int``,

View File

@@ -60,4 +60,5 @@ plugins = ["pydantic.mypy"]
asyncio_mode = "auto" asyncio_mode = "auto"
pythonpath = [".", "../fail2ban-master"] pythonpath = [".", "../fail2ban-master"]
testpaths = ["tests"] testpaths = ["tests"]
addopts = "--cov=app --cov-report=term-missing" addopts = "--asyncio-mode=auto --cov=app --cov-report=term-missing"
filterwarnings = ["ignore::pytest.PytestRemovedIn9Warning"]

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Generator
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@@ -157,12 +158,12 @@ class TestRequireAuthSessionCache:
"""In-memory session token cache inside ``require_auth``.""" """In-memory session token cache inside ``require_auth``."""
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reset_cache(self) -> None: # type: ignore[misc] def reset_cache(self) -> Generator[None, None, None]:
"""Flush the session cache before and after every test in this class.""" """Flush the session cache before and after every test in this class."""
from app import dependencies from app import dependencies
dependencies.clear_session_cache() dependencies.clear_session_cache()
yield # type: ignore[misc] yield
dependencies.clear_session_cache() dependencies.clear_session_cache()
async def test_second_request_skips_db(self, client: AsyncClient) -> None: async def test_second_request_skips_db(self, client: AsyncClient) -> None:

View File

@@ -70,7 +70,7 @@ class TestGeoLookup:
async def test_200_with_geo_info(self, geo_client: AsyncClient) -> None: async def test_200_with_geo_info(self, geo_client: AsyncClient) -> None:
"""GET /api/geo/lookup/{ip} returns 200 with enriched result.""" """GET /api/geo/lookup/{ip} returns 200 with enriched result."""
geo = GeoInfo(country_code="DE", country_name="Germany", asn="12345", org="Acme") geo = GeoInfo(country_code="DE", country_name="Germany", asn="12345", org="Acme")
result = { result: dict[str, object] = {
"ip": "1.2.3.4", "ip": "1.2.3.4",
"currently_banned_in": ["sshd"], "currently_banned_in": ["sshd"],
"geo": geo, "geo": geo,
@@ -92,7 +92,7 @@ class TestGeoLookup:
async def test_200_when_not_banned(self, geo_client: AsyncClient) -> None: async def test_200_when_not_banned(self, geo_client: AsyncClient) -> None:
"""GET /api/geo/lookup/{ip} returns empty list when IP is not banned anywhere.""" """GET /api/geo/lookup/{ip} returns empty list when IP is not banned anywhere."""
result = { result: dict[str, object] = {
"ip": "8.8.8.8", "ip": "8.8.8.8",
"currently_banned_in": [], "currently_banned_in": [],
"geo": GeoInfo(country_code="US", country_name="United States", asn=None, org=None), "geo": GeoInfo(country_code="US", country_name="United States", asn=None, org=None),
@@ -108,7 +108,7 @@ class TestGeoLookup:
async def test_200_with_no_geo(self, geo_client: AsyncClient) -> None: async def test_200_with_no_geo(self, geo_client: AsyncClient) -> None:
"""GET /api/geo/lookup/{ip} returns null geo when enricher fails.""" """GET /api/geo/lookup/{ip} returns null geo when enricher fails."""
result = { result: dict[str, object] = {
"ip": "1.2.3.4", "ip": "1.2.3.4",
"currently_banned_in": [], "currently_banned_in": [],
"geo": None, "geo": None,
@@ -144,7 +144,7 @@ class TestGeoLookup:
async def test_ipv6_address(self, geo_client: AsyncClient) -> None: async def test_ipv6_address(self, geo_client: AsyncClient) -> None:
"""GET /api/geo/lookup/{ip} handles IPv6 addresses.""" """GET /api/geo/lookup/{ip} handles IPv6 addresses."""
result = { result: dict[str, object] = {
"ip": "2001:db8::1", "ip": "2001:db8::1",
"currently_banned_in": [], "currently_banned_in": [],
"geo": None, "geo": None,

View File

@@ -12,6 +12,7 @@ from httpx import ASGITransport, AsyncClient
from app.config import Settings from app.config import Settings
from app.db import init_db from app.db import init_db
from app.main import create_app from app.main import create_app
from app.models.ban import JailBannedIpsResponse
from app.models.jail import Jail, JailDetailResponse, JailListResponse, JailStatus, JailSummary from app.models.jail import Jail, JailDetailResponse, JailListResponse, JailStatus, JailSummary
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -801,17 +802,17 @@ class TestGetJailBannedIps:
def _mock_response( def _mock_response(
self, self,
*, *,
items: list[dict] | None = None, items: list[dict[str, str | None]] | None = None,
total: int = 2, total: int = 2,
page: int = 1, page: int = 1,
page_size: int = 25, page_size: int = 25,
) -> "JailBannedIpsResponse": # type: ignore[name-defined] ) -> JailBannedIpsResponse:
from app.models.ban import ActiveBan, JailBannedIpsResponse from app.models.ban import ActiveBan, JailBannedIpsResponse
ban_items = ( ban_items = (
[ [
ActiveBan( ActiveBan(
ip=item.get("ip", "1.2.3.4"), ip=item.get("ip") or "1.2.3.4",
jail="sshd", jail="sshd",
banned_at=item.get("banned_at", "2025-01-01T10:00:00+00:00"), banned_at=item.get("banned_at", "2025-01-01T10:00:00+00:00"),
expires_at=item.get("expires_at", "2025-01-01T10:10:00+00:00"), expires_at=item.get("expires_at", "2025-01-01T10:10:00+00:00"),

View File

@@ -247,9 +247,9 @@ class TestSetupCompleteCaching:
assert not getattr(app.state, "_setup_complete_cached", False) assert not getattr(app.state, "_setup_complete_cached", False)
# First non-exempt request — middleware queries DB and sets the flag. # First non-exempt request — middleware queries DB and sets the flag.
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload] await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]})
assert app.state._setup_complete_cached is True # type: ignore[attr-defined] assert app.state._setup_complete_cached is True
async def test_cached_path_skips_is_setup_complete( async def test_cached_path_skips_is_setup_complete(
self, self,
@@ -267,12 +267,12 @@ class TestSetupCompleteCaching:
# Do setup and warm the cache. # Do setup and warm the cache.
await client.post("/api/setup", json=_SETUP_PAYLOAD) await client.post("/api/setup", json=_SETUP_PAYLOAD)
await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}) # type: ignore[call-overload] await client.post("/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]})
assert app.state._setup_complete_cached is True # type: ignore[attr-defined] assert app.state._setup_complete_cached is True
call_count = 0 call_count = 0
async def _counting(db): # type: ignore[no-untyped-def] async def _counting(db: aiosqlite.Connection) -> bool:
nonlocal call_count nonlocal call_count
call_count += 1 call_count += 1
return True return True

View File

@@ -73,7 +73,7 @@ class TestCheckPasswordAsync:
auth_service._check_password("secret", hashed), # noqa: SLF001 auth_service._check_password("secret", hashed), # noqa: SLF001
auth_service._check_password("wrong", hashed), # noqa: SLF001 auth_service._check_password("wrong", hashed), # noqa: SLF001
) )
assert results == [True, False] assert tuple(results) == (True, False)
class TestLogin: class TestLogin:

View File

@@ -64,7 +64,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
@pytest.fixture @pytest.fixture
async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc] async def f2b_db_path(tmp_path: Path) -> str:
"""Return the path to a test fail2ban SQLite database with several bans.""" """Return the path to a test fail2ban SQLite database with several bans."""
path = str(tmp_path / "fail2ban_test.sqlite3") path = str(tmp_path / "fail2ban_test.sqlite3")
await _create_f2b_db( await _create_f2b_db(
@@ -103,7 +103,7 @@ async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc]
@pytest.fixture @pytest.fixture
async def mixed_origin_db_path(tmp_path: Path) -> str: # type: ignore[misc] async def mixed_origin_db_path(tmp_path: Path) -> str:
"""Return a database with bans from both blocklist-import and organic jails.""" """Return a database with bans from both blocklist-import and organic jails."""
path = str(tmp_path / "fail2ban_mixed_origin.sqlite3") path = str(tmp_path / "fail2ban_mixed_origin.sqlite3")
await _create_f2b_db( await _create_f2b_db(
@@ -136,7 +136,7 @@ async def mixed_origin_db_path(tmp_path: Path) -> str: # type: ignore[misc]
@pytest.fixture @pytest.fixture
async def empty_f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc] async def empty_f2b_db_path(tmp_path: Path) -> str:
"""Return the path to a fail2ban SQLite database with no ban records.""" """Return the path to a fail2ban SQLite database with no ban records."""
path = str(tmp_path / "fail2ban_empty.sqlite3") path = str(tmp_path / "fail2ban_empty.sqlite3")
await _create_f2b_db(path, []) await _create_f2b_db(path, [])
@@ -632,13 +632,13 @@ class TestBansbyCountryBackground:
from app.services import geo_service from app.services import geo_service
# Pre-populate the cache for all three IPs in the fixture. # Pre-populate the cache for all three IPs in the fixture.
geo_service._cache["10.0.0.1"] = geo_service.GeoInfo( # type: ignore[attr-defined] geo_service._cache["10.0.0.1"] = geo_service.GeoInfo(
country_code="DE", country_name="Germany", asn=None, org=None country_code="DE", country_name="Germany", asn=None, org=None
) )
geo_service._cache["10.0.0.2"] = geo_service.GeoInfo( # type: ignore[attr-defined] geo_service._cache["10.0.0.2"] = geo_service.GeoInfo(
country_code="US", country_name="United States", asn=None, org=None country_code="US", country_name="United States", asn=None, org=None
) )
geo_service._cache["10.0.0.3"] = geo_service.GeoInfo( # type: ignore[attr-defined] geo_service._cache["10.0.0.3"] = geo_service.GeoInfo(
country_code="JP", country_name="Japan", asn=None, org=None country_code="JP", country_name="Japan", asn=None, org=None
) )

View File

@@ -114,13 +114,13 @@ async def _seed_f2b_db(path: str, n: int) -> list[str]:
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def event_loop_policy() -> None: # type: ignore[misc] def event_loop_policy() -> None:
"""Use the default event loop policy for module-scoped fixtures.""" """Use the default event loop policy for module-scoped fixtures."""
return None return None
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
async def perf_db_path(tmp_path_factory: Any) -> str: # type: ignore[misc] async def perf_db_path(tmp_path_factory: Any) -> str:
"""Return the path to a fail2ban DB seeded with 10 000 synthetic bans. """Return the path to a fail2ban DB seeded with 10 000 synthetic bans.
Module-scoped so the database is created only once for all perf tests. Module-scoped so the database is created only once for all perf tests.

View File

@@ -13,15 +13,19 @@ from app.services.config_file_service import (
JailNameError, JailNameError,
JailNotFoundInConfigError, JailNotFoundInConfigError,
_build_inactive_jail, _build_inactive_jail,
_extract_action_base_name,
_extract_filter_base_name,
_ordered_config_files, _ordered_config_files,
_parse_jails_sync, _parse_jails_sync,
_resolve_filter, _resolve_filter,
_safe_jail_name, _safe_jail_name,
_validate_jail_config_sync,
_write_local_override_sync, _write_local_override_sync,
activate_jail, activate_jail,
deactivate_jail, deactivate_jail,
list_inactive_jails, list_inactive_jails,
rollback_jail, rollback_jail,
validate_jail_config,
) )
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -292,9 +296,7 @@ class TestBuildInactiveJail:
def test_has_local_override_absent(self, tmp_path: Path) -> None: def test_has_local_override_absent(self, tmp_path: Path) -> None:
"""has_local_override is False when no .local file exists.""" """has_local_override is False when no .local file exists."""
jail = _build_inactive_jail( jail = _build_inactive_jail("sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path)
"sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path
)
assert jail.has_local_override is False assert jail.has_local_override is False
def test_has_local_override_present(self, tmp_path: Path) -> None: def test_has_local_override_present(self, tmp_path: Path) -> None:
@@ -302,9 +304,7 @@ class TestBuildInactiveJail:
local = tmp_path / "jail.d" / "sshd.local" local = tmp_path / "jail.d" / "sshd.local"
local.parent.mkdir(parents=True, exist_ok=True) local.parent.mkdir(parents=True, exist_ok=True)
local.write_text("[sshd]\nenabled = false\n") local.write_text("[sshd]\nenabled = false\n")
jail = _build_inactive_jail( jail = _build_inactive_jail("sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path)
"sshd", {}, "/etc/fail2ban/jail.d/sshd.conf", config_dir=tmp_path
)
assert jail.has_local_override is True assert jail.has_local_override is True
def test_has_local_override_no_config_dir(self) -> None: def test_has_local_override_no_config_dir(self) -> None:
@@ -363,9 +363,7 @@ class TestWriteLocalOverrideSync:
assert "2222" in content assert "2222" in content
def test_override_logpath_list(self, tmp_path: Path) -> None: def test_override_logpath_list(self, tmp_path: Path) -> None:
_write_local_override_sync( _write_local_override_sync(tmp_path, "sshd", True, {"logpath": ["/var/log/auth.log", "/var/log/secure"]})
tmp_path, "sshd", True, {"logpath": ["/var/log/auth.log", "/var/log/secure"]}
)
content = (tmp_path / "jail.d" / "sshd.local").read_text() content = (tmp_path / "jail.d" / "sshd.local").read_text()
assert "/var/log/auth.log" in content assert "/var/log/auth.log" in content
assert "/var/log/secure" in content assert "/var/log/secure" in content
@@ -447,9 +445,7 @@ class TestListInactiveJails:
assert "sshd" in names assert "sshd" in names
assert "apache-auth" in names assert "apache-auth" in names
async def test_has_local_override_true_when_local_file_exists( async def test_has_local_override_true_when_local_file_exists(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
"""has_local_override is True for a jail whose jail.d .local file exists.""" """has_local_override is True for a jail whose jail.d .local file exists."""
_write(tmp_path / "jail.conf", JAIL_CONF) _write(tmp_path / "jail.conf", JAIL_CONF)
local = tmp_path / "jail.d" / "apache-auth.local" local = tmp_path / "jail.d" / "apache-auth.local"
@@ -463,9 +459,7 @@ class TestListInactiveJails:
jail = next(j for j in result.jails if j.name == "apache-auth") jail = next(j for j in result.jails if j.name == "apache-auth")
assert jail.has_local_override is True assert jail.has_local_override is True
async def test_has_local_override_false_when_no_local_file( async def test_has_local_override_false_when_no_local_file(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
"""has_local_override is False when no jail.d .local file exists.""" """has_local_override is False when no jail.d .local file exists."""
_write(tmp_path / "jail.conf", JAIL_CONF) _write(tmp_path / "jail.conf", JAIL_CONF)
with patch( with patch(
@@ -608,7 +602,8 @@ class TestActivateJail:
patch( patch(
"app.services.config_file_service._get_active_jail_names", "app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()), new=AsyncMock(return_value=set()),
),pytest.raises(JailNotFoundInConfigError) ),
pytest.raises(JailNotFoundInConfigError),
): ):
await activate_jail(str(tmp_path), "/fake.sock", "nonexistent", req) await activate_jail(str(tmp_path), "/fake.sock", "nonexistent", req)
@@ -621,7 +616,8 @@ class TestActivateJail:
patch( patch(
"app.services.config_file_service._get_active_jail_names", "app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value={"sshd"}), new=AsyncMock(return_value={"sshd"}),
),pytest.raises(JailAlreadyActiveError) ),
pytest.raises(JailAlreadyActiveError),
): ):
await activate_jail(str(tmp_path), "/fake.sock", "sshd", req) await activate_jail(str(tmp_path), "/fake.sock", "sshd", req)
@@ -691,7 +687,8 @@ class TestDeactivateJail:
patch( patch(
"app.services.config_file_service._get_active_jail_names", "app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value={"sshd"}), new=AsyncMock(return_value={"sshd"}),
),pytest.raises(JailNotFoundInConfigError) ),
pytest.raises(JailNotFoundInConfigError),
): ):
await deactivate_jail(str(tmp_path), "/fake.sock", "nonexistent") await deactivate_jail(str(tmp_path), "/fake.sock", "nonexistent")
@@ -701,7 +698,8 @@ class TestDeactivateJail:
patch( patch(
"app.services.config_file_service._get_active_jail_names", "app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()), new=AsyncMock(return_value=set()),
),pytest.raises(JailAlreadyInactiveError) ),
pytest.raises(JailAlreadyInactiveError),
): ):
await deactivate_jail(str(tmp_path), "/fake.sock", "apache-auth") await deactivate_jail(str(tmp_path), "/fake.sock", "apache-auth")
@@ -710,38 +708,6 @@ class TestDeactivateJail:
await deactivate_jail(str(tmp_path), "/fake.sock", "a/b") await deactivate_jail(str(tmp_path), "/fake.sock", "a/b")
# ---------------------------------------------------------------------------
# _extract_filter_base_name
# ---------------------------------------------------------------------------
class TestExtractFilterBaseName:
def test_simple_name(self) -> None:
from app.services.config_file_service import _extract_filter_base_name
assert _extract_filter_base_name("sshd") == "sshd"
def test_name_with_mode(self) -> None:
from app.services.config_file_service import _extract_filter_base_name
assert _extract_filter_base_name("sshd[mode=aggressive]") == "sshd"
def test_name_with_variable_mode(self) -> None:
from app.services.config_file_service import _extract_filter_base_name
assert _extract_filter_base_name("sshd[mode=%(mode)s]") == "sshd"
def test_whitespace_stripped(self) -> None:
from app.services.config_file_service import _extract_filter_base_name
assert _extract_filter_base_name(" nginx ") == "nginx"
def test_empty_string(self) -> None:
from app.services.config_file_service import _extract_filter_base_name
assert _extract_filter_base_name("") == ""
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# _build_filter_to_jails_map # _build_filter_to_jails_map
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -757,9 +723,7 @@ class TestBuildFilterToJailsMap:
def test_inactive_jail_not_included(self) -> None: def test_inactive_jail_not_included(self) -> None:
from app.services.config_file_service import _build_filter_to_jails_map from app.services.config_file_service import _build_filter_to_jails_map
result = _build_filter_to_jails_map( result = _build_filter_to_jails_map({"apache-auth": {"filter": "apache-auth"}}, set())
{"apache-auth": {"filter": "apache-auth"}}, set()
)
assert result == {} assert result == {}
def test_multiple_jails_sharing_filter(self) -> None: def test_multiple_jails_sharing_filter(self) -> None:
@@ -775,9 +739,7 @@ class TestBuildFilterToJailsMap:
def test_mode_suffix_stripped(self) -> None: def test_mode_suffix_stripped(self) -> None:
from app.services.config_file_service import _build_filter_to_jails_map from app.services.config_file_service import _build_filter_to_jails_map
result = _build_filter_to_jails_map( result = _build_filter_to_jails_map({"sshd": {"filter": "sshd[mode=aggressive]"}}, {"sshd"})
{"sshd": {"filter": "sshd[mode=aggressive]"}}, {"sshd"}
)
assert "sshd" in result assert "sshd" in result
def test_missing_filter_key_falls_back_to_jail_name(self) -> None: def test_missing_filter_key_falls_back_to_jail_name(self) -> None:
@@ -988,10 +950,13 @@ class TestGetFilter:
async def test_raises_filter_not_found(self, tmp_path: Path) -> None: async def test_raises_filter_not_found(self, tmp_path: Path) -> None:
from app.services.config_file_service import FilterNotFoundError, get_filter from app.services.config_file_service import FilterNotFoundError, get_filter
with patch( with (
patch(
"app.services.config_file_service._get_active_jail_names", "app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()), new=AsyncMock(return_value=set()),
), pytest.raises(FilterNotFoundError): ),
pytest.raises(FilterNotFoundError),
):
await get_filter(str(tmp_path), "/fake.sock", "nonexistent") await get_filter(str(tmp_path), "/fake.sock", "nonexistent")
async def test_has_local_override_detected(self, tmp_path: Path) -> None: async def test_has_local_override_detected(self, tmp_path: Path) -> None:
@@ -1093,10 +1058,13 @@ class TestGetFilterLocalOnly:
async def test_raises_when_neither_conf_nor_local(self, tmp_path: Path) -> None: async def test_raises_when_neither_conf_nor_local(self, tmp_path: Path) -> None:
from app.services.config_file_service import FilterNotFoundError, get_filter from app.services.config_file_service import FilterNotFoundError, get_filter
with patch( with (
patch(
"app.services.config_file_service._get_active_jail_names", "app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()), new=AsyncMock(return_value=set()),
), pytest.raises(FilterNotFoundError): ),
pytest.raises(FilterNotFoundError),
):
await get_filter(str(tmp_path), "/fake.sock", "nonexistent") await get_filter(str(tmp_path), "/fake.sock", "nonexistent")
async def test_accepts_local_extension(self, tmp_path: Path) -> None: async def test_accepts_local_extension(self, tmp_path: Path) -> None:
@@ -1212,9 +1180,7 @@ class TestSetJailLocalKeySync:
jail_d = tmp_path / "jail.d" jail_d = tmp_path / "jail.d"
jail_d.mkdir() jail_d.mkdir()
(jail_d / "sshd.local").write_text( (jail_d / "sshd.local").write_text("[sshd]\nenabled = true\n")
"[sshd]\nenabled = true\n"
)
_set_jail_local_key_sync(tmp_path, "sshd", "filter", "newfilter") _set_jail_local_key_sync(tmp_path, "sshd", "filter", "newfilter")
@@ -1300,10 +1266,13 @@ class TestUpdateFilter:
from app.models.config import FilterUpdateRequest from app.models.config import FilterUpdateRequest
from app.services.config_file_service import FilterNotFoundError, update_filter from app.services.config_file_service import FilterNotFoundError, update_filter
with patch( with (
patch(
"app.services.config_file_service._get_active_jail_names", "app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()), new=AsyncMock(return_value=set()),
), pytest.raises(FilterNotFoundError): ),
pytest.raises(FilterNotFoundError),
):
await update_filter( await update_filter(
str(tmp_path), str(tmp_path),
"/fake.sock", "/fake.sock",
@@ -1321,10 +1290,13 @@ class TestUpdateFilter:
filter_d = tmp_path / "filter.d" filter_d = tmp_path / "filter.d"
_write(filter_d / "sshd.conf", _FILTER_CONF_WITH_REGEX) _write(filter_d / "sshd.conf", _FILTER_CONF_WITH_REGEX)
with patch( with (
patch(
"app.services.config_file_service._get_active_jail_names", "app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()), new=AsyncMock(return_value=set()),
), pytest.raises(FilterInvalidRegexError): ),
pytest.raises(FilterInvalidRegexError),
):
await update_filter( await update_filter(
str(tmp_path), str(tmp_path),
"/fake.sock", "/fake.sock",
@@ -1351,13 +1323,16 @@ class TestUpdateFilter:
filter_d = tmp_path / "filter.d" filter_d = tmp_path / "filter.d"
_write(filter_d / "sshd.conf", _FILTER_CONF) _write(filter_d / "sshd.conf", _FILTER_CONF)
with patch( with (
patch(
"app.services.config_file_service._get_active_jail_names", "app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()), new=AsyncMock(return_value=set()),
), patch( ),
patch(
"app.services.config_file_service.jail_service.reload_all", "app.services.config_file_service.jail_service.reload_all",
new=AsyncMock(), new=AsyncMock(),
) as mock_reload: ) as mock_reload,
):
await update_filter( await update_filter(
str(tmp_path), str(tmp_path),
"/fake.sock", "/fake.sock",
@@ -1405,10 +1380,13 @@ class TestCreateFilter:
filter_d = tmp_path / "filter.d" filter_d = tmp_path / "filter.d"
_write(filter_d / "sshd.conf", _FILTER_CONF) _write(filter_d / "sshd.conf", _FILTER_CONF)
with patch( with (
patch(
"app.services.config_file_service._get_active_jail_names", "app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()), new=AsyncMock(return_value=set()),
), pytest.raises(FilterAlreadyExistsError): ),
pytest.raises(FilterAlreadyExistsError),
):
await create_filter( await create_filter(
str(tmp_path), str(tmp_path),
"/fake.sock", "/fake.sock",
@@ -1422,10 +1400,13 @@ class TestCreateFilter:
filter_d = tmp_path / "filter.d" filter_d = tmp_path / "filter.d"
_write(filter_d / "custom.local", "[Definition]\n") _write(filter_d / "custom.local", "[Definition]\n")
with patch( with (
patch(
"app.services.config_file_service._get_active_jail_names", "app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()), new=AsyncMock(return_value=set()),
), pytest.raises(FilterAlreadyExistsError): ),
pytest.raises(FilterAlreadyExistsError),
):
await create_filter( await create_filter(
str(tmp_path), str(tmp_path),
"/fake.sock", "/fake.sock",
@@ -1436,10 +1417,13 @@ class TestCreateFilter:
from app.models.config import FilterCreateRequest from app.models.config import FilterCreateRequest
from app.services.config_file_service import FilterInvalidRegexError, create_filter from app.services.config_file_service import FilterInvalidRegexError, create_filter
with patch( with (
patch(
"app.services.config_file_service._get_active_jail_names", "app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()), new=AsyncMock(return_value=set()),
), pytest.raises(FilterInvalidRegexError): ),
pytest.raises(FilterInvalidRegexError),
):
await create_filter( await create_filter(
str(tmp_path), str(tmp_path),
"/fake.sock", "/fake.sock",
@@ -1461,13 +1445,16 @@ class TestCreateFilter:
from app.models.config import FilterCreateRequest from app.models.config import FilterCreateRequest
from app.services.config_file_service import create_filter from app.services.config_file_service import create_filter
with patch( with (
patch(
"app.services.config_file_service._get_active_jail_names", "app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()), new=AsyncMock(return_value=set()),
), patch( ),
patch(
"app.services.config_file_service.jail_service.reload_all", "app.services.config_file_service.jail_service.reload_all",
new=AsyncMock(), new=AsyncMock(),
) as mock_reload: ) as mock_reload,
):
await create_filter( await create_filter(
str(tmp_path), str(tmp_path),
"/fake.sock", "/fake.sock",
@@ -1485,9 +1472,7 @@ class TestCreateFilter:
@pytest.mark.asyncio @pytest.mark.asyncio
class TestDeleteFilter: class TestDeleteFilter:
async def test_deletes_local_file_when_conf_and_local_exist( async def test_deletes_local_file_when_conf_and_local_exist(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
from app.services.config_file_service import delete_filter from app.services.config_file_service import delete_filter
filter_d = tmp_path / "filter.d" filter_d = tmp_path / "filter.d"
@@ -1524,9 +1509,7 @@ class TestDeleteFilter:
with pytest.raises(FilterNotFoundError): with pytest.raises(FilterNotFoundError):
await delete_filter(str(tmp_path), "nonexistent") await delete_filter(str(tmp_path), "nonexistent")
async def test_accepts_filter_name_error_for_invalid_name( async def test_accepts_filter_name_error_for_invalid_name(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
from app.services.config_file_service import FilterNameError, delete_filter from app.services.config_file_service import FilterNameError, delete_filter
with pytest.raises(FilterNameError): with pytest.raises(FilterNameError):
@@ -1607,9 +1590,7 @@ class TestAssignFilterToJail:
AssignFilterRequest(filter_name="sshd"), AssignFilterRequest(filter_name="sshd"),
) )
async def test_raises_filter_name_error_for_invalid_filter( async def test_raises_filter_name_error_for_invalid_filter(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
from app.models.config import AssignFilterRequest from app.models.config import AssignFilterRequest
from app.services.config_file_service import FilterNameError, assign_filter_to_jail from app.services.config_file_service import FilterNameError, assign_filter_to_jail
@@ -1719,34 +1700,26 @@ class TestBuildActionToJailsMap:
def test_active_jail_maps_to_action(self) -> None: def test_active_jail_maps_to_action(self) -> None:
from app.services.config_file_service import _build_action_to_jails_map from app.services.config_file_service import _build_action_to_jails_map
result = _build_action_to_jails_map( result = _build_action_to_jails_map({"sshd": {"action": "iptables-multiport"}}, {"sshd"})
{"sshd": {"action": "iptables-multiport"}}, {"sshd"}
)
assert result == {"iptables-multiport": ["sshd"]} assert result == {"iptables-multiport": ["sshd"]}
def test_inactive_jail_not_included(self) -> None: def test_inactive_jail_not_included(self) -> None:
from app.services.config_file_service import _build_action_to_jails_map from app.services.config_file_service import _build_action_to_jails_map
result = _build_action_to_jails_map( result = _build_action_to_jails_map({"sshd": {"action": "iptables-multiport"}}, set())
{"sshd": {"action": "iptables-multiport"}}, set()
)
assert result == {} assert result == {}
def test_multiple_actions_per_jail(self) -> None: def test_multiple_actions_per_jail(self) -> None:
from app.services.config_file_service import _build_action_to_jails_map from app.services.config_file_service import _build_action_to_jails_map
result = _build_action_to_jails_map( result = _build_action_to_jails_map({"sshd": {"action": "iptables-multiport\niptables-ipset"}}, {"sshd"})
{"sshd": {"action": "iptables-multiport\niptables-ipset"}}, {"sshd"}
)
assert "iptables-multiport" in result assert "iptables-multiport" in result
assert "iptables-ipset" in result assert "iptables-ipset" in result
def test_parameter_block_stripped(self) -> None: def test_parameter_block_stripped(self) -> None:
from app.services.config_file_service import _build_action_to_jails_map from app.services.config_file_service import _build_action_to_jails_map
result = _build_action_to_jails_map( result = _build_action_to_jails_map({"sshd": {"action": "iptables[port=ssh, protocol=tcp]"}}, {"sshd"})
{"sshd": {"action": "iptables[port=ssh, protocol=tcp]"}}, {"sshd"}
)
assert "iptables" in result assert "iptables" in result
def test_multiple_jails_sharing_action(self) -> None: def test_multiple_jails_sharing_action(self) -> None:
@@ -2001,10 +1974,13 @@ class TestGetAction:
async def test_raises_for_unknown_action(self, tmp_path: Path) -> None: async def test_raises_for_unknown_action(self, tmp_path: Path) -> None:
from app.services.config_file_service import ActionNotFoundError, get_action from app.services.config_file_service import ActionNotFoundError, get_action
with patch( with (
patch(
"app.services.config_file_service._get_active_jail_names", "app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()), new=AsyncMock(return_value=set()),
), pytest.raises(ActionNotFoundError): ),
pytest.raises(ActionNotFoundError),
):
await get_action(str(tmp_path), "/fake.sock", "nonexistent") await get_action(str(tmp_path), "/fake.sock", "nonexistent")
async def test_local_only_action_returned(self, tmp_path: Path) -> None: async def test_local_only_action_returned(self, tmp_path: Path) -> None:
@@ -2118,10 +2094,13 @@ class TestUpdateAction:
from app.models.config import ActionUpdateRequest from app.models.config import ActionUpdateRequest
from app.services.config_file_service import ActionNotFoundError, update_action from app.services.config_file_service import ActionNotFoundError, update_action
with patch( with (
patch(
"app.services.config_file_service._get_active_jail_names", "app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()), new=AsyncMock(return_value=set()),
), pytest.raises(ActionNotFoundError): ),
pytest.raises(ActionNotFoundError),
):
await update_action( await update_action(
str(tmp_path), str(tmp_path),
"/fake.sock", "/fake.sock",
@@ -2587,9 +2566,7 @@ class TestRemoveActionFromJail:
"app.services.config_file_service._get_active_jail_names", "app.services.config_file_service._get_active_jail_names",
new=AsyncMock(return_value=set()), new=AsyncMock(return_value=set()),
): ):
await remove_action_from_jail( await remove_action_from_jail(str(tmp_path), "/fake.sock", "sshd", "iptables-multiport")
str(tmp_path), "/fake.sock", "sshd", "iptables-multiport"
)
content = (jail_d / "sshd.local").read_text() content = (jail_d / "sshd.local").read_text()
assert "iptables-multiport" not in content assert "iptables-multiport" not in content
@@ -2601,17 +2578,13 @@ class TestRemoveActionFromJail:
) )
with pytest.raises(JailNotFoundInConfigError): with pytest.raises(JailNotFoundInConfigError):
await remove_action_from_jail( await remove_action_from_jail(str(tmp_path), "/fake.sock", "nonexistent", "iptables")
str(tmp_path), "/fake.sock", "nonexistent", "iptables"
)
async def test_raises_jail_name_error(self, tmp_path: Path) -> None: async def test_raises_jail_name_error(self, tmp_path: Path) -> None:
from app.services.config_file_service import JailNameError, remove_action_from_jail from app.services.config_file_service import JailNameError, remove_action_from_jail
with pytest.raises(JailNameError): with pytest.raises(JailNameError):
await remove_action_from_jail( await remove_action_from_jail(str(tmp_path), "/fake.sock", "../evil", "iptables")
str(tmp_path), "/fake.sock", "../evil", "iptables"
)
async def test_raises_action_name_error(self, tmp_path: Path) -> None: async def test_raises_action_name_error(self, tmp_path: Path) -> None:
from app.services.config_file_service import ActionNameError, remove_action_from_jail from app.services.config_file_service import ActionNameError, remove_action_from_jail
@@ -2619,9 +2592,7 @@ class TestRemoveActionFromJail:
_write(tmp_path / "jail.conf", JAIL_CONF) _write(tmp_path / "jail.conf", JAIL_CONF)
with pytest.raises(ActionNameError): with pytest.raises(ActionNameError):
await remove_action_from_jail( await remove_action_from_jail(str(tmp_path), "/fake.sock", "sshd", "../evil")
str(tmp_path), "/fake.sock", "sshd", "../evil"
)
async def test_triggers_reload_when_requested(self, tmp_path: Path) -> None: async def test_triggers_reload_when_requested(self, tmp_path: Path) -> None:
from app.services.config_file_service import remove_action_from_jail from app.services.config_file_service import remove_action_from_jail
@@ -2640,9 +2611,7 @@ class TestRemoveActionFromJail:
new=AsyncMock(), new=AsyncMock(),
) as mock_reload, ) as mock_reload,
): ):
await remove_action_from_jail( await remove_action_from_jail(str(tmp_path), "/fake.sock", "sshd", "iptables", do_reload=True)
str(tmp_path), "/fake.sock", "sshd", "iptables", do_reload=True
)
mock_reload.assert_awaited_once() mock_reload.assert_awaited_once()
@@ -2680,13 +2649,9 @@ class TestActivateJailReloadArgs:
mock_js.reload_all = AsyncMock() mock_js.reload_all = AsyncMock()
await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req) await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
mock_js.reload_all.assert_awaited_once_with( mock_js.reload_all.assert_awaited_once_with("/fake.sock", include_jails=["apache-auth"])
"/fake.sock", include_jails=["apache-auth"]
)
async def test_activate_returns_active_true_when_jail_starts( async def test_activate_returns_active_true_when_jail_starts(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
"""activate_jail returns active=True when the jail appears in post-reload names.""" """activate_jail returns active=True when the jail appears in post-reload names."""
_write(tmp_path / "jail.conf", JAIL_CONF) _write(tmp_path / "jail.conf", JAIL_CONF)
from app.models.config import ActivateJailRequest, JailValidationResult from app.models.config import ActivateJailRequest, JailValidationResult
@@ -2708,16 +2673,12 @@ class TestActivateJailReloadArgs:
), ),
): ):
mock_js.reload_all = AsyncMock() mock_js.reload_all = AsyncMock()
result = await activate_jail( result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
str(tmp_path), "/fake.sock", "apache-auth", req
)
assert result.active is True assert result.active is True
assert "activated" in result.message.lower() assert "activated" in result.message.lower()
async def test_activate_returns_active_false_when_jail_does_not_start( async def test_activate_returns_active_false_when_jail_does_not_start(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
"""activate_jail returns active=False when the jail is absent after reload. """activate_jail returns active=False when the jail is absent after reload.
This covers the Stage 3.1 requirement: if the jail config is invalid This covers the Stage 3.1 requirement: if the jail config is invalid
@@ -2746,9 +2707,7 @@ class TestActivateJailReloadArgs:
), ),
): ):
mock_js.reload_all = AsyncMock() mock_js.reload_all = AsyncMock()
result = await activate_jail( result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
str(tmp_path), "/fake.sock", "apache-auth", req
)
assert result.active is False assert result.active is False
assert "apache-auth" in result.name assert "apache-auth" in result.name
@@ -2776,23 +2735,13 @@ class TestDeactivateJailReloadArgs:
mock_js.reload_all = AsyncMock() mock_js.reload_all = AsyncMock()
await deactivate_jail(str(tmp_path), "/fake.sock", "sshd") await deactivate_jail(str(tmp_path), "/fake.sock", "sshd")
mock_js.reload_all.assert_awaited_once_with( mock_js.reload_all.assert_awaited_once_with("/fake.sock", exclude_jails=["sshd"])
"/fake.sock", exclude_jails=["sshd"]
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# _validate_jail_config_sync (Task 3) # _validate_jail_config_sync (Task 3)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
from app.services.config_file_service import ( # noqa: E402 (added after block)
_validate_jail_config_sync,
_extract_filter_base_name,
_extract_action_base_name,
validate_jail_config,
rollback_jail,
)
class TestExtractFilterBaseName: class TestExtractFilterBaseName:
def test_plain_name(self) -> None: def test_plain_name(self) -> None:
@@ -2938,11 +2887,11 @@ class TestRollbackJail:
with ( with (
patch( patch(
"app.services.config_file_service._start_daemon", "app.services.config_file_service.start_daemon",
new=AsyncMock(return_value=True), new=AsyncMock(return_value=True),
), ),
patch( patch(
"app.services.config_file_service._wait_for_fail2ban", "app.services.config_file_service.wait_for_fail2ban",
new=AsyncMock(return_value=True), new=AsyncMock(return_value=True),
), ),
patch( patch(
@@ -2950,9 +2899,7 @@ class TestRollbackJail:
new=AsyncMock(return_value=set()), new=AsyncMock(return_value=set()),
), ),
): ):
result = await rollback_jail( result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
)
assert result.disabled is True assert result.disabled is True
assert result.fail2ban_running is True assert result.fail2ban_running is True
@@ -2968,26 +2915,22 @@ class TestRollbackJail:
with ( with (
patch( patch(
"app.services.config_file_service._start_daemon", "app.services.config_file_service.start_daemon",
new=AsyncMock(return_value=False), new=AsyncMock(return_value=False),
), ),
patch( patch(
"app.services.config_file_service._wait_for_fail2ban", "app.services.config_file_service.wait_for_fail2ban",
new=AsyncMock(return_value=False), new=AsyncMock(return_value=False),
), ),
): ):
result = await rollback_jail( result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
)
assert result.fail2ban_running is False assert result.fail2ban_running is False
assert result.disabled is True assert result.disabled is True
async def test_rollback_rejects_unsafe_name(self, tmp_path: Path) -> None: async def test_rollback_rejects_unsafe_name(self, tmp_path: Path) -> None:
with pytest.raises(JailNameError): with pytest.raises(JailNameError):
await rollback_jail( await rollback_jail(str(tmp_path), "/fake.sock", "../evil", ["fail2ban-client", "start"])
str(tmp_path), "/fake.sock", "../evil", ["fail2ban-client", "start"]
)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -3096,9 +3039,7 @@ class TestActivateJailBlocking:
class TestActivateJailRollback: class TestActivateJailRollback:
"""Rollback logic in activate_jail restores the .local file and recovers.""" """Rollback logic in activate_jail restores the .local file and recovers."""
async def test_activate_jail_rollback_on_reload_failure( async def test_activate_jail_rollback_on_reload_failure(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
"""Rollback when reload_all raises on the activation reload. """Rollback when reload_all raises on the activation reload.
Expects: Expects:
@@ -3135,23 +3076,17 @@ class TestActivateJailRollback:
), ),
patch( patch(
"app.services.config_file_service._validate_jail_config_sync", "app.services.config_file_service._validate_jail_config_sync",
return_value=JailValidationResult( return_value=JailValidationResult(jail_name="apache-auth", valid=True),
jail_name="apache-auth", valid=True
),
), ),
): ):
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect) mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
result = await activate_jail( result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
str(tmp_path), "/fake.sock", "apache-auth", req
)
assert result.active is False assert result.active is False
assert result.recovered is True assert result.recovered is True
assert local_path.read_text() == original_local assert local_path.read_text() == original_local
async def test_activate_jail_rollback_on_health_check_failure( async def test_activate_jail_rollback_on_health_check_failure(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
"""Rollback when fail2ban is unreachable after the activation reload. """Rollback when fail2ban is unreachable after the activation reload.
Expects: Expects:
@@ -3190,15 +3125,11 @@ class TestActivateJailRollback:
), ),
patch( patch(
"app.services.config_file_service._validate_jail_config_sync", "app.services.config_file_service._validate_jail_config_sync",
return_value=JailValidationResult( return_value=JailValidationResult(jail_name="apache-auth", valid=True),
jail_name="apache-auth", valid=True
),
), ),
): ):
mock_js.reload_all = AsyncMock() mock_js.reload_all = AsyncMock()
result = await activate_jail( result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
str(tmp_path), "/fake.sock", "apache-auth", req
)
assert result.active is False assert result.active is False
assert result.recovered is True assert result.recovered is True
@@ -3232,25 +3163,17 @@ class TestActivateJailRollback:
), ),
patch( patch(
"app.services.config_file_service._validate_jail_config_sync", "app.services.config_file_service._validate_jail_config_sync",
return_value=JailValidationResult( return_value=JailValidationResult(jail_name="apache-auth", valid=True),
jail_name="apache-auth", valid=True
),
), ),
): ):
# Both the activation reload and the recovery reload fail. # Both the activation reload and the recovery reload fail.
mock_js.reload_all = AsyncMock( mock_js.reload_all = AsyncMock(side_effect=RuntimeError("fail2ban unavailable"))
side_effect=RuntimeError("fail2ban unavailable") result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
)
result = await activate_jail(
str(tmp_path), "/fake.sock", "apache-auth", req
)
assert result.active is False assert result.active is False
assert result.recovered is False assert result.recovered is False
async def test_activate_jail_rollback_on_jail_not_found_error( async def test_activate_jail_rollback_on_jail_not_found_error(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
"""Rollback when reload_all raises JailNotFoundError (invalid config). """Rollback when reload_all raises JailNotFoundError (invalid config).
When fail2ban cannot create a jail due to invalid configuration When fail2ban cannot create a jail due to invalid configuration
@@ -3294,16 +3217,12 @@ class TestActivateJailRollback:
), ),
patch( patch(
"app.services.config_file_service._validate_jail_config_sync", "app.services.config_file_service._validate_jail_config_sync",
return_value=JailValidationResult( return_value=JailValidationResult(jail_name="apache-auth", valid=True),
jail_name="apache-auth", valid=True
),
), ),
): ):
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect) mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
mock_js.JailNotFoundError = JailNotFoundError mock_js.JailNotFoundError = JailNotFoundError
result = await activate_jail( result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
str(tmp_path), "/fake.sock", "apache-auth", req
)
assert result.active is False assert result.active is False
assert result.recovered is True assert result.recovered is True
@@ -3311,9 +3230,7 @@ class TestActivateJailRollback:
# Verify the error message mentions logpath issues. # Verify the error message mentions logpath issues.
assert "logpath" in result.message.lower() or "check that all logpath" in result.message.lower() assert "logpath" in result.message.lower() or "check that all logpath" in result.message.lower()
async def test_activate_jail_rollback_deletes_file_when_no_prior_local( async def test_activate_jail_rollback_deletes_file_when_no_prior_local(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
"""Rollback deletes the .local file when none existed before activation. """Rollback deletes the .local file when none existed before activation.
When a jail had no .local override before activation, activate_jail When a jail had no .local override before activation, activate_jail
@@ -3355,15 +3272,11 @@ class TestActivateJailRollback:
), ),
patch( patch(
"app.services.config_file_service._validate_jail_config_sync", "app.services.config_file_service._validate_jail_config_sync",
return_value=JailValidationResult( return_value=JailValidationResult(jail_name="apache-auth", valid=True),
jail_name="apache-auth", valid=True
),
), ),
): ):
mock_js.reload_all = AsyncMock(side_effect=reload_side_effect) mock_js.reload_all = AsyncMock(side_effect=reload_side_effect)
result = await activate_jail( result = await activate_jail(str(tmp_path), "/fake.sock", "apache-auth", req)
str(tmp_path), "/fake.sock", "apache-auth", req
)
assert result.active is False assert result.active is False
assert result.recovered is True assert result.recovered is True
@@ -3376,7 +3289,7 @@ class TestActivateJailRollback:
@pytest.mark.asyncio @pytest.mark.asyncio
class TestRollbackJail: class TestRollbackJailIntegration:
"""Integration tests for :func:`~app.services.config_file_service.rollback_jail`.""" """Integration tests for :func:`~app.services.config_file_service.rollback_jail`."""
async def test_local_file_written_enabled_false(self, tmp_path: Path) -> None: async def test_local_file_written_enabled_false(self, tmp_path: Path) -> None:
@@ -3419,15 +3332,11 @@ class TestRollbackJail:
AsyncMock(return_value={"other"}), AsyncMock(return_value={"other"}),
), ),
): ):
await rollback_jail( await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
)
mock_start.assert_awaited_once_with(["fail2ban-client", "start"]) mock_start.assert_awaited_once_with(["fail2ban-client", "start"])
async def test_fail2ban_running_reflects_socket_probe_not_subprocess_exit( async def test_fail2ban_running_reflects_socket_probe_not_subprocess_exit(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
"""fail2ban_running in the response reflects the socket probe result. """fail2ban_running in the response reflects the socket probe result.
Even when start_daemon returns True (subprocess exit 0), if the socket Even when start_daemon returns True (subprocess exit 0), if the socket
@@ -3443,15 +3352,11 @@ class TestRollbackJail:
AsyncMock(return_value=False), # socket still unresponsive AsyncMock(return_value=False), # socket still unresponsive
), ),
): ):
result = await rollback_jail( result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
)
assert result.fail2ban_running is False assert result.fail2ban_running is False
async def test_active_jails_zero_when_fail2ban_not_running( async def test_active_jails_zero_when_fail2ban_not_running(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
"""active_jails is 0 in the response when fail2ban_running is False.""" """active_jails is 0 in the response when fail2ban_running is False."""
with ( with (
patch( patch(
@@ -3463,15 +3368,11 @@ class TestRollbackJail:
AsyncMock(return_value=False), AsyncMock(return_value=False),
), ),
): ):
result = await rollback_jail( result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
)
assert result.active_jails == 0 assert result.active_jails == 0
async def test_active_jails_count_from_socket_when_running( async def test_active_jails_count_from_socket_when_running(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
"""active_jails reflects the actual jail count from the socket when fail2ban is up.""" """active_jails reflects the actual jail count from the socket when fail2ban is up."""
with ( with (
patch( patch(
@@ -3487,15 +3388,11 @@ class TestRollbackJail:
AsyncMock(return_value={"sshd", "nginx", "apache-auth"}), AsyncMock(return_value={"sshd", "nginx", "apache-auth"}),
), ),
): ):
result = await rollback_jail( result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
)
assert result.active_jails == 3 assert result.active_jails == 3
async def test_fail2ban_down_at_start_still_succeeds_file_write( async def test_fail2ban_down_at_start_still_succeeds_file_write(self, tmp_path: Path) -> None:
self, tmp_path: Path
) -> None:
"""rollback_jail writes the local file even when fail2ban is down at call time.""" """rollback_jail writes the local file even when fail2ban is down at call time."""
# fail2ban is down: start_daemon fails and wait_for_fail2ban returns False. # fail2ban is down: start_daemon fails and wait_for_fail2ban returns False.
with ( with (
@@ -3508,12 +3405,9 @@ class TestRollbackJail:
AsyncMock(return_value=False), AsyncMock(return_value=False),
), ),
): ):
result = await rollback_jail( result = await rollback_jail(str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"])
str(tmp_path), "/fake.sock", "sshd", ["fail2ban-client", "start"]
)
local = tmp_path / "jail.d" / "sshd.local" local = tmp_path / "jail.d" / "sshd.local"
assert local.is_file(), "local file must be written even when fail2ban is down" assert local.is_file(), "local file must be written even when fail2ban is down"
assert result.disabled is True assert result.disabled is True
assert result.fail2ban_running is False assert result.fail2ban_running is False

View File

@@ -2,6 +2,7 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping, Sequence
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@@ -44,7 +45,7 @@ def _make_session(response_json: dict[str, object], status: int = 200) -> MagicM
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def clear_geo_cache() -> None: # type: ignore[misc] def clear_geo_cache() -> None:
"""Flush the module-level geo cache before every test.""" """Flush the module-level geo cache before every test."""
geo_service.clear_cache() geo_service.clear_cache()
@@ -68,7 +69,7 @@ class TestLookupSuccess:
"org": "AS3320 Deutsche Telekom AG", "org": "AS3320 Deutsche Telekom AG",
} }
) )
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] result = await geo_service.lookup("1.2.3.4", session)
assert result is not None assert result is not None
assert result.country_code == "DE" assert result.country_code == "DE"
@@ -84,7 +85,7 @@ class TestLookupSuccess:
"org": "Google LLC", "org": "Google LLC",
} }
) )
result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type] result = await geo_service.lookup("8.8.8.8", session)
assert result is not None assert result is not None
assert result.country_name == "United States" assert result.country_name == "United States"
@@ -100,7 +101,7 @@ class TestLookupSuccess:
"org": "Deutsche Telekom", "org": "Deutsche Telekom",
} }
) )
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] result = await geo_service.lookup("1.2.3.4", session)
assert result is not None assert result is not None
assert result.asn == "AS3320" assert result.asn == "AS3320"
@@ -116,7 +117,7 @@ class TestLookupSuccess:
"org": "Google LLC", "org": "Google LLC",
} }
) )
result = await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type] result = await geo_service.lookup("8.8.8.8", session)
assert result is not None assert result is not None
assert result.org == "Google LLC" assert result.org == "Google LLC"
@@ -142,8 +143,8 @@ class TestLookupCaching:
} }
) )
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] await geo_service.lookup("1.2.3.4", session)
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] await geo_service.lookup("1.2.3.4", session)
# The session.get() should only have been called once. # The session.get() should only have been called once.
assert session.get.call_count == 1 assert session.get.call_count == 1
@@ -160,9 +161,9 @@ class TestLookupCaching:
} }
) )
await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type] await geo_service.lookup("2.3.4.5", session)
geo_service.clear_cache() geo_service.clear_cache()
await geo_service.lookup("2.3.4.5", session) # type: ignore[arg-type] await geo_service.lookup("2.3.4.5", session)
assert session.get.call_count == 2 assert session.get.call_count == 2
@@ -172,8 +173,8 @@ class TestLookupCaching:
{"status": "fail", "message": "reserved range"} {"status": "fail", "message": "reserved range"}
) )
await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type] await geo_service.lookup("192.168.1.1", session)
await geo_service.lookup("192.168.1.1", session) # type: ignore[arg-type] await geo_service.lookup("192.168.1.1", session)
# Second call is blocked by the negative cache — only one API hit. # Second call is blocked by the negative cache — only one API hit.
assert session.get.call_count == 1 assert session.get.call_count == 1
@@ -190,7 +191,7 @@ class TestLookupFailures:
async def test_non_200_response_returns_null_geo_info(self) -> None: async def test_non_200_response_returns_null_geo_info(self) -> None:
"""A 429 or 500 status returns GeoInfo with null fields (not None).""" """A 429 or 500 status returns GeoInfo with null fields (not None)."""
session = _make_session({}, status=429) session = _make_session({}, status=429)
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] result = await geo_service.lookup("1.2.3.4", session)
assert result is not None assert result is not None
assert isinstance(result, GeoInfo) assert isinstance(result, GeoInfo)
assert result.country_code is None assert result.country_code is None
@@ -203,7 +204,7 @@ class TestLookupFailures:
mock_ctx.__aexit__ = AsyncMock(return_value=False) mock_ctx.__aexit__ = AsyncMock(return_value=False)
session.get = MagicMock(return_value=mock_ctx) session.get = MagicMock(return_value=mock_ctx)
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type] result = await geo_service.lookup("10.0.0.1", session)
assert result is not None assert result is not None
assert isinstance(result, GeoInfo) assert isinstance(result, GeoInfo)
assert result.country_code is None assert result.country_code is None
@@ -211,7 +212,7 @@ class TestLookupFailures:
async def test_failed_status_returns_geo_info_with_nulls(self) -> None: async def test_failed_status_returns_geo_info_with_nulls(self) -> None:
"""When ip-api returns ``status=fail`` a GeoInfo with null fields is returned (but not cached).""" """When ip-api returns ``status=fail`` a GeoInfo with null fields is returned (but not cached)."""
session = _make_session({"status": "fail", "message": "private range"}) session = _make_session({"status": "fail", "message": "private range"})
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type] result = await geo_service.lookup("10.0.0.1", session)
assert result is not None assert result is not None
assert isinstance(result, GeoInfo) assert isinstance(result, GeoInfo)
@@ -231,8 +232,8 @@ class TestNegativeCache:
"""After a failed lookup the second call is served from the neg cache.""" """After a failed lookup the second call is served from the neg cache."""
session = _make_session({"status": "fail", "message": "private range"}) session = _make_session({"status": "fail", "message": "private range"})
r1 = await geo_service.lookup("192.0.2.1", session) # type: ignore[arg-type] r1 = await geo_service.lookup("192.0.2.1", session)
r2 = await geo_service.lookup("192.0.2.1", session) # type: ignore[arg-type] r2 = await geo_service.lookup("192.0.2.1", session)
# Only one HTTP call should have been made; second served from neg cache. # Only one HTTP call should have been made; second served from neg cache.
assert session.get.call_count == 1 assert session.get.call_count == 1
@@ -243,12 +244,12 @@ class TestNegativeCache:
"""When the neg-cache entry is older than the TTL a new API call is made.""" """When the neg-cache entry is older than the TTL a new API call is made."""
session = _make_session({"status": "fail", "message": "private range"}) session = _make_session({"status": "fail", "message": "private range"})
await geo_service.lookup("192.0.2.2", session) # type: ignore[arg-type] await geo_service.lookup("192.0.2.2", session)
# Manually expire the neg-cache entry. # Manually expire the neg-cache entry.
geo_service._neg_cache["192.0.2.2"] -= geo_service._NEG_CACHE_TTL + 1 # type: ignore[attr-defined] geo_service._neg_cache["192.0.2.2"] -= geo_service._NEG_CACHE_TTL + 1
await geo_service.lookup("192.0.2.2", session) # type: ignore[arg-type] await geo_service.lookup("192.0.2.2", session)
# Both calls should have hit the API. # Both calls should have hit the API.
assert session.get.call_count == 2 assert session.get.call_count == 2
@@ -257,9 +258,9 @@ class TestNegativeCache:
"""After clearing the neg cache the IP is eligible for a new API call.""" """After clearing the neg cache the IP is eligible for a new API call."""
session = _make_session({"status": "fail", "message": "private range"}) session = _make_session({"status": "fail", "message": "private range"})
await geo_service.lookup("192.0.2.3", session) # type: ignore[arg-type] await geo_service.lookup("192.0.2.3", session)
geo_service.clear_neg_cache() geo_service.clear_neg_cache()
await geo_service.lookup("192.0.2.3", session) # type: ignore[arg-type] await geo_service.lookup("192.0.2.3", session)
assert session.get.call_count == 2 assert session.get.call_count == 2
@@ -275,9 +276,9 @@ class TestNegativeCache:
} }
) )
await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] await geo_service.lookup("1.2.3.4", session)
assert "1.2.3.4" not in geo_service._neg_cache # type: ignore[attr-defined] assert "1.2.3.4" not in geo_service._neg_cache
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -307,7 +308,7 @@ class TestGeoipFallback:
mock_reader = self._make_geoip_reader("DE", "Germany") mock_reader = self._make_geoip_reader("DE", "Germany")
with patch.object(geo_service, "_geoip_reader", mock_reader): with patch.object(geo_service, "_geoip_reader", mock_reader):
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] result = await geo_service.lookup("1.2.3.4", session)
mock_reader.country.assert_called_once_with("1.2.3.4") mock_reader.country.assert_called_once_with("1.2.3.4")
assert result is not None assert result is not None
@@ -320,12 +321,12 @@ class TestGeoipFallback:
mock_reader = self._make_geoip_reader("US", "United States") mock_reader = self._make_geoip_reader("US", "United States")
with patch.object(geo_service, "_geoip_reader", mock_reader): with patch.object(geo_service, "_geoip_reader", mock_reader):
await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type] await geo_service.lookup("8.8.8.8", session)
# Second call must be served from positive cache without hitting API. # Second call must be served from positive cache without hitting API.
await geo_service.lookup("8.8.8.8", session) # type: ignore[arg-type] await geo_service.lookup("8.8.8.8", session)
assert session.get.call_count == 1 assert session.get.call_count == 1
assert "8.8.8.8" in geo_service._cache # type: ignore[attr-defined] assert "8.8.8.8" in geo_service._cache
async def test_geoip_fallback_not_called_on_api_success(self) -> None: async def test_geoip_fallback_not_called_on_api_success(self) -> None:
"""When ip-api succeeds, the geoip2 reader must not be consulted.""" """When ip-api succeeds, the geoip2 reader must not be consulted."""
@@ -341,7 +342,7 @@ class TestGeoipFallback:
mock_reader = self._make_geoip_reader("XX", "Nowhere") mock_reader = self._make_geoip_reader("XX", "Nowhere")
with patch.object(geo_service, "_geoip_reader", mock_reader): with patch.object(geo_service, "_geoip_reader", mock_reader):
result = await geo_service.lookup("1.2.3.4", session) # type: ignore[arg-type] result = await geo_service.lookup("1.2.3.4", session)
mock_reader.country.assert_not_called() mock_reader.country.assert_not_called()
assert result is not None assert result is not None
@@ -352,7 +353,7 @@ class TestGeoipFallback:
session = _make_session({"status": "fail", "message": "private range"}) session = _make_session({"status": "fail", "message": "private range"})
with patch.object(geo_service, "_geoip_reader", None): with patch.object(geo_service, "_geoip_reader", None):
result = await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type] result = await geo_service.lookup("10.0.0.1", session)
assert result is not None assert result is not None
assert result.country_code is None assert result.country_code is None
@@ -363,7 +364,7 @@ class TestGeoipFallback:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
def _make_batch_session(batch_response: list[dict[str, object]]) -> MagicMock: def _make_batch_session(batch_response: Sequence[Mapping[str, object]]) -> MagicMock:
"""Build a mock aiohttp.ClientSession for batch POST calls. """Build a mock aiohttp.ClientSession for batch POST calls.
Args: Args:
@@ -412,7 +413,7 @@ class TestLookupBatchSingleCommit:
session = _make_batch_session(batch_response) session = _make_batch_session(batch_response)
db = _make_async_db() db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type] await geo_service.lookup_batch(ips, session, db=db)
db.commit.assert_awaited_once() db.commit.assert_awaited_once()
@@ -426,7 +427,7 @@ class TestLookupBatchSingleCommit:
session = _make_batch_session(batch_response) session = _make_batch_session(batch_response)
db = _make_async_db() db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type] await geo_service.lookup_batch(ips, session, db=db)
db.commit.assert_awaited_once() db.commit.assert_awaited_once()
@@ -452,13 +453,13 @@ class TestLookupBatchSingleCommit:
async def test_no_commit_for_all_cached_ips(self) -> None: async def test_no_commit_for_all_cached_ips(self) -> None:
"""When all IPs are already cached, no HTTP call and no commit occur.""" """When all IPs are already cached, no HTTP call and no commit occur."""
geo_service._cache["5.5.5.5"] = GeoInfo( # type: ignore[attr-defined] geo_service._cache["5.5.5.5"] = GeoInfo(
country_code="FR", country_name="France", asn="AS1", org="ISP" country_code="FR", country_name="France", asn="AS1", org="ISP"
) )
db = _make_async_db() db = _make_async_db()
session = _make_batch_session([]) session = _make_batch_session([])
result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db) # type: ignore[arg-type] result = await geo_service.lookup_batch(["5.5.5.5"], session, db=db)
assert result["5.5.5.5"].country_code == "FR" assert result["5.5.5.5"].country_code == "FR"
db.commit.assert_not_awaited() db.commit.assert_not_awaited()
@@ -476,26 +477,26 @@ class TestDirtySetTracking:
def test_successful_resolution_adds_to_dirty(self) -> None: def test_successful_resolution_adds_to_dirty(self) -> None:
"""Storing a GeoInfo with a country_code adds the IP to _dirty.""" """Storing a GeoInfo with a country_code adds the IP to _dirty."""
info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP") info = GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")
geo_service._store("1.2.3.4", info) # type: ignore[attr-defined] geo_service._store("1.2.3.4", info)
assert "1.2.3.4" in geo_service._dirty # type: ignore[attr-defined] assert "1.2.3.4" in geo_service._dirty
def test_null_country_does_not_add_to_dirty(self) -> None: def test_null_country_does_not_add_to_dirty(self) -> None:
"""Storing a GeoInfo with country_code=None must not pollute _dirty.""" """Storing a GeoInfo with country_code=None must not pollute _dirty."""
info = GeoInfo(country_code=None, country_name=None, asn=None, org=None) info = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
geo_service._store("10.0.0.1", info) # type: ignore[attr-defined] geo_service._store("10.0.0.1", info)
assert "10.0.0.1" not in geo_service._dirty # type: ignore[attr-defined] assert "10.0.0.1" not in geo_service._dirty
def test_clear_cache_also_clears_dirty(self) -> None: def test_clear_cache_also_clears_dirty(self) -> None:
"""clear_cache() must discard any pending dirty entries.""" """clear_cache() must discard any pending dirty entries."""
info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP") info = GeoInfo(country_code="US", country_name="United States", asn="AS1", org="ISP")
geo_service._store("8.8.8.8", info) # type: ignore[attr-defined] geo_service._store("8.8.8.8", info)
assert geo_service._dirty # type: ignore[attr-defined] assert geo_service._dirty
geo_service.clear_cache() geo_service.clear_cache()
assert not geo_service._dirty # type: ignore[attr-defined] assert not geo_service._dirty
async def test_lookup_batch_populates_dirty(self) -> None: async def test_lookup_batch_populates_dirty(self) -> None:
"""After lookup_batch() with db=None, resolved IPs appear in _dirty.""" """After lookup_batch() with db=None, resolved IPs appear in _dirty."""
@@ -509,7 +510,7 @@ class TestDirtySetTracking:
await geo_service.lookup_batch(ips, session, db=None) await geo_service.lookup_batch(ips, session, db=None)
for ip in ips: for ip in ips:
assert ip in geo_service._dirty # type: ignore[attr-defined] assert ip in geo_service._dirty
class TestFlushDirty: class TestFlushDirty:
@@ -518,8 +519,8 @@ class TestFlushDirty:
async def test_flush_writes_and_clears_dirty(self) -> None: async def test_flush_writes_and_clears_dirty(self) -> None:
"""flush_dirty() inserts all dirty IPs and clears _dirty afterwards.""" """flush_dirty() inserts all dirty IPs and clears _dirty afterwards."""
info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT") info = GeoInfo(country_code="GB", country_name="United Kingdom", asn="AS2856", org="BT")
geo_service._store("100.0.0.1", info) # type: ignore[attr-defined] geo_service._store("100.0.0.1", info)
assert "100.0.0.1" in geo_service._dirty # type: ignore[attr-defined] assert "100.0.0.1" in geo_service._dirty
db = _make_async_db() db = _make_async_db()
count = await geo_service.flush_dirty(db) count = await geo_service.flush_dirty(db)
@@ -527,7 +528,7 @@ class TestFlushDirty:
assert count == 1 assert count == 1
db.executemany.assert_awaited_once() db.executemany.assert_awaited_once()
db.commit.assert_awaited_once() db.commit.assert_awaited_once()
assert "100.0.0.1" not in geo_service._dirty # type: ignore[attr-defined] assert "100.0.0.1" not in geo_service._dirty
async def test_flush_returns_zero_when_nothing_dirty(self) -> None: async def test_flush_returns_zero_when_nothing_dirty(self) -> None:
"""flush_dirty() returns 0 and makes no DB calls when _dirty is empty.""" """flush_dirty() returns 0 and makes no DB calls when _dirty is empty."""
@@ -541,7 +542,7 @@ class TestFlushDirty:
async def test_flush_re_adds_to_dirty_on_db_error(self) -> None: async def test_flush_re_adds_to_dirty_on_db_error(self) -> None:
"""When the DB write fails, entries are re-added to _dirty for retry.""" """When the DB write fails, entries are re-added to _dirty for retry."""
info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP") info = GeoInfo(country_code="AU", country_name="Australia", asn="AS1", org="ISP")
geo_service._store("200.0.0.1", info) # type: ignore[attr-defined] geo_service._store("200.0.0.1", info)
db = _make_async_db() db = _make_async_db()
db.executemany = AsyncMock(side_effect=OSError("disk full")) db.executemany = AsyncMock(side_effect=OSError("disk full"))
@@ -549,7 +550,7 @@ class TestFlushDirty:
count = await geo_service.flush_dirty(db) count = await geo_service.flush_dirty(db)
assert count == 0 assert count == 0
assert "200.0.0.1" in geo_service._dirty # type: ignore[attr-defined] assert "200.0.0.1" in geo_service._dirty
async def test_flush_batch_and_lookup_batch_integration(self) -> None: async def test_flush_batch_and_lookup_batch_integration(self) -> None:
"""lookup_batch() populates _dirty; flush_dirty() then persists them.""" """lookup_batch() populates _dirty; flush_dirty() then persists them."""
@@ -562,14 +563,14 @@ class TestFlushDirty:
# Resolve without DB to populate only in-memory cache and _dirty. # Resolve without DB to populate only in-memory cache and _dirty.
await geo_service.lookup_batch(ips, session, db=None) await geo_service.lookup_batch(ips, session, db=None)
assert geo_service._dirty == set(ips) # type: ignore[attr-defined] assert geo_service._dirty == set(ips)
# Now flush to the DB. # Now flush to the DB.
db = _make_async_db() db = _make_async_db()
count = await geo_service.flush_dirty(db) count = await geo_service.flush_dirty(db)
assert count == 2 assert count == 2
assert not geo_service._dirty # type: ignore[attr-defined] assert not geo_service._dirty
db.commit.assert_awaited_once() db.commit.assert_awaited_once()
@@ -585,7 +586,7 @@ class TestLookupBatchThrottling:
"""When more than _BATCH_SIZE IPs are sent, asyncio.sleep is called """When more than _BATCH_SIZE IPs are sent, asyncio.sleep is called
between consecutive batch HTTP calls with at least _BATCH_DELAY.""" between consecutive batch HTTP calls with at least _BATCH_DELAY."""
# Generate _BATCH_SIZE + 1 IPs so we get exactly 2 batch calls. # Generate _BATCH_SIZE + 1 IPs so we get exactly 2 batch calls.
batch_size: int = geo_service._BATCH_SIZE # type: ignore[attr-defined] batch_size: int = geo_service._BATCH_SIZE
ips = [f"10.0.{i // 256}.{i % 256}" for i in range(batch_size + 1)] ips = [f"10.0.{i // 256}.{i % 256}" for i in range(batch_size + 1)]
def _make_result(chunk: list[str], _session: object) -> dict[str, GeoInfo]: def _make_result(chunk: list[str], _session: object) -> dict[str, GeoInfo]:
@@ -608,7 +609,7 @@ class TestLookupBatchThrottling:
assert mock_batch.call_count == 2 assert mock_batch.call_count == 2
mock_sleep.assert_awaited_once() mock_sleep.assert_awaited_once()
delay_arg: float = mock_sleep.call_args[0][0] delay_arg: float = mock_sleep.call_args[0][0]
assert delay_arg >= geo_service._BATCH_DELAY # type: ignore[attr-defined] assert delay_arg >= geo_service._BATCH_DELAY
async def test_lookup_batch_retries_on_full_chunk_failure(self) -> None: async def test_lookup_batch_retries_on_full_chunk_failure(self) -> None:
"""When a chunk returns all-None on first try, it retries and succeeds.""" """When a chunk returns all-None on first try, it retries and succeeds."""
@@ -650,7 +651,7 @@ class TestLookupBatchThrottling:
_empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None) _empty = GeoInfo(country_code=None, country_name=None, asn=None, org=None)
_failure: dict[str, GeoInfo] = dict.fromkeys(ips, _empty) _failure: dict[str, GeoInfo] = dict.fromkeys(ips, _empty)
max_retries: int = geo_service._BATCH_MAX_RETRIES # type: ignore[attr-defined] max_retries: int = geo_service._BATCH_MAX_RETRIES
with ( with (
patch( patch(
@@ -667,11 +668,11 @@ class TestLookupBatchThrottling:
# IP should have no country. # IP should have no country.
assert result["9.9.9.9"].country_code is None assert result["9.9.9.9"].country_code is None
# Negative cache should contain the IP. # Negative cache should contain the IP.
assert "9.9.9.9" in geo_service._neg_cache # type: ignore[attr-defined] assert "9.9.9.9" in geo_service._neg_cache
# Sleep called for each retry with exponential backoff. # Sleep called for each retry with exponential backoff.
assert mock_sleep.call_count == max_retries assert mock_sleep.call_count == max_retries
backoff_values = [call.args[0] for call in mock_sleep.call_args_list] backoff_values = [call.args[0] for call in mock_sleep.call_args_list]
batch_delay: float = geo_service._BATCH_DELAY # type: ignore[attr-defined] batch_delay: float = geo_service._BATCH_DELAY
for i, val in enumerate(backoff_values): for i, val in enumerate(backoff_values):
expected = batch_delay * (2 ** (i + 1)) expected = batch_delay * (2 ** (i + 1))
assert val == pytest.approx(expected) assert val == pytest.approx(expected)
@@ -709,7 +710,7 @@ class TestErrorLogging:
import structlog.testing import structlog.testing
with structlog.testing.capture_logs() as captured: with structlog.testing.capture_logs() as captured:
result = await geo_service.lookup("197.221.98.153", session) # type: ignore[arg-type] result = await geo_service.lookup("197.221.98.153", session)
assert result is not None assert result is not None
assert result.country_code is None assert result.country_code is None
@@ -733,7 +734,7 @@ class TestErrorLogging:
import structlog.testing import structlog.testing
with structlog.testing.capture_logs() as captured: with structlog.testing.capture_logs() as captured:
await geo_service.lookup("10.0.0.1", session) # type: ignore[arg-type] await geo_service.lookup("10.0.0.1", session)
request_failed = [e for e in captured if e.get("event") == "geo_lookup_request_failed"] request_failed = [e for e in captured if e.get("event") == "geo_lookup_request_failed"]
assert len(request_failed) == 1 assert len(request_failed) == 1
@@ -757,7 +758,7 @@ class TestErrorLogging:
import structlog.testing import structlog.testing
with structlog.testing.capture_logs() as captured: with structlog.testing.capture_logs() as captured:
result = await geo_service._batch_api_call(["1.2.3.4"], session) # type: ignore[attr-defined] result = await geo_service._batch_api_call(["1.2.3.4"], session)
assert result["1.2.3.4"].country_code is None assert result["1.2.3.4"].country_code is None
@@ -778,7 +779,7 @@ class TestLookupCachedOnly:
def test_returns_cached_ips(self) -> None: def test_returns_cached_ips(self) -> None:
"""IPs already in the cache are returned in the geo_map.""" """IPs already in the cache are returned in the geo_map."""
geo_service._cache["1.1.1.1"] = GeoInfo( # type: ignore[attr-defined] geo_service._cache["1.1.1.1"] = GeoInfo(
country_code="AU", country_name="Australia", asn="AS13335", org="Cloudflare" country_code="AU", country_name="Australia", asn="AS13335", org="Cloudflare"
) )
geo_map, uncached = geo_service.lookup_cached_only(["1.1.1.1"]) geo_map, uncached = geo_service.lookup_cached_only(["1.1.1.1"])
@@ -798,7 +799,7 @@ class TestLookupCachedOnly:
"""IPs in the negative cache within TTL are not re-queued as uncached.""" """IPs in the negative cache within TTL are not re-queued as uncached."""
import time import time
geo_service._neg_cache["10.0.0.1"] = time.monotonic() # type: ignore[attr-defined] geo_service._neg_cache["10.0.0.1"] = time.monotonic()
geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.1"]) geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.1"])
@@ -807,7 +808,7 @@ class TestLookupCachedOnly:
def test_expired_neg_cache_requeued(self) -> None: def test_expired_neg_cache_requeued(self) -> None:
"""IPs whose neg-cache entry has expired are listed as uncached.""" """IPs whose neg-cache entry has expired are listed as uncached."""
geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired # type: ignore[attr-defined] geo_service._neg_cache["10.0.0.2"] = 0.0 # epoch 0 → expired
_geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.2"]) _geo_map, uncached = geo_service.lookup_cached_only(["10.0.0.2"])
@@ -815,12 +816,12 @@ class TestLookupCachedOnly:
def test_mixed_ips(self) -> None: def test_mixed_ips(self) -> None:
"""A mix of cached, neg-cached, and unknown IPs is split correctly.""" """A mix of cached, neg-cached, and unknown IPs is split correctly."""
geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined] geo_service._cache["1.2.3.4"] = GeoInfo(
country_code="DE", country_name="Germany", asn=None, org=None country_code="DE", country_name="Germany", asn=None, org=None
) )
import time import time
geo_service._neg_cache["5.5.5.5"] = time.monotonic() # type: ignore[attr-defined] geo_service._neg_cache["5.5.5.5"] = time.monotonic()
geo_map, uncached = geo_service.lookup_cached_only(["1.2.3.4", "5.5.5.5", "9.9.9.9"]) geo_map, uncached = geo_service.lookup_cached_only(["1.2.3.4", "5.5.5.5", "9.9.9.9"])
@@ -829,7 +830,7 @@ class TestLookupCachedOnly:
def test_deduplication(self) -> None: def test_deduplication(self) -> None:
"""Duplicate IPs in the input appear at most once in the output.""" """Duplicate IPs in the input appear at most once in the output."""
geo_service._cache["1.2.3.4"] = GeoInfo( # type: ignore[attr-defined] geo_service._cache["1.2.3.4"] = GeoInfo(
country_code="US", country_name="United States", asn=None, org=None country_code="US", country_name="United States", asn=None, org=None
) )
@@ -866,7 +867,7 @@ class TestLookupBatchBulkWrites:
session = _make_batch_session(batch_response) session = _make_batch_session(batch_response)
db = _make_async_db() db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type] await geo_service.lookup_batch(ips, session, db=db)
# One executemany for the positive rows. # One executemany for the positive rows.
assert db.executemany.await_count >= 1 assert db.executemany.await_count >= 1
@@ -883,7 +884,7 @@ class TestLookupBatchBulkWrites:
session = _make_batch_session(batch_response) session = _make_batch_session(batch_response)
db = _make_async_db() db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type] await geo_service.lookup_batch(ips, session, db=db)
assert db.executemany.await_count >= 1 assert db.executemany.await_count >= 1
db.execute.assert_not_awaited() db.execute.assert_not_awaited()
@@ -905,7 +906,7 @@ class TestLookupBatchBulkWrites:
session = _make_batch_session(batch_response) session = _make_batch_session(batch_response)
db = _make_async_db() db = _make_async_db()
await geo_service.lookup_batch(ips, session, db=db) # type: ignore[arg-type] await geo_service.lookup_batch(ips, session, db=db)
# One executemany for positives, one for negatives. # One executemany for positives, one for negatives.
assert db.executemany.await_count == 2 assert db.executemany.await_count == 2

View File

@@ -64,7 +64,7 @@ async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None:
@pytest.fixture @pytest.fixture
async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc] async def f2b_db_path(tmp_path: Path) -> str:
"""Return the path to a test fail2ban SQLite database.""" """Return the path to a test fail2ban SQLite database."""
path = str(tmp_path / "fail2ban_test.sqlite3") path = str(tmp_path / "fail2ban_test.sqlite3")
await _create_f2b_db( await _create_f2b_db(

View File

@@ -996,9 +996,6 @@ class TestGetJailBannedIps:
async def test_unknown_jail_raises_jail_not_found_error(self) -> None: async def test_unknown_jail_raises_jail_not_found_error(self) -> None:
"""get_jail_banned_ips raises JailNotFoundError for unknown jail.""" """get_jail_banned_ips raises JailNotFoundError for unknown jail."""
responses = {
"status|ghost|short": (0, pytest.raises), # will be overridden
}
# Simulate fail2ban returning an "unknown jail" error. # Simulate fail2ban returning an "unknown jail" error.
class _FakeClient: class _FakeClient:
def __init__(self, **_kw: Any) -> None: def __init__(self, **_kw: Any) -> None:

View File

@@ -270,7 +270,7 @@ class TestCrashDetection:
async def test_crash_within_window_creates_pending_recovery(self) -> None: async def test_crash_within_window_creates_pending_recovery(self) -> None:
"""An online→offline transition within 60 s of activation must set pending_recovery.""" """An online→offline transition within 60 s of activation must set pending_recovery."""
app = _make_app(prev_online=True) app = _make_app(prev_online=True)
now = datetime.datetime.now(tz=datetime.timezone.utc) now = datetime.datetime.now(tz=datetime.UTC)
app.state.last_activation = { app.state.last_activation = {
"jail_name": "sshd", "jail_name": "sshd",
"at": now - datetime.timedelta(seconds=10), "at": now - datetime.timedelta(seconds=10),
@@ -297,7 +297,7 @@ class TestCrashDetection:
app = _make_app(prev_online=True) app = _make_app(prev_online=True)
app.state.last_activation = { app.state.last_activation = {
"jail_name": "sshd", "jail_name": "sshd",
"at": datetime.datetime.now(tz=datetime.timezone.utc) "at": datetime.datetime.now(tz=datetime.UTC)
- datetime.timedelta(seconds=120), - datetime.timedelta(seconds=120),
} }
app.state.pending_recovery = None app.state.pending_recovery = None
@@ -315,8 +315,8 @@ class TestCrashDetection:
async def test_came_online_marks_pending_recovery_resolved(self) -> None: async def test_came_online_marks_pending_recovery_resolved(self) -> None:
"""An offline→online transition must mark an existing pending_recovery as recovered.""" """An offline→online transition must mark an existing pending_recovery as recovered."""
app = _make_app(prev_online=False) app = _make_app(prev_online=False)
activated_at = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(seconds=30) activated_at = datetime.datetime.now(tz=datetime.UTC) - datetime.timedelta(seconds=30)
detected_at = datetime.datetime.now(tz=datetime.timezone.utc) detected_at = datetime.datetime.now(tz=datetime.UTC)
app.state.pending_recovery = PendingRecovery( app.state.pending_recovery = PendingRecovery(
jail_name="sshd", jail_name="sshd",
activated_at=activated_at, activated_at=activated_at,