"""Tests for fail2ban_db_utils module.""" import sqlite3 from pathlib import Path import pytest from app.utils.fail2ban_db_utils import ( ensure_fail2ban_indexes, escape_like, ) @pytest.fixture def tmp_bans_table(tmp_path: Path) -> str: """Create a minimal fail2ban-style database with bans table.""" db_path = str(tmp_path / "test_f2b.db") conn = sqlite3.connect(db_path) conn.execute("CREATE TABLE bans (jail, ip, timeofban, bancount, data)") conn.execute("CREATE INDEX idx_jail_timeofban_ip ON bans(jail, timeofban)") conn.execute("CREATE INDEX idx_jail_ip ON bans(jail, ip)") conn.execute("CREATE INDEX idx_ip ON bans(ip)") conn.commit() conn.close() return db_path @pytest.mark.asyncio async def test_ensure_fail2ban_indexes_creates_missing_index(tmp_bans_table: str) -> None: """Index is created when idx_bans_timeofban_desc does not exist.""" await ensure_fail2ban_indexes(tmp_bans_table) conn = sqlite3.connect(tmp_bans_table) conn.row_factory = sqlite3.Row cur = conn.execute( "SELECT name FROM sqlite_master WHERE type='index' AND tbl_name='bans'" ) index_names = [str(r["name"]) for r in cur.fetchall()] conn.close() assert "idx_bans_timeofban_desc" in index_names @pytest.mark.asyncio async def test_ensure_fail2ban_indexes_idempotent(tmp_bans_table: str) -> None: """Calling twice does not raise or duplicate the index.""" await ensure_fail2ban_indexes(tmp_bans_table) await ensure_fail2ban_indexes(tmp_bans_table) conn = sqlite3.connect(tmp_bans_table) cur = conn.execute( "SELECT name FROM sqlite_master WHERE type='index' AND tbl_name='bans' AND name='idx_bans_timeofban_desc'" ) count = len(cur.fetchall()) conn.close() assert count == 1 def test_escape_like_percent_sign() -> None: """Test escaping of percent signs (% wildcard).""" assert escape_like("10.0.0%") == "10.0.0\\%" def test_escape_like_underscore() -> None: """Test escaping of underscores (_ wildcard).""" assert escape_like("10.0.0_1") == "10.0.0\\_1" def test_escape_like_backslash() -> None: """Test escaping of backslashes.""" assert escape_like("10.0.0\\") == "10.0.0\\\\" def test_escape_like_combined_wildcards() -> None: """Test escaping when both % and _ are present.""" assert escape_like("10.0_%") == "10.0\\_\\%" def test_escape_like_combined_with_backslash() -> None: """Test escaping backslash first, then wildcards.""" assert escape_like("10\\0_%") == "10\\\\0\\_\\%" def test_escape_like_normal_ip() -> None: """Test that normal IPs pass through unchanged (dots are not wildcards).""" assert escape_like("10.0.0.1") == "10.0.0.1" def test_escape_like_empty_string() -> None: """Test escaping empty string.""" assert escape_like("") == "" def test_escape_like_only_backslash() -> None: """Test string with only backslashes.""" assert escape_like("\\\\") == "\\\\\\\\" def test_escape_like_only_percent() -> None: """Test string with only percent signs.""" assert escape_like("%%%") == "\\%\\%\\%" def test_escape_like_only_underscore() -> None: """Test string with only underscores.""" assert escape_like("___") == "\\_\\_\\_" def test_escape_like_backslash_before_wildcard() -> None: """Test that backslash before wildcard is properly escaped.""" result = escape_like("10\\_%") # Expected: backslash → \\ , underscore → \_ , percent → \% assert result == "10\\\\\\_\\%"