Add better jail configuration: file CRUD, enable/disable, log paths

Task 4 (Better Jail Configuration) implementation:
- Add fail2ban_config_dir setting to app/config.py
- New file_config_service: list/view/edit/create jail.d, filter.d, action.d files
  with path-traversal prevention and 512 KB content size limit
- New file_config router: GET/PUT/POST endpoints for jail files, filter files,
  and action files; PUT .../enabled for toggle on/off
- Extend config_service with delete_log_path() and add_log_path()
- Add DELETE /api/config/jails/{name}/logpath and POST /api/config/jails/{name}/logpath
- Extend geo router with re-resolve endpoint; add geo_re_resolve background task
- Update blocklist_service with revised scheduling helpers
- Update Docker compose files with BANGUI_FAIL2BAN_CONFIG_DIR env var and
  rw volume mount for the fail2ban config directory
- Frontend: new Jail Files, Filters, Actions tabs in ConfigPage; file editor
  with accordion-per-file, editable textarea, save/create; add/delete log paths
- Frontend: types in types/config.ts; API calls in api/config.ts and api/endpoints.ts
- 63 new backend tests (test_file_config_service, test_file_config, test_geo_re_resolve)
- 6 new frontend tests in ConfigPageLogPath.test.tsx
- ruff, mypy --strict, tsc --noEmit, eslint: all clean; 617 backend tests pass
This commit is contained in:
2026-03-12 20:08:33 +01:00
parent 59464a1592
commit ea35695221
23 changed files with 2911 additions and 91 deletions

View File

