"""Tests for ban_service.list_bans().""" from __future__ import annotations import json import time from pathlib import Path from typing import Any from unittest.mock import AsyncMock, patch import aiosqlite import pytest from app.services import ban_service # --------------------------------------------------------------------------- # Fixtures # --------------------------------------------------------------------------- _NOW: int = int(time.time()) _ONE_HOUR_AGO: int = _NOW - 3600 _TWO_DAYS_AGO: int = _NOW - 2 * 24 * 3600 async def _create_f2b_db(path: str, rows: list[dict[str, Any]]) -> None: """Create a minimal fail2ban SQLite database with the given ban rows. Args: path: Filesystem path for the new SQLite file. rows: Sequence of dicts with keys ``jail``, ``ip``, ``timeofban``, ``bantime``, ``bancount``, and optionally ``data``. """ async with aiosqlite.connect(path) as db: await db.execute( "CREATE TABLE jails (" "name TEXT NOT NULL UNIQUE, " "enabled INTEGER NOT NULL DEFAULT 1" ")" ) await db.execute( "CREATE TABLE bans (" "jail TEXT NOT NULL, " "ip TEXT, " "timeofban INTEGER NOT NULL, " "bantime INTEGER NOT NULL, " "bancount INTEGER NOT NULL DEFAULT 1, " "data JSON" ")" ) for row in rows: await db.execute( "INSERT INTO bans (jail, ip, timeofban, bantime, bancount, data) " "VALUES (?, ?, ?, ?, ?, ?)", ( row["jail"], row["ip"], row["timeofban"], row.get("bantime", 3600), row.get("bancount", 1), json.dumps(row["data"]) if "data" in row else None, ), ) await db.commit() @pytest.fixture async def f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc] """Return the path to a test fail2ban SQLite database with several bans.""" path = str(tmp_path / "fail2ban_test.sqlite3") await _create_f2b_db( path, [ { "jail": "sshd", "ip": "1.2.3.4", "timeofban": _ONE_HOUR_AGO, "bantime": 3600, "bancount": 2, "data": { "matches": ["Nov 10 10:00 sshd[123]: Failed password for root"], "failures": 5, }, }, { "jail": "nginx", "ip": "5.6.7.8", "timeofban": _ONE_HOUR_AGO, "bantime": 7200, "bancount": 1, "data": {"matches": ["GET /admin HTTP/1.1"], "failures": 3}, }, { "jail": "sshd", "ip": "9.10.11.12", "timeofban": _TWO_DAYS_AGO, "bantime": 3600, "bancount": 1, "data": {"failures": 6}, # no matches }, ], ) return path @pytest.fixture async def mixed_origin_db_path(tmp_path: Path) -> str: # type: ignore[misc] """Return a database with bans from both blocklist-import and organic jails.""" path = str(tmp_path / "fail2ban_mixed_origin.sqlite3") await _create_f2b_db( path, [ { "jail": "blocklist-import", "ip": "10.0.0.1", "timeofban": _ONE_HOUR_AGO, "bantime": -1, "bancount": 1, }, { "jail": "sshd", "ip": "10.0.0.2", "timeofban": _ONE_HOUR_AGO, "bantime": 3600, "bancount": 3, }, { "jail": "nginx", "ip": "10.0.0.3", "timeofban": _ONE_HOUR_AGO, "bantime": 7200, "bancount": 1, }, ], ) return path @pytest.fixture async def empty_f2b_db_path(tmp_path: Path) -> str: # type: ignore[misc] """Return the path to a fail2ban SQLite database with no ban records.""" path = str(tmp_path / "fail2ban_empty.sqlite3") await _create_f2b_db(path, []) return path # --------------------------------------------------------------------------- # list_bans — happy path # --------------------------------------------------------------------------- class TestListBansHappyPath: """Verify ban_service.list_bans() under normal conditions.""" async def test_returns_bans_in_range(self, f2b_db_path: str) -> None: """Only bans within the selected range are returned.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): result = await ban_service.list_bans("/fake/sock", "24h") # Two bans within last 24 h; one is 2 days old and excluded. assert result.total == 2 assert len(result.items) == 2 async def test_results_sorted_newest_first(self, f2b_db_path: str) -> None: """Items are ordered by ``banned_at`` descending (newest first).""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): result = await ban_service.list_bans("/fake/sock", "24h") timestamps = [item.banned_at for item in result.items] assert timestamps == sorted(timestamps, reverse=True) async def test_ban_fields_present(self, f2b_db_path: str) -> None: """Each item contains ip, jail, banned_at, ban_count.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): result = await ban_service.list_bans("/fake/sock", "24h") for item in result.items: assert item.ip assert item.jail assert item.banned_at assert item.ban_count >= 1 async def test_service_extracted_from_first_match(self, f2b_db_path: str) -> None: """``service`` field is the first element of ``data.matches``.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): result = await ban_service.list_bans("/fake/sock", "24h") sshd_item = next(i for i in result.items if i.jail == "sshd") assert sshd_item.service is not None assert "Failed password" in sshd_item.service async def test_service_is_none_when_no_matches(self, f2b_db_path: str) -> None: """``service`` is ``None`` when the ban has no stored matches.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): # Use 7d to include the older ban with no matches. result = await ban_service.list_bans("/fake/sock", "7d") no_match = next(i for i in result.items if i.ip == "9.10.11.12") assert no_match.service is None async def test_empty_db_returns_zero(self, empty_f2b_db_path: str) -> None: """When no bans exist the result has total=0 and no items.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=empty_f2b_db_path), ): result = await ban_service.list_bans("/fake/sock", "24h") assert result.total == 0 assert result.items == [] async def test_365d_range_includes_old_bans(self, f2b_db_path: str) -> None: """The ``365d`` range includes bans that are 2 days old.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): result = await ban_service.list_bans("/fake/sock", "365d") assert result.total == 3 # --------------------------------------------------------------------------- # list_bans — geo enrichment # --------------------------------------------------------------------------- class TestListBansGeoEnrichment: """Verify geo enrichment integration in ban_service.list_bans().""" async def test_geo_data_applied_when_enricher_provided( self, f2b_db_path: str ) -> None: """Geo fields are populated when an enricher returns data.""" from app.services.geo_service import GeoInfo async def fake_enricher(ip: str) -> GeoInfo: return GeoInfo( country_code="DE", country_name="Germany", asn="AS3320", org="Deutsche Telekom", ) with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): result = await ban_service.list_bans( "/fake/sock", "24h", geo_enricher=fake_enricher ) for item in result.items: assert item.country_code == "DE" assert item.country_name == "Germany" assert item.asn == "AS3320" async def test_geo_failure_does_not_break_results( self, f2b_db_path: str ) -> None: """A geo enricher that raises still returns ban items (geo fields null).""" async def failing_enricher(ip: str) -> None: raise RuntimeError("geo service down") with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): result = await ban_service.list_bans( "/fake/sock", "24h", geo_enricher=failing_enricher ) assert result.total == 2 for item in result.items: assert item.country_code is None # --------------------------------------------------------------------------- # list_bans — pagination # --------------------------------------------------------------------------- class TestListBansPagination: """Verify pagination parameters in list_bans().""" async def test_page_size_respected(self, f2b_db_path: str) -> None: """``page_size=1`` returns at most one item.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): result = await ban_service.list_bans("/fake/sock", "7d", page_size=1) assert len(result.items) == 1 assert result.page_size == 1 async def test_page_2_returns_remaining_items(self, f2b_db_path: str) -> None: """The second page returns items not on the first page.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): page1 = await ban_service.list_bans("/fake/sock", "7d", page=1, page_size=1) page2 = await ban_service.list_bans("/fake/sock", "7d", page=2, page_size=1) # Different IPs should appear on different pages. assert page1.items[0].ip != page2.items[0].ip async def test_total_reflects_full_count_not_page_count( self, f2b_db_path: str ) -> None: """``total`` reports all matching records regardless of pagination.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=f2b_db_path), ): result = await ban_service.list_bans("/fake/sock", "7d", page_size=1) assert result.total == 3 # All three bans are within 7d. # --------------------------------------------------------------------------- # list_bans / bans_by_country — origin derivation # --------------------------------------------------------------------------- class TestBanOriginDerivation: """Verify that ban_service correctly derives ``origin`` from jail names.""" async def test_blocklist_import_jail_yields_blocklist_origin( self, mixed_origin_db_path: str ) -> None: """Bans from ``blocklist-import`` jail carry ``origin == "blocklist"``.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=mixed_origin_db_path), ): result = await ban_service.list_bans("/fake/sock", "24h") blocklist_items = [i for i in result.items if i.jail == "blocklist-import"] assert len(blocklist_items) == 1 assert blocklist_items[0].origin == "blocklist" async def test_organic_jail_yields_selfblock_origin( self, mixed_origin_db_path: str ) -> None: """Bans from organic jails (sshd, nginx, …) carry ``origin == "selfblock"``.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=mixed_origin_db_path), ): result = await ban_service.list_bans("/fake/sock", "24h") organic_items = [i for i in result.items if i.jail != "blocklist-import"] assert len(organic_items) == 2 for item in organic_items: assert item.origin == "selfblock" async def test_all_items_carry_origin_field( self, mixed_origin_db_path: str ) -> None: """Every returned item has an ``origin`` field with a valid value.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=mixed_origin_db_path), ): result = await ban_service.list_bans("/fake/sock", "24h") for item in result.items: assert item.origin in ("blocklist", "selfblock") async def test_bans_by_country_blocklist_origin( self, mixed_origin_db_path: str ) -> None: """``bans_by_country`` also derives origin correctly for blocklist bans.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=mixed_origin_db_path), ): result = await ban_service.bans_by_country("/fake/sock", "24h") blocklist_bans = [b for b in result.bans if b.jail == "blocklist-import"] assert len(blocklist_bans) == 1 assert blocklist_bans[0].origin == "blocklist" async def test_bans_by_country_selfblock_origin( self, mixed_origin_db_path: str ) -> None: """``bans_by_country`` derives origin correctly for organic jails.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=mixed_origin_db_path), ): result = await ban_service.bans_by_country("/fake/sock", "24h") organic_bans = [b for b in result.bans if b.jail != "blocklist-import"] assert len(organic_bans) == 2 for ban in organic_bans: assert ban.origin == "selfblock" # --------------------------------------------------------------------------- # list_bans / bans_by_country — origin filter parameter # --------------------------------------------------------------------------- class TestOriginFilter: """Verify that the origin filter correctly restricts results.""" async def test_list_bans_blocklist_filter_returns_only_blocklist( self, mixed_origin_db_path: str ) -> None: """``origin='blocklist'`` returns only blocklist-import jail bans.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=mixed_origin_db_path), ): result = await ban_service.list_bans( "/fake/sock", "24h", origin="blocklist" ) assert result.total == 1 assert len(result.items) == 1 assert result.items[0].jail == "blocklist-import" assert result.items[0].origin == "blocklist" async def test_list_bans_selfblock_filter_excludes_blocklist( self, mixed_origin_db_path: str ) -> None: """``origin='selfblock'`` excludes the blocklist-import jail.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=mixed_origin_db_path), ): result = await ban_service.list_bans( "/fake/sock", "24h", origin="selfblock" ) assert result.total == 2 assert len(result.items) == 2 for item in result.items: assert item.jail != "blocklist-import" assert item.origin == "selfblock" async def test_list_bans_no_filter_returns_all( self, mixed_origin_db_path: str ) -> None: """``origin=None`` applies no jail restriction — all bans returned.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=mixed_origin_db_path), ): result = await ban_service.list_bans("/fake/sock", "24h", origin=None) assert result.total == 3 async def test_bans_by_country_blocklist_filter( self, mixed_origin_db_path: str ) -> None: """``bans_by_country`` with ``origin='blocklist'`` counts only blocklist bans.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=mixed_origin_db_path), ): result = await ban_service.bans_by_country( "/fake/sock", "24h", origin="blocklist" ) assert result.total == 1 assert all(b.jail == "blocklist-import" for b in result.bans) async def test_bans_by_country_selfblock_filter( self, mixed_origin_db_path: str ) -> None: """``bans_by_country`` with ``origin='selfblock'`` excludes blocklist jails.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=mixed_origin_db_path), ): result = await ban_service.bans_by_country( "/fake/sock", "24h", origin="selfblock" ) assert result.total == 2 assert all(b.jail != "blocklist-import" for b in result.bans) async def test_bans_by_country_no_filter_returns_all( self, mixed_origin_db_path: str ) -> None: """``bans_by_country`` with ``origin=None`` returns all bans.""" with patch( "app.services.ban_service._get_fail2ban_db_path", new=AsyncMock(return_value=mixed_origin_db_path), ): result = await ban_service.bans_by_country( "/fake/sock", "24h", origin=None ) assert result.total == 3