Files
BanGUI/backend/tests/test_services/test_jail_service.py
Lukas 2e221f6852 Refactor: Move module-level mutable flags to JailServiceState
TASK-004: Replace module-level mutable runtime flags in service layer with
injected state holder, eliminating hidden global state and improving testability
and synchronization boundaries.

Changes:
- Create JailServiceState dataclass in app/utils/runtime_state.py to hold
  backend capability cache and synchronization lock
- Add JailServiceState as a field in RuntimeState (with default_factory)
- Remove module-level _backend_cmd_supported and _backend_cmd_lock from
  jail_service.py
- Refactor _check_backend_cmd_supported() to accept state parameter
- Inject JailServiceState into list_jails() and _fetch_jail_summary() via
  parameters
- Add get_jail_service_state() dependency provider in app/dependencies.py
- Add JailServiceStateDep type alias for router injection
- Update jails router to receive and pass state to service functions
- Update all tests to use jail_service_state fixture and pass state to functions
- Remove duplicate _MAX_PAGE_SIZE constant definition
- Document mutable state management in Backend-Development.md
- Update Architecture.md to describe JailServiceState and state nesting pattern

Benefits:
- Eliminates global mutable state and associated race conditions
- Makes state visible to callers (not hidden in module scope)
- Enables test isolation (each test gets fresh state)
- Prepares codebase for multi-worker deployments (state can be extracted to
  shared backend)
- Synchronization boundaries are now explicit (state.get_backend_cmd_lock())

Compliance:
- All tests pass (17 passed in TestListJails, TestGetJail, TestLockInitialization)
- No ruff linting errors
- Type-safe: JailServiceState properly typed with asyncio.Lock, bool | None

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-04-27 18:42:52 +02:00

1154 lines
46 KiB
Python

