Files
BanGUI/backend/app/repositories/fail2ban_db_repo.py

359 lines
10 KiB
Python

"""Fail2Ban SQLite database repository.
This module contains helper functions that query the read-only fail2ban
SQLite database file. All functions accept a *db_path* and manage their own
connection using aiosqlite in read-only mode.
The functions intentionally return plain Python data structures (dataclasses) so
service layers can focus on business logic and formatting.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import aiosqlite
if TYPE_CHECKING:
from collections.abc import Iterable
from app.models.ban import BanOrigin
@dataclass(frozen=True)
class BanRecord:
"""A single row from the fail2ban ``bans`` table."""
jail: str
ip: str
timeofban: int
bancount: int
data: str
@dataclass(frozen=True)
class BanIpCount:
"""Aggregated ban count for a single IP."""
ip: str
event_count: int
@dataclass(frozen=True)
class JailBanCount:
"""Aggregated ban count for a single jail."""
jail: str
count: int
@dataclass(frozen=True)
class HistoryRecord:
"""A single row from the fail2ban ``bans`` table for history queries."""
jail: str
ip: str
timeofban: int
bancount: int
data: str
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _make_db_uri(db_path: str) -> str:
"""Return a read-only sqlite URI for the given file path."""
return f"file:{db_path}?mode=ro"
def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]:
"""Return a SQL fragment and parameters for the origin filter."""
if origin == "blocklist":
return " AND jail = ?", ("blocklist-import",)
if origin == "selfblock":
return " AND jail != ?", ("blocklist-import",)
return "", ()
def _rows_to_ban_records(rows: Iterable[aiosqlite.Row]) -> list[BanRecord]:
return [
BanRecord(
jail=str(r["jail"]),
ip=str(r["ip"]),
timeofban=int(r["timeofban"]),
bancount=int(r["bancount"]),
data=str(r["data"]),
)
for r in rows
]
def _rows_to_history_records(rows: Iterable[aiosqlite.Row]) -> list[HistoryRecord]:
return [
HistoryRecord(
jail=str(r["jail"]),
ip=str(r["ip"]),
timeofban=int(r["timeofban"]),
bancount=int(r["bancount"]),
data=str(r["data"]),
)
for r in rows
]
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
async def check_db_nonempty(db_path: str) -> bool:
"""Return True if the fail2ban database contains at least one ban row."""
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db, db.execute(
"SELECT 1 FROM bans LIMIT 1"
) as cur:
row = await cur.fetchone()
return row is not None
async def get_currently_banned(
db_path: str,
since: int,
origin: BanOrigin | None = None,
*,
limit: int | None = None,
offset: int | None = None,
) -> tuple[list[BanRecord], int]:
"""Return a page of currently banned IPs and the total matching count.
Args:
db_path: File path to the fail2ban SQLite database.
since: Unix timestamp to filter bans newer than or equal to.
origin: Optional origin filter.
limit: Optional maximum number of rows to return.
offset: Optional offset for pagination.
Returns:
A ``(records, total)`` tuple.
"""
origin_clause, origin_params = _origin_sql_filter(origin)
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
(since, *origin_params),
) as cur:
count_row = await cur.fetchone()
total: int = int(count_row[0]) if count_row else 0
query = (
"SELECT jail, ip, timeofban, bancount, data "
"FROM bans "
"WHERE timeofban >= ?" + origin_clause + " ORDER BY timeofban DESC"
)
params: list[object] = [since, *origin_params]
if limit is not None:
query += " LIMIT ?"
params.append(limit)
if offset is not None:
query += " OFFSET ?"
params.append(offset)
async with db.execute(query, params) as cur:
rows = await cur.fetchall()
return _rows_to_ban_records(rows), total
async def get_ban_counts_by_bucket(
db_path: str,
since: int,
bucket_secs: int,
num_buckets: int,
origin: BanOrigin | None = None,
) -> list[int]:
"""Return ban counts aggregated into equal-width time buckets."""
origin_clause, origin_params = _origin_sql_filter(origin)
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT CAST((timeofban - ?) / ? AS INTEGER) AS bucket_idx, "
"COUNT(*) AS cnt "
"FROM bans "
"WHERE timeofban >= ?" + origin_clause + " GROUP BY bucket_idx "
"ORDER BY bucket_idx",
(since, bucket_secs, since, *origin_params),
) as cur:
rows = await cur.fetchall()
counts: list[int] = [0] * num_buckets
for row in rows:
idx: int = int(row["bucket_idx"])
if 0 <= idx < num_buckets:
counts[idx] = int(row["cnt"])
return counts
async def get_ban_event_counts(
db_path: str,
since: int,
origin: BanOrigin | None = None,
) -> list[BanIpCount]:
"""Return total ban events per unique IP in the window."""
origin_clause, origin_params = _origin_sql_filter(origin)
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT ip, COUNT(*) AS event_count "
"FROM bans "
"WHERE timeofban >= ?" + origin_clause + " GROUP BY ip",
(since, *origin_params),
) as cur:
rows = await cur.fetchall()
return [
BanIpCount(ip=str(r["ip"]), event_count=int(r["event_count"]))
for r in rows
]
async def get_bans_by_jail(
db_path: str,
since: int,
origin: BanOrigin | None = None,
) -> tuple[int, list[JailBanCount]]:
"""Return per-jail ban counts and the total ban count."""
origin_clause, origin_params = _origin_sql_filter(origin)
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause,
(since, *origin_params),
) as cur:
count_row = await cur.fetchone()
total: int = int(count_row[0]) if count_row else 0
async with db.execute(
"SELECT jail, COUNT(*) AS cnt "
"FROM bans "
"WHERE timeofban >= ?" + origin_clause + " GROUP BY jail ORDER BY cnt DESC",
(since, *origin_params),
) as cur:
rows = await cur.fetchall()
return total, [
JailBanCount(jail=str(r["jail"]), count=int(r["cnt"])) for r in rows
]
async def get_bans_table_summary(
db_path: str,
) -> tuple[int, int | None, int | None]:
"""Return basic summary stats for the ``bans`` table.
Returns:
A tuple ``(row_count, min_timeofban, max_timeofban)``. If the table is
empty the min/max values will be ``None``.
"""
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT COUNT(*), MIN(timeofban), MAX(timeofban) FROM bans"
) as cur:
row = await cur.fetchone()
if row is None:
return 0, None, None
return (
int(row[0]),
int(row[1]) if row[1] is not None else None,
int(row[2]) if row[2] is not None else None,
)
async def get_history_page(
db_path: str,
since: int | None = None,
jail: str | None = None,
ip_filter: str | None = None,
page: int = 1,
page_size: int = 100,
) -> tuple[list[HistoryRecord], int]:
"""Return a paginated list of history records with total count."""
wheres: list[str] = []
params: list[object] = []
if since is not None:
wheres.append("timeofban >= ?")
params.append(since)
if jail is not None:
wheres.append("jail = ?")
params.append(jail)
if ip_filter is not None:
wheres.append("ip LIKE ?")
params.append(f"{ip_filter}%")
where_sql: str = ("WHERE " + " AND ".join(wheres)) if wheres else ""
effective_page_size: int = page_size
offset: int = (page - 1) * effective_page_size
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
f"SELECT COUNT(*) FROM bans {where_sql}", # noqa: S608
params,
) as cur:
count_row = await cur.fetchone()
total: int = int(count_row[0]) if count_row else 0
async with db.execute(
f"SELECT jail, ip, timeofban, bancount, data "
f"FROM bans {where_sql} "
"ORDER BY timeofban DESC "
"LIMIT ? OFFSET ?",
[*params, effective_page_size, offset],
) as cur:
rows = await cur.fetchall()
return _rows_to_history_records(rows), total
async def get_history_for_ip(db_path: str, ip: str) -> list[HistoryRecord]:
"""Return the full ban timeline for a specific IP."""
async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db:
db.row_factory = aiosqlite.Row
async with db.execute(
"SELECT jail, ip, timeofban, bancount, data "
"FROM bans "
"WHERE ip = ? "
"ORDER BY timeofban DESC",
(ip,),
) as cur:
rows = await cur.fetchall()
return _rows_to_history_records(rows)