refactoring-backend #3

Merged
lukas.pupkalipinski merged 403 commits from refactoring-backend into main 2026-05-20 20:23:46 +02:00
6 changed files with 572 additions and 35 deletions
Showing only changes of commit a79f5339bc - Show all commits

View File

@@ -22,7 +22,7 @@ from app.dependencies import (
HttpSessionDep,
get_db,
)
from app.models.geo import GeoCacheStatsResponse, GeoDetail, GeoInfo, GeoReResolveResponse, IpLookupResponse
from app.models.geo import GeoCacheStatsResponse, GeoReResolveResponse, IpLookupResponse
from app.services import geo_service, jail_service
from app.utils.fail2ban_client import Fail2BanConnectionError
@@ -79,21 +79,7 @@ async def lookup_ip(
detail=f"Cannot reach fail2ban: {exc}",
) from exc
raw_geo = result["geo"]
geo_detail: GeoDetail | None = None
if isinstance(raw_geo, GeoInfo):
geo_detail = GeoDetail(
country_code=raw_geo.country_code,
country_name=raw_geo.country_name,
asn=raw_geo.asn,
org=raw_geo.org,
)
return IpLookupResponse(
ip=result["ip"],
currently_banned_in=result["currently_banned_in"],
geo=geo_detail,
)
return IpLookupResponse(**result)
# ---------------------------------------------------------------------------

View File

@@ -40,7 +40,8 @@ from app.repositories.history_archive_repo import (
get_all_archived_history,
get_archived_history,
)
from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso
if TYPE_CHECKING:
import aiohttp
@@ -50,6 +51,12 @@ if TYPE_CHECKING:
log: structlog.stdlib.BoundLogger = structlog.get_logger()
async def get_fail2ban_db_path(socket_path: str) -> str:
"""Return the fail2ban database path using the shared metadata cache."""
return await default_fail2ban_metadata_service.get_db_path(socket_path)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

View File

@@ -29,10 +29,17 @@ from app.models.history import (
)
from app.repositories import fail2ban_db_repo
from app.repositories.history_archive_repo import archive_ban_event
from app.utils.fail2ban_db_utils import get_fail2ban_db_path, parse_data_json, ts_to_iso
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso
log: structlog.stdlib.BoundLogger = structlog.get_logger()
async def get_fail2ban_db_path(socket_path: str) -> str:
"""Return the fail2ban database path using the shared metadata cache."""
return await default_fail2ban_metadata_service.get_db_path(socket_path)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------

View File

