"""Tests for ban_service.list_bans().""" from __future__ import annotations import json import time from pathlib import Path from typing import Any from unittest.mock import AsyncMock, patch import aiosqlite import pytest from app.services import ban_service # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- _NOW: int = int(time.time()) _ONE_HOUR_AGO: int = _NOW - 3600 _TWO_DAYS_AGO: int = _NOW - 2 * 24 * 3600 async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None: """Create a minimal fail2ban SQLite database with the given ban rows. Args: path: Filesystem path for the new SQLite file. rows: Sequence of dicts with keys ``jail``, ``ip``, ``timeofban``, ``bantime``, ``bancount``, and optionally ``data``. """ async with aiosqlite.connect(path) as db: await db.execute( "CREATE TABLE jails (" "name TEXT NOT NULL UNIQUE, " "enabled INTEGER NOT NULL DEFAULT 1" ")" ) await db.execute( "CREATE TABLE bans (" "jail TEXT NOT NULL, " "ip TEXT, " "timeofban INTEGER NOT NULL, " "bantime INTEGER NOT NULL, " "bancount INTEGER NOT NULL DEFAULT 1, " "data JSON" ")" ) for row in rows: await db.execute( "INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) " "VALUES (?, ?, ?, ?, ?, ?)", ( row["jail"], row["ip"], row["timeofban"], row.get("bantime", 3600), row.get("bancount", 1), json.dumps(row["data"]) if "data" in row else None, ), ) await db.commit() @pytest.fixture async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc] """Return the path to a test fail2ban SQLite database with several bans.""" path = str(tmp_path / "fail2ban_test.sqlite3") await _create_f2b_db( path, [ { "jail": "sshd", "ip": "1.2.3.4", "timeofban": _ONE_HOUR_AGO, "bantime": 3600, "bancount": 2, "data": { "matches": ["Nov 10 10:00 sshd[123]: Failed password for root"], "failures": 5, }, }, { "jail": "nginx", "ip": "5.6.7.8", "timeofban": _ONE_HOUR_AGO, "bantime": 7200, "bancount": 1, "data": {"matches": ["GET /admin HTTP/1.1"], "failures": 3}, }, { "jail": "sshd", "ip": "9.10.11.12", "timeofban": _TWO_DAYS_AGO, "bantime": 3600, "bancount": 1, "data": {"failures": 6}, # no matches }, ], ) return path @pytest.fixture async def empty_f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc] """Return the path to a fail2ban SQLite database with no ban records.""" path = str(tmp_path / "fail2ban_empty.sqlite3") await _create_f2b_db(path, []) return path # --------------------------------------------------------------------------- # list_bans — happy path # --------------------------------------------------------------------------- class TestListBansHappyPath: """Verify ban_service.list_bans() under normal conditions.""" async def test_returns_bans_in_range(self, f2b_db_path: str) -> None: """Only bans within the selected range are returned.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): result = await ban_service.list_bans("/fake/sock", "24h") # Two bans within last 24 h; one is 2 days old and excluded. assert result.total == 2 assert len(result.items) == 2 async def test_results_sorted_newest_first(self, f2b_db_path: str) -> None: """Items are ordered by ``banned_at`` descending (newest first).""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): result = await ban_service.list_bans("/fake/sock", "24h") timestamps = [item.banned_at for item in result.items] assert timestamps == sorted(timestamps, reverse=True) async def test_ban_fields_present(self, f2b_db_path: str) -> None: """Each item contains ip, jail, banned_at, ban_count.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): result = await ban_service.list_bans("/fake/sock", "24h") for item in result.items: assert item.ip assert item.jail assert item.banned_at assert item.ban_count >= 1 async def test_service_extracted_from_first_match(self, f2b_db_path: str) -> None: """``service`` field is the first element of ``data.matches``.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): result = await ban_service.list_bans("/fake/sock", "24h") sshd_item = next(i for i in result.items if i.jail == "sshd") assert sshd_item.service is not None assert "Failed password" in sshd_item.service async def test_service_is_none_when_no_matches(self, f2b_db_path: str) -> None: """``service`` is ``None`` when the ban has no stored matches.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): # Use 7d to include the older ban with no matches. result = await ban_service.list_bans("/fake/sock", "7d") no_match = next(i for i in result.items if i.ip == "9.10.11.12") assert no_match.service is None async def test_empty_db_returns_zero(self, empty_f2b_db_path: str) -> None: """When no bans exist the result has total=0 and no items.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=empty_f2b_db_path), ): result = await ban_service.list_bans("/fake/sock", "24h") assert result.total == 0 assert result.items == [] async def test_365d_range_includes_old_bans(self, f2b_db_path: str) -> None: """The ``365d`` range includes bans that are 2 days old.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): result = await ban_service.list_bans("/fake/sock", "365d") assert result.total == 3 # --------------------------------------------------------------------------- # list_bans — geo enrichment # --------------------------------------------------------------------------- class TestListBansGeoEnrichment: """Verify geo enrichment integration in ban_service.list_bans().""" async def test_geo_data_applied_when_enricher_provided( self, f2b_db_path: str ) -> None: """Geo fields are populated when an enricher returns data.""" from app.services.geo_service import GeoInfo async def fake_enricher(ip: str) -> GeoInfo: return GeoInfo( country_code="DE", country_name="Germany", asn="AS3320", org="Deutsche Telekom", ) with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): result = await ban_service.list_bans( "/fake/sock", "24h", geo_enricher=fake_enricher ) for item in result.items: assert item.country_code == "DE" assert item.country_name == "Germany" assert item.asn == "AS3320" async def test_geo_failure_does_not_break_results( self, f2b_db_path: str ) -> None: """A geo enricher that raises still returns ban items (geo fields null).""" async def failing_enricher(ip: str) -> None: raise RuntimeError("geo service down") with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): result = await ban_service.list_bans( "/fake/sock", "24h", geo_enricher=failing_enricher ) assert result.total == 2 for item in result.items: assert item.country_code is None # --------------------------------------------------------------------------- # list_bans — pagination # --------------------------------------------------------------------------- class TestListBansPagination: """Verify pagination parameters in list_bans().""" async def test_page_size_respected(self, f2b_db_path: str) -> None: """``page_size=1`` returns at most one item.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): result = await ban_service.list_bans("/fake/sock", "7d", page_size=1) assert len(result.items) == 1 assert result.page_size == 1 async def test_page_2_returns_remaining_items(self, f2b_db_path: str) -> None: """The second page returns items not on the first page.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): page1 = await ban_service.list_bans("/fake/sock", "7d", page=1, page_size=1) page2 = await ban_service.list_bans("/fake/sock", "7d", page=2, page_size=1) # Different IPs should appear on different pages. assert page1.items[0].ip != page2.items[0].ip async def test_total_reflects_full_count_not_page_count( self, f2b_db_path: str ) -> None: """``total`` reports all matching records regardless of pagination.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): result = await ban_service.list_bans("/fake/sock", "7d", page_size=1) assert result.total == 3 # All three bans are within 7d.