fixed tests
This commit is contained in:
@@ -12,15 +12,36 @@ from httpx import ASGITransport, AsyncClient
|
||||
from app.config import Settings
|
||||
from app.db import init_db
|
||||
from app.main import create_app
|
||||
from app.models.ban import JailBannedIpsResponse
|
||||
from app.models.jail import Jail, JailDetailResponse, JailListResponse, JailStatus, JailSummary
|
||||
from app.services.geo_cache import GeoCache
|
||||
from app.utils.session_cache import NoOpSessionCache
|
||||
from app.utils.setup_state import set_setup_complete_cache
|
||||
|
||||
|
||||
async def _write_password_hash(db: aiosqlite.Connection, password: str) -> str:
|
||||
"""Hash password and write to settings table."""
|
||||
import asyncio
|
||||
|
||||
import bcrypt
|
||||
|
||||
pw_bytes = password.encode()
|
||||
hashed = await asyncio.get_event_loop().run_in_executor(
|
||||
None, lambda: bcrypt.hashpw(pw_bytes, bcrypt.gensalt()).decode()
|
||||
)
|
||||
await db.execute(
|
||||
"INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)",
|
||||
("master_password_hash", hashed),
|
||||
)
|
||||
await db.commit()
|
||||
return hashed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SETUP_PAYLOAD = {
|
||||
"master_password": "testpassword1",
|
||||
"master_password": "Testpass1!",
|
||||
"database_path": "bangui.db",
|
||||
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
||||
"timezone": "UTC",
|
||||
@@ -31,25 +52,41 @@ _SETUP_PAYLOAD = {
|
||||
@pytest.fixture
|
||||
async def jails_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
"""Provide an authenticated ``AsyncClient`` for jail endpoint tests."""
|
||||
import os
|
||||
|
||||
os.makedirs(tmp_path / "fail2ban", exist_ok=True)
|
||||
settings = Settings(
|
||||
database_path=str(tmp_path / "jails_test.db"),
|
||||
fail2ban_socket="/tmp/fake.sock",
|
||||
fail2ban_config_dir=str(tmp_path / "fail2ban"),
|
||||
session_secret="test-jails-secret-0000000000000000000000",
|
||||
session_duration_minutes=60,
|
||||
timezone="UTC",
|
||||
log_level="debug",
|
||||
session_cookie_secure=False,
|
||||
)
|
||||
app = create_app(settings=settings)
|
||||
set_setup_complete_cache(app, True)
|
||||
|
||||
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
|
||||
db.row_factory = aiosqlite.Row
|
||||
await init_db(db)
|
||||
await _write_password_hash(db, _SETUP_PAYLOAD["master_password"])
|
||||
app.state.db = db
|
||||
app.state.http_session = MagicMock()
|
||||
app.state.session_cache = NoOpSessionCache()
|
||||
app.state.geo_cache = GeoCache()
|
||||
|
||||
async def _override_get_db():
|
||||
yield db
|
||||
|
||||
from app.dependencies import get_db, get_session_cache
|
||||
|
||||
app.dependency_overrides[get_db] = _override_get_db
|
||||
app.dependency_overrides[get_session_cache] = lambda: NoOpSessionCache()
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
|
||||
async with AsyncClient(transport=transport, base_url="http://test", headers={"X-BanGUI-Request": "1"}) as ac:
|
||||
login = await ac.post(
|
||||
"/api/v1/auth/login",
|
||||
json={"password": _SETUP_PAYLOAD["master_password"]},
|
||||
@@ -58,6 +95,7 @@ async def jails_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
|
||||
yield ac
|
||||
|
||||
await db.close()
|
||||
app.dependency_overrides.clear()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -172,9 +210,19 @@ class TestGetJailDetail:
|
||||
|
||||
async def test_200_for_existing_jail(self, jails_client: AsyncClient) -> None:
|
||||
"""GET /api/jails/sshd returns 200 with full jail detail."""
|
||||
with patch(
|
||||
"app.routers.jails.jail_service.get_jail",
|
||||
AsyncMock(return_value=_detail()),
|
||||
with (
|
||||
patch(
|
||||
"app.routers.jails.jail_service.get_jail",
|
||||
AsyncMock(return_value=_detail()),
|
||||
),
|
||||
patch(
|
||||
"app.routers.jails.jail_service.get_ignore_list",
|
||||
AsyncMock(return_value=["127.0.0.1"]),
|
||||
),
|
||||
patch(
|
||||
"app.routers.jails.jail_service.get_ignore_self",
|
||||
AsyncMock(return_value=False),
|
||||
),
|
||||
):
|
||||
resp = await jails_client.get("/api/v1/jails/sshd")
|
||||
|
||||
@@ -808,25 +856,21 @@ class TestGetJailBannedIps:
|
||||
total: int = 2,
|
||||
page: int = 1,
|
||||
page_size: int = 25,
|
||||
) -> JailBannedIpsResponse:
|
||||
from app.models.ban import ActiveBan, JailBannedIpsResponse
|
||||
):
|
||||
from app.models.jail_domain import DomainActiveBan, DomainJailBannedIps
|
||||
|
||||
ban_items = (
|
||||
[
|
||||
ActiveBan(
|
||||
ip=item.get("ip") or "1.2.3.4",
|
||||
jail="sshd",
|
||||
banned_at=item.get("banned_at", "2025-01-01T10:00:00+00:00"),
|
||||
expires_at=item.get("expires_at", "2025-01-01T10:10:00+00:00"),
|
||||
ban_count=1,
|
||||
country=item.get("country", None),
|
||||
)
|
||||
for item in (items or [{"ip": "1.2.3.4"}, {"ip": "5.6.7.8"}])
|
||||
]
|
||||
)
|
||||
return JailBannedIpsResponse(
|
||||
items=ban_items, total=total, page=page, page_size=page_size
|
||||
)
|
||||
ban_items = [
|
||||
DomainActiveBan(
|
||||
ip=item.get("ip") or "1.2.3.4",
|
||||
jail="sshd",
|
||||
banned_at=item.get("banned_at", "2025-01-01T10:00:00+00:00"),
|
||||
expires_at=item.get("expires_at", "2025-01-01T10:10:00+00:00"),
|
||||
ban_count=1,
|
||||
country=item.get("country", None),
|
||||
)
|
||||
for item in (items or [{"ip": "1.2.3.4"}, {"ip": "5.6.7.8"}])
|
||||
]
|
||||
return DomainJailBannedIps(items=ban_items, total=total, page=page, page_size=page_size)
|
||||
|
||||
async def test_200_returns_paginated_bans(self, jails_client: AsyncClient) -> None:
|
||||
"""GET /api/jails/sshd/banned returns 200 with a JailBannedIpsResponse."""
|
||||
@@ -839,10 +883,10 @@ class TestGetJailBannedIps:
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "items" in data
|
||||
assert "total" in data
|
||||
assert "page" in data
|
||||
assert "page_size" in data
|
||||
assert data["total"] == 2
|
||||
assert "pagination" in data
|
||||
assert data["pagination"]["total"] == 2
|
||||
assert data["pagination"]["page"] == 1
|
||||
assert data["pagination"]["page_size"] == 25
|
||||
|
||||
async def test_200_with_search_parameter(self, jails_client: AsyncClient) -> None:
|
||||
"""GET /api/jails/sshd/banned?search=1.2.3 passes search to service."""
|
||||
@@ -856,9 +900,7 @@ class TestGetJailBannedIps:
|
||||
|
||||
async def test_200_with_page_and_page_size(self, jails_client: AsyncClient) -> None:
|
||||
"""GET /api/jails/sshd/banned?page=2&page_size=10 passes params to service."""
|
||||
mock_fn = AsyncMock(
|
||||
return_value=self._mock_response(page=2, page_size=10, total=0, items=[])
|
||||
)
|
||||
mock_fn = AsyncMock(return_value=self._mock_response(page=2, page_size=10, total=0, items=[]))
|
||||
with patch("app.routers.jails.jail_service.get_jail_banned_ips", mock_fn):
|
||||
resp = await jails_client.get("/api/v1/jails/sshd/banned?page=2&page_size=10")
|
||||
|
||||
@@ -900,17 +942,13 @@ class TestGetJailBannedIps:
|
||||
|
||||
with patch(
|
||||
"app.routers.jails.jail_service.get_jail_banned_ips",
|
||||
AsyncMock(
|
||||
side_effect=Fail2BanConnectionError("socket dead", "/tmp/fake.sock")
|
||||
),
|
||||
AsyncMock(side_effect=Fail2BanConnectionError("socket dead", "/tmp/fake.sock")),
|
||||
):
|
||||
resp = await jails_client.get("/api/v1/jails/sshd/banned")
|
||||
|
||||
assert resp.status_code == 502
|
||||
|
||||
async def test_response_items_have_expected_fields(
|
||||
self, jails_client: AsyncClient
|
||||
) -> None:
|
||||
async def test_response_items_have_expected_fields(self, jails_client: AsyncClient) -> None:
|
||||
"""Response items contain ip, jail, banned_at, expires_at, ban_count, country."""
|
||||
with patch(
|
||||
"app.routers.jails.jail_service.get_jail_banned_ips",
|
||||
@@ -933,4 +971,3 @@ class TestGetJailBannedIps:
|
||||
base_url="http://test",
|
||||
).get("/api/v1/jails/sshd/banned")
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
Reference in New Issue
Block a user