refactoring-backend #3
@@ -311,6 +311,59 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None]:
|
||||
|
||||
---
|
||||
|
||||
## 6.1 Database Query Conventions
|
||||
|
||||
### LIKE Queries and Wildcard Escaping
|
||||
|
||||
SQLite's `LIKE` operator treats `%` (any sequence of characters) and `_` (any single character) as wildcards. When querying with user-supplied filters that may contain these characters, you must escape them to prevent unintended matches.
|
||||
|
||||
**The Problem:**
|
||||
```python
|
||||
# Bad — ip_filter="10.0.0_" matches "10.0.0.1", "10.0.0.2", etc.
|
||||
ip_filter = "10.0.0_"
|
||||
await db.execute(
|
||||
"SELECT * FROM bans WHERE ip LIKE ?",
|
||||
(f"{ip_filter}%",) # ← wildcard characters not escaped
|
||||
)
|
||||
```
|
||||
|
||||
**The Solution:**
|
||||
|
||||
Use the `escape_like()` helper from `app.utils.fail2ban_db_utils`:
|
||||
|
||||
```python
|
||||
from app.utils.fail2ban_db_utils import escape_like
|
||||
|
||||
# Good — wildcard characters are escaped
|
||||
ip_filter = "10.0.0_"
|
||||
await db.execute(
|
||||
"SELECT * FROM bans WHERE ip LIKE ? ESCAPE '\\'",
|
||||
(f"{escape_like(ip_filter)}%",) # ← underscores escaped to literal
|
||||
)
|
||||
```
|
||||
|
||||
**How `escape_like()` works:**
|
||||
|
||||
The function escapes backslashes first, then `%` and `_` signs:
|
||||
```python
|
||||
def escape_like(s: str) -> str:
|
||||
return s.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
```
|
||||
|
||||
**Key rules:**
|
||||
1. **Backslash escapes first** — to prevent double-escaping when the input contains backslashes.
|
||||
2. **Add `ESCAPE '\\'` to the SQL** — tells SQLite which character to use for escaping.
|
||||
3. **Dots are not wildcards** — they do not need escaping; normal IP addresses pass through unchanged.
|
||||
|
||||
**Test example:**
|
||||
```python
|
||||
assert escape_like("10.0.0_") == "10.0.0\\_"
|
||||
assert escape_like("10.0.0%test") == "10.0.0\\%test"
|
||||
assert escape_like("10.0.0.1") == "10.0.0.1" # Unchanged
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. Logging
|
||||
|
||||
- Use **structlog** for every log message.
|
||||
|
||||
@@ -15,6 +15,8 @@ from typing import TYPE_CHECKING
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from app.utils.fail2ban_db_utils import escape_like
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
@@ -321,8 +323,8 @@ async def get_history_page(
|
||||
params.append(jail)
|
||||
|
||||
if ip_filter is not None:
|
||||
wheres.append("ip LIKE ?")
|
||||
params.append(f"{ip_filter}%")
|
||||
wheres.append("ip LIKE ? ESCAPE '\\'")
|
||||
params.append(f"{escape_like(ip_filter)}%")
|
||||
|
||||
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||
if origin_clause:
|
||||
|
||||
@@ -10,6 +10,7 @@ import datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from app.models.ban import BLOCKLIST_JAIL, BanOrigin
|
||||
from app.utils.fail2ban_db_utils import escape_like
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiosqlite
|
||||
@@ -76,8 +77,8 @@ async def get_archived_history(
|
||||
wheres.append(f"ip IN ({placeholder})")
|
||||
params.extend(ip_filter)
|
||||
else:
|
||||
wheres.append("ip LIKE ?")
|
||||
params.append(f"{ip_filter}%")
|
||||
wheres.append("ip LIKE ? ESCAPE '\\'")
|
||||
params.append(f"{escape_like(ip_filter)}%")
|
||||
|
||||
if origin == "blocklist":
|
||||
wheres.append("jail = ?")
|
||||
|
||||
@@ -6,6 +6,21 @@ import json
|
||||
from datetime import UTC, datetime
|
||||
|
||||
|
||||
def escape_like(s: str) -> str:
|
||||
"""Escape SQLite LIKE wildcard characters in a string.
|
||||
|
||||
SQLite's LIKE operator treats % (any sequence) and _ (any single char) as
|
||||
wildcards. This function escapes them to prevent unintended matches.
|
||||
|
||||
Args:
|
||||
s: The string to escape.
|
||||
|
||||
Returns:
|
||||
The escaped string where backslashes, %, and _ are escaped.
|
||||
"""
|
||||
return s.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_")
|
||||
|
||||
|
||||
def ts_to_iso(unix_ts: int) -> str:
|
||||
"""Convert a Unix timestamp to an ISO 8601 UTC string."""
|
||||
return datetime.fromtimestamp(unix_ts, tz=UTC).isoformat()
|
||||
|
||||
@@ -191,3 +191,68 @@ async def test_get_history_page_origin_filter(tmp_path: Path) -> None:
|
||||
assert total == 1
|
||||
assert len(page) == 1
|
||||
assert page[0].ip == "1.1.1.1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_history_page_ip_filter_with_wildcard_like_underscore(tmp_path: Path) -> None:
|
||||
"""Test that ip_filter with underscore does not trigger LIKE wildcard match."""
|
||||
db_path = str(tmp_path / "fail2ban.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_bans_table(db)
|
||||
# Insert IPs: one with dots (should match filter "10.0.0"), others with different patterns
|
||||
await db.executemany(
|
||||
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
||||
[
|
||||
("jail1", "10.0.0.1", 100, 1, "{}"),
|
||||
("jail1", "10.0.0.2", 150, 1, "{}"),
|
||||
("jail1", "10.0.0_1", 200, 1, "{}"), # This should NOT match "10.0.0_" if unescaped
|
||||
],
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
# Use ip_filter that contains underscore character
|
||||
page, total = await fail2ban_db_repo.get_history_page(
|
||||
db_path=db_path,
|
||||
since=None,
|
||||
jail=None,
|
||||
ip_filter="10.0.0_", # With underscore, should match only the exact IP
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
# Should only match the IP that starts with exactly "10.0.0_" (one IP)
|
||||
assert total == 1
|
||||
assert len(page) == 1
|
||||
assert page[0].ip == "10.0.0_1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_history_page_ip_filter_with_wildcard_like_percent(tmp_path: Path) -> None:
|
||||
"""Test that ip_filter with percent sign does not trigger LIKE wildcard match."""
|
||||
db_path = str(tmp_path / "fail2ban.db")
|
||||
async with aiosqlite.connect(db_path) as db:
|
||||
await _create_bans_table(db)
|
||||
await db.executemany(
|
||||
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
||||
[
|
||||
("jail1", "10.0.0.1", 100, 1, "{}"),
|
||||
("jail1", "10.0.0%test", 200, 1, "{}"), # IP with literal %
|
||||
],
|
||||
)
|
||||
await db.commit()
|
||||
|
||||
# Use ip_filter with percent sign - should only match IPs that start with "10.0.0%"
|
||||
page, total = await fail2ban_db_repo.get_history_page(
|
||||
db_path=db_path,
|
||||
since=None,
|
||||
jail=None,
|
||||
ip_filter="10.0.0%",
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
|
||||
# Should only match the IP with the literal % character
|
||||
assert total == 1
|
||||
assert len(page) == 1
|
||||
assert page[0].ip == "10.0.0%test"
|
||||
|
||||
|
||||
@@ -86,3 +86,37 @@ async def test_purge_archived_history(app_db: str) -> None:
|
||||
assert deleted == 1
|
||||
rows, total = await get_archived_history(db)
|
||||
assert total == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_archived_history_ip_filter_with_wildcard_like_underscore(app_db: str) -> None:
|
||||
"""Test that ip_filter with underscore does not trigger LIKE wildcard match."""
|
||||
async with aiosqlite.connect(app_db) as db:
|
||||
await archive_ban_event(db, "sshd", "10.0.0.1", 1000, 1, "{}", "ban")
|
||||
await archive_ban_event(db, "sshd", "10.0.0.2", 1100, 1, "{}", "ban")
|
||||
await archive_ban_event(db, "sshd", "10.0.0_1", 1200, 1, "{}", "ban")
|
||||
|
||||
# Use ip_filter that contains underscore - should only match the exact prefix
|
||||
rows, total = await get_archived_history(db, ip_filter="10.0.0_")
|
||||
|
||||
# Should only match the IP that starts with exactly "10.0.0_"
|
||||
assert total == 1
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["ip"] == "10.0.0_1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_archived_history_ip_filter_with_wildcard_like_percent(app_db: str) -> None:
|
||||
"""Test that ip_filter with percent sign does not trigger LIKE wildcard match."""
|
||||
async with aiosqlite.connect(app_db) as db:
|
||||
await archive_ban_event(db, "sshd", "10.0.0.1", 1000, 1, "{}", "ban")
|
||||
await archive_ban_event(db, "sshd", "10.0.0%test", 1100, 1, "{}", "ban")
|
||||
|
||||
# Use ip_filter with percent sign - should only match IPs that start with "10.0.0%"
|
||||
rows, total = await get_archived_history(db, ip_filter="10.0.0%")
|
||||
|
||||
# Should only match the IP with the literal % character
|
||||
assert total == 1
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["ip"] == "10.0.0%test"
|
||||
|
||||
|
||||
60
backend/tests/test_utils/test_fail2ban_db_utils.py
Normal file
60
backend/tests/test_utils/test_fail2ban_db_utils.py
Normal file
@@ -0,0 +1,60 @@
|
||||
"""Tests for fail2ban_db_utils module."""
|
||||
|
||||
from app.utils.fail2ban_db_utils import escape_like
|
||||
|
||||
|
||||
def test_escape_like_percent_sign() -> None:
|
||||
"""Test escaping of percent signs (% wildcard)."""
|
||||
assert escape_like("10.0.0%") == "10.0.0\\%"
|
||||
|
||||
|
||||
def test_escape_like_underscore() -> None:
|
||||
"""Test escaping of underscores (_ wildcard)."""
|
||||
assert escape_like("10.0.0_1") == "10.0.0\\_1"
|
||||
|
||||
|
||||
def test_escape_like_backslash() -> None:
|
||||
"""Test escaping of backslashes."""
|
||||
assert escape_like("10.0.0\\") == "10.0.0\\\\"
|
||||
|
||||
|
||||
def test_escape_like_combined_wildcards() -> None:
|
||||
"""Test escaping when both % and _ are present."""
|
||||
assert escape_like("10.0_%") == "10.0\\_\\%"
|
||||
|
||||
|
||||
def test_escape_like_combined_with_backslash() -> None:
|
||||
"""Test escaping backslash first, then wildcards."""
|
||||
assert escape_like("10\\0_%") == "10\\\\0\\_\\%"
|
||||
|
||||
|
||||
def test_escape_like_normal_ip() -> None:
|
||||
"""Test that normal IPs pass through unchanged (dots are not wildcards)."""
|
||||
assert escape_like("10.0.0.1") == "10.0.0.1"
|
||||
|
||||
|
||||
def test_escape_like_empty_string() -> None:
|
||||
"""Test escaping empty string."""
|
||||
assert escape_like("") == ""
|
||||
|
||||
|
||||
def test_escape_like_only_backslash() -> None:
|
||||
"""Test string with only backslashes."""
|
||||
assert escape_like("\\\\") == "\\\\\\\\"
|
||||
|
||||
|
||||
def test_escape_like_only_percent() -> None:
|
||||
"""Test string with only percent signs."""
|
||||
assert escape_like("%%%") == "\\%\\%\\%"
|
||||
|
||||
|
||||
def test_escape_like_only_underscore() -> None:
|
||||
"""Test string with only underscores."""
|
||||
assert escape_like("___") == "\\_\\_\\_"
|
||||
|
||||
|
||||
def test_escape_like_backslash_before_wildcard() -> None:
|
||||
"""Test that backslash before wildcard is properly escaped."""
|
||||
result = escape_like("10\\_%")
|
||||
# Expected: backslash → \\ , underscore → \_ , percent → \%
|
||||
assert result == "10\\\\\\_\\%"
|
||||
Reference in New Issue
Block a user