diff --git a/Docs/Tasks.md b/Docs/Tasks.md index 59b58ea..ac2bb82 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -60,6 +60,8 @@ This document breaks the entire BanGUI project into development stages, ordered #### TASK B-3 โ€” Remove repository import from `routers/blocklist.py` +**Status:** Completed โœ… + **Violated rule:** Refactoring.md ยง2.1 โ€” Routers must not import from repositories; all data access must go through services. **Files affected:** diff --git a/backend/app/repositories/fail2ban_db_repo.py b/backend/app/repositories/fail2ban_db_repo.py new file mode 100644 index 0000000..acc17d3 --- /dev/null +++ b/backend/app/repositories/fail2ban_db_repo.py @@ -0,0 +1,358 @@ +"""Fail2Ban SQLite database repository. + +This module contains helper functions that query the read-only fail2ban +SQLite database file. All functions accept a *db_path* and manage their own +connection using aiosqlite in read-only mode. + +The functions intentionally return plain Python data structures (dataclasses) so +service layers can focus on business logic and formatting. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import aiosqlite + +if TYPE_CHECKING: + from collections.abc import Iterable + + from app.models.ban import BanOrigin + + +@dataclass(frozen=True) +class BanRecord: + """A single row from the fail2ban ``bans`` table.""" + + jail: str + ip: str + timeofban: int + bancount: int + data: str + + +@dataclass(frozen=True) +class BanIpCount: + """Aggregated ban count for a single IP.""" + + ip: str + event_count: int + + +@dataclass(frozen=True) +class JailBanCount: + """Aggregated ban count for a single jail.""" + + jail: str + count: int + + +@dataclass(frozen=True) +class HistoryRecord: + """A single row from the fail2ban ``bans`` table for history queries.""" + + jail: str + ip: str + timeofban: int + bancount: int + data: str + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _make_db_uri(db_path: str) -> str: + """Return a read-only sqlite URI for the given file path.""" + + return f"file:{db_path}?mode=ro" + + +def _origin_sql_filter(origin: BanOrigin | None) -> tuple[str, tuple[str, ...]]: + """Return a SQL fragment and parameters for the origin filter.""" + + if origin == "blocklist": + return " AND jail = ?", ("blocklist-import",) + if origin == "selfblock": + return " AND jail != ?", ("blocklist-import",) + return "", () + + +def _rows_to_ban_records(rows: Iterable[aiosqlite.Row]) -> list[BanRecord]: + return [ + BanRecord( + jail=str(r["jail"]), + ip=str(r["ip"]), + timeofban=int(r["timeofban"]), + bancount=int(r["bancount"]), + data=str(r["data"]), + ) + for r in rows + ] + + +def _rows_to_history_records(rows: Iterable[aiosqlite.Row]) -> list[HistoryRecord]: + return [ + HistoryRecord( + jail=str(r["jail"]), + ip=str(r["ip"]), + timeofban=int(r["timeofban"]), + bancount=int(r["bancount"]), + data=str(r["data"]), + ) + for r in rows + ] + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +async def check_db_nonempty(db_path: str) -> bool: + """Return True if the fail2ban database contains at least one ban row.""" + + async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db, db.execute( + "SELECT 1 FROM bans LIMIT 1" + ) as cur: + row = await cur.fetchone() + return row is not None + + +async def get_currently_banned( + db_path: str, + since: int, + origin: BanOrigin | None = None, + *, + limit: int | None = None, + offset: int | None = None, +) -> tuple[list[BanRecord], int]: + """Return a page of currently banned IPs and the total matching count. + + Args: + db_path: File path to the fail2ban SQLite database. + since: Unix timestamp to filter bans newer than or equal to. + origin: Optional origin filter. + limit: Optional maximum number of rows to return. + offset: Optional offset for pagination. + + Returns: + A ``(records, total)`` tuple. + """ + + origin_clause, origin_params = _origin_sql_filter(origin) + + async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db: + db.row_factory = aiosqlite.Row + + async with db.execute( + "SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause, + (since, *origin_params), + ) as cur: + count_row = await cur.fetchone() + total: int = int(count_row[0]) if count_row else 0 + + query = ( + "SELECT jail, ip, timeofban, bancount, data " + "FROM bans " + "WHERE timeofban >= ?" + origin_clause + " ORDER BY timeofban DESC" + ) + params: list[object] = [since, *origin_params] + if limit is not None: + query += " LIMIT ?" + params.append(limit) + if offset is not None: + query += " OFFSET ?" + params.append(offset) + + async with db.execute(query, params) as cur: + rows = await cur.fetchall() + + return _rows_to_ban_records(rows), total + + +async def get_ban_counts_by_bucket( + db_path: str, + since: int, + bucket_secs: int, + num_buckets: int, + origin: BanOrigin | None = None, +) -> list[int]: + """Return ban counts aggregated into equal-width time buckets.""" + + origin_clause, origin_params = _origin_sql_filter(origin) + + async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db: + db.row_factory = aiosqlite.Row + async with db.execute( + "SELECT CAST((timeofban - ?) / ? AS INTEGER) AS bucket_idx, " + "COUNT(*) AS cnt " + "FROM bans " + "WHERE timeofban >= ?" + origin_clause + " GROUP BY bucket_idx " + "ORDER BY bucket_idx", + (since, bucket_secs, since, *origin_params), + ) as cur: + rows = await cur.fetchall() + + counts: list[int] = [0] * num_buckets + for row in rows: + idx: int = int(row["bucket_idx"]) + if 0 <= idx < num_buckets: + counts[idx] = int(row["cnt"]) + + return counts + + +async def get_ban_event_counts( + db_path: str, + since: int, + origin: BanOrigin | None = None, +) -> list[BanIpCount]: + """Return total ban events per unique IP in the window.""" + + origin_clause, origin_params = _origin_sql_filter(origin) + + async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db: + db.row_factory = aiosqlite.Row + async with db.execute( + "SELECT ip, COUNT(*) AS event_count " + "FROM bans " + "WHERE timeofban >= ?" + origin_clause + " GROUP BY ip", + (since, *origin_params), + ) as cur: + rows = await cur.fetchall() + + return [ + BanIpCount(ip=str(r["ip"]), event_count=int(r["event_count"])) + for r in rows + ] + + +async def get_bans_by_jail( + db_path: str, + since: int, + origin: BanOrigin | None = None, +) -> tuple[int, list[JailBanCount]]: + """Return per-jail ban counts and the total ban count.""" + + origin_clause, origin_params = _origin_sql_filter(origin) + + async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db: + db.row_factory = aiosqlite.Row + + async with db.execute( + "SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause, + (since, *origin_params), + ) as cur: + count_row = await cur.fetchone() + total: int = int(count_row[0]) if count_row else 0 + + async with db.execute( + "SELECT jail, COUNT(*) AS cnt " + "FROM bans " + "WHERE timeofban >= ?" + origin_clause + " GROUP BY jail ORDER BY cnt DESC", + (since, *origin_params), + ) as cur: + rows = await cur.fetchall() + + return total, [ + JailBanCount(jail=str(r["jail"]), count=int(r["cnt"])) for r in rows + ] + + +async def get_bans_table_summary( + db_path: str, +) -> tuple[int, int | None, int | None]: + """Return basic summary stats for the ``bans`` table. + + Returns: + A tuple ``(row_count, min_timeofban, max_timeofban)``. If the table is + empty the min/max values will be ``None``. + """ + + async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db: + db.row_factory = aiosqlite.Row + async with db.execute( + "SELECT COUNT(*), MIN(timeofban), MAX(timeofban) FROM bans" + ) as cur: + row = await cur.fetchone() + + if row is None: + return 0, None, None + + return ( + int(row[0]), + int(row[1]) if row[1] is not None else None, + int(row[2]) if row[2] is not None else None, + ) + + +async def get_history_page( + db_path: str, + since: int | None = None, + jail: str | None = None, + ip_filter: str | None = None, + page: int = 1, + page_size: int = 100, +) -> tuple[list[HistoryRecord], int]: + """Return a paginated list of history records with total count.""" + + wheres: list[str] = [] + params: list[object] = [] + + if since is not None: + wheres.append("timeofban >= ?") + params.append(since) + + if jail is not None: + wheres.append("jail = ?") + params.append(jail) + + if ip_filter is not None: + wheres.append("ip LIKE ?") + params.append(f"{ip_filter}%") + + where_sql: str = ("WHERE " + " AND ".join(wheres)) if wheres else "" + + effective_page_size: int = page_size + offset: int = (page - 1) * effective_page_size + + async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db: + db.row_factory = aiosqlite.Row + + async with db.execute( + f"SELECT COUNT(*) FROM bans {where_sql}", # noqa: S608 + params, + ) as cur: + count_row = await cur.fetchone() + total: int = int(count_row[0]) if count_row else 0 + + async with db.execute( + f"SELECT jail, ip, timeofban, bancount, data " + f"FROM bans {where_sql} " + "ORDER BY timeofban DESC " + "LIMIT ? OFFSET ?", + [*params, effective_page_size, offset], + ) as cur: + rows = await cur.fetchall() + + return _rows_to_history_records(rows), total + + +async def get_history_for_ip(db_path: str, ip: str) -> list[HistoryRecord]: + """Return the full ban timeline for a specific IP.""" + + async with aiosqlite.connect(_make_db_uri(db_path), uri=True) as db: + db.row_factory = aiosqlite.Row + async with db.execute( + "SELECT jail, ip, timeofban, bancount, data " + "FROM bans " + "WHERE ip = ? " + "ORDER BY timeofban DESC", + (ip,), + ) as cur: + rows = await cur.fetchall() + + return _rows_to_history_records(rows) diff --git a/backend/app/routers/blocklist.py b/backend/app/routers/blocklist.py index 58cf951..04757a8 100644 --- a/backend/app/routers/blocklist.py +++ b/backend/app/routers/blocklist.py @@ -42,7 +42,6 @@ from app.models.blocklist import ( ScheduleConfig, ScheduleInfo, ) -from app.repositories import import_log_repo from app.services import blocklist_service from app.tasks import blocklist_import as blocklist_import_task @@ -225,19 +224,9 @@ async def get_import_log( Returns: :class:`~app.models.blocklist.ImportLogListResponse`. """ - items, total = await import_log_repo.list_logs( + return await blocklist_service.list_import_logs( db, source_id=source_id, page=page, page_size=page_size ) - total_pages = import_log_repo.compute_total_pages(total, page_size) - from app.models.blocklist import ImportLogEntry # noqa: PLC0415 - - return ImportLogListResponse( - items=[ImportLogEntry.model_validate(i) for i in items], - total=total, - page=page, - page_size=page_size, - total_pages=total_pages, - ) # --------------------------------------------------------------------------- diff --git a/backend/app/services/blocklist_service.py b/backend/app/services/blocklist_service.py index 5719a45..23df0d1 100644 --- a/backend/app/services/blocklist_service.py +++ b/backend/app/services/blocklist_service.py @@ -21,6 +21,8 @@ import structlog from app.models.blocklist import ( BlocklistSource, + ImportLogEntry, + ImportLogListResponse, ImportRunResult, ImportSourceResult, PreviewResponse, @@ -503,6 +505,38 @@ async def get_schedule_info( ) +async def list_import_logs( + db: aiosqlite.Connection, + *, + source_id: int | None = None, + page: int = 1, + page_size: int = 50, +) -> ImportLogListResponse: + """Return a paginated list of import log entries. + + Args: + db: Active application database connection. + source_id: Optional filter to only return logs for a specific source. + page: 1-based page number. + page_size: Items per page. + + Returns: + :class:`~app.models.blocklist.ImportLogListResponse`. + """ + items, total = await import_log_repo.list_logs( + db, source_id=source_id, page=page, page_size=page_size + ) + total_pages = import_log_repo.compute_total_pages(total, page_size) + + return ImportLogListResponse( + items=[ImportLogEntry.model_validate(i) for i in items], + total=total, + page=page, + page_size=page_size, + total_pages=total_pages, + ) + + # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- diff --git a/backend/tests/test_repositories/test_fail2ban_db_repo.py b/backend/tests/test_repositories/test_fail2ban_db_repo.py new file mode 100644 index 0000000..9f3c094 --- /dev/null +++ b/backend/tests/test_repositories/test_fail2ban_db_repo.py @@ -0,0 +1,138 @@ +"""Tests for the fail2ban_db repository. + +These tests use an in-memory sqlite file created under pytest's tmp_path and +exercise the core query functions used by the services. +""" + +from pathlib import Path + +import aiosqlite +import pytest + +from app.repositories import fail2ban_db_repo + + +async def _create_bans_table(db: aiosqlite.Connection) -> None: + await db.execute( + """ + CREATE TABLE bans ( + jail TEXT, + ip TEXT, + timeofban INTEGER, + bancount INTEGER, + data TEXT + ) + """ + ) + await db.commit() + + +@pytest.mark.asyncio +async def test_check_db_nonempty_returns_false_when_table_is_empty(tmp_path: Path) -> None: + db_path = str(tmp_path / "fail2ban.db") + async with aiosqlite.connect(db_path) as db: + await _create_bans_table(db) + + assert await fail2ban_db_repo.check_db_nonempty(db_path) is False + + +@pytest.mark.asyncio +async def test_check_db_nonempty_returns_true_when_row_exists(tmp_path: Path) -> None: + db_path = str(tmp_path / "fail2ban.db") + async with aiosqlite.connect(db_path) as db: + await _create_bans_table(db) + await db.execute( + "INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)", + ("jail1", "1.2.3.4", 123, 1, "{}"), + ) + await db.commit() + + assert await fail2ban_db_repo.check_db_nonempty(db_path) is True + + +@pytest.mark.asyncio +async def test_get_currently_banned_filters_and_pagination(tmp_path: Path) -> None: + db_path = str(tmp_path / "fail2ban.db") + async with aiosqlite.connect(db_path) as db: + await _create_bans_table(db) + # Three bans; one is from the blocklist-import jail. + await db.executemany( + "INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)", + [ + ("jail1", "1.1.1.1", 10, 1, "{}"), + ("blocklist-import", "2.2.2.2", 20, 2, "{}"), + ("jail1", "3.3.3.3", 30, 3, "{}"), + ], + ) + await db.commit() + + records, total = await fail2ban_db_repo.get_currently_banned( + db_path=db_path, + since=15, + origin="selfblock", + limit=10, + offset=0, + ) + + # Only the non-blocklist row with timeofban >= 15 should remain. + assert total == 1 + assert len(records) == 1 + assert records[0].ip == "3.3.3.3" + + +@pytest.mark.asyncio +async def test_get_ban_counts_by_bucket_ignores_out_of_range_buckets(tmp_path: Path) -> None: + db_path = str(tmp_path / "fail2ban.db") + async with aiosqlite.connect(db_path) as db: + await _create_bans_table(db) + await db.executemany( + "INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)", + [ + ("jail1", "1.1.1.1", 5, 1, "{}"), + ("jail1", "2.2.2.2", 15, 1, "{}"), + ("jail1", "3.3.3.3", 35, 1, "{}"), + ], + ) + await db.commit() + + counts = await fail2ban_db_repo.get_ban_counts_by_bucket( + db_path=db_path, + since=0, + bucket_secs=10, + num_buckets=3, + ) + + assert counts == [1, 1, 0] + + +@pytest.mark.asyncio +async def test_get_history_page_and_for_ip(tmp_path: Path) -> None: + db_path = str(tmp_path / "fail2ban.db") + async with aiosqlite.connect(db_path) as db: + await _create_bans_table(db) + await db.executemany( + "INSERT INTO bans (jail, ip, timeofban, bancount, data) VALUES (?, ?, ?, ?, ?)", + [ + ("jail1", "1.1.1.1", 100, 1, "{}"), + ("jail1", "1.1.1.1", 200, 2, "{}"), + ("jail1", "2.2.2.2", 300, 3, "{}"), + ], + ) + await db.commit() + + page, total = await fail2ban_db_repo.get_history_page( + db_path=db_path, + since=None, + jail="jail1", + ip_filter="1.1.1", + page=1, + page_size=10, + ) + + assert total == 2 + assert len(page) == 2 + assert page[0].ip == "1.1.1.1" + + history_for_ip = await fail2ban_db_repo.get_history_for_ip(db_path=db_path, ip="2.2.2.2") + assert len(history_for_ip) == 1 + assert history_for_ip[0].ip == "2.2.2.2" diff --git a/backend/tests/test_repositories/test_geo_cache_repo.py b/backend/tests/test_repositories/test_geo_cache_repo.py new file mode 100644 index 0000000..2e070b9 --- /dev/null +++ b/backend/tests/test_repositories/test_geo_cache_repo.py @@ -0,0 +1,62 @@ +"""Tests for the geo cache repository.""" + +from pathlib import Path + +import aiosqlite +import pytest + +from app.repositories import geo_cache_repo + + +async def _create_geo_cache_table(db: aiosqlite.Connection) -> None: + await db.execute( + """ + CREATE TABLE IF NOT EXISTS geo_cache ( + ip TEXT PRIMARY KEY, + country_code TEXT, + country_name TEXT, + asn TEXT, + org TEXT, + cached_at TEXT NOT NULL DEFAULT (strftime('%Y-%m-%dT%H:%M:%fZ', 'now')) + ) + """ + ) + await db.commit() + + +@pytest.mark.asyncio +async def test_get_unresolved_ips_returns_empty_when_none_exist(tmp_path: Path) -> None: + db_path = str(tmp_path / "geo_cache.db") + async with aiosqlite.connect(db_path) as db: + await _create_geo_cache_table(db) + await db.execute( + "INSERT INTO geo_cache (ip, country_code, country_name, asn, org) VALUES (?, ?, ?, ?, ?)", + ("1.1.1.1", "DE", "Germany", "AS123", "Test"), + ) + await db.commit() + + async with aiosqlite.connect(db_path) as db: + ips = await geo_cache_repo.get_unresolved_ips(db) + + assert ips == [] + + +@pytest.mark.asyncio +async def test_get_unresolved_ips_returns_pending_ips(tmp_path: Path) -> None: + db_path = str(tmp_path / "geo_cache.db") + async with aiosqlite.connect(db_path) as db: + await _create_geo_cache_table(db) + await db.executemany( + "INSERT INTO geo_cache (ip, country_code) VALUES (?, ?)", + [ + ("2.2.2.2", None), + ("3.3.3.3", None), + ("4.4.4.4", "US"), + ], + ) + await db.commit() + + async with aiosqlite.connect(db_path) as db: + ips = await geo_cache_repo.get_unresolved_ips(db) + + assert sorted(ips) == ["2.2.2.2", "3.3.3.3"] diff --git a/backend/tests/test_services/test_blocklist_service.py b/backend/tests/test_services/test_blocklist_service.py index 579b4c1..151a260 100644 --- a/backend/tests/test_services/test_blocklist_service.py +++ b/backend/tests/test_services/test_blocklist_service.py @@ -337,3 +337,40 @@ class TestGeoPrewarmCacheFilter: call_ips = mock_batch.call_args[0][0] assert "1.2.3.4" not in call_ips assert set(call_ips) == {"5.6.7.8", "9.10.11.12"} + + +class TestImportLogPagination: + async def test_list_import_logs_empty(self, db: aiosqlite.Connection) -> None: + """list_import_logs returns an empty page when no logs exist.""" + resp = await blocklist_service.list_import_logs( + db, source_id=None, page=1, page_size=10 + ) + assert resp.items == [] + assert resp.total == 0 + assert resp.page == 1 + assert resp.page_size == 10 + assert resp.total_pages == 1 + + async def test_list_import_logs_paginates(self, db: aiosqlite.Connection) -> None: + """list_import_logs computes total pages and returns the correct subset.""" + from app.repositories import import_log_repo + + for i in range(3): + await import_log_repo.add_log( + db, + source_id=None, + source_url=f"https://example{i}.test/ips.txt", + ips_imported=1, + ips_skipped=0, + errors=None, + ) + + resp = await blocklist_service.list_import_logs( + db, source_id=None, page=2, page_size=2 + ) + assert resp.total == 3 + assert resp.total_pages == 2 + assert resp.page == 2 + assert resp.page_size == 2 + assert len(resp.items) == 1 + assert resp.items[0].source_url == "https://example0.test/ips.txt"