"""Regex pattern validation with security checks against ReDoS attacks. Provides timeout and complexity limits to prevent catastrophic backtracking (ReDoS - Regular Expression Denial of Service). """ from __future__ import annotations import re import signal from contextlib import contextmanager from typing import TYPE_CHECKING import structlog try: from regexploit.ast.sre import SreOpParser from regexploit.redos import Redos, find _REGEXPLOIT_AVAILABLE = True except ImportError: SreOpParser = Redos = find = None _REGEXPLOIT_AVAILABLE = False if TYPE_CHECKING: from collections.abc import Generator logger = structlog.get_logger() # Constants for regex validation MAX_REGEX_LENGTH = 1000 REGEX_COMPILE_TIMEOUT_SECONDS = 2 # Minimum starriness threshold for flagging as ReDoS # Higher values = more severe/numerous nested quantifiers _MINIMUM_STARRINESS = 3 class RegexTimeoutError(Exception): """Raised when regex compilation exceeds the timeout limit.""" def __init__(self, pattern: str, timeout_seconds: int) -> None: """Initialize with the pattern and timeout value. Args: pattern: The regex pattern that timed out. timeout_seconds: The timeout value in seconds. """ self.pattern = pattern self.timeout_seconds = timeout_seconds super().__init__( f"Regex pattern compilation timed out after {timeout_seconds}s " f"(possible ReDoS attack): {pattern!r}" ) class ReDoSDetectedError(Exception): """Raised when a regex pattern is detected to have catastrophic backtracking.""" def __init__(self, pattern: str, redos: Redos) -> None: """Initialize with the pattern and detection reason. Args: pattern: The regex pattern that was detected as dangerous. redos: The Redos object containing details about the vulnerability. """ self.pattern = pattern self.starriness = redos.starriness self.reason = redos.example() super().__init__( f"ReDoS pattern detected (starriness={redos.starriness}): {self.reason}" ) def _check_redos(pattern: str) -> "Redos | None": """Check if a pattern has catastrophic backtracking. Args: pattern: The regex pattern string to check. Returns: A Redos object if vulnerability detected, None otherwise. """ if not _REGEXPLOIT_AVAILABLE: return None try: parsed = SreOpParser().parse_sre(pattern, 0) except re.error: # Invalid regex - will be caught by re.compile() later return None redos_list = find(parsed) for redos in redos_list: if redos.starriness >= _MINIMUM_STARRINESS: return redos return None def validate_regex_pattern(pattern: str) -> None: """Validate a regex pattern with length and ReDoS checks. Validates a regex pattern by: 1. Checking length does not exceed MAX_REGEX_LENGTH characters 2. Checking for known catastrophic backtracking patterns (ReDoS) 3. Attempting compilation with a timeout to prevent ReDoS attacks Args: pattern: The regex pattern string to validate. Raises: ValueError: If the pattern exceeds maximum length. ReDoSDetectedError: If the pattern is detected as a ReDoS vulnerability. RegexTimeoutError: If compilation exceeds the timeout. re.error: If the pattern is syntactically invalid. Example: >>> validate_regex_pattern(r'^[a-z]+$') # OK >>> validate_regex_pattern('a' * 1001) # Raises ValueError >>> validate_regex_pattern(r'(a+)+b') # Raises ReDoSDetectedError """ # Check length first (fast, no timeout needed) if len(pattern) > MAX_REGEX_LENGTH: msg = f"Regex pattern exceeds maximum length of {MAX_REGEX_LENGTH} characters: {len(pattern)} provided" logger.warning("regex_validation_length_exceeded", max_length=MAX_REGEX_LENGTH, actual_length=len(pattern)) raise ValueError(msg) # Check for ReDoS patterns before compilation redos = _check_redos(pattern) if redos is not None: logger.warning( "regex_redos_detected", starriness=redos.starriness, pattern_preview=pattern[:100], ) raise ReDoSDetectedError(pattern, redos) # Attempt compilation with timeout try: with _timeout_context(REGEX_COMPILE_TIMEOUT_SECONDS): re.compile(pattern) except TimeoutError as exc: logger.warning( "regex_compilation_timeout", timeout_seconds=REGEX_COMPILE_TIMEOUT_SECONDS, pattern_preview=pattern[:100], ) raise RegexTimeoutError(pattern, REGEX_COMPILE_TIMEOUT_SECONDS) from exc @contextmanager def _timeout_context(timeout_seconds: int) -> Generator[None, None, None]: """Context manager to enforce a timeout using signal.alarm(). Works on Unix-like systems (Linux, macOS, etc.). On Windows or other platforms where signal.SIGALRM is unavailable, compilation proceeds without timeout (not ideal, but graceful degradation). Args: timeout_seconds: Timeout duration in seconds. Yields: None. Raises: TimeoutError: If the timeout is exceeded. Note: This uses signal.alarm() which is only available on Unix. On Windows, timeouts are not enforced (limitation of the platform). """ # Check if signal.SIGALRM is available (Unix-like systems) if not hasattr(signal, "SIGALRM"): # Windows or other platforms without SIGALRM # Just proceed without timeout (not ideal, but prevents crashes) yield return def _timeout_handler(signum: int, frame: object) -> None: raise TimeoutError("Timeout exceeded") # Set up signal handler old_handler = signal.signal(signal.SIGALRM, _timeout_handler) signal.alarm(timeout_seconds) try: yield finally: # Always disable the alarm, even if an exception occurred signal.alarm(0) signal.signal(signal.SIGALRM, old_handler)