"""Tests for the bans router endpoints.""" from __future__ import annotations from collections.abc import AsyncGenerator from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import aiosqlite import bcrypt import pytest from httpx import ASGITransport, AsyncClient from app.config import Settings from app.db import init_db from app.exceptions import Fail2BanConnectionError from app.main import create_app from app.models.ban_domain import DomainActiveBan, DomainActiveBanList from app.services.geo_cache import GeoCache from app.utils.session_cache import NoOpSessionCache from app.utils.setup_state import set_setup_complete_cache async def _write_password_hash(db: aiosqlite.Connection, password: str) -> str: """Hash password and write to settings table.""" pw_bytes = password.encode() import asyncio hashed = await asyncio.get_event_loop().run_in_executor( None, lambda: bcrypt.hashpw(pw_bytes, bcrypt.gensalt()).decode() ) await db.execute( "INSERT OR REPLACE INTO settings (key, value) VALUES (?, ?)", ("master_password_hash", hashed), ) await db.commit() return hashed @pytest.fixture async def bans_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc] """Provide an authenticated ``AsyncClient`` for bans endpoint tests.""" (tmp_path / "fail2ban").mkdir() settings = Settings( database_path=str(tmp_path / "bans_test.db"), fail2ban_socket="/tmp/fake.sock", session_secret="test-bans-secret-that-is-at-least-32-chars", session_duration_minutes=60, timezone="UTC", log_level="debug", fail2ban_config_dir=str(tmp_path / "fail2ban"), session_cache_enabled=False, session_cookie_secure=False, ) app = create_app(settings=settings) set_setup_complete_cache(app, True) db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path) db.row_factory = aiosqlite.Row await init_db(db) await _write_password_hash(db, _SETUP_PAYLOAD["master_password"]) app.state.db = db app.state.http_session = MagicMock() app.state.session_cache = NoOpSessionCache() app.state.geo_cache = GeoCache() async def _override_get_db() -> AsyncGenerator[aiosqlite.Connection, None]: yield db from app.dependencies import get_db, get_session_cache app.dependency_overrides[get_db] = _override_get_db app.dependency_overrides[get_session_cache] = lambda: NoOpSessionCache() transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: login = await ac.post( "/api/v1/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}, ) assert login.status_code == 200 yield ac await db.close() app.dependency_overrides.clear() # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- _SETUP_PAYLOAD = { "master_password": "Testpass1!", "fail2ban_socket": "/var/run/fail2ban/fail2ban.sock", "timezone": "UTC", "session_duration_minutes": 60, "database_path": "bans_test.db", } # --------------------------------------------------------------------------- # GET /api/bans/active # --------------------------------------------------------------------------- class TestGetActiveBans: """Tests for ``GET /api/bans/active``.""" async def test_200_when_authenticated(self, bans_client: AsyncClient) -> None: """GET /api/bans/active returns 200 with an ActiveBanListResponse.""" from app.models.ban_domain import DomainActiveBan, DomainActiveBanList mock_response = DomainActiveBanList( bans=[ DomainActiveBan( ip="1.2.3.4", jail="sshd", banned_at="2025-01-01T12:00:00+00:00", expires_at="2025-01-01T13:00:00+00:00", ban_count=1, country="DE", ) ], total=1, ) with patch( "app.routers.bans.ban_service.get_active_bans", AsyncMock(return_value=mock_response), ): resp = await bans_client.get("/api/v1/bans/active") assert resp.status_code == 200 data = resp.json() assert data["total"] == 1 assert data["items"][0]["ip"] == "1.2.3.4" assert data["items"][0]["jail"] == "sshd" async def test_401_when_unauthenticated(self, bans_client: AsyncClient, monkeypatch: pytest.MonkeyPatch) -> None: """GET /api/bans/active returns 401 without session.""" class FakeLogger: def error(self, *args, **kwargs): pass def warning(self, *args, **kwargs): pass def info(self, *args, **kwargs): pass monkeypatch.setattr("app.main.log", FakeLogger()) resp = await AsyncClient( transport=ASGITransport(app=bans_client._transport.app), # type: ignore[attr-defined] base_url="http://test", ).get("/api/v1/bans/active") assert resp.status_code == 401 async def test_empty_when_no_bans(self, bans_client: AsyncClient) -> None: """GET /api/bans/active returns empty list when no bans are active.""" mock_response = DomainActiveBanList(bans=[], total=0) with patch( "app.routers.bans.ban_service.get_active_bans", AsyncMock(return_value=mock_response), ): resp = await bans_client.get("/api/v1/bans/active") assert resp.status_code == 200 assert resp.json()["total"] == 0 assert resp.json()["items"] == [] async def test_response_shape(self, bans_client: AsyncClient) -> None: """GET /api/bans/active returns expected fields per ban entry.""" mock_response = DomainActiveBanList( bans=[ DomainActiveBan( ip="10.0.0.1", jail="nginx", banned_at=None, expires_at=None, ban_count=1, country=None, ) ], total=1, ) with patch( "app.routers.bans.ban_service.get_active_bans", AsyncMock(return_value=mock_response), ): resp = await bans_client.get("/api/v1/bans/active") ban = resp.json()["items"][0] assert "ip" in ban assert "jail" in ban assert "banned_at" in ban assert "expires_at" in ban assert "ban_count" in ban # --------------------------------------------------------------------------- # POST /api/bans # --------------------------------------------------------------------------- class TestBanIp: """Tests for ``POST /api/bans``.""" async def test_201_on_success(self, bans_client: AsyncClient) -> None: """POST /api/bans returns 201 when the IP is banned.""" with patch( "app.routers.bans.ban_service.ban_ip", AsyncMock(return_value=None), ): resp = await bans_client.post( "/api/v1/bans", json={"ip": "1.2.3.4", "jail": "sshd"}, headers={"X-BanGUI-Request": "1"}, ) assert resp.status_code == 201 assert resp.json()["jail"] == "sshd" async def test_400_for_invalid_ip(self, bans_client: AsyncClient) -> None: """POST /api/bans returns 400 for an invalid IP address.""" with patch( "app.routers.bans.ban_service.ban_ip", AsyncMock(side_effect=ValueError("Invalid IP address: 'bad'")), ): resp = await bans_client.post( "/api/v1/bans", json={"ip": "bad", "jail": "sshd"}, headers={"X-BanGUI-Request": "1"}, ) assert resp.status_code == 400 async def test_404_for_unknown_jail(self, bans_client: AsyncClient) -> None: """POST /api/bans returns 404 when jail does not exist.""" from app.services.jail_service import JailNotFoundError with patch( "app.routers.bans.ban_service.ban_ip", AsyncMock(side_effect=JailNotFoundError("ghost")), ): resp = await bans_client.post( "/api/v1/bans", json={"ip": "1.2.3.4", "jail": "ghost"}, headers={"X-BanGUI-Request": "1"}, ) assert resp.status_code == 404 async def test_401_when_unauthenticated(self, bans_client: AsyncClient) -> None: """POST /api/bans returns 401 without session.""" resp = await AsyncClient( transport=ASGITransport(app=bans_client._transport.app), # type: ignore[attr-defined] base_url="http://test", ).post("/api/v1/bans", json={"ip": "1.2.3.4", "jail": "sshd"}) assert resp.status_code == 401 # --------------------------------------------------------------------------- # DELETE /api/bans # --------------------------------------------------------------------------- class TestUnbanIp: """Tests for ``DELETE /api/bans``.""" async def test_200_unban_from_all(self, bans_client: AsyncClient) -> None: """DELETE /api/bans with unban_all=true unbans from all jails.""" with patch( "app.routers.bans.ban_service.unban_ip", AsyncMock(return_value=None), ): resp = await bans_client.request( "DELETE", "/api/v1/bans", json={"ip": "1.2.3.4", "unban_all": True}, headers={"X-BanGUI-Request": "1"}, ) assert resp.status_code == 200 assert "all jails" in resp.json()["message"] async def test_200_unban_from_specific_jail(self, bans_client: AsyncClient) -> None: """DELETE /api/bans with a jail unbans from that jail only.""" with patch( "app.routers.bans.ban_service.unban_ip", AsyncMock(return_value=None), ): resp = await bans_client.request( "DELETE", "/api/v1/bans", json={"ip": "1.2.3.4", "jail": "sshd"}, headers={"X-BanGUI-Request": "1"}, ) assert resp.status_code == 200 assert "sshd" in resp.json()["message"] async def test_400_for_invalid_ip(self, bans_client: AsyncClient) -> None: """DELETE /api/bans returns 400 for an invalid IP.""" with patch( "app.routers.bans.ban_service.unban_ip", AsyncMock(side_effect=ValueError("Invalid IP address: 'bad'")), ): resp = await bans_client.request( "DELETE", "/api/v1/bans", json={"ip": "bad", "unban_all": True}, headers={"X-BanGUI-Request": "1"}, ) assert resp.status_code == 400 async def test_404_for_unknown_jail(self, bans_client: AsyncClient) -> None: """DELETE /api/bans returns 404 when jail does not exist.""" from app.services.jail_service import JailNotFoundError with patch( "app.routers.bans.ban_service.unban_ip", AsyncMock(side_effect=JailNotFoundError("ghost")), ): resp = await bans_client.request( "DELETE", "/api/v1/bans", json={"ip": "1.2.3.4", "jail": "ghost"}, headers={"X-BanGUI-Request": "1"}, ) assert resp.status_code == 404 # --------------------------------------------------------------------------- # DELETE /api/bans/all # --------------------------------------------------------------------------- class TestUnbanAll: """Tests for ``DELETE /api/bans/all``.""" async def test_200_clears_all_bans(self, bans_client: AsyncClient) -> None: """DELETE /api/bans/all returns 200 with count when successful.""" with patch( "app.routers.bans.jail_service.unban_all_ips", AsyncMock(return_value=3), ): resp = await bans_client.request("DELETE", "/api/v1/bans/all", headers={"X-BanGUI-Request": "1"}) assert resp.status_code == 200 data = resp.json() assert data["count"] == 3 assert "3" in data["message"] async def test_200_with_zero_count(self, bans_client: AsyncClient) -> None: """DELETE /api/bans/all returns 200 with count=0 when no bans existed.""" with patch( "app.routers.bans.jail_service.unban_all_ips", AsyncMock(return_value=0), ): resp = await bans_client.request("DELETE", "/api/v1/bans/all", headers={"X-BanGUI-Request": "1"}) assert resp.status_code == 200 assert resp.json()["count"] == 0 async def test_502_when_fail2ban_unreachable(self, bans_client: AsyncClient) -> None: """DELETE /api/bans/all returns 502 when fail2ban is unreachable.""" with patch( "app.routers.bans.jail_service.unban_all_ips", AsyncMock( side_effect=Fail2BanConnectionError( "cannot connect", "/var/run/fail2ban/fail2ban.sock", ) ), ): resp = await bans_client.request("DELETE", "/api/v1/bans/all", headers={"X-BanGUI-Request": "1"}) assert resp.status_code == 502 async def test_401_when_unauthenticated(self, bans_client: AsyncClient) -> None: """DELETE /api/bans/all returns 401 without session.""" resp = await AsyncClient( transport=ASGITransport(app=bans_client._transport.app), # type: ignore[attr-defined] base_url="http://test", ).request("DELETE", "/api/v1/bans/all") assert resp.status_code == 401