Files
BanGUI/backend/tests/test_routers/test_bans.py
2026-05-15 20:41:05 +02:00

389 lines
14 KiB
Python

"""Tests for the bans router endpoints."""
from __future__ import annotations
from collections.abc import AsyncGenerator
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import aiosqlite
import bcrypt
import pytest
from httpx import ASGITransport, AsyncClient
from app.config import Settings
from app.db import init_db
from app.exceptions import Fail2BanConnectionError
from app.main import create_app
from app.models.ban_domain import DomainActiveBan, DomainActiveBanList
from app.services.geo_cache import GeoCache
from app.utils.session_cache import NoOpSessionCache
from app.utils.setup_state import set_setup_complete_cache
async def _write_password_hash(db: aiosqlite.Connection, password: str) -> str:
"""Hash password and write to settings table."""
pw_bytes = password.encode()
import asyncio
hashed = await asyncio.get_event_loop().run_in_executor(
None, lambda: bcrypt.hashpw(pw_bytes, bcrypt.gensalt()).decode()
)
await db.execute(
"INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)",
("master_password_hash", hashed),
)
await db.commit()
return hashed
@pytest.fixture
async def bans_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
"""Provide an authenticated ``AsyncClient`` for bans endpoint tests."""
(tmp_path / "fail2ban").mkdir()
settings = Settings(
database_path=str(tmp_path / "bans_test.db"),
fail2ban_socket="/tmp/fake.sock",
session_secret="test-bans-secret-that-is-at-least-32-chars",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_cache_enabled=False,
session_cookie_secure=False,
)
app = create_app(settings=settings)
set_setup_complete_cache(app, True)
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
db.row_factory = aiosqlite.Row
await init_db(db)
await _write_password_hash(db, _SETUP_PAYLOAD["master_password"])
app.state.db = db
app.state.http_session = MagicMock()
app.state.session_cache = NoOpSessionCache()
app.state.geo_cache = GeoCache()
async def _override_get_db() -> AsyncGenerator[aiosqlite.Connection, None]:
yield db
from app.dependencies import get_db, get_session_cache
app.dependency_overrides[get_db] = _override_get_db
app.dependency_overrides[get_session_cache] = lambda: NoOpSessionCache()
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
login = await ac.post(
"/api/v1/auth/login",
json={"password": _SETUP_PAYLOAD["master_password"]},
)
assert login.status_code == 200
yield ac
await db.close()
app.dependency_overrides.clear()
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
_SETUP_PAYLOAD = {
"master_password": "Testpass1!",
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
"timezone": "UTC",
"session_duration_minutes": 60,
"database_path": "bans_test.db",
}
# ---------------------------------------------------------------------------
# GET /api/bans/active
# ---------------------------------------------------------------------------
class TestGetActiveBans:
"""Tests for ``GET /api/bans/active``."""
async def test_200_when_authenticated(self, bans_client: AsyncClient) -> None:
"""GET /api/bans/active returns 200 with an ActiveBanListResponse."""
from app.models.ban_domain import DomainActiveBan, DomainActiveBanList
mock_response = DomainActiveBanList(
bans=[
DomainActiveBan(
ip="1.2.3.4",
jail="sshd",
banned_at="2025-01-01T12:00:00+00:00",
expires_at="2025-01-01T13:00:00+00:00",
ban_count=1,
country="DE",
)
],
total=1,
)
with patch(
"app.routers.bans.ban_service.get_active_bans",
AsyncMock(return_value=mock_response),
):
resp = await bans_client.get("/api/v1/bans/active")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 1
assert data["items"][0]["ip"] == "1.2.3.4"
assert data["items"][0]["jail"] == "sshd"
async def test_401_when_unauthenticated(self, bans_client: AsyncClient, monkeypatch: pytest.MonkeyPatch) -> None:
"""GET /api/bans/active returns 401 without session."""
class FakeLogger:
def error(self, *args, **kwargs):
pass
def warning(self, *args, **kwargs):
pass
def info(self, *args, **kwargs):
pass
monkeypatch.setattr("app.main.log", FakeLogger())
resp = await AsyncClient(
transport=ASGITransport(app=bans_client._transport.app), # type: ignore[attr-defined]
base_url="http://test",
).get("/api/v1/bans/active")
assert resp.status_code == 401
async def test_empty_when_no_bans(self, bans_client: AsyncClient) -> None:
"""GET /api/bans/active returns empty list when no bans are active."""
mock_response = DomainActiveBanList(bans=[], total=0)
with patch(
"app.routers.bans.ban_service.get_active_bans",
AsyncMock(return_value=mock_response),
):
resp = await bans_client.get("/api/v1/bans/active")
assert resp.status_code == 200
assert resp.json()["total"] == 0
assert resp.json()["items"] == []
async def test_response_shape(self, bans_client: AsyncClient) -> None:
"""GET /api/bans/active returns expected fields per ban entry."""
mock_response = DomainActiveBanList(
bans=[
DomainActiveBan(
ip="10.0.0.1",
jail="nginx",
banned_at=None,
expires_at=None,
ban_count=1,
country=None,
)
],
total=1,
)
with patch(
"app.routers.bans.ban_service.get_active_bans",
AsyncMock(return_value=mock_response),
):
resp = await bans_client.get("/api/v1/bans/active")
ban = resp.json()["items"][0]
assert "ip" in ban
assert "jail" in ban
assert "banned_at" in ban
assert "expires_at" in ban
assert "ban_count" in ban
# ---------------------------------------------------------------------------
# POST /api/bans
# ---------------------------------------------------------------------------
class TestBanIp:
"""Tests for ``POST /api/bans``."""
async def test_201_on_success(self, bans_client: AsyncClient) -> None:
"""POST /api/bans returns 201 when the IP is banned."""
with patch(
"app.routers.bans.ban_service.ban_ip",
AsyncMock(return_value=None),
):
resp = await bans_client.post(
"/api/v1/bans",
json={"ip": "1.2.3.4", "jail": "sshd"},
headers={"X-BanGUI-Request": "1"},
)
assert resp.status_code == 201
assert resp.json()["jail"] == "sshd"
async def test_400_for_invalid_ip(self, bans_client: AsyncClient) -> None:
"""POST /api/bans returns 400 for an invalid IP address."""
with patch(
"app.routers.bans.ban_service.ban_ip",
AsyncMock(side_effect=ValueError("Invalid IP address: 'bad'")),
):
resp = await bans_client.post(
"/api/v1/bans",
json={"ip": "bad", "jail": "sshd"},
headers={"X-BanGUI-Request": "1"},
)
assert resp.status_code == 400
async def test_404_for_unknown_jail(self, bans_client: AsyncClient) -> None:
"""POST /api/bans returns 404 when jail does not exist."""
from app.services.jail_service import JailNotFoundError
with patch(
"app.routers.bans.ban_service.ban_ip",
AsyncMock(side_effect=JailNotFoundError("ghost")),
):
resp = await bans_client.post(
"/api/v1/bans",
json={"ip": "1.2.3.4", "jail": "ghost"},
headers={"X-BanGUI-Request": "1"},
)
assert resp.status_code == 404
async def test_401_when_unauthenticated(self, bans_client: AsyncClient) -> None:
"""POST /api/bans returns 401 without session."""
resp = await AsyncClient(
transport=ASGITransport(app=bans_client._transport.app), # type: ignore[attr-defined]
base_url="http://test",
).post("/api/v1/bans", json={"ip": "1.2.3.4", "jail": "sshd"})
assert resp.status_code == 401
# ---------------------------------------------------------------------------
# DELETE /api/bans
# ---------------------------------------------------------------------------
class TestUnbanIp:
"""Tests for ``DELETE /api/bans``."""
async def test_200_unban_from_all(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans with unban_all=true unbans from all jails."""
with patch(
"app.routers.bans.ban_service.unban_ip",
AsyncMock(return_value=None),
):
resp = await bans_client.request(
"DELETE",
"/api/v1/bans",
json={"ip": "1.2.3.4", "unban_all": True},
headers={"X-BanGUI-Request": "1"},
)
assert resp.status_code == 200
assert "all jails" in resp.json()["message"]
async def test_200_unban_from_specific_jail(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans with a jail unbans from that jail only."""
with patch(
"app.routers.bans.ban_service.unban_ip",
AsyncMock(return_value=None),
):
resp = await bans_client.request(
"DELETE",
"/api/v1/bans",
json={"ip": "1.2.3.4", "jail": "sshd"},
headers={"X-BanGUI-Request": "1"},
)
assert resp.status_code == 200
assert "sshd" in resp.json()["message"]
async def test_400_for_invalid_ip(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans returns 400 for an invalid IP."""
with patch(
"app.routers.bans.ban_service.unban_ip",
AsyncMock(side_effect=ValueError("Invalid IP address: 'bad'")),
):
resp = await bans_client.request(
"DELETE",
"/api/v1/bans",
json={"ip": "bad", "unban_all": True},
headers={"X-BanGUI-Request": "1"},
)
assert resp.status_code == 400
async def test_404_for_unknown_jail(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans returns 404 when jail does not exist."""
from app.services.jail_service import JailNotFoundError
with patch(
"app.routers.bans.ban_service.unban_ip",
AsyncMock(side_effect=JailNotFoundError("ghost")),
):
resp = await bans_client.request(
"DELETE",
"/api/v1/bans",
json={"ip": "1.2.3.4", "jail": "ghost"},
headers={"X-BanGUI-Request": "1"},
)
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# DELETE /api/bans/all
# ---------------------------------------------------------------------------
class TestUnbanAll:
"""Tests for ``DELETE /api/bans/all``."""
async def test_200_clears_all_bans(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans/all returns 200 with count when successful."""
with patch(
"app.routers.bans.jail_service.unban_all_ips",
AsyncMock(return_value=3),
):
resp = await bans_client.request("DELETE", "/api/v1/bans/all", headers={"X-BanGUI-Request": "1"})
assert resp.status_code == 200
data = resp.json()
assert data["count"] == 3
assert "3" in data["message"]
async def test_200_with_zero_count(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans/all returns 200 with count=0 when no bans existed."""
with patch(
"app.routers.bans.jail_service.unban_all_ips",
AsyncMock(return_value=0),
):
resp = await bans_client.request("DELETE", "/api/v1/bans/all", headers={"X-BanGUI-Request": "1"})
assert resp.status_code == 200
assert resp.json()["count"] == 0
async def test_502_when_fail2ban_unreachable(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans/all returns 502 when fail2ban is unreachable."""
with patch(
"app.routers.bans.jail_service.unban_all_ips",
AsyncMock(
side_effect=Fail2BanConnectionError(
"cannot connect",
"/var/run/fail2ban/fail2ban.sock",
)
),
):
resp = await bans_client.request("DELETE", "/api/v1/bans/all", headers={"X-BanGUI-Request": "1"})
assert resp.status_code == 502
async def test_401_when_unauthenticated(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans/all returns 401 without session."""
resp = await AsyncClient(
transport=ASGITransport(app=bans_client._transport.app), # type: ignore[attr-defined]
base_url="http://test",
).request("DELETE", "/api/v1/bans/all")
assert resp.status_code == 401