fix(regex_validator): add ReDoS detection via regexploit

Detect catastrophic backtracking patterns before regex compilation
using regexploit library. Add ReDoSDetectedError exception and
_MINIMUM_STARRINESS threshold (>=3) to catch dangerous patterns
like (a+)+b. Update pyproject.toml deps, add tests for detection.
This commit is contained in:
2026-05-03 00:05:33 +02:00
parent e436727942
commit 0817a4cb47
5 changed files with 290 additions and 8 deletions

View File

@@ -5,7 +5,7 @@ Request, response, and domain models used by the ban router and service.
from typing import Literal from typing import Literal
from pydantic import Field from pydantic import Field, field_validator
from app.models.response import BanGuiBaseModel, CollectionResponse, PaginatedListResponse from app.models.response import BanGuiBaseModel, CollectionResponse, PaginatedListResponse
@@ -67,6 +67,18 @@ class Ban(BanGuiBaseModel):
description="Whether this ban came from a blocklist import or fail2ban itself.", description="Whether this ban came from a blocklist import or fail2ban itself.",
) )
@field_validator("country")
@classmethod
def _normalize_empty_country(cls, v: str | None) -> str | None:
"""Coerce empty strings to None for country.
Geo enrichment may produce an empty string instead of None for
unresolved IPs, which breaks frontend truthiness checks.
"""
if v == "":
return None
return v
class BanResponse(BanGuiBaseModel): class BanResponse(BanGuiBaseModel):
"""Response containing a single ban record.""" """Response containing a single ban record."""
@@ -97,6 +109,18 @@ class ActiveBan(BanGuiBaseModel):
ban_count: int = Field(default=1, ge=1, description="Running ban count for this IP.") ban_count: int = Field(default=1, ge=1, description="Running ban count for this IP.")
country: str | None = Field(default=None, description="ISO 3166-1 alpha-2 country code.") country: str | None = Field(default=None, description="ISO 3166-1 alpha-2 country code.")
@field_validator("country")
@classmethod
def _normalize_empty_country(cls, v: str | None) -> str | None:
"""Coerce empty strings to None for country.
Geo enrichment may produce an empty string instead of None for
unresolved IPs, which breaks frontend truthiness checks.
"""
if v == "":
return None
return v
class ActiveBanListResponse(CollectionResponse[ActiveBan]): class ActiveBanListResponse(CollectionResponse[ActiveBan]):
"""List of all currently active bans across all jails. """List of all currently active bans across all jails.
@@ -154,6 +178,20 @@ class DashboardBanItem(BanGuiBaseModel):
description="Whether this ban came from a blocklist import or fail2ban itself.", description="Whether this ban came from a blocklist import or fail2ban itself.",
) )
@field_validator("country_code")
@classmethod
def _normalize_empty_country_code(cls, v: str | None) -> str | None:
"""Coerce empty strings to None for country_code.
The geo enrichment layer may produce an empty string instead of None
for unresolved IPs. Frontend type narrowing uses truthiness, so an
empty string would slip through ``if (ban.country_code)`` checks and
appear as a falsy-but-not-null value — breaking UI rendering.
"""
if v == "":
return None
return v
class DashboardBanListResponse(PaginatedListResponse[DashboardBanItem]): class DashboardBanListResponse(PaginatedListResponse[DashboardBanItem]):
"""Paginated dashboard ban-list response. """Paginated dashboard ban-list response.

View File

