Task 25: extend service/repository protocol coverage and wire DI aliases

This commit is contained in:
2026-04-14 12:32:42 +02:00
parent b1fba79a2e
commit 09c764cebc
4 changed files with 554 additions and 741 deletions

View File

@@ -22,8 +22,23 @@ from app.models.auth import Session
from app.models.config import PendingRecovery
from app.models.geo import GeoBatchLookup
from app.models.server import ServerStatus
from app.repositories.protocols import SessionRepository
from app.services.protocols import AuthService, JailService
from app.repositories.protocols import (
BlocklistRepository,
Fail2BanDbRepository,
GeoCacheRepository,
ImportLogRepository,
SessionRepository,
)
from app.services.protocols import (
AuthService,
BlocklistService,
ConfigService,
GeoService,
HealthService,
HistoryService,
JailService,
ServerService,
)
from app.utils.constants import SESSION_COOKIE_NAME
from app.utils.runtime_state import RuntimeState
from app.utils.session_cache import InMemorySessionCache, NoOpSessionCache, SessionCache
@@ -239,6 +254,48 @@ async def get_jail_service() -> JailService:
return cast("JailService", jail_service)
async def get_blocklist_service() -> BlocklistService:
"""Provide the concrete blocklist service implementation."""
from app.services import blocklist_service # noqa: PLC0415
return cast("BlocklistService", blocklist_service)
async def get_config_service() -> ConfigService:
"""Provide the concrete configuration service implementation."""
from app.services import config_service # noqa: PLC0415
return cast("ConfigService", config_service)
async def get_history_service() -> HistoryService:
"""Provide the concrete history service implementation."""
from app.services import history_service # noqa: PLC0415
return cast("HistoryService", history_service)
async def get_geo_service() -> GeoService:
"""Provide the concrete geo service implementation."""
from app.services import geo_service # noqa: PLC0415
return cast("GeoService", geo_service)
async def get_health_service() -> HealthService:
"""Provide the concrete health service implementation."""
from app.services import health_service # noqa: PLC0415
return cast("HealthService", health_service)
async def get_server_service() -> ServerService:
"""Provide the concrete server service implementation."""
from app.services import server_service # noqa: PLC0415
return cast("ServerService", server_service)
async def get_session_repo() -> SessionRepository:
"""Provide the concrete session repository implementation."""
from app.repositories import session_repo # noqa: PLC0415
@@ -246,6 +303,34 @@ async def get_session_repo() -> SessionRepository:
return session_repo
async def get_blocklist_repo() -> BlocklistRepository:
"""Provide the concrete blocklist repository implementation."""
from app.repositories import blocklist_repo # noqa: PLC0415
return cast("BlocklistRepository", blocklist_repo)
async def get_import_log_repo() -> ImportLogRepository:
"""Provide the concrete import log repository implementation."""
from app.repositories import import_log_repo # noqa: PLC0415
return cast("ImportLogRepository", import_log_repo)
async def get_geo_cache_repo() -> GeoCacheRepository:
"""Provide the concrete geo cache repository implementation."""
from app.repositories import geo_cache_repo # noqa: PLC0415
return cast("GeoCacheRepository", geo_cache_repo)
async def get_fail2ban_db_repo() -> Fail2BanDbRepository:
"""Provide the concrete fail2ban DB repository implementation."""
from app.repositories import fail2ban_db_repo # noqa: PLC0415
return cast("Fail2BanDbRepository", fail2ban_db_repo)
async def get_app_state(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> ApplicationContext:
"""Provide the application state object for the current request."""
return app_context
@@ -351,7 +436,17 @@ PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recov
SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)]
AuthServiceDep = Annotated[AuthService, Depends(get_auth_service)]
JailServiceDep = Annotated[JailService, Depends(get_jail_service)]
BlocklistServiceDep = Annotated[BlocklistService, Depends(get_blocklist_service)]
ConfigServiceDep = Annotated[ConfigService, Depends(get_config_service)]
HistoryServiceDep = Annotated[HistoryService, Depends(get_history_service)]
GeoServiceDep = Annotated[GeoService, Depends(get_geo_service)]
HealthServiceDep = Annotated[HealthService, Depends(get_health_service)]
ServerServiceDep = Annotated[ServerService, Depends(get_server_service)]
SessionRepoDep = Annotated[SessionRepository, Depends(get_session_repo)]
AppStateDep = Annotated[AppState, Depends(get_app_state)]
BlocklistRepositoryDep = Annotated[BlocklistRepository, Depends(get_blocklist_repo)]
ImportLogRepositoryDep = Annotated[ImportLogRepository, Depends(get_import_log_repo)]
GeoCacheRepositoryDep = Annotated[GeoCacheRepository, Depends(get_geo_cache_repo)]
Fail2BanDbRepositoryDep = Annotated[Fail2BanDbRepository, Depends(get_fail2ban_db_repo)]
AppStateDep = Annotated[ApplicationContext, Depends(get_app_state)]
AppDep = Annotated[FastAPI, Depends(get_app)]
AuthDep = Annotated[Session, Depends(require_auth)]

