"""Tests validating repository protocol compliance. These tests ensure that each repository module exports all methods defined in its corresponding Protocol class in protocols.py, with matching signatures. This validates structural typing compatibility and prevents silent breakage when protocol methods change. """ from __future__ import annotations import inspect from typing import Protocol, get_type_hints import pytest from app.repositories import ( blocklist_repo, fail2ban_db_repo, geo_cache_repo, history_archive_repo, import_log_repo, session_repo, settings_repo, ) from app.repositories.protocols import ( BlocklistRepository, Fail2BanDbRepository, GeoCacheRepository, HistoryArchiveRepository, ImportLogRepository, SessionRepository, SettingsRepository, ) def _get_protocol_methods(protocol_class: type[Protocol]) -> set[str]: """Extract public async/sync method names from a Protocol class.""" methods: set[str] = set() for name, member in inspect.getmembers(protocol_class): # Skip private/magic methods and non-callables if name.startswith("_") or not callable(member): continue # Include both async and sync methods if inspect.iscoroutinefunction(member) or ( hasattr(member, "__func__") and callable(member) ): methods.add(name) return methods def _get_module_methods(module: object) -> set[str]: """Extract public async/sync function names from a module.""" methods: set[str] = set() for name, member in inspect.getmembers(module): # Skip private functions, classes, and non-callables if name.startswith("_") or not callable(member): continue if inspect.isclass(member): continue # Include both async and sync functions if inspect.iscoroutinefunction(member) or callable(member): methods.add(name) return methods class TestSessionRepositoryCompliance: """Validate SessionRepository protocol compliance.""" def test_implements_all_protocol_methods(self) -> None: """Session repository module implements all SessionRepository methods.""" protocol_methods = _get_protocol_methods(SessionRepository) module_methods = _get_module_methods(session_repo) missing = protocol_methods - module_methods assert not missing, f"Missing methods: {missing}" def test_no_unexpected_public_methods(self) -> None: """Session repository has no unexpected public methods.""" protocol_methods = _get_protocol_methods(SessionRepository) module_methods = _get_module_methods(session_repo) extra = module_methods - protocol_methods # Allow _hash_token and other private functions assert not extra, f"Unexpected public methods: {extra}" class TestSettingsRepositoryCompliance: """Validate SettingsRepository protocol compliance.""" def test_implements_all_protocol_methods(self) -> None: """Settings repository module implements all SettingsRepository methods.""" protocol_methods = _get_protocol_methods(SettingsRepository) module_methods = _get_module_methods(settings_repo) missing = protocol_methods - module_methods assert not missing, f"Missing methods: {missing}" class TestBlocklistRepositoryCompliance: """Validate BlocklistRepository protocol compliance.""" def test_implements_all_protocol_methods(self) -> None: """Blocklist repository module implements all BlocklistRepository methods.""" protocol_methods = _get_protocol_methods(BlocklistRepository) module_methods = _get_module_methods(blocklist_repo) missing = protocol_methods - module_methods assert not missing, f"Missing methods: {missing}" class TestImportLogRepositoryCompliance: """Validate ImportLogRepository protocol compliance.""" def test_implements_all_protocol_methods(self) -> None: """ImportLog repository module implements all ImportLogRepository methods.""" protocol_methods = _get_protocol_methods(ImportLogRepository) module_methods = _get_module_methods(import_log_repo) missing = protocol_methods - module_methods assert not missing, f"Missing methods: {missing}" class TestGeoCacheRepositoryCompliance: """Validate GeoCacheRepository protocol compliance.""" def test_implements_all_protocol_methods(self) -> None: """GeoCache repository module implements all GeoCacheRepository methods.""" protocol_methods = _get_protocol_methods(GeoCacheRepository) module_methods = _get_module_methods(geo_cache_repo) missing = protocol_methods - module_methods assert not missing, f"Missing methods: {missing}" class TestHistoryArchiveRepositoryCompliance: """Validate HistoryArchiveRepository protocol compliance.""" def test_implements_all_protocol_methods(self) -> None: """HistoryArchive repository module implements all HistoryArchiveRepository methods.""" protocol_methods = _get_protocol_methods(HistoryArchiveRepository) module_methods = _get_module_methods(history_archive_repo) missing = protocol_methods - module_methods assert not missing, f"Missing methods: {missing}" class TestFail2BanDbRepositoryCompliance: """Validate Fail2BanDbRepository protocol compliance.""" def test_implements_all_protocol_methods(self) -> None: """Fail2BanDb repository module implements all Fail2BanDbRepository methods.""" protocol_methods = _get_protocol_methods(Fail2BanDbRepository) module_methods = _get_module_methods(fail2ban_db_repo) missing = protocol_methods - module_methods assert not missing, f"Missing methods: {missing}"