Task 25: extend service/repository protocol coverage and wire DI aliases
This commit is contained in:
@@ -22,8 +22,23 @@ from app.models.auth import Session
|
||||
from app.models.config import PendingRecovery
|
||||
from app.models.geo import GeoBatchLookup
|
||||
from app.models.server import ServerStatus
|
||||
from app.repositories.protocols import SessionRepository
|
||||
from app.services.protocols import AuthService, JailService
|
||||
from app.repositories.protocols import (
|
||||
BlocklistRepository,
|
||||
Fail2BanDbRepository,
|
||||
GeoCacheRepository,
|
||||
ImportLogRepository,
|
||||
SessionRepository,
|
||||
)
|
||||
from app.services.protocols import (
|
||||
AuthService,
|
||||
BlocklistService,
|
||||
ConfigService,
|
||||
GeoService,
|
||||
HealthService,
|
||||
HistoryService,
|
||||
JailService,
|
||||
ServerService,
|
||||
)
|
||||
from app.utils.constants import SESSION_COOKIE_NAME
|
||||
from app.utils.runtime_state import RuntimeState
|
||||
from app.utils.session_cache import InMemorySessionCache, NoOpSessionCache, SessionCache
|
||||
@@ -239,6 +254,48 @@ async def get_jail_service() -> JailService:
|
||||
return cast("JailService", jail_service)
|
||||
|
||||
|
||||
async def get_blocklist_service() -> BlocklistService:
|
||||
"""Provide the concrete blocklist service implementation."""
|
||||
from app.services import blocklist_service # noqa: PLC0415
|
||||
|
||||
return cast("BlocklistService", blocklist_service)
|
||||
|
||||
|
||||
async def get_config_service() -> ConfigService:
|
||||
"""Provide the concrete configuration service implementation."""
|
||||
from app.services import config_service # noqa: PLC0415
|
||||
|
||||
return cast("ConfigService", config_service)
|
||||
|
||||
|
||||
async def get_history_service() -> HistoryService:
|
||||
"""Provide the concrete history service implementation."""
|
||||
from app.services import history_service # noqa: PLC0415
|
||||
|
||||
return cast("HistoryService", history_service)
|
||||
|
||||
|
||||
async def get_geo_service() -> GeoService:
|
||||
"""Provide the concrete geo service implementation."""
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
return cast("GeoService", geo_service)
|
||||
|
||||
|
||||
async def get_health_service() -> HealthService:
|
||||
"""Provide the concrete health service implementation."""
|
||||
from app.services import health_service # noqa: PLC0415
|
||||
|
||||
return cast("HealthService", health_service)
|
||||
|
||||
|
||||
async def get_server_service() -> ServerService:
|
||||
"""Provide the concrete server service implementation."""
|
||||
from app.services import server_service # noqa: PLC0415
|
||||
|
||||
return cast("ServerService", server_service)
|
||||
|
||||
|
||||
async def get_session_repo() -> SessionRepository:
|
||||
"""Provide the concrete session repository implementation."""
|
||||
from app.repositories import session_repo # noqa: PLC0415
|
||||
@@ -246,6 +303,34 @@ async def get_session_repo() -> SessionRepository:
|
||||
return session_repo
|
||||
|
||||
|
||||
async def get_blocklist_repo() -> BlocklistRepository:
|
||||
"""Provide the concrete blocklist repository implementation."""
|
||||
from app.repositories import blocklist_repo # noqa: PLC0415
|
||||
|
||||
return cast("BlocklistRepository", blocklist_repo)
|
||||
|
||||
|
||||
async def get_import_log_repo() -> ImportLogRepository:
|
||||
"""Provide the concrete import log repository implementation."""
|
||||
from app.repositories import import_log_repo # noqa: PLC0415
|
||||
|
||||
return cast("ImportLogRepository", import_log_repo)
|
||||
|
||||
|
||||
async def get_geo_cache_repo() -> GeoCacheRepository:
|
||||
"""Provide the concrete geo cache repository implementation."""
|
||||
from app.repositories import geo_cache_repo # noqa: PLC0415
|
||||
|
||||
return cast("GeoCacheRepository", geo_cache_repo)
|
||||
|
||||
|
||||
async def get_fail2ban_db_repo() -> Fail2BanDbRepository:
|
||||
"""Provide the concrete fail2ban DB repository implementation."""
|
||||
from app.repositories import fail2ban_db_repo # noqa: PLC0415
|
||||
|
||||
return cast("Fail2BanDbRepository", fail2ban_db_repo)
|
||||
|
||||
|
||||
async def get_app_state(app_context: Annotated[ApplicationContext, Depends(get_app_context)]) -> ApplicationContext:
|
||||
"""Provide the application state object for the current request."""
|
||||
return app_context
|
||||
@@ -351,7 +436,17 @@ PendingRecoveryDep = Annotated[PendingRecovery | None, Depends(get_pending_recov
|
||||
SessionCacheDep = Annotated[SessionCache, Depends(get_session_cache)]
|
||||
AuthServiceDep = Annotated[AuthService, Depends(get_auth_service)]
|
||||
JailServiceDep = Annotated[JailService, Depends(get_jail_service)]
|
||||
BlocklistServiceDep = Annotated[BlocklistService, Depends(get_blocklist_service)]
|
||||
ConfigServiceDep = Annotated[ConfigService, Depends(get_config_service)]
|
||||
HistoryServiceDep = Annotated[HistoryService, Depends(get_history_service)]
|
||||
GeoServiceDep = Annotated[GeoService, Depends(get_geo_service)]
|
||||
HealthServiceDep = Annotated[HealthService, Depends(get_health_service)]
|
||||
ServerServiceDep = Annotated[ServerService, Depends(get_server_service)]
|
||||
SessionRepoDep = Annotated[SessionRepository, Depends(get_session_repo)]
|
||||
AppStateDep = Annotated[AppState, Depends(get_app_state)]
|
||||
BlocklistRepositoryDep = Annotated[BlocklistRepository, Depends(get_blocklist_repo)]
|
||||
ImportLogRepositoryDep = Annotated[ImportLogRepository, Depends(get_import_log_repo)]
|
||||
GeoCacheRepositoryDep = Annotated[GeoCacheRepository, Depends(get_geo_cache_repo)]
|
||||
Fail2BanDbRepositoryDep = Annotated[Fail2BanDbRepository, Depends(get_fail2ban_db_repo)]
|
||||
AppStateDep = Annotated[ApplicationContext, Depends(get_app_state)]
|
||||
AppDep = Annotated[FastAPI, Depends(get_app)]
|
||||
AuthDep = Annotated[Session, Depends(require_auth)]
|
||||
|
||||
@@ -6,11 +6,14 @@ module implementations, making the backend easier to test and extend.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Protocol
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from app.models.auth import Session
|
||||
from app.models.ban import BanOrigin
|
||||
from app.repositories.fail2ban_db_repo import BanIpCount, BanRecord, HistoryRecord, JailBanCount
|
||||
|
||||
|
||||
class SessionRepository(Protocol):
|
||||
@@ -45,3 +48,160 @@ class SessionRepository(Protocol):
|
||||
now_iso: str,
|
||||
) -> int:
|
||||
...
|
||||
|
||||
|
||||
class BlocklistRepository(Protocol):
|
||||
async def create_source(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
name: str,
|
||||
url: str,
|
||||
*,
|
||||
enabled: bool = True,
|
||||
) -> int:
|
||||
...
|
||||
|
||||
async def get_source(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
source_id: int,
|
||||
) -> dict[str, object] | None:
|
||||
...
|
||||
|
||||
async def list_sources(self, db: aiosqlite.Connection) -> list[dict[str, object]]:
|
||||
...
|
||||
|
||||
async def list_enabled_sources(self, db: aiosqlite.Connection) -> list[dict[str, object]]:
|
||||
...
|
||||
|
||||
async def update_source(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
source_id: int,
|
||||
*,
|
||||
name: str | None = None,
|
||||
url: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
) -> bool:
|
||||
...
|
||||
|
||||
async def delete_source(self, db: aiosqlite.Connection, source_id: int) -> bool:
|
||||
...
|
||||
|
||||
|
||||
class ImportLogRepository(Protocol):
|
||||
async def add_log(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
*,
|
||||
source_id: int | None,
|
||||
source_url: str,
|
||||
ips_imported: int,
|
||||
ips_skipped: int,
|
||||
errors: str | None,
|
||||
) -> int:
|
||||
...
|
||||
|
||||
async def list_logs(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
*,
|
||||
source_id: int | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> tuple[list[dict[str, object]], int]:
|
||||
...
|
||||
|
||||
async def get_last_log(self, db: aiosqlite.Connection) -> dict[str, object] | None:
|
||||
...
|
||||
|
||||
async def compute_total_pages(self, total: int, page_size: int) -> int:
|
||||
...
|
||||
|
||||
|
||||
class GeoCacheRepository(Protocol):
|
||||
async def load_all(self, db: aiosqlite.Connection) -> list[dict[str, object]]:
|
||||
...
|
||||
|
||||
async def get_unresolved_ips(self, db: aiosqlite.Connection) -> list[str]:
|
||||
...
|
||||
|
||||
async def count_unresolved(self, db: aiosqlite.Connection) -> int:
|
||||
...
|
||||
|
||||
async def upsert_entry(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
ip: str,
|
||||
country_code: str | None,
|
||||
country_name: str | None,
|
||||
asn: str | None,
|
||||
org: str | None,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
async def upsert_neg_entry(self, db: aiosqlite.Connection, ip: str) -> None:
|
||||
...
|
||||
|
||||
async def bulk_upsert_entries(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
rows: Iterable[tuple[str, str | None, str | None, str | None, str | None]],
|
||||
) -> int:
|
||||
...
|
||||
|
||||
async def bulk_upsert_neg_entries(self, db: aiosqlite.Connection, ips: list[str]) -> int:
|
||||
...
|
||||
|
||||
|
||||
class Fail2BanDbRepository(Protocol):
|
||||
async def check_db_nonempty(self, db_path: str) -> bool:
|
||||
...
|
||||
|
||||
async def get_currently_banned(
|
||||
self,
|
||||
db_path: str,
|
||||
since: int,
|
||||
origin: BanOrigin | None = None,
|
||||
*,
|
||||
ip_filter: list[str] | None = None,
|
||||
limit: int | None = None,
|
||||
offset: int | None = None,
|
||||
) -> tuple[list[BanRecord], int]:
|
||||
...
|
||||
|
||||
async def get_ban_counts_by_bucket(
|
||||
self,
|
||||
db_path: str,
|
||||
since: int,
|
||||
bucket_secs: int,
|
||||
num_buckets: int,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> list[int]:
|
||||
...
|
||||
|
||||
async def get_bans_by_jail(
|
||||
self,
|
||||
db_path: str,
|
||||
since: int,
|
||||
origin: BanOrigin | None = None,
|
||||
) -> tuple[int, list[JailBanCount]]:
|
||||
...
|
||||
|
||||
async def get_bans_table_summary(self, db_path: str) -> tuple[int, int | None, int | None]:
|
||||
...
|
||||
|
||||
async def get_history_page(
|
||||
self,
|
||||
db_path: str,
|
||||
since: int | None = None,
|
||||
jail: str | None = None,
|
||||
ip_filter: str | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 100,
|
||||
) -> tuple[list[HistoryRecord], int]:
|
||||
...
|
||||
|
||||
async def get_history_for_ip(self, db_path: str, ip: str) -> list[HistoryRecord]:
|
||||
...
|
||||
|
||||
@@ -6,13 +6,41 @@ layers depend on, without binding them to concrete module implementations.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Protocol
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
import aiosqlite
|
||||
import aiohttp
|
||||
|
||||
from app.models.auth import Session
|
||||
from app.models.ban import JailBannedIpsResponse
|
||||
from app.models.jail import JailDetailResponse, JailListResponse
|
||||
from app.models.ban import BanOrigin, JailBannedIpsResponse, TimeRange
|
||||
from app.models.blocklist import (
|
||||
BlocklistSource,
|
||||
ImportLogListResponse,
|
||||
ImportRunResult,
|
||||
ImportSourceResult,
|
||||
PreviewResponse,
|
||||
ScheduleConfig,
|
||||
ScheduleInfo,
|
||||
)
|
||||
from app.models.config import (
|
||||
AddLogPathRequest,
|
||||
GlobalConfigResponse,
|
||||
GlobalConfigUpdate,
|
||||
JailConfigListResponse,
|
||||
JailConfigResponse,
|
||||
JailConfigUpdate,
|
||||
LogPreviewRequest,
|
||||
LogPreviewResponse,
|
||||
MapColorThresholdsResponse,
|
||||
MapColorThresholdsUpdate,
|
||||
RegexTestResponse,
|
||||
Fail2BanLogResponse,
|
||||
ServiceStatusResponse,
|
||||
)
|
||||
from app.models.geo import GeoBatchLookup, GeoEnricher, GeoInfo
|
||||
from app.models.history import HistoryListResponse, IpDetailResponse
|
||||
from app.models.server import ServerSettingsResponse, ServerSettingsUpdate, ServerStatus
|
||||
|
||||
|
||||
class AuthService(Protocol):
|
||||
@@ -103,3 +131,268 @@ class JailService(Protocol):
|
||||
geo_enricher: object,
|
||||
) -> object:
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class BlocklistService(Protocol):
|
||||
async def list_sources(self, db: aiosqlite.Connection) -> list[BlocklistSource]:
|
||||
...
|
||||
|
||||
async def get_source(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
source_id: int,
|
||||
) -> BlocklistSource | None:
|
||||
...
|
||||
|
||||
async def create_source(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
name: str,
|
||||
url: str,
|
||||
*,
|
||||
enabled: bool = True,
|
||||
) -> BlocklistSource:
|
||||
...
|
||||
|
||||
async def update_source(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
source_id: int,
|
||||
*,
|
||||
name: str | None = None,
|
||||
url: str | None = None,
|
||||
enabled: bool | None = None,
|
||||
) -> BlocklistSource | None:
|
||||
...
|
||||
|
||||
async def delete_source(self, db: aiosqlite.Connection, source_id: int) -> bool:
|
||||
...
|
||||
|
||||
async def preview_source(
|
||||
self,
|
||||
url: str,
|
||||
http_session: aiohttp.ClientSession,
|
||||
*,
|
||||
sample_lines: int = ...,
|
||||
) -> PreviewResponse:
|
||||
...
|
||||
|
||||
async def import_source(
|
||||
self,
|
||||
source: BlocklistSource,
|
||||
http_session: aiohttp.ClientSession,
|
||||
socket_path: str,
|
||||
db: aiosqlite.Connection,
|
||||
*,
|
||||
geo_is_cached: Callable[[str], bool] | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
ban_ip: Callable[[str, str, str], Awaitable[None]],
|
||||
) -> ImportSourceResult:
|
||||
...
|
||||
|
||||
async def import_all(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
http_session: aiohttp.ClientSession,
|
||||
socket_path: str,
|
||||
*,
|
||||
ban_ip: Callable[[str, str, str], Awaitable[None]],
|
||||
geo_is_cached: Callable[[str], bool] | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
) -> ImportRunResult:
|
||||
...
|
||||
|
||||
async def get_schedule(self, db: aiosqlite.Connection) -> ScheduleConfig:
|
||||
...
|
||||
|
||||
async def set_schedule(self, db: aiosqlite.Connection, update: ScheduleConfig) -> None:
|
||||
...
|
||||
|
||||
async def get_schedule_info(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
next_run_at: str | None,
|
||||
) -> ScheduleInfo:
|
||||
...
|
||||
|
||||
async def list_import_logs(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
*,
|
||||
source_id: int | None = None,
|
||||
page: int = 1,
|
||||
page_size: int = 50,
|
||||
) -> ImportLogListResponse:
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ConfigService(Protocol):
|
||||
async def get_jail_config(self, socket_path: str, name: str) -> JailConfigResponse:
|
||||
...
|
||||
|
||||
async def list_jail_configs(self, socket_path: str) -> JailConfigListResponse:
|
||||
...
|
||||
|
||||
async def update_jail_config(
|
||||
self,
|
||||
socket_path: str,
|
||||
name: str,
|
||||
update: JailConfigUpdate,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
async def get_global_config(self, socket_path: str) -> GlobalConfigResponse:
|
||||
...
|
||||
|
||||
async def update_global_config(
|
||||
self,
|
||||
socket_path: str,
|
||||
update: GlobalConfigUpdate,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
def test_regex(self, request: object) -> RegexTestResponse:
|
||||
...
|
||||
|
||||
async def add_log_path(
|
||||
self,
|
||||
socket_path: str,
|
||||
jail: str,
|
||||
req: AddLogPathRequest,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
async def delete_log_path(self, socket_path: str, jail: str, log_path: str) -> None:
|
||||
...
|
||||
|
||||
async def preview_log(
|
||||
self,
|
||||
req: LogPreviewRequest,
|
||||
preview_fn: Callable[[LogPreviewRequest], Awaitable[LogPreviewResponse]] | None = None,
|
||||
) -> LogPreviewResponse:
|
||||
...
|
||||
|
||||
async def get_map_color_thresholds(self, db: aiosqlite.Connection) -> MapColorThresholdsResponse:
|
||||
...
|
||||
|
||||
async def update_map_color_thresholds(
|
||||
self,
|
||||
db: aiosqlite.Connection,
|
||||
update: MapColorThresholdsUpdate,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
async def read_fail2ban_log(
|
||||
self,
|
||||
socket_path: str,
|
||||
lines: int,
|
||||
filter_text: str | None = None,
|
||||
) -> Fail2BanLogResponse:
|
||||
...
|
||||
|
||||
async def get_service_status(
|
||||
self,
|
||||
socket_path: str,
|
||||
probe_fn: Callable[[str], Awaitable[ServiceStatusResponse]] | None = None,
|
||||
) -> ServiceStatusResponse:
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class HistoryService(Protocol):
|
||||
async def list_history(
|
||||
self,
|
||||
socket_path: str,
|
||||
*,
|
||||
range_: TimeRange | None = None,
|
||||
jail: str | None = None,
|
||||
ip_filter: str | None = None,
|
||||
origin: BanOrigin | None = None,
|
||||
source: str = "fail2ban",
|
||||
page: int = 1,
|
||||
page_size: int = 100,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
db: aiosqlite.Connection | None = None,
|
||||
) -> HistoryListResponse:
|
||||
...
|
||||
|
||||
async def get_ip_detail(
|
||||
self,
|
||||
socket_path: str,
|
||||
ip: str,
|
||||
*,
|
||||
geo_enricher: GeoEnricher | None = None,
|
||||
) -> IpDetailResponse | None:
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class GeoService(Protocol):
|
||||
def clear_cache(self) -> None:
|
||||
...
|
||||
|
||||
def clear_neg_cache(self) -> None:
|
||||
...
|
||||
|
||||
def is_cached(self, ip: str) -> bool:
|
||||
...
|
||||
|
||||
def init_geoip(self, mmdb_path: str | None) -> None:
|
||||
...
|
||||
|
||||
async def cache_stats(self, db: aiosqlite.Connection) -> dict[str, int]:
|
||||
...
|
||||
|
||||
async def count_unresolved(self, db: aiosqlite.Connection) -> int:
|
||||
...
|
||||
|
||||
async def get_unresolved_ips(self, db: aiosqlite.Connection) -> list[str]:
|
||||
...
|
||||
|
||||
async def load_cache_from_db(self, db: aiosqlite.Connection) -> None:
|
||||
...
|
||||
|
||||
async def lookup(
|
||||
self,
|
||||
ip: str,
|
||||
http_session: aiohttp.ClientSession,
|
||||
) -> GeoInfo | None:
|
||||
...
|
||||
|
||||
async def lookup_batch(
|
||||
self,
|
||||
ips: list[str],
|
||||
http_session: aiohttp.ClientSession,
|
||||
db: aiosqlite.Connection | None = None,
|
||||
) -> dict[str, GeoInfo]:
|
||||
...
|
||||
|
||||
def lookup_cached_only(self, ip: str) -> GeoInfo | None:
|
||||
...
|
||||
|
||||
async def flush_dirty(self, db: aiosqlite.Connection) -> int:
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class HealthService(Protocol):
|
||||
async def probe(self, socket_path: str, timeout: float = ...) -> ServerStatus:
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ServerService(Protocol):
|
||||
async def get_settings(self, socket_path: str) -> ServerSettingsResponse:
|
||||
...
|
||||
|
||||
async def update_settings(
|
||||
self,
|
||||
socket_path: str,
|
||||
update: ServerSettingsUpdate,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
async def flush_logs(self, socket_path: str) -> str:
|
||||
...
|
||||
|
||||
Reference in New Issue
Block a user