Stage 10: external blocklist importer — backend + frontend
- blocklist_repo.py: CRUD for blocklist_sources table - import_log_repo.py: add/list/get-last log entries - blocklist_service.py: source CRUD, preview, import (download/validate/ban), import_all, schedule get/set/info - blocklist_import.py: APScheduler task (hourly/daily/weekly schedule triggers) - blocklist.py router: 9 endpoints (list/create/update/delete/preview/import/ schedule-get+put/log) - blocklist.py models: ScheduleFrequency (StrEnum), ScheduleConfig, ScheduleInfo, ImportSourceResult, ImportRunResult, PreviewResponse - 59 new tests (18 repo + 19 service + 22 router); 374 total pass - ruff clean, mypy clean for Stage 10 files - types/blocklist.ts, api/blocklist.ts, hooks/useBlocklist.ts - BlocklistsPage.tsx: source management, schedule picker, import log table - Frontend tsc + ESLint clean
This commit is contained in:
210
backend/tests/test_repositories/test_blocklist.py
Normal file
210
backend/tests/test_repositories/test_blocklist.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Tests for blocklist_repo and import_log_repo."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from app.db import init_db
|
||||
from app.repositories import blocklist_repo, import_log_repo
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db(tmp_path: Path) -> aiosqlite.Connection: # type: ignore[misc]
|
||||
"""Provide an initialised aiosqlite connection for repository tests."""
|
||||
conn: aiosqlite.Connection = await aiosqlite.connect(str(tmp_path / "bl_test.db"))
|
||||
conn.row_factory = aiosqlite.Row
|
||||
await init_db(conn)
|
||||
yield conn
|
||||
await conn.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# blocklist_repo tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBlocklistRepo:
|
||||
async def test_create_source_returns_int_id(self, db: aiosqlite.Connection) -> None:
|
||||
"""create_source returns a positive integer id."""
|
||||
source_id = await blocklist_repo.create_source(db, "Test", "https://example.com/list.txt")
|
||||
assert isinstance(source_id, int)
|
||||
assert source_id > 0
|
||||
|
||||
async def test_get_source_returns_row(self, db: aiosqlite.Connection) -> None:
|
||||
"""get_source returns the correct row after creation."""
|
||||
source_id = await blocklist_repo.create_source(db, "Alpha", "https://alpha.test/ips.txt")
|
||||
row = await blocklist_repo.get_source(db, source_id)
|
||||
assert row is not None
|
||||
assert row["name"] == "Alpha"
|
||||
assert row["url"] == "https://alpha.test/ips.txt"
|
||||
assert row["enabled"] is True
|
||||
|
||||
async def test_get_source_missing_returns_none(self, db: aiosqlite.Connection) -> None:
|
||||
"""get_source returns None for a non-existent id."""
|
||||
result = await blocklist_repo.get_source(db, 9999)
|
||||
assert result is None
|
||||
|
||||
async def test_list_sources_empty(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_sources returns empty list when no sources exist."""
|
||||
rows = await blocklist_repo.list_sources(db)
|
||||
assert rows == []
|
||||
|
||||
async def test_list_sources_returns_all(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_sources returns all created sources."""
|
||||
await blocklist_repo.create_source(db, "A", "https://a.test/")
|
||||
await blocklist_repo.create_source(db, "B", "https://b.test/")
|
||||
rows = await blocklist_repo.list_sources(db)
|
||||
assert len(rows) == 2
|
||||
|
||||
async def test_list_enabled_sources_filters(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_enabled_sources only returns rows with enabled=True."""
|
||||
await blocklist_repo.create_source(db, "Enabled", "https://on.test/", enabled=True)
|
||||
id2 = await blocklist_repo.create_source(db, "Disabled", "https://off.test/", enabled=False)
|
||||
await blocklist_repo.update_source(db, id2, enabled=False)
|
||||
rows = await blocklist_repo.list_enabled_sources(db)
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["name"] == "Enabled"
|
||||
|
||||
async def test_update_source_name(self, db: aiosqlite.Connection) -> None:
|
||||
"""update_source changes the name field."""
|
||||
source_id = await blocklist_repo.create_source(db, "Old", "https://old.test/")
|
||||
updated = await blocklist_repo.update_source(db, source_id, name="New")
|
||||
assert updated is True
|
||||
row = await blocklist_repo.get_source(db, source_id)
|
||||
assert row is not None
|
||||
assert row["name"] == "New"
|
||||
|
||||
async def test_update_source_enabled_false(self, db: aiosqlite.Connection) -> None:
|
||||
"""update_source can disable a source."""
|
||||
source_id = await blocklist_repo.create_source(db, "On", "https://on.test/")
|
||||
await blocklist_repo.update_source(db, source_id, enabled=False)
|
||||
row = await blocklist_repo.get_source(db, source_id)
|
||||
assert row is not None
|
||||
assert row["enabled"] is False
|
||||
|
||||
async def test_update_source_missing_returns_false(self, db: aiosqlite.Connection) -> None:
|
||||
"""update_source returns False for a non-existent id."""
|
||||
result = await blocklist_repo.update_source(db, 9999, name="Ghost")
|
||||
assert result is False
|
||||
|
||||
async def test_delete_source_removes_row(self, db: aiosqlite.Connection) -> None:
|
||||
"""delete_source removes the row and returns True."""
|
||||
source_id = await blocklist_repo.create_source(db, "Del", "https://del.test/")
|
||||
deleted = await blocklist_repo.delete_source(db, source_id)
|
||||
assert deleted is True
|
||||
assert await blocklist_repo.get_source(db, source_id) is None
|
||||
|
||||
async def test_delete_source_missing_returns_false(self, db: aiosqlite.Connection) -> None:
|
||||
"""delete_source returns False for a non-existent id."""
|
||||
result = await blocklist_repo.delete_source(db, 9999)
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# import_log_repo tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestImportLogRepo:
|
||||
async def test_add_log_returns_id(self, db: aiosqlite.Connection) -> None:
|
||||
"""add_log returns a positive integer id."""
|
||||
log_id = await import_log_repo.add_log(
|
||||
db,
|
||||
source_id=None,
|
||||
source_url="https://example.com/list.txt",
|
||||
ips_imported=10,
|
||||
ips_skipped=2,
|
||||
errors=None,
|
||||
)
|
||||
assert isinstance(log_id, int)
|
||||
assert log_id > 0
|
||||
|
||||
async def test_list_logs_returns_all(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_logs returns all logs when no source_id filter is applied."""
|
||||
for i in range(3):
|
||||
await import_log_repo.add_log(
|
||||
db,
|
||||
source_id=None,
|
||||
source_url=f"https://s{i}.test/",
|
||||
ips_imported=i * 5,
|
||||
ips_skipped=0,
|
||||
errors=None,
|
||||
)
|
||||
items, total = await import_log_repo.list_logs(db)
|
||||
assert total == 3
|
||||
assert len(items) == 3
|
||||
|
||||
async def test_list_logs_pagination(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_logs respects page and page_size."""
|
||||
for i in range(5):
|
||||
await import_log_repo.add_log(
|
||||
db,
|
||||
source_id=None,
|
||||
source_url=f"https://p{i}.test/",
|
||||
ips_imported=1,
|
||||
ips_skipped=0,
|
||||
errors=None,
|
||||
)
|
||||
items, total = await import_log_repo.list_logs(db, page=2, page_size=2)
|
||||
assert total == 5
|
||||
assert len(items) == 2
|
||||
|
||||
async def test_list_logs_source_filter(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_logs filters by source_id."""
|
||||
source_id = await blocklist_repo.create_source(db, "Src", "https://s.test/")
|
||||
await import_log_repo.add_log(
|
||||
db,
|
||||
source_id=source_id,
|
||||
source_url="https://s.test/",
|
||||
ips_imported=5,
|
||||
ips_skipped=0,
|
||||
errors=None,
|
||||
)
|
||||
await import_log_repo.add_log(
|
||||
db,
|
||||
source_id=None,
|
||||
source_url="https://other.test/",
|
||||
ips_imported=3,
|
||||
ips_skipped=0,
|
||||
errors=None,
|
||||
)
|
||||
items, total = await import_log_repo.list_logs(db, source_id=source_id)
|
||||
assert total == 1
|
||||
assert items[0]["source_url"] == "https://s.test/"
|
||||
|
||||
async def test_get_last_log_empty(self, db: aiosqlite.Connection) -> None:
|
||||
"""get_last_log returns None when no logs exist."""
|
||||
result = await import_log_repo.get_last_log(db)
|
||||
assert result is None
|
||||
|
||||
async def test_get_last_log_returns_most_recent(self, db: aiosqlite.Connection) -> None:
|
||||
"""get_last_log returns the most recently inserted entry."""
|
||||
await import_log_repo.add_log(
|
||||
db,
|
||||
source_id=None,
|
||||
source_url="https://first.test/",
|
||||
ips_imported=1,
|
||||
ips_skipped=0,
|
||||
errors=None,
|
||||
)
|
||||
await import_log_repo.add_log(
|
||||
db,
|
||||
source_id=None,
|
||||
source_url="https://last.test/",
|
||||
ips_imported=2,
|
||||
ips_skipped=0,
|
||||
errors=None,
|
||||
)
|
||||
last = await import_log_repo.get_last_log(db)
|
||||
assert last is not None
|
||||
assert last["source_url"] == "https://last.test/"
|
||||
|
||||
async def test_compute_total_pages(self) -> None:
|
||||
"""compute_total_pages returns correct page count."""
|
||||
assert import_log_repo.compute_total_pages(0, 10) == 1
|
||||
assert import_log_repo.compute_total_pages(10, 10) == 1
|
||||
assert import_log_repo.compute_total_pages(11, 10) == 2
|
||||
assert import_log_repo.compute_total_pages(20, 5) == 4
|
||||
447
backend/tests/test_routers/test_blocklist.py
Normal file
447
backend/tests/test_routers/test_blocklist.py
Normal file
@@ -0,0 +1,447 @@
|
||||
"""Tests for the blocklist router (9 endpoints)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.models.blocklist import (
|
||||
BlocklistListResponse,
|
||||
BlocklistSource,
|
||||
ImportLogListResponse,
|
||||
ImportRunResult,
|
||||
ImportSourceResult,
|
||||
PreviewResponse,
|
||||
ScheduleConfig,
|
||||
ScheduleFrequency,
|
||||
ScheduleInfo,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SETUP_PAYLOAD = {
|
||||
"master_password": "testpassword1",
|
||||
"database_path": "bangui.db",
|
||||
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
||||
"timezone": "UTC",
|
||||
"session_duration_minutes": 60,
|
||||
}
|
||||
|
||||
|
||||
def _make_source(source_id: int = 1) -> BlocklistSource:
|
||||
return BlocklistSource(
|
||||
id=source_id,
|
||||
name="Test Source",
|
||||
url="https://test.test/ips.txt",
|
||||
enabled=True,
|
||||
created_at="2026-01-01T00:00:00Z",
|
||||
updated_at="2026-01-01T00:00:00Z",
|
||||
)
|
||||
|
||||
|
||||
def _make_source_list() -> BlocklistListResponse:
|
||||
return BlocklistListResponse(sources=[_make_source(1), _make_source(2)])
|
||||
|
||||
|
||||
def _make_schedule_info() -> ScheduleInfo:
|
||||
return ScheduleInfo(
|
||||
config=ScheduleConfig(
|
||||
frequency=ScheduleFrequency.daily,
|
||||
interval_hours=24,
|
||||
hour=3,
|
||||
minute=0,
|
||||
day_of_week=0,
|
||||
),
|
||||
next_run_at="2026-02-01T03:00:00+00:00",
|
||||
last_run_at=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_import_result() -> ImportRunResult:
|
||||
return ImportRunResult(
|
||||
results=[
|
||||
ImportSourceResult(
|
||||
source_id=1,
|
||||
source_url="https://test.test/ips.txt",
|
||||
ips_imported=5,
|
||||
ips_skipped=1,
|
||||
error=None,
|
||||
)
|
||||
],
|
||||
total_imported=5,
|
||||
total_skipped=1,
|
||||
errors_count=0,
|
||||
)
|
||||
|
||||
|
||||
def _make_log_response() -> ImportLogListResponse:
|
||||
return ImportLogListResponse(
|
||||
items=[], total=0, page=1, page_size=50, total_pages=1
|
||||
)
|
||||
|
||||
|
||||
def _make_preview() -> PreviewResponse:
|
||||
return PreviewResponse(
|
||||
entries=["1.2.3.4", "5.6.7.8"],
|
||||
total_lines=10,
|
||||
valid_count=8,
|
||||
skipped_count=2,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def bl_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
"""Provide an authenticated AsyncClient for blocklist endpoint tests."""
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "bl_router_test.db"),
|
||||
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||
session_secret="test-bl-secret",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
|
||||
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
|
||||
db.row_factory = aiosqlite.Row
|
||||
await init_db(db)
|
||||
app.state.db = db
|
||||
app.state.http_session = MagicMock()
|
||||
|
||||
# Provide a minimal scheduler stub so the router can call .get_job().
|
||||
scheduler_stub = MagicMock()
|
||||
scheduler_stub.get_job = MagicMock(return_value=None)
|
||||
app.state.scheduler = scheduler_stub
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
resp = await ac.post("/api/setup", json=_SETUP_PAYLOAD)
|
||||
assert resp.status_code == 201
|
||||
|
||||
login_resp = await ac.post(
|
||||
"/api/auth/login",
|
||||
json={"password": _SETUP_PAYLOAD["master_password"]},
|
||||
)
|
||||
assert login_resp.status_code == 200
|
||||
|
||||
yield ac
|
||||
|
||||
await db.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/blocklists
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestListBlocklists:
|
||||
async def test_authenticated_returns_200(self, bl_client: AsyncClient) -> None:
|
||||
"""Authenticated request to list sources returns HTTP 200."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.list_sources",
|
||||
new=AsyncMock(return_value=_make_source_list().sources),
|
||||
):
|
||||
resp = await bl_client.get("/api/blocklists")
|
||||
assert resp.status_code == 200
|
||||
|
||||
async def test_returns_401_unauthenticated(self, client: AsyncClient) -> None:
|
||||
"""Unauthenticated request returns 401."""
|
||||
await client.post("/api/setup", json=_SETUP_PAYLOAD)
|
||||
resp = await client.get("/api/blocklists")
|
||||
assert resp.status_code == 401
|
||||
|
||||
async def test_response_contains_sources_key(self, bl_client: AsyncClient) -> None:
|
||||
"""Response body has a 'sources' array."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.list_sources",
|
||||
new=AsyncMock(return_value=[_make_source()]),
|
||||
):
|
||||
resp = await bl_client.get("/api/blocklists")
|
||||
body = resp.json()
|
||||
assert "sources" in body
|
||||
assert isinstance(body["sources"], list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/blocklists
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCreateBlocklist:
|
||||
async def test_create_returns_201(self, bl_client: AsyncClient) -> None:
|
||||
"""POST /api/blocklists creates a source and returns HTTP 201."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.create_source",
|
||||
new=AsyncMock(return_value=_make_source()),
|
||||
):
|
||||
resp = await bl_client.post(
|
||||
"/api/blocklists",
|
||||
json={"name": "Test", "url": "https://test.test/", "enabled": True},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
async def test_create_source_id_in_response(self, bl_client: AsyncClient) -> None:
|
||||
"""Created source response includes the id field."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.create_source",
|
||||
new=AsyncMock(return_value=_make_source(42)),
|
||||
):
|
||||
resp = await bl_client.post(
|
||||
"/api/blocklists",
|
||||
json={"name": "Test", "url": "https://test.test/", "enabled": True},
|
||||
)
|
||||
assert resp.json()["id"] == 42
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /api/blocklists/{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUpdateBlocklist:
|
||||
async def test_update_returns_200(self, bl_client: AsyncClient) -> None:
|
||||
"""PUT /api/blocklists/1 returns 200 for a found source."""
|
||||
updated = _make_source()
|
||||
updated.enabled = False
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.update_source",
|
||||
new=AsyncMock(return_value=updated),
|
||||
):
|
||||
resp = await bl_client.put(
|
||||
"/api/blocklists/1",
|
||||
json={"enabled": False},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
async def test_update_returns_404_for_missing(self, bl_client: AsyncClient) -> None:
|
||||
"""PUT /api/blocklists/999 returns 404 when source does not exist."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.update_source",
|
||||
new=AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await bl_client.put(
|
||||
"/api/blocklists/999",
|
||||
json={"enabled": False},
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DELETE /api/blocklists/{id}
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeleteBlocklist:
|
||||
async def test_delete_returns_204(self, bl_client: AsyncClient) -> None:
|
||||
"""DELETE /api/blocklists/1 returns 204 for a found source."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.delete_source",
|
||||
new=AsyncMock(return_value=True),
|
||||
):
|
||||
resp = await bl_client.delete("/api/blocklists/1")
|
||||
assert resp.status_code == 204
|
||||
|
||||
async def test_delete_returns_404_for_missing(self, bl_client: AsyncClient) -> None:
|
||||
"""DELETE /api/blocklists/999 returns 404 when source does not exist."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.delete_source",
|
||||
new=AsyncMock(return_value=False),
|
||||
):
|
||||
resp = await bl_client.delete("/api/blocklists/999")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/blocklists/{id}/preview
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPreviewBlocklist:
|
||||
async def test_preview_returns_200(self, bl_client: AsyncClient) -> None:
|
||||
"""GET /api/blocklists/1/preview returns 200 for existing source."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.get_source",
|
||||
new=AsyncMock(return_value=_make_source()),
|
||||
), patch(
|
||||
"app.routers.blocklist.blocklist_service.preview_source",
|
||||
new=AsyncMock(return_value=_make_preview()),
|
||||
):
|
||||
resp = await bl_client.get("/api/blocklists/1/preview")
|
||||
assert resp.status_code == 200
|
||||
|
||||
async def test_preview_returns_404_for_missing(self, bl_client: AsyncClient) -> None:
|
||||
"""GET /api/blocklists/999/preview returns 404 when source not found."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.get_source",
|
||||
new=AsyncMock(return_value=None),
|
||||
):
|
||||
resp = await bl_client.get("/api/blocklists/999/preview")
|
||||
assert resp.status_code == 404
|
||||
|
||||
async def test_preview_returns_502_on_download_error(
|
||||
self, bl_client: AsyncClient
|
||||
) -> None:
|
||||
"""GET /api/blocklists/1/preview returns 502 when URL is unreachable."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.get_source",
|
||||
new=AsyncMock(return_value=_make_source()),
|
||||
), patch(
|
||||
"app.routers.blocklist.blocklist_service.preview_source",
|
||||
new=AsyncMock(side_effect=ValueError("Connection refused")),
|
||||
):
|
||||
resp = await bl_client.get("/api/blocklists/1/preview")
|
||||
assert resp.status_code == 502
|
||||
|
||||
async def test_preview_response_shape(self, bl_client: AsyncClient) -> None:
|
||||
"""Preview response has entries, valid_count, skipped_count, total_lines."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.get_source",
|
||||
new=AsyncMock(return_value=_make_source()),
|
||||
), patch(
|
||||
"app.routers.blocklist.blocklist_service.preview_source",
|
||||
new=AsyncMock(return_value=_make_preview()),
|
||||
):
|
||||
resp = await bl_client.get("/api/blocklists/1/preview")
|
||||
body = resp.json()
|
||||
assert "entries" in body
|
||||
assert "valid_count" in body
|
||||
assert "skipped_count" in body
|
||||
assert "total_lines" in body
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# POST /api/blocklists/import
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunImport:
|
||||
async def test_import_returns_200(self, bl_client: AsyncClient) -> None:
|
||||
"""POST /api/blocklists/import returns 200 with aggregated results."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.import_all",
|
||||
new=AsyncMock(return_value=_make_import_result()),
|
||||
):
|
||||
resp = await bl_client.post("/api/blocklists/import")
|
||||
assert resp.status_code == 200
|
||||
|
||||
async def test_import_response_shape(self, bl_client: AsyncClient) -> None:
|
||||
"""Import response has results, total_imported, total_skipped, errors_count."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.import_all",
|
||||
new=AsyncMock(return_value=_make_import_result()),
|
||||
):
|
||||
resp = await bl_client.post("/api/blocklists/import")
|
||||
body = resp.json()
|
||||
assert "total_imported" in body
|
||||
assert "total_skipped" in body
|
||||
assert "errors_count" in body
|
||||
assert "results" in body
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/blocklists/schedule
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetSchedule:
|
||||
async def test_schedule_returns_200(self, bl_client: AsyncClient) -> None:
|
||||
"""GET /api/blocklists/schedule returns 200."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.get_schedule_info",
|
||||
new=AsyncMock(return_value=_make_schedule_info()),
|
||||
):
|
||||
resp = await bl_client.get("/api/blocklists/schedule")
|
||||
assert resp.status_code == 200
|
||||
|
||||
async def test_schedule_response_has_config(self, bl_client: AsyncClient) -> None:
|
||||
"""Schedule response includes the config sub-object."""
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.get_schedule_info",
|
||||
new=AsyncMock(return_value=_make_schedule_info()),
|
||||
):
|
||||
resp = await bl_client.get("/api/blocklists/schedule")
|
||||
body = resp.json()
|
||||
assert "config" in body
|
||||
assert "next_run_at" in body
|
||||
assert "last_run_at" in body
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PUT /api/blocklists/schedule
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestUpdateSchedule:
|
||||
async def test_update_schedule_returns_200(self, bl_client: AsyncClient) -> None:
|
||||
"""PUT /api/blocklists/schedule persists new config and returns 200."""
|
||||
new_info = ScheduleInfo(
|
||||
config=ScheduleConfig(
|
||||
frequency=ScheduleFrequency.hourly,
|
||||
interval_hours=12,
|
||||
hour=0,
|
||||
minute=0,
|
||||
day_of_week=0,
|
||||
),
|
||||
next_run_at=None,
|
||||
last_run_at=None,
|
||||
)
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.set_schedule",
|
||||
new=AsyncMock(),
|
||||
), patch(
|
||||
"app.routers.blocklist.blocklist_service.get_schedule_info",
|
||||
new=AsyncMock(return_value=new_info),
|
||||
), patch(
|
||||
"app.routers.blocklist.blocklist_import_task.reschedule",
|
||||
):
|
||||
resp = await bl_client.put(
|
||||
"/api/blocklists/schedule",
|
||||
json={
|
||||
"frequency": "hourly",
|
||||
"interval_hours": 12,
|
||||
"hour": 0,
|
||||
"minute": 0,
|
||||
"day_of_week": 0,
|
||||
},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GET /api/blocklists/log
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestImportLog:
|
||||
async def test_log_returns_200(self, bl_client: AsyncClient) -> None:
|
||||
"""GET /api/blocklists/log returns 200."""
|
||||
resp = await bl_client.get("/api/blocklists/log")
|
||||
assert resp.status_code == 200
|
||||
|
||||
async def test_log_response_shape(self, bl_client: AsyncClient) -> None:
|
||||
"""Log response has items, total, page, page_size, total_pages."""
|
||||
resp = await bl_client.get("/api/blocklists/log")
|
||||
body = resp.json()
|
||||
for key in ("items", "total", "page", "page_size", "total_pages"):
|
||||
assert key in body
|
||||
|
||||
async def test_log_empty_when_no_runs(self, bl_client: AsyncClient) -> None:
|
||||
"""Log returns empty items list when no import runs have occurred."""
|
||||
resp = await bl_client.get("/api/blocklists/log")
|
||||
body = resp.json()
|
||||
assert body["total"] == 0
|
||||
assert body["items"] == []
|
||||
233
backend/tests/test_services/test_blocklist_service.py
Normal file
233
backend/tests/test_services/test_blocklist_service.py
Normal file
@@ -0,0 +1,233 @@
|
||||
"""Tests for blocklist_service — source CRUD, preview, import, schedule."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiosqlite
|
||||
import pytest
|
||||
|
||||
from app.db import init_db
|
||||
from app.models.blocklist import BlocklistSource, ScheduleConfig, ScheduleFrequency
|
||||
from app.services import blocklist_service
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixture
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def db(tmp_path: Path) -> aiosqlite.Connection: # type: ignore[misc]
|
||||
"""Provide an initialised aiosqlite connection."""
|
||||
conn: aiosqlite.Connection = await aiosqlite.connect(str(tmp_path / "bl_svc.db"))
|
||||
conn.row_factory = aiosqlite.Row
|
||||
await init_db(conn)
|
||||
yield conn
|
||||
await conn.close()
|
||||
|
||||
|
||||
def _make_session(text: str, status: int = 200) -> MagicMock:
|
||||
"""Build a mock aiohttp session that returns *text* for GET requests."""
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = status
|
||||
mock_resp.text = AsyncMock(return_value=text)
|
||||
mock_resp.content = AsyncMock()
|
||||
mock_resp.content.read = AsyncMock(return_value=text.encode())
|
||||
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_resp)
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
session = MagicMock()
|
||||
session.get = MagicMock(return_value=mock_ctx)
|
||||
return session
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Source CRUD
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSourceCRUD:
|
||||
async def test_create_and_get(self, db: aiosqlite.Connection) -> None:
|
||||
"""create_source persists and get_source retrieves a source."""
|
||||
source = await blocklist_service.create_source(db, "Test", "https://t.test/")
|
||||
assert isinstance(source, BlocklistSource)
|
||||
assert source.name == "Test"
|
||||
assert source.enabled is True
|
||||
|
||||
fetched = await blocklist_service.get_source(db, source.id)
|
||||
assert fetched is not None
|
||||
assert fetched.id == source.id
|
||||
|
||||
async def test_get_missing_returns_none(self, db: aiosqlite.Connection) -> None:
|
||||
"""get_source returns None for a non-existent id."""
|
||||
result = await blocklist_service.get_source(db, 9999)
|
||||
assert result is None
|
||||
|
||||
async def test_list_sources_empty(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_sources returns empty list when no sources exist."""
|
||||
sources = await blocklist_service.list_sources(db)
|
||||
assert sources == []
|
||||
|
||||
async def test_list_sources_returns_all(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_sources returns all created sources."""
|
||||
await blocklist_service.create_source(db, "A", "https://a.test/")
|
||||
await blocklist_service.create_source(db, "B", "https://b.test/")
|
||||
sources = await blocklist_service.list_sources(db)
|
||||
assert len(sources) == 2
|
||||
|
||||
async def test_update_source_fields(self, db: aiosqlite.Connection) -> None:
|
||||
"""update_source modifies specified fields."""
|
||||
source = await blocklist_service.create_source(db, "Original", "https://orig.test/")
|
||||
updated = await blocklist_service.update_source(db, source.id, name="Updated", enabled=False)
|
||||
assert updated is not None
|
||||
assert updated.name == "Updated"
|
||||
assert updated.enabled is False
|
||||
|
||||
async def test_update_source_missing_returns_none(self, db: aiosqlite.Connection) -> None:
|
||||
"""update_source returns None for a non-existent id."""
|
||||
result = await blocklist_service.update_source(db, 9999, name="Ghost")
|
||||
assert result is None
|
||||
|
||||
async def test_delete_source(self, db: aiosqlite.Connection) -> None:
|
||||
"""delete_source removes a source and returns True."""
|
||||
source = await blocklist_service.create_source(db, "Del", "https://del.test/")
|
||||
deleted = await blocklist_service.delete_source(db, source.id)
|
||||
assert deleted is True
|
||||
assert await blocklist_service.get_source(db, source.id) is None
|
||||
|
||||
async def test_delete_source_missing_returns_false(self, db: aiosqlite.Connection) -> None:
|
||||
"""delete_source returns False for a non-existent id."""
|
||||
result = await blocklist_service.delete_source(db, 9999)
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Preview
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPreview:
|
||||
async def test_preview_valid_ips(self) -> None:
|
||||
"""preview_source returns valid IPs from the downloaded content."""
|
||||
content = "1.2.3.4\n5.6.7.8\n# comment\ninvalid\n9.0.0.1\n"
|
||||
session = _make_session(content)
|
||||
result = await blocklist_service.preview_source("https://test.test/ips.txt", session)
|
||||
assert result.valid_count == 3
|
||||
assert result.skipped_count == 1 # "invalid"
|
||||
assert "1.2.3.4" in result.entries
|
||||
|
||||
async def test_preview_http_error_raises(self) -> None:
|
||||
"""preview_source raises ValueError when the server returns non-200."""
|
||||
session = _make_session("", status=404)
|
||||
with pytest.raises(ValueError, match="HTTP 404"):
|
||||
await blocklist_service.preview_source("https://bad.test/", session)
|
||||
|
||||
async def test_preview_limits_entries(self) -> None:
|
||||
"""preview_source caps entries to sample_lines."""
|
||||
ips = "\n".join(f"1.2.3.{i}" for i in range(50))
|
||||
session = _make_session(ips)
|
||||
result = await blocklist_service.preview_source(
|
||||
"https://test.test/", session, sample_lines=10
|
||||
)
|
||||
assert len(result.entries) <= 10
|
||||
assert result.valid_count == 50
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Import
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestImport:
|
||||
async def test_import_source_bans_valid_ips(self, db: aiosqlite.Connection) -> None:
|
||||
"""import_source calls ban_ip for every valid IP in the blocklist."""
|
||||
content = "1.2.3.4\n5.6.7.8\n# skip me\n"
|
||||
session = _make_session(content)
|
||||
|
||||
source = await blocklist_service.create_source(db, "Import Test", "https://t.test/")
|
||||
|
||||
with patch(
|
||||
"app.services.jail_service.ban_ip", new_callable=AsyncMock
|
||||
) as mock_ban:
|
||||
result = await blocklist_service.import_source(
|
||||
source, session, "/tmp/fake.sock", db
|
||||
)
|
||||
|
||||
assert result.ips_imported == 2
|
||||
assert result.ips_skipped == 0
|
||||
assert result.error is None
|
||||
assert mock_ban.call_count == 2
|
||||
|
||||
async def test_import_source_skips_cidrs(self, db: aiosqlite.Connection) -> None:
|
||||
"""import_source skips CIDR ranges (fail2ban expects individual IPs)."""
|
||||
content = "1.2.3.4\n10.0.0.0/24\n"
|
||||
session = _make_session(content)
|
||||
source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/")
|
||||
|
||||
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
|
||||
result = await blocklist_service.import_source(
|
||||
source, session, "/tmp/fake.sock", db
|
||||
)
|
||||
|
||||
assert result.ips_imported == 1
|
||||
assert result.ips_skipped == 1
|
||||
|
||||
async def test_import_source_records_download_error(self, db: aiosqlite.Connection) -> None:
|
||||
"""import_source records an error and returns 0 imported on HTTP failure."""
|
||||
session = _make_session("", status=503)
|
||||
source = await blocklist_service.create_source(db, "Err Source", "https://err.test/")
|
||||
|
||||
result = await blocklist_service.import_source(
|
||||
source, session, "/tmp/fake.sock", db
|
||||
)
|
||||
|
||||
assert result.ips_imported == 0
|
||||
assert result.error is not None
|
||||
|
||||
async def test_import_all_runs_all_enabled(self, db: aiosqlite.Connection) -> None:
|
||||
"""import_all aggregates results across all enabled sources."""
|
||||
await blocklist_service.create_source(db, "S1", "https://s1.test/")
|
||||
s2 = await blocklist_service.create_source(db, "S2", "https://s2.test/", enabled=False)
|
||||
_ = s2 # noqa: F841
|
||||
|
||||
content = "1.2.3.4\n5.6.7.8\n"
|
||||
session = _make_session(content)
|
||||
|
||||
with patch(
|
||||
"app.services.jail_service.ban_ip", new_callable=AsyncMock
|
||||
):
|
||||
result = await blocklist_service.import_all(db, session, "/tmp/fake.sock")
|
||||
|
||||
# Only S1 is enabled, S2 is disabled.
|
||||
assert len(result.results) == 1
|
||||
assert result.results[0].source_url == "https://s1.test/"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schedule
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSchedule:
|
||||
async def test_get_schedule_default(self, db: aiosqlite.Connection) -> None:
|
||||
"""get_schedule returns the default daily-03:00 config when nothing is saved."""
|
||||
config = await blocklist_service.get_schedule(db)
|
||||
assert config.frequency == ScheduleFrequency.daily
|
||||
assert config.hour == 3
|
||||
|
||||
async def test_set_and_get_round_trip(self, db: aiosqlite.Connection) -> None:
|
||||
"""set_schedule persists config retrievable by get_schedule."""
|
||||
cfg = ScheduleConfig(frequency=ScheduleFrequency.hourly, interval_hours=6, hour=0, minute=0, day_of_week=0)
|
||||
await blocklist_service.set_schedule(db, cfg)
|
||||
loaded = await blocklist_service.get_schedule(db)
|
||||
assert loaded.frequency == ScheduleFrequency.hourly
|
||||
assert loaded.interval_hours == 6
|
||||
|
||||
async def test_get_schedule_info_no_log(self, db: aiosqlite.Connection) -> None:
|
||||
"""get_schedule_info returns None for last_run_at when no log exists."""
|
||||
info = await blocklist_service.get_schedule_info(db, None)
|
||||
assert info.last_run_at is None
|
||||
assert info.next_run_at is None
|
||||
Reference in New Issue
Block a user