Files
BanGUI/backend/app/services/blocklist_service.py
Lukas 4ab767e3d4 TASK-009: Mitigate SSRF vulnerability in blocklist URL validation
- Change BlocklistSourceCreate.url from str to AnyHttpUrl (Pydantic type)
  - Rejects non-http schemes (file://, ftp://, etc.) at model boundary

- Add is_private_ip() utility to detect RFC 1918 private ranges:
  - 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16 (RFC 1918)
  - 127.0.0.0/8, ::1/128 (loopback)
  - 169.254.0.0/16, fe80::/10 (link-local)
  - IPv6 site-local, multicast, and reserved ranges

- Add async validate_blocklist_url() function:
  - Resolves hostname via DNS using loop.run_in_executor()
  - Rejects if hostname resolves to private/reserved IP
  - Raises ValueError on validation failure

- Integrate validation into service layer:
  - create_source() calls validate_blocklist_url() before persist
  - update_source() conditionally validates if url provided
  - Both raise ValueError on failure

- Update router endpoints with error handling:
  - create_blocklist() and update_blocklist() catch ValueError
  - Return HTTP 400 Bad Request with descriptive error message

- Add comprehensive test coverage (9 new SSRF tests):
  - file://, ftp://, localhost, 127.0.0.1, 192.168.x.x
  - 10.x.x.x, 172.16.x.x, 169.254.x.x (link-local)
  - Valid public URLs (passes validation)
  - All 36 service tests passing

- Update documentation:
  - Features.md: Document URL validation constraints
  - Backend-Development.md: Add SSRF prevention pattern section

Fixes SSRF vulnerability where authenticated users could supply
file://, ftp://, or private IP URLs and the backend would fetch them.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-04-26 12:57:23 +02:00

765 lines
24 KiB
Python

"""Blocklist service.
Manages blocklist source CRUD, URL preview, IP import (download → validate →
ban via fail2ban), and schedule persistence.
All ban operations target a dedicated fail2ban jail (default:
``"blocklist-import"``) so blocklist-origin bans are tracked separately from
regular bans. If that jail does not exist or fail2ban is unreachable, the
error is recorded in the import log and processing continues.
Schedule configuration is stored as JSON in the application settings table
under the key ``"blocklist_schedule"``.
"""
from __future__ import annotations
import asyncio
import json
from typing import TYPE_CHECKING
import aiohttp
import structlog
from app.exceptions import JailNotFoundError
from app.models.blocklist import (
BlocklistSource,
ImportLogEntry,
ImportLogListResponse,
ImportRunResult,
ImportSourceResult,
PreviewResponse,
ScheduleConfig,
ScheduleFrequency,
ScheduleInfo,
)
from app.repositories import blocklist_repo, import_log_repo, settings_repo
from app.utils.ip_utils import is_valid_ip, is_valid_network
if TYPE_CHECKING:
from collections.abc import Awaitable, Callable
import aiohttp
import aiosqlite
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from app.config import Settings
from app.services.geo_cache import GeoCache
log: structlog.stdlib.BoundLogger = structlog.get_logger()
#: Settings key used to persist the schedule config.
_SCHEDULE_SETTINGS_KEY: str = "blocklist_schedule"
#: fail2ban jail name for blocklist-origin bans.
BLOCKLIST_JAIL: str = "blocklist-import"
#: Maximum number of sample entries returned by the preview endpoint.
_PREVIEW_LINES: int = 20
#: Maximum bytes to download for a preview (first 64 KB).
_PREVIEW_MAX_BYTES: int = 65536
#: HTTP status codes that should be retried for blocklist downloads.
_BLOCKLIST_HTTP_RETRY_STATUSES: frozenset[int] = frozenset({429, 500, 502, 503, 504})
#: How many attempts to make for transient blocklist download failures.
_BLOCKLIST_HTTP_RETRY_ATTEMPTS: int = 2
#: Base backoff in seconds used between retry attempts.
_BLOCKLIST_HTTP_BACKOFF_BASE_SECONDS: float = 1.0
async def _download_text_with_retries(
http_session: aiohttp.ClientSession,
url: str,
timeout: aiohttp.ClientTimeout,
) -> tuple[int, str]:
"""Download text from *url* with a small retry policy for transient failures."""
last_exception: Exception | None = None
for attempt in range(1, _BLOCKLIST_HTTP_RETRY_ATTEMPTS + 1):
try:
async with http_session.get(url, timeout=timeout) as resp:
text = await resp.text(errors="replace")
if resp.status in _BLOCKLIST_HTTP_RETRY_STATUSES and attempt < _BLOCKLIST_HTTP_RETRY_ATTEMPTS:
backoff = _BLOCKLIST_HTTP_BACKOFF_BASE_SECONDS * (2 ** (attempt - 1))
log.warning(
"blocklist_download_retry",
url=url,
status=resp.status,
attempt=attempt,
backoff=backoff,
)
await asyncio.sleep(backoff)
continue
return resp.status, text
except Exception as exc: # noqa: BLE001
last_exception = exc
if attempt >= _BLOCKLIST_HTTP_RETRY_ATTEMPTS:
raise
backoff = _BLOCKLIST_HTTP_BACKOFF_BASE_SECONDS * (2 ** (attempt - 1))
log.warning(
"blocklist_download_retry_error",
url=url,
attempt=attempt,
error=repr(exc),
backoff=backoff,
)
await asyncio.sleep(backoff)
assert last_exception is not None
raise last_exception
# ---------------------------------------------------------------------------
# Source CRUD helpers
# ---------------------------------------------------------------------------
def _row_to_source(row: dict[str, object]) -> BlocklistSource:
"""Convert a repository row dict to a :class:`BlocklistSource`.
Args:
row: Dict with keys matching the ``blocklist_sources`` columns.
Returns:
A validated :class:`~app.models.blocklist.BlocklistSource` instance.
"""
return BlocklistSource.model_validate(row)
async def list_sources(db: aiosqlite.Connection) -> list[BlocklistSource]:
"""Return all configured blocklist sources.
Args:
db: Active application database connection.
Returns:
List of :class:`~app.models.blocklist.BlocklistSource` instances.
"""
rows = await blocklist_repo.list_sources(db)
return [_row_to_source(r) for r in rows]
async def get_source(
db: aiosqlite.Connection,
source_id: int,
) -> BlocklistSource | None:
"""Return a single blocklist source, or ``None`` if not found.
Args:
db: Active application database connection.
source_id: Primary key of the desired source.
Returns:
:class:`~app.models.blocklist.BlocklistSource` or ``None``.
"""
row = await blocklist_repo.get_source(db, source_id)
return _row_to_source(row) if row is not None else None
async def create_source(
db: aiosqlite.Connection,
name: str,
url: str,
*,
enabled: bool = True,
) -> BlocklistSource:
"""Create a new blocklist source and return the persisted record.
Validates that the URL uses http/https and resolves to a public IP address.
Args:
db: Active application database connection.
name: Human-readable display name.
url: URL of the blocklist text file (must be http/https and resolve to public IP).
enabled: Whether the source is active. Defaults to ``True``.
Returns:
The newly created :class:`~app.models.blocklist.BlocklistSource`.
Raises:
ValueError: If the URL fails SSRF validation.
"""
from app.utils.ip_utils import validate_blocklist_url
await validate_blocklist_url(url)
new_id = await blocklist_repo.create_source(db, name, url, enabled=enabled)
source = await get_source(db, new_id)
assert source is not None # noqa: S101
log.info("blocklist_source_created", id=new_id, name=name, url=url)
return source
async def update_source(
db: aiosqlite.Connection,
source_id: int,
*,
name: str | None = None,
url: str | None = None,
enabled: bool | None = None,
) -> BlocklistSource | None:
"""Update fields on a blocklist source.
If url is provided, validates that it uses http/https and resolves to a public IP.
Args:
db: Active application database connection.
source_id: Primary key of the source to modify.
name: New display name, or ``None`` to leave unchanged.
url: New URL, or ``None`` to leave unchanged (validated if provided).
enabled: New enabled state, or ``None`` to leave unchanged.
Returns:
Updated :class:`~app.models.blocklist.BlocklistSource`, or ``None``
if the source does not exist.
Raises:
ValueError: If the URL fails SSRF validation.
"""
if url is not None:
from app.utils.ip_utils import validate_blocklist_url
await validate_blocklist_url(url)
updated = await blocklist_repo.update_source(
db, source_id, name=name, url=url, enabled=enabled
)
if not updated:
return None
source = await get_source(db, source_id)
log.info("blocklist_source_updated", id=source_id)
return source
async def delete_source(db: aiosqlite.Connection, source_id: int) -> bool:
"""Delete a blocklist source.
Args:
db: Active application database connection.
source_id: Primary key of the source to delete.
Returns:
``True`` if the source was found and deleted, ``False`` otherwise.
"""
deleted = await blocklist_repo.delete_source(db, source_id)
if deleted:
log.info("blocklist_source_deleted", id=source_id)
return deleted
# ---------------------------------------------------------------------------
# Preview
# ---------------------------------------------------------------------------
async def preview_source(
url: str,
http_session: aiohttp.ClientSession,
*,
sample_lines: int = _PREVIEW_LINES,
) -> PreviewResponse:
"""Download the beginning of a blocklist URL and return a preview.
Args:
url: URL to download.
http_session: Shared :class:`aiohttp.ClientSession`.
sample_lines: Maximum number of lines to include in the preview.
Returns:
:class:`~app.models.blocklist.PreviewResponse` with a sample of
valid IP entries and validation statistics.
Raises:
ValueError: If the URL cannot be reached or returns a non-200 status.
"""
try:
status, raw = await _download_text_with_retries(
http_session,
url,
_aiohttp_timeout(10),
)
if status != 200:
raise ValueError(f"HTTP {status} from {url}")
except Exception as exc:
log.warning("blocklist_preview_failed", url=url, error=str(exc))
raise ValueError(str(exc)) from exc
lines = raw.splitlines()
entries: list[str] = []
valid = 0
skipped = 0
for line in lines:
stripped = line.strip()
if not stripped or stripped.startswith("#"):
continue
if is_valid_ip(stripped) or is_valid_network(stripped):
valid += 1
if len(entries) < sample_lines:
entries.append(stripped)
else:
skipped += 1
return PreviewResponse(
entries=entries,
total_lines=len(lines),
valid_count=valid,
skipped_count=skipped,
)
# ---------------------------------------------------------------------------
# Import
# ---------------------------------------------------------------------------
async def import_source(
source: BlocklistSource,
http_session: aiohttp.ClientSession,
socket_path: str,
db: aiosqlite.Connection,
*,
ban_ip: Callable[[str, str, str], Awaitable[None]],
geo_is_cached: Callable[[str], bool] | None = None,
geo_cache: GeoCache | None = None,
) -> ImportSourceResult:
"""Download and apply bans from a single blocklist source.
The function downloads the URL, validates each line as an IP address,
and bans valid IPv4/IPv6 addresses via fail2ban in
:data:`BLOCKLIST_JAIL`. CIDR ranges are counted as skipped since
fail2ban requires individual addresses. Any error encountered during
download is recorded and the result is returned without raising.
After a successful import the geo cache is pre-warmed by batch-resolving
all newly banned IPs. This ensures the dashboard and map show country
data immediately after import rather than facing cold-cache lookups.
Args:
source: The :class:`~app.models.blocklist.BlocklistSource` to import.
http_session: Shared :class:`aiohttp.ClientSession`.
socket_path: Path to the fail2ban Unix socket.
db: Application database for logging.
Returns:
:class:`~app.models.blocklist.ImportSourceResult` with counters.
"""
# --- Download ---
try:
status, content = await _download_text_with_retries(
http_session,
source.url,
_aiohttp_timeout(30),
)
if status != 200:
error_msg = f"HTTP {status}"
await _log_result(db, source, 0, 0, error_msg)
log.warning("blocklist_import_download_failed", url=source.url, status=status)
return ImportSourceResult(
source_id=source.id,
source_url=source.url,
ips_imported=0,
ips_skipped=0,
error=error_msg,
)
except Exception as exc:
error_msg = str(exc)
await _log_result(db, source, 0, 0, error_msg)
log.warning("blocklist_import_download_error", url=source.url, error=error_msg)
return ImportSourceResult(
source_id=source.id,
source_url=source.url,
ips_imported=0,
ips_skipped=0,
error=error_msg,
)
# --- Validate and ban ---
imported = 0
skipped = 0
ban_error: str | None = None
imported_ips: list[str] = []
ban_ip_fn = ban_ip
for line in content.splitlines():
stripped = line.strip()
if not stripped or stripped.startswith("#"):
continue
if not is_valid_ip(stripped):
# Skip CIDRs and malformed entries gracefully.
skipped += 1
continue
try:
await ban_ip_fn(socket_path, BLOCKLIST_JAIL, stripped)
imported += 1
imported_ips.append(stripped)
except JailNotFoundError as exc:
# The target jail does not exist in fail2ban — there is no point
# continuing because every subsequent ban would also fail.
ban_error = str(exc)
log.warning(
"blocklist_jail_not_found",
jail=BLOCKLIST_JAIL,
error=str(exc),
)
break
except Exception as exc:
skipped += 1
if ban_error is None:
ban_error = str(exc)
log.debug("blocklist_ban_failed", ip=stripped, error=str(exc))
await _log_result(db, source, imported, skipped, ban_error)
log.info(
"blocklist_source_imported",
source_id=source.id,
url=source.url,
imported=imported,
skipped=skipped,
error=ban_error,
)
# --- Pre-warm geo cache for newly imported IPs ---
if imported_ips and geo_is_cached is not None:
uncached_ips: list[str] = [ip for ip in imported_ips if not geo_is_cached(ip)]
skipped_geo: int = len(imported_ips) - len(uncached_ips)
if skipped_geo > 0:
log.info(
"blocklist_geo_prewarm_cache_hit",
source_id=source.id,
skipped=skipped_geo,
to_lookup=len(uncached_ips),
)
if uncached_ips and geo_cache is not None:
try:
await geo_cache.lookup_batch(uncached_ips, http_session, db=db)
log.info(
"blocklist_geo_prewarm_complete",
source_id=source.id,
count=len(uncached_ips),
)
except Exception as exc: # noqa: BLE001
log.warning(
"blocklist_geo_prewarm_failed",
source_id=source.id,
error=str(exc),
)
return ImportSourceResult(
source_id=source.id,
source_url=source.url,
ips_imported=imported,
ips_skipped=skipped,
error=ban_error,
)
async def import_all(
db: aiosqlite.Connection,
http_session: aiohttp.ClientSession,
socket_path: str,
*,
ban_ip: Callable[[str, str, str], Awaitable[None]],
geo_is_cached: Callable[[str], bool] | None = None,
geo_cache: GeoCache | None = None,
) -> ImportRunResult:
"""Import all enabled blocklist sources.
Iterates over every source with ``enabled = True``, calls
:func:`import_source` for each, and aggregates the results.
Args:
db: Application database connection.
http_session: Shared :class:`aiohttp.ClientSession`.
socket_path: fail2ban socket path.
Returns:
:class:`~app.models.blocklist.ImportRunResult` with aggregated
counters and per-source results.
"""
sources = await blocklist_repo.list_enabled_sources(db)
results: list[ImportSourceResult] = []
total_imported = 0
total_skipped = 0
errors_count = 0
for row in sources:
source = _row_to_source(row)
result = await import_source(
source,
http_session,
socket_path,
db,
geo_is_cached=geo_is_cached,
geo_cache=geo_cache,
ban_ip=ban_ip,
)
results.append(result)
total_imported += result.ips_imported
total_skipped += result.ips_skipped
if result.error is not None:
errors_count += 1
log.info(
"blocklist_import_all_complete",
sources=len(sources),
total_imported=total_imported,
total_skipped=total_skipped,
errors=errors_count,
)
return ImportRunResult(
results=results,
total_imported=total_imported,
total_skipped=total_skipped,
errors_count=errors_count,
)
# ---------------------------------------------------------------------------
# Schedule
# ---------------------------------------------------------------------------
_DEFAULT_SCHEDULE = ScheduleConfig()
#: Stable APScheduler job id for the blocklist import job.
JOB_ID: str = "blocklist_import"
def _get_job_next_run_at(scheduler: AsyncIOScheduler) -> str | None:
"""Return the next scheduled run time as an ISO 8601 string."""
job = scheduler.get_job(JOB_ID)
if job is None or job.next_run_time is None:
return None
return job.next_run_time.isoformat()
def schedule_blocklist_job(
scheduler: AsyncIOScheduler,
settings: Settings,
http_session: aiohttp.ClientSession,
config: ScheduleConfig,
run_import_callback: Callable[[Settings, aiohttp.ClientSession], Awaitable[None]],
) -> None:
"""Register or replace the scheduled blocklist import job."""
if scheduler.get_job(JOB_ID):
scheduler.remove_job(JOB_ID)
kwargs: dict[str, object] = {
"settings": settings,
"http_session": http_session,
}
if config.frequency == ScheduleFrequency.hourly:
trigger_type = "interval"
trigger_kwargs = {"hours": config.interval_hours}
elif config.frequency == ScheduleFrequency.weekly:
trigger_type = "cron"
trigger_kwargs = {
"day_of_week": config.day_of_week,
"hour": config.hour,
"minute": config.minute,
}
else:
trigger_type = "cron"
trigger_kwargs = {
"hour": config.hour,
"minute": config.minute,
}
scheduler.add_job(
run_import_callback,
trigger=trigger_type,
id=JOB_ID,
kwargs=kwargs,
**trigger_kwargs,
)
log.info(
"blocklist_import_scheduled",
frequency=config.frequency,
trigger=trigger_type,
trigger_kwargs=trigger_kwargs,
)
async def get_schedule(db: aiosqlite.Connection) -> ScheduleConfig:
"""Read the import schedule config from the settings table.
Returns the default config (daily at 03:00 UTC) if no schedule has been
saved yet.
Args:
db: Active application database connection.
Returns:
The stored (or default) :class:`~app.models.blocklist.ScheduleConfig`.
"""
raw = await settings_repo.get_setting(db, _SCHEDULE_SETTINGS_KEY)
if raw is None:
return _DEFAULT_SCHEDULE
try:
data = json.loads(raw)
return ScheduleConfig.model_validate(data)
except Exception:
log.warning("blocklist_schedule_invalid", raw=raw)
return _DEFAULT_SCHEDULE
async def set_schedule(
db: aiosqlite.Connection,
config: ScheduleConfig,
) -> ScheduleConfig:
"""Persist a new schedule configuration.
Args:
db: Active application database connection.
config: The :class:`~app.models.blocklist.ScheduleConfig` to store.
Returns:
The saved configuration (same object after validation).
"""
await settings_repo.set_setting(
db, _SCHEDULE_SETTINGS_KEY, config.model_dump_json()
)
log.info("blocklist_schedule_updated", frequency=config.frequency, hour=config.hour)
return config
async def get_schedule_info(
db: aiosqlite.Connection,
next_run_at: str | None,
) -> ScheduleInfo:
"""Return the schedule config together with last-run metadata.
Args:
db: Active application database connection.
next_run_at: ISO 8601 string of the next scheduled run, or ``None``
if not yet scheduled (provided by the caller from APScheduler).
Returns:
:class:`~app.models.blocklist.ScheduleInfo` combining config and
runtime metadata.
"""
config = await get_schedule(db)
last_log = await import_log_repo.get_last_log(db)
last_run_at = last_log["timestamp"] if last_log else None
last_run_errors: bool | None = (last_log["errors"] is not None) if last_log else None
return ScheduleInfo(
config=config,
next_run_at=next_run_at,
last_run_at=last_run_at,
last_run_errors=last_run_errors,
)
async def get_schedule_info_with_runtime(
db: aiosqlite.Connection,
scheduler: AsyncIOScheduler,
) -> ScheduleInfo:
"""Return schedule info enriched with runtime scheduler metadata."""
next_run_at = _get_job_next_run_at(scheduler)
return await get_schedule_info(db, next_run_at)
async def update_schedule(
db: aiosqlite.Connection,
scheduler: AsyncIOScheduler,
http_session: aiohttp.ClientSession,
settings: Settings,
config: ScheduleConfig,
run_import_callback: Callable[[Settings, aiohttp.ClientSession], Awaitable[None]],
) -> ScheduleInfo:
"""Persist a new schedule config and re-register the scheduled job."""
await set_schedule(db, config)
schedule_blocklist_job(
scheduler,
settings,
http_session,
config,
run_import_callback,
)
return await get_schedule_info(db, _get_job_next_run_at(scheduler))
async def list_import_logs(
db: aiosqlite.Connection,
*,
source_id: int | None = None,
page: int = 1,
page_size: int = 50,
) -> ImportLogListResponse:
"""Return a paginated list of import log entries.
Args:
db: Active application database connection.
source_id: Optional filter to only return logs for a specific source.
page: 1-based page number.
page_size: Items per page.
Returns:
:class:`~app.models.blocklist.ImportLogListResponse`.
"""
items, total = await import_log_repo.list_logs(
db, source_id=source_id, page=page, page_size=page_size
)
total_pages = import_log_repo.compute_total_pages(total, page_size)
return ImportLogListResponse(
items=[ImportLogEntry.model_validate(i) for i in items],
total=total,
page=page,
page_size=page_size,
total_pages=total_pages,
)
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _aiohttp_timeout(seconds: float) -> aiohttp.ClientTimeout:
"""Return an :class:`aiohttp.ClientTimeout` with the given total timeout.
Args:
seconds: Total timeout in seconds.
Returns:
An :class:`aiohttp.ClientTimeout` instance.
"""
import aiohttp # noqa: PLC0415
return aiohttp.ClientTimeout(total=seconds)
async def _log_result(
db: aiosqlite.Connection,
source: BlocklistSource,
ips_imported: int,
ips_skipped: int,
error: str | None,
) -> None:
"""Write an import log entry for a completed source import.
Args:
db: Application database connection.
source: The source that was imported.
ips_imported: Count of successfully banned IPs.
ips_skipped: Count of skipped/invalid entries.
error: Error string, or ``None`` on success.
"""
await import_log_repo.add_log(
db,
source_id=source.id,
source_url=source.url,
ips_imported=ips_imported,
ips_skipped=ips_skipped,
errors=error,
)