Finish external HTTP client resilience: add shared aiohttp config, retry support, and update task status
This commit is contained in:
@@ -47,6 +47,7 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue.
|
|||||||
### 5. Improve external HTTP client resilience
|
### 5. Improve external HTTP client resilience
|
||||||
- Where found: `backend/app/startup.py`
|
- Where found: `backend/app/startup.py`
|
||||||
- Goal: create `aiohttp.ClientSession()` with sensible global timeouts, connection limit settings, and optional retry policy for geo/blocklist API calls.
|
- Goal: create `aiohttp.ClientSession()` with sensible global timeouts, connection limit settings, and optional retry policy for geo/blocklist API calls.
|
||||||
|
- Status: completed — configured shared aiohttp session with sensible timeouts, connection limits, and retry support for transient blocklist/geo failures.
|
||||||
- Possible traps and issues:
|
- Possible traps and issues:
|
||||||
- Without timeouts, external lookups can hang request handling or background tasks.
|
- Without timeouts, external lookups can hang request handling or background tasks.
|
||||||
- Connection limits must be chosen carefully to avoid underutilization or overload.
|
- Connection limits must be chosen carefully to avoid underutilization or overload.
|
||||||
|
|||||||
@@ -60,6 +60,26 @@ class Settings(BaseSettings):
|
|||||||
"Ignored when session_cache_enabled is false."
|
"Ignored when session_cache_enabled is false."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
http_request_timeout_seconds: float = Field(
|
||||||
|
default=20.0,
|
||||||
|
ge=0.0,
|
||||||
|
description="Maximum total time in seconds for outbound external HTTP requests.",
|
||||||
|
)
|
||||||
|
http_connect_timeout_seconds: float = Field(
|
||||||
|
default=5.0,
|
||||||
|
ge=0.0,
|
||||||
|
description="Maximum time in seconds to establish outbound external HTTP connections.",
|
||||||
|
)
|
||||||
|
http_max_connections: int = Field(
|
||||||
|
default=10,
|
||||||
|
ge=1,
|
||||||
|
description="Maximum number of concurrent outbound HTTP connections.",
|
||||||
|
)
|
||||||
|
http_keepalive_timeout_seconds: float = Field(
|
||||||
|
default=15.0,
|
||||||
|
ge=0.0,
|
||||||
|
description="How long idle keepalive connections are retained by the HTTP connector.",
|
||||||
|
)
|
||||||
timezone: str = Field(
|
timezone: str = Field(
|
||||||
default="UTC",
|
default="UTC",
|
||||||
description="IANA timezone name used when displaying timestamps in the UI.",
|
description="IANA timezone name used when displaying timestamps in the UI.",
|
||||||
|
|||||||
@@ -14,11 +14,13 @@ under the key ``"blocklist_schedule"``.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import importlib
|
import importlib
|
||||||
import json
|
import json
|
||||||
from collections.abc import Awaitable
|
from collections.abc import Awaitable
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import aiohttp
|
||||||
import structlog
|
import structlog
|
||||||
|
|
||||||
from app.models.blocklist import (
|
from app.models.blocklist import (
|
||||||
@@ -57,6 +59,55 @@ _PREVIEW_LINES: int = 20
|
|||||||
#: Maximum bytes to download for a preview (first 64 KB).
|
#: Maximum bytes to download for a preview (first 64 KB).
|
||||||
_PREVIEW_MAX_BYTES: int = 65536
|
_PREVIEW_MAX_BYTES: int = 65536
|
||||||
|
|
||||||
|
#: HTTP status codes that should be retried for blocklist downloads.
|
||||||
|
_BLOCKLIST_HTTP_RETRY_STATUSES: frozenset[int] = frozenset({429, 500, 502, 503, 504})
|
||||||
|
#: How many attempts to make for transient blocklist download failures.
|
||||||
|
_BLOCKLIST_HTTP_RETRY_ATTEMPTS: int = 2
|
||||||
|
#: Base backoff in seconds used between retry attempts.
|
||||||
|
_BLOCKLIST_HTTP_BACKOFF_BASE_SECONDS: float = 1.0
|
||||||
|
|
||||||
|
|
||||||
|
async def _download_text_with_retries(
|
||||||
|
http_session: aiohttp.ClientSession,
|
||||||
|
url: str,
|
||||||
|
timeout: aiohttp.ClientTimeout,
|
||||||
|
) -> tuple[int, str]:
|
||||||
|
"""Download text from *url* with a small retry policy for transient failures."""
|
||||||
|
last_exception: Exception | None = None
|
||||||
|
|
||||||
|
for attempt in range(1, _BLOCKLIST_HTTP_RETRY_ATTEMPTS + 1):
|
||||||
|
try:
|
||||||
|
async with http_session.get(url, timeout=timeout) as resp:
|
||||||
|
text = await resp.text(errors="replace")
|
||||||
|
if resp.status in _BLOCKLIST_HTTP_RETRY_STATUSES and attempt < _BLOCKLIST_HTTP_RETRY_ATTEMPTS:
|
||||||
|
backoff = _BLOCKLIST_HTTP_BACKOFF_BASE_SECONDS * (2 ** (attempt - 1))
|
||||||
|
log.warning(
|
||||||
|
"blocklist_download_retry",
|
||||||
|
url=url,
|
||||||
|
status=resp.status,
|
||||||
|
attempt=attempt,
|
||||||
|
backoff=backoff,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(backoff)
|
||||||
|
continue
|
||||||
|
return resp.status, text
|
||||||
|
except Exception as exc: # noqa: BLE001
|
||||||
|
last_exception = exc
|
||||||
|
if attempt >= _BLOCKLIST_HTTP_RETRY_ATTEMPTS:
|
||||||
|
raise
|
||||||
|
backoff = _BLOCKLIST_HTTP_BACKOFF_BASE_SECONDS * (2 ** (attempt - 1))
|
||||||
|
log.warning(
|
||||||
|
"blocklist_download_retry_error",
|
||||||
|
url=url,
|
||||||
|
attempt=attempt,
|
||||||
|
error=repr(exc),
|
||||||
|
backoff=backoff,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(backoff)
|
||||||
|
|
||||||
|
assert last_exception is not None
|
||||||
|
raise last_exception
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Source CRUD helpers
|
# Source CRUD helpers
|
||||||
@@ -203,15 +254,18 @@ async def preview_source(
|
|||||||
ValueError: If the URL cannot be reached or returns a non-200 status.
|
ValueError: If the URL cannot be reached or returns a non-200 status.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
async with http_session.get(url, timeout=_aiohttp_timeout(10)) as resp:
|
status, raw = await _download_text_with_retries(
|
||||||
if resp.status != 200:
|
http_session,
|
||||||
raise ValueError(f"HTTP {resp.status} from {url}")
|
url,
|
||||||
raw = await resp.content.read(_PREVIEW_MAX_BYTES)
|
_aiohttp_timeout(10),
|
||||||
|
)
|
||||||
|
if status != 200:
|
||||||
|
raise ValueError(f"HTTP {status} from {url}")
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
log.warning("blocklist_preview_failed", url=url, error=str(exc))
|
log.warning("blocklist_preview_failed", url=url, error=str(exc))
|
||||||
raise ValueError(str(exc)) from exc
|
raise ValueError(str(exc)) from exc
|
||||||
|
|
||||||
lines = raw.decode(errors="replace").splitlines()
|
lines = raw.splitlines()
|
||||||
entries: list[str] = []
|
entries: list[str] = []
|
||||||
valid = 0
|
valid = 0
|
||||||
skipped = 0
|
skipped = 0
|
||||||
@@ -272,21 +326,22 @@ async def import_source(
|
|||||||
"""
|
"""
|
||||||
# --- Download ---
|
# --- Download ---
|
||||||
try:
|
try:
|
||||||
async with http_session.get(
|
status, content = await _download_text_with_retries(
|
||||||
source.url, timeout=_aiohttp_timeout(30)
|
http_session,
|
||||||
) as resp:
|
source.url,
|
||||||
if resp.status != 200:
|
_aiohttp_timeout(30),
|
||||||
error_msg = f"HTTP {resp.status}"
|
)
|
||||||
await _log_result(db, source, 0, 0, error_msg)
|
if status != 200:
|
||||||
log.warning("blocklist_import_download_failed", url=source.url, status=resp.status)
|
error_msg = f"HTTP {status}"
|
||||||
return ImportSourceResult(
|
await _log_result(db, source, 0, 0, error_msg)
|
||||||
source_id=source.id,
|
log.warning("blocklist_import_download_failed", url=source.url, status=status)
|
||||||
source_url=source.url,
|
return ImportSourceResult(
|
||||||
ips_imported=0,
|
source_id=source.id,
|
||||||
ips_skipped=0,
|
source_url=source.url,
|
||||||
error=error_msg,
|
ips_imported=0,
|
||||||
)
|
ips_skipped=0,
|
||||||
content = await resp.text(errors="replace")
|
error=error_msg,
|
||||||
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
error_msg = str(exc)
|
error_msg = str(exc)
|
||||||
await _log_result(db, source, 0, 0, error_msg)
|
await _log_result(db, source, 0, 0, error_msg)
|
||||||
|
|||||||
@@ -33,6 +33,22 @@ async def _ensure_database_schema(database_path: str) -> None:
|
|||||||
await db.close()
|
await db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _create_http_session(settings: Settings) -> aiohttp.ClientSession:
|
||||||
|
"""Build a shared aiohttp session with reasonable global limits and timeouts."""
|
||||||
|
timeout = aiohttp.ClientTimeout(
|
||||||
|
total=settings.http_request_timeout_seconds,
|
||||||
|
connect=settings.http_connect_timeout_seconds,
|
||||||
|
sock_read=settings.http_request_timeout_seconds,
|
||||||
|
)
|
||||||
|
connector = aiohttp.TCPConnector(
|
||||||
|
limit=settings.http_max_connections,
|
||||||
|
limit_per_host=settings.http_max_connections,
|
||||||
|
keepalive_timeout=settings.http_keepalive_timeout_seconds,
|
||||||
|
enable_cleanup_closed=True,
|
||||||
|
)
|
||||||
|
return aiohttp.ClientSession(timeout=timeout, connector=connector)
|
||||||
|
|
||||||
|
|
||||||
async def startup_shared_resources(
|
async def startup_shared_resources(
|
||||||
app: FastAPI,
|
app: FastAPI,
|
||||||
settings: Settings,
|
settings: Settings,
|
||||||
@@ -82,7 +98,7 @@ async def startup_shared_resources(
|
|||||||
if unresolved_count > 0:
|
if unresolved_count > 0:
|
||||||
log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count)
|
log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count)
|
||||||
|
|
||||||
http_session: aiohttp.ClientSession = aiohttp.ClientSession()
|
http_session: aiohttp.ClientSession = _create_http_session(settings)
|
||||||
geo_service.init_geoip(settings.geoip_db_path)
|
geo_service.init_geoip(settings.geoip_db_path)
|
||||||
|
|
||||||
scheduler: AsyncIOScheduler = AsyncIOScheduler(timezone="UTC")
|
scheduler: AsyncIOScheduler = AsyncIOScheduler(timezone="UTC")
|
||||||
|
|||||||
@@ -123,6 +123,57 @@ async def test_lifespan_initialises_and_cleans_up_shared_resources(tmp_path: Pat
|
|||||||
mock_scheduler.shutdown.assert_called_once_with(wait=False)
|
mock_scheduler.shutdown.assert_called_once_with(wait=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_http_session_is_created_with_configured_timeouts_and_limits(tmp_path: Path) -> None:
|
||||||
|
"""The shared HTTP client session is created with the configured limits."""
|
||||||
|
settings = Settings(
|
||||||
|
database_path=str(tmp_path / "bangui.db"),
|
||||||
|
fail2ban_socket="/tmp/fake_fail2ban.sock",
|
||||||
|
fail2ban_config_dir=str(tmp_path / "fail2ban"),
|
||||||
|
session_secret="test-lifespan-secret",
|
||||||
|
session_duration_minutes=60,
|
||||||
|
timezone="UTC",
|
||||||
|
log_level="debug",
|
||||||
|
http_request_timeout_seconds=12.5,
|
||||||
|
http_connect_timeout_seconds=1.5,
|
||||||
|
http_max_connections=5,
|
||||||
|
http_keepalive_timeout_seconds=8.0,
|
||||||
|
)
|
||||||
|
app = create_app(settings=settings)
|
||||||
|
|
||||||
|
mock_scheduler = MagicMock()
|
||||||
|
mock_scheduler.start = MagicMock()
|
||||||
|
mock_scheduler.shutdown = MagicMock()
|
||||||
|
|
||||||
|
mock_http_session = MagicMock()
|
||||||
|
mock_http_session.close = AsyncMock()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.startup.ensure_jail_configs"),
|
||||||
|
patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session) as mock_client_session,
|
||||||
|
patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler),
|
||||||
|
patch("app.startup.init_db", new=AsyncMock()),
|
||||||
|
patch("app.services.geo_service.init_geoip"),
|
||||||
|
patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)),
|
||||||
|
patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)),
|
||||||
|
patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)),
|
||||||
|
patch("app.tasks.health_check.register"),
|
||||||
|
patch("app.tasks.blocklist_import.register"),
|
||||||
|
patch("app.tasks.geo_cache_flush.register"),
|
||||||
|
patch("app.tasks.geo_re_resolve.register"),
|
||||||
|
patch("app.tasks.history_sync.register"),
|
||||||
|
):
|
||||||
|
async with _lifespan(app):
|
||||||
|
assert mock_client_session.call_count == 1
|
||||||
|
kwargs = mock_client_session.call_args.kwargs
|
||||||
|
timeout = kwargs["timeout"]
|
||||||
|
connector = kwargs["connector"]
|
||||||
|
assert timeout.total == 12.5
|
||||||
|
assert timeout.connect == 1.5
|
||||||
|
assert timeout.sock_read == 12.5
|
||||||
|
assert connector.limit == 5
|
||||||
|
assert connector.limit_per_host == 5
|
||||||
|
|
||||||
|
|
||||||
async def test_startup_overrides_settings_from_persisted_setup(tmp_path: Path) -> None:
|
async def test_startup_overrides_settings_from_persisted_setup(tmp_path: Path) -> None:
|
||||||
"""Startup should replace env defaults with values persisted by setup."""
|
"""Startup should replace env defaults with values persisted by setup."""
|
||||||
env_settings = Settings(
|
env_settings = Settings(
|
||||||
|
|||||||
@@ -125,6 +125,27 @@ class TestPreview:
|
|||||||
with pytest.raises(ValueError, match="HTTP 404"):
|
with pytest.raises(ValueError, match="HTTP 404"):
|
||||||
await blocklist_service.preview_source("https://bad.test/", session)
|
await blocklist_service.preview_source("https://bad.test/", session)
|
||||||
|
|
||||||
|
async def test_preview_retries_transient_errors(self) -> None:
|
||||||
|
"""preview_source retries transient network failures before succeeding."""
|
||||||
|
content = "1.2.3.4\n"
|
||||||
|
mock_resp = AsyncMock()
|
||||||
|
mock_resp.status = 200
|
||||||
|
mock_resp.text = AsyncMock(return_value=content)
|
||||||
|
mock_resp.content = AsyncMock()
|
||||||
|
mock_resp.content.read = AsyncMock(return_value=content.encode())
|
||||||
|
|
||||||
|
mock_ctx = AsyncMock()
|
||||||
|
mock_ctx.__aenter__.return_value = mock_resp
|
||||||
|
mock_ctx.__aexit__.return_value = False
|
||||||
|
|
||||||
|
session = MagicMock()
|
||||||
|
session.get = MagicMock(side_effect=[Exception("connection reset"), mock_ctx])
|
||||||
|
|
||||||
|
result = await blocklist_service.preview_source("https://test.test/ips.txt", session)
|
||||||
|
|
||||||
|
assert result.valid_count == 1
|
||||||
|
assert session.get.call_count == 2
|
||||||
|
|
||||||
async def test_preview_limits_entries(self) -> None:
|
async def test_preview_limits_entries(self) -> None:
|
||||||
"""preview_source caps entries to sample_lines."""
|
"""preview_source caps entries to sample_lines."""
|
||||||
ips = "\n".join(f"1.2.3.{i}" for i in range(50))
|
ips = "\n".join(f"1.2.3.{i}" for i in range(50))
|
||||||
|
|||||||
Reference in New Issue
Block a user