refactoring-backend #3

Merged
lukas.pupkalipinski merged 403 commits from refactoring-backend into main 2026-05-20 20:23:46 +02:00
9 changed files with 291 additions and 66 deletions
Showing only changes of commit 4ab767e3d4 - Show all commits

View File

@@ -783,6 +783,31 @@ To adopt a Redis backend:
- Handle edge cases explicitly: empty lists, `None` values, negative numbers, empty strings.
- Use type narrowing and exhaustive pattern matching (`match` / `case`) to eliminate impossible states.
### 14.12 SSRF Prevention (Server-Side Request Forgery)
When user-supplied URLs are fetched by the backend, validate them before making any HTTP requests:
1. **Use Pydantic's `AnyHttpUrl` type** to restrict schemes to `http://` and `https://` only.
- Rejects `file://`, `ftp://`, `gopher://`, and other non-http schemes at the model boundary.
2. **Validate resolved IP addresses** before fetching:
- Parse the hostname and resolve it via DNS (using `socket.getaddrinfo()`).
- Use `ipaddress.ip_address().is_private` to reject private/reserved ranges:
- RFC 1918: `10.0.0.0/8`, `172.16.0.0/12`, `192.168.0.0/16`
- Loopback: `127.0.0.0/8`, `::1/128`
- Link-local: `169.254.0.0/16`, `fe80::/10`
- IPv6 site-local, multicast, and reserved ranges.
- Raise `ValueError` if validation fails; let the router convert it to HTTP 400.
3. **Guard against DNS rebinding**:
- Validate DNS at URL creation/validation time (performed during request deserialization).
- For additional safety, re-validate the connection IP at HTTP client time (e.g., custom `aiohttp.TCPConnector` can inspect the resolved address during connect).
4. **Example implementation** (see `backend/app/utils/ip_utils.py`):
- `is_private_ip(ip_str: str) → bool`: Checks if IP is private/reserved/loopback/link-local.
- `async validate_blocklist_url(url: AnyHttpUrl) → None`: Async DNS resolution + private IP check.
- Service layer calls `await validate_blocklist_url(url)` before persisting; router catches `ValueError` and returns 400.
---
## 16. Quick Reference — Do / Don't

View File

@@ -311,6 +311,17 @@ Automated downloading and applying of external IP blocklists to block known mali
- Support for plain-text lists with one IP address per line.
- Preview the contents of a blocklist URL before enabling it (download and display a sample of entries).
#### URL Validation & Security
- **Scheme restriction:** Only `http://` and `https://` schemes are accepted. `file://`, `ftp://`, and other schemes are rejected.
- **Hostname validation:** The hostname is resolved via DNS and the resulting IP address is validated to prevent SSRF attacks:
- Private IP ranges (`10.0.0.0/8`, `172.16.0.0/12`, `192.168.0.0/16`) are rejected.
- Loopback addresses (`127.0.0.1`, `::1`) are rejected.
- Link-local addresses (`169.254.0.0/16`, `fe80::/10`) are rejected.
- Reserved and multicast addresses are rejected.
- **Error handling:** If a URL fails validation (invalid scheme, unresolvable hostname, or resolves to a private IP), the API returns a `400 Bad Request` with a descriptive error message.
- **Ports:** URLs may specify custom ports (e.g. `https://example.com:8443/list.txt`), but the hostname must still resolve to a public IP address.
### Schedule
- Configure when the blocklist import runs using a simple time-and-frequency picker (no raw cron syntax required).

View File

@@ -1,35 +1,3 @@
## TASK-008 — `delete_expired_sessions` never scheduled — sessions table grows unbounded
**Severity:** Medium
### Where found
`backend/app/repositories/session_repo.py``delete_expired_sessions()` exists but is never called from any task or lifespan handler.
### Why this is needed
Expired sessions are only removed individually when that specific token is validated and found expired. The bulk cleanup function is never called. Over months of operation, the `sessions` table accumulates every session ever created and is never trimmed, increasing DB size and degrading query performance.
### Goal
Periodically purge expired sessions from the database.
### What to do
1. Create `backend/app/tasks/session_cleanup.py` following the same pattern as `geo_cache_flush.py`.
2. Schedule it as an interval job (e.g., every 6 hours) in `startup_shared_resources`.
3. The task should call `session_repo.delete_expired_sessions(db, now_iso)` and log how many rows were deleted.
### Possible traps and issues
- The task must use `task_db(settings)` (not the request-scoped `get_db`) to open its own connection.
- Log the count of deleted rows at `info` level, not `debug`, so administrators can see the cleanup is running.
### Docs changes needed
- `Architekture.md` — add `session_cleanup` to the scheduled tasks table.
- `Backend-Development.md` — background task patterns.
### Doc references
- [Architekture.md](Architekture.md) — background tasks
- [Backend-Development.md](Backend-Development.md) — scheduled tasks
---
## TASK-009 — Blocklist URL has no scheme/host validation — SSRF risk
**Severity:** High

