Add settings and history archive repository protocols and DI support

This commit is contained in:
2026-04-17 20:54:08 +02:00
parent 7055971163
commit db5b4cb77e
8 changed files with 137 additions and 27 deletions

View File

@@ -40,11 +40,7 @@ from app.models.ban import (
from app.models.ban import (
JailBanCount as JailBanCountModel,
)
from app.repositories import fail2ban_db_repo
from app.repositories.history_archive_repo import (
get_all_archived_history,
get_archived_history,
)
from app.repositories import fail2ban_db_repo, history_archive_repo as default_history_archive_repo
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso
from app.utils.fail2ban_client import (
@@ -57,6 +53,7 @@ if TYPE_CHECKING:
import aiosqlite
from app.models.geo import GeoBatchLookup, GeoCacheLookup, GeoEnricher, GeoInfo
from app.repositories.protocols import HistoryArchiveRepository
log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -432,6 +429,7 @@ async def list_bans(
app_db: aiosqlite.Connection | None = None,
geo_batch_lookup: GeoBatchLookup | None = None,
geo_enricher: GeoEnricher | None = None,
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
origin: BanOrigin | None = None,
) -> DashboardBanListResponse:
"""Return a paginated list of bans within the selected time window.
@@ -482,7 +480,7 @@ async def list_bans(
if app_db is None:
raise ValueError("app_db must be provided when source is 'archive'")
rows, total = await get_archived_history(
rows, total = await history_archive_repo.get_archived_history(
db=app_db,
since=since,
origin=origin,
@@ -599,6 +597,7 @@ async def bans_by_country(
geo_cache_lookup: GeoCacheLookup | None = None,
geo_batch_lookup: GeoBatchLookup | None = None,
geo_enricher: GeoEnricher | None = None,
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
app_db: aiosqlite.Connection | None = None,
origin: BanOrigin | None = None,
country_code: str | None = None,
@@ -648,7 +647,7 @@ async def bans_by_country(
if app_db is None:
raise ValueError("app_db must be provided when source is 'archive'")
all_rows = await get_all_archived_history(
all_rows = await history_archive_repo.get_all_archived_history(
db=app_db,
since=since,
origin=origin,
@@ -726,7 +725,7 @@ async def bans_by_country(
companion_rows: list[dict[str, object] | fail2ban_db_repo.BanRecord]
if country_code is None:
if source == "archive":
companion_rows, _ = await get_archived_history(
companion_rows, _ = await history_archive_repo.get_archived_history(
db=app_db,
since=since,
origin=origin,
@@ -751,7 +750,7 @@ async def bans_by_country(
if source == "archive":
if matched_ips:
companion_rows = await get_all_archived_history(
companion_rows = await history_archive_repo.get_all_archived_history(
db=app_db,
since=since,
origin=origin,
@@ -859,6 +858,7 @@ async def ban_trend(
range_: TimeRange,
*,
source: str = "fail2ban",
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
app_db: aiosqlite.Connection | None = None,
origin: BanOrigin | None = None,
) -> BanTrendResponse:
@@ -899,7 +899,7 @@ async def ban_trend(
if app_db is None:
raise ValueError("app_db must be provided when source is 'archive'")
all_rows = await get_all_archived_history(
all_rows = await history_archive_repo.get_all_archived_history(
db=app_db,
since=since,
origin=origin,
@@ -966,6 +966,7 @@ async def bans_by_jail(
range_: TimeRange,
*,
source: str = "fail2ban",
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
app_db: aiosqlite.Connection | None = None,
origin: BanOrigin | None = None,
) -> BansByJailResponse:
@@ -996,7 +997,7 @@ async def bans_by_jail(
if app_db is None:
raise ValueError("app_db must be provided when source is 'archive'")
all_rows = await get_all_archived_history(
all_rows = await history_archive_repo.get_all_archived_history(
db=app_db,
since=since,
origin=origin,

View File

@@ -23,14 +23,14 @@ if TYPE_CHECKING:
import aiohttp
from app.models.geo import GeoEnricher, GeoInfo
from app.repositories.protocols import HistoryArchiveRepository
from app.models.history import (
HistoryBanItem,
HistoryListResponse,
IpDetailResponse,
IpTimelineEvent,
)
from app.repositories import fail2ban_db_repo
from app.repositories.history_archive_repo import archive_ban_event, get_max_timeofban
from app.repositories import fail2ban_db_repo, history_archive_repo as default_history_archive_repo
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso
@@ -88,14 +88,18 @@ _HISTORY_SYNC_PAGE_SIZE: int = 500
_HISTORY_SYNC_BACKFILL_WINDOW: int = 648000
async def _get_last_archive_ts(db: aiosqlite.Connection) -> int | None:
async def _get_last_archive_ts(
db: aiosqlite.Connection,
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
) -> int | None:
"""Return the most recent archived ban timestamp, or ``None`` if empty."""
return await get_max_timeofban(db)
return await history_archive_repo.get_max_timeofban(db)
async def sync_from_fail2ban_db(
db: aiosqlite.Connection,
socket_path: str,
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
) -> int:
"""Copy new records from the fail2ban DB into the BanGUI archive table.
@@ -106,7 +110,7 @@ async def sync_from_fail2ban_db(
Returns:
Number of fail2ban records scanned and archived.
"""
last_ts = await _get_last_archive_ts(db)
last_ts = await _get_last_archive_ts(db, history_archive_repo=history_archive_repo)
now_ts = int(datetime.now(tz=UTC).timestamp())
if last_ts is None:
@@ -129,7 +133,7 @@ async def sync_from_fail2ban_db(
break
for row in rows:
await archive_ban_event(
await history_archive_repo.archive_ban_event(
db=db,
jail=row.jail,
ip=row.ip,
@@ -167,6 +171,7 @@ async def list_history(
http_session: "aiohttp.ClientSession" | None = None,
geo_enricher: GeoEnricher | None = None,
db: aiosqlite.Connection | None = None,
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
) -> HistoryListResponse:
"""Return a paginated list of historical ban records with optional filters.
@@ -214,9 +219,7 @@ async def list_history(
if db is None:
raise ValueError("db must be provided when source is 'archive'")
from app.repositories.history_archive_repo import get_archived_history
archived_rows, total = await get_archived_history(
archived_rows, total = await history_archive_repo.get_archived_history(
db=db,
since=since,
jail=jail,

View File

@@ -14,12 +14,14 @@ import bcrypt
import structlog
from app.db import init_db, open_db
from app.repositories import settings_repo
from app.repositories import settings_repo as default_settings_repo
from app.utils.async_utils import run_blocking
if TYPE_CHECKING:
import aiosqlite
from app.repositories.protocols import SettingsRepository
log: structlog.stdlib.BoundLogger = structlog.get_logger()
# Keys used in the settings table.
@@ -34,11 +36,15 @@ _KEY_MAP_COLOR_THRESHOLD_MEDIUM = "map_color_threshold_medium"
_KEY_MAP_COLOR_THRESHOLD_LOW = "map_color_threshold_low"
async def is_setup_complete(db: aiosqlite.Connection) -> bool:
async def is_setup_complete(
db: aiosqlite.Connection,
settings_repo: SettingsRepository = default_settings_repo,
) -> bool:
"""Return ``True`` if initial setup has already been performed.
Args:
db: Active aiosqlite connection.
settings_repo: Repository interface for settings persistence.
Returns:
``True`` when the ``setup_completed`` key exists in settings.
@@ -55,6 +61,7 @@ async def run_setup(
fail2ban_socket: str,
timezone: str,
session_duration_minutes: int,
settings_repo: SettingsRepository = default_settings_repo,
) -> None:
"""Persist the initial configuration and mark setup as complete.
@@ -72,7 +79,7 @@ async def run_setup(
Raises:
RuntimeError: If setup has already been completed.
"""
if await is_setup_complete(db):
if await is_setup_complete(db, settings_repo=settings_repo):
raise RuntimeError("Setup has already been completed.")
log.info("bangui_setup_started")
@@ -120,17 +127,26 @@ async def run_setup(
log.info("bangui_setup_completed")
async def get_password_hash(db: aiosqlite.Connection) -> str | None:
async def get_password_hash(
db: aiosqlite.Connection,
settings_repo: SettingsRepository = default_settings_repo,
) -> str | None:
"""Return the stored bcrypt password hash, or ``None`` if not set."""
return await settings_repo.get_setting(db, _KEY_PASSWORD_HASH)
async def get_runtime_database_path(db: aiosqlite.Connection) -> str | None:
async def get_runtime_database_path(
db: aiosqlite.Connection,
settings_repo: SettingsRepository = default_settings_repo,
) -> str | None:
"""Return the runtime database path persisted during initial setup."""
return await settings_repo.get_setting(db, _KEY_DATABASE_PATH)
async def get_persisted_runtime_settings(db: aiosqlite.Connection) -> dict[str, str | int]:
async def get_persisted_runtime_settings(
db: aiosqlite.Connection,
settings_repo: SettingsRepository = default_settings_repo,
) -> dict[str, str | int]:
"""Return runtime configuration values persisted during initial setup."""
runtime_settings: dict[str, str | int] = {}
@@ -183,7 +199,10 @@ async def _ensure_database_initialized(database_path: str) -> bool:
return True
async def get_timezone(db: aiosqlite.Connection) -> str:
async def get_timezone(
db: aiosqlite.Connection,
settings_repo: SettingsRepository = default_settings_repo,
) -> str:
"""Return the configured IANA timezone string."""
tz = await settings_repo.get_setting(db, _KEY_TIMEZONE)
return tz if tz else "UTC"