@@ -0,0 +1,379 @@
"""Tests for the file_config router endpoints."""
from __future__ import annotations
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import aiosqlite
import pytest
from httpx import ASGITransport, AsyncClient
from app.config import Settings
from app.db import init_db
from app.main import create_app
from app.models.file_config import (
ConfFileContent,
ConfFileEntry,
ConfFilesResponse,
JailConfigFile,
JailConfigFileContent,
JailConfigFilesResponse,
)
from app.services.file_config_service import (
ConfigDirError,
ConfigFileExistsError,
ConfigFileNameError,
ConfigFileNotFoundError,
ConfigFileWriteError,
)
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
_SETUP_PAYLOAD = {
"master_password": "testpassword1",
"database_path": "bangui.db",
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
"timezone": "UTC",
"session_duration_minutes": 60,
}
@pytest.fixture
async def file_config_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
"""Provide an authenticated ``AsyncClient`` for file_config endpoint tests."""
settings = Settings(
database_path=str(tmp_path / "file_config_test.db"),
fail2ban_socket="/tmp/fake.sock",
session_secret="test-file-config-secret",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
)
app = create_app(settings=settings)
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
db.row_factory = aiosqlite.Row
await init_db(db)
app.state.db = db
app.state.http_session = MagicMock()
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
await ac.post("/api/setup", json=_SETUP_PAYLOAD)
login = await ac.post(
"/api/auth/login",
json={"password": _SETUP_PAYLOAD["master_password"]},
)
assert login.status_code == 200
yield ac
await db.close()
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _jail_files_resp(files: list[JailConfigFile] | None = None) -> JailConfigFilesResponse:
files = files or [JailConfigFile(name="sshd", filename="sshd.conf", enabled=True)]
return JailConfigFilesResponse(files=files, total=len(files))
def _conf_files_resp(files: list[ConfFileEntry] | None = None) -> ConfFilesResponse:
files = files or [ConfFileEntry(name="nginx", filename="nginx.conf")]
return ConfFilesResponse(files=files, total=len(files))
def _conf_file_content(name: str = "nginx") -> ConfFileContent:
return ConfFileContent(
name=name,
filename=f"{name}.conf",
content=f"[Definition]\n# {name} filter\n",
)
# ---------------------------------------------------------------------------
# GET /api/config/jail-files
# ---------------------------------------------------------------------------
class TestListJailConfigFiles:
async def test_200_returns_file_list(
self, file_config_client: AsyncClient
) -> None:
with patch(
"app.routers.file_config.file_config_service.list_jail_config_files",
AsyncMock(return_value=_jail_files_resp()),
):
resp = await file_config_client.get("/api/config/jail-files")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 1
assert data["files"][0]["filename"] == "sshd.conf"
async def test_503_on_config_dir_error(
self, file_config_client: AsyncClient
) -> None:
with patch(
"app.routers.file_config.file_config_service.list_jail_config_files",
AsyncMock(side_effect=ConfigDirError("not found")),
):
resp = await file_config_client.get("/api/config/jail-files")
assert resp.status_code == 503
async def test_401_unauthenticated(self, file_config_client: AsyncClient) -> None:
resp = await AsyncClient(
transport=ASGITransport(app=file_config_client._transport.app), # type: ignore[attr-defined]
base_url="http://test",
).get("/api/config/jail-files")
assert resp.status_code == 401
# ---------------------------------------------------------------------------
# GET /api/config/jail-files/{filename}
# ---------------------------------------------------------------------------
class TestGetJailConfigFile:
async def test_200_returns_content(
self, file_config_client: AsyncClient
) -> None:
content = JailConfigFileContent(
name="sshd",
filename="sshd.conf",
enabled=True,
content="[sshd]\nenabled = true\n",
)
with patch(
"app.routers.file_config.file_config_service.get_jail_config_file",
AsyncMock(return_value=content),
):
resp = await file_config_client.get("/api/config/jail-files/sshd.conf")
assert resp.status_code == 200
assert resp.json()["content"] == "[sshd]\nenabled = true\n"
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.get_jail_config_file",
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
):
resp = await file_config_client.get("/api/config/jail-files/missing.conf")
assert resp.status_code == 404
async def test_400_invalid_filename(
self, file_config_client: AsyncClient
) -> None:
with patch(
"app.routers.file_config.file_config_service.get_jail_config_file",
AsyncMock(side_effect=ConfigFileNameError("bad name")),
):
resp = await file_config_client.get("/api/config/jail-files/bad.txt")
assert resp.status_code == 400
# ---------------------------------------------------------------------------
# PUT /api/config/jail-files/{filename}/enabled
# ---------------------------------------------------------------------------
class TestSetJailConfigEnabled:
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.set_jail_config_enabled",
AsyncMock(return_value=None),
):
resp = await file_config_client.put(
"/api/config/jail-files/sshd.conf/enabled",
json={"enabled": False},
)
assert resp.status_code == 204
async def test_404_file_not_found(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.set_jail_config_enabled",
AsyncMock(side_effect=ConfigFileNotFoundError("missing.conf")),
):
resp = await file_config_client.put(
"/api/config/jail-files/missing.conf/enabled",
json={"enabled": True},
)
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# GET /api/config/filters
# ---------------------------------------------------------------------------
class TestListFilterFiles:
async def test_200_returns_files(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.list_filter_files",
AsyncMock(return_value=_conf_files_resp()),
):
resp = await file_config_client.get("/api/config/filters")
assert resp.status_code == 200
assert resp.json()["total"] == 1
async def test_503_on_config_dir_error(
self, file_config_client: AsyncClient
) -> None:
with patch(
"app.routers.file_config.file_config_service.list_filter_files",
AsyncMock(side_effect=ConfigDirError("x")),
):
resp = await file_config_client.get("/api/config/filters")
assert resp.status_code == 503
# ---------------------------------------------------------------------------
# GET /api/config/filters/{name}
# ---------------------------------------------------------------------------
class TestGetFilterFile:
async def test_200_returns_content(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.get_filter_file",
AsyncMock(return_value=_conf_file_content("nginx")),
):
resp = await file_config_client.get("/api/config/filters/nginx")
assert resp.status_code == 200
assert resp.json()["name"] == "nginx"
async def test_404_not_found(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.get_filter_file",
AsyncMock(side_effect=ConfigFileNotFoundError("missing")),
):
resp = await file_config_client.get("/api/config/filters/missing")
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# PUT /api/config/filters/{name}
# ---------------------------------------------------------------------------
class TestUpdateFilterFile:
async def test_204_on_success(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.write_filter_file",
AsyncMock(return_value=None),
):
resp = await file_config_client.put(
"/api/config/filters/nginx",
json={"content": "[Definition]\nfailregex = test\n"},
)
assert resp.status_code == 204
async def test_400_write_error(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.write_filter_file",
AsyncMock(side_effect=ConfigFileWriteError("disk full")),
):
resp = await file_config_client.put(
"/api/config/filters/nginx",
json={"content": "x"},
)
assert resp.status_code == 400
# ---------------------------------------------------------------------------
# POST /api/config/filters
# ---------------------------------------------------------------------------
class TestCreateFilterFile:
async def test_201_creates_file(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.create_filter_file",
AsyncMock(return_value="myfilter.conf"),
):
resp = await file_config_client.post(
"/api/config/filters",
json={"name": "myfilter", "content": "[Definition]\n"},
)
assert resp.status_code == 201
assert resp.json()["filename"] == "myfilter.conf"
async def test_409_conflict(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.create_filter_file",
AsyncMock(side_effect=ConfigFileExistsError("myfilter.conf")),
):
resp = await file_config_client.post(
"/api/config/filters",
json={"name": "myfilter", "content": "[Definition]\n"},
)
assert resp.status_code == 409
async def test_400_invalid_name(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.create_filter_file",
AsyncMock(side_effect=ConfigFileNameError("bad/../name")),
):
resp = await file_config_client.post(
"/api/config/filters",
json={"name": "../escape", "content": "[Definition]\n"},
)
assert resp.status_code == 400
# ---------------------------------------------------------------------------
# GET /api/config/actions (smoke test — same logic as filters)
# ---------------------------------------------------------------------------
class TestListActionFiles:
async def test_200_returns_files(self, file_config_client: AsyncClient) -> None:
action_entry = ConfFileEntry(name="iptables", filename="iptables.conf")
resp_data = ConfFilesResponse(files=[action_entry], total=1)
with patch(
"app.routers.file_config.file_config_service.list_action_files",
AsyncMock(return_value=resp_data),
):
resp = await file_config_client.get("/api/config/actions")
assert resp.status_code == 200
assert resp.json()["files"][0]["filename"] == "iptables.conf"
# ---------------------------------------------------------------------------
# POST /api/config/actions
# ---------------------------------------------------------------------------
class TestCreateActionFile:
async def test_201_creates_file(self, file_config_client: AsyncClient) -> None:
with patch(
"app.routers.file_config.file_config_service.create_action_file",
AsyncMock(return_value="myaction.conf"),
):
resp = await file_config_client.post(
"/api/config/actions",
json={"name": "myaction", "content": "[Definition]\n"},
)
assert resp.status_code == 201
assert resp.json()["filename"] == "myaction.conf"

