"""Tests for the bans router endpoints.""" from __future__ import annotations from pathlib import Path from unittest.mock import AsyncMock, MagicMock, patch import aiosqlite import pytest from httpx import ASGITransport, AsyncClient from app.config import Settings from app.db import init_db from app.main import create_app from app.models.ban import ActiveBan, ActiveBanListResponse from app.utils.fail2ban_client import Fail2BanConnectionError # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- _SETUP_PAYLOAD = { "master_password": "testpassword1", "database_path": "bangui.db", "fail2ban_socket": "/var/run/fail2ban/fail2ban.sock", "timezone": "UTC", "session_duration_minutes": 60, } @pytest.fixture async def bans_client(tmp_path: Path) -> AsyncClient: # type: ignore[misc] """Provide an authenticated ``AsyncClient`` for bans endpoint tests.""" settings = Settings( database_path=str(tmp_path / "bans_test.db"), fail2ban_socket="/tmp/fake.sock", session_secret="test-bans-secret", session_duration_minutes=60, timezone="UTC", log_level="debug", ) app = create_app(settings=settings) db: aiosqlite.Connection = await aiosqlite.connect(settings.database_path) db.row_factory = aiosqlite.Row await init_db(db) app.state.db = db app.state.http_session = MagicMock() transport = ASGITransport(app=app) async with AsyncClient(transport=transport, base_url="http://test") as ac: await ac.post("/api/setup", json=_SETUP_PAYLOAD) login = await ac.post( "/api/auth/login", json={"password": _SETUP_PAYLOAD["master_password"]}, ) assert login.status_code == 200 yield ac await db.close() # --------------------------------------------------------------------------- # GET /api/bans/active # --------------------------------------------------------------------------- class TestGetActiveBans: """Tests for ``GET /api/bans/active``.""" async def test_200_when_authenticated(self, bans_client: AsyncClient) -> None: """GET /api/bans/active returns 200 with an ActiveBanListResponse.""" mock_response = ActiveBanListResponse( bans=[ ActiveBan( ip="1.2.3.4", jail="sshd", banned_at="2025-01-01T12:00:00+00:00", expires_at="2025-01-01T13:00:00+00:00", ban_count=1, country="DE", ) ], total=1, ) with patch( "app.routers.bans.jail_service.get_active_bans", AsyncMock(return_value=mock_response), ): resp = await bans_client.get("/api/bans/active") assert resp.status_code == 200 data = resp.json() assert data["total"] == 1 assert data["bans"][0]["ip"] == "1.2.3.4" assert data["bans"][0]["jail"] == "sshd" async def test_401_when_unauthenticated(self, bans_client: AsyncClient) -> None: """GET /api/bans/active returns 401 without session.""" resp = await AsyncClient( transport=ASGITransport(app=bans_client._transport.app), # type: ignore[attr-defined] base_url="http://test", ).get("/api/bans/active") assert resp.status_code == 401 async def test_empty_when_no_bans(self, bans_client: AsyncClient) -> None: """GET /api/bans/active returns empty list when no bans are active.""" mock_response = ActiveBanListResponse(bans=[], total=0) with patch( "app.routers.bans.jail_service.get_active_bans", AsyncMock(return_value=mock_response), ): resp = await bans_client.get("/api/bans/active") assert resp.status_code == 200 assert resp.json()["total"] == 0 assert resp.json()["bans"] == [] async def test_response_shape(self, bans_client: AsyncClient) -> None: """GET /api/bans/active returns expected fields per ban entry.""" mock_response = ActiveBanListResponse( bans=[ ActiveBan( ip="10.0.0.1", jail="nginx", banned_at=None, expires_at=None, ban_count=1, country=None, ) ], total=1, ) with patch( "app.routers.bans.jail_service.get_active_bans", AsyncMock(return_value=mock_response), ): resp = await bans_client.get("/api/bans/active") ban = resp.json()["bans"][0] assert "ip" in ban assert "jail" in ban assert "banned_at" in ban assert "expires_at" in ban assert "ban_count" in ban # --------------------------------------------------------------------------- # POST /api/bans # --------------------------------------------------------------------------- class TestBanIp: """Tests for ``POST /api/bans``.""" async def test_201_on_success(self, bans_client: AsyncClient) -> None: """POST /api/bans returns 201 when the IP is banned.""" with patch( "app.routers.bans.jail_service.ban_ip", AsyncMock(return_value=None), ): resp = await bans_client.post( "/api/bans", json={"ip": "1.2.3.4", "jail": "sshd"}, ) assert resp.status_code == 201 assert resp.json()["jail"] == "sshd" async def test_400_for_invalid_ip(self, bans_client: AsyncClient) -> None: """POST /api/bans returns 400 for an invalid IP address.""" with patch( "app.routers.bans.jail_service.ban_ip", AsyncMock(side_effect=ValueError("Invalid IP address: 'bad'")), ): resp = await bans_client.post( "/api/bans", json={"ip": "bad", "jail": "sshd"}, ) assert resp.status_code == 400 async def test_404_for_unknown_jail(self, bans_client: AsyncClient) -> None: """POST /api/bans returns 404 when jail does not exist.""" from app.services.jail_service import JailNotFoundError with patch( "app.routers.bans.jail_service.ban_ip", AsyncMock(side_effect=JailNotFoundError("ghost")), ): resp = await bans_client.post( "/api/bans", json={"ip": "1.2.3.4", "jail": "ghost"}, ) assert resp.status_code == 404 async def test_401_when_unauthenticated(self, bans_client: AsyncClient) -> None: """POST /api/bans returns 401 without session.""" resp = await AsyncClient( transport=ASGITransport(app=bans_client._transport.app), # type: ignore[attr-defined] base_url="http://test", ).post("/api/bans", json={"ip": "1.2.3.4", "jail": "sshd"}) assert resp.status_code == 401 # --------------------------------------------------------------------------- # DELETE /api/bans # --------------------------------------------------------------------------- class TestUnbanIp: """Tests for ``DELETE /api/bans``.""" async def test_200_unban_from_all(self, bans_client: AsyncClient) -> None: """DELETE /api/bans with unban_all=true unbans from all jails.""" with patch( "app.routers.bans.jail_service.unban_ip", AsyncMock(return_value=None), ): resp = await bans_client.request( "DELETE", "/api/bans", json={"ip": "1.2.3.4", "unban_all": True}, ) assert resp.status_code == 200 assert "all jails" in resp.json()["message"] async def test_200_unban_from_specific_jail(self, bans_client: AsyncClient) -> None: """DELETE /api/bans with a jail unbans from that jail only.""" with patch( "app.routers.bans.jail_service.unban_ip", AsyncMock(return_value=None), ): resp = await bans_client.request( "DELETE", "/api/bans", json={"ip": "1.2.3.4", "jail": "sshd"}, ) assert resp.status_code == 200 assert "sshd" in resp.json()["message"] async def test_400_for_invalid_ip(self, bans_client: AsyncClient) -> None: """DELETE /api/bans returns 400 for an invalid IP.""" with patch( "app.routers.bans.jail_service.unban_ip", AsyncMock(side_effect=ValueError("Invalid IP address: 'bad'")), ): resp = await bans_client.request( "DELETE", "/api/bans", json={"ip": "bad", "unban_all": True}, ) assert resp.status_code == 400 async def test_404_for_unknown_jail(self, bans_client: AsyncClient) -> None: """DELETE /api/bans returns 404 when jail does not exist.""" from app.services.jail_service import JailNotFoundError with patch( "app.routers.bans.jail_service.unban_ip", AsyncMock(side_effect=JailNotFoundError("ghost")), ): resp = await bans_client.request( "DELETE", "/api/bans", json={"ip": "1.2.3.4", "jail": "ghost"}, ) assert resp.status_code == 404 # --------------------------------------------------------------------------- # DELETE /api/bans/all # --------------------------------------------------------------------------- class TestUnbanAll: """Tests for ``DELETE /api/bans/all``.""" async def test_200_clears_all_bans(self, bans_client: AsyncClient) -> None: """DELETE /api/bans/all returns 200 with count when successful.""" with patch( "app.routers.bans.jail_service.unban_all_ips", AsyncMock(return_value=3), ): resp = await bans_client.request("DELETE", "/api/bans/all") assert resp.status_code == 200 data = resp.json() assert data["count"] == 3 assert "3" in data["message"] async def test_200_with_zero_count(self, bans_client: AsyncClient) -> None: """DELETE /api/bans/all returns 200 with count=0 when no bans existed.""" with patch( "app.routers.bans.jail_service.unban_all_ips", AsyncMock(return_value=0), ): resp = await bans_client.request("DELETE", "/api/bans/all") assert resp.status_code == 200 assert resp.json()["count"] == 0 async def test_502_when_fail2ban_unreachable( self, bans_client: AsyncClient ) -> None: """DELETE /api/bans/all returns 502 when fail2ban is unreachable.""" with patch( "app.routers.bans.jail_service.unban_all_ips", AsyncMock( side_effect=Fail2BanConnectionError( "cannot connect", "/var/run/fail2ban/fail2ban.sock", ) ), ): resp = await bans_client.request("DELETE", "/api/bans/all") assert resp.status_code == 502 async def test_401_when_unauthenticated(self, bans_client: AsyncClient) -> None: """DELETE /api/bans/all returns 401 without session.""" resp = await AsyncClient( transport=ASGITransport(app=bans_client._transport.app), # type: ignore[attr-defined] base_url="http://test", ).request("DELETE", "/api/bans/all") assert resp.status_code == 401