Files
BanGUI/backend/tests/test_repositories/test_fail2ban_db_repo.py

168 lines
4.9 KiB
Python

"""Tests for the fail2ban_db repository.
These tests use an in-memory sqlite file created under pytest's tmp_path and
exercise the core query functions used by the services.
"""
from pathlib import Path
import aiosqlite
import pytest
from app.repositories import fail2ban_db_repo
async def _create_bans_table(db: aiosqlite.Connection) -> None:
await db.execute(
"""
CREATE TABLE bans (
jail TEXT,
ip TEXT,
timeofban INTEGER,
bancount INTEGER,
data TEXT
)
"""
)
await db.commit()
@pytest.mark.asyncio
async def test_check_db_nonempty_returns_false_when_table_is_empty(tmp_path: Path) -> None:
db_path = str(tmp_path / "fail2ban.db")
async with aiosqlite.connect(db_path) as db:
await _create_bans_table(db)
assert await fail2ban_db_repo.check_db_nonempty(db_path) is False
@pytest.mark.asyncio
async def test_check_db_nonempty_returns_true_when_row_exists(tmp_path: Path) -> None:
db_path = str(tmp_path / "fail2ban.db")
async with aiosqlite.connect(db_path) as db:
await _create_bans_table(db)
await db.execute(
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
("jail1", "1.2.3.4", 123, 1, "{}"),
)
await db.commit()
assert await fail2ban_db_repo.check_db_nonempty(db_path) is True
@pytest.mark.asyncio
async def test_get_currently_banned_filters_and_pagination(tmp_path: Path) -> None:
db_path = str(tmp_path / "fail2ban.db")
async with aiosqlite.connect(db_path) as db:
await _create_bans_table(db)
# Three bans; one is from the blocklist-import jail.
await db.executemany(
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
[
("jail1", "1.1.1.1", 10, 1, "{}"),
("blocklist-import", "2.2.2.2", 20, 2, "{}"),
("jail1", "3.3.3.3", 30, 3, "{}"),
],
)
await db.commit()
records, total = await fail2ban_db_repo.get_currently_banned(
db_path=db_path,
since=15,
origin="selfblock",
limit=10,
offset=0,
)
# Only the non-blocklist row with timeofban >= 15 should remain.
assert total == 1
assert len(records) == 1
assert records[0].ip == "3.3.3.3"
@pytest.mark.asyncio
async def test_get_ban_counts_by_bucket_ignores_out_of_range_buckets(tmp_path: Path) -> None:
db_path = str(tmp_path / "fail2ban.db")
async with aiosqlite.connect(db_path) as db:
await _create_bans_table(db)
await db.executemany(
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
[
("jail1", "1.1.1.1", 5, 1, "{}"),
("jail1", "2.2.2.2", 15, 1, "{}"),
("jail1", "3.3.3.3", 35, 1, "{}"),
],
)
await db.commit()
counts = await fail2ban_db_repo.get_ban_counts_by_bucket(
db_path=db_path,
since=0,
bucket_secs=10,
num_buckets=3,
)
assert counts == [1, 1, 0]
@pytest.mark.asyncio
async def test_get_history_page_and_for_ip(tmp_path: Path) -> None:
db_path = str(tmp_path / "fail2ban.db")
async with aiosqlite.connect(db_path) as db:
await _create_bans_table(db)
await db.executemany(
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
[
("jail1", "1.1.1.1", 100, 1, "{}"),
("jail1", "1.1.1.1", 200, 2, "{}"),
("jail1", "2.2.2.2", 300, 3, "{}"),
],
)
await db.commit()
page, total = await fail2ban_db_repo.get_history_page(
db_path=db_path,
since=None,
jail="jail1",
ip_filter="1.1.1",
page=1,
page_size=10,
)
assert total == 2
assert len(page) == 2
assert page[0].ip == "1.1.1.1"
history_for_ip = await fail2ban_db_repo.get_history_for_ip(db_path=db_path, ip="2.2.2.2")
assert len(history_for_ip) == 1
assert history_for_ip[0].ip == "2.2.2.2"
@pytest.mark.asyncio
async def test_get_history_page_origin_filter(tmp_path: Path) -> None:
db_path = str(tmp_path / "fail2ban.db")
async with aiosqlite.connect(db_path) as db:
await _create_bans_table(db)
await db.executemany(
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
[
("jail1", "1.1.1.1", 100, 1, "{}"),
("blocklist-import", "2.2.2.2", 200, 1, "{}"),
],
)
await db.commit()
page, total = await fail2ban_db_repo.get_history_page(
db_path=db_path,
since=None,
jail=None,
ip_filter=None,
origin="selfblock",
page=1,
page_size=10,
)
assert total == 1
assert len(page) == 1
assert page[0].ip == "1.1.1.1"