Files
BanGUI/backend/scripts/validate_repository_protocols.py
Lukas b44b72053a T-11: Validate repository Protocol structural compatibility — minimal approach (Option B)
Problem: Repository modules use structural typing to satisfy Protocol interfaces via
cast(). A function rename, parameter change, or signature mismatch would silently pass
mypy but fail at runtime.

Solution (Option B — minimal):
1. Aligned Protocol signatures in protocols.py with actual implementations:
   - BlocklistRepository: dict[str, object] → dict[str, Any] (matches implementation)
   - ImportLogRepository: dict[str, object] → ImportLogRow (typed model)
   - GeoCacheRepository: dict[str, object] → GeoCacheRow; Iterable → Sequence
   - HistoryArchiveRepository: dict[str, object] → dict[str, Any]
   - ImportLogRepository: async compute_total_pages → sync (matches implementation)

2. Created CI validation script (backend/scripts/validate_repository_protocols.py)
   that runs at build time to ensure all repository modules satisfy their Protocol
   interfaces. Exit 0 if valid, 1 if any mismatch. Detects:
   - Missing functions
   - Parameter count mismatches
   - Type annotation mismatches
   - Return type mismatches

3. Updated backend/app/dependencies.py with explicit docstrings linking each
   get_*_repo() provider to Backend-Development.md § 13.7.1, explaining the
   module-as-Protocol pattern and that it is intentional and validated.

4. Documented the pattern in Backend-Development.md § 13.7.1:
   'Repository Module Pattern — Module-as-Protocol Structural Compatibility'
   explaining why the pattern works, risks (silent breakage), and how the
   validation mitigates it.

5. Fixed type annotation in history_archive_repo.py:
   - get_all_archived_history returns list[dict] → list[dict[str, Any]]
   - Imported Any type

Benefits:
- Prevents silent breakage of repository interfaces
- Formalizes the module-as-Protocol pattern as intentional
- CI validation prevents regressions without refactoring cost
- All repository tests pass (53/53)
- mypy --strict passes on modified files

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-04-25 18:59:49 +02:00

155 lines
5.4 KiB
Python

#!/usr/bin/env python3
"""Validate that repository modules satisfy their Protocol interfaces.
This script verifies that each repository module's top-level async functions
match the signatures defined in the corresponding Protocol in protocols.py.
This is a CI-time validation to ensure the module-as-Protocol structural typing
pattern documented in Backend-Development.md § 13.7.1 does not silently break.
Exit code:
0 → All repositories satisfy their Protocol interfaces
1 → One or more repositories do not satisfy their Protocol interfaces
"""
from __future__ import annotations
import inspect
import sys
from pathlib import Path
from typing import Any
# Add backend to path
backend_path = Path(__file__).parent.parent
sys.path.insert(0, str(backend_path))
from app.repositories import protocols
def get_protocol_methods(protocol_cls: type) -> dict[str, inspect.Signature]:
"""Extract all non-private async method signatures from a Protocol class."""
methods: dict[str, inspect.Signature] = {}
for name, method in inspect.getmembers(protocol_cls, predicate=inspect.iscoroutinefunction):
if not name.startswith("_"):
methods[name] = inspect.signature(method)
return methods
def get_module_functions(module: Any) -> dict[str, inspect.Signature]:
"""Extract all non-private async functions from a module."""
functions: dict[str, inspect.Signature] = {}
for name, func in inspect.getmembers(module, predicate=inspect.iscoroutinefunction):
if not name.startswith("_"):
functions[name] = inspect.signature(func)
return functions
def signature_matches(protocol_sig: inspect.Signature, module_sig: inspect.Signature) -> bool:
"""Check if a module function signature matches a Protocol method signature.
Protocol methods have 'self' as the first parameter, which module functions
do not have. Ignore this difference when comparing.
"""
proto_params = list(protocol_sig.parameters.values())
mod_params = list(module_sig.parameters.values())
# Remove 'self' from protocol parameters
if proto_params and proto_params[0].name == "self":
proto_params = proto_params[1:]
# Compare parameter count
if len(proto_params) != len(mod_params):
return False
# Compare parameter names, annotations, and defaults
for proto_param, mod_param in zip(proto_params, mod_params):
if proto_param.name != mod_param.name:
return False
if proto_param.annotation != mod_param.annotation:
return False
if proto_param.default != mod_param.default:
return False
# Compare return type
if protocol_sig.return_annotation != module_sig.return_annotation:
return False
return True
def validate_repository(repo_name: str, protocol_cls: type, module: Any) -> bool:
"""Validate that a repository module satisfies its Protocol interface.
Returns True if valid, False if invalid.
"""
protocol_methods = get_protocol_methods(protocol_cls)
module_functions = get_module_functions(module)
errors: list[str] = []
# Check for missing functions
for method_name in protocol_methods:
if method_name not in module_functions:
errors.append(f" ✗ Missing function: {method_name}")
# Check for signature mismatches
for method_name, protocol_sig in protocol_methods.items():
if method_name in module_functions:
module_sig = module_functions[method_name]
if not signature_matches(protocol_sig, module_sig):
errors.append(
f" ✗ Signature mismatch for {method_name}:\n"
f" Protocol: {protocol_sig}\n"
f" Module: {module_sig}"
)
if errors:
print(f"\n{repo_name} does NOT satisfy {protocol_cls.__name__}:")
for error in errors:
print(error)
return False
print(f"{repo_name} satisfies {protocol_cls.__name__}")
return True
def main() -> int:
"""Run all repository validations."""
# Import all repository modules
from app.repositories import ( # noqa: PLC0415
blocklist_repo,
fail2ban_db_repo,
geo_cache_repo,
history_archive_repo,
import_log_repo,
session_repo,
settings_repo,
)
validations: list[tuple[str, type, Any]] = [
("session_repo", protocols.SessionRepository, session_repo),
("settings_repo", protocols.SettingsRepository, settings_repo),
("blocklist_repo", protocols.BlocklistRepository, blocklist_repo),
("import_log_repo", protocols.ImportLogRepository, import_log_repo),
("geo_cache_repo", protocols.GeoCacheRepository, geo_cache_repo),
("history_archive_repo", protocols.HistoryArchiveRepository, history_archive_repo),
("fail2ban_db_repo", protocols.Fail2BanDbRepository, fail2ban_db_repo),
]
print("Validating repository Protocol compatibility...\n")
all_valid = True
for repo_name, protocol_cls, module in validations:
if not validate_repository(repo_name, protocol_cls, module):
all_valid = False
if all_valid:
print("\n✓ All repositories satisfy their Protocol interfaces.")
return 0
else:
print("\n✗ One or more repositories do not satisfy their Protocol interfaces.")
return 1
if __name__ == "__main__":
sys.exit(main())