Stage 10: external blocklist importer — backend + frontend

- blocklist_repo.py: CRUD for blocklist_sources table
- import_log_repo.py: add/list/get-last log entries
- blocklist_service.py: source CRUD, preview, import (download/validate/ban),
  import_all, schedule get/set/info
- blocklist_import.py: APScheduler task (hourly/daily/weekly schedule triggers)
- blocklist.py router: 9 endpoints (list/create/update/delete/preview/import/
  schedule-get+put/log)
- blocklist.py models: ScheduleFrequency (StrEnum), ScheduleConfig, ScheduleInfo,
  ImportSourceResult, ImportRunResult, PreviewResponse
- 59 new tests (18 repo + 19 service + 22 router); 374 total pass
- ruff clean, mypy clean for Stage 10 files
- types/blocklist.ts, api/blocklist.ts, hooks/useBlocklist.ts
- BlocklistsPage.tsx: source management, schedule picker, import log table
- Frontend tsc + ESLint clean
This commit is contained in:
2026-03-01 15:33:24 +01:00
parent b8f3a1c562
commit 1efa0e973b
15 changed files with 3771 additions and 53 deletions

View File

@@ -0,0 +1,210 @@
"""Tests for blocklist_repo and import_log_repo."""
from __future__ import annotations
from pathlib import Path
import aiosqlite
import pytest
from app.db import init_db
from app.repositories import blocklist_repo, import_log_repo
@pytest.fixture
async def db(tmp_path: Path) -> aiosqlite.Connection: # type: ignore[misc]
"""Provide an initialised aiosqlite connection for repository tests."""
conn: aiosqlite.Connection = await aiosqlite.connect(str(tmp_path / "bl_test.db"))
conn.row_factory = aiosqlite.Row
await init_db(conn)
yield conn
await conn.close()
# ---------------------------------------------------------------------------
# blocklist_repo tests
# ---------------------------------------------------------------------------
class TestBlocklistRepo:
async def test_create_source_returns_int_id(self, db: aiosqlite.Connection) -> None:
"""create_source returns a positive integer id."""
source_id = await blocklist_repo.create_source(db, "Test", "https://example.com/list.txt")
assert isinstance(source_id, int)
assert source_id > 0
async def test_get_source_returns_row(self, db: aiosqlite.Connection) -> None:
"""get_source returns the correct row after creation."""
source_id = await blocklist_repo.create_source(db, "Alpha", "https://alpha.test/ips.txt")
row = await blocklist_repo.get_source(db, source_id)
assert row is not None
assert row["name"] == "Alpha"
assert row["url"] == "https://alpha.test/ips.txt"
assert row["enabled"] is True
async def test_get_source_missing_returns_none(self, db: aiosqlite.Connection) -> None:
"""get_source returns None for a non-existent id."""
result = await blocklist_repo.get_source(db, 9999)
assert result is None
async def test_list_sources_empty(self, db: aiosqlite.Connection) -> None:
"""list_sources returns empty list when no sources exist."""
rows = await blocklist_repo.list_sources(db)
assert rows == []
async def test_list_sources_returns_all(self, db: aiosqlite.Connection) -> None:
"""list_sources returns all created sources."""
await blocklist_repo.create_source(db, "A", "https://a.test/")
await blocklist_repo.create_source(db, "B", "https://b.test/")
rows = await blocklist_repo.list_sources(db)
assert len(rows) == 2
async def test_list_enabled_sources_filters(self, db: aiosqlite.Connection) -> None:
"""list_enabled_sources only returns rows with enabled=True."""
await blocklist_repo.create_source(db, "Enabled", "https://on.test/", enabled=True)
id2 = await blocklist_repo.create_source(db, "Disabled", "https://off.test/", enabled=False)
await blocklist_repo.update_source(db, id2, enabled=False)
rows = await blocklist_repo.list_enabled_sources(db)
assert len(rows) == 1
assert rows[0]["name"] == "Enabled"
async def test_update_source_name(self, db: aiosqlite.Connection) -> None:
"""update_source changes the name field."""
source_id = await blocklist_repo.create_source(db, "Old", "https://old.test/")
updated = await blocklist_repo.update_source(db, source_id, name="New")
assert updated is True
row = await blocklist_repo.get_source(db, source_id)
assert row is not None
assert row["name"] == "New"
async def test_update_source_enabled_false(self, db: aiosqlite.Connection) -> None:
"""update_source can disable a source."""
source_id = await blocklist_repo.create_source(db, "On", "https://on.test/")
await blocklist_repo.update_source(db, source_id, enabled=False)
row = await blocklist_repo.get_source(db, source_id)
assert row is not None
assert row["enabled"] is False
async def test_update_source_missing_returns_false(self, db: aiosqlite.Connection) -> None:
"""update_source returns False for a non-existent id."""
result = await blocklist_repo.update_source(db, 9999, name="Ghost")
assert result is False
async def test_delete_source_removes_row(self, db: aiosqlite.Connection) -> None:
"""delete_source removes the row and returns True."""
source_id = await blocklist_repo.create_source(db, "Del", "https://del.test/")
deleted = await blocklist_repo.delete_source(db, source_id)
assert deleted is True
assert await blocklist_repo.get_source(db, source_id) is None
async def test_delete_source_missing_returns_false(self, db: aiosqlite.Connection) -> None:
"""delete_source returns False for a non-existent id."""
result = await blocklist_repo.delete_source(db, 9999)
assert result is False
# ---------------------------------------------------------------------------
# import_log_repo tests
# ---------------------------------------------------------------------------
class TestImportLogRepo:
async def test_add_log_returns_id(self, db: aiosqlite.Connection) -> None:
"""add_log returns a positive integer id."""
log_id = await import_log_repo.add_log(
db,
source_id=None,
source_url="https://example.com/list.txt",
ips_imported=10,
ips_skipped=2,
errors=None,
)
assert isinstance(log_id, int)
assert log_id > 0
async def test_list_logs_returns_all(self, db: aiosqlite.Connection) -> None:
"""list_logs returns all logs when no source_id filter is applied."""
for i in range(3):
await import_log_repo.add_log(
db,
source_id=None,
source_url=f"https://s{i}.test/",
ips_imported=i * 5,
ips_skipped=0,
errors=None,
)
items, total = await import_log_repo.list_logs(db)
assert total == 3
assert len(items) == 3
async def test_list_logs_pagination(self, db: aiosqlite.Connection) -> None:
"""list_logs respects page and page_size."""
for i in range(5):
await import_log_repo.add_log(
db,
source_id=None,
source_url=f"https://p{i}.test/",
ips_imported=1,
ips_skipped=0,
errors=None,
)
items, total = await import_log_repo.list_logs(db, page=2, page_size=2)
assert total == 5
assert len(items) == 2
async def test_list_logs_source_filter(self, db: aiosqlite.Connection) -> None:
"""list_logs filters by source_id."""
source_id = await blocklist_repo.create_source(db, "Src", "https://s.test/")
await import_log_repo.add_log(
db,
source_id=source_id,
source_url="https://s.test/",
ips_imported=5,
ips_skipped=0,
errors=None,
)
await import_log_repo.add_log(
db,
source_id=None,
source_url="https://other.test/",
ips_imported=3,
ips_skipped=0,
errors=None,
)
items, total = await import_log_repo.list_logs(db, source_id=source_id)
assert total == 1
assert items[0]["source_url"] == "https://s.test/"
async def test_get_last_log_empty(self, db: aiosqlite.Connection) -> None:
"""get_last_log returns None when no logs exist."""
result = await import_log_repo.get_last_log(db)
assert result is None
async def test_get_last_log_returns_most_recent(self, db: aiosqlite.Connection) -> None:
"""get_last_log returns the most recently inserted entry."""
await import_log_repo.add_log(
db,
source_id=None,
source_url="https://first.test/",
ips_imported=1,
ips_skipped=0,
errors=None,
)
await import_log_repo.add_log(
db,
source_id=None,
source_url="https://last.test/",
ips_imported=2,
ips_skipped=0,
errors=None,
)
last = await import_log_repo.get_last_log(db)
assert last is not None
assert last["source_url"] == "https://last.test/"
async def test_compute_total_pages(self) -> None:
"""compute_total_pages returns correct page count."""
assert import_log_repo.compute_total_pages(0, 10) == 1
assert import_log_repo.compute_total_pages(10, 10) == 1
assert import_log_repo.compute_total_pages(11, 10) == 2
assert import_log_repo.compute_total_pages(20, 5) == 4