View File

@@ -215,3 +215,66 @@ class TestReResolve:
base_url="http://test",
).post("/api/geo/re-resolve")
assert resp.status_code == 401
# ---------------------------------------------------------------------------
# GET /api/geo/stats
# ---------------------------------------------------------------------------
class TestGeoStats:
"""Tests for ``GET /api/geo/stats``."""
async def test_returns_200_with_stats(self, geo_client: AsyncClient) -> None:
"""GET /api/geo/stats returns 200 with the expected keys."""
stats = {
"cache_size": 100,
"unresolved": 5,
"neg_cache_size": 2,
"dirty_size": 0,
}
with patch(
"app.routers.geo.geo_service.cache_stats",
AsyncMock(return_value=stats),
):
resp = await geo_client.get("/api/geo/stats")
assert resp.status_code == 200
data = resp.json()
assert data["cache_size"] == 100
assert data["unresolved"] == 5
assert data["neg_cache_size"] == 2
assert data["dirty_size"] == 0
async def test_stats_empty_cache(self, geo_client: AsyncClient) -> None:
"""GET /api/geo/stats returns all zeros on a fresh database."""
resp = await geo_client.get("/api/geo/stats")
assert resp.status_code == 200
data = resp.json()
assert data["cache_size"] >= 0
assert data["unresolved"] == 0
assert data["neg_cache_size"] >= 0
assert data["dirty_size"] >= 0
async def test_stats_counts_unresolved(self, geo_client: AsyncClient) -> None:
"""GET /api/geo/stats counts NULL-country rows correctly."""
app = geo_client._transport.app # type: ignore[attr-defined]
db: aiosqlite.Connection = app.state.db
await db.execute("INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)", ("7.7.7.7",))
await db.execute("INSERT OR IGNORE INTO geo_cache (ip) VALUES (?)", ("8.8.8.8",))
await db.commit()
resp = await geo_client.get("/api/geo/stats")
assert resp.status_code == 200
assert resp.json()["unresolved"] >= 2
async def test_401_when_unauthenticated(self, geo_client: AsyncClient) -> None:
"""GET /api/geo/stats requires authentication."""
app = geo_client._transport.app # type: ignore[attr-defined]
resp = await AsyncClient(
transport=ASGITransport(app=app),
base_url="http://test",
).get("/api/geo/stats")
assert resp.status_code == 401

