refactoring-backend #3
@@ -22,7 +22,7 @@ from app.dependencies import (
|
||||
HttpSessionDep,
|
||||
get_db,
|
||||
)
|
||||
from app.models.geo import GeoCacheStatsResponse, GeoDetail, GeoInfo, GeoReResolveResponse, IpLookupResponse
|
||||
from app.models.geo import GeoCacheStatsResponse, GeoReResolveResponse, IpLookupResponse
|
||||
from app.services import geo_service, jail_service
|
||||
from app.utils.fail2ban_client import Fail2BanConnectionError
|
||||
|
||||
@@ -79,21 +79,7 @@ async def lookup_ip(
|
||||
detail=f"Cannot reach fail2ban: {exc}",
|
||||
) from exc
|
||||
|
||||
raw_geo = result["geo"]
|
||||
geo_detail: GeoDetail | None = None
|
||||
if isinstance(raw_geo, GeoInfo):
|
||||
geo_detail = GeoDetail(
|
||||
country_code=raw_geo.country_code,
|
||||
country_name=raw_geo.country_name,
|
||||
asn=raw_geo.asn,
|
||||
org=raw_geo.org,
|
||||
)
|
||||
|
||||
return IpLookupResponse(
|
||||
ip=result["ip"],
|
||||
currently_banned_in=result["currently_banned_in"],
|
||||
geo=geo_detail,
|
||||
)
|
||||
return IpLookupResponse(**result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -40,7 +40,8 @@ from app.repositories.history_archive_repo import (
|
||||
get_all_archived_history,
|
||||
get_archived_history,
|
||||
)
|
||||
from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso
|
||||
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
|
||||
from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
@@ -50,6 +51,12 @@ if TYPE_CHECKING:
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
|
||||
async def get_fail2ban_db_path(socket_path: str) -> str:
|
||||
"""Return the fail2ban database path using the shared metadata cache."""
|
||||
return await default_fail2ban_metadata_service.get_db_path(socket_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -29,10 +29,17 @@ from app.models.history import (
|
||||
)
|
||||
from app.repositories import fail2ban_db_repo
|
||||
from app.repositories.history_archive_repo import archive_ban_event
|
||||
from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso
|
||||
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
|
||||
from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
|
||||
async def get_fail2ban_db_path(socket_path: str) -> str:
|
||||
"""Return the fail2ban database path using the shared metadata cache."""
|
||||
return await default_fail2ban_metadata_service.get_db_path(socket_path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
551
backend/app/utils/config_file_utils.py
Normal file
551
backend/app/utils/config_file_utils.py
Normal file
@@ -0,0 +1,551 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import configparser
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
|
||||
import structlog
|
||||
|
||||
from app.exceptions import (
|
||||
ConfigWriteError,
|
||||
FilterNameError,
|
||||
JailNameError,
|
||||
)
|
||||
from app.models.config import (
|
||||
ActionConfig,
|
||||
BantimeEscalation,
|
||||
InactiveJail,
|
||||
JailValidationIssue,
|
||||
JailValidationResult,
|
||||
)
|
||||
from app.utils import conffile_parser
|
||||
from app.utils.constants import FAIL2BAN_TRUTHY_VALUES
|
||||
from app.utils.fail2ban_client import (
|
||||
Fail2BanClient,
|
||||
Fail2BanConnectionError,
|
||||
Fail2BanResponse,
|
||||
)
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
_SOCKET_TIMEOUT: float = 10.0
|
||||
|
||||
# Allowlist pattern for jail names used in path construction.
|
||||
_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
||||
|
||||
# Allowlist pattern for filter names used in path construction.
|
||||
_SAFE_FILTER_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
||||
|
||||
# Allowlist pattern for action names used in path construction.
|
||||
_SAFE_ACTION_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
|
||||
|
||||
# Sections that are not jail definitions.
|
||||
_META_SECTIONS: frozenset[str] = frozenset({"INCLUDES", "DEFAULT"})
|
||||
|
||||
# False-ish values for the ``enabled`` key.
|
||||
_FALSE_VALUES: frozenset[str] = frozenset({"false", "no", "0"})
|
||||
|
||||
|
||||
def _build_parser() -> configparser.RawConfigParser:
|
||||
"""Return a parser configured for fail2ban-style INI files."""
|
||||
parser = configparser.RawConfigParser(strict=False, interpolation=None)
|
||||
parser.optionxform = str
|
||||
return parser
|
||||
|
||||
|
||||
def _is_truthy(value: str) -> bool:
|
||||
"""Return ``True`` if *value* represents a fail2ban boolean true."""
|
||||
return value.strip().lower() in FAIL2BAN_TRUTHY_VALUES
|
||||
|
||||
|
||||
def _parse_int_safe(value: str) -> int | None:
|
||||
"""Parse *value* as int, returning ``None`` on failure."""
|
||||
try:
|
||||
return int(value.strip())
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
|
||||
|
||||
def _parse_time_to_seconds(value: str | None, default: int) -> int:
|
||||
"""Convert a fail2ban time string to seconds."""
|
||||
if not value:
|
||||
return default
|
||||
stripped = value.strip()
|
||||
if stripped == "-1":
|
||||
return -1
|
||||
multipliers: dict[str, int] = {
|
||||
"w": 604800,
|
||||
"d": 86400,
|
||||
"h": 3600,
|
||||
"m": 60,
|
||||
"s": 1,
|
||||
}
|
||||
for suffix, factor in multipliers.items():
|
||||
if stripped.endswith(suffix) and len(stripped) > 1:
|
||||
try:
|
||||
return int(stripped[:-1]) * factor
|
||||
except ValueError:
|
||||
return default
|
||||
try:
|
||||
return int(stripped)
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
|
||||
def _parse_multiline(raw: str) -> list[str]:
|
||||
"""Split a multi-line INI value into individual non-blank lines."""
|
||||
result: list[str] = []
|
||||
for line in raw.splitlines():
|
||||
stripped = line.strip()
|
||||
if stripped and not stripped.startswith("#"):
|
||||
result.append(stripped)
|
||||
return result
|
||||
|
||||
|
||||
def _resolve_filter(raw_filter: str, jail_name: str, mode: str) -> str:
|
||||
"""Resolve fail2ban variable placeholders in a filter string."""
|
||||
result = raw_filter.replace("%(__name__)s", jail_name)
|
||||
return result.replace("%(mode)s", mode)
|
||||
|
||||
|
||||
def _ordered_config_files(config_dir: Path) -> list[Path]:
|
||||
"""Return all jail config files in fail2ban merge order."""
|
||||
files: list[Path] = []
|
||||
|
||||
jail_conf = config_dir / "jail.conf"
|
||||
if jail_conf.is_file():
|
||||
files.append(jail_conf)
|
||||
|
||||
jail_local = config_dir / "jail.local"
|
||||
if jail_local.is_file():
|
||||
files.append(jail_local)
|
||||
|
||||
jail_d = config_dir / "jail.d"
|
||||
if jail_d.is_dir():
|
||||
files.extend(sorted(jail_d.glob("*.conf")))
|
||||
files.extend(sorted(jail_d.glob("*.local")))
|
||||
|
||||
return files
|
||||
|
||||
|
||||
def _build_inactive_jail(
|
||||
name: str,
|
||||
settings: dict[str, str],
|
||||
source_file: str,
|
||||
config_dir: Path | None = None,
|
||||
) -> InactiveJail:
|
||||
"""Construct an :class:`~app.models.config.InactiveJail` from raw settings."""
|
||||
raw_filter = settings.get("filter", "")
|
||||
mode = settings.get("mode", "normal")
|
||||
filter_name = _resolve_filter(raw_filter, name, mode) if raw_filter else name
|
||||
|
||||
raw_action = settings.get("action", "")
|
||||
actions = _parse_multiline(raw_action) if raw_action else []
|
||||
|
||||
raw_logpath = settings.get("logpath", "")
|
||||
logpath = _parse_multiline(raw_logpath) if raw_logpath else []
|
||||
|
||||
enabled_raw = settings.get("enabled", "false")
|
||||
enabled = _is_truthy(enabled_raw)
|
||||
|
||||
maxretry_raw = settings.get("maxretry", "")
|
||||
maxretry = _parse_int_safe(maxretry_raw)
|
||||
|
||||
ban_time_seconds = _parse_time_to_seconds(settings.get("bantime"), 600)
|
||||
find_time_seconds = _parse_time_to_seconds(settings.get("findtime"), 600)
|
||||
log_encoding = settings.get("logencoding") or "auto"
|
||||
backend = settings.get("backend") or "auto"
|
||||
date_pattern = settings.get("datepattern") or None
|
||||
use_dns = settings.get("usedns") or "warn"
|
||||
prefregex = settings.get("prefregex") or ""
|
||||
fail_regex = _parse_multiline(settings.get("failregex", ""))
|
||||
ignore_regex = _parse_multiline(settings.get("ignoreregex", ""))
|
||||
|
||||
esc_increment = _is_truthy(settings.get("bantime.increment", "false"))
|
||||
esc_factor_raw = settings.get("bantime.factor")
|
||||
esc_factor = float(esc_factor_raw) if esc_factor_raw else None
|
||||
esc_formula = settings.get("bantime.formula") or None
|
||||
esc_multipliers = settings.get("bantime.multipliers") or None
|
||||
esc_max_raw = settings.get("bantime.maxtime")
|
||||
esc_max_time = _parse_time_to_seconds(esc_max_raw, 0) if esc_max_raw else None
|
||||
esc_rnd_raw = settings.get("bantime.rndtime")
|
||||
esc_rnd_time = _parse_time_to_seconds(esc_rnd_raw, 0) if esc_rnd_raw else None
|
||||
esc_overall = _is_truthy(settings.get("bantime.overalljails", "false"))
|
||||
bantime_escalation = (
|
||||
BantimeEscalation(
|
||||
increment=esc_increment,
|
||||
factor=esc_factor,
|
||||
formula=esc_formula,
|
||||
multipliers=esc_multipliers,
|
||||
max_time=esc_max_time,
|
||||
rnd_time=esc_rnd_time,
|
||||
overall_jails=esc_overall,
|
||||
)
|
||||
if esc_increment
|
||||
else None
|
||||
)
|
||||
|
||||
return InactiveJail(
|
||||
name=name,
|
||||
filter=filter_name,
|
||||
actions=actions,
|
||||
port=settings.get("port") or None,
|
||||
logpath=logpath,
|
||||
bantime=settings.get("bantime") or None,
|
||||
findtime=settings.get("findtime") or None,
|
||||
maxretry=maxretry,
|
||||
ban_time_seconds=ban_time_seconds,
|
||||
find_time_seconds=find_time_seconds,
|
||||
log_encoding=log_encoding,
|
||||
backend=backend,
|
||||
date_pattern=date_pattern,
|
||||
use_dns=use_dns,
|
||||
prefregex=prefregex,
|
||||
fail_regex=fail_regex,
|
||||
ignore_regex=ignore_regex,
|
||||
bantime_escalation=bantime_escalation,
|
||||
source_file=source_file,
|
||||
enabled=enabled,
|
||||
has_local_override=((config_dir / "jail.d" / f"{name}.local").is_file() if config_dir is not None else False),
|
||||
)
|
||||
|
||||
|
||||
def _parse_jails_sync(
|
||||
config_dir: Path,
|
||||
) -> tuple[dict[str, dict[str, str]], dict[str, str]]:
|
||||
"""Synchronously parse all jail configs and return merged definitions."""
|
||||
parser = _build_parser()
|
||||
files = _ordered_config_files(config_dir)
|
||||
|
||||
source_files: dict[str, str] = {}
|
||||
for path in files:
|
||||
try:
|
||||
single = _build_parser()
|
||||
single.read(str(path), encoding="utf-8")
|
||||
for section in single.sections():
|
||||
if section not in _META_SECTIONS:
|
||||
source_files[section] = str(path)
|
||||
except (configparser.Error, OSError) as exc:
|
||||
log.warning("jail_config_read_error", path=str(path), error=str(exc))
|
||||
|
||||
try:
|
||||
parser.read([str(p) for p in files], encoding="utf-8")
|
||||
except configparser.Error as exc:
|
||||
log.warning("jail_config_parse_error", error=str(exc))
|
||||
|
||||
jails: dict[str, dict[str, str]] = {}
|
||||
for section in parser.sections():
|
||||
if section in _META_SECTIONS:
|
||||
continue
|
||||
try:
|
||||
jails[section] = dict(parser.items(section))
|
||||
except configparser.Error as exc:
|
||||
log.warning("jail_section_parse_error", section=section, error=str(exc))
|
||||
|
||||
log.debug("jails_parsed", count=len(jails), config_dir=str(config_dir))
|
||||
return jails, source_files
|
||||
|
||||
|
||||
async def _get_active_jail_names(socket_path: str) -> set[str]:
|
||||
"""Fetch the set of currently running jail names from fail2ban."""
|
||||
try:
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
|
||||
def _to_dict_inner(pairs: object) -> dict[str, object]:
|
||||
if not isinstance(pairs, (list, tuple)):
|
||||
return {}
|
||||
result: dict[str, object] = {}
|
||||
for item in pairs:
|
||||
try:
|
||||
k, v = item
|
||||
result[str(k)] = v
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return result
|
||||
|
||||
def _ok(response: object) -> object:
|
||||
code, data = cast("Fail2BanResponse", response)
|
||||
if code != 0:
|
||||
raise ValueError(f"fail2ban error {code}: {data!r}")
|
||||
return data
|
||||
|
||||
status_raw = _ok(await client.send(["status"]))
|
||||
status_dict = _to_dict_inner(status_raw)
|
||||
jail_list_raw: str = str(status_dict.get("Jail list", "") or "").strip()
|
||||
if not jail_list_raw:
|
||||
return set()
|
||||
return {j.strip() for j in jail_list_raw.split(",") if j.strip()}
|
||||
except Fail2BanConnectionError:
|
||||
log.warning("fail2ban_unreachable_during_inactive_list")
|
||||
return set()
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("fail2ban_status_error_during_inactive_list", error=str(exc))
|
||||
return set()
|
||||
|
||||
|
||||
async def _probe_fail2ban_running(socket_path: str) -> bool:
|
||||
"""Return ``True`` when fail2ban responds successfully to a status request."""
|
||||
try:
|
||||
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
|
||||
response = await client.send(["status"])
|
||||
code, _ = cast("Fail2BanResponse", response)
|
||||
return code == 0
|
||||
except Fail2BanConnectionError:
|
||||
log.warning("fail2ban_unreachable_during_probe", socket_path=str(socket_path))
|
||||
return False
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("fail2ban_probe_error", socket_path=str(socket_path), error=str(exc))
|
||||
return False
|
||||
|
||||
|
||||
def _extract_action_base_name(action_str: str) -> str | None:
|
||||
"""Return the base action name from an action assignment string."""
|
||||
if "%" in action_str or "$" in action_str:
|
||||
return None
|
||||
base = action_str.split("[")[0].strip()
|
||||
if _SAFE_ACTION_NAME_RE.match(base):
|
||||
return base
|
||||
return None
|
||||
|
||||
|
||||
def _validate_jail_config_sync(
|
||||
config_dir: Path,
|
||||
name: str,
|
||||
) -> JailValidationResult:
|
||||
"""Run synchronous pre-activation checks on a jail configuration."""
|
||||
issues: list[JailValidationIssue] = []
|
||||
|
||||
all_jails, _ = _parse_jails_sync(config_dir)
|
||||
settings = all_jails.get(name)
|
||||
|
||||
if settings is None:
|
||||
return JailValidationResult(
|
||||
jail_name=name,
|
||||
valid=False,
|
||||
issues=[
|
||||
JailValidationIssue(
|
||||
field="name",
|
||||
message=f"Jail {name!r} not found in config files.",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
filter_d = config_dir / "filter.d"
|
||||
action_d = config_dir / "action.d"
|
||||
|
||||
raw_filter = settings.get("filter", "")
|
||||
if raw_filter:
|
||||
mode = settings.get("mode", "normal")
|
||||
resolved = _resolve_filter(raw_filter, name, mode)
|
||||
base_filter = _extract_action_base_name(resolved)
|
||||
if base_filter:
|
||||
conf_ok = (filter_d / f"{base_filter}.conf").is_file()
|
||||
local_ok = (filter_d / f"{base_filter}.local").is_file()
|
||||
if not conf_ok and not local_ok:
|
||||
issues.append(
|
||||
JailValidationIssue(
|
||||
field="filter",
|
||||
message=(f"Filter file not found: filter.d/{base_filter}.conf (or .local)"),
|
||||
)
|
||||
)
|
||||
|
||||
raw_action = settings.get("action", "")
|
||||
if raw_action:
|
||||
for action_line in _parse_multiline(raw_action):
|
||||
action_name = _extract_action_base_name(action_line)
|
||||
if action_name:
|
||||
conf_ok = (action_d / f"{action_name}.conf").is_file()
|
||||
local_ok = (action_d / f"{action_name}.local").is_file()
|
||||
if not conf_ok and not local_ok:
|
||||
issues.append(
|
||||
JailValidationIssue(
|
||||
field="action",
|
||||
message=(f"Action file not found: action.d/{action_name}.conf (or .local)"),
|
||||
)
|
||||
)
|
||||
|
||||
for pattern in _parse_multiline(settings.get("failregex", "")):
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error as exc:
|
||||
issues.append(
|
||||
JailValidationIssue(
|
||||
field="failregex",
|
||||
message=f"Invalid regex pattern: {exc}",
|
||||
)
|
||||
)
|
||||
|
||||
for pattern in _parse_multiline(settings.get("ignoreregex", "")):
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error as exc:
|
||||
issues.append(
|
||||
JailValidationIssue(
|
||||
field="ignoreregex",
|
||||
message=f"Invalid regex pattern: {exc}",
|
||||
)
|
||||
)
|
||||
|
||||
raw_logpath = settings.get("logpath", "")
|
||||
if raw_logpath:
|
||||
for log_path in _parse_multiline(raw_logpath):
|
||||
if "*" in log_path or "?" in log_path or "%(" in log_path:
|
||||
continue
|
||||
if not Path(log_path).exists():
|
||||
issues.append(
|
||||
JailValidationIssue(
|
||||
field="logpath",
|
||||
message=f"Log file not found on disk: {log_path}",
|
||||
)
|
||||
)
|
||||
|
||||
valid = len(issues) == 0
|
||||
log.debug(
|
||||
"jail_validation_complete",
|
||||
jail=name,
|
||||
valid=valid,
|
||||
issue_count=len(issues),
|
||||
)
|
||||
return JailValidationResult(jail_name=name, valid=valid, issues=issues)
|
||||
|
||||
|
||||
def _safe_jail_name(name: str) -> str:
|
||||
"""Validate *name* and return it unchanged or raise :class:`JailNameError`."""
|
||||
if not _SAFE_JAIL_NAME_RE.match(name):
|
||||
raise JailNameError(
|
||||
f"Jail name {name!r} contains invalid characters. "
|
||||
"Only alphanumeric characters, hyphens, underscores, and dots are "
|
||||
"allowed; must start with an alphanumeric character."
|
||||
)
|
||||
return name
|
||||
|
||||
|
||||
def _safe_filter_name(name: str) -> str:
|
||||
"""Validate *name* and return it unchanged or raise :class:`FilterNameError`."""
|
||||
if not _SAFE_FILTER_NAME_RE.match(name):
|
||||
raise FilterNameError(
|
||||
f"Filter name {name!r} contains invalid characters. "
|
||||
"Only alphanumeric characters, hyphens, underscores, and dots are "
|
||||
"allowed; must start with an alphanumeric character."
|
||||
)
|
||||
return name
|
||||
|
||||
|
||||
def _set_jail_local_key_sync(
|
||||
config_dir: Path,
|
||||
jail_name: str,
|
||||
key: str,
|
||||
value: str,
|
||||
) -> None:
|
||||
"""Update ``jail.d/{jail_name}.local`` to set a single key in the jail section."""
|
||||
jail_d = config_dir / "jail.d"
|
||||
try:
|
||||
jail_d.mkdir(parents=True, exist_ok=True)
|
||||
except OSError as exc:
|
||||
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
|
||||
|
||||
local_path = jail_d / f"{jail_name}.local"
|
||||
parser = _build_parser()
|
||||
if local_path.is_file():
|
||||
try:
|
||||
parser.read(str(local_path), encoding="utf-8")
|
||||
except (configparser.Error, OSError) as exc:
|
||||
log.warning(
|
||||
"jail_local_read_for_update_error",
|
||||
jail=jail_name,
|
||||
error=str(exc),
|
||||
)
|
||||
|
||||
if not parser.has_section(jail_name):
|
||||
parser.add_section(jail_name)
|
||||
parser.set(jail_name, key, value)
|
||||
|
||||
buf = io.StringIO()
|
||||
buf.write("# Managed by BanGUI — do not edit manually\n\n")
|
||||
parser.write(buf)
|
||||
content = buf.getvalue()
|
||||
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(
|
||||
mode="w",
|
||||
encoding="utf-8",
|
||||
dir=jail_d,
|
||||
delete=False,
|
||||
suffix=".tmp",
|
||||
) as tmp:
|
||||
tmp.write(content)
|
||||
tmp_name = tmp.name
|
||||
os.replace(tmp_name, local_path)
|
||||
except OSError as exc:
|
||||
with contextlib.suppress(OSError):
|
||||
os.unlink(tmp_name) # noqa: F821
|
||||
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
|
||||
|
||||
log.info(
|
||||
"jail_local_key_set",
|
||||
jail=jail_name,
|
||||
key=key,
|
||||
path=str(local_path),
|
||||
)
|
||||
|
||||
|
||||
ordered_config_files = _ordered_config_files
|
||||
build_parser = _build_parser
|
||||
is_truthy = _is_truthy
|
||||
parse_multiline = _parse_multiline
|
||||
parse_jails_sync = _parse_jails_sync
|
||||
build_inactive_jail = _build_inactive_jail
|
||||
get_active_jail_names = _get_active_jail_names
|
||||
validate_jail_config_sync = _validate_jail_config_sync
|
||||
set_jail_local_key_sync = _set_jail_local_key_sync
|
||||
safe_jail_name = _safe_jail_name
|
||||
safe_filter_name = _safe_filter_name
|
||||
probe_fail2ban_running = _probe_fail2ban_running
|
||||
|
||||
|
||||
async def start_daemon(start_cmd_parts: list[str]) -> bool:
|
||||
"""Run the configured fail2ban start command and return whether it launched."""
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
*start_cmd_parts,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
stdout, stderr = await process.communicate()
|
||||
if process.returncode != 0:
|
||||
log.error(
|
||||
"fail2ban_start_failed",
|
||||
command=" ".join(start_cmd_parts),
|
||||
returncode=process.returncode,
|
||||
stdout=stdout.decode("utf-8", errors="replace"),
|
||||
stderr=stderr.decode("utf-8", errors="replace"),
|
||||
)
|
||||
return False
|
||||
log.info(
|
||||
"fail2ban_start_succeeded",
|
||||
command=" ".join(start_cmd_parts),
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
async def wait_for_fail2ban(
|
||||
socket_path: str,
|
||||
max_wait_seconds: float,
|
||||
poll_interval: float = 0.5,
|
||||
) -> bool:
|
||||
"""Probe the fail2ban socket until it is responsive or the timeout expires."""
|
||||
deadline = asyncio.get_running_loop().time() + max_wait_seconds
|
||||
while asyncio.get_running_loop().time() < deadline:
|
||||
if await _probe_fail2ban_running(socket_path):
|
||||
return True
|
||||
await asyncio.sleep(poll_interval)
|
||||
log.warning(
|
||||
"wait_for_fail2ban_timeout",
|
||||
socket_path=str(socket_path),
|
||||
max_wait_seconds=max_wait_seconds,
|
||||
)
|
||||
return False
|
||||
@@ -5,26 +5,12 @@ from __future__ import annotations
|
||||
import json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
|
||||
|
||||
|
||||
def ts_to_iso(unix_ts: int) -> str:
|
||||
"""Convert a Unix timestamp to an ISO 8601 UTC string."""
|
||||
return datetime.fromtimestamp(unix_ts, tz=UTC).isoformat()
|
||||
|
||||
|
||||
async def get_fail2ban_db_path(socket_path: str, *, force_refresh: bool = False) -> str:
|
||||
"""Return the fail2ban database path, using cached metadata when available."""
|
||||
return await default_fail2ban_metadata_service.get_db_path(
|
||||
socket_path, force_refresh=force_refresh
|
||||
)
|
||||
|
||||
|
||||
def invalidate_fail2ban_db_path(socket_path: str) -> None:
|
||||
"""Invalidate the cached fail2ban database path for the given socket."""
|
||||
default_fail2ban_metadata_service.invalidate_db_path(socket_path)
|
||||
|
||||
|
||||
def parse_data_json(raw: object) -> tuple[list[str], int]:
|
||||
"""Extract matches and failure count from the fail2ban bans.data value."""
|
||||
if raw is None:
|
||||
|
||||
@@ -12,7 +12,7 @@ from httpx import ASGITransport, AsyncClient
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.models.geo import GeoInfo
|
||||
from app.models.geo import GeoDetail, GeoInfo
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
@@ -71,7 +71,7 @@ class TestGeoLookup:
|
||||
|
||||
async def test_200_with_geo_info(self, geo_client: AsyncClient) -> None:
|
||||
"""GET /api/geo/lookup/{ip} returns 200 with enriched result."""
|
||||
geo = GeoInfo(country_code="DE", country_name="Germany", asn="12345", org="Acme")
|
||||
geo = GeoDetail(country_code="DE", country_name="Germany", asn="12345", org="Acme")
|
||||
result: dict[str, object] = {
|
||||
"ip": "1.2.3.4",
|
||||
"currently_banned_in": ["sshd"],
|
||||
@@ -97,7 +97,7 @@ class TestGeoLookup:
|
||||
result: dict[str, object] = {
|
||||
"ip": "8.8.8.8",
|
||||
"currently_banned_in": [],
|
||||
"geo": GeoInfo(country_code="US", country_name="United States", asn=None, org=None),
|
||||
"geo": GeoDetail(country_code="US", country_name="United States", asn=None, org=None),
|
||||
}
|
||||
with patch(
|
||||
"app.routers.geo.jail_service.lookup_ip",
|
||||
|
||||
Reference in New Issue
Block a user