208 lines
5.0 KiB
Python
208 lines
5.0 KiB
Python
"""Repository interface protocols for dependency injection.
|
|
|
|
Routers and services can depend on these abstractions instead of concrete
|
|
module implementations, making the backend easier to test and extend.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from collections.abc import Iterable
|
|
from typing import Protocol
|
|
|
|
import aiosqlite
|
|
|
|
from app.models.auth import Session
|
|
from app.models.ban import BanOrigin
|
|
from app.repositories.fail2ban_db_repo import BanIpCount, BanRecord, HistoryRecord, JailBanCount
|
|
|
|
|
|
class SessionRepository(Protocol):
|
|
"""Protocol for session persistence operations."""
|
|
|
|
async def create_session(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
token: str,
|
|
created_at: str,
|
|
expires_at: str,
|
|
) -> Session:
|
|
...
|
|
|
|
async def get_session(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
token: str,
|
|
) -> Session | None:
|
|
...
|
|
|
|
async def delete_session(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
token: str,
|
|
) -> None:
|
|
...
|
|
|
|
async def delete_expired_sessions(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
now_iso: str,
|
|
) -> int:
|
|
...
|
|
|
|
|
|
class BlocklistRepository(Protocol):
|
|
async def create_source(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
name: str,
|
|
url: str,
|
|
*,
|
|
enabled: bool = True,
|
|
) -> int:
|
|
...
|
|
|
|
async def get_source(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
source_id: int,
|
|
) -> dict[str, object] | None:
|
|
...
|
|
|
|
async def list_sources(self, db: aiosqlite.Connection) -> list[dict[str, object]]:
|
|
...
|
|
|
|
async def list_enabled_sources(self, db: aiosqlite.Connection) -> list[dict[str, object]]:
|
|
...
|
|
|
|
async def update_source(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
source_id: int,
|
|
*,
|
|
name: str | None = None,
|
|
url: str | None = None,
|
|
enabled: bool | None = None,
|
|
) -> bool:
|
|
...
|
|
|
|
async def delete_source(self, db: aiosqlite.Connection, source_id: int) -> bool:
|
|
...
|
|
|
|
|
|
class ImportLogRepository(Protocol):
|
|
async def add_log(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
*,
|
|
source_id: int | None,
|
|
source_url: str,
|
|
ips_imported: int,
|
|
ips_skipped: int,
|
|
errors: str | None,
|
|
) -> int:
|
|
...
|
|
|
|
async def list_logs(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
*,
|
|
source_id: int | None = None,
|
|
page: int = 1,
|
|
page_size: int = 50,
|
|
) -> tuple[list[dict[str, object]], int]:
|
|
...
|
|
|
|
async def get_last_log(self, db: aiosqlite.Connection) -> dict[str, object] | None:
|
|
...
|
|
|
|
async def compute_total_pages(self, total: int, page_size: int) -> int:
|
|
...
|
|
|
|
|
|
class GeoCacheRepository(Protocol):
|
|
async def load_all(self, db: aiosqlite.Connection) -> list[dict[str, object]]:
|
|
...
|
|
|
|
async def get_unresolved_ips(self, db: aiosqlite.Connection) -> list[str]:
|
|
...
|
|
|
|
async def count_unresolved(self, db: aiosqlite.Connection) -> int:
|
|
...
|
|
|
|
async def upsert_entry(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
ip: str,
|
|
country_code: str | None,
|
|
country_name: str | None,
|
|
asn: str | None,
|
|
org: str | None,
|
|
) -> None:
|
|
...
|
|
|
|
async def upsert_neg_entry(self, db: aiosqlite.Connection, ip: str) -> None:
|
|
...
|
|
|
|
async def bulk_upsert_entries(
|
|
self,
|
|
db: aiosqlite.Connection,
|
|
rows: Iterable[tuple[str, str | None, str | None, str | None, str | None]],
|
|
) -> int:
|
|
...
|
|
|
|
async def bulk_upsert_neg_entries(self, db: aiosqlite.Connection, ips: list[str]) -> int:
|
|
...
|
|
|
|
|
|
class Fail2BanDbRepository(Protocol):
|
|
async def check_db_nonempty(self, db_path: str) -> bool:
|
|
...
|
|
|
|
async def get_currently_banned(
|
|
self,
|
|
db_path: str,
|
|
since: int,
|
|
origin: BanOrigin | None = None,
|
|
*,
|
|
ip_filter: list[str] | None = None,
|
|
limit: int | None = None,
|
|
offset: int | None = None,
|
|
) -> tuple[list[BanRecord], int]:
|
|
...
|
|
|
|
async def get_ban_counts_by_bucket(
|
|
self,
|
|
db_path: str,
|
|
since: int,
|
|
bucket_secs: int,
|
|
num_buckets: int,
|
|
origin: BanOrigin | None = None,
|
|
) -> list[int]:
|
|
...
|
|
|
|
async def get_bans_by_jail(
|
|
self,
|
|
db_path: str,
|
|
since: int,
|
|
origin: BanOrigin | None = None,
|
|
) -> tuple[int, list[JailBanCount]]:
|
|
...
|
|
|
|
async def get_bans_table_summary(self, db_path: str) -> tuple[int, int | None, int | None]:
|
|
...
|
|
|
|
async def get_history_page(
|
|
self,
|
|
db_path: str,
|
|
since: int | None = None,
|
|
jail: str | None = None,
|
|
ip_filter: str | None = None,
|
|
origin: BanOrigin | None = None,
|
|
page: int = 1,
|
|
page_size: int = 100,
|
|
) -> tuple[list[HistoryRecord], int]:
|
|
...
|
|
|
|
async def get_history_for_ip(self, db_path: str, ip: str) -> list[HistoryRecord]:
|
|
...
|