refactor: complete Task 2/3 geo decouple + exceptions centralization; mark as done

This commit is contained in:
2026-03-21 17:15:02 +01:00
parent 3aba2b6446
commit a442836c5c
28 changed files with 803 additions and 571 deletions

View File

@@ -33,9 +33,13 @@ from app.repositories import blocklist_repo, import_log_repo, settings_repo
from app.utils.ip_utils import is_valid_ip, is_valid_network
if TYPE_CHECKING:
from collections.abc import Callable
import aiohttp
import aiosqlite
from app.models.geo import GeoBatchLookup
log: structlog.stdlib.BoundLogger = structlog.get_logger()
#: Settings key used to persist the schedule config.
@@ -238,6 +242,8 @@ async def import_source(
http_session: aiohttp.ClientSession,
socket_path: str,
db: aiosqlite.Connection,
geo_is_cached: Callable[[str], bool] | None = None,
geo_batch_lookup: GeoBatchLookup | None = None,
) -> ImportSourceResult:
"""Download and apply bans from a single blocklist source.
@@ -339,12 +345,8 @@ async def import_source(
)
# --- Pre-warm geo cache for newly imported IPs ---
if imported_ips:
from app.services import geo_service # noqa: PLC0415
uncached_ips: list[str] = [
ip for ip in imported_ips if not geo_service.is_cached(ip)
]
if imported_ips and geo_is_cached is not None:
uncached_ips: list[str] = [ip for ip in imported_ips if not geo_is_cached(ip)]
skipped_geo: int = len(imported_ips) - len(uncached_ips)
if skipped_geo > 0:
@@ -355,9 +357,9 @@ async def import_source(
to_lookup=len(uncached_ips),
)
if uncached_ips:
if uncached_ips and geo_batch_lookup is not None:
try:
await geo_service.lookup_batch(uncached_ips, http_session, db=db)
await geo_batch_lookup(uncached_ips, http_session, db=db)
log.info(
"blocklist_geo_prewarm_complete",
source_id=source.id,
@@ -383,6 +385,8 @@ async def import_all(
db: aiosqlite.Connection,
http_session: aiohttp.ClientSession,
socket_path: str,
geo_is_cached: Callable[[str], bool] | None = None,
geo_batch_lookup: GeoBatchLookup | None = None,
) -> ImportRunResult:
"""Import all enabled blocklist sources.
@@ -406,7 +410,14 @@ async def import_all(
for row in sources:
source = _row_to_source(row)
result = await import_source(source, http_session, socket_path, db)
result = await import_source(
source,
http_session,
socket_path,
db,
geo_is_cached=geo_is_cached,
geo_batch_lookup=geo_batch_lookup,
)
results.append(result)
total_imported += result.ips_imported
total_skipped += result.ips_skipped