@@ -12,6 +12,8 @@ from contextlib import contextmanager
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
import structlog import structlog
from regexploit.ast.sre import SreOpParser
from regexploit.redos import Redos, find
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Generator from collections.abc import Generator
@@ -22,6 +24,10 @@ logger = structlog.get_logger()
MAX_REGEX_LENGTH = 1000 MAX_REGEX_LENGTH = 1000
REGEX_COMPILE_TIMEOUT_SECONDS = 2 REGEX_COMPILE_TIMEOUT_SECONDS = 2
# Minimum starriness threshold for flagging as ReDoS
# Higher values = more severe/numerous nested quantifiers
_MINIMUM_STARRINESS = 3
class RegexTimeoutError(Exception): class RegexTimeoutError(Exception):
"""Raised when regex compilation exceeds the timeout limit.""" """Raised when regex compilation exceeds the timeout limit."""
@@ -41,25 +47,67 @@ class RegexTimeoutError(Exception):
) )
class ReDoSDetectedError(Exception):
"""Raised when a regex pattern is detected to have catastrophic backtracking."""
def __init__(self, pattern: str, redos: Redos) -> None:
"""Initialize with the pattern and detection reason.
Args:
pattern: The regex pattern that was detected as dangerous.
redos: The Redos object containing details about the vulnerability.
"""
self.pattern = pattern
self.starriness = redos.starriness
self.reason = redos.example()
super().__init__(
f"ReDoS pattern detected (starriness={redos.starriness}): {self.reason}"
)
def _check_redos(pattern: str) -> Redos | None:
"""Check if a pattern has catastrophic backtracking.
Args:
pattern: The regex pattern string to check.
Returns:
A Redos object if vulnerability detected, None otherwise.
"""
try:
parsed = SreOpParser().parse_sre(pattern, 0)
except re.error:
# Invalid regex - will be caught by re.compile() later
return None
redos_list = find(parsed)
for redos in redos_list:
if redos.starriness >= _MINIMUM_STARRINESS:
return redos
return None
def validate_regex_pattern(pattern: str) -> None: def validate_regex_pattern(pattern: str) -> None:
"""Validate a regex pattern with length and timeout checks. """Validate a regex pattern with length and ReDoS checks.
Validates a regex pattern by: Validates a regex pattern by:
1. Checking length does not exceed MAX_REGEX_LENGTH characters 1. Checking length does not exceed MAX_REGEX_LENGTH characters
2. Attempting compilation with a timeout to prevent ReDoS attacks 2. Checking for known catastrophic backtracking patterns (ReDoS)
3. Attempting compilation with a timeout to prevent ReDoS attacks
Args: Args:
pattern: The regex pattern string to validate. pattern: The regex pattern string to validate.
Raises: Raises:
ValueError: If the pattern exceeds maximum length. ValueError: If the pattern exceeds maximum length.
ReDoSDetectedError: If the pattern is detected as a ReDoS vulnerability.
RegexTimeoutError: If compilation exceeds the timeout. RegexTimeoutError: If compilation exceeds the timeout.
re.error: If the pattern is syntactically invalid. re.error: If the pattern is syntactically invalid.
Example: Example:
>>> validate_regex_pattern(r'^[a-z]+$') # OK >>> validate_regex_pattern(r'^[a-z]+$') # OK
>>> validate_regex_pattern('a' * 1001) # Raises ValueError >>> validate_regex_pattern('a' * 1001) # Raises ValueError
>>> validate_regex_pattern(r'(a+)+b') # May raise RegexTimeoutError >>> validate_regex_pattern(r'(a+)+b') # Raises ReDoSDetectedError
""" """
# Check length first (fast, no timeout needed) # Check length first (fast, no timeout needed)
if len(pattern) > MAX_REGEX_LENGTH: if len(pattern) > MAX_REGEX_LENGTH:
@@ -67,6 +115,16 @@ def validate_regex_pattern(pattern: str) -> None:
logger.warning("regex_validation_length_exceeded", max_length=MAX_REGEX_LENGTH, actual_length=len(pattern)) logger.warning("regex_validation_length_exceeded", max_length=MAX_REGEX_LENGTH, actual_length=len(pattern))
raise ValueError(msg) raise ValueError(msg)
# Check for ReDoS patterns before compilation
redos = _check_redos(pattern)
if redos is not None:
logger.warning(
"regex_redos_detected",
starriness=redos.starriness,
pattern_preview=pattern[:100],
)
raise ReDoSDetectedError(pattern, redos)
# Attempt compilation with timeout # Attempt compilation with timeout
try: try:
with _timeout_context(REGEX_COMPILE_TIMEOUT_SECONDS): with _timeout_context(REGEX_COMPILE_TIMEOUT_SECONDS):

View File

