#!/usr/bin/env python3 """Validate that repository modules satisfy their Protocol interfaces. This script verifies that each repository module's top-level async functions match the signatures defined in the corresponding Protocol in protocols.py. This is a CI-time validation to ensure the module-as-Protocol structural typing pattern documented in Backend-Development.md § 13.7.1 does not silently break. Exit code: 0 → All repositories satisfy their Protocol interfaces 1 → One or more repositories do not satisfy their Protocol interfaces """ from __future__ import annotations import inspect import sys from pathlib import Path from typing import Any # Add backend to path backend_path = Path(__file__).parent.parent sys.path.insert(0, str(backend_path)) from app.repositories import protocols def get_protocol_methods(protocol_cls: type) -> dict[str, inspect.Signature]: """Extract all non-private async method signatures from a Protocol class.""" methods: dict[str, inspect.Signature] = {} for name, method in inspect.getmembers(protocol_cls, predicate=inspect.iscoroutinefunction): if not name.startswith("_"): methods[name] = inspect.signature(method) return methods def get_module_functions(module: Any) -> dict[str, inspect.Signature]: """Extract all non-private async functions from a module.""" functions: dict[str, inspect.Signature] = {} for name, func in inspect.getmembers(module, predicate=inspect.iscoroutinefunction): if not name.startswith("_"): functions[name] = inspect.signature(func) return functions def signature_matches(protocol_sig: inspect.Signature, module_sig: inspect.Signature) -> bool: """Check if a module function signature matches a Protocol method signature. Protocol methods have 'self' as the first parameter, which module functions do not have. Ignore this difference when comparing. """ proto_params = list(protocol_sig.parameters.values()) mod_params = list(module_sig.parameters.values()) # Remove 'self' from protocol parameters if proto_params and proto_params[0].name == "self": proto_params = proto_params[1:] # Compare parameter count if len(proto_params) != len(mod_params): return False # Compare parameter names, annotations, and defaults for proto_param, mod_param in zip(proto_params, mod_params): if proto_param.name != mod_param.name: return False if proto_param.annotation != mod_param.annotation: return False if proto_param.default != mod_param.default: return False # Compare return type if protocol_sig.return_annotation != module_sig.return_annotation: return False return True def validate_repository(repo_name: str, protocol_cls: type, module: Any) -> bool: """Validate that a repository module satisfies its Protocol interface. Returns True if valid, False if invalid. """ protocol_methods = get_protocol_methods(protocol_cls) module_functions = get_module_functions(module) errors: list[str] = [] # Check for missing functions for method_name in protocol_methods: if method_name not in module_functions: errors.append(f" ✗ Missing function: {method_name}") # Check for signature mismatches for method_name, protocol_sig in protocol_methods.items(): if method_name in module_functions: module_sig = module_functions[method_name] if not signature_matches(protocol_sig, module_sig): errors.append( f" ✗ Signature mismatch for {method_name}:\n" f" Protocol: {protocol_sig}\n" f" Module: {module_sig}" ) if errors: print(f"\n❌ {repo_name} does NOT satisfy {protocol_cls.__name__}:") for error in errors: print(error) return False print(f"✓ {repo_name} satisfies {protocol_cls.__name__}") return True def main() -> int: """Run all repository validations.""" # Import all repository modules from app.repositories import ( # noqa: PLC0415 blocklist_repo, fail2ban_db_repo, geo_cache_repo, history_archive_repo, import_log_repo, session_repo, settings_repo, ) validations: list[tuple[str, type, Any]] = [ ("session_repo", protocols.SessionRepository, session_repo), ("settings_repo", protocols.SettingsRepository, settings_repo), ("blocklist_repo", protocols.BlocklistRepository, blocklist_repo), ("import_log_repo", protocols.ImportLogRepository, import_log_repo), ("geo_cache_repo", protocols.GeoCacheRepository, geo_cache_repo), ("history_archive_repo", protocols.HistoryArchiveRepository, history_archive_repo), ("fail2ban_db_repo", protocols.Fail2BanDbRepository, fail2ban_db_repo), ] print("Validating repository Protocol compatibility...\n") all_valid = True for repo_name, protocol_cls, module in validations: if not validate_repository(repo_name, protocol_cls, module): all_valid = False if all_valid: print("\n✓ All repositories satisfy their Protocol interfaces.") return 0 else: print("\n✗ One or more repositories do not satisfy their Protocol interfaces.") return 1 if __name__ == "__main__": sys.exit(main())