refactoring-backend #3

Merged
lukas.pupkalipinski merged 403 commits from refactoring-backend into main 2026-05-20 20:23:46 +02:00
3 changed files with 183 additions and 1 deletions
Showing only changes of commit a273b96563 - Show all commits

View File

@@ -1710,6 +1710,26 @@ async def get_session_repo() -> SessionRepository:
**How the validation works (CI check):**
- Before each deployment, run `mypy --strict` to ensure all dependency providers return values compatible with their Protocol types.
- The `cast()` calls in `dependencies.py` are a documented signal that structural compatibility is being verified externally, not via explicit class inheritance.
- Automated tests in `backend/tests/test_repositories/test_protocol_compliance.py` verify that each repository module implements all protocol methods, preventing silent protocol drift.
#### 13.7.1.1 Repository Protocol Coverage Checklist
All public repository functions must be defined in a corresponding Protocol. To add a new repository:
1. **Create the repository module** — `backend/app/repositories/my_repo.py` with async functions.
2. **Define the Protocol** — Add a `MyRepository(Protocol)` class in `backend/app/repositories/protocols.py` with methods matching every public function signature.
3. **Add imports** — If the Protocol uses custom return types, import them in `protocols.py`.
4. **Run compliance tests** — Execute `pytest backend/tests/test_repositories/test_protocol_compliance.py` to verify coverage.
5. **Verify type safety** — Run `mypy --strict backend/app/repositories/protocols.py` to ensure all types are correct.
**Current repository protocol coverage** (all 7 repositories fully covered):
- `SessionRepository` — 4 methods
- `SettingsRepository` — 4 methods
- `BlocklistRepository` — 6 methods
- `ImportLogRepository` — 4 methods
- `GeoCacheRepository` — 13 methods
- `HistoryArchiveRepository` — 5 methods
- `Fail2BanDbRepository` — 8 methods
#### 13.7.2 Session Token Hashing — One-Way Protection Against Database Exposure

View File

