diff --git a/Docs/Tasks.md b/Docs/Tasks.md index d7ed642..796f3d8 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -260,3 +260,65 @@ Reference config directory: `/home/lukas/Volume/repo/BanGUI/Docker/fail2ban-dev- - Backend: add router integration tests for the new update fields. - Frontend: update `ConfigPageLogPath.test.tsx` mock `JailConfig` to include `use_dns` and `prefregex`. +--- + +## Task 8 — Improve Test Coverage for Background Tasks and Utilities ✅ DONE + +**Goal:** Raise test coverage for the background-task modules and the fail2ban client utility to ≥ 80 %, closing the critical-path gap flagged in the Step 6.2 review. + +**Coverage before this task (from last full run):** + +| Module | Before | +|---|---| +| `app/tasks/blocklist_import.py` | 23 % | +| `app/tasks/health_check.py` | 43 % | +| `app/tasks/geo_cache_flush.py` | 60 % | +| `app/utils/fail2ban_client.py` | 58 % | + +### 8a — Tests for `blocklist_import` + +Create `tests/test_tasks/test_blocklist_import.py`: +- `_run_import` happy path: mock `blocklist_service.import_all`, verify structured log emitted. +- `_run_import` exception path: simulate unexpected exception, verify `log.exception` called. +- `_apply_schedule` hourly: mock scheduler, verify `add_job` called with `"interval"` trigger and correct `hours`. +- `_apply_schedule` daily: verify `"cron"` trigger with `hour` and `minute`. +- `_apply_schedule` weekly: verify `"cron"` trigger with `day_of_week`, `hour`, `minute`. +- `_apply_schedule` replaces an existing job: confirm `remove_job` called first when job already exists. + +### 8b — Tests for `health_check` + +Create `tests/test_tasks/test_health_check.py`: +- `_run_probe` online status: verify `app.state.server_status` is updated correctly. +- `_run_probe` offline→online transition: verify `"fail2ban_came_online"` log event. +- `_run_probe` online→offline transition: verify `"fail2ban_went_offline"` log event. +- `_run_probe` stable online (no transition): verify no transition log events. +- `register`: verify `add_job` is called with `"interval"` trigger and initial offline status set. + +### 8c — Tests for `geo_cache_flush` + +Create `tests/test_tasks/test_geo_cache_flush.py`: +- `_run_flush` with dirty IPs: verify `geo_service.flush_dirty` is called and debug log emitted when count > 0. +- `_run_flush` with nothing: verify `flush_dirty` called but no debug log. +- `register`: verify `add_job` called with correct interval and stable job ID. + +### 8d — Extended tests for `fail2ban_client` + +Extend `tests/test_services/test_fail2ban_client.py`: +- `send()` success path: mock `run_in_executor`, verify response is returned and debug log emitted. +- `send()` `Fail2BanConnectionError`: verify exception is re-raised and warning log emitted. +- `send()` `Fail2BanProtocolError`: verify exception is re-raised and error log emitted. +- `_send_command_sync` connection closed mid-stream (empty chunk): verify `Fail2BanConnectionError`. +- `_send_command_sync` pickle parse error (bad bytes in response): verify `Fail2BanProtocolError`. +- `_coerce_command_token` for `str`, `bool`, `int`, `float`, `list`, `dict`, `set`, and a custom object (stringified). + +**Result:** 50 new tests added (678 total). Coverage after: + +| Module | Before | After | +|---|---|---| +| `app/tasks/blocklist_import.py` | 23 % | 96 % | +| `app/tasks/health_check.py` | 43 % | 100 % | +| `app/tasks/geo_cache_flush.py` | 60 % | 100 % | +| `app/utils/fail2ban_client.py` | 58 % | 96 % | + +Overall backend coverage: 85 % → 87 %. ruff, mypy --strict, tsc, and eslint all clean. + diff --git a/backend/tests/test_services/test_fail2ban_client.py b/backend/tests/test_services/test_fail2ban_client.py index 0b81dc5..66fa33e 100644 --- a/backend/tests/test_services/test_fail2ban_client.py +++ b/backend/tests/test_services/test_fail2ban_client.py @@ -5,9 +5,11 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.utils.fail2ban_client import ( + _PROTO_END, Fail2BanClient, Fail2BanConnectionError, Fail2BanProtocolError, + _coerce_command_token, _send_command_sync, ) @@ -85,3 +87,223 @@ class TestSendCommandSync: command=["status"], timeout=1.0, ) + + +class TestSendCommandSyncProtocol: + """Tests for edge cases in the receive-loop and unpickling logic.""" + + def _make_connected_sock(self) -> MagicMock: + """Return a minimal mock socket that reports a successful connect. + + Returns: + A :class:`unittest.mock.MagicMock` that mimics a socket. + """ + mock_sock = MagicMock() + mock_sock.connect.return_value = None + return mock_sock + + def test_send_command_sync_raises_connection_error_on_empty_chunk(self) -> None: + """Must raise :class:`Fail2BanConnectionError` when the server closes mid-stream.""" + mock_sock = self._make_connected_sock() + # First recv returns empty bytes → server closed the connection. + mock_sock.recv.return_value = b"" + + with ( + patch("socket.socket", return_value=mock_sock), + pytest.raises(Fail2BanConnectionError, match="closed unexpectedly"), + ): + _send_command_sync( + socket_path="/fake/fail2ban.sock", + command=["ping"], + timeout=1.0, + ) + + def test_send_command_sync_raises_protocol_error_on_bad_pickle(self) -> None: + """Must raise :class:`Fail2BanProtocolError` when the response is not valid pickle.""" + mock_sock = self._make_connected_sock() + # Return the end marker directly so the recv-loop terminates immediately, + # but prepend garbage bytes so ``loads`` fails. + mock_sock.recv.side_effect = [ + _PROTO_END, # first call — exits the receive loop + ] + + # Patch loads to raise to simulate a corrupted response. + with ( + patch("socket.socket", return_value=mock_sock), + patch("app.utils.fail2ban_client.loads", side_effect=Exception("bad pickle")), + pytest.raises(Fail2BanProtocolError, match="Failed to unpickle"), + ): + _send_command_sync( + socket_path="/fake/fail2ban.sock", + command=["status"], + timeout=1.0, + ) + + def test_send_command_sync_returns_parsed_response(self) -> None: + """Must return the Python object that was pickled by fail2ban.""" + expected_response = [0, ["sshd", "nginx"]] + mock_sock = self._make_connected_sock() + # Return the proto end-marker so the recv-loop exits, then parse the raw bytes. + mock_sock.recv.return_value = _PROTO_END + + with ( + patch("socket.socket", return_value=mock_sock), + patch("app.utils.fail2ban_client.loads", return_value=expected_response), + ): + result = _send_command_sync( + socket_path="/fake/fail2ban.sock", + command=["status"], + timeout=1.0, + ) + + assert result == expected_response + + +# --------------------------------------------------------------------------- +# Tests for _coerce_command_token +# --------------------------------------------------------------------------- + + +class TestCoerceCommandToken: + """Tests for :func:`~app.utils.fail2ban_client._coerce_command_token`.""" + + def test_coerce_str_unchanged(self) -> None: + """``str`` tokens must pass through unchanged.""" + assert _coerce_command_token("sshd") == "sshd" + + def test_coerce_bool_unchanged(self) -> None: + """``bool`` tokens must pass through unchanged.""" + assert _coerce_command_token(True) is True # noqa: FBT003 + + def test_coerce_int_unchanged(self) -> None: + """``int`` tokens must pass through unchanged.""" + assert _coerce_command_token(42) == 42 + + def test_coerce_float_unchanged(self) -> None: + """``float`` tokens must pass through unchanged.""" + assert _coerce_command_token(1.5) == 1.5 + + def test_coerce_list_unchanged(self) -> None: + """``list`` tokens must pass through unchanged.""" + token: list[int] = [1, 2] + assert _coerce_command_token(token) is token + + def test_coerce_dict_unchanged(self) -> None: + """``dict`` tokens must pass through unchanged.""" + token: dict[str, str] = {"key": "value"} + assert _coerce_command_token(token) is token + + def test_coerce_set_unchanged(self) -> None: + """``set`` tokens must pass through unchanged.""" + token: set[str] = {"a", "b"} + assert _coerce_command_token(token) is token + + def test_coerce_unknown_type_stringified(self) -> None: + """Any other type must be converted to its ``str()`` representation.""" + + class CustomObj: + def __str__(self) -> str: + return "custom_repr" + + assert _coerce_command_token(CustomObj()) == "custom_repr" + + def test_coerce_none_stringified(self) -> None: + """``None`` must be stringified to ``"None"``.""" + assert _coerce_command_token(None) == "None" + + +# --------------------------------------------------------------------------- +# Extended tests for Fail2BanClient.send +# --------------------------------------------------------------------------- + + +class TestFail2BanClientSend: + """Tests for :meth:`Fail2BanClient.send`.""" + + @pytest.mark.asyncio + async def test_send_returns_response_on_success(self) -> None: + """``send()`` must return the response from the executor.""" + expected = [0, "OK"] + client = Fail2BanClient(socket_path="/fake/fail2ban.sock") + # asyncio.get_event_loop().run_in_executor is called inside send(). + # We patch it on the loop object returned by asyncio.get_event_loop(). + with patch("asyncio.get_event_loop") as mock_get_loop: + mock_loop = AsyncMock() + mock_loop.run_in_executor = AsyncMock(return_value=expected) + mock_get_loop.return_value = mock_loop + + result = await client.send(["status"]) + + assert result == expected + + @pytest.mark.asyncio + async def test_send_reraises_connection_error(self) -> None: + """``send()`` must re-raise :class:`Fail2BanConnectionError`.""" + client = Fail2BanClient(socket_path="/fake/fail2ban.sock") + + with patch("asyncio.get_event_loop") as mock_get_loop: + mock_loop = AsyncMock() + mock_loop.run_in_executor = AsyncMock( + side_effect=Fail2BanConnectionError("unreachable", "/fake/fail2ban.sock") + ) + mock_get_loop.return_value = mock_loop + + with pytest.raises(Fail2BanConnectionError): + await client.send(["status"]) + + @pytest.mark.asyncio + async def test_send_logs_warning_on_connection_error(self) -> None: + """``send()`` must log a warning when a connection error occurs.""" + client = Fail2BanClient(socket_path="/fake/fail2ban.sock") + + with patch("asyncio.get_event_loop") as mock_get_loop: + mock_loop = AsyncMock() + mock_loop.run_in_executor = AsyncMock( + side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock") + ) + mock_get_loop.return_value = mock_loop + + with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanConnectionError): + await client.send(["ping"]) + + warning_calls = [ + c for c in mock_log.warning.call_args_list + if c[0][0] == "fail2ban_connection_error" + ] + assert len(warning_calls) == 1 + + @pytest.mark.asyncio + async def test_send_reraises_protocol_error(self) -> None: + """``send()`` must re-raise :class:`Fail2BanProtocolError`.""" + client = Fail2BanClient(socket_path="/fake/fail2ban.sock") + + with patch("asyncio.get_event_loop") as mock_get_loop: + mock_loop = AsyncMock() + mock_loop.run_in_executor = AsyncMock( + side_effect=Fail2BanProtocolError("bad pickle") + ) + mock_get_loop.return_value = mock_loop + + with pytest.raises(Fail2BanProtocolError): + await client.send(["status"]) + + @pytest.mark.asyncio + async def test_send_logs_error_on_protocol_error(self) -> None: + """``send()`` must log an error when a protocol error occurs.""" + client = Fail2BanClient(socket_path="/fake/fail2ban.sock") + + with patch("asyncio.get_event_loop") as mock_get_loop: + mock_loop = AsyncMock() + mock_loop.run_in_executor = AsyncMock( + side_effect=Fail2BanProtocolError("corrupt response") + ) + mock_get_loop.return_value = mock_loop + + with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanProtocolError): + await client.send(["get", "sshd", "banned"]) + + error_calls = [ + c for c in mock_log.error.call_args_list + if c[0][0] == "fail2ban_protocol_error" + ] + assert len(error_calls) == 1 diff --git a/backend/tests/test_tasks/test_blocklist_import.py b/backend/tests/test_tasks/test_blocklist_import.py new file mode 100644 index 0000000..b512601 --- /dev/null +++ b/backend/tests/test_tasks/test_blocklist_import.py @@ -0,0 +1,352 @@ +"""Tests for the blocklist import background task. + +Validates that :func:`~app.tasks.blocklist_import._run_import` correctly +delegates to :func:`~app.services.blocklist_service.import_all`, handles +unexpected exceptions gracefully, and that :func:`~app.tasks.blocklist_import._apply_schedule` +registers the correct APScheduler trigger for each frequency preset. +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import AsyncMock, MagicMock, call, patch + +import pytest + +from app.models.blocklist import ImportRunResult, ScheduleConfig, ScheduleFrequency +from app.tasks.blocklist_import import JOB_ID, _apply_schedule, _run_import + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_app( + import_result: ImportRunResult | None = None, + import_side_effect: Exception | None = None, +) -> MagicMock: + """Build a minimal mock ``app`` for blocklist import task tests. + + Args: + import_result: Value returned by the mocked ``import_all`` call. + import_side_effect: If provided, ``import_all`` raises this exception. + + Returns: + A :class:`unittest.mock.MagicMock` that mimics ``fastapi.FastAPI``. + """ + app = MagicMock() + app.state.db = MagicMock() + app.state.http_session = MagicMock() + app.state.settings.fail2ban_socket = "/var/run/fail2ban/fail2ban.sock" + return app + + +def _make_import_result( + total_imported: int = 50, + total_skipped: int = 5, + errors_count: int = 0, +) -> ImportRunResult: + """Construct a minimal :class:`ImportRunResult` for testing. + + Args: + total_imported: Number of IPs successfully imported. + total_skipped: Number of skipped entries. + errors_count: Number of sources that encountered errors. + + Returns: + An :class:`ImportRunResult` with the given counters. + """ + return ImportRunResult( + results=[], + total_imported=total_imported, + total_skipped=total_skipped, + errors_count=errors_count, + ) + + +def _make_scheduler(has_existing_job: bool = False) -> MagicMock: + """Build a mock APScheduler-like object. + + Args: + has_existing_job: Whether ``get_job(JOB_ID)`` should return a truthy value. + + Returns: + A :class:`unittest.mock.MagicMock` that mimics a scheduler. + """ + scheduler = MagicMock() + scheduler.get_job.return_value = MagicMock() if has_existing_job else None + return scheduler + + +# --------------------------------------------------------------------------- +# Tests for _run_import +# --------------------------------------------------------------------------- + + +class TestRunImport: + """Tests for :func:`~app.tasks.blocklist_import._run_import`.""" + + @pytest.mark.asyncio + async def test_run_import_happy_path_calls_import_all(self) -> None: + """``_run_import`` must delegate to ``blocklist_service.import_all``.""" + app = _make_app() + result = _make_import_result(total_imported=100, total_skipped=2, errors_count=0) + + with patch( + "app.tasks.blocklist_import.blocklist_service.import_all", + new_callable=AsyncMock, + return_value=result, + ) as mock_import_all: + await _run_import(app) + + mock_import_all.assert_awaited_once_with( + app.state.db, + app.state.http_session, + app.state.settings.fail2ban_socket, + ) + + @pytest.mark.asyncio + async def test_run_import_logs_counters_on_success(self) -> None: + """``_run_import`` must emit a structured log event with import counters.""" + app = _make_app() + result = _make_import_result(total_imported=42, total_skipped=3, errors_count=1) + + with patch( + "app.tasks.blocklist_import.blocklist_service.import_all", + new_callable=AsyncMock, + return_value=result, + ), patch("app.tasks.blocklist_import.log") as mock_log: + await _run_import(app) + + info_calls = [c for c in mock_log.info.call_args_list if c[0][0] == "blocklist_import_finished"] + assert len(info_calls) == 1 + kwargs = info_calls[0][1] + assert kwargs["total_imported"] == 42 + assert kwargs["total_skipped"] == 3 + assert kwargs["errors"] == 1 + + @pytest.mark.asyncio + async def test_run_import_logs_start_event(self) -> None: + """``_run_import`` must emit a ``blocklist_import_starting`` event.""" + app = _make_app() + result = _make_import_result() + + with patch( + "app.tasks.blocklist_import.blocklist_service.import_all", + new_callable=AsyncMock, + return_value=result, + ), patch("app.tasks.blocklist_import.log") as mock_log: + await _run_import(app) + + start_calls = [c for c in mock_log.info.call_args_list if c[0][0] == "blocklist_import_starting"] + assert len(start_calls) == 1 + + @pytest.mark.asyncio + async def test_run_import_handles_unexpected_exception(self) -> None: + """``_run_import`` must catch unexpected exceptions and log them.""" + app = _make_app() + + with patch( + "app.tasks.blocklist_import.blocklist_service.import_all", + new_callable=AsyncMock, + side_effect=RuntimeError("unexpected failure"), + ), patch("app.tasks.blocklist_import.log") as mock_log: + # Must not raise — the task swallows unexpected errors. + await _run_import(app) + + mock_log.exception.assert_called_once_with("blocklist_import_unexpected_error") + + +# --------------------------------------------------------------------------- +# Tests for _apply_schedule +# --------------------------------------------------------------------------- + + +class TestApplySchedule: + """Tests for :func:`~app.tasks.blocklist_import._apply_schedule`.""" + + def _make_app_with_scheduler(self, scheduler: Any) -> MagicMock: + """Return a mock ``app`` whose ``state.scheduler`` is *scheduler*. + + Args: + scheduler: Mock scheduler object. + + Returns: + Mock FastAPI application instance. + """ + app = MagicMock() + app.state.scheduler = scheduler + return app + + def test_apply_schedule_daily_registers_cron_trigger(self) -> None: + """Daily frequency must register a ``"cron"`` trigger with hour and minute.""" + scheduler = _make_scheduler() + app = self._make_app_with_scheduler(scheduler) + config = ScheduleConfig(frequency=ScheduleFrequency.daily, hour=3, minute=0) + + _apply_schedule(app, config) + + scheduler.add_job.assert_called_once() + _, kwargs = scheduler.add_job.call_args + assert kwargs["trigger"] == "cron" + assert kwargs["hour"] == 3 + assert kwargs["minute"] == 0 + assert "day_of_week" not in kwargs + + def test_apply_schedule_hourly_registers_interval_trigger(self) -> None: + """Hourly frequency must register an ``"interval"`` trigger with correct hours.""" + scheduler = _make_scheduler() + app = self._make_app_with_scheduler(scheduler) + config = ScheduleConfig(frequency=ScheduleFrequency.hourly, interval_hours=6) + + _apply_schedule(app, config) + + scheduler.add_job.assert_called_once() + _, kwargs = scheduler.add_job.call_args + assert kwargs["trigger"] == "interval" + assert kwargs["hours"] == 6 + + def test_apply_schedule_weekly_registers_cron_trigger_with_day(self) -> None: + """Weekly frequency must register a ``"cron"`` trigger including ``day_of_week``.""" + scheduler = _make_scheduler() + app = self._make_app_with_scheduler(scheduler) + config = ScheduleConfig( + frequency=ScheduleFrequency.weekly, + day_of_week=0, + hour=4, + minute=30, + ) + + _apply_schedule(app, config) + + scheduler.add_job.assert_called_once() + _, kwargs = scheduler.add_job.call_args + assert kwargs["trigger"] == "cron" + assert kwargs["day_of_week"] == 0 + assert kwargs["hour"] == 4 + assert kwargs["minute"] == 30 + + def test_apply_schedule_removes_existing_job_before_adding(self) -> None: + """If a job with ``JOB_ID`` exists, it must be removed before a new one is added.""" + scheduler = _make_scheduler(has_existing_job=True) + app = self._make_app_with_scheduler(scheduler) + config = ScheduleConfig(frequency=ScheduleFrequency.daily) + + _apply_schedule(app, config) + + scheduler.remove_job.assert_called_once_with(JOB_ID) + assert scheduler.remove_job.call_args_list.index(call(JOB_ID)) < len( + scheduler.add_job.call_args_list + ) + + def test_apply_schedule_skips_remove_when_no_existing_job(self) -> None: + """If no job exists, ``remove_job`` must not be called.""" + scheduler = _make_scheduler(has_existing_job=False) + app = self._make_app_with_scheduler(scheduler) + config = ScheduleConfig(frequency=ScheduleFrequency.daily) + + _apply_schedule(app, config) + + scheduler.remove_job.assert_not_called() + + def test_apply_schedule_uses_stable_job_id(self) -> None: + """The registered job must use the module-level ``JOB_ID`` constant.""" + scheduler = _make_scheduler() + app = self._make_app_with_scheduler(scheduler) + config = ScheduleConfig(frequency=ScheduleFrequency.daily) + + _apply_schedule(app, config) + + _, kwargs = scheduler.add_job.call_args + assert kwargs["id"] == JOB_ID + + def test_apply_schedule_passes_app_in_kwargs(self) -> None: + """The scheduled job must receive ``app`` as a kwarg for state access.""" + scheduler = _make_scheduler() + app = self._make_app_with_scheduler(scheduler) + config = ScheduleConfig(frequency=ScheduleFrequency.daily) + + _apply_schedule(app, config) + + _, kwargs = scheduler.add_job.call_args + assert kwargs["kwargs"] == {"app": app} + + +# --------------------------------------------------------------------------- +# Tests for register / reschedule +# --------------------------------------------------------------------------- + + +class TestRegister: + """Tests for :func:`~app.tasks.blocklist_import.register`.""" + + def test_register_calls_apply_schedule_via_event_loop(self) -> None: + """``register`` must call ``_apply_schedule`` after reading the stored config.""" + import asyncio + + from app.tasks.blocklist_import import register + + app = MagicMock() + app.state.db = MagicMock() + app.state.scheduler = MagicMock() + app.state.scheduler.get_job.return_value = None + + config = ScheduleConfig(frequency=ScheduleFrequency.daily, hour=3, minute=0) + + with patch( + "app.tasks.blocklist_import.blocklist_service.get_schedule", + new_callable=AsyncMock, + return_value=config, + ), patch("app.tasks.blocklist_import._apply_schedule") as mock_apply: + # Use a fresh event loop to avoid interference from pytest-asyncio. + loop = asyncio.new_event_loop() + try: + with patch("asyncio.get_event_loop", return_value=loop): + register(app) + finally: + loop.close() + + mock_apply.assert_called_once_with(app, config) + + def test_register_falls_back_to_ensure_future_on_runtime_error(self) -> None: + """When ``run_until_complete`` raises ``RuntimeError``, ``ensure_future`` is used.""" + from app.tasks.blocklist_import import register + + app = MagicMock() + app.state.db = MagicMock() + app.state.scheduler = MagicMock() + + config = ScheduleConfig(frequency=ScheduleFrequency.daily) + + mock_loop = MagicMock() + mock_loop.run_until_complete.side_effect = RuntimeError("already running") + + with ( + patch( + "app.tasks.blocklist_import.blocklist_service.get_schedule", + new_callable=AsyncMock, + return_value=config, + ), + patch("asyncio.get_event_loop", return_value=mock_loop), + patch("asyncio.ensure_future") as mock_ensure_future, + ): + register(app) + + mock_ensure_future.assert_called_once() + + +class TestReschedule: + """Tests for :func:`~app.tasks.blocklist_import.reschedule`.""" + + def test_reschedule_calls_ensure_future(self) -> None: + """``reschedule`` must schedule the re-registration with ``asyncio.ensure_future``.""" + from app.tasks.blocklist_import import reschedule + + app = MagicMock() + app.state.db = MagicMock() + app.state.scheduler = MagicMock() + + with patch("asyncio.ensure_future") as mock_ensure_future: + reschedule(app) + + mock_ensure_future.assert_called_once() diff --git a/backend/tests/test_tasks/test_geo_cache_flush.py b/backend/tests/test_tasks/test_geo_cache_flush.py new file mode 100644 index 0000000..ab65a39 --- /dev/null +++ b/backend/tests/test_tasks/test_geo_cache_flush.py @@ -0,0 +1,136 @@ +"""Tests for the geo cache flush background task. + +Validates that :func:`~app.tasks.geo_cache_flush._run_flush` correctly +delegates to :func:`~app.services.geo_service.flush_dirty` and only logs +when entries were actually flushed, and that +:func:`~app.tasks.geo_cache_flush.register` configures the APScheduler job +with the correct interval and stable job ID. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.tasks.geo_cache_flush import GEO_FLUSH_INTERVAL, JOB_ID, _run_flush, register + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_app(flush_count: int = 0) -> MagicMock: + """Build a minimal mock ``app`` for geo cache flush task tests. + + Args: + flush_count: The value returned by the mocked ``flush_dirty`` call. + + Returns: + A :class:`unittest.mock.MagicMock` that mimics ``fastapi.FastAPI``. + """ + app = MagicMock() + app.state.db = MagicMock() + app.state.scheduler = MagicMock() + return app + + +# --------------------------------------------------------------------------- +# Tests for _run_flush +# --------------------------------------------------------------------------- + + +class TestRunFlush: + """Tests for :func:`~app.tasks.geo_cache_flush._run_flush`.""" + + @pytest.mark.asyncio + async def test_run_flush_calls_flush_dirty_with_db(self) -> None: + """``_run_flush`` must call ``geo_service.flush_dirty`` with ``app.state.db``.""" + app = _make_app() + + with patch( + "app.tasks.geo_cache_flush.geo_service.flush_dirty", + new_callable=AsyncMock, + return_value=0, + ) as mock_flush: + await _run_flush(app) + + mock_flush.assert_awaited_once_with(app.state.db) + + @pytest.mark.asyncio + async def test_run_flush_logs_when_entries_flushed(self) -> None: + """``_run_flush`` must emit a debug log when ``flush_dirty`` returns > 0.""" + app = _make_app() + + with patch( + "app.tasks.geo_cache_flush.geo_service.flush_dirty", + new_callable=AsyncMock, + return_value=15, + ), patch("app.tasks.geo_cache_flush.log") as mock_log: + await _run_flush(app) + + debug_calls = [c for c in mock_log.debug.call_args_list if c[0][0] == "geo_cache_flush_ran"] + assert len(debug_calls) == 1 + assert debug_calls[0][1]["flushed"] == 15 + + @pytest.mark.asyncio + async def test_run_flush_does_not_log_when_nothing_to_flush(self) -> None: + """``_run_flush`` must not emit any log when ``flush_dirty`` returns 0.""" + app = _make_app() + + with patch( + "app.tasks.geo_cache_flush.geo_service.flush_dirty", + new_callable=AsyncMock, + return_value=0, + ), patch("app.tasks.geo_cache_flush.log") as mock_log: + await _run_flush(app) + + debug_calls = [c for c in mock_log.debug.call_args_list if c[0][0] == "geo_cache_flush_ran"] + assert debug_calls == [] + + +# --------------------------------------------------------------------------- +# Tests for register +# --------------------------------------------------------------------------- + + +class TestRegister: + """Tests for :func:`~app.tasks.geo_cache_flush.register`.""" + + def test_register_adds_interval_job_to_scheduler(self) -> None: + """``register`` must add a job with an ``"interval"`` trigger.""" + app = _make_app() + + register(app) + + app.state.scheduler.add_job.assert_called_once() + _, kwargs = app.state.scheduler.add_job.call_args + assert kwargs["trigger"] == "interval" + assert kwargs["seconds"] == GEO_FLUSH_INTERVAL + + def test_register_uses_stable_job_id(self) -> None: + """``register`` must use the module-level ``JOB_ID`` constant.""" + app = _make_app() + + register(app) + + _, kwargs = app.state.scheduler.add_job.call_args + assert kwargs["id"] == JOB_ID + + def test_register_sets_replace_existing(self) -> None: + """``register`` must use ``replace_existing=True`` to avoid duplicate jobs.""" + app = _make_app() + + register(app) + + _, kwargs = app.state.scheduler.add_job.call_args + assert kwargs["replace_existing"] is True + + def test_register_passes_app_in_kwargs(self) -> None: + """The scheduled job must receive ``app`` as a kwarg for state access.""" + app = _make_app() + + register(app) + + _, kwargs = app.state.scheduler.add_job.call_args + assert kwargs["kwargs"] == {"app": app} diff --git a/backend/tests/test_tasks/test_health_check.py b/backend/tests/test_tasks/test_health_check.py new file mode 100644 index 0000000..2615c8b --- /dev/null +++ b/backend/tests/test_tasks/test_health_check.py @@ -0,0 +1,238 @@ +"""Tests for the health-check background task. + +Validates that :func:`~app.tasks.health_check._run_probe` correctly stores +the probe result on ``app.state.server_status``, logs online/offline +transitions, and that :func:`~app.tasks.health_check.register` configures +the scheduler and primes the initial status. +""" + +from __future__ import annotations + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from app.models.server import ServerStatus +from app.tasks.health_check import HEALTH_CHECK_INTERVAL, _run_probe, register + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_app(prev_online: bool = False) -> MagicMock: + """Build a minimal mock ``app`` for health-check task tests. + + Args: + prev_online: Whether the previous ``server_status`` was online. + + Returns: + A :class:`unittest.mock.MagicMock` that mimics ``fastapi.FastAPI``. + """ + app = MagicMock() + app.state.settings.fail2ban_socket = "/var/run/fail2ban/fail2ban.sock" + app.state.server_status = ServerStatus(online=prev_online) + app.state.scheduler = MagicMock() + return app + + +# --------------------------------------------------------------------------- +# Tests for _run_probe +# --------------------------------------------------------------------------- + + +class TestRunProbe: + """Tests for :func:`~app.tasks.health_check._run_probe`.""" + + @pytest.mark.asyncio + async def test_run_probe_updates_server_status(self) -> None: + """``_run_probe`` must store the probe result on ``app.state.server_status``.""" + app = _make_app(prev_online=False) + new_status = ServerStatus(online=True, version="0.11.2", active_jails=3) + + with patch( + "app.tasks.health_check.health_service.probe", + new_callable=AsyncMock, + return_value=new_status, + ): + await _run_probe(app) + + assert app.state.server_status is new_status + + @pytest.mark.asyncio + async def test_run_probe_logs_came_online_transition(self) -> None: + """When fail2ban comes online, ``"fail2ban_came_online"`` must be logged.""" + app = _make_app(prev_online=False) + new_status = ServerStatus(online=True, version="0.11.2", active_jails=2) + + with patch( + "app.tasks.health_check.health_service.probe", + new_callable=AsyncMock, + return_value=new_status, + ), patch("app.tasks.health_check.log") as mock_log: + await _run_probe(app) + + online_calls = [c for c in mock_log.info.call_args_list if c[0][0] == "fail2ban_came_online"] + assert len(online_calls) == 1 + + @pytest.mark.asyncio + async def test_run_probe_logs_went_offline_transition(self) -> None: + """When fail2ban goes offline, ``"fail2ban_went_offline"`` must be logged.""" + app = _make_app(prev_online=True) + new_status = ServerStatus(online=False) + + with patch( + "app.tasks.health_check.health_service.probe", + new_callable=AsyncMock, + return_value=new_status, + ), patch("app.tasks.health_check.log") as mock_log: + await _run_probe(app) + + offline_calls = [c for c in mock_log.warning.call_args_list if c[0][0] == "fail2ban_went_offline"] + assert len(offline_calls) == 1 + + @pytest.mark.asyncio + async def test_run_probe_stable_online_no_transition_log(self) -> None: + """When status stays online, no transition events must be emitted.""" + app = _make_app(prev_online=True) + new_status = ServerStatus(online=True, version="0.11.2", active_jails=1) + + with patch( + "app.tasks.health_check.health_service.probe", + new_callable=AsyncMock, + return_value=new_status, + ), patch("app.tasks.health_check.log") as mock_log: + await _run_probe(app) + + transition_calls = [ + c + for c in mock_log.info.call_args_list + if c[0][0] in ("fail2ban_came_online", "fail2ban_went_offline") + ] + transition_calls += [ + c + for c in mock_log.warning.call_args_list + if c[0][0] in ("fail2ban_came_online", "fail2ban_went_offline") + ] + assert transition_calls == [] + + @pytest.mark.asyncio + async def test_run_probe_stable_offline_no_transition_log(self) -> None: + """When status stays offline, no transition events must be emitted.""" + app = _make_app(prev_online=False) + new_status = ServerStatus(online=False) + + with patch( + "app.tasks.health_check.health_service.probe", + new_callable=AsyncMock, + return_value=new_status, + ), patch("app.tasks.health_check.log") as mock_log: + await _run_probe(app) + + transition_calls = [ + c + for c in mock_log.info.call_args_list + if c[0][0] == "fail2ban_came_online" + ] + transition_calls += [ + c + for c in mock_log.warning.call_args_list + if c[0][0] == "fail2ban_went_offline" + ] + assert transition_calls == [] + + @pytest.mark.asyncio + async def test_run_probe_uses_socket_path_from_settings(self) -> None: + """``_run_probe`` must pass the socket path from ``app.state.settings``.""" + expected_socket = "/custom/fail2ban.sock" + app = _make_app() + app.state.settings.fail2ban_socket = expected_socket + new_status = ServerStatus(online=False) + + with patch( + "app.tasks.health_check.health_service.probe", + new_callable=AsyncMock, + return_value=new_status, + ) as mock_probe: + await _run_probe(app) + + mock_probe.assert_awaited_once_with(expected_socket) + + @pytest.mark.asyncio + async def test_run_probe_uses_default_offline_status_when_state_missing(self) -> None: + """``_run_probe`` must handle missing ``server_status`` on first run.""" + app = _make_app() + # Simulate first run: no previous server_status attribute set yet. + del app.state.server_status + new_status = ServerStatus(online=True, version="0.11.2", active_jails=0) + + with ( + patch( + "app.tasks.health_check.health_service.probe", + new_callable=AsyncMock, + return_value=new_status, + ), + patch("app.tasks.health_check.log"), + ): + # Must not raise even with no prior status. + await _run_probe(app) + + assert app.state.server_status is new_status + + +# --------------------------------------------------------------------------- +# Tests for register +# --------------------------------------------------------------------------- + + +class TestRegister: + """Tests for :func:`~app.tasks.health_check.register`.""" + + def test_register_adds_interval_job_to_scheduler(self) -> None: + """``register`` must add a job with an ``"interval"`` trigger.""" + app = _make_app() + + register(app) + + app.state.scheduler.add_job.assert_called_once() + _, kwargs = app.state.scheduler.add_job.call_args + assert kwargs["trigger"] == "interval" + assert kwargs["seconds"] == HEALTH_CHECK_INTERVAL + + def test_register_primes_offline_server_status(self) -> None: + """``register`` must set an initial offline status before the first probe fires.""" + app = _make_app() + # Reset any value set by _make_app. + del app.state.server_status + + register(app) + + assert isinstance(app.state.server_status, ServerStatus) + assert app.state.server_status.online is False + + def test_register_uses_stable_job_id(self) -> None: + """``register`` must register the job under the stable id ``"health_check"``.""" + app = _make_app() + + register(app) + + _, kwargs = app.state.scheduler.add_job.call_args + assert kwargs["id"] == "health_check" + + def test_register_sets_replace_existing(self) -> None: + """``register`` must use ``replace_existing=True`` to avoid duplicate jobs.""" + app = _make_app() + + register(app) + + _, kwargs = app.state.scheduler.add_job.call_args + assert kwargs["replace_existing"] is True + + def test_register_passes_app_in_kwargs(self) -> None: + """The scheduled job must receive ``app`` as a kwarg for state access.""" + app = _make_app() + + register(app) + + _, kwargs = app.state.scheduler.add_job.call_args + assert kwargs["kwargs"] == {"app": app}