"""Tests for jail_service functions."""
from __future__ import annotations
import asyncio
import contextlib
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
from app.exceptions import Fail2BanConnectionError
from app.models.ban import ActiveBanListResponse, JailBannedIpsResponse
from app.models.geo import GeoDetail, GeoInfo
from app.models.jail import JailDetailResponse, JailListResponse
from app.services import ban_service, jail_service
from app.services.jail_service import JailNotFoundError, JailOperationError
from app.utils import jail_socket
from app.utils.runtime_state import JailServiceState
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_SOCKET = "/fake/fail2ban.sock"
_JAIL_NAMES = "sshd, nginx"
def _make_global_status(names: str = _JAIL_NAMES) -> tuple[int, list[Any]]:
return (0, [("Number of jail", 2), ("Jail list", names)])
def _make_short_status(
banned: int = 2,
total_banned: int = 10,
failed: int = 3,
total_failed: int = 20,
) -> tuple[int, list[Any]]:
return (
0,
[
("Filter", [("Currently failed", failed), ("Total failed", total_failed)]),
("Actions", [("Currently banned", banned), ("Total banned", total_banned)]),
],
)
def _make_send(responses: dict[str, Any]) -> AsyncMock:
"""Build an ``AsyncMock`` for ``Fail2BanClient.send``.
Responses are keyed by the command joined with a pipe, e.g.
``"status"`` or ``"status|sshd|short"``.
"""
async def _side_effect(command: list[Any]) -> Any:
key = "|".join(str(c) for c in command)
if key in responses:
return responses[key]
# Fall back to partial key matching.
for resp_key, resp_value in responses.items():
if key.startswith(resp_key):
return resp_value
raise KeyError(f"Unexpected command key {key!r}")
return AsyncMock(side_effect=_side_effect)
def _patch_client(responses: dict[str, Any]) -> Any:
"""Return a ``patch`` context manager that mocks ``Fail2BanClient``."""
mock_send = _make_send(responses)
class _FakeClient:
def __init__(self, **_kw: Any) -> None:
self.send = mock_send
stack = contextlib.ExitStack()
stack.enter_context(patch("app.services.jail_service.Fail2BanClient", _FakeClient))
stack.enter_context(patch("app.services.ban_service.Fail2BanClient", _FakeClient))
stack.enter_context(patch("app.utils.jail_socket.Fail2BanClient", _FakeClient))
return stack
@pytest.fixture
def jail_service_state() -> JailServiceState:
"""Provide a fresh JailServiceState for each test."""
return JailServiceState()
# ---------------------------------------------------------------------------
# list_jails
# ---------------------------------------------------------------------------
class TestListJails:
"""Unit tests for :func:`~app.services.jail_service.list_jails`."""
async def test_returns_jail_list_response(self, jail_service_state: JailServiceState) -> None:
"""list_jails returns a JailListResponse."""
responses = {
"status": _make_global_status("sshd"),
"status|sshd|short": _make_short_status(),
"get|sshd|bantime": (0, 600),
"get|sshd|findtime": (0, 600),
"get|sshd|maxretry": (0, 5),
"get|sshd|backend": (0, "polling"),
"get|sshd|idle": (0, False),
}
with _patch_client(responses):
result = await jail_service.list_jails(_SOCKET, jail_service_state)
assert isinstance(result, JailListResponse)
assert result.total == 1
assert result.jails[0].name == "sshd"
async def test_empty_jail_list(self, jail_service_state: JailServiceState) -> None:
"""list_jails returns empty response when no jails are active."""
responses = {"status": (0, [("Number of jail", 0), ("Jail list", "")])}
with _patch_client(responses):
result = await jail_service.list_jails(_SOCKET, jail_service_state)
assert result.total == 0
assert result.jails == []
async def test_jail_status_populated(self, jail_service_state: JailServiceState) -> None:
"""list_jails populates JailStatus with failed/banned counters."""
responses = {
"status": _make_global_status("sshd"),
"status|sshd|short": _make_short_status(banned=5, total_banned=50),
"get|sshd|bantime": (0, 600),
"get|sshd|findtime": (0, 600),
"get|sshd|maxretry": (0, 5),
"get|sshd|backend": (0, "polling"),
"get|sshd|idle": (0, False),
}
with _patch_client(responses):
result = await jail_service.list_jails(_SOCKET, jail_service_state)
jail = result.jails[0]
assert jail.status is not None
assert jail.status.currently_banned == 5
assert jail.status.total_banned == 50
async def test_jail_config_populated(self, jail_service_state: JailServiceState) -> None:
"""list_jails populates ban_time, find_time, max_retry, backend."""
responses = {
"status": _make_global_status("sshd"),
"status|sshd|short": _make_short_status(),
"get|sshd|bantime": (0, 3600),
"get|sshd|findtime": (0, 300),
"get|sshd|maxretry": (0, 3),
"get|sshd|backend": (0, "systemd"),
"get|sshd|idle": (0, True),
}
with _patch_client(responses):
result = await jail_service.list_jails(_SOCKET, jail_service_state)
jail = result.jails[0]
assert jail.ban_time == 3600
assert jail.find_time == 300
assert jail.max_retry == 3
assert jail.backend == "systemd"
assert jail.idle is True
async def test_multiple_jails_returned(self, jail_service_state: JailServiceState) -> None:
"""list_jails fetches all jails listed in the global status."""
responses = {
"status": _make_global_status("sshd, nginx"),
"status|sshd|short": _make_short_status(),
"status|nginx|short": _make_short_status(banned=0),
"get|sshd|bantime": (0, 600),
"get|sshd|findtime": (0, 600),
"get|sshd|maxretry": (0, 5),
"get|sshd|backend": (0, "polling"),
"get|sshd|idle": (0, False),
"get|nginx|bantime": (0, 1800),
"get|nginx|findtime": (0, 600),
"get|nginx|maxretry": (0, 5),
"get|nginx|backend": (0, "polling"),
"get|nginx|idle": (0, False),
}
with _patch_client(responses):
result = await jail_service.list_jails(_SOCKET, jail_service_state)
assert result.total == 2
names = {j.name for j in result.jails}
assert names == {"sshd", "nginx"}
async def test_connection_error_propagates(self, jail_service_state: JailServiceState) -> None:
"""list_jails raises Fail2BanConnectionError when socket unreachable."""
async def _raise(*_: Any, **__: Any) -> None:
raise Fail2BanConnectionError("no socket", _SOCKET)
class _FailClient:
def __init__(self, **_kw: Any) -> None:
self.send = AsyncMock(side_effect=Fail2BanConnectionError("no socket", _SOCKET))
with patch("app.services.jail_service.Fail2BanClient", _FailClient), pytest.raises(Fail2BanConnectionError):
await jail_service.list_jails(_SOCKET, jail_service_state)
async def test_backend_idle_commands_unsupported(self, jail_service_state: JailServiceState) -> None:
"""list_jails handles unsupported backend and idle commands gracefully.
When the fail2ban daemon does not support get ... backend/idle commands,
list_jails should not send them, avoiding "Invalid command" errors in the
fail2ban log.
"""
# Reset the capability cache to test detection.
await jail_service_state.reset_backend_capability_cache()
responses = {
"status": _make_global_status("sshd"),
"status|sshd|short": _make_short_status(),
# Capability probe: get backend fails (command not supported).
"get|sshd|backend": (1, Exception("Invalid command (no get action or not yet implemented)")),
# Subsequent gets should still work.
"get|sshd|bantime": (0, 600),
"get|sshd|findtime": (0, 600),
"get|sshd|maxretry": (0, 5),
}
with _patch_client(responses):
result = await jail_service.list_jails(_SOCKET, jail_service_state)
# Verify the result uses the default values for backend and idle.
jail = result.jails[0]
assert jail.backend == "polling" # default
assert jail.idle is False # default
# Capability should now be cached as False.
assert jail_service_state.backend_cmd_supported is False
async def test_backend_idle_commands_supported(self, jail_service_state: JailServiceState) -> None:
"""list_jails detects and sends backend/idle commands when supported."""
# Reset the capability cache to test detection.
await jail_service_state.reset_backend_capability_cache()
responses = {
"status": _make_global_status("sshd"),
"status|sshd|short": _make_short_status(),
# Capability probe: get backend succeeds.
"get|sshd|backend": (0, "systemd"),
# All other commands.
"get|sshd|bantime": (0, 600),
"get|sshd|findtime": (0, 600),
"get|sshd|maxretry": (0, 5),
"get|sshd|idle": (0, True),
}
with _patch_client(responses):
result = await jail_service.list_jails(_SOCKET, jail_service_state)
# Verify real values are returned.
jail = result.jails[0]
assert jail.backend == "systemd" # real value
assert jail.idle is True # real value
# Capability should now be cached as True.
assert jail_service_state.backend_cmd_supported is True
async def test_backend_idle_commands_cached_after_first_probe(self, jail_service_state: JailServiceState) -> None:
"""list_jails caches capability result and reuses it across polling cycles."""
# Reset the capability cache.
await jail_service_state.reset_backend_capability_cache()
responses = {
"status": _make_global_status("sshd, nginx"),
# Probes happen once per polling cycle (for the first jail listed).
"status|sshd|short": _make_short_status(),
"status|nginx|short": _make_short_status(),
# Capability probe: backend is unsupported.
"get|sshd|backend": (1, Exception("Invalid command")),
# Subsequent jails do not trigger another probe; they use cached result.
# (The mock doesn't have get|nginx|backend because it shouldn't be called.)
"get|sshd|bantime": (0, 600),
"get|sshd|findtime": (0, 600),
"get|sshd|maxretry": (0, 5),
"get|nginx|bantime": (0, 600),
"get|nginx|findtime": (0, 600),
"get|nginx|maxretry": (0, 5),
}
with _patch_client(responses):
result = await jail_service.list_jails(_SOCKET, jail_service_state)
# Both jails should return default values (cached result is False).
for jail in result.jails:
assert jail.backend == "polling"
assert jail.idle is False
class TestLockInitialization:
"""Regression tests for asyncio lock creation in jail_service."""
async def test_reload_all_lock_is_lazy_initialised(self) -> None:
"""The reload-all lock should be created lazily on first use."""
jail_socket._reload_all_lock = None
lock = _ = jail_socket._get_reload_all_lock()
assert isinstance(lock, asyncio.Lock)
assert jail_socket._reload_all_lock is lock
async def test_backend_cmd_lock_is_lazy_initialised(self, jail_service_state: JailServiceState) -> None:
"""The backend capability probe lock should be created lazily on first use."""
# Ensure state starts with no lock.
jail_service_state.backend_cmd_lock = None
lock = jail_service_state.get_backend_cmd_lock()
assert isinstance(lock, asyncio.Lock)
assert jail_service_state.backend_cmd_lock is lock
class TestGetJail:
"""Unit tests for :func:`~app.services.jail_service.get_jail`."""
def _full_responses(self, name: str = "sshd") -> dict[str, Any]:
return {
f"status|{name}|short": _make_short_status(),
f"get|{name}|logpath": (0, ["/var/log/auth.log"]),
f"get|{name}|failregex": (0, ["^.*Failed.*from <HOST>"]),
f"get|{name}|ignoreregex": (0, []),
f"get|{name}|ignoreip": (0, ["127.0.0.1"]),
f"get|{name}|datepattern": (0, None),
f"get|{name}|logencoding": (0, "UTF-8"),
f"get|{name}|bantime": (0, 600),
f"get|{name}|findtime": (0, 600),
f"get|{name}|maxretry": (0, 5),
f"get|{name}|backend": (0, "polling"),
f"get|{name}|idle": (0, False),
f"get|{name}|actions": (0, ["iptables-multiport"]),
}
async def test_returns_jail_detail_response(self, jail_service_state: JailServiceState) -> None:
"""get_jail returns a JailDetailResponse."""
with _patch_client(self._full_responses()):
result = await jail_service.get_jail(_SOCKET, "sshd")
assert isinstance(result, JailDetailResponse)
assert result.jail.name == "sshd"
async def test_log_paths_parsed(self, jail_service_state: JailServiceState) -> None:
"""get_jail populates log_paths from fail2ban."""
with _patch_client(self._full_responses()):
result = await jail_service.get_jail(_SOCKET, "sshd")
assert result.jail.log_paths == ["/var/log/auth.log"]
async def test_fail_regex_parsed(self, jail_service_state: JailServiceState) -> None:
"""get_jail populates fail_regex list."""
with _patch_client(self._full_responses()):
result = await jail_service.get_jail(_SOCKET, "sshd")
assert "^.*Failed.*from <HOST>" in result.jail.fail_regex
async def test_ignore_ips_parsed(self, jail_service_state: JailServiceState) -> None:
"""get_jail populates ignore_ips list."""
with _patch_client(self._full_responses()):
result = await jail_service.get_jail(_SOCKET, "sshd")
assert "127.0.0.1" in result.jail.ignore_ips
async def test_actions_parsed(self, jail_service_state: JailServiceState) -> None:
"""get_jail populates actions list."""
with _patch_client(self._full_responses()):
result = await jail_service.get_jail(_SOCKET, "sshd")
assert result.jail.actions == ["iptables-multiport"]
async def test_jail_not_found_raises(self, jail_service_state: JailServiceState) -> None:
"""get_jail raises JailNotFoundError when jail is unknown."""
not_found_response = (1, Exception("Unknown jail: 'ghost'"))
with _patch_client({r"status|ghost|short": not_found_response}), pytest.raises(JailNotFoundError):
await jail_service.get_jail(_SOCKET, "ghost")
# ---------------------------------------------------------------------------
# Jail control commands
# ---------------------------------------------------------------------------
class TestJailControls:
"""Unit tests for start, stop, idle, reload commands."""
async def test_start_jail_success(self) -> None:
"""start_jail sends the start command without error."""
with _patch_client({"start|sshd": (0, None)}):
await jail_service.start_jail(_SOCKET, "sshd") # should not raise
async def test_stop_jail_success(self) -> None:
"""stop_jail sends the stop command without error."""
with _patch_client({"stop|sshd": (0, None)}):
await jail_service.stop_jail(_SOCKET, "sshd") # should not raise
async def test_set_idle_on(self) -> None:
"""set_idle sends idle=on when on=True."""
with _patch_client({"set|sshd|idle|on": (0, True)}):
await jail_service.set_idle(_SOCKET, "sshd", on=True) # should not raise
async def test_set_idle_off(self) -> None:
"""set_idle sends idle=off when on=False."""
with _patch_client({"set|sshd|idle|off": (0, True)}):
await jail_service.set_idle(_SOCKET, "sshd", on=False) # should not raise
async def test_reload_jail_success(self) -> None:
"""reload_jail sends a reload command with a minimal start-stream."""
with _patch_client({"reload|sshd|[]|[['start', 'sshd']]": (0, "OK")}):
await jail_service.reload_jail(_SOCKET, "sshd") # should not raise
async def test_reload_all_success(self) -> None:
"""reload_all fetches jail names then sends reload --all with a start-stream."""
with _patch_client(
{
"status": _make_global_status("sshd, nginx"),
"reload|--all|[]|[['start', 'nginx'], ['start', 'sshd']]": (0, "OK"),
}
):
await jail_service.reload_all(_SOCKET) # should not raise
async def test_reload_all_no_jails_still_sends_reload(self) -> None:
"""reload_all works with an empty jail list (sends an empty stream)."""
with _patch_client(
{
"status": (0, [("Number of jail", 0), ("Jail list", "")]),
"reload|--all|[]|[]": (0, "OK"),
}
):
await jail_service.reload_all(_SOCKET) # should not raise
async def test_reload_all_include_jails(self) -> None:
"""reload_all with include_jails adds the new jail to the stream."""
with _patch_client(
{
"status": _make_global_status("sshd, nginx"),
"reload|--all|[]|[['start', 'apache-auth'], ['start', 'nginx'], ['start', 'sshd']]": (0, "OK"),
}
):
await jail_service.reload_all(_SOCKET, include_jails=["apache-auth"])
async def test_reload_all_exclude_jails(self) -> None:
"""reload_all with exclude_jails removes the jail from the stream."""
with _patch_client(
{
"status": _make_global_status("sshd, nginx"),
"reload|--all|[]|[['start', 'nginx']]": (0, "OK"),
}
):
await jail_service.reload_all(_SOCKET, exclude_jails=["sshd"])
async def test_reload_all_include_and_exclude(self) -> None:
"""reload_all with both include and exclude applies both correctly."""
with _patch_client(
{
"status": _make_global_status("old, nginx"),
"reload|--all|[]|[['start', 'new'], ['start', 'nginx']]": (0, "OK"),
}
):
await jail_service.reload_all(
_SOCKET, include_jails=["new"], exclude_jails=["old"]
)
async def test_reload_all_unknown_jail_raises_jail_not_found(self) -> None:
"""reload_all detects UnknownJailException and raises JailNotFoundError.
When fail2ban cannot load a jail due to invalid configuration (e.g.,
missing logpath), it raises UnknownJailException during reload. This
test verifies that reload_all detects this and re-raises as
JailNotFoundError instead of the generic JailOperationError.
"""
with _patch_client(
{
"status": _make_global_status("sshd"),
"reload|--all|[]|[['start', 'airsonic-auth'], ['start', 'sshd']]": (
1,
Exception("UnknownJailException('airsonic-auth')"),
),
}
), pytest.raises(jail_service.JailNotFoundError) as exc_info:
await jail_service.reload_all(
_SOCKET, include_jails=["airsonic-auth"]
)
assert exc_info.value.name == "airsonic-auth"
async def test_restart_sends_stop_command(self) -> None:
"""restart() sends the ['stop'] command to the fail2ban socket."""
with _patch_client({"stop": (0, None)}):
await jail_service.restart(_SOCKET) # should not raise
async def test_restart_operation_error_raises(self) -> None:
"""restart() raises JailOperationError when fail2ban rejects the stop."""
with _patch_client({"stop": (1, Exception("cannot stop"))}), pytest.raises(
JailOperationError
):
await jail_service.restart(_SOCKET)
async def test_restart_connection_error_propagates(self) -> None:
"""restart() propagates Fail2BanConnectionError when socket is unreachable."""
class _FailClient:
def __init__(self, **_kw: Any) -> None:
self.send = AsyncMock(
side_effect=Fail2BanConnectionError("no socket", _SOCKET)
)
with (
patch("app.services.jail_service.Fail2BanClient", _FailClient),
pytest.raises(Fail2BanConnectionError),
):
await jail_service.restart(_SOCKET)
async def test_restart_daemon_returns_true_on_success(self) -> None:
"""restart_daemon returns True when stop, start, and probe all succeed."""
with (
patch("app.services.jail_service.restart", AsyncMock(return_value=None)),
patch("app.services.jail_service.start_daemon", AsyncMock(return_value=True)),
patch("app.services.jail_service.wait_for_fail2ban", AsyncMock(return_value=True)),
):
result = await jail_service.restart_daemon(
_SOCKET,
["fail2ban-client", "start"],
)
assert result is True
async def test_restart_daemon_returns_false_when_start_fails(self) -> None:
"""restart_daemon returns False when the configured start command fails."""
with (
patch("app.services.jail_service.restart", AsyncMock(return_value=None)),
patch("app.services.jail_service.start_daemon", AsyncMock(return_value=False)),
):
result = await jail_service.restart_daemon(
_SOCKET,
["fail2ban-client", "start"],
)
assert result is False
async def test_restart_daemon_returns_false_when_wait_fails(self) -> None:
"""restart_daemon returns False when fail2ban does not become responsive."""
with (
patch("app.services.jail_service.restart", AsyncMock(return_value=None)),
patch("app.services.jail_service.start_daemon", AsyncMock(return_value=True)),
patch("app.services.jail_service.wait_for_fail2ban", AsyncMock(return_value=False)),
):
result = await jail_service.restart_daemon(
_SOCKET,
["fail2ban-client", "start"],
)
assert result is False
async def test_start_not_found_raises(self) -> None:
"""start_jail raises JailNotFoundError for unknown jail."""
with _patch_client({"start|ghost": (1, Exception("Unknown jail: 'ghost'"))}), pytest.raises(JailNotFoundError):
await jail_service.start_jail(_SOCKET, "ghost")
async def test_stop_jail_already_stopped_is_noop(self) -> None:
"""stop_jail silently succeeds when the jail is not found (idempotent)."""
with _patch_client({"stop|sshd": (1, Exception("UnknownJailException('sshd')"))}):
await jail_service.stop_jail(_SOCKET, "sshd") # should not raise
async def test_stop_operation_error_raises(self) -> None:
"""stop_jail raises JailOperationError on a non-not-found fail2ban error."""
with _patch_client({"stop|sshd": (1, Exception("cannot stop"))}), pytest.raises(JailOperationError):
await jail_service.stop_jail(_SOCKET, "sshd")
# ---------------------------------------------------------------------------
# ban_ip / unban_ip
# ---------------------------------------------------------------------------
class TestBanUnban:
"""Unit tests for :func:`~app.services.ban_service.ban_ip` and
:func:`~app.services.ban_service.unban_ip`.
"""
async def test_ban_ip_success(self) -> None:
"""ban_ip sends the banip command for a valid IP."""
with _patch_client({"set|sshd|banip|1.2.3.4": (0, 1)}):
await ban_service.ban_ip(_SOCKET, "sshd", "1.2.3.4") # should not raise
async def test_ban_ip_invalid_raises(self) -> None:
"""ban_ip raises ValueError for a non-IP value."""
with pytest.raises(ValueError, match="Invalid IP"):
await ban_service.ban_ip(_SOCKET, "sshd", "not-an-ip")
async def test_ban_ip_unknown_jail_exception_raises_jail_not_found(self) -> None:
"""ban_ip raises JailNotFoundError when fail2ban returns UnknownJailException.
fail2ban serialises the exception without a space (``UnknownJailException``
rather than ``Unknown JailException``), so _is_not_found_error must match
the concatenated form ``"unknownjail``".
"""
response = (1, Exception("UnknownJailException('blocklist-import')"))
with (
_patch_client({"set|missing-jail|banip|1.2.3.4": response}),
pytest.raises(JailNotFoundError, match="missing-jail"),
):
await ban_service.ban_ip(_SOCKET, "missing-jail", "1.2.3.4")
async def test_ban_ipv6_success(self) -> None:
"""ban_ip accepts an IPv6 address."""
with _patch_client({"set|sshd|banip|::1": (0, 1)}):
await ban_service.ban_ip(_SOCKET, "sshd", "::1") # should not raise
async def test_unban_ip_all_jails(self) -> None:
"""unban_ip with jail=None uses the global unban command."""
with _patch_client({"unban|1.2.3.4": (0, 1)}):
await ban_service.unban_ip(_SOCKET, "1.2.3.4") # should not raise
async def test_unban_ip_specific_jail(self) -> None:
"""unban_ip with a jail sends the set unbanip command."""
with _patch_client({"set|sshd|unbanip|1.2.3.4": (0, 1)}):
await ban_service.unban_ip(_SOCKET, "1.2.3.4", jail="sshd") # should not raise
async def test_unban_invalid_ip_raises(self) -> None:
"""unban_ip raises ValueError for an invalid IP."""
with pytest.raises(ValueError, match="Invalid IP"):
await ban_service.unban_ip(_SOCKET, "bad-ip")
# ---------------------------------------------------------------------------
# get_active_bans
# ---------------------------------------------------------------------------
class TestGetActiveBans:
"""Unit tests for :func:`~app.services.ban_service.get_active_bans`."""
async def test_returns_active_ban_list_response(self) -> None:
"""get_active_bans returns an ActiveBanListResponse."""
responses = {
"status": _make_global_status("sshd"),
"get|sshd|banip|--with-time": (
0,
["1.2.3.4 \t2025-01-01 12:00:00 + 3600 = 2025-01-01 13:00:00"],
),
}
with _patch_client(responses):
result = await ban_service.get_active_bans(_SOCKET)
assert isinstance(result, ActiveBanListResponse)
assert result.total == 1
assert result.bans[0].ip == "1.2.3.4"
assert result.bans[0].jail == "sshd"
async def test_empty_when_no_jails(self) -> None:
"""get_active_bans returns empty list when no jails are active."""
responses = {"status": (0, [("Number of jail", 0), ("Jail list", "")])}
with _patch_client(responses):
result = await ban_service.get_active_bans(_SOCKET)
assert result.total == 0
assert result.bans == []
async def test_empty_when_no_bans(self) -> None:
"""get_active_bans returns empty list when all jails have zero bans."""
responses = {
"status": _make_global_status("sshd"),
"get|sshd|banip|--with-time": (0, []),
}
with _patch_client(responses):
result = await ban_service.get_active_bans(_SOCKET)
assert result.total == 0
async def test_ban_time_parsed(self) -> None:
"""get_active_bans populates banned_at and expires_at from the entry."""
responses = {
"status": _make_global_status("sshd"),
"get|sshd|banip|--with-time": (
0,
["10.0.0.1 \t2025-03-01 08:00:00 + 7200 = 2025-03-01 10:00:00"],
),
}
with _patch_client(responses):
result = await ban_service.get_active_bans(_SOCKET)
ban = result.bans[0]
assert ban.banned_at is not None
assert "2025-03-01" in ban.banned_at
assert ban.expires_at is not None
assert "2025-03-01" in ban.expires_at
async def test_error_in_jail_tolerated(self) -> None:
"""get_active_bans skips a jail that errors during the ban-list fetch."""
responses = {
"status": _make_global_status("sshd, nginx"),
"get|sshd|banip|--with-time": (
0,
["1.2.3.4 \t2025-01-01 10:00:00 + 600 = 2025-01-01 10:10:00"],
),
"get|nginx|banip|--with-time": Fail2BanConnectionError("no nginx", _SOCKET),
}
async def _side(*args: Any) -> Any:
key = "|".join(str(a) for a in args[0])
resp = responses.get(key)
if isinstance(resp, Exception):
raise resp
if resp is None:
raise KeyError(f"Unexpected key {key!r}")
return resp
class _FakeClientPartial:
def __init__(self, **_kw: Any) -> None:
self.send = AsyncMock(side_effect=_side)
with patch("app.services.ban_service.Fail2BanClient", _FakeClientPartial):
result = await ban_service.get_active_bans(_SOCKET)
# Only sshd ban returned (nginx silently skipped)
assert result.total == 1
assert result.bans[0].jail == "sshd"
async def test_http_session_triggers_lookup_batch(self) -> None:
"""When http_session is provided, geo_service.lookup_batch is used."""
from app.models.geo import GeoInfo
responses = {
"status": _make_global_status("sshd"),
"get|sshd|banip|--with-time": (
0,
["1.2.3.4 \t2025-01-01 12:00:00 + 3600 = 2025-01-01 13:00:00"],
),
}
mock_geo = {"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS1", org="ISP")}
mock_batch = AsyncMock(return_value=mock_geo)
with _patch_client(responses):
mock_session = AsyncMock()
result = await ban_service.get_active_bans(
_SOCKET,
http_session=mock_session,
geo_batch_lookup=mock_batch,
)
mock_batch.assert_awaited_once()
assert result.total == 1
assert result.bans[0].country == "DE"
async def test_http_session_batch_failure_graceful(self) -> None:
"""When lookup_batch raises, get_active_bans returns bans without geo."""
responses = {
"status": _make_global_status("sshd"),
"get|sshd|banip|--with-time": (
0,
["1.2.3.4 \t2025-01-01 12:00:00 + 3600 = 2025-01-01 13:00:00"],
),
}
failing_batch = AsyncMock(side_effect=RuntimeError("geo down"))
with _patch_client(responses):
mock_session = AsyncMock()
result = await ban_service.get_active_bans(
_SOCKET,
http_session=mock_session,
geo_batch_lookup=failing_batch,
)
assert result.total == 1
assert result.bans[0].country is None
async def test_geo_enricher_still_used_without_http_session(self) -> None:
"""Legacy geo_enricher is still called when http_session is not provided."""
from app.models.geo import GeoInfo
responses = {
"status": _make_global_status("sshd"),
"get|sshd|banip|--with-time": (
0,
["1.2.3.4 \t2025-01-01 12:00:00 + 3600 = 2025-01-01 13:00:00"],
),
}
async def _enricher(ip: str) -> GeoInfo | None:
return GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None)
with _patch_client(responses):
result = await ban_service.get_active_bans(
_SOCKET, geo_enricher=_enricher
)
assert result.total == 1
assert result.bans[0].country == "JP"
# ---------------------------------------------------------------------------
# Ignore list
# ---------------------------------------------------------------------------
class TestIgnoreList:
"""Unit tests for ignore list operations."""
async def test_get_ignore_list(self) -> None:
"""get_ignore_list returns a list of IP strings."""
with _patch_client({"get|sshd|ignoreip": (0, ["127.0.0.1", "10.0.0.0/8"])}):
result = await jail_service.get_ignore_list(_SOCKET, "sshd")
assert "127.0.0.1" in result
assert "10.0.0.0/8" in result
async def test_add_ignore_ip(self) -> None:
"""add_ignore_ip sends addignoreip for a valid CIDR."""
with _patch_client({"set|sshd|addignoreip|192.168.0.0/24": (0, "OK")}):
await jail_service.add_ignore_ip(_SOCKET, "sshd", "192.168.0.0/24")
async def test_add_ignore_ip_invalid_raises(self) -> None:
"""add_ignore_ip raises ValueError for an invalid CIDR."""
with pytest.raises(ValueError, match="Invalid IP"):
await jail_service.add_ignore_ip(_SOCKET, "sshd", "not-a-cidr")
async def test_del_ignore_ip(self) -> None:
"""del_ignore_ip sends delignoreip command."""
with _patch_client({"set|sshd|delignoreip|127.0.0.1": (0, "OK")}):
await jail_service.del_ignore_ip(_SOCKET, "sshd", "127.0.0.1")
async def test_get_ignore_self(self) -> None:
"""get_ignore_self returns a boolean."""
with _patch_client({"get|sshd|ignoreself": (0, True)}):
result = await jail_service.get_ignore_self(_SOCKET, "sshd")
assert result is True
async def test_set_ignore_self_on(self) -> None:
"""set_ignore_self sends ignoreself=true."""
with _patch_client({"set|sshd|ignoreself|true": (0, True)}):
await jail_service.set_ignore_self(_SOCKET, "sshd", on=True)
# ---------------------------------------------------------------------------
# lookup_ip
# ---------------------------------------------------------------------------
class TestLookupIp:
"""Unit tests for :func:`~app.services.jail_service.lookup_ip`."""
async def test_basic_lookup(self) -> None:
"""lookup_ip returns currently_banned_in list."""
responses = {
"get|--all|banned|1.2.3.4": (0, []),
"status": _make_global_status("sshd"),
"get|sshd|banip": (0, ["1.2.3.4", "5.6.7.8"]),
}
with _patch_client(responses):
result = await jail_service.lookup_ip(_SOCKET, "1.2.3.4")
assert result["ip"] == "1.2.3.4"
assert "sshd" in result["currently_banned_in"]
async def test_geo_enricher_returns_geo_detail(self) -> None:
"""lookup_ip converts GeoInfo from the enricher into GeoDetail."""
responses = {
"get|--all|banned|1.2.3.4": (0, []),
"status": _make_global_status("sshd"),
"get|sshd|banip": (0, ["1.2.3.4", "5.6.7.8"]),
}
async def _enricher(ip: str) -> GeoInfo:
return GeoInfo(country_code="DE", country_name="Germany", asn="AS123", org="Acme")
with _patch_client(responses):
result = await jail_service.lookup_ip(
_SOCKET,
"1.2.3.4",
geo_enricher=_enricher,
)
assert isinstance(result["geo"], GeoDetail)
assert result["geo"].country_code == "DE"
assert result["geo"].country_name == "Germany"
assert result["geo"].asn == "AS123"
assert result["geo"].org == "Acme"
async def test_http_session_uses_geo_service_lookup(self) -> None:
"""lookup_ip uses geo_service.lookup when http_session is provided."""
responses = {
"get|--all|banned|1.2.3.4": (0, []),
"status": _make_global_status("sshd"),
"get|sshd|banip": (0, ["1.2.3.4", "5.6.7.8"]),
}
mock_geo = GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None)
mock_session = AsyncMock()
with _patch_client(responses), patch(
"app.services.jail_service.geo_service.lookup",
AsyncMock(return_value=mock_geo),
) as mock_lookup:
result = await jail_service.lookup_ip(
_SOCKET,
"1.2.3.4",
http_session=mock_session,
)
mock_lookup.assert_awaited_once_with("1.2.3.4", mock_session)
assert isinstance(result["geo"], GeoDetail)
assert result["geo"].country_code == "JP"
assert result["geo"].country_name == "Japan"
assert result["geo"].asn is None
assert result["geo"].org is None
async def test_invalid_ip_raises(self) -> None:
"""lookup_ip raises ValueError for invalid IP."""
with pytest.raises(ValueError, match="Invalid IP"):
await jail_service.lookup_ip(_SOCKET, "not-an-ip")
async def test_not_banned_returns_empty_list(self) -> None:
"""lookup_ip returns empty currently_banned_in when IP is not banned."""
responses = {
"get|--all|banned|9.9.9.9": (0, []),
"status": _make_global_status("sshd"),
"get|sshd|banip": (0, ["1.2.3.4"]),
}
with _patch_client(responses):
result = await jail_service.lookup_ip(_SOCKET, "9.9.9.9")
assert result["currently_banned_in"] == []
# ---------------------------------------------------------------------------
# unban_all_ips
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
class TestUnbanAllIps:
"""Tests for :func:`~app.services.jail_service.unban_all_ips`."""
async def test_unban_all_ips_returns_count(self) -> None:
"""unban_all_ips returns the integer count from fail2ban."""
responses = {"unban|--all": (0, 5)}
with _patch_client(responses):
count = await jail_service.unban_all_ips(_SOCKET)
assert count == 5
async def test_unban_all_ips_returns_zero_when_none_banned(self) -> None:
"""unban_all_ips returns 0 when no IPs are currently banned."""
responses = {"unban|--all": (0, 0)}
with _patch_client(responses):
count = await jail_service.unban_all_ips(_SOCKET)
assert count == 0
async def test_unban_all_ips_raises_on_connection_error(self) -> None:
"""unban_all_ips propagates Fail2BanConnectionError."""
with (
patch(
"app.services.jail_service.Fail2BanClient",
side_effect=Fail2BanConnectionError("unreachable", _SOCKET),
),
pytest.raises(Fail2BanConnectionError),
):
await jail_service.unban_all_ips(_SOCKET)
# ---------------------------------------------------------------------------
# get_jail_banned_ips
# ---------------------------------------------------------------------------
#: A raw ban entry string in the format produced by fail2ban --with-time.
_BAN_ENTRY_1 = "1.2.3.4\t2025-01-01 10:00:00 + 600 = 2025-01-01 10:10:00"
_BAN_ENTRY_2 = "5.6.7.8\t2025-01-01 11:00:00 + 600 = 2025-01-01 11:10:00"
_BAN_ENTRY_3 = "9.10.11.12\t2025-01-01 12:00:00 + 600 = 2025-01-01 12:10:00"
def _banned_ips_responses(jail: str = "sshd", entries: list[str] | None = None) -> dict[str, Any]:
"""Build mock responses for get_jail_banned_ips tests."""
if entries is None:
entries = [_BAN_ENTRY_1, _BAN_ENTRY_2]
return {
f"status|{jail}|short": _make_short_status(),
f"get|{jail}|banip|--with-time": (0, entries),
}
class TestGetJailBannedIps:
"""Unit tests for :func:`~app.services.jail_service.get_jail_banned_ips`."""
async def test_returns_jail_banned_ips_response(self) -> None:
"""get_jail_banned_ips returns a JailBannedIpsResponse."""
with _patch_client(_banned_ips_responses()):
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd")
assert isinstance(result, JailBannedIpsResponse)
async def test_total_reflects_all_entries(self) -> None:
"""total equals the number of parsed ban entries."""
with _patch_client(_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])):
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd")
assert result.total == 3
async def test_page_1_returns_first_n_items(self) -> None:
"""page=1 with page_size=2 returns the first two entries."""
with _patch_client(
_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])
):
result = await jail_service.get_jail_banned_ips(
_SOCKET, "sshd", page=1, page_size=2
)
assert len(result.items) == 2
assert result.items[0].ip == "1.2.3.4"
assert result.items[1].ip == "5.6.7.8"
assert result.total == 3
async def test_page_2_returns_remaining_items(self) -> None:
"""page=2 with page_size=2 returns the third entry."""
with _patch_client(
_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])
):
result = await jail_service.get_jail_banned_ips(
_SOCKET, "sshd", page=2, page_size=2
)
assert len(result.items) == 1
assert result.items[0].ip == "9.10.11.12"
async def test_page_beyond_last_returns_empty_items(self) -> None:
"""Requesting a page past the end returns an empty items list."""
with _patch_client(_banned_ips_responses()):
result = await jail_service.get_jail_banned_ips(
_SOCKET, "sshd", page=99, page_size=25
)
assert result.items == []
assert result.total == 2
async def test_search_filter_narrows_results(self) -> None:
"""search parameter filters entries by IP substring."""
with _patch_client(_banned_ips_responses()):
result = await jail_service.get_jail_banned_ips(
_SOCKET, "sshd", search="1.2.3"
)
assert result.total == 1
assert result.items[0].ip == "1.2.3.4"
async def test_search_filter_case_insensitive(self) -> None:
"""search filter is case-insensitive."""
entries = ["192.168.0.1\t2025-01-01 10:00:00 + 600 = 2025-01-01 10:10:00"]
with _patch_client(_banned_ips_responses(entries=entries)):
result = await jail_service.get_jail_banned_ips(
_SOCKET, "sshd", search="192.168"
)
assert result.total == 1
async def test_search_no_match_returns_empty(self) -> None:
"""search that matches nothing returns empty items and total=0."""
with _patch_client(_banned_ips_responses()):
result = await jail_service.get_jail_banned_ips(
_SOCKET, "sshd", search="999.999"
)
assert result.total == 0
assert result.items == []
async def test_empty_ban_list_returns_total_zero(self) -> None:
"""get_jail_banned_ips handles an empty ban list gracefully."""
responses = {
"status|sshd|short": _make_short_status(),
"get|sshd|banip|--with-time": (0, []),
}
with _patch_client(responses):
result = await jail_service.get_jail_banned_ips(_SOCKET, "sshd")
assert result.total == 0
assert result.items == []
async def test_page_size_clamped_to_max(self) -> None:
"""page_size values above 100 are silently clamped to 100."""
entries = [f"10.0.0.{i}\t2025-01-01 10:00:00 + 600 = 2025-01-01 10:10:00" for i in range(1, 101)]
responses = {
"status|sshd|short": _make_short_status(),
"get|sshd|banip|--with-time": (0, entries),
}
with _patch_client(responses):
result = await jail_service.get_jail_banned_ips(
_SOCKET, "sshd", page=1, page_size=200
)
assert len(result.items) <= 100
async def test_geo_enrichment_called_for_page_slice_only(self) -> None:
"""Geo enrichment is requested only for IPs in the current page."""
from unittest.mock import MagicMock
from app.services import geo_service
http_session = MagicMock()
geo_enrichment_ips: list[list[str]] = []
async def _mock_lookup_batch(
ips: list[str], _session: Any, **_kw: Any
) -> dict[str, Any]:
geo_enrichment_ips.append(list(ips))
return {}
with (
_patch_client(
_banned_ips_responses(entries=[_BAN_ENTRY_1, _BAN_ENTRY_2, _BAN_ENTRY_3])
),
patch.object(geo_service, "lookup_batch", side_effect=_mock_lookup_batch),
):
result = await jail_service.get_jail_banned_ips(
_SOCKET,
"sshd",
page=1,
page_size=2,
http_session=http_session,
geo_batch_lookup=geo_service.lookup_batch,
)
# Only the 2-IP page slice should be passed to geo enrichment.
assert len(geo_enrichment_ips) == 1
assert len(geo_enrichment_ips[0]) == 2
assert result.total == 3
async def test_unknown_jail_raises_jail_not_found_error(self) -> None:
"""get_jail_banned_ips raises JailNotFoundError for unknown jail."""
# Simulate fail2ban returning an "unknown jail" error.
class _FakeClient:
def __init__(self, **_kw: Any) -> None:
pass
async def send(self, command: list[Any]) -> Any:
raise ValueError("Unknown jail: ghost")
with (
patch("app.services.jail_service.Fail2BanClient", _FakeClient),
pytest.raises(JailNotFoundError),
):
await jail_service.get_jail_banned_ips(_SOCKET, "ghost")
async def test_connection_error_propagates(self) -> None:
"""get_jail_banned_ips propagates Fail2BanConnectionError."""
class _FailClient:
def __init__(self, **_kw: Any) -> None:
self.send = AsyncMock(
side_effect=Fail2BanConnectionError("no socket", _SOCKET)
)
with (
patch("app.services.jail_service.Fail2BanClient", _FailClient),
pytest.raises(Fail2BanConnectionError),
):
await jail_service.get_jail_banned_ips(_SOCKET, "sshd")