@@ -13,7 +13,7 @@ import aiosqlite
from app.models.auth import Session
from app.models.ban import BanOrigin
from app.repositories.fail2ban_db_repo import BanRecord, HistoryRecord, JailBanCount
from app.repositories.fail2ban_db_repo import BanIpCount, BanRecord, HistoryRecord, JailBanCount
from app.repositories.geo_cache_repo import GeoCacheRow
from app.repositories.import_log_repo import ImportLogRow
@@ -203,6 +203,9 @@ class GeoCacheRepository(Protocol):
) -> tuple[int, int]:
...
async def delete_stale_entries(self, db: aiosqlite.Connection, cutoff_iso: str) -> int:
...
class HistoryArchiveRepository(Protocol):
"""Protocol for archived ban history persistence operations."""
@@ -246,6 +249,9 @@ class HistoryArchiveRepository(Protocol):
) -> list[dict[str, Any]]:
...
async def purge_archived_history(self, db: aiosqlite.Connection, age_seconds: int) -> int:
...
class Fail2BanDbRepository(Protocol):
async def check_db_nonempty(self, db_path: str) -> bool:
@@ -273,6 +279,14 @@ class Fail2BanDbRepository(Protocol):
) -> list[int]:
...
async def get_ban_event_counts(
self,
db_path: str,
since: int,
origin: BanOrigin | None = None,
) -> list[BanIpCount]:
...
async def get_bans_by_jail(
self,
db_path: str,

View File

@@ -0,0 +1,148 @@
"""Tests validating repository protocol compliance.
These tests ensure that each repository module exports all methods defined
in its corresponding Protocol class in protocols.py, with matching signatures.
This validates structural typing compatibility and prevents silent breakage
when protocol methods change.
"""
from __future__ import annotations
import inspect
from typing import Protocol, get_type_hints
import pytest
from app.repositories import (
blocklist_repo,
fail2ban_db_repo,
geo_cache_repo,
history_archive_repo,
import_log_repo,
session_repo,
settings_repo,
)
from app.repositories.protocols import (
BlocklistRepository,
Fail2BanDbRepository,
GeoCacheRepository,
HistoryArchiveRepository,
ImportLogRepository,
SessionRepository,
SettingsRepository,
)
def _get_protocol_methods(protocol_class: type[Protocol]) -> set[str]:
"""Extract public async/sync method names from a Protocol class."""
methods: set[str] = set()
for name, member in inspect.getmembers(protocol_class):
# Skip private/magic methods and non-callables
if name.startswith("_") or not callable(member):
continue
# Include both async and sync methods
if inspect.iscoroutinefunction(member) or (
hasattr(member, "__func__") and callable(member)
):
methods.add(name)
return methods
def _get_module_methods(module: object) -> set[str]:
"""Extract public async/sync function names from a module."""
methods: set[str] = set()
for name, member in inspect.getmembers(module):
# Skip private functions, classes, and non-callables
if name.startswith("_") or not callable(member):
continue
if inspect.isclass(member):
continue
# Include both async and sync functions
if inspect.iscoroutinefunction(member) or callable(member):
methods.add(name)
return methods
class TestSessionRepositoryCompliance:
"""Validate SessionRepository protocol compliance."""
def test_implements_all_protocol_methods(self) -> None:
"""Session repository module implements all SessionRepository methods."""
protocol_methods = _get_protocol_methods(SessionRepository)
module_methods = _get_module_methods(session_repo)
missing = protocol_methods - module_methods
assert not missing, f"Missing methods: {missing}"
def test_no_unexpected_public_methods(self) -> None:
"""Session repository has no unexpected public methods."""
protocol_methods = _get_protocol_methods(SessionRepository)
module_methods = _get_module_methods(session_repo)
extra = module_methods - protocol_methods
# Allow _hash_token and other private functions
assert not extra, f"Unexpected public methods: {extra}"
class TestSettingsRepositoryCompliance:
"""Validate SettingsRepository protocol compliance."""
def test_implements_all_protocol_methods(self) -> None:
"""Settings repository module implements all SettingsRepository methods."""
protocol_methods = _get_protocol_methods(SettingsRepository)
module_methods = _get_module_methods(settings_repo)
missing = protocol_methods - module_methods
assert not missing, f"Missing methods: {missing}"
class TestBlocklistRepositoryCompliance:
"""Validate BlocklistRepository protocol compliance."""
def test_implements_all_protocol_methods(self) -> None:
"""Blocklist repository module implements all BlocklistRepository methods."""
protocol_methods = _get_protocol_methods(BlocklistRepository)
module_methods = _get_module_methods(blocklist_repo)
missing = protocol_methods - module_methods
assert not missing, f"Missing methods: {missing}"
class TestImportLogRepositoryCompliance:
"""Validate ImportLogRepository protocol compliance."""
def test_implements_all_protocol_methods(self) -> None:
"""ImportLog repository module implements all ImportLogRepository methods."""
protocol_methods = _get_protocol_methods(ImportLogRepository)
module_methods = _get_module_methods(import_log_repo)
missing = protocol_methods - module_methods
assert not missing, f"Missing methods: {missing}"
class TestGeoCacheRepositoryCompliance:
"""Validate GeoCacheRepository protocol compliance."""
def test_implements_all_protocol_methods(self) -> None:
"""GeoCache repository module implements all GeoCacheRepository methods."""
protocol_methods = _get_protocol_methods(GeoCacheRepository)
module_methods = _get_module_methods(geo_cache_repo)
missing = protocol_methods - module_methods
assert not missing, f"Missing methods: {missing}"
class TestHistoryArchiveRepositoryCompliance:
"""Validate HistoryArchiveRepository protocol compliance."""
def test_implements_all_protocol_methods(self) -> None:
"""HistoryArchive repository module implements all HistoryArchiveRepository methods."""
protocol_methods = _get_protocol_methods(HistoryArchiveRepository)
module_methods = _get_module_methods(history_archive_repo)
missing = protocol_methods - module_methods
assert not missing, f"Missing methods: {missing}"
class TestFail2BanDbRepositoryCompliance:
"""Validate Fail2BanDbRepository protocol compliance."""
def test_implements_all_protocol_methods(self) -> None:
"""Fail2BanDb repository module implements all Fail2BanDbRepository methods."""
protocol_methods = _get_protocol_methods(Fail2BanDbRepository)
module_methods = _get_module_methods(fail2ban_db_repo)
missing = protocol_methods - module_methods
assert not missing, f"Missing methods: {missing}"