refactor: improve backend type safety and import organization
- Add TYPE_CHECKING guards for runtime-expensive imports (aiohttp, aiosqlite) - Reorganize imports to follow PEP 8 conventions - Convert TypeAlias to modern PEP 695 type syntax (where appropriate) - Use Sequence/Mapping from collections.abc for type hints (covariant) - Replace string literals with cast() for improved type inference - Fix casting of Fail2BanResponse and TypedDict patterns - Add IpLookupResult TypedDict for precise return type annotation - Reformat overlong lines for readability (120 char limit) - Add asyncio_mode and filterwarnings to pytest config - Update test fixtures with improved type hints This improves mypy type checking and makes type relationships explicit.
This commit is contained in:
@@ -85,4 +85,4 @@ def get_settings() -> Settings:
|
||||
A validated :class:`Settings` object. Raises :class:`pydantic.ValidationError`
|
||||
if required keys are absent or values fail validation.
|
||||
"""
|
||||
return Settings() # pydantic-settings populates required fields from env vars
|
||||
return Settings() # type: ignore[call-arg] # pydantic-settings populates required fields from env vars
|
||||
|
||||
@@ -92,7 +92,7 @@ async def get_settings(request: Request) -> Settings:
|
||||
Returns:
|
||||
The application settings loaded at startup.
|
||||
"""
|
||||
state = cast(AppState, request.app.state)
|
||||
state = cast("AppState", request.app.state)
|
||||
return state.settings
|
||||
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ from __future__ import annotations
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
import aiosqlite
|
||||
|
||||
|
||||
@@ -112,7 +114,7 @@ async def upsert_neg_entry(db: aiosqlite.Connection, ip: str) -> None:
|
||||
|
||||
async def bulk_upsert_entries(
|
||||
db: aiosqlite.Connection,
|
||||
rows: list[tuple[str, str | None, str | None, str | None, str | None]],
|
||||
rows: Sequence[tuple[str, str | None, str | None, str | None, str | None]],
|
||||
) -> int:
|
||||
"""Bulk insert or update multiple geo cache entries."""
|
||||
if not rows:
|
||||
|
||||
@@ -8,10 +8,11 @@ table. All methods are plain async functions that accept a
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, TypedDict, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
|
||||
import aiosqlite
|
||||
|
||||
|
||||
@@ -165,5 +166,5 @@ def _row_to_dict(row: object) -> ImportLogRow:
|
||||
Returns:
|
||||
Dict mapping column names to Python values.
|
||||
"""
|
||||
mapping = cast(Mapping[str, object], row)
|
||||
return cast(ImportLogRow, dict(mapping))
|
||||
mapping = cast("Mapping[str, object]", row)
|
||||
return cast("ImportLogRow", dict(mapping))
|
||||
|
||||
@@ -44,8 +44,6 @@ import structlog
|
||||
from fastapi import APIRouter, HTTPException, Path, Query, Request, status
|
||||
|
||||
from app.dependencies import AuthDep
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
from app.models.config import (
|
||||
ActionConfig,
|
||||
ActionCreateRequest,
|
||||
@@ -104,6 +102,8 @@ from app.services.jail_service import JailOperationError
|
||||
from app.tasks.health_check import _run_probe
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/config", tags=["Config"])
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -428,9 +428,7 @@ async def restart_fail2ban(
|
||||
await config_file_service.start_daemon(start_cmd_parts)
|
||||
|
||||
# Step 3: probe the socket until fail2ban is responsive or the budget expires.
|
||||
fail2ban_running: bool = await config_file_service.wait_for_fail2ban(
|
||||
socket_path, max_wait_seconds=10.0
|
||||
)
|
||||
fail2ban_running: bool = await config_file_service.wait_for_fail2ban(socket_path, max_wait_seconds=10.0)
|
||||
if not fail2ban_running:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
@@ -604,9 +602,7 @@ async def get_map_color_thresholds(
|
||||
"""
|
||||
from app.services import setup_service
|
||||
|
||||
high, medium, low = await setup_service.get_map_color_thresholds(
|
||||
request.app.state.db
|
||||
)
|
||||
high, medium, low = await setup_service.get_map_color_thresholds(request.app.state.db)
|
||||
return MapColorThresholdsResponse(
|
||||
threshold_high=high,
|
||||
threshold_medium=medium,
|
||||
@@ -696,9 +692,7 @@ async def activate_jail(
|
||||
req = body if body is not None else ActivateJailRequest()
|
||||
|
||||
try:
|
||||
result = await config_file_service.activate_jail(
|
||||
config_dir, socket_path, name, req
|
||||
)
|
||||
result = await config_file_service.activate_jail(config_dir, socket_path, name, req)
|
||||
except JailNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except JailNotFoundInConfigError:
|
||||
@@ -831,9 +825,7 @@ async def delete_jail_local_override(
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
|
||||
try:
|
||||
await config_file_service.delete_jail_local_override(
|
||||
config_dir, socket_path, name
|
||||
)
|
||||
await config_file_service.delete_jail_local_override(config_dir, socket_path, name)
|
||||
except JailNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except JailNotFoundInConfigError:
|
||||
@@ -952,9 +944,7 @@ async def rollback_jail(
|
||||
start_cmd_parts: list[str] = start_cmd.split()
|
||||
|
||||
try:
|
||||
result = await config_file_service.rollback_jail(
|
||||
config_dir, socket_path, name, start_cmd_parts
|
||||
)
|
||||
result = await config_file_service.rollback_jail(config_dir, socket_path, name, start_cmd_parts)
|
||||
except JailNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ConfigWriteError as exc:
|
||||
@@ -1107,9 +1097,7 @@ async def update_filter(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
return await config_file_service.update_filter(
|
||||
config_dir, socket_path, name, body, do_reload=reload
|
||||
)
|
||||
return await config_file_service.update_filter(config_dir, socket_path, name, body, do_reload=reload)
|
||||
except FilterNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except FilterNotFoundError:
|
||||
@@ -1159,9 +1147,7 @@ async def create_filter(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
return await config_file_service.create_filter(
|
||||
config_dir, socket_path, body, do_reload=reload
|
||||
)
|
||||
return await config_file_service.create_filter(config_dir, socket_path, body, do_reload=reload)
|
||||
except FilterNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except FilterAlreadyExistsError as exc:
|
||||
@@ -1257,9 +1243,7 @@ async def assign_filter_to_jail(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
await config_file_service.assign_filter_to_jail(
|
||||
config_dir, socket_path, name, body, do_reload=reload
|
||||
)
|
||||
await config_file_service.assign_filter_to_jail(config_dir, socket_path, name, body, do_reload=reload)
|
||||
except (JailNameError, FilterNameError) as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except JailNotFoundInConfigError:
|
||||
@@ -1403,9 +1387,7 @@ async def update_action(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
return await config_file_service.update_action(
|
||||
config_dir, socket_path, name, body, do_reload=reload
|
||||
)
|
||||
return await config_file_service.update_action(config_dir, socket_path, name, body, do_reload=reload)
|
||||
except ActionNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ActionNotFoundError:
|
||||
@@ -1451,9 +1433,7 @@ async def create_action(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
return await config_file_service.create_action(
|
||||
config_dir, socket_path, body, do_reload=reload
|
||||
)
|
||||
return await config_file_service.create_action(config_dir, socket_path, body, do_reload=reload)
|
||||
except ActionNameError as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except ActionAlreadyExistsError as exc:
|
||||
@@ -1546,9 +1526,7 @@ async def assign_action_to_jail(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
await config_file_service.assign_action_to_jail(
|
||||
config_dir, socket_path, name, body, do_reload=reload
|
||||
)
|
||||
await config_file_service.assign_action_to_jail(config_dir, socket_path, name, body, do_reload=reload)
|
||||
except (JailNameError, ActionNameError) as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except JailNotFoundInConfigError:
|
||||
@@ -1597,9 +1575,7 @@ async def remove_action_from_jail(
|
||||
config_dir: str = request.app.state.settings.fail2ban_config_dir
|
||||
socket_path: str = request.app.state.settings.fail2ban_socket
|
||||
try:
|
||||
await config_file_service.remove_action_from_jail(
|
||||
config_dir, socket_path, name, action_name, do_reload=reload
|
||||
)
|
||||
await config_file_service.remove_action_from_jail(config_dir, socket_path, name, action_name, do_reload=reload)
|
||||
except (JailNameError, ActionNameError) as exc:
|
||||
raise _bad_request(str(exc)) from exc
|
||||
except JailNotFoundInConfigError:
|
||||
@@ -1689,4 +1665,3 @@ async def get_service_status(
|
||||
return await config_service.get_service_status(socket_path)
|
||||
except Fail2BanConnectionError as exc:
|
||||
raise _bad_gateway(exc) from exc
|
||||
|
||||
|
||||
@@ -13,12 +13,15 @@ from typing import TYPE_CHECKING, Annotated
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
|
||||
from app.services.jail_service import IpLookupResult
|
||||
|
||||
import aiosqlite
|
||||
from fastapi import APIRouter, Depends, HTTPException, Path, Request, status
|
||||
|
||||
from app.dependencies import AuthDep, get_db
|
||||
from app.models.geo import GeoCacheStatsResponse, GeoDetail, IpLookupResponse
|
||||
from app.services import geo_service, jail_service
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/api/geo", tags=["Geo"])
|
||||
@@ -61,7 +64,7 @@ async def lookup_ip(
|
||||
return await geo_service.lookup(addr, http_session)
|
||||
|
||||
try:
|
||||
result = await jail_service.lookup_ip(
|
||||
result: IpLookupResult = await jail_service.lookup_ip(
|
||||
socket_path,
|
||||
ip,
|
||||
geo_enricher=_enricher,
|
||||
@@ -77,9 +80,9 @@ async def lookup_ip(
|
||||
detail=f"Cannot reach fail2ban: {exc}",
|
||||
) from exc
|
||||
|
||||
raw_geo = result.get("geo")
|
||||
raw_geo = result["geo"]
|
||||
geo_detail: GeoDetail | None = None
|
||||
if raw_geo is not None:
|
||||
if isinstance(raw_geo, GeoInfo):
|
||||
geo_detail = GeoDetail(
|
||||
country_code=raw_geo.country_code,
|
||||
country_name=raw_geo.country_name,
|
||||
|
||||
@@ -14,17 +14,11 @@ import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import asdict
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, TypeAlias
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import structlog
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiosqlite
|
||||
|
||||
from app.services.geo_service import GeoInfo
|
||||
|
||||
from app.models.ban import (
|
||||
BLOCKLIST_JAIL,
|
||||
BUCKET_SECONDS,
|
||||
@@ -37,20 +31,25 @@ from app.models.ban import (
|
||||
BanTrendResponse,
|
||||
DashboardBanItem,
|
||||
DashboardBanListResponse,
|
||||
JailBanCount as JailBanCountModel,
|
||||
TimeRange,
|
||||
_derive_origin,
|
||||
bucket_count,
|
||||
)
|
||||
from app.models.ban import (
|
||||
JailBanCount as JailBanCountModel,
|
||||
)
|
||||
from app.repositories import fail2ban_db_repo
|
||||
from app.utils.fail2ban_client import Fail2BanClient
|
||||
from app.utils.fail2ban_client import Fail2BanClient, Fail2BanResponse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
|
||||
from app.services.geo_service import GeoInfo
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
GeoEnricher: TypeAlias = Callable[[str], Awaitable["GeoInfo"] | None]
|
||||
type GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
@@ -137,7 +136,7 @@ async def _get_fail2ban_db_path(socket_path: str) -> str:
|
||||
response = await client.send(["get", "dbfile"])
|
||||
|
||||
try:
|
||||
code, data = response
|
||||
code, data = cast("Fail2BanResponse", response)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise RuntimeError(f"Unexpected response from fail2ban: {response!r}") from exc
|
||||
|
||||
@@ -276,7 +275,7 @@ async def list_bans(
|
||||
# Batch-resolve geo data for all IPs on this page in a single API call.
|
||||
# This avoids hitting the 45 req/min single-IP rate limit when the
|
||||
# page contains many bans (e.g. after a large blocklist import).
|
||||
geo_map: dict[str, "GeoInfo"] = {}
|
||||
geo_map: dict[str, GeoInfo] = {}
|
||||
if http_session is not None and rows:
|
||||
page_ips: list[str] = [r.ip for r in rows]
|
||||
try:
|
||||
@@ -428,7 +427,7 @@ async def bans_by_country(
|
||||
)
|
||||
|
||||
unique_ips: list[str] = [r.ip for r in agg_rows]
|
||||
geo_map: dict[str, "GeoInfo"] = {}
|
||||
geo_map: dict[str, GeoInfo] = {}
|
||||
|
||||
if http_session is not None and unique_ips:
|
||||
# Serve only what is already in the in-memory cache — no API calls on
|
||||
@@ -449,7 +448,7 @@ async def bans_by_country(
|
||||
)
|
||||
elif geo_enricher is not None and unique_ips:
|
||||
# Fallback: legacy per-IP enricher (used in tests / older callers).
|
||||
async def _safe_lookup(ip: str) -> tuple[str, "GeoInfo" | None]:
|
||||
async def _safe_lookup(ip: str) -> tuple[str, GeoInfo | None]:
|
||||
try:
|
||||
return ip, await geo_enricher(ip)
|
||||
except Exception: # noqa: BLE001
|
||||
@@ -636,9 +635,7 @@ async def bans_by_jail(
|
||||
# has *any* rows and log a warning with min/max timeofban so operators can
|
||||
# diagnose timezone or filter mismatches from logs.
|
||||
if total == 0:
|
||||
table_row_count, min_timeofban, max_timeofban = (
|
||||
await fail2ban_db_repo.get_bans_table_summary(db_path)
|
||||
)
|
||||
table_row_count, min_timeofban, max_timeofban = await fail2ban_db_repo.get_bans_table_summary(db_path)
|
||||
if table_row_count > 0:
|
||||
log.warning(
|
||||
"ban_service_bans_by_jail_empty_despite_data",
|
||||
|
||||
@@ -542,7 +542,7 @@ async def list_import_logs(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _aiohttp_timeout(seconds: float) -> "aiohttp.ClientTimeout":
|
||||
def _aiohttp_timeout(seconds: float) -> aiohttp.ClientTimeout:
|
||||
"""Return an :class:`aiohttp.ClientTimeout` with the given total timeout.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -28,7 +28,7 @@ import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, cast, TypeAlias
|
||||
from typing import cast
|
||||
|
||||
import structlog
|
||||
|
||||
@@ -59,7 +59,6 @@ from app.services.jail_service import JailNotFoundError as JailNotFoundError
|
||||
from app.utils import conffile_parser
|
||||
from app.utils.fail2ban_client import (
|
||||
Fail2BanClient,
|
||||
Fail2BanCommand,
|
||||
Fail2BanConnectionError,
|
||||
Fail2BanResponse,
|
||||
)
|
||||
@@ -73,9 +72,7 @@ log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
_SOCKET_TIMEOUT: float = 10.0
|
||||
|
||||
# Allowlist pattern for jail names used in path construction.
|
||||
_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(
|
||||
r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$"
|
||||
)
|
||||
_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
||||
|
||||
# Sections that are not jail definitions.
|
||||
_META_SECTIONS: frozenset[str] = frozenset({"INCLUDES", "DEFAULT"})
|
||||
@@ -167,8 +164,7 @@ class FilterReadonlyError(Exception):
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(
|
||||
f"Filter {name!r} is a shipped default (.conf only); "
|
||||
"only user-created .local files can be deleted."
|
||||
f"Filter {name!r} is a shipped default (.conf only); only user-created .local files can be deleted."
|
||||
)
|
||||
|
||||
|
||||
@@ -423,9 +419,7 @@ def _parse_jails_sync(
|
||||
# items() merges DEFAULT values automatically.
|
||||
jails[section] = dict(parser.items(section))
|
||||
except configparser.Error as exc:
|
||||
log.warning(
|
||||
"jail_section_parse_error", section=section, error=str(exc)
|
||||
)
|
||||
log.warning("jail_section_parse_error", section=section, error=str(exc))
|
||||
|
||||
log.debug("jails_parsed", count=len(jails), config_dir=str(config_dir))
|
||||
return jails, source_files
|
||||
@@ -522,11 +516,7 @@ def _build_inactive_jail(
|
||||
bantime_escalation=bantime_escalation,
|
||||
source_file=source_file,
|
||||
enabled=enabled,
|
||||
has_local_override=(
|
||||
(config_dir / "jail.d" / f"{name}.local").is_file()
|
||||
if config_dir is not None
|
||||
else False
|
||||
),
|
||||
has_local_override=((config_dir / "jail.d" / f"{name}.local").is_file() if config_dir is not None else False),
|
||||
)
|
||||
|
||||
|
||||
@@ -557,7 +547,7 @@ async def _get_active_jail_names(socket_path: str) -> set[str]:
|
||||
return result
|
||||
|
||||
def _ok(response: object) -> object:
|
||||
code, data = cast(Fail2BanResponse, response)
|
||||
code, data = cast("Fail2BanResponse", response)
|
||||
if code != 0:
|
||||
raise ValueError(f"fail2ban error {code}: {data!r}")
|
||||
return data
|
||||
@@ -572,9 +562,7 @@ async def _get_active_jail_names(socket_path: str) -> set[str]:
|
||||
log.warning("fail2ban_unreachable_during_inactive_list")
|
||||
return set()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"fail2ban_status_error_during_inactive_list", error=str(exc)
|
||||
)
|
||||
log.warning("fail2ban_status_error_during_inactive_list", error=str(exc))
|
||||
return set()
|
||||
|
||||
|
||||
@@ -662,10 +650,7 @@ def _validate_jail_config_sync(
|
||||
issues.append(
|
||||
JailValidationIssue(
|
||||
field="filter",
|
||||
message=(
|
||||
f"Filter file not found: filter.d/{base_filter}.conf"
|
||||
" (or .local)"
|
||||
),
|
||||
message=(f"Filter file not found: filter.d/{base_filter}.conf (or .local)"),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -681,10 +666,7 @@ def _validate_jail_config_sync(
|
||||
issues.append(
|
||||
JailValidationIssue(
|
||||
field="action",
|
||||
message=(
|
||||
f"Action file not found: action.d/{action_name}.conf"
|
||||
" (or .local)"
|
||||
),
|
||||
message=(f"Action file not found: action.d/{action_name}.conf (or .local)"),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -840,9 +822,7 @@ def _write_local_override_sync(
|
||||
try:
|
||||
jail_d.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(
|
||||
f"Cannot create jail.d directory: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
||||
|
||||
local_path = jail_d / f"{jail_name}.local"
|
||||
|
||||
@@ -867,7 +847,7 @@ def _write_local_override_sync(
|
||||
if overrides.get("port") is not None:
|
||||
lines.append(f"port = {overrides['port']}")
|
||||
if overrides.get("logpath"):
|
||||
paths: list[str] = cast(list[str], overrides["logpath"])
|
||||
paths: list[str] = cast("list[str]", overrides["logpath"])
|
||||
if paths:
|
||||
lines.append(f"logpath = {paths[0]}")
|
||||
for p in paths[1:]:
|
||||
@@ -890,9 +870,7 @@ def _write_local_override_sync(
|
||||
# Clean up temp file if rename failed.
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821 — only reachable when tmp_name is set
|
||||
raise ConfigWriteError(
|
||||
f"Failed to write {local_path}: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info(
|
||||
"jail_local_written",
|
||||
@@ -921,9 +899,7 @@ def _restore_local_file_sync(local_path: Path, original_content: bytes | None) -
|
||||
try:
|
||||
local_path.unlink(missing_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(
|
||||
f"Failed to delete {local_path} during rollback: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to delete {local_path} during rollback: {exc}") from exc
|
||||
return
|
||||
|
||||
tmp_name: str | None = None
|
||||
@@ -941,9 +917,7 @@ def _restore_local_file_sync(local_path: Path, original_content: bytes | None) -
|
||||
with contextlib.suppress(OSError):
|
||||
if tmp_name is not None:
|
||||
os.unlink(tmp_name)
|
||||
raise ConfigWriteError(
|
||||
f"Failed to restore {local_path} during rollback: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to restore {local_path} during rollback: {exc}") from exc
|
||||
|
||||
|
||||
def _validate_regex_patterns(patterns: list[str]) -> None:
|
||||
@@ -979,9 +953,7 @@ def _write_filter_local_sync(filter_d: Path, name: str, content: str) -> None:
|
||||
try:
|
||||
filter_d.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(
|
||||
f"Cannot create filter.d directory: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Cannot create filter.d directory: {exc}") from exc
|
||||
|
||||
local_path = filter_d / f"{name}.local"
|
||||
try:
|
||||
@@ -998,9 +970,7 @@ def _write_filter_local_sync(filter_d: Path, name: str, content: str) -> None:
|
||||
except OSError as exc:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821
|
||||
raise ConfigWriteError(
|
||||
f"Failed to write {local_path}: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info("filter_local_written", filter=name, path=str(local_path))
|
||||
|
||||
@@ -1031,9 +1001,7 @@ def _set_jail_local_key_sync(
|
||||
try:
|
||||
jail_d.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(
|
||||
f"Cannot create jail.d directory: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
||||
|
||||
local_path = jail_d / f"{jail_name}.local"
|
||||
|
||||
@@ -1072,9 +1040,7 @@ def _set_jail_local_key_sync(
|
||||
except OSError as exc:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821
|
||||
raise ConfigWriteError(
|
||||
f"Failed to write {local_path}: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info(
|
||||
"jail_local_key_set",
|
||||
@@ -1112,8 +1078,8 @@ async def list_inactive_jails(
|
||||
inactive jails.
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
parsed_result: tuple[dict[str, dict[str, str]], dict[str, str]] = (
|
||||
await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
parsed_result: tuple[dict[str, dict[str, str]], dict[str, str]] = await loop.run_in_executor(
|
||||
None, _parse_jails_sync, Path(config_dir)
|
||||
)
|
||||
all_jails, source_files = parsed_result
|
||||
active_names: set[str] = await _get_active_jail_names(socket_path)
|
||||
@@ -1170,9 +1136,7 @@ async def activate_jail(
|
||||
_safe_jail_name(name)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
all_jails, _source_files = await loop.run_in_executor(
|
||||
None, _parse_jails_sync, Path(config_dir)
|
||||
)
|
||||
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
|
||||
if name not in all_jails:
|
||||
raise JailNotFoundInConfigError(name)
|
||||
@@ -1208,10 +1172,7 @@ async def activate_jail(
|
||||
active=False,
|
||||
fail2ban_running=True,
|
||||
validation_warnings=warnings,
|
||||
message=(
|
||||
f"Jail {name!r} cannot be activated: "
|
||||
+ "; ".join(i.message for i in blocking)
|
||||
),
|
||||
message=(f"Jail {name!r} cannot be activated: " + "; ".join(i.message for i in blocking)),
|
||||
)
|
||||
|
||||
overrides: dict[str, object] = {
|
||||
@@ -1254,9 +1215,7 @@ async def activate_jail(
|
||||
jail=name,
|
||||
error=str(exc),
|
||||
)
|
||||
recovered = await _rollback_activation_async(
|
||||
config_dir, name, socket_path, original_content
|
||||
)
|
||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
@@ -1272,9 +1231,7 @@ async def activate_jail(
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("reload_after_activate_failed", jail=name, error=str(exc))
|
||||
recovered = await _rollback_activation_async(
|
||||
config_dir, name, socket_path, original_content
|
||||
)
|
||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
@@ -1305,9 +1262,7 @@ async def activate_jail(
|
||||
jail=name,
|
||||
message="fail2ban socket unreachable after reload — initiating rollback.",
|
||||
)
|
||||
recovered = await _rollback_activation_async(
|
||||
config_dir, name, socket_path, original_content
|
||||
)
|
||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
@@ -1330,9 +1285,7 @@ async def activate_jail(
|
||||
jail=name,
|
||||
message="Jail did not appear in running jails — initiating rollback.",
|
||||
)
|
||||
recovered = await _rollback_activation_async(
|
||||
config_dir, name, socket_path, original_content
|
||||
)
|
||||
recovered = await _rollback_activation_async(config_dir, name, socket_path, original_content)
|
||||
return JailActivationResponse(
|
||||
name=name,
|
||||
active=False,
|
||||
@@ -1388,14 +1341,10 @@ async def _rollback_activation_async(
|
||||
|
||||
# Step 1 — restore original file (or delete it).
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None, _restore_local_file_sync, local_path, original_content
|
||||
)
|
||||
await loop.run_in_executor(None, _restore_local_file_sync, local_path, original_content)
|
||||
log.info("jail_activation_rollback_file_restored", jail=name)
|
||||
except ConfigWriteError as exc:
|
||||
log.error(
|
||||
"jail_activation_rollback_restore_failed", jail=name, error=str(exc)
|
||||
)
|
||||
log.error("jail_activation_rollback_restore_failed", jail=name, error=str(exc))
|
||||
return False
|
||||
|
||||
# Step 2 — reload fail2ban with the restored config.
|
||||
@@ -1403,9 +1352,7 @@ async def _rollback_activation_async(
|
||||
await jail_service.reload_all(socket_path)
|
||||
log.info("jail_activation_rollback_reload_ok", jail=name)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning(
|
||||
"jail_activation_rollback_reload_failed", jail=name, error=str(exc)
|
||||
)
|
||||
log.warning("jail_activation_rollback_reload_failed", jail=name, error=str(exc))
|
||||
return False
|
||||
|
||||
# Step 3 — wait for fail2ban to come back.
|
||||
@@ -1450,9 +1397,7 @@ async def deactivate_jail(
|
||||
_safe_jail_name(name)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
all_jails, _source_files = await loop.run_in_executor(
|
||||
None, _parse_jails_sync, Path(config_dir)
|
||||
)
|
||||
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
|
||||
if name not in all_jails:
|
||||
raise JailNotFoundInConfigError(name)
|
||||
@@ -1510,9 +1455,7 @@ async def delete_jail_local_override(
|
||||
_safe_jail_name(name)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
all_jails, _source_files = await loop.run_in_executor(
|
||||
None, _parse_jails_sync, Path(config_dir)
|
||||
)
|
||||
all_jails, _source_files = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
|
||||
if name not in all_jails:
|
||||
raise JailNotFoundInConfigError(name)
|
||||
@@ -1523,13 +1466,9 @@ async def delete_jail_local_override(
|
||||
|
||||
local_path = Path(config_dir) / "jail.d" / f"{name}.local"
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None, lambda: local_path.unlink(missing_ok=True)
|
||||
)
|
||||
await loop.run_in_executor(None, lambda: local_path.unlink(missing_ok=True))
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(
|
||||
f"Failed to delete {local_path}: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
|
||||
|
||||
log.info("jail_local_override_deleted", jail=name, path=str(local_path))
|
||||
|
||||
@@ -1610,9 +1549,7 @@ async def rollback_jail(
|
||||
log.info("jail_rollback_start_attempted", jail=name, start_ok=started)
|
||||
|
||||
# Wait for the socket to come back.
|
||||
fail2ban_running = await wait_for_fail2ban(
|
||||
socket_path, max_wait_seconds=10.0, poll_interval=2.0
|
||||
)
|
||||
fail2ban_running = await wait_for_fail2ban(socket_path, max_wait_seconds=10.0, poll_interval=2.0)
|
||||
|
||||
active_jails = 0
|
||||
if fail2ban_running:
|
||||
@@ -1626,10 +1563,7 @@ async def rollback_jail(
|
||||
disabled=True,
|
||||
fail2ban_running=True,
|
||||
active_jails=active_jails,
|
||||
message=(
|
||||
f"Jail {name!r} disabled and fail2ban restarted successfully "
|
||||
f"with {active_jails} active jail(s)."
|
||||
),
|
||||
message=(f"Jail {name!r} disabled and fail2ban restarted successfully with {active_jails} active jail(s)."),
|
||||
)
|
||||
|
||||
log.warning("jail_rollback_fail2ban_still_down", jail=name)
|
||||
@@ -1650,9 +1584,7 @@ async def rollback_jail(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Allowlist pattern for filter names used in path construction.
|
||||
_SAFE_FILTER_NAME_RE: re.Pattern[str] = re.compile(
|
||||
r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$"
|
||||
)
|
||||
_SAFE_FILTER_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
||||
|
||||
|
||||
class FilterNotFoundError(Exception):
|
||||
@@ -1764,9 +1696,7 @@ def _parse_filters_sync(
|
||||
try:
|
||||
content = conf_path.read_text(encoding="utf-8")
|
||||
except OSError as exc:
|
||||
log.warning(
|
||||
"filter_read_error", name=name, path=str(conf_path), error=str(exc)
|
||||
)
|
||||
log.warning("filter_read_error", name=name, path=str(conf_path), error=str(exc))
|
||||
continue
|
||||
|
||||
if has_local:
|
||||
@@ -1842,9 +1772,7 @@ async def list_filters(
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Run the synchronous scan in a thread-pool executor.
|
||||
raw_filters: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(
|
||||
None, _parse_filters_sync, filter_d
|
||||
)
|
||||
raw_filters: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(None, _parse_filters_sync, filter_d)
|
||||
|
||||
# Fetch active jail names and their configs concurrently.
|
||||
all_jails_result, active_names = await asyncio.gather(
|
||||
@@ -1857,9 +1785,7 @@ async def list_filters(
|
||||
|
||||
filters: list[FilterConfig] = []
|
||||
for name, filename, content, has_local, source_path in raw_filters:
|
||||
cfg = conffile_parser.parse_filter_file(
|
||||
content, name=name, filename=filename
|
||||
)
|
||||
cfg = conffile_parser.parse_filter_file(content, name=name, filename=filename)
|
||||
used_by = sorted(filter_to_jails.get(name, []))
|
||||
filters.append(
|
||||
FilterConfig(
|
||||
@@ -1947,9 +1873,7 @@ async def get_filter(
|
||||
|
||||
content, has_local, source_path = await loop.run_in_executor(None, _read)
|
||||
|
||||
cfg = conffile_parser.parse_filter_file(
|
||||
content, name=base_name, filename=f"{base_name}.conf"
|
||||
)
|
||||
cfg = conffile_parser.parse_filter_file(content, name=base_name, filename=f"{base_name}.conf")
|
||||
|
||||
all_jails_result, active_names = await asyncio.gather(
|
||||
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
||||
@@ -2182,9 +2106,7 @@ async def delete_filter(
|
||||
try:
|
||||
local_path.unlink()
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(
|
||||
f"Failed to delete {local_path}: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
|
||||
|
||||
log.info("filter_local_deleted", filter=base_name, path=str(local_path))
|
||||
|
||||
@@ -2226,9 +2148,7 @@ async def assign_filter_to_jail(
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# Verify the jail exists in config.
|
||||
all_jails, _src = await loop.run_in_executor(
|
||||
None, _parse_jails_sync, Path(config_dir)
|
||||
)
|
||||
all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
if jail_name not in all_jails:
|
||||
raise JailNotFoundInConfigError(jail_name)
|
||||
|
||||
@@ -2276,9 +2196,7 @@ async def assign_filter_to_jail(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Allowlist pattern for action names used in path construction.
|
||||
_SAFE_ACTION_NAME_RE: re.Pattern[str] = re.compile(
|
||||
r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$"
|
||||
)
|
||||
_SAFE_ACTION_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
||||
|
||||
|
||||
class ActionNotFoundError(Exception):
|
||||
@@ -2318,8 +2236,7 @@ class ActionReadonlyError(Exception):
|
||||
"""
|
||||
self.name: str = name
|
||||
super().__init__(
|
||||
f"Action {name!r} is a shipped default (.conf only); "
|
||||
"only user-created .local files can be deleted."
|
||||
f"Action {name!r} is a shipped default (.conf only); only user-created .local files can be deleted."
|
||||
)
|
||||
|
||||
|
||||
@@ -2428,9 +2345,7 @@ def _parse_actions_sync(
|
||||
try:
|
||||
content = conf_path.read_text(encoding="utf-8")
|
||||
except OSError as exc:
|
||||
log.warning(
|
||||
"action_read_error", name=name, path=str(conf_path), error=str(exc)
|
||||
)
|
||||
log.warning("action_read_error", name=name, path=str(conf_path), error=str(exc))
|
||||
continue
|
||||
|
||||
if has_local:
|
||||
@@ -2495,9 +2410,7 @@ def _append_jail_action_sync(
|
||||
try:
|
||||
jail_d.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(
|
||||
f"Cannot create jail.d directory: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
||||
|
||||
local_path = jail_d / f"{jail_name}.local"
|
||||
|
||||
@@ -2517,9 +2430,7 @@ def _append_jail_action_sync(
|
||||
|
||||
existing_raw = parser.get(jail_name, "action") if parser.has_option(jail_name, "action") else ""
|
||||
existing_lines = [
|
||||
line.strip()
|
||||
for line in existing_raw.splitlines()
|
||||
if line.strip() and not line.strip().startswith("#")
|
||||
line.strip() for line in existing_raw.splitlines() if line.strip() and not line.strip().startswith("#")
|
||||
]
|
||||
|
||||
# Extract base names from existing entries for duplicate checking.
|
||||
@@ -2533,9 +2444,7 @@ def _append_jail_action_sync(
|
||||
|
||||
if existing_lines:
|
||||
# configparser multi-line: continuation lines start with whitespace.
|
||||
new_value = existing_lines[0] + "".join(
|
||||
f"\n {line}" for line in existing_lines[1:]
|
||||
)
|
||||
new_value = existing_lines[0] + "".join(f"\n {line}" for line in existing_lines[1:])
|
||||
parser.set(jail_name, "action", new_value)
|
||||
else:
|
||||
parser.set(jail_name, "action", action_entry)
|
||||
@@ -2559,9 +2468,7 @@ def _append_jail_action_sync(
|
||||
except OSError as exc:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821
|
||||
raise ConfigWriteError(
|
||||
f"Failed to write {local_path}: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info(
|
||||
"jail_action_appended",
|
||||
@@ -2612,9 +2519,7 @@ def _remove_jail_action_sync(
|
||||
|
||||
existing_raw = parser.get(jail_name, "action")
|
||||
existing_lines = [
|
||||
line.strip()
|
||||
for line in existing_raw.splitlines()
|
||||
if line.strip() and not line.strip().startswith("#")
|
||||
line.strip() for line in existing_raw.splitlines() if line.strip() and not line.strip().startswith("#")
|
||||
]
|
||||
|
||||
def _base(entry: str) -> str:
|
||||
@@ -2628,9 +2533,7 @@ def _remove_jail_action_sync(
|
||||
return
|
||||
|
||||
if filtered:
|
||||
new_value = filtered[0] + "".join(
|
||||
f"\n {line}" for line in filtered[1:]
|
||||
)
|
||||
new_value = filtered[0] + "".join(f"\n {line}" for line in filtered[1:])
|
||||
parser.set(jail_name, "action", new_value)
|
||||
else:
|
||||
parser.remove_option(jail_name, "action")
|
||||
@@ -2654,9 +2557,7 @@ def _remove_jail_action_sync(
|
||||
except OSError as exc:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821
|
||||
raise ConfigWriteError(
|
||||
f"Failed to write {local_path}: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info(
|
||||
"jail_action_removed",
|
||||
@@ -2683,9 +2584,7 @@ def _write_action_local_sync(action_d: Path, name: str, content: str) -> None:
|
||||
try:
|
||||
action_d.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(
|
||||
f"Cannot create action.d directory: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Cannot create action.d directory: {exc}") from exc
|
||||
|
||||
local_path = action_d / f"{name}.local"
|
||||
try:
|
||||
@@ -2702,9 +2601,7 @@ def _write_action_local_sync(action_d: Path, name: str, content: str) -> None:
|
||||
except OSError as exc:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821
|
||||
raise ConfigWriteError(
|
||||
f"Failed to write {local_path}: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info("action_local_written", action=name, path=str(local_path))
|
||||
|
||||
@@ -2740,9 +2637,7 @@ async def list_actions(
|
||||
action_d = Path(config_dir) / "action.d"
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
raw_actions: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(
|
||||
None, _parse_actions_sync, action_d
|
||||
)
|
||||
raw_actions: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor(None, _parse_actions_sync, action_d)
|
||||
|
||||
all_jails_result, active_names = await asyncio.gather(
|
||||
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
||||
@@ -2754,9 +2649,7 @@ async def list_actions(
|
||||
|
||||
actions: list[ActionConfig] = []
|
||||
for name, filename, content, has_local, source_path in raw_actions:
|
||||
cfg = conffile_parser.parse_action_file(
|
||||
content, name=name, filename=filename
|
||||
)
|
||||
cfg = conffile_parser.parse_action_file(content, name=name, filename=filename)
|
||||
used_by = sorted(action_to_jails.get(name, []))
|
||||
actions.append(
|
||||
ActionConfig(
|
||||
@@ -2843,9 +2736,7 @@ async def get_action(
|
||||
|
||||
content, has_local, source_path = await loop.run_in_executor(None, _read)
|
||||
|
||||
cfg = conffile_parser.parse_action_file(
|
||||
content, name=base_name, filename=f"{base_name}.conf"
|
||||
)
|
||||
cfg = conffile_parser.parse_action_file(content, name=base_name, filename=f"{base_name}.conf")
|
||||
|
||||
all_jails_result, active_names = await asyncio.gather(
|
||||
loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)),
|
||||
@@ -3061,9 +2952,7 @@ async def delete_action(
|
||||
try:
|
||||
local_path.unlink()
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(
|
||||
f"Failed to delete {local_path}: {exc}"
|
||||
) from exc
|
||||
raise ConfigWriteError(f"Failed to delete {local_path}: {exc}") from exc
|
||||
|
||||
log.info("action_local_deleted", action=base_name, path=str(local_path))
|
||||
|
||||
@@ -3105,9 +2994,7 @@ async def assign_action_to_jail(
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
all_jails, _src = await loop.run_in_executor(
|
||||
None, _parse_jails_sync, Path(config_dir)
|
||||
)
|
||||
all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
if jail_name not in all_jails:
|
||||
raise JailNotFoundInConfigError(jail_name)
|
||||
|
||||
@@ -3187,9 +3074,7 @@ async def remove_action_from_jail(
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
all_jails, _src = await loop.run_in_executor(
|
||||
None, _parse_jails_sync, Path(config_dir)
|
||||
)
|
||||
all_jails, _src = await loop.run_in_executor(None, _parse_jails_sync, Path(config_dir))
|
||||
if jail_name not in all_jails:
|
||||
raise JailNotFoundInConfigError(jail_name)
|
||||
|
||||
@@ -3218,4 +3103,3 @@ async def remove_action_from_jail(
|
||||
action=action_name,
|
||||
reload=do_reload,
|
||||
)
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ def _ok(response: object) -> object:
|
||||
ValueError: If the return code indicates an error.
|
||||
"""
|
||||
try:
|
||||
code, data = cast(Fail2BanResponse, response)
|
||||
code, data = cast("Fail2BanResponse", response)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||
if code != 0:
|
||||
@@ -128,7 +128,7 @@ def _ensure_list(value: object | None) -> list[str]:
|
||||
return [str(value)]
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
async def _safe_get(
|
||||
@@ -143,13 +143,13 @@ async def _safe_get(
|
||||
return default
|
||||
|
||||
|
||||
async def _safe_get_typed(
|
||||
async def _safe_get_typed[T](
|
||||
client: Fail2BanClient,
|
||||
command: Fail2BanCommand,
|
||||
default: _T,
|
||||
) -> _T:
|
||||
default: T,
|
||||
) -> T:
|
||||
"""Send a command and return the result typed as ``default``'s type."""
|
||||
return cast(_T, await _safe_get(client, command, default))
|
||||
return cast("T", await _safe_get(client, command, default))
|
||||
|
||||
|
||||
def _is_not_found_error(exc: Exception) -> bool:
|
||||
|
||||
@@ -47,7 +47,7 @@ def _ok(response: object) -> object:
|
||||
ValueError: If the response indicates an error (return code ≠ 0).
|
||||
"""
|
||||
try:
|
||||
code, data = cast(Fail2BanResponse, response)
|
||||
code, data = cast("Fail2BanResponse", response)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||
|
||||
|
||||
@@ -11,10 +11,12 @@ modifies or locks the fail2ban database.
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
|
||||
from app.services.geo_service import GeoEnricher
|
||||
if TYPE_CHECKING:
|
||||
from app.services.geo_service import GeoEnricher
|
||||
|
||||
from app.models.ban import TIME_RANGE_SECONDS, TimeRange
|
||||
from app.models.history import (
|
||||
|
||||
@@ -14,7 +14,8 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import contextlib
|
||||
import ipaddress
|
||||
from typing import TYPE_CHECKING, Awaitable, Callable, cast, TypeAlias
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, TypedDict, cast
|
||||
|
||||
import structlog
|
||||
|
||||
@@ -27,6 +28,7 @@ from app.models.jail import (
|
||||
JailStatus,
|
||||
JailSummary,
|
||||
)
|
||||
from app.services.geo_service import GeoInfo
|
||||
from app.utils.fail2ban_client import (
|
||||
Fail2BanClient,
|
||||
Fail2BanCommand,
|
||||
@@ -39,11 +41,21 @@ if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
|
||||
from app.services.geo_service import GeoInfo
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
GeoEnricher: TypeAlias = Callable[[str], Awaitable["GeoInfo | None"]]
|
||||
class IpLookupResult(TypedDict):
|
||||
"""Result returned by :func:`lookup_ip`.
|
||||
|
||||
This is intentionally a :class:`TypedDict` to provide precise typing for
|
||||
callers (e.g. routers) while keeping the implementation flexible.
|
||||
"""
|
||||
|
||||
ip: str
|
||||
currently_banned_in: list[str]
|
||||
geo: GeoInfo | None
|
||||
|
||||
|
||||
GeoEnricher = Callable[[str], Awaitable[GeoInfo | None]]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
@@ -104,7 +116,7 @@ def _ok(response: object) -> object:
|
||||
ValueError: If the response indicates an error (return code ≠ 0).
|
||||
"""
|
||||
try:
|
||||
code, data = cast(Fail2BanResponse, response)
|
||||
code, data = cast("Fail2BanResponse", response)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError(f"Unexpected fail2ban response shape: {response!r}") from exc
|
||||
|
||||
@@ -202,7 +214,7 @@ async def _safe_get(
|
||||
"""
|
||||
try:
|
||||
response = await client.send(command)
|
||||
return _ok(cast(Fail2BanResponse, response))
|
||||
return _ok(cast("Fail2BanResponse", response))
|
||||
except (ValueError, TypeError, Exception):
|
||||
return default
|
||||
|
||||
@@ -337,7 +349,6 @@ async def _fetch_jail_summary(
|
||||
client.send(["get", name, "backend"]),
|
||||
client.send(["get", name, "idle"]),
|
||||
])
|
||||
uses_backend_backend_commands = True
|
||||
else:
|
||||
# Commands not supported; return default values without sending.
|
||||
async def _return_default(value: object | None) -> Fail2BanResponse:
|
||||
@@ -347,7 +358,6 @@ async def _fetch_jail_summary(
|
||||
_return_default("polling"), # backend default
|
||||
_return_default(False), # idle default
|
||||
])
|
||||
uses_backend_backend_commands = False
|
||||
|
||||
_r = await asyncio.gather(*gather_list, return_exceptions=True)
|
||||
status_raw: object | Exception = _r[0]
|
||||
@@ -377,7 +387,7 @@ async def _fetch_jail_summary(
|
||||
if isinstance(raw, Exception):
|
||||
return fallback
|
||||
try:
|
||||
return int(str(_ok(cast(Fail2BanResponse, raw))))
|
||||
return int(str(_ok(cast("Fail2BanResponse", raw))))
|
||||
except (ValueError, TypeError):
|
||||
return fallback
|
||||
|
||||
@@ -385,7 +395,7 @@ async def _fetch_jail_summary(
|
||||
if isinstance(raw, Exception):
|
||||
return fallback
|
||||
try:
|
||||
return str(_ok(cast(Fail2BanResponse, raw)))
|
||||
return str(_ok(cast("Fail2BanResponse", raw)))
|
||||
except (ValueError, TypeError):
|
||||
return fallback
|
||||
|
||||
@@ -393,7 +403,7 @@ async def _fetch_jail_summary(
|
||||
if isinstance(raw, Exception):
|
||||
return fallback
|
||||
try:
|
||||
return bool(_ok(cast(Fail2BanResponse, raw)))
|
||||
return bool(_ok(cast("Fail2BanResponse", raw)))
|
||||
except (ValueError, TypeError):
|
||||
return fallback
|
||||
|
||||
@@ -687,7 +697,7 @@ async def reload_all(
|
||||
names_set -= set(exclude_jails)
|
||||
|
||||
stream: list[list[object]] = [["start", n] for n in sorted(names_set)]
|
||||
_ok(await client.send(["reload", "--all", [], cast(Fail2BanToken, stream)]))
|
||||
_ok(await client.send(["reload", "--all", [], cast("Fail2BanToken", stream)]))
|
||||
log.info("all_jails_reloaded")
|
||||
except ValueError as exc:
|
||||
# Detect UnknownJailException (missing or invalid jail configuration)
|
||||
@@ -811,8 +821,8 @@ async def unban_ip(
|
||||
async def get_active_bans(
|
||||
socket_path: str,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
http_session: "aiohttp.ClientSession" | None = None,
|
||||
app_db: "aiosqlite.Connection" | None = None,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
) -> ActiveBanListResponse:
|
||||
"""Return all currently banned IPs across every jail.
|
||||
|
||||
@@ -880,7 +890,7 @@ async def get_active_bans(
|
||||
continue
|
||||
|
||||
try:
|
||||
ban_list: list[str] = cast(list[str], _ok(raw_result)) or []
|
||||
ban_list: list[str] = cast("list[str]", _ok(raw_result)) or []
|
||||
except (TypeError, ValueError) as exc:
|
||||
log.warning(
|
||||
"active_bans_parse_error",
|
||||
@@ -1007,8 +1017,8 @@ async def get_jail_banned_ips(
|
||||
page: int = 1,
|
||||
page_size: int = 25,
|
||||
search: str | None = None,
|
||||
http_session: "aiohttp.ClientSession" | None = None,
|
||||
app_db: "aiosqlite.Connection" | None = None,
|
||||
http_session: aiohttp.ClientSession | None = None,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
) -> JailBannedIpsResponse:
|
||||
"""Return a paginated list of currently banned IPs for a single jail.
|
||||
|
||||
@@ -1055,7 +1065,7 @@ async def get_jail_banned_ips(
|
||||
except (ValueError, TypeError):
|
||||
raw_result = []
|
||||
|
||||
ban_list: list[str] = cast(list[str], raw_result) or []
|
||||
ban_list: list[str] = cast("list[str]", raw_result) or []
|
||||
|
||||
# Parse all entries.
|
||||
all_bans: list[ActiveBan] = []
|
||||
@@ -1121,7 +1131,7 @@ async def _enrich_bans(
|
||||
The same list with ``country`` fields populated where lookup succeeded.
|
||||
"""
|
||||
geo_results: list[object | Exception] = await asyncio.gather(
|
||||
*[cast(Awaitable[object], geo_enricher(ban.ip)) for ban in bans],
|
||||
*[cast("Awaitable[object]", geo_enricher(ban.ip)) for ban in bans],
|
||||
return_exceptions=True,
|
||||
)
|
||||
enriched: list[ActiveBan] = []
|
||||
@@ -1277,7 +1287,7 @@ async def lookup_ip(
|
||||
socket_path: str,
|
||||
ip: str,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
) -> dict[str, object | list[str] | None]:
|
||||
) -> IpLookupResult:
|
||||
"""Return ban status and history for a single IP address.
|
||||
|
||||
Checks every running jail for whether the IP is currently banned.
|
||||
@@ -1330,7 +1340,7 @@ async def lookup_ip(
|
||||
if isinstance(result, Exception):
|
||||
continue
|
||||
try:
|
||||
ban_list: list[str] = cast(list[str], _ok(result)) or []
|
||||
ban_list: list[str] = cast("list[str]", _ok(result)) or []
|
||||
if ip in ban_list:
|
||||
currently_banned_in.append(jail_name)
|
||||
except (ValueError, TypeError):
|
||||
|
||||
@@ -10,7 +10,7 @@ HTTP/FastAPI concerns.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast, TypeAlias
|
||||
from typing import cast
|
||||
|
||||
import structlog
|
||||
|
||||
@@ -21,7 +21,7 @@ from app.utils.fail2ban_client import Fail2BanClient, Fail2BanCommand, Fail2BanR
|
||||
# Types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Fail2BanSettingValue: TypeAlias = str | int | bool
|
||||
type Fail2BanSettingValue = str | int | bool
|
||||
"""Allowed values for server settings commands."""
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
@@ -106,7 +106,7 @@ async def _safe_get(
|
||||
"""
|
||||
try:
|
||||
response = await client.send(command)
|
||||
return _ok(cast(Fail2BanResponse, response))
|
||||
return _ok(cast("Fail2BanResponse", response))
|
||||
except Exception:
|
||||
return default
|
||||
|
||||
@@ -189,7 +189,7 @@ async def update_settings(socket_path: str, update: ServerSettingsUpdate) -> Non
|
||||
async def _set(key: str, value: Fail2BanSettingValue) -> None:
|
||||
try:
|
||||
response = await client.send(["set", key, value])
|
||||
_ok(cast(Fail2BanResponse, response))
|
||||
_ok(cast("Fail2BanResponse", response))
|
||||
except ValueError as exc:
|
||||
raise ServerOperationError(f"Failed to set {key!r} = {value!r}: {exc}") from exc
|
||||
|
||||
@@ -224,7 +224,7 @@ async def flush_logs(socket_path: str) -> str:
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
try:
|
||||
response = await client.send(["flushlogs"])
|
||||
result = _ok(cast(Fail2BanResponse, response))
|
||||
result = _ok(cast("Fail2BanResponse", response))
|
||||
log.info("logs_flushed", result=result)
|
||||
return str(result)
|
||||
except ValueError as exc:
|
||||
|
||||
@@ -35,7 +35,7 @@ GEO_RE_RESOLVE_INTERVAL: int = 600
|
||||
JOB_ID: str = "geo_re_resolve"
|
||||
|
||||
|
||||
async def _run_re_resolve(app: "FastAPI") -> None:
|
||||
async def _run_re_resolve(app: FastAPI) -> None:
|
||||
"""Query NULL-country IPs from the database and re-resolve them.
|
||||
|
||||
Reads shared resources from ``app.state`` and delegates to
|
||||
|
||||
@@ -47,7 +47,7 @@ HEALTH_CHECK_INTERVAL: int = 30
|
||||
_ACTIVATION_CRASH_WINDOW: int = 60
|
||||
|
||||
|
||||
async def _run_probe(app: "FastAPI") -> None:
|
||||
async def _run_probe(app: FastAPI) -> None:
|
||||
"""Probe fail2ban and cache the result on *app.state*.
|
||||
|
||||
Detects online/offline state transitions. When fail2ban goes offline
|
||||
|
||||
@@ -21,34 +21,52 @@ import contextlib
|
||||
import errno
|
||||
import socket
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence, Set
|
||||
from pickle import HIGHEST_PROTOCOL, dumps, loads
|
||||
from typing import TYPE_CHECKING, TypeAlias
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
Fail2BanToken: TypeAlias = str | int | float | bool | None | dict[str, object] | list[object]
|
||||
# Use covariant container types so callers can pass ``list[int]`` / ``dict[str, str]``
|
||||
# without needing to cast. At runtime we only accept the basic built-in
|
||||
# containers supported by fail2ban's protocol (list/dict/set) and stringify
|
||||
# anything else.
|
||||
#
|
||||
# NOTE: ``Sequence`` will also accept tuples, but tuples are stringified at
|
||||
# runtime because fail2ban only understands lists.
|
||||
|
||||
type Fail2BanToken = (
|
||||
str
|
||||
| int
|
||||
| float
|
||||
| bool
|
||||
| None
|
||||
| Mapping[str, object]
|
||||
| Sequence[object]
|
||||
| Set[object]
|
||||
)
|
||||
"""A single token in a fail2ban command.
|
||||
|
||||
Fail2ban accepts simple types (str/int/float/bool) plus compound types
|
||||
(list/dict). Complex objects are stringified before being sent.
|
||||
(list/dict/set). Complex objects are stringified before being sent.
|
||||
"""
|
||||
|
||||
Fail2BanCommand: TypeAlias = list[Fail2BanToken]
|
||||
type Fail2BanCommand = Sequence[Fail2BanToken]
|
||||
"""A command sent to fail2ban over the socket.
|
||||
|
||||
Commands are pickle serialised lists of tokens.
|
||||
Commands are pickle serialised sequences of tokens.
|
||||
"""
|
||||
|
||||
Fail2BanResponse: TypeAlias = tuple[int, object]
|
||||
type Fail2BanResponse = tuple[int, object]
|
||||
"""A typical fail2ban response containing a status code and payload."""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import TracebackType
|
||||
|
||||
import structlog
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
# fail2ban protocol constants — inline to avoid a hard import dependency
|
||||
@@ -200,7 +218,7 @@ def _send_command_sync(
|
||||
) from last_oserror
|
||||
|
||||
|
||||
def _coerce_command_token(token: Fail2BanToken) -> Fail2BanToken:
|
||||
def _coerce_command_token(token: object) -> Fail2BanToken:
|
||||
"""Coerce a command token to a type that fail2ban understands.
|
||||
|
||||
fail2ban's ``CSocket.convert`` accepts ``str``, ``bool``, ``int``,
|
||||
|
||||
Reference in New Issue
Block a user