Refactor blocklist log retrieval via service layer and add fail2ban DB repo

This commit is contained in:
2026-03-17 08:58:04 +01:00
parent 93f0feabde
commit 1ce5da9e23
7 changed files with 632 additions and 12 deletions

View File

@@ -60,6 +60,8 @@ This document breaks the entire BanGUI project into development stages, ordered
#### TASK B-3 — Remove repository import from `routers/blocklist.py`
**Status:** Completed ✅
**Violated rule:** Refactoring.md §2.1 — Routers must not import from repositories; all data access must go through services.
**Files affected:**

View File

@@ -0,0 +1,358 @@
"""Fail2Ban SQLite database repository.
This module contains helper functions that query the read-only fail2ban
SQLite database file. All functions accept a *db_path* and manage their own
connection using aiosqlite in read-only mode.
The functions intentionally return plain Python data structures (dataclasses) so
service layers can focus on business logic and formatting.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import aiosqlite
if TYPE_CHECKING:
from collections.abc import Iterable
from app.models.ban import BanOrigin
@dataclass(frozen=True)
class BanRecord:
"""A single row from the fail2ban ``bans`` table."""
jail: str
ip: str
timeofban: int
bancount: int
data: str
@dataclass(frozen=True)
class BanIpCount:
"""Aggregated ban count for a single IP."""
ip: str
event_count: int
@dataclass(frozen=True)
class JailBanCount:
"""Aggregated ban count for a single jail."""
jail: str
count: int
@dataclass(frozen=True)
class HistoryRecord:
"""A single row from the fail2ban ``bans`` table for history queries."""
jail: str
ip: str
timeofban: int
bancount: int
data: str
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _make_db_uri(db_path: str) -> str:
"""Return a read-only sqlite URI for the given file path."""
return f"file:{db_path}?mode=ro"
def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
"""Return a SQL fragment and parameters for the origin filter."""
if origin == "blocklist":
return " AND jail = ?", ("blocklist-import",)
if origin == "selfblock":
return " AND jail != ?", ("blocklist-import",)
return "", ()
def _rows_to_ban_records(rows: Iterable[aiosqlite.Row]) -> list[BanRecord]:
return [
BanRecord(
jail=str(r["jail"]),
ip=str(r["ip"]),
timeofban=int(r["timeofban"]),
bancount=int(r["bancount"]),
data=str(r["data"]),
)
for r in rows
]
def _rows_to_history_records(rows: Iterable[aiosqlite.Row]) -> list[HistoryRecord]:
return [
HistoryRecord(
jail=str(r["jail"]),
ip=str(r["ip"]),
timeofban=int(r["timeofban"]),
bancount=int(r["bancount"]),
data=str(r["data"]),
)
for r in rows
]
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
async def check_db_nonempty(db_path: str) -> bool:
"""Return True if the fail2ban database contains at least one ban row."""
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db, db.execute(
"SELECT 1 FROM bans LIMIT 1"
) as cur:
row = await cur.fetchone()
return row is not None
async def get_currently_banned(
db_path: str,
since: int,
origin: BanOrigin | None = None,
*,
limit: int | None = None,
offset: int | None = None,
) -> tuple[list[BanRecord], int]:
"""Return a page of currently banned IPs and the total matching count.
Args:
db_path: File path to the fail2ban SQLite database.
since: Unix timestamp to filter bans newer than or equal to.
origin: Optional origin filter.
limit: Optional maximum number of rows to return.
offset: Optional offset for pagination.
Returns:
A ``(records, total)`` tuple.
"""
origin_clause, origin_params = _origin_sql_filter(origin)
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
(since, *origin_params),
) as cur:
count_row = await cur.fetchone()
total: int = int(count_row[0]) if count_row else 0
query = (
"SELECT jail, ip, timeofban, bancount, data "
"FROM bans "
"WHERE timeofban >= ?" + origin_clause + " ORDER BY timeofban DESC"
)
params: list[object] = [since, *origin_params]
if limit is not None:
query += " LIMIT ?"
params.append(limit)
if offset is not None:
query += " OFFSET ?"
params.append(offset)
async with db.execute(query, params) as cur:
rows = await cur.fetchall()
return _rows_to_ban_records(rows), total
async def get_ban_counts_by_bucket(
db_path: str,
since: int,
bucket_secs: int,
num_buckets: int,
origin: BanOrigin | None = None,
) -> list[int]:
"""Return ban counts aggregated into equal-width time buckets."""
origin_clause, origin_params = _origin_sql_filter(origin)
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT CAST((timeofban - ?) / ? AS INTEGER) AS bucket_idx, "
"COUNT(*) AS cnt "
"FROM bans "
"WHERE timeofban >= ?" + origin_clause + " GROUP BY bucket_idx "
"ORDER BY bucket_idx",
(since, bucket_secs, since, *origin_params),
) as cur:
rows = await cur.fetchall()
counts: list[int] = [0] * num_buckets
for row in rows:
idx: int = int(row["bucket_idx"])
if 0 <= idx < num_buckets:
counts[idx] = int(row["cnt"])
return counts
async def get_ban_event_counts(
db_path: str,
since: int,
origin: BanOrigin | None = None,
) -> list[BanIpCount]:
"""Return total ban events per unique IP in the window."""
origin_clause, origin_params = _origin_sql_filter(origin)
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT ip, COUNT(*) AS event_count "
"FROM bans "
"WHERE timeofban >= ?" + origin_clause + " GROUP BY ip",
(since, *origin_params),
) as cur:
rows = await cur.fetchall()
return [
BanIpCount(ip=str(r["ip"]), event_count=int(r["event_count"]))
for r in rows
]
async def get_bans_by_jail(
db_path: str,
since: int,
origin: BanOrigin | None = None,
) -> tuple[int, list[JailBanCount]]:
"""Return per-jail ban counts and the total ban count."""
origin_clause, origin_params = _origin_sql_filter(origin)
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
(since, *origin_params),
) as cur:
count_row = await cur.fetchone()
total: int = int(count_row[0]) if count_row else 0
async with db.execute(
"SELECT jail, COUNT(*) AS cnt "
"FROM bans "
"WHERE timeofban >= ?" + origin_clause + " GROUP BY jail ORDER BY cnt DESC",
(since, *origin_params),
) as cur:
rows = await cur.fetchall()
return total, [
JailBanCount(jail=str(r["jail"]), count=int(r["cnt"])) for r in rows
]
async def get_bans_table_summary(
db_path: str,
) -> tuple[int, int | None, int | None]:
"""Return basic summary stats for the ``bans`` table.
Returns:
A tuple ``(row_count, min_timeofban, max_timeofban)``. If the table is
empty the min/max values will be ``None``.
"""
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT COUNT(*), MIN(timeofban), MAX(timeofban) FROM bans"
) as cur:
row = await cur.fetchone()
if row is None:
return 0, None, None
return (
int(row[0]),
int(row[1]) if row[1] is not None else None,
int(row[2]) if row[2] is not None else None,
)
async def get_history_page(
db_path: str,
since: int | None = None,
jail: str | None = None,
ip_filter: str | None = None,
page: int = 1,
page_size: int = 100,
) -> tuple[list[HistoryRecord], int]:
"""Return a paginated list of history records with total count."""
wheres: list[str] = []
params: list[object] = []
if since is not None:
wheres.append("timeofban >= ?")
params.append(since)
if jail is not None:
wheres.append("jail = ?")
params.append(jail)
if ip_filter is not None:
wheres.append("ip LIKE ?")
params.append(f"{ip_filter}%")
where_sql: str = ("WHERE " + " AND ".join(wheres)) if wheres else ""
effective_page_size: int = page_size
offset: int = (page - 1) * effective_page_size
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
f"SELECT COUNT(*) FROM bans {where_sql}", # noqa: S608
params,
) as cur:
count_row = await cur.fetchone()
total: int = int(count_row[0]) if count_row else 0
async with db.execute(
f"SELECT jail, ip, timeofban, bancount, data "
f"FROM bans {where_sql} "
"ORDER BY timeofban DESC "
"LIMIT ? OFFSET ?",
[*params, effective_page_size, offset],
) as cur:
rows = await cur.fetchall()
return _rows_to_history_records(rows), total
async def get_history_for_ip(db_path: str, ip: str) -> list[HistoryRecord]:
"""Return the full ban timeline for a specific IP."""
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT jail, ip, timeofban, bancount, data "
"FROM bans "
"WHERE ip = ? "
"ORDER BY timeofban DESC",
(ip,),
) as cur:
rows = await cur.fetchall()
return _rows_to_history_records(rows)

View File

@@ -42,7 +42,6 @@ from app.models.blocklist import (
ScheduleConfig,
ScheduleInfo,
)
from app.repositories import import_log_repo
from app.services import blocklist_service
from app.tasks import blocklist_import as blocklist_import_task
@@ -225,19 +224,9 @@ async def get_import_log(
Returns:
:class:`~app.models.blocklist.ImportLogListResponse`.
"""
items, total = await import_log_repo.list_logs(
return await blocklist_service.list_import_logs(
db, source_id=source_id, page=page, page_size=page_size
)
total_pages = import_log_repo.compute_total_pages(total, page_size)
from app.models.blocklist import ImportLogEntry # noqa: PLC0415
return ImportLogListResponse(
items=[ImportLogEntry.model_validate(i) for i in items],
total=total,
page=page,
page_size=page_size,
total_pages=total_pages,
)
# ---------------------------------------------------------------------------

View File

@@ -21,6 +21,8 @@ import structlog
from app.models.blocklist import (
BlocklistSource,
ImportLogEntry,
ImportLogListResponse,
ImportRunResult,
ImportSourceResult,
PreviewResponse,
@@ -503,6 +505,38 @@ async def get_schedule_info(
)
async def list_import_logs(
db: aiosqlite.Connection,
*,
source_id: int | None = None,
page: int = 1,
page_size: int = 50,
) -> ImportLogListResponse:
"""Return a paginated list of import log entries.
Args:
db: Active application database connection.
source_id: Optional filter to only return logs for a specific source.
page: 1-based page number.
page_size: Items per page.
Returns:
:class:`~app.models.blocklist.ImportLogListResponse`.
"""
items, total = await import_log_repo.list_logs(
db, source_id=source_id, page=page, page_size=page_size
)
total_pages = import_log_repo.compute_total_pages(total, page_size)
return ImportLogListResponse(
items=[ImportLogEntry.model_validate(i) for i in items],
total=total,
page=page,
page_size=page_size,
total_pages=total_pages,
)
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------

View File

@@ -0,0 +1,138 @@
"""Tests for the fail2ban_db repository.
These tests use an in-memory sqlite file created under pytest's tmp_path and
exercise the core query functions used by the services.
"""
from pathlib import Path
import aiosqlite
import pytest
from app.repositories import fail2ban_db_repo
async def _create_bans_table(db: aiosqlite.Connection) -> None:
await db.execute(
"""
CREATE TABLE bans (
jail TEXT,
ip TEXT,
timeofban INTEGER,
bancount INTEGER,
data TEXT
)
"""
)
await db.commit()
@pytest.mark.asyncio
async def test_check_db_nonempty_returns_false_when_table_is_empty(tmp_path: Path) -> None:
db_path = str(tmp_path / "fail2ban.db")
async with aiosqlite.connect(db_path) as db:
await _create_bans_table(db)
assert await fail2ban_db_repo.check_db_nonempty(db_path) is False
@pytest.mark.asyncio
async def test_check_db_nonempty_returns_true_when_row_exists(tmp_path: Path) -> None:
db_path = str(tmp_path / "fail2ban.db")
async with aiosqlite.connect(db_path) as db:
await _create_bans_table(db)
await db.execute(
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
("jail1", "1.2.3.4", 123, 1, "{}"),
)
await db.commit()
assert await fail2ban_db_repo.check_db_nonempty(db_path) is True
@pytest.mark.asyncio
async def test_get_currently_banned_filters_and_pagination(tmp_path: Path) -> None:
db_path = str(tmp_path / "fail2ban.db")
async with aiosqlite.connect(db_path) as db:
await _create_bans_table(db)
# Three bans; one is from the blocklist-import jail.
await db.executemany(
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
[
("jail1", "1.1.1.1", 10, 1, "{}"),
("blocklist-import", "2.2.2.2", 20, 2, "{}"),
("jail1", "3.3.3.3", 30, 3, "{}"),
],
)
await db.commit()
records, total = await fail2ban_db_repo.get_currently_banned(
db_path=db_path,
since=15,
origin="selfblock",
limit=10,
offset=0,
)
# Only the non-blocklist row with timeofban >= 15 should remain.
assert total == 1
assert len(records) == 1
assert records[0].ip == "3.3.3.3"
@pytest.mark.asyncio
async def test_get_ban_counts_by_bucket_ignores_out_of_range_buckets(tmp_path: Path) -> None:
db_path = str(tmp_path / "fail2ban.db")
async with aiosqlite.connect(db_path) as db:
await _create_bans_table(db)
await db.executemany(
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
[
("jail1", "1.1.1.1", 5, 1, "{}"),
("jail1", "2.2.2.2", 15, 1, "{}"),
("jail1", "3.3.3.3", 35, 1, "{}"),
],
)
await db.commit()
counts = await fail2ban_db_repo.get_ban_counts_by_bucket(
db_path=db_path,
since=0,
bucket_secs=10,
num_buckets=3,
)
assert counts == [1, 1, 0]
@pytest.mark.asyncio
async def test_get_history_page_and_for_ip(tmp_path: Path) -> None:
db_path = str(tmp_path / "fail2ban.db")
async with aiosqlite.connect(db_path) as db:
await _create_bans_table(db)
await db.executemany(
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
[
("jail1", "1.1.1.1", 100, 1, "{}"),
("jail1", "1.1.1.1", 200, 2, "{}"),
("jail1", "2.2.2.2", 300, 3, "{}"),
],
)
await db.commit()
page, total = await fail2ban_db_repo.get_history_page(
db_path=db_path,
since=None,
jail="jail1",
ip_filter="1.1.1",
page=1,
page_size=10,
)
assert total == 2
assert len(page) == 2
assert page[0].ip == "1.1.1.1"
history_for_ip = await fail2ban_db_repo.get_history_for_ip(db_path=db_path, ip="2.2.2.2")
assert len(history_for_ip) == 1
assert history_for_ip[0].ip == "2.2.2.2"

View File

@@ -0,0 +1,62 @@
"""Tests for the geo cache repository."""
from pathlib import Path
import aiosqlite
import pytest
from app.repositories import geo_cache_repo
async def _create_geo_cache_table(db: aiosqlite.Connection) -> None:
await db.execute(
"""
CREATE TABLE IF NOT EXISTS geo_cache (
ip TEXT PRIMARY KEY,
country_code TEXT,
country_name TEXT,
asn TEXT,
org TEXT,
cached_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
)
"""
)
await db.commit()
@pytest.mark.asyncio
async def test_get_unresolved_ips_returns_empty_when_none_exist(tmp_path: Path) -> None:
db_path = str(tmp_path / "geo_cache.db")
async with aiosqlite.connect(db_path) as db:
await _create_geo_cache_table(db)
await db.execute(
"INSERT INTO geo_cache (ip, country_code, country_name, asn, org) VALUES (?, ?, ?, ?, ?)",
("1.1.1.1", "DE", "Germany", "AS123", "Test"),
)
await db.commit()
async with aiosqlite.connect(db_path) as db:
ips = await geo_cache_repo.get_unresolved_ips(db)
assert ips == []
@pytest.mark.asyncio
async def test_get_unresolved_ips_returns_pending_ips(tmp_path: Path) -> None:
db_path = str(tmp_path / "geo_cache.db")
async with aiosqlite.connect(db_path) as db:
await _create_geo_cache_table(db)
await db.executemany(
"INSERT INTO geo_cache (ip, country_code) VALUES (?, ?)",
[
("2.2.2.2", None),
("3.3.3.3", None),
("4.4.4.4", "US"),
],
)
await db.commit()
async with aiosqlite.connect(db_path) as db:
ips = await geo_cache_repo.get_unresolved_ips(db)
assert sorted(ips) == ["2.2.2.2", "3.3.3.3"]

View File

@@ -337,3 +337,40 @@ class TestGeoPrewarmCacheFilter:
call_ips = mock_batch.call_args[0][0]
assert "1.2.3.4" not in call_ips
assert set(call_ips) == {"5.6.7.8", "9.10.11.12"}
class TestImportLogPagination:
async def test_list_import_logs_empty(self, db: aiosqlite.Connection) -> None:
"""list_import_logs returns an empty page when no logs exist."""
resp = await blocklist_service.list_import_logs(
db, source_id=None, page=1, page_size=10
)
assert resp.items == []
assert resp.total == 0
assert resp.page == 1
assert resp.page_size == 10
assert resp.total_pages == 1
async def test_list_import_logs_paginates(self, db: aiosqlite.Connection) -> None:
"""list_import_logs computes total pages and returns the correct subset."""
from app.repositories import import_log_repo
for i in range(3):
await import_log_repo.add_log(
db,
source_id=None,
source_url=f"https://example{i}.test/ips.txt",
ips_imported=1,
ips_skipped=0,
errors=None,
)
resp = await blocklist_service.list_import_logs(
db, source_id=None, page=2, page_size=2
)
assert resp.total == 3
assert resp.total_pages == 2
assert resp.page == 2
assert resp.page_size == 2
assert len(resp.items) == 1
assert resp.items[0].source_url == "https://example0.test/ips.txt"