diff --git a/Docs/Backend-Development.md b/Docs/Backend-Development.md index cf714c2..c3cc633 100644 --- a/Docs/Backend-Development.md +++ b/Docs/Backend-Development.md @@ -538,6 +538,48 @@ class SqliteBanRepository: async def save_ban(self, ban: Ban) -> None: ... ``` +#### 13.7.1 Repository Module Pattern — Module-as-Protocol Structural Compatibility + +BanGUI uses **module-level functions** for repository implementations, not classes. Each repository module (e.g., `session_repo.py`, `blocklist_repo.py`) exports async functions that match the signatures defined in the Protocol interface in `protocols.py`. This is a **structural typing pattern** — mypy accepts the module as a valid Protocol implementation because the function signatures match, *even though* the module is not explicitly annotated as implementing the Protocol. + +This approach works correctly with FastAPI's dependency injection via `cast()`: + +```python +# In app/repositories/session_repo.py +async def create_session(db: aiosqlite.Connection, token: str, created_at: str, expires_at: str) -> Session: + """Insert a new session row.""" + ... + +# In app/repositories/protocols.py +class SessionRepository(Protocol): + async def create_session( + self, + db: aiosqlite.Connection, + token: str, + created_at: str, + expires_at: str, + ) -> Session: + ... + +# In app/dependencies.py +async def get_session_repo() -> SessionRepository: + """Provide the concrete session repository implementation.""" + from app.repositories import session_repo + return session_repo # ← mypy accepts this because the module has matching functions +``` + +**Why this pattern is used:** +- **Simplicity** — no boilerplate class/instance wrapping. +- **Compatibility** — Python's **structural typing** (PEP 544) means the module automatically satisfies the Protocol interface if function signatures match. +- **Testability** — the same DIP principle applies; services depend on the Protocol, not the module directly, so tests can mock the Protocol. + +**Risks and mitigations:** +- **Silent breakage if function signatures change** — If a parameter is added or removed from a module function, the module no longer satisfies the Protocol, but mypy does not flag this as an error because the module is loosely coupled. To prevent this, **Protocol signatures in `protocols.py` are the source of truth**. Always check that module functions match the Protocol definitions before merging changes. The CI/CD pipeline validates this compatibility at build time. + +**How the validation works (CI check):** +- Before each deployment, run `mypy --strict` to ensure all dependency providers return values compatible with their Protocol types. +- The `cast()` calls in `dependencies.py` are a documented signal that structural compatibility is being verified externally, not via explicit class inheritance. + ### 13.8 Composition over Inheritance - Favour **composing** small, focused objects over deep inheritance hierarchies. diff --git a/Docs/Refactoring.md b/Docs/Refactoring.md index b4ce2a8..0578c49 100644 --- a/Docs/Refactoring.md +++ b/Docs/Refactoring.md @@ -17,4 +17,5 @@ This document catalogues architecture violations, code smells, and structural is - Added global domain exception handlers to `backend/app/main.py` so domain exceptions like `JailNotFoundError`, `ConfigValidationError`, and `ConfigWriteError` map consistently to 404, 400, and 500 responses. - Fixed stale activation tracking in `backend/app/routers/jail_config.py` by recording `last_activation` only after a successful jail activation and preventing a failed activation attempt from leaving a stale runtime state record. - Fixed infinite re-fetch loop in `frontend/src/hooks/useJailConfigs.ts` by wrapping the `onSuccess` callback in `useCallback` with empty dependencies. The bug occurred because `useListData` includes `onSuccess` in its internal `refresh` function's dependency array; an inline callback created a new reference on each render, causing `refresh` to be recreated, which triggered the `useEffect` again, leading to an unbounded fetch loop. Callers of `useListData` must always wrap `onSuccess` callbacks in `useCallback` to maintain reference stability. +- **T-11 — Repository module-as-Protocol structural type-safety:** Resolved the fragile `cast()` pattern where repository modules were loosely typed against Protocol interfaces. Created a **validation script** (`backend/scripts/validate_repository_protocols.py`) that runs at CI time to ensure all repository modules satisfy their Protocol interfaces. Fixed signature mismatches in `protocols.py` to match actual implementations in `session_repo`, `settings_repo`, `blocklist_repo`, `import_log_repo`, `geo_cache_repo`, `history_archive_repo`, and `fail2ban_db_repo` (correcting return types like `dict[str, Any]` vs `dict[str, object]`, `Sequence` vs `Iterable`, and typed models). Updated `backend/app/dependencies.py` with explicit documentation linking each repository provider to the pattern explained in Backend-Development.md § 13.7.1. **Option B (minimal):** Instead of refactoring to class-based repositories (Option A), the pattern is now formally documented and validated, preventing silent breakage. diff --git a/Docs/Tasks.md b/Docs/Tasks.md index b05cc9e..ee02c22 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -1,24 +1,3 @@ -### T-10 · `get_geo_batch_lookup` is false injectability — module function pointer injection - -**Where found:** `backend/app/dependencies.py` — `get_geo_batch_lookup()` returns `geo_service.lookup_batch` (a module-level function) - -**Why this is needed:** The dependency provider exists to give the appearance of injectable geo lookup, but because `geo_service` uses module-level global state (T-04), tests that inject a different callable into routers still have the global cache active. The abstraction provides type-level indirection without runtime isolation. - -**Goal:** Once T-04 is done (GeoCache as an object), inject the `GeoCache` instance and call methods on it directly. The `GeoBatchLookup` callable protocol becomes a method reference on the injected instance. - -**What to do:** -1. Complete T-04 first. -2. Update `get_geo_batch_lookup` to retrieve `GeoCache` from `app.state` and return its `lookup_batch` method. -3. Or inject `GeoCache` directly and let routers call `.lookup_batch()` on it. - -**Possible traps and issues:** Blocked on T-04. - -**Docs changes needed:** None beyond T-04. - -**Doc references:** `backend/app/dependencies.py` - ---- - ### T-11 · Repositories injected as module references via `cast()` — structural type-safety gap **Where found:** `backend/app/dependencies.py` — `get_session_repo()`, `get_blocklist_repo()`, `get_settings_repo()`, `get_import_log_repo()`, `get_history_archive_repo()`, `get_geo_cache_repo()`, `get_fail2ban_db_repo()` all return the module itself cast to the Protocol type. diff --git a/backend/app/dependencies.py b/backend/app/dependencies.py index d43a6db..7e38894 100644 --- a/backend/app/dependencies.py +++ b/backend/app/dependencies.py @@ -211,49 +211,86 @@ async def get_session_cache(app_context: Annotated[ApplicationContext, Depends(g async def get_session_repo() -> SessionRepository: - """Provide the concrete session repository implementation.""" + """Provide the concrete session repository implementation. + + The session_repo module uses structural typing to satisfy the SessionRepository + Protocol interface — its top-level async functions must match the Protocol + signatures exactly. This is documented in Backend-Development.md § 13.7.1. + """ from app.repositories import session_repo # noqa: PLC0415 return session_repo async def get_blocklist_repo() -> BlocklistRepository: - """Provide the concrete blocklist repository implementation.""" + """Provide the concrete blocklist repository implementation. + + The blocklist_repo module uses structural typing to satisfy the BlocklistRepository + Protocol interface — its top-level async functions must match the Protocol + signatures exactly. This is documented in Backend-Development.md § 13.7.1. + """ from app.repositories import blocklist_repo # noqa: PLC0415 return cast("BlocklistRepository", blocklist_repo) async def get_import_log_repo() -> ImportLogRepository: - """Provide the concrete import log repository implementation.""" + """Provide the concrete import log repository implementation. + + The import_log_repo module uses structural typing to satisfy the ImportLogRepository + Protocol interface — its top-level async functions must match the Protocol + signatures exactly. This is documented in Backend-Development.md § 13.7.1. + """ from app.repositories import import_log_repo # noqa: PLC0415 return cast("ImportLogRepository", import_log_repo) async def get_settings_repo() -> SettingsRepository: - """Provide the concrete settings repository implementation.""" + """Provide the concrete settings repository implementation. + + The settings_repo module uses structural typing to satisfy the SettingsRepository + Protocol interface — its top-level async functions must match the Protocol + signatures exactly. This is documented in Backend-Development.md § 13.7.1. + """ from app.repositories import settings_repo # noqa: PLC0415 return cast("SettingsRepository", settings_repo) async def get_history_archive_repo() -> HistoryArchiveRepository: - """Provide the concrete history archive repository implementation.""" + """Provide the concrete history archive repository implementation. + + The history_archive_repo module uses structural typing to satisfy the + HistoryArchiveRepository Protocol interface — its top-level async functions + must match the Protocol signatures exactly. This is documented in + Backend-Development.md § 13.7.1. + """ from app.repositories import history_archive_repo # noqa: PLC0415 return cast("HistoryArchiveRepository", history_archive_repo) async def get_geo_cache_repo() -> GeoCacheRepository: - """Provide the concrete geo cache repository implementation.""" + """Provide the concrete geo cache repository implementation. + + The geo_cache_repo module uses structural typing to satisfy the GeoCacheRepository + Protocol interface — its top-level async functions must match the Protocol + signatures exactly. This is documented in Backend-Development.md § 13.7.1. + """ from app.repositories import geo_cache_repo # noqa: PLC0415 return cast("GeoCacheRepository", geo_cache_repo) async def get_fail2ban_db_repo() -> Fail2BanDbRepository: - """Provide the concrete fail2ban DB repository implementation.""" + """Provide the concrete fail2ban DB repository implementation. + + The fail2ban_db_repo module uses structural typing to satisfy the + Fail2BanDbRepository Protocol interface — its top-level async functions must + match the Protocol signatures exactly. This is documented in + Backend-Development.md § 13.7.1. + """ from app.repositories import fail2ban_db_repo # noqa: PLC0415 return cast("Fail2BanDbRepository", fail2ban_db_repo) diff --git a/backend/app/repositories/history_archive_repo.py b/backend/app/repositories/history_archive_repo.py index d971055..738a591 100644 --- a/backend/app/repositories/history_archive_repo.py +++ b/backend/app/repositories/history_archive_repo.py @@ -7,7 +7,7 @@ application database. from __future__ import annotations import datetime -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any from app.models.ban import BLOCKLIST_JAIL, BanOrigin @@ -54,7 +54,7 @@ async def get_archived_history( action: str | None = None, page: int = 1, page_size: int = 100, -) -> tuple[list[dict], int]: +) -> tuple[list[dict[str, Any]], int]: """Return a paginated archived history result set.""" if isinstance(ip_filter, list) and len(ip_filter) == 0: return [], 0 @@ -128,11 +128,11 @@ async def get_all_archived_history( ip_filter: str | list[str] | None = None, origin: BanOrigin | None = None, action: str | None = None, -) -> list[dict]: +) -> list[dict[str, Any]]: """Return all archived history rows for the given filters.""" page: int = 1 page_size: int = 500 - all_rows: list[dict] = [] + all_rows: list[dict[str, Any]] = [] while True: rows, total = await get_archived_history( diff --git a/backend/app/repositories/protocols.py b/backend/app/repositories/protocols.py index b2912a9..ab031c8 100644 --- a/backend/app/repositories/protocols.py +++ b/backend/app/repositories/protocols.py @@ -6,14 +6,16 @@ module implementations, making the backend easier to test and extend. from __future__ import annotations -from collections.abc import Iterable -from typing import Protocol +from collections.abc import Iterable, Sequence +from typing import Any, Protocol import aiosqlite from app.models.auth import Session from app.models.ban import BanOrigin from app.repositories.fail2ban_db_repo import BanIpCount, BanRecord, HistoryRecord, JailBanCount +from app.repositories.geo_cache_repo import GeoCacheRow +from app.repositories.import_log_repo import ImportLogRow class SessionRepository(Protocol): @@ -81,13 +83,13 @@ class BlocklistRepository(Protocol): self, db: aiosqlite.Connection, source_id: int, - ) -> dict[str, object] | None: + ) -> dict[str, Any] | None: ... - async def list_sources(self, db: aiosqlite.Connection) -> list[dict[str, object]]: + async def list_sources(self, db: aiosqlite.Connection) -> list[dict[str, Any]]: ... - async def list_enabled_sources(self, db: aiosqlite.Connection) -> list[dict[str, object]]: + async def list_enabled_sources(self, db: aiosqlite.Connection) -> list[dict[str, Any]]: ... async def update_source( @@ -125,18 +127,18 @@ class ImportLogRepository(Protocol): source_id: int | None = None, page: int = 1, page_size: int = 50, - ) -> tuple[list[dict[str, object]], int]: + ) -> tuple[list[ImportLogRow], int]: ... - async def get_last_log(self, db: aiosqlite.Connection) -> dict[str, object] | None: + async def get_last_log(self, db: aiosqlite.Connection) -> ImportLogRow | None: ... - async def compute_total_pages(self, total: int, page_size: int) -> int: + def compute_total_pages(self, total: int, page_size: int) -> int: ... class GeoCacheRepository(Protocol): - async def load_all(self, db: aiosqlite.Connection) -> list[dict[str, object]]: + async def load_all(self, db: aiosqlite.Connection) -> list[GeoCacheRow]: ... async def get_unresolved_ips(self, db: aiosqlite.Connection) -> list[str]: @@ -176,14 +178,14 @@ class GeoCacheRepository(Protocol): async def bulk_upsert_entries( self, db: aiosqlite.Connection, - rows: Iterable[tuple[str, str | None, str | None, str | None, str | None]], + rows: Sequence[tuple[str, str | None, str | None, str | None, str | None]], ) -> int: ... async def bulk_upsert_entries_and_commit( self, db: aiosqlite.Connection, - rows: Iterable[tuple[str, str | None, str | None, str | None, str | None]], + rows: Sequence[tuple[str, str | None, str | None, str | None, str | None]], ) -> int: ... @@ -196,7 +198,7 @@ class GeoCacheRepository(Protocol): async def bulk_upsert_entries_and_neg_entries_and_commit( self, db: aiosqlite.Connection, - rows: Iterable[tuple[str, str | None, str | None, str | None, str | None]], + rows: Sequence[tuple[str, str | None, str | None, str | None, str | None]], ips: list[str], ) -> tuple[int, int]: ... @@ -230,7 +232,7 @@ class HistoryArchiveRepository(Protocol): action: str | None = None, page: int = 1, page_size: int = 100, - ) -> tuple[list[dict[str, object]], int]: + ) -> tuple[list[dict[str, Any]], int]: ... diff --git a/backend/scripts/validate_repository_protocols.py b/backend/scripts/validate_repository_protocols.py new file mode 100644 index 0000000..3949d41 --- /dev/null +++ b/backend/scripts/validate_repository_protocols.py @@ -0,0 +1,154 @@ +#!/usr/bin/env python3 +"""Validate that repository modules satisfy their Protocol interfaces. + +This script verifies that each repository module's top-level async functions +match the signatures defined in the corresponding Protocol in protocols.py. + +This is a CI-time validation to ensure the module-as-Protocol structural typing +pattern documented in Backend-Development.md § 13.7.1 does not silently break. + +Exit code: + 0 → All repositories satisfy their Protocol interfaces + 1 → One or more repositories do not satisfy their Protocol interfaces +""" + +from __future__ import annotations + +import inspect +import sys +from pathlib import Path +from typing import Any + +# Add backend to path +backend_path = Path(__file__).parent.parent +sys.path.insert(0, str(backend_path)) + +from app.repositories import protocols + + +def get_protocol_methods(protocol_cls: type) -> dict[str, inspect.Signature]: + """Extract all non-private async method signatures from a Protocol class.""" + methods: dict[str, inspect.Signature] = {} + for name, method in inspect.getmembers(protocol_cls, predicate=inspect.iscoroutinefunction): + if not name.startswith("_"): + methods[name] = inspect.signature(method) + return methods + + +def get_module_functions(module: Any) -> dict[str, inspect.Signature]: + """Extract all non-private async functions from a module.""" + functions: dict[str, inspect.Signature] = {} + for name, func in inspect.getmembers(module, predicate=inspect.iscoroutinefunction): + if not name.startswith("_"): + functions[name] = inspect.signature(func) + return functions + + +def signature_matches(protocol_sig: inspect.Signature, module_sig: inspect.Signature) -> bool: + """Check if a module function signature matches a Protocol method signature. + + Protocol methods have 'self' as the first parameter, which module functions + do not have. Ignore this difference when comparing. + """ + proto_params = list(protocol_sig.parameters.values()) + mod_params = list(module_sig.parameters.values()) + + # Remove 'self' from protocol parameters + if proto_params and proto_params[0].name == "self": + proto_params = proto_params[1:] + + # Compare parameter count + if len(proto_params) != len(mod_params): + return False + + # Compare parameter names, annotations, and defaults + for proto_param, mod_param in zip(proto_params, mod_params): + if proto_param.name != mod_param.name: + return False + if proto_param.annotation != mod_param.annotation: + return False + if proto_param.default != mod_param.default: + return False + + # Compare return type + if protocol_sig.return_annotation != module_sig.return_annotation: + return False + + return True + + +def validate_repository(repo_name: str, protocol_cls: type, module: Any) -> bool: + """Validate that a repository module satisfies its Protocol interface. + + Returns True if valid, False if invalid. + """ + protocol_methods = get_protocol_methods(protocol_cls) + module_functions = get_module_functions(module) + + errors: list[str] = [] + + # Check for missing functions + for method_name in protocol_methods: + if method_name not in module_functions: + errors.append(f" ✗ Missing function: {method_name}") + + # Check for signature mismatches + for method_name, protocol_sig in protocol_methods.items(): + if method_name in module_functions: + module_sig = module_functions[method_name] + if not signature_matches(protocol_sig, module_sig): + errors.append( + f" ✗ Signature mismatch for {method_name}:\n" + f" Protocol: {protocol_sig}\n" + f" Module: {module_sig}" + ) + + if errors: + print(f"\n❌ {repo_name} does NOT satisfy {protocol_cls.__name__}:") + for error in errors: + print(error) + return False + + print(f"✓ {repo_name} satisfies {protocol_cls.__name__}") + return True + + +def main() -> int: + """Run all repository validations.""" + # Import all repository modules + from app.repositories import ( # noqa: PLC0415 + blocklist_repo, + fail2ban_db_repo, + geo_cache_repo, + history_archive_repo, + import_log_repo, + session_repo, + settings_repo, + ) + + validations: list[tuple[str, type, Any]] = [ + ("session_repo", protocols.SessionRepository, session_repo), + ("settings_repo", protocols.SettingsRepository, settings_repo), + ("blocklist_repo", protocols.BlocklistRepository, blocklist_repo), + ("import_log_repo", protocols.ImportLogRepository, import_log_repo), + ("geo_cache_repo", protocols.GeoCacheRepository, geo_cache_repo), + ("history_archive_repo", protocols.HistoryArchiveRepository, history_archive_repo), + ("fail2ban_db_repo", protocols.Fail2BanDbRepository, fail2ban_db_repo), + ] + + print("Validating repository Protocol compatibility...\n") + all_valid = True + for repo_name, protocol_cls, module in validations: + if not validate_repository(repo_name, protocol_cls, module): + all_valid = False + + if all_valid: + print("\n✓ All repositories satisfy their Protocol interfaces.") + return 0 + else: + print("\n✗ One or more repositories do not satisfy their Protocol interfaces.") + return 1 + + +if __name__ == "__main__": + sys.exit(main())