"""Tests for blocklist_service — source CRUD, preview, import, schedule.""" from __future__ import annotations from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import aiosqlite import pytest from app.db import init_db from app.models.blocklist import BlocklistSource, ScheduleConfig, ScheduleFrequency from app.services import blocklist_service # --------------------------------------------------------------------------- # Fixture # --------------------------------------------------------------------------- @pytest.fixture async def db(tmp_path: Path) -> aiosqlite.Connection: # type: ignore[misc] """Provide an initialised aiosqlite connection.""" conn: aiosqlite.Connection = await aiosqlite.connect(str(tmp_path / "bl_svc.db")) conn.row_factory = aiosqlite.Row await init_db(conn) yield conn await conn.close() def _make_session(text: str, status: int = 200) -> MagicMock: """Build a mock aiohttp session that returns *text* for GET requests.""" mock_resp = AsyncMock() mock_resp.status = status mock_resp.text = AsyncMock(return_value=text) mock_resp.content = AsyncMock() mock_resp.content.read = AsyncMock(return_value=text.encode()) mock_ctx = AsyncMock() mock_ctx.__aenter__ = AsyncMock(return_value=mock_resp) mock_ctx.__aexit__ = AsyncMock(return_value=False) session = MagicMock() session.get = MagicMock(return_value=mock_ctx) return session # --------------------------------------------------------------------------- # Source CRUD # --------------------------------------------------------------------------- class TestSourceCRUD: async def test_create_and_get(self, db: aiosqlite.Connection) -> None: """create_source persists and get_source retrieves a source.""" source = await blocklist_service.create_source(db, "Test", "https://t.test/") assert isinstance(source, BlocklistSource) assert source.name == "Test" assert source.enabled is True fetched = await blocklist_service.get_source(db, source.id) assert fetched is not None assert fetched.id == source.id async def test_get_missing_returns_none(self, db: aiosqlite.Connection) -> None: """get_source returns None for a non-existent id.""" result = await blocklist_service.get_source(db, 9999) assert result is None async def test_list_sources_empty(self, db: aiosqlite.Connection) -> None: """list_sources returns empty list when no sources exist.""" sources = await blocklist_service.list_sources(db) assert sources == [] async def test_list_sources_returns_all(self, db: aiosqlite.Connection) -> None: """list_sources returns all created sources.""" await blocklist_service.create_source(db, "A", "https://a.test/") await blocklist_service.create_source(db, "B", "https://b.test/") sources = await blocklist_service.list_sources(db) assert len(sources) == 2 async def test_update_source_fields(self, db: aiosqlite.Connection) -> None: """update_source modifies specified fields.""" source = await blocklist_service.create_source(db, "Original", "https://orig.test/") updated = await blocklist_service.update_source(db, source.id, name="Updated", enabled=False) assert updated is not None assert updated.name == "Updated" assert updated.enabled is False async def test_update_source_missing_returns_none(self, db: aiosqlite.Connection) -> None: """update_source returns None for a non-existent id.""" result = await blocklist_service.update_source(db, 9999, name="Ghost") assert result is None async def test_delete_source(self, db: aiosqlite.Connection) -> None: """delete_source removes a source and returns True.""" source = await blocklist_service.create_source(db, "Del", "https://del.test/") deleted = await blocklist_service.delete_source(db, source.id) assert deleted is True assert await blocklist_service.get_source(db, source.id) is None async def test_delete_source_missing_returns_false(self, db: aiosqlite.Connection) -> None: """delete_source returns False for a non-existent id.""" result = await blocklist_service.delete_source(db, 9999) assert result is False # --------------------------------------------------------------------------- # Preview # --------------------------------------------------------------------------- class TestPreview: async def test_preview_valid_ips(self) -> None: """preview_source returns valid IPs from the downloaded content.""" content = "1.2.3.4\n5.6.7.8\n# comment\ninvalid\n9.0.0.1\n" session = _make_session(content) result = await blocklist_service.preview_source("https://test.test/ips.txt", session) assert result.valid_count == 3 assert result.skipped_count == 1 # "invalid" assert "1.2.3.4" in result.entries async def test_preview_http_error_raises(self) -> None: """preview_source raises ValueError when the server returns non-200.""" session = _make_session("", status=404) with pytest.raises(ValueError, match="HTTP 404"): await blocklist_service.preview_source("https://bad.test/", session) async def test_preview_limits_entries(self) -> None: """preview_source caps entries to sample_lines.""" ips = "\n".join(f"1.2.3.{i}" for i in range(50)) session = _make_session(ips) result = await blocklist_service.preview_source( "https://test.test/", session, sample_lines=10 ) assert len(result.entries) <= 10 assert result.valid_count == 50 # --------------------------------------------------------------------------- # Import # --------------------------------------------------------------------------- class TestImport: async def test_import_source_bans_valid_ips(self, db: aiosqlite.Connection) -> None: """import_source calls ban_ip for every valid IP in the blocklist.""" content = "1.2.3.4\n5.6.7.8\n# skip me\n" session = _make_session(content) source = await blocklist_service.create_source(db, "Import Test", "https://t.test/") with patch( "app.services.jail_service.ban_ip", new_callable=AsyncMock ) as mock_ban: result = await blocklist_service.import_source( source, session, "/tmp/fake.sock", db ) assert result.ips_imported == 2 assert result.ips_skipped == 0 assert result.error is None assert mock_ban.call_count == 2 async def test_import_source_skips_cidrs(self, db: aiosqlite.Connection) -> None: """import_source skips CIDR ranges (fail2ban expects individual IPs).""" content = "1.2.3.4\n10.0.0.0/24\n" session = _make_session(content) source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/") with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock): result = await blocklist_service.import_source( source, session, "/tmp/fake.sock", db ) assert result.ips_imported == 1 assert result.ips_skipped == 1 async def test_import_source_records_download_error(self, db: aiosqlite.Connection) -> None: """import_source records an error and returns 0 imported on HTTP failure.""" session = _make_session("", status=503) source = await blocklist_service.create_source(db, "Err Source", "https://err.test/") result = await blocklist_service.import_source( source, session, "/tmp/fake.sock", db ) assert result.ips_imported == 0 assert result.error is not None async def test_import_source_aborts_on_jail_not_found(self, db: aiosqlite.Connection) -> None: """import_source aborts immediately and records an error when the target jail does not exist in fail2ban instead of silently skipping every IP.""" from app.services.jail_service import JailNotFoundError content = "\n".join(f"1.2.3.{i}" for i in range(100)) session = _make_session(content) source = await blocklist_service.create_source(db, "Missing Jail", "https://mj.test/") call_count = 0 async def _raise_jail_not_found(socket_path: str, jail: str, ip: str) -> None: nonlocal call_count call_count += 1 raise JailNotFoundError(jail) with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found): result = await blocklist_service.import_source( source, session, "/tmp/fake.sock", db ) # Must abort after the first JailNotFoundError — only one ban attempt. assert call_count == 1 assert result.ips_imported == 0 assert result.error is not None assert "not found" in result.error.lower() or "blocklist-import" in result.error async def test_import_all_runs_all_enabled(self, db: aiosqlite.Connection) -> None: """import_all aggregates results across all enabled sources.""" await blocklist_service.create_source(db, "S1", "https://s1.test/") s2 = await blocklist_service.create_source(db, "S2", "https://s2.test/", enabled=False) _ = s2 # noqa: F841 content = "1.2.3.4\n5.6.7.8\n" session = _make_session(content) with patch( "app.services.jail_service.ban_ip", new_callable=AsyncMock ): result = await blocklist_service.import_all(db, session, "/tmp/fake.sock") # Only S1 is enabled, S2 is disabled. assert len(result.results) == 1 assert result.results[0].source_url == "https://s1.test/" # --------------------------------------------------------------------------- # Schedule # --------------------------------------------------------------------------- class TestSchedule: async def test_get_schedule_default(self, db: aiosqlite.Connection) -> None: """get_schedule returns the default daily-03:00 config when nothing is saved.""" config = await blocklist_service.get_schedule(db) assert config.frequency == ScheduleFrequency.daily assert config.hour == 3 async def test_set_and_get_round_trip(self, db: aiosqlite.Connection) -> None: """set_schedule persists config retrievable by get_schedule.""" cfg = ScheduleConfig(frequency=ScheduleFrequency.hourly, interval_hours=6, hour=0, minute=0, day_of_week=0) await blocklist_service.set_schedule(db, cfg) loaded = await blocklist_service.get_schedule(db) assert loaded.frequency == ScheduleFrequency.hourly assert loaded.interval_hours == 6 async def test_get_schedule_info_no_log(self, db: aiosqlite.Connection) -> None: """get_schedule_info returns None for last_run_at and last_run_errors when no log exists.""" info = await blocklist_service.get_schedule_info(db, None) assert info.last_run_at is None assert info.next_run_at is None assert info.last_run_errors is None async def test_get_schedule_info_no_errors_when_clean( self, db: aiosqlite.Connection ) -> None: """get_schedule_info returns last_run_errors=False when the last run had no errors.""" from app.repositories import import_log_repo await import_log_repo.add_log( db, source_id=None, source_url="https://example.test/ips.txt", ips_imported=10, ips_skipped=0, errors=None, ) info = await blocklist_service.get_schedule_info(db, None) assert info.last_run_errors is False async def test_get_schedule_info_errors_flag_when_failed( self, db: aiosqlite.Connection ) -> None: """get_schedule_info returns last_run_errors=True when the last run had errors.""" from app.repositories import import_log_repo await import_log_repo.add_log( db, source_id=None, source_url="https://example.test/ips.txt", ips_imported=0, ips_skipped=0, errors="Connection timeout", ) info = await blocklist_service.get_schedule_info(db, None) assert info.last_run_errors is True # --------------------------------------------------------------------------- # Geo prewarm cache filtering # --------------------------------------------------------------------------- class TestGeoPrewarmCacheFilter: async def test_import_source_skips_cached_ips_for_geo_prewarm( self, db: aiosqlite.Connection ) -> None: """import_source only sends uncached IPs to geo_service.lookup_batch.""" content = "1.2.3.4\n5.6.7.8\n9.10.11.12\n" session = _make_session(content) source = await blocklist_service.create_source( db, "Geo Filter", "https://gf.test/" ) # Pretend 1.2.3.4 is already cached. def _mock_is_cached(ip: str) -> bool: return ip == "1.2.3.4" with ( patch("app.services.jail_service.ban_ip", new_callable=AsyncMock), patch( "app.services.geo_service.is_cached", side_effect=_mock_is_cached, ), patch( "app.services.geo_service.lookup_batch", new_callable=AsyncMock, return_value={}, ) as mock_batch, ): result = await blocklist_service.import_source( source, session, "/tmp/fake.sock", db ) assert result.ips_imported == 3 # lookup_batch should receive only the 2 uncached IPs. mock_batch.assert_called_once() call_ips = mock_batch.call_args[0][0] assert "1.2.3.4" not in call_ips assert set(call_ips) == {"5.6.7.8", "9.10.11.12"}