Fix blocklist service injection and centralize session cookie name
This commit is contained in:
@@ -15,14 +15,13 @@ under the key ``"blocklist_schedule"``.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
import json
|
||||
from collections.abc import Awaitable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import aiohttp
|
||||
import structlog
|
||||
|
||||
from app.exceptions import JailNotFoundError
|
||||
from app.models.blocklist import (
|
||||
BlocklistSource,
|
||||
ImportLogEntry,
|
||||
@@ -33,12 +32,11 @@ from app.models.blocklist import (
|
||||
ScheduleConfig,
|
||||
ScheduleInfo,
|
||||
)
|
||||
from app.exceptions import JailNotFoundError
|
||||
from app.repositories import blocklist_repo, import_log_repo, settings_repo
|
||||
from app.utils.ip_utils import is_valid_ip, is_valid_network
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
import aiohttp
|
||||
import aiosqlite
|
||||
@@ -299,9 +297,10 @@ async def import_source(
|
||||
http_session: aiohttp.ClientSession,
|
||||
socket_path: str,
|
||||
db: aiosqlite.Connection,
|
||||
*,
|
||||
ban_ip: Callable[[str, str, str], Awaitable[None]],
|
||||
geo_is_cached: Callable[[str], bool] | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None,
|
||||
) -> ImportSourceResult:
|
||||
"""Download and apply bans from a single blocklist source.
|
||||
|
||||
@@ -360,14 +359,7 @@ async def import_source(
|
||||
ban_error: str | None = None
|
||||
imported_ips: list[str] = []
|
||||
|
||||
if ban_ip is None:
|
||||
try:
|
||||
jail_svc = importlib.import_module("app.services.jail_service")
|
||||
ban_ip_fn = jail_svc.ban_ip
|
||||
except (ModuleNotFoundError, AttributeError) as exc:
|
||||
raise ValueError("ban_ip callback is required") from exc
|
||||
else:
|
||||
ban_ip_fn = ban_ip
|
||||
ban_ip_fn = ban_ip
|
||||
|
||||
for line in content.splitlines():
|
||||
stripped = line.strip()
|
||||
@@ -450,9 +442,10 @@ async def import_all(
|
||||
db: aiosqlite.Connection,
|
||||
http_session: aiohttp.ClientSession,
|
||||
socket_path: str,
|
||||
*,
|
||||
ban_ip: Callable[[str, str, str], Awaitable[None]],
|
||||
geo_is_cached: Callable[[str], bool] | None = None,
|
||||
geo_batch_lookup: GeoBatchLookup | None = None,
|
||||
ban_ip: Callable[[str, str, str], Awaitable[None]] | None = None,
|
||||
) -> ImportRunResult:
|
||||
"""Import all enabled blocklist sources.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user