_is_not_found_error in jail_service did not match the concatenated form 'unknownjailexception' that fail2ban produces when it serialises UnknownJailException, so JailOperationError was raised instead of JailNotFoundError and every ban attempt in the import loop failed individually, skipping all 27 840 IPs before returning an error. Two changes: - Add 'unknownjail' to the phrase list in _is_not_found_error so that UnknownJailException is correctly mapped to JailNotFoundError. - In blocklist_service.import_source, catch JailNotFoundError explicitly and break out of the loop immediately with a warning log instead of retrying on every IP.
261 lines
11 KiB
Python
261 lines
11 KiB
Python
"""Tests for blocklist_service — source CRUD, preview, import, schedule."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import aiosqlite
|
|
import pytest
|
|
|
|
from app.db import init_db
|
|
from app.models.blocklist import BlocklistSource, ScheduleConfig, ScheduleFrequency
|
|
from app.services import blocklist_service
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixture
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
async def db(tmp_path: Path) -> aiosqlite.Connection: # type: ignore[misc]
|
|
"""Provide an initialised aiosqlite connection."""
|
|
conn: aiosqlite.Connection = await aiosqlite.connect(str(tmp_path / "bl_svc.db"))
|
|
conn.row_factory = aiosqlite.Row
|
|
await init_db(conn)
|
|
yield conn
|
|
await conn.close()
|
|
|
|
|
|
def _make_session(text: str, status: int = 200) -> MagicMock:
|
|
"""Build a mock aiohttp session that returns *text* for GET requests."""
|
|
mock_resp = AsyncMock()
|
|
mock_resp.status = status
|
|
mock_resp.text = AsyncMock(return_value=text)
|
|
mock_resp.content = AsyncMock()
|
|
mock_resp.content.read = AsyncMock(return_value=text.encode())
|
|
|
|
mock_ctx = AsyncMock()
|
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_resp)
|
|
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
session = MagicMock()
|
|
session.get = MagicMock(return_value=mock_ctx)
|
|
return session
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Source CRUD
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSourceCRUD:
|
|
async def test_create_and_get(self, db: aiosqlite.Connection) -> None:
|
|
"""create_source persists and get_source retrieves a source."""
|
|
source = await blocklist_service.create_source(db, "Test", "https://t.test/")
|
|
assert isinstance(source, BlocklistSource)
|
|
assert source.name == "Test"
|
|
assert source.enabled is True
|
|
|
|
fetched = await blocklist_service.get_source(db, source.id)
|
|
assert fetched is not None
|
|
assert fetched.id == source.id
|
|
|
|
async def test_get_missing_returns_none(self, db: aiosqlite.Connection) -> None:
|
|
"""get_source returns None for a non-existent id."""
|
|
result = await blocklist_service.get_source(db, 9999)
|
|
assert result is None
|
|
|
|
async def test_list_sources_empty(self, db: aiosqlite.Connection) -> None:
|
|
"""list_sources returns empty list when no sources exist."""
|
|
sources = await blocklist_service.list_sources(db)
|
|
assert sources == []
|
|
|
|
async def test_list_sources_returns_all(self, db: aiosqlite.Connection) -> None:
|
|
"""list_sources returns all created sources."""
|
|
await blocklist_service.create_source(db, "A", "https://a.test/")
|
|
await blocklist_service.create_source(db, "B", "https://b.test/")
|
|
sources = await blocklist_service.list_sources(db)
|
|
assert len(sources) == 2
|
|
|
|
async def test_update_source_fields(self, db: aiosqlite.Connection) -> None:
|
|
"""update_source modifies specified fields."""
|
|
source = await blocklist_service.create_source(db, "Original", "https://orig.test/")
|
|
updated = await blocklist_service.update_source(db, source.id, name="Updated", enabled=False)
|
|
assert updated is not None
|
|
assert updated.name == "Updated"
|
|
assert updated.enabled is False
|
|
|
|
async def test_update_source_missing_returns_none(self, db: aiosqlite.Connection) -> None:
|
|
"""update_source returns None for a non-existent id."""
|
|
result = await blocklist_service.update_source(db, 9999, name="Ghost")
|
|
assert result is None
|
|
|
|
async def test_delete_source(self, db: aiosqlite.Connection) -> None:
|
|
"""delete_source removes a source and returns True."""
|
|
source = await blocklist_service.create_source(db, "Del", "https://del.test/")
|
|
deleted = await blocklist_service.delete_source(db, source.id)
|
|
assert deleted is True
|
|
assert await blocklist_service.get_source(db, source.id) is None
|
|
|
|
async def test_delete_source_missing_returns_false(self, db: aiosqlite.Connection) -> None:
|
|
"""delete_source returns False for a non-existent id."""
|
|
result = await blocklist_service.delete_source(db, 9999)
|
|
assert result is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Preview
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPreview:
|
|
async def test_preview_valid_ips(self) -> None:
|
|
"""preview_source returns valid IPs from the downloaded content."""
|
|
content = "1.2.3.4\n5.6.7.8\n# comment\ninvalid\n9.0.0.1\n"
|
|
session = _make_session(content)
|
|
result = await blocklist_service.preview_source("https://test.test/ips.txt", session)
|
|
assert result.valid_count == 3
|
|
assert result.skipped_count == 1 # "invalid"
|
|
assert "1.2.3.4" in result.entries
|
|
|
|
async def test_preview_http_error_raises(self) -> None:
|
|
"""preview_source raises ValueError when the server returns non-200."""
|
|
session = _make_session("", status=404)
|
|
with pytest.raises(ValueError, match="HTTP 404"):
|
|
await blocklist_service.preview_source("https://bad.test/", session)
|
|
|
|
async def test_preview_limits_entries(self) -> None:
|
|
"""preview_source caps entries to sample_lines."""
|
|
ips = "\n".join(f"1.2.3.{i}" for i in range(50))
|
|
session = _make_session(ips)
|
|
result = await blocklist_service.preview_source(
|
|
"https://test.test/", session, sample_lines=10
|
|
)
|
|
assert len(result.entries) <= 10
|
|
assert result.valid_count == 50
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Import
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestImport:
|
|
async def test_import_source_bans_valid_ips(self, db: aiosqlite.Connection) -> None:
|
|
"""import_source calls ban_ip for every valid IP in the blocklist."""
|
|
content = "1.2.3.4\n5.6.7.8\n# skip me\n"
|
|
session = _make_session(content)
|
|
|
|
source = await blocklist_service.create_source(db, "Import Test", "https://t.test/")
|
|
|
|
with patch(
|
|
"app.services.jail_service.ban_ip", new_callable=AsyncMock
|
|
) as mock_ban:
|
|
result = await blocklist_service.import_source(
|
|
source, session, "/tmp/fake.sock", db
|
|
)
|
|
|
|
assert result.ips_imported == 2
|
|
assert result.ips_skipped == 0
|
|
assert result.error is None
|
|
assert mock_ban.call_count == 2
|
|
|
|
async def test_import_source_skips_cidrs(self, db: aiosqlite.Connection) -> None:
|
|
"""import_source skips CIDR ranges (fail2ban expects individual IPs)."""
|
|
content = "1.2.3.4\n10.0.0.0/24\n"
|
|
session = _make_session(content)
|
|
source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/")
|
|
|
|
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
|
|
result = await blocklist_service.import_source(
|
|
source, session, "/tmp/fake.sock", db
|
|
)
|
|
|
|
assert result.ips_imported == 1
|
|
assert result.ips_skipped == 1
|
|
|
|
async def test_import_source_records_download_error(self, db: aiosqlite.Connection) -> None:
|
|
"""import_source records an error and returns 0 imported on HTTP failure."""
|
|
session = _make_session("", status=503)
|
|
source = await blocklist_service.create_source(db, "Err Source", "https://err.test/")
|
|
|
|
result = await blocklist_service.import_source(
|
|
source, session, "/tmp/fake.sock", db
|
|
)
|
|
|
|
assert result.ips_imported == 0
|
|
assert result.error is not None
|
|
|
|
async def test_import_source_aborts_on_jail_not_found(self, db: aiosqlite.Connection) -> None:
|
|
"""import_source aborts immediately and records an error when the target jail
|
|
does not exist in fail2ban instead of silently skipping every IP."""
|
|
from app.services.jail_service import JailNotFoundError
|
|
|
|
content = "\n".join(f"1.2.3.{i}" for i in range(100))
|
|
session = _make_session(content)
|
|
source = await blocklist_service.create_source(db, "Missing Jail", "https://mj.test/")
|
|
|
|
call_count = 0
|
|
|
|
async def _raise_jail_not_found(socket_path: str, jail: str, ip: str) -> None:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
raise JailNotFoundError(jail)
|
|
|
|
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found):
|
|
result = await blocklist_service.import_source(
|
|
source, session, "/tmp/fake.sock", db
|
|
)
|
|
|
|
# Must abort after the first JailNotFoundError — only one ban attempt.
|
|
assert call_count == 1
|
|
assert result.ips_imported == 0
|
|
assert result.error is not None
|
|
assert "not found" in result.error.lower() or "blocklist-import" in result.error
|
|
|
|
async def test_import_all_runs_all_enabled(self, db: aiosqlite.Connection) -> None:
|
|
"""import_all aggregates results across all enabled sources."""
|
|
await blocklist_service.create_source(db, "S1", "https://s1.test/")
|
|
s2 = await blocklist_service.create_source(db, "S2", "https://s2.test/", enabled=False)
|
|
_ = s2 # noqa: F841
|
|
|
|
content = "1.2.3.4\n5.6.7.8\n"
|
|
session = _make_session(content)
|
|
|
|
with patch(
|
|
"app.services.jail_service.ban_ip", new_callable=AsyncMock
|
|
):
|
|
result = await blocklist_service.import_all(db, session, "/tmp/fake.sock")
|
|
|
|
# Only S1 is enabled, S2 is disabled.
|
|
assert len(result.results) == 1
|
|
assert result.results[0].source_url == "https://s1.test/"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Schedule
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSchedule:
|
|
async def test_get_schedule_default(self, db: aiosqlite.Connection) -> None:
|
|
"""get_schedule returns the default daily-03:00 config when nothing is saved."""
|
|
config = await blocklist_service.get_schedule(db)
|
|
assert config.frequency == ScheduleFrequency.daily
|
|
assert config.hour == 3
|
|
|
|
async def test_set_and_get_round_trip(self, db: aiosqlite.Connection) -> None:
|
|
"""set_schedule persists config retrievable by get_schedule."""
|
|
cfg = ScheduleConfig(frequency=ScheduleFrequency.hourly, interval_hours=6, hour=0, minute=0, day_of_week=0)
|
|
await blocklist_service.set_schedule(db, cfg)
|
|
loaded = await blocklist_service.get_schedule(db)
|
|
assert loaded.frequency == ScheduleFrequency.hourly
|
|
assert loaded.interval_hours == 6
|
|
|
|
async def test_get_schedule_info_no_log(self, db: aiosqlite.Connection) -> None:
|
|
"""get_schedule_info returns None for last_run_at when no log exists."""
|
|
info = await blocklist_service.get_schedule_info(db, None)
|
|
assert info.last_run_at is None
|
|
assert info.next_run_at is None
|