diff --git a/backend/app/routers/geo.py b/backend/app/routers/geo.py index 6c1b681..9c5504f 100644 --- a/backend/app/routers/geo.py +++ b/backend/app/routers/geo.py @@ -22,7 +22,7 @@ from app.dependencies import ( HttpSessionDep, get_db, ) -from app.models.geo import GeoCacheStatsResponse, GeoDetail, GeoInfo, GeoReResolveResponse, IpLookupResponse +from app.models.geo import GeoCacheStatsResponse, GeoReResolveResponse, IpLookupResponse from app.services import geo_service, jail_service from app.utils.fail2ban_client import Fail2BanConnectionError @@ -79,21 +79,7 @@ async def lookup_ip( detail=f"Cannot reach fail2ban: {exc}", ) from exc - raw_geo = result["geo"] - geo_detail: GeoDetail | None = None - if isinstance(raw_geo, GeoInfo): - geo_detail = GeoDetail( - country_code=raw_geo.country_code, - country_name=raw_geo.country_name, - asn=raw_geo.asn, - org=raw_geo.org, - ) - - return IpLookupResponse( - ip=result["ip"], - currently_banned_in=result["currently_banned_in"], - geo=geo_detail, - ) + return IpLookupResponse(**result) # --------------------------------------------------------------------------- diff --git a/backend/app/services/ban_service.py b/backend/app/services/ban_service.py index 866749e..df3bf64 100644 --- a/backend/app/services/ban_service.py +++ b/backend/app/services/ban_service.py @@ -40,7 +40,8 @@ from app.repositories.history_archive_repo import ( get_all_archived_history, get_archived_history, ) -from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso +from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service +from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso if TYPE_CHECKING: import aiohttp @@ -50,6 +51,12 @@ if TYPE_CHECKING: log: structlog.stdlib.BoundLogger = structlog.get_logger() + +async def get_fail2ban_db_path(socket_path: str) -> str: + """Return the fail2ban database path using the shared metadata cache.""" + return await default_fail2ban_metadata_service.get_db_path(socket_path) + + # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- diff --git a/backend/app/services/history_service.py b/backend/app/services/history_service.py index c96c4b1..d2ba99d 100644 --- a/backend/app/services/history_service.py +++ b/backend/app/services/history_service.py @@ -29,10 +29,17 @@ from app.models.history import ( ) from app.repositories import fail2ban_db_repo from app.repositories.history_archive_repo import archive_ban_event -from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso +from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service +from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso log: structlog.stdlib.BoundLogger = structlog.get_logger() + +async def get_fail2ban_db_path(socket_path: str) -> str: + """Return the fail2ban database path using the shared metadata cache.""" + return await default_fail2ban_metadata_service.get_db_path(socket_path) + + # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- diff --git a/backend/app/utils/config_file_utils.py b/backend/app/utils/config_file_utils.py new file mode 100644 index 0000000..26d339f --- /dev/null +++ b/backend/app/utils/config_file_utils.py @@ -0,0 +1,551 @@ +from __future__ import annotations + +import asyncio +import configparser +import io +import os +import re +import tempfile +from pathlib import Path +from typing import cast + +import structlog + +from app.exceptions import ( + ConfigWriteError, + FilterNameError, + JailNameError, +) +from app.models.config import ( + ActionConfig, + BantimeEscalation, + InactiveJail, + JailValidationIssue, + JailValidationResult, +) +from app.utils import conffile_parser +from app.utils.constants import FAIL2BAN_TRUTHY_VALUES +from app.utils.fail2ban_client import ( + Fail2BanClient, + Fail2BanConnectionError, + Fail2BanResponse, +) + +log: structlog.stdlib.BoundLogger = structlog.get_logger() + +_SOCKET_TIMEOUT: float = 10.0 + +# Allowlist pattern for jail names used in path construction. +_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$") + +# Allowlist pattern for filter names used in path construction. +_SAFE_FILTER_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$") + +# Allowlist pattern for action names used in path construction. +_SAFE_ACTION_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$") + +# Sections that are not jail definitions. +_META_SECTIONS: frozenset[str] = frozenset({"INCLUDES", "DEFAULT"}) + +# False-ish values for the ``enabled`` key. +_FALSE_VALUES: frozenset[str] = frozenset({"false", "no", "0"}) + + +def _build_parser() -> configparser.RawConfigParser: + """Return a parser configured for fail2ban-style INI files.""" + parser = configparser.RawConfigParser(strict=False, interpolation=None) + parser.optionxform = str + return parser + + +def _is_truthy(value: str) -> bool: + """Return ``True`` if *value* represents a fail2ban boolean true.""" + return value.strip().lower() in FAIL2BAN_TRUTHY_VALUES + + +def _parse_int_safe(value: str) -> int | None: + """Parse *value* as int, returning ``None`` on failure.""" + try: + return int(value.strip()) + except (ValueError, AttributeError): + return None + + +def _parse_time_to_seconds(value: str | None, default: int) -> int: + """Convert a fail2ban time string to seconds.""" + if not value: + return default + stripped = value.strip() + if stripped == "-1": + return -1 + multipliers: dict[str, int] = { + "w": 604800, + "d": 86400, + "h": 3600, + "m": 60, + "s": 1, + } + for suffix, factor in multipliers.items(): + if stripped.endswith(suffix) and len(stripped) > 1: + try: + return int(stripped[:-1]) * factor + except ValueError: + return default + try: + return int(stripped) + except ValueError: + return default + + +def _parse_multiline(raw: str) -> list[str]: + """Split a multi-line INI value into individual non-blank lines.""" + result: list[str] = [] + for line in raw.splitlines(): + stripped = line.strip() + if stripped and not stripped.startswith("#"): + result.append(stripped) + return result + + +def _resolve_filter(raw_filter: str, jail_name: str, mode: str) -> str: + """Resolve fail2ban variable placeholders in a filter string.""" + result = raw_filter.replace("%(__name__)s", jail_name) + return result.replace("%(mode)s", mode) + + +def _ordered_config_files(config_dir: Path) -> list[Path]: + """Return all jail config files in fail2ban merge order.""" + files: list[Path] = [] + + jail_conf = config_dir / "jail.conf" + if jail_conf.is_file(): + files.append(jail_conf) + + jail_local = config_dir / "jail.local" + if jail_local.is_file(): + files.append(jail_local) + + jail_d = config_dir / "jail.d" + if jail_d.is_dir(): + files.extend(sorted(jail_d.glob("*.conf"))) + files.extend(sorted(jail_d.glob("*.local"))) + + return files + + +def _build_inactive_jail( + name: str, + settings: dict[str, str], + source_file: str, + config_dir: Path | None = None, +) -> InactiveJail: + """Construct an :class:`~app.models.config.InactiveJail` from raw settings.""" + raw_filter = settings.get("filter", "") + mode = settings.get("mode", "normal") + filter_name = _resolve_filter(raw_filter, name, mode) if raw_filter else name + + raw_action = settings.get("action", "") + actions = _parse_multiline(raw_action) if raw_action else [] + + raw_logpath = settings.get("logpath", "") + logpath = _parse_multiline(raw_logpath) if raw_logpath else [] + + enabled_raw = settings.get("enabled", "false") + enabled = _is_truthy(enabled_raw) + + maxretry_raw = settings.get("maxretry", "") + maxretry = _parse_int_safe(maxretry_raw) + + ban_time_seconds = _parse_time_to_seconds(settings.get("bantime"), 600) + find_time_seconds = _parse_time_to_seconds(settings.get("findtime"), 600) + log_encoding = settings.get("logencoding") or "auto" + backend = settings.get("backend") or "auto" + date_pattern = settings.get("datepattern") or None + use_dns = settings.get("usedns") or "warn" + prefregex = settings.get("prefregex") or "" + fail_regex = _parse_multiline(settings.get("failregex", "")) + ignore_regex = _parse_multiline(settings.get("ignoreregex", "")) + + esc_increment = _is_truthy(settings.get("bantime.increment", "false")) + esc_factor_raw = settings.get("bantime.factor") + esc_factor = float(esc_factor_raw) if esc_factor_raw else None + esc_formula = settings.get("bantime.formula") or None + esc_multipliers = settings.get("bantime.multipliers") or None + esc_max_raw = settings.get("bantime.maxtime") + esc_max_time = _parse_time_to_seconds(esc_max_raw, 0) if esc_max_raw else None + esc_rnd_raw = settings.get("bantime.rndtime") + esc_rnd_time = _parse_time_to_seconds(esc_rnd_raw, 0) if esc_rnd_raw else None + esc_overall = _is_truthy(settings.get("bantime.overalljails", "false")) + bantime_escalation = ( + BantimeEscalation( + increment=esc_increment, + factor=esc_factor, + formula=esc_formula, + multipliers=esc_multipliers, + max_time=esc_max_time, + rnd_time=esc_rnd_time, + overall_jails=esc_overall, + ) + if esc_increment + else None + ) + + return InactiveJail( + name=name, + filter=filter_name, + actions=actions, + port=settings.get("port") or None, + logpath=logpath, + bantime=settings.get("bantime") or None, + findtime=settings.get("findtime") or None, + maxretry=maxretry, + ban_time_seconds=ban_time_seconds, + find_time_seconds=find_time_seconds, + log_encoding=log_encoding, + backend=backend, + date_pattern=date_pattern, + use_dns=use_dns, + prefregex=prefregex, + fail_regex=fail_regex, + ignore_regex=ignore_regex, + bantime_escalation=bantime_escalation, + source_file=source_file, + enabled=enabled, + has_local_override=((config_dir / "jail.d" / f"{name}.local").is_file() if config_dir is not None else False), + ) + + +def _parse_jails_sync( + config_dir: Path, +) -> tuple[dict[str, dict[str, str]], dict[str, str]]: + """Synchronously parse all jail configs and return merged definitions.""" + parser = _build_parser() + files = _ordered_config_files(config_dir) + + source_files: dict[str, str] = {} + for path in files: + try: + single = _build_parser() + single.read(str(path), encoding="utf-8") + for section in single.sections(): + if section not in _META_SECTIONS: + source_files[section] = str(path) + except (configparser.Error, OSError) as exc: + log.warning("jail_config_read_error", path=str(path), error=str(exc)) + + try: + parser.read([str(p) for p in files], encoding="utf-8") + except configparser.Error as exc: + log.warning("jail_config_parse_error", error=str(exc)) + + jails: dict[str, dict[str, str]] = {} + for section in parser.sections(): + if section in _META_SECTIONS: + continue + try: + jails[section] = dict(parser.items(section)) + except configparser.Error as exc: + log.warning("jail_section_parse_error", section=section, error=str(exc)) + + log.debug("jails_parsed", count=len(jails), config_dir=str(config_dir)) + return jails, source_files + + +async def _get_active_jail_names(socket_path: str) -> set[str]: + """Fetch the set of currently running jail names from fail2ban.""" + try: + client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) + + def _to_dict_inner(pairs: object) -> dict[str, object]: + if not isinstance(pairs, (list, tuple)): + return {} + result: dict[str, object] = {} + for item in pairs: + try: + k, v = item + result[str(k)] = v + except (TypeError, ValueError): + pass + return result + + def _ok(response: object) -> object: + code, data = cast("Fail2BanResponse", response) + if code != 0: + raise ValueError(f"fail2ban error {code}: {data!r}") + return data + + status_raw = _ok(await client.send(["status"])) + status_dict = _to_dict_inner(status_raw) + jail_list_raw: str = str(status_dict.get("Jail list", "") or "").strip() + if not jail_list_raw: + return set() + return {j.strip() for j in jail_list_raw.split(",") if j.strip()} + except Fail2BanConnectionError: + log.warning("fail2ban_unreachable_during_inactive_list") + return set() + except Exception as exc: # noqa: BLE001 + log.warning("fail2ban_status_error_during_inactive_list", error=str(exc)) + return set() + + +async def _probe_fail2ban_running(socket_path: str) -> bool: + """Return ``True`` when fail2ban responds successfully to a status request.""" + try: + client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT) + response = await client.send(["status"]) + code, _ = cast("Fail2BanResponse", response) + return code == 0 + except Fail2BanConnectionError: + log.warning("fail2ban_unreachable_during_probe", socket_path=str(socket_path)) + return False + except Exception as exc: # noqa: BLE001 + log.warning("fail2ban_probe_error", socket_path=str(socket_path), error=str(exc)) + return False + + +def _extract_action_base_name(action_str: str) -> str | None: + """Return the base action name from an action assignment string.""" + if "%" in action_str or "$" in action_str: + return None + base = action_str.split("[")[0].strip() + if _SAFE_ACTION_NAME_RE.match(base): + return base + return None + + +def _validate_jail_config_sync( + config_dir: Path, + name: str, +) -> JailValidationResult: + """Run synchronous pre-activation checks on a jail configuration.""" + issues: list[JailValidationIssue] = [] + + all_jails, _ = _parse_jails_sync(config_dir) + settings = all_jails.get(name) + + if settings is None: + return JailValidationResult( + jail_name=name, + valid=False, + issues=[ + JailValidationIssue( + field="name", + message=f"Jail {name!r} not found in config files.", + ) + ], + ) + + filter_d = config_dir / "filter.d" + action_d = config_dir / "action.d" + + raw_filter = settings.get("filter", "") + if raw_filter: + mode = settings.get("mode", "normal") + resolved = _resolve_filter(raw_filter, name, mode) + base_filter = _extract_action_base_name(resolved) + if base_filter: + conf_ok = (filter_d / f"{base_filter}.conf").is_file() + local_ok = (filter_d / f"{base_filter}.local").is_file() + if not conf_ok and not local_ok: + issues.append( + JailValidationIssue( + field="filter", + message=(f"Filter file not found: filter.d/{base_filter}.conf (or .local)"), + ) + ) + + raw_action = settings.get("action", "") + if raw_action: + for action_line in _parse_multiline(raw_action): + action_name = _extract_action_base_name(action_line) + if action_name: + conf_ok = (action_d / f"{action_name}.conf").is_file() + local_ok = (action_d / f"{action_name}.local").is_file() + if not conf_ok and not local_ok: + issues.append( + JailValidationIssue( + field="action", + message=(f"Action file not found: action.d/{action_name}.conf (or .local)"), + ) + ) + + for pattern in _parse_multiline(settings.get("failregex", "")): + try: + re.compile(pattern) + except re.error as exc: + issues.append( + JailValidationIssue( + field="failregex", + message=f"Invalid regex pattern: {exc}", + ) + ) + + for pattern in _parse_multiline(settings.get("ignoreregex", "")): + try: + re.compile(pattern) + except re.error as exc: + issues.append( + JailValidationIssue( + field="ignoreregex", + message=f"Invalid regex pattern: {exc}", + ) + ) + + raw_logpath = settings.get("logpath", "") + if raw_logpath: + for log_path in _parse_multiline(raw_logpath): + if "*" in log_path or "?" in log_path or "%(" in log_path: + continue + if not Path(log_path).exists(): + issues.append( + JailValidationIssue( + field="logpath", + message=f"Log file not found on disk: {log_path}", + ) + ) + + valid = len(issues) == 0 + log.debug( + "jail_validation_complete", + jail=name, + valid=valid, + issue_count=len(issues), + ) + return JailValidationResult(jail_name=name, valid=valid, issues=issues) + + +def _safe_jail_name(name: str) -> str: + """Validate *name* and return it unchanged or raise :class:`JailNameError`.""" + if not _SAFE_JAIL_NAME_RE.match(name): + raise JailNameError( + f"Jail name {name!r} contains invalid characters. " + "Only alphanumeric characters, hyphens, underscores, and dots are " + "allowed; must start with an alphanumeric character." + ) + return name + + +def _safe_filter_name(name: str) -> str: + """Validate *name* and return it unchanged or raise :class:`FilterNameError`.""" + if not _SAFE_FILTER_NAME_RE.match(name): + raise FilterNameError( + f"Filter name {name!r} contains invalid characters. " + "Only alphanumeric characters, hyphens, underscores, and dots are " + "allowed; must start with an alphanumeric character." + ) + return name + + +def _set_jail_local_key_sync( + config_dir: Path, + jail_name: str, + key: str, + value: str, +) -> None: + """Update ``jail.d/{jail_name}.local`` to set a single key in the jail section.""" + jail_d = config_dir / "jail.d" + try: + jail_d.mkdir(parents=True, exist_ok=True) + except OSError as exc: + raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc + + local_path = jail_d / f"{jail_name}.local" + parser = _build_parser() + if local_path.is_file(): + try: + parser.read(str(local_path), encoding="utf-8") + except (configparser.Error, OSError) as exc: + log.warning( + "jail_local_read_for_update_error", + jail=jail_name, + error=str(exc), + ) + + if not parser.has_section(jail_name): + parser.add_section(jail_name) + parser.set(jail_name, key, value) + + buf = io.StringIO() + buf.write("# Managed by BanGUI — do not edit manually\n\n") + parser.write(buf) + content = buf.getvalue() + + try: + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + dir=jail_d, + delete=False, + suffix=".tmp", + ) as tmp: + tmp.write(content) + tmp_name = tmp.name + os.replace(tmp_name, local_path) + except OSError as exc: + with contextlib.suppress(OSError): + os.unlink(tmp_name) # noqa: F821 + raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc + + log.info( + "jail_local_key_set", + jail=jail_name, + key=key, + path=str(local_path), + ) + + +ordered_config_files = _ordered_config_files +build_parser = _build_parser +is_truthy = _is_truthy +parse_multiline = _parse_multiline +parse_jails_sync = _parse_jails_sync +build_inactive_jail = _build_inactive_jail +get_active_jail_names = _get_active_jail_names +validate_jail_config_sync = _validate_jail_config_sync +set_jail_local_key_sync = _set_jail_local_key_sync +safe_jail_name = _safe_jail_name +safe_filter_name = _safe_filter_name +probe_fail2ban_running = _probe_fail2ban_running + + +async def start_daemon(start_cmd_parts: list[str]) -> bool: + """Run the configured fail2ban start command and return whether it launched.""" + process = await asyncio.create_subprocess_exec( + *start_cmd_parts, + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.PIPE, + ) + stdout, stderr = await process.communicate() + if process.returncode != 0: + log.error( + "fail2ban_start_failed", + command=" ".join(start_cmd_parts), + returncode=process.returncode, + stdout=stdout.decode("utf-8", errors="replace"), + stderr=stderr.decode("utf-8", errors="replace"), + ) + return False + log.info( + "fail2ban_start_succeeded", + command=" ".join(start_cmd_parts), + ) + return True + + +async def wait_for_fail2ban( + socket_path: str, + max_wait_seconds: float, + poll_interval: float = 0.5, +) -> bool: + """Probe the fail2ban socket until it is responsive or the timeout expires.""" + deadline = asyncio.get_running_loop().time() + max_wait_seconds + while asyncio.get_running_loop().time() < deadline: + if await _probe_fail2ban_running(socket_path): + return True + await asyncio.sleep(poll_interval) + log.warning( + "wait_for_fail2ban_timeout", + socket_path=str(socket_path), + max_wait_seconds=max_wait_seconds, + ) + return False diff --git a/backend/app/utils/fail2ban_db_utils.py b/backend/app/utils/fail2ban_db_utils.py index 89e20b2..703a00b 100644 --- a/backend/app/utils/fail2ban_db_utils.py +++ b/backend/app/utils/fail2ban_db_utils.py @@ -5,26 +5,12 @@ from __future__ import annotations import json from datetime import UTC, datetime -from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service - def ts_to_iso(unix_ts: int) -> str: """Convert a Unix timestamp to an ISO 8601 UTC string.""" return datetime.fromtimestamp(unix_ts, tz=UTC).isoformat() -async def get_fail2ban_db_path(socket_path: str, *, force_refresh: bool = False) -> str: - """Return the fail2ban database path, using cached metadata when available.""" - return await default_fail2ban_metadata_service.get_db_path( - socket_path, force_refresh=force_refresh - ) - - -def invalidate_fail2ban_db_path(socket_path: str) -> None: - """Invalidate the cached fail2ban database path for the given socket.""" - default_fail2ban_metadata_service.invalidate_db_path(socket_path) - - def parse_data_json(raw: object) -> tuple[list[str], int]: """Extract matches and failure count from the fail2ban bans.data value.""" if raw is None: diff --git a/backend/tests/test_routers/test_geo.py b/backend/tests/test_routers/test_geo.py index 3244133..aad1419 100644 --- a/backend/tests/test_routers/test_geo.py +++ b/backend/tests/test_routers/test_geo.py @@ -12,7 +12,7 @@ from httpx import ASGITransport, AsyncClient from app.config import Settings from app.db import init_db from app.main import create_app -from app.models.geo import GeoInfo +from app.models.geo import GeoDetail, GeoInfo # --------------------------------------------------------------------------- # Fixtures @@ -71,7 +71,7 @@ class TestGeoLookup: async def test_200_with_geo_info(self, geo_client: AsyncClient) -> None: """GET /api/geo/lookup/{ip} returns 200 with enriched result.""" - geo = GeoInfo(country_code="DE", country_name="Germany", asn="12345", org="Acme") + geo = GeoDetail(country_code="DE", country_name="Germany", asn="12345", org="Acme") result: dict[str, object] = { "ip": "1.2.3.4", "currently_banned_in": ["sshd"], @@ -97,7 +97,7 @@ class TestGeoLookup: result: dict[str, object] = { "ip": "8.8.8.8", "currently_banned_in": [], - "geo": GeoInfo(country_code="US", country_name="United States", asn=None, org=None), + "geo": GeoDetail(country_code="US", country_name="United States", asn=None, org=None), } with patch( "app.routers.geo.jail_service.lookup_ip",