diff --git a/Docs/Backend-Development.md b/Docs/Backend-Development.md index f75ec3b..6652fb6 100644 --- a/Docs/Backend-Development.md +++ b/Docs/Backend-Development.md @@ -311,6 +311,59 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None]: --- +## 6.1 Database Query Conventions + +### LIKE Queries and Wildcard Escaping + +SQLite's `LIKE` operator treats `%` (any sequence of characters) and `_` (any single character) as wildcards. When querying with user-supplied filters that may contain these characters, you must escape them to prevent unintended matches. + +**The Problem:** +```python +# Bad — ip_filter="10.0.0_" matches "10.0.0.1", "10.0.0.2", etc. +ip_filter = "10.0.0_" +await db.execute( + "SELECT * FROM bans WHERE ip LIKE ?", + (f"{ip_filter}%",) # ← wildcard characters not escaped +) +``` + +**The Solution:** + +Use the `escape_like()` helper from `app.utils.fail2ban_db_utils`: + +```python +from app.utils.fail2ban_db_utils import escape_like + +# Good — wildcard characters are escaped +ip_filter = "10.0.0_" +await db.execute( + "SELECT * FROM bans WHERE ip LIKE ? ESCAPE '\\'", + (f"{escape_like(ip_filter)}%",) # ← underscores escaped to literal +) +``` + +**How `escape_like()` works:** + +The function escapes backslashes first, then `%` and `_` signs: +```python +def escape_like(s: str) -> str: + return s.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") +``` + +**Key rules:** +1. **Backslash escapes first** — to prevent double-escaping when the input contains backslashes. +2. **Add `ESCAPE '\\'` to the SQL** — tells SQLite which character to use for escaping. +3. **Dots are not wildcards** — they do not need escaping; normal IP addresses pass through unchanged. + +**Test example:** +```python +assert escape_like("10.0.0_") == "10.0.0\\_" +assert escape_like("10.0.0%test") == "10.0.0\\%test" +assert escape_like("10.0.0.1") == "10.0.0.1" # Unchanged +``` + +--- + ## 7. Logging - Use **structlog** for every log message. diff --git a/backend/app/repositories/fail2ban_db_repo.py b/backend/app/repositories/fail2ban_db_repo.py index ce6c10c..5b9b0c8 100644 --- a/backend/app/repositories/fail2ban_db_repo.py +++ b/backend/app/repositories/fail2ban_db_repo.py @@ -15,6 +15,8 @@ from typing import TYPE_CHECKING import aiosqlite +from app.utils.fail2ban_db_utils import escape_like + if TYPE_CHECKING: from collections.abc import Iterable @@ -321,8 +323,8 @@ async def get_history_page( params.append(jail) if ip_filter is not None: - wheres.append("ip LIKE ?") - params.append(f"{ip_filter}%") + wheres.append("ip LIKE ? ESCAPE '\\'") + params.append(f"{escape_like(ip_filter)}%") origin_clause, origin_params = _origin_sql_filter(origin) if origin_clause: diff --git a/backend/app/repositories/history_archive_repo.py b/backend/app/repositories/history_archive_repo.py index 738a591..891cc41 100644 --- a/backend/app/repositories/history_archive_repo.py +++ b/backend/app/repositories/history_archive_repo.py @@ -10,6 +10,7 @@ import datetime from typing import TYPE_CHECKING, Any from app.models.ban import BLOCKLIST_JAIL, BanOrigin +from app.utils.fail2ban_db_utils import escape_like if TYPE_CHECKING: import aiosqlite @@ -76,8 +77,8 @@ async def get_archived_history( wheres.append(f"ip IN ({placeholder})") params.extend(ip_filter) else: - wheres.append("ip LIKE ?") - params.append(f"{ip_filter}%") + wheres.append("ip LIKE ? ESCAPE '\\'") + params.append(f"{escape_like(ip_filter)}%") if origin == "blocklist": wheres.append("jail = ?") diff --git a/backend/app/utils/fail2ban_db_utils.py b/backend/app/utils/fail2ban_db_utils.py index 703a00b..a1d9479 100644 --- a/backend/app/utils/fail2ban_db_utils.py +++ b/backend/app/utils/fail2ban_db_utils.py @@ -6,6 +6,21 @@ import json from datetime import UTC, datetime +def escape_like(s: str) -> str: + """Escape SQLite LIKE wildcard characters in a string. + + SQLite's LIKE operator treats % (any sequence) and _ (any single char) as + wildcards. This function escapes them to prevent unintended matches. + + Args: + s: The string to escape. + + Returns: + The escaped string where backslashes, %, and _ are escaped. + """ + return s.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") + + def ts_to_iso(unix_ts: int) -> str: """Convert a Unix timestamp to an ISO 8601 UTC string.""" return datetime.fromtimestamp(unix_ts, tz=UTC).isoformat() diff --git a/backend/tests/test_repositories/test_fail2ban_db_repo.py b/backend/tests/test_repositories/test_fail2ban_db_repo.py index 5f0c429..3a29ecd 100644 --- a/backend/tests/test_repositories/test_fail2ban_db_repo.py +++ b/backend/tests/test_repositories/test_fail2ban_db_repo.py @@ -191,3 +191,68 @@ async def test_get_history_page_origin_filter(tmp_path: Path) -> None: assert total == 1 assert len(page) == 1 assert page[0].ip == "1.1.1.1" + + +@pytest.mark.asyncio +async def test_get_history_page_ip_filter_with_wildcard_like_underscore(tmp_path: Path) -> None: + """Test that ip_filter with underscore does not trigger LIKE wildcard match.""" + db_path = str(tmp_path / "fail2ban.db") + async with aiosqlite.connect(db_path) as db: + await _create_bans_table(db) + # Insert IPs: one with dots (should match filter "10.0.0"), others with different patterns + await db.executemany( + "INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)", + [ + ("jail1", "10.0.0.1", 100, 1, "{}"), + ("jail1", "10.0.0.2", 150, 1, "{}"), + ("jail1", "10.0.0_1", 200, 1, "{}"), # This should NOT match "10.0.0_" if unescaped + ], + ) + await db.commit() + + # Use ip_filter that contains underscore character + page, total = await fail2ban_db_repo.get_history_page( + db_path=db_path, + since=None, + jail=None, + ip_filter="10.0.0_", # With underscore, should match only the exact IP + page=1, + page_size=10, + ) + + # Should only match the IP that starts with exactly "10.0.0_" (one IP) + assert total == 1 + assert len(page) == 1 + assert page[0].ip == "10.0.0_1" + + +@pytest.mark.asyncio +async def test_get_history_page_ip_filter_with_wildcard_like_percent(tmp_path: Path) -> None: + """Test that ip_filter with percent sign does not trigger LIKE wildcard match.""" + db_path = str(tmp_path / "fail2ban.db") + async with aiosqlite.connect(db_path) as db: + await _create_bans_table(db) + await db.executemany( + "INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)", + [ + ("jail1", "10.0.0.1", 100, 1, "{}"), + ("jail1", "10.0.0%test", 200, 1, "{}"), # IP with literal % + ], + ) + await db.commit() + + # Use ip_filter with percent sign - should only match IPs that start with "10.0.0%" + page, total = await fail2ban_db_repo.get_history_page( + db_path=db_path, + since=None, + jail=None, + ip_filter="10.0.0%", + page=1, + page_size=10, + ) + + # Should only match the IP with the literal % character + assert total == 1 + assert len(page) == 1 + assert page[0].ip == "10.0.0%test" + diff --git a/backend/tests/test_repositories/test_history_archive_repo.py b/backend/tests/test_repositories/test_history_archive_repo.py index 1a86cc5..e69b1d5 100644 --- a/backend/tests/test_repositories/test_history_archive_repo.py +++ b/backend/tests/test_repositories/test_history_archive_repo.py @@ -86,3 +86,37 @@ async def test_purge_archived_history(app_db: str) -> None: assert deleted == 1 rows, total = await get_archived_history(db) assert total == 1 + + +@pytest.mark.asyncio +async def test_get_archived_history_ip_filter_with_wildcard_like_underscore(app_db: str) -> None: + """Test that ip_filter with underscore does not trigger LIKE wildcard match.""" + async with aiosqlite.connect(app_db) as db: + await archive_ban_event(db, "sshd", "10.0.0.1", 1000, 1, "{}", "ban") + await archive_ban_event(db, "sshd", "10.0.0.2", 1100, 1, "{}", "ban") + await archive_ban_event(db, "sshd", "10.0.0_1", 1200, 1, "{}", "ban") + + # Use ip_filter that contains underscore - should only match the exact prefix + rows, total = await get_archived_history(db, ip_filter="10.0.0_") + + # Should only match the IP that starts with exactly "10.0.0_" + assert total == 1 + assert len(rows) == 1 + assert rows[0]["ip"] == "10.0.0_1" + + +@pytest.mark.asyncio +async def test_get_archived_history_ip_filter_with_wildcard_like_percent(app_db: str) -> None: + """Test that ip_filter with percent sign does not trigger LIKE wildcard match.""" + async with aiosqlite.connect(app_db) as db: + await archive_ban_event(db, "sshd", "10.0.0.1", 1000, 1, "{}", "ban") + await archive_ban_event(db, "sshd", "10.0.0%test", 1100, 1, "{}", "ban") + + # Use ip_filter with percent sign - should only match IPs that start with "10.0.0%" + rows, total = await get_archived_history(db, ip_filter="10.0.0%") + + # Should only match the IP with the literal % character + assert total == 1 + assert len(rows) == 1 + assert rows[0]["ip"] == "10.0.0%test" + diff --git a/backend/tests/test_utils/test_fail2ban_db_utils.py b/backend/tests/test_utils/test_fail2ban_db_utils.py new file mode 100644 index 0000000..40849fb --- /dev/null +++ b/backend/tests/test_utils/test_fail2ban_db_utils.py @@ -0,0 +1,60 @@ +"""Tests for fail2ban_db_utils module.""" + +from app.utils.fail2ban_db_utils import escape_like + + +def test_escape_like_percent_sign() -> None: + """Test escaping of percent signs (% wildcard).""" + assert escape_like("10.0.0%") == "10.0.0\\%" + + +def test_escape_like_underscore() -> None: + """Test escaping of underscores (_ wildcard).""" + assert escape_like("10.0.0_1") == "10.0.0\\_1" + + +def test_escape_like_backslash() -> None: + """Test escaping of backslashes.""" + assert escape_like("10.0.0\\") == "10.0.0\\\\" + + +def test_escape_like_combined_wildcards() -> None: + """Test escaping when both % and _ are present.""" + assert escape_like("10.0_%") == "10.0\\_\\%" + + +def test_escape_like_combined_with_backslash() -> None: + """Test escaping backslash first, then wildcards.""" + assert escape_like("10\\0_%") == "10\\\\0\\_\\%" + + +def test_escape_like_normal_ip() -> None: + """Test that normal IPs pass through unchanged (dots are not wildcards).""" + assert escape_like("10.0.0.1") == "10.0.0.1" + + +def test_escape_like_empty_string() -> None: + """Test escaping empty string.""" + assert escape_like("") == "" + + +def test_escape_like_only_backslash() -> None: + """Test string with only backslashes.""" + assert escape_like("\\\\") == "\\\\\\\\" + + +def test_escape_like_only_percent() -> None: + """Test string with only percent signs.""" + assert escape_like("%%%") == "\\%\\%\\%" + + +def test_escape_like_only_underscore() -> None: + """Test string with only underscores.""" + assert escape_like("___") == "\\_\\_\\_" + + +def test_escape_like_backslash_before_wildcard() -> None: + """Test that backslash before wildcard is properly escaped.""" + result = escape_like("10\\_%") + # Expected: backslash → \\ , underscore → \_ , percent → \% + assert result == "10\\\\\\_\\%"