Files
BanGUI/backend/tests/test_routers/test_bans.py
Lukas 96ce516ecf fix(logging): resolve logging_compat keyword arg conflicts
- Fix logging_compat._log() to handle extra keyword arguments properly
- Update config.py, main.py, and test_bans.py for compatibility
- Update Tasks.md and runner.csx
2026-05-10 15:54:00 +02:00

353 lines
12 KiB
Python

"""Tests for the bans router endpoints."""
from __future__ import annotations
from collections.abc import AsyncGenerator
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock, patch
import aiosqlite
import pytest
from httpx import ASGITransport, AsyncClient
from app.config import Settings
from app.db import init_db
from app.main import create_app
from app.models.ban import ActiveBan, ActiveBanListResponse
from app.exceptions import Fail2BanConnectionError
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
_SETUP_PAYLOAD = {
"master_password": "Testpass1!",
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
"timezone": "UTC",
"session_duration_minutes": 60,
}
@pytest.fixture
async def bans_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc]
"""Provide an authenticated ``AsyncClient`` for bans endpoint tests."""
(tmp_path / "fail2ban").mkdir()
settings = Settings(
database_path=str(tmp_path / "bans_test.db"),
fail2ban_socket="/tmp/fake.sock",
session_secret="test-bans-secret-that-is-at-least-32-chars",
session_duration_minutes=60,
timezone="UTC",
log_level="debug",
fail2ban_config_dir=str(tmp_path / "fail2ban"),
session_cache_enabled=False,
)
app = create_app(settings=settings)
db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path)
db.row_factory = aiosqlite.Row
await init_db(db)
app.state.db = db
app.state.http_session = MagicMock()
async def _override_get_db() -> AsyncGenerator[aiosqlite.Connection, None]:
yield db
from app.dependencies import get_db
app.dependency_overrides[get_db] = _override_get_db
transport = ASGITransport(app=app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
await ac.post("/api/v1/setup", json=_SETUP_PAYLOAD)
login = await ac.post(
"/api/v1/auth/login",
json={"password": _SETUP_PAYLOAD["master_password"]},
)
assert login.status_code == 200
yield ac
await db.close()
app.dependency_overrides.clear()
# ---------------------------------------------------------------------------
# GET /api/bans/active
# ---------------------------------------------------------------------------
class TestGetActiveBans:
"""Tests for ``GET /api/bans/active``."""
async def test_200_when_authenticated(self, bans_client: AsyncClient) -> None:
"""GET /api/bans/active returns 200 with an ActiveBanListResponse."""
mock_response = ActiveBanListResponse(
bans=[
ActiveBan(
ip="1.2.3.4",
jail="sshd",
banned_at="2025-01-01T12:00:00+00:00",
expires_at="2025-01-01T13:00:00+00:00",
ban_count=1,
country="DE",
)
],
total=1,
)
with patch(
"app.routers.bans.ban_service.get_active_bans",
AsyncMock(return_value=mock_response),
):
resp = await bans_client.get("/api/v1/bans/active")
assert resp.status_code == 200
data = resp.json()
assert data["total"] == 1
assert data["bans"][0]["ip"] == "1.2.3.4"
assert data["bans"][0]["jail"] == "sshd"
async def test_401_when_unauthenticated(
self, bans_client: AsyncClient, monkeypatch: pytest.MonkeyPatch
) -> None:
"""GET /api/bans/active returns 401 without session."""
import logging
from unittest.mock import MagicMock
class FakeLogger:
def error(self, *args, **kwargs): pass
def warning(self, *args, **kwargs): pass
def info(self, *args, **kwargs): pass
monkeypatch.setattr("app.main.log", FakeLogger())
resp = await AsyncClient(
transport=ASGITransport(app=bans_client._transport.app), # type: ignore[attr-defined]
base_url="http://test",
).get("/api/v1/bans/active")
assert resp.status_code == 401
async def test_empty_when_no_bans(self, bans_client: AsyncClient) -> None:
"""GET /api/bans/active returns empty list when no bans are active."""
mock_response = ActiveBanListResponse(bans=[], total=0)
with patch(
"app.routers.bans.ban_service.get_active_bans",
AsyncMock(return_value=mock_response),
):
resp = await bans_client.get("/api/v1/bans/active")
assert resp.status_code == 200
assert resp.json()["total"] == 0
assert resp.json()["bans"] == []
async def test_response_shape(self, bans_client: AsyncClient) -> None:
"""GET /api/bans/active returns expected fields per ban entry."""
mock_response = ActiveBanListResponse(
bans=[
ActiveBan(
ip="10.0.0.1",
jail="nginx",
banned_at=None,
expires_at=None,
ban_count=1,
country=None,
)
],
total=1,
)
with patch(
"app.routers.bans.ban_service.get_active_bans",
AsyncMock(return_value=mock_response),
):
resp = await bans_client.get("/api/v1/bans/active")
ban = resp.json()["bans"][0]
assert "ip" in ban
assert "jail" in ban
assert "banned_at" in ban
assert "expires_at" in ban
assert "ban_count" in ban
# ---------------------------------------------------------------------------
# POST /api/bans
# ---------------------------------------------------------------------------
class TestBanIp:
"""Tests for ``POST /api/bans``."""
async def test_201_on_success(self, bans_client: AsyncClient) -> None:
"""POST /api/bans returns 201 when the IP is banned."""
with patch(
"app.routers.bans.ban_service.ban_ip",
AsyncMock(return_value=None),
):
resp = await bans_client.post(
"/api/v1/bans",
json={"ip": "1.2.3.4", "jail": "sshd"},
)
assert resp.status_code == 201
assert resp.json()["jail"] == "sshd"
async def test_400_for_invalid_ip(self, bans_client: AsyncClient) -> None:
"""POST /api/bans returns 400 for an invalid IP address."""
with patch(
"app.routers.bans.ban_service.ban_ip",
AsyncMock(side_effect=ValueError("Invalid IP address: 'bad'")),
):
resp = await bans_client.post(
"/api/v1/bans",
json={"ip": "bad", "jail": "sshd"},
)
assert resp.status_code == 400
async def test_404_for_unknown_jail(self, bans_client: AsyncClient) -> None:
"""POST /api/bans returns 404 when jail does not exist."""
from app.services.jail_service import JailNotFoundError
with patch(
"app.routers.bans.ban_service.ban_ip",
AsyncMock(side_effect=JailNotFoundError("ghost")),
):
resp = await bans_client.post(
"/api/v1/bans",
json={"ip": "1.2.3.4", "jail": "ghost"},
)
assert resp.status_code == 404
async def test_401_when_unauthenticated(self, bans_client: AsyncClient) -> None:
"""POST /api/bans returns 401 without session."""
resp = await AsyncClient(
transport=ASGITransport(app=bans_client._transport.app), # type: ignore[attr-defined]
base_url="http://test",
).post("/api/v1/bans", json={"ip": "1.2.3.4", "jail": "sshd"})
assert resp.status_code == 401
# ---------------------------------------------------------------------------
# DELETE /api/bans
# ---------------------------------------------------------------------------
class TestUnbanIp:
"""Tests for ``DELETE /api/bans``."""
async def test_200_unban_from_all(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans with unban_all=true unbans from all jails."""
with patch(
"app.routers.bans.ban_service.unban_ip",
AsyncMock(return_value=None),
):
resp = await bans_client.request(
"DELETE",
"/api/v1/bans",
json={"ip": "1.2.3.4", "unban_all": True},
)
assert resp.status_code == 200
assert "all jails" in resp.json()["message"]
async def test_200_unban_from_specific_jail(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans with a jail unbans from that jail only."""
with patch(
"app.routers.bans.ban_service.unban_ip",
AsyncMock(return_value=None),
):
resp = await bans_client.request(
"DELETE",
"/api/v1/bans",
json={"ip": "1.2.3.4", "jail": "sshd"},
)
assert resp.status_code == 200
assert "sshd" in resp.json()["message"]
async def test_400_for_invalid_ip(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans returns 400 for an invalid IP."""
with patch(
"app.routers.bans.ban_service.unban_ip",
AsyncMock(side_effect=ValueError("Invalid IP address: 'bad'")),
):
resp = await bans_client.request(
"DELETE",
"/api/v1/bans",
json={"ip": "bad", "unban_all": True},
)
assert resp.status_code == 400
async def test_404_for_unknown_jail(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans returns 404 when jail does not exist."""
from app.services.jail_service import JailNotFoundError
with patch(
"app.routers.bans.ban_service.unban_ip",
AsyncMock(side_effect=JailNotFoundError("ghost")),
):
resp = await bans_client.request(
"DELETE",
"/api/v1/bans",
json={"ip": "1.2.3.4", "jail": "ghost"},
)
assert resp.status_code == 404
# ---------------------------------------------------------------------------
# DELETE /api/bans/all
# ---------------------------------------------------------------------------
class TestUnbanAll:
"""Tests for ``DELETE /api/bans/all``."""
async def test_200_clears_all_bans(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans/all returns 200 with count when successful."""
with patch(
"app.routers.bans.jail_service.unban_all_ips",
AsyncMock(return_value=3),
):
resp = await bans_client.request("DELETE", "/api/v1/bans/all")
assert resp.status_code == 200
data = resp.json()
assert data["count"] == 3
assert "3" in data["message"]
async def test_200_with_zero_count(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans/all returns 200 with count=0 when no bans existed."""
with patch(
"app.routers.bans.jail_service.unban_all_ips",
AsyncMock(return_value=0),
):
resp = await bans_client.request("DELETE", "/api/v1/bans/all")
assert resp.status_code == 200
assert resp.json()["count"] == 0
async def test_502_when_fail2ban_unreachable(
self, bans_client: AsyncClient
) -> None:
"""DELETE /api/bans/all returns 502 when fail2ban is unreachable."""
with patch(
"app.routers.bans.jail_service.unban_all_ips",
AsyncMock(
side_effect=Fail2BanConnectionError(
"cannot connect",
"/var/run/fail2ban/fail2ban.sock",
)
),
):
resp = await bans_client.request("DELETE", "/api/v1/bans/all")
assert resp.status_code == 502
async def test_401_when_unauthenticated(self, bans_client: AsyncClient) -> None:
"""DELETE /api/bans/all returns 401 without session."""
resp = await AsyncClient(
transport=ASGITransport(app=bans_client._transport.app), # type: ignore[attr-defined]
base_url="http://test",
).request("DELETE", "/api/v1/bans/all")
assert resp.status_code == 401