"""Tests for blocklist_repo and import_log_repo.""" from __future__ import annotations from pathlib import Path import aiosqlite import pytest from app.db import init_db from app.repositories import blocklist_repo, import_log_repo @pytest.fixture async def db(tmp_path: Path) -> aiosqlite.Connection: # type: ignore[misc] """Provide an initialised aiosqlite connection for repository tests.""" conn: aiosqlite.Connection = await aiosqlite.connect(str(tmp_path / "bl_test.db")) conn.row_factory = aiosqlite.Row await init_db(conn) yield conn await conn.close() # --------------------------------------------------------------------------- # blocklist_repo tests # --------------------------------------------------------------------------- class TestBlocklistRepo: async def test_create_source_returns_int_id(self, db: aiosqlite.Connection) -> None: """create_source returns a positive integer id.""" source_id = await blocklist_repo.create_source(db, "Test", "https://example.com/list.txt") assert isinstance(source_id, int) assert source_id > 0 async def test_get_source_returns_row(self, db: aiosqlite.Connection) -> None: """get_source returns the correct row after creation.""" source_id = await blocklist_repo.create_source(db, "Alpha", "https://alpha.test/ips.txt") row = await blocklist_repo.get_source(db, source_id) assert row is not None assert row["name"] == "Alpha" assert row["url"] == "https://alpha.test/ips.txt" assert row["enabled"] is True async def test_get_source_missing_returns_none(self, db: aiosqlite.Connection) -> None: """get_source returns None for a non-existent id.""" result = await blocklist_repo.get_source(db, 9999) assert result is None async def test_list_sources_empty(self, db: aiosqlite.Connection) -> None: """list_sources returns empty list when no sources exist.""" rows = await blocklist_repo.list_sources(db) assert rows == [] async def test_list_sources_returns_all(self, db: aiosqlite.Connection) -> None: """list_sources returns all created sources.""" await blocklist_repo.create_source(db, "A", "https://a.test/") await blocklist_repo.create_source(db, "B", "https://b.test/") rows = await blocklist_repo.list_sources(db) assert len(rows) == 2 async def test_list_enabled_sources_filters(self, db: aiosqlite.Connection) -> None: """list_enabled_sources only returns rows with enabled=True.""" await blocklist_repo.create_source(db, "Enabled", "https://on.test/", enabled=True) id2 = await blocklist_repo.create_source(db, "Disabled", "https://off.test/", enabled=False) await blocklist_repo.update_source(db, id2, enabled=False) rows = await blocklist_repo.list_enabled_sources(db) assert len(rows) == 1 assert rows[0]["name"] == "Enabled" async def test_update_source_name(self, db: aiosqlite.Connection) -> None: """update_source changes the name field.""" source_id = await blocklist_repo.create_source(db, "Old", "https://old.test/") updated = await blocklist_repo.update_source(db, source_id, name="New") assert updated is True row = await blocklist_repo.get_source(db, source_id) assert row is not None assert row["name"] == "New" async def test_update_source_enabled_false(self, db: aiosqlite.Connection) -> None: """update_source can disable a source.""" source_id = await blocklist_repo.create_source(db, "On", "https://on.test/") await blocklist_repo.update_source(db, source_id, enabled=False) row = await blocklist_repo.get_source(db, source_id) assert row is not None assert row["enabled"] is False async def test_update_source_missing_returns_false(self, db: aiosqlite.Connection) -> None: """update_source returns False for a non-existent id.""" result = await blocklist_repo.update_source(db, 9999, name="Ghost") assert result is False async def test_delete_source_removes_row(self, db: aiosqlite.Connection) -> None: """delete_source removes the row and returns True.""" source_id = await blocklist_repo.create_source(db, "Del", "https://del.test/") deleted = await blocklist_repo.delete_source(db, source_id) assert deleted is True assert await blocklist_repo.get_source(db, source_id) is None async def test_delete_source_missing_returns_false(self, db: aiosqlite.Connection) -> None: """delete_source returns False for a non-existent id.""" result = await blocklist_repo.delete_source(db, 9999) assert result is False # --------------------------------------------------------------------------- # import_log_repo tests # --------------------------------------------------------------------------- class TestImportLogRepo: async def test_add_log_returns_id(self, db: aiosqlite.Connection) -> None: """add_log returns a positive integer id.""" log_id = await import_log_repo.add_log( db, source_id=None, source_url="https://example.com/list.txt", ips_imported=10, ips_skipped=2, errors=None, ) assert isinstance(log_id, int) assert log_id > 0 async def test_list_logs_returns_all(self, db: aiosqlite.Connection) -> None: """list_logs returns all logs when no source_id filter is applied.""" for i in range(3): await import_log_repo.add_log( db, source_id=None, source_url=f"https://s{i}.test/", ips_imported=i * 5, ips_skipped=0, errors=None, ) items, total = await import_log_repo.list_logs(db) assert total == 3 assert len(items) == 3 async def test_list_logs_pagination(self, db: aiosqlite.Connection) -> None: """list_logs respects page and page_size.""" for i in range(5): await import_log_repo.add_log( db, source_id=None, source_url=f"https://p{i}.test/", ips_imported=1, ips_skipped=0, errors=None, ) items, total = await import_log_repo.list_logs(db, page=2, page_size=2) assert total == 5 assert len(items) == 2 async def test_list_logs_source_filter(self, db: aiosqlite.Connection) -> None: """list_logs filters by source_id.""" source_id = await blocklist_repo.create_source(db, "Src", "https://s.test/") await import_log_repo.add_log( db, source_id=source_id, source_url="https://s.test/", ips_imported=5, ips_skipped=0, errors=None, ) await import_log_repo.add_log( db, source_id=None, source_url="https://other.test/", ips_imported=3, ips_skipped=0, errors=None, ) items, total = await import_log_repo.list_logs(db, source_id=source_id) assert total == 1 assert items[0]["source_url"] == "https://s.test/" async def test_get_last_log_empty(self, db: aiosqlite.Connection) -> None: """get_last_log returns None when no logs exist.""" result = await import_log_repo.get_last_log(db) assert result is None async def test_get_last_log_returns_most_recent(self, db: aiosqlite.Connection) -> None: """get_last_log returns the most recently inserted entry.""" await import_log_repo.add_log( db, source_id=None, source_url="https://first.test/", ips_imported=1, ips_skipped=0, errors=None, ) await import_log_repo.add_log( db, source_id=None, source_url="https://last.test/", ips_imported=2, ips_skipped=0, errors=None, ) last = await import_log_repo.get_last_log(db) assert last is not None assert last["source_url"] == "https://last.test/" async def test_compute_total_pages(self) -> None: """compute_total_pages returns correct page count.""" assert import_log_repo.compute_total_pages(0, 10) == 1 assert import_log_repo.compute_total_pages(10, 10) == 1 assert import_log_repo.compute_total_pages(11, 10) == 2 assert import_log_repo.compute_total_pages(20, 5) == 4