Add settings and history archive repository protocols and DI support

This commit is contained in:
2026-04-17 20:54:08 +02:00
parent 7055971163
commit db5b4cb77e
8 changed files with 137 additions and 27 deletions

View File

@@ -208,6 +208,8 @@ The data access layer. Repositories execute raw SQL queries against the applicat
| `geo_cache_repo.py` | Persist and query IP geo resolution cache | | `geo_cache_repo.py` | Persist and query IP geo resolution cache |
| `import_log_repo.py` | Record import run results (timestamp, source, IPs imported, errors) for the import log view | | `import_log_repo.py` | Record import run results (timestamp, source, IPs imported, errors) for the import log view |
Every repository in `app/repositories/` has a corresponding protocol in `app/repositories/protocols.py`, including `settings_repo.py` and `history_archive_repo.py`.
#### Models (`app/models/`) #### Models (`app/models/`)
Pydantic schemas that define data shapes and validation. Models are split into three categories per domain. Pydantic schemas that define data shapes and validation. Models are split into three categories per domain.

View File

@@ -29,6 +29,8 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
### 2. Missing Repository Protocols for `settings_repo` and `history_archive_repo` ### 2. Missing Repository Protocols for `settings_repo` and `history_archive_repo`
**Status:** Completed.
**Where:** `backend/app/repositories/protocols.py` defines protocols for 5 of the 7 repositories. The two missing are `settings_repo.py` and `history_archive_repo.py`. **Where:** `backend/app/repositories/protocols.py` defines protocols for 5 of the 7 repositories. The two missing are `settings_repo.py` and `history_archive_repo.py`.
**Goal:** Define `SettingsRepository` and `HistoryArchiveRepository` protocols in `protocols.py` that match the public functions of `settings_repo` (`get_setting`, `set_setting`, `delete_setting`, `get_all_settings`) and `history_archive_repo` (`archive_ban_event`, `get_max_timeofban`, `get_archived_history`). Add corresponding dependency providers in `dependencies.py` and typed aliases so these repositories can be injected the same way as the other five. **Goal:** Define `SettingsRepository` and `HistoryArchiveRepository` protocols in `protocols.py` that match the public functions of `settings_repo` (`get_setting`, `set_setting`, `delete_setting`, `get_all_settings`) and `history_archive_repo` (`archive_ban_event`, `get_max_timeofban`, `get_archived_history`). Add corresponding dependency providers in `dependencies.py` and typed aliases so these repositories can be injected the same way as the other five.

View File

@@ -26,7 +26,9 @@ from app.repositories.protocols import (
BlocklistRepository, BlocklistRepository,
Fail2BanDbRepository, Fail2BanDbRepository,
GeoCacheRepository, GeoCacheRepository,
HistoryArchiveRepository,
ImportLogRepository, ImportLogRepository,
SettingsRepository,
SessionRepository, SessionRepository,
) )
from app.utils.constants import SESSION_COOKIE_NAME from app.utils.constants import SESSION_COOKIE_NAME
@@ -251,6 +253,20 @@ async def get_import_log_repo() -> ImportLogRepository:
return cast("ImportLogRepository", import_log_repo) return cast("ImportLogRepository", import_log_repo)
async def get_settings_repo() -> SettingsRepository:
"""Provide the concrete settings repository implementation."""
from app.repositories import settings_repo # noqa: PLC0415
return cast("SettingsRepository", settings_repo)
async def get_history_archive_repo() -> HistoryArchiveRepository:
"""Provide the concrete history archive repository implementation."""
from app.repositories import history_archive_repo # noqa: PLC0415
return cast("HistoryArchiveRepository", history_archive_repo)
async def get_geo_cache_repo() -> GeoCacheRepository: async def get_geo_cache_repo() -> GeoCacheRepository:
"""Provide the concrete geo cache repository implementation.""" """Provide the concrete geo cache repository implementation."""
from app.repositories import geo_cache_repo # noqa: PLC0415 from app.repositories import geo_cache_repo # noqa: PLC0415
@@ -370,6 +386,8 @@ ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)]
PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)] PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)]
SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)] SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)]
SessionRepoDep = Annotated[SessionRepository, Depends(get_session_repo)] SessionRepoDep = Annotated[SessionRepository, Depends(get_session_repo)]
SettingsRepoDep = Annotated[SettingsRepository, Depends(get_settings_repo)]
HistoryArchiveRepositoryDep = Annotated[HistoryArchiveRepository, Depends(get_history_archive_repo)]
BlocklistRepositoryDep = Annotated[BlocklistRepository, Depends(get_blocklist_repo)] BlocklistRepositoryDep = Annotated[BlocklistRepository, Depends(get_blocklist_repo)]
ImportLogRepositoryDep = Annotated[ImportLogRepository, Depends(get_import_log_repo)] ImportLogRepositoryDep = Annotated[ImportLogRepository, Depends(get_import_log_repo)]
GeoCacheRepositoryDep = Annotated[GeoCacheRepository, Depends(get_geo_cache_repo)] GeoCacheRepositoryDep = Annotated[GeoCacheRepository, Depends(get_geo_cache_repo)]

View File

@@ -50,6 +50,22 @@ class SessionRepository(Protocol):
... ...
class SettingsRepository(Protocol):
"""Protocol for application settings persistence operations."""
async def get_setting(self, db: aiosqlite.Connection, key: str) -> str | None:
...
async def set_setting(self, db: aiosqlite.Connection, key: str, value: str) -> None:
...
async def delete_setting(self, db: aiosqlite.Connection, key: str) -> None:
...
async def get_all_settings(self, db: aiosqlite.Connection) -> dict[str, str]:
...
class BlocklistRepository(Protocol): class BlocklistRepository(Protocol):
async def create_source( async def create_source(
self, self,
@@ -154,6 +170,38 @@ class GeoCacheRepository(Protocol):
... ...
class HistoryArchiveRepository(Protocol):
"""Protocol for archived ban history persistence operations."""
async def archive_ban_event(
self,
db: aiosqlite.Connection,
jail: str,
ip: str,
timeofban: int,
bancount: int,
data: str,
action: str = "ban",
) -> bool:
...
async def get_max_timeofban(self, db: aiosqlite.Connection) -> int | None:
...
async def get_archived_history(
self,
db: aiosqlite.Connection,
since: int | None = None,
jail: str | None = None,
ip_filter: str | list[str] | None = None,
origin: BanOrigin | None = None,
action: str | None = None,
page: int = 1,
page_size: int = 100,
) -> tuple[list[dict[str, object]], int]:
...
class Fail2BanDbRepository(Protocol): class Fail2BanDbRepository(Protocol):
async def check_db_nonempty(self, db_path: str) -> bool: async def check_db_nonempty(self, db_path: str) -> bool:
... ...

View File

@@ -40,11 +40,7 @@ from app.models.ban import (
from app.models.ban import ( from app.models.ban import (
JailBanCount as JailBanCountModel, JailBanCount as JailBanCountModel,
) )
from app.repositories import fail2ban_db_repo from app.repositories import fail2ban_db_repo, history_archive_repo as default_history_archive_repo
from app.repositories.history_archive_repo import (
get_all_archived_history,
get_archived_history,
)
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso
from app.utils.fail2ban_client import ( from app.utils.fail2ban_client import (
@@ -57,6 +53,7 @@ if TYPE_CHECKING:
import aiosqlite import aiosqlite
from app.models.geo import GeoBatchLookup, GeoCacheLookup, GeoEnricher, GeoInfo from app.models.geo import GeoBatchLookup, GeoCacheLookup, GeoEnricher, GeoInfo
from app.repositories.protocols import HistoryArchiveRepository
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
@@ -432,6 +429,7 @@ async def list_bans(
app_db: aiosqlite.Connection | None = None, app_db: aiosqlite.Connection | None = None,
geo_batch_lookup: GeoBatchLookup | None = None, geo_batch_lookup: GeoBatchLookup | None = None,
geo_enricher: GeoEnricher | None = None, geo_enricher: GeoEnricher | None = None,
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
origin: BanOrigin | None = None, origin: BanOrigin | None = None,
) -> DashboardBanListResponse: ) -> DashboardBanListResponse:
"""Return a paginated list of bans within the selected time window. """Return a paginated list of bans within the selected time window.
@@ -482,7 +480,7 @@ async def list_bans(
if app_db is None: if app_db is None:
raise ValueError("app_db must be provided when source is 'archive'") raise ValueError("app_db must be provided when source is 'archive'")
rows, total = await get_archived_history( rows, total = await history_archive_repo.get_archived_history(
db=app_db, db=app_db,
since=since, since=since,
origin=origin, origin=origin,
@@ -599,6 +597,7 @@ async def bans_by_country(
geo_cache_lookup: GeoCacheLookup | None = None, geo_cache_lookup: GeoCacheLookup | None = None,
geo_batch_lookup: GeoBatchLookup | None = None, geo_batch_lookup: GeoBatchLookup | None = None,
geo_enricher: GeoEnricher | None = None, geo_enricher: GeoEnricher | None = None,
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
app_db: aiosqlite.Connection | None = None, app_db: aiosqlite.Connection | None = None,
origin: BanOrigin | None = None, origin: BanOrigin | None = None,
country_code: str | None = None, country_code: str | None = None,
@@ -648,7 +647,7 @@ async def bans_by_country(
if app_db is None: if app_db is None:
raise ValueError("app_db must be provided when source is 'archive'") raise ValueError("app_db must be provided when source is 'archive'")
all_rows = await get_all_archived_history( all_rows = await history_archive_repo.get_all_archived_history(
db=app_db, db=app_db,
since=since, since=since,
origin=origin, origin=origin,
@@ -726,7 +725,7 @@ async def bans_by_country(
companion_rows: list[dict[str, object] | fail2ban_db_repo.BanRecord] companion_rows: list[dict[str, object] | fail2ban_db_repo.BanRecord]
if country_code is None: if country_code is None:
if source == "archive": if source == "archive":
companion_rows, _ = await get_archived_history( companion_rows, _ = await history_archive_repo.get_archived_history(
db=app_db, db=app_db,
since=since, since=since,
origin=origin, origin=origin,
@@ -751,7 +750,7 @@ async def bans_by_country(
if source == "archive": if source == "archive":
if matched_ips: if matched_ips:
companion_rows = await get_all_archived_history( companion_rows = await history_archive_repo.get_all_archived_history(
db=app_db, db=app_db,
since=since, since=since,
origin=origin, origin=origin,
@@ -859,6 +858,7 @@ async def ban_trend(
range_: TimeRange, range_: TimeRange,
*, *,
source: str = "fail2ban", source: str = "fail2ban",
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
app_db: aiosqlite.Connection | None = None, app_db: aiosqlite.Connection | None = None,
origin: BanOrigin | None = None, origin: BanOrigin | None = None,
) -> BanTrendResponse: ) -> BanTrendResponse:
@@ -899,7 +899,7 @@ async def ban_trend(
if app_db is None: if app_db is None:
raise ValueError("app_db must be provided when source is 'archive'") raise ValueError("app_db must be provided when source is 'archive'")
all_rows = await get_all_archived_history( all_rows = await history_archive_repo.get_all_archived_history(
db=app_db, db=app_db,
since=since, since=since,
origin=origin, origin=origin,
@@ -966,6 +966,7 @@ async def bans_by_jail(
range_: TimeRange, range_: TimeRange,
*, *,
source: str = "fail2ban", source: str = "fail2ban",
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
app_db: aiosqlite.Connection | None = None, app_db: aiosqlite.Connection | None = None,
origin: BanOrigin | None = None, origin: BanOrigin | None = None,
) -> BansByJailResponse: ) -> BansByJailResponse:
@@ -996,7 +997,7 @@ async def bans_by_jail(
if app_db is None: if app_db is None:
raise ValueError("app_db must be provided when source is 'archive'") raise ValueError("app_db must be provided when source is 'archive'")
all_rows = await get_all_archived_history( all_rows = await history_archive_repo.get_all_archived_history(
db=app_db, db=app_db,
since=since, since=since,
origin=origin, origin=origin,

View File

@@ -23,14 +23,14 @@ if TYPE_CHECKING:
import aiohttp import aiohttp
from app.models.geo import GeoEnricher, GeoInfo from app.models.geo import GeoEnricher, GeoInfo
from app.repositories.protocols import HistoryArchiveRepository
from app.models.history import ( from app.models.history import (
HistoryBanItem, HistoryBanItem,
HistoryListResponse, HistoryListResponse,
IpDetailResponse, IpDetailResponse,
IpTimelineEvent, IpTimelineEvent,
) )
from app.repositories import fail2ban_db_repo from app.repositories import fail2ban_db_repo, history_archive_repo as default_history_archive_repo
from app.repositories.history_archive_repo import archive_ban_event, get_max_timeofban
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso
@@ -88,14 +88,18 @@ _HISTORY_SYNC_PAGE_SIZE: int = 500
_HISTORY_SYNC_BACKFILL_WINDOW: int = 648000 _HISTORY_SYNC_BACKFILL_WINDOW: int = 648000
async def _get_last_archive_ts(db: aiosqlite.Connection) -> int | None: async def _get_last_archive_ts(
db: aiosqlite.Connection,
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
) -> int | None:
"""Return the most recent archived ban timestamp, or ``None`` if empty.""" """Return the most recent archived ban timestamp, or ``None`` if empty."""
return await get_max_timeofban(db) return await history_archive_repo.get_max_timeofban(db)
async def sync_from_fail2ban_db( async def sync_from_fail2ban_db(
db: aiosqlite.Connection, db: aiosqlite.Connection,
socket_path: str, socket_path: str,
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
) -> int: ) -> int:
"""Copy new records from the fail2ban DB into the BanGUI archive table. """Copy new records from the fail2ban DB into the BanGUI archive table.
@@ -106,7 +110,7 @@ async def sync_from_fail2ban_db(
Returns: Returns:
Number of fail2ban records scanned and archived. Number of fail2ban records scanned and archived.
""" """
last_ts = await _get_last_archive_ts(db) last_ts = await _get_last_archive_ts(db, history_archive_repo=history_archive_repo)
now_ts = int(datetime.now(tz=UTC).timestamp()) now_ts = int(datetime.now(tz=UTC).timestamp())
if last_ts is None: if last_ts is None:
@@ -129,7 +133,7 @@ async def sync_from_fail2ban_db(
break break
for row in rows: for row in rows:
await archive_ban_event( await history_archive_repo.archive_ban_event(
db=db, db=db,
jail=row.jail, jail=row.jail,
ip=row.ip, ip=row.ip,
@@ -167,6 +171,7 @@ async def list_history(
http_session: "aiohttp.ClientSession" | None = None, http_session: "aiohttp.ClientSession" | None = None,
geo_enricher: GeoEnricher | None = None, geo_enricher: GeoEnricher | None = None,
db: aiosqlite.Connection | None = None, db: aiosqlite.Connection | None = None,
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
) -> HistoryListResponse: ) -> HistoryListResponse:
"""Return a paginated list of historical ban records with optional filters. """Return a paginated list of historical ban records with optional filters.
@@ -214,9 +219,7 @@ async def list_history(
if db is None: if db is None:
raise ValueError("db must be provided when source is 'archive'") raise ValueError("db must be provided when source is 'archive'")
from app.repositories.history_archive_repo import get_archived_history archived_rows, total = await history_archive_repo.get_archived_history(
archived_rows, total = await get_archived_history(
db=db, db=db,
since=since, since=since,
jail=jail, jail=jail,

View File

@@ -14,12 +14,14 @@ import bcrypt
import structlog import structlog
from app.db import init_db, open_db from app.db import init_db, open_db
from app.repositories import settings_repo from app.repositories import settings_repo as default_settings_repo
from app.utils.async_utils import run_blocking from app.utils.async_utils import run_blocking
if TYPE_CHECKING: if TYPE_CHECKING:
import aiosqlite import aiosqlite
from app.repositories.protocols import SettingsRepository
log: structlog.stdlib.BoundLogger = structlog.get_logger() log: structlog.stdlib.BoundLogger = structlog.get_logger()
# Keys used in the settings table. # Keys used in the settings table.
@@ -34,11 +36,15 @@ _KEY_MAP_COLOR_THRESHOLD_MEDIUM = "map_color_threshold_medium"
_KEY_MAP_COLOR_THRESHOLD_LOW = "map_color_threshold_low" _KEY_MAP_COLOR_THRESHOLD_LOW = "map_color_threshold_low"
async def is_setup_complete(db: aiosqlite.Connection) -> bool: async def is_setup_complete(
db: aiosqlite.Connection,
settings_repo: SettingsRepository = default_settings_repo,
) -> bool:
"""Return ``True`` if initial setup has already been performed. """Return ``True`` if initial setup has already been performed.
Args: Args:
db: Active aiosqlite connection. db: Active aiosqlite connection.
settings_repo: Repository interface for settings persistence.
Returns: Returns:
``True`` when the ``setup_completed`` key exists in settings. ``True`` when the ``setup_completed`` key exists in settings.
@@ -55,6 +61,7 @@ async def run_setup(
fail2ban_socket: str, fail2ban_socket: str,
timezone: str, timezone: str,
session_duration_minutes: int, session_duration_minutes: int,
settings_repo: SettingsRepository = default_settings_repo,
) -> None: ) -> None:
"""Persist the initial configuration and mark setup as complete. """Persist the initial configuration and mark setup as complete.
@@ -72,7 +79,7 @@ async def run_setup(
Raises: Raises:
RuntimeError: If setup has already been completed. RuntimeError: If setup has already been completed.
""" """
if await is_setup_complete(db): if await is_setup_complete(db, settings_repo=settings_repo):
raise RuntimeError("Setup has already been completed.") raise RuntimeError("Setup has already been completed.")
log.info("bangui_setup_started") log.info("bangui_setup_started")
@@ -120,17 +127,26 @@ async def run_setup(
log.info("bangui_setup_completed") log.info("bangui_setup_completed")
async def get_password_hash(db: aiosqlite.Connection) -> str | None: async def get_password_hash(
db: aiosqlite.Connection,
settings_repo: SettingsRepository = default_settings_repo,
) -> str | None:
"""Return the stored bcrypt password hash, or ``None`` if not set.""" """Return the stored bcrypt password hash, or ``None`` if not set."""
return await settings_repo.get_setting(db, _KEY_PASSWORD_HASH) return await settings_repo.get_setting(db, _KEY_PASSWORD_HASH)
async def get_runtime_database_path(db: aiosqlite.Connection) -> str | None: async def get_runtime_database_path(
db: aiosqlite.Connection,
settings_repo: SettingsRepository = default_settings_repo,
) -> str | None:
"""Return the runtime database path persisted during initial setup.""" """Return the runtime database path persisted during initial setup."""
return await settings_repo.get_setting(db, _KEY_DATABASE_PATH) return await settings_repo.get_setting(db, _KEY_DATABASE_PATH)
async def get_persisted_runtime_settings(db: aiosqlite.Connection) -> dict[str, str | int]: async def get_persisted_runtime_settings(
db: aiosqlite.Connection,
settings_repo: SettingsRepository = default_settings_repo,
) -> dict[str, str | int]:
"""Return runtime configuration values persisted during initial setup.""" """Return runtime configuration values persisted during initial setup."""
runtime_settings: dict[str, str | int] = {} runtime_settings: dict[str, str | int] = {}
@@ -183,7 +199,10 @@ async def _ensure_database_initialized(database_path: str) -> bool:
return True return True
async def get_timezone(db: aiosqlite.Connection) -> str: async def get_timezone(
db: aiosqlite.Connection,
settings_repo: SettingsRepository = default_settings_repo,
) -> str:
"""Return the configured IANA timezone string.""" """Return the configured IANA timezone string."""
tz = await settings_repo.get_setting(db, _KEY_TIMEZONE) tz = await settings_repo.get_setting(db, _KEY_TIMEZONE)
return tz if tz else "UTC" return tz if tz else "UTC"

View File

@@ -14,9 +14,11 @@ from app.dependencies import (
get_app_context, get_app_context,
get_db, get_db,
get_http_session, get_http_session,
get_history_archive_repo,
get_scheduler, get_scheduler,
get_settings, get_settings,
get_session_cache, get_session_cache,
get_settings_repo,
) )
from app.main import create_app from app.main import create_app
from app.models.server import ServerStatus from app.models.server import ServerStatus
@@ -65,6 +67,21 @@ async def test_app_context_dependency_exposes_shared_resources(test_settings: Se
await session.close() await session.close()
@pytest.mark.asyncio
async def test_settings_and_history_archive_repo_dependencies_return_modules() -> None:
settings_repo = await get_settings_repo()
history_archive_repo = await get_history_archive_repo()
assert hasattr(settings_repo, "get_setting")
assert hasattr(settings_repo, "set_setting")
assert hasattr(settings_repo, "delete_setting")
assert hasattr(settings_repo, "get_all_settings")
assert hasattr(history_archive_repo, "archive_ban_event")
assert hasattr(history_archive_repo, "get_max_timeofban")
assert hasattr(history_archive_repo, "get_archived_history")
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_get_db_uses_effective_runtime_database_path(test_settings: Settings) -> None: async def test_get_db_uses_effective_runtime_database_path(test_settings: Settings) -> None:
"""Database connections should use effective runtime settings when overridden.""" """Database connections should use effective runtime settings when overridden."""