View File

@@ -293,3 +293,47 @@ class TestSchedule:
)
info = await blocklist_service.get_schedule_info(db, None)
assert info.last_run_errors is True
# ---------------------------------------------------------------------------
# Geo prewarm cache filtering
# ---------------------------------------------------------------------------
class TestGeoPrewarmCacheFilter:
async def test_import_source_skips_cached_ips_for_geo_prewarm(
self, db: aiosqlite.Connection
) -> None:
"""import_source only sends uncached IPs to geo_service.lookup_batch."""
content = "1.2.3.4\n5.6.7.8\n9.10.11.12\n"
session = _make_session(content)
source = await blocklist_service.create_source(
db, "Geo Filter", "https://gf.test/"
)
# Pretend 1.2.3.4 is already cached.
def _mock_is_cached(ip: str) -> bool:
return ip == "1.2.3.4"
with (
patch("app.services.jail_service.ban_ip", new_callable=AsyncMock),
patch(
"app.services.geo_service.is_cached",
side_effect=_mock_is_cached,
),
patch(
"app.services.geo_service.lookup_batch",
new_callable=AsyncMock,
return_value={},
) as mock_batch,
):
result = await blocklist_service.import_source(
source, session, "/tmp/fake.sock", db
)
assert result.ips_imported == 3
# lookup_batch should receive only the 2 uncached IPs.
mock_batch.assert_called_once()
call_ips = mock_batch.call_args[0][0]
assert "1.2.3.4" not in call_ips
assert set(call_ips) == {"5.6.7.8", "9.10.11.12"}

View File

