diff --git a/backend/app/models/ban.py b/backend/app/models/ban.py index 17f05e7..23187a7 100644 --- a/backend/app/models/ban.py +++ b/backend/app/models/ban.py @@ -5,7 +5,7 @@ Request, response, and domain models used by the ban router and service. from typing import Literal -from pydantic import Field +from pydantic import Field, field_validator from app.models.response import BanGuiBaseModel, CollectionResponse, PaginatedListResponse @@ -67,6 +67,18 @@ class Ban(BanGuiBaseModel): description="Whether this ban came from a blocklist import or fail2ban itself.", ) + @field_validator("country") + @classmethod + def _normalize_empty_country(cls, v: str | None) -> str | None: + """Coerce empty strings to None for country. + + Geo enrichment may produce an empty string instead of None for + unresolved IPs, which breaks frontend truthiness checks. + """ + if v == "": + return None + return v + class BanResponse(BanGuiBaseModel): """Response containing a single ban record.""" @@ -97,6 +109,18 @@ class ActiveBan(BanGuiBaseModel): ban_count: int = Field(default=1, ge=1, description="Running ban count for this IP.") country: str | None = Field(default=None, description="ISO 3166-1 alpha-2 country code.") + @field_validator("country") + @classmethod + def _normalize_empty_country(cls, v: str | None) -> str | None: + """Coerce empty strings to None for country. + + Geo enrichment may produce an empty string instead of None for + unresolved IPs, which breaks frontend truthiness checks. + """ + if v == "": + return None + return v + class ActiveBanListResponse(CollectionResponse[ActiveBan]): """List of all currently active bans across all jails. @@ -154,6 +178,20 @@ class DashboardBanItem(BanGuiBaseModel): description="Whether this ban came from a blocklist import or fail2ban itself.", ) + @field_validator("country_code") + @classmethod + def _normalize_empty_country_code(cls, v: str | None) -> str | None: + """Coerce empty strings to None for country_code. + + The geo enrichment layer may produce an empty string instead of None + for unresolved IPs. Frontend type narrowing uses truthiness, so an + empty string would slip through ``if (ban.country_code)`` checks and + appear as a falsy-but-not-null value — breaking UI rendering. + """ + if v == "": + return None + return v + class DashboardBanListResponse(PaginatedListResponse[DashboardBanItem]): """Paginated dashboard ban-list response. diff --git a/backend/app/utils/regex_validator.py b/backend/app/utils/regex_validator.py index 68610b6..9ab43bd 100644 --- a/backend/app/utils/regex_validator.py +++ b/backend/app/utils/regex_validator.py @@ -12,6 +12,8 @@ from contextlib import contextmanager from typing import TYPE_CHECKING import structlog +from regexploit.ast.sre import SreOpParser +from regexploit.redos import Redos, find if TYPE_CHECKING: from collections.abc import Generator @@ -22,6 +24,10 @@ logger = structlog.get_logger() MAX_REGEX_LENGTH = 1000 REGEX_COMPILE_TIMEOUT_SECONDS = 2 +# Minimum starriness threshold for flagging as ReDoS +# Higher values = more severe/numerous nested quantifiers +_MINIMUM_STARRINESS = 3 + class RegexTimeoutError(Exception): """Raised when regex compilation exceeds the timeout limit.""" @@ -41,25 +47,67 @@ class RegexTimeoutError(Exception): ) +class ReDoSDetectedError(Exception): + """Raised when a regex pattern is detected to have catastrophic backtracking.""" + + def __init__(self, pattern: str, redos: Redos) -> None: + """Initialize with the pattern and detection reason. + + Args: + pattern: The regex pattern that was detected as dangerous. + redos: The Redos object containing details about the vulnerability. + """ + self.pattern = pattern + self.starriness = redos.starriness + self.reason = redos.example() + super().__init__( + f"ReDoS pattern detected (starriness={redos.starriness}): {self.reason}" + ) + + +def _check_redos(pattern: str) -> Redos | None: + """Check if a pattern has catastrophic backtracking. + + Args: + pattern: The regex pattern string to check. + + Returns: + A Redos object if vulnerability detected, None otherwise. + """ + try: + parsed = SreOpParser().parse_sre(pattern, 0) + except re.error: + # Invalid regex - will be caught by re.compile() later + return None + + redos_list = find(parsed) + for redos in redos_list: + if redos.starriness >= _MINIMUM_STARRINESS: + return redos + return None + + def validate_regex_pattern(pattern: str) -> None: - """Validate a regex pattern with length and timeout checks. + """Validate a regex pattern with length and ReDoS checks. Validates a regex pattern by: 1. Checking length does not exceed MAX_REGEX_LENGTH characters - 2. Attempting compilation with a timeout to prevent ReDoS attacks + 2. Checking for known catastrophic backtracking patterns (ReDoS) + 3. Attempting compilation with a timeout to prevent ReDoS attacks Args: pattern: The regex pattern string to validate. Raises: ValueError: If the pattern exceeds maximum length. + ReDoSDetectedError: If the pattern is detected as a ReDoS vulnerability. RegexTimeoutError: If compilation exceeds the timeout. re.error: If the pattern is syntactically invalid. Example: >>> validate_regex_pattern(r'^[a-z]+$') # OK >>> validate_regex_pattern('a' * 1001) # Raises ValueError - >>> validate_regex_pattern(r'(a+)+b') # May raise RegexTimeoutError + >>> validate_regex_pattern(r'(a+)+b') # Raises ReDoSDetectedError """ # Check length first (fast, no timeout needed) if len(pattern) > MAX_REGEX_LENGTH: @@ -67,6 +115,16 @@ def validate_regex_pattern(pattern: str) -> None: logger.warning("regex_validation_length_exceeded", max_length=MAX_REGEX_LENGTH, actual_length=len(pattern)) raise ValueError(msg) + # Check for ReDoS patterns before compilation + redos = _check_redos(pattern) + if redos is not None: + logger.warning( + "regex_redos_detected", + starriness=redos.starriness, + pattern_preview=pattern[:100], + ) + raise ReDoSDetectedError(pattern, redos) + # Attempt compilation with timeout try: with _timeout_context(REGEX_COMPILE_TIMEOUT_SECONDS): diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 173ea99..44cbfc5 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -19,6 +19,7 @@ dependencies = [ "bcrypt>=4.2.0", "geoip2>=4.8.0", "prometheus-client>=0.21.0", + "regexploit>=1.0.0", ] [project.optional-dependencies] diff --git a/backend/tests/test_models.py b/backend/tests/test_models.py index 0a5ffb8..42bfe6b 100644 --- a/backend/tests/test_models.py +++ b/backend/tests/test_models.py @@ -3,12 +3,13 @@ import pytest from pydantic import ValidationError -from app.models.config import GlobalConfigUpdate, GlobalConfigResponse +from app.models.config import GlobalConfigResponse, GlobalConfigUpdate + def test_add_log_path_request_default_tail_is_true() -> None: """Tail defaults to True.""" from app.models.config import AddLogPathRequest - + req = AddLogPathRequest(log_path="/var/log/app.log") assert req.tail is True @@ -16,7 +17,7 @@ def test_add_log_path_request_default_tail_is_true() -> None: def test_add_log_path_request_can_be_created() -> None: """AddLogPathRequest can be created with valid data (no validators in model).""" from app.models.config import AddLogPathRequest - + req = AddLogPathRequest(log_path="/etc/passwd", tail=True) # Note: path validation is now in the router layer, not in the model assert req.log_path == "/etc/passwd" @@ -188,3 +189,132 @@ def test_setup_request_master_password_complexity_still_enforced() -> None: with pytest.raises(ValidationError) as exc_info: SetupRequest(master_password="Password1") assert "special character" in str(exc_info.value) + + +# --------------------------------------------------------------------------- +# DashboardBanItem country_code validator +# --------------------------------------------------------------------------- + + +def test_dashboard_ban_item_country_code_null() -> None: + """DashboardBanItem accepts None for country_code.""" + from app.models.ban import DashboardBanItem + + item = DashboardBanItem( + ip="1.2.3.4", + jail="sshd", + banned_at="2026-04-28T07:00:00+00:00", + ban_count=1, + origin="selfblock", + country_code=None, + ) + assert item.country_code is None + + +def test_dashboard_ban_item_country_code_valid() -> None: + """DashboardBanItem accepts a valid 2-char uppercase country code.""" + from app.models.ban import DashboardBanItem + + item = DashboardBanItem( + ip="1.2.3.4", + jail="sshd", + banned_at="2026-04-28T07:00:00+00:00", + ban_count=1, + origin="selfblock", + country_code="US", + ) + assert item.country_code == "US" + + +def test_dashboard_ban_item_country_code_empty_string_coerced_to_none() -> None: + """DashboardBanItem coerces empty-string country_code to None.""" + from app.models.ban import DashboardBanItem + + item = DashboardBanItem( + ip="1.2.3.4", + jail="sshd", + banned_at="2026-04-28T07:00:00+00:00", + ban_count=1, + origin="selfblock", + country_code="", + ) + assert item.country_code is None + + +# --------------------------------------------------------------------------- +# ActiveBan country validator +# --------------------------------------------------------------------------- + + +def test_active_ban_country_null() -> None: + """ActiveBan accepts None for country.""" + from app.models.ban import ActiveBan + + ban = ActiveBan(ip="1.2.3.4", jail="sshd", country=None) + assert ban.country is None + + +def test_active_ban_country_valid() -> None: + """ActiveBan accepts a valid country code.""" + from app.models.ban import ActiveBan + + ban = ActiveBan(ip="1.2.3.4", jail="sshd", country="DE") + assert ban.country == "DE" + + +def test_active_ban_country_empty_string_coerced_to_none() -> None: + """ActiveBan coerces empty-string country to None.""" + from app.models.ban import ActiveBan + + ban = ActiveBan(ip="1.2.3.4", jail="sshd", country="") + assert ban.country is None + + +# --------------------------------------------------------------------------- +# Ban country validator +# --------------------------------------------------------------------------- + + +def test_ban_country_null() -> None: + """Ban accepts None for country.""" + from app.models.ban import Ban + + ban = Ban( + ip="1.2.3.4", + jail="sshd", + banned_at="2026-04-28T07:00:00+00:00", + ban_count=1, + origin="selfblock", + country=None, + ) + assert ban.country is None + + +def test_ban_country_valid() -> None: + """Ban accepts a valid country code.""" + from app.models.ban import Ban + + ban = Ban( + ip="1.2.3.4", + jail="sshd", + banned_at="2026-04-28T07:00:00+00:00", + ban_count=1, + origin="selfblock", + country="FR", + ) + assert ban.country == "FR" + + +def test_ban_country_empty_string_coerced_to_none() -> None: + """Ban coerces empty-string country to None.""" + from app.models.ban import Ban + + ban = Ban( + ip="1.2.3.4", + jail="sshd", + banned_at="2026-04-28T07:00:00+00:00", + ban_count=1, + origin="selfblock", + country="", + ) + assert ban.country is None diff --git a/backend/tests/test_utils/test_regex_validator.py b/backend/tests/test_utils/test_regex_validator.py index 2395433..73e54aa 100644 --- a/backend/tests/test_utils/test_regex_validator.py +++ b/backend/tests/test_utils/test_regex_validator.py @@ -8,7 +8,7 @@ import pytest from app.utils.regex_validator import ( MAX_REGEX_LENGTH, - REGEX_COMPILE_TIMEOUT_SECONDS, + ReDoSDetectedError, RegexTimeoutError, validate_regex_pattern, ) @@ -116,6 +116,61 @@ class TestRegexTimeoutError: assert exc.timeout_seconds == timeout_seconds +class TestReDoSDetection: + """Tests for ReDoS pattern detection via regexploit.""" + + def test_redos_pattern_raises_error(self) -> None: + """Known catastrophic backtracking patterns should raise ReDoSDetectedError.""" + redos_patterns = [ + r"(a+)+b", + r"([a-zA-Z]+)*d", + r"(x+)+y", + ] + for pattern in redos_patterns: + with pytest.raises(ReDoSDetectedError, match="ReDoS pattern detected"): + validate_regex_pattern(pattern) + + def test_redos_error_message_contains_reason(self) -> None: + """ReDoSDetectedError should include the detection reason.""" + pattern = r"(a+)+b" + from regexploit.ast.sre import SreOpParser + from regexploit.redos import find + parsed = SreOpParser().parse_sre(pattern, 0) + redos_obj = list(find(parsed))[0] + exc = ReDoSDetectedError(pattern, redos_obj) + assert "ReDoS pattern detected" in str(exc) + assert str(redos_obj.starriness) in str(exc) # starriness is included + + def test_redos_error_attributes(self) -> None: + """ReDoSDetectedError should store pattern and starriness.""" + pattern = r"(x+)+y" + from regexploit.ast.sre import SreOpParser + from regexploit.redos import find + parsed = SreOpParser().parse_sre(pattern, 0) + redos_obj = list(find(parsed))[0] + exc = ReDoSDetectedError(pattern, redos_obj) + assert exc.pattern == pattern + assert exc.starriness == redos_obj.starriness + assert exc.reason is not None + + def test_non_redos_complex_pattern_passes(self) -> None: + """Complex but safe patterns should pass validation.""" + safe_patterns = [ + r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}", + r"^(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$", + r"(?:foo|bar|baz)", + ] + for pattern in safe_patterns: + validate_regex_pattern(pattern) + + def test_redos_detection_before_timeout(self) -> None: + """ReDoS detection should occur before timeout check.""" + # This pattern is detected as ReDoS by regexploit + redos_pattern = r"(a+)+b" + with pytest.raises(ReDoSDetectedError): + validate_regex_pattern(redos_pattern) + + class TestValidateRegexPatternEdgeCases: """Test edge cases and boundary conditions."""