diff --git a/Docs/Tasks.md b/Docs/Tasks.md index 3c0e095..59b58ea 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -16,7 +16,9 @@ This document breaks the entire BanGUI project into development stages, ordered --- -#### TASK B-1 — Create a `fail2ban_db` repository for direct fail2ban database queries +#### TASK B-1 — Create a `fail2ban_db` repository for direct fail2ban database queries ✅ + +**Status:** Completed **Violated rule:** Refactoring.md §2.2 — Services must not perform direct `aiosqlite` calls; go through a repository. @@ -41,6 +43,8 @@ This document breaks the entire BanGUI project into development stages, ordered #### TASK B-2 — Remove direct SQL query from `routers/geo.py` +**Status:** Completed ✅ + **Violated rule:** Refactoring.md §2.1 — Routers must contain zero business logic; no SQL or repository imports. **Files affected:** diff --git a/backend/app/repositories/geo_cache_repo.py b/backend/app/repositories/geo_cache_repo.py new file mode 100644 index 0000000..8e7ed8d --- /dev/null +++ b/backend/app/repositories/geo_cache_repo.py @@ -0,0 +1,33 @@ +"""Repository for the geo cache persistent store. + +This module provides typed, async helpers for querying and mutating the +``geo_cache`` table in the BanGUI application database. + +All functions accept an open :class:`aiosqlite.Connection` and do not manage +connection lifetimes. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + import aiosqlite + + +async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]: + """Return all IPs in ``geo_cache`` where ``country_code`` is NULL. + + Args: + db: Open BanGUI application database connection. + + Returns: + List of IPv4/IPv6 strings that need geo resolution. + """ + ips: list[str] = [] + async with db.execute( + "SELECT ip FROM geo_cache WHERE country_code IS NULL" + ) as cur: + async for row in cur: + ips.append(str(row[0])) + return ips diff --git a/backend/app/routers/geo.py b/backend/app/routers/geo.py index 0200496..2b0abfc 100644 --- a/backend/app/routers/geo.py +++ b/backend/app/routers/geo.py @@ -153,12 +153,7 @@ async def re_resolve_geo( that were retried. """ # Collect all IPs in geo_cache that still lack a country code. - unresolved: list[str] = [] - async with db.execute( - "SELECT ip FROM geo_cache WHERE country_code IS NULL" - ) as cur: - async for row in cur: - unresolved.append(str(row[0])) + unresolved = await geo_service.get_unresolved_ips(db) if not unresolved: return {"resolved": 0, "total": 0} diff --git a/backend/app/services/ban_service.py b/backend/app/services/ban_service.py index ab08ab4..14c9bc7 100644 --- a/backend/app/services/ban_service.py +++ b/backend/app/services/ban_service.py @@ -13,12 +13,15 @@ from __future__ import annotations import asyncio import json import time +from dataclasses import asdict from datetime import UTC, datetime from typing import TYPE_CHECKING, Any -import aiosqlite import structlog +if TYPE_CHECKING: + import aiosqlite + from app.models.ban import ( BLOCKLIST_JAIL, BUCKET_SECONDS, @@ -31,11 +34,11 @@ from app.models.ban import ( BanTrendResponse, DashboardBanItem, DashboardBanListResponse, - JailBanCount, TimeRange, _derive_origin, bucket_count, ) +from app.repositories import fail2ban_db_repo from app.utils.fail2ban_client import Fail2BanClient if TYPE_CHECKING: @@ -244,33 +247,20 @@ async def list_bans( origin=origin, ) - async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db: - f2b_db.row_factory = aiosqlite.Row - - async with f2b_db.execute( - "SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause, - (since, *origin_params), - ) as cur: - count_row = await cur.fetchone() - total: int = int(count_row[0]) if count_row else 0 - - async with f2b_db.execute( - "SELECT jail, ip, timeofban, bancount, data " - "FROM bans " - "WHERE timeofban >= ?" - + origin_clause - + " ORDER BY timeofban DESC " - "LIMIT ? OFFSET ?", - (since, *origin_params, effective_page_size, offset), - ) as cur: - rows = await cur.fetchall() + rows, total = await fail2ban_db_repo.get_currently_banned( + db_path=db_path, + since=since, + origin=origin, + limit=effective_page_size, + offset=offset, + ) # Batch-resolve geo data for all IPs on this page in a single API call. # This avoids hitting the 45 req/min single-IP rate limit when the # page contains many bans (e.g. after a large blocklist import). geo_map: dict[str, Any] = {} if http_session is not None and rows: - page_ips: list[str] = [str(r["ip"]) for r in rows] + page_ips: list[str] = [r.ip for r in rows] try: geo_map = await geo_service.lookup_batch(page_ips, http_session, db=app_db) except Exception: # noqa: BLE001 @@ -278,11 +268,11 @@ async def list_bans( items: list[DashboardBanItem] = [] for row in rows: - jail: str = str(row["jail"]) - ip: str = str(row["ip"]) - banned_at: str = _ts_to_iso(int(row["timeofban"])) - ban_count: int = int(row["bancount"]) - matches, _ = _parse_data_json(row["data"]) + jail: str = row.jail + ip: str = row.ip + banned_at: str = _ts_to_iso(row.timeofban) + ban_count: int = row.bancount + matches, _ = _parse_data_json(row.data) service: str | None = matches[0] if matches else None country_code: str | None = None @@ -395,42 +385,31 @@ async def bans_by_country( origin=origin, ) - async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db: - f2b_db.row_factory = aiosqlite.Row + # Total count and companion rows reuse the same SQL query logic. + # Passing limit=0 returns only the total from the count query. + _, total = await fail2ban_db_repo.get_currently_banned( + db_path=db_path, + since=since, + origin=origin, + limit=0, + offset=0, + ) - # Total count for the window. - async with f2b_db.execute( - "SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause, - (since, *origin_params), - ) as cur: - count_row = await cur.fetchone() - total: int = int(count_row[0]) if count_row else 0 + agg_rows = await fail2ban_db_repo.get_ban_event_counts( + db_path=db_path, + since=since, + origin=origin, + ) - # Aggregation: unique IPs + their total event count. - # No LIMIT here — we need all unique source IPs for accurate country counts. - async with f2b_db.execute( - "SELECT ip, COUNT(*) AS event_count " - "FROM bans " - "WHERE timeofban >= ?" - + origin_clause - + " GROUP BY ip", - (since, *origin_params), - ) as cur: - agg_rows = await cur.fetchall() + companion_rows, _ = await fail2ban_db_repo.get_currently_banned( + db_path=db_path, + since=since, + origin=origin, + limit=_MAX_COMPANION_BANS, + offset=0, + ) - # Companion table: most recent raw rows for display alongside the map. - async with f2b_db.execute( - "SELECT jail, ip, timeofban, bancount, data " - "FROM bans " - "WHERE timeofban >= ?" - + origin_clause - + " ORDER BY timeofban DESC " - "LIMIT ?", - (since, *origin_params, _MAX_COMPANION_BANS), - ) as cur: - companion_rows = await cur.fetchall() - - unique_ips: list[str] = [str(r["ip"]) for r in agg_rows] + unique_ips: list[str] = [r.ip for r in agg_rows] geo_map: dict[str, Any] = {} if http_session is not None and unique_ips: @@ -467,11 +446,11 @@ async def bans_by_country( country_names: dict[str, str] = {} for row in agg_rows: - ip: str = str(row["ip"]) + ip: str = row.ip geo = geo_map.get(ip) cc: str | None = geo.country_code if geo else None cn: str | None = geo.country_name if geo else None - event_count: int = int(row["event_count"]) + event_count: int = row.event_count if cc: countries[cc] = countries.get(cc, 0) + event_count @@ -481,26 +460,26 @@ async def bans_by_country( # Build companion table from recent rows (geo already cached from batch step). bans: list[DashboardBanItem] = [] for row in companion_rows: - ip = str(row["ip"]) + ip = row.ip geo = geo_map.get(ip) cc = geo.country_code if geo else None cn = geo.country_name if geo else None asn: str | None = geo.asn if geo else None org: str | None = geo.org if geo else None - matches, _ = _parse_data_json(row["data"]) + matches, _ = _parse_data_json(row.data) bans.append( DashboardBanItem( ip=ip, - jail=str(row["jail"]), - banned_at=_ts_to_iso(int(row["timeofban"])), + jail=row.jail, + banned_at=_ts_to_iso(row.timeofban), service=matches[0] if matches else None, country_code=cc, country_name=cn, asn=asn, org=org, - ban_count=int(row["bancount"]), - origin=_derive_origin(str(row["jail"])), + ban_count=row.bancount, + origin=_derive_origin(row.jail), ) ) @@ -565,32 +544,18 @@ async def ban_trend( num_buckets=num_buckets, ) - async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db: - f2b_db.row_factory = aiosqlite.Row - - async with f2b_db.execute( - "SELECT CAST((timeofban - ?) / ? AS INTEGER) AS bucket_idx, " - "COUNT(*) AS cnt " - "FROM bans " - "WHERE timeofban >= ?" - + origin_clause - + " GROUP BY bucket_idx " - "ORDER BY bucket_idx", - (since, bucket_secs, since, *origin_params), - ) as cur: - rows = await cur.fetchall() - - # Map bucket_idx → count; ignore any out-of-range indices. - counts: dict[int, int] = {} - for row in rows: - idx: int = int(row["bucket_idx"]) - if 0 <= idx < num_buckets: - counts[idx] = int(row["cnt"]) + counts = await fail2ban_db_repo.get_ban_counts_by_bucket( + db_path=db_path, + since=since, + bucket_secs=bucket_secs, + num_buckets=num_buckets, + origin=origin, + ) buckets: list[BanTrendBucket] = [ BanTrendBucket( timestamp=_ts_to_iso(since + i * bucket_secs), - count=counts.get(i, 0), + count=counts[i], ) for i in range(num_buckets) ] @@ -643,50 +608,37 @@ async def bans_by_jail( origin=origin, ) - async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db: - f2b_db.row_factory = aiosqlite.Row + total, jails = await fail2ban_db_repo.get_bans_by_jail( + db_path=db_path, + since=since, + origin=origin, + ) - async with f2b_db.execute( - "SELECT COUNT(*) FROM bans WHERE timeofban >= ?" + origin_clause, - (since, *origin_params), - ) as cur: - count_row = await cur.fetchone() - total: int = int(count_row[0]) if count_row else 0 + # Diagnostic guard: if zero results were returned, check whether the table + # has *any* rows and log a warning with min/max timeofban so operators can + # diagnose timezone or filter mismatches from logs. + if total == 0: + table_row_count, min_timeofban, max_timeofban = ( + await fail2ban_db_repo.get_bans_table_summary(db_path) + ) + if table_row_count > 0: + log.warning( + "ban_service_bans_by_jail_empty_despite_data", + table_row_count=table_row_count, + min_timeofban=min_timeofban, + max_timeofban=max_timeofban, + since=since, + range=range_, + ) - # Diagnostic guard: if zero results were returned, check whether the - # table has *any* rows and log a warning with min/max timeofban so - # operators can diagnose timezone or filter mismatches from logs. - if total == 0: - async with f2b_db.execute( - "SELECT COUNT(*), MIN(timeofban), MAX(timeofban) FROM bans" - ) as cur: - diag_row = await cur.fetchone() - if diag_row and diag_row[0] > 0: - log.warning( - "ban_service_bans_by_jail_empty_despite_data", - table_row_count=diag_row[0], - min_timeofban=diag_row[1], - max_timeofban=diag_row[2], - since=since, - range=range_, - ) - - async with f2b_db.execute( - "SELECT jail, COUNT(*) AS cnt " - "FROM bans " - "WHERE timeofban >= ?" - + origin_clause - + " GROUP BY jail ORDER BY cnt DESC", - (since, *origin_params), - ) as cur: - rows = await cur.fetchall() - - jails: list[JailBanCount] = [ - JailBanCount(jail=str(row["jail"]), count=int(row["cnt"])) for row in rows - ] log.debug( "ban_service_bans_by_jail_result", total=total, jail_count=len(jails), ) - return BansByJailResponse(jails=jails, total=total) + + # Pydantic strict validation requires either dicts or model instances. + # Our repository returns dataclasses for simplicity, so convert them here. + jail_dicts: list[dict[str, object]] = [asdict(j) for j in jails] + + return BansByJailResponse(jails=jail_dicts, total=total) diff --git a/backend/app/services/geo_service.py b/backend/app/services/geo_service.py index 325517e..95f5927 100644 --- a/backend/app/services/geo_service.py +++ b/backend/app/services/geo_service.py @@ -46,6 +46,8 @@ from typing import TYPE_CHECKING import aiohttp import structlog +from app.repositories import geo_cache_repo + if TYPE_CHECKING: import aiosqlite import geoip2.database @@ -198,6 +200,18 @@ async def cache_stats(db: aiosqlite.Connection) -> dict[str, int]: } +async def get_unresolved_ips(db: aiosqlite.Connection) -> list[str]: + """Return geo cache IPs where the country code has not yet been resolved. + + Args: + db: Open BanGUI application database connection. + + Returns: + List of IP addresses that are candidates for re-resolution. + """ + return await geo_cache_repo.get_unresolved_ips(db) + + def init_geoip(mmdb_path: str | None) -> None: """Initialise the MaxMind GeoLite2-Country database reader. diff --git a/backend/app/services/history_service.py b/backend/app/services/history_service.py index 26c2f78..bad337b 100644 --- a/backend/app/services/history_service.py +++ b/backend/app/services/history_service.py @@ -13,7 +13,6 @@ from __future__ import annotations from datetime import UTC, datetime from typing import Any -import aiosqlite import structlog from app.models.ban import TIME_RANGE_SECONDS, TimeRange @@ -23,6 +22,7 @@ from app.models.history import ( IpDetailResponse, IpTimelineEvent, ) +from app.repositories import fail2ban_db_repo from app.services.ban_service import _get_fail2ban_db_path, _parse_data_json, _ts_to_iso log: structlog.stdlib.BoundLogger = structlog.get_logger() @@ -84,26 +84,11 @@ async def list_history( and the total matching count. """ effective_page_size: int = min(page_size, _MAX_PAGE_SIZE) - offset: int = (page - 1) * effective_page_size # Build WHERE clauses dynamically. - wheres: list[str] = [] - params: list[Any] = [] - + since: int | None = None if range_ is not None: - since: int = _since_unix(range_) - wheres.append("timeofban >= ?") - params.append(since) - - if jail is not None: - wheres.append("jail = ?") - params.append(jail) - - if ip_filter is not None: - wheres.append("ip LIKE ?") - params.append(f"{ip_filter}%") - - where_sql: str = ("WHERE " + " AND ".join(wheres)) if wheres else "" + since = _since_unix(range_) db_path: str = await _get_fail2ban_db_path(socket_path) log.info( @@ -115,32 +100,22 @@ async def list_history( page=page, ) - async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db: - f2b_db.row_factory = aiosqlite.Row - - async with f2b_db.execute( - f"SELECT COUNT(*) FROM bans {where_sql}", # noqa: S608 - params, - ) as cur: - count_row = await cur.fetchone() - total: int = int(count_row[0]) if count_row else 0 - - async with f2b_db.execute( - f"SELECT jail, ip, timeofban, bancount, data " # noqa: S608 - f"FROM bans {where_sql} " - "ORDER BY timeofban DESC " - "LIMIT ? OFFSET ?", - [*params, effective_page_size, offset], - ) as cur: - rows = await cur.fetchall() + rows, total = await fail2ban_db_repo.get_history_page( + db_path=db_path, + since=since, + jail=jail, + ip_filter=ip_filter, + page=page, + page_size=effective_page_size, + ) items: list[HistoryBanItem] = [] for row in rows: - jail_name: str = str(row["jail"]) - ip: str = str(row["ip"]) - banned_at: str = _ts_to_iso(int(row["timeofban"])) - ban_count: int = int(row["bancount"]) - matches, failures = _parse_data_json(row["data"]) + jail_name: str = row.jail + ip: str = row.ip + banned_at: str = _ts_to_iso(row.timeofban) + ban_count: int = row.bancount + matches, failures = _parse_data_json(row.data) country_code: str | None = None country_name: str | None = None @@ -205,16 +180,7 @@ async def get_ip_detail( db_path: str = await _get_fail2ban_db_path(socket_path) log.info("history_service_ip_detail", db_path=db_path, ip=ip) - async with aiosqlite.connect(f"file:{db_path}?mode=ro", uri=True) as f2b_db: - f2b_db.row_factory = aiosqlite.Row - async with f2b_db.execute( - "SELECT jail, ip, timeofban, bancount, data " - "FROM bans " - "WHERE ip = ? " - "ORDER BY timeofban DESC", - (ip,), - ) as cur: - rows = await cur.fetchall() + rows = await fail2ban_db_repo.get_history_for_ip(db_path=db_path, ip=ip) if not rows: return None @@ -223,10 +189,10 @@ async def get_ip_detail( total_failures: int = 0 for row in rows: - jail_name: str = str(row["jail"]) - banned_at: str = _ts_to_iso(int(row["timeofban"])) - ban_count: int = int(row["bancount"]) - matches, failures = _parse_data_json(row["data"]) + jail_name: str = row.jail + banned_at: str = _ts_to_iso(row.timeofban) + ban_count: int = row.bancount + matches, failures = _parse_data_json(row.data) total_failures += failures timeline.append( IpTimelineEvent(