Files
BanGUI/backend/app/repositories/protocols.py
Lukas a273b96563 feat: Complete repository protocol coverage
- Add missing protocol methods to Fail2BanDbRepository:
  - get_ban_event_counts: Aggregate ban events per IP (used in ban_service)

- Add missing protocol methods to GeoCacheRepository:
  - delete_stale_entries: Remove old geo cache entries (used in geo_cache_cleanup)

- Add missing protocol methods to HistoryArchiveRepository:
  - purge_archived_history: Remove archived entries older than age threshold

- Add comprehensive protocol compliance tests:
  - Created test_protocol_compliance.py with 8 test classes
  - Validates all 7 repository modules fully implement their protocols
  - Prevents silent protocol drift when methods change signatures
  - Tests verify no unexpected public methods in repository modules

- Update documentation:
  - Add Repository Protocol Coverage Checklist to Backend-Development.md
  - Document procedure for adding new repositories with protocol definitions
  - List current protocol coverage (all 7 repositories, 40 total methods)

- All repositories now have 100% protocol coverage:
  - SessionRepository: 4 methods
  - SettingsRepository: 4 methods
  - BlocklistRepository: 6 methods
  - ImportLogRepository: 4 methods
  - GeoCacheRepository: 13 methods
  - HistoryArchiveRepository: 5 methods
  - Fail2BanDbRepository: 8 methods

This ensures:
- Enhanced mockability for testing
- Static contract verification
- Prevention of protocol drift
- Better IDE support and type checking

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-04-28 07:58:57 +02:00

315 lines
8.0 KiB
Python