@@ -0,0 +1,551 @@
from __future__ import annotations
import asyncio
import configparser
import io
import os
import re
import tempfile
from pathlib import Path
from typing import cast
import structlog
from app.exceptions import (
ConfigWriteError,
FilterNameError,
JailNameError,
)
from app.models.config import (
ActionConfig,
BantimeEscalation,
InactiveJail,
JailValidationIssue,
JailValidationResult,
)
from app.utils import conffile_parser
from app.utils.constants import FAIL2BAN_TRUTHY_VALUES
from app.utils.fail2ban_client import (
Fail2BanClient,
Fail2BanConnectionError,
Fail2BanResponse,
)
log: structlog.stdlib.BoundLogger = structlog.get_logger()
_SOCKET_TIMEOUT: float = 10.0
# Allowlist pattern for jail names used in path construction.
_SAFE_JAIL_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
# Allowlist pattern for filter names used in path construction.
_SAFE_FILTER_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
# Allowlist pattern for action names used in path construction.
_SAFE_ACTION_NAME_RE: re.Pattern[str] = re.compile(r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$")
# Sections that are not jail definitions.
_META_SECTIONS: frozenset[str] = frozenset({"INCLUDES", "DEFAULT"})
# False-ish values for the ``enabled`` key.
_FALSE_VALUES: frozenset[str] = frozenset({"false", "no", "0"})
def _build_parser() -> configparser.RawConfigParser:
"""Return a parser configured for fail2ban-style INI files."""
parser = configparser.RawConfigParser(strict=False, interpolation=None)
parser.optionxform = str
return parser
def _is_truthy(value: str) -> bool:
"""Return ``True`` if *value* represents a fail2ban boolean true."""
return value.strip().lower() in FAIL2BAN_TRUTHY_VALUES
def _parse_int_safe(value: str) -> int | None:
"""Parse *value* as int, returning ``None`` on failure."""
try:
return int(value.strip())
except (ValueError, AttributeError):
return None
def _parse_time_to_seconds(value: str | None, default: int) -> int:
"""Convert a fail2ban time string to seconds."""
if not value:
return default
stripped = value.strip()
if stripped == "-1":
return -1
multipliers: dict[str, int] = {
"w": 604800,
"d": 86400,
"h": 3600,
"m": 60,
"s": 1,
}
for suffix, factor in multipliers.items():
if stripped.endswith(suffix) and len(stripped) > 1:
try:
return int(stripped[:-1]) * factor
except ValueError:
return default
try:
return int(stripped)
except ValueError:
return default
def _parse_multiline(raw: str) -> list[str]:
"""Split a multi-line INI value into individual non-blank lines."""
result: list[str] = []
for line in raw.splitlines():
stripped = line.strip()
if stripped and not stripped.startswith("#"):
result.append(stripped)
return result
def _resolve_filter(raw_filter: str, jail_name: str, mode: str) -> str:
"""Resolve fail2ban variable placeholders in a filter string."""
result = raw_filter.replace("%(__name__)s", jail_name)
return result.replace("%(mode)s", mode)
def _ordered_config_files(config_dir: Path) -> list[Path]:
"""Return all jail config files in fail2ban merge order."""
files: list[Path] = []
jail_conf = config_dir / "jail.conf"
if jail_conf.is_file():
files.append(jail_conf)
jail_local = config_dir / "jail.local"
if jail_local.is_file():
files.append(jail_local)
jail_d = config_dir / "jail.d"
if jail_d.is_dir():
files.extend(sorted(jail_d.glob("*.conf")))
files.extend(sorted(jail_d.glob("*.local")))
return files
def _build_inactive_jail(
name: str,
settings: dict[str, str],
source_file: str,
config_dir: Path | None = None,
) -> InactiveJail:
"""Construct an :class:`~app.models.config.InactiveJail` from raw settings."""
raw_filter = settings.get("filter", "")
mode = settings.get("mode", "normal")
filter_name = _resolve_filter(raw_filter, name, mode) if raw_filter else name
raw_action = settings.get("action", "")
actions = _parse_multiline(raw_action) if raw_action else []
raw_logpath = settings.get("logpath", "")
logpath = _parse_multiline(raw_logpath) if raw_logpath else []
enabled_raw = settings.get("enabled", "false")
enabled = _is_truthy(enabled_raw)
maxretry_raw = settings.get("maxretry", "")
maxretry = _parse_int_safe(maxretry_raw)
ban_time_seconds = _parse_time_to_seconds(settings.get("bantime"), 600)
find_time_seconds = _parse_time_to_seconds(settings.get("findtime"), 600)
log_encoding = settings.get("logencoding") or "auto"
backend = settings.get("backend") or "auto"
date_pattern = settings.get("datepattern") or None
use_dns = settings.get("usedns") or "warn"
prefregex = settings.get("prefregex") or ""
fail_regex = _parse_multiline(settings.get("failregex", ""))
ignore_regex = _parse_multiline(settings.get("ignoreregex", ""))
esc_increment = _is_truthy(settings.get("bantime.increment", "false"))
esc_factor_raw = settings.get("bantime.factor")
esc_factor = float(esc_factor_raw) if esc_factor_raw else None
esc_formula = settings.get("bantime.formula") or None
esc_multipliers = settings.get("bantime.multipliers") or None
esc_max_raw = settings.get("bantime.maxtime")
esc_max_time = _parse_time_to_seconds(esc_max_raw, 0) if esc_max_raw else None
esc_rnd_raw = settings.get("bantime.rndtime")
esc_rnd_time = _parse_time_to_seconds(esc_rnd_raw, 0) if esc_rnd_raw else None
esc_overall = _is_truthy(settings.get("bantime.overalljails", "false"))
bantime_escalation = (
BantimeEscalation(
increment=esc_increment,
factor=esc_factor,
formula=esc_formula,
multipliers=esc_multipliers,
max_time=esc_max_time,
rnd_time=esc_rnd_time,
overall_jails=esc_overall,
)
if esc_increment
else None
)
return InactiveJail(
name=name,
filter=filter_name,
actions=actions,
port=settings.get("port") or None,
logpath=logpath,
bantime=settings.get("bantime") or None,
findtime=settings.get("findtime") or None,
maxretry=maxretry,
ban_time_seconds=ban_time_seconds,
find_time_seconds=find_time_seconds,
log_encoding=log_encoding,
backend=backend,
date_pattern=date_pattern,
use_dns=use_dns,
prefregex=prefregex,
fail_regex=fail_regex,
ignore_regex=ignore_regex,
bantime_escalation=bantime_escalation,
source_file=source_file,
enabled=enabled,
has_local_override=((config_dir / "jail.d" / f"{name}.local").is_file() if config_dir is not None else False),
)
def _parse_jails_sync(
config_dir: Path,
) -> tuple[dict[str, dict[str, str]], dict[str, str]]:
"""Synchronously parse all jail configs and return merged definitions."""
parser = _build_parser()
files = _ordered_config_files(config_dir)
source_files: dict[str, str] = {}
for path in files:
try:
single = _build_parser()
single.read(str(path), encoding="utf-8")
for section in single.sections():
if section not in _META_SECTIONS:
source_files[section] = str(path)
except (configparser.Error, OSError) as exc:
log.warning("jail_config_read_error", path=str(path), error=str(exc))
try:
parser.read([str(p) for p in files], encoding="utf-8")
except configparser.Error as exc:
log.warning("jail_config_parse_error", error=str(exc))
jails: dict[str, dict[str, str]] = {}
for section in parser.sections():
if section in _META_SECTIONS:
continue
try:
jails[section] = dict(parser.items(section))
except configparser.Error as exc:
log.warning("jail_section_parse_error", section=section, error=str(exc))
log.debug("jails_parsed", count=len(jails), config_dir=str(config_dir))
return jails, source_files
async def _get_active_jail_names(socket_path: str) -> set[str]:
"""Fetch the set of currently running jail names from fail2ban."""
try:
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
def _to_dict_inner(pairs: object) -> dict[str, object]:
if not isinstance(pairs, (list, tuple)):
return {}
result: dict[str, object] = {}
for item in pairs:
try:
k, v = item
result[str(k)] = v
except (TypeError, ValueError):
pass
return result
def _ok(response: object) -> object:
code, data = cast("Fail2BanResponse", response)
if code != 0:
raise ValueError(f"fail2ban error {code}: {data!r}")
return data
status_raw = _ok(await client.send(["status"]))
status_dict = _to_dict_inner(status_raw)
jail_list_raw: str = str(status_dict.get("Jail list", "") or "").strip()
if not jail_list_raw:
return set()
return {j.strip() for j in jail_list_raw.split(",") if j.strip()}
except Fail2BanConnectionError:
log.warning("fail2ban_unreachable_during_inactive_list")
return set()
except Exception as exc: # noqa: BLE001
log.warning("fail2ban_status_error_during_inactive_list", error=str(exc))
return set()
async def _probe_fail2ban_running(socket_path: str) -> bool:
"""Return ``True`` when fail2ban responds successfully to a status request."""
try:
client = Fail2BanClient(socket_path=socket_path, timeout=_SOCKET_TIMEOUT)
response = await client.send(["status"])
code, _ = cast("Fail2BanResponse", response)
return code == 0
except Fail2BanConnectionError:
log.warning("fail2ban_unreachable_during_probe", socket_path=str(socket_path))
return False
except Exception as exc: # noqa: BLE001
log.warning("fail2ban_probe_error", socket_path=str(socket_path), error=str(exc))
return False
def _extract_action_base_name(action_str: str) -> str | None:
"""Return the base action name from an action assignment string."""
if "%" in action_str or "$" in action_str:
return None
base = action_str.split("[")[0].strip()
if _SAFE_ACTION_NAME_RE.match(base):
return base
return None
def _validate_jail_config_sync(
config_dir: Path,
name: str,
) -> JailValidationResult:
"""Run synchronous pre-activation checks on a jail configuration."""
issues: list[JailValidationIssue] = []
all_jails, _ = _parse_jails_sync(config_dir)
settings = all_jails.get(name)
if settings is None:
return JailValidationResult(
jail_name=name,
valid=False,
issues=[
JailValidationIssue(
field="name",
message=f"Jail {name!r} not found in config files.",
)
],
)
filter_d = config_dir / "filter.d"
action_d = config_dir / "action.d"
raw_filter = settings.get("filter", "")
if raw_filter:
mode = settings.get("mode", "normal")
resolved = _resolve_filter(raw_filter, name, mode)
base_filter = _extract_action_base_name(resolved)
if base_filter:
conf_ok = (filter_d / f"{base_filter}.conf").is_file()
local_ok = (filter_d / f"{base_filter}.local").is_file()
if not conf_ok and not local_ok:
issues.append(
JailValidationIssue(
field="filter",
message=(f"Filter file not found: filter.d/{base_filter}.conf (or .local)"),
)
)
raw_action = settings.get("action", "")
if raw_action:
for action_line in _parse_multiline(raw_action):
action_name = _extract_action_base_name(action_line)
if action_name:
conf_ok = (action_d / f"{action_name}.conf").is_file()
local_ok = (action_d / f"{action_name}.local").is_file()
if not conf_ok and not local_ok:
issues.append(
JailValidationIssue(
field="action",
message=(f"Action file not found: action.d/{action_name}.conf (or .local)"),
)
)
for pattern in _parse_multiline(settings.get("failregex", "")):
try:
re.compile(pattern)
except re.error as exc:
issues.append(
JailValidationIssue(
field="failregex",
message=f"Invalid regex pattern: {exc}",
)
)
for pattern in _parse_multiline(settings.get("ignoreregex", "")):
try:
re.compile(pattern)
except re.error as exc:
issues.append(
JailValidationIssue(
field="ignoreregex",
message=f"Invalid regex pattern: {exc}",
)
)
raw_logpath = settings.get("logpath", "")
if raw_logpath:
for log_path in _parse_multiline(raw_logpath):
if "*" in log_path or "?" in log_path or "%(" in log_path:
continue
if not Path(log_path).exists():
issues.append(
JailValidationIssue(
field="logpath",
message=f"Log file not found on disk: {log_path}",
)
)
valid = len(issues) == 0
log.debug(
"jail_validation_complete",
jail=name,
valid=valid,
issue_count=len(issues),
)
return JailValidationResult(jail_name=name, valid=valid, issues=issues)
def _safe_jail_name(name: str) -> str:
"""Validate *name* and return it unchanged or raise :class:`JailNameError`."""
if not _SAFE_JAIL_NAME_RE.match(name):
raise JailNameError(
f"Jail name {name!r} contains invalid characters. "
"Only alphanumeric characters, hyphens, underscores, and dots are "
"allowed; must start with an alphanumeric character."
)
return name
def _safe_filter_name(name: str) -> str:
"""Validate *name* and return it unchanged or raise :class:`FilterNameError`."""
if not _SAFE_FILTER_NAME_RE.match(name):
raise FilterNameError(
f"Filter name {name!r} contains invalid characters. "
"Only alphanumeric characters, hyphens, underscores, and dots are "
"allowed; must start with an alphanumeric character."
)
return name
def _set_jail_local_key_sync(
config_dir: Path,
jail_name: str,
key: str,
value: str,
) -> None:
"""Update ``jail.d/{jail_name}.local`` to set a single key in the jail section."""
jail_d = config_dir / "jail.d"
try:
jail_d.mkdir(parents=True, exist_ok=True)
except OSError as exc:
raise ConfigWriteError(f"Cannot create jail.d directory: {exc}") from exc
local_path = jail_d / f"{jail_name}.local"
parser = _build_parser()
if local_path.is_file():
try:
parser.read(str(local_path), encoding="utf-8")
except (configparser.Error, OSError) as exc:
log.warning(
"jail_local_read_for_update_error",
jail=jail_name,
error=str(exc),
)
if not parser.has_section(jail_name):
parser.add_section(jail_name)
parser.set(jail_name, key, value)
buf = io.StringIO()
buf.write("# Managed by BanGUI — do not edit manually\n\n")
parser.write(buf)
content = buf.getvalue()
try:
with tempfile.NamedTemporaryFile(
mode="w",
encoding="utf-8",
dir=jail_d,
delete=False,
suffix=".tmp",
) as tmp:
tmp.write(content)
tmp_name = tmp.name
os.replace(tmp_name, local_path)
except OSError as exc:
with contextlib.suppress(OSError):
os.unlink(tmp_name) # noqa: F821
raise ConfigWriteError(f"Failed to write {local_path}: {exc}") from exc
log.info(
"jail_local_key_set",
jail=jail_name,
key=key,
path=str(local_path),
)
ordered_config_files = _ordered_config_files
build_parser = _build_parser
is_truthy = _is_truthy
parse_multiline = _parse_multiline
parse_jails_sync = _parse_jails_sync
build_inactive_jail = _build_inactive_jail
get_active_jail_names = _get_active_jail_names
validate_jail_config_sync = _validate_jail_config_sync
set_jail_local_key_sync = _set_jail_local_key_sync
safe_jail_name = _safe_jail_name
safe_filter_name = _safe_filter_name
probe_fail2ban_running = _probe_fail2ban_running
async def start_daemon(start_cmd_parts: list[str]) -> bool:
"""Run the configured fail2ban start command and return whether it launched."""
process = await asyncio.create_subprocess_exec(
*start_cmd_parts,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
stdout, stderr = await process.communicate()
if process.returncode != 0:
log.error(
"fail2ban_start_failed",
command=" ".join(start_cmd_parts),
returncode=process.returncode,
stdout=stdout.decode("utf-8", errors="replace"),
stderr=stderr.decode("utf-8", errors="replace"),
)
return False
log.info(
"fail2ban_start_succeeded",
command=" ".join(start_cmd_parts),
)
return True
async def wait_for_fail2ban(
socket_path: str,
max_wait_seconds: float,
poll_interval: float = 0.5,
) -> bool:
"""Probe the fail2ban socket until it is responsive or the timeout expires."""
deadline = asyncio.get_running_loop().time() + max_wait_seconds
while asyncio.get_running_loop().time() < deadline:
if await _probe_fail2ban_running(socket_path):
return True
await asyncio.sleep(poll_interval)
log.warning(
"wait_for_fail2ban_timeout",
socket_path=str(socket_path),
max_wait_seconds=max_wait_seconds,
)
return False

View File

@@ -5,26 +5,12 @@ from __future__ import annotations
import json
from datetime import UTC, datetime
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
def ts_to_iso(unix_ts: int) -> str:
"""Convert a Unix timestamp to an ISO 8601 UTC string."""
return datetime.fromtimestamp(unix_ts, tz=UTC).isoformat()
async def get_fail2ban_db_path(socket_path: str, *, force_refresh: bool = False) -> str:
"""Return the fail2ban database path, using cached metadata when available."""
return await default_fail2ban_metadata_service.get_db_path(
socket_path, force_refresh=force_refresh
)
def invalidate_fail2ban_db_path(socket_path: str) -> None:
"""Invalidate the cached fail2ban database path for the given socket."""
default_fail2ban_metadata_service.invalidate_db_path(socket_path)
def parse_data_json(raw: object) -> tuple[list[str], int]:
"""Extract matches and failure count from the fail2ban bans.data value."""
if raw is None:

View File

@@ -12,7 +12,7 @@ from httpx import ASGITransport, AsyncClient
from app.config import Settings
from app.db import init_db
from app.main import create_app
from app.models.geo import GeoInfo
from app.models.geo import GeoDetail, GeoInfo
# ---------------------------------------------------------------------------
# Fixtures
@@ -71,7 +71,7 @@ class TestGeoLookup:
async def test_200_with_geo_info(self, geo_client: AsyncClient) -> None:
"""GET /api/geo/lookup/{ip} returns 200 with enriched result."""
geo = GeoInfo(country_code="DE", country_name="Germany", asn="12345", org="Acme")
geo = GeoDetail(country_code="DE", country_name="Germany", asn="12345", org="Acme")
result: dict[str, object] = {
"ip": "1.2.3.4",
"currently_banned_in": ["sshd"],
@@ -97,7 +97,7 @@ class TestGeoLookup:
result: dict[str, object] = {
"ip": "8.8.8.8",
"currently_banned_in": [],
"geo": GeoInfo(country_code="US", country_name="United States", asn=None, org=None),
"geo": GeoDetail(country_code="US", country_name="United States", asn=None, org=None),
}
with patch(
"app.routers.geo.jail_service.lookup_ip",