View File

@@ -6,11 +6,14 @@ module implementations, making the backend easier to test and extend.
from __future__ import annotations
from collections.abc import Iterable
from typing import Protocol
import aiosqlite
from app.models.auth import Session
from app.models.ban import BanOrigin
from app.repositories.fail2ban_db_repo import BanIpCount, BanRecord, HistoryRecord, JailBanCount
class SessionRepository(Protocol):
@@ -45,3 +48,160 @@ class SessionRepository(Protocol):
now_iso: str,
) -> int:
...
class BlocklistRepository(Protocol):
async def create_source(
self,
db: aiosqlite.Connection,
name: str,
url: str,
*,
enabled: bool = True,
) -> int:
...
async def get_source(
self,
db: aiosqlite.Connection,
source_id: int,
) -> dict[str, object] | None:
...
async def list_sources(self, db: aiosqlite.Connection) -> list[dict[str, object]]:
...
async def list_enabled_sources(self, db: aiosqlite.Connection) -> list[dict[str, object]]:
...
async def update_source(
self,
db: aiosqlite.Connection,
source_id: int,
*,
name: str | None = None,
url: str | None = None,
enabled: bool | None = None,
) -> bool:
...
async def delete_source(self, db: aiosqlite.Connection, source_id: int) -> bool:
...
class ImportLogRepository(Protocol):
async def add_log(
self,
db: aiosqlite.Connection,
*,
source_id: int | None,
source_url: str,
ips_imported: int,
ips_skipped: int,
errors: str | None,
) -> int:
...
async def list_logs(
self,
db: aiosqlite.Connection,
*,
source_id: int | None = None,
page: int = 1,
page_size: int = 50,
) -> tuple[list[dict[str, object]], int]:
...
async def get_last_log(self, db: aiosqlite.Connection) -> dict[str, object] | None:
...
async def compute_total_pages(self, total: int, page_size: int) -> int:
...
class GeoCacheRepository(Protocol):
async def load_all(self, db: aiosqlite.Connection) -> list[dict[str, object]]:
...
async def get_unresolved_ips(self, db: aiosqlite.Connection) -> list[str]:
...
async def count_unresolved(self, db: aiosqlite.Connection) -> int:
...
async def upsert_entry(
self,
db: aiosqlite.Connection,
ip: str,
country_code: str | None,
country_name: str | None,
asn: str | None,
org: str | None,
) -> None:
...
async def upsert_neg_entry(self, db: aiosqlite.Connection, ip: str) -> None:
...
async def bulk_upsert_entries(
self,
db: aiosqlite.Connection,
rows: Iterable[tuple[str, str | None, str | None, str | None, str | None]],
) -> int:
...
async def bulk_upsert_neg_entries(self, db: aiosqlite.Connection, ips: list[str]) -> int:
...
class Fail2BanDbRepository(Protocol):
async def check_db_nonempty(self, db_path: str) -> bool:
...
async def get_currently_banned(
self,
db_path: str,
since: int,
origin: BanOrigin | None = None,
*,
ip_filter: list[str] | None = None,
limit: int | None = None,
offset: int | None = None,
) -> tuple[list[BanRecord], int]:
...
async def get_ban_counts_by_bucket(
self,
db_path: str,
since: int,
bucket_secs: int,
num_buckets: int,
origin: BanOrigin | None = None,
) -> list[int]:
...
async def get_bans_by_jail(
self,
db_path: str,
since: int,
origin: BanOrigin | None = None,
) -> tuple[int, list[JailBanCount]]:
...
async def get_bans_table_summary(self, db_path: str) -> tuple[int, int | None, int | None]:
...
async def get_history_page(
self,
db_path: str,
since: int | None = None,
jail: str | None = None,
ip_filter: str | None = None,
origin: BanOrigin | None = None,
page: int = 1,
page_size: int = 100,
) -> tuple[list[HistoryRecord], int]:
...
async def get_history_for_ip(self, db_path: str, ip: str) -> list[HistoryRecord]:
...

