fixed tests

This commit is contained in:
2026-05-15 20:41:05 +02:00
parent 96ce516ecf
commit 77df5d5d65
50 changed files with 1482 additions and 5089 deletions

View File

@@ -12,15 +12,36 @@ from httpx import ASGITransport, AsyncClient
from app.config import Settings
from app.db import init_db
from app.main import create_app
from app.models.ban import JailBannedIpsResponse
from app.models.jail import Jail, JailDetailResponse, JailListResponse, JailStatus, JailSummary
from app.services.geo_cache import GeoCache
from app.utils.session_cache import NoOpSessionCache
from app.utils.setup_state import set_setup_complete_cache
async def _write_password_hash(db: aiosqlite.Connection, password: str) -> str:
"""Hash password and write to settings table."""
import asyncio
import bcrypt
pw_bytes = password.encode()
hashed = await asyncio.get_event_loop().run_in_executor(
None, lambda: bcrypt.hashpw(pw_bytes, bcrypt.gensalt()).decode()
)
await db.execute(
"INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)",
("master_password_hash", hashed),
)
await db.commit()
return hashed
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
_SETUP_PAYLOAD = {
"master_password": "testpassword1",
"master_password": "Testpass1!",
"database_path": "bangui.db",
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
"timezone": "UTC",
@@ -31,25 +52,41 @@ _SETUP_PAYLOAD = {
@pytest.fixture
async def jails_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
"""Provide an authenticated ``AsyncClient`` for jail endpoint tests."""
import os
os.makedirs(tmp_path / "fail2ban", exist_ok=True)
settings = Settings(
database_path=str(tmp_path / "jails_test.db"),
fail2ban_socket="/tmp/fake.sock",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_secret="test-jails-secret-0000000000000000000000",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
session_cookie_secure=False,
)
app = create_app(settings=settings)
set_setup_complete_cache(app, True)
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
db.row_factory = aiosqlite.Row
await init_db(db)
await _write_password_hash(db, _SETUP_PAYLOAD["master_password"])
app.state.db = db
app.state.http_session = MagicMock()
app.state.session_cache = NoOpSessionCache()
app.state.geo_cache = GeoCache()
async def _override_get_db():
yield db
from app.dependencies import get_db, get_session_cache
app.dependency_overrides[get_db] = _override_get_db
app.dependency_overrides[get_session_cache] = lambda: NoOpSessionCache()
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
async with AsyncClient(transport=transport, base_url="http://test", headers={"X-BanGUI-Request": "1"}) as ac:
login = await ac.post(
"/api/v1/auth/login",
json={"password": _SETUP_PAYLOAD["master_password"]},
@@ -58,6 +95,7 @@ async def jails_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
yield ac
await db.close()
app.dependency_overrides.clear()
# ---------------------------------------------------------------------------
@@ -172,9 +210,19 @@ class TestGetJailDetail:
async def test_200_for_existing_jail(self, jails_client: AsyncClient) -> None:
"""GET /api/jails/sshd returns 200 with full jail detail."""
with patch(
"app.routers.jails.jail_service.get_jail",
AsyncMock(return_value=_detail()),
with (
patch(
"app.routers.jails.jail_service.get_jail",
AsyncMock(return_value=_detail()),
),
patch(
"app.routers.jails.jail_service.get_ignore_list",
AsyncMock(return_value=["127.0.0.1"]),
),
patch(
"app.routers.jails.jail_service.get_ignore_self",
AsyncMock(return_value=False),
),
):
resp = await jails_client.get("/api/v1/jails/sshd")
@@ -808,25 +856,21 @@ class TestGetJailBannedIps:
total: int = 2,
page: int = 1,
page_size: int = 25,
) -> JailBannedIpsResponse:
from app.models.ban import ActiveBan, JailBannedIpsResponse
):
from app.models.jail_domain import DomainActiveBan, DomainJailBannedIps
ban_items = (
[
ActiveBan(
ip=item.get("ip") or "1.2.3.4",
jail="sshd",
banned_at=item.get("banned_at", "2025-01-01T10:00:00+00:00"),
expires_at=item.get("expires_at", "2025-01-01T10:10:00+00:00"),
ban_count=1,
country=item.get("country", None),
)
for item in (items or [{"ip": "1.2.3.4"}, {"ip": "5.6.7.8"}])
]
)
return JailBannedIpsResponse(
items=ban_items, total=total, page=page, page_size=page_size
)
ban_items = [
DomainActiveBan(
ip=item.get("ip") or "1.2.3.4",
jail="sshd",
banned_at=item.get("banned_at", "2025-01-01T10:00:00+00:00"),
expires_at=item.get("expires_at", "2025-01-01T10:10:00+00:00"),
ban_count=1,
country=item.get("country", None),
)
for item in (items or [{"ip": "1.2.3.4"}, {"ip": "5.6.7.8"}])
]
return DomainJailBannedIps(items=ban_items, total=total, page=page, page_size=page_size)
async def test_200_returns_paginated_bans(self, jails_client: AsyncClient) -> None:
"""GET /api/jails/sshd/banned returns 200 with a JailBannedIpsResponse."""
@@ -839,10 +883,10 @@ class TestGetJailBannedIps:
assert resp.status_code == 200
data = resp.json()
assert "items" in data
assert "total" in data
assert "page" in data
assert "page_size" in data
assert data["total"] == 2
assert "pagination" in data
assert data["pagination"]["total"] == 2
assert data["pagination"]["page"] == 1
assert data["pagination"]["page_size"] == 25
async def test_200_with_search_parameter(self, jails_client: AsyncClient) -> None:
"""GET /api/jails/sshd/banned?search=1.2.3 passes search to service."""
@@ -856,9 +900,7 @@ class TestGetJailBannedIps:
async def test_200_with_page_and_page_size(self, jails_client: AsyncClient) -> None:
"""GET /api/jails/sshd/banned?page=2&page_size=10 passes params to service."""
mock_fn = AsyncMock(
return_value=self._mock_response(page=2, page_size=10, total=0, items=[])
)
mock_fn = AsyncMock(return_value=self._mock_response(page=2, page_size=10, total=0, items=[]))
with patch("app.routers.jails.jail_service.get_jail_banned_ips", mock_fn):
resp = await jails_client.get("/api/v1/jails/sshd/banned?page=2&page_size=10")
@@ -900,17 +942,13 @@ class TestGetJailBannedIps:
with patch(
"app.routers.jails.jail_service.get_jail_banned_ips",
AsyncMock(
side_effect=Fail2BanConnectionError("socket dead", "/tmp/fake.sock")
),
AsyncMock(side_effect=Fail2BanConnectionError("socket dead", "/tmp/fake.sock")),
):
resp = await jails_client.get("/api/v1/jails/sshd/banned")
assert resp.status_code == 502
async def test_response_items_have_expected_fields(
self, jails_client: AsyncClient
) -> None:
async def test_response_items_have_expected_fields(self, jails_client: AsyncClient) -> None:
"""Response items contain ip, jail, banned_at, expires_at, ban_count, country."""
with patch(
"app.routers.jails.jail_service.get_jail_banned_ips",
@@ -933,4 +971,3 @@ class TestGetJailBannedIps:
base_url="http://test",
).get("/api/v1/jails/sshd/banned")
assert resp.status_code == 401