277 lines
10 KiB
Python
277 lines
10 KiB
Python
"""Regression tests for the four 500-error bugs discovered on 2026-03-22.
|
|
|
|
Each test targets the exact code path that caused a 500 Internal Server Error.
|
|
These tests call the **real** service/repository functions (not the router)
|
|
so they fail even if the route layer is mocked in router-level tests.
|
|
|
|
Bugs covered:
|
|
1. ``list_history`` rejected the ``origin`` keyword argument (TypeError).
|
|
2. ``jail_config_service`` used ``_get_active_jail_names`` without importing it.
|
|
3. ``filter_config_service`` used ``_parse_jails_sync`` / ``_get_active_jail_names``
|
|
without importing them.
|
|
4. ``config_service.get_service_status`` omitted the required ``bangui_version``
|
|
field from the ``ServiceStatusResponse`` constructor (Pydantic ValidationError).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import inspect
|
|
import json
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import aiosqlite
|
|
import pytest
|
|
|
|
# ── Bug 1 ─────────────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestHistoryOriginParameter:
|
|
"""Bug 1: ``origin`` parameter must be threaded through service → repo."""
|
|
|
|
# -- Service layer --
|
|
|
|
async def test_list_history_accepts_origin_kwarg(self) -> None:
|
|
"""``history_service.list_history()`` must accept an ``origin`` keyword."""
|
|
from app.services import history_service
|
|
|
|
sig = inspect.signature(history_service.list_history)
|
|
assert "origin" in sig.parameters, (
|
|
"list_history() is missing the 'origin' parameter — "
|
|
"the router passes origin=… which would cause a TypeError"
|
|
)
|
|
|
|
async def test_list_history_forwards_origin_to_repo(
|
|
self, tmp_path: Path
|
|
) -> None:
|
|
"""``list_history(origin='blocklist')`` must forward origin to the DB repo."""
|
|
from app.services import history_service
|
|
|
|
db_path = str(tmp_path / "f2b.db")
|
|
async with aiosqlite.connect(db_path) as db:
|
|
await db.execute(
|
|
"CREATE TABLE jails (name TEXT, enabled INTEGER DEFAULT 1)"
|
|
)
|
|
await db.execute(
|
|
"CREATE TABLE bans "
|
|
"(jail TEXT, ip TEXT, timeofban INTEGER, bantime INTEGER, "
|
|
"bancount INTEGER DEFAULT 1, data JSON)"
|
|
)
|
|
await db.execute(
|
|
"INSERT INTO bans VALUES (?, ?, ?, ?, ?, ?)",
|
|
("blocklist-import", "10.0.0.1", int(time.time()), 3600, 1, "{}"),
|
|
)
|
|
await db.execute(
|
|
"INSERT INTO bans VALUES (?, ?, ?, ?, ?, ?)",
|
|
("sshd", "10.0.0.2", int(time.time()), 3600, 1, "{}"),
|
|
)
|
|
await db.commit()
|
|
|
|
with patch(
|
|
"app.services.history_service.get_fail2ban_db_path",
|
|
new=AsyncMock(return_value=db_path),
|
|
):
|
|
result = await history_service.list_history(
|
|
"fake_socket", origin="blocklist"
|
|
)
|
|
|
|
assert all(
|
|
item.jail == "blocklist-import" for item in result.items
|
|
), "origin='blocklist' must filter to blocklist-import jail only"
|
|
|
|
# -- Repository layer --
|
|
|
|
async def test_get_history_page_accepts_origin_kwarg(self) -> None:
|
|
"""``fail2ban_db_repo.get_history_page()`` must accept ``origin``."""
|
|
from app.repositories import fail2ban_db_repo
|
|
|
|
sig = inspect.signature(fail2ban_db_repo.get_history_page)
|
|
assert "origin" in sig.parameters, (
|
|
"get_history_page() is missing the 'origin' parameter"
|
|
)
|
|
|
|
async def test_get_history_page_filters_by_origin(
|
|
self, tmp_path: Path
|
|
) -> None:
|
|
"""``get_history_page(origin='selfblock')`` excludes blocklist-import."""
|
|
from app.repositories import fail2ban_db_repo
|
|
|
|
db_path = str(tmp_path / "f2b.db")
|
|
async with aiosqlite.connect(db_path) as db:
|
|
await db.execute(
|
|
"CREATE TABLE bans "
|
|
"(jail TEXT, ip TEXT, timeofban INTEGER, bancount INTEGER, data TEXT)"
|
|
)
|
|
await db.executemany(
|
|
"INSERT INTO bans VALUES (?, ?, ?, ?, ?)",
|
|
[
|
|
("blocklist-import", "10.0.0.1", 100, 1, "{}"),
|
|
("sshd", "10.0.0.2", 200, 1, "{}"),
|
|
("sshd", "10.0.0.3", 300, 1, "{}"),
|
|
],
|
|
)
|
|
await db.commit()
|
|
|
|
rows, total = await fail2ban_db_repo.get_history_page(
|
|
db_path=db_path, origin="selfblock"
|
|
)
|
|
|
|
assert total == 2
|
|
assert all(r.jail != "blocklist-import" for r in rows)
|
|
|
|
|
|
# ── Bug 2 ─────────────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestJailConfigImports:
|
|
"""Bug 2: ``jail_config_service`` must import ``_get_active_jail_names``."""
|
|
|
|
async def test_get_active_jail_names_is_importable(self) -> None:
|
|
"""The module must successfully import ``_get_active_jail_names``."""
|
|
import app.services.jail_config_service as mod
|
|
|
|
assert hasattr(mod, "_get_active_jail_names") or callable(
|
|
getattr(mod, "_get_active_jail_names", None)
|
|
), (
|
|
"_get_active_jail_names is not available in jail_config_service — "
|
|
"any call site will raise NameError → 500"
|
|
)
|
|
|
|
async def test_list_inactive_jails_does_not_raise_name_error(
|
|
self, tmp_path: Path
|
|
) -> None:
|
|
"""``list_inactive_jails`` must not crash with NameError."""
|
|
from app.services import jail_config_service
|
|
|
|
config_dir = str(tmp_path / "fail2ban")
|
|
Path(config_dir).mkdir()
|
|
(Path(config_dir) / "jail.conf").write_text("[DEFAULT]\n")
|
|
|
|
with patch(
|
|
"app.services.jail_config_service._get_active_jail_names",
|
|
new=AsyncMock(return_value=set()),
|
|
):
|
|
result = await jail_config_service.list_inactive_jails(
|
|
config_dir, "/fake/socket"
|
|
)
|
|
|
|
assert result.total >= 0
|
|
|
|
|
|
# ── Bug 3 ─────────────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestFilterConfigImports:
|
|
"""Bug 3: ``filter_config_service`` must import ``_parse_jails_sync``
|
|
and ``_get_active_jail_names``."""
|
|
|
|
async def test_parse_jails_sync_is_available(self) -> None:
|
|
"""``_parse_jails_sync`` must be resolvable at module scope."""
|
|
import app.services.filter_config_service as mod
|
|
|
|
assert hasattr(mod, "_parse_jails_sync"), (
|
|
"_parse_jails_sync is not available in filter_config_service — "
|
|
"list_filters() will raise NameError → 500"
|
|
)
|
|
|
|
async def test_get_active_jail_names_is_available(self) -> None:
|
|
"""``_get_active_jail_names`` must be resolvable at module scope."""
|
|
import app.services.filter_config_service as mod
|
|
|
|
assert hasattr(mod, "_get_active_jail_names"), (
|
|
"_get_active_jail_names is not available in filter_config_service — "
|
|
"list_filters() will raise NameError → 500"
|
|
)
|
|
|
|
async def test_list_filters_does_not_raise_name_error(
|
|
self, tmp_path: Path
|
|
) -> None:
|
|
"""``list_filters`` must not crash with NameError."""
|
|
from app.services import filter_config_service
|
|
|
|
config_dir = str(tmp_path / "fail2ban")
|
|
filter_d = Path(config_dir) / "filter.d"
|
|
filter_d.mkdir(parents=True)
|
|
|
|
# Create a minimal filter file so _parse_filters_sync has something to scan.
|
|
(filter_d / "sshd.conf").write_text(
|
|
"[Definition]\nfailregex = ^Failed password\n"
|
|
)
|
|
|
|
with (
|
|
patch(
|
|
"app.services.filter_config_service._parse_jails_sync",
|
|
return_value=({}, {}),
|
|
),
|
|
patch(
|
|
"app.services.filter_config_service._get_active_jail_names",
|
|
new=AsyncMock(return_value=set()),
|
|
),
|
|
):
|
|
result = await filter_config_service.list_filters(
|
|
config_dir, "/fake/socket"
|
|
)
|
|
|
|
assert result.total >= 0
|
|
|
|
|
|
# ── Bug 4 ─────────────────────────────────────────────────────────────────
|
|
|
|
|
|
class TestServiceStatusBanguiVersion:
|
|
"""Bug 4: ``get_service_status`` must include application version
|
|
in the ``version`` field of the ``ServiceStatusResponse``."""
|
|
|
|
async def test_online_response_contains_bangui_version(self) -> None:
|
|
"""The returned model must contain the ``bangui_version`` field."""
|
|
from app.models.server import ServerStatus
|
|
from app.services import config_service
|
|
import app
|
|
|
|
online_status = ServerStatus(
|
|
online=True,
|
|
version="1.0.0",
|
|
active_jails=2,
|
|
total_bans=5,
|
|
total_failures=3,
|
|
)
|
|
|
|
async def _send(command: list[Any]) -> Any:
|
|
key = "|".join(str(c) for c in command)
|
|
if key == "get|loglevel":
|
|
return (0, "INFO")
|
|
if key == "get|logtarget":
|
|
return (0, "/var/log/fail2ban.log")
|
|
return (0, None)
|
|
|
|
class _FakeClient:
|
|
def __init__(self, **_kw: Any) -> None:
|
|
self.send = AsyncMock(side_effect=_send)
|
|
|
|
with patch("app.services.config_service.Fail2BanClient", _FakeClient):
|
|
result = await config_service.get_service_status(
|
|
"/fake/socket",
|
|
probe_fn=AsyncMock(return_value=online_status),
|
|
)
|
|
|
|
assert result.version == app.__version__, (
|
|
"ServiceStatusResponse must expose BanGUI version in version field"
|
|
)
|
|
|
|
async def test_offline_response_contains_bangui_version(self) -> None:
|
|
"""Even when fail2ban is offline, ``bangui_version`` must be present."""
|
|
from app.models.server import ServerStatus
|
|
from app.services import config_service
|
|
import app
|
|
|
|
offline_status = ServerStatus(online=False)
|
|
|
|
result = await config_service.get_service_status(
|
|
"/fake/socket",
|
|
probe_fn=AsyncMock(return_value=offline_status),
|
|
)
|
|
|
|
assert result.version == app.__version__
|