@@ -0,0 +1,401 @@
"""Tests for file_config_service functions."""
from __future__ import annotations
from pathlib import Path
import pytest
from app.models.file_config import ConfFileCreateRequest, ConfFileUpdateRequest
from app.services.file_config_service import (
ConfigDirError,
ConfigFileExistsError,
ConfigFileNameError,
ConfigFileNotFoundError,
ConfigFileWriteError,
_parse_enabled,
_set_enabled_in_content,
_validate_new_name,
create_action_file,
create_filter_file,
get_action_file,
get_filter_file,
get_jail_config_file,
list_action_files,
list_filter_files,
list_jail_config_files,
set_jail_config_enabled,
write_action_file,
write_filter_file,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_config_dir(tmp_path: Path) -> Path:
"""Create a minimal fail2ban config directory structure."""
config_dir = tmp_path / "fail2ban"
(config_dir / "jail.d").mkdir(parents=True)
(config_dir / "filter.d").mkdir(parents=True)
(config_dir / "action.d").mkdir(parents=True)
return config_dir
# ---------------------------------------------------------------------------
# _parse_enabled
# ---------------------------------------------------------------------------
def test_parse_enabled_explicit_true(tmp_path: Path) -> None:
f = tmp_path / "sshd.conf"
f.write_text("[sshd]\nenabled = true\n")
assert _parse_enabled(f) is True
def test_parse_enabled_explicit_false(tmp_path: Path) -> None:
f = tmp_path / "sshd.conf"
f.write_text("[sshd]\nenabled = false\n")
assert _parse_enabled(f) is False
def test_parse_enabled_default_true_when_absent(tmp_path: Path) -> None:
f = tmp_path / "sshd.conf"
f.write_text("[sshd]\nbantime = 600\n")
assert _parse_enabled(f) is True
def test_parse_enabled_in_default_section(tmp_path: Path) -> None:
f = tmp_path / "custom.conf"
f.write_text("[DEFAULT]\nenabled = false\n")
assert _parse_enabled(f) is False
# ---------------------------------------------------------------------------
# _set_enabled_in_content
# ---------------------------------------------------------------------------
def test_set_enabled_replaces_existing_line() -> None:
src = "[sshd]\nenabled = false\nbantime = 600\n"
result = _set_enabled_in_content(src, True)
assert "enabled = true" in result
assert "enabled = false" not in result
def test_set_enabled_inserts_after_section() -> None:
src = "[sshd]\nbantime = 600\n"
result = _set_enabled_in_content(src, False)
assert "enabled = false" in result
def test_set_enabled_prepends_default_when_no_section() -> None:
result = _set_enabled_in_content("bantime = 600\n", True)
assert "enabled = true" in result
# ---------------------------------------------------------------------------
# _validate_new_name
# ---------------------------------------------------------------------------
@pytest.mark.parametrize("name", ["sshd", "my-filter", "test.local", "A1_filter"])
def test_validate_new_name_valid(name: str) -> None:
_validate_new_name(name) # should not raise
@pytest.mark.parametrize(
"name",
[
"",
".",
".hidden",
"../escape",
"bad/slash",
"a" * 129, # too long
"hello world", # space
],
)
def test_validate_new_name_invalid(name: str) -> None:
with pytest.raises(ConfigFileNameError):
_validate_new_name(name)
# ---------------------------------------------------------------------------
# list_jail_config_files
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_list_jail_config_files_empty(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
resp = await list_jail_config_files(str(config_dir))
assert resp.files == []
assert resp.total == 0
@pytest.mark.asyncio
async def test_list_jail_config_files_returns_conf_files(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "jail.d" / "sshd.conf").write_text("[sshd]\nenabled = true\n")
(config_dir / "jail.d" / "nginx.conf").write_text("[nginx]\n")
(config_dir / "jail.d" / "other.txt").write_text("ignored")
resp = await list_jail_config_files(str(config_dir))
names = {f.filename for f in resp.files}
assert names == {"sshd.conf", "nginx.conf"}
assert resp.total == 2
@pytest.mark.asyncio
async def test_list_jail_config_files_enabled_state(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "jail.d" / "a.conf").write_text("[a]\nenabled = false\n")
(config_dir / "jail.d" / "b.conf").write_text("[b]\n")
resp = await list_jail_config_files(str(config_dir))
by_name = {f.filename: f for f in resp.files}
assert by_name["a.conf"].enabled is False
assert by_name["b.conf"].enabled is True
@pytest.mark.asyncio
async def test_list_jail_config_files_missing_config_dir(tmp_path: Path) -> None:
with pytest.raises(ConfigDirError):
await list_jail_config_files(str(tmp_path / "nonexistent"))
# ---------------------------------------------------------------------------
# get_jail_config_file
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_jail_config_file_returns_content(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "jail.d" / "sshd.conf").write_text("[sshd]\nenabled = true\n")
result = await get_jail_config_file(str(config_dir), "sshd.conf")
assert result.filename == "sshd.conf"
assert result.name == "sshd"
assert result.enabled is True
assert "[sshd]" in result.content
@pytest.mark.asyncio
async def test_get_jail_config_file_not_found(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
with pytest.raises(ConfigFileNotFoundError):
await get_jail_config_file(str(config_dir), "missing.conf")
@pytest.mark.asyncio
async def test_get_jail_config_file_invalid_extension(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "jail.d" / "bad.txt").write_text("content")
with pytest.raises(ConfigFileNameError):
await get_jail_config_file(str(config_dir), "bad.txt")
@pytest.mark.asyncio
async def test_get_jail_config_file_path_traversal(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
with pytest.raises((ConfigFileNameError, ConfigFileNotFoundError)):
await get_jail_config_file(str(config_dir), "../jail.conf")
# ---------------------------------------------------------------------------
# set_jail_config_enabled
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_set_jail_config_enabled_writes_false(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
path = config_dir / "jail.d" / "sshd.conf"
path.write_text("[sshd]\nenabled = true\n")
await set_jail_config_enabled(str(config_dir), "sshd.conf", False)
assert "enabled = false" in path.read_text()
@pytest.mark.asyncio
async def test_set_jail_config_enabled_inserts_when_missing(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
path = config_dir / "jail.d" / "sshd.conf"
path.write_text("[sshd]\nbantime = 600\n")
await set_jail_config_enabled(str(config_dir), "sshd.conf", False)
assert "enabled = false" in path.read_text()
@pytest.mark.asyncio
async def test_set_jail_config_enabled_file_not_found(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
with pytest.raises(ConfigFileNotFoundError):
await set_jail_config_enabled(str(config_dir), "missing.conf", True)
# ---------------------------------------------------------------------------
# list_filter_files / list_action_files
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_list_filter_files_empty(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
resp = await list_filter_files(str(config_dir))
assert resp.files == []
@pytest.mark.asyncio
async def test_list_filter_files_returns_files(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "filter.d" / "nginx.conf").write_text("[Definition]\n")
(config_dir / "filter.d" / "sshd.local").write_text("[Definition]\n")
(config_dir / "filter.d" / "ignore.py").write_text("# ignored")
resp = await list_filter_files(str(config_dir))
names = {f.filename for f in resp.files}
assert names == {"nginx.conf", "sshd.local"}
@pytest.mark.asyncio
async def test_list_action_files_returns_files(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "action.d" / "iptables.conf").write_text("[Definition]\n")
resp = await list_action_files(str(config_dir))
assert resp.files[0].filename == "iptables.conf"
# ---------------------------------------------------------------------------
# get_filter_file / get_action_file
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_filter_file_by_stem(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "filter.d" / "nginx.conf").write_text("[Definition]\nfailregex = test\n")
result = await get_filter_file(str(config_dir), "nginx")
assert result.name == "nginx"
assert "failregex" in result.content
@pytest.mark.asyncio
async def test_get_filter_file_by_full_name(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "filter.d" / "nginx.conf").write_text("[Definition]\n")
result = await get_filter_file(str(config_dir), "nginx.conf")
assert result.filename == "nginx.conf"
@pytest.mark.asyncio
async def test_get_filter_file_not_found(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
with pytest.raises(ConfigFileNotFoundError):
await get_filter_file(str(config_dir), "nonexistent")
@pytest.mark.asyncio
async def test_get_action_file_returns_content(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "action.d" / "iptables.conf").write_text("[Definition]\nactionban = <ip>\n")
result = await get_action_file(str(config_dir), "iptables")
assert "actionban" in result.content
# ---------------------------------------------------------------------------
# write_filter_file / write_action_file
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_write_filter_file_updates_content(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "filter.d" / "nginx.conf").write_text("[Definition]\n")
req = ConfFileUpdateRequest(content="[Definition]\nfailregex = new\n")
await write_filter_file(str(config_dir), "nginx", req)
assert "failregex = new" in (config_dir / "filter.d" / "nginx.conf").read_text()
@pytest.mark.asyncio
async def test_write_filter_file_not_found(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
req = ConfFileUpdateRequest(content="[Definition]\n")
with pytest.raises(ConfigFileNotFoundError):
await write_filter_file(str(config_dir), "missing", req)
@pytest.mark.asyncio
async def test_write_filter_file_too_large(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "filter.d" / "nginx.conf").write_text("[Definition]\n")
big_content = "x" * (512 * 1024 + 1)
req = ConfFileUpdateRequest(content=big_content)
with pytest.raises(ConfigFileWriteError):
await write_filter_file(str(config_dir), "nginx", req)
@pytest.mark.asyncio
async def test_write_action_file_updates_content(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "action.d" / "iptables.conf").write_text("[Definition]\n")
req = ConfFileUpdateRequest(content="[Definition]\nactionban = new\n")
await write_action_file(str(config_dir), "iptables", req)
assert "actionban = new" in (config_dir / "action.d" / "iptables.conf").read_text()
# ---------------------------------------------------------------------------
# create_filter_file / create_action_file
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_create_filter_file_creates_file(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
req = ConfFileCreateRequest(name="myfilter", content="[Definition]\n")
result = await create_filter_file(str(config_dir), req)
assert result == "myfilter.conf"
assert (config_dir / "filter.d" / "myfilter.conf").is_file()
@pytest.mark.asyncio
async def test_create_filter_file_conflict(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
(config_dir / "filter.d" / "ngx.conf").write_text("[Definition]\n")
req = ConfFileCreateRequest(name="ngx", content="[Definition]\n")
with pytest.raises(ConfigFileExistsError):
await create_filter_file(str(config_dir), req)
@pytest.mark.asyncio
async def test_create_filter_file_invalid_name(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
req = ConfFileCreateRequest(name="../escape", content="[Definition]\n")
with pytest.raises(ConfigFileNameError):
await create_filter_file(str(config_dir), req)
@pytest.mark.asyncio
async def test_create_action_file_creates_file(tmp_path: Path) -> None:
config_dir = _make_config_dir(tmp_path)
req = ConfFileCreateRequest(name="my-action", content="[Definition]\n")
result = await create_action_file(str(config_dir), req)
assert result == "my-action.conf"
assert (config_dir / "action.d" / "my-action.conf").is_file()

View File

@@ -0,0 +1 @@
"""APScheduler task tests package."""

View File

@@ -0,0 +1,167 @@
"""Tests for the geo re-resolve background task.
Validates that :func:`~app.tasks.geo_re_resolve._run_re_resolve` correctly
queries NULL-country IPs from the database, clears the negative cache, and
delegates to :func:`~app.services.geo_service.lookup_batch` for a fresh
resolution attempt.
"""
from __future__ import annotations
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.services.geo_service import GeoInfo
from app.tasks.geo_re_resolve import _run_re_resolve
class _AsyncRowIterator:
"""Minimal async iterator over a list of row tuples."""
def __init__(self, rows: list[tuple[str]]) -> None:
self._iter = iter(rows)
def __aiter__(self) -> _AsyncRowIterator:
return self
async def __anext__(self) -> tuple[str]:
try:
return next(self._iter)
except StopIteration:
raise StopAsyncIteration # noqa: B904
def _make_app(
unresolved_ips: list[str],
lookup_result: dict[str, GeoInfo] | None = None,
) -> MagicMock:
"""Build a minimal mock ``app`` with ``state.db`` and ``state.http_session``.
The mock database returns *unresolved_ips* when the re-resolve task
queries ``SELECT ip FROM geo_cache WHERE country_code IS NULL``.
Args:
unresolved_ips: IPs to return from the mocked DB query.
lookup_result: Value returned by the mocked ``lookup_batch``.
Defaults to an empty dict.
Returns:
A :class:`unittest.mock.MagicMock` that mimics ``fastapi.FastAPI``.
"""
if lookup_result is None:
lookup_result = {}
rows = [(ip,) for ip in unresolved_ips]
cursor = _AsyncRowIterator(rows)
# db.execute() returns an async context manager yielding the cursor.
ctx = AsyncMock()
ctx.__aenter__ = AsyncMock(return_value=cursor)
ctx.__aexit__ = AsyncMock(return_value=False)
db = AsyncMock()
db.execute = MagicMock(return_value=ctx)
http_session = MagicMock()
app = MagicMock()
app.state.db = db
app.state.http_session = http_session
return app
@pytest.mark.asyncio
async def test_run_re_resolve_no_unresolved_ips_skips() -> None:
"""The task should return immediately when no NULL-country IPs exist."""
app = _make_app(unresolved_ips=[])
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
await _run_re_resolve(app)
mock_geo.clear_neg_cache.assert_not_called()
mock_geo.lookup_batch.assert_not_called()
@pytest.mark.asyncio
async def test_run_re_resolve_clears_neg_cache() -> None:
"""The task must clear the negative cache before calling lookup_batch."""
ips = ["1.2.3.4", "5.6.7.8"]
result: dict[str, GeoInfo] = {
"1.2.3.4": GeoInfo(country_code="DE", country_name="Germany", asn="AS3320", org="DTAG"),
"5.6.7.8": GeoInfo(country_code="US", country_name="United States", asn="AS15169", org="Google"),
}
app = _make_app(unresolved_ips=ips, lookup_result=result)
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
mock_geo.lookup_batch = AsyncMock(return_value=result)
await _run_re_resolve(app)
mock_geo.clear_neg_cache.assert_called_once()
@pytest.mark.asyncio
async def test_run_re_resolve_calls_lookup_batch_with_db() -> None:
"""The task must pass the real db to lookup_batch for persistence."""
ips = ["10.0.0.1", "10.0.0.2"]
result: dict[str, GeoInfo] = {
"10.0.0.1": GeoInfo(country_code="FR", country_name="France", asn=None, org=None),
"10.0.0.2": GeoInfo(country_code=None, country_name=None, asn=None, org=None),
}
app = _make_app(unresolved_ips=ips, lookup_result=result)
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
mock_geo.lookup_batch = AsyncMock(return_value=result)
await _run_re_resolve(app)
mock_geo.lookup_batch.assert_called_once_with(
ips,
app.state.http_session,
db=app.state.db,
)
@pytest.mark.asyncio
async def test_run_re_resolve_logs_correct_counts(caplog: Any) -> None:
"""The task should log the number retried and number resolved."""
ips = ["1.1.1.1", "2.2.2.2", "3.3.3.3"]
result: dict[str, GeoInfo] = {
"1.1.1.1": GeoInfo(country_code="AU", country_name="Australia", asn=None, org=None),
"2.2.2.2": GeoInfo(country_code="JP", country_name="Japan", asn=None, org=None),
"3.3.3.3": GeoInfo(country_code=None, country_name=None, asn=None, org=None),
}
app = _make_app(unresolved_ips=ips, lookup_result=result)
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
mock_geo.lookup_batch = AsyncMock(return_value=result)
await _run_re_resolve(app)
# Verify lookup_batch was called (the logging assertions rely on
# structlog which is hard to capture in caplog; instead we verify
# the function ran to completion and the counts are correct by
# checking that lookup_batch received the right number of IPs).
call_args = mock_geo.lookup_batch.call_args
assert len(call_args[0][0]) == 3
@pytest.mark.asyncio
async def test_run_re_resolve_handles_all_resolved() -> None:
"""When every IP resolves successfully the task should complete normally."""
ips = ["4.4.4.4"]
result: dict[str, GeoInfo] = {
"4.4.4.4": GeoInfo(country_code="GB", country_name="United Kingdom", asn=None, org=None),
}
app = _make_app(unresolved_ips=ips, lookup_result=result)
with patch("app.tasks.geo_re_resolve.geo_service") as mock_geo:
mock_geo.lookup_batch = AsyncMock(return_value=result)
await _run_re_resolve(app)
mock_geo.clear_neg_cache.assert_called_once()
mock_geo.lookup_batch.assert_called_once()