Task 4 (Better Jail Configuration) implementation:
- Add fail2ban_config_dir setting to app/config.py
- New file_config_service: list/view/edit/create jail.d, filter.d, action.d files
with path-traversal prevention and 512 KB content size limit
- New file_config router: GET/PUT/POST endpoints for jail files, filter files,
and action files; PUT .../enabled for toggle on/off
- Extend config_service with delete_log_path() and add_log_path()
- Add DELETE /api/config/jails/{name}/logpath and POST /api/config/jails/{name}/logpath
- Extend geo router with re-resolve endpoint; add geo_re_resolve background task
- Update blocklist_service with revised scheduling helpers
- Update Docker compose files with BANGUI_FAIL2BAN_CONFIG_DIR env var and
rw volume mount for the fail2ban config directory
- Frontend: new Jail Files, Filters, Actions tabs in ConfigPage; file editor
with accordion-per-file, editable textarea, save/create; add/delete log paths
- Frontend: types in types/config.ts; API calls in api/config.ts and api/endpoints.ts
- 63 new backend tests (test_file_config_service, test_file_config, test_geo_re_resolve)
- 6 new frontend tests in ConfigPageLogPath.test.tsx
- ruff, mypy --strict, tsc --noEmit, eslint: all clean; 617 backend tests pass
340 lines
14 KiB
Python
340 lines
14 KiB
Python
"""Tests for blocklist_service — source CRUD, preview, import, schedule."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import aiosqlite
|
|
import pytest
|
|
|
|
from app.db import init_db
|
|
from app.models.blocklist import BlocklistSource, ScheduleConfig, ScheduleFrequency
|
|
from app.services import blocklist_service
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixture
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture
|
|
async def db(tmp_path: Path) -> aiosqlite.Connection: # type: ignore[misc]
|
|
"""Provide an initialised aiosqlite connection."""
|
|
conn: aiosqlite.Connection = await aiosqlite.connect(str(tmp_path / "bl_svc.db"))
|
|
conn.row_factory = aiosqlite.Row
|
|
await init_db(conn)
|
|
yield conn
|
|
await conn.close()
|
|
|
|
|
|
def _make_session(text: str, status: int = 200) -> MagicMock:
|
|
"""Build a mock aiohttp session that returns *text* for GET requests."""
|
|
mock_resp = AsyncMock()
|
|
mock_resp.status = status
|
|
mock_resp.text = AsyncMock(return_value=text)
|
|
mock_resp.content = AsyncMock()
|
|
mock_resp.content.read = AsyncMock(return_value=text.encode())
|
|
|
|
mock_ctx = AsyncMock()
|
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_resp)
|
|
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
session = MagicMock()
|
|
session.get = MagicMock(return_value=mock_ctx)
|
|
return session
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Source CRUD
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSourceCRUD:
|
|
async def test_create_and_get(self, db: aiosqlite.Connection) -> None:
|
|
"""create_source persists and get_source retrieves a source."""
|
|
source = await blocklist_service.create_source(db, "Test", "https://t.test/")
|
|
assert isinstance(source, BlocklistSource)
|
|
assert source.name == "Test"
|
|
assert source.enabled is True
|
|
|
|
fetched = await blocklist_service.get_source(db, source.id)
|
|
assert fetched is not None
|
|
assert fetched.id == source.id
|
|
|
|
async def test_get_missing_returns_none(self, db: aiosqlite.Connection) -> None:
|
|
"""get_source returns None for a non-existent id."""
|
|
result = await blocklist_service.get_source(db, 9999)
|
|
assert result is None
|
|
|
|
async def test_list_sources_empty(self, db: aiosqlite.Connection) -> None:
|
|
"""list_sources returns empty list when no sources exist."""
|
|
sources = await blocklist_service.list_sources(db)
|
|
assert sources == []
|
|
|
|
async def test_list_sources_returns_all(self, db: aiosqlite.Connection) -> None:
|
|
"""list_sources returns all created sources."""
|
|
await blocklist_service.create_source(db, "A", "https://a.test/")
|
|
await blocklist_service.create_source(db, "B", "https://b.test/")
|
|
sources = await blocklist_service.list_sources(db)
|
|
assert len(sources) == 2
|
|
|
|
async def test_update_source_fields(self, db: aiosqlite.Connection) -> None:
|
|
"""update_source modifies specified fields."""
|
|
source = await blocklist_service.create_source(db, "Original", "https://orig.test/")
|
|
updated = await blocklist_service.update_source(db, source.id, name="Updated", enabled=False)
|
|
assert updated is not None
|
|
assert updated.name == "Updated"
|
|
assert updated.enabled is False
|
|
|
|
async def test_update_source_missing_returns_none(self, db: aiosqlite.Connection) -> None:
|
|
"""update_source returns None for a non-existent id."""
|
|
result = await blocklist_service.update_source(db, 9999, name="Ghost")
|
|
assert result is None
|
|
|
|
async def test_delete_source(self, db: aiosqlite.Connection) -> None:
|
|
"""delete_source removes a source and returns True."""
|
|
source = await blocklist_service.create_source(db, "Del", "https://del.test/")
|
|
deleted = await blocklist_service.delete_source(db, source.id)
|
|
assert deleted is True
|
|
assert await blocklist_service.get_source(db, source.id) is None
|
|
|
|
async def test_delete_source_missing_returns_false(self, db: aiosqlite.Connection) -> None:
|
|
"""delete_source returns False for a non-existent id."""
|
|
result = await blocklist_service.delete_source(db, 9999)
|
|
assert result is False
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Preview
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPreview:
|
|
async def test_preview_valid_ips(self) -> None:
|
|
"""preview_source returns valid IPs from the downloaded content."""
|
|
content = "1.2.3.4\n5.6.7.8\n# comment\ninvalid\n9.0.0.1\n"
|
|
session = _make_session(content)
|
|
result = await blocklist_service.preview_source("https://test.test/ips.txt", session)
|
|
assert result.valid_count == 3
|
|
assert result.skipped_count == 1 # "invalid"
|
|
assert "1.2.3.4" in result.entries
|
|
|
|
async def test_preview_http_error_raises(self) -> None:
|
|
"""preview_source raises ValueError when the server returns non-200."""
|
|
session = _make_session("", status=404)
|
|
with pytest.raises(ValueError, match="HTTP 404"):
|
|
await blocklist_service.preview_source("https://bad.test/", session)
|
|
|
|
async def test_preview_limits_entries(self) -> None:
|
|
"""preview_source caps entries to sample_lines."""
|
|
ips = "\n".join(f"1.2.3.{i}" for i in range(50))
|
|
session = _make_session(ips)
|
|
result = await blocklist_service.preview_source(
|
|
"https://test.test/", session, sample_lines=10
|
|
)
|
|
assert len(result.entries) <= 10
|
|
assert result.valid_count == 50
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Import
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestImport:
|
|
async def test_import_source_bans_valid_ips(self, db: aiosqlite.Connection) -> None:
|
|
"""import_source calls ban_ip for every valid IP in the blocklist."""
|
|
content = "1.2.3.4\n5.6.7.8\n# skip me\n"
|
|
session = _make_session(content)
|
|
|
|
source = await blocklist_service.create_source(db, "Import Test", "https://t.test/")
|
|
|
|
with patch(
|
|
"app.services.jail_service.ban_ip", new_callable=AsyncMock
|
|
) as mock_ban:
|
|
result = await blocklist_service.import_source(
|
|
source, session, "/tmp/fake.sock", db
|
|
)
|
|
|
|
assert result.ips_imported == 2
|
|
assert result.ips_skipped == 0
|
|
assert result.error is None
|
|
assert mock_ban.call_count == 2
|
|
|
|
async def test_import_source_skips_cidrs(self, db: aiosqlite.Connection) -> None:
|
|
"""import_source skips CIDR ranges (fail2ban expects individual IPs)."""
|
|
content = "1.2.3.4\n10.0.0.0/24\n"
|
|
session = _make_session(content)
|
|
source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/")
|
|
|
|
with patch("app.services.jail_service.ban_ip", new_callable=AsyncMock):
|
|
result = await blocklist_service.import_source(
|
|
source, session, "/tmp/fake.sock", db
|
|
)
|
|
|
|
assert result.ips_imported == 1
|
|
assert result.ips_skipped == 1
|
|
|
|
async def test_import_source_records_download_error(self, db: aiosqlite.Connection) -> None:
|
|
"""import_source records an error and returns 0 imported on HTTP failure."""
|
|
session = _make_session("", status=503)
|
|
source = await blocklist_service.create_source(db, "Err Source", "https://err.test/")
|
|
|
|
result = await blocklist_service.import_source(
|
|
source, session, "/tmp/fake.sock", db
|
|
)
|
|
|
|
assert result.ips_imported == 0
|
|
assert result.error is not None
|
|
|
|
async def test_import_source_aborts_on_jail_not_found(self, db: aiosqlite.Connection) -> None:
|
|
"""import_source aborts immediately and records an error when the target jail
|
|
does not exist in fail2ban instead of silently skipping every IP."""
|
|
from app.services.jail_service import JailNotFoundError
|
|
|
|
content = "\n".join(f"1.2.3.{i}" for i in range(100))
|
|
session = _make_session(content)
|
|
source = await blocklist_service.create_source(db, "Missing Jail", "https://mj.test/")
|
|
|
|
call_count = 0
|
|
|
|
async def _raise_jail_not_found(socket_path: str, jail: str, ip: str) -> None:
|
|
nonlocal call_count
|
|
call_count += 1
|
|
raise JailNotFoundError(jail)
|
|
|
|
with patch("app.services.jail_service.ban_ip", side_effect=_raise_jail_not_found):
|
|
result = await blocklist_service.import_source(
|
|
source, session, "/tmp/fake.sock", db
|
|
)
|
|
|
|
# Must abort after the first JailNotFoundError — only one ban attempt.
|
|
assert call_count == 1
|
|
assert result.ips_imported == 0
|
|
assert result.error is not None
|
|
assert "not found" in result.error.lower() or "blocklist-import" in result.error
|
|
|
|
async def test_import_all_runs_all_enabled(self, db: aiosqlite.Connection) -> None:
|
|
"""import_all aggregates results across all enabled sources."""
|
|
await blocklist_service.create_source(db, "S1", "https://s1.test/")
|
|
s2 = await blocklist_service.create_source(db, "S2", "https://s2.test/", enabled=False)
|
|
_ = s2 # noqa: F841
|
|
|
|
content = "1.2.3.4\n5.6.7.8\n"
|
|
session = _make_session(content)
|
|
|
|
with patch(
|
|
"app.services.jail_service.ban_ip", new_callable=AsyncMock
|
|
):
|
|
result = await blocklist_service.import_all(db, session, "/tmp/fake.sock")
|
|
|
|
# Only S1 is enabled, S2 is disabled.
|
|
assert len(result.results) == 1
|
|
assert result.results[0].source_url == "https://s1.test/"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Schedule
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestSchedule:
|
|
async def test_get_schedule_default(self, db: aiosqlite.Connection) -> None:
|
|
"""get_schedule returns the default daily-03:00 config when nothing is saved."""
|
|
config = await blocklist_service.get_schedule(db)
|
|
assert config.frequency == ScheduleFrequency.daily
|
|
assert config.hour == 3
|
|
|
|
async def test_set_and_get_round_trip(self, db: aiosqlite.Connection) -> None:
|
|
"""set_schedule persists config retrievable by get_schedule."""
|
|
cfg = ScheduleConfig(frequency=ScheduleFrequency.hourly, interval_hours=6, hour=0, minute=0, day_of_week=0)
|
|
await blocklist_service.set_schedule(db, cfg)
|
|
loaded = await blocklist_service.get_schedule(db)
|
|
assert loaded.frequency == ScheduleFrequency.hourly
|
|
assert loaded.interval_hours == 6
|
|
|
|
async def test_get_schedule_info_no_log(self, db: aiosqlite.Connection) -> None:
|
|
"""get_schedule_info returns None for last_run_at and last_run_errors when no log exists."""
|
|
info = await blocklist_service.get_schedule_info(db, None)
|
|
assert info.last_run_at is None
|
|
assert info.next_run_at is None
|
|
assert info.last_run_errors is None
|
|
|
|
async def test_get_schedule_info_no_errors_when_clean(
|
|
self, db: aiosqlite.Connection
|
|
) -> None:
|
|
"""get_schedule_info returns last_run_errors=False when the last run had no errors."""
|
|
from app.repositories import import_log_repo
|
|
|
|
await import_log_repo.add_log(
|
|
db,
|
|
source_id=None,
|
|
source_url="https://example.test/ips.txt",
|
|
ips_imported=10,
|
|
ips_skipped=0,
|
|
errors=None,
|
|
)
|
|
info = await blocklist_service.get_schedule_info(db, None)
|
|
assert info.last_run_errors is False
|
|
|
|
async def test_get_schedule_info_errors_flag_when_failed(
|
|
self, db: aiosqlite.Connection
|
|
) -> None:
|
|
"""get_schedule_info returns last_run_errors=True when the last run had errors."""
|
|
from app.repositories import import_log_repo
|
|
|
|
await import_log_repo.add_log(
|
|
db,
|
|
source_id=None,
|
|
source_url="https://example.test/ips.txt",
|
|
ips_imported=0,
|
|
ips_skipped=0,
|
|
errors="Connection timeout",
|
|
)
|
|
info = await blocklist_service.get_schedule_info(db, None)
|
|
assert info.last_run_errors is True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Geo prewarm cache filtering
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestGeoPrewarmCacheFilter:
|
|
async def test_import_source_skips_cached_ips_for_geo_prewarm(
|
|
self, db: aiosqlite.Connection
|
|
) -> None:
|
|
"""import_source only sends uncached IPs to geo_service.lookup_batch."""
|
|
content = "1.2.3.4\n5.6.7.8\n9.10.11.12\n"
|
|
session = _make_session(content)
|
|
source = await blocklist_service.create_source(
|
|
db, "Geo Filter", "https://gf.test/"
|
|
)
|
|
|
|
# Pretend 1.2.3.4 is already cached.
|
|
def _mock_is_cached(ip: str) -> bool:
|
|
return ip == "1.2.3.4"
|
|
|
|
with (
|
|
patch("app.services.jail_service.ban_ip", new_callable=AsyncMock),
|
|
patch(
|
|
"app.services.geo_service.is_cached",
|
|
side_effect=_mock_is_cached,
|
|
),
|
|
patch(
|
|
"app.services.geo_service.lookup_batch",
|
|
new_callable=AsyncMock,
|
|
return_value={},
|
|
) as mock_batch,
|
|
):
|
|
result = await blocklist_service.import_source(
|
|
source, session, "/tmp/fake.sock", db
|
|
)
|
|
|
|
assert result.ips_imported == 3
|
|
# lookup_batch should receive only the 2 uncached IPs.
|
|
mock_batch.assert_called_once()
|
|
call_ips = mock_batch.call_args[0][0]
|
|
assert "1.2.3.4" not in call_ips
|
|
assert set(call_ips) == {"5.6.7.8", "9.10.11.12"}
|