- New get_fail2ban_db_path() in setup_service resolves DB path from configured socket path - New ensure_fail2ban_indexes() creates missing performance indexes on bans table - Call ensure_fail2ban_indexes on every startup before first ban query - Remove completed tasks from Docs/Tasks.md - Update Docs/PERFORMANCE.md with index findings
115 lines
3.5 KiB
Python
115 lines
3.5 KiB
Python
"""Tests for fail2ban_db_utils module."""
|
|
|
|
import sqlite3
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from app.utils.fail2ban_db_utils import (
|
|
ensure_fail2ban_indexes,
|
|
escape_like,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
def tmp_bans_table(tmp_path: Path) -> str:
|
|
"""Create a minimal fail2ban-style database with bans table."""
|
|
db_path = str(tmp_path / "test_f2b.db")
|
|
conn = sqlite3.connect(db_path)
|
|
conn.execute("CREATE TABLE bans (jail, ip, timeofban, bancount, data)")
|
|
conn.execute("CREATE INDEX idx_jail_timeofban_ip ON bans(jail, timeofban)")
|
|
conn.execute("CREATE INDEX idx_jail_ip ON bans(jail, ip)")
|
|
conn.execute("CREATE INDEX idx_ip ON bans(ip)")
|
|
conn.commit()
|
|
conn.close()
|
|
return db_path
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ensure_fail2ban_indexes_creates_missing_index(tmp_bans_table: str) -> None:
|
|
"""Index is created when idx_bans_timeofban_desc does not exist."""
|
|
await ensure_fail2ban_indexes(tmp_bans_table)
|
|
|
|
conn = sqlite3.connect(tmp_bans_table)
|
|
conn.row_factory = sqlite3.Row
|
|
cur = conn.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='index' AND tbl_name='bans'"
|
|
)
|
|
index_names = [str(r["name"]) for r in cur.fetchall()]
|
|
conn.close()
|
|
|
|
assert "idx_bans_timeofban_desc" in index_names
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_ensure_fail2ban_indexes_idempotent(tmp_bans_table: str) -> None:
|
|
"""Calling twice does not raise or duplicate the index."""
|
|
await ensure_fail2ban_indexes(tmp_bans_table)
|
|
await ensure_fail2ban_indexes(tmp_bans_table)
|
|
|
|
conn = sqlite3.connect(tmp_bans_table)
|
|
cur = conn.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='index' AND tbl_name='bans' AND name='idx_bans_timeofban_desc'"
|
|
)
|
|
count = len(cur.fetchall())
|
|
conn.close()
|
|
|
|
assert count == 1
|
|
|
|
|
|
def test_escape_like_percent_sign() -> None:
|
|
"""Test escaping of percent signs (% wildcard)."""
|
|
assert escape_like("10.0.0%") == "10.0.0\\%"
|
|
|
|
|
|
def test_escape_like_underscore() -> None:
|
|
"""Test escaping of underscores (_ wildcard)."""
|
|
assert escape_like("10.0.0_1") == "10.0.0\\_1"
|
|
|
|
|
|
def test_escape_like_backslash() -> None:
|
|
"""Test escaping of backslashes."""
|
|
assert escape_like("10.0.0\\") == "10.0.0\\\\"
|
|
|
|
|
|
def test_escape_like_combined_wildcards() -> None:
|
|
"""Test escaping when both % and _ are present."""
|
|
assert escape_like("10.0_%") == "10.0\\_\\%"
|
|
|
|
|
|
def test_escape_like_combined_with_backslash() -> None:
|
|
"""Test escaping backslash first, then wildcards."""
|
|
assert escape_like("10\\0_%") == "10\\\\0\\_\\%"
|
|
|
|
|
|
def test_escape_like_normal_ip() -> None:
|
|
"""Test that normal IPs pass through unchanged (dots are not wildcards)."""
|
|
assert escape_like("10.0.0.1") == "10.0.0.1"
|
|
|
|
|
|
def test_escape_like_empty_string() -> None:
|
|
"""Test escaping empty string."""
|
|
assert escape_like("") == ""
|
|
|
|
|
|
def test_escape_like_only_backslash() -> None:
|
|
"""Test string with only backslashes."""
|
|
assert escape_like("\\\\") == "\\\\\\\\"
|
|
|
|
|
|
def test_escape_like_only_percent() -> None:
|
|
"""Test string with only percent signs."""
|
|
assert escape_like("%%%") == "\\%\\%\\%"
|
|
|
|
|
|
def test_escape_like_only_underscore() -> None:
|
|
"""Test string with only underscores."""
|
|
assert escape_like("___") == "\\_\\_\\_"
|
|
|
|
|
|
def test_escape_like_backslash_before_wildcard() -> None:
|
|
"""Test that backslash before wildcard is properly escaped."""
|
|
result = escape_like("10\\_%")
|
|
# Expected: backslash → \\ , underscore → \_ , percent → \%
|
|
assert result == "10\\\\\\_\\%"
|