- blocklist_repo.py: CRUD for blocklist_sources table - import_log_repo.py: add/list/get-last log entries - blocklist_service.py: source CRUD, preview, import (download/validate/ban), import_all, schedule get/set/info - blocklist_import.py: APScheduler task (hourly/daily/weekly schedule triggers) - blocklist.py router: 9 endpoints (list/create/update/delete/preview/import/ schedule-get+put/log) - blocklist.py models: ScheduleFrequency (StrEnum), ScheduleConfig, ScheduleInfo, ImportSourceResult, ImportRunResult, PreviewResponse - 59 new tests (18 repo + 19 service + 22 router); 374 total pass - ruff clean, mypy clean for Stage 10 files - types/blocklist.ts, api/blocklist.ts, hooks/useBlocklist.ts - BlocklistsPage.tsx: source management, schedule picker, import log table - Frontend tsc + ESLint clean
211 lines
8.5 KiB
Python
211 lines
8.5 KiB
Python
"""Tests for blocklist_repo and import_log_repo."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
|
|
import aiosqlite
|
|
import pytest
|
|
|
|
from app.db import init_db
|
|
from app.repositories import blocklist_repo, import_log_repo
|
|
|
|
|
|
@pytest.fixture
|
|
async def db(tmp_path: Path) -> aiosqlite.Connection: # type: ignore[misc]
|
|
"""Provide an initialised aiosqlite connection for repository tests."""
|
|
conn: aiosqlite.Connection = await aiosqlite.connect(str(tmp_path / "bl_test.db"))
|
|
conn.row_factory = aiosqlite.Row
|
|
await init_db(conn)
|
|
yield conn
|
|
await conn.close()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# blocklist_repo tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestBlocklistRepo:
|
|
async def test_create_source_returns_int_id(self, db: aiosqlite.Connection) -> None:
|
|
"""create_source returns a positive integer id."""
|
|
source_id = await blocklist_repo.create_source(db, "Test", "https://example.com/list.txt")
|
|
assert isinstance(source_id, int)
|
|
assert source_id > 0
|
|
|
|
async def test_get_source_returns_row(self, db: aiosqlite.Connection) -> None:
|
|
"""get_source returns the correct row after creation."""
|
|
source_id = await blocklist_repo.create_source(db, "Alpha", "https://alpha.test/ips.txt")
|
|
row = await blocklist_repo.get_source(db, source_id)
|
|
assert row is not None
|
|
assert row["name"] == "Alpha"
|
|
assert row["url"] == "https://alpha.test/ips.txt"
|
|
assert row["enabled"] is True
|
|
|
|
async def test_get_source_missing_returns_none(self, db: aiosqlite.Connection) -> None:
|
|
"""get_source returns None for a non-existent id."""
|
|
result = await blocklist_repo.get_source(db, 9999)
|
|
assert result is None
|
|
|
|
async def test_list_sources_empty(self, db: aiosqlite.Connection) -> None:
|
|
"""list_sources returns empty list when no sources exist."""
|
|
rows = await blocklist_repo.list_sources(db)
|
|
assert rows == []
|
|
|
|
async def test_list_sources_returns_all(self, db: aiosqlite.Connection) -> None:
|
|
"""list_sources returns all created sources."""
|
|
await blocklist_repo.create_source(db, "A", "https://a.test/")
|
|
await blocklist_repo.create_source(db, "B", "https://b.test/")
|
|
rows = await blocklist_repo.list_sources(db)
|
|
assert len(rows) == 2
|
|
|
|
async def test_list_enabled_sources_filters(self, db: aiosqlite.Connection) -> None:
|
|
"""list_enabled_sources only returns rows with enabled=True."""
|
|
await blocklist_repo.create_source(db, "Enabled", "https://on.test/", enabled=True)
|
|
id2 = await blocklist_repo.create_source(db, "Disabled", "https://off.test/", enabled=False)
|
|
await blocklist_repo.update_source(db, id2, enabled=False)
|
|
rows = await blocklist_repo.list_enabled_sources(db)
|
|
assert len(rows) == 1
|
|
assert rows[0]["name"] == "Enabled"
|
|
|
|
async def test_update_source_name(self, db: aiosqlite.Connection) -> None:
|
|
"""update_source changes the name field."""
|
|
source_id = await blocklist_repo.create_source(db, "Old", "https://old.test/")
|
|
updated = await blocklist_repo.update_source(db, source_id, name="New")
|
|
assert updated is True
|
|
row = await blocklist_repo.get_source(db, source_id)
|
|
assert row is not None
|
|
assert row["name"] == "New"
|
|
|
|
async def test_update_source_enabled_false(self, db: aiosqlite.Connection) -> None:
|
|
"""update_source can disable a source."""
|
|
source_id = await blocklist_repo.create_source(db, "On", "https://on.test/")
|
|
await blocklist_repo.update_source(db, source_id, enabled=False)
|
|
row = await blocklist_repo.get_source(db, source_id)
|
|
assert row is not None
|
|
assert row["enabled"] is False
|
|
|
|
async def test_update_source_missing_returns_false(self, db: aiosqlite.Connection) -> None:
|
|
"""update_source returns False for a non-existent id."""
|
|
result = await blocklist_repo.update_source(db, 9999, name="Ghost")
|
|
assert result is False
|
|
|
|
async def test_delete_source_removes_row(self, db: aiosqlite.Connection) -> None:
|
|
"""delete_source removes the row and returns True."""
|
|
source_id = await blocklist_repo.create_source(db, "Del", "https://del.test/")
|
|
deleted = await blocklist_repo.delete_source(db, source_id)
|
|
assert deleted is True
|
|
assert await blocklist_repo.get_source(db, source_id) is None
|
|
|
|
async def test_delete_source_missing_returns_false(self, db: aiosqlite.Connection) -> None:
|
|
"""delete_source returns False for a non-existent id."""
|
|
result = await blocklist_repo.delete_source(db, 9999)
|
|
assert result is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# import_log_repo tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestImportLogRepo:
|
|
async def test_add_log_returns_id(self, db: aiosqlite.Connection) -> None:
|
|
"""add_log returns a positive integer id."""
|
|
log_id = await import_log_repo.add_log(
|
|
db,
|
|
source_id=None,
|
|
source_url="https://example.com/list.txt",
|
|
ips_imported=10,
|
|
ips_skipped=2,
|
|
errors=None,
|
|
)
|
|
assert isinstance(log_id, int)
|
|
assert log_id > 0
|
|
|
|
async def test_list_logs_returns_all(self, db: aiosqlite.Connection) -> None:
|
|
"""list_logs returns all logs when no source_id filter is applied."""
|
|
for i in range(3):
|
|
await import_log_repo.add_log(
|
|
db,
|
|
source_id=None,
|
|
source_url=f"https://s{i}.test/",
|
|
ips_imported=i * 5,
|
|
ips_skipped=0,
|
|
errors=None,
|
|
)
|
|
items, total = await import_log_repo.list_logs(db)
|
|
assert total == 3
|
|
assert len(items) == 3
|
|
|
|
async def test_list_logs_pagination(self, db: aiosqlite.Connection) -> None:
|
|
"""list_logs respects page and page_size."""
|
|
for i in range(5):
|
|
await import_log_repo.add_log(
|
|
db,
|
|
source_id=None,
|
|
source_url=f"https://p{i}.test/",
|
|
ips_imported=1,
|
|
ips_skipped=0,
|
|
errors=None,
|
|
)
|
|
items, total = await import_log_repo.list_logs(db, page=2, page_size=2)
|
|
assert total == 5
|
|
assert len(items) == 2
|
|
|
|
async def test_list_logs_source_filter(self, db: aiosqlite.Connection) -> None:
|
|
"""list_logs filters by source_id."""
|
|
source_id = await blocklist_repo.create_source(db, "Src", "https://s.test/")
|
|
await import_log_repo.add_log(
|
|
db,
|
|
source_id=source_id,
|
|
source_url="https://s.test/",
|
|
ips_imported=5,
|
|
ips_skipped=0,
|
|
errors=None,
|
|
)
|
|
await import_log_repo.add_log(
|
|
db,
|
|
source_id=None,
|
|
source_url="https://other.test/",
|
|
ips_imported=3,
|
|
ips_skipped=0,
|
|
errors=None,
|
|
)
|
|
items, total = await import_log_repo.list_logs(db, source_id=source_id)
|
|
assert total == 1
|
|
assert items[0]["source_url"] == "https://s.test/"
|
|
|
|
async def test_get_last_log_empty(self, db: aiosqlite.Connection) -> None:
|
|
"""get_last_log returns None when no logs exist."""
|
|
result = await import_log_repo.get_last_log(db)
|
|
assert result is None
|
|
|
|
async def test_get_last_log_returns_most_recent(self, db: aiosqlite.Connection) -> None:
|
|
"""get_last_log returns the most recently inserted entry."""
|
|
await import_log_repo.add_log(
|
|
db,
|
|
source_id=None,
|
|
source_url="https://first.test/",
|
|
ips_imported=1,
|
|
ips_skipped=0,
|
|
errors=None,
|
|
)
|
|
await import_log_repo.add_log(
|
|
db,
|
|
source_id=None,
|
|
source_url="https://last.test/",
|
|
ips_imported=2,
|
|
ips_skipped=0,
|
|
errors=None,
|
|
)
|
|
last = await import_log_repo.get_last_log(db)
|
|
assert last is not None
|
|
assert last["source_url"] == "https://last.test/"
|
|
|
|
async def test_compute_total_pages(self) -> None:
|
|
"""compute_total_pages returns correct page count."""
|
|
assert import_log_repo.compute_total_pages(0, 10) == 1
|
|
assert import_log_repo.compute_total_pages(10, 10) == 1
|
|
assert import_log_repo.compute_total_pages(11, 10) == 2
|
|
assert import_log_repo.compute_total_pages(20, 5) == 4
|