Files
BanGUI/backend/tests/test_services/test_blocklist_service.py
Lukas 6e76711940 Fix blocklist import: detect UnknownJailException and abort early
_is_not_found_error in jail_service did not match the concatenated form
'unknownjailexception' that fail2ban produces when it serialises
UnknownJailException, so JailOperationError was raised instead of
JailNotFoundError and every ban attempt in the import loop failed
individually, skipping all 27 840 IPs before returning an error.

Two changes:
- Add 'unknownjail' to the phrase list in _is_not_found_error so that
  UnknownJailException is correctly mapped to JailNotFoundError.
- In blocklist_service.import_source, catch JailNotFoundError explicitly
  and break out of the loop immediately with a warning log instead of
  retrying on every IP.
2026-03-01 21:02:37 +01:00

261 lines
11 KiB
Python

"""Tests for blocklist_service — source CRUD, preview, import, schedule."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import aiosqlite
import pytest
from app.db import init_db
from app.models.blocklist import BlocklistSource, ScheduleConfig, ScheduleFrequency
from app.services import blocklist_service
# ---------------------------------------------------------------------------
# Fixture
# ---------------------------------------------------------------------------
@pytest.fixture
async def db(tmp_path: Path) -> aiosqlite.Connection: # type: ignore[misc]
"""Provide an initialised aiosqlite connection."""
conn: aiosqlite.Connection = await aiosqlite.connect(str(tmp_path / "bl_svc.db"))
conn.row_factory = aiosqlite.Row
await init_db(conn)
yield conn
await conn.close()
def _make_session(text: str, status: int = 200) -> MagicMock:
"""Build a mock aiohttp session that returns *text* for GET requests."""
mock_resp = AsyncMock()
mock_resp.status = status
mock_resp.text = AsyncMock(return_value=text)
mock_resp.content = AsyncMock()
mock_resp.content.read = AsyncMock(return_value=text.encode())
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_resp)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
session = MagicMock()
session.get = MagicMock(return_value=mock_ctx)
return session
# ---------------------------------------------------------------------------
# Source CRUD
# ---------------------------------------------------------------------------
class TestSourceCRUD:
async def test_create_and_get(self, db: aiosqlite.Connection) -> None:
"""create_source persists and get_source retrieves a source."""
source = await blocklist_service.create_source(db, "Test", "https://t.test/")
assert isinstance(source, BlocklistSource)
assert source.name == "Test"
assert source.enabled is True
fetched = await blocklist_service.get_source(db, source.id)
assert fetched is not None
assert fetched.id == source.id
async def test_get_missing_returns_none(self, db: aiosqlite.Connection) -> None:
"""get_source returns None for a non-existent id."""
result = await blocklist_service.get_source(db, 9999)
assert result is None
async def test_list_sources_empty(self, db: aiosqlite.Connection) -> None:
"""list_sources returns empty list when no sources exist."""
sources = await blocklist_service.list_sources(db)
assert sources == []
async def test_list_sources_returns_all(self, db: aiosqlite.Connection) -> None:
"""list_sources returns all created sources."""
await blocklist_service.create_source(db, "A", "https://a.test/")
await blocklist_service.create_source(db, "B", "https://b.test/")
sources = await blocklist_service.list_sources(db)
assert len(sources) == 2
async def test_update_source_fields(self, db: aiosqlite.Connection) -> None:
"""update_source modifies specified fields."""
source = await blocklist_service.create_source(db, "Original", "https://orig.test/")
updated = await blocklist_service.update_source(db, source.id, name="Updated", enabled=False)
assert updated is not None
assert updated.name == "Updated"
assert updated.enabled is False
async def test_update_source_missing_returns_none(self, db: aiosqlite.Connection) -> None:
"""update_source returns None for a non-existent id."""
result = await blocklist_service.update_source(db, 9999, name="Ghost")
assert result is None
async def test_delete_source(self, db: aiosqlite.Connection) -> None:
"""delete_source removes a source and returns True."""
source = await blocklist_service.create_source(db, "Del", "https://del.test/")
deleted = await blocklist_service.delete_source(db, source.id)
assert deleted is True
assert await blocklist_service.get_source(db, source.id) is None
async def test_delete_source_missing_returns_false(self, db: aiosqlite.Connection) -> None:
"""delete_source returns False for a non-existent id."""
result = await blocklist_service.delete_source(db, 9999)
assert result is False
# ---------------------------------------------------------------------------
# Preview
# ---------------------------------------------------------------------------
class TestPreview:
async def test_preview_valid_ips(self) -> None:
"""preview_source returns valid IPs from the downloaded content."""
content = "1.2.3.4\n5.6.7.8\n# comment\ninvalid\n9.0.0.1\n"
session = _make_session(content)
result = await blocklist_service.preview_source("https://test.test/ips.txt", session)
assert result.valid_count == 3
assert result.skipped_count == 1 # "invalid"
assert "1.2.3.4" in result.entries
async def test_preview_http_error_raises(self) -> None:
"""preview_source raises ValueError when the server returns non-200."""
session = _make_session("", status=404)
with pytest.raises(ValueError, match="HTTP 404"):
await blocklist_service.preview_source("https://bad.test/", session)
async def test_preview_limits_entries(self) -> None:
"""preview_source caps entries to sample_lines."""
ips = "\n".join(f"1.2.3.{i}" for i in range(50))
session = _make_session(ips)
result = await blocklist_service.preview_source(
"https://test.test/", session, sample_lines=10
)
assert len(result.entries) <= 10
assert result.valid_count == 50
# ---------------------------------------------------------------------------
# Import
# ---------------------------------------------------------------------------
class TestImport:
async def test_import_source_bans_valid_ips(self, db: aiosqlite.Connection) -> None:
"""import_source calls ban_ip for every valid IP in the blocklist."""
content = "1.2.3.4\n5.6.7.8\n# skip me\n"
session = _make_session(content)
source = await blocklist_service.create_source(db, "Import Test", "https://t.test/")
with patch(
"app.services.jail_service.ban_ip", new_callable=AsyncMock
) as mock_ban:
result = await blocklist_service.import_source(
source, session, "/tmp/fake.sock", db
)
assert result.ips_imported == 2
assert result.ips_skipped == 0
assert result.error is None
assert mock_ban.call_count == 2
async def test_import_source_skips_cidrs(self, db: aiosqlite.Connection) -> None:
"""import_source skips CIDR ranges (fail2ban expects individual IPs)."""
content = "1.2.3.4\n10.0.0.0/24\n"
session = _make_session(content)
source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/")
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
result = await blocklist_service.import_source(
source, session, "/tmp/fake.sock", db
)
assert result.ips_imported == 1
assert result.ips_skipped == 1
async def test_import_source_records_download_error(self, db: aiosqlite.Connection) -> None:
"""import_source records an error and returns 0 imported on HTTP failure."""
session = _make_session("", status=503)
source = await blocklist_service.create_source(db, "Err Source", "https://err.test/")
result = await blocklist_service.import_source(
source, session, "/tmp/fake.sock", db
)
assert result.ips_imported == 0
assert result.error is not None
async def test_import_source_aborts_on_jail_not_found(self, db: aiosqlite.Connection) -> None:
"""import_source aborts immediately and records an error when the target jail
does not exist in fail2ban instead of silently skipping every IP."""
from app.services.jail_service import JailNotFoundError
content = "\n".join(f"1.2.3.{i}" for i in range(100))
session = _make_session(content)
source = await blocklist_service.create_source(db, "Missing Jail", "https://mj.test/")
call_count = 0
async def _raise_jail_not_found(socket_path: str, jail: str, ip: str) -> None:
nonlocal call_count
call_count += 1
raise JailNotFoundError(jail)
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found):
result = await blocklist_service.import_source(
source, session, "/tmp/fake.sock", db
)
# Must abort after the first JailNotFoundError — only one ban attempt.
assert call_count == 1
assert result.ips_imported == 0
assert result.error is not None
assert "not found" in result.error.lower() or "blocklist-import" in result.error
async def test_import_all_runs_all_enabled(self, db: aiosqlite.Connection) -> None:
"""import_all aggregates results across all enabled sources."""
await blocklist_service.create_source(db, "S1", "https://s1.test/")
s2 = await blocklist_service.create_source(db, "S2", "https://s2.test/", enabled=False)
_ = s2 # noqa: F841
content = "1.2.3.4\n5.6.7.8\n"
session = _make_session(content)
with patch(
"app.services.jail_service.ban_ip", new_callable=AsyncMock
):
result = await blocklist_service.import_all(db, session, "/tmp/fake.sock")
# Only S1 is enabled, S2 is disabled.
assert len(result.results) == 1
assert result.results[0].source_url == "https://s1.test/"
# ---------------------------------------------------------------------------
# Schedule
# ---------------------------------------------------------------------------
class TestSchedule:
async def test_get_schedule_default(self, db: aiosqlite.Connection) -> None:
"""get_schedule returns the default daily-03:00 config when nothing is saved."""
config = await blocklist_service.get_schedule(db)
assert config.frequency == ScheduleFrequency.daily
assert config.hour == 3
async def test_set_and_get_round_trip(self, db: aiosqlite.Connection) -> None:
"""set_schedule persists config retrievable by get_schedule."""
cfg = ScheduleConfig(frequency=ScheduleFrequency.hourly, interval_hours=6, hour=0, minute=0, day_of_week=0)
await blocklist_service.set_schedule(db, cfg)
loaded = await blocklist_service.get_schedule(db)
assert loaded.frequency == ScheduleFrequency.hourly
assert loaded.interval_hours == 6
async def test_get_schedule_info_no_log(self, db: aiosqlite.Connection) -> None:
"""get_schedule_info returns None for last_run_at when no log exists."""
info = await blocklist_service.get_schedule_info(db, None)
assert info.last_run_at is None
assert info.next_run_at is None