Files
BanGUI/backend/tests/test_services/test_blocklist_service.py

555 lines
23 KiB
Python

"""Tests for blocklist_service — source CRUD, preview, import, schedule."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import aiosqlite
import pytest
from app.db import init_db
from app.models.blocklist import (
BlocklistSource,
ScheduleConfig,
ScheduleFrequency,
)
from app.services import blocklist_service
# ---------------------------------------------------------------------------
# Fixture
# ---------------------------------------------------------------------------
@pytest.fixture
async def db(tmp_path: Path) -> aiosqlite.Connection: # type: ignore[misc]
"""Provide an initialised aiosqlite connection."""
conn: aiosqlite.Connection = await aiosqlite.connect(str(tmp_path / "bl_svc.db"))
conn.row_factory = aiosqlite.Row
await init_db(conn)
yield conn
await conn.close()
def _make_session(text: str, status: int = 200) -> MagicMock:
"""Build a mock aiohttp session that returns *text* for GET requests."""
mock_resp = AsyncMock()
mock_resp.status = status
mock_resp.text = AsyncMock(return_value=text)
mock_resp.content = AsyncMock()
mock_resp.content.read = AsyncMock(return_value=text.encode())
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_resp)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
session = MagicMock()
session.get = MagicMock(return_value=mock_ctx)
return session
# ---------------------------------------------------------------------------
# Source CRUD
# ---------------------------------------------------------------------------
class TestSourceCRUD:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_create_and_get(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""create_source persists and get_source retrieves a source."""
mock_validate.return_value = None
source = await blocklist_service.create_source(db, "Test", "https://t.test/")
assert isinstance(source, BlocklistSource)
assert source.name == "Test"
assert source.enabled is True
fetched = await blocklist_service.get_source(db, source.id)
assert fetched is not None
assert fetched.id == source.id
async def test_get_missing_returns_none(self, db: aiosqlite.Connection) -> None:
"""get_source returns None for a non-existent id."""
result = await blocklist_service.get_source(db, 9999)
assert result is None
async def test_list_sources_empty(self, db: aiosqlite.Connection) -> None:
"""list_sources returns empty list when no sources exist."""
sources = await blocklist_service.list_sources(db)
assert sources == []
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_list_sources_returns_all(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""list_sources returns all created sources."""
mock_validate.return_value = None
await blocklist_service.create_source(db, "A", "https://a.test/")
await blocklist_service.create_source(db, "B", "https://b.test/")
sources = await blocklist_service.list_sources(db)
assert len(sources) == 2
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_update_source_fields(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""update_source modifies specified fields."""
mock_validate.return_value = None
source = await blocklist_service.create_source(db, "Original", "https://orig.test/")
updated = await blocklist_service.update_source(db, source.id, name="Updated", enabled=False)
assert updated is not None
assert updated.name == "Updated"
assert updated.enabled is False
async def test_update_source_missing_returns_none(self, db: aiosqlite.Connection) -> None:
"""update_source returns None for a non-existent id."""
result = await blocklist_service.update_source(db, 9999, name="Ghost")
assert result is None
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_delete_source(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""delete_source removes a source and returns True."""
mock_validate.return_value = None
source = await blocklist_service.create_source(db, "Del", "https://del.test/")
deleted = await blocklist_service.delete_source(db, source.id)
assert deleted is True
assert await blocklist_service.get_source(db, source.id) is None
async def test_delete_source_missing_returns_false(self, db: aiosqlite.Connection) -> None:
"""delete_source returns False for a non-existent id."""
result = await blocklist_service.delete_source(db, 9999)
assert result is False
# ---------------------------------------------------------------------------
# Preview
# ---------------------------------------------------------------------------
class TestPreview:
async def test_preview_valid_ips(self) -> None:
"""preview_source returns valid IPs from the downloaded content."""
content = "1.2.3.4\n5.6.7.8\n# comment\ninvalid\n9.0.0.1\n"
session = _make_session(content)
result = await blocklist_service.preview_source("https://test.test/ips.txt", session)
assert result.valid_count == 3
assert result.skipped_count == 1 # "invalid"
assert "1.2.3.4" in result.entries
async def test_preview_http_error_raises(self) -> None:
"""preview_source raises ValueError when the server returns non-200."""
session = _make_session("", status=404)
with pytest.raises(ValueError, match="HTTP 404"):
await blocklist_service.preview_source("https://bad.test/", session)
async def test_preview_retries_transient_errors(self) -> None:
"""preview_source retries transient network failures before succeeding."""
content = "1.2.3.4\n"
mock_resp = AsyncMock()
mock_resp.status = 200
mock_resp.text = AsyncMock(return_value=content)
mock_resp.content = AsyncMock()
mock_resp.content.read = AsyncMock(return_value=content.encode())
mock_ctx = AsyncMock()
mock_ctx.__aenter__.return_value = mock_resp
mock_ctx.__aexit__.return_value = False
session = MagicMock()
session.get = MagicMock(side_effect=[Exception("connection reset"), mock_ctx])
result = await blocklist_service.preview_source("https://test.test/ips.txt", session)
assert result.valid_count == 1
assert session.get.call_count == 2
async def test_preview_limits_entries(self) -> None:
"""preview_source caps entries to sample_lines."""
ips = "\n".join(f"1.2.3.{i}" for i in range(50))
session = _make_session(ips)
result = await blocklist_service.preview_source(
"https://test.test/", session, sample_lines=10
)
assert len(result.entries) <= 10
assert result.valid_count == 50
# ---------------------------------------------------------------------------
# Import
# ---------------------------------------------------------------------------
class TestImport:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_import_source_bans_valid_ips(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""import_source calls ban_ip for every valid IP in the blocklist."""
mock_validate.return_value = None
content = "1.2.3.4\n5.6.7.8\n# skip me\n"
session = _make_session(content)
source = await blocklist_service.create_source(db, "Import Test", "https://t.test/")
from app.services import ban_service
with patch(
"app.services.ban_service.ban_ip", new_callable=AsyncMock
) as mock_ban:
result = await blocklist_service.import_source(
source,
session,
"/tmp/fake.sock",
db,
ban_ip=ban_service.ban_ip,
)
assert result.ips_imported == 2
assert result.ips_skipped == 0
assert result.error is None
assert mock_ban.call_count == 2
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_import_source_skips_cidrs(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""import_source skips CIDR ranges (fail2ban expects individual IPs)."""
mock_validate.return_value = None
content = "1.2.3.4\n10.0.0.0/24\n"
session = _make_session(content)
source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/")
from app.services import ban_service
with patch("app.services.ban_service.ban_ip", new_callable=AsyncMock):
result = await blocklist_service.import_source(
source,
session,
"/tmp/fake.sock",
db,
ban_ip=ban_service.ban_ip,
)
assert result.ips_imported == 1
assert result.ips_skipped == 1
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_import_source_records_download_error(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""import_source records an error and returns 0 imported on HTTP failure."""
mock_validate.return_value = None
session = _make_session("", status=503)
source = await blocklist_service.create_source(db, "Err Source", "https://err.test/")
from app.services import ban_service
result = await blocklist_service.import_source(
source,
session,
"/tmp/fake.sock",
db,
ban_ip=ban_service.ban_ip,
)
assert result.ips_imported == 0
assert result.error is not None
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_import_source_aborts_on_jail_not_found(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""import_source aborts immediately and records an error when the target jail
does not exist in fail2ban instead of silently skipping every IP."""
mock_validate.return_value = None
from app.services.jail_service import JailNotFoundError
from app.services import ban_service
content = "\n".join(f"1.2.3.{i}" for i in range(100))
session = _make_session(content)
source = await blocklist_service.create_source(db, "Missing Jail", "https://mj.test/")
call_count = 0
async def _raise_jail_not_found(socket_path: str, jail: str, ip: str) -> None:
nonlocal call_count
call_count += 1
raise JailNotFoundError(jail)
with patch("app.services.ban_service.ban_ip", side_effect=_raise_jail_not_found):
result = await blocklist_service.import_source(
source,
session,
"/tmp/fake.sock",
db,
ban_ip=ban_service.ban_ip,
)
# Must abort after the first JailNotFoundError — only one ban attempt.
assert call_count == 1
assert result.ips_imported == 0
assert result.error is not None
assert "not found" in result.error.lower() or "blocklist-import" in result.error
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_import_all_runs_all_enabled(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""import_all aggregates results across all enabled sources."""
mock_validate.return_value = None
await blocklist_service.create_source(db, "S1", "https://s1.test/")
s2 = await blocklist_service.create_source(db, "S2", "https://s2.test/", enabled=False)
_ = s2 # noqa: F841
content = "1.2.3.4\n5.6.7.8\n"
session = _make_session(content)
with patch(
"app.services.ban_service.ban_ip", new_callable=AsyncMock
):
from app.services import ban_service
result = await blocklist_service.import_all(
db,
session,
"/tmp/fake.sock",
ban_ip=ban_service.ban_ip,
)
# Only S1 is enabled, S2 is disabled.
assert len(result.results) == 1
assert result.results[0].source_url == "https://s1.test/"
# ---------------------------------------------------------------------------
# Schedule
# ---------------------------------------------------------------------------
class TestSchedule:
async def test_get_schedule_default(self, db: aiosqlite.Connection) -> None:
"""get_schedule returns the default daily-03:00 config when nothing is saved."""
config = await blocklist_service.get_schedule(db)
assert config.frequency == ScheduleFrequency.daily
assert config.hour == 3
async def test_set_and_get_round_trip(self, db: aiosqlite.Connection) -> None:
"""set_schedule persists config retrievable by get_schedule."""
cfg = ScheduleConfig(frequency=ScheduleFrequency.hourly, interval_hours=6, hour=0, minute=0, day_of_week=0)
await blocklist_service.set_schedule(db, cfg)
loaded = await blocklist_service.get_schedule(db)
assert loaded.frequency == ScheduleFrequency.hourly
assert loaded.interval_hours == 6
async def test_get_schedule_info_no_log(self, db: aiosqlite.Connection) -> None:
"""get_schedule_info returns None for last_run_at and last_run_errors when no log exists."""
info = await blocklist_service.get_schedule_info(db, None)
assert info.last_run_at is None
assert info.next_run_at is None
assert info.last_run_errors is None
async def test_get_schedule_info_no_errors_when_clean(
self, db: aiosqlite.Connection
) -> None:
"""get_schedule_info returns last_run_errors=False when the last run had no errors."""
from app.repositories import import_log_repo
await import_log_repo.add_log(
db,
source_id=None,
source_url="https://example.test/ips.txt",
ips_imported=10,
ips_skipped=0,
errors=None,
)
info = await blocklist_service.get_schedule_info(db, None)
assert info.last_run_errors is False
async def test_get_schedule_info_errors_flag_when_failed(
self, db: aiosqlite.Connection
) -> None:
"""get_schedule_info returns last_run_errors=True when the last run had errors."""
from app.repositories import import_log_repo
await import_log_repo.add_log(
db,
source_id=None,
source_url="https://example.test/ips.txt",
ips_imported=0,
ips_skipped=0,
errors="Connection timeout",
)
info = await blocklist_service.get_schedule_info(db, None)
assert info.last_run_errors is True
async def test_get_schedule_info_with_runtime_uses_scheduler_metadata(
self, db: aiosqlite.Connection
) -> None:
"""get_schedule_info_with_runtime derives next_run_at from the scheduler."""
next_run = MagicMock()
next_run.isoformat.return_value = "2099-01-01T00:00:00+00:00"
scheduler = MagicMock()
scheduler.get_job.return_value = MagicMock(next_run_time=next_run)
info = await blocklist_service.get_schedule_info_with_runtime(db, scheduler)
assert info.next_run_at == "2099-01-01T00:00:00+00:00"
async def test_update_schedule_persists_and_schedules_job(
self, db: aiosqlite.Connection
) -> None:
"""update_schedule must persist the config and schedule a job."""
settings = MagicMock(
fail2ban_socket="/var/run/fail2ban/fail2ban.sock",
database_path=":memory:",
)
http_session = MagicMock()
scheduler = MagicMock()
scheduler.get_job.return_value = None
config = ScheduleConfig(
frequency=ScheduleFrequency.daily,
hour=4,
minute=15,
)
run_import_callback = AsyncMock(return_value=None)
info = await blocklist_service.update_schedule(
db,
scheduler,
http_session,
settings,
config,
run_import_callback,
)
assert info.config.frequency == ScheduleFrequency.daily
scheduler.add_job.assert_called_once()
# ---------------------------------------------------------------------------
# Geo prewarm cache filtering
# ---------------------------------------------------------------------------
class TestGeoPrewarmCacheFilter:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_import_source_skips_cached_ips_for_geo_prewarm(
self, mock_validate: AsyncMock, db: aiosqlite.Connection
) -> None:
"""import_source only sends uncached IPs to geo_service.lookup_batch."""
mock_validate.return_value = None
content = "1.2.3.4\n5.6.7.8\n9.10.11.12\n"
session = _make_session(content)
source = await blocklist_service.create_source(
db, "Geo Filter", "https://gf.test/"
)
# Pretend 1.2.3.4 is already cached.
def _mock_is_cached(ip: str) -> bool:
return ip == "1.2.3.4"
from app.services import ban_service
mock_lookup = AsyncMock(return_value={})
mock_geo_cache = MagicMock()
mock_geo_cache.is_cached = _mock_is_cached
mock_geo_cache.lookup_batch = mock_lookup
with patch("app.services.ban_service.ban_ip", new_callable=AsyncMock):
result = await blocklist_service.import_source(
source,
session,
"/tmp/fake.sock",
db,
ban_ip=ban_service.ban_ip,
geo_is_cached=_mock_is_cached,
geo_cache=mock_geo_cache,
)
assert result.ips_imported == 3
# lookup_batch should receive only the 2 uncached IPs.
mock_lookup.assert_called_once()
call_ips = mock_lookup.call_args[0][0]
assert "1.2.3.4" not in call_ips
assert set(call_ips) == {"5.6.7.8", "9.10.11.12"}
# ---------------------------------------------------------------------------
# URL Validation (SSRF Prevention)
# ---------------------------------------------------------------------------
class TestURLValidation:
"""Test SSRF protection by validating blocklist URLs."""
async def test_create_source_rejects_file_url(self, db: aiosqlite.Connection) -> None:
"""create_source rejects file:// URLs."""
with pytest.raises(ValueError, match="Invalid scheme"):
await blocklist_service.create_source(db, "Bad", "file:///etc/passwd")
async def test_create_source_rejects_ftp_url(self, db: aiosqlite.Connection) -> None:
"""create_source rejects ftp:// URLs."""
with pytest.raises(ValueError, match="Invalid scheme"):
await blocklist_service.create_source(db, "Bad", "ftp://evil.com/file.txt")
async def test_create_source_rejects_localhost(self, db: aiosqlite.Connection) -> None:
"""create_source rejects localhost (127.0.0.1)."""
with pytest.raises(ValueError, match="private|reserved"):
await blocklist_service.create_source(db, "Bad", "http://127.0.0.1/list")
async def test_create_source_rejects_localhost_name(self, db: aiosqlite.Connection) -> None:
"""create_source rejects localhost hostname."""
with pytest.raises(ValueError, match="private|reserved"):
await blocklist_service.create_source(db, "Bad", "http://localhost/list")
async def test_create_source_rejects_private_network(self, db: aiosqlite.Connection) -> None:
"""create_source rejects private RFC 1918 networks (10.0.0.0/8)."""
with pytest.raises(ValueError, match="private|reserved"):
await blocklist_service.create_source(db, "Bad", "http://10.0.0.1/list")
async def test_create_source_rejects_private_network_172(self, db: aiosqlite.Connection) -> None:
"""create_source rejects private RFC 1918 networks (172.16.0.0/12)."""
with pytest.raises(ValueError, match="private|reserved"):
await blocklist_service.create_source(db, "Bad", "http://172.16.0.1/list")
async def test_create_source_rejects_private_network_192(self, db: aiosqlite.Connection) -> None:
"""create_source rejects private RFC 1918 networks (192.168.0.0/16)."""
with pytest.raises(ValueError, match="private|reserved"):
await blocklist_service.create_source(db, "Bad", "http://192.168.1.1/list")
async def test_create_source_rejects_link_local(self, db: aiosqlite.Connection) -> None:
"""create_source rejects link-local addresses (169.254.x.x)."""
with pytest.raises(ValueError, match="private|reserved"):
await blocklist_service.create_source(db, "Bad", "http://169.254.169.254/latest/meta-data")
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_create_source_accepts_valid_public_url(
self, mock_validate: AsyncMock, db: aiosqlite.Connection
) -> None:
"""create_source accepts valid public HTTPS URLs (validation mocked)."""
mock_validate.return_value = None
source = await blocklist_service.create_source(db, "Good", "https://example.com/list.txt")
assert source.name == "Good"
assert source.url == "https://example.com/list.txt"
class TestImportLogPagination:
async def test_list_import_logs_empty(self, db: aiosqlite.Connection) -> None:
"""list_import_logs returns an empty page when no logs exist."""
resp = await blocklist_service.list_import_logs(
db, source_id=None, page=1, page_size=10
)
assert resp.items == []
assert resp.total == 0
assert resp.page == 1
assert resp.page_size == 10
async def test_list_import_logs_paginates(self, db: aiosqlite.Connection) -> None:
"""list_import_logs computes total pages and returns the correct subset."""
from app.repositories import import_log_repo
for i in range(3):
await import_log_repo.add_log(
db,
source_id=None,
source_url=f"https://example{i}.test/ips.txt",
ips_imported=1,
ips_skipped=0,
errors=None,
)
resp = await blocklist_service.list_import_logs(
db, source_id=None, page=2, page_size=2
)
assert resp.total == 3
assert resp.page == 2
assert resp.page_size == 2
assert len(resp.items) == 1
assert resp.items[0].source_url == "https://example0.test/ips.txt"