From db5b4cb77e0f0c3621c4e0b5473bc21bf7cd1562 Mon Sep 17 00:00:00 2001 From: Lukas Date: Fri, 17 Apr 2026 20:54:08 +0200 Subject: [PATCH] Add settings and history archive repository protocols and DI support --- Docs/Architekture.md | 2 ++ Docs/Tasks.md | 2 ++ backend/app/dependencies.py | 18 ++++++++++ backend/app/repositories/protocols.py | 48 +++++++++++++++++++++++++ backend/app/services/ban_service.py | 23 ++++++------ backend/app/services/history_service.py | 21 ++++++----- backend/app/services/setup_service.py | 33 +++++++++++++---- backend/tests/test_dependencies.py | 17 +++++++++ 8 files changed, 137 insertions(+), 27 deletions(-) diff --git a/Docs/Architekture.md b/Docs/Architekture.md index 7d2fb4b..0c030f8 100644 --- a/Docs/Architekture.md +++ b/Docs/Architekture.md @@ -208,6 +208,8 @@ The data access layer. Repositories execute raw SQL queries against the applicat | `geo_cache_repo.py` | Persist and query IP geo resolution cache | | `import_log_repo.py` | Record import run results (timestamp, source, IPs imported, errors) for the import log view | +Every repository in `app/repositories/` has a corresponding protocol in `app/repositories/protocols.py`, including `settings_repo.py` and `history_archive_repo.py`. + #### Models (`app/models/`) Pydantic schemas that define data shapes and validation. Models are split into three categories per domain. diff --git a/Docs/Tasks.md b/Docs/Tasks.md index 3b9c943..46f2ece 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -29,6 +29,8 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue. ### 2. Missing Repository Protocols for `settings_repo` and `history_archive_repo` +**Status:** Completed. + **Where:** `backend/app/repositories/protocols.py` defines protocols for 5 of the 7 repositories. The two missing are `settings_repo.py` and `history_archive_repo.py`. **Goal:** Define `SettingsRepository` and `HistoryArchiveRepository` protocols in `protocols.py` that match the public functions of `settings_repo` (`get_setting`, `set_setting`, `delete_setting`, `get_all_settings`) and `history_archive_repo` (`archive_ban_event`, `get_max_timeofban`, `get_archived_history`). Add corresponding dependency providers in `dependencies.py` and typed aliases so these repositories can be injected the same way as the other five. diff --git a/backend/app/dependencies.py b/backend/app/dependencies.py index 566de0a..3b0563e 100644 --- a/backend/app/dependencies.py +++ b/backend/app/dependencies.py @@ -26,7 +26,9 @@ from app.repositories.protocols import ( BlocklistRepository, Fail2BanDbRepository, GeoCacheRepository, + HistoryArchiveRepository, ImportLogRepository, + SettingsRepository, SessionRepository, ) from app.utils.constants import SESSION_COOKIE_NAME @@ -251,6 +253,20 @@ async def get_import_log_repo() -> ImportLogRepository: return cast("ImportLogRepository", import_log_repo) +async def get_settings_repo() -> SettingsRepository: + """Provide the concrete settings repository implementation.""" + from app.repositories import settings_repo # noqa: PLC0415 + + return cast("SettingsRepository", settings_repo) + + +async def get_history_archive_repo() -> HistoryArchiveRepository: + """Provide the concrete history archive repository implementation.""" + from app.repositories import history_archive_repo # noqa: PLC0415 + + return cast("HistoryArchiveRepository", history_archive_repo) + + async def get_geo_cache_repo() -> GeoCacheRepository: """Provide the concrete geo cache repository implementation.""" from app.repositories import geo_cache_repo # noqa: PLC0415 @@ -370,6 +386,8 @@ ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)] PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)] SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)] SessionRepoDep = Annotated[SessionRepository, Depends(get_session_repo)] +SettingsRepoDep = Annotated[SettingsRepository, Depends(get_settings_repo)] +HistoryArchiveRepositoryDep = Annotated[HistoryArchiveRepository, Depends(get_history_archive_repo)] BlocklistRepositoryDep = Annotated[BlocklistRepository, Depends(get_blocklist_repo)] ImportLogRepositoryDep = Annotated[ImportLogRepository, Depends(get_import_log_repo)] GeoCacheRepositoryDep = Annotated[GeoCacheRepository, Depends(get_geo_cache_repo)] diff --git a/backend/app/repositories/protocols.py b/backend/app/repositories/protocols.py index 85d4ba2..f0fa58b 100644 --- a/backend/app/repositories/protocols.py +++ b/backend/app/repositories/protocols.py @@ -50,6 +50,22 @@ class SessionRepository(Protocol): ... +class SettingsRepository(Protocol): + """Protocol for application settings persistence operations.""" + + async def get_setting(self, db: aiosqlite.Connection, key: str) -> str | None: + ... + + async def set_setting(self, db: aiosqlite.Connection, key: str, value: str) -> None: + ... + + async def delete_setting(self, db: aiosqlite.Connection, key: str) -> None: + ... + + async def get_all_settings(self, db: aiosqlite.Connection) -> dict[str, str]: + ... + + class BlocklistRepository(Protocol): async def create_source( self, @@ -154,6 +170,38 @@ class GeoCacheRepository(Protocol): ... +class HistoryArchiveRepository(Protocol): + """Protocol for archived ban history persistence operations.""" + + async def archive_ban_event( + self, + db: aiosqlite.Connection, + jail: str, + ip: str, + timeofban: int, + bancount: int, + data: str, + action: str = "ban", + ) -> bool: + ... + + async def get_max_timeofban(self, db: aiosqlite.Connection) -> int | None: + ... + + async def get_archived_history( + self, + db: aiosqlite.Connection, + since: int | None = None, + jail: str | None = None, + ip_filter: str | list[str] | None = None, + origin: BanOrigin | None = None, + action: str | None = None, + page: int = 1, + page_size: int = 100, + ) -> tuple[list[dict[str, object]], int]: + ... + + class Fail2BanDbRepository(Protocol): async def check_db_nonempty(self, db_path: str) -> bool: ... diff --git a/backend/app/services/ban_service.py b/backend/app/services/ban_service.py index 6502763..af313f0 100644 --- a/backend/app/services/ban_service.py +++ b/backend/app/services/ban_service.py @@ -40,11 +40,7 @@ from app.models.ban import ( from app.models.ban import ( JailBanCount as JailBanCountModel, ) -from app.repositories import fail2ban_db_repo -from app.repositories.history_archive_repo import ( - get_all_archived_history, - get_archived_history, -) +from app.repositories import fail2ban_db_repo, history_archive_repo as default_history_archive_repo from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso from app.utils.fail2ban_client import ( @@ -57,6 +53,7 @@ if TYPE_CHECKING: import aiosqlite from app.models.geo import GeoBatchLookup, GeoCacheLookup, GeoEnricher, GeoInfo + from app.repositories.protocols import HistoryArchiveRepository log: structlog.stdlib.BoundLogger = structlog.get_logger() @@ -432,6 +429,7 @@ async def list_bans( app_db: aiosqlite.Connection | None = None, geo_batch_lookup: GeoBatchLookup | None = None, geo_enricher: GeoEnricher | None = None, + history_archive_repo: HistoryArchiveRepository = default_history_archive_repo, origin: BanOrigin | None = None, ) -> DashboardBanListResponse: """Return a paginated list of bans within the selected time window. @@ -482,7 +480,7 @@ async def list_bans( if app_db is None: raise ValueError("app_db must be provided when source is 'archive'") - rows, total = await get_archived_history( + rows, total = await history_archive_repo.get_archived_history( db=app_db, since=since, origin=origin, @@ -599,6 +597,7 @@ async def bans_by_country( geo_cache_lookup: GeoCacheLookup | None = None, geo_batch_lookup: GeoBatchLookup | None = None, geo_enricher: GeoEnricher | None = None, + history_archive_repo: HistoryArchiveRepository = default_history_archive_repo, app_db: aiosqlite.Connection | None = None, origin: BanOrigin | None = None, country_code: str | None = None, @@ -648,7 +647,7 @@ async def bans_by_country( if app_db is None: raise ValueError("app_db must be provided when source is 'archive'") - all_rows = await get_all_archived_history( + all_rows = await history_archive_repo.get_all_archived_history( db=app_db, since=since, origin=origin, @@ -726,7 +725,7 @@ async def bans_by_country( companion_rows: list[dict[str, object] | fail2ban_db_repo.BanRecord] if country_code is None: if source == "archive": - companion_rows, _ = await get_archived_history( + companion_rows, _ = await history_archive_repo.get_archived_history( db=app_db, since=since, origin=origin, @@ -751,7 +750,7 @@ async def bans_by_country( if source == "archive": if matched_ips: - companion_rows = await get_all_archived_history( + companion_rows = await history_archive_repo.get_all_archived_history( db=app_db, since=since, origin=origin, @@ -859,6 +858,7 @@ async def ban_trend( range_: TimeRange, *, source: str = "fail2ban", + history_archive_repo: HistoryArchiveRepository = default_history_archive_repo, app_db: aiosqlite.Connection | None = None, origin: BanOrigin | None = None, ) -> BanTrendResponse: @@ -899,7 +899,7 @@ async def ban_trend( if app_db is None: raise ValueError("app_db must be provided when source is 'archive'") - all_rows = await get_all_archived_history( + all_rows = await history_archive_repo.get_all_archived_history( db=app_db, since=since, origin=origin, @@ -966,6 +966,7 @@ async def bans_by_jail( range_: TimeRange, *, source: str = "fail2ban", + history_archive_repo: HistoryArchiveRepository = default_history_archive_repo, app_db: aiosqlite.Connection | None = None, origin: BanOrigin | None = None, ) -> BansByJailResponse: @@ -996,7 +997,7 @@ async def bans_by_jail( if app_db is None: raise ValueError("app_db must be provided when source is 'archive'") - all_rows = await get_all_archived_history( + all_rows = await history_archive_repo.get_all_archived_history( db=app_db, since=since, origin=origin, diff --git a/backend/app/services/history_service.py b/backend/app/services/history_service.py index eb2de4a..50eb4ee 100644 --- a/backend/app/services/history_service.py +++ b/backend/app/services/history_service.py @@ -23,14 +23,14 @@ if TYPE_CHECKING: import aiohttp from app.models.geo import GeoEnricher, GeoInfo + from app.repositories.protocols import HistoryArchiveRepository from app.models.history import ( HistoryBanItem, HistoryListResponse, IpDetailResponse, IpTimelineEvent, ) -from app.repositories import fail2ban_db_repo -from app.repositories.history_archive_repo import archive_ban_event, get_max_timeofban +from app.repositories import fail2ban_db_repo, history_archive_repo as default_history_archive_repo from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso @@ -88,14 +88,18 @@ _HISTORY_SYNC_PAGE_SIZE: int = 500 _HISTORY_SYNC_BACKFILL_WINDOW: int = 648000 -async def _get_last_archive_ts(db: aiosqlite.Connection) -> int | None: +async def _get_last_archive_ts( + db: aiosqlite.Connection, + history_archive_repo: HistoryArchiveRepository = default_history_archive_repo, +) -> int | None: """Return the most recent archived ban timestamp, or ``None`` if empty.""" - return await get_max_timeofban(db) + return await history_archive_repo.get_max_timeofban(db) async def sync_from_fail2ban_db( db: aiosqlite.Connection, socket_path: str, + history_archive_repo: HistoryArchiveRepository = default_history_archive_repo, ) -> int: """Copy new records from the fail2ban DB into the BanGUI archive table. @@ -106,7 +110,7 @@ async def sync_from_fail2ban_db( Returns: Number of fail2ban records scanned and archived. """ - last_ts = await _get_last_archive_ts(db) + last_ts = await _get_last_archive_ts(db, history_archive_repo=history_archive_repo) now_ts = int(datetime.now(tz=UTC).timestamp()) if last_ts is None: @@ -129,7 +133,7 @@ async def sync_from_fail2ban_db( break for row in rows: - await archive_ban_event( + await history_archive_repo.archive_ban_event( db=db, jail=row.jail, ip=row.ip, @@ -167,6 +171,7 @@ async def list_history( http_session: "aiohttp.ClientSession" | None = None, geo_enricher: GeoEnricher | None = None, db: aiosqlite.Connection | None = None, + history_archive_repo: HistoryArchiveRepository = default_history_archive_repo, ) -> HistoryListResponse: """Return a paginated list of historical ban records with optional filters. @@ -214,9 +219,7 @@ async def list_history( if db is None: raise ValueError("db must be provided when source is 'archive'") - from app.repositories.history_archive_repo import get_archived_history - - archived_rows, total = await get_archived_history( + archived_rows, total = await history_archive_repo.get_archived_history( db=db, since=since, jail=jail, diff --git a/backend/app/services/setup_service.py b/backend/app/services/setup_service.py index c1c0185..bfb99a3 100644 --- a/backend/app/services/setup_service.py +++ b/backend/app/services/setup_service.py @@ -14,12 +14,14 @@ import bcrypt import structlog from app.db import init_db, open_db -from app.repositories import settings_repo +from app.repositories import settings_repo as default_settings_repo from app.utils.async_utils import run_blocking if TYPE_CHECKING: import aiosqlite + from app.repositories.protocols import SettingsRepository + log: structlog.stdlib.BoundLogger = structlog.get_logger() # Keys used in the settings table. @@ -34,11 +36,15 @@ _KEY_MAP_COLOR_THRESHOLD_MEDIUM = "map_color_threshold_medium" _KEY_MAP_COLOR_THRESHOLD_LOW = "map_color_threshold_low" -async def is_setup_complete(db: aiosqlite.Connection) -> bool: +async def is_setup_complete( + db: aiosqlite.Connection, + settings_repo: SettingsRepository = default_settings_repo, +) -> bool: """Return ``True`` if initial setup has already been performed. Args: db: Active aiosqlite connection. + settings_repo: Repository interface for settings persistence. Returns: ``True`` when the ``setup_completed`` key exists in settings. @@ -55,6 +61,7 @@ async def run_setup( fail2ban_socket: str, timezone: str, session_duration_minutes: int, + settings_repo: SettingsRepository = default_settings_repo, ) -> None: """Persist the initial configuration and mark setup as complete. @@ -72,7 +79,7 @@ async def run_setup( Raises: RuntimeError: If setup has already been completed. """ - if await is_setup_complete(db): + if await is_setup_complete(db, settings_repo=settings_repo): raise RuntimeError("Setup has already been completed.") log.info("bangui_setup_started") @@ -120,17 +127,26 @@ async def run_setup( log.info("bangui_setup_completed") -async def get_password_hash(db: aiosqlite.Connection) -> str | None: +async def get_password_hash( + db: aiosqlite.Connection, + settings_repo: SettingsRepository = default_settings_repo, +) -> str | None: """Return the stored bcrypt password hash, or ``None`` if not set.""" return await settings_repo.get_setting(db, _KEY_PASSWORD_HASH) -async def get_runtime_database_path(db: aiosqlite.Connection) -> str | None: +async def get_runtime_database_path( + db: aiosqlite.Connection, + settings_repo: SettingsRepository = default_settings_repo, +) -> str | None: """Return the runtime database path persisted during initial setup.""" return await settings_repo.get_setting(db, _KEY_DATABASE_PATH) -async def get_persisted_runtime_settings(db: aiosqlite.Connection) -> dict[str, str | int]: +async def get_persisted_runtime_settings( + db: aiosqlite.Connection, + settings_repo: SettingsRepository = default_settings_repo, +) -> dict[str, str | int]: """Return runtime configuration values persisted during initial setup.""" runtime_settings: dict[str, str | int] = {} @@ -183,7 +199,10 @@ async def _ensure_database_initialized(database_path: str) -> bool: return True -async def get_timezone(db: aiosqlite.Connection) -> str: +async def get_timezone( + db: aiosqlite.Connection, + settings_repo: SettingsRepository = default_settings_repo, +) -> str: """Return the configured IANA timezone string.""" tz = await settings_repo.get_setting(db, _KEY_TIMEZONE) return tz if tz else "UTC" diff --git a/backend/tests/test_dependencies.py b/backend/tests/test_dependencies.py index 6c8c9f1..1156f9b 100644 --- a/backend/tests/test_dependencies.py +++ b/backend/tests/test_dependencies.py @@ -14,9 +14,11 @@ from app.dependencies import ( get_app_context, get_db, get_http_session, + get_history_archive_repo, get_scheduler, get_settings, get_session_cache, + get_settings_repo, ) from app.main import create_app from app.models.server import ServerStatus @@ -65,6 +67,21 @@ async def test_app_context_dependency_exposes_shared_resources(test_settings: Se await session.close() +@pytest.mark.asyncio +async def test_settings_and_history_archive_repo_dependencies_return_modules() -> None: + settings_repo = await get_settings_repo() + history_archive_repo = await get_history_archive_repo() + + assert hasattr(settings_repo, "get_setting") + assert hasattr(settings_repo, "set_setting") + assert hasattr(settings_repo, "delete_setting") + assert hasattr(settings_repo, "get_all_settings") + + assert hasattr(history_archive_repo, "archive_ban_event") + assert hasattr(history_archive_repo, "get_max_timeofban") + assert hasattr(history_archive_repo, "get_archived_history") + + @pytest.mark.asyncio async def test_get_db_uses_effective_runtime_database_path(test_settings: Settings) -> None: """Database connections should use effective runtime settings when overridden."""