@@ -19,6 +19,7 @@ dependencies = [
"bcrypt>=4.2.0", "bcrypt>=4.2.0",
"geoip2>=4.8.0", "geoip2>=4.8.0",
"prometheus-client>=0.21.0", "prometheus-client>=0.21.0",
"regexploit>=1.0.0",
] ]
[project.optional-dependencies] [project.optional-dependencies]

View File

@@ -3,12 +3,13 @@
import pytest import pytest
from pydantic import ValidationError from pydantic import ValidationError
from app.models.config import GlobalConfigUpdate, GlobalConfigResponse from app.models.config import GlobalConfigResponse, GlobalConfigUpdate
def test_add_log_path_request_default_tail_is_true() -> None: def test_add_log_path_request_default_tail_is_true() -> None:
"""Tail defaults to True.""" """Tail defaults to True."""
from app.models.config import AddLogPathRequest from app.models.config import AddLogPathRequest
req = AddLogPathRequest(log_path="/var/log/app.log") req = AddLogPathRequest(log_path="/var/log/app.log")
assert req.tail is True assert req.tail is True
@@ -16,7 +17,7 @@ def test_add_log_path_request_default_tail_is_true() -> None:
def test_add_log_path_request_can_be_created() -> None: def test_add_log_path_request_can_be_created() -> None:
"""AddLogPathRequest can be created with valid data (no validators in model).""" """AddLogPathRequest can be created with valid data (no validators in model)."""
from app.models.config import AddLogPathRequest from app.models.config import AddLogPathRequest
req = AddLogPathRequest(log_path="/etc/passwd", tail=True) req = AddLogPathRequest(log_path="/etc/passwd", tail=True)
# Note: path validation is now in the router layer, not in the model # Note: path validation is now in the router layer, not in the model
assert req.log_path == "/etc/passwd" assert req.log_path == "/etc/passwd"
@@ -188,3 +189,132 @@ def test_setup_request_master_password_complexity_still_enforced() -> None:
with pytest.raises(ValidationError) as exc_info: with pytest.raises(ValidationError) as exc_info:
SetupRequest(master_password="Password1") SetupRequest(master_password="Password1")
assert "special character" in str(exc_info.value) assert "special character" in str(exc_info.value)
# ---------------------------------------------------------------------------
# DashboardBanItem country_code validator
# ---------------------------------------------------------------------------
def test_dashboard_ban_item_country_code_null() -> None:
"""DashboardBanItem accepts None for country_code."""
from app.models.ban import DashboardBanItem
item = DashboardBanItem(
ip="1.2.3.4",
jail="sshd",
banned_at="2026-04-28T07:00:00+00:00",
ban_count=1,
origin="selfblock",
country_code=None,
)
assert item.country_code is None
def test_dashboard_ban_item_country_code_valid() -> None:
"""DashboardBanItem accepts a valid 2-char uppercase country code."""
from app.models.ban import DashboardBanItem
item = DashboardBanItem(
ip="1.2.3.4",
jail="sshd",
banned_at="2026-04-28T07:00:00+00:00",
ban_count=1,
origin="selfblock",
country_code="US",
)
assert item.country_code == "US"
def test_dashboard_ban_item_country_code_empty_string_coerced_to_none() -> None:
"""DashboardBanItem coerces empty-string country_code to None."""
from app.models.ban import DashboardBanItem
item = DashboardBanItem(
ip="1.2.3.4",
jail="sshd",
banned_at="2026-04-28T07:00:00+00:00",
ban_count=1,
origin="selfblock",
country_code="",
)
assert item.country_code is None
# ---------------------------------------------------------------------------
# ActiveBan country validator
# ---------------------------------------------------------------------------
def test_active_ban_country_null() -> None:
"""ActiveBan accepts None for country."""
from app.models.ban import ActiveBan
ban = ActiveBan(ip="1.2.3.4", jail="sshd", country=None)
assert ban.country is None
def test_active_ban_country_valid() -> None:
"""ActiveBan accepts a valid country code."""
from app.models.ban import ActiveBan
ban = ActiveBan(ip="1.2.3.4", jail="sshd", country="DE")
assert ban.country == "DE"
def test_active_ban_country_empty_string_coerced_to_none() -> None:
"""ActiveBan coerces empty-string country to None."""
from app.models.ban import ActiveBan
ban = ActiveBan(ip="1.2.3.4", jail="sshd", country="")
assert ban.country is None
# ---------------------------------------------------------------------------
# Ban country validator
# ---------------------------------------------------------------------------
def test_ban_country_null() -> None:
"""Ban accepts None for country."""
from app.models.ban import Ban
ban = Ban(
ip="1.2.3.4",
jail="sshd",
banned_at="2026-04-28T07:00:00+00:00",
ban_count=1,
origin="selfblock",
country=None,
)
assert ban.country is None
def test_ban_country_valid() -> None:
"""Ban accepts a valid country code."""
from app.models.ban import Ban
ban = Ban(
ip="1.2.3.4",
jail="sshd",
banned_at="2026-04-28T07:00:00+00:00",
ban_count=1,
origin="selfblock",
country="FR",
)
assert ban.country == "FR"
def test_ban_country_empty_string_coerced_to_none() -> None:
"""Ban coerces empty-string country to None."""
from app.models.ban import Ban
ban = Ban(
ip="1.2.3.4",
jail="sshd",
banned_at="2026-04-28T07:00:00+00:00",
ban_count=1,
origin="selfblock",
country="",
)
assert ban.country is None

