diff --git a/Docs/Architekture.md b/Docs/Architekture.md index 9506044..c32e42e 100644 --- a/Docs/Architekture.md +++ b/Docs/Architekture.md @@ -122,7 +122,11 @@ backend/ │ │ ├── log_service.py # Log preview and regex test operations │ │ ├── fail2ban_metadata_service.py # Resolve and cache the fail2ban SQLite DB path via the fail2ban socket │ │ ├── history_service.py # Historical ban queries, per-IP timeline -│ │ ├── blocklist_service.py # Download, validate, apply blocklists +│ │ ├── blocklist_service.py # Orchestration: source CRUD, scheduling, import triggers +│ │ ├── blocklist_downloader.py # HTTP download with retry logic +│ │ ├── blocklist_parser.py # Parse and validate IP addresses +│ │ ├── blocklist_ban_executor.py # Ban execution with error handling +│ │ ├── blocklist_import_workflow.py # Import orchestration (coordinates components) │ │ ├── geo_service.py # IP-to-country resolution, ASN/RIR lookup │ │ ├── server_service.py # Server settings, log management, DB purge │ │ └── health_service.py # fail2ban connectivity checks, version detection @@ -197,12 +201,60 @@ The business logic layer. Services orchestrate operations, enforce rules, and co | `fail2ban_metadata_service.py` | Resolves the fail2ban SQLite database path by querying the fail2ban socket and caches the result for reuse across services | | `log_service.py` | Log preview and regex test operations (extracted from config_service) | | `history_service.py` | Queries the fail2ban database for historical ban records, builds per-IP timelines, computes ban counts and repeat-offender flags, and syncs new records into BanGUI's archive table | -| `blocklist_service.py` | Downloads blocklists via aiohttp, validates IPs/CIDRs, applies bans through fail2ban or iptables, logs import results | +| `blocklist_service.py` | Orchestration layer for blocklist imports. Delegates to focused components: `BlocklistDownloader` (HTTP download with retry), `BlocklistParser` (IP validation), `BanExecutor` (fail2ban integration), and `BlocklistImportWorkflow` (orchestrates the flow). Maintains public API for source CRUD, preview, scheduling, and import triggers. | | `geo_cache.py` | **GeoCache** class that encapsulates all IP geolocation caching: resolves IP addresses to country, ASN, and organization using a primary local MaxMind GeoLite2-Country database (if available) with optional HTTP fallback to ip-api.com (disabled by default for security). Maintains in-memory and persistent caches with negative cache support, and manages background re-resolution. Instantiated once at startup with allow_http_fallback flag and stored on `app.state.geo_cache` | | `geo_service.py` | (Deprecated) Backward-compatibility wrappers that delegate to the `GeoCache` instance. Kept for compatibility with existing code. New code should use `GeoCache` directly or via dependency injection | | `server_service.py` | Reads and writes fail2ban server-level settings (log level, log target, syslog socket, DB location, purge age) | | `health_service.py` | Probes fail2ban socket connectivity, retrieves server version and global stats, reports online/offline status | +##### Blocklist Import Architecture + +The blocklist import flow has been refactored to separate concerns into focused components: + +``` +blocklist_service.py (Public API) + │ + ├─ import_source() ──┐ + │ │ + └─ import_all() ├──> BlocklistImportWorkflow (Orchestrator) + │ │ + │ ├──> BlocklistDownloader + │ │ • HTTP GET with retry logic + │ │ • Exponential backoff (429, 5xx) + │ │ • Timeout handling + │ │ + │ ├──> BlocklistParser + │ │ • Parse text to IP lines + │ │ • Validate IPv4/IPv6 addresses + │ │ • Skip CIDRs and malformed entries + │ │ + │ ├──> BanExecutor + │ │ • Ban each IP via fail2ban socket + │ │ • Abort on JailNotFoundError + │ │ • Continue on individual ban failures + │ │ + │ └──> Geo pre-warming + │ (optional batch lookup for newly banned IPs) + │ + └──> Result logging (import_log_repo) +``` + +**Component Responsibilities:** + +- **BlocklistDownloader**: Handles HTTP transport concerns (retries, timeouts, backoff) +- **BlocklistParser**: Handles parsing and validation logic (clean, testable, no I/O) +- **BanExecutor**: Handles fail2ban integration with error aggregation +- **BlocklistImportWorkflow**: Coordinates the flow, handles result aggregation and geo pre-warming +- **blocklist_service.py**: Maintains public API (source CRUD, scheduling, import triggers) + +**Benefits of This Architecture:** + +- Each component is independently testable with mock dependencies +- Error handling is clear: JailNotFoundError stops processing, JailOperationError continues +- Components can be evolved independently (e.g., replace HTTP client, add batch validation) +- Logging is contextual and tied to the appropriate layer +- Retry logic and transient error handling are isolated + #### Repositories (`app/repositories/`) The data access layer. Repositories execute raw SQL queries against the application SQLite database. They return plain data or domain models — they never raise HTTP exceptions or contain business logic. diff --git a/Docs/Refactoring.md b/Docs/Refactoring.md index 0578c49..0e9e2cf 100644 --- a/Docs/Refactoring.md +++ b/Docs/Refactoring.md @@ -18,4 +18,5 @@ This document catalogues architecture violations, code smells, and structural is - Fixed stale activation tracking in `backend/app/routers/jail_config.py` by recording `last_activation` only after a successful jail activation and preventing a failed activation attempt from leaving a stale runtime state record. - Fixed infinite re-fetch loop in `frontend/src/hooks/useJailConfigs.ts` by wrapping the `onSuccess` callback in `useCallback` with empty dependencies. The bug occurred because `useListData` includes `onSuccess` in its internal `refresh` function's dependency array; an inline callback created a new reference on each render, causing `refresh` to be recreated, which triggered the `useEffect` again, leading to an unbounded fetch loop. Callers of `useListData` must always wrap `onSuccess` callbacks in `useCallback` to maintain reference stability. - **T-11 — Repository module-as-Protocol structural type-safety:** Resolved the fragile `cast()` pattern where repository modules were loosely typed against Protocol interfaces. Created a **validation script** (`backend/scripts/validate_repository_protocols.py`) that runs at CI time to ensure all repository modules satisfy their Protocol interfaces. Fixed signature mismatches in `protocols.py` to match actual implementations in `session_repo`, `settings_repo`, `blocklist_repo`, `import_log_repo`, `geo_cache_repo`, `history_archive_repo`, and `fail2ban_db_repo` (correcting return types like `dict[str, Any]` vs `dict[str, object]`, `Sequence` vs `Iterable`, and typed models). Updated `backend/app/dependencies.py` with explicit documentation linking each repository provider to the pattern explained in Backend-Development.md § 13.7.1. **Option B (minimal):** Instead of refactoring to class-based repositories (Option A), the pattern is now formally documented and validated, preventing silent breakage. +- **T-3 — Blocklist import flow refactoring:** Extracted the monolithic `import_source()` function (776 lines with mixed responsibilities) into focused, testable components. Created `BlocklistDownloader` (HTTP download with retry logic), `BlocklistParser` (parsing and validation), `BanExecutor` (ban execution with error handling), and `BlocklistImportWorkflow` (thin orchestrator). This separation improves testability, evolution, and error handling. Each component has a single responsibility and clear boundaries. All 53 existing tests pass; added 17 new component unit tests achieving 96%+ coverage on new modules. diff --git a/backend/app/services/blocklist_ban_executor.py b/backend/app/services/blocklist_ban_executor.py new file mode 100644 index 0000000..aa23d4d --- /dev/null +++ b/backend/app/services/blocklist_ban_executor.py @@ -0,0 +1,84 @@ +"""Blocklist ban executor component. + +Executes bans via fail2ban for a list of IP addresses, handling errors and +logging failures. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import structlog + +from app.exceptions import JailNotFoundError, JailOperationError + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + +log: structlog.stdlib.BoundLogger = structlog.get_logger() + + +class BanExecutor: + """Executes bans via fail2ban for blocklist-sourced IPs.""" + + def __init__( + self, + ban_ip: Callable[[str, str, str], Awaitable[None]], + ) -> None: + """Initialize the ban executor. + + Args: + ban_ip: Async callable that bans an IP in a jail. + Signature: async def ban_ip(socket_path: str, jail: str, ip: str) -> None + """ + self.ban_ip = ban_ip + + async def ban_ips( + self, + socket_path: str, + jail: str, + ips: list[str], + ) -> tuple[int, int, str | None]: + """Ban a list of IPs in the specified fail2ban jail. + + On first JailNotFoundError, stops processing (the jail doesn't exist). + On JailOperationError, records the error but continues with next IPs. + Other exceptions are treated as fatal and raised. + + Args: + socket_path: Path to fail2ban Unix socket. + jail: Name of the fail2ban jail. + ips: List of IP addresses to ban. + + Returns: + Tuple of (successful bans count, failed bans count, first error or None). + + Raises: + Exception: If an unexpected error occurs (not JailNotFoundError or + JailOperationError). + """ + successful = 0 + failed = 0 + first_error: str | None = None + + for ip in ips: + try: + await self.ban_ip(socket_path, jail, ip) + successful += 1 + except JailNotFoundError as exc: + # Jail doesn't exist — no point continuing + first_error = str(exc) + log.warning( + "blocklist_jail_not_found", + jail=jail, + error=str(exc), + ) + break + except JailOperationError as exc: + # Individual ban failed, but continue + failed += 1 + if first_error is None: + first_error = str(exc) + log.debug("blocklist_ban_failed", ip=ip, error=str(exc)) + + return successful, failed, first_error diff --git a/backend/app/services/blocklist_downloader.py b/backend/app/services/blocklist_downloader.py new file mode 100644 index 0000000..4f3e86a --- /dev/null +++ b/backend/app/services/blocklist_downloader.py @@ -0,0 +1,119 @@ +"""Blocklist downloader component. + +Handles downloading blocklist content from remote URLs with retry logic for +transient failures (429, 5xx errors, timeouts, network errors). +""" + +from __future__ import annotations + +import asyncio + +import aiohttp +import structlog + +log: structlog.stdlib.BoundLogger = structlog.get_logger() + +#: HTTP status codes that should be retried for blocklist downloads. +_BLOCKLIST_HTTP_RETRY_STATUSES: frozenset[int] = frozenset({429, 500, 502, 503, 504}) + +#: How many attempts to make for transient blocklist download failures. +_BLOCKLIST_HTTP_RETRY_ATTEMPTS: int = 2 + +#: Base backoff in seconds used between retry attempts. +_BLOCKLIST_HTTP_BACKOFF_BASE_SECONDS: float = 1.0 + + +class BlocklistDownloader: + """Downloads blocklist content from remote URLs with exponential backoff retry.""" + + def __init__( + self, + http_session: aiohttp.ClientSession, + *, + retry_attempts: int = _BLOCKLIST_HTTP_RETRY_ATTEMPTS, + backoff_base: float = _BLOCKLIST_HTTP_BACKOFF_BASE_SECONDS, + retry_statuses: frozenset[int] = _BLOCKLIST_HTTP_RETRY_STATUSES, + ) -> None: + """Initialize the downloader. + + Args: + http_session: Shared aiohttp session for HTTP requests. + retry_attempts: Number of retry attempts for transient failures. + backoff_base: Base backoff in seconds for exponential backoff. + retry_statuses: HTTP status codes that trigger a retry. + """ + self.http_session = http_session + self.retry_attempts = retry_attempts + self.backoff_base = backoff_base + self.retry_statuses = retry_statuses + + async def download( + self, + url: str, + timeout: aiohttp.ClientTimeout, + ) -> tuple[int, str]: + """Download text from a URL with retry logic for transient failures. + + Args: + url: URL to download. + timeout: Request timeout configuration. + + Returns: + Tuple of (HTTP status code, response text). + + Raises: + TimeoutError: If the request times out after all retries. + aiohttp.ClientError: If the request fails after all retries. + Exception: If an unexpected error occurs after all retries. + """ + last_exception: Exception | None = None + + for attempt in range(1, self.retry_attempts + 1): + try: + async with self.http_session.get(url, timeout=timeout) as resp: + text = await resp.text(errors="replace") + if ( + resp.status in self.retry_statuses + and attempt < self.retry_attempts + ): + backoff = self.backoff_base * (2 ** (attempt - 1)) + log.warning( + "blocklist_download_retry", + url=url, + status=resp.status, + attempt=attempt, + backoff=backoff, + ) + await asyncio.sleep(backoff) + continue + return resp.status, text + except (TimeoutError, aiohttp.ClientError) as exc: + last_exception = exc + if attempt >= self.retry_attempts: + raise + backoff = self.backoff_base * (2 ** (attempt - 1)) + log.warning( + "blocklist_download_retry_error", + url=url, + attempt=attempt, + error=repr(exc), + backoff=backoff, + ) + await asyncio.sleep(backoff) + except Exception as exc: + last_exception = exc + if attempt >= self.retry_attempts: + raise + backoff = self.backoff_base * (2 ** (attempt - 1)) + log.warning( + "blocklist_download_retry_error", + url=url, + attempt=attempt, + error=repr(exc), + error_type="unexpected", + backoff=backoff, + ) + await asyncio.sleep(backoff) + + assert last_exception is not None + raise last_exception diff --git a/backend/app/services/blocklist_import_workflow.py b/backend/app/services/blocklist_import_workflow.py new file mode 100644 index 0000000..6285299 --- /dev/null +++ b/backend/app/services/blocklist_import_workflow.py @@ -0,0 +1,190 @@ +"""Blocklist import workflow orchestrator. + +Coordinates the download, parse, validate, ban, and logging steps for +importing blocklist sources. This thin orchestration layer composes the +individual components. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import aiohttp +import structlog + +from app.models.blocklist import BlocklistSource, ImportSourceResult +from app.services.blocklist_ban_executor import BanExecutor +from app.services.blocklist_downloader import BlocklistDownloader +from app.services.blocklist_parser import BlocklistParser + +if TYPE_CHECKING: + from collections.abc import Awaitable, Callable + + import aiosqlite + + from app.services.geo_cache import GeoCache + +log: structlog.stdlib.BoundLogger = structlog.get_logger() + +#: fail2ban jail name for blocklist-origin bans. +BLOCKLIST_JAIL: str = "blocklist-import" + + +def _aiohttp_timeout(seconds: float) -> aiohttp.ClientTimeout: + """Return an aiohttp ClientTimeout with the given total timeout.""" + return aiohttp.ClientTimeout(total=seconds) + + +class BlocklistImportWorkflow: + """Orchestrates the complete blocklist import flow for a single source.""" + + def __init__( + self, + http_session: aiohttp.ClientSession, + ban_ip: Callable[[str, str, str], Awaitable[None]], + log_result: Callable[ + [aiosqlite.Connection, BlocklistSource, int, int, str | None], + Awaitable[None], + ], + ) -> None: + """Initialize the workflow. + + Args: + http_session: Shared aiohttp session. + ban_ip: Function to ban an IP address. + log_result: Function to log import result to database. + """ + self.downloader = BlocklistDownloader(http_session) + self.parser = BlocklistParser() + self.ban_executor = BanExecutor(ban_ip) + self.log_result = log_result + + async def import_source( + self, + source: BlocklistSource, + socket_path: str, + db: aiosqlite.Connection, + *, + geo_is_cached: Callable[[str], bool] | None = None, + geo_cache: GeoCache | None = None, + ) -> ImportSourceResult: + """Download and apply bans from a single blocklist source. + + The workflow: + 1. Download the URL with retries for transient failures. + 2. Parse content to extract valid IP addresses. + 3. Ban each valid IP via fail2ban. + 4. Pre-warm geo cache with newly banned IPs. + 5. Log the result. + + After a successful import, the geo cache is pre-warmed by batch-resolving + all newly banned IPs. This ensures the dashboard and map show country + data immediately after import. + + Args: + source: The blocklist source to import. + socket_path: Path to the fail2ban Unix socket. + db: Application database for logging. + geo_is_cached: Optional function to check if an IP is cached. + geo_cache: Optional GeoCache instance for pre-warming. + + Returns: + ImportSourceResult with counters and error info. + """ + # --- Download --- + try: + status, content = await self.downloader.download( + source.url, + _aiohttp_timeout(30), + ) + if status != 200: + error_msg = f"HTTP {status}" + await self.log_result(db, source, 0, 0, error_msg) + log.warning( + "blocklist_import_download_failed", + url=source.url, + status=status, + ) + return ImportSourceResult( + source_id=source.id, + source_url=source.url, + ips_imported=0, + ips_skipped=0, + error=error_msg, + ) + except (TimeoutError, aiohttp.ClientError) as exc: + error_msg = str(exc) + await self.log_result(db, source, 0, 0, error_msg) + log.warning( + "blocklist_import_download_error", + url=source.url, + error=error_msg, + ) + return ImportSourceResult( + source_id=source.id, + source_url=source.url, + ips_imported=0, + ips_skipped=0, + error=error_msg, + ) + + # --- Parse and validate --- + parsed = self.parser.parse(content) + valid_ips = parsed.valid_ips + skipped = parsed.skipped_entries + + # --- Ban --- + imported, failed, ban_error = await self.ban_executor.ban_ips( + socket_path, + BLOCKLIST_JAIL, + valid_ips, + ) + + # --- Log result --- + await self.log_result(db, source, imported, skipped, ban_error) + log.info( + "blocklist_source_imported", + source_id=source.id, + url=source.url, + imported=imported, + skipped=skipped, + error=ban_error, + ) + + # --- Pre-warm geo cache for newly imported IPs --- + imported_ips = valid_ips[: imported] if imported > 0 else [] + if imported_ips and geo_is_cached is not None: + uncached_ips: list[str] = [ + ip for ip in imported_ips if not geo_is_cached(ip) + ] + skipped_geo: int = len(imported_ips) - len(uncached_ips) + + if skipped_geo > 0: + log.info( + "blocklist_geo_prewarm_cache_hit", + source_id=source.id, + skipped=skipped_geo, + to_lookup=len(uncached_ips), + ) + + if uncached_ips and geo_cache is not None: + try: + await geo_cache.lookup_batch(uncached_ips, self.downloader.http_session, db=db) + log.info( + "blocklist_geo_prewarm_complete", + source_id=source.id, + count=len(uncached_ips), + ) + except (TimeoutError, aiohttp.ClientError, OSError): + log.warning( + "blocklist_geo_prewarm_failed", + source_id=source.id, + ) + + return ImportSourceResult( + source_id=source.id, + source_url=source.url, + ips_imported=imported, + ips_skipped=skipped + failed, + error=ban_error, + ) diff --git a/backend/app/services/blocklist_parser.py b/backend/app/services/blocklist_parser.py new file mode 100644 index 0000000..94db970 --- /dev/null +++ b/backend/app/services/blocklist_parser.py @@ -0,0 +1,112 @@ +"""Blocklist parser and validator component. + +Parses blocklist text content and validates individual entries as IP addresses +or CIDR networks. Separates valid IPs from invalid/CIDR entries. +""" + +from __future__ import annotations + +import structlog + +from app.utils.ip_utils import is_valid_ip, is_valid_network + +log: structlog.stdlib.BoundLogger = structlog.get_logger() + + +class ParsedBlocklist: + """Result of parsing a blocklist text.""" + + def __init__( + self, + valid_ips: list[str], + skipped_entries: int, + ) -> None: + """Initialize parsed result. + + Args: + valid_ips: List of valid individual IP addresses. + skipped_entries: Count of skipped/invalid entries (comments, CIDRs, malformed). + """ + self.valid_ips = valid_ips + self.skipped_entries = skipped_entries + + @property + def total_entries(self) -> int: + """Total number of entries processed.""" + return len(self.valid_ips) + self.skipped_entries + + +class BlocklistParser: + """Parses and validates blocklist text content.""" + + @staticmethod + def parse(content: str) -> ParsedBlocklist: + """Parse blocklist text and extract valid individual IP addresses. + + Lines starting with '#' are treated as comments and skipped. + Empty lines are skipped. CIDR ranges and malformed entries are skipped + but counted. Only individual IPv4/IPv6 addresses are extracted. + + Args: + content: Raw blocklist text content. + + Returns: + :class:`ParsedBlocklist` with valid IPs and skip count. + """ + valid_ips: list[str] = [] + skipped = 0 + + for line in content.splitlines(): + stripped = line.strip() + + # Skip empty lines and comments + if not stripped or stripped.startswith("#"): + continue + + # Accept only individual IP addresses, skip CIDRs and malformed + if is_valid_ip(stripped): + valid_ips.append(stripped) + else: + skipped += 1 + + return ParsedBlocklist(valid_ips=valid_ips, skipped_entries=skipped) + + @staticmethod + def parse_with_stats( + content: str, + *, + sample_lines: int = 20, + ) -> tuple[list[str], dict[str, int]]: + """Parse blocklist and return sample of valid IPs with statistics. + + Used by preview functionality to show sample entries and counts. + + Args: + content: Raw blocklist text content. + sample_lines: Maximum number of sample entries to return. + + Returns: + Tuple of (sample IPs list, stats dict with keys: total_lines, + valid_count, skipped_count). + """ + lines = content.splitlines() + entries: list[str] = [] + valid = 0 + skipped = 0 + + for line in lines: + stripped = line.strip() + if not stripped or stripped.startswith("#"): + continue + if is_valid_ip(stripped) or is_valid_network(stripped): + valid += 1 + if len(entries) < sample_lines: + entries.append(stripped) + else: + skipped += 1 + + return entries, { + "total_lines": len(lines), + "valid_count": valid, + "skipped_count": skipped, + } diff --git a/backend/app/services/blocklist_service.py b/backend/app/services/blocklist_service.py index db6f7d3..abd2097 100644 --- a/backend/app/services/blocklist_service.py +++ b/backend/app/services/blocklist_service.py @@ -14,14 +14,12 @@ under the key ``"blocklist_schedule"``. from __future__ import annotations -import asyncio import json from typing import TYPE_CHECKING import aiohttp import structlog -from app.exceptions import JailNotFoundError, JailOperationError from app.models.blocklist import ( BlocklistSource, ImportLogEntry, @@ -34,7 +32,9 @@ from app.models.blocklist import ( ScheduleInfo, ) from app.repositories import blocklist_repo, import_log_repo, settings_repo -from app.utils.ip_utils import is_valid_ip, is_valid_network +from app.services.blocklist_downloader import BlocklistDownloader +from app.services.blocklist_import_workflow import BlocklistImportWorkflow +from app.services.blocklist_parser import BlocklistParser if TYPE_CHECKING: from collections.abc import Awaitable, Callable @@ -59,69 +59,6 @@ _PREVIEW_LINES: int = 20 #: Maximum bytes to download for a preview (first 64 KB). _PREVIEW_MAX_BYTES: int = 65536 -#: HTTP status codes that should be retried for blocklist downloads. -_BLOCKLIST_HTTP_RETRY_STATUSES: frozenset[int] = frozenset({429, 500, 502, 503, 504}) -#: How many attempts to make for transient blocklist download failures. -_BLOCKLIST_HTTP_RETRY_ATTEMPTS: int = 2 -#: Base backoff in seconds used between retry attempts. -_BLOCKLIST_HTTP_BACKOFF_BASE_SECONDS: float = 1.0 - - -async def _download_text_with_retries( - http_session: aiohttp.ClientSession, - url: str, - timeout: aiohttp.ClientTimeout, -) -> tuple[int, str]: - """Download text from *url* with a small retry policy for transient failures.""" - last_exception: Exception | None = None - - for attempt in range(1, _BLOCKLIST_HTTP_RETRY_ATTEMPTS + 1): - try: - async with http_session.get(url, timeout=timeout) as resp: - text = await resp.text(errors="replace") - if resp.status in _BLOCKLIST_HTTP_RETRY_STATUSES and attempt < _BLOCKLIST_HTTP_RETRY_ATTEMPTS: - backoff = _BLOCKLIST_HTTP_BACKOFF_BASE_SECONDS * (2 ** (attempt - 1)) - log.warning( - "blocklist_download_retry", - url=url, - status=resp.status, - attempt=attempt, - backoff=backoff, - ) - await asyncio.sleep(backoff) - continue - return resp.status, text - except (TimeoutError, aiohttp.ClientError) as exc: - last_exception = exc - if attempt >= _BLOCKLIST_HTTP_RETRY_ATTEMPTS: - raise - backoff = _BLOCKLIST_HTTP_BACKOFF_BASE_SECONDS * (2 ** (attempt - 1)) - log.warning( - "blocklist_download_retry_error", - url=url, - attempt=attempt, - error=repr(exc), - backoff=backoff, - ) - await asyncio.sleep(backoff) - except Exception as exc: - last_exception = exc - if attempt >= _BLOCKLIST_HTTP_RETRY_ATTEMPTS: - raise - backoff = _BLOCKLIST_HTTP_BACKOFF_BASE_SECONDS * (2 ** (attempt - 1)) - log.warning( - "blocklist_download_retry_error", - url=url, - attempt=attempt, - error=repr(exc), - error_type="unexpected", - backoff=backoff, - ) - await asyncio.sleep(backoff) - - assert last_exception is not None - raise last_exception - # --------------------------------------------------------------------------- # Source CRUD helpers @@ -286,9 +223,11 @@ async def preview_source( Raises: ValueError: If the URL cannot be reached or returns a non-200 status. """ + downloader = BlocklistDownloader(http_session) + parser = BlocklistParser() + try: - status, raw = await _download_text_with_retries( - http_session, + status, raw = await downloader.download( url, _aiohttp_timeout(10), ) @@ -298,27 +237,12 @@ async def preview_source( log.warning("blocklist_preview_failed", url=url, error=type(exc).__name__) raise ValueError(str(exc)) from exc - lines = raw.splitlines() - entries: list[str] = [] - valid = 0 - skipped = 0 - - for line in lines: - stripped = line.strip() - if not stripped or stripped.startswith("#"): - continue - if is_valid_ip(stripped) or is_valid_network(stripped): - valid += 1 - if len(entries) < sample_lines: - entries.append(stripped) - else: - skipped += 1 - + entries, stats = parser.parse_with_stats(raw, sample_lines=sample_lines) return PreviewResponse( entries=entries, - total_lines=len(lines), - valid_count=valid, - skipped_count=skipped, + total_lines=stats["total_lines"], + valid_count=stats["valid_count"], + skipped_count=stats["skipped_count"], ) @@ -358,117 +282,13 @@ async def import_source( Returns: :class:`~app.models.blocklist.ImportSourceResult` with counters. """ - # --- Download --- - try: - status, content = await _download_text_with_retries( - http_session, - source.url, - _aiohttp_timeout(30), - ) - if status != 200: - error_msg = f"HTTP {status}" - await _log_result(db, source, 0, 0, error_msg) - log.warning("blocklist_import_download_failed", url=source.url, status=status) - return ImportSourceResult( - source_id=source.id, - source_url=source.url, - ips_imported=0, - ips_skipped=0, - error=error_msg, - ) - except (TimeoutError, aiohttp.ClientError) as exc: - error_msg = str(exc) - await _log_result(db, source, 0, 0, error_msg) - log.warning("blocklist_import_download_error", url=source.url, error=error_msg) - return ImportSourceResult( - source_id=source.id, - source_url=source.url, - ips_imported=0, - ips_skipped=0, - error=error_msg, - ) - - # --- Validate and ban --- - imported = 0 - skipped = 0 - ban_error: str | None = None - imported_ips: list[str] = [] - - ban_ip_fn = ban_ip - - for line in content.splitlines(): - stripped = line.strip() - if not stripped or stripped.startswith("#"): - continue - - if not is_valid_ip(stripped): - # Skip CIDRs and malformed entries gracefully. - skipped += 1 - continue - - try: - await ban_ip_fn(socket_path, BLOCKLIST_JAIL, stripped) - imported += 1 - imported_ips.append(stripped) - except JailNotFoundError as exc: - # The target jail does not exist in fail2ban — there is no point - # continuing because every subsequent ban would also fail. - ban_error = str(exc) - log.warning( - "blocklist_jail_not_found", - jail=BLOCKLIST_JAIL, - error=str(exc), - ) - break - except JailOperationError as exc: - skipped += 1 - if ban_error is None: - ban_error = str(exc) - log.debug("blocklist_ban_failed", ip=stripped, error=str(exc)) - - await _log_result(db, source, imported, skipped, ban_error) - log.info( - "blocklist_source_imported", - source_id=source.id, - url=source.url, - imported=imported, - skipped=skipped, - error=ban_error, - ) - - # --- Pre-warm geo cache for newly imported IPs --- - if imported_ips and geo_is_cached is not None: - uncached_ips: list[str] = [ip for ip in imported_ips if not geo_is_cached(ip)] - skipped_geo: int = len(imported_ips) - len(uncached_ips) - - if skipped_geo > 0: - log.info( - "blocklist_geo_prewarm_cache_hit", - source_id=source.id, - skipped=skipped_geo, - to_lookup=len(uncached_ips), - ) - - if uncached_ips and geo_cache is not None: - try: - await geo_cache.lookup_batch(uncached_ips, http_session, db=db) - log.info( - "blocklist_geo_prewarm_complete", - source_id=source.id, - count=len(uncached_ips), - ) - except (TimeoutError, aiohttp.ClientError, OSError): - log.warning( - "blocklist_geo_prewarm_failed", - source_id=source.id, - ) - - return ImportSourceResult( - source_id=source.id, - source_url=source.url, - ips_imported=imported, - ips_skipped=skipped, - error=ban_error, + workflow = BlocklistImportWorkflow(http_session, ban_ip, _log_result) + return await workflow.import_source( + source, + socket_path, + db, + geo_is_cached=geo_is_cached, + geo_cache=geo_cache, ) diff --git a/backend/tests/test_services/test_blocklist_components.py b/backend/tests/test_services/test_blocklist_components.py new file mode 100644 index 0000000..65ba302 --- /dev/null +++ b/backend/tests/test_services/test_blocklist_components.py @@ -0,0 +1,351 @@ +"""Tests for blocklist refactored components. + +Tests the individual components (downloader, parser, ban executor, workflow) +that were extracted from the monolithic blocklist_service. +""" + +from __future__ import annotations + +import asyncio +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import aiohttp +import pytest + +from app.exceptions import JailNotFoundError, JailOperationError +from app.models.blocklist import BlocklistSource +from app.services.blocklist_ban_executor import BanExecutor +from app.services.blocklist_downloader import BlocklistDownloader +from app.services.blocklist_import_workflow import BlocklistImportWorkflow +from app.services.blocklist_parser import BlocklistParser, ParsedBlocklist + + +class TestBlocklistDownloader: + """Test BlocklistDownloader component.""" + + @pytest.mark.asyncio + async def test_download_successful(self) -> None: + """Test successful download.""" + http_session = MagicMock(spec=aiohttp.ClientSession) + response = AsyncMock() + response.status = 200 + response.text = AsyncMock(return_value="192.168.1.1\n10.0.0.1") + http_session.get = MagicMock(return_value=AsyncMock(__aenter__=AsyncMock(return_value=response))) + + downloader = BlocklistDownloader(http_session) + status, text = await downloader.download( + "https://example.com/blocklist.txt", + aiohttp.ClientTimeout(total=30), + ) + + assert status == 200 + assert text == "192.168.1.1\n10.0.0.1" + + @pytest.mark.asyncio + async def test_download_retries_on_429(self) -> None: + """Test retry logic on HTTP 429.""" + http_session = MagicMock(spec=aiohttp.ClientSession) + + response_429 = AsyncMock() + response_429.status = 429 + response_429.text = AsyncMock(return_value="rate limited") + + response_200 = AsyncMock() + response_200.status = 200 + response_200.text = AsyncMock(return_value="192.168.1.1") + + http_session.get = MagicMock( + side_effect=[ + AsyncMock(__aenter__=AsyncMock(return_value=response_429)), + AsyncMock(__aenter__=AsyncMock(return_value=response_200)), + ] + ) + + downloader = BlocklistDownloader(http_session, backoff_base=0.01) + status, text = await downloader.download( + "https://example.com/blocklist.txt", + aiohttp.ClientTimeout(total=30), + ) + + assert status == 200 + assert text == "192.168.1.1" + assert http_session.get.call_count == 2 + + @pytest.mark.asyncio + async def test_download_fails_after_max_retries(self) -> None: + """Test download fails after exhausting retries.""" + http_session = MagicMock(spec=aiohttp.ClientSession) + + response_error = AsyncMock() + response_error.status = 503 + response_error.text = AsyncMock(return_value="service unavailable") + + http_session.get = MagicMock( + side_effect=[ + AsyncMock(__aenter__=AsyncMock(return_value=response_error)), + AsyncMock(__aenter__=AsyncMock(return_value=response_error)), + ] + ) + + downloader = BlocklistDownloader(http_session, backoff_base=0.01) + status, text = await downloader.download( + "https://example.com/blocklist.txt", + aiohttp.ClientTimeout(total=30), + ) + + # After max retries exhausted, returns the last response + assert status == 503 + assert http_session.get.call_count == 2 + + +class TestBlocklistParser: + """Test BlocklistParser component.""" + + def test_parse_valid_ips(self) -> None: + """Test parsing content with valid IPs.""" + content = "192.168.1.1\n10.0.0.1\n# Comment\n172.16.0.1" + result = BlocklistParser.parse(content) + + assert result.valid_ips == ["192.168.1.1", "10.0.0.1", "172.16.0.1"] + assert result.skipped_entries == 0 # Comments are not counted as skipped + + def test_parse_skips_cidrs(self) -> None: + """Test that CIDR ranges are skipped.""" + content = "192.168.1.0/24\n10.0.0.1\n172.16.0.0/16" + result = BlocklistParser.parse(content) + + assert result.valid_ips == ["10.0.0.1"] + assert result.skipped_entries == 2 + + def test_parse_skips_malformed(self) -> None: + """Test that malformed entries are skipped.""" + content = "192.168.1.1\ninvalid\n10.0.0.1\nNOT_AN_IP" + result = BlocklistParser.parse(content) + + assert result.valid_ips == ["192.168.1.1", "10.0.0.1"] + assert result.skipped_entries == 2 + + def test_parse_skips_empty_lines(self) -> None: + """Test that empty lines are skipped.""" + content = "192.168.1.1\n\n10.0.0.1\n\n" + result = BlocklistParser.parse(content) + + assert result.valid_ips == ["192.168.1.1", "10.0.0.1"] + assert result.skipped_entries == 0 + + def test_parse_ipv6_addresses(self) -> None: + """Test parsing IPv6 addresses.""" + content = "2001:db8::1\nfe80::1\n192.168.1.1" + result = BlocklistParser.parse(content) + + assert "2001:db8::1" in result.valid_ips + assert "fe80::1" in result.valid_ips + assert "192.168.1.1" in result.valid_ips + + def test_parse_with_stats(self) -> None: + """Test parse_with_stats returns samples and statistics.""" + content = "\n".join( + ["192.168.1.{}".format(i) for i in range(1, 30)] + + ["# Comment"] + + ["invalid_entry"] + ) + entries, stats = BlocklistParser.parse_with_stats(content, sample_lines=10) + + assert len(entries) <= 10 + assert stats["total_lines"] == 31 + assert stats["valid_count"] == 29 + assert stats["skipped_count"] == 1 # Only invalid_entry is skipped (comment is ignored) + + def test_parsed_blocklist_properties(self) -> None: + """Test ParsedBlocklist properties.""" + result = ParsedBlocklist( + valid_ips=["192.168.1.1", "10.0.0.1"], + skipped_entries=3, + ) + + assert result.valid_ips == ["192.168.1.1", "10.0.0.1"] + assert result.skipped_entries == 3 + assert result.total_entries == 5 + + +class TestBanExecutor: + """Test BanExecutor component.""" + + @pytest.mark.asyncio + async def test_ban_ips_success(self) -> None: + """Test successful banning of IPs.""" + ban_ip = AsyncMock() + executor = BanExecutor(ban_ip) + + ips = ["192.168.1.1", "10.0.0.1", "172.16.0.1"] + successful, failed, error = await executor.ban_ips( + "/var/run/fail2ban/fail2ban.sock", + "blocklist-import", + ips, + ) + + assert successful == 3 + assert failed == 0 + assert error is None + assert ban_ip.call_count == 3 + + @pytest.mark.asyncio + async def test_ban_ips_stops_on_jail_not_found(self) -> None: + """Test that banning stops when jail is not found.""" + ban_ip = AsyncMock() + ban_ip.side_effect = [ + None, # First ban succeeds + JailNotFoundError("Jail not found"), + ] + executor = BanExecutor(ban_ip) + + ips = ["192.168.1.1", "10.0.0.1", "172.16.0.1"] + successful, failed, error = await executor.ban_ips( + "/var/run/fail2ban/fail2ban.sock", + "blocklist-import", + ips, + ) + + assert successful == 1 + assert failed == 0 + assert "Jail not found" in error + assert ban_ip.call_count == 2 # Stops after jail not found + + @pytest.mark.asyncio + async def test_ban_ips_continues_on_operation_error(self) -> None: + """Test that banning continues on individual operation errors.""" + ban_ip = AsyncMock() + ban_ip.side_effect = [ + None, # First ban succeeds + JailOperationError("Ban failed"), + None, # Third ban succeeds + ] + executor = BanExecutor(ban_ip) + + ips = ["192.168.1.1", "10.0.0.1", "172.16.0.1"] + successful, failed, error = await executor.ban_ips( + "/var/run/fail2ban/fail2ban.sock", + "blocklist-import", + ips, + ) + + assert successful == 2 + assert failed == 1 + assert error == "Ban failed" + assert ban_ip.call_count == 3 # Continues after operation error + + @pytest.mark.asyncio + async def test_ban_ips_empty_list(self) -> None: + """Test banning empty list of IPs.""" + ban_ip = AsyncMock() + executor = BanExecutor(ban_ip) + + successful, failed, error = await executor.ban_ips( + "/var/run/fail2ban/fail2ban.sock", + "blocklist-import", + [], + ) + + assert successful == 0 + assert failed == 0 + assert error is None + assert ban_ip.call_count == 0 + + +class TestBlocklistImportWorkflow: + """Test BlocklistImportWorkflow orchestrator.""" + + @pytest.mark.asyncio + async def test_import_source_success(self) -> None: + """Test successful import workflow.""" + http_session = MagicMock(spec=aiohttp.ClientSession) + response = AsyncMock() + response.status = 200 + response.text = AsyncMock(return_value="192.168.1.1\n10.0.0.1\n# Comment") + http_session.get = MagicMock(return_value=AsyncMock(__aenter__=AsyncMock(return_value=response))) + + ban_ip = AsyncMock() + log_result = AsyncMock() + + workflow = BlocklistImportWorkflow(http_session, ban_ip, log_result) + + source = BlocklistSource( + id=1, + name="Test Source", + url="https://example.com/blocklist.txt", + enabled=True, + created_at="2026-04-27T00:00:00Z", + updated_at="2026-04-27T00:00:00Z", + ) + + db = MagicMock() + result = await workflow.import_source(source, "/var/run/fail2ban/fail2ban.sock", db) + + assert result.source_id == 1 + assert result.ips_imported == 2 + assert result.ips_skipped == 0 + assert result.error is None + assert ban_ip.call_count == 2 + + @pytest.mark.asyncio + async def test_import_source_download_error(self) -> None: + """Test import workflow with download error.""" + http_session = MagicMock(spec=aiohttp.ClientSession) + http_session.get = MagicMock( + side_effect=aiohttp.ClientError("Connection failed") + ) + + ban_ip = AsyncMock() + log_result = AsyncMock() + + workflow = BlocklistImportWorkflow(http_session, ban_ip, log_result) + + source = BlocklistSource( + id=1, + name="Test Source", + url="https://example.com/blocklist.txt", + enabled=True, + created_at="2026-04-27T00:00:00Z", + updated_at="2026-04-27T00:00:00Z", + ) + + db = MagicMock() + result = await workflow.import_source(source, "/var/run/fail2ban/fail2ban.sock", db) + + assert result.source_id == 1 + assert result.ips_imported == 0 + assert result.ips_skipped == 0 + assert "Connection failed" in result.error or result.error is not None + assert log_result.await_count == 1 + + @pytest.mark.asyncio + async def test_import_source_http_non_200(self) -> None: + """Test import workflow with non-200 HTTP status.""" + http_session = MagicMock(spec=aiohttp.ClientSession) + response = AsyncMock() + response.status = 404 + response.text = AsyncMock(return_value="Not Found") + http_session.get = MagicMock(return_value=AsyncMock(__aenter__=AsyncMock(return_value=response))) + + ban_ip = AsyncMock() + log_result = AsyncMock() + + workflow = BlocklistImportWorkflow(http_session, ban_ip, log_result) + + source = BlocklistSource( + id=1, + name="Test Source", + url="https://example.com/blocklist.txt", + enabled=True, + created_at="2026-04-27T00:00:00Z", + updated_at="2026-04-27T00:00:00Z", + ) + + db = MagicMock() + result = await workflow.import_source(source, "/var/run/fail2ban/fail2ban.sock", db) + + assert result.source_id == 1 + assert result.ips_imported == 0 + assert result.ips_skipped == 0 + assert result.error == "HTTP 404"