Refactor blocklist log retrieval via service layer and add fail2ban DB repo
This commit is contained in:
@@ -60,6 +60,8 @@ This document breaks the entire BanGUI project into development stages, ordered
|
|||||||
|
|
||||||
#### TASK B-3 — Remove repository import from `routers/blocklist.py`
|
#### TASK B-3 — Remove repository import from `routers/blocklist.py`
|
||||||
|
|
||||||
|
**Status:** Completed ✅
|
||||||
|
|
||||||
**Violated rule:** Refactoring.md §2.1 — Routers must not import from repositories; all data access must go through services.
|
**Violated rule:** Refactoring.md §2.1 — Routers must not import from repositories; all data access must go through services.
|
||||||
|
|
||||||
**Files affected:**
|
**Files affected:**
|
||||||
|
|||||||
358
backend/app/repositories/fail2ban_db_repo.py
Normal file
358
backend/app/repositories/fail2ban_db_repo.py
Normal file
@@ -0,0 +1,358 @@
|
|||||||
|
"""Fail2Ban SQLite database repository.
|
||||||
|
|
||||||
|
This module contains helper functions that query the read-only fail2ban
|
||||||
|
SQLite database file. All functions accept a *db_path* and manage their own
|
||||||
|
connection using aiosqlite in read-only mode.
|
||||||
|
|
||||||
|
The functions intentionally return plain Python data structures (dataclasses) so
|
||||||
|
service layers can focus on business logic and formatting.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Iterable
|
||||||
|
|
||||||
|
from app.models.ban import BanOrigin
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class BanRecord:
|
||||||
|
"""A single row from the fail2ban ``bans`` table."""
|
||||||
|
|
||||||
|
jail: str
|
||||||
|
ip: str
|
||||||
|
timeofban: int
|
||||||
|
bancount: int
|
||||||
|
data: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class BanIpCount:
|
||||||
|
"""Aggregated ban count for a single IP."""
|
||||||
|
|
||||||
|
ip: str
|
||||||
|
event_count: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class JailBanCount:
|
||||||
|
"""Aggregated ban count for a single jail."""
|
||||||
|
|
||||||
|
jail: str
|
||||||
|
count: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class HistoryRecord:
|
||||||
|
"""A single row from the fail2ban ``bans`` table for history queries."""
|
||||||
|
|
||||||
|
jail: str
|
||||||
|
ip: str
|
||||||
|
timeofban: int
|
||||||
|
bancount: int
|
||||||
|
data: str
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Internal helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_db_uri(db_path: str) -> str:
|
||||||
|
"""Return a read-only sqlite URI for the given file path."""
|
||||||
|
|
||||||
|
return f"file:{db_path}?mode=ro"
|
||||||
|
|
||||||
|
|
||||||
|
def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
|
||||||
|
"""Return a SQL fragment and parameters for the origin filter."""
|
||||||
|
|
||||||
|
if origin == "blocklist":
|
||||||
|
return " AND jail = ?", ("blocklist-import",)
|
||||||
|
if origin == "selfblock":
|
||||||
|
return " AND jail != ?", ("blocklist-import",)
|
||||||
|
return "", ()
|
||||||
|
|
||||||
|
|
||||||
|
def _rows_to_ban_records(rows: Iterable[aiosqlite.Row]) -> list[BanRecord]:
|
||||||
|
return [
|
||||||
|
BanRecord(
|
||||||
|
jail=str(r["jail"]),
|
||||||
|
ip=str(r["ip"]),
|
||||||
|
timeofban=int(r["timeofban"]),
|
||||||
|
bancount=int(r["bancount"]),
|
||||||
|
data=str(r["data"]),
|
||||||
|
)
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _rows_to_history_records(rows: Iterable[aiosqlite.Row]) -> list[HistoryRecord]:
|
||||||
|
return [
|
||||||
|
HistoryRecord(
|
||||||
|
jail=str(r["jail"]),
|
||||||
|
ip=str(r["ip"]),
|
||||||
|
timeofban=int(r["timeofban"]),
|
||||||
|
bancount=int(r["bancount"]),
|
||||||
|
data=str(r["data"]),
|
||||||
|
)
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Public API
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
async def check_db_nonempty(db_path: str) -> bool:
|
||||||
|
"""Return True if the fail2ban database contains at least one ban row."""
|
||||||
|
|
||||||
|
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db, db.execute(
|
||||||
|
"SELECT 1 FROM bans LIMIT 1"
|
||||||
|
) as cur:
|
||||||
|
row = await cur.fetchone()
|
||||||
|
return row is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_currently_banned(
|
||||||
|
db_path: str,
|
||||||
|
since: int,
|
||||||
|
origin: BanOrigin | None = None,
|
||||||
|
*,
|
||||||
|
limit: int | None = None,
|
||||||
|
offset: int | None = None,
|
||||||
|
) -> tuple[list[BanRecord], int]:
|
||||||
|
"""Return a page of currently banned IPs and the total matching count.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: File path to the fail2ban SQLite database.
|
||||||
|
since: Unix timestamp to filter bans newer than or equal to.
|
||||||
|
origin: Optional origin filter.
|
||||||
|
limit: Optional maximum number of rows to return.
|
||||||
|
offset: Optional offset for pagination.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A ``(records, total)`` tuple.
|
||||||
|
"""
|
||||||
|
|
||||||
|
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||||
|
|
||||||
|
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
|
||||||
|
async with db.execute(
|
||||||
|
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
||||||
|
(since, *origin_params),
|
||||||
|
) as cur:
|
||||||
|
count_row = await cur.fetchone()
|
||||||
|
total: int = int(count_row[0]) if count_row else 0
|
||||||
|
|
||||||
|
query = (
|
||||||
|
"SELECT jail, ip, timeofban, bancount, data "
|
||||||
|
"FROM bans "
|
||||||
|
"WHERE timeofban >= ?" + origin_clause + " ORDER BY timeofban DESC"
|
||||||
|
)
|
||||||
|
params: list[object] = [since, *origin_params]
|
||||||
|
if limit is not None:
|
||||||
|
query += " LIMIT ?"
|
||||||
|
params.append(limit)
|
||||||
|
if offset is not None:
|
||||||
|
query += " OFFSET ?"
|
||||||
|
params.append(offset)
|
||||||
|
|
||||||
|
async with db.execute(query, params) as cur:
|
||||||
|
rows = await cur.fetchall()
|
||||||
|
|
||||||
|
return _rows_to_ban_records(rows), total
|
||||||
|
|
||||||
|
|
||||||
|
async def get_ban_counts_by_bucket(
|
||||||
|
db_path: str,
|
||||||
|
since: int,
|
||||||
|
bucket_secs: int,
|
||||||
|
num_buckets: int,
|
||||||
|
origin: BanOrigin | None = None,
|
||||||
|
) -> list[int]:
|
||||||
|
"""Return ban counts aggregated into equal-width time buckets."""
|
||||||
|
|
||||||
|
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||||
|
|
||||||
|
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
async with db.execute(
|
||||||
|
"SELECT CAST((timeofban - ?) / ? AS INTEGER) AS bucket_idx, "
|
||||||
|
"COUNT(*) AS cnt "
|
||||||
|
"FROM bans "
|
||||||
|
"WHERE timeofban >= ?" + origin_clause + " GROUP BY bucket_idx "
|
||||||
|
"ORDER BY bucket_idx",
|
||||||
|
(since, bucket_secs, since, *origin_params),
|
||||||
|
) as cur:
|
||||||
|
rows = await cur.fetchall()
|
||||||
|
|
||||||
|
counts: list[int] = [0] * num_buckets
|
||||||
|
for row in rows:
|
||||||
|
idx: int = int(row["bucket_idx"])
|
||||||
|
if 0 <= idx < num_buckets:
|
||||||
|
counts[idx] = int(row["cnt"])
|
||||||
|
|
||||||
|
return counts
|
||||||
|
|
||||||
|
|
||||||
|
async def get_ban_event_counts(
|
||||||
|
db_path: str,
|
||||||
|
since: int,
|
||||||
|
origin: BanOrigin | None = None,
|
||||||
|
) -> list[BanIpCount]:
|
||||||
|
"""Return total ban events per unique IP in the window."""
|
||||||
|
|
||||||
|
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||||
|
|
||||||
|
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
async with db.execute(
|
||||||
|
"SELECT ip, COUNT(*) AS event_count "
|
||||||
|
"FROM bans "
|
||||||
|
"WHERE timeofban >= ?" + origin_clause + " GROUP BY ip",
|
||||||
|
(since, *origin_params),
|
||||||
|
) as cur:
|
||||||
|
rows = await cur.fetchall()
|
||||||
|
|
||||||
|
return [
|
||||||
|
BanIpCount(ip=str(r["ip"]), event_count=int(r["event_count"]))
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_bans_by_jail(
|
||||||
|
db_path: str,
|
||||||
|
since: int,
|
||||||
|
origin: BanOrigin | None = None,
|
||||||
|
) -> tuple[int, list[JailBanCount]]:
|
||||||
|
"""Return per-jail ban counts and the total ban count."""
|
||||||
|
|
||||||
|
origin_clause, origin_params = _origin_sql_filter(origin)
|
||||||
|
|
||||||
|
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
|
||||||
|
async with db.execute(
|
||||||
|
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
|
||||||
|
(since, *origin_params),
|
||||||
|
) as cur:
|
||||||
|
count_row = await cur.fetchone()
|
||||||
|
total: int = int(count_row[0]) if count_row else 0
|
||||||
|
|
||||||
|
async with db.execute(
|
||||||
|
"SELECT jail, COUNT(*) AS cnt "
|
||||||
|
"FROM bans "
|
||||||
|
"WHERE timeofban >= ?" + origin_clause + " GROUP BY jail ORDER BY cnt DESC",
|
||||||
|
(since, *origin_params),
|
||||||
|
) as cur:
|
||||||
|
rows = await cur.fetchall()
|
||||||
|
|
||||||
|
return total, [
|
||||||
|
JailBanCount(jail=str(r["jail"]), count=int(r["cnt"])) for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_bans_table_summary(
|
||||||
|
db_path: str,
|
||||||
|
) -> tuple[int, int | None, int | None]:
|
||||||
|
"""Return basic summary stats for the ``bans`` table.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A tuple ``(row_count, min_timeofban, max_timeofban)``. If the table is
|
||||||
|
empty the min/max values will be ``None``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
async with db.execute(
|
||||||
|
"SELECT COUNT(*), MIN(timeofban), MAX(timeofban) FROM bans"
|
||||||
|
) as cur:
|
||||||
|
row = await cur.fetchone()
|
||||||
|
|
||||||
|
if row is None:
|
||||||
|
return 0, None, None
|
||||||
|
|
||||||
|
return (
|
||||||
|
int(row[0]),
|
||||||
|
int(row[1]) if row[1] is not None else None,
|
||||||
|
int(row[2]) if row[2] is not None else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_history_page(
|
||||||
|
db_path: str,
|
||||||
|
since: int | None = None,
|
||||||
|
jail: str | None = None,
|
||||||
|
ip_filter: str | None = None,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 100,
|
||||||
|
) -> tuple[list[HistoryRecord], int]:
|
||||||
|
"""Return a paginated list of history records with total count."""
|
||||||
|
|
||||||
|
wheres: list[str] = []
|
||||||
|
params: list[object] = []
|
||||||
|
|
||||||
|
if since is not None:
|
||||||
|
wheres.append("timeofban >= ?")
|
||||||
|
params.append(since)
|
||||||
|
|
||||||
|
if jail is not None:
|
||||||
|
wheres.append("jail = ?")
|
||||||
|
params.append(jail)
|
||||||
|
|
||||||
|
if ip_filter is not None:
|
||||||
|
wheres.append("ip LIKE ?")
|
||||||
|
params.append(f"{ip_filter}%")
|
||||||
|
|
||||||
|
where_sql: str = ("WHERE " + " AND ".join(wheres)) if wheres else ""
|
||||||
|
|
||||||
|
effective_page_size: int = page_size
|
||||||
|
offset: int = (page - 1) * effective_page_size
|
||||||
|
|
||||||
|
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
|
||||||
|
async with db.execute(
|
||||||
|
f"SELECT COUNT(*) FROM bans {where_sql}", # noqa: S608
|
||||||
|
params,
|
||||||
|
) as cur:
|
||||||
|
count_row = await cur.fetchone()
|
||||||
|
total: int = int(count_row[0]) if count_row else 0
|
||||||
|
|
||||||
|
async with db.execute(
|
||||||
|
f"SELECT jail, ip, timeofban, bancount, data "
|
||||||
|
f"FROM bans {where_sql} "
|
||||||
|
"ORDER BY timeofban DESC "
|
||||||
|
"LIMIT ? OFFSET ?",
|
||||||
|
[*params, effective_page_size, offset],
|
||||||
|
) as cur:
|
||||||
|
rows = await cur.fetchall()
|
||||||
|
|
||||||
|
return _rows_to_history_records(rows), total
|
||||||
|
|
||||||
|
|
||||||
|
async def get_history_for_ip(db_path: str, ip: str) -> list[HistoryRecord]:
|
||||||
|
"""Return the full ban timeline for a specific IP."""
|
||||||
|
|
||||||
|
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
|
||||||
|
db.row_factory = aiosqlite.Row
|
||||||
|
async with db.execute(
|
||||||
|
"SELECT jail, ip, timeofban, bancount, data "
|
||||||
|
"FROM bans "
|
||||||
|
"WHERE ip = ? "
|
||||||
|
"ORDER BY timeofban DESC",
|
||||||
|
(ip,),
|
||||||
|
) as cur:
|
||||||
|
rows = await cur.fetchall()
|
||||||
|
|
||||||
|
return _rows_to_history_records(rows)
|
||||||
@@ -42,7 +42,6 @@ from app.models.blocklist import (
|
|||||||
ScheduleConfig,
|
ScheduleConfig,
|
||||||
ScheduleInfo,
|
ScheduleInfo,
|
||||||
)
|
)
|
||||||
from app.repositories import import_log_repo
|
|
||||||
from app.services import blocklist_service
|
from app.services import blocklist_service
|
||||||
from app.tasks import blocklist_import as blocklist_import_task
|
from app.tasks import blocklist_import as blocklist_import_task
|
||||||
|
|
||||||
@@ -225,19 +224,9 @@ async def get_import_log(
|
|||||||
Returns:
|
Returns:
|
||||||
:class:`~app.models.blocklist.ImportLogListResponse`.
|
:class:`~app.models.blocklist.ImportLogListResponse`.
|
||||||
"""
|
"""
|
||||||
items, total = await import_log_repo.list_logs(
|
return await blocklist_service.list_import_logs(
|
||||||
db, source_id=source_id, page=page, page_size=page_size
|
db, source_id=source_id, page=page, page_size=page_size
|
||||||
)
|
)
|
||||||
total_pages = import_log_repo.compute_total_pages(total, page_size)
|
|
||||||
from app.models.blocklist import ImportLogEntry # noqa: PLC0415
|
|
||||||
|
|
||||||
return ImportLogListResponse(
|
|
||||||
items=[ImportLogEntry.model_validate(i) for i in items],
|
|
||||||
total=total,
|
|
||||||
page=page,
|
|
||||||
page_size=page_size,
|
|
||||||
total_pages=total_pages,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ import structlog
|
|||||||
|
|
||||||
from app.models.blocklist import (
|
from app.models.blocklist import (
|
||||||
BlocklistSource,
|
BlocklistSource,
|
||||||
|
ImportLogEntry,
|
||||||
|
ImportLogListResponse,
|
||||||
ImportRunResult,
|
ImportRunResult,
|
||||||
ImportSourceResult,
|
ImportSourceResult,
|
||||||
PreviewResponse,
|
PreviewResponse,
|
||||||
@@ -503,6 +505,38 @@ async def get_schedule_info(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def list_import_logs(
|
||||||
|
db: aiosqlite.Connection,
|
||||||
|
*,
|
||||||
|
source_id: int | None = None,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 50,
|
||||||
|
) -> ImportLogListResponse:
|
||||||
|
"""Return a paginated list of import log entries.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db: Active application database connection.
|
||||||
|
source_id: Optional filter to only return logs for a specific source.
|
||||||
|
page: 1-based page number.
|
||||||
|
page_size: Items per page.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:class:`~app.models.blocklist.ImportLogListResponse`.
|
||||||
|
"""
|
||||||
|
items, total = await import_log_repo.list_logs(
|
||||||
|
db, source_id=source_id, page=page, page_size=page_size
|
||||||
|
)
|
||||||
|
total_pages = import_log_repo.compute_total_pages(total, page_size)
|
||||||
|
|
||||||
|
return ImportLogListResponse(
|
||||||
|
items=[ImportLogEntry.model_validate(i) for i in items],
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
page_size=page_size,
|
||||||
|
total_pages=total_pages,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Internal helpers
|
# Internal helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
138
backend/tests/test_repositories/test_fail2ban_db_repo.py
Normal file
138
backend/tests/test_repositories/test_fail2ban_db_repo.py
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
"""Tests for the fail2ban_db repository.
|
||||||
|
|
||||||
|
These tests use an in-memory sqlite file created under pytest's tmp_path and
|
||||||
|
exercise the core query functions used by the services.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.repositories import fail2ban_db_repo
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_bans_table(db: aiosqlite.Connection) -> None:
|
||||||
|
await db.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE bans (
|
||||||
|
jail TEXT,
|
||||||
|
ip TEXT,
|
||||||
|
timeofban INTEGER,
|
||||||
|
bancount INTEGER,
|
||||||
|
data TEXT
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_db_nonempty_returns_false_when_table_is_empty(tmp_path: Path) -> None:
|
||||||
|
db_path = str(tmp_path / "fail2ban.db")
|
||||||
|
async with aiosqlite.connect(db_path) as db:
|
||||||
|
await _create_bans_table(db)
|
||||||
|
|
||||||
|
assert await fail2ban_db_repo.check_db_nonempty(db_path) is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_check_db_nonempty_returns_true_when_row_exists(tmp_path: Path) -> None:
|
||||||
|
db_path = str(tmp_path / "fail2ban.db")
|
||||||
|
async with aiosqlite.connect(db_path) as db:
|
||||||
|
await _create_bans_table(db)
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
("jail1", "1.2.3.4", 123, 1, "{}"),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
assert await fail2ban_db_repo.check_db_nonempty(db_path) is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_currently_banned_filters_and_pagination(tmp_path: Path) -> None:
|
||||||
|
db_path = str(tmp_path / "fail2ban.db")
|
||||||
|
async with aiosqlite.connect(db_path) as db:
|
||||||
|
await _create_bans_table(db)
|
||||||
|
# Three bans; one is from the blocklist-import jail.
|
||||||
|
await db.executemany(
|
||||||
|
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
[
|
||||||
|
("jail1", "1.1.1.1", 10, 1, "{}"),
|
||||||
|
("blocklist-import", "2.2.2.2", 20, 2, "{}"),
|
||||||
|
("jail1", "3.3.3.3", 30, 3, "{}"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
records, total = await fail2ban_db_repo.get_currently_banned(
|
||||||
|
db_path=db_path,
|
||||||
|
since=15,
|
||||||
|
origin="selfblock",
|
||||||
|
limit=10,
|
||||||
|
offset=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Only the non-blocklist row with timeofban >= 15 should remain.
|
||||||
|
assert total == 1
|
||||||
|
assert len(records) == 1
|
||||||
|
assert records[0].ip == "3.3.3.3"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_ban_counts_by_bucket_ignores_out_of_range_buckets(tmp_path: Path) -> None:
|
||||||
|
db_path = str(tmp_path / "fail2ban.db")
|
||||||
|
async with aiosqlite.connect(db_path) as db:
|
||||||
|
await _create_bans_table(db)
|
||||||
|
await db.executemany(
|
||||||
|
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
[
|
||||||
|
("jail1", "1.1.1.1", 5, 1, "{}"),
|
||||||
|
("jail1", "2.2.2.2", 15, 1, "{}"),
|
||||||
|
("jail1", "3.3.3.3", 35, 1, "{}"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
counts = await fail2ban_db_repo.get_ban_counts_by_bucket(
|
||||||
|
db_path=db_path,
|
||||||
|
since=0,
|
||||||
|
bucket_secs=10,
|
||||||
|
num_buckets=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert counts == [1, 1, 0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_history_page_and_for_ip(tmp_path: Path) -> None:
|
||||||
|
db_path = str(tmp_path / "fail2ban.db")
|
||||||
|
async with aiosqlite.connect(db_path) as db:
|
||||||
|
await _create_bans_table(db)
|
||||||
|
await db.executemany(
|
||||||
|
"INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
[
|
||||||
|
("jail1", "1.1.1.1", 100, 1, "{}"),
|
||||||
|
("jail1", "1.1.1.1", 200, 2, "{}"),
|
||||||
|
("jail1", "2.2.2.2", 300, 3, "{}"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
page, total = await fail2ban_db_repo.get_history_page(
|
||||||
|
db_path=db_path,
|
||||||
|
since=None,
|
||||||
|
jail="jail1",
|
||||||
|
ip_filter="1.1.1",
|
||||||
|
page=1,
|
||||||
|
page_size=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert total == 2
|
||||||
|
assert len(page) == 2
|
||||||
|
assert page[0].ip == "1.1.1.1"
|
||||||
|
|
||||||
|
history_for_ip = await fail2ban_db_repo.get_history_for_ip(db_path=db_path, ip="2.2.2.2")
|
||||||
|
assert len(history_for_ip) == 1
|
||||||
|
assert history_for_ip[0].ip == "2.2.2.2"
|
||||||
62
backend/tests/test_repositories/test_geo_cache_repo.py
Normal file
62
backend/tests/test_repositories/test_geo_cache_repo.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""Tests for the geo cache repository."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiosqlite
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.repositories import geo_cache_repo
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_geo_cache_table(db: aiosqlite.Connection) -> None:
|
||||||
|
await db.execute(
|
||||||
|
"""
|
||||||
|
CREATE TABLE IF NOT EXISTS geo_cache (
|
||||||
|
ip TEXT PRIMARY KEY,
|
||||||
|
country_code TEXT,
|
||||||
|
country_name TEXT,
|
||||||
|
asn TEXT,
|
||||||
|
org TEXT,
|
||||||
|
cached_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now'))
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_unresolved_ips_returns_empty_when_none_exist(tmp_path: Path) -> None:
|
||||||
|
db_path = str(tmp_path / "geo_cache.db")
|
||||||
|
async with aiosqlite.connect(db_path) as db:
|
||||||
|
await _create_geo_cache_table(db)
|
||||||
|
await db.execute(
|
||||||
|
"INSERT INTO geo_cache (ip, country_code, country_name, asn, org) VALUES (?, ?, ?, ?, ?)",
|
||||||
|
("1.1.1.1", "DE", "Germany", "AS123", "Test"),
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async with aiosqlite.connect(db_path) as db:
|
||||||
|
ips = await geo_cache_repo.get_unresolved_ips(db)
|
||||||
|
|
||||||
|
assert ips == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_unresolved_ips_returns_pending_ips(tmp_path: Path) -> None:
|
||||||
|
db_path = str(tmp_path / "geo_cache.db")
|
||||||
|
async with aiosqlite.connect(db_path) as db:
|
||||||
|
await _create_geo_cache_table(db)
|
||||||
|
await db.executemany(
|
||||||
|
"INSERT INTO geo_cache (ip, country_code) VALUES (?, ?)",
|
||||||
|
[
|
||||||
|
("2.2.2.2", None),
|
||||||
|
("3.3.3.3", None),
|
||||||
|
("4.4.4.4", "US"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async with aiosqlite.connect(db_path) as db:
|
||||||
|
ips = await geo_cache_repo.get_unresolved_ips(db)
|
||||||
|
|
||||||
|
assert sorted(ips) == ["2.2.2.2", "3.3.3.3"]
|
||||||
@@ -337,3 +337,40 @@ class TestGeoPrewarmCacheFilter:
|
|||||||
call_ips = mock_batch.call_args[0][0]
|
call_ips = mock_batch.call_args[0][0]
|
||||||
assert "1.2.3.4" not in call_ips
|
assert "1.2.3.4" not in call_ips
|
||||||
assert set(call_ips) == {"5.6.7.8", "9.10.11.12"}
|
assert set(call_ips) == {"5.6.7.8", "9.10.11.12"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestImportLogPagination:
|
||||||
|
async def test_list_import_logs_empty(self, db: aiosqlite.Connection) -> None:
|
||||||
|
"""list_import_logs returns an empty page when no logs exist."""
|
||||||
|
resp = await blocklist_service.list_import_logs(
|
||||||
|
db, source_id=None, page=1, page_size=10
|
||||||
|
)
|
||||||
|
assert resp.items == []
|
||||||
|
assert resp.total == 0
|
||||||
|
assert resp.page == 1
|
||||||
|
assert resp.page_size == 10
|
||||||
|
assert resp.total_pages == 1
|
||||||
|
|
||||||
|
async def test_list_import_logs_paginates(self, db: aiosqlite.Connection) -> None:
|
||||||
|
"""list_import_logs computes total pages and returns the correct subset."""
|
||||||
|
from app.repositories import import_log_repo
|
||||||
|
|
||||||
|
for i in range(3):
|
||||||
|
await import_log_repo.add_log(
|
||||||
|
db,
|
||||||
|
source_id=None,
|
||||||
|
source_url=f"https://example{i}.test/ips.txt",
|
||||||
|
ips_imported=1,
|
||||||
|
ips_skipped=0,
|
||||||
|
errors=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
resp = await blocklist_service.list_import_logs(
|
||||||
|
db, source_id=None, page=2, page_size=2
|
||||||
|
)
|
||||||
|
assert resp.total == 3
|
||||||
|
assert resp.total_pages == 2
|
||||||
|
assert resp.page == 2
|
||||||
|
assert resp.page_size == 2
|
||||||
|
assert len(resp.items) == 1
|
||||||
|
assert resp.items[0].source_url == "https://example0.test/ips.txt"
|
||||||
|
|||||||
Reference in New Issue
Block a user