View File

@@ -8,7 +8,7 @@ import pytest
from app.utils.regex_validator import ( from app.utils.regex_validator import (
MAX_REGEX_LENGTH, MAX_REGEX_LENGTH,
REGEX_COMPILE_TIMEOUT_SECONDS, ReDoSDetectedError,
RegexTimeoutError, RegexTimeoutError,
validate_regex_pattern, validate_regex_pattern,
) )
@@ -116,6 +116,61 @@ class TestRegexTimeoutError:
assert exc.timeout_seconds == timeout_seconds assert exc.timeout_seconds == timeout_seconds
class TestReDoSDetection:
"""Tests for ReDoS pattern detection via regexploit."""
def test_redos_pattern_raises_error(self) -> None:
"""Known catastrophic backtracking patterns should raise ReDoSDetectedError."""
redos_patterns = [
r"(a+)+b",
r"([a-zA-Z]+)*d",
r"(x+)+y",
]
for pattern in redos_patterns:
with pytest.raises(ReDoSDetectedError, match="ReDoS pattern detected"):
validate_regex_pattern(pattern)
def test_redos_error_message_contains_reason(self) -> None:
"""ReDoSDetectedError should include the detection reason."""
pattern = r"(a+)+b"
from regexploit.ast.sre import SreOpParser
from regexploit.redos import find
parsed = SreOpParser().parse_sre(pattern, 0)
redos_obj = list(find(parsed))[0]
exc = ReDoSDetectedError(pattern, redos_obj)
assert "ReDoS pattern detected" in str(exc)
assert str(redos_obj.starriness) in str(exc) # starriness is included
def test_redos_error_attributes(self) -> None:
"""ReDoSDetectedError should store pattern and starriness."""
pattern = r"(x+)+y"
from regexploit.ast.sre import SreOpParser
from regexploit.redos import find
parsed = SreOpParser().parse_sre(pattern, 0)
redos_obj = list(find(parsed))[0]
exc = ReDoSDetectedError(pattern, redos_obj)
assert exc.pattern == pattern
assert exc.starriness == redos_obj.starriness
assert exc.reason is not None
def test_non_redos_complex_pattern_passes(self) -> None:
"""Complex but safe patterns should pass validation."""
safe_patterns = [
r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
r"^(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$",
r"(?:foo|bar|baz)",
]
for pattern in safe_patterns:
validate_regex_pattern(pattern)
def test_redos_detection_before_timeout(self) -> None:
"""ReDoS detection should occur before timeout check."""
# This pattern is detected as ReDoS by regexploit
redos_pattern = r"(a+)+b"
with pytest.raises(ReDoSDetectedError):
validate_regex_pattern(redos_pattern)
class TestValidateRegexPatternEdgeCases: class TestValidateRegexPatternEdgeCases:
"""Test edge cases and boundary conditions.""" """Test edge cases and boundary conditions."""