"""Repository interface protocols for dependency injection.
Routers and services can depend on these abstractions instead of concrete
module implementations, making the backend easier to test and extend.
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import Any, Protocol
import aiosqlite
from app.models.auth import Session
from app.models.ban import BanOrigin
from app.repositories.fail2ban_db_repo import BanIpCount, BanRecord, HistoryRecord, JailBanCount
from app.repositories.geo_cache_repo import GeoCacheRow
from app.repositories.import_log_repo import ImportLogRow
class SessionRepository(Protocol):
"""Protocol for session persistence operations."""
async def create_session(
self,
db: aiosqlite.Connection,
token: str,
created_at: str,
expires_at: str,
) -> Session:
...
async def get_session(
self,
db: aiosqlite.Connection,
token: str,
) -> Session | None:
...
async def delete_session(
self,
db: aiosqlite.Connection,
token: str,
) -> None:
...
async def delete_expired_sessions(
self,
db: aiosqlite.Connection,
now_iso: str,
) -> int:
...
class SettingsRepository(Protocol):
"""Protocol for application settings persistence operations."""
async def get_setting(self, db: aiosqlite.Connection, key: str) -> str | None:
...
async def set_setting(self, db: aiosqlite.Connection, key: str, value: str) -> None:
...
async def delete_setting(self, db: aiosqlite.Connection, key: str) -> None:
...
async def get_all_settings(self, db: aiosqlite.Connection) -> dict[str, str]:
...
class BlocklistRepository(Protocol):
async def create_source(
self,
db: aiosqlite.Connection,
name: str,
url: str,
*,
enabled: bool = True,
) -> int:
...
async def get_source(
self,
db: aiosqlite.Connection,
source_id: int,
) -> dict[str, Any] | None:
...
async def list_sources(self, db: aiosqlite.Connection) -> list[dict[str, Any]]:
...
async def list_enabled_sources(self, db: aiosqlite.Connection) -> list[dict[str, Any]]:
...
async def update_source(
self,
db: aiosqlite.Connection,
source_id: int,
*,
name: str | None = None,
url: str | None = None,
enabled: bool | None = None,
) -> bool:
...
async def delete_source(self, db: aiosqlite.Connection, source_id: int) -> bool:
...
class ImportLogRepository(Protocol):
async def add_log(
self,
db: aiosqlite.Connection,
*,
source_id: int | None,
source_url: str,
ips_imported: int,
ips_skipped: int,
errors: str | None,
) -> int:
...
async def list_logs(
self,
db: aiosqlite.Connection,
*,
source_id: int | None = None,
page: int = 1,
page_size: int = 50,
) -> tuple[list[ImportLogRow], int]:
...
async def get_last_log(self, db: aiosqlite.Connection) -> ImportLogRow | None:
...
def compute_total_pages(self, total: int, page_size: int) -> int:
...
class GeoCacheRepository(Protocol):
async def load_all(self, db: aiosqlite.Connection) -> list[GeoCacheRow]:
...
async def get_unresolved_ips(self, db: aiosqlite.Connection) -> list[str]:
...
async def count_unresolved(self, db: aiosqlite.Connection) -> int:
...
async def upsert_entry(
self,
db: aiosqlite.Connection,
ip: str,
country_code: str | None,
country_name: str | None,
asn: str | None,
org: str | None,
) -> None:
...
async def upsert_entry_and_commit(
self,
db: aiosqlite.Connection,
ip: str,
country_code: str | None,
country_name: str | None,
asn: str | None,
org: str | None,
) -> None:
...
async def upsert_neg_entry(self, db: aiosqlite.Connection, ip: str) -> None:
...
async def upsert_neg_entry_and_commit(self, db: aiosqlite.Connection, ip: str) -> None:
...
async def bulk_upsert_entries(
self,
db: aiosqlite.Connection,
rows: Sequence[tuple[str, str | None, str | None, str | None, str | None]],
) -> int:
...
async def bulk_upsert_entries_and_commit(
self,
db: aiosqlite.Connection,
rows: Sequence[tuple[str, str | None, str | None, str | None, str | None]],
) -> int:
...
async def bulk_upsert_neg_entries(self, db: aiosqlite.Connection, ips: list[str]) -> int:
...
async def bulk_upsert_neg_entries_and_commit(self, db: aiosqlite.Connection, ips: list[str]) -> int:
...
async def bulk_upsert_entries_and_neg_entries_and_commit(
self,
db: aiosqlite.Connection,
rows: Sequence[tuple[str, str | None, str | None, str | None, str | None]],
ips: list[str],
) -> tuple[int, int]:
...
async def delete_stale_entries(self, db: aiosqlite.Connection, cutoff_iso: str) -> int:
...
class HistoryArchiveRepository(Protocol):
"""Protocol for archived ban history persistence operations."""
async def archive_ban_event(
self,
db: aiosqlite.Connection,
jail: str,
ip: str,
timeofban: int,
bancount: int,
data: str,
action: str = "ban",
) -> bool:
...
async def get_max_timeofban(self, db: aiosqlite.Connection) -> int | None:
...
async def get_archived_history(
self,
db: aiosqlite.Connection,
since: int | None = None,
jail: str | None = None,
ip_filter: str | list[str] | None = None,
origin: BanOrigin | None = None,
action: str | None = None,
page: int = 1,
page_size: int = 100,
) -> tuple[list[dict[str, Any]], int]:
...
async def get_all_archived_history(
self,
db: aiosqlite.Connection,
since: int | None = None,
jail: str | None = None,
ip_filter: str | list[str] | None = None,
origin: BanOrigin | None = None,
action: str | None = None,
) -> list[dict[str, Any]]:
...
async def purge_archived_history(self, db: aiosqlite.Connection, age_seconds: int) -> int:
...
class Fail2BanDbRepository(Protocol):
async def check_db_nonempty(self, db_path: str) -> bool:
...
async def get_currently_banned(
self,
db_path: str,
since: int,
origin: BanOrigin | None = None,
*,
ip_filter: list[str] | None = None,
limit: int | None = None,
offset: int | None = None,
) -> tuple[list[BanRecord], int]:
...
async def get_ban_counts_by_bucket(
self,
db_path: str,
since: int,
bucket_secs: int,
num_buckets: int,
origin: BanOrigin | None = None,
) -> list[int]:
...
async def get_ban_event_counts(
self,
db_path: str,
since: int,
origin: BanOrigin | None = None,
) -> list[BanIpCount]:
...
async def get_bans_by_jail(
self,
db_path: str,
since: int,
origin: BanOrigin | None = None,
) -> tuple[int, list[JailBanCount]]:
...
async def get_bans_table_summary(self, db_path: str) -> tuple[int, int | None, int | None]:
...
async def get_history_page(
self,
db_path: str,
since: int | None = None,
jail: str | None = None,
ip_filter: str | None = None,
origin: BanOrigin | None = None,
page: int = 1,
page_size: int = 100,
) -> tuple[list[HistoryRecord], int]:
...
async def get_history_for_ip(self, db_path: str, ip: str) -> list[HistoryRecord]:
...