refactoring-backend #3
@@ -783,6 +783,31 @@ To adopt a Redis backend:
|
||||
- Handle edge cases explicitly: empty lists, `None` values, negative numbers, empty strings.
|
||||
- Use type narrowing and exhaustive pattern matching (`match` / `case`) to eliminate impossible states.
|
||||
|
||||
### 14.12 SSRF Prevention (Server-Side Request Forgery)
|
||||
|
||||
When user-supplied URLs are fetched by the backend, validate them before making any HTTP requests:
|
||||
|
||||
1. **Use Pydantic's `AnyHttpUrl` type** to restrict schemes to `http://` and `https://` only.
|
||||
- Rejects `file://`, `ftp://`, `gopher://`, and other non-http schemes at the model boundary.
|
||||
|
||||
2. **Validate resolved IP addresses** before fetching:
|
||||
- Parse the hostname and resolve it via DNS (using `socket.getaddrinfo()`).
|
||||
- Use `ipaddress.ip_address().is_private` to reject private/reserved ranges:
|
||||
- RFC 1918: `10.0.0.0/8`, `172.16.0.0/12`, `192.168.0.0/16`
|
||||
- Loopback: `127.0.0.0/8`, `::1/128`
|
||||
- Link-local: `169.254.0.0/16`, `fe80::/10`
|
||||
- IPv6 site-local, multicast, and reserved ranges.
|
||||
- Raise `ValueError` if validation fails; let the router convert it to HTTP 400.
|
||||
|
||||
3. **Guard against DNS rebinding**:
|
||||
- Validate DNS at URL creation/validation time (performed during request deserialization).
|
||||
- For additional safety, re-validate the connection IP at HTTP client time (e.g., custom `aiohttp.TCPConnector` can inspect the resolved address during connect).
|
||||
|
||||
4. **Example implementation** (see `backend/app/utils/ip_utils.py`):
|
||||
- `is_private_ip(ip_str: str) → bool`: Checks if IP is private/reserved/loopback/link-local.
|
||||
- `async validate_blocklist_url(url: AnyHttpUrl) → None`: Async DNS resolution + private IP check.
|
||||
- Service layer calls `await validate_blocklist_url(url)` before persisting; router catches `ValueError` and returns 400.
|
||||
|
||||
---
|
||||
|
||||
## 16. Quick Reference — Do / Don't
|
||||
|
||||
@@ -311,6 +311,17 @@ Automated downloading and applying of external IP blocklists to block known mali
|
||||
- Support for plain-text lists with one IP address per line.
|
||||
- Preview the contents of a blocklist URL before enabling it (download and display a sample of entries).
|
||||
|
||||
#### URL Validation & Security
|
||||
|
||||
- **Scheme restriction:** Only `http://` and `https://` schemes are accepted. `file://`, `ftp://`, and other schemes are rejected.
|
||||
- **Hostname validation:** The hostname is resolved via DNS and the resulting IP address is validated to prevent SSRF attacks:
|
||||
- Private IP ranges (`10.0.0.0/8`, `172.16.0.0/12`, `192.168.0.0/16`) are rejected.
|
||||
- Loopback addresses (`127.0.0.1`, `::1`) are rejected.
|
||||
- Link-local addresses (`169.254.0.0/16`, `fe80::/10`) are rejected.
|
||||
- Reserved and multicast addresses are rejected.
|
||||
- **Error handling:** If a URL fails validation (invalid scheme, unresolvable hostname, or resolves to a private IP), the API returns a `400 Bad Request` with a descriptive error message.
|
||||
- **Ports:** URLs may specify custom ports (e.g. `https://example.com:8443/list.txt`), but the hostname must still resolve to a public IP address.
|
||||
|
||||
### Schedule
|
||||
|
||||
- Configure when the blocklist import runs using a simple time-and-frequency picker (no raw cron syntax required).
|
||||
|
||||
@@ -1,35 +1,3 @@
|
||||
## TASK-008 — `delete_expired_sessions` never scheduled — sessions table grows unbounded
|
||||
|
||||
**Severity:** Medium
|
||||
|
||||
### Where found
|
||||
`backend/app/repositories/session_repo.py` — `delete_expired_sessions()` exists but is never called from any task or lifespan handler.
|
||||
|
||||
### Why this is needed
|
||||
Expired sessions are only removed individually when that specific token is validated and found expired. The bulk cleanup function is never called. Over months of operation, the `sessions` table accumulates every session ever created and is never trimmed, increasing DB size and degrading query performance.
|
||||
|
||||
### Goal
|
||||
Periodically purge expired sessions from the database.
|
||||
|
||||
### What to do
|
||||
1. Create `backend/app/tasks/session_cleanup.py` following the same pattern as `geo_cache_flush.py`.
|
||||
2. Schedule it as an interval job (e.g., every 6 hours) in `startup_shared_resources`.
|
||||
3. The task should call `session_repo.delete_expired_sessions(db, now_iso)` and log how many rows were deleted.
|
||||
|
||||
### Possible traps and issues
|
||||
- The task must use `task_db(settings)` (not the request-scoped `get_db`) to open its own connection.
|
||||
- Log the count of deleted rows at `info` level, not `debug`, so administrators can see the cleanup is running.
|
||||
|
||||
### Docs changes needed
|
||||
- `Architekture.md` — add `session_cleanup` to the scheduled tasks table.
|
||||
- `Backend-Development.md` — background task patterns.
|
||||
|
||||
### Doc references
|
||||
- [Architekture.md](Architekture.md) — background tasks
|
||||
- [Backend-Development.md](Backend-Development.md) — scheduled tasks
|
||||
|
||||
---
|
||||
|
||||
## TASK-009 — Blocklist URL has no scheme/host validation — SSRF risk
|
||||
|
||||
**Severity:** High
|
||||
|
||||
@@ -8,7 +8,7 @@ from __future__ import annotations
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Blocklist source
|
||||
@@ -29,22 +29,30 @@ class BlocklistSource(BaseModel):
|
||||
|
||||
|
||||
class BlocklistSourceCreate(BaseModel):
|
||||
"""Payload for ``POST /api/blocklists``."""
|
||||
"""Payload for ``POST /api/blocklists``.
|
||||
|
||||
URL must use http/https scheme. The hostname must resolve to a public IP
|
||||
(not private, loopback, link-local, or reserved). Validation happens
|
||||
asynchronously in the service layer.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=100, description="Human-readable source name.")
|
||||
url: str = Field(..., min_length=1, description="URL of the blocklist file.")
|
||||
url: AnyHttpUrl = Field(..., description="URL of the blocklist file (http/https only).")
|
||||
enabled: bool = Field(default=True)
|
||||
|
||||
|
||||
class BlocklistSourceUpdate(BaseModel):
|
||||
"""Payload for ``PUT /api/blocklists/{id}``. All fields are optional."""
|
||||
"""Payload for ``PUT /api/blocklists/{id}``. All fields are optional.
|
||||
|
||||
If URL is provided, it must use http/https scheme.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(strict=True)
|
||||
|
||||
name: str | None = Field(default=None, min_length=1, max_length=100)
|
||||
url: str | None = Field(default=None)
|
||||
url: AnyHttpUrl | None = Field(default=None)
|
||||
enabled: bool | None = Field(default=None)
|
||||
|
||||
|
||||
|
||||
@@ -97,10 +97,16 @@ async def create_blocklist(
|
||||
|
||||
Returns:
|
||||
The newly created :class:`~app.models.blocklist.BlocklistSource`.
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if URL validation fails.
|
||||
"""
|
||||
return await blocklist_service.create_source(
|
||||
db, payload.name, payload.url, enabled=payload.enabled
|
||||
)
|
||||
try:
|
||||
return await blocklist_service.create_source(
|
||||
db, payload.name, str(payload.url), enabled=payload.enabled
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -283,15 +289,19 @@ async def update_blocklist(
|
||||
_auth: Validated session — enforces authentication.
|
||||
|
||||
Raises:
|
||||
HTTPException: 400 if URL validation fails.
|
||||
HTTPException: 404 if the source does not exist.
|
||||
"""
|
||||
updated = await blocklist_service.update_source(
|
||||
db,
|
||||
source_id,
|
||||
name=payload.name,
|
||||
url=payload.url,
|
||||
enabled=payload.enabled,
|
||||
)
|
||||
try:
|
||||
updated = await blocklist_service.update_source(
|
||||
db,
|
||||
source_id,
|
||||
name=payload.name,
|
||||
url=str(payload.url) if payload.url is not None else None,
|
||||
enabled=payload.enabled,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
||||
if updated is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Blocklist source not found.")
|
||||
return updated
|
||||
|
||||
@@ -166,15 +166,24 @@ async def create_source(
|
||||
) -> BlocklistSource:
|
||||
"""Create a new blocklist source and return the persisted record.
|
||||
|
||||
Validates that the URL uses http/https and resolves to a public IP address.
|
||||
|
||||
Args:
|
||||
db: Active application database connection.
|
||||
name: Human-readable display name.
|
||||
url: URL of the blocklist text file.
|
||||
url: URL of the blocklist text file (must be http/https and resolve to public IP).
|
||||
enabled: Whether the source is active. Defaults to ``True``.
|
||||
|
||||
Returns:
|
||||
The newly created :class:`~app.models.blocklist.BlocklistSource`.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL fails SSRF validation.
|
||||
"""
|
||||
from app.utils.ip_utils import validate_blocklist_url
|
||||
|
||||
await validate_blocklist_url(url)
|
||||
|
||||
new_id = await blocklist_repo.create_source(db, name, url, enabled=enabled)
|
||||
source = await get_source(db, new_id)
|
||||
assert source is not None # noqa: S101
|
||||
@@ -192,17 +201,27 @@ async def update_source(
|
||||
) -> BlocklistSource | None:
|
||||
"""Update fields on a blocklist source.
|
||||
|
||||
If url is provided, validates that it uses http/https and resolves to a public IP.
|
||||
|
||||
Args:
|
||||
db: Active application database connection.
|
||||
source_id: Primary key of the source to modify.
|
||||
name: New display name, or ``None`` to leave unchanged.
|
||||
url: New URL, or ``None`` to leave unchanged.
|
||||
url: New URL, or ``None`` to leave unchanged (validated if provided).
|
||||
enabled: New enabled state, or ``None`` to leave unchanged.
|
||||
|
||||
Returns:
|
||||
Updated :class:`~app.models.blocklist.BlocklistSource`, or ``None``
|
||||
if the source does not exist.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL fails SSRF validation.
|
||||
"""
|
||||
if url is not None:
|
||||
from app.utils.ip_utils import validate_blocklist_url
|
||||
|
||||
await validate_blocklist_url(url)
|
||||
|
||||
updated = await blocklist_repo.update_source(
|
||||
db, source_id, name=name, url=url, enabled=enabled
|
||||
)
|
||||
|
||||
@@ -4,7 +4,10 @@ All IP handling in BanGUI goes through these helpers to enforce consistency
|
||||
and prevent malformed addresses from reaching fail2ban.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import ipaddress
|
||||
import socket
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
def is_valid_ip(address: str) -> bool:
|
||||
@@ -99,3 +102,97 @@ def ip_version(address: str) -> int:
|
||||
ValueError: If *address* is not a valid IP address.
|
||||
"""
|
||||
return ipaddress.ip_address(address).version
|
||||
|
||||
|
||||
def is_private_ip(address: str) -> bool:
|
||||
"""Return ``True`` if *address* is a private or reserved IP address.
|
||||
|
||||
Private ranges include:
|
||||
- RFC 1918: 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16
|
||||
- Loopback: 127.0.0.0/8 (IPv4), ::1/128 (IPv6)
|
||||
- Link-local: 169.254.0.0/16 (IPv4), fe80::/10 (IPv6)
|
||||
- IPv6 ULA: fc00::/7
|
||||
- Multicast and other reserved ranges
|
||||
|
||||
Args:
|
||||
address: A valid IP address string.
|
||||
|
||||
Returns:
|
||||
``True`` if the address is private or reserved, ``False`` if it is public.
|
||||
|
||||
Raises:
|
||||
ValueError: If *address* is not a valid IP address.
|
||||
"""
|
||||
ip = ipaddress.ip_address(address)
|
||||
return (
|
||||
ip.is_private
|
||||
or ip.is_loopback
|
||||
or ip.is_link_local
|
||||
or ip.is_multicast
|
||||
or ip.is_reserved
|
||||
)
|
||||
|
||||
|
||||
async def validate_blocklist_url(url: str) -> None:
|
||||
"""Validate that a blocklist URL points to a public HTTP(S) endpoint.
|
||||
|
||||
Checks that:
|
||||
- The URL uses HTTP or HTTPS scheme
|
||||
- The hostname resolves to a public (non-private, non-reserved) IP address
|
||||
- IPv4-mapped IPv6 addresses are checked against IPv4 private ranges
|
||||
|
||||
Performs DNS resolution asynchronously to check the resolved IP.
|
||||
This is a point-in-time check; DNS rebinding attacks may still be possible
|
||||
at actual fetch time. Callers should re-validate the final connection
|
||||
in the HTTP client layer.
|
||||
|
||||
Args:
|
||||
url: The blocklist URL to validate.
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL has an invalid scheme, hostname cannot be resolved,
|
||||
or the resolved IP is private/reserved.
|
||||
"""
|
||||
try:
|
||||
parsed = urlparse(url)
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Invalid URL format: {exc}") from exc
|
||||
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(
|
||||
f"Invalid scheme '{parsed.scheme}': only http and https are allowed"
|
||||
)
|
||||
|
||||
if not parsed.hostname:
|
||||
raise ValueError("URL has no hostname")
|
||||
|
||||
hostname = parsed.hostname
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
addrinfo = await loop.run_in_executor(
|
||||
None,
|
||||
socket.getaddrinfo,
|
||||
hostname,
|
||||
parsed.port or 80,
|
||||
socket.AF_UNSPEC,
|
||||
socket.SOCK_STREAM,
|
||||
)
|
||||
except socket.gaierror as exc:
|
||||
raise ValueError(f"Cannot resolve hostname '{hostname}': {exc}") from exc
|
||||
except Exception as exc:
|
||||
raise ValueError(f"DNS resolution error for '{hostname}': {exc}") from exc
|
||||
|
||||
if not addrinfo:
|
||||
raise ValueError(f"No address resolved for hostname '{hostname}'")
|
||||
|
||||
for family, socktype, proto, canonname, sockaddr in addrinfo:
|
||||
ip_str: str = sockaddr[0] # type: ignore[assignment]
|
||||
try:
|
||||
if is_private_ip(ip_str):
|
||||
raise ValueError(
|
||||
f"Hostname '{hostname}' resolves to private/reserved IP: {ip_str}"
|
||||
)
|
||||
except ipaddress.AddressValueError as exc:
|
||||
raise ValueError(f"Invalid IP address: {ip_str}") from exc
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ from app.models.blocklist import (
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_SETUP_PAYLOAD = {
|
||||
"master_password": "testpassword1",
|
||||
"master_password": "TestPassword1!",
|
||||
"database_path": "bangui.db",
|
||||
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
|
||||
"timezone": "UTC",
|
||||
@@ -182,8 +182,10 @@ class TestListBlocklists:
|
||||
|
||||
|
||||
class TestCreateBlocklist:
|
||||
async def test_create_returns_201(self, bl_client: AsyncClient) -> None:
|
||||
@patch("app.utils.ip_utils.validate_blocklist_url")
|
||||
async def test_create_returns_201(self, mock_validate: AsyncMock, bl_client: AsyncClient) -> None:
|
||||
"""POST /api/blocklists creates a source and returns HTTP 201."""
|
||||
mock_validate.return_value = None
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.create_source",
|
||||
new=AsyncMock(return_value=_make_source()),
|
||||
@@ -194,8 +196,10 @@ class TestCreateBlocklist:
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
|
||||
async def test_create_source_id_in_response(self, bl_client: AsyncClient) -> None:
|
||||
@patch("app.utils.ip_utils.validate_blocklist_url")
|
||||
async def test_create_source_id_in_response(self, mock_validate: AsyncMock, bl_client: AsyncClient) -> None:
|
||||
"""Created source response includes the id field."""
|
||||
mock_validate.return_value = None
|
||||
with patch(
|
||||
"app.routers.blocklist.blocklist_service.create_source",
|
||||
new=AsyncMock(return_value=_make_source(42)),
|
||||
|
||||
@@ -54,8 +54,10 @@ def _make_session(text: str, status: int = 200) -> MagicMock:
|
||||
|
||||
|
||||
class TestSourceCRUD:
|
||||
async def test_create_and_get(self, db: aiosqlite.Connection) -> None:
|
||||
@patch("app.utils.ip_utils.validate_blocklist_url")
|
||||
async def test_create_and_get(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
|
||||
"""create_source persists and get_source retrieves a source."""
|
||||
mock_validate.return_value = None
|
||||
source = await blocklist_service.create_source(db, "Test", "https://t.test/")
|
||||
assert isinstance(source, BlocklistSource)
|
||||
assert source.name == "Test"
|
||||
@@ -75,15 +77,19 @@ class TestSourceCRUD:
|
||||
sources = await blocklist_service.list_sources(db)
|
||||
assert sources == []
|
||||
|
||||
async def test_list_sources_returns_all(self, db: aiosqlite.Connection) -> None:
|
||||
@patch("app.utils.ip_utils.validate_blocklist_url")
|
||||
async def test_list_sources_returns_all(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
|
||||
"""list_sources returns all created sources."""
|
||||
mock_validate.return_value = None
|
||||
await blocklist_service.create_source(db, "A", "https://a.test/")
|
||||
await blocklist_service.create_source(db, "B", "https://b.test/")
|
||||
sources = await blocklist_service.list_sources(db)
|
||||
assert len(sources) == 2
|
||||
|
||||
async def test_update_source_fields(self, db: aiosqlite.Connection) -> None:
|
||||
@patch("app.utils.ip_utils.validate_blocklist_url")
|
||||
async def test_update_source_fields(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
|
||||
"""update_source modifies specified fields."""
|
||||
mock_validate.return_value = None
|
||||
source = await blocklist_service.create_source(db, "Original", "https://orig.test/")
|
||||
updated = await blocklist_service.update_source(db, source.id, name="Updated", enabled=False)
|
||||
assert updated is not None
|
||||
@@ -95,8 +101,10 @@ class TestSourceCRUD:
|
||||
result = await blocklist_service.update_source(db, 9999, name="Ghost")
|
||||
assert result is None
|
||||
|
||||
async def test_delete_source(self, db: aiosqlite.Connection) -> None:
|
||||
@patch("app.utils.ip_utils.validate_blocklist_url")
|
||||
async def test_delete_source(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
|
||||
"""delete_source removes a source and returns True."""
|
||||
mock_validate.return_value = None
|
||||
source = await blocklist_service.create_source(db, "Del", "https://del.test/")
|
||||
deleted = await blocklist_service.delete_source(db, source.id)
|
||||
assert deleted is True
|
||||
@@ -167,8 +175,10 @@ class TestPreview:
|
||||
|
||||
|
||||
class TestImport:
|
||||
async def test_import_source_bans_valid_ips(self, db: aiosqlite.Connection) -> None:
|
||||
@patch("app.utils.ip_utils.validate_blocklist_url")
|
||||
async def test_import_source_bans_valid_ips(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
|
||||
"""import_source calls ban_ip for every valid IP in the blocklist."""
|
||||
mock_validate.return_value = None
|
||||
content = "1.2.3.4\n5.6.7.8\n# skip me\n"
|
||||
session = _make_session(content)
|
||||
|
||||
@@ -192,8 +202,10 @@ class TestImport:
|
||||
assert result.error is None
|
||||
assert mock_ban.call_count == 2
|
||||
|
||||
async def test_import_source_skips_cidrs(self, db: aiosqlite.Connection) -> None:
|
||||
@patch("app.utils.ip_utils.validate_blocklist_url")
|
||||
async def test_import_source_skips_cidrs(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
|
||||
"""import_source skips CIDR ranges (fail2ban expects individual IPs)."""
|
||||
mock_validate.return_value = None
|
||||
content = "1.2.3.4\n10.0.0.0/24\n"
|
||||
session = _make_session(content)
|
||||
source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/")
|
||||
@@ -212,8 +224,10 @@ class TestImport:
|
||||
assert result.ips_imported == 1
|
||||
assert result.ips_skipped == 1
|
||||
|
||||
async def test_import_source_records_download_error(self, db: aiosqlite.Connection) -> None:
|
||||
@patch("app.utils.ip_utils.validate_blocklist_url")
|
||||
async def test_import_source_records_download_error(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
|
||||
"""import_source records an error and returns 0 imported on HTTP failure."""
|
||||
mock_validate.return_value = None
|
||||
session = _make_session("", status=503)
|
||||
source = await blocklist_service.create_source(db, "Err Source", "https://err.test/")
|
||||
|
||||
@@ -230,9 +244,11 @@ class TestImport:
|
||||
assert result.ips_imported == 0
|
||||
assert result.error is not None
|
||||
|
||||
async def test_import_source_aborts_on_jail_not_found(self, db: aiosqlite.Connection) -> None:
|
||||
@patch("app.utils.ip_utils.validate_blocklist_url")
|
||||
async def test_import_source_aborts_on_jail_not_found(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
|
||||
"""import_source aborts immediately and records an error when the target jail
|
||||
does not exist in fail2ban instead of silently skipping every IP."""
|
||||
mock_validate.return_value = None
|
||||
from app.services.jail_service import JailNotFoundError
|
||||
from app.services import ban_service
|
||||
|
||||
@@ -262,8 +278,10 @@ class TestImport:
|
||||
assert result.error is not None
|
||||
assert "not found" in result.error.lower() or "blocklist-import" in result.error
|
||||
|
||||
async def test_import_all_runs_all_enabled(self, db: aiosqlite.Connection) -> None:
|
||||
@patch("app.utils.ip_utils.validate_blocklist_url")
|
||||
async def test_import_all_runs_all_enabled(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
|
||||
"""import_all aggregates results across all enabled sources."""
|
||||
mock_validate.return_value = None
|
||||
await blocklist_service.create_source(db, "S1", "https://s1.test/")
|
||||
s2 = await blocklist_service.create_source(db, "S2", "https://s2.test/", enabled=False)
|
||||
_ = s2 # noqa: F841
|
||||
@@ -400,10 +418,12 @@ class TestSchedule:
|
||||
|
||||
|
||||
class TestGeoPrewarmCacheFilter:
|
||||
@patch("app.utils.ip_utils.validate_blocklist_url")
|
||||
async def test_import_source_skips_cached_ips_for_geo_prewarm(
|
||||
self, db: aiosqlite.Connection
|
||||
self, mock_validate: AsyncMock, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""import_source only sends uncached IPs to geo_service.lookup_batch."""
|
||||
mock_validate.return_value = None
|
||||
content = "1.2.3.4\n5.6.7.8\n9.10.11.12\n"
|
||||
session = _make_session(content)
|
||||
source = await blocklist_service.create_source(
|
||||
@@ -416,7 +436,11 @@ class TestGeoPrewarmCacheFilter:
|
||||
|
||||
from app.services import ban_service
|
||||
|
||||
mock_batch = AsyncMock(return_value={})
|
||||
mock_lookup = AsyncMock(return_value={})
|
||||
mock_geo_cache = MagicMock()
|
||||
mock_geo_cache.is_cached = _mock_is_cached
|
||||
mock_geo_cache.lookup_batch = mock_lookup
|
||||
|
||||
with patch("app.services.ban_service.ban_ip", new_callable=AsyncMock):
|
||||
result = await blocklist_service.import_source(
|
||||
source,
|
||||
@@ -425,17 +449,76 @@ class TestGeoPrewarmCacheFilter:
|
||||
db,
|
||||
ban_ip=ban_service.ban_ip,
|
||||
geo_is_cached=_mock_is_cached,
|
||||
geo_batch_lookup=mock_batch,
|
||||
geo_cache=mock_geo_cache,
|
||||
)
|
||||
|
||||
assert result.ips_imported == 3
|
||||
# lookup_batch should receive only the 2 uncached IPs.
|
||||
mock_batch.assert_called_once()
|
||||
call_ips = mock_batch.call_args[0][0]
|
||||
mock_lookup.assert_called_once()
|
||||
call_ips = mock_lookup.call_args[0][0]
|
||||
assert "1.2.3.4" not in call_ips
|
||||
assert set(call_ips) == {"5.6.7.8", "9.10.11.12"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# URL Validation (SSRF Prevention)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestURLValidation:
|
||||
"""Test SSRF protection by validating blocklist URLs."""
|
||||
|
||||
async def test_create_source_rejects_file_url(self, db: aiosqlite.Connection) -> None:
|
||||
"""create_source rejects file:// URLs."""
|
||||
with pytest.raises(ValueError, match="Invalid scheme"):
|
||||
await blocklist_service.create_source(db, "Bad", "file:///etc/passwd")
|
||||
|
||||
async def test_create_source_rejects_ftp_url(self, db: aiosqlite.Connection) -> None:
|
||||
"""create_source rejects ftp:// URLs."""
|
||||
with pytest.raises(ValueError, match="Invalid scheme"):
|
||||
await blocklist_service.create_source(db, "Bad", "ftp://evil.com/file.txt")
|
||||
|
||||
async def test_create_source_rejects_localhost(self, db: aiosqlite.Connection) -> None:
|
||||
"""create_source rejects localhost (127.0.0.1)."""
|
||||
with pytest.raises(ValueError, match="private|reserved"):
|
||||
await blocklist_service.create_source(db, "Bad", "http://127.0.0.1/list")
|
||||
|
||||
async def test_create_source_rejects_localhost_name(self, db: aiosqlite.Connection) -> None:
|
||||
"""create_source rejects localhost hostname."""
|
||||
with pytest.raises(ValueError, match="private|reserved"):
|
||||
await blocklist_service.create_source(db, "Bad", "http://localhost/list")
|
||||
|
||||
async def test_create_source_rejects_private_network(self, db: aiosqlite.Connection) -> None:
|
||||
"""create_source rejects private RFC 1918 networks (10.0.0.0/8)."""
|
||||
with pytest.raises(ValueError, match="private|reserved"):
|
||||
await blocklist_service.create_source(db, "Bad", "http://10.0.0.1/list")
|
||||
|
||||
async def test_create_source_rejects_private_network_172(self, db: aiosqlite.Connection) -> None:
|
||||
"""create_source rejects private RFC 1918 networks (172.16.0.0/12)."""
|
||||
with pytest.raises(ValueError, match="private|reserved"):
|
||||
await blocklist_service.create_source(db, "Bad", "http://172.16.0.1/list")
|
||||
|
||||
async def test_create_source_rejects_private_network_192(self, db: aiosqlite.Connection) -> None:
|
||||
"""create_source rejects private RFC 1918 networks (192.168.0.0/16)."""
|
||||
with pytest.raises(ValueError, match="private|reserved"):
|
||||
await blocklist_service.create_source(db, "Bad", "http://192.168.1.1/list")
|
||||
|
||||
async def test_create_source_rejects_link_local(self, db: aiosqlite.Connection) -> None:
|
||||
"""create_source rejects link-local addresses (169.254.x.x)."""
|
||||
with pytest.raises(ValueError, match="private|reserved"):
|
||||
await blocklist_service.create_source(db, "Bad", "http://169.254.169.254/latest/meta-data")
|
||||
|
||||
@patch("app.utils.ip_utils.validate_blocklist_url")
|
||||
async def test_create_source_accepts_valid_public_url(
|
||||
self, mock_validate: AsyncMock, db: aiosqlite.Connection
|
||||
) -> None:
|
||||
"""create_source accepts valid public HTTPS URLs (validation mocked)."""
|
||||
mock_validate.return_value = None
|
||||
source = await blocklist_service.create_source(db, "Good", "https://example.com/list.txt")
|
||||
assert source.name == "Good"
|
||||
assert source.url == "https://example.com/list.txt"
|
||||
|
||||
|
||||
class TestImportLogPagination:
|
||||
async def test_list_import_logs_empty(self, db: aiosqlite.Connection) -> None:
|
||||
"""list_import_logs returns an empty page when no logs exist."""
|
||||
|
||||
Reference in New Issue
Block a user