Files
BanGUI/backend/tests/test_repositories/test_blocklist.py
Lukas 1efa0e973b Stage 10: external blocklist importer — backend + frontend
- blocklist_repo.py: CRUD for blocklist_sources table
- import_log_repo.py: add/list/get-last log entries
- blocklist_service.py: source CRUD, preview, import (download/validate/ban),
  import_all, schedule get/set/info
- blocklist_import.py: APScheduler task (hourly/daily/weekly schedule triggers)
- blocklist.py router: 9 endpoints (list/create/update/delete/preview/import/
  schedule-get+put/log)
- blocklist.py models: ScheduleFrequency (StrEnum), ScheduleConfig, ScheduleInfo,
  ImportSourceResult, ImportRunResult, PreviewResponse
- 59 new tests (18 repo + 19 service + 22 router); 374 total pass
- ruff clean, mypy clean for Stage 10 files
- types/blocklist.ts, api/blocklist.ts, hooks/useBlocklist.ts
- BlocklistsPage.tsx: source management, schedule picker, import log table
- Frontend tsc + ESLint clean
2026-03-01 15:33:24 +01:00

211 lines
8.5 KiB
Python

"""Tests for blocklist_repo and import_log_repo."""
from __future__ import annotations
from pathlib import Path
import aiosqlite
import pytest
from app.db import init_db
from app.repositories import blocklist_repo, import_log_repo
@pytest.fixture
async def db(tmp_path: Path) -> aiosqlite.Connection: # type: ignore[misc]
"""Provide an initialised aiosqlite connection for repository tests."""
conn: aiosqlite.Connection = await aiosqlite.connect(str(tmp_path / "bl_test.db"))
conn.row_factory = aiosqlite.Row
await init_db(conn)
yield conn
await conn.close()
# ---------------------------------------------------------------------------
# blocklist_repo tests
# ---------------------------------------------------------------------------
class TestBlocklistRepo:
async def test_create_source_returns_int_id(self, db: aiosqlite.Connection) -> None:
"""create_source returns a positive integer id."""
source_id = await blocklist_repo.create_source(db, "Test", "https://example.com/list.txt")
assert isinstance(source_id, int)
assert source_id > 0
async def test_get_source_returns_row(self, db: aiosqlite.Connection) -> None:
"""get_source returns the correct row after creation."""
source_id = await blocklist_repo.create_source(db, "Alpha", "https://alpha.test/ips.txt")
row = await blocklist_repo.get_source(db, source_id)
assert row is not None
assert row["name"] == "Alpha"
assert row["url"] == "https://alpha.test/ips.txt"
assert row["enabled"] is True
async def test_get_source_missing_returns_none(self, db: aiosqlite.Connection) -> None:
"""get_source returns None for a non-existent id."""
result = await blocklist_repo.get_source(db, 9999)
assert result is None
async def test_list_sources_empty(self, db: aiosqlite.Connection) -> None:
"""list_sources returns empty list when no sources exist."""
rows = await blocklist_repo.list_sources(db)
assert rows == []
async def test_list_sources_returns_all(self, db: aiosqlite.Connection) -> None:
"""list_sources returns all created sources."""
await blocklist_repo.create_source(db, "A", "https://a.test/")
await blocklist_repo.create_source(db, "B", "https://b.test/")
rows = await blocklist_repo.list_sources(db)
assert len(rows) == 2
async def test_list_enabled_sources_filters(self, db: aiosqlite.Connection) -> None:
"""list_enabled_sources only returns rows with enabled=True."""
await blocklist_repo.create_source(db, "Enabled", "https://on.test/", enabled=True)
id2 = await blocklist_repo.create_source(db, "Disabled", "https://off.test/", enabled=False)
await blocklist_repo.update_source(db, id2, enabled=False)
rows = await blocklist_repo.list_enabled_sources(db)
assert len(rows) == 1
assert rows[0]["name"] == "Enabled"
async def test_update_source_name(self, db: aiosqlite.Connection) -> None:
"""update_source changes the name field."""
source_id = await blocklist_repo.create_source(db, "Old", "https://old.test/")
updated = await blocklist_repo.update_source(db, source_id, name="New")
assert updated is True
row = await blocklist_repo.get_source(db, source_id)
assert row is not None
assert row["name"] == "New"
async def test_update_source_enabled_false(self, db: aiosqlite.Connection) -> None:
"""update_source can disable a source."""
source_id = await blocklist_repo.create_source(db, "On", "https://on.test/")
await blocklist_repo.update_source(db, source_id, enabled=False)
row = await blocklist_repo.get_source(db, source_id)
assert row is not None
assert row["enabled"] is False
async def test_update_source_missing_returns_false(self, db: aiosqlite.Connection) -> None:
"""update_source returns False for a non-existent id."""
result = await blocklist_repo.update_source(db, 9999, name="Ghost")
assert result is False
async def test_delete_source_removes_row(self, db: aiosqlite.Connection) -> None:
"""delete_source removes the row and returns True."""
source_id = await blocklist_repo.create_source(db, "Del", "https://del.test/")
deleted = await blocklist_repo.delete_source(db, source_id)
assert deleted is True
assert await blocklist_repo.get_source(db, source_id) is None
async def test_delete_source_missing_returns_false(self, db: aiosqlite.Connection) -> None:
"""delete_source returns False for a non-existent id."""
result = await blocklist_repo.delete_source(db, 9999)
assert result is False
# ---------------------------------------------------------------------------
# import_log_repo tests
# ---------------------------------------------------------------------------
class TestImportLogRepo:
async def test_add_log_returns_id(self, db: aiosqlite.Connection) -> None:
"""add_log returns a positive integer id."""
log_id = await import_log_repo.add_log(
db,
source_id=None,
source_url="https://example.com/list.txt",
ips_imported=10,
ips_skipped=2,
errors=None,
)
assert isinstance(log_id, int)
assert log_id > 0
async def test_list_logs_returns_all(self, db: aiosqlite.Connection) -> None:
"""list_logs returns all logs when no source_id filter is applied."""
for i in range(3):
await import_log_repo.add_log(
db,
source_id=None,
source_url=f"https://s{i}.test/",
ips_imported=i * 5,
ips_skipped=0,
errors=None,
)
items, total = await import_log_repo.list_logs(db)
assert total == 3
assert len(items) == 3
async def test_list_logs_pagination(self, db: aiosqlite.Connection) -> None:
"""list_logs respects page and page_size."""
for i in range(5):
await import_log_repo.add_log(
db,
source_id=None,
source_url=f"https://p{i}.test/",
ips_imported=1,
ips_skipped=0,
errors=None,
)
items, total = await import_log_repo.list_logs(db, page=2, page_size=2)
assert total == 5
assert len(items) == 2
async def test_list_logs_source_filter(self, db: aiosqlite.Connection) -> None:
"""list_logs filters by source_id."""
source_id = await blocklist_repo.create_source(db, "Src", "https://s.test/")
await import_log_repo.add_log(
db,
source_id=source_id,
source_url="https://s.test/",
ips_imported=5,
ips_skipped=0,
errors=None,
)
await import_log_repo.add_log(
db,
source_id=None,
source_url="https://other.test/",
ips_imported=3,
ips_skipped=0,
errors=None,
)
items, total = await import_log_repo.list_logs(db, source_id=source_id)
assert total == 1
assert items[0]["source_url"] == "https://s.test/"
async def test_get_last_log_empty(self, db: aiosqlite.Connection) -> None:
"""get_last_log returns None when no logs exist."""
result = await import_log_repo.get_last_log(db)
assert result is None
async def test_get_last_log_returns_most_recent(self, db: aiosqlite.Connection) -> None:
"""get_last_log returns the most recently inserted entry."""
await import_log_repo.add_log(
db,
source_id=None,
source_url="https://first.test/",
ips_imported=1,
ips_skipped=0,
errors=None,
)
await import_log_repo.add_log(
db,
source_id=None,
source_url="https://last.test/",
ips_imported=2,
ips_skipped=0,
errors=None,
)
last = await import_log_repo.get_last_log(db)
assert last is not None
assert last["source_url"] == "https://last.test/"
async def test_compute_total_pages(self) -> None:
"""compute_total_pages returns correct page count."""
assert import_log_repo.compute_total_pages(0, 10) == 1
assert import_log_repo.compute_total_pages(10, 10) == 1
assert import_log_repo.compute_total_pages(11, 10) == 2
assert import_log_repo.compute_total_pages(20, 5) == 4