diff --git a/Docs/Tasks.md b/Docs/Tasks.md index 73be16e..af05129 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -208,7 +208,7 @@ fail2ban ships with a large collection of filter definitions in `filter.d/` (ove fail2ban ships with many action definitions in `action.d/` (iptables, firewalld, cloudflare, sendmail, etc.). Users need to see all available actions, understand which are in use, and assign them to jails. -### Task 3.1 — Backend: List All Available Actions with Active/Inactive Status +### Task 3.1 — Backend: List All Available Actions with Active/Inactive Status ✅ DONE **Goal:** Enumerate all action config files and mark each as active or inactive based on jail usage. @@ -231,7 +231,7 @@ fail2ban ships with many action definitions in `action.d/` (iptables, firewalld, --- -### Task 3.2 — Backend: Activate and Edit Actions +### Task 3.2 — Backend: Activate and Edit Actions ✅ DONE **Goal:** Allow users to assign actions to jails, edit action definitions, and create new actions. @@ -281,7 +281,7 @@ fail2ban ships with many action definitions in `action.d/` (iptables, firewalld, --- -### Task 3.4 — Tests: Action Discovery and Management +### Task 3.4 — Tests: Action Discovery and Management ✅ DONE **Goal:** Test coverage for action listing, editing, creation, and assignment. diff --git a/backend/app/models/config.py b/backend/app/models/config.py index 1f7a644..ae55fc4 100644 --- a/backend/app/models/config.py +++ b/backend/app/models/config.py @@ -508,6 +508,33 @@ class ActionConfig(BaseModel): default_factory=dict, description="Runtime parameters that can be overridden per jail.", ) + # Active-status fields — populated by config_file_service.list_actions / + # get_action; default to safe "inactive" values when not computed. + active: bool = Field( + default=False, + description=( + "``True`` when this action is referenced by at least one currently " + "enabled (running) jail." + ), + ) + used_by_jails: list[str] = Field( + default_factory=list, + description=( + "Names of currently enabled jails that reference this action. " + "Empty when ``active`` is ``False``." + ), + ) + source_file: str = Field( + default="", + description="Absolute path to the ``.conf`` source file for this action.", + ) + has_local_override: bool = Field( + default=False, + description=( + "``True`` when a ``.local`` override file exists alongside the " + "base ``.conf`` file." + ), + ) class ActionConfigUpdate(BaseModel): @@ -527,6 +554,110 @@ class ActionConfigUpdate(BaseModel): init_vars: dict[str, str] | None = Field(default=None) +class ActionListResponse(BaseModel): + """Response for ``GET /api/config/actions``.""" + + model_config = ConfigDict(strict=True) + + actions: list[ActionConfig] = Field( + default_factory=list, + description=( + "All discovered actions, each annotated with active/inactive status " + "and the jails that reference them." + ), + ) + total: int = Field(..., ge=0, description="Total number of actions found.") + + +class ActionUpdateRequest(BaseModel): + """Payload for ``PUT /api/config/actions/{name}``. + + Accepts only the user-editable ``[Definition]`` lifecycle fields and + ``[Init]`` parameters. Fields left as ``None`` are not changed. + """ + + model_config = ConfigDict(strict=True) + + actionstart: str | None = Field( + default=None, + description="Updated ``actionstart`` command. ``None`` = keep existing.", + ) + actionstop: str | None = Field( + default=None, + description="Updated ``actionstop`` command. ``None`` = keep existing.", + ) + actioncheck: str | None = Field( + default=None, + description="Updated ``actioncheck`` command. ``None`` = keep existing.", + ) + actionban: str | None = Field( + default=None, + description="Updated ``actionban`` command. ``None`` = keep existing.", + ) + actionunban: str | None = Field( + default=None, + description="Updated ``actionunban`` command. ``None`` = keep existing.", + ) + actionflush: str | None = Field( + default=None, + description="Updated ``actionflush`` command. ``None`` = keep existing.", + ) + definition_vars: dict[str, str] | None = Field( + default=None, + description="Additional ``[Definition]`` variables to set. ``None`` = keep existing.", + ) + init_vars: dict[str, str] | None = Field( + default=None, + description="``[Init]`` parameters to set. ``None`` = keep existing.", + ) + + +class ActionCreateRequest(BaseModel): + """Payload for ``POST /api/config/actions``. + + Creates a new user-defined action at ``action.d/{name}.local``. + """ + + model_config = ConfigDict(strict=True) + + name: str = Field( + ..., + description="Action base name (e.g. ``my-custom-action``). Must not already exist.", + ) + actionstart: str | None = Field(default=None, description="Command to execute at jail start.") + actionstop: str | None = Field(default=None, description="Command to execute at jail stop.") + actioncheck: str | None = Field(default=None, description="Command to execute before each ban.") + actionban: str | None = Field(default=None, description="Command to execute to ban an IP.") + actionunban: str | None = Field(default=None, description="Command to execute to unban an IP.") + actionflush: str | None = Field(default=None, description="Command to flush all bans on shutdown.") + definition_vars: dict[str, str] = Field( + default_factory=dict, + description="Additional ``[Definition]`` variables.", + ) + init_vars: dict[str, str] = Field( + default_factory=dict, + description="``[Init]`` runtime parameters.", + ) + + +class AssignActionRequest(BaseModel): + """Payload for ``POST /api/config/jails/{jail_name}/action``.""" + + model_config = ConfigDict(strict=True) + + action_name: str = Field( + ..., + description="Action base name to add to the jail (e.g. ``iptables-multiport``).", + ) + params: dict[str, str] = Field( + default_factory=dict, + description=( + "Optional per-jail action parameters written as " + "``action_name[key=value, ...]`` in the jail config." + ), + ) + + # --------------------------------------------------------------------------- # Jail file config models (Task 6.1) # --------------------------------------------------------------------------- diff --git a/backend/app/routers/config.py b/backend/app/routers/config.py index c81c906..e39a496 100644 --- a/backend/app/routers/config.py +++ b/backend/app/routers/config.py @@ -10,6 +10,8 @@ global settings, test regex patterns, add log paths, and preview log files. * ``POST /api/config/jails/{name}/activate`` — activate an inactive jail * ``POST /api/config/jails/{name}/deactivate`` — deactivate an active jail * ``POST /api/config/jails/{name}/filter`` — assign a filter to a jail +* ``POST /api/config/jails/{name}/action`` — add an action to a jail +* ``DELETE /api/config/jails/{name}/action/{action_name}`` — remove an action from a jail * ``GET /api/config/global`` — global fail2ban settings * ``PUT /api/config/global`` — update global settings * ``POST /api/config/reload`` — reload fail2ban @@ -21,6 +23,11 @@ global settings, test regex patterns, add log paths, and preview log files. * ``PUT /api/config/filters/{name}`` — update a filter's .local override * ``POST /api/config/filters`` — create a new user-defined filter * ``DELETE /api/config/filters/{name}`` — delete a filter's .local file +* ``GET /api/config/actions`` — list all actions with active/inactive status +* ``GET /api/config/actions/{name}`` — full parsed detail for one action +* ``PUT /api/config/actions/{name}`` — update an action's .local override +* ``POST /api/config/actions`` — create a new user-defined action +* ``DELETE /api/config/actions/{name}`` — delete an action's .local file """ from __future__ import annotations @@ -31,8 +38,13 @@ from fastapi import APIRouter, HTTPException, Path, Query, Request, status from app.dependencies import AuthDep from app.models.config import ( + ActionConfig, + ActionCreateRequest, + ActionListResponse, + ActionUpdateRequest, ActivateJailRequest, AddLogPathRequest, + AssignActionRequest, AssignFilterRequest, FilterConfig, FilterCreateRequest, @@ -54,6 +66,10 @@ from app.models.config import ( ) from app.services import config_file_service, config_service, jail_service from app.services.config_file_service import ( + ActionAlreadyExistsError, + ActionNameError, + ActionNotFoundError, + ActionReadonlyError, ConfigWriteError, FilterAlreadyExistsError, FilterInvalidRegexError, @@ -968,3 +984,338 @@ async def assign_filter_to_jail( detail=f"Failed to write jail override: {exc}", ) from exc + +# --------------------------------------------------------------------------- +# Action discovery endpoints (Task 3.1) +# --------------------------------------------------------------------------- + +_ActionNamePath = Annotated[ + str, + Path(description="Action base name, e.g. ``iptables`` or ``iptables.conf``."), +] + + +def _action_not_found(name: str) -> HTTPException: + return HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Action not found: {name!r}", + ) + + +@router.get( + "/actions", + response_model=ActionListResponse, + summary="List all available actions with active/inactive status", +) +async def list_actions( + request: Request, + _auth: AuthDep, +) -> ActionListResponse: + """Return all actions discovered in ``action.d/`` with active/inactive status. + + Scans ``{config_dir}/action.d/`` for ``.conf`` files, merges any + corresponding ``.local`` overrides, and cross-references each action's + name against the ``action`` fields of currently running jails to determine + whether it is active. + + Active actions (those used by at least one running jail) are sorted to the + top of the list; inactive actions follow. Both groups are sorted + alphabetically within themselves. + + Args: + request: FastAPI request object. + _auth: Validated session — enforces authentication. + + Returns: + :class:`~app.models.config.ActionListResponse` with all discovered + actions. + """ + config_dir: str = request.app.state.settings.fail2ban_config_dir + socket_path: str = request.app.state.settings.fail2ban_socket + result = await config_file_service.list_actions(config_dir, socket_path) + result.actions.sort(key=lambda a: (not a.active, a.name.lower())) + return result + + +@router.get( + "/actions/{name}", + response_model=ActionConfig, + summary="Return full parsed detail for a single action", +) +async def get_action( + request: Request, + _auth: AuthDep, + name: _ActionNamePath, +) -> ActionConfig: + """Return the full parsed configuration and active/inactive status for one action. + + Reads ``{config_dir}/action.d/{name}.conf``, merges any corresponding + ``.local`` override, and annotates the result with ``active``, + ``used_by_jails``, ``source_file``, and ``has_local_override``. + + Args: + request: FastAPI request object. + _auth: Validated session — enforces authentication. + name: Action base name (with or without ``.conf`` extension). + + Returns: + :class:`~app.models.config.ActionConfig`. + + Raises: + HTTPException: 404 if the action is not found in ``action.d/``. + """ + config_dir: str = request.app.state.settings.fail2ban_config_dir + socket_path: str = request.app.state.settings.fail2ban_socket + try: + return await config_file_service.get_action(config_dir, socket_path, name) + except ActionNotFoundError: + raise _action_not_found(name) from None + + +# --------------------------------------------------------------------------- +# Action write endpoints (Task 3.2) +# --------------------------------------------------------------------------- + + +@router.put( + "/actions/{name}", + response_model=ActionConfig, + summary="Update an action's .local override with new lifecycle command values", +) +async def update_action( + request: Request, + _auth: AuthDep, + name: _ActionNamePath, + body: ActionUpdateRequest, + reload: bool = Query(default=False, description="Reload fail2ban after writing."), +) -> ActionConfig: + """Update an action's ``[Definition]`` fields by writing a ``.local`` override. + + Only non-``null`` fields in the request body are written. The original + ``.conf`` file is never modified. + + Args: + request: FastAPI request object. + _auth: Validated session. + name: Action base name (with or without ``.conf`` extension). + body: Partial update — lifecycle commands and ``[Init]`` parameters. + reload: When ``true``, trigger a fail2ban reload after writing. + + Returns: + Updated :class:`~app.models.config.ActionConfig`. + + Raises: + HTTPException: 400 if *name* contains invalid characters. + HTTPException: 404 if the action does not exist. + HTTPException: 500 if writing the ``.local`` file fails. + """ + config_dir: str = request.app.state.settings.fail2ban_config_dir + socket_path: str = request.app.state.settings.fail2ban_socket + try: + return await config_file_service.update_action( + config_dir, socket_path, name, body, do_reload=reload + ) + except ActionNameError as exc: + raise _bad_request(str(exc)) from exc + except ActionNotFoundError: + raise _action_not_found(name) from None + except ConfigWriteError as exc: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to write action override: {exc}", + ) from exc + + +@router.post( + "/actions", + response_model=ActionConfig, + status_code=status.HTTP_201_CREATED, + summary="Create a new user-defined action", +) +async def create_action( + request: Request, + _auth: AuthDep, + body: ActionCreateRequest, + reload: bool = Query(default=False, description="Reload fail2ban after creating."), +) -> ActionConfig: + """Create a new user-defined action at ``action.d/{name}.local``. + + Returns 409 if a ``.conf`` or ``.local`` for the requested name already + exists. + + Args: + request: FastAPI request object. + _auth: Validated session. + body: Action name and ``[Definition]`` lifecycle fields. + reload: When ``true``, trigger a fail2ban reload after creating. + + Returns: + :class:`~app.models.config.ActionConfig` for the new action. + + Raises: + HTTPException: 400 if the name contains invalid characters. + HTTPException: 409 if the action already exists. + HTTPException: 500 if writing fails. + """ + config_dir: str = request.app.state.settings.fail2ban_config_dir + socket_path: str = request.app.state.settings.fail2ban_socket + try: + return await config_file_service.create_action( + config_dir, socket_path, body, do_reload=reload + ) + except ActionNameError as exc: + raise _bad_request(str(exc)) from exc + except ActionAlreadyExistsError as exc: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=f"Action {exc.name!r} already exists.", + ) from exc + except ConfigWriteError as exc: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to write action: {exc}", + ) from exc + + +@router.delete( + "/actions/{name}", + status_code=status.HTTP_204_NO_CONTENT, + summary="Delete a user-created action's .local file", +) +async def delete_action( + request: Request, + _auth: AuthDep, + name: _ActionNamePath, +) -> None: + """Delete a user-created action's ``.local`` override file. + + Shipped ``.conf``-only actions cannot be deleted (returns 409). When + both a ``.conf`` and a ``.local`` exist, only the ``.local`` is removed. + + Args: + request: FastAPI request object. + _auth: Validated session. + name: Action base name. + + Raises: + HTTPException: 400 if *name* contains invalid characters. + HTTPException: 404 if the action does not exist. + HTTPException: 409 if the action is a shipped default (conf-only). + HTTPException: 500 if deletion fails. + """ + config_dir: str = request.app.state.settings.fail2ban_config_dir + try: + await config_file_service.delete_action(config_dir, name) + except ActionNameError as exc: + raise _bad_request(str(exc)) from exc + except ActionNotFoundError: + raise _action_not_found(name) from None + except ActionReadonlyError as exc: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=str(exc), + ) from exc + except ConfigWriteError as exc: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to delete action: {exc}", + ) from exc + + +@router.post( + "/jails/{name}/action", + status_code=status.HTTP_204_NO_CONTENT, + summary="Add an action to a jail", +) +async def assign_action_to_jail( + request: Request, + _auth: AuthDep, + name: _NamePath, + body: AssignActionRequest, + reload: bool = Query(default=False, description="Reload fail2ban after assigning."), +) -> None: + """Append an action entry to the jail's ``.local`` config. + + Existing keys in the jail's ``.local`` file are preserved. If the file + does not exist it is created. The action is not duplicated if it is + already present. + + Args: + request: FastAPI request object. + _auth: Validated session. + name: Jail name. + body: Action to add plus optional per-jail parameters. + reload: When ``true``, trigger a fail2ban reload after writing. + + Raises: + HTTPException: 400 if *name* or *action_name* contain invalid characters. + HTTPException: 404 if the jail or action does not exist. + HTTPException: 500 if writing fails. + """ + config_dir: str = request.app.state.settings.fail2ban_config_dir + socket_path: str = request.app.state.settings.fail2ban_socket + try: + await config_file_service.assign_action_to_jail( + config_dir, socket_path, name, body, do_reload=reload + ) + except (JailNameError, ActionNameError) as exc: + raise _bad_request(str(exc)) from exc + except JailNotFoundInConfigError: + raise _not_found(name) from None + except ActionNotFoundError as exc: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Action not found: {exc.name!r}", + ) from exc + except ConfigWriteError as exc: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to write jail override: {exc}", + ) from exc + + +@router.delete( + "/jails/{name}/action/{action_name}", + status_code=status.HTTP_204_NO_CONTENT, + summary="Remove an action from a jail", +) +async def remove_action_from_jail( + request: Request, + _auth: AuthDep, + name: _NamePath, + action_name: Annotated[str, Path(description="Action base name to remove.")], + reload: bool = Query(default=False, description="Reload fail2ban after removing."), +) -> None: + """Remove an action from the jail's ``.local`` config. + + If the jail has no ``.local`` file or the action is not listed there, + the call is silently idempotent. + + Args: + request: FastAPI request object. + _auth: Validated session. + name: Jail name. + action_name: Base name of the action to remove. + reload: When ``true``, trigger a fail2ban reload after writing. + + Raises: + HTTPException: 400 if *name* or *action_name* contain invalid characters. + HTTPException: 404 if the jail is not found in config files. + HTTPException: 500 if writing fails. + """ + config_dir: str = request.app.state.settings.fail2ban_config_dir + socket_path: str = request.app.state.settings.fail2ban_socket + try: + await config_file_service.remove_action_from_jail( + config_dir, socket_path, name, action_name, do_reload=reload + ) + except (JailNameError, ActionNameError) as exc: + raise _bad_request(str(exc)) from exc + except JailNotFoundInConfigError: + raise _not_found(name) from None + except ConfigWriteError as exc: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to write jail override: {exc}", + ) from exc + diff --git a/backend/app/services/config_file_service.py b/backend/app/services/config_file_service.py index d05b63c..e98c2d6 100644 --- a/backend/app/services/config_file_service.py +++ b/backend/app/services/config_file_service.py @@ -33,7 +33,13 @@ from typing import Any import structlog from app.models.config import ( + ActionConfig, + ActionConfigUpdate, + ActionCreateRequest, + ActionListResponse, + ActionUpdateRequest, ActivateJailRequest, + AssignActionRequest, AssignFilterRequest, FilterConfig, FilterConfigUpdate, @@ -1504,3 +1510,952 @@ async def assign_filter_to_jail( reload=do_reload, ) + +# --------------------------------------------------------------------------- +# Action discovery helpers (Task 3.1) +# --------------------------------------------------------------------------- + +# Allowlist pattern for action names used in path construction. +_SAFE_ACTION_NAME_RE: re.Pattern[str] = re.compile( + r"^[A-Za-z0-9][A-Za-z0-9._-]{0,127}$" +) + + +class ActionNotFoundError(Exception): + """Raised when the requested action name is not found in ``action.d/``.""" + + def __init__(self, name: str) -> None: + """Initialise with the action name that was not found. + + Args: + name: The action name that could not be located. + """ + self.name: str = name + super().__init__(f"Action not found: {name!r}") + + +class ActionAlreadyExistsError(Exception): + """Raised when trying to create an action whose ``.conf`` or ``.local`` already exists.""" + + def __init__(self, name: str) -> None: + """Initialise with the action name that already exists. + + Args: + name: The action name that already exists. + """ + self.name: str = name + super().__init__(f"Action already exists: {name!r}") + + +class ActionReadonlyError(Exception): + """Raised when trying to delete a shipped ``.conf`` action with no ``.local`` override.""" + + def __init__(self, name: str) -> None: + """Initialise with the action name that cannot be deleted. + + Args: + name: The action name that is read-only (shipped ``.conf`` only). + """ + self.name: str = name + super().__init__( + f"Action {name!r} is a shipped default (.conf only); " + "only user-created .local files can be deleted." + ) + + +class ActionNameError(Exception): + """Raised when an action name contains invalid characters.""" + + +def _safe_action_name(name: str) -> str: + """Validate *name* and return it unchanged or raise :class:`ActionNameError`. + + Args: + name: Proposed action name (without extension). + + Returns: + The name unchanged if valid. + + Raises: + ActionNameError: If *name* contains unsafe characters. + """ + if not _SAFE_ACTION_NAME_RE.match(name): + raise ActionNameError( + f"Action name {name!r} contains invalid characters. " + "Only alphanumeric characters, hyphens, underscores, and dots are " + "allowed; must start with an alphanumeric character." + ) + return name + + +def _build_action_to_jails_map( + all_jails: dict[str, dict[str, str]], + active_names: set[str], +) -> dict[str, list[str]]: + """Return a mapping of action base name → list of active jail names. + + Iterates over every jail whose name is in *active_names*, resolves each + entry in its ``action`` config key to an action base name (stripping + ``[…]`` parameter blocks), and records the jail against each base name. + + Args: + all_jails: Merged jail config dict — ``{jail_name: {key: value}}``. + active_names: Set of jail names currently running in fail2ban. + + Returns: + ``{action_base_name: [jail_name, …]}``. + """ + mapping: dict[str, list[str]] = {} + for jail_name, settings in all_jails.items(): + if jail_name not in active_names: + continue + raw_action = settings.get("action", "") + if not raw_action: + continue + for line in raw_action.splitlines(): + stripped = line.strip() + if not stripped or stripped.startswith("#"): + continue + # Strip optional [key=value] parameter block to get the base name. + bracket = stripped.find("[") + base = stripped[:bracket].strip() if bracket != -1 else stripped + if base: + mapping.setdefault(base, []).append(jail_name) + return mapping + + +def _parse_actions_sync( + action_d: Path, +) -> list[tuple[str, str, str, bool, str]]: + """Synchronously scan ``action.d/`` and return per-action tuples. + + Each tuple contains: + + - ``name`` — action base name (``"iptables"``). + - ``filename`` — actual filename (``"iptables.conf"``). + - ``content`` — merged file content (``conf`` overridden by ``local``). + - ``has_local`` — whether a ``.local`` override exists alongside a ``.conf``. + - ``source_path`` — absolute path to the primary (``conf``) source file, or + to the ``.local`` file for user-created (local-only) actions. + + Also discovers ``.local``-only files (user-created actions with no + corresponding ``.conf``). + + Args: + action_d: Path to the ``action.d`` directory. + + Returns: + List of ``(name, filename, content, has_local, source_path)`` tuples, + sorted by name. + """ + if not action_d.is_dir(): + log.warning("action_d_not_found", path=str(action_d)) + return [] + + conf_names: set[str] = set() + results: list[tuple[str, str, str, bool, str]] = [] + + # ---- .conf-based actions (with optional .local override) ---------------- + for conf_path in sorted(action_d.glob("*.conf")): + if not conf_path.is_file(): + continue + name = conf_path.stem + filename = conf_path.name + conf_names.add(name) + local_path = conf_path.with_suffix(".local") + has_local = local_path.is_file() + + try: + content = conf_path.read_text(encoding="utf-8") + except OSError as exc: + log.warning( + "action_read_error", name=name, path=str(conf_path), error=str(exc) + ) + continue + + if has_local: + try: + local_content = local_path.read_text(encoding="utf-8") + content = content + "\n" + local_content + except OSError as exc: + log.warning( + "action_local_read_error", + name=name, + path=str(local_path), + error=str(exc), + ) + + results.append((name, filename, content, has_local, str(conf_path))) + + # ---- .local-only actions (user-created, no corresponding .conf) ---------- + for local_path in sorted(action_d.glob("*.local")): + if not local_path.is_file(): + continue + name = local_path.stem + if name in conf_names: + continue + try: + content = local_path.read_text(encoding="utf-8") + except OSError as exc: + log.warning( + "action_local_read_error", + name=name, + path=str(local_path), + error=str(exc), + ) + continue + results.append((name, local_path.name, content, False, str(local_path))) + + results.sort(key=lambda t: t[0]) + log.debug("actions_scanned", count=len(results), action_d=str(action_d)) + return results + + +def _append_jail_action_sync( + config_dir: Path, + jail_name: str, + action_entry: str, +) -> None: + """Append an action entry to the ``action`` key in ``jail.d/{jail_name}.local``. + + If the ``.local`` file already contains an ``action`` key under the jail + section, the new entry is appended as an additional line (multi-line + configparser format) unless it is already present. If no ``action`` key + exists, one is created. + + Args: + config_dir: The fail2ban configuration root directory. + jail_name: Validated jail name. + action_entry: Full action string including any ``[…]`` parameters. + + Raises: + ConfigWriteError: If writing fails. + """ + jail_d = config_dir / "jail.d" + try: + jail_d.mkdir(parents=True, exist_ok=True) + except OSError as exc: + raise ConfigWriteError( + f"Cannot create jail.d directory: {exc}" + ) from exc + + local_path = jail_d / f"{jail_name}.local" + + parser = _build_parser() + if local_path.is_file(): + try: + parser.read(str(local_path), encoding="utf-8") + except (configparser.Error, OSError) as exc: + log.warning( + "jail_local_read_for_update_error", + jail=jail_name, + error=str(exc), + ) + + if not parser.has_section(jail_name): + parser.add_section(jail_name) + + existing_raw = parser.get(jail_name, "action") if parser.has_option(jail_name, "action") else "" + existing_lines = [ + line.strip() + for line in existing_raw.splitlines() + if line.strip() and not line.strip().startswith("#") + ] + + # Extract base names from existing entries for duplicate checking. + def _base(entry: str) -> str: + bracket = entry.find("[") + return entry[:bracket].strip() if bracket != -1 else entry.strip() + + new_base = _base(action_entry) + if not any(_base(e) == new_base for e in existing_lines): + existing_lines.append(action_entry) + + if existing_lines: + # configparser multi-line: continuation lines start with whitespace. + new_value = existing_lines[0] + "".join( + f"\n {line}" for line in existing_lines[1:] + ) + parser.set(jail_name, "action", new_value) + else: + parser.set(jail_name, "action", action_entry) + + buf = io.StringIO() + buf.write("# Managed by BanGUI — do not edit manually\n\n") + parser.write(buf) + content = buf.getvalue() + + try: + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + dir=jail_d, + delete=False, + suffix=".tmp", + ) as tmp: + tmp.write(content) + tmp_name = tmp.name + os.replace(tmp_name, local_path) + except OSError as exc: + with contextlib.suppress(OSError): + os.unlink(tmp_name) # noqa: F821 + raise ConfigWriteError( + f"Failed to write {local_path}: {exc}" + ) from exc + + log.info( + "jail_action_appended", + jail=jail_name, + action=action_entry, + path=str(local_path), + ) + + +def _remove_jail_action_sync( + config_dir: Path, + jail_name: str, + action_name: str, +) -> None: + """Remove an action entry from the ``action`` key in ``jail.d/{jail_name}.local``. + + Reads the ``.local`` file, removes any ``action`` entries whose base name + matches *action_name*, and writes the result back atomically. If no + ``.local`` file exists, this is a no-op. + + Args: + config_dir: The fail2ban configuration root directory. + jail_name: Validated jail name. + action_name: Base name of the action to remove (without ``[…]``). + + Raises: + ConfigWriteError: If writing fails. + """ + jail_d = config_dir / "jail.d" + local_path = jail_d / f"{jail_name}.local" + + if not local_path.is_file(): + return + + parser = _build_parser() + try: + parser.read(str(local_path), encoding="utf-8") + except (configparser.Error, OSError) as exc: + log.warning( + "jail_local_read_for_update_error", + jail=jail_name, + error=str(exc), + ) + return + + if not parser.has_section(jail_name) or not parser.has_option(jail_name, "action"): + return + + existing_raw = parser.get(jail_name, "action") + existing_lines = [ + line.strip() + for line in existing_raw.splitlines() + if line.strip() and not line.strip().startswith("#") + ] + + def _base(entry: str) -> str: + bracket = entry.find("[") + return entry[:bracket].strip() if bracket != -1 else entry.strip() + + filtered = [e for e in existing_lines if _base(e) != action_name] + + if len(filtered) == len(existing_lines): + # Action was not found — silently return (idempotent). + return + + if filtered: + new_value = filtered[0] + "".join( + f"\n {line}" for line in filtered[1:] + ) + parser.set(jail_name, "action", new_value) + else: + parser.remove_option(jail_name, "action") + + buf = io.StringIO() + buf.write("# Managed by BanGUI — do not edit manually\n\n") + parser.write(buf) + content = buf.getvalue() + + try: + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + dir=jail_d, + delete=False, + suffix=".tmp", + ) as tmp: + tmp.write(content) + tmp_name = tmp.name + os.replace(tmp_name, local_path) + except OSError as exc: + with contextlib.suppress(OSError): + os.unlink(tmp_name) # noqa: F821 + raise ConfigWriteError( + f"Failed to write {local_path}: {exc}" + ) from exc + + log.info( + "jail_action_removed", + jail=jail_name, + action=action_name, + path=str(local_path), + ) + + +def _write_action_local_sync(action_d: Path, name: str, content: str) -> None: + """Write *content* to ``action.d/{name}.local`` atomically. + + The write is atomic: content is written to a temp file first, then + renamed into place. The ``action.d/`` directory is created if absent. + + Args: + action_d: Path to the ``action.d`` directory. + name: Validated action base name (used as filename stem). + content: Full serialized action content to write. + + Raises: + ConfigWriteError: If writing fails. + """ + try: + action_d.mkdir(parents=True, exist_ok=True) + except OSError as exc: + raise ConfigWriteError( + f"Cannot create action.d directory: {exc}" + ) from exc + + local_path = action_d / f"{name}.local" + try: + with tempfile.NamedTemporaryFile( + mode="w", + encoding="utf-8", + dir=action_d, + delete=False, + suffix=".tmp", + ) as tmp: + tmp.write(content) + tmp_name = tmp.name + os.replace(tmp_name, local_path) + except OSError as exc: + with contextlib.suppress(OSError): + os.unlink(tmp_name) # noqa: F821 + raise ConfigWriteError( + f"Failed to write {local_path}: {exc}" + ) from exc + + log.info("action_local_written", action=name, path=str(local_path)) + + +# --------------------------------------------------------------------------- +# Public API — action discovery (Task 3.1) +# --------------------------------------------------------------------------- + + +async def list_actions( + config_dir: str, + socket_path: str, +) -> ActionListResponse: + """Return all available actions from ``action.d/`` with active/inactive status. + + Scans ``{config_dir}/action.d/`` for ``.conf`` files, merges any + corresponding ``.local`` overrides, parses each file into an + :class:`~app.models.config.ActionConfig`, and cross-references with the + currently running jails to determine which actions are active. + + An action is considered *active* when its base name appears in the + ``action`` field of at least one currently running jail. + + Args: + config_dir: Absolute path to the fail2ban configuration directory. + socket_path: Path to the fail2ban Unix domain socket. + + Returns: + :class:`~app.models.config.ActionListResponse` with all actions + sorted alphabetically, active ones carrying non-empty + ``used_by_jails`` lists. + """ + action_d = Path(config_dir) / "action.d" + loop = asyncio.get_event_loop() + + raw_actions: list[tuple[str, str, str, bool, str]] = await loop.run_in_executor( + None, _parse_actions_sync, action_d + ) + + all_jails_result, active_names = await asyncio.gather( + loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)), + _get_active_jail_names(socket_path), + ) + all_jails, _source_files = all_jails_result + + action_to_jails = _build_action_to_jails_map(all_jails, active_names) + + actions: list[ActionConfig] = [] + for name, filename, content, has_local, source_path in raw_actions: + cfg = conffile_parser.parse_action_file( + content, name=name, filename=filename + ) + used_by = sorted(action_to_jails.get(name, [])) + actions.append( + ActionConfig( + name=cfg.name, + filename=cfg.filename, + before=cfg.before, + after=cfg.after, + actionstart=cfg.actionstart, + actionstop=cfg.actionstop, + actioncheck=cfg.actioncheck, + actionban=cfg.actionban, + actionunban=cfg.actionunban, + actionflush=cfg.actionflush, + definition_vars=cfg.definition_vars, + init_vars=cfg.init_vars, + active=len(used_by) > 0, + used_by_jails=used_by, + source_file=source_path, + has_local_override=has_local, + ) + ) + + log.info("actions_listed", total=len(actions), active=sum(1 for a in actions if a.active)) + return ActionListResponse(actions=actions, total=len(actions)) + + +async def get_action( + config_dir: str, + socket_path: str, + name: str, +) -> ActionConfig: + """Return a single action from ``action.d/`` with active/inactive status. + + Reads ``{config_dir}/action.d/{name}.conf``, merges any ``.local`` + override, and enriches the parsed :class:`~app.models.config.ActionConfig` + with ``active``, ``used_by_jails``, ``source_file``, and + ``has_local_override``. + + Args: + config_dir: Absolute path to the fail2ban configuration directory. + socket_path: Path to the fail2ban Unix domain socket. + name: Action base name (e.g. ``"iptables"`` or ``"iptables.conf"``). + + Returns: + :class:`~app.models.config.ActionConfig` with status fields populated. + + Raises: + ActionNotFoundError: If no ``{name}.conf`` or ``{name}.local`` file + exists in ``action.d/``. + """ + if name.endswith(".conf"): + base_name = name[:-5] + elif name.endswith(".local"): + base_name = name[:-6] + else: + base_name = name + + action_d = Path(config_dir) / "action.d" + conf_path = action_d / f"{base_name}.conf" + local_path = action_d / f"{base_name}.local" + loop = asyncio.get_event_loop() + + def _read() -> tuple[str, bool, str]: + """Read action content and return (content, has_local_override, source_path).""" + has_local = local_path.is_file() + if conf_path.is_file(): + content = conf_path.read_text(encoding="utf-8") + if has_local: + try: + content += "\n" + local_path.read_text(encoding="utf-8") + except OSError as exc: + log.warning( + "action_local_read_error", + name=base_name, + path=str(local_path), + error=str(exc), + ) + return content, has_local, str(conf_path) + elif has_local: + content = local_path.read_text(encoding="utf-8") + return content, False, str(local_path) + else: + raise ActionNotFoundError(base_name) + + content, has_local, source_path = await loop.run_in_executor(None, _read) + + cfg = conffile_parser.parse_action_file( + content, name=base_name, filename=f"{base_name}.conf" + ) + + all_jails_result, active_names = await asyncio.gather( + loop.run_in_executor(None, _parse_jails_sync, Path(config_dir)), + _get_active_jail_names(socket_path), + ) + all_jails, _source_files = all_jails_result + action_to_jails = _build_action_to_jails_map(all_jails, active_names) + + used_by = sorted(action_to_jails.get(base_name, [])) + log.info("action_fetched", name=base_name, active=len(used_by) > 0) + return ActionConfig( + name=cfg.name, + filename=cfg.filename, + before=cfg.before, + after=cfg.after, + actionstart=cfg.actionstart, + actionstop=cfg.actionstop, + actioncheck=cfg.actioncheck, + actionban=cfg.actionban, + actionunban=cfg.actionunban, + actionflush=cfg.actionflush, + definition_vars=cfg.definition_vars, + init_vars=cfg.init_vars, + active=len(used_by) > 0, + used_by_jails=used_by, + source_file=source_path, + has_local_override=has_local, + ) + + +# --------------------------------------------------------------------------- +# Public API — action write operations (Task 3.2) +# --------------------------------------------------------------------------- + + +async def update_action( + config_dir: str, + socket_path: str, + name: str, + req: ActionUpdateRequest, + do_reload: bool = False, +) -> ActionConfig: + """Update an action's ``.local`` override with new lifecycle command values. + + Reads the current merged configuration for *name* (``conf`` + any existing + ``local``), applies the non-``None`` fields in *req* on top of it, and + writes the resulting definition to ``action.d/{name}.local``. The + original ``.conf`` file is never modified. + + Args: + config_dir: Absolute path to the fail2ban configuration directory. + socket_path: Path to the fail2ban Unix domain socket. + name: Action base name (e.g. ``"iptables"`` or ``"iptables.conf"``). + req: Partial update — only non-``None`` fields are applied. + do_reload: When ``True``, trigger a full fail2ban reload after writing. + + Returns: + :class:`~app.models.config.ActionConfig` reflecting the updated state. + + Raises: + ActionNameError: If *name* contains invalid characters. + ActionNotFoundError: If no ``{name}.conf`` or ``{name}.local`` exists. + ConfigWriteError: If writing the ``.local`` file fails. + """ + base_name = name[:-5] if name.endswith((".conf", ".local")) else name + _safe_action_name(base_name) + + current = await get_action(config_dir, socket_path, base_name) + + update = ActionConfigUpdate( + actionstart=req.actionstart, + actionstop=req.actionstop, + actioncheck=req.actioncheck, + actionban=req.actionban, + actionunban=req.actionunban, + actionflush=req.actionflush, + definition_vars=req.definition_vars, + init_vars=req.init_vars, + ) + + merged = conffile_parser.merge_action_update(current, update) + content = conffile_parser.serialize_action_config(merged) + + action_d = Path(config_dir) / "action.d" + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, _write_action_local_sync, action_d, base_name, content) + + if do_reload: + try: + await jail_service.reload_all(socket_path) + except Exception as exc: # noqa: BLE001 + log.warning( + "reload_after_action_update_failed", + action=base_name, + error=str(exc), + ) + + log.info("action_updated", action=base_name, reload=do_reload) + return await get_action(config_dir, socket_path, base_name) + + +async def create_action( + config_dir: str, + socket_path: str, + req: ActionCreateRequest, + do_reload: bool = False, +) -> ActionConfig: + """Create a brand-new user-defined action in ``action.d/{name}.local``. + + No ``.conf`` is written; fail2ban loads ``.local`` files directly. If a + ``.conf`` or ``.local`` file already exists for the requested name, an + :class:`ActionAlreadyExistsError` is raised. + + Args: + config_dir: Absolute path to the fail2ban configuration directory. + socket_path: Path to the fail2ban Unix domain socket. + req: Action name and definition fields. + do_reload: When ``True``, trigger a full fail2ban reload after writing. + + Returns: + :class:`~app.models.config.ActionConfig` for the newly created action. + + Raises: + ActionNameError: If ``req.name`` contains invalid characters. + ActionAlreadyExistsError: If a ``.conf`` or ``.local`` already exists. + ConfigWriteError: If writing fails. + """ + _safe_action_name(req.name) + + action_d = Path(config_dir) / "action.d" + conf_path = action_d / f"{req.name}.conf" + local_path = action_d / f"{req.name}.local" + + def _check_not_exists() -> None: + if conf_path.is_file() or local_path.is_file(): + raise ActionAlreadyExistsError(req.name) + + loop = asyncio.get_event_loop() + await loop.run_in_executor(None, _check_not_exists) + + cfg = ActionConfig( + name=req.name, + filename=f"{req.name}.local", + actionstart=req.actionstart, + actionstop=req.actionstop, + actioncheck=req.actioncheck, + actionban=req.actionban, + actionunban=req.actionunban, + actionflush=req.actionflush, + definition_vars=req.definition_vars, + init_vars=req.init_vars, + ) + content = conffile_parser.serialize_action_config(cfg) + + await loop.run_in_executor(None, _write_action_local_sync, action_d, req.name, content) + + if do_reload: + try: + await jail_service.reload_all(socket_path) + except Exception as exc: # noqa: BLE001 + log.warning( + "reload_after_action_create_failed", + action=req.name, + error=str(exc), + ) + + log.info("action_created", action=req.name, reload=do_reload) + return await get_action(config_dir, socket_path, req.name) + + +async def delete_action( + config_dir: str, + name: str, +) -> None: + """Delete a user-created action's ``.local`` file. + + Deletion rules: + - If only a ``.conf`` file exists (shipped default, no user override) → + :class:`ActionReadonlyError`. + - If a ``.local`` file exists (whether or not a ``.conf`` also exists) → + only the ``.local`` file is deleted. + - If neither file exists → :class:`ActionNotFoundError`. + + Args: + config_dir: Absolute path to the fail2ban configuration directory. + name: Action base name (e.g. ``"iptables"``). + + Raises: + ActionNameError: If *name* contains invalid characters. + ActionNotFoundError: If no action file is found for *name*. + ActionReadonlyError: If only a shipped ``.conf`` exists (no ``.local``). + ConfigWriteError: If deletion of the ``.local`` file fails. + """ + base_name = name[:-5] if name.endswith((".conf", ".local")) else name + _safe_action_name(base_name) + + action_d = Path(config_dir) / "action.d" + conf_path = action_d / f"{base_name}.conf" + local_path = action_d / f"{base_name}.local" + + loop = asyncio.get_event_loop() + + def _delete() -> None: + has_conf = conf_path.is_file() + has_local = local_path.is_file() + + if not has_conf and not has_local: + raise ActionNotFoundError(base_name) + + if has_conf and not has_local: + raise ActionReadonlyError(base_name) + + try: + local_path.unlink() + except OSError as exc: + raise ConfigWriteError( + f"Failed to delete {local_path}: {exc}" + ) from exc + + log.info("action_local_deleted", action=base_name, path=str(local_path)) + + await loop.run_in_executor(None, _delete) + + +async def assign_action_to_jail( + config_dir: str, + socket_path: str, + jail_name: str, + req: AssignActionRequest, + do_reload: bool = False, +) -> None: + """Add an action to a jail by updating the jail's ``.local`` file. + + Appends ``{req.action_name}[{params}]`` (or just ``{req.action_name}`` when + no params are given) to the ``action`` key in the ``[{jail_name}]`` section + of ``jail.d/{jail_name}.local``. If the action is already listed it is not + duplicated. If the ``.local`` file does not exist it is created. + + Args: + config_dir: Absolute path to the fail2ban configuration directory. + socket_path: Path to the fail2ban Unix domain socket. + jail_name: Name of the jail to update. + req: Request containing the action name and optional parameters. + do_reload: When ``True``, trigger a full fail2ban reload after writing. + + Raises: + JailNameError: If *jail_name* contains invalid characters. + ActionNameError: If ``req.action_name`` contains invalid characters. + JailNotFoundInConfigError: If *jail_name* is not defined in any config + file. + ActionNotFoundError: If ``req.action_name`` does not exist in + ``action.d/``. + ConfigWriteError: If writing fails. + """ + _safe_jail_name(jail_name) + _safe_action_name(req.action_name) + + loop = asyncio.get_event_loop() + + all_jails, _src = await loop.run_in_executor( + None, _parse_jails_sync, Path(config_dir) + ) + if jail_name not in all_jails: + raise JailNotFoundInConfigError(jail_name) + + action_d = Path(config_dir) / "action.d" + + def _check_action() -> None: + if ( + not (action_d / f"{req.action_name}.conf").is_file() + and not (action_d / f"{req.action_name}.local").is_file() + ): + raise ActionNotFoundError(req.action_name) + + await loop.run_in_executor(None, _check_action) + + # Build the action string with optional parameters. + if req.params: + param_str = ", ".join(f"{k}={v}" for k, v in sorted(req.params.items())) + action_entry = f"{req.action_name}[{param_str}]" + else: + action_entry = req.action_name + + await loop.run_in_executor( + None, + _append_jail_action_sync, + Path(config_dir), + jail_name, + action_entry, + ) + + if do_reload: + try: + await jail_service.reload_all(socket_path) + except Exception as exc: # noqa: BLE001 + log.warning( + "reload_after_assign_action_failed", + jail=jail_name, + action=req.action_name, + error=str(exc), + ) + + log.info( + "action_assigned_to_jail", + jail=jail_name, + action=req.action_name, + reload=do_reload, + ) + + +async def remove_action_from_jail( + config_dir: str, + socket_path: str, + jail_name: str, + action_name: str, + do_reload: bool = False, +) -> None: + """Remove an action from a jail's ``.local`` config. + + Reads ``jail.d/{jail_name}.local``, removes the line(s) that reference + ``{action_name}`` from the ``action`` key (including any ``[…]`` parameter + blocks), and writes the file back atomically. + + Args: + config_dir: Absolute path to the fail2ban configuration directory. + socket_path: Path to the fail2ban Unix domain socket. + jail_name: Name of the jail to update. + action_name: Base name of the action to remove. + do_reload: When ``True``, trigger a full fail2ban reload after writing. + + Raises: + JailNameError: If *jail_name* contains invalid characters. + ActionNameError: If *action_name* contains invalid characters. + JailNotFoundInConfigError: If *jail_name* is not defined in any config. + ConfigWriteError: If writing fails. + """ + _safe_jail_name(jail_name) + _safe_action_name(action_name) + + loop = asyncio.get_event_loop() + + all_jails, _src = await loop.run_in_executor( + None, _parse_jails_sync, Path(config_dir) + ) + if jail_name not in all_jails: + raise JailNotFoundInConfigError(jail_name) + + await loop.run_in_executor( + None, + _remove_jail_action_sync, + Path(config_dir), + jail_name, + action_name, + ) + + if do_reload: + try: + await jail_service.reload_all(socket_path) + except Exception as exc: # noqa: BLE001 + log.warning( + "reload_after_remove_action_failed", + jail=jail_name, + action=action_name, + error=str(exc), + ) + + log.info( + "action_removed_from_jail", + jail=jail_name, + action=action_name, + reload=do_reload, + ) + diff --git a/backend/tests/test_routers/test_config.py b/backend/tests/test_routers/test_config.py index 706fd21..4bcd47d 100644 --- a/backend/tests/test_routers/test_config.py +++ b/backend/tests/test_routers/test_config.py @@ -1289,3 +1289,425 @@ class TestAssignFilterToJail: ).post("/api/config/jails/sshd/filter", json={"filter_name": "sshd"}) assert resp.status_code == 401 + +# =========================================================================== +# Action router tests (Task 3.1 + 3.2) +# =========================================================================== + + +@pytest.mark.asyncio +class TestListActionsRouter: + async def test_200_returns_action_list(self, config_client: AsyncClient) -> None: + from app.models.config import ActionConfig, ActionListResponse + + mock_action = ActionConfig( + name="iptables", + filename="iptables.conf", + actionban="/sbin/iptables -I f2b- 1 -s -j DROP", + ) + mock_response = ActionListResponse(actions=[mock_action], total=1) + + with patch( + "app.routers.config.config_file_service.list_actions", + AsyncMock(return_value=mock_response), + ): + resp = await config_client.get("/api/config/actions") + + assert resp.status_code == 200 + data = resp.json() + assert data["total"] == 1 + assert data["actions"][0]["name"] == "iptables" + + async def test_active_sorted_first(self, config_client: AsyncClient) -> None: + from app.models.config import ActionConfig, ActionListResponse + + inactive = ActionConfig(name="aaa", filename="aaa.conf", active=False) + active = ActionConfig(name="zzz", filename="zzz.conf", active=True) + mock_response = ActionListResponse(actions=[inactive, active], total=2) + + with patch( + "app.routers.config.config_file_service.list_actions", + AsyncMock(return_value=mock_response), + ): + resp = await config_client.get("/api/config/actions") + + data = resp.json() + assert data["actions"][0]["name"] == "zzz" # active comes first + + async def test_401_when_unauthenticated(self, config_client: AsyncClient) -> None: + resp = await AsyncClient( + transport=ASGITransport(app=config_client._transport.app), # type: ignore[attr-defined] + base_url="http://test", + ).get("/api/config/actions") + assert resp.status_code == 401 + + +@pytest.mark.asyncio +class TestGetActionRouter: + async def test_200_returns_action(self, config_client: AsyncClient) -> None: + from app.models.config import ActionConfig + + mock_action = ActionConfig( + name="iptables", + filename="iptables.conf", + actionban="/sbin/iptables -I f2b- 1 -s -j DROP", + ) + + with patch( + "app.routers.config.config_file_service.get_action", + AsyncMock(return_value=mock_action), + ): + resp = await config_client.get("/api/config/actions/iptables") + + assert resp.status_code == 200 + assert resp.json()["name"] == "iptables" + + async def test_404_when_not_found(self, config_client: AsyncClient) -> None: + from app.services.config_file_service import ActionNotFoundError + + with patch( + "app.routers.config.config_file_service.get_action", + AsyncMock(side_effect=ActionNotFoundError("missing")), + ): + resp = await config_client.get("/api/config/actions/missing") + + assert resp.status_code == 404 + + async def test_401_when_unauthenticated(self, config_client: AsyncClient) -> None: + resp = await AsyncClient( + transport=ASGITransport(app=config_client._transport.app), # type: ignore[attr-defined] + base_url="http://test", + ).get("/api/config/actions/iptables") + assert resp.status_code == 401 + + +@pytest.mark.asyncio +class TestUpdateActionRouter: + async def test_200_returns_updated_action(self, config_client: AsyncClient) -> None: + from app.models.config import ActionConfig + + updated = ActionConfig( + name="iptables", + filename="iptables.local", + actionban="echo ban", + ) + + with patch( + "app.routers.config.config_file_service.update_action", + AsyncMock(return_value=updated), + ): + resp = await config_client.put( + "/api/config/actions/iptables", + json={"actionban": "echo ban"}, + ) + + assert resp.status_code == 200 + assert resp.json()["actionban"] == "echo ban" + + async def test_404_when_not_found(self, config_client: AsyncClient) -> None: + from app.services.config_file_service import ActionNotFoundError + + with patch( + "app.routers.config.config_file_service.update_action", + AsyncMock(side_effect=ActionNotFoundError("missing")), + ): + resp = await config_client.put( + "/api/config/actions/missing", json={} + ) + + assert resp.status_code == 404 + + async def test_400_for_bad_name(self, config_client: AsyncClient) -> None: + from app.services.config_file_service import ActionNameError + + with patch( + "app.routers.config.config_file_service.update_action", + AsyncMock(side_effect=ActionNameError()), + ): + resp = await config_client.put( + "/api/config/actions/badname", json={} + ) + + assert resp.status_code == 400 + + async def test_401_when_unauthenticated(self, config_client: AsyncClient) -> None: + resp = await AsyncClient( + transport=ASGITransport(app=config_client._transport.app), # type: ignore[attr-defined] + base_url="http://test", + ).put("/api/config/actions/iptables", json={}) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +class TestCreateActionRouter: + async def test_201_returns_created_action(self, config_client: AsyncClient) -> None: + from app.models.config import ActionConfig + + created = ActionConfig( + name="custom", + filename="custom.local", + actionban="echo ban", + ) + + with patch( + "app.routers.config.config_file_service.create_action", + AsyncMock(return_value=created), + ): + resp = await config_client.post( + "/api/config/actions", + json={"name": "custom", "actionban": "echo ban"}, + ) + + assert resp.status_code == 201 + assert resp.json()["name"] == "custom" + + async def test_409_when_already_exists(self, config_client: AsyncClient) -> None: + from app.services.config_file_service import ActionAlreadyExistsError + + with patch( + "app.routers.config.config_file_service.create_action", + AsyncMock(side_effect=ActionAlreadyExistsError("iptables")), + ): + resp = await config_client.post( + "/api/config/actions", + json={"name": "iptables"}, + ) + + assert resp.status_code == 409 + + async def test_400_for_bad_name(self, config_client: AsyncClient) -> None: + from app.services.config_file_service import ActionNameError + + with patch( + "app.routers.config.config_file_service.create_action", + AsyncMock(side_effect=ActionNameError()), + ): + resp = await config_client.post( + "/api/config/actions", + json={"name": "badname"}, + ) + + assert resp.status_code == 400 + + async def test_401_when_unauthenticated(self, config_client: AsyncClient) -> None: + resp = await AsyncClient( + transport=ASGITransport(app=config_client._transport.app), # type: ignore[attr-defined] + base_url="http://test", + ).post("/api/config/actions", json={"name": "x"}) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +class TestDeleteActionRouter: + async def test_204_on_delete(self, config_client: AsyncClient) -> None: + with patch( + "app.routers.config.config_file_service.delete_action", + AsyncMock(return_value=None), + ): + resp = await config_client.delete("/api/config/actions/custom") + + assert resp.status_code == 204 + + async def test_404_when_not_found(self, config_client: AsyncClient) -> None: + from app.services.config_file_service import ActionNotFoundError + + with patch( + "app.routers.config.config_file_service.delete_action", + AsyncMock(side_effect=ActionNotFoundError("missing")), + ): + resp = await config_client.delete("/api/config/actions/missing") + + assert resp.status_code == 404 + + async def test_409_when_readonly(self, config_client: AsyncClient) -> None: + from app.services.config_file_service import ActionReadonlyError + + with patch( + "app.routers.config.config_file_service.delete_action", + AsyncMock(side_effect=ActionReadonlyError("iptables")), + ): + resp = await config_client.delete("/api/config/actions/iptables") + + assert resp.status_code == 409 + + async def test_400_for_bad_name(self, config_client: AsyncClient) -> None: + from app.services.config_file_service import ActionNameError + + with patch( + "app.routers.config.config_file_service.delete_action", + AsyncMock(side_effect=ActionNameError()), + ): + resp = await config_client.delete("/api/config/actions/badname") + + assert resp.status_code == 400 + + async def test_401_when_unauthenticated(self, config_client: AsyncClient) -> None: + resp = await AsyncClient( + transport=ASGITransport(app=config_client._transport.app), # type: ignore[attr-defined] + base_url="http://test", + ).delete("/api/config/actions/iptables") + assert resp.status_code == 401 + + +@pytest.mark.asyncio +class TestAssignActionToJailRouter: + async def test_204_on_success(self, config_client: AsyncClient) -> None: + with patch( + "app.routers.config.config_file_service.assign_action_to_jail", + AsyncMock(return_value=None), + ): + resp = await config_client.post( + "/api/config/jails/sshd/action", + json={"action_name": "iptables"}, + ) + + assert resp.status_code == 204 + + async def test_404_when_jail_not_found(self, config_client: AsyncClient) -> None: + from app.services.config_file_service import JailNotFoundInConfigError + + with patch( + "app.routers.config.config_file_service.assign_action_to_jail", + AsyncMock(side_effect=JailNotFoundInConfigError("missing")), + ): + resp = await config_client.post( + "/api/config/jails/missing/action", + json={"action_name": "iptables"}, + ) + + assert resp.status_code == 404 + + async def test_404_when_action_not_found(self, config_client: AsyncClient) -> None: + from app.services.config_file_service import ActionNotFoundError + + with patch( + "app.routers.config.config_file_service.assign_action_to_jail", + AsyncMock(side_effect=ActionNotFoundError("missing")), + ): + resp = await config_client.post( + "/api/config/jails/sshd/action", + json={"action_name": "missing"}, + ) + + assert resp.status_code == 404 + + async def test_400_for_bad_jail_name(self, config_client: AsyncClient) -> None: + from app.services.config_file_service import JailNameError + + with patch( + "app.routers.config.config_file_service.assign_action_to_jail", + AsyncMock(side_effect=JailNameError()), + ): + resp = await config_client.post( + "/api/config/jails/badjailname/action", + json={"action_name": "iptables"}, + ) + + assert resp.status_code == 400 + + async def test_400_for_bad_action_name(self, config_client: AsyncClient) -> None: + from app.services.config_file_service import ActionNameError + + with patch( + "app.routers.config.config_file_service.assign_action_to_jail", + AsyncMock(side_effect=ActionNameError()), + ): + resp = await config_client.post( + "/api/config/jails/sshd/action", + json={"action_name": "badaction"}, + ) + + assert resp.status_code == 400 + + async def test_reload_param_passed(self, config_client: AsyncClient) -> None: + with patch( + "app.routers.config.config_file_service.assign_action_to_jail", + AsyncMock(return_value=None), + ) as mock_assign: + resp = await config_client.post( + "/api/config/jails/sshd/action?reload=true", + json={"action_name": "iptables"}, + ) + + assert resp.status_code == 204 + assert mock_assign.call_args.kwargs.get("do_reload") is True + + async def test_401_when_unauthenticated(self, config_client: AsyncClient) -> None: + resp = await AsyncClient( + transport=ASGITransport(app=config_client._transport.app), # type: ignore[attr-defined] + base_url="http://test", + ).post("/api/config/jails/sshd/action", json={"action_name": "iptables"}) + assert resp.status_code == 401 + + +@pytest.mark.asyncio +class TestRemoveActionFromJailRouter: + async def test_204_on_success(self, config_client: AsyncClient) -> None: + with patch( + "app.routers.config.config_file_service.remove_action_from_jail", + AsyncMock(return_value=None), + ): + resp = await config_client.delete( + "/api/config/jails/sshd/action/iptables" + ) + + assert resp.status_code == 204 + + async def test_404_when_jail_not_found(self, config_client: AsyncClient) -> None: + from app.services.config_file_service import JailNotFoundInConfigError + + with patch( + "app.routers.config.config_file_service.remove_action_from_jail", + AsyncMock(side_effect=JailNotFoundInConfigError("missing")), + ): + resp = await config_client.delete( + "/api/config/jails/missing/action/iptables" + ) + + assert resp.status_code == 404 + + async def test_400_for_bad_jail_name(self, config_client: AsyncClient) -> None: + from app.services.config_file_service import JailNameError + + with patch( + "app.routers.config.config_file_service.remove_action_from_jail", + AsyncMock(side_effect=JailNameError()), + ): + resp = await config_client.delete( + "/api/config/jails/badjailname/action/iptables" + ) + + assert resp.status_code == 400 + + async def test_400_for_bad_action_name(self, config_client: AsyncClient) -> None: + from app.services.config_file_service import ActionNameError + + with patch( + "app.routers.config.config_file_service.remove_action_from_jail", + AsyncMock(side_effect=ActionNameError()), + ): + resp = await config_client.delete( + "/api/config/jails/sshd/action/badactionname" + ) + + assert resp.status_code == 400 + + async def test_reload_param_passed(self, config_client: AsyncClient) -> None: + with patch( + "app.routers.config.config_file_service.remove_action_from_jail", + AsyncMock(return_value=None), + ) as mock_rm: + resp = await config_client.delete( + "/api/config/jails/sshd/action/iptables?reload=true" + ) + + assert resp.status_code == 204 + assert mock_rm.call_args.kwargs.get("do_reload") is True + + async def test_401_when_unauthenticated(self, config_client: AsyncClient) -> None: + resp = await AsyncClient( + transport=ASGITransport(app=config_client._transport.app), # type: ignore[attr-defined] + base_url="http://test", + ).delete("/api/config/jails/sshd/action/iptables") + assert resp.status_code == 401 + diff --git a/backend/tests/test_services/test_config_file_service.py b/backend/tests/test_services/test_config_file_service.py index 56a94cb..2e4cc6f 100644 --- a/backend/tests/test_services/test_config_file_service.py +++ b/backend/tests/test_services/test_config_file_service.py @@ -1487,3 +1487,1007 @@ class TestAssignFilterToJail: mock_reload.assert_awaited_once() + +# =========================================================================== +# Action service tests (Task 3.1 + 3.2) +# =========================================================================== + +_ACTION_CONF = """\ +[Definition] + +actionstart = /sbin/iptables -N f2b- +actionstop = /sbin/iptables -D INPUT -j f2b- +actionban = /sbin/iptables -I f2b- 1 -s -j DROP +actionunban = /sbin/iptables -D f2b- -s -j DROP + +[Init] + +name = default +port = ssh +protocol = tcp +""" + +_ACTION_CONF_MINIMAL = """\ +[Definition] + +actionban = echo ban +actionunban = echo unban +""" + + +# --------------------------------------------------------------------------- +# _safe_action_name +# --------------------------------------------------------------------------- + + +class TestSafeActionName: + def test_valid_simple(self) -> None: + from app.services.config_file_service import _safe_action_name + + assert _safe_action_name("iptables") == "iptables" + + def test_valid_with_hyphen(self) -> None: + from app.services.config_file_service import _safe_action_name + + assert _safe_action_name("iptables-multiport") == "iptables-multiport" + + def test_valid_with_dot(self) -> None: + from app.services.config_file_service import _safe_action_name + + assert _safe_action_name("my.action") == "my.action" + + def test_invalid_path_traversal(self) -> None: + from app.services.config_file_service import ActionNameError, _safe_action_name + + with pytest.raises(ActionNameError): + _safe_action_name("../evil") + + def test_invalid_empty(self) -> None: + from app.services.config_file_service import ActionNameError, _safe_action_name + + with pytest.raises(ActionNameError): + _safe_action_name("") + + def test_invalid_slash(self) -> None: + from app.services.config_file_service import ActionNameError, _safe_action_name + + with pytest.raises(ActionNameError): + _safe_action_name("a/b") + + +# --------------------------------------------------------------------------- +# _build_action_to_jails_map +# --------------------------------------------------------------------------- + + +class TestBuildActionToJailsMap: + def test_active_jail_maps_to_action(self) -> None: + from app.services.config_file_service import _build_action_to_jails_map + + result = _build_action_to_jails_map( + {"sshd": {"action": "iptables-multiport"}}, {"sshd"} + ) + assert result == {"iptables-multiport": ["sshd"]} + + def test_inactive_jail_not_included(self) -> None: + from app.services.config_file_service import _build_action_to_jails_map + + result = _build_action_to_jails_map( + {"sshd": {"action": "iptables-multiport"}}, set() + ) + assert result == {} + + def test_multiple_actions_per_jail(self) -> None: + from app.services.config_file_service import _build_action_to_jails_map + + result = _build_action_to_jails_map( + {"sshd": {"action": "iptables-multiport\niptables-ipset"}}, {"sshd"} + ) + assert "iptables-multiport" in result + assert "iptables-ipset" in result + + def test_parameter_block_stripped(self) -> None: + from app.services.config_file_service import _build_action_to_jails_map + + result = _build_action_to_jails_map( + {"sshd": {"action": "iptables[port=ssh, protocol=tcp]"}}, {"sshd"} + ) + assert "iptables" in result + + def test_multiple_jails_sharing_action(self) -> None: + from app.services.config_file_service import _build_action_to_jails_map + + all_jails = { + "sshd": {"action": "iptables-multiport"}, + "apache": {"action": "iptables-multiport"}, + } + result = _build_action_to_jails_map(all_jails, {"sshd", "apache"}) + assert sorted(result["iptables-multiport"]) == ["apache", "sshd"] + + def test_jail_with_no_action_key(self) -> None: + from app.services.config_file_service import _build_action_to_jails_map + + result = _build_action_to_jails_map({"sshd": {}}, {"sshd"}) + assert result == {} + + +# --------------------------------------------------------------------------- +# _parse_actions_sync +# --------------------------------------------------------------------------- + + +class TestParseActionsSync: + def test_returns_empty_for_missing_dir(self, tmp_path: Path) -> None: + from app.services.config_file_service import _parse_actions_sync + + result = _parse_actions_sync(tmp_path / "nonexistent") + assert result == [] + + def test_single_action_returned(self, tmp_path: Path) -> None: + from app.services.config_file_service import _parse_actions_sync + + action_d = tmp_path / "action.d" + _write(action_d / "iptables.conf", _ACTION_CONF) + + result = _parse_actions_sync(action_d) + + assert len(result) == 1 + name, filename, content, has_local, source_path = result[0] + assert name == "iptables" + assert filename == "iptables.conf" + assert "actionban" in content + assert has_local is False + assert source_path.endswith("iptables.conf") + + def test_local_override_detected(self, tmp_path: Path) -> None: + from app.services.config_file_service import _parse_actions_sync + + action_d = tmp_path / "action.d" + _write(action_d / "iptables.conf", _ACTION_CONF) + _write(action_d / "iptables.local", "[Definition]\n# override\n") + + result = _parse_actions_sync(action_d) + + _, _, _, has_local, _ = result[0] + assert has_local is True + + def test_local_content_merged_into_content(self, tmp_path: Path) -> None: + from app.services.config_file_service import _parse_actions_sync + + action_d = tmp_path / "action.d" + _write(action_d / "iptables.conf", _ACTION_CONF) + _write(action_d / "iptables.local", "[Definition]\n# local override tweak\n") + + result = _parse_actions_sync(action_d) + + _, _, content, _, _ = result[0] + assert "local override tweak" in content + + def test_local_only_action_included(self, tmp_path: Path) -> None: + from app.services.config_file_service import _parse_actions_sync + + action_d = tmp_path / "action.d" + _write(action_d / "custom.local", _ACTION_CONF_MINIMAL) + + result = _parse_actions_sync(action_d) + + assert len(result) == 1 + name, filename, _, has_local, source_path = result[0] + assert name == "custom" + assert filename == "custom.local" + assert has_local is False # local-only: no .conf to pair with + assert source_path.endswith("custom.local") + + def test_sorted_alphabetically(self, tmp_path: Path) -> None: + from app.services.config_file_service import _parse_actions_sync + + action_d = tmp_path / "action.d" + for n in ("zzz", "aaa", "mmm"): + _write(action_d / f"{n}.conf", _ACTION_CONF_MINIMAL) + + result = _parse_actions_sync(action_d) + + assert [r[0] for r in result] == ["aaa", "mmm", "zzz"] + + +# --------------------------------------------------------------------------- +# list_actions +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestListActions: + async def test_returns_all_actions(self, tmp_path: Path) -> None: + from app.services.config_file_service import list_actions + + action_d = tmp_path / "action.d" + _write(action_d / "iptables.conf", _ACTION_CONF) + _write(action_d / "cloudflare.conf", _ACTION_CONF_MINIMAL) + + with patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ): + result = await list_actions(str(tmp_path), "/fake.sock") + + assert result.total == 2 + names = {a.name for a in result.actions} + assert "iptables" in names + assert "cloudflare" in names + + async def test_active_flag_set_for_used_action(self, tmp_path: Path) -> None: + from app.services.config_file_service import list_actions + + action_d = tmp_path / "action.d" + _write(action_d / "iptables.conf", _ACTION_CONF) + _write(tmp_path / "jail.conf", JAIL_CONF) + + all_jails_with_action = { + "sshd": { + "enabled": "true", + "filter": "sshd", + "action": "iptables", + }, + "apache-auth": {"enabled": "false"}, + } + + with ( + patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value={"sshd"}), + ), + patch( + "app.services.config_file_service._parse_jails_sync", + return_value=(all_jails_with_action, {}), + ), + ): + result = await list_actions(str(tmp_path), "/fake.sock") + + iptables = next(a for a in result.actions if a.name == "iptables") + assert iptables.active is True + assert "sshd" in iptables.used_by_jails + + async def test_inactive_action_has_active_false(self, tmp_path: Path) -> None: + from app.services.config_file_service import list_actions + + action_d = tmp_path / "action.d" + _write(action_d / "cloudflare.conf", _ACTION_CONF_MINIMAL) + + with patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ): + result = await list_actions(str(tmp_path), "/fake.sock") + + cf = next(a for a in result.actions if a.name == "cloudflare") + assert cf.active is False + assert cf.used_by_jails == [] + + async def test_has_local_override_detected(self, tmp_path: Path) -> None: + from app.services.config_file_service import list_actions + + action_d = tmp_path / "action.d" + _write(action_d / "iptables.conf", _ACTION_CONF) + _write(action_d / "iptables.local", "[Definition]\n") + + with patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ): + result = await list_actions(str(tmp_path), "/fake.sock") + + ipt = next(a for a in result.actions if a.name == "iptables") + assert ipt.has_local_override is True + + async def test_empty_action_d_returns_empty(self, tmp_path: Path) -> None: + from app.services.config_file_service import list_actions + + with patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ): + result = await list_actions(str(tmp_path), "/fake.sock") + + assert result.actions == [] + assert result.total == 0 + + async def test_total_matches_actions_count(self, tmp_path: Path) -> None: + from app.services.config_file_service import list_actions + + action_d = tmp_path / "action.d" + _write(action_d / "iptables.conf", _ACTION_CONF) + + with patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ): + result = await list_actions(str(tmp_path), "/fake.sock") + + assert result.total == len(result.actions) + + +# --------------------------------------------------------------------------- +# get_action +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestGetAction: + async def test_returns_action_config(self, tmp_path: Path) -> None: + from app.services.config_file_service import get_action + + action_d = tmp_path / "action.d" + _write(action_d / "iptables.conf", _ACTION_CONF) + + with patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ): + result = await get_action(str(tmp_path), "/fake.sock", "iptables") + + assert result.name == "iptables" + assert result.actionban is not None + assert "iptables" in (result.actionban or "") + + async def test_strips_conf_extension(self, tmp_path: Path) -> None: + from app.services.config_file_service import get_action + + action_d = tmp_path / "action.d" + _write(action_d / "iptables.conf", _ACTION_CONF) + + with patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ): + result = await get_action(str(tmp_path), "/fake.sock", "iptables.conf") + + assert result.name == "iptables" + + async def test_raises_for_unknown_action(self, tmp_path: Path) -> None: + from app.services.config_file_service import ActionNotFoundError, get_action + + with patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ), pytest.raises(ActionNotFoundError): + await get_action(str(tmp_path), "/fake.sock", "nonexistent") + + async def test_local_only_action_returned(self, tmp_path: Path) -> None: + from app.services.config_file_service import get_action + + action_d = tmp_path / "action.d" + _write(action_d / "custom.local", _ACTION_CONF_MINIMAL) + + with patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ): + result = await get_action(str(tmp_path), "/fake.sock", "custom") + + assert result.name == "custom" + + async def test_active_status_populated(self, tmp_path: Path) -> None: + from app.services.config_file_service import get_action + + action_d = tmp_path / "action.d" + _write(action_d / "iptables.conf", _ACTION_CONF) + + all_jails_with_action = { + "sshd": {"enabled": "true", "filter": "sshd", "action": "iptables"}, + } + + with ( + patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value={"sshd"}), + ), + patch( + "app.services.config_file_service._parse_jails_sync", + return_value=(all_jails_with_action, {}), + ), + ): + result = await get_action(str(tmp_path), "/fake.sock", "iptables") + + assert result.active is True + assert "sshd" in result.used_by_jails + + +# --------------------------------------------------------------------------- +# _write_action_local_sync +# --------------------------------------------------------------------------- + + +class TestWriteActionLocalSync: + def test_writes_file(self, tmp_path: Path) -> None: + from app.services.config_file_service import _write_action_local_sync + + action_d = tmp_path / "action.d" + action_d.mkdir() + _write_action_local_sync(action_d, "myaction", "[Definition]\n") + + local = action_d / "myaction.local" + assert local.is_file() + assert "[Definition]" in local.read_text() + + def test_creates_action_d_if_missing(self, tmp_path: Path) -> None: + from app.services.config_file_service import _write_action_local_sync + + action_d = tmp_path / "action.d" + _write_action_local_sync(action_d, "test", "[Definition]\n") + assert (action_d / "test.local").is_file() + + def test_overwrites_existing_file(self, tmp_path: Path) -> None: + from app.services.config_file_service import _write_action_local_sync + + action_d = tmp_path / "action.d" + action_d.mkdir() + (action_d / "myaction.local").write_text("old content") + + _write_action_local_sync(action_d, "myaction", "[Definition]\nnew=1\n") + + assert "new=1" in (action_d / "myaction.local").read_text() + assert "old content" not in (action_d / "myaction.local").read_text() + + +# --------------------------------------------------------------------------- +# update_action (Task 3.2) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestUpdateAction: + async def test_updates_actionban(self, tmp_path: Path) -> None: + from app.models.config import ActionUpdateRequest + from app.services.config_file_service import update_action + + action_d = tmp_path / "action.d" + _write(action_d / "iptables.conf", _ACTION_CONF) + + with patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ): + result = await update_action( + str(tmp_path), + "/fake.sock", + "iptables", + ActionUpdateRequest(actionban="echo ban "), + ) + + local = action_d / "iptables.local" + assert local.is_file() + assert "echo ban" in local.read_text() + assert result.name == "iptables" + + async def test_raises_not_found_for_unknown_action(self, tmp_path: Path) -> None: + from app.models.config import ActionUpdateRequest + from app.services.config_file_service import ActionNotFoundError, update_action + + with patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ), pytest.raises(ActionNotFoundError): + await update_action( + str(tmp_path), + "/fake.sock", + "nonexistent", + ActionUpdateRequest(), + ) + + async def test_raises_name_error_for_invalid_name(self, tmp_path: Path) -> None: + from app.models.config import ActionUpdateRequest + from app.services.config_file_service import ActionNameError, update_action + + with pytest.raises(ActionNameError): + await update_action( + str(tmp_path), + "/fake.sock", + "../evil", + ActionUpdateRequest(), + ) + + async def test_triggers_reload_when_requested(self, tmp_path: Path) -> None: + from app.models.config import ActionUpdateRequest + from app.services.config_file_service import update_action + + action_d = tmp_path / "action.d" + _write(action_d / "iptables.conf", _ACTION_CONF) + + with ( + patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ), + patch( + "app.services.config_file_service.jail_service.reload_all", + new=AsyncMock(), + ) as mock_reload, + ): + await update_action( + str(tmp_path), + "/fake.sock", + "iptables", + ActionUpdateRequest(actionban="echo ban "), + do_reload=True, + ) + + mock_reload.assert_awaited_once() + + +# --------------------------------------------------------------------------- +# create_action (Task 3.2) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestCreateAction: + async def test_creates_local_file(self, tmp_path: Path) -> None: + from app.models.config import ActionCreateRequest + from app.services.config_file_service import create_action + + with patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ): + result = await create_action( + str(tmp_path), + "/fake.sock", + ActionCreateRequest( + name="my-action", + actionban="echo ban ", + actionunban="echo unban ", + ), + ) + + local = tmp_path / "action.d" / "my-action.local" + assert local.is_file() + assert result.name == "my-action" + + async def test_raises_already_exists_for_conf(self, tmp_path: Path) -> None: + from app.models.config import ActionCreateRequest + from app.services.config_file_service import ( + ActionAlreadyExistsError, + create_action, + ) + + action_d = tmp_path / "action.d" + _write(action_d / "iptables.conf", _ACTION_CONF) + + with pytest.raises(ActionAlreadyExistsError): + await create_action( + str(tmp_path), + "/fake.sock", + ActionCreateRequest(name="iptables"), + ) + + async def test_raises_already_exists_for_local(self, tmp_path: Path) -> None: + from app.models.config import ActionCreateRequest + from app.services.config_file_service import ( + ActionAlreadyExistsError, + create_action, + ) + + action_d = tmp_path / "action.d" + _write(action_d / "custom.local", _ACTION_CONF_MINIMAL) + + with pytest.raises(ActionAlreadyExistsError): + await create_action( + str(tmp_path), + "/fake.sock", + ActionCreateRequest(name="custom"), + ) + + async def test_raises_name_error_for_invalid_name(self, tmp_path: Path) -> None: + from app.models.config import ActionCreateRequest + from app.services.config_file_service import ActionNameError, create_action + + with pytest.raises(ActionNameError): + await create_action( + str(tmp_path), + "/fake.sock", + ActionCreateRequest(name="../evil"), + ) + + async def test_triggers_reload_when_requested(self, tmp_path: Path) -> None: + from app.models.config import ActionCreateRequest + from app.services.config_file_service import create_action + + with ( + patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ), + patch( + "app.services.config_file_service.jail_service.reload_all", + new=AsyncMock(), + ) as mock_reload, + ): + await create_action( + str(tmp_path), + "/fake.sock", + ActionCreateRequest(name="new-action"), + do_reload=True, + ) + + mock_reload.assert_awaited_once() + + +# --------------------------------------------------------------------------- +# delete_action (Task 3.2) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestDeleteAction: + async def test_deletes_local_file(self, tmp_path: Path) -> None: + from app.services.config_file_service import delete_action + + action_d = tmp_path / "action.d" + _write(action_d / "iptables.conf", _ACTION_CONF) + _write(action_d / "iptables.local", "[Definition]\n") + + await delete_action(str(tmp_path), "iptables") + + assert not (action_d / "iptables.local").is_file() + assert (action_d / "iptables.conf").is_file() # original untouched + + async def test_raises_readonly_for_conf_only(self, tmp_path: Path) -> None: + from app.services.config_file_service import ActionReadonlyError, delete_action + + action_d = tmp_path / "action.d" + _write(action_d / "iptables.conf", _ACTION_CONF) + + with pytest.raises(ActionReadonlyError): + await delete_action(str(tmp_path), "iptables") + + async def test_raises_not_found_for_missing(self, tmp_path: Path) -> None: + from app.services.config_file_service import ActionNotFoundError, delete_action + + with pytest.raises(ActionNotFoundError): + await delete_action(str(tmp_path), "nonexistent") + + async def test_deletes_local_only_action(self, tmp_path: Path) -> None: + from app.services.config_file_service import delete_action + + action_d = tmp_path / "action.d" + _write(action_d / "custom.local", _ACTION_CONF_MINIMAL) + + await delete_action(str(tmp_path), "custom") + + assert not (action_d / "custom.local").is_file() + + async def test_raises_name_error_for_invalid_name(self, tmp_path: Path) -> None: + from app.services.config_file_service import ActionNameError, delete_action + + with pytest.raises(ActionNameError): + await delete_action(str(tmp_path), "../etc/evil") + + +# --------------------------------------------------------------------------- +# _append_jail_action_sync +# --------------------------------------------------------------------------- + + +class TestAppendJailActionSync: + def test_creates_local_with_action(self, tmp_path: Path) -> None: + from app.services.config_file_service import _append_jail_action_sync + + _append_jail_action_sync(tmp_path, "sshd", "iptables-multiport") + + local = tmp_path / "jail.d" / "sshd.local" + assert local.is_file() + assert "iptables-multiport" in local.read_text() + + def test_appends_to_existing_action_list(self, tmp_path: Path) -> None: + from app.services.config_file_service import _append_jail_action_sync + + jail_d = tmp_path / "jail.d" + _write(jail_d / "sshd.local", "[sshd]\naction = iptables-multiport\n") + + _append_jail_action_sync(tmp_path, "sshd", "cloudflare") + + content = (jail_d / "sshd.local").read_text() + assert "iptables-multiport" in content + assert "cloudflare" in content + + def test_does_not_duplicate_action(self, tmp_path: Path) -> None: + from app.services.config_file_service import _append_jail_action_sync + + jail_d = tmp_path / "jail.d" + _write(jail_d / "sshd.local", "[sshd]\naction = iptables-multiport\n") + + _append_jail_action_sync(tmp_path, "sshd", "iptables-multiport") + _append_jail_action_sync(tmp_path, "sshd", "iptables-multiport") + + content = (jail_d / "sshd.local").read_text() + # Should appear only once in the action list + assert content.count("iptables-multiport") == 1 + + def test_does_not_duplicate_when_params_differ(self, tmp_path: Path) -> None: + from app.services.config_file_service import _append_jail_action_sync + + jail_d = tmp_path / "jail.d" + _write( + jail_d / "sshd.local", + "[sshd]\naction = iptables[port=ssh]\n", + ) + + # Same base name, different params — should not duplicate. + _append_jail_action_sync(tmp_path, "sshd", "iptables[port=22]") + + content = (jail_d / "sshd.local").read_text() + assert content.count("iptables") == 1 + + +# --------------------------------------------------------------------------- +# _remove_jail_action_sync +# --------------------------------------------------------------------------- + + +class TestRemoveJailActionSync: + def test_removes_action_from_list(self, tmp_path: Path) -> None: + from app.services.config_file_service import _remove_jail_action_sync + + jail_d = tmp_path / "jail.d" + _write( + jail_d / "sshd.local", + "[sshd]\naction = iptables-multiport\n", + ) + + _remove_jail_action_sync(tmp_path, "sshd", "iptables-multiport") + + content = (jail_d / "sshd.local").read_text() + assert "iptables-multiport" not in content + + def test_removes_only_targeted_action(self, tmp_path: Path) -> None: + from app.services.config_file_service import ( + _append_jail_action_sync, + _remove_jail_action_sync, + ) + + jail_d = tmp_path / "jail.d" + jail_d.mkdir(parents=True, exist_ok=True) + _append_jail_action_sync(tmp_path, "sshd", "iptables-multiport") + _append_jail_action_sync(tmp_path, "sshd", "cloudflare") + + _remove_jail_action_sync(tmp_path, "sshd", "iptables-multiport") + + content = (jail_d / "sshd.local").read_text() + assert "iptables-multiport" not in content + assert "cloudflare" in content + + def test_is_noop_when_no_local_file(self, tmp_path: Path) -> None: + from app.services.config_file_service import _remove_jail_action_sync + + # Should not raise; no .local file to modify. + _remove_jail_action_sync(tmp_path, "sshd", "iptables-multiport") + + def test_is_noop_when_action_not_in_list(self, tmp_path: Path) -> None: + from app.services.config_file_service import _remove_jail_action_sync + + jail_d = tmp_path / "jail.d" + _write(jail_d / "sshd.local", "[sshd]\naction = cloudflare\n") + + _remove_jail_action_sync(tmp_path, "sshd", "iptables-multiport") + + content = (jail_d / "sshd.local").read_text() + assert "cloudflare" in content # untouched + + +# --------------------------------------------------------------------------- +# assign_action_to_jail (Task 3.2) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestAssignActionToJail: + async def test_creates_local_with_action(self, tmp_path: Path) -> None: + from app.models.config import AssignActionRequest + from app.services.config_file_service import assign_action_to_jail + + _write(tmp_path / "jail.conf", JAIL_CONF) + action_d = tmp_path / "action.d" + _write(action_d / "iptables.conf", _ACTION_CONF) + + with patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ): + await assign_action_to_jail( + str(tmp_path), + "/fake.sock", + "sshd", + AssignActionRequest(action_name="iptables"), + ) + + local = tmp_path / "jail.d" / "sshd.local" + assert local.is_file() + assert "iptables" in local.read_text() + + async def test_params_written_to_action_entry(self, tmp_path: Path) -> None: + from app.models.config import AssignActionRequest + from app.services.config_file_service import assign_action_to_jail + + _write(tmp_path / "jail.conf", JAIL_CONF) + action_d = tmp_path / "action.d" + _write(action_d / "iptables.conf", _ACTION_CONF) + + with patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ): + await assign_action_to_jail( + str(tmp_path), + "/fake.sock", + "sshd", + AssignActionRequest(action_name="iptables", params={"port": "ssh"}), + ) + + content = (tmp_path / "jail.d" / "sshd.local").read_text() + assert "port=ssh" in content + + async def test_raises_jail_not_found(self, tmp_path: Path) -> None: + from app.models.config import AssignActionRequest + from app.services.config_file_service import ( + JailNotFoundInConfigError, + assign_action_to_jail, + ) + + with pytest.raises(JailNotFoundInConfigError): + await assign_action_to_jail( + str(tmp_path), + "/fake.sock", + "nonexistent", + AssignActionRequest(action_name="iptables"), + ) + + async def test_raises_action_not_found(self, tmp_path: Path) -> None: + from app.models.config import AssignActionRequest + from app.services.config_file_service import ( + ActionNotFoundError, + assign_action_to_jail, + ) + + _write(tmp_path / "jail.conf", JAIL_CONF) + + with pytest.raises(ActionNotFoundError): + await assign_action_to_jail( + str(tmp_path), + "/fake.sock", + "sshd", + AssignActionRequest(action_name="nonexistent-action"), + ) + + async def test_raises_jail_name_error(self, tmp_path: Path) -> None: + from app.models.config import AssignActionRequest + from app.services.config_file_service import JailNameError, assign_action_to_jail + + with pytest.raises(JailNameError): + await assign_action_to_jail( + str(tmp_path), + "/fake.sock", + "../etc/evil", + AssignActionRequest(action_name="iptables"), + ) + + async def test_raises_action_name_error(self, tmp_path: Path) -> None: + from app.models.config import AssignActionRequest + from app.services.config_file_service import ( + ActionNameError, + assign_action_to_jail, + ) + + with pytest.raises(ActionNameError): + await assign_action_to_jail( + str(tmp_path), + "/fake.sock", + "sshd", + AssignActionRequest(action_name="../evil"), + ) + + async def test_triggers_reload_when_requested(self, tmp_path: Path) -> None: + from app.models.config import AssignActionRequest + from app.services.config_file_service import assign_action_to_jail + + _write(tmp_path / "jail.conf", JAIL_CONF) + action_d = tmp_path / "action.d" + _write(action_d / "iptables.conf", _ACTION_CONF) + + with ( + patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ), + patch( + "app.services.config_file_service.jail_service.reload_all", + new=AsyncMock(), + ) as mock_reload, + ): + await assign_action_to_jail( + str(tmp_path), + "/fake.sock", + "sshd", + AssignActionRequest(action_name="iptables"), + do_reload=True, + ) + + mock_reload.assert_awaited_once() + + +# --------------------------------------------------------------------------- +# remove_action_from_jail (Task 3.2) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +class TestRemoveActionFromJail: + async def test_removes_action_from_local(self, tmp_path: Path) -> None: + from app.services.config_file_service import remove_action_from_jail + + _write(tmp_path / "jail.conf", JAIL_CONF) + jail_d = tmp_path / "jail.d" + _write(jail_d / "sshd.local", "[sshd]\naction = iptables-multiport\n") + + with patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ): + await remove_action_from_jail( + str(tmp_path), "/fake.sock", "sshd", "iptables-multiport" + ) + + content = (jail_d / "sshd.local").read_text() + assert "iptables-multiport" not in content + + async def test_raises_jail_not_found(self, tmp_path: Path) -> None: + from app.services.config_file_service import ( + JailNotFoundInConfigError, + remove_action_from_jail, + ) + + with pytest.raises(JailNotFoundInConfigError): + await remove_action_from_jail( + str(tmp_path), "/fake.sock", "nonexistent", "iptables" + ) + + async def test_raises_jail_name_error(self, tmp_path: Path) -> None: + from app.services.config_file_service import JailNameError, remove_action_from_jail + + with pytest.raises(JailNameError): + await remove_action_from_jail( + str(tmp_path), "/fake.sock", "../evil", "iptables" + ) + + async def test_raises_action_name_error(self, tmp_path: Path) -> None: + from app.services.config_file_service import ActionNameError, remove_action_from_jail + + _write(tmp_path / "jail.conf", JAIL_CONF) + + with pytest.raises(ActionNameError): + await remove_action_from_jail( + str(tmp_path), "/fake.sock", "sshd", "../evil" + ) + + async def test_triggers_reload_when_requested(self, tmp_path: Path) -> None: + from app.services.config_file_service import remove_action_from_jail + + _write(tmp_path / "jail.conf", JAIL_CONF) + jail_d = tmp_path / "jail.d" + _write(jail_d / "sshd.local", "[sshd]\naction = iptables\n") + + with ( + patch( + "app.services.config_file_service._get_active_jail_names", + new=AsyncMock(return_value=set()), + ), + patch( + "app.services.config_file_service.jail_service.reload_all", + new=AsyncMock(), + ) as mock_reload, + ): + await remove_action_from_jail( + str(tmp_path), "/fake.sock", "sshd", "iptables", do_reload=True + ) + + mock_reload.assert_awaited_once() +