View File

@@ -0,0 +1,447 @@
"""Tests for the blocklist router (9 endpoints)."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import aiosqlite
import pytest
from httpx import ASGITransport, AsyncClient
from app.config import Settings
from app.db import init_db
from app.main import create_app
from app.models.blocklist import (
BlocklistListResponse,
BlocklistSource,
ImportLogListResponse,
ImportRunResult,
ImportSourceResult,
PreviewResponse,
ScheduleConfig,
ScheduleFrequency,
ScheduleInfo,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_SETUP_PAYLOAD = {
"master_password": "testpassword1",
"database_path": "bangui.db",
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
"timezone": "UTC",
"session_duration_minutes": 60,
}
def _make_source(source_id: int = 1) -> BlocklistSource:
return BlocklistSource(
id=source_id,
name="Test Source",
url="https://test.test/ips.txt",
enabled=True,
created_at="2026-01-01T00:00:00Z",
updated_at="2026-01-01T00:00:00Z",
)
def _make_source_list() -> BlocklistListResponse:
return BlocklistListResponse(sources=[_make_source(1), _make_source(2)])
def _make_schedule_info() -> ScheduleInfo:
return ScheduleInfo(
config=ScheduleConfig(
frequency=ScheduleFrequency.daily,
interval_hours=24,
hour=3,
minute=0,
day_of_week=0,
),
next_run_at="2026-02-01T03:00:00+00:00",
last_run_at=None,
)
def _make_import_result() -> ImportRunResult:
return ImportRunResult(
results=[
ImportSourceResult(
source_id=1,
source_url="https://test.test/ips.txt",
ips_imported=5,
ips_skipped=1,
error=None,
)
],
total_imported=5,
total_skipped=1,
errors_count=0,
)
def _make_log_response() -> ImportLogListResponse:
return ImportLogListResponse(
items=[], total=0, page=1, page_size=50, total_pages=1
)
def _make_preview() -> PreviewResponse:
return PreviewResponse(
entries=["1.2.3.4", "5.6.7.8"],
total_lines=10,
valid_count=8,
skipped_count=2,
)
# ---------------------------------------------------------------------------
# Fixture
# ---------------------------------------------------------------------------
@pytest.fixture
async def bl_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
"""Provide an authenticated AsyncClient for blocklist endpoint tests."""
settings = Settings(
database_path=str(tmp_path / "bl_router_test.db"),
fail2ban_socket="/tmp/fake_fail2ban.sock",
session_secret="test-bl-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
db.row_factory = aiosqlite.Row
await init_db(db)
app.state.db = db
app.state.http_session = MagicMock()
# Provide a minimal scheduler stub so the router can call .get_job().
scheduler_stub = MagicMock()
scheduler_stub.get_job = MagicMock(return_value=None)
app.state.scheduler = scheduler_stub
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
resp = await ac.post("/api/setup", json=_SETUP_PAYLOAD)
assert resp.status_code == 201
login_resp = await ac.post(
"/api/auth/login",
json={"password": _SETUP_PAYLOAD["master_password"]},
)
assert login_resp.status_code == 200
yield ac
await db.close()
# ---------------------------------------------------------------------------
# GET /api/blocklists
# ---------------------------------------------------------------------------
class TestListBlocklists:
async def test_authenticated_returns_200(self, bl_client: AsyncClient) -> None:
"""Authenticated request to list sources returns HTTP 200."""
with patch(
"app.routers.blocklist.blocklist_service.list_sources",
new=AsyncMock(return_value=_make_source_list().sources),
):
resp = await bl_client.get("/api/blocklists")
assert resp.status_code == 200
async def test_returns_401_unauthenticated(self, client: AsyncClient) -> None:
"""Unauthenticated request returns 401."""
await client.post("/api/setup", json=_SETUP_PAYLOAD)
resp = await client.get("/api/blocklists")
assert resp.status_code == 401
async def test_response_contains_sources_key(self, bl_client: AsyncClient) -> None:
"""Response body has a 'sources' array."""
with patch(
"app.routers.blocklist.blocklist_service.list_sources",
new=AsyncMock(return_value=[_make_source()]),
):
resp = await bl_client.get("/api/blocklists")
body = resp.json()
assert "sources" in body
assert isinstance(body["sources"], list)
# ---------------------------------------------------------------------------
# POST /api/blocklists
# ---------------------------------------------------------------------------
class TestCreateBlocklist:
async def test_create_returns_201(self, bl_client: AsyncClient) -> None:
"""POST /api/blocklists creates a source and returns HTTP 201."""
with patch(
"app.routers.blocklist.blocklist_service.create_source",
new=AsyncMock(return_value=_make_source()),
):
resp = await bl_client.post(
"/api/blocklists",
json={"name": "Test", "url": "https://test.test/", "enabled": True},
)
assert resp.status_code == 201
async def test_create_source_id_in_response(self, bl_client: AsyncClient) -> None:
"""Created source response includes the id field."""
with patch(
"app.routers.blocklist.blocklist_service.create_source",
new=AsyncMock(return_value=_make_source(42)),
):
resp = await bl_client.post(
"/api/blocklists",
json={"name": "Test", "url": "https://test.test/", "enabled": True},
)
assert resp.json()["id"] == 42
# ---------------------------------------------------------------------------
# PUT /api/blocklists/{id}
# ---------------------------------------------------------------------------
class TestUpdateBlocklist:
async def test_update_returns_200(self, bl_client: AsyncClient) -> None:
"""PUT /api/blocklists/1 returns 200 for a found source."""
updated = _make_source()
updated.enabled = False
with patch(
"app.routers.blocklist.blocklist_service.update_source",
new=AsyncMock(return_value=updated),
):
resp = await bl_client.put(
"/api/blocklists/1",
json={"enabled": False},
)
assert resp.status_code == 200
async def test_update_returns_404_for_missing(self, bl_client: AsyncClient) -> None:
"""PUT /api/blocklists/999 returns 404 when source does not exist."""
with patch(
"app.routers.blocklist.blocklist_service.update_source",
new=AsyncMock(return_value=None),
):
resp = await bl_client.put(
"/api/blocklists/999",
json={"enabled": False},
)
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# DELETE /api/blocklists/{id}
# ---------------------------------------------------------------------------
class TestDeleteBlocklist:
async def test_delete_returns_204(self, bl_client: AsyncClient) -> None:
"""DELETE /api/blocklists/1 returns 204 for a found source."""
with patch(
"app.routers.blocklist.blocklist_service.delete_source",
new=AsyncMock(return_value=True),
):
resp = await bl_client.delete("/api/blocklists/1")
assert resp.status_code == 204
async def test_delete_returns_404_for_missing(self, bl_client: AsyncClient) -> None:
"""DELETE /api/blocklists/999 returns 404 when source does not exist."""
with patch(
"app.routers.blocklist.blocklist_service.delete_source",
new=AsyncMock(return_value=False),
):
resp = await bl_client.delete("/api/blocklists/999")
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# GET /api/blocklists/{id}/preview
# ---------------------------------------------------------------------------
class TestPreviewBlocklist:
async def test_preview_returns_200(self, bl_client: AsyncClient) -> None:
"""GET /api/blocklists/1/preview returns 200 for existing source."""
with patch(
"app.routers.blocklist.blocklist_service.get_source",
new=AsyncMock(return_value=_make_source()),
), patch(
"app.routers.blocklist.blocklist_service.preview_source",
new=AsyncMock(return_value=_make_preview()),
):
resp = await bl_client.get("/api/blocklists/1/preview")
assert resp.status_code == 200
async def test_preview_returns_404_for_missing(self, bl_client: AsyncClient) -> None:
"""GET /api/blocklists/999/preview returns 404 when source not found."""
with patch(
"app.routers.blocklist.blocklist_service.get_source",
new=AsyncMock(return_value=None),
):
resp = await bl_client.get("/api/blocklists/999/preview")
assert resp.status_code == 404
async def test_preview_returns_502_on_download_error(
self, bl_client: AsyncClient
) -> None:
"""GET /api/blocklists/1/preview returns 502 when URL is unreachable."""
with patch(
"app.routers.blocklist.blocklist_service.get_source",
new=AsyncMock(return_value=_make_source()),
), patch(
"app.routers.blocklist.blocklist_service.preview_source",
new=AsyncMock(side_effect=ValueError("Connection refused")),
):
resp = await bl_client.get("/api/blocklists/1/preview")
assert resp.status_code == 502
async def test_preview_response_shape(self, bl_client: AsyncClient) -> None:
"""Preview response has entries, valid_count, skipped_count, total_lines."""
with patch(
"app.routers.blocklist.blocklist_service.get_source",
new=AsyncMock(return_value=_make_source()),
), patch(
"app.routers.blocklist.blocklist_service.preview_source",
new=AsyncMock(return_value=_make_preview()),
):
resp = await bl_client.get("/api/blocklists/1/preview")
body = resp.json()
assert "entries" in body
assert "valid_count" in body
assert "skipped_count" in body
assert "total_lines" in body
# ---------------------------------------------------------------------------
# POST /api/blocklists/import
# ---------------------------------------------------------------------------
class TestRunImport:
async def test_import_returns_200(self, bl_client: AsyncClient) -> None:
"""POST /api/blocklists/import returns 200 with aggregated results."""
with patch(
"app.routers.blocklist.blocklist_service.import_all",
new=AsyncMock(return_value=_make_import_result()),
):
resp = await bl_client.post("/api/blocklists/import")
assert resp.status_code == 200
async def test_import_response_shape(self, bl_client: AsyncClient) -> None:
"""Import response has results, total_imported, total_skipped, errors_count."""
with patch(
"app.routers.blocklist.blocklist_service.import_all",
new=AsyncMock(return_value=_make_import_result()),
):
resp = await bl_client.post("/api/blocklists/import")
body = resp.json()
assert "total_imported" in body
assert "total_skipped" in body
assert "errors_count" in body
assert "results" in body
# ---------------------------------------------------------------------------
# GET /api/blocklists/schedule
# ---------------------------------------------------------------------------
class TestGetSchedule:
async def test_schedule_returns_200(self, bl_client: AsyncClient) -> None:
"""GET /api/blocklists/schedule returns 200."""
with patch(
"app.routers.blocklist.blocklist_service.get_schedule_info",
new=AsyncMock(return_value=_make_schedule_info()),
):
resp = await bl_client.get("/api/blocklists/schedule")
assert resp.status_code == 200
async def test_schedule_response_has_config(self, bl_client: AsyncClient) -> None:
"""Schedule response includes the config sub-object."""
with patch(
"app.routers.blocklist.blocklist_service.get_schedule_info",
new=AsyncMock(return_value=_make_schedule_info()),
):
resp = await bl_client.get("/api/blocklists/schedule")
body = resp.json()
assert "config" in body
assert "next_run_at" in body
assert "last_run_at" in body
# ---------------------------------------------------------------------------
# PUT /api/blocklists/schedule
# ---------------------------------------------------------------------------
class TestUpdateSchedule:
async def test_update_schedule_returns_200(self, bl_client: AsyncClient) -> None:
"""PUT /api/blocklists/schedule persists new config and returns 200."""
new_info = ScheduleInfo(
config=ScheduleConfig(
frequency=ScheduleFrequency.hourly,
interval_hours=12,
hour=0,
minute=0,
day_of_week=0,
),
next_run_at=None,
last_run_at=None,
)
with patch(
"app.routers.blocklist.blocklist_service.set_schedule",
new=AsyncMock(),
), patch(
"app.routers.blocklist.blocklist_service.get_schedule_info",
new=AsyncMock(return_value=new_info),
), patch(
"app.routers.blocklist.blocklist_import_task.reschedule",
):
resp = await bl_client.put(
"/api/blocklists/schedule",
json={
"frequency": "hourly",
"interval_hours": 12,
"hour": 0,
"minute": 0,
"day_of_week": 0,
},
)
assert resp.status_code == 200
# ---------------------------------------------------------------------------
# GET /api/blocklists/log
# ---------------------------------------------------------------------------
class TestImportLog:
async def test_log_returns_200(self, bl_client: AsyncClient) -> None:
"""GET /api/blocklists/log returns 200."""
resp = await bl_client.get("/api/blocklists/log")
assert resp.status_code == 200
async def test_log_response_shape(self, bl_client: AsyncClient) -> None:
"""Log response has items, total, page, page_size, total_pages."""
resp = await bl_client.get("/api/blocklists/log")
body = resp.json()
for key in ("items", "total", "page", "page_size", "total_pages"):
assert key in body
async def test_log_empty_when_no_runs(self, bl_client: AsyncClient) -> None:
"""Log returns empty items list when no import runs have occurred."""
resp = await bl_client.get("/api/blocklists/log")
body = resp.json()
assert body["total"] == 0
assert body["items"] == []

View File

@@ -0,0 +1,233 @@
"""Tests for blocklist_service — source CRUD, preview, import, schedule."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import aiosqlite
import pytest
from app.db import init_db
from app.models.blocklist import BlocklistSource, ScheduleConfig, ScheduleFrequency
from app.services import blocklist_service
# ---------------------------------------------------------------------------
# Fixture
# ---------------------------------------------------------------------------
@pytest.fixture
async def db(tmp_path: Path) -> aiosqlite.Connection: # type: ignore[misc]
"""Provide an initialised aiosqlite connection."""
conn: aiosqlite.Connection = await aiosqlite.connect(str(tmp_path / "bl_svc.db"))
conn.row_factory = aiosqlite.Row
await init_db(conn)
yield conn
await conn.close()
def _make_session(text: str, status: int = 200) -> MagicMock:
"""Build a mock aiohttp session that returns *text* for GET requests."""
mock_resp = AsyncMock()
mock_resp.status = status
mock_resp.text = AsyncMock(return_value=text)
mock_resp.content = AsyncMock()
mock_resp.content.read = AsyncMock(return_value=text.encode())
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_resp)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
session = MagicMock()
session.get = MagicMock(return_value=mock_ctx)
return session
# ---------------------------------------------------------------------------
# Source CRUD
# ---------------------------------------------------------------------------
class TestSourceCRUD:
async def test_create_and_get(self, db: aiosqlite.Connection) -> None:
"""create_source persists and get_source retrieves a source."""
source = await blocklist_service.create_source(db, "Test", "https://t.test/")
assert isinstance(source, BlocklistSource)
assert source.name == "Test"
assert source.enabled is True
fetched = await blocklist_service.get_source(db, source.id)
assert fetched is not None
assert fetched.id == source.id
async def test_get_missing_returns_none(self, db: aiosqlite.Connection) -> None:
"""get_source returns None for a non-existent id."""
result = await blocklist_service.get_source(db, 9999)
assert result is None
async def test_list_sources_empty(self, db: aiosqlite.Connection) -> None:
"""list_sources returns empty list when no sources exist."""
sources = await blocklist_service.list_sources(db)
assert sources == []
async def test_list_sources_returns_all(self, db: aiosqlite.Connection) -> None:
"""list_sources returns all created sources."""
await blocklist_service.create_source(db, "A", "https://a.test/")
await blocklist_service.create_source(db, "B", "https://b.test/")
sources = await blocklist_service.list_sources(db)
assert len(sources) == 2
async def test_update_source_fields(self, db: aiosqlite.Connection) -> None:
"""update_source modifies specified fields."""
source = await blocklist_service.create_source(db, "Original", "https://orig.test/")
updated = await blocklist_service.update_source(db, source.id, name="Updated", enabled=False)
assert updated is not None
assert updated.name == "Updated"
assert updated.enabled is False
async def test_update_source_missing_returns_none(self, db: aiosqlite.Connection) -> None:
"""update_source returns None for a non-existent id."""
result = await blocklist_service.update_source(db, 9999, name="Ghost")
assert result is None
async def test_delete_source(self, db: aiosqlite.Connection) -> None:
"""delete_source removes a source and returns True."""
source = await blocklist_service.create_source(db, "Del", "https://del.test/")
deleted = await blocklist_service.delete_source(db, source.id)
assert deleted is True
assert await blocklist_service.get_source(db, source.id) is None
async def test_delete_source_missing_returns_false(self, db: aiosqlite.Connection) -> None:
"""delete_source returns False for a non-existent id."""
result = await blocklist_service.delete_source(db, 9999)
assert result is False
# ---------------------------------------------------------------------------
# Preview
# ---------------------------------------------------------------------------
class TestPreview:
async def test_preview_valid_ips(self) -> None:
"""preview_source returns valid IPs from the downloaded content."""
content = "1.2.3.4\n5.6.7.8\n# comment\ninvalid\n9.0.0.1\n"
session = _make_session(content)
result = await blocklist_service.preview_source("https://test.test/ips.txt", session)
assert result.valid_count == 3
assert result.skipped_count == 1 # "invalid"
assert "1.2.3.4" in result.entries
async def test_preview_http_error_raises(self) -> None:
"""preview_source raises ValueError when the server returns non-200."""
session = _make_session("", status=404)
with pytest.raises(ValueError, match="HTTP 404"):
await blocklist_service.preview_source("https://bad.test/", session)
async def test_preview_limits_entries(self) -> None:
"""preview_source caps entries to sample_lines."""
ips = "\n".join(f"1.2.3.{i}" for i in range(50))
session = _make_session(ips)
result = await blocklist_service.preview_source(
"https://test.test/", session, sample_lines=10
)
assert len(result.entries) <= 10
assert result.valid_count == 50
# ---------------------------------------------------------------------------
# Import
# ---------------------------------------------------------------------------
class TestImport:
async def test_import_source_bans_valid_ips(self, db: aiosqlite.Connection) -> None:
"""import_source calls ban_ip for every valid IP in the blocklist."""
content = "1.2.3.4\n5.6.7.8\n# skip me\n"
session = _make_session(content)
source = await blocklist_service.create_source(db, "Import Test", "https://t.test/")
with patch(
"app.services.jail_service.ban_ip", new_callable=AsyncMock
) as mock_ban:
result = await blocklist_service.import_source(
source, session, "/tmp/fake.sock", db
)
assert result.ips_imported == 2
assert result.ips_skipped == 0
assert result.error is None
assert mock_ban.call_count == 2
async def test_import_source_skips_cidrs(self, db: aiosqlite.Connection) -> None:
"""import_source skips CIDR ranges (fail2ban expects individual IPs)."""
content = "1.2.3.4\n10.0.0.0/24\n"
session = _make_session(content)
source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/")
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
result = await blocklist_service.import_source(
source, session, "/tmp/fake.sock", db
)
assert result.ips_imported == 1
assert result.ips_skipped == 1
async def test_import_source_records_download_error(self, db: aiosqlite.Connection) -> None:
"""import_source records an error and returns 0 imported on HTTP failure."""
session = _make_session("", status=503)
source = await blocklist_service.create_source(db, "Err Source", "https://err.test/")
result = await blocklist_service.import_source(
source, session, "/tmp/fake.sock", db
)
assert result.ips_imported == 0
assert result.error is not None
async def test_import_all_runs_all_enabled(self, db: aiosqlite.Connection) -> None:
"""import_all aggregates results across all enabled sources."""
await blocklist_service.create_source(db, "S1", "https://s1.test/")
s2 = await blocklist_service.create_source(db, "S2", "https://s2.test/", enabled=False)
_ = s2 # noqa: F841
content = "1.2.3.4\n5.6.7.8\n"
session = _make_session(content)
with patch(
"app.services.jail_service.ban_ip", new_callable=AsyncMock
):
result = await blocklist_service.import_all(db, session, "/tmp/fake.sock")
# Only S1 is enabled, S2 is disabled.
assert len(result.results) == 1
assert result.results[0].source_url == "https://s1.test/"
# ---------------------------------------------------------------------------
# Schedule
# ---------------------------------------------------------------------------
class TestSchedule:
async def test_get_schedule_default(self, db: aiosqlite.Connection) -> None:
"""get_schedule returns the default daily-03:00 config when nothing is saved."""
config = await blocklist_service.get_schedule(db)
assert config.frequency == ScheduleFrequency.daily
assert config.hour == 3
async def test_set_and_get_round_trip(self, db: aiosqlite.Connection) -> None:
"""set_schedule persists config retrievable by get_schedule."""
cfg = ScheduleConfig(frequency=ScheduleFrequency.hourly, interval_hours=6, hour=0, minute=0, day_of_week=0)
await blocklist_service.set_schedule(db, cfg)
loaded = await blocklist_service.get_schedule(db)
assert loaded.frequency == ScheduleFrequency.hourly
assert loaded.interval_hours == 6
async def test_get_schedule_info_no_log(self, db: aiosqlite.Connection) -> None:
"""get_schedule_info returns None for last_run_at when no log exists."""
info = await blocklist_service.get_schedule_info(db, None)
assert info.last_run_at is None
assert info.next_run_at is None