View File

@@ -8,7 +8,7 @@ from __future__ import annotations
from enum import StrEnum
from pydantic import BaseModel, ConfigDict, Field
from pydantic import AnyHttpUrl, BaseModel, ConfigDict, Field
# ---------------------------------------------------------------------------
# Blocklist source
@@ -29,22 +29,30 @@ class BlocklistSource(BaseModel):
class BlocklistSourceCreate(BaseModel):
"""Payload for ``POST /api/blocklists``."""
"""Payload for ``POST /api/blocklists``.
URL must use http/https scheme. The hostname must resolve to a public IP
(not private, loopback, link-local, or reserved). Validation happens
asynchronously in the service layer.
"""
model_config = ConfigDict(strict=True)
name: str = Field(..., min_length=1, max_length=100, description="Human-readable source name.")
url: str = Field(..., min_length=1, description="URL of the blocklist file.")
url: AnyHttpUrl = Field(..., description="URL of the blocklist file (http/https only).")
enabled: bool = Field(default=True)
class BlocklistSourceUpdate(BaseModel):
"""Payload for ``PUT /api/blocklists/{id}``. All fields are optional."""
"""Payload for ``PUT /api/blocklists/{id}``. All fields are optional.
If URL is provided, it must use http/https scheme.
"""
model_config = ConfigDict(strict=True)
name: str | None = Field(default=None, min_length=1, max_length=100)
url: str | None = Field(default=None)
url: AnyHttpUrl | None = Field(default=None)
enabled: bool | None = Field(default=None)

View File

