From 148756fb792d5a9045d7f1cffc8354bc63349168 Mon Sep 17 00:00:00 2001 From: Lukas Date: Thu, 9 Apr 2026 22:01:11 +0200 Subject: [PATCH] Finish external HTTP client resilience: add shared aiohttp config, retry support, and update task status --- Docs/Tasks.md | 1 + backend/app/config.py | 20 ++++ backend/app/services/blocklist_service.py | 95 +++++++++++++++---- backend/app/startup.py | 18 +++- backend/tests/test_main.py | 51 ++++++++++ .../test_services/test_blocklist_service.py | 21 ++++ 6 files changed, 185 insertions(+), 21 deletions(-) diff --git a/Docs/Tasks.md b/Docs/Tasks.md index 46ad111..11bd672 100644 --- a/Docs/Tasks.md +++ b/Docs/Tasks.md @@ -47,6 +47,7 @@ Reference: `Docs/Refactoring.md` for full analysis of each issue. ### 5. Improve external HTTP client resilience - Where found: `backend/app/startup.py` - Goal: create `aiohttp.ClientSession()` with sensible global timeouts, connection limit settings, and optional retry policy for geo/blocklist API calls. +- Status: completed — configured shared aiohttp session with sensible timeouts, connection limits, and retry support for transient blocklist/geo failures. - Possible traps and issues: - Without timeouts, external lookups can hang request handling or background tasks. - Connection limits must be chosen carefully to avoid underutilization or overload. diff --git a/backend/app/config.py b/backend/app/config.py index a9cc522..c5c0f5d 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -60,6 +60,26 @@ class Settings(BaseSettings): "Ignored when session_cache_enabled is false." ), ) + http_request_timeout_seconds: float = Field( + default=20.0, + ge=0.0, + description="Maximum total time in seconds for outbound external HTTP requests.", + ) + http_connect_timeout_seconds: float = Field( + default=5.0, + ge=0.0, + description="Maximum time in seconds to establish outbound external HTTP connections.", + ) + http_max_connections: int = Field( + default=10, + ge=1, + description="Maximum number of concurrent outbound HTTP connections.", + ) + http_keepalive_timeout_seconds: float = Field( + default=15.0, + ge=0.0, + description="How long idle keepalive connections are retained by the HTTP connector.", + ) timezone: str = Field( default="UTC", description="IANA timezone name used when displaying timestamps in the UI.", diff --git a/backend/app/services/blocklist_service.py b/backend/app/services/blocklist_service.py index 91003c5..07d6ca8 100644 --- a/backend/app/services/blocklist_service.py +++ b/backend/app/services/blocklist_service.py @@ -14,11 +14,13 @@ under the key ``"blocklist_schedule"``. from __future__ import annotations +import asyncio import importlib import json from collections.abc import Awaitable from typing import TYPE_CHECKING +import aiohttp import structlog from app.models.blocklist import ( @@ -57,6 +59,55 @@ _PREVIEW_LINES: int = 20 #: Maximum bytes to download for a preview (first 64 KB). _PREVIEW_MAX_BYTES: int = 65536 +#: HTTP status codes that should be retried for blocklist downloads. +_BLOCKLIST_HTTP_RETRY_STATUSES: frozenset[int] = frozenset({429, 500, 502, 503, 504}) +#: How many attempts to make for transient blocklist download failures. +_BLOCKLIST_HTTP_RETRY_ATTEMPTS: int = 2 +#: Base backoff in seconds used between retry attempts. +_BLOCKLIST_HTTP_BACKOFF_BASE_SECONDS: float = 1.0 + + +async def _download_text_with_retries( + http_session: aiohttp.ClientSession, + url: str, + timeout: aiohttp.ClientTimeout, +) -> tuple[int, str]: + """Download text from *url* with a small retry policy for transient failures.""" + last_exception: Exception | None = None + + for attempt in range(1, _BLOCKLIST_HTTP_RETRY_ATTEMPTS + 1): + try: + async with http_session.get(url, timeout=timeout) as resp: + text = await resp.text(errors="replace") + if resp.status in _BLOCKLIST_HTTP_RETRY_STATUSES and attempt < _BLOCKLIST_HTTP_RETRY_ATTEMPTS: + backoff = _BLOCKLIST_HTTP_BACKOFF_BASE_SECONDS * (2 ** (attempt - 1)) + log.warning( + "blocklist_download_retry", + url=url, + status=resp.status, + attempt=attempt, + backoff=backoff, + ) + await asyncio.sleep(backoff) + continue + return resp.status, text + except Exception as exc: # noqa: BLE001 + last_exception = exc + if attempt >= _BLOCKLIST_HTTP_RETRY_ATTEMPTS: + raise + backoff = _BLOCKLIST_HTTP_BACKOFF_BASE_SECONDS * (2 ** (attempt - 1)) + log.warning( + "blocklist_download_retry_error", + url=url, + attempt=attempt, + error=repr(exc), + backoff=backoff, + ) + await asyncio.sleep(backoff) + + assert last_exception is not None + raise last_exception + # --------------------------------------------------------------------------- # Source CRUD helpers @@ -203,15 +254,18 @@ async def preview_source( ValueError: If the URL cannot be reached or returns a non-200 status. """ try: - async with http_session.get(url, timeout=_aiohttp_timeout(10)) as resp: - if resp.status != 200: - raise ValueError(f"HTTP {resp.status} from {url}") - raw = await resp.content.read(_PREVIEW_MAX_BYTES) + status, raw = await _download_text_with_retries( + http_session, + url, + _aiohttp_timeout(10), + ) + if status != 200: + raise ValueError(f"HTTP {status} from {url}") except Exception as exc: log.warning("blocklist_preview_failed", url=url, error=str(exc)) raise ValueError(str(exc)) from exc - lines = raw.decode(errors="replace").splitlines() + lines = raw.splitlines() entries: list[str] = [] valid = 0 skipped = 0 @@ -272,21 +326,22 @@ async def import_source( """ # --- Download --- try: - async with http_session.get( - source.url, timeout=_aiohttp_timeout(30) - ) as resp: - if resp.status != 200: - error_msg = f"HTTP {resp.status}" - await _log_result(db, source, 0, 0, error_msg) - log.warning("blocklist_import_download_failed", url=source.url, status=resp.status) - return ImportSourceResult( - source_id=source.id, - source_url=source.url, - ips_imported=0, - ips_skipped=0, - error=error_msg, - ) - content = await resp.text(errors="replace") + status, content = await _download_text_with_retries( + http_session, + source.url, + _aiohttp_timeout(30), + ) + if status != 200: + error_msg = f"HTTP {status}" + await _log_result(db, source, 0, 0, error_msg) + log.warning("blocklist_import_download_failed", url=source.url, status=status) + return ImportSourceResult( + source_id=source.id, + source_url=source.url, + ips_imported=0, + ips_skipped=0, + error=error_msg, + ) except Exception as exc: error_msg = str(exc) await _log_result(db, source, 0, 0, error_msg) diff --git a/backend/app/startup.py b/backend/app/startup.py index 1a76027..ebf908a 100644 --- a/backend/app/startup.py +++ b/backend/app/startup.py @@ -33,6 +33,22 @@ async def _ensure_database_schema(database_path: str) -> None: await db.close() +def _create_http_session(settings: Settings) -> aiohttp.ClientSession: + """Build a shared aiohttp session with reasonable global limits and timeouts.""" + timeout = aiohttp.ClientTimeout( + total=settings.http_request_timeout_seconds, + connect=settings.http_connect_timeout_seconds, + sock_read=settings.http_request_timeout_seconds, + ) + connector = aiohttp.TCPConnector( + limit=settings.http_max_connections, + limit_per_host=settings.http_max_connections, + keepalive_timeout=settings.http_keepalive_timeout_seconds, + enable_cleanup_closed=True, + ) + return aiohttp.ClientSession(timeout=timeout, connector=connector) + + async def startup_shared_resources( app: FastAPI, settings: Settings, @@ -82,7 +98,7 @@ async def startup_shared_resources( if unresolved_count > 0: log.warning("geo_cache_unresolved_ips", unresolved=unresolved_count) - http_session: aiohttp.ClientSession = aiohttp.ClientSession() + http_session: aiohttp.ClientSession = _create_http_session(settings) geo_service.init_geoip(settings.geoip_db_path) scheduler: AsyncIOScheduler = AsyncIOScheduler(timezone="UTC") diff --git a/backend/tests/test_main.py b/backend/tests/test_main.py index 9608bf0..95b96d8 100644 --- a/backend/tests/test_main.py +++ b/backend/tests/test_main.py @@ -123,6 +123,57 @@ async def test_lifespan_initialises_and_cleans_up_shared_resources(tmp_path: Pat mock_scheduler.shutdown.assert_called_once_with(wait=False) +async def test_http_session_is_created_with_configured_timeouts_and_limits(tmp_path: Path) -> None: + """The shared HTTP client session is created with the configured limits.""" + settings = Settings( + database_path=str(tmp_path / "bangui.db"), + fail2ban_socket="/tmp/fake_fail2ban.sock", + fail2ban_config_dir=str(tmp_path / "fail2ban"), + session_secret="test-lifespan-secret", + session_duration_minutes=60, + timezone="UTC", + log_level="debug", + http_request_timeout_seconds=12.5, + http_connect_timeout_seconds=1.5, + http_max_connections=5, + http_keepalive_timeout_seconds=8.0, + ) + app = create_app(settings=settings) + + mock_scheduler = MagicMock() + mock_scheduler.start = MagicMock() + mock_scheduler.shutdown = MagicMock() + + mock_http_session = MagicMock() + mock_http_session.close = AsyncMock() + + with ( + patch("app.startup.ensure_jail_configs"), + patch("app.startup.aiohttp.ClientSession", return_value=mock_http_session) as mock_client_session, + patch("app.startup.AsyncIOScheduler", return_value=mock_scheduler), + patch("app.startup.init_db", new=AsyncMock()), + patch("app.services.geo_service.init_geoip"), + patch("app.services.geo_service.load_cache_from_db", new=AsyncMock(return_value=None)), + patch("app.services.geo_service.count_unresolved", new=AsyncMock(return_value=0)), + patch("app.services.setup_service.is_setup_complete", new=AsyncMock(return_value=False)), + patch("app.tasks.health_check.register"), + patch("app.tasks.blocklist_import.register"), + patch("app.tasks.geo_cache_flush.register"), + patch("app.tasks.geo_re_resolve.register"), + patch("app.tasks.history_sync.register"), + ): + async with _lifespan(app): + assert mock_client_session.call_count == 1 + kwargs = mock_client_session.call_args.kwargs + timeout = kwargs["timeout"] + connector = kwargs["connector"] + assert timeout.total == 12.5 + assert timeout.connect == 1.5 + assert timeout.sock_read == 12.5 + assert connector.limit == 5 + assert connector.limit_per_host == 5 + + async def test_startup_overrides_settings_from_persisted_setup(tmp_path: Path) -> None: """Startup should replace env defaults with values persisted by setup.""" env_settings = Settings( diff --git a/backend/tests/test_services/test_blocklist_service.py b/backend/tests/test_services/test_blocklist_service.py index 674c554..0866163 100644 --- a/backend/tests/test_services/test_blocklist_service.py +++ b/backend/tests/test_services/test_blocklist_service.py @@ -125,6 +125,27 @@ class TestPreview: with pytest.raises(ValueError, match="HTTP 404"): await blocklist_service.preview_source("https://bad.test/", session) + async def test_preview_retries_transient_errors(self) -> None: + """preview_source retries transient network failures before succeeding.""" + content = "1.2.3.4\n" + mock_resp = AsyncMock() + mock_resp.status = 200 + mock_resp.text = AsyncMock(return_value=content) + mock_resp.content = AsyncMock() + mock_resp.content.read = AsyncMock(return_value=content.encode()) + + mock_ctx = AsyncMock() + mock_ctx.__aenter__.return_value = mock_resp + mock_ctx.__aexit__.return_value = False + + session = MagicMock() + session.get = MagicMock(side_effect=[Exception("connection reset"), mock_ctx]) + + result = await blocklist_service.preview_source("https://test.test/ips.txt", session) + + assert result.valid_count == 1 + assert session.get.call_count == 2 + async def test_preview_limits_entries(self) -> None: """preview_source caps entries to sample_lines.""" ips = "\n".join(f"1.2.3.{i}" for i in range(50))