"""Tests for config_service functions.""" from __future__ import annotations from typing import Any from unittest.mock import AsyncMock, patch import pytest from app.models.config import ( GlobalConfigUpdate, JailConfigListResponse, JailConfigResponse, LogPreviewRequest, RegexTestRequest, ) from app.services import config_service from app.services.config_service import ( ConfigValidationError, JailNotFoundError, ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- _SOCKET = "/fake/fail2ban.sock" def _make_global_status(names: str = "sshd") -> tuple[int, list[Any]]: return (0, [("Number of jail", 1), ("Jail list", names)]) def _make_short_status() -> tuple[int, list[Any]]: return ( 0, [ ("Filter", [("Currently failed", 3), ("Total failed", 20)]), ("Actions", [("Currently banned", 2), ("Total banned", 10)]), ], ) def _make_send(responses: dict[str, Any]) -> AsyncMock: async def _side_effect(command: list[Any]) -> Any: key = "|".join(str(c) for c in command) if key in responses: return responses[key] for resp_key, resp_value in responses.items(): if key.startswith(resp_key): return resp_value return (0, None) return AsyncMock(side_effect=_side_effect) def _patch_client(responses: dict[str, Any]) -> Any: mock_send = _make_send(responses) class _FakeClient: def __init__(self, **_kw: Any) -> None: self.send = mock_send return patch("app.services.config_service.Fail2BanClient", _FakeClient) _DEFAULT_JAIL_RESPONSES: dict[str, Any] = { "status|sshd|short": _make_short_status(), "get|sshd|bantime": (0, 600), "get|sshd|findtime": (0, 600), "get|sshd|maxretry": (0, 5), "get|sshd|failregex": (0, ["regex1", "regex2"]), "get|sshd|ignoreregex": (0, []), "get|sshd|logpath": (0, ["/var/log/auth.log"]), "get|sshd|datepattern": (0, None), "get|sshd|logencoding": (0, "UTF-8"), "get|sshd|backend": (0, "polling"), "get|sshd|usedns": (0, "warn"), "get|sshd|prefregex": (0, ""), "get|sshd|actions": (0, ["iptables"]), } # --------------------------------------------------------------------------- # get_jail_config # --------------------------------------------------------------------------- class TestGetJailConfig: """Unit tests for :func:`~app.services.config_service.get_jail_config`.""" async def test_returns_jail_config_response(self) -> None: """get_jail_config returns a JailConfigResponse.""" with _patch_client(_DEFAULT_JAIL_RESPONSES): result = await config_service.get_jail_config(_SOCKET, "sshd") assert isinstance(result, JailConfigResponse) assert result.jail.name == "sshd" assert result.jail.ban_time == 600 assert result.jail.max_retry == 5 assert result.jail.fail_regex == ["regex1", "regex2"] assert result.jail.log_paths == ["/var/log/auth.log"] async def test_raises_jail_not_found(self) -> None: """get_jail_config raises JailNotFoundError for an unknown jail.""" async def _send(command: list[Any]) -> Any: raise Exception("Unknown jail 'missing'") class _FakeClient: def __init__(self, **_kw: Any) -> None: self.send = AsyncMock(side_effect=_send) # Patch the client to raise on status command. async def _faulty_send(command: list[Any]) -> Any: if command[0] == "status": return (1, "unknown jail 'missing'") return (0, None) with patch( "app.services.config_service.Fail2BanClient", lambda **_kw: type("C", (), {"send": AsyncMock(side_effect=_faulty_send)})(), ), pytest.raises(JailNotFoundError): await config_service.get_jail_config(_SOCKET, "missing") async def test_actions_parsed_correctly(self) -> None: """get_jail_config includes actions list.""" with _patch_client(_DEFAULT_JAIL_RESPONSES): result = await config_service.get_jail_config(_SOCKET, "sshd") assert "iptables" in result.jail.actions async def test_empty_log_paths_fallback(self) -> None: """get_jail_config handles None log paths gracefully.""" responses = {**_DEFAULT_JAIL_RESPONSES, "get|sshd|logpath": (0, None)} with _patch_client(responses): result = await config_service.get_jail_config(_SOCKET, "sshd") assert result.jail.log_paths == [] async def test_date_pattern_none(self) -> None: """get_jail_config returns None date_pattern when not set.""" with _patch_client(_DEFAULT_JAIL_RESPONSES): result = await config_service.get_jail_config(_SOCKET, "sshd") assert result.jail.date_pattern is None async def test_use_dns_populated(self) -> None: """get_jail_config returns use_dns from the socket response.""" responses = {**_DEFAULT_JAIL_RESPONSES, "get|sshd|usedns": (0, "no")} with _patch_client(responses): result = await config_service.get_jail_config(_SOCKET, "sshd") assert result.jail.use_dns == "no" async def test_use_dns_default_when_missing(self) -> None: """get_jail_config defaults use_dns to 'warn' when socket returns None.""" responses = {**_DEFAULT_JAIL_RESPONSES, "get|sshd|usedns": (0, None)} with _patch_client(responses): result = await config_service.get_jail_config(_SOCKET, "sshd") assert result.jail.use_dns == "warn" async def test_prefregex_populated(self) -> None: """get_jail_config returns prefregex from the socket response.""" responses = { **_DEFAULT_JAIL_RESPONSES, "get|sshd|prefregex": (0, r"^%(__prefix_line)s"), } with _patch_client(responses): result = await config_service.get_jail_config(_SOCKET, "sshd") assert result.jail.prefregex == r"^%(__prefix_line)s" async def test_prefregex_empty_when_missing(self) -> None: """get_jail_config returns empty string prefregex when socket returns None.""" responses = {**_DEFAULT_JAIL_RESPONSES, "get|sshd|prefregex": (0, None)} with _patch_client(responses): result = await config_service.get_jail_config(_SOCKET, "sshd") assert result.jail.prefregex == "" # --------------------------------------------------------------------------- # list_jail_configs # --------------------------------------------------------------------------- class TestListJailConfigs: """Unit tests for :func:`~app.services.config_service.list_jail_configs`.""" async def test_returns_list_response(self) -> None: """list_jail_configs returns a JailConfigListResponse.""" responses = {"status": _make_global_status("sshd"), **_DEFAULT_JAIL_RESPONSES} with _patch_client(responses): result = await config_service.list_jail_configs(_SOCKET) assert isinstance(result, JailConfigListResponse) assert result.total == 1 assert result.jails[0].name == "sshd" async def test_empty_when_no_jails(self) -> None: """list_jail_configs returns empty list when no jails are active.""" responses = {"status": (0, [("Jail list", ""), ("Number of jail", 0)])} with _patch_client(responses): result = await config_service.list_jail_configs(_SOCKET) assert result.total == 0 assert result.jails == [] async def test_multiple_jails(self) -> None: """list_jail_configs handles comma-separated jail names.""" nginx_responses = { k.replace("sshd", "nginx"): v for k, v in _DEFAULT_JAIL_RESPONSES.items() } responses = { "status": _make_global_status("sshd, nginx"), **_DEFAULT_JAIL_RESPONSES, **nginx_responses, } with _patch_client(responses): result = await config_service.list_jail_configs(_SOCKET) assert result.total == 2 names = {j.name for j in result.jails} assert names == {"sshd", "nginx"} # --------------------------------------------------------------------------- # update_jail_config # --------------------------------------------------------------------------- class TestUpdateJailConfig: """Unit tests for :func:`~app.services.config_service.update_jail_config`.""" async def test_updates_numeric_fields(self) -> None: """update_jail_config sends set commands for numeric fields.""" sent_commands: list[list[Any]] = [] async def _send(command: list[Any]) -> Any: sent_commands.append(command) return (0, "OK") class _FakeClient: def __init__(self, **_kw: Any) -> None: self.send = AsyncMock(side_effect=_send) from app.models.config import JailConfigUpdate update = JailConfigUpdate(ban_time=3600, max_retry=10) with patch("app.services.config_service.Fail2BanClient", _FakeClient): await config_service.update_jail_config(_SOCKET, "sshd", update) keys = [cmd[2] for cmd in sent_commands if len(cmd) >= 3 and cmd[0] == "set"] assert "bantime" in keys assert "maxretry" in keys async def test_raises_validation_error_on_bad_regex(self) -> None: """update_jail_config raises ConfigValidationError for invalid regex.""" from app.models.config import JailConfigUpdate update = JailConfigUpdate(fail_regex=["[invalid"]) with pytest.raises(ConfigValidationError, match="Invalid regex"): await config_service.update_jail_config(_SOCKET, "sshd", update) async def test_skips_none_fields(self) -> None: """update_jail_config does not send commands for None fields.""" sent_commands: list[list[Any]] = [] async def _send(command: list[Any]) -> Any: sent_commands.append(command) return (0, "OK") class _FakeClient: def __init__(self, **_kw: Any) -> None: self.send = AsyncMock(side_effect=_send) from app.models.config import JailConfigUpdate update = JailConfigUpdate(ban_time=None, max_retry=None, find_time=None) with patch("app.services.config_service.Fail2BanClient", _FakeClient): await config_service.update_jail_config(_SOCKET, "sshd", update) set_commands = [cmd for cmd in sent_commands if len(cmd) >= 3 and cmd[0] == "set"] assert set_commands == [] async def test_replaces_fail_regex(self) -> None: """update_jail_config deletes old regexes and adds new ones.""" sent_commands: list[list[Any]] = [] async def _send(command: list[Any]) -> Any: sent_commands.append(command) if command[0] == "get": return (0, ["old_pattern"]) return (0, "OK") class _FakeClient: def __init__(self, **_kw: Any) -> None: self.send = AsyncMock(side_effect=_send) from app.models.config import JailConfigUpdate update = JailConfigUpdate(fail_regex=["new_pattern"]) with patch("app.services.config_service.Fail2BanClient", _FakeClient): await config_service.update_jail_config(_SOCKET, "sshd", update) add_cmd = next( (c for c in sent_commands if len(c) >= 4 and c[2] == "addfailregex"), None, ) assert add_cmd is not None assert add_cmd[3] == "new_pattern" async def test_sets_dns_mode(self) -> None: """update_jail_config sends 'set usedns' for dns_mode.""" from app.models.config import JailConfigUpdate sent_commands: list[list[Any]] = [] async def _send(command: list[Any]) -> Any: sent_commands.append(command) return (0, "OK") class _FakeClient: def __init__(self, **_kw: Any) -> None: self.send = AsyncMock(side_effect=_send) update = JailConfigUpdate(dns_mode="no") with patch("app.services.config_service.Fail2BanClient", _FakeClient): await config_service.update_jail_config(_SOCKET, "sshd", update) usedns_cmd = next( (c for c in sent_commands if len(c) >= 4 and c[2] == "usedns"), None, ) assert usedns_cmd is not None assert usedns_cmd[3] == "no" async def test_sets_prefregex(self) -> None: """update_jail_config sends 'set prefregex' for prefregex.""" from app.models.config import JailConfigUpdate sent_commands: list[list[Any]] = [] async def _send(command: list[Any]) -> Any: sent_commands.append(command) return (0, "OK") class _FakeClient: def __init__(self, **_kw: Any) -> None: self.send = AsyncMock(side_effect=_send) update = JailConfigUpdate(prefregex=r"^%(__prefix_line)s") with patch("app.services.config_service.Fail2BanClient", _FakeClient): await config_service.update_jail_config(_SOCKET, "sshd", update) prefregex_cmd = next( (c for c in sent_commands if len(c) >= 4 and c[2] == "prefregex"), None, ) assert prefregex_cmd is not None assert prefregex_cmd[3] == r"^%(__prefix_line)s" async def test_skips_none_prefregex(self) -> None: """update_jail_config does not send prefregex command when field is None.""" from app.models.config import JailConfigUpdate sent_commands: list[list[Any]] = [] async def _send(command: list[Any]) -> Any: sent_commands.append(command) return (0, "OK") class _FakeClient: def __init__(self, **_kw: Any) -> None: self.send = AsyncMock(side_effect=_send) update = JailConfigUpdate(prefregex=None) with patch("app.services.config_service.Fail2BanClient", _FakeClient): await config_service.update_jail_config(_SOCKET, "sshd", update) prefregex_cmd = next( (c for c in sent_commands if len(c) >= 4 and c[2] == "prefregex"), None, ) assert prefregex_cmd is None async def test_raises_validation_error_on_invalid_prefregex(self) -> None: """update_jail_config raises ConfigValidationError for an invalid prefregex.""" from app.models.config import JailConfigUpdate update = JailConfigUpdate(prefregex="[invalid") with pytest.raises(ConfigValidationError, match="prefregex"): await config_service.update_jail_config(_SOCKET, "sshd", update) # --------------------------------------------------------------------------- # get_global_config # --------------------------------------------------------------------------- class TestGetGlobalConfig: """Unit tests for :func:`~app.services.config_service.get_global_config`.""" async def test_returns_global_config(self) -> None: """get_global_config returns parsed GlobalConfigResponse.""" responses = { "get|loglevel": (0, "WARNING"), "get|logtarget": (0, "/var/log/fail2ban.log"), "get|dbpurgeage": (0, 86400), "get|dbmaxmatches": (0, 10), } with _patch_client(responses): result = await config_service.get_global_config(_SOCKET) assert result.log_level == "WARNING" assert result.log_target == "/var/log/fail2ban.log" assert result.db_purge_age == 86400 assert result.db_max_matches == 10 async def test_defaults_used_on_error(self) -> None: """get_global_config uses fallback defaults when commands fail.""" responses: dict[str, Any] = {} with _patch_client(responses): result = await config_service.get_global_config(_SOCKET) assert result.log_level is not None assert result.log_target is not None # --------------------------------------------------------------------------- # update_global_config # --------------------------------------------------------------------------- class TestUpdateGlobalConfig: """Unit tests for :func:`~app.services.config_service.update_global_config`.""" async def test_sends_set_commands(self) -> None: """update_global_config sends set commands for non-None fields.""" sent: list[list[Any]] = [] async def _send(command: list[Any]) -> Any: sent.append(command) return (0, "OK") class _FakeClient: def __init__(self, **_kw: Any) -> None: self.send = AsyncMock(side_effect=_send) update = GlobalConfigUpdate(log_level="debug", db_purge_age=3600) with patch("app.services.config_service.Fail2BanClient", _FakeClient): await config_service.update_global_config(_SOCKET, update) keys = [cmd[1] for cmd in sent if len(cmd) >= 3 and cmd[0] == "set"] assert "loglevel" in keys assert "dbpurgeage" in keys async def test_log_level_uppercased(self) -> None: """update_global_config uppercases log_level before sending.""" sent: list[list[Any]] = [] async def _send(command: list[Any]) -> Any: sent.append(command) return (0, "OK") class _FakeClient: def __init__(self, **_kw: Any) -> None: self.send = AsyncMock(side_effect=_send) update = GlobalConfigUpdate(log_level="debug") with patch("app.services.config_service.Fail2BanClient", _FakeClient): await config_service.update_global_config(_SOCKET, update) cmd = next(c for c in sent if len(c) >= 3 and c[1] == "loglevel") assert cmd[2] == "DEBUG" # --------------------------------------------------------------------------- # test_regex (synchronous) # --------------------------------------------------------------------------- class TestTestRegex: """Unit tests for :func:`~app.services.config_service.test_regex`.""" def test_matching_pattern(self) -> None: """test_regex returns matched=True for a valid match.""" req = RegexTestRequest( log_line="Failed password for user from 1.2.3.4", fail_regex=r"(?P\d+\.\d+\.\d+\.\d+)", ) result = config_service.test_regex(req) assert result.matched is True assert "1.2.3.4" in result.groups assert result.error is None def test_non_matching_pattern(self) -> None: """test_regex returns matched=False when pattern does not match.""" req = RegexTestRequest( log_line="Normal log line here", fail_regex=r"BANME", ) result = config_service.test_regex(req) assert result.matched is False assert result.groups == [] def test_invalid_pattern_returns_error(self) -> None: """test_regex returns error message for an invalid regex.""" req = RegexTestRequest( log_line="any line", fail_regex=r"[invalid", ) result = config_service.test_regex(req) assert result.matched is False assert result.error is not None assert len(result.error) > 0 def test_empty_groups_when_no_capture(self) -> None: """test_regex returns empty groups when pattern has no capture groups.""" req = RegexTestRequest( log_line="fail here", fail_regex=r"fail", ) result = config_service.test_regex(req) assert result.matched is True assert result.groups == [] def test_multiple_capture_groups(self) -> None: """test_regex returns all captured groups.""" req = RegexTestRequest( log_line="user=root ip=1.2.3.4", fail_regex=r"user=(\w+) ip=([\d.]+)", ) result = config_service.test_regex(req) assert result.matched is True assert len(result.groups) == 2 # --------------------------------------------------------------------------- # preview_log # --------------------------------------------------------------------------- class TestPreviewLog: """Unit tests for :func:`~app.services.config_service.preview_log`.""" async def test_returns_error_for_invalid_regex(self, tmp_path: Any) -> None: """preview_log returns regex_error for an invalid pattern.""" req = LogPreviewRequest(log_path=str(tmp_path / "fake.log"), fail_regex="[bad") result = await config_service.preview_log(req) assert result.regex_error is not None assert result.total_lines == 0 async def test_returns_error_for_missing_file(self) -> None: """preview_log returns regex_error when file does not exist.""" req = LogPreviewRequest( log_path="/nonexistent/path/log.txt", fail_regex=r"test", ) result = await config_service.preview_log(req) assert result.regex_error is not None async def test_matches_lines_in_file(self, tmp_path: Any) -> None: """preview_log correctly identifies matching and non-matching lines.""" log_file = tmp_path / "test.log" log_file.write_text("FAIL login from 1.2.3.4\nOK normal line\nFAIL from 5.6.7.8\n") req = LogPreviewRequest(log_path=str(log_file), fail_regex=r"FAIL") result = await config_service.preview_log(req) assert result.total_lines == 3 assert result.matched_count == 2 async def test_matched_line_has_groups(self, tmp_path: Any) -> None: """preview_log captures regex groups in matched lines.""" log_file = tmp_path / "test.log" log_file.write_text("error from 1.2.3.4 port 22\n") req = LogPreviewRequest( log_path=str(log_file), fail_regex=r"from (\d+\.\d+\.\d+\.\d+)", ) result = await config_service.preview_log(req) matched = [ln for ln in result.lines if ln.matched] assert len(matched) == 1 assert "1.2.3.4" in matched[0].groups async def test_num_lines_limit(self, tmp_path: Any) -> None: """preview_log respects the num_lines limit.""" log_file = tmp_path / "big.log" log_file.write_text("\n".join(f"line {i}" for i in range(500)) + "\n") req = LogPreviewRequest(log_path=str(log_file), fail_regex=r"line", num_lines=50) result = await config_service.preview_log(req) assert result.total_lines <= 50