Implement transactional setup with explicit state machine and crash-safety
to prevent partial commits from leaving inconsistent state.
## Changes
### Core Implementation
1. **settings_repo.py**: Add atomic batch settings write
- New set_settings_batch() method: writes multiple settings in single
transaction (BEGIN IMMEDIATE ... COMMIT). Either all settings persist
or none do, preventing partial state if crash occurs mid-batch.
2. **setup_service.py**: Refactor run_setup() with transactional phases
- Phase 0: Compute password hash early (before any DB writes) to ensure
idempotency. Same hash is used throughout retries, preventing divergent
hashes from bcrypt's random salt.
- Phase 1 (Bootstrap DB transaction): Set setup_state=in_progress and
database_path, then commit. First checkpoint for crash detection.
- Phase 2 (Filesystem): Initialize runtime database (idempotent)
- Phase 3 (Runtime DB transaction): Batch-write all settings atomically
- Phase 4 (Bootstrap DB transaction): Set setup_state=complete and
setup_completed=1. Final commit point.
3. **protocols.py**: Add set_settings_batch to SettingsRepository protocol
### Testing
- Added 6 new transactionality tests covering:
- State machine transitions (None → in_progress → complete)
- Password hash idempotency across retries
- Atomic batch writes (all-or-nothing persistence)
- Bootstrap DB state tracking
- Database path propagation to both DBs
- Recovery on partial failure
- All 18 tests pass (12 existing + 6 new)
### Documentation
- Updated Docs/Architekture.md with new section 6:
- Setup state machine with state transitions
- Transaction boundary documentation
- Password hash idempotency rationale
- Backward compatibility notes
## Design Decisions
### Why This Approach
- Current code already idempotent via INSERT OR REPLACE, but password
hash non-idempotency created silent inconsistency risk
- Simpler than multi-state machine: 2 states sufficient for detection
- Maintains backward compatibility (setup_completed key still written)
- Explicit transactions make crash-safety obvious to future maintainers
### Crash Scenarios Now Handled
1. Crash after Phase 1 → detected by setup_state=in_progress on retry
2. Crash after Phase 2 → runtime DB may be partial, safe to retry
3. Crash after Phase 3 → runtime DB rolls back on next connection
4. Crash after Phase 4 → setup_completed detected, skipped
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
318 lines
8.1 KiB
Python
318 lines
8.1 KiB
Python
"""Repository interface protocols for dependency injection.
|
|
|
|
Routers and services can depend on these abstractions instead of concrete
|
|
module implementations, making the backend easier to test and extend.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Sequence
|
|
from typing import Any, Protocol
|
|
|
|
import aiosqlite
|
|
|
|
from app.models.auth import Session
|
|
from app.models.ban import BanOrigin
|
|
from app.repositories.fail2ban_db_repo import BanIpCount, BanRecord, HistoryRecord, JailBanCount
|
|
from app.repositories.geo_cache_repo import GeoCacheRow
|
|
from app.repositories.import_log_repo import ImportLogRow
|
|
|
|
|
|
class SessionRepository(Protocol):
|
|
"""Protocol for session persistence operations."""
|
|
|
|
async def create_session(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
token: str,
|
|
created_at: str,
|
|
expires_at: str,
|
|
) -> Session:
|
|
...
|
|
|
|
async def get_session(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
token: str,
|
|
) -> Session | None:
|
|
...
|
|
|
|
async def delete_session(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
token: str,
|
|
) -> None:
|
|
...
|
|
|
|
async def delete_expired_sessions(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
now_iso: str,
|
|
) -> int:
|
|
...
|
|
|
|
|
|
class SettingsRepository(Protocol):
|
|
"""Protocol for application settings persistence operations."""
|
|
|
|
async def get_setting(self, db: aiosqlite.Connection, key: str) -> str | None:
|
|
...
|
|
|
|
async def set_setting(self, db: aiosqlite.Connection, key: str, value: str) -> None:
|
|
...
|
|
|
|
async def delete_setting(self, db: aiosqlite.Connection, key: str) -> None:
|
|
...
|
|
|
|
async def get_all_settings(self, db: aiosqlite.Connection) -> dict[str, str]:
|
|
...
|
|
|
|
async def set_settings_batch(self, db: aiosqlite.Connection, settings: dict[str, str]) -> None:
|
|
...
|
|
|
|
|
|
class BlocklistRepository(Protocol):
|
|
async def create_source(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
name: str,
|
|
url: str,
|
|
*,
|
|
enabled: bool = True,
|
|
) -> int:
|
|
...
|
|
|
|
async def get_source(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
source_id: int,
|
|
) -> dict[str, Any] | None:
|
|
...
|
|
|
|
async def list_sources(self, db: aiosqlite.Connection) -> list[dict[str, Any]]:
|
|
...
|
|
|
|
async def list_enabled_sources(self, db: aiosqlite.Connection) -> list[dict[str, Any]]:
|
|
...
|
|
|
|
async def update_source(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
source_id: int,
|
|
*,
|
|
name: str | None = None,
|
|
url: str | None = None,
|
|
enabled: bool | None = None,
|
|
) -> bool:
|
|
...
|
|
|
|
async def delete_source(self, db: aiosqlite.Connection, source_id: int) -> bool:
|
|
...
|
|
|
|
|
|
class ImportLogRepository(Protocol):
|
|
async def add_log(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
*,
|
|
source_id: int | None,
|
|
source_url: str,
|
|
ips_imported: int,
|
|
ips_skipped: int,
|
|
errors: str | None,
|
|
) -> int:
|
|
...
|
|
|
|
async def list_logs(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
*,
|
|
source_id: int | None = None,
|
|
page: int = 1,
|
|
page_size: int = 50,
|
|
) -> tuple[list[ImportLogRow], int]:
|
|
...
|
|
|
|
async def get_last_log(self, db: aiosqlite.Connection) -> ImportLogRow | None:
|
|
...
|
|
|
|
def compute_total_pages(self, total: int, page_size: int) -> int:
|
|
...
|
|
|
|
|
|
class GeoCacheRepository(Protocol):
|
|
async def load_all(self, db: aiosqlite.Connection) -> list[GeoCacheRow]:
|
|
...
|
|
|
|
async def get_unresolved_ips(self, db: aiosqlite.Connection) -> list[str]:
|
|
...
|
|
|
|
async def count_unresolved(self, db: aiosqlite.Connection) -> int:
|
|
...
|
|
|
|
async def upsert_entry(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
ip: str,
|
|
country_code: str | None,
|
|
country_name: str | None,
|
|
asn: str | None,
|
|
org: str | None,
|
|
) -> None:
|
|
...
|
|
|
|
async def upsert_entry_and_commit(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
ip: str,
|
|
country_code: str | None,
|
|
country_name: str | None,
|
|
asn: str | None,
|
|
org: str | None,
|
|
) -> None:
|
|
...
|
|
|
|
async def upsert_neg_entry(self, db: aiosqlite.Connection, ip: str) -> None:
|
|
...
|
|
|
|
async def upsert_neg_entry_and_commit(self, db: aiosqlite.Connection, ip: str) -> None:
|
|
...
|
|
|
|
async def bulk_upsert_entries(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
rows: Sequence[tuple[str, str | None, str | None, str | None, str | None]],
|
|
) -> int:
|
|
...
|
|
|
|
async def bulk_upsert_entries_and_commit(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
rows: Sequence[tuple[str, str | None, str | None, str | None, str | None]],
|
|
) -> int:
|
|
...
|
|
|
|
async def bulk_upsert_neg_entries(self, db: aiosqlite.Connection, ips: list[str]) -> int:
|
|
...
|
|
|
|
async def bulk_upsert_neg_entries_and_commit(self, db: aiosqlite.Connection, ips: list[str]) -> int:
|
|
...
|
|
|
|
async def bulk_upsert_entries_and_neg_entries_and_commit(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
rows: Sequence[tuple[str, str | None, str | None, str | None, str | None]],
|
|
ips: list[str],
|
|
) -> tuple[int, int]:
|
|
...
|
|
|
|
async def delete_stale_entries(self, db: aiosqlite.Connection, cutoff_iso: str) -> int:
|
|
...
|
|
|
|
|
|
class HistoryArchiveRepository(Protocol):
|
|
"""Protocol for archived ban history persistence operations."""
|
|
|
|
async def archive_ban_event(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
jail: str,
|
|
ip: str,
|
|
timeofban: int,
|
|
bancount: int,
|
|
data: str,
|
|
action: str = "ban",
|
|
) -> bool:
|
|
...
|
|
|
|
async def get_max_timeofban(self, db: aiosqlite.Connection) -> int | None:
|
|
...
|
|
|
|
async def get_archived_history(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
since: int | None = None,
|
|
jail: str | None = None,
|
|
ip_filter: str | list[str] | None = None,
|
|
origin: BanOrigin | None = None,
|
|
action: str | None = None,
|
|
page: int = 1,
|
|
page_size: int = 100,
|
|
) -> tuple[list[dict[str, Any]], int]:
|
|
...
|
|
|
|
async def get_all_archived_history(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
since: int | None = None,
|
|
jail: str | None = None,
|
|
ip_filter: str | list[str] | None = None,
|
|
origin: BanOrigin | None = None,
|
|
action: str | None = None,
|
|
) -> list[dict[str, Any]]:
|
|
...
|
|
|
|
async def purge_archived_history(self, db: aiosqlite.Connection, age_seconds: int) -> int:
|
|
...
|
|
|
|
|
|
class Fail2BanDbRepository(Protocol):
|
|
async def check_db_nonempty(self, db_path: str) -> bool:
|
|
...
|
|
|
|
async def get_currently_banned(
|
|
self,
|
|
db_path: str,
|
|
since: int,
|
|
origin: BanOrigin | None = None,
|
|
*,
|
|
ip_filter: list[str] | None = None,
|
|
limit: int | None = None,
|
|
offset: int | None = None,
|
|
) -> tuple[list[BanRecord], int]:
|
|
...
|
|
|
|
async def get_ban_counts_by_bucket(
|
|
self,
|
|
db_path: str,
|
|
since: int,
|
|
bucket_secs: int,
|
|
num_buckets: int,
|
|
origin: BanOrigin | None = None,
|
|
) -> list[int]:
|
|
...
|
|
|
|
async def get_ban_event_counts(
|
|
self,
|
|
db_path: str,
|
|
since: int,
|
|
origin: BanOrigin | None = None,
|
|
) -> list[BanIpCount]:
|
|
...
|
|
|
|
async def get_bans_by_jail(
|
|
self,
|
|
db_path: str,
|
|
since: int,
|
|
origin: BanOrigin | None = None,
|
|
) -> tuple[int, list[JailBanCount]]:
|
|
...
|
|
|
|
async def get_bans_table_summary(self, db_path: str) -> tuple[int, int | None, int | None]:
|
|
...
|
|
|
|
async def get_history_page(
|
|
self,
|
|
db_path: str,
|
|
since: int | None = None,
|
|
jail: str | None = None,
|
|
ip_filter: str | None = None,
|
|
origin: BanOrigin | None = None,
|
|
page: int = 1,
|
|
page_size: int = 100,
|
|
) -> tuple[list[HistoryRecord], int]:
|
|
...
|
|
|
|
async def get_history_for_ip(self, db_path: str, ip: str) -> list[HistoryRecord]:
|
|
...
|