"""Fail2Ban SQLite database repository. This module contains helper functions that query the read-only fail2ban SQLite database file. All functions accept a *db_path* and manage their own connection using aiosqlite in read-only mode. The functions intentionally return plain Python data structures (dataclasses) so service layers can focus on business logic and formatting. """ from __future__ import annotations from dataclasses import dataclass from typing import TYPE_CHECKING import aiosqlite if TYPE_CHECKING: from collections.abc import Iterable from app.models.ban import BanOrigin @dataclass(frozen=True) class BanRecord: """A single row from the fail2ban ``bans`` table.""" jail: str ip: str timeofban: int bancount: int data: str @dataclass(frozen=True) class BanIpCount: """Aggregated ban count for a single IP.""" ip: str event_count: int @dataclass(frozen=True) class JailBanCount: """Aggregated ban count for a single jail.""" jail: str count: int @dataclass(frozen=True) class HistoryRecord: """A single row from the fail2ban ``bans`` table for history queries.""" jail: str ip: str timeofban: int bancount: int data: str # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _make_db_uri(db_path: str) -> str: """Return a read-only sqlite URI for the given file path.""" return f"file:{db_path}?mode=ro" def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]: """Return a SQL fragment and parameters for the origin filter.""" if origin == "blocklist": return " AND jail = ?", ("blocklist-import",) if origin == "selfblock": return " AND jail != ?", ("blocklist-import",) return "", () def _rows_to_ban_records(rows: Iterable[aiosqlite.Row]) -> list[BanRecord]: return [ BanRecord( jail=str(r["jail"]), ip=str(r["ip"]), timeofban=int(r["timeofban"]), bancount=int(r["bancount"]), data=str(r["data"]), ) for r in rows ] def _rows_to_history_records(rows: Iterable[aiosqlite.Row]) -> list[HistoryRecord]: return [ HistoryRecord( jail=str(r["jail"]), ip=str(r["ip"]), timeofban=int(r["timeofban"]), bancount=int(r["bancount"]), data=str(r["data"]), ) for r in rows ] # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- async def check_db_nonempty(db_path: str) -> bool: """Return True if the fail2ban database contains at least one ban row.""" async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db, db.execute( "SELECT 1 FROM bans LIMIT 1" ) as cur: row = await cur.fetchone() return row is not None async def get_currently_banned( db_path: str, since: int, origin: BanOrigin | None = None, *, limit: int | None = None, offset: int | None = None, ) -> tuple[list[BanRecord], int]: """Return a page of currently banned IPs and the total matching count. Args: db_path: File path to the fail2ban SQLite database. since: Unix timestamp to filter bans newer than or equal to. origin: Optional origin filter. limit: Optional maximum number of rows to return. offset: Optional offset for pagination. Returns: A ``(records, total)`` tuple. """ origin_clause, origin_params = _origin_sql_filter(origin) async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db: db.row_factory = aiosqlite.Row async with db.execute( "SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause, (since, *origin_params), ) as cur: count_row = await cur.fetchone() total: int = int(count_row[0]) if count_row else 0 query = ( "SELECT jail, ip, timeofban, bancount, data " "FROM bans " "WHERE timeofban >= ?" + origin_clause + " ORDER BY timeofban DESC" ) params: list[object] = [since, *origin_params] if limit is not None: query += " LIMIT ?" params.append(limit) if offset is not None: query += " OFFSET ?" params.append(offset) async with db.execute(query, params) as cur: rows = await cur.fetchall() return _rows_to_ban_records(rows), total async def get_ban_counts_by_bucket( db_path: str, since: int, bucket_secs: int, num_buckets: int, origin: BanOrigin | None = None, ) -> list[int]: """Return ban counts aggregated into equal-width time buckets.""" origin_clause, origin_params = _origin_sql_filter(origin) async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db: db.row_factory = aiosqlite.Row async with db.execute( "SELECT CAST((timeofban - ?) / ? AS INTEGER) AS bucket_idx, " "COUNT(*) AS cnt " "FROM bans " "WHERE timeofban >= ?" + origin_clause + " GROUP BY bucket_idx " "ORDER BY bucket_idx", (since, bucket_secs, since, *origin_params), ) as cur: rows = await cur.fetchall() counts: list[int] = [0] * num_buckets for row in rows: idx: int = int(row["bucket_idx"]) if 0 <= idx < num_buckets: counts[idx] = int(row["cnt"]) return counts async def get_ban_event_counts( db_path: str, since: int, origin: BanOrigin | None = None, ) -> list[BanIpCount]: """Return total ban events per unique IP in the window.""" origin_clause, origin_params = _origin_sql_filter(origin) async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db: db.row_factory = aiosqlite.Row async with db.execute( "SELECT ip, COUNT(*) AS event_count " "FROM bans " "WHERE timeofban >= ?" + origin_clause + " GROUP BY ip", (since, *origin_params), ) as cur: rows = await cur.fetchall() return [ BanIpCount(ip=str(r["ip"]), event_count=int(r["event_count"])) for r in rows ] async def get_bans_by_jail( db_path: str, since: int, origin: BanOrigin | None = None, ) -> tuple[int, list[JailBanCount]]: """Return per-jail ban counts and the total ban count.""" origin_clause, origin_params = _origin_sql_filter(origin) async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db: db.row_factory = aiosqlite.Row async with db.execute( "SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause, (since, *origin_params), ) as cur: count_row = await cur.fetchone() total: int = int(count_row[0]) if count_row else 0 async with db.execute( "SELECT jail, COUNT(*) AS cnt " "FROM bans " "WHERE timeofban >= ?" + origin_clause + " GROUP BY jail ORDER BY cnt DESC", (since, *origin_params), ) as cur: rows = await cur.fetchall() return total, [ JailBanCount(jail=str(r["jail"]), count=int(r["cnt"])) for r in rows ] async def get_bans_table_summary( db_path: str, ) -> tuple[int, int | None, int | None]: """Return basic summary stats for the ``bans`` table. Returns: A tuple ``(row_count, min_timeofban, max_timeofban)``. If the table is empty the min/max values will be ``None``. """ async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db: db.row_factory = aiosqlite.Row async with db.execute( "SELECT COUNT(*), MIN(timeofban), MAX(timeofban) FROM bans" ) as cur: row = await cur.fetchone() if row is None: return 0, None, None return ( int(row[0]), int(row[1]) if row[1] is not None else None, int(row[2]) if row[2] is not None else None, ) async def get_history_page( db_path: str, since: int | None = None, jail: str | None = None, ip_filter: str | None = None, page: int = 1, page_size: int = 100, ) -> tuple[list[HistoryRecord], int]: """Return a paginated list of history records with total count.""" wheres: list[str] = [] params: list[object] = [] if since is not None: wheres.append("timeofban >= ?") params.append(since) if jail is not None: wheres.append("jail = ?") params.append(jail) if ip_filter is not None: wheres.append("ip LIKE ?") params.append(f"{ip_filter}%") where_sql: str = ("WHERE " + " AND ".join(wheres)) if wheres else "" effective_page_size: int = page_size offset: int = (page - 1) * effective_page_size async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db: db.row_factory = aiosqlite.Row async with db.execute( f"SELECT COUNT(*) FROM bans {where_sql}", # noqa: S608 params, ) as cur: count_row = await cur.fetchone() total: int = int(count_row[0]) if count_row else 0 async with db.execute( f"SELECT jail, ip, timeofban, bancount, data " f"FROM bans {where_sql} " "ORDER BY timeofban DESC " "LIMIT ? OFFSET ?", [*params, effective_page_size, offset], ) as cur: rows = await cur.fetchall() return _rows_to_history_records(rows), total async def get_history_for_ip(db_path: str, ip: str) -> list[HistoryRecord]: """Return the full ban timeline for a specific IP.""" async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db: db.row_factory = aiosqlite.Row async with db.execute( "SELECT jail, ip, timeofban, bancount, data " "FROM bans " "WHERE ip = ? " "ORDER BY timeofban DESC", (ip,), ) as cur: rows = await cur.fetchall() return _rows_to_history_records(rows)