View File

@@ -6,13 +6,41 @@ layers depend on, without binding them to concrete module implementations.
from __future__ import annotations
from typing import Protocol
from collections.abc import Awaitable, Callable
from typing import Protocol, runtime_checkable
import aiosqlite
import aiohttp
from app.models.auth import Session
from app.models.ban import JailBannedIpsResponse
from app.models.jail import JailDetailResponse, JailListResponse
from app.models.ban import BanOrigin, JailBannedIpsResponse, TimeRange
from app.models.blocklist import (
BlocklistSource,
ImportLogListResponse,
ImportRunResult,
ImportSourceResult,
PreviewResponse,
ScheduleConfig,
ScheduleInfo,
)
from app.models.config import (
AddLogPathRequest,
GlobalConfigResponse,
GlobalConfigUpdate,
JailConfigListResponse,
JailConfigResponse,
JailConfigUpdate,
LogPreviewRequest,
LogPreviewResponse,
MapColorThresholdsResponse,
MapColorThresholdsUpdate,
RegexTestResponse,
Fail2BanLogResponse,
ServiceStatusResponse,
)
from app.models.geo import GeoBatchLookup, GeoEnricher, GeoInfo
from app.models.history import HistoryListResponse, IpDetailResponse
from app.models.server import ServerSettingsResponse, ServerSettingsUpdate, ServerStatus
class AuthService(Protocol):
@@ -103,3 +131,268 @@ class JailService(Protocol):
geo_enricher: object,
) -> object:
...
@runtime_checkable
class BlocklistService(Protocol):
async def list_sources(self, db: aiosqlite.Connection) -> list[BlocklistSource]:
...
async def get_source(
self,
db: aiosqlite.Connection,
source_id: int,
) -> BlocklistSource | None:
...
async def create_source(
self,
db: aiosqlite.Connection,
name: str,
url: str,
*,
enabled: bool = True,
) -> BlocklistSource:
...
async def update_source(
self,
db: aiosqlite.Connection,
source_id: int,
*,
name: str | None = None,
url: str | None = None,
enabled: bool | None = None,
) -> BlocklistSource | None:
...
async def delete_source(self, db: aiosqlite.Connection, source_id: int) -> bool:
...
async def preview_source(
self,
url: str,
http_session: aiohttp.ClientSession,
*,
sample_lines: int = ...,
) -> PreviewResponse:
...
async def import_source(
self,
source: BlocklistSource,
http_session: aiohttp.ClientSession,
socket_path: str,
db: aiosqlite.Connection,
*,
geo_is_cached: Callable[[str], bool] | None = None,
geo_batch_lookup: GeoBatchLookup | None = None,
ban_ip: Callable[[str, str, str], Awaitable[None]],
) -> ImportSourceResult:
...
async def import_all(
self,
db: aiosqlite.Connection,
http_session: aiohttp.ClientSession,
socket_path: str,
*,
ban_ip: Callable[[str, str, str], Awaitable[None]],
geo_is_cached: Callable[[str], bool] | None = None,
geo_batch_lookup: GeoBatchLookup | None = None,
) -> ImportRunResult:
...
async def get_schedule(self, db: aiosqlite.Connection) -> ScheduleConfig:
...
async def set_schedule(self, db: aiosqlite.Connection, update: ScheduleConfig) -> None:
...
async def get_schedule_info(
self,
db: aiosqlite.Connection,
next_run_at: str | None,
) -> ScheduleInfo:
...
async def list_import_logs(
self,
db: aiosqlite.Connection,
*,
source_id: int | None = None,
page: int = 1,
page_size: int = 50,
) -> ImportLogListResponse:
...
@runtime_checkable
class ConfigService(Protocol):
async def get_jail_config(self, socket_path: str, name: str) -> JailConfigResponse:
...
async def list_jail_configs(self, socket_path: str) -> JailConfigListResponse:
...
async def update_jail_config(
self,
socket_path: str,
name: str,
update: JailConfigUpdate,
) -> None:
...
async def get_global_config(self, socket_path: str) -> GlobalConfigResponse:
...
async def update_global_config(
self,
socket_path: str,
update: GlobalConfigUpdate,
) -> None:
...
def test_regex(self, request: object) -> RegexTestResponse:
...
async def add_log_path(
self,
socket_path: str,
jail: str,
req: AddLogPathRequest,
) -> None:
...
async def delete_log_path(self, socket_path: str, jail: str, log_path: str) -> None:
...
async def preview_log(
self,
req: LogPreviewRequest,
preview_fn: Callable[[LogPreviewRequest], Awaitable[LogPreviewResponse]] | None = None,
) -> LogPreviewResponse:
...
async def get_map_color_thresholds(self, db: aiosqlite.Connection) -> MapColorThresholdsResponse:
...
async def update_map_color_thresholds(
self,
db: aiosqlite.Connection,
update: MapColorThresholdsUpdate,
) -> None:
...
async def read_fail2ban_log(
self,
socket_path: str,
lines: int,
filter_text: str | None = None,
) -> Fail2BanLogResponse:
...
async def get_service_status(
self,
socket_path: str,
probe_fn: Callable[[str], Awaitable[ServiceStatusResponse]] | None = None,
) -> ServiceStatusResponse:
...
@runtime_checkable
class HistoryService(Protocol):
async def list_history(
self,
socket_path: str,
*,
range_: TimeRange | None = None,
jail: str | None = None,
ip_filter: str | None = None,
origin: BanOrigin | None = None,
source: str = "fail2ban",
page: int = 1,
page_size: int = 100,
geo_enricher: GeoEnricher | None = None,
db: aiosqlite.Connection | None = None,
) -> HistoryListResponse:
...
async def get_ip_detail(
self,
socket_path: str,
ip: str,
*,
geo_enricher: GeoEnricher | None = None,
) -> IpDetailResponse | None:
...
@runtime_checkable
class GeoService(Protocol):
def clear_cache(self) -> None:
...
def clear_neg_cache(self) -> None:
...
def is_cached(self, ip: str) -> bool:
...
def init_geoip(self, mmdb_path: str | None) -> None:
...
async def cache_stats(self, db: aiosqlite.Connection) -> dict[str, int]:
...
async def count_unresolved(self, db: aiosqlite.Connection) -> int:
...
async def get_unresolved_ips(self, db: aiosqlite.Connection) -> list[str]:
...
async def load_cache_from_db(self, db: aiosqlite.Connection) -> None:
...
async def lookup(
self,
ip: str,
http_session: aiohttp.ClientSession,
) -> GeoInfo | None:
...
async def lookup_batch(
self,
ips: list[str],
http_session: aiohttp.ClientSession,
db: aiosqlite.Connection | None = None,
) -> dict[str, GeoInfo]:
...
def lookup_cached_only(self, ip: str) -> GeoInfo | None:
...
async def flush_dirty(self, db: aiosqlite.Connection) -> int:
...
@runtime_checkable
class HealthService(Protocol):
async def probe(self, socket_path: str, timeout: float = ...) -> ServerStatus:
...
@runtime_checkable
class ServerService(Protocol):
async def get_settings(self, socket_path: str) -> ServerSettingsResponse:
...
async def update_settings(
self,
socket_path: str,
update: ServerSettingsUpdate,
) -> None:
...
async def flush_logs(self, socket_path: str) -> str:
...