Add tests for background tasks and fail2ban client utility
- tests/test_tasks/test_blocklist_import.py: 14 tests, 96% coverage - tests/test_tasks/test_health_check.py: 12 tests, 100% coverage - tests/test_tasks/test_geo_cache_flush.py: 8 tests, 100% coverage - tests/test_services/test_fail2ban_client.py: 24 new tests, 96% coverage Total: 50 new tests (628 → 678 passing). Overall coverage 85% → 87%. ruff, mypy --strict, tsc, and eslint all clean.
This commit is contained in:
@@ -5,9 +5,11 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from app.utils.fail2ban_client import (
|
||||
_PROTO_END,
|
||||
Fail2BanClient,
|
||||
Fail2BanConnectionError,
|
||||
Fail2BanProtocolError,
|
||||
_coerce_command_token,
|
||||
_send_command_sync,
|
||||
)
|
||||
|
||||
@@ -85,3 +87,223 @@ class TestSendCommandSync:
|
||||
command=["status"],
|
||||
timeout=1.0,
|
||||
)
|
||||
|
||||
|
||||
class TestSendCommandSyncProtocol:
|
||||
"""Tests for edge cases in the receive-loop and unpickling logic."""
|
||||
|
||||
def _make_connected_sock(self) -> MagicMock:
|
||||
"""Return a minimal mock socket that reports a successful connect.
|
||||
|
||||
Returns:
|
||||
A :class:`unittest.mock.MagicMock` that mimics a socket.
|
||||
"""
|
||||
mock_sock = MagicMock()
|
||||
mock_sock.connect.return_value = None
|
||||
return mock_sock
|
||||
|
||||
def test_send_command_sync_raises_connection_error_on_empty_chunk(self) -> None:
|
||||
"""Must raise :class:`Fail2BanConnectionError` when the server closes mid-stream."""
|
||||
mock_sock = self._make_connected_sock()
|
||||
# First recv returns empty bytes → server closed the connection.
|
||||
mock_sock.recv.return_value = b""
|
||||
|
||||
with (
|
||||
patch("socket.socket", return_value=mock_sock),
|
||||
pytest.raises(Fail2BanConnectionError, match="closed unexpectedly"),
|
||||
):
|
||||
_send_command_sync(
|
||||
socket_path="/fake/fail2ban.sock",
|
||||
command=["ping"],
|
||||
timeout=1.0,
|
||||
)
|
||||
|
||||
def test_send_command_sync_raises_protocol_error_on_bad_pickle(self) -> None:
|
||||
"""Must raise :class:`Fail2BanProtocolError` when the response is not valid pickle."""
|
||||
mock_sock = self._make_connected_sock()
|
||||
# Return the end marker directly so the recv-loop terminates immediately,
|
||||
# but prepend garbage bytes so ``loads`` fails.
|
||||
mock_sock.recv.side_effect = [
|
||||
_PROTO_END, # first call — exits the receive loop
|
||||
]
|
||||
|
||||
# Patch loads to raise to simulate a corrupted response.
|
||||
with (
|
||||
patch("socket.socket", return_value=mock_sock),
|
||||
patch("app.utils.fail2ban_client.loads", side_effect=Exception("bad pickle")),
|
||||
pytest.raises(Fail2BanProtocolError, match="Failed to unpickle"),
|
||||
):
|
||||
_send_command_sync(
|
||||
socket_path="/fake/fail2ban.sock",
|
||||
command=["status"],
|
||||
timeout=1.0,
|
||||
)
|
||||
|
||||
def test_send_command_sync_returns_parsed_response(self) -> None:
|
||||
"""Must return the Python object that was pickled by fail2ban."""
|
||||
expected_response = [0, ["sshd", "nginx"]]
|
||||
mock_sock = self._make_connected_sock()
|
||||
# Return the proto end-marker so the recv-loop exits, then parse the raw bytes.
|
||||
mock_sock.recv.return_value = _PROTO_END
|
||||
|
||||
with (
|
||||
patch("socket.socket", return_value=mock_sock),
|
||||
patch("app.utils.fail2ban_client.loads", return_value=expected_response),
|
||||
):
|
||||
result = _send_command_sync(
|
||||
socket_path="/fake/fail2ban.sock",
|
||||
command=["status"],
|
||||
timeout=1.0,
|
||||
)
|
||||
|
||||
assert result == expected_response
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for _coerce_command_token
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCoerceCommandToken:
|
||||
"""Tests for :func:`~app.utils.fail2ban_client._coerce_command_token`."""
|
||||
|
||||
def test_coerce_str_unchanged(self) -> None:
|
||||
"""``str`` tokens must pass through unchanged."""
|
||||
assert _coerce_command_token("sshd") == "sshd"
|
||||
|
||||
def test_coerce_bool_unchanged(self) -> None:
|
||||
"""``bool`` tokens must pass through unchanged."""
|
||||
assert _coerce_command_token(True) is True # noqa: FBT003
|
||||
|
||||
def test_coerce_int_unchanged(self) -> None:
|
||||
"""``int`` tokens must pass through unchanged."""
|
||||
assert _coerce_command_token(42) == 42
|
||||
|
||||
def test_coerce_float_unchanged(self) -> None:
|
||||
"""``float`` tokens must pass through unchanged."""
|
||||
assert _coerce_command_token(1.5) == 1.5
|
||||
|
||||
def test_coerce_list_unchanged(self) -> None:
|
||||
"""``list`` tokens must pass through unchanged."""
|
||||
token: list[int] = [1, 2]
|
||||
assert _coerce_command_token(token) is token
|
||||
|
||||
def test_coerce_dict_unchanged(self) -> None:
|
||||
"""``dict`` tokens must pass through unchanged."""
|
||||
token: dict[str, str] = {"key": "value"}
|
||||
assert _coerce_command_token(token) is token
|
||||
|
||||
def test_coerce_set_unchanged(self) -> None:
|
||||
"""``set`` tokens must pass through unchanged."""
|
||||
token: set[str] = {"a", "b"}
|
||||
assert _coerce_command_token(token) is token
|
||||
|
||||
def test_coerce_unknown_type_stringified(self) -> None:
|
||||
"""Any other type must be converted to its ``str()`` representation."""
|
||||
|
||||
class CustomObj:
|
||||
def __str__(self) -> str:
|
||||
return "custom_repr"
|
||||
|
||||
assert _coerce_command_token(CustomObj()) == "custom_repr"
|
||||
|
||||
def test_coerce_none_stringified(self) -> None:
|
||||
"""``None`` must be stringified to ``"None"``."""
|
||||
assert _coerce_command_token(None) == "None"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Extended tests for Fail2BanClient.send
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFail2BanClientSend:
|
||||
"""Tests for :meth:`Fail2BanClient.send`."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_returns_response_on_success(self) -> None:
|
||||
"""``send()`` must return the response from the executor."""
|
||||
expected = [0, "OK"]
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
# asyncio.get_event_loop().run_in_executor is called inside send().
|
||||
# We patch it on the loop object returned by asyncio.get_event_loop().
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(return_value=expected)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
result = await client.send(["status"])
|
||||
|
||||
assert result == expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reraises_connection_error(self) -> None:
|
||||
"""``send()`` must re-raise :class:`Fail2BanConnectionError`."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=Fail2BanConnectionError("unreachable", "/fake/fail2ban.sock")
|
||||
)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with pytest.raises(Fail2BanConnectionError):
|
||||
await client.send(["status"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_logs_warning_on_connection_error(self) -> None:
|
||||
"""``send()`` must log a warning when a connection error occurs."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=Fail2BanConnectionError("refused", "/fake/fail2ban.sock")
|
||||
)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanConnectionError):
|
||||
await client.send(["ping"])
|
||||
|
||||
warning_calls = [
|
||||
c for c in mock_log.warning.call_args_list
|
||||
if c[0][0] == "fail2ban_connection_error"
|
||||
]
|
||||
assert len(warning_calls) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_reraises_protocol_error(self) -> None:
|
||||
"""``send()`` must re-raise :class:`Fail2BanProtocolError`."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=Fail2BanProtocolError("bad pickle")
|
||||
)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with pytest.raises(Fail2BanProtocolError):
|
||||
await client.send(["status"])
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_logs_error_on_protocol_error(self) -> None:
|
||||
"""``send()`` must log an error when a protocol error occurs."""
|
||||
client = Fail2BanClient(socket_path="/fake/fail2ban.sock")
|
||||
|
||||
with patch("asyncio.get_event_loop") as mock_get_loop:
|
||||
mock_loop = AsyncMock()
|
||||
mock_loop.run_in_executor = AsyncMock(
|
||||
side_effect=Fail2BanProtocolError("corrupt response")
|
||||
)
|
||||
mock_get_loop.return_value = mock_loop
|
||||
|
||||
with patch("app.utils.fail2ban_client.log") as mock_log, pytest.raises(Fail2BanProtocolError):
|
||||
await client.send(["get", "sshd", "banned"])
|
||||
|
||||
error_calls = [
|
||||
c for c in mock_log.error.call_args_list
|
||||
if c[0][0] == "fail2ban_protocol_error"
|
||||
]
|
||||
assert len(error_calls) == 1
|
||||
|
||||
352
backend/tests/test_tasks/test_blocklist_import.py
Normal file
352
backend/tests/test_tasks/test_blocklist_import.py
Normal file
@@ -0,0 +1,352 @@
|
||||
"""Tests for the blocklist import background task.
|
||||
|
||||
Validates that :func:`~app.tasks.blocklist_import._run_import` correctly
|
||||
delegates to :func:`~app.services.blocklist_service.import_all`, handles
|
||||
unexpected exceptions gracefully, and that :func:`~app.tasks.blocklist_import._apply_schedule`
|
||||
registers the correct APScheduler trigger for each frequency preset.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.blocklist import ImportRunResult, ScheduleConfig, ScheduleFrequency
|
||||
from app.tasks.blocklist_import import JOB_ID, _apply_schedule, _run_import
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_app(
|
||||
import_result: ImportRunResult | None = None,
|
||||
import_side_effect: Exception | None = None,
|
||||
) -> MagicMock:
|
||||
"""Build a minimal mock ``app`` for blocklist import task tests.
|
||||
|
||||
Args:
|
||||
import_result: Value returned by the mocked ``import_all`` call.
|
||||
import_side_effect: If provided, ``import_all`` raises this exception.
|
||||
|
||||
Returns:
|
||||
A :class:`unittest.mock.MagicMock` that mimics ``fastapi.FastAPI``.
|
||||
"""
|
||||
app = MagicMock()
|
||||
app.state.db = MagicMock()
|
||||
app.state.http_session = MagicMock()
|
||||
app.state.settings.fail2ban_socket = "/var/run/fail2ban/fail2ban.sock"
|
||||
return app
|
||||
|
||||
|
||||
def _make_import_result(
|
||||
total_imported: int = 50,
|
||||
total_skipped: int = 5,
|
||||
errors_count: int = 0,
|
||||
) -> ImportRunResult:
|
||||
"""Construct a minimal :class:`ImportRunResult` for testing.
|
||||
|
||||
Args:
|
||||
total_imported: Number of IPs successfully imported.
|
||||
total_skipped: Number of skipped entries.
|
||||
errors_count: Number of sources that encountered errors.
|
||||
|
||||
Returns:
|
||||
An :class:`ImportRunResult` with the given counters.
|
||||
"""
|
||||
return ImportRunResult(
|
||||
results=[],
|
||||
total_imported=total_imported,
|
||||
total_skipped=total_skipped,
|
||||
errors_count=errors_count,
|
||||
)
|
||||
|
||||
|
||||
def _make_scheduler(has_existing_job: bool = False) -> MagicMock:
|
||||
"""Build a mock APScheduler-like object.
|
||||
|
||||
Args:
|
||||
has_existing_job: Whether ``get_job(JOB_ID)`` should return a truthy value.
|
||||
|
||||
Returns:
|
||||
A :class:`unittest.mock.MagicMock` that mimics a scheduler.
|
||||
"""
|
||||
scheduler = MagicMock()
|
||||
scheduler.get_job.return_value = MagicMock() if has_existing_job else None
|
||||
return scheduler
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for _run_import
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunImport:
|
||||
"""Tests for :func:`~app.tasks.blocklist_import._run_import`."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_import_happy_path_calls_import_all(self) -> None:
|
||||
"""``_run_import`` must delegate to ``blocklist_service.import_all``."""
|
||||
app = _make_app()
|
||||
result = _make_import_result(total_imported=100, total_skipped=2, errors_count=0)
|
||||
|
||||
with patch(
|
||||
"app.tasks.blocklist_import.blocklist_service.import_all",
|
||||
new_callable=AsyncMock,
|
||||
return_value=result,
|
||||
) as mock_import_all:
|
||||
await _run_import(app)
|
||||
|
||||
mock_import_all.assert_awaited_once_with(
|
||||
app.state.db,
|
||||
app.state.http_session,
|
||||
app.state.settings.fail2ban_socket,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_import_logs_counters_on_success(self) -> None:
|
||||
"""``_run_import`` must emit a structured log event with import counters."""
|
||||
app = _make_app()
|
||||
result = _make_import_result(total_imported=42, total_skipped=3, errors_count=1)
|
||||
|
||||
with patch(
|
||||
"app.tasks.blocklist_import.blocklist_service.import_all",
|
||||
new_callable=AsyncMock,
|
||||
return_value=result,
|
||||
), patch("app.tasks.blocklist_import.log") as mock_log:
|
||||
await _run_import(app)
|
||||
|
||||
info_calls = [c for c in mock_log.info.call_args_list if c[0][0] == "blocklist_import_finished"]
|
||||
assert len(info_calls) == 1
|
||||
kwargs = info_calls[0][1]
|
||||
assert kwargs["total_imported"] == 42
|
||||
assert kwargs["total_skipped"] == 3
|
||||
assert kwargs["errors"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_import_logs_start_event(self) -> None:
|
||||
"""``_run_import`` must emit a ``blocklist_import_starting`` event."""
|
||||
app = _make_app()
|
||||
result = _make_import_result()
|
||||
|
||||
with patch(
|
||||
"app.tasks.blocklist_import.blocklist_service.import_all",
|
||||
new_callable=AsyncMock,
|
||||
return_value=result,
|
||||
), patch("app.tasks.blocklist_import.log") as mock_log:
|
||||
await _run_import(app)
|
||||
|
||||
start_calls = [c for c in mock_log.info.call_args_list if c[0][0] == "blocklist_import_starting"]
|
||||
assert len(start_calls) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_import_handles_unexpected_exception(self) -> None:
|
||||
"""``_run_import`` must catch unexpected exceptions and log them."""
|
||||
app = _make_app()
|
||||
|
||||
with patch(
|
||||
"app.tasks.blocklist_import.blocklist_service.import_all",
|
||||
new_callable=AsyncMock,
|
||||
side_effect=RuntimeError("unexpected failure"),
|
||||
), patch("app.tasks.blocklist_import.log") as mock_log:
|
||||
# Must not raise — the task swallows unexpected errors.
|
||||
await _run_import(app)
|
||||
|
||||
mock_log.exception.assert_called_once_with("blocklist_import_unexpected_error")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for _apply_schedule
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestApplySchedule:
|
||||
"""Tests for :func:`~app.tasks.blocklist_import._apply_schedule`."""
|
||||
|
||||
def _make_app_with_scheduler(self, scheduler: Any) -> MagicMock:
|
||||
"""Return a mock ``app`` whose ``state.scheduler`` is *scheduler*.
|
||||
|
||||
Args:
|
||||
scheduler: Mock scheduler object.
|
||||
|
||||
Returns:
|
||||
Mock FastAPI application instance.
|
||||
"""
|
||||
app = MagicMock()
|
||||
app.state.scheduler = scheduler
|
||||
return app
|
||||
|
||||
def test_apply_schedule_daily_registers_cron_trigger(self) -> None:
|
||||
"""Daily frequency must register a ``"cron"`` trigger with hour and minute."""
|
||||
scheduler = _make_scheduler()
|
||||
app = self._make_app_with_scheduler(scheduler)
|
||||
config = ScheduleConfig(frequency=ScheduleFrequency.daily, hour=3, minute=0)
|
||||
|
||||
_apply_schedule(app, config)
|
||||
|
||||
scheduler.add_job.assert_called_once()
|
||||
_, kwargs = scheduler.add_job.call_args
|
||||
assert kwargs["trigger"] == "cron"
|
||||
assert kwargs["hour"] == 3
|
||||
assert kwargs["minute"] == 0
|
||||
assert "day_of_week" not in kwargs
|
||||
|
||||
def test_apply_schedule_hourly_registers_interval_trigger(self) -> None:
|
||||
"""Hourly frequency must register an ``"interval"`` trigger with correct hours."""
|
||||
scheduler = _make_scheduler()
|
||||
app = self._make_app_with_scheduler(scheduler)
|
||||
config = ScheduleConfig(frequency=ScheduleFrequency.hourly, interval_hours=6)
|
||||
|
||||
_apply_schedule(app, config)
|
||||
|
||||
scheduler.add_job.assert_called_once()
|
||||
_, kwargs = scheduler.add_job.call_args
|
||||
assert kwargs["trigger"] == "interval"
|
||||
assert kwargs["hours"] == 6
|
||||
|
||||
def test_apply_schedule_weekly_registers_cron_trigger_with_day(self) -> None:
|
||||
"""Weekly frequency must register a ``"cron"`` trigger including ``day_of_week``."""
|
||||
scheduler = _make_scheduler()
|
||||
app = self._make_app_with_scheduler(scheduler)
|
||||
config = ScheduleConfig(
|
||||
frequency=ScheduleFrequency.weekly,
|
||||
day_of_week=0,
|
||||
hour=4,
|
||||
minute=30,
|
||||
)
|
||||
|
||||
_apply_schedule(app, config)
|
||||
|
||||
scheduler.add_job.assert_called_once()
|
||||
_, kwargs = scheduler.add_job.call_args
|
||||
assert kwargs["trigger"] == "cron"
|
||||
assert kwargs["day_of_week"] == 0
|
||||
assert kwargs["hour"] == 4
|
||||
assert kwargs["minute"] == 30
|
||||
|
||||
def test_apply_schedule_removes_existing_job_before_adding(self) -> None:
|
||||
"""If a job with ``JOB_ID`` exists, it must be removed before a new one is added."""
|
||||
scheduler = _make_scheduler(has_existing_job=True)
|
||||
app = self._make_app_with_scheduler(scheduler)
|
||||
config = ScheduleConfig(frequency=ScheduleFrequency.daily)
|
||||
|
||||
_apply_schedule(app, config)
|
||||
|
||||
scheduler.remove_job.assert_called_once_with(JOB_ID)
|
||||
assert scheduler.remove_job.call_args_list.index(call(JOB_ID)) < len(
|
||||
scheduler.add_job.call_args_list
|
||||
)
|
||||
|
||||
def test_apply_schedule_skips_remove_when_no_existing_job(self) -> None:
|
||||
"""If no job exists, ``remove_job`` must not be called."""
|
||||
scheduler = _make_scheduler(has_existing_job=False)
|
||||
app = self._make_app_with_scheduler(scheduler)
|
||||
config = ScheduleConfig(frequency=ScheduleFrequency.daily)
|
||||
|
||||
_apply_schedule(app, config)
|
||||
|
||||
scheduler.remove_job.assert_not_called()
|
||||
|
||||
def test_apply_schedule_uses_stable_job_id(self) -> None:
|
||||
"""The registered job must use the module-level ``JOB_ID`` constant."""
|
||||
scheduler = _make_scheduler()
|
||||
app = self._make_app_with_scheduler(scheduler)
|
||||
config = ScheduleConfig(frequency=ScheduleFrequency.daily)
|
||||
|
||||
_apply_schedule(app, config)
|
||||
|
||||
_, kwargs = scheduler.add_job.call_args
|
||||
assert kwargs["id"] == JOB_ID
|
||||
|
||||
def test_apply_schedule_passes_app_in_kwargs(self) -> None:
|
||||
"""The scheduled job must receive ``app`` as a kwarg for state access."""
|
||||
scheduler = _make_scheduler()
|
||||
app = self._make_app_with_scheduler(scheduler)
|
||||
config = ScheduleConfig(frequency=ScheduleFrequency.daily)
|
||||
|
||||
_apply_schedule(app, config)
|
||||
|
||||
_, kwargs = scheduler.add_job.call_args
|
||||
assert kwargs["kwargs"] == {"app": app}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for register / reschedule
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegister:
|
||||
"""Tests for :func:`~app.tasks.blocklist_import.register`."""
|
||||
|
||||
def test_register_calls_apply_schedule_via_event_loop(self) -> None:
|
||||
"""``register`` must call ``_apply_schedule`` after reading the stored config."""
|
||||
import asyncio
|
||||
|
||||
from app.tasks.blocklist_import import register
|
||||
|
||||
app = MagicMock()
|
||||
app.state.db = MagicMock()
|
||||
app.state.scheduler = MagicMock()
|
||||
app.state.scheduler.get_job.return_value = None
|
||||
|
||||
config = ScheduleConfig(frequency=ScheduleFrequency.daily, hour=3, minute=0)
|
||||
|
||||
with patch(
|
||||
"app.tasks.blocklist_import.blocklist_service.get_schedule",
|
||||
new_callable=AsyncMock,
|
||||
return_value=config,
|
||||
), patch("app.tasks.blocklist_import._apply_schedule") as mock_apply:
|
||||
# Use a fresh event loop to avoid interference from pytest-asyncio.
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
with patch("asyncio.get_event_loop", return_value=loop):
|
||||
register(app)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
mock_apply.assert_called_once_with(app, config)
|
||||
|
||||
def test_register_falls_back_to_ensure_future_on_runtime_error(self) -> None:
|
||||
"""When ``run_until_complete`` raises ``RuntimeError``, ``ensure_future`` is used."""
|
||||
from app.tasks.blocklist_import import register
|
||||
|
||||
app = MagicMock()
|
||||
app.state.db = MagicMock()
|
||||
app.state.scheduler = MagicMock()
|
||||
|
||||
config = ScheduleConfig(frequency=ScheduleFrequency.daily)
|
||||
|
||||
mock_loop = MagicMock()
|
||||
mock_loop.run_until_complete.side_effect = RuntimeError("already running")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.tasks.blocklist_import.blocklist_service.get_schedule",
|
||||
new_callable=AsyncMock,
|
||||
return_value=config,
|
||||
),
|
||||
patch("asyncio.get_event_loop", return_value=mock_loop),
|
||||
patch("asyncio.ensure_future") as mock_ensure_future,
|
||||
):
|
||||
register(app)
|
||||
|
||||
mock_ensure_future.assert_called_once()
|
||||
|
||||
|
||||
class TestReschedule:
|
||||
"""Tests for :func:`~app.tasks.blocklist_import.reschedule`."""
|
||||
|
||||
def test_reschedule_calls_ensure_future(self) -> None:
|
||||
"""``reschedule`` must schedule the re-registration with ``asyncio.ensure_future``."""
|
||||
from app.tasks.blocklist_import import reschedule
|
||||
|
||||
app = MagicMock()
|
||||
app.state.db = MagicMock()
|
||||
app.state.scheduler = MagicMock()
|
||||
|
||||
with patch("asyncio.ensure_future") as mock_ensure_future:
|
||||
reschedule(app)
|
||||
|
||||
mock_ensure_future.assert_called_once()
|
||||
136
backend/tests/test_tasks/test_geo_cache_flush.py
Normal file
136
backend/tests/test_tasks/test_geo_cache_flush.py
Normal file
@@ -0,0 +1,136 @@
|
||||
"""Tests for the geo cache flush background task.
|
||||
|
||||
Validates that :func:`~app.tasks.geo_cache_flush._run_flush` correctly
|
||||
delegates to :func:`~app.services.geo_service.flush_dirty` and only logs
|
||||
when entries were actually flushed, and that
|
||||
:func:`~app.tasks.geo_cache_flush.register` configures the APScheduler job
|
||||
with the correct interval and stable job ID.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.tasks.geo_cache_flush import GEO_FLUSH_INTERVAL, JOB_ID, _run_flush, register
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_app(flush_count: int = 0) -> MagicMock:
|
||||
"""Build a minimal mock ``app`` for geo cache flush task tests.
|
||||
|
||||
Args:
|
||||
flush_count: The value returned by the mocked ``flush_dirty`` call.
|
||||
|
||||
Returns:
|
||||
A :class:`unittest.mock.MagicMock` that mimics ``fastapi.FastAPI``.
|
||||
"""
|
||||
app = MagicMock()
|
||||
app.state.db = MagicMock()
|
||||
app.state.scheduler = MagicMock()
|
||||
return app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for _run_flush
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunFlush:
|
||||
"""Tests for :func:`~app.tasks.geo_cache_flush._run_flush`."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_flush_calls_flush_dirty_with_db(self) -> None:
|
||||
"""``_run_flush`` must call ``geo_service.flush_dirty`` with ``app.state.db``."""
|
||||
app = _make_app()
|
||||
|
||||
with patch(
|
||||
"app.tasks.geo_cache_flush.geo_service.flush_dirty",
|
||||
new_callable=AsyncMock,
|
||||
return_value=0,
|
||||
) as mock_flush:
|
||||
await _run_flush(app)
|
||||
|
||||
mock_flush.assert_awaited_once_with(app.state.db)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_flush_logs_when_entries_flushed(self) -> None:
|
||||
"""``_run_flush`` must emit a debug log when ``flush_dirty`` returns > 0."""
|
||||
app = _make_app()
|
||||
|
||||
with patch(
|
||||
"app.tasks.geo_cache_flush.geo_service.flush_dirty",
|
||||
new_callable=AsyncMock,
|
||||
return_value=15,
|
||||
), patch("app.tasks.geo_cache_flush.log") as mock_log:
|
||||
await _run_flush(app)
|
||||
|
||||
debug_calls = [c for c in mock_log.debug.call_args_list if c[0][0] == "geo_cache_flush_ran"]
|
||||
assert len(debug_calls) == 1
|
||||
assert debug_calls[0][1]["flushed"] == 15
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_flush_does_not_log_when_nothing_to_flush(self) -> None:
|
||||
"""``_run_flush`` must not emit any log when ``flush_dirty`` returns 0."""
|
||||
app = _make_app()
|
||||
|
||||
with patch(
|
||||
"app.tasks.geo_cache_flush.geo_service.flush_dirty",
|
||||
new_callable=AsyncMock,
|
||||
return_value=0,
|
||||
), patch("app.tasks.geo_cache_flush.log") as mock_log:
|
||||
await _run_flush(app)
|
||||
|
||||
debug_calls = [c for c in mock_log.debug.call_args_list if c[0][0] == "geo_cache_flush_ran"]
|
||||
assert debug_calls == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for register
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegister:
|
||||
"""Tests for :func:`~app.tasks.geo_cache_flush.register`."""
|
||||
|
||||
def test_register_adds_interval_job_to_scheduler(self) -> None:
|
||||
"""``register`` must add a job with an ``"interval"`` trigger."""
|
||||
app = _make_app()
|
||||
|
||||
register(app)
|
||||
|
||||
app.state.scheduler.add_job.assert_called_once()
|
||||
_, kwargs = app.state.scheduler.add_job.call_args
|
||||
assert kwargs["trigger"] == "interval"
|
||||
assert kwargs["seconds"] == GEO_FLUSH_INTERVAL
|
||||
|
||||
def test_register_uses_stable_job_id(self) -> None:
|
||||
"""``register`` must use the module-level ``JOB_ID`` constant."""
|
||||
app = _make_app()
|
||||
|
||||
register(app)
|
||||
|
||||
_, kwargs = app.state.scheduler.add_job.call_args
|
||||
assert kwargs["id"] == JOB_ID
|
||||
|
||||
def test_register_sets_replace_existing(self) -> None:
|
||||
"""``register`` must use ``replace_existing=True`` to avoid duplicate jobs."""
|
||||
app = _make_app()
|
||||
|
||||
register(app)
|
||||
|
||||
_, kwargs = app.state.scheduler.add_job.call_args
|
||||
assert kwargs["replace_existing"] is True
|
||||
|
||||
def test_register_passes_app_in_kwargs(self) -> None:
|
||||
"""The scheduled job must receive ``app`` as a kwarg for state access."""
|
||||
app = _make_app()
|
||||
|
||||
register(app)
|
||||
|
||||
_, kwargs = app.state.scheduler.add_job.call_args
|
||||
assert kwargs["kwargs"] == {"app": app}
|
||||
238
backend/tests/test_tasks/test_health_check.py
Normal file
238
backend/tests/test_tasks/test_health_check.py
Normal file
@@ -0,0 +1,238 @@
|
||||
"""Tests for the health-check background task.
|
||||
|
||||
Validates that :func:`~app.tasks.health_check._run_probe` correctly stores
|
||||
the probe result on ``app.state.server_status``, logs online/offline
|
||||
transitions, and that :func:`~app.tasks.health_check.register` configures
|
||||
the scheduler and primes the initial status.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models.server import ServerStatus
|
||||
from app.tasks.health_check import HEALTH_CHECK_INTERVAL, _run_probe, register
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_app(prev_online: bool = False) -> MagicMock:
|
||||
"""Build a minimal mock ``app`` for health-check task tests.
|
||||
|
||||
Args:
|
||||
prev_online: Whether the previous ``server_status`` was online.
|
||||
|
||||
Returns:
|
||||
A :class:`unittest.mock.MagicMock` that mimics ``fastapi.FastAPI``.
|
||||
"""
|
||||
app = MagicMock()
|
||||
app.state.settings.fail2ban_socket = "/var/run/fail2ban/fail2ban.sock"
|
||||
app.state.server_status = ServerStatus(online=prev_online)
|
||||
app.state.scheduler = MagicMock()
|
||||
return app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for _run_probe
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunProbe:
|
||||
"""Tests for :func:`~app.tasks.health_check._run_probe`."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_probe_updates_server_status(self) -> None:
|
||||
"""``_run_probe`` must store the probe result on ``app.state.server_status``."""
|
||||
app = _make_app(prev_online=False)
|
||||
new_status = ServerStatus(online=True, version="0.11.2", active_jails=3)
|
||||
|
||||
with patch(
|
||||
"app.tasks.health_check.health_service.probe",
|
||||
new_callable=AsyncMock,
|
||||
return_value=new_status,
|
||||
):
|
||||
await _run_probe(app)
|
||||
|
||||
assert app.state.server_status is new_status
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_probe_logs_came_online_transition(self) -> None:
|
||||
"""When fail2ban comes online, ``"fail2ban_came_online"`` must be logged."""
|
||||
app = _make_app(prev_online=False)
|
||||
new_status = ServerStatus(online=True, version="0.11.2", active_jails=2)
|
||||
|
||||
with patch(
|
||||
"app.tasks.health_check.health_service.probe",
|
||||
new_callable=AsyncMock,
|
||||
return_value=new_status,
|
||||
), patch("app.tasks.health_check.log") as mock_log:
|
||||
await _run_probe(app)
|
||||
|
||||
online_calls = [c for c in mock_log.info.call_args_list if c[0][0] == "fail2ban_came_online"]
|
||||
assert len(online_calls) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_probe_logs_went_offline_transition(self) -> None:
|
||||
"""When fail2ban goes offline, ``"fail2ban_went_offline"`` must be logged."""
|
||||
app = _make_app(prev_online=True)
|
||||
new_status = ServerStatus(online=False)
|
||||
|
||||
with patch(
|
||||
"app.tasks.health_check.health_service.probe",
|
||||
new_callable=AsyncMock,
|
||||
return_value=new_status,
|
||||
), patch("app.tasks.health_check.log") as mock_log:
|
||||
await _run_probe(app)
|
||||
|
||||
offline_calls = [c for c in mock_log.warning.call_args_list if c[0][0] == "fail2ban_went_offline"]
|
||||
assert len(offline_calls) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_probe_stable_online_no_transition_log(self) -> None:
|
||||
"""When status stays online, no transition events must be emitted."""
|
||||
app = _make_app(prev_online=True)
|
||||
new_status = ServerStatus(online=True, version="0.11.2", active_jails=1)
|
||||
|
||||
with patch(
|
||||
"app.tasks.health_check.health_service.probe",
|
||||
new_callable=AsyncMock,
|
||||
return_value=new_status,
|
||||
), patch("app.tasks.health_check.log") as mock_log:
|
||||
await _run_probe(app)
|
||||
|
||||
transition_calls = [
|
||||
c
|
||||
for c in mock_log.info.call_args_list
|
||||
if c[0][0] in ("fail2ban_came_online", "fail2ban_went_offline")
|
||||
]
|
||||
transition_calls += [
|
||||
c
|
||||
for c in mock_log.warning.call_args_list
|
||||
if c[0][0] in ("fail2ban_came_online", "fail2ban_went_offline")
|
||||
]
|
||||
assert transition_calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_probe_stable_offline_no_transition_log(self) -> None:
|
||||
"""When status stays offline, no transition events must be emitted."""
|
||||
app = _make_app(prev_online=False)
|
||||
new_status = ServerStatus(online=False)
|
||||
|
||||
with patch(
|
||||
"app.tasks.health_check.health_service.probe",
|
||||
new_callable=AsyncMock,
|
||||
return_value=new_status,
|
||||
), patch("app.tasks.health_check.log") as mock_log:
|
||||
await _run_probe(app)
|
||||
|
||||
transition_calls = [
|
||||
c
|
||||
for c in mock_log.info.call_args_list
|
||||
if c[0][0] == "fail2ban_came_online"
|
||||
]
|
||||
transition_calls += [
|
||||
c
|
||||
for c in mock_log.warning.call_args_list
|
||||
if c[0][0] == "fail2ban_went_offline"
|
||||
]
|
||||
assert transition_calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_probe_uses_socket_path_from_settings(self) -> None:
|
||||
"""``_run_probe`` must pass the socket path from ``app.state.settings``."""
|
||||
expected_socket = "/custom/fail2ban.sock"
|
||||
app = _make_app()
|
||||
app.state.settings.fail2ban_socket = expected_socket
|
||||
new_status = ServerStatus(online=False)
|
||||
|
||||
with patch(
|
||||
"app.tasks.health_check.health_service.probe",
|
||||
new_callable=AsyncMock,
|
||||
return_value=new_status,
|
||||
) as mock_probe:
|
||||
await _run_probe(app)
|
||||
|
||||
mock_probe.assert_awaited_once_with(expected_socket)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_probe_uses_default_offline_status_when_state_missing(self) -> None:
|
||||
"""``_run_probe`` must handle missing ``server_status`` on first run."""
|
||||
app = _make_app()
|
||||
# Simulate first run: no previous server_status attribute set yet.
|
||||
del app.state.server_status
|
||||
new_status = ServerStatus(online=True, version="0.11.2", active_jails=0)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.tasks.health_check.health_service.probe",
|
||||
new_callable=AsyncMock,
|
||||
return_value=new_status,
|
||||
),
|
||||
patch("app.tasks.health_check.log"),
|
||||
):
|
||||
# Must not raise even with no prior status.
|
||||
await _run_probe(app)
|
||||
|
||||
assert app.state.server_status is new_status
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests for register
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegister:
|
||||
"""Tests for :func:`~app.tasks.health_check.register`."""
|
||||
|
||||
def test_register_adds_interval_job_to_scheduler(self) -> None:
|
||||
"""``register`` must add a job with an ``"interval"`` trigger."""
|
||||
app = _make_app()
|
||||
|
||||
register(app)
|
||||
|
||||
app.state.scheduler.add_job.assert_called_once()
|
||||
_, kwargs = app.state.scheduler.add_job.call_args
|
||||
assert kwargs["trigger"] == "interval"
|
||||
assert kwargs["seconds"] == HEALTH_CHECK_INTERVAL
|
||||
|
||||
def test_register_primes_offline_server_status(self) -> None:
|
||||
"""``register`` must set an initial offline status before the first probe fires."""
|
||||
app = _make_app()
|
||||
# Reset any value set by _make_app.
|
||||
del app.state.server_status
|
||||
|
||||
register(app)
|
||||
|
||||
assert isinstance(app.state.server_status, ServerStatus)
|
||||
assert app.state.server_status.online is False
|
||||
|
||||
def test_register_uses_stable_job_id(self) -> None:
|
||||
"""``register`` must register the job under the stable id ``"health_check"``."""
|
||||
app = _make_app()
|
||||
|
||||
register(app)
|
||||
|
||||
_, kwargs = app.state.scheduler.add_job.call_args
|
||||
assert kwargs["id"] == "health_check"
|
||||
|
||||
def test_register_sets_replace_existing(self) -> None:
|
||||
"""``register`` must use ``replace_existing=True`` to avoid duplicate jobs."""
|
||||
app = _make_app()
|
||||
|
||||
register(app)
|
||||
|
||||
_, kwargs = app.state.scheduler.add_job.call_args
|
||||
assert kwargs["replace_existing"] is True
|
||||
|
||||
def test_register_passes_app_in_kwargs(self) -> None:
|
||||
"""The scheduled job must receive ``app`` as a kwarg for state access."""
|
||||
app = _make_app()
|
||||
|
||||
register(app)
|
||||
|
||||
_, kwargs = app.state.scheduler.add_job.call_args
|
||||
assert kwargs["kwargs"] == {"app": app}
|
||||
Reference in New Issue
Block a user