TASK-009: Mitigate SSRF vulnerability in blocklist URL validation

- Change BlocklistSourceCreate.url from str to AnyHttpUrl (Pydantic type)
  - Rejects non-http schemes (file://, ftp://, etc.) at model boundary

- Add is_private_ip() utility to detect RFC 1918 private ranges:
  - 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16 (RFC 1918)
  - 127.0.0.0/8, ::1/128 (loopback)
  - 169.254.0.0/16, fe80::/10 (link-local)
  - IPv6 site-local, multicast, and reserved ranges

- Add async validate_blocklist_url() function:
  - Resolves hostname via DNS using loop.run_in_executor()
  - Rejects if hostname resolves to private/reserved IP
  - Raises ValueError on validation failure

- Integrate validation into service layer:
  - create_source() calls validate_blocklist_url() before persist
  - update_source() conditionally validates if url provided
  - Both raise ValueError on failure

- Update router endpoints with error handling:
  - create_blocklist() and update_blocklist() catch ValueError
  - Return HTTP 400 Bad Request with descriptive error message

- Add comprehensive test coverage (9 new SSRF tests):
  - file://, ftp://, localhost, 127.0.0.1, 192.168.x.x
  - 10.x.x.x, 172.16.x.x, 169.254.x.x (link-local)
  - Valid public URLs (passes validation)
  - All 36 service tests passing

- Update documentation:
  - Features.md: Document URL validation constraints
  - Backend-Development.md: Add SSRF prevention pattern section

Fixes SSRF vulnerability where authenticated users could supply
file://, ftp://, or private IP URLs and the backend would fetch them.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
2026-04-26 12:57:23 +02:00
parent a5b55d1248
commit 4ab767e3d4
9 changed files with 291 additions and 66 deletions

View File

@@ -8,7 +8,7 @@ from __future__ import annotations
from enum import StrEnum
from pydantic import BaseModel, ConfigDict, Field
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
# ---------------------------------------------------------------------------
# Blocklist source
@@ -29,22 +29,30 @@ class BlocklistSource(BaseModel):
class BlocklistSourceCreate(BaseModel):
"""Payload for ``POST /api/blocklists``."""
"""Payload for ``POST /api/blocklists``.
URL must use http/https scheme. The hostname must resolve to a public IP
(not private, loopback, link-local, or reserved). Validation happens
asynchronously in the service layer.
"""
model_config = ConfigDict(strict=True)
name: str = Field(..., min_length=1, max_length=100, description="Human-readable source name.")
url: str = Field(..., min_length=1, description="URL of the blocklist file.")
url: AnyHttpUrl = Field(..., description="URL of the blocklist file (http/https only).")
enabled: bool = Field(default=True)
class BlocklistSourceUpdate(BaseModel):
"""Payload for ``PUT /api/blocklists/{id}``. All fields are optional."""
"""Payload for ``PUT /api/blocklists/{id}``. All fields are optional.
If URL is provided, it must use http/https scheme.
"""
model_config = ConfigDict(strict=True)
name: str | None = Field(default=None, min_length=1, max_length=100)
url: str | None = Field(default=None)
url: AnyHttpUrl | None = Field(default=None)
enabled: bool | None = Field(default=None)

View File

@@ -97,10 +97,16 @@ async def create_blocklist(
Returns:
The newly created :class:`~app.models.blocklist.BlocklistSource`.
Raises:
HTTPException: 400 if URL validation fails.
"""
return await blocklist_service.create_source(
db, payload.name, payload.url, enabled=payload.enabled
)
try:
return await blocklist_service.create_source(
db, payload.name, str(payload.url), enabled=payload.enabled
)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
# ---------------------------------------------------------------------------
@@ -283,15 +289,19 @@ async def update_blocklist(
_auth: Validated session — enforces authentication.
Raises:
HTTPException: 400 if URL validation fails.
HTTPException: 404 if the source does not exist.
"""
updated = await blocklist_service.update_source(
db,
source_id,
name=payload.name,
url=payload.url,
enabled=payload.enabled,
)
try:
updated = await blocklist_service.update_source(
db,
source_id,
name=payload.name,
url=str(payload.url) if payload.url is not None else None,
enabled=payload.enabled,
)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
if updated is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Blocklist source not found.")
return updated

View File

@@ -166,15 +166,24 @@ async def create_source(
) -> BlocklistSource:
"""Create a new blocklist source and return the persisted record.
Validates that the URL uses http/https and resolves to a public IP address.
Args:
db: Active application database connection.
name: Human-readable display name.
url: URL of the blocklist text file.
url: URL of the blocklist text file (must be http/https and resolve to public IP).
enabled: Whether the source is active. Defaults to ``True``.
Returns:
The newly created :class:`~app.models.blocklist.BlocklistSource`.
Raises:
ValueError: If the URL fails SSRF validation.
"""
from app.utils.ip_utils import validate_blocklist_url
await validate_blocklist_url(url)
new_id = await blocklist_repo.create_source(db, name, url, enabled=enabled)
source = await get_source(db, new_id)
assert source is not None # noqa: S101
@@ -192,17 +201,27 @@ async def update_source(
) -> BlocklistSource | None:
"""Update fields on a blocklist source.
If url is provided, validates that it uses http/https and resolves to a public IP.
Args:
db: Active application database connection.
source_id: Primary key of the source to modify.
name: New display name, or ``None`` to leave unchanged.
url: New URL, or ``None`` to leave unchanged.
url: New URL, or ``None`` to leave unchanged (validated if provided).
enabled: New enabled state, or ``None`` to leave unchanged.
Returns:
Updated :class:`~app.models.blocklist.BlocklistSource`, or ``None``
if the source does not exist.
Raises:
ValueError: If the URL fails SSRF validation.
"""
if url is not None:
from app.utils.ip_utils import validate_blocklist_url
await validate_blocklist_url(url)
updated = await blocklist_repo.update_source(
db, source_id, name=name, url=url, enabled=enabled
)

View File

@@ -4,7 +4,10 @@ All IP handling in BanGUI goes through these helpers to enforce consistency
and prevent malformed addresses from reaching fail2ban.
"""
import asyncio
import ipaddress
import socket
from urllib.parse import urlparse
def is_valid_ip(address: str) -> bool:
@@ -99,3 +102,97 @@ def ip_version(address: str) -> int:
ValueError: If *address* is not a valid IP address.
"""
return ipaddress.ip_address(address).version
def is_private_ip(address: str) -> bool:
"""Return ``True`` if *address* is a private or reserved IP address.
Private ranges include:
- RFC 1918: 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16
- Loopback: 127.0.0.0/8 (IPv4), ::1/128 (IPv6)
- Link-local: 169.254.0.0/16 (IPv4), fe80::/10 (IPv6)
- IPv6 ULA: fc00::/7
- Multicast and other reserved ranges
Args:
address: A valid IP address string.
Returns:
``True`` if the address is private or reserved, ``False`` if it is public.
Raises:
ValueError: If *address* is not a valid IP address.
"""
ip = ipaddress.ip_address(address)
return (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_multicast
or ip.is_reserved
)
async def validate_blocklist_url(url: str) -> None:
"""Validate that a blocklist URL points to a public HTTP(S) endpoint.
Checks that:
- The URL uses HTTP or HTTPS scheme
- The hostname resolves to a public (non-private, non-reserved) IP address
- IPv4-mapped IPv6 addresses are checked against IPv4 private ranges
Performs DNS resolution asynchronously to check the resolved IP.
This is a point-in-time check; DNS rebinding attacks may still be possible
at actual fetch time. Callers should re-validate the final connection
in the HTTP client layer.
Args:
url: The blocklist URL to validate.
Raises:
ValueError: If the URL has an invalid scheme, hostname cannot be resolved,
or the resolved IP is private/reserved.
"""
try:
parsed = urlparse(url)
except Exception as exc:
raise ValueError(f"Invalid URL format: {exc}") from exc
if parsed.scheme not in ("http", "https"):
raise ValueError(
f"Invalid scheme '{parsed.scheme}': only http and https are allowed"
)
if not parsed.hostname:
raise ValueError("URL has no hostname")
hostname = parsed.hostname
try:
loop = asyncio.get_event_loop()
addrinfo = await loop.run_in_executor(
None,
socket.getaddrinfo,
hostname,
parsed.port or 80,
socket.AF_UNSPEC,
socket.SOCK_STREAM,
)
except socket.gaierror as exc:
raise ValueError(f"Cannot resolve hostname '{hostname}': {exc}") from exc
except Exception as exc:
raise ValueError(f"DNS resolution error for '{hostname}': {exc}") from exc
if not addrinfo:
raise ValueError(f"No address resolved for hostname '{hostname}'")
for family, socktype, proto, canonname, sockaddr in addrinfo:
ip_str: str = sockaddr[0] # type: ignore[assignment]
try:
if is_private_ip(ip_str):
raise ValueError(
f"Hostname '{hostname}' resolves to private/reserved IP: {ip_str}"
)
except ipaddress.AddressValueError as exc:
raise ValueError(f"Invalid IP address: {ip_str}") from exc

View File

@@ -29,7 +29,7 @@ from app.models.blocklist import (
# ---------------------------------------------------------------------------
_SETUP_PAYLOAD = {
"master_password": "testpassword1",
"master_password": "TestPassword1!",
"database_path": "bangui.db",
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
"timezone": "UTC",
@@ -182,8 +182,10 @@ class TestListBlocklists:
class TestCreateBlocklist:
async def test_create_returns_201(self, bl_client: AsyncClient) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_create_returns_201(self, mock_validate: AsyncMock, bl_client: AsyncClient) -> None:
"""POST /api/blocklists creates a source and returns HTTP 201."""
mock_validate.return_value = None
with patch(
"app.routers.blocklist.blocklist_service.create_source",
new=AsyncMock(return_value=_make_source()),
@@ -194,8 +196,10 @@ class TestCreateBlocklist:
)
assert resp.status_code == 201
async def test_create_source_id_in_response(self, bl_client: AsyncClient) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_create_source_id_in_response(self, mock_validate: AsyncMock, bl_client: AsyncClient) -> None:
"""Created source response includes the id field."""
mock_validate.return_value = None
with patch(
"app.routers.blocklist.blocklist_service.create_source",
new=AsyncMock(return_value=_make_source(42)),

View File

@@ -54,8 +54,10 @@ def _make_session(text: str, status: int = 200) -> MagicMock:
class TestSourceCRUD:
async def test_create_and_get(self, db: aiosqlite.Connection) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_create_and_get(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""create_source persists and get_source retrieves a source."""
mock_validate.return_value = None
source = await blocklist_service.create_source(db, "Test", "https://t.test/")
assert isinstance(source, BlocklistSource)
assert source.name == "Test"
@@ -75,15 +77,19 @@ class TestSourceCRUD:
sources = await blocklist_service.list_sources(db)
assert sources == []
async def test_list_sources_returns_all(self, db: aiosqlite.Connection) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_list_sources_returns_all(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""list_sources returns all created sources."""
mock_validate.return_value = None
await blocklist_service.create_source(db, "A", "https://a.test/")
await blocklist_service.create_source(db, "B", "https://b.test/")
sources = await blocklist_service.list_sources(db)
assert len(sources) == 2
async def test_update_source_fields(self, db: aiosqlite.Connection) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_update_source_fields(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""update_source modifies specified fields."""
mock_validate.return_value = None
source = await blocklist_service.create_source(db, "Original", "https://orig.test/")
updated = await blocklist_service.update_source(db, source.id, name="Updated", enabled=False)
assert updated is not None
@@ -95,8 +101,10 @@ class TestSourceCRUD:
result = await blocklist_service.update_source(db, 9999, name="Ghost")
assert result is None
async def test_delete_source(self, db: aiosqlite.Connection) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_delete_source(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""delete_source removes a source and returns True."""
mock_validate.return_value = None
source = await blocklist_service.create_source(db, "Del", "https://del.test/")
deleted = await blocklist_service.delete_source(db, source.id)
assert deleted is True
@@ -167,8 +175,10 @@ class TestPreview:
class TestImport:
async def test_import_source_bans_valid_ips(self, db: aiosqlite.Connection) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_import_source_bans_valid_ips(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""import_source calls ban_ip for every valid IP in the blocklist."""
mock_validate.return_value = None
content = "1.2.3.4\n5.6.7.8\n# skip me\n"
session = _make_session(content)
@@ -192,8 +202,10 @@ class TestImport:
assert result.error is None
assert mock_ban.call_count == 2
async def test_import_source_skips_cidrs(self, db: aiosqlite.Connection) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_import_source_skips_cidrs(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""import_source skips CIDR ranges (fail2ban expects individual IPs)."""
mock_validate.return_value = None
content = "1.2.3.4\n10.0.0.0/24\n"
session = _make_session(content)
source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/")
@@ -212,8 +224,10 @@ class TestImport:
assert result.ips_imported == 1
assert result.ips_skipped == 1
async def test_import_source_records_download_error(self, db: aiosqlite.Connection) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_import_source_records_download_error(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""import_source records an error and returns 0 imported on HTTP failure."""
mock_validate.return_value = None
session = _make_session("", status=503)
source = await blocklist_service.create_source(db, "Err Source", "https://err.test/")
@@ -230,9 +244,11 @@ class TestImport:
assert result.ips_imported == 0
assert result.error is not None
async def test_import_source_aborts_on_jail_not_found(self, db: aiosqlite.Connection) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_import_source_aborts_on_jail_not_found(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""import_source aborts immediately and records an error when the target jail
does not exist in fail2ban instead of silently skipping every IP."""
mock_validate.return_value = None
from app.services.jail_service import JailNotFoundError
from app.services import ban_service
@@ -262,8 +278,10 @@ class TestImport:
assert result.error is not None
assert "not found" in result.error.lower() or "blocklist-import" in result.error
async def test_import_all_runs_all_enabled(self, db: aiosqlite.Connection) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_import_all_runs_all_enabled(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""import_all aggregates results across all enabled sources."""
mock_validate.return_value = None
await blocklist_service.create_source(db, "S1", "https://s1.test/")
s2 = await blocklist_service.create_source(db, "S2", "https://s2.test/", enabled=False)
_ = s2 # noqa: F841
@@ -400,10 +418,12 @@ class TestSchedule:
class TestGeoPrewarmCacheFilter:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_import_source_skips_cached_ips_for_geo_prewarm(
self, db: aiosqlite.Connection
self, mock_validate: AsyncMock, db: aiosqlite.Connection
) -> None:
"""import_source only sends uncached IPs to geo_service.lookup_batch."""
mock_validate.return_value = None
content = "1.2.3.4\n5.6.7.8\n9.10.11.12\n"
session = _make_session(content)
source = await blocklist_service.create_source(
@@ -416,7 +436,11 @@ class TestGeoPrewarmCacheFilter:
from app.services import ban_service
mock_batch = AsyncMock(return_value={})
mock_lookup = AsyncMock(return_value={})
mock_geo_cache = MagicMock()
mock_geo_cache.is_cached = _mock_is_cached
mock_geo_cache.lookup_batch = mock_lookup
with patch("app.services.ban_service.ban_ip", new_callable=AsyncMock):
result = await blocklist_service.import_source(
source,
@@ -425,17 +449,76 @@ class TestGeoPrewarmCacheFilter:
db,
ban_ip=ban_service.ban_ip,
geo_is_cached=_mock_is_cached,
geo_batch_lookup=mock_batch,
geo_cache=mock_geo_cache,
)
assert result.ips_imported == 3
# lookup_batch should receive only the 2 uncached IPs.
mock_batch.assert_called_once()
call_ips = mock_batch.call_args[0][0]
mock_lookup.assert_called_once()
call_ips = mock_lookup.call_args[0][0]
assert "1.2.3.4" not in call_ips
assert set(call_ips) == {"5.6.7.8", "9.10.11.12"}
# ---------------------------------------------------------------------------
# URL Validation (SSRF Prevention)
# ---------------------------------------------------------------------------
class TestURLValidation:
"""Test SSRF protection by validating blocklist URLs."""
async def test_create_source_rejects_file_url(self, db: aiosqlite.Connection) -> None:
"""create_source rejects file:// URLs."""
with pytest.raises(ValueError, match="Invalid scheme"):
await blocklist_service.create_source(db, "Bad", "file:///etc/passwd")
async def test_create_source_rejects_ftp_url(self, db: aiosqlite.Connection) -> None:
"""create_source rejects ftp:// URLs."""
with pytest.raises(ValueError, match="Invalid scheme"):
await blocklist_service.create_source(db, "Bad", "ftp://evil.com/file.txt")
async def test_create_source_rejects_localhost(self, db: aiosqlite.Connection) -> None:
"""create_source rejects localhost (127.0.0.1)."""
with pytest.raises(ValueError, match="private|reserved"):
await blocklist_service.create_source(db, "Bad", "http://127.0.0.1/list")
async def test_create_source_rejects_localhost_name(self, db: aiosqlite.Connection) -> None:
"""create_source rejects localhost hostname."""
with pytest.raises(ValueError, match="private|reserved"):
await blocklist_service.create_source(db, "Bad", "http://localhost/list")
async def test_create_source_rejects_private_network(self, db: aiosqlite.Connection) -> None:
"""create_source rejects private RFC 1918 networks (10.0.0.0/8)."""
with pytest.raises(ValueError, match="private|reserved"):
await blocklist_service.create_source(db, "Bad", "http://10.0.0.1/list")
async def test_create_source_rejects_private_network_172(self, db: aiosqlite.Connection) -> None:
"""create_source rejects private RFC 1918 networks (172.16.0.0/12)."""
with pytest.raises(ValueError, match="private|reserved"):
await blocklist_service.create_source(db, "Bad", "http://172.16.0.1/list")
async def test_create_source_rejects_private_network_192(self, db: aiosqlite.Connection) -> None:
"""create_source rejects private RFC 1918 networks (192.168.0.0/16)."""
with pytest.raises(ValueError, match="private|reserved"):
await blocklist_service.create_source(db, "Bad", "http://192.168.1.1/list")
async def test_create_source_rejects_link_local(self, db: aiosqlite.Connection) -> None:
"""create_source rejects link-local addresses (169.254.x.x)."""
with pytest.raises(ValueError, match="private|reserved"):
await blocklist_service.create_source(db, "Bad", "http://169.254.169.254/latest/meta-data")
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_create_source_accepts_valid_public_url(
self, mock_validate: AsyncMock, db: aiosqlite.Connection
) -> None:
"""create_source accepts valid public HTTPS URLs (validation mocked)."""
mock_validate.return_value = None
source = await blocklist_service.create_source(db, "Good", "https://example.com/list.txt")
assert source.name == "Good"
assert source.url == "https://example.com/list.txt"
class TestImportLogPagination:
async def test_list_import_logs_empty(self, db: aiosqlite.Connection) -> None:
"""list_import_logs returns an empty page when no logs exist."""