@@ -97,10 +97,16 @@ async def create_blocklist(
Returns:
The newly created :class:`~app.models.blocklist.BlocklistSource`.
Raises:
HTTPException: 400 if URL validation fails.
"""
return await blocklist_service.create_source(
db, payload.name, payload.url, enabled=payload.enabled
)
try:
return await blocklist_service.create_source(
db, payload.name, str(payload.url), enabled=payload.enabled
)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
# ---------------------------------------------------------------------------
@@ -283,15 +289,19 @@ async def update_blocklist(
_auth: Validated session — enforces authentication.
Raises:
HTTPException: 400 if URL validation fails.
HTTPException: 404 if the source does not exist.
"""
updated = await blocklist_service.update_source(
db,
source_id,
name=payload.name,
url=payload.url,
enabled=payload.enabled,
)
try:
updated = await blocklist_service.update_source(
db,
source_id,
name=payload.name,
url=str(payload.url) if payload.url is not None else None,
enabled=payload.enabled,
)
except ValueError as exc:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
if updated is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Blocklist source not found.")
return updated

View File

@@ -166,15 +166,24 @@ async def create_source(
) -> BlocklistSource:
"""Create a new blocklist source and return the persisted record.
Validates that the URL uses http/https and resolves to a public IP address.
Args:
db: Active application database connection.
name: Human-readable display name.
url: URL of the blocklist text file.
url: URL of the blocklist text file (must be http/https and resolve to public IP).
enabled: Whether the source is active. Defaults to ``True``.
Returns:
The newly created :class:`~app.models.blocklist.BlocklistSource`.
Raises:
ValueError: If the URL fails SSRF validation.
"""
from app.utils.ip_utils import validate_blocklist_url
await validate_blocklist_url(url)
new_id = await blocklist_repo.create_source(db, name, url, enabled=enabled)
source = await get_source(db, new_id)
assert source is not None # noqa: S101
@@ -192,17 +201,27 @@ async def update_source(
) -> BlocklistSource | None:
"""Update fields on a blocklist source.
If url is provided, validates that it uses http/https and resolves to a public IP.
Args:
db: Active application database connection.
source_id: Primary key of the source to modify.
name: New display name, or ``None`` to leave unchanged.
url: New URL, or ``None`` to leave unchanged.
url: New URL, or ``None`` to leave unchanged (validated if provided).
enabled: New enabled state, or ``None`` to leave unchanged.
Returns:
Updated :class:`~app.models.blocklist.BlocklistSource`, or ``None``
if the source does not exist.
Raises:
ValueError: If the URL fails SSRF validation.
"""
if url is not None:
from app.utils.ip_utils import validate_blocklist_url
await validate_blocklist_url(url)
updated = await blocklist_repo.update_source(
db, source_id, name=name, url=url, enabled=enabled
)

View File

@@ -4,7 +4,10 @@ All IP handling in BanGUI goes through these helpers to enforce consistency
and prevent malformed addresses from reaching fail2ban.
"""
import asyncio
import ipaddress
import socket
from urllib.parse import urlparse
def is_valid_ip(address: str) -> bool:
@@ -99,3 +102,97 @@ def ip_version(address: str) -> int:
ValueError: If *address* is not a valid IP address.
"""
return ipaddress.ip_address(address).version
def is_private_ip(address: str) -> bool:
"""Return ``True`` if *address* is a private or reserved IP address.
Private ranges include:
- RFC 1918: 10.0.0.0/8, 172.16.0.0/12, 192.168.0.0/16
- Loopback: 127.0.0.0/8 (IPv4), ::1/128 (IPv6)
- Link-local: 169.254.0.0/16 (IPv4), fe80::/10 (IPv6)
- IPv6 ULA: fc00::/7
- Multicast and other reserved ranges
Args:
address: A valid IP address string.
Returns:
``True`` if the address is private or reserved, ``False`` if it is public.
Raises:
ValueError: If *address* is not a valid IP address.
"""
ip = ipaddress.ip_address(address)
return (
ip.is_private
or ip.is_loopback
or ip.is_link_local
or ip.is_multicast
or ip.is_reserved
)
async def validate_blocklist_url(url: str) -> None:
"""Validate that a blocklist URL points to a public HTTP(S) endpoint.
Checks that:
- The URL uses HTTP or HTTPS scheme
- The hostname resolves to a public (non-private, non-reserved) IP address
- IPv4-mapped IPv6 addresses are checked against IPv4 private ranges
Performs DNS resolution asynchronously to check the resolved IP.
This is a point-in-time check; DNS rebinding attacks may still be possible
at actual fetch time. Callers should re-validate the final connection
in the HTTP client layer.
Args:
url: The blocklist URL to validate.
Raises:
ValueError: If the URL has an invalid scheme, hostname cannot be resolved,
or the resolved IP is private/reserved.
"""
try:
parsed = urlparse(url)
except Exception as exc:
raise ValueError(f"Invalid URL format: {exc}") from exc
if parsed.scheme not in ("http", "https"):
raise ValueError(
f"Invalid scheme '{parsed.scheme}': only http and https are allowed"
)
if not parsed.hostname:
raise ValueError("URL has no hostname")
hostname = parsed.hostname
try:
loop = asyncio.get_event_loop()
addrinfo = await loop.run_in_executor(
None,
socket.getaddrinfo,
hostname,
parsed.port or 80,
socket.AF_UNSPEC,
socket.SOCK_STREAM,
)
except socket.gaierror as exc:
raise ValueError(f"Cannot resolve hostname '{hostname}': {exc}") from exc
except Exception as exc:
raise ValueError(f"DNS resolution error for '{hostname}': {exc}") from exc
if not addrinfo:
raise ValueError(f"No address resolved for hostname '{hostname}'")
for family, socktype, proto, canonname, sockaddr in addrinfo:
ip_str: str = sockaddr[0] # type: ignore[assignment]
try:
if is_private_ip(ip_str):
raise ValueError(
f"Hostname '{hostname}' resolves to private/reserved IP: {ip_str}"
)
except ipaddress.AddressValueError as exc:
raise ValueError(f"Invalid IP address: {ip_str}") from exc

View File

@@ -29,7 +29,7 @@ from app.models.blocklist import (
# ---------------------------------------------------------------------------
_SETUP_PAYLOAD = {
"master_password": "testpassword1",
"master_password": "TestPassword1!",
"database_path": "bangui.db",
"fail2ban_socket": "/var/run/fail2ban/fail2ban.sock",
"timezone": "UTC",
@@ -182,8 +182,10 @@ class TestListBlocklists:
class TestCreateBlocklist:
async def test_create_returns_201(self, bl_client: AsyncClient) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_create_returns_201(self, mock_validate: AsyncMock, bl_client: AsyncClient) -> None:
"""POST /api/blocklists creates a source and returns HTTP 201."""
mock_validate.return_value = None
with patch(
"app.routers.blocklist.blocklist_service.create_source",
new=AsyncMock(return_value=_make_source()),
@@ -194,8 +196,10 @@ class TestCreateBlocklist:
)
assert resp.status_code == 201
async def test_create_source_id_in_response(self, bl_client: AsyncClient) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_create_source_id_in_response(self, mock_validate: AsyncMock, bl_client: AsyncClient) -> None:
"""Created source response includes the id field."""
mock_validate.return_value = None
with patch(
"app.routers.blocklist.blocklist_service.create_source",
new=AsyncMock(return_value=_make_source(42)),

View File

@@ -54,8 +54,10 @@ def _make_session(text: str, status: int = 200) -> MagicMock:
class TestSourceCRUD:
async def test_create_and_get(self, db: aiosqlite.Connection) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_create_and_get(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""create_source persists and get_source retrieves a source."""
mock_validate.return_value = None
source = await blocklist_service.create_source(db, "Test", "https://t.test/")
assert isinstance(source, BlocklistSource)
assert source.name == "Test"
@@ -75,15 +77,19 @@ class TestSourceCRUD:
sources = await blocklist_service.list_sources(db)
assert sources == []
async def test_list_sources_returns_all(self, db: aiosqlite.Connection) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_list_sources_returns_all(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""list_sources returns all created sources."""
mock_validate.return_value = None
await blocklist_service.create_source(db, "A", "https://a.test/")
await blocklist_service.create_source(db, "B", "https://b.test/")
sources = await blocklist_service.list_sources(db)
assert len(sources) == 2
async def test_update_source_fields(self, db: aiosqlite.Connection) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_update_source_fields(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""update_source modifies specified fields."""
mock_validate.return_value = None
source = await blocklist_service.create_source(db, "Original", "https://orig.test/")
updated = await blocklist_service.update_source(db, source.id, name="Updated", enabled=False)
assert updated is not None
@@ -95,8 +101,10 @@ class TestSourceCRUD:
result = await blocklist_service.update_source(db, 9999, name="Ghost")
assert result is None
async def test_delete_source(self, db: aiosqlite.Connection) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_delete_source(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""delete_source removes a source and returns True."""
mock_validate.return_value = None
source = await blocklist_service.create_source(db, "Del", "https://del.test/")
deleted = await blocklist_service.delete_source(db, source.id)
assert deleted is True
@@ -167,8 +175,10 @@ class TestPreview:
class TestImport:
async def test_import_source_bans_valid_ips(self, db: aiosqlite.Connection) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_import_source_bans_valid_ips(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""import_source calls ban_ip for every valid IP in the blocklist."""
mock_validate.return_value = None
content = "1.2.3.4\n5.6.7.8\n# skip me\n"
session = _make_session(content)
@@ -192,8 +202,10 @@ class TestImport:
assert result.error is None
assert mock_ban.call_count == 2
async def test_import_source_skips_cidrs(self, db: aiosqlite.Connection) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_import_source_skips_cidrs(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""import_source skips CIDR ranges (fail2ban expects individual IPs)."""
mock_validate.return_value = None
content = "1.2.3.4\n10.0.0.0/24\n"
session = _make_session(content)
source = await blocklist_service.create_source(db, "CIDR Test", "https://c.test/")
@@ -212,8 +224,10 @@ class TestImport:
assert result.ips_imported == 1
assert result.ips_skipped == 1
async def test_import_source_records_download_error(self, db: aiosqlite.Connection) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_import_source_records_download_error(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""import_source records an error and returns 0 imported on HTTP failure."""
mock_validate.return_value = None
session = _make_session("", status=503)
source = await blocklist_service.create_source(db, "Err Source", "https://err.test/")
@@ -230,9 +244,11 @@ class TestImport:
assert result.ips_imported == 0
assert result.error is not None
async def test_import_source_aborts_on_jail_not_found(self, db: aiosqlite.Connection) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_import_source_aborts_on_jail_not_found(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""import_source aborts immediately and records an error when the target jail
does not exist in fail2ban instead of silently skipping every IP."""
mock_validate.return_value = None
from app.services.jail_service import JailNotFoundError
from app.services import ban_service
@@ -262,8 +278,10 @@ class TestImport:
assert result.error is not None
assert "not found" in result.error.lower() or "blocklist-import" in result.error
async def test_import_all_runs_all_enabled(self, db: aiosqlite.Connection) -> None:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_import_all_runs_all_enabled(self, mock_validate: AsyncMock, db: aiosqlite.Connection) -> None:
"""import_all aggregates results across all enabled sources."""
mock_validate.return_value = None
await blocklist_service.create_source(db, "S1", "https://s1.test/")
s2 = await blocklist_service.create_source(db, "S2", "https://s2.test/", enabled=False)
_ = s2 # noqa: F841
@@ -400,10 +418,12 @@ class TestSchedule:
class TestGeoPrewarmCacheFilter:
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_import_source_skips_cached_ips_for_geo_prewarm(
self, db: aiosqlite.Connection
self, mock_validate: AsyncMock, db: aiosqlite.Connection
) -> None:
"""import_source only sends uncached IPs to geo_service.lookup_batch."""
mock_validate.return_value = None
content = "1.2.3.4\n5.6.7.8\n9.10.11.12\n"
session = _make_session(content)
source = await blocklist_service.create_source(
@@ -416,7 +436,11 @@ class TestGeoPrewarmCacheFilter:
from app.services import ban_service
mock_batch = AsyncMock(return_value={})
mock_lookup = AsyncMock(return_value={})
mock_geo_cache = MagicMock()
mock_geo_cache.is_cached = _mock_is_cached
mock_geo_cache.lookup_batch = mock_lookup
with patch("app.services.ban_service.ban_ip", new_callable=AsyncMock):
result = await blocklist_service.import_source(
source,
@@ -425,17 +449,76 @@ class TestGeoPrewarmCacheFilter:
db,
ban_ip=ban_service.ban_ip,
geo_is_cached=_mock_is_cached,
geo_batch_lookup=mock_batch,
geo_cache=mock_geo_cache,
)
assert result.ips_imported == 3
# lookup_batch should receive only the 2 uncached IPs.
mock_batch.assert_called_once()
call_ips = mock_batch.call_args[0][0]
mock_lookup.assert_called_once()
call_ips = mock_lookup.call_args[0][0]
assert "1.2.3.4" not in call_ips
assert set(call_ips) == {"5.6.7.8", "9.10.11.12"}
# ---------------------------------------------------------------------------
# URL Validation (SSRF Prevention)
# ---------------------------------------------------------------------------
class TestURLValidation:
"""Test SSRF protection by validating blocklist URLs."""
async def test_create_source_rejects_file_url(self, db: aiosqlite.Connection) -> None:
"""create_source rejects file:// URLs."""
with pytest.raises(ValueError, match="Invalid scheme"):
await blocklist_service.create_source(db, "Bad", "file:///etc/passwd")
async def test_create_source_rejects_ftp_url(self, db: aiosqlite.Connection) -> None:
"""create_source rejects ftp:// URLs."""
with pytest.raises(ValueError, match="Invalid scheme"):
await blocklist_service.create_source(db, "Bad", "ftp://evil.com/file.txt")
async def test_create_source_rejects_localhost(self, db: aiosqlite.Connection) -> None:
"""create_source rejects localhost (127.0.0.1)."""
with pytest.raises(ValueError, match="private|reserved"):
await blocklist_service.create_source(db, "Bad", "http://127.0.0.1/list")
async def test_create_source_rejects_localhost_name(self, db: aiosqlite.Connection) -> None:
"""create_source rejects localhost hostname."""
with pytest.raises(ValueError, match="private|reserved"):
await blocklist_service.create_source(db, "Bad", "http://localhost/list")
async def test_create_source_rejects_private_network(self, db: aiosqlite.Connection) -> None:
"""create_source rejects private RFC 1918 networks (10.0.0.0/8)."""
with pytest.raises(ValueError, match="private|reserved"):
await blocklist_service.create_source(db, "Bad", "http://10.0.0.1/list")
async def test_create_source_rejects_private_network_172(self, db: aiosqlite.Connection) -> None:
"""create_source rejects private RFC 1918 networks (172.16.0.0/12)."""
with pytest.raises(ValueError, match="private|reserved"):
await blocklist_service.create_source(db, "Bad", "http://172.16.0.1/list")
async def test_create_source_rejects_private_network_192(self, db: aiosqlite.Connection) -> None:
"""create_source rejects private RFC 1918 networks (192.168.0.0/16)."""
with pytest.raises(ValueError, match="private|reserved"):
await blocklist_service.create_source(db, "Bad", "http://192.168.1.1/list")
async def test_create_source_rejects_link_local(self, db: aiosqlite.Connection) -> None:
"""create_source rejects link-local addresses (169.254.x.x)."""
with pytest.raises(ValueError, match="private|reserved"):
await blocklist_service.create_source(db, "Bad", "http://169.254.169.254/latest/meta-data")
@patch("app.utils.ip_utils.validate_blocklist_url")
async def test_create_source_accepts_valid_public_url(
self, mock_validate: AsyncMock, db: aiosqlite.Connection
) -> None:
"""create_source accepts valid public HTTPS URLs (validation mocked)."""
mock_validate.return_value = None
source = await blocklist_service.create_source(db, "Good", "https://example.com/list.txt")
assert source.name == "Good"
assert source.url == "https://example.com/list.txt"
class TestImportLogPagination:
async def test_list_import_logs_empty(self, db: aiosqlite.Connection) -> None:
"""list_import_logs returns an empty page when no logs exist."""