refactor: complete Task 2/3 geo decouple + exceptions centralization; mark as done
This commit is contained in:
@@ -33,9 +33,13 @@ from app.repositories import blocklist_repo, import_log_repo, settings_repo
|
||||
from app.utils.ip_utils import is_valid_ip, is_valid_network
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
|
||||
from app.models.geo import GeoBatchLookup
|
||||
|
||||
log: structlog.stdlib.BoundLogger = structlog.get_logger()
|
||||
|
||||
#: Settings key used to persist the schedule config.
|
||||
@@ -238,6 +242,8 @@ async def import_source(
|
||||
http_session: aiohttp.ClientSession,
|
||||
socket_path: str,
|
||||
db: aiosqlite.Connection,
|
||||
geo_is_cached: Callable[[str], bool] | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
) -> ImportSourceResult:
|
||||
"""Download and apply bans from a single blocklist source.
|
||||
|
||||
@@ -339,12 +345,8 @@ async def import_source(
|
||||
)
|
||||
|
||||
# --- Pre-warm geo cache for newly imported IPs ---
|
||||
if imported_ips:
|
||||
from app.services import geo_service # noqa: PLC0415
|
||||
|
||||
uncached_ips: list[str] = [
|
||||
ip for ip in imported_ips if not geo_service.is_cached(ip)
|
||||
]
|
||||
if imported_ips and geo_is_cached is not None:
|
||||
uncached_ips: list[str] = [ip for ip in imported_ips if not geo_is_cached(ip)]
|
||||
skipped_geo: int = len(imported_ips) - len(uncached_ips)
|
||||
|
||||
if skipped_geo > 0:
|
||||
@@ -355,9 +357,9 @@ async def import_source(
|
||||
to_lookup=len(uncached_ips),
|
||||
)
|
||||
|
||||
if uncached_ips:
|
||||
if uncached_ips and geo_batch_lookup is not None:
|
||||
try:
|
||||
await geo_service.lookup_batch(uncached_ips, http_session, db=db)
|
||||
await geo_batch_lookup(uncached_ips, http_session, db=db)
|
||||
log.info(
|
||||
"blocklist_geo_prewarm_complete",
|
||||
source_id=source.id,
|
||||
@@ -383,6 +385,8 @@ async def import_all(
|
||||
db: aiosqlite.Connection,
|
||||
http_session: aiohttp.ClientSession,
|
||||
socket_path: str,
|
||||
geo_is_cached: Callable[[str], bool] | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
) -> ImportRunResult:
|
||||
"""Import all enabled blocklist sources.
|
||||
|
||||
@@ -406,7 +410,14 @@ async def import_all(
|
||||
|
||||
for row in sources:
|
||||
source = _row_to_source(row)
|
||||
result = await import_source(source, http_session, socket_path, db)
|
||||
result = await import_source(
|
||||
source,
|
||||
http_session,
|
||||
socket_path,
|
||||
db,
|
||||
geo_is_cached=geo_is_cached,
|
||||
geo_batch_lookup=geo_batch_lookup,
|
||||
)
|
||||
results.append(result)
|
||||
total_imported += result.ips_imported
|
||||
total_skipped += result.ips_skipped
|
||||
|
||||
Reference in New Issue
Block a user