from __future__ import annotations import asyncio import configparser import contextlib import io import os import re import tempfile from pathlib import Path from typing import cast import structlog from app.exceptions import ( ConfigWriteError, FilterNameError, JailNameError, ) from app.models.config import ( BantimeEscalation, InactiveJail, JailValidationIssue, JailValidationResult, ) from app.utils.constants import FAIL2BAN_SOCKET_TIMEOUT, FAIL2BAN_TRUTHY_VALUES from app.utils.fail2ban_client import ( Fail2BanClient, Fail2BanConnectionError, Fail2BanResponse, ) from app.utils.fail2ban_response import ok, to_dict from app.utils.log_sanitizer import sanitize_for_logging log: structlog.stdlib.BoundLogger = structlog.get_logger() # Allowlist pattern for jail names used in path construction. _SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$") # Allowlist pattern for filter names used in path construction. _SAFE_FILTER_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$") # Allowlist pattern for action names used in path construction. _SAFE_ACTION_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$") # Sections that are not jail definitions. _META_SECTIONS: frozenset[str] = frozenset({"INCLUDES", "DEFAULT"}) # False-ish values for the ``enabled`` key. _FALSE_VALUES: frozenset[str] = frozenset({"false", "no", "0"}) def _build_parser() -> configparser.RawConfigParser: """Return a parser configured for fail2ban-style INI files.""" parser = configparser.RawConfigParser(strict=False, interpolation=None) parser.optionxform = str return parser def _is_truthy(value: str) -> bool: """Return ``True`` if *value* represents a fail2ban boolean true.""" return value.strip().lower() in FAIL2BAN_TRUTHY_VALUES def _parse_int_safe(value: str) -> int | None: """Parse *value* as int, returning ``None`` on failure.""" try: return int(value.strip()) except (ValueError, AttributeError): return None def _parse_time_to_seconds(value: str | None, default: int) -> int: """Convert a fail2ban time string to seconds.""" if not value: return default stripped = value.strip() if stripped == "-1": return -1 multipliers: dict[str, int] = { "w": 604800, "d": 86400, "h": 3600, "m": 60, "s": 1, } for suffix, factor in multipliers.items(): if stripped.endswith(suffix) and len(stripped) > 1: try: return int(stripped[:-1]) * factor except ValueError: return default try: return int(stripped) except ValueError: return default def _parse_multiline(raw: str) -> list[str]: """Split a multi-line INI value into individual non-blank lines.""" result: list[str] = [] for line in raw.splitlines(): stripped = line.strip() if stripped and not stripped.startswith("#"): result.append(stripped) return result def _resolve_filter(raw_filter: str, jail_name: str, mode: str) -> str: """Resolve fail2ban variable placeholders in a filter string.""" result = raw_filter.replace("%(__name__)s", jail_name) return result.replace("%(mode)s", mode) def _ordered_config_files(config_dir: Path) -> list[Path]: """Return all jail config files in fail2ban merge order.""" files: list[Path] = [] jail_conf = config_dir / "jail.conf" if jail_conf.is_file(): files.append(jail_conf) jail_local = config_dir / "jail.local" if jail_local.is_file(): files.append(jail_local) jail_d = config_dir / "jail.d" if jail_d.is_dir(): files.extend(sorted(jail_d.glob("*.conf"))) files.extend(sorted(jail_d.glob("*.local"))) return files def _build_inactive_jail( name: str, settings: dict[str, str], source_file: str, config_dir: Path | None = None, ) -> InactiveJail: """Construct an :class:`~app.models.config.InactiveJail` from raw settings.""" raw_filter = settings.get("filter", "") mode = settings.get("mode", "normal") filter_name = _resolve_filter(raw_filter, name, mode) if raw_filter else name raw_action = settings.get("action", "") actions = _parse_multiline(raw_action) if raw_action else [] raw_logpath = settings.get("logpath", "") logpath = _parse_multiline(raw_logpath) if raw_logpath else [] enabled_raw = settings.get("enabled", "false") enabled = _is_truthy(enabled_raw) maxretry_raw = settings.get("maxretry", "") maxretry = _parse_int_safe(maxretry_raw) ban_time_seconds = _parse_time_to_seconds(settings.get("bantime"), 600) find_time_seconds = _parse_time_to_seconds(settings.get("findtime"), 600) log_encoding = settings.get("logencoding") or "auto" backend = settings.get("backend") or "auto" date_pattern = settings.get("datepattern") or None use_dns = settings.get("usedns") or "warn" prefregex = settings.get("prefregex") or "" fail_regex = _parse_multiline(settings.get("failregex", "")) ignore_regex = _parse_multiline(settings.get("ignoreregex", "")) esc_increment = _is_truthy(settings.get("bantime.increment", "false")) esc_factor_raw = settings.get("bantime.factor") esc_factor = float(esc_factor_raw) if esc_factor_raw else None esc_formula = settings.get("bantime.formula") or None esc_multipliers = settings.get("bantime.multipliers") or None esc_max_raw = settings.get("bantime.maxtime") esc_max_time = _parse_time_to_seconds(esc_max_raw, 0) if esc_max_raw else None esc_rnd_raw = settings.get("bantime.rndtime") esc_rnd_time = _parse_time_to_seconds(esc_rnd_raw, 0) if esc_rnd_raw else None esc_overall = _is_truthy(settings.get("bantime.overalljails", "false")) bantime_escalation = ( BantimeEscalation( increment=esc_increment, factor=esc_factor, formula=esc_formula, multipliers=esc_multipliers, max_time=esc_max_time, rnd_time=esc_rnd_time, overall_jails=esc_overall, ) if esc_increment else None ) return InactiveJail( name=name, filter=filter_name, actions=actions, port=settings.get("port") or None, logpath=logpath, bantime=settings.get("bantime") or None, findtime=settings.get("findtime") or None, maxretry=maxretry, ban_time_seconds=ban_time_seconds, find_time_seconds=find_time_seconds, log_encoding=log_encoding, backend=backend, date_pattern=date_pattern, use_dns=use_dns, prefregex=prefregex, fail_regex=fail_regex, ignore_regex=ignore_regex, bantime_escalation=bantime_escalation, source_file=source_file, enabled=enabled, has_local_override=((config_dir / "jail.d" / f"{name}.local").is_file() if config_dir is not None else False), ) def _parse_jails_sync( config_dir: Path, ) -> tuple[dict[str, dict[str, str]], dict[str, str]]: """Synchronously parse all jail configs and return merged definitions.""" parser = _build_parser() files = _ordered_config_files(config_dir) source_files: dict[str, str] = {} for path in files: try: single = _build_parser() single.read(str(path), encoding="utf-8") for section in single.sections(): if section not in _META_SECTIONS: source_files[section] = str(path) except (configparser.Error, OSError) as exc: log.warning("jail_config_read_error", path=str(path), error=str(exc)) try: parser.read([str(p) for p in files], encoding="utf-8") except configparser.Error as exc: log.warning("jail_config_parse_error", error=str(exc)) jails: dict[str, dict[str, str]] = {} for section in parser.sections(): if section in _META_SECTIONS: continue try: jails[section] = dict(parser.items(section)) except configparser.Error as exc: log.warning("jail_section_parse_error", section=section, error=str(exc)) log.debug("jails_parsed", count=len(jails), config_dir=str(config_dir)) return jails, source_files async def _get_active_jail_names(socket_path: str) -> set[str]: """Fetch the set of currently running jail names from fail2ban.""" try: client = Fail2BanClient(socket_path=socket_path, timeout=FAIL2BAN_SOCKET_TIMEOUT) status_raw = ok(await client.send(["status"])) status_dict = to_dict(status_raw) jail_list_raw: str = str(status_dict.get("Jail list", "") or "").strip() if not jail_list_raw: return set() return {j.strip() for j in jail_list_raw.split(",") if j.strip()} except Fail2BanConnectionError: log.warning("fail2ban_unreachable_during_inactive_list") return set() except Exception as exc: # noqa: BLE001 log.warning("fail2ban_status_error_during_inactive_list", error=str(exc)) return set() async def _probe_fail2ban_running(socket_path: str) -> bool: """Return ``True`` when fail2ban responds successfully to a status request.""" try: client = Fail2BanClient(socket_path=socket_path, timeout=FAIL2BAN_SOCKET_TIMEOUT) response = await client.send(["status"]) code, _ = cast("Fail2BanResponse", response) return code == 0 except Fail2BanConnectionError: log.warning("fail2ban_unreachable_during_probe", socket_path=str(socket_path)) return False except Exception as exc: # noqa: BLE001 log.warning("fail2ban_probe_error", socket_path=str(socket_path), error=str(exc)) return False def _extract_action_base_name(action_str: str) -> str | None: """Return the base action name from an action assignment string.""" if "%" in action_str or "$" in action_str: return None base = action_str.split("[")[0].strip() if _SAFE_ACTION_NAME_RE.match(base): return base return None def _validate_jail_config_sync( config_dir: Path, name: str, ) -> JailValidationResult: """Run synchronous pre-activation checks on a jail configuration.""" issues: list[JailValidationIssue] = [] all_jails, _ = _parse_jails_sync(config_dir) settings = all_jails.get(name) if settings is None: return JailValidationResult( jail_name=name, valid=False, issues=[ JailValidationIssue( field="name", message=f"Jail {name!r} not found in config files.", ) ], ) filter_d = config_dir / "filter.d" action_d = config_dir / "action.d" raw_filter = settings.get("filter", "") if raw_filter: mode = settings.get("mode", "normal") resolved = _resolve_filter(raw_filter, name, mode) base_filter = _extract_action_base_name(resolved) if base_filter: conf_ok = (filter_d / f"{base_filter}.conf").is_file() local_ok = (filter_d / f"{base_filter}.local").is_file() if not conf_ok and not local_ok: issues.append( JailValidationIssue( field="filter", message=(f"Filter file not found: filter.d/{base_filter}.conf (or .local)"), ) ) raw_action = settings.get("action", "") if raw_action: for action_line in _parse_multiline(raw_action): action_name = _extract_action_base_name(action_line) if action_name: conf_ok = (action_d / f"{action_name}.conf").is_file() local_ok = (action_d / f"{action_name}.local").is_file() if not conf_ok and not local_ok: issues.append( JailValidationIssue( field="action", message=(f"Action file not found: action.d/{action_name}.conf (or .local)"), ) ) for pattern in _parse_multiline(settings.get("failregex", "")): try: re.compile(pattern) except re.error as exc: issues.append( JailValidationIssue( field="failregex", message=f"Invalid regex pattern: {exc}", ) ) for pattern in _parse_multiline(settings.get("ignoreregex", "")): try: re.compile(pattern) except re.error as exc: issues.append( JailValidationIssue( field="ignoreregex", message=f"Invalid regex pattern: {exc}", ) ) raw_logpath = settings.get("logpath", "") if raw_logpath: for log_path in _parse_multiline(raw_logpath): if "*" in log_path or "?" in log_path or "%(" in log_path: continue if not Path(log_path).exists(): issues.append( JailValidationIssue( field="logpath", message=f"Log file not found on disk: {log_path}", ) ) valid = len(issues) == 0 log.debug( "jail_validation_complete", jail=name, valid=valid, issue_count=len(issues), ) return JailValidationResult(jail_name=name, valid=valid, issues=issues) def _safe_jail_name(name: str) -> str: """Validate *name* and return it unchanged or raise :class:`JailNameError`.""" if not _SAFE_JAIL_NAME_RE.match(name): raise JailNameError( f"Jail name {name!r} contains invalid characters. " "Only alphanumeric characters, hyphens, underscores, and dots are " "allowed; must start with an alphanumeric character." ) return name def _safe_filter_name(name: str) -> str: """Validate *name* and return it unchanged or raise :class:`FilterNameError`.""" if not _SAFE_FILTER_NAME_RE.match(name): raise FilterNameError( f"Filter name {name!r} contains invalid characters. " "Only alphanumeric characters, hyphens, underscores, and dots are " "allowed; must start with an alphanumeric character." ) return name def _set_jail_local_key_sync( config_dir: Path, jail_name: str, key: str, value: str, ) -> None: """Update ``jail.d/{jail_name}.local`` to set a single key in the jail section.""" jail_d = config_dir / "jail.d" try: jail_d.mkdir(parents=True, exist_ok=True) except OSError as exc: raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc local_path = jail_d / f"{jail_name}.local" parser = _build_parser() if local_path.is_file(): try: parser.read(str(local_path), encoding="utf-8") except (configparser.Error, OSError) as exc: log.warning( "jail_local_read_for_update_error", jail=jail_name, error=str(exc), ) if not parser.has_section(jail_name): parser.add_section(jail_name) parser.set(jail_name, key, value) buf = io.StringIO() buf.write("# Managed by BanGUI — do not edit manually\n\n") parser.write(buf) content = buf.getvalue() try: with tempfile.NamedTemporaryFile( mode="w", encoding="utf-8", dir=jail_d, delete=False, suffix=".tmp", ) as tmp: tmp.write(content) tmp_name = tmp.name os.replace(tmp_name, local_path) except OSError as exc: with contextlib.suppress(OSError): os.unlink(tmp_name) # noqa: F821 raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc log.info( "jail_local_key_set", jail=jail_name, key=key, path=str(local_path), ) ordered_config_files = _ordered_config_files build_parser = _build_parser is_truthy = _is_truthy parse_multiline = _parse_multiline parse_jails_sync = _parse_jails_sync build_inactive_jail = _build_inactive_jail get_active_jail_names = _get_active_jail_names validate_jail_config_sync = _validate_jail_config_sync set_jail_local_key_sync = _set_jail_local_key_sync safe_jail_name = _safe_jail_name safe_filter_name = _safe_filter_name probe_fail2ban_running = _probe_fail2ban_running async def start_daemon(start_cmd_parts: list[str]) -> bool: """Run the configured fail2ban start command and return whether it launched.""" process = await asyncio.create_subprocess_exec( *start_cmd_parts, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE, ) stdout, stderr = await process.communicate() if process.returncode != 0: log.error( "fail2ban_start_failed", command=" ".join(start_cmd_parts), returncode=process.returncode, stdout=sanitize_for_logging(stdout.decode("utf-8", errors="replace")), stderr=sanitize_for_logging(stderr.decode("utf-8", errors="replace")), ) return False log.info( "fail2ban_start_succeeded", command=" ".join(start_cmd_parts), ) return True async def wait_for_fail2ban( socket_path: str, max_wait_seconds: float, poll_interval: float = 0.5, ) -> bool: """Probe the fail2ban socket until it is responsive or the timeout expires.""" deadline = asyncio.get_running_loop().time() + max_wait_seconds while asyncio.get_running_loop().time() < deadline: if await _probe_fail2ban_running(socket_path): return True await asyncio.sleep(poll_interval) log.warning( "wait_for_fail2ban_timeout", socket_path=str(socket_path), max_wait_seconds=max_wait_seconds, ) return False