Add settings and history archive repository protocols and DI support
This commit is contained in:
@@ -208,6 +208,8 @@ The data access layer. Repositories execute raw SQL queries against the applicat
|
||||
| `geo_cache_repo.py` | Persist and query IP geo resolution cache |
|
||||
| `import_log_repo.py` | Record import run results (timestamp, source, IPs imported, errors) for the import log view |
|
||||
|
||||
Every repository in `app/repositories/` has a corresponding protocol in `app/repositories/protocols.py`, including `settings_repo.py` and `history_archive_repo.py`.
|
||||
|
||||
#### Models (`app/models/`)
|
||||
|
||||
Pydantic schemas that define data shapes and validation. Models are split into three categories per domain.
|
||||
|
||||
@@ -29,6 +29,8 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
|
||||
|
||||
### 2. Missing Repository Protocols for `settings_repo` and `history_archive_repo`
|
||||
|
||||
**Status:** Completed.
|
||||
|
||||
**Where:** `backend/app/repositories/protocols.py` defines protocols for 5 of the 7 repositories. The two missing are `settings_repo.py` and `history_archive_repo.py`.
|
||||
|
||||
**Goal:** Define `SettingsRepository` and `HistoryArchiveRepository` protocols in `protocols.py` that match the public functions of `settings_repo` (`get_setting`, `set_setting`, `delete_setting`, `get_all_settings`) and `history_archive_repo` (`archive_ban_event`, `get_max_timeofban`, `get_archived_history`). Add corresponding dependency providers in `dependencies.py` and typed aliases so these repositories can be injected the same way as the other five.
|
||||
|
||||
@@ -26,7 +26,9 @@ from app.repositories.protocols import (
|
||||
BlocklistRepository,
|
||||
Fail2BanDbRepository,
|
||||
GeoCacheRepository,
|
||||
HistoryArchiveRepository,
|
||||
ImportLogRepository,
|
||||
SettingsRepository,
|
||||
SessionRepository,
|
||||
)
|
||||
from app.utils.constants import SESSION_COOKIE_NAME
|
||||
@@ -251,6 +253,20 @@ async def get_import_log_repo() -> ImportLogRepository:
|
||||
return cast("ImportLogRepository", import_log_repo)
|
||||
|
||||
|
||||
async def get_settings_repo() -> SettingsRepository:
|
||||
"""Provide the concrete settings repository implementation."""
|
||||
from app.repositories import settings_repo # noqa: PLC0415
|
||||
|
||||
return cast("SettingsRepository", settings_repo)
|
||||
|
||||
|
||||
async def get_history_archive_repo() -> HistoryArchiveRepository:
|
||||
"""Provide the concrete history archive repository implementation."""
|
||||
from app.repositories import history_archive_repo # noqa: PLC0415
|
||||
|
||||
return cast("HistoryArchiveRepository", history_archive_repo)
|
||||
|
||||
|
||||
async def get_geo_cache_repo() -> GeoCacheRepository:
|
||||
"""Provide the concrete geo cache repository implementation."""
|
||||
from app.repositories import geo_cache_repo # noqa: PLC0415
|
||||
@@ -370,6 +386,8 @@ ServerStatusDep = Annotated[ServerStatus, Depends(get_server_status)]
|
||||
PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recovery)]
|
||||
SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)]
|
||||
SessionRepoDep = Annotated[SessionRepository, Depends(get_session_repo)]
|
||||
SettingsRepoDep = Annotated[SettingsRepository, Depends(get_settings_repo)]
|
||||
HistoryArchiveRepositoryDep = Annotated[HistoryArchiveRepository, Depends(get_history_archive_repo)]
|
||||
BlocklistRepositoryDep = Annotated[BlocklistRepository, Depends(get_blocklist_repo)]
|
||||
ImportLogRepositoryDep = Annotated[ImportLogRepository, Depends(get_import_log_repo)]
|
||||
GeoCacheRepositoryDep = Annotated[GeoCacheRepository, Depends(get_geo_cache_repo)]
|
||||
|
||||
@@ -50,6 +50,22 @@ class SessionRepository(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class SettingsRepository(Protocol):
|
||||
"""Protocol for application settings persistence operations."""
|
||||
|
||||
async def get_setting(self, db: aiosqlite.Connection, key: str) -> str | None:
|
||||
...
|
||||
|
||||
async def set_setting(self, db: aiosqlite.Connection, key: str, value: str) -> None:
|
||||
...
|
||||
|
||||
async def delete_setting(self, db: aiosqlite.Connection, key: str) -> None:
|
||||
...
|
||||
|
||||
async def get_all_settings(self, db: aiosqlite.Connection) -> dict[str, str]:
|
||||
...
|
||||
|
||||
|
||||
class BlocklistRepository(Protocol):
|
||||
async def create_source(
|
||||
self,
|
||||
@@ -154,6 +170,38 @@ class GeoCacheRepository(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class HistoryArchiveRepository(Protocol):
|
||||
"""Protocol for archived ban history persistence operations."""
|
||||
|
||||
async def archive_ban_event(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
jail: str,
|
||||
ip: str,
|
||||
timeofban: int,
|
||||
bancount: int,
|
||||
data: str,
|
||||
action: str = "ban",
|
||||
) -> bool:
|
||||
...
|
||||
|
||||
async def get_max_timeofban(self, db: aiosqlite.Connection) -> int | None:
|
||||
...
|
||||
|
||||
async def get_archived_history(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
since: int | None = None,
|
||||
jail: str | None = None,
|
||||
ip_filter: str | list[str] | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
action: str | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 100,
|
||||
) -> tuple[list[dict[str, object]], int]:
|
||||
...
|
||||
|
||||
|
||||
class Fail2BanDbRepository(Protocol):
|
||||
async def check_db_nonempty(self, db_path: str) -> bool:
|
||||
...
|
||||
|
||||
@@ -40,11 +40,7 @@ from app.models.ban import (
|
||||
from app.models.ban import (
|
||||
JailBanCount as JailBanCountModel,
|
||||
)
|
||||
from app.repositories import fail2ban_db_repo
|
||||
from app.repositories.history_archive_repo import (
|
||||
get_all_archived_history,
|
||||
get_archived_history,
|
||||
)
|
||||
from app.repositories import fail2ban_db_repo, history_archive_repo as default_history_archive_repo
|
||||
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
|
||||
from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso
|
||||
from app.utils.fail2ban_client import (
|
||||
@@ -57,6 +53,7 @@ if TYPE_CHECKING:
|
||||
import aiosqlite
|
||||
|
||||
from app.models.geo import GeoBatchLookup, GeoCacheLookup, GeoEnricher, GeoInfo
|
||||
from app.repositories.protocols import HistoryArchiveRepository
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
@@ -432,6 +429,7 @@ async def list_bans(
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> DashboardBanListResponse:
|
||||
"""Return a paginated list of bans within the selected time window.
|
||||
@@ -482,7 +480,7 @@ async def list_bans(
|
||||
if app_db is None:
|
||||
raise ValueError("app_db must be provided when source is 'archive'")
|
||||
|
||||
rows, total = await get_archived_history(
|
||||
rows, total = await history_archive_repo.get_archived_history(
|
||||
db=app_db,
|
||||
since=since,
|
||||
origin=origin,
|
||||
@@ -599,6 +597,7 @@ async def bans_by_country(
|
||||
geo_cache_lookup: GeoCacheLookup | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
country_code: str | None = None,
|
||||
@@ -648,7 +647,7 @@ async def bans_by_country(
|
||||
if app_db is None:
|
||||
raise ValueError("app_db must be provided when source is 'archive'")
|
||||
|
||||
all_rows = await get_all_archived_history(
|
||||
all_rows = await history_archive_repo.get_all_archived_history(
|
||||
db=app_db,
|
||||
since=since,
|
||||
origin=origin,
|
||||
@@ -726,7 +725,7 @@ async def bans_by_country(
|
||||
companion_rows: list[dict[str, object] | fail2ban_db_repo.BanRecord]
|
||||
if country_code is None:
|
||||
if source == "archive":
|
||||
companion_rows, _ = await get_archived_history(
|
||||
companion_rows, _ = await history_archive_repo.get_archived_history(
|
||||
db=app_db,
|
||||
since=since,
|
||||
origin=origin,
|
||||
@@ -751,7 +750,7 @@ async def bans_by_country(
|
||||
|
||||
if source == "archive":
|
||||
if matched_ips:
|
||||
companion_rows = await get_all_archived_history(
|
||||
companion_rows = await history_archive_repo.get_all_archived_history(
|
||||
db=app_db,
|
||||
since=since,
|
||||
origin=origin,
|
||||
@@ -859,6 +858,7 @@ async def ban_trend(
|
||||
range_: TimeRange,
|
||||
*,
|
||||
source: str = "fail2ban",
|
||||
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> BanTrendResponse:
|
||||
@@ -899,7 +899,7 @@ async def ban_trend(
|
||||
if app_db is None:
|
||||
raise ValueError("app_db must be provided when source is 'archive'")
|
||||
|
||||
all_rows = await get_all_archived_history(
|
||||
all_rows = await history_archive_repo.get_all_archived_history(
|
||||
db=app_db,
|
||||
since=since,
|
||||
origin=origin,
|
||||
@@ -966,6 +966,7 @@ async def bans_by_jail(
|
||||
range_: TimeRange,
|
||||
*,
|
||||
source: str = "fail2ban",
|
||||
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
|
||||
app_db: aiosqlite.Connection | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> BansByJailResponse:
|
||||
@@ -996,7 +997,7 @@ async def bans_by_jail(
|
||||
if app_db is None:
|
||||
raise ValueError("app_db must be provided when source is 'archive'")
|
||||
|
||||
all_rows = await get_all_archived_history(
|
||||
all_rows = await history_archive_repo.get_all_archived_history(
|
||||
db=app_db,
|
||||
since=since,
|
||||
origin=origin,
|
||||
|
||||
@@ -23,14 +23,14 @@ if TYPE_CHECKING:
|
||||
import aiohttp
|
||||
|
||||
from app.models.geo import GeoEnricher, GeoInfo
|
||||
from app.repositories.protocols import HistoryArchiveRepository
|
||||
from app.models.history import (
|
||||
HistoryBanItem,
|
||||
HistoryListResponse,
|
||||
IpDetailResponse,
|
||||
IpTimelineEvent,
|
||||
)
|
||||
from app.repositories import fail2ban_db_repo
|
||||
from app.repositories.history_archive_repo import archive_ban_event, get_max_timeofban
|
||||
from app.repositories import fail2ban_db_repo, history_archive_repo as default_history_archive_repo
|
||||
from app.services.fail2ban_metadata_service import default_fail2ban_metadata_service
|
||||
from app.utils.fail2ban_db_utils import parse_data_json, ts_to_iso
|
||||
|
||||
@@ -88,14 +88,18 @@ _HISTORY_SYNC_PAGE_SIZE: int = 500
|
||||
_HISTORY_SYNC_BACKFILL_WINDOW: int = 648000
|
||||
|
||||
|
||||
async def _get_last_archive_ts(db: aiosqlite.Connection) -> int | None:
|
||||
async def _get_last_archive_ts(
|
||||
db: aiosqlite.Connection,
|
||||
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
|
||||
) -> int | None:
|
||||
"""Return the most recent archived ban timestamp, or ``None`` if empty."""
|
||||
return await get_max_timeofban(db)
|
||||
return await history_archive_repo.get_max_timeofban(db)
|
||||
|
||||
|
||||
async def sync_from_fail2ban_db(
|
||||
db: aiosqlite.Connection,
|
||||
socket_path: str,
|
||||
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
|
||||
) -> int:
|
||||
"""Copy new records from the fail2ban DB into the BanGUI archive table.
|
||||
|
||||
@@ -106,7 +110,7 @@ async def sync_from_fail2ban_db(
|
||||
Returns:
|
||||
Number of fail2ban records scanned and archived.
|
||||
"""
|
||||
last_ts = await _get_last_archive_ts(db)
|
||||
last_ts = await _get_last_archive_ts(db, history_archive_repo=history_archive_repo)
|
||||
now_ts = int(datetime.now(tz=UTC).timestamp())
|
||||
|
||||
if last_ts is None:
|
||||
@@ -129,7 +133,7 @@ async def sync_from_fail2ban_db(
|
||||
break
|
||||
|
||||
for row in rows:
|
||||
await archive_ban_event(
|
||||
await history_archive_repo.archive_ban_event(
|
||||
db=db,
|
||||
jail=row.jail,
|
||||
ip=row.ip,
|
||||
@@ -167,6 +171,7 @@ async def list_history(
|
||||
http_session: "aiohttp.ClientSession" | None = None,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
db: aiosqlite.Connection | None = None,
|
||||
history_archive_repo: HistoryArchiveRepository = default_history_archive_repo,
|
||||
) -> HistoryListResponse:
|
||||
"""Return a paginated list of historical ban records with optional filters.
|
||||
|
||||
@@ -214,9 +219,7 @@ async def list_history(
|
||||
if db is None:
|
||||
raise ValueError("db must be provided when source is 'archive'")
|
||||
|
||||
from app.repositories.history_archive_repo import get_archived_history
|
||||
|
||||
archived_rows, total = await get_archived_history(
|
||||
archived_rows, total = await history_archive_repo.get_archived_history(
|
||||
db=db,
|
||||
since=since,
|
||||
jail=jail,
|
||||
|
||||
@@ -14,12 +14,14 @@ import bcrypt
|
||||
import structlog
|
||||
|
||||
from app.db import init_db, open_db
|
||||
from app.repositories import settings_repo
|
||||
from app.repositories import settings_repo as default_settings_repo
|
||||
from app.utils.async_utils import run_blocking
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import aiosqlite
|
||||
|
||||
from app.repositories.protocols import SettingsRepository
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
# Keys used in the settings table.
|
||||
@@ -34,11 +36,15 @@ _KEY_MAP_COLOR_THRESHOLD_MEDIUM = "map_color_threshold_medium"
|
||||
_KEY_MAP_COLOR_THRESHOLD_LOW = "map_color_threshold_low"
|
||||
|
||||
|
||||
async def is_setup_complete(db: aiosqlite.Connection) -> bool:
|
||||
async def is_setup_complete(
|
||||
db: aiosqlite.Connection,
|
||||
settings_repo: SettingsRepository = default_settings_repo,
|
||||
) -> bool:
|
||||
"""Return ``True`` if initial setup has already been performed.
|
||||
|
||||
Args:
|
||||
db: Active aiosqlite connection.
|
||||
settings_repo: Repository interface for settings persistence.
|
||||
|
||||
Returns:
|
||||
``True`` when the ``setup_completed`` key exists in settings.
|
||||
@@ -55,6 +61,7 @@ async def run_setup(
|
||||
fail2ban_socket: str,
|
||||
timezone: str,
|
||||
session_duration_minutes: int,
|
||||
settings_repo: SettingsRepository = default_settings_repo,
|
||||
) -> None:
|
||||
"""Persist the initial configuration and mark setup as complete.
|
||||
|
||||
@@ -72,7 +79,7 @@ async def run_setup(
|
||||
Raises:
|
||||
RuntimeError: If setup has already been completed.
|
||||
"""
|
||||
if await is_setup_complete(db):
|
||||
if await is_setup_complete(db, settings_repo=settings_repo):
|
||||
raise RuntimeError("Setup has already been completed.")
|
||||
|
||||
log.info("bangui_setup_started")
|
||||
@@ -120,17 +127,26 @@ async def run_setup(
|
||||
log.info("bangui_setup_completed")
|
||||
|
||||
|
||||
async def get_password_hash(db: aiosqlite.Connection) -> str | None:
|
||||
async def get_password_hash(
|
||||
db: aiosqlite.Connection,
|
||||
settings_repo: SettingsRepository = default_settings_repo,
|
||||
) -> str | None:
|
||||
"""Return the stored bcrypt password hash, or ``None`` if not set."""
|
||||
return await settings_repo.get_setting(db, _KEY_PASSWORD_HASH)
|
||||
|
||||
|
||||
async def get_runtime_database_path(db: aiosqlite.Connection) -> str | None:
|
||||
async def get_runtime_database_path(
|
||||
db: aiosqlite.Connection,
|
||||
settings_repo: SettingsRepository = default_settings_repo,
|
||||
) -> str | None:
|
||||
"""Return the runtime database path persisted during initial setup."""
|
||||
return await settings_repo.get_setting(db, _KEY_DATABASE_PATH)
|
||||
|
||||
|
||||
async def get_persisted_runtime_settings(db: aiosqlite.Connection) -> dict[str, str | int]:
|
||||
async def get_persisted_runtime_settings(
|
||||
db: aiosqlite.Connection,
|
||||
settings_repo: SettingsRepository = default_settings_repo,
|
||||
) -> dict[str, str | int]:
|
||||
"""Return runtime configuration values persisted during initial setup."""
|
||||
runtime_settings: dict[str, str | int] = {}
|
||||
|
||||
@@ -183,7 +199,10 @@ async def _ensure_database_initialized(database_path: str) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
async def get_timezone(db: aiosqlite.Connection) -> str:
|
||||
async def get_timezone(
|
||||
db: aiosqlite.Connection,
|
||||
settings_repo: SettingsRepository = default_settings_repo,
|
||||
) -> str:
|
||||
"""Return the configured IANA timezone string."""
|
||||
tz = await settings_repo.get_setting(db, _KEY_TIMEZONE)
|
||||
return tz if tz else "UTC"
|
||||
|
||||
@@ -14,9 +14,11 @@ from app.dependencies import (
|
||||
get_app_context,
|
||||
get_db,
|
||||
get_http_session,
|
||||
get_history_archive_repo,
|
||||
get_scheduler,
|
||||
get_settings,
|
||||
get_session_cache,
|
||||
get_settings_repo,
|
||||
)
|
||||
from app.main import create_app
|
||||
from app.models.server import ServerStatus
|
||||
@@ -65,6 +67,21 @@ async def test_app_context_dependency_exposes_shared_resources(test_settings: Se
|
||||
await session.close()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_settings_and_history_archive_repo_dependencies_return_modules() -> None:
|
||||
settings_repo = await get_settings_repo()
|
||||
history_archive_repo = await get_history_archive_repo()
|
||||
|
||||
assert hasattr(settings_repo, "get_setting")
|
||||
assert hasattr(settings_repo, "set_setting")
|
||||
assert hasattr(settings_repo, "delete_setting")
|
||||
assert hasattr(settings_repo, "get_all_settings")
|
||||
|
||||
assert hasattr(history_archive_repo, "archive_ban_event")
|
||||
assert hasattr(history_archive_repo, "get_max_timeofban")
|
||||
assert hasattr(history_archive_repo, "get_archived_history")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_db_uses_effective_runtime_database_path(test_settings: Settings) -> None:
|
||||
"""Database connections should use effective runtime settings when overridden."""
|
||||
|
||||
Reference in New Issue
Block a user