diff --git a/Docs/Tasks.md b/Docs/Tasks.md index af05129..551b745 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -253,7 +253,7 @@ fail2ban ships with many action definitions in `action.d/` (iptables, firewalld, --- -### Task 3.3 — Frontend: Actions Tab with Active/Inactive Display and Activation +### Task 3.3 — Frontend: Actions Tab with Active/Inactive Display and Activation ✅ DONE **Goal:** Enhance the Actions tab in the Configuration page to show all actions with active/inactive status and allow editing and assignment. @@ -300,7 +300,7 @@ fail2ban ships with many action definitions in `action.d/` (iptables, firewalld, ## Stage 4 — Unified Config File Service and Shared Utilities -### Task 4.1 — Config File Parser Utility +### Task 4.1 — Config File Parser Utility ✅ DONE **Goal:** Build a robust, reusable parser for fail2ban INI-style config files that all config-related features share. @@ -324,7 +324,7 @@ fail2ban ships with many action definitions in `action.d/` (iptables, firewalld, --- -### Task 4.2 — Config File Writer Utility +### Task 4.2 — Config File Writer Utility ✅ DONE **Goal:** Build a safe writer utility for creating and updating `.local` override files. diff --git a/backend/app/utils/config_parser.py b/backend/app/utils/config_parser.py new file mode 100644 index 0000000..4202917 --- /dev/null +++ b/backend/app/utils/config_parser.py @@ -0,0 +1,358 @@ +"""Fail2ban INI-style config parser with include and interpolation support. + +Provides a :class:`Fail2BanConfigParser` class that wraps Python's +:class:`configparser.RawConfigParser` with fail2ban-specific behaviour: + +- **Merge order**: ``.conf`` file first, then ``.local`` overlay, then ``*.d/`` + directory overrides — each subsequent layer overwrites earlier values. +- **Include directives**: ``[INCLUDES]`` sections can specify ``before`` and + ``after`` filenames. ``before`` is loaded at lower priority (loaded first), + ``after`` at higher priority (loaded last). Both are resolved relative to + the directory of the including file. Circular includes and runaway recursion + are detected and logged. +- **Variable interpolation**: :meth:`interpolate` resolves ``%(variable)s`` + references using the ``[DEFAULT]`` section, the ``[Init]`` section, and any + caller-supplied variables. Multiple passes handle nested references. +- **Multi-line values**: Handled transparently by ``configparser``; the + :meth:`split_multiline` helper further strips blank lines and ``#`` comments. +- **Comments**: ``configparser`` strips full-line ``#``/``;`` comments; inline + comments inside multi-line values are stripped by :meth:`split_multiline`. + +All methods are synchronous. Call from async contexts via +:func:`asyncio.get_event_loop().run_in_executor`. +""" + +from __future__ import annotations + +import configparser +import re +from typing import TYPE_CHECKING + +import structlog + +if TYPE_CHECKING: + from pathlib import Path + +log: structlog.stdlib.BoundLogger = structlog.get_logger() + +# Compiled pattern that matches fail2ban-style %(variable_name)s references. +_INTERPOLATE_RE: re.Pattern[str] = re.compile(r"%\((\w+)\)s") + +# Guard against infinite interpolation loops. +_MAX_INTERPOLATION_PASSES: int = 10 + + +class Fail2BanConfigParser: + """Parse fail2ban INI config files with include resolution and interpolation. + + Typical usage for a ``filter.d/`` file:: + + parser = Fail2BanConfigParser(config_dir=Path("/etc/fail2ban")) + parser.read_with_overrides(Path("/etc/fail2ban/filter.d/sshd.conf")) + section = parser.section_dict("Definition") + failregex = parser.split_multiline(section.get("failregex", "")) + + Args: + config_dir: Optional fail2ban configuration root directory. Used only + by :meth:`ordered_conf_files`; pass ``None`` if not needed. + max_include_depth: Maximum ``[INCLUDES]`` nesting depth before giving up. + """ + + def __init__( + self, + config_dir: Path | None = None, + max_include_depth: int = 10, + ) -> None: + self._config_dir = config_dir + self._max_include_depth = max_include_depth + self._parser: configparser.RawConfigParser = self._make_parser() + # Tracks resolved absolute paths to detect include cycles. + self._read_paths: set[Path] = set() + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _make_parser() -> configparser.RawConfigParser: + """Return a case-sensitive :class:`configparser.RawConfigParser`.""" + parser = configparser.RawConfigParser(interpolation=None, strict=False) + # Keep original key casing (fail2ban is case-sensitive in option names). + parser.optionxform = str # type: ignore[assignment] + return parser + + def _get_include( + self, + include_dir: Path, + tmp_parser: configparser.RawConfigParser, + key: str, + ) -> Path | None: + """Return the resolved path for an include directive, or ``None``.""" + if not tmp_parser.has_section("INCLUDES"): + return None + if not tmp_parser.has_option("INCLUDES", key): + return None + raw = tmp_parser.get("INCLUDES", key).strip() + if not raw: + return None + return include_dir / raw + + # ------------------------------------------------------------------ + # Public interface — reading files + # ------------------------------------------------------------------ + + def read_file(self, path: Path, _depth: int = 0) -> None: + """Read *path*, following ``[INCLUDES]`` ``before``/``after`` directives. + + ``before`` references are loaded before the current file (lower + priority); ``after`` references are loaded after (higher priority). + Circular includes are detected by tracking resolved absolute paths. + + Args: + path: Config file to read. + _depth: Current include nesting depth. Internal parameter. + """ + if _depth > self._max_include_depth: + log.warning( + "include_depth_exceeded", + path=str(path), + max_depth=self._max_include_depth, + ) + return + + resolved = path.resolve() + if resolved in self._read_paths: + log.debug("include_cycle_detected", path=str(path)) + return + + try: + content = path.read_text(encoding="utf-8") + except OSError as exc: + log.warning("config_read_error", path=str(path), error=str(exc)) + return + + # Pre-scan for includes without yet committing to the main parser. + tmp = self._make_parser() + try: + tmp.read_string(content) + except configparser.Error as exc: + log.warning("config_parse_error", path=str(path), error=str(exc)) + return + + include_dir = path.parent + before_path = self._get_include(include_dir, tmp, "before") + after_path = self._get_include(include_dir, tmp, "after") + + # Load ``before`` first (lower priority than current file). + if before_path is not None: + self.read_file(before_path, _depth=_depth + 1) + + # Mark this path visited *before* merging to guard against cycles + # introduced by the ``after`` include referencing the same file. + self._read_paths.add(resolved) + + # Merge current file into the accumulating parser. + try: + self._parser.read_string(content, source=str(path)) + except configparser.Error as exc: + log.warning( + "config_parse_string_error", path=str(path), error=str(exc) + ) + + # Load ``after`` last (highest priority). + if after_path is not None: + self.read_file(after_path, _depth=_depth + 1) + + def read_with_overrides(self, conf_path: Path) -> None: + """Read *conf_path* and its ``.local`` override if it exists. + + The ``.local`` file is read after the ``.conf`` file so its values + take precedence. Include directives inside each file are still honoured. + + Args: + conf_path: Path to the ``.conf`` file. The corresponding + ``.local`` is derived by replacing the suffix with ``.local``. + """ + self.read_file(conf_path) + local_path = conf_path.with_suffix(".local") + if local_path.is_file(): + self.read_file(local_path) + + # ------------------------------------------------------------------ + # Public interface — querying parsed data + # ------------------------------------------------------------------ + + def sections(self) -> list[str]: + """Return all section names (excludes the ``[DEFAULT]`` pseudo-section). + + Returns: + Sorted list of section names present in the parsed files. + """ + return list(self._parser.sections()) + + def has_section(self, section: str) -> bool: + """Return whether *section* exists in the parsed configuration. + + Args: + section: Section name to check. + """ + return self._parser.has_section(section) + + def get(self, section: str, key: str) -> str | None: + """Return the raw value for *key* in *section*, or ``None``. + + Args: + section: Section name. + key: Option name. + + Returns: + Raw option value string, or ``None`` if not present. + """ + if self._parser.has_section(section) and self._parser.has_option( + section, key + ): + return self._parser.get(section, key) + return None + + def section_dict( + self, + section: str, + *, + skip: frozenset[str] | None = None, + ) -> dict[str, str]: + """Return all key-value pairs from *section* as a plain :class:`dict`. + + Keys whose names start with ``__`` (configparser internals from + ``DEFAULT`` inheritance) are always excluded. + + Args: + section: Section name to read. + skip: Additional key names to exclude. + + Returns: + Mapping of option name → raw value. Empty dict if section absent. + """ + if not self._parser.has_section(section): + return {} + drop: frozenset[str] = skip or frozenset() + return { + k: v + for k, v in self._parser.items(section) + if not k.startswith("__") and k not in drop + } + + def defaults(self) -> dict[str, str]: + """Return all ``[DEFAULT]`` section key-value pairs. + + Returns: + Dict of default keys and their values. + """ + return dict(self._parser.defaults()) + + # ------------------------------------------------------------------ + # Public interface — interpolation and helpers + # ------------------------------------------------------------------ + + def interpolate( + self, + value: str, + extra_vars: dict[str, str] | None = None, + ) -> str: + """Resolve ``%(variable)s`` references in *value*. + + Variables are resolved in the following priority order (low → high): + + 1. ``[DEFAULT]`` section values. + 2. ``[Init]`` section values (fail2ban action parameters). + 3. *extra_vars* provided by the caller. + + Multiple passes are performed to handle nested references (up to + :data:`_MAX_INTERPOLATION_PASSES` iterations). Unresolvable references + are left unchanged. + + Args: + value: Raw string possibly containing ``%(name)s`` placeholders. + extra_vars: Optional caller-supplied variables (highest priority). + + Returns: + String with ``%(name)s`` references substituted where possible. + """ + vars_: dict[str, str] = {} + vars_.update(self.defaults()) + vars_.update(self.section_dict("Init")) + if extra_vars: + vars_.update(extra_vars) + + def _sub(m: re.Match[str]) -> str: + return vars_.get(m.group(1), m.group(0)) + + result = value + for _ in range(_MAX_INTERPOLATION_PASSES): + new = _INTERPOLATE_RE.sub(_sub, result) + if new == result: + break + result = new + return result + + @staticmethod + def split_multiline(raw: str) -> list[str]: + """Split a multi-line INI value into individual non-blank lines. + + Each line is stripped of surrounding whitespace. Lines that are empty + or that start with ``#`` (comments) are discarded. + + Used for ``failregex``, ``ignoreregex``, ``action``, and ``logpath`` + values which fail2ban allows to span multiple lines. + + Args: + raw: Raw multi-line string from configparser. + + Returns: + List of stripped, non-empty, non-comment strings. + """ + result: list[str] = [] + for line in raw.splitlines(): + stripped = line.strip() + if stripped and not stripped.startswith("#"): + result.append(stripped) + return result + + # ------------------------------------------------------------------ + # Class-level utility — file ordering + # ------------------------------------------------------------------ + + @classmethod + def ordered_conf_files(cls, config_dir: Path, base_name: str) -> list[Path]: + """Return config files for *base_name* in fail2ban merge order. + + Merge order (ascending priority — later entries override earlier): + + 1. ``{config_dir}/{base_name}.conf`` + 2. ``{config_dir}/{base_name}.local`` + 3. ``{config_dir}/{base_name}.d/*.conf`` (sorted alphabetically) + 4. ``{config_dir}/{base_name}.d/*.local`` (sorted alphabetically) + + Args: + config_dir: Fail2ban configuration root directory. + base_name: Config base name without extension (e.g. ``"jail"``). + + Returns: + List of existing :class:`~pathlib.Path` objects in ascending + priority order (only files that actually exist are included). + """ + files: list[Path] = [] + + conf = config_dir / f"{base_name}.conf" + if conf.is_file(): + files.append(conf) + + local = config_dir / f"{base_name}.local" + if local.is_file(): + files.append(local) + + d_dir = config_dir / f"{base_name}.d" + if d_dir.is_dir(): + files.extend(sorted(d_dir.glob("*.conf"))) + files.extend(sorted(d_dir.glob("*.local"))) + + return files diff --git a/backend/app/utils/config_writer.py b/backend/app/utils/config_writer.py new file mode 100644 index 0000000..112f4e6 --- /dev/null +++ b/backend/app/utils/config_writer.py @@ -0,0 +1,303 @@ +"""Atomic config file writer for fail2ban ``.local`` override files. + +All write operations are atomic: content is first written to a temporary file +in the same directory as the target, then :func:`os.replace` is used to rename +it into place. This guarantees that a crash or power failure during the write +never leaves a partially-written file behind. + +A per-file :class:`threading.Lock` prevents concurrent writes from the same +process from racing. + +Security constraints +-------------------- +- Every write function asserts that the target path **ends in ``.local``**. + This prevents accidentally writing to ``.conf`` files (which belong to the + fail2ban package and should never be modified by BanGUI). + +Public functions +---------------- +- :func:`write_local_override` — create or update keys inside a ``.local`` file. +- :func:`remove_local_key` — remove a single key from a ``.local`` file. +- :func:`delete_local_file` — delete an entire ``.local`` file. +""" + +from __future__ import annotations + +import configparser +import contextlib +import io +import os +import tempfile +import threading +from typing import TYPE_CHECKING + +import structlog + +if TYPE_CHECKING: + from pathlib import Path + +log: structlog.stdlib.BoundLogger = structlog.get_logger() + +# --------------------------------------------------------------------------- +# Per-file lock registry +# --------------------------------------------------------------------------- + +# Maps resolved absolute path strings → threading.Lock instances. +_locks: dict[str, threading.Lock] = {} +# Guards the _locks dict itself. +_registry_lock: threading.Lock = threading.Lock() + + +def _get_file_lock(path: Path) -> threading.Lock: + """Return the per-file :class:`threading.Lock` for *path*. + + The lock is created on first access and reused on subsequent calls. + + Args: + path: Target file path (need not exist yet). + + Returns: + :class:`threading.Lock` bound to the resolved absolute path of *path*. + """ + key = str(path.resolve()) + with _registry_lock: + if key not in _locks: + _locks[key] = threading.Lock() + return _locks[key] + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _assert_local_file(path: Path) -> None: + """Raise :class:`ValueError` if *path* does not end with ``.local``. + + This is a safety guard against accidentally modifying ``.conf`` files. + + Args: + path: Path to validate. + + Raises: + ValueError: When *path* does not have a ``.local`` suffix. + """ + if path.suffix != ".local": + raise ValueError( + f"Refusing to write to non-.local file: {path!r}. " + "Only .local override files may be modified by BanGUI." + ) + + +def _make_parser() -> configparser.RawConfigParser: + """Return a case-sensitive :class:`configparser.RawConfigParser`.""" + parser = configparser.RawConfigParser(interpolation=None, strict=False) + parser.optionxform = str # type: ignore[assignment] + return parser + + +def _read_or_new_parser(path: Path) -> configparser.RawConfigParser: + """Read *path* into a parser, or return a fresh empty parser. + + If the file does not exist or cannot be read, a fresh parser is returned. + Any parse errors are logged as warnings (not re-raised). + + Args: + path: Path to the ``.local`` file to read. + + Returns: + Populated (or empty) :class:`configparser.RawConfigParser`. + """ + parser = _make_parser() + if path.is_file(): + try: + content = path.read_text(encoding="utf-8") + parser.read_string(content) + except (OSError, configparser.Error) as exc: + log.warning("local_file_read_error", path=str(path), error=str(exc)) + return parser + + +def _write_parser_atomic( + parser: configparser.RawConfigParser, + path: Path, +) -> None: + """Write *parser* contents to *path* atomically. + + Writes to a temporary file in the same directory as *path*, then renames + the temporary file over *path* using :func:`os.replace`. The temporary + file is cleaned up on failure. + + Args: + parser: Populated parser whose contents should be written. + path: Destination ``.local`` file path. + + Raises: + OSError: On filesystem errors (propagated to caller). + """ + buf = io.StringIO() + parser.write(buf) + content = buf.getvalue() + + path.parent.mkdir(parents=True, exist_ok=True) + + fd, tmp_path_str = tempfile.mkstemp( + dir=str(path.parent), + prefix=f".{path.name}.tmp", + suffix="", + ) + try: + with os.fdopen(fd, "w", encoding="utf-8") as f: + f.write(content) + os.replace(tmp_path_str, str(path)) + except Exception: + with contextlib.suppress(OSError): + os.unlink(tmp_path_str) + raise + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def write_local_override( + base_path: Path, + section: str, + key_values: dict[str, str], +) -> None: + """Create or update keys in a ``.local`` override file. + + If the file already exists, only the specified *key_values* are written + under *section*; all other sections and keys are preserved. + + If the file does not exist, it is created with the given *section* and + *key_values*. + + The write is **atomic**: a temporary file is written and renamed into place. + + Args: + base_path: Absolute path to the ``.local`` file (e.g. + ``filter.d/sshd.local``). The parent directory is created if it + does not already exist. + section: INI section name (e.g. ``"Definition"``, ``"Init"``). + key_values: Mapping of option name → value to write/update. + + Raises: + ValueError: If *base_path* does not end with ``.local``. + """ + _assert_local_file(base_path) + + lock = _get_file_lock(base_path) + with lock: + parser = _read_or_new_parser(base_path) + + if not parser.has_section(section): + parser.add_section(section) + + for key, value in key_values.items(): + parser.set(section, key, value) + + log.info( + "local_override_written", + path=str(base_path), + section=section, + keys=sorted(key_values), + ) + _write_parser_atomic(parser, base_path) + + +def remove_local_key(base_path: Path, section: str, key: str) -> None: + """Remove a single key from a ``.local`` override file. + + Post-removal cleanup: + + - If the section becomes empty after key removal, the section is also + removed. + - If no sections remain after section removal, the file is deleted. + + This function is a no-op when the file, section, or key does not exist. + + Args: + base_path: Path to the ``.local`` file to update. + section: INI section containing the key. + key: Option name to remove. + + Raises: + ValueError: If *base_path* does not end with ``.local``. + """ + _assert_local_file(base_path) + + if not base_path.is_file(): + return + + lock = _get_file_lock(base_path) + with lock: + parser = _read_or_new_parser(base_path) + + if not parser.has_section(section) or not parser.has_option(section, key): + return # Nothing to remove. + + parser.remove_option(section, key) + + # Remove the section if it has no remaining options. + if not parser.options(section): + parser.remove_section(section) + + # Delete the file entirely if it has no remaining sections. + if not parser.sections(): + with contextlib.suppress(OSError): + base_path.unlink() + log.info("local_file_deleted_empty", path=str(base_path)) + return + + log.info( + "local_key_removed", + path=str(base_path), + section=section, + key=key, + ) + _write_parser_atomic(parser, base_path) + + +def delete_local_file(path: Path, *, allow_orphan: bool = False) -> None: + """Delete a ``.local`` override file. + + By default, refuses to delete a ``.local`` file that has no corresponding + ``.conf`` file (an *orphan* ``.local``), because it may be the only copy of + a user-defined config. Pass ``allow_orphan=True`` to override this guard. + + Args: + path: Path to the ``.local`` file to delete. + allow_orphan: When ``True``, delete even if no corresponding ``.conf`` + exists alongside *path*. + + Raises: + ValueError: If *path* does not end with ``.local``. + FileNotFoundError: If *path* does not exist. + OSError: If no corresponding ``.conf`` exists and *allow_orphan* is + ``False``. + """ + _assert_local_file(path) + + if not path.is_file(): + raise FileNotFoundError(f"Local file not found: {path!r}") + + if not allow_orphan: + conf_path = path.with_suffix(".conf") + if not conf_path.is_file(): + raise OSError( + f"No corresponding .conf file found for {path!r}. " + "Pass allow_orphan=True to delete a local-only file." + ) + + lock = _get_file_lock(path) + with lock: + try: + path.unlink() + log.info("local_file_deleted", path=str(path)) + except OSError as exc: + log.error( + "local_file_delete_failed", path=str(path), error=str(exc) + ) + raise diff --git a/backend/tests/test_utils/__init__.py b/backend/tests/test_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/tests/test_utils/test_config_parser.py b/backend/tests/test_utils/test_config_parser.py new file mode 100644 index 0000000..3ae2343 --- /dev/null +++ b/backend/tests/test_utils/test_config_parser.py @@ -0,0 +1,473 @@ +"""Tests for app.utils.config_parser.Fail2BanConfigParser.""" + +from __future__ import annotations + +from pathlib import Path + +from app.utils.config_parser import Fail2BanConfigParser + +# --------------------------------------------------------------------------- +# Fixtures and helpers +# --------------------------------------------------------------------------- + +_FILTER_CONF = """\ +[INCLUDES] +before = common.conf + +[Definition] +failregex = ^%(host)s .*$ +ignoreregex = +""" + +_COMMON_CONF = """\ +[DEFAULT] +host = +""" + +_FILTER_LOCAL = """\ +[Definition] +failregex = ^OVERRIDE %(host)s$ +""" + +_ACTION_CONF = """\ +[Definition] +actionstart = iptables -N f2b- +actionstop = iptables -X f2b- +actionban = iptables -I INPUT -s -j DROP +actionunban = iptables -D INPUT -s -j DROP + +[Init] +name = default +ip = 1.2.3.4 +""" + + +def _write(tmp_path: Path, name: str, content: str) -> Path: + """Write *content* to *tmp_path/name* and return the path.""" + p = tmp_path / name + p.write_text(content, encoding="utf-8") + return p + + +# --------------------------------------------------------------------------- +# TestOrderedConfFiles +# --------------------------------------------------------------------------- + + +class TestOrderedConfFiles: + def test_empty_dir_returns_empty(self, tmp_path: Path) -> None: + result = Fail2BanConfigParser.ordered_conf_files(tmp_path, "jail") + assert result == [] + + def test_conf_only(self, tmp_path: Path) -> None: + conf = _write(tmp_path, "jail.conf", "[DEFAULT]\n") + result = Fail2BanConfigParser.ordered_conf_files(tmp_path, "jail") + assert result == [conf] + + def test_conf_then_local(self, tmp_path: Path) -> None: + conf = _write(tmp_path, "jail.conf", "[DEFAULT]\n") + local = _write(tmp_path, "jail.local", "[DEFAULT]\n") + result = Fail2BanConfigParser.ordered_conf_files(tmp_path, "jail") + assert result == [conf, local] + + def test_d_dir_overrides_appended(self, tmp_path: Path) -> None: + conf = _write(tmp_path, "jail.conf", "[DEFAULT]\n") + local = _write(tmp_path, "jail.local", "[DEFAULT]\n") + d_dir = tmp_path / "jail.d" + d_dir.mkdir() + d_conf = _write(d_dir, "extra.conf", "[DEFAULT]\n") + d_local = _write(d_dir, "extra.local", "[DEFAULT]\n") + result = Fail2BanConfigParser.ordered_conf_files(tmp_path, "jail") + assert result == [conf, local, d_conf, d_local] + + def test_missing_local_skipped(self, tmp_path: Path) -> None: + conf = _write(tmp_path, "jail.conf", "[DEFAULT]\n") + result = Fail2BanConfigParser.ordered_conf_files(tmp_path, "jail") + assert conf in result + assert len(result) == 1 + + def test_d_dir_sorted(self, tmp_path: Path) -> None: + d_dir = tmp_path / "jail.d" + d_dir.mkdir() + _write(d_dir, "zzz.conf", "[DEFAULT]\n") + _write(d_dir, "aaa.conf", "[DEFAULT]\n") + result = Fail2BanConfigParser.ordered_conf_files(tmp_path, "jail") + names = [p.name for p in result] + assert names == ["aaa.conf", "zzz.conf"] + + +# --------------------------------------------------------------------------- +# TestReadFile +# --------------------------------------------------------------------------- + + +class TestReadFile: + def test_reads_single_file(self, tmp_path: Path) -> None: + _write(tmp_path, "common.conf", _COMMON_CONF) + p = _write(tmp_path, "filter.conf", _FILTER_CONF) + parser = Fail2BanConfigParser() + parser.read_file(p) + assert parser.has_section("Definition") + + def test_before_include_loaded(self, tmp_path: Path) -> None: + _write(tmp_path, "common.conf", _COMMON_CONF) + p = _write(tmp_path, "filter.conf", _FILTER_CONF) + parser = Fail2BanConfigParser() + parser.read_file(p) + # DEFAULT from common.conf should be merged. + defaults = parser.defaults() + assert "host" in defaults + + def test_missing_file_is_silent(self, tmp_path: Path) -> None: + parser = Fail2BanConfigParser() + parser.read_file(tmp_path / "nonexistent.conf") + assert parser.sections() == [] + + def test_after_include_overrides(self, tmp_path: Path) -> None: + after_content = """\ +[Definition] +key = after_value +""" + after = _write(tmp_path, "after.conf", after_content) + _ = after # used via [INCLUDES] + main_content = """\ +[INCLUDES] +after = after.conf + +[Definition] +key = main_value +""" + p = _write(tmp_path, "main.conf", main_content) + parser = Fail2BanConfigParser() + parser.read_file(p) + # 'after' was loaded last → highest priority. + assert parser.get("Definition", "key") == "after_value" + + def test_cycle_detection(self, tmp_path: Path) -> None: + # A includes B, B includes A. + a_content = """\ +[INCLUDES] +before = b.conf + +[Definition] +key = from_a +""" + b_content = """\ +[INCLUDES] +before = a.conf + +[Definition] +key = from_b +""" + _write(tmp_path, "a.conf", a_content) + _write(tmp_path, "b.conf", b_content) + parser = Fail2BanConfigParser() + # Should not infinite-loop; terminates via cycle detection. + parser.read_file(tmp_path / "a.conf") + assert parser.has_section("Definition") + + def test_max_depth_guard(self, tmp_path: Path) -> None: + # Create a chain: 0→1→2→…→max+1 + max_depth = 3 + for i in range(max_depth + 2): + content = f"[INCLUDES]\nbefore = {i + 1}.conf\n\n[s{i}]\nk = v\n" + _write(tmp_path, f"{i}.conf", content) + parser = Fail2BanConfigParser(max_include_depth=max_depth) + parser.read_file(tmp_path / "0.conf") + # Should complete without recursion error; some sections will be missing. + assert isinstance(parser.sections(), list) + + def test_invalid_ini_is_ignored(self, tmp_path: Path) -> None: + bad = _write(tmp_path, "bad.conf", "this is not valid [[[ini\nstuff") + parser = Fail2BanConfigParser() + # Should not raise; parser logs and continues. + parser.read_file(bad) + + +# --------------------------------------------------------------------------- +# TestReadWithOverrides +# --------------------------------------------------------------------------- + + +class TestReadWithOverrides: + def test_local_overrides_conf(self, tmp_path: Path) -> None: + _write(tmp_path, "sshd.conf", "[Definition]\nfailregex = original\n") + _write(tmp_path, "sshd.local", "[Definition]\nfailregex = overridden\n") + parser = Fail2BanConfigParser() + parser.read_with_overrides(tmp_path / "sshd.conf") + assert parser.get("Definition", "failregex") == "overridden" + + def test_no_local_just_reads_conf(self, tmp_path: Path) -> None: + _write(tmp_path, "sshd.conf", "[Definition]\nfailregex = only_conf\n") + parser = Fail2BanConfigParser() + parser.read_with_overrides(tmp_path / "sshd.conf") + assert parser.get("Definition", "failregex") == "only_conf" + + def test_local_adds_new_key(self, tmp_path: Path) -> None: + _write(tmp_path, "sshd.conf", "[Definition]\nfailregex = orig\n") + _write(tmp_path, "sshd.local", "[Definition]\nextrakey = newval\n") + parser = Fail2BanConfigParser() + parser.read_with_overrides(tmp_path / "sshd.conf") + assert parser.get("Definition", "extrakey") == "newval" + + def test_conf_keys_preserved_when_local_overrides_other( + self, tmp_path: Path + ) -> None: + _write( + tmp_path, + "sshd.conf", + "[Definition]\nfailregex = orig\nignoreregex = keep_me\n", + ) + _write(tmp_path, "sshd.local", "[Definition]\nfailregex = new\n") + parser = Fail2BanConfigParser() + parser.read_with_overrides(tmp_path / "sshd.conf") + assert parser.get("Definition", "ignoreregex") == "keep_me" + assert parser.get("Definition", "failregex") == "new" + + +# --------------------------------------------------------------------------- +# TestSections +# --------------------------------------------------------------------------- + + +class TestSections: + def test_sections_excludes_default(self, tmp_path: Path) -> None: + content = "[DEFAULT]\nfoo = bar\n\n[Definition]\nbaz = qux\n" + p = _write(tmp_path, "x.conf", content) + parser = Fail2BanConfigParser() + parser.read_file(p) + secs = parser.sections() + assert "DEFAULT" not in secs + assert "Definition" in secs + + def test_has_section_true(self, tmp_path: Path) -> None: + p = _write(tmp_path, "x.conf", "[Init]\nname = test\n") + parser = Fail2BanConfigParser() + parser.read_file(p) + assert parser.has_section("Init") is True + + def test_has_section_false(self, tmp_path: Path) -> None: + p = _write(tmp_path, "x.conf", "[Init]\nname = test\n") + parser = Fail2BanConfigParser() + parser.read_file(p) + assert parser.has_section("Nonexistent") is False + + def test_get_returns_none_for_missing_section(self, tmp_path: Path) -> None: + p = _write(tmp_path, "x.conf", "[Init]\nname = test\n") + parser = Fail2BanConfigParser() + parser.read_file(p) + assert parser.get("NoSection", "key") is None + + def test_get_returns_none_for_missing_key(self, tmp_path: Path) -> None: + p = _write(tmp_path, "x.conf", "[Init]\nname = test\n") + parser = Fail2BanConfigParser() + parser.read_file(p) + assert parser.get("Init", "nokey") is None + + +# --------------------------------------------------------------------------- +# TestSectionDict +# --------------------------------------------------------------------------- + + +class TestSectionDict: + def test_returns_all_keys(self, tmp_path: Path) -> None: + p = _write(tmp_path, "x.conf", "[Definition]\na = 1\nb = 2\n") + parser = Fail2BanConfigParser() + parser.read_file(p) + d = parser.section_dict("Definition") + assert d == {"a": "1", "b": "2"} + + def test_empty_for_missing_section(self, tmp_path: Path) -> None: + p = _write(tmp_path, "x.conf", "[Definition]\na = 1\n") + parser = Fail2BanConfigParser() + parser.read_file(p) + assert parser.section_dict("Init") == {} + + def test_skip_excludes_keys(self, tmp_path: Path) -> None: + p = _write(tmp_path, "x.conf", "[Definition]\na = 1\nb = 2\nc = 3\n") + parser = Fail2BanConfigParser() + parser.read_file(p) + d = parser.section_dict("Definition", skip=frozenset({"b"})) + assert "b" not in d + assert d["a"] == "1" + + def test_dunder_keys_excluded(self, tmp_path: Path) -> None: + # configparser can inject __name__, __add__ etc. from DEFAULT. + content = "[DEFAULT]\n__name__ = foo\n\n[Definition]\nreal = val\n" + p = _write(tmp_path, "x.conf", content) + parser = Fail2BanConfigParser() + parser.read_file(p) + d = parser.section_dict("Definition") + assert "__name__" not in d + assert "real" in d + + +# --------------------------------------------------------------------------- +# TestDefaults +# --------------------------------------------------------------------------- + + +class TestDefaults: + def test_defaults_from_default_section(self, tmp_path: Path) -> None: + content = "[DEFAULT]\nhost = \n\n[Definition]\nfailregex = ^\n" + p = _write(tmp_path, "x.conf", content) + parser = Fail2BanConfigParser() + parser.read_file(p) + assert parser.defaults().get("host") == "" + + def test_defaults_empty_when_no_default_section(self, tmp_path: Path) -> None: + p = _write(tmp_path, "x.conf", "[Definition]\nfailregex = ^\n") + parser = Fail2BanConfigParser() + parser.read_file(p) + assert parser.defaults() == {} + + +# --------------------------------------------------------------------------- +# TestInterpolate +# --------------------------------------------------------------------------- + + +class TestInterpolate: + def _parser_with(self, tmp_path: Path, content: str) -> Fail2BanConfigParser: + p = _write(tmp_path, "x.conf", content) + parser = Fail2BanConfigParser() + parser.read_file(p) + return parser + + def test_substitutes_default_var(self, tmp_path: Path) -> None: + parser = self._parser_with( + tmp_path, + "[DEFAULT]\nhost = \n\n[Definition]\nrule = match %(host)s\n", + ) + assert parser.interpolate("match %(host)s") == "match " + + def test_substitutes_init_var(self, tmp_path: Path) -> None: + parser = self._parser_with( + tmp_path, + "[Init]\nname = sshd\n", + ) + assert parser.interpolate("f2b-%(name)s") == "f2b-sshd" + + def test_extra_vars_highest_priority(self, tmp_path: Path) -> None: + parser = self._parser_with( + tmp_path, + "[DEFAULT]\nname = default_name\n", + ) + result = parser.interpolate("%(name)s", extra_vars={"name": "override"}) + assert result == "override" + + def test_unresolvable_left_unchanged(self, tmp_path: Path) -> None: + parser = Fail2BanConfigParser() + result = parser.interpolate("value %(unknown)s end") + assert result == "value %(unknown)s end" + + def test_nested_interpolation(self, tmp_path: Path) -> None: + # %(outer)s → %(inner)s → final + parser = self._parser_with( + tmp_path, + "[DEFAULT]\ninner = final\nouter = %(inner)s\n", + ) + assert parser.interpolate("%(outer)s") == "final" + + def test_no_references_returned_unchanged(self, tmp_path: Path) -> None: + parser = Fail2BanConfigParser() + assert parser.interpolate("plain value") == "plain value" + + def test_empty_string(self, tmp_path: Path) -> None: + parser = Fail2BanConfigParser() + assert parser.interpolate("") == "" + + +# --------------------------------------------------------------------------- +# TestSplitMultiline +# --------------------------------------------------------------------------- + + +class TestSplitMultiline: + def test_strips_blank_lines(self) -> None: + raw = "line1\n\nline2\n\n" + assert Fail2BanConfigParser.split_multiline(raw) == ["line1", "line2"] + + def test_strips_comment_lines(self) -> None: + raw = "line1\n# comment\nline2" + assert Fail2BanConfigParser.split_multiline(raw) == ["line1", "line2"] + + def test_strips_leading_whitespace(self) -> None: + raw = " line1\n line2" + assert Fail2BanConfigParser.split_multiline(raw) == ["line1", "line2"] + + def test_empty_input(self) -> None: + assert Fail2BanConfigParser.split_multiline("") == [] + + def test_all_comments(self) -> None: + raw = "# first\n# second" + assert Fail2BanConfigParser.split_multiline(raw) == [] + + def test_single_line(self) -> None: + assert Fail2BanConfigParser.split_multiline("single") == ["single"] + + def test_preserves_internal_spaces(self) -> None: + raw = "iptables -I INPUT -s -j DROP" + assert Fail2BanConfigParser.split_multiline(raw) == [ + "iptables -I INPUT -s -j DROP" + ] + + def test_multiline_regex_list(self) -> None: + raw = ( + "\n" + " ^%(__prefix_line)s Authentication failure for .* from \n" + " # inline comment, skip\n" + " ^%(__prefix_line)s BREAK-IN ATTEMPT by \n" + ) + result = Fail2BanConfigParser.split_multiline(raw) + assert len(result) == 2 + assert all("HOST" in r for r in result) + + +# --------------------------------------------------------------------------- +# TestMultipleFiles (integration-style tests) +# --------------------------------------------------------------------------- + + +class TestMultipleFilesIntegration: + """Tests that combine several files to verify merge order.""" + + def test_local_only_override(self, tmp_path: Path) -> None: + _write(tmp_path, "test.conf", "[Definition]\nfailregex = base\n") + _write(tmp_path, "test.local", "[Definition]\nfailregex = local\n") + parser = Fail2BanConfigParser() + parser.read_with_overrides(tmp_path / "test.conf") + assert parser.get("Definition", "failregex") == "local" + + def test_before_then_conf_then_local(self, tmp_path: Path) -> None: + # before.conf → test.conf → test.local (ascending priority) + _write(tmp_path, "before.conf", "[Definition]\nsource = before\n") + _write( + tmp_path, + "test.conf", + "[INCLUDES]\nbefore = before.conf\n\n[Definition]\nsource = conf\n", + ) + _write(tmp_path, "test.local", "[Definition]\nsource = local\n") + parser = Fail2BanConfigParser() + parser.read_with_overrides(tmp_path / "test.conf") + assert parser.get("Definition", "source") == "local" + + def test_before_key_preserved_if_not_overridden(self, tmp_path: Path) -> None: + _write(tmp_path, "common.conf", "[DEFAULT]\nhost = \n") + _write( + tmp_path, + "filter.conf", + "[INCLUDES]\nbefore = common.conf\n\n[Definition]\nfailregex = ^%(host)s\n", + ) + parser = Fail2BanConfigParser() + parser.read_file(tmp_path / "filter.conf") + assert parser.defaults().get("host") == "" + assert parser.get("Definition", "failregex") == "^%(host)s" + + def test_fresh_parser_has_no_state(self) -> None: + p1 = Fail2BanConfigParser() + p2 = Fail2BanConfigParser() + assert p1.sections() == [] + assert p2.sections() == [] + assert p1 is not p2 diff --git a/backend/tests/test_utils/test_config_writer.py b/backend/tests/test_utils/test_config_writer.py new file mode 100644 index 0000000..f749d21 --- /dev/null +++ b/backend/tests/test_utils/test_config_writer.py @@ -0,0 +1,290 @@ +"""Tests for app.utils.config_writer.""" + +from __future__ import annotations + +from pathlib import Path + +import pytest # noqa: F401 — used by pytest.raises + +from app.utils.config_writer import ( + _get_file_lock, + delete_local_file, + remove_local_key, + write_local_override, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _write(tmp_path: Path, name: str, content: str) -> Path: + p = tmp_path / name + p.write_text(content, encoding="utf-8") + return p + + +def _read(path: Path) -> str: + return path.read_text(encoding="utf-8") + + +# --------------------------------------------------------------------------- +# TestGetFileLock +# --------------------------------------------------------------------------- + + +class TestGetFileLock: + def test_same_path_returns_same_lock(self, tmp_path: Path) -> None: + path = tmp_path / "test.local" + lock_a = _get_file_lock(path) + lock_b = _get_file_lock(path) + assert lock_a is lock_b + + def test_different_paths_return_different_locks(self, tmp_path: Path) -> None: + lock_a = _get_file_lock(tmp_path / "a.local") + lock_b = _get_file_lock(tmp_path / "b.local") + assert lock_a is not lock_b + + +# --------------------------------------------------------------------------- +# TestWriteLocalOverride +# --------------------------------------------------------------------------- + + +class TestWriteLocalOverride: + def test_creates_new_file(self, tmp_path: Path) -> None: + path = tmp_path / "sshd.local" + write_local_override(path, "Definition", {"failregex": "^bad$"}) + assert path.is_file() + + def test_file_contains_written_key(self, tmp_path: Path) -> None: + path = tmp_path / "sshd.local" + write_local_override(path, "Definition", {"failregex": "^bad$"}) + content = _read(path) + assert "failregex" in content + assert "^bad$" in content + + def test_creates_parent_directory(self, tmp_path: Path) -> None: + path = tmp_path / "subdir" / "sshd.local" + write_local_override(path, "Definition", {"key": "val"}) + assert path.is_file() + + def test_updates_existing_key(self, tmp_path: Path) -> None: + path = tmp_path / "sshd.local" + write_local_override(path, "Definition", {"failregex": "original"}) + write_local_override(path, "Definition", {"failregex": "updated"}) + content = _read(path) + assert "updated" in content + assert "original" not in content + + def test_preserves_other_sections(self, tmp_path: Path) -> None: + existing = "[Init]\nname = sshd\n\n[Definition]\nfailregex = orig\n" + path = _write(tmp_path, "sshd.local", existing) + write_local_override(path, "Definition", {"failregex": "new"}) + content = _read(path) + assert "Init" in content + assert "name" in content + + def test_preserves_other_keys_in_section(self, tmp_path: Path) -> None: + existing = "[Definition]\nfailregex = orig\nignoreregex = keep\n" + path = _write(tmp_path, "sshd.local", existing) + write_local_override(path, "Definition", {"failregex": "new"}) + content = _read(path) + assert "ignoreregex" in content + assert "keep" in content + + def test_adds_new_section(self, tmp_path: Path) -> None: + existing = "[Definition]\nfailregex = orig\n" + path = _write(tmp_path, "sshd.local", existing) + write_local_override(path, "Init", {"name": "sshd"}) + content = _read(path) + assert "[Init]" in content + assert "name" in content + + def test_writes_multiple_keys(self, tmp_path: Path) -> None: + path = tmp_path / "sshd.local" + write_local_override(path, "Definition", {"a": "1", "b": "2", "c": "3"}) + content = _read(path) + assert "a" in content + assert "b" in content + assert "c" in content + + def test_raises_for_conf_path(self, tmp_path: Path) -> None: + bad = tmp_path / "sshd.conf" + with pytest.raises(ValueError, match=r"\.local"): + write_local_override(bad, "Definition", {"key": "val"}) + + def test_raises_for_non_local_extension(self, tmp_path: Path) -> None: + bad = tmp_path / "sshd.ini" + with pytest.raises(ValueError): + write_local_override(bad, "Definition", {"key": "val"}) + + +# --------------------------------------------------------------------------- +# TestRemoveLocalKey +# --------------------------------------------------------------------------- + + +class TestRemoveLocalKey: + def test_removes_existing_key(self, tmp_path: Path) -> None: + path = _write( + tmp_path, + "sshd.local", + "[Definition]\nfailregex = bad\nignoreregex = keep\n", + ) + remove_local_key(path, "Definition", "failregex") + content = _read(path) + assert "failregex" not in content + assert "ignoreregex" in content + + def test_noop_for_missing_key(self, tmp_path: Path) -> None: + path = _write(tmp_path, "sshd.local", "[Definition]\nother = val\n") + # Should not raise. + remove_local_key(path, "Definition", "nonexistent") + assert path.is_file() + + def test_noop_for_missing_section(self, tmp_path: Path) -> None: + path = _write(tmp_path, "sshd.local", "[Definition]\nother = val\n") + remove_local_key(path, "Init", "name") + assert path.is_file() + + def test_noop_for_missing_file(self, tmp_path: Path) -> None: + path = tmp_path / "missing.local" + # Should not raise even if file doesn't exist. + remove_local_key(path, "Definition", "key") + + def test_removes_empty_section(self, tmp_path: Path) -> None: + # [Definition] will become empty and be removed; [Init] keeps the file. + path = _write( + tmp_path, + "sshd.local", + "[Definition]\nonly_key = val\n\n[Init]\nname = sshd\n", + ) + remove_local_key(path, "Definition", "only_key") + content = _read(path) + assert "[Definition]" not in content + assert "[Init]" in content + + def test_deletes_file_when_no_sections_remain(self, tmp_path: Path) -> None: + path = _write(tmp_path, "sshd.local", "[Definition]\nonly_key = val\n") + remove_local_key(path, "Definition", "only_key") + assert not path.exists() + + def test_preserves_other_sections_after_removal(self, tmp_path: Path) -> None: + path = _write( + tmp_path, + "sshd.local", + "[Definition]\nkey = val\n\n[Init]\nname = sshd\n", + ) + remove_local_key(path, "Definition", "key") + content = _read(path) + assert "[Init]" in content + assert "name" in content + + def test_raises_for_conf_path(self, tmp_path: Path) -> None: + bad = tmp_path / "sshd.conf" + with pytest.raises(ValueError, match=r"\.local"): + remove_local_key(bad, "Definition", "key") + + +# --------------------------------------------------------------------------- +# TestDeleteLocalFile +# --------------------------------------------------------------------------- + + +class TestDeleteLocalFile: + def test_deletes_existing_local_with_conf(self, tmp_path: Path) -> None: + _write(tmp_path, "sshd.conf", "[Definition]\n") + path = _write(tmp_path, "sshd.local", "[Definition]\nkey = val\n") + delete_local_file(path) + assert not path.exists() + + def test_raises_file_not_found(self, tmp_path: Path) -> None: + _write(tmp_path, "sshd.conf", "[Definition]\n") + missing = tmp_path / "sshd.local" + with pytest.raises(FileNotFoundError): + delete_local_file(missing) + + def test_raises_oserror_for_orphan_without_flag(self, tmp_path: Path) -> None: + path = _write(tmp_path, "orphan.local", "[Definition]\nkey = val\n") + with pytest.raises(OSError, match="No corresponding .conf"): + delete_local_file(path) + + def test_allow_orphan_deletes_local_only_file(self, tmp_path: Path) -> None: + path = _write(tmp_path, "orphan.local", "[Definition]\nkey = val\n") + delete_local_file(path, allow_orphan=True) + assert not path.exists() + + def test_raises_for_conf_path(self, tmp_path: Path) -> None: + bad = _write(tmp_path, "sshd.conf", "[Definition]\n") + with pytest.raises(ValueError, match=r"\.local"): + delete_local_file(bad) + + def test_raises_for_non_local_extension(self, tmp_path: Path) -> None: + bad = tmp_path / "sshd.ini" + bad.write_text("x", encoding="utf-8") + with pytest.raises(ValueError): + delete_local_file(bad) + + +# --------------------------------------------------------------------------- +# TestAtomicWrite (integration) +# --------------------------------------------------------------------------- + + +class TestAtomicWrite: + def test_no_temp_files_left_after_write(self, tmp_path: Path) -> None: + path = tmp_path / "sshd.local" + write_local_override(path, "Definition", {"key": "val"}) + files = list(tmp_path.iterdir()) + # Only the target file should exist. + assert len(files) == 1 + assert files[0].name == "sshd.local" + + def test_write_is_idempotent(self, tmp_path: Path) -> None: + path = tmp_path / "sshd.local" + for _ in range(5): + write_local_override(path, "Definition", {"key": "val"}) + content = _read(path) + # 'key' should appear exactly once. + assert content.count("key") == 1 + + +# --------------------------------------------------------------------------- +# TestEdgeCases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_write_empty_key_values_creates_empty_section( + self, tmp_path: Path + ) -> None: + path = tmp_path / "sshd.local" + write_local_override(path, "Definition", {}) + content = _read(path) + assert "[Definition]" in content + + def test_remove_key_with_unicode_value(self, tmp_path: Path) -> None: + path = _write( + tmp_path, + "sshd.local", + "[Definition]\nkey = 日本語\nother = keep\n", + ) + remove_local_key(path, "Definition", "key") + content = _read(path) + assert "日本語" not in content + assert "other" in content + + def test_write_value_with_newlines(self, tmp_path: Path) -> None: + path = tmp_path / "sshd.local" + # configparser stores multi-line values with continuation indent. + multiline = "line1\n line2\n line3" + write_local_override(path, "Definition", {"failregex": multiline}) + assert path.is_file() + + def test_remove_last_key_of_last_section_deletes_file( + self, tmp_path: Path + ) -> None: + path = _write(tmp_path, "sshd.local", "[Definition]\nlast_key = val\n") + remove_local_key(path, "Definition", "last_key") + assert not path.exists()