refactoring-backend #3
@@ -5,7 +5,7 @@ Request, response, and domain models used by the ban router and service.
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from app.models.response import BanGuiBaseModel, CollectionResponse, PaginatedListResponse
|
||||
|
||||
@@ -67,6 +67,18 @@ class Ban(BanGuiBaseModel):
|
||||
description="Whether this ban came from a blocklist import or fail2ban itself.",
|
||||
)
|
||||
|
||||
@field_validator("country")
|
||||
@classmethod
|
||||
def _normalize_empty_country(cls, v: str | None) -> str | None:
|
||||
"""Coerce empty strings to None for country.
|
||||
|
||||
Geo enrichment may produce an empty string instead of None for
|
||||
unresolved IPs, which breaks frontend truthiness checks.
|
||||
"""
|
||||
if v == "":
|
||||
return None
|
||||
return v
|
||||
|
||||
class BanResponse(BanGuiBaseModel):
|
||||
"""Response containing a single ban record."""
|
||||
|
||||
@@ -97,6 +109,18 @@ class ActiveBan(BanGuiBaseModel):
|
||||
ban_count: int = Field(default=1, ge=1, description="Running ban count for this IP.")
|
||||
country: str | None = Field(default=None, description="ISO 3166-1 alpha-2 country code.")
|
||||
|
||||
@field_validator("country")
|
||||
@classmethod
|
||||
def _normalize_empty_country(cls, v: str | None) -> str | None:
|
||||
"""Coerce empty strings to None for country.
|
||||
|
||||
Geo enrichment may produce an empty string instead of None for
|
||||
unresolved IPs, which breaks frontend truthiness checks.
|
||||
"""
|
||||
if v == "":
|
||||
return None
|
||||
return v
|
||||
|
||||
class ActiveBanListResponse(CollectionResponse[ActiveBan]):
|
||||
"""List of all currently active bans across all jails.
|
||||
|
||||
@@ -154,6 +178,20 @@ class DashboardBanItem(BanGuiBaseModel):
|
||||
description="Whether this ban came from a blocklist import or fail2ban itself.",
|
||||
)
|
||||
|
||||
@field_validator("country_code")
|
||||
@classmethod
|
||||
def _normalize_empty_country_code(cls, v: str | None) -> str | None:
|
||||
"""Coerce empty strings to None for country_code.
|
||||
|
||||
The geo enrichment layer may produce an empty string instead of None
|
||||
for unresolved IPs. Frontend type narrowing uses truthiness, so an
|
||||
empty string would slip through ``if (ban.country_code)`` checks and
|
||||
appear as a falsy-but-not-null value — breaking UI rendering.
|
||||
"""
|
||||
if v == "":
|
||||
return None
|
||||
return v
|
||||
|
||||
class DashboardBanListResponse(PaginatedListResponse[DashboardBanItem]):
|
||||
"""Paginated dashboard ban-list response.
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ from contextlib import contextmanager
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import structlog
|
||||
from regexploit.ast.sre import SreOpParser
|
||||
from regexploit.redos import Redos, find
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
@@ -22,6 +24,10 @@ logger = structlog.get_logger()
|
||||
MAX_REGEX_LENGTH = 1000
|
||||
REGEX_COMPILE_TIMEOUT_SECONDS = 2
|
||||
|
||||
# Minimum starriness threshold for flagging as ReDoS
|
||||
# Higher values = more severe/numerous nested quantifiers
|
||||
_MINIMUM_STARRINESS = 3
|
||||
|
||||
|
||||
class RegexTimeoutError(Exception):
|
||||
"""Raised when regex compilation exceeds the timeout limit."""
|
||||
@@ -41,25 +47,67 @@ class RegexTimeoutError(Exception):
|
||||
)
|
||||
|
||||
|
||||
class ReDoSDetectedError(Exception):
|
||||
"""Raised when a regex pattern is detected to have catastrophic backtracking."""
|
||||
|
||||
def __init__(self, pattern: str, redos: Redos) -> None:
|
||||
"""Initialize with the pattern and detection reason.
|
||||
|
||||
Args:
|
||||
pattern: The regex pattern that was detected as dangerous.
|
||||
redos: The Redos object containing details about the vulnerability.
|
||||
"""
|
||||
self.pattern = pattern
|
||||
self.starriness = redos.starriness
|
||||
self.reason = redos.example()
|
||||
super().__init__(
|
||||
f"ReDoS pattern detected (starriness={redos.starriness}): {self.reason}"
|
||||
)
|
||||
|
||||
|
||||
def _check_redos(pattern: str) -> Redos | None:
|
||||
"""Check if a pattern has catastrophic backtracking.
|
||||
|
||||
Args:
|
||||
pattern: The regex pattern string to check.
|
||||
|
||||
Returns:
|
||||
A Redos object if vulnerability detected, None otherwise.
|
||||
"""
|
||||
try:
|
||||
parsed = SreOpParser().parse_sre(pattern, 0)
|
||||
except re.error:
|
||||
# Invalid regex - will be caught by re.compile() later
|
||||
return None
|
||||
|
||||
redos_list = find(parsed)
|
||||
for redos in redos_list:
|
||||
if redos.starriness >= _MINIMUM_STARRINESS:
|
||||
return redos
|
||||
return None
|
||||
|
||||
|
||||
def validate_regex_pattern(pattern: str) -> None:
|
||||
"""Validate a regex pattern with length and timeout checks.
|
||||
"""Validate a regex pattern with length and ReDoS checks.
|
||||
|
||||
Validates a regex pattern by:
|
||||
1. Checking length does not exceed MAX_REGEX_LENGTH characters
|
||||
2. Attempting compilation with a timeout to prevent ReDoS attacks
|
||||
2. Checking for known catastrophic backtracking patterns (ReDoS)
|
||||
3. Attempting compilation with a timeout to prevent ReDoS attacks
|
||||
|
||||
Args:
|
||||
pattern: The regex pattern string to validate.
|
||||
|
||||
Raises:
|
||||
ValueError: If the pattern exceeds maximum length.
|
||||
ReDoSDetectedError: If the pattern is detected as a ReDoS vulnerability.
|
||||
RegexTimeoutError: If compilation exceeds the timeout.
|
||||
re.error: If the pattern is syntactically invalid.
|
||||
|
||||
Example:
|
||||
>>> validate_regex_pattern(r'^[a-z]+$') # OK
|
||||
>>> validate_regex_pattern('a' * 1001) # Raises ValueError
|
||||
>>> validate_regex_pattern(r'(a+)+b') # May raise RegexTimeoutError
|
||||
>>> validate_regex_pattern(r'(a+)+b') # Raises ReDoSDetectedError
|
||||
"""
|
||||
# Check length first (fast, no timeout needed)
|
||||
if len(pattern) > MAX_REGEX_LENGTH:
|
||||
@@ -67,6 +115,16 @@ def validate_regex_pattern(pattern: str) -> None:
|
||||
logger.warning("regex_validation_length_exceeded", max_length=MAX_REGEX_LENGTH, actual_length=len(pattern))
|
||||
raise ValueError(msg)
|
||||
|
||||
# Check for ReDoS patterns before compilation
|
||||
redos = _check_redos(pattern)
|
||||
if redos is not None:
|
||||
logger.warning(
|
||||
"regex_redos_detected",
|
||||
starriness=redos.starriness,
|
||||
pattern_preview=pattern[:100],
|
||||
)
|
||||
raise ReDoSDetectedError(pattern, redos)
|
||||
|
||||
# Attempt compilation with timeout
|
||||
try:
|
||||
with _timeout_context(REGEX_COMPILE_TIMEOUT_SECONDS):
|
||||
|
||||
@@ -19,6 +19,7 @@ dependencies = [
|
||||
"bcrypt>=4.2.0",
|
||||
"geoip2>=4.8.0",
|
||||
"prometheus-client>=0.21.0",
|
||||
"regexploit>=1.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -3,12 +3,13 @@
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.models.config import GlobalConfigUpdate, GlobalConfigResponse
|
||||
from app.models.config import GlobalConfigResponse, GlobalConfigUpdate
|
||||
|
||||
|
||||
def test_add_log_path_request_default_tail_is_true() -> None:
|
||||
"""Tail defaults to True."""
|
||||
from app.models.config import AddLogPathRequest
|
||||
|
||||
|
||||
req = AddLogPathRequest(log_path="/var/log/app.log")
|
||||
assert req.tail is True
|
||||
|
||||
@@ -16,7 +17,7 @@ def test_add_log_path_request_default_tail_is_true() -> None:
|
||||
def test_add_log_path_request_can_be_created() -> None:
|
||||
"""AddLogPathRequest can be created with valid data (no validators in model)."""
|
||||
from app.models.config import AddLogPathRequest
|
||||
|
||||
|
||||
req = AddLogPathRequest(log_path="/etc/passwd", tail=True)
|
||||
# Note: path validation is now in the router layer, not in the model
|
||||
assert req.log_path == "/etc/passwd"
|
||||
@@ -188,3 +189,132 @@ def test_setup_request_master_password_complexity_still_enforced() -> None:
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
SetupRequest(master_password="Password1")
|
||||
assert "special character" in str(exc_info.value)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DashboardBanItem country_code validator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_dashboard_ban_item_country_code_null() -> None:
|
||||
"""DashboardBanItem accepts None for country_code."""
|
||||
from app.models.ban import DashboardBanItem
|
||||
|
||||
item = DashboardBanItem(
|
||||
ip="1.2.3.4",
|
||||
jail="sshd",
|
||||
banned_at="2026-04-28T07:00:00+00:00",
|
||||
ban_count=1,
|
||||
origin="selfblock",
|
||||
country_code=None,
|
||||
)
|
||||
assert item.country_code is None
|
||||
|
||||
|
||||
def test_dashboard_ban_item_country_code_valid() -> None:
|
||||
"""DashboardBanItem accepts a valid 2-char uppercase country code."""
|
||||
from app.models.ban import DashboardBanItem
|
||||
|
||||
item = DashboardBanItem(
|
||||
ip="1.2.3.4",
|
||||
jail="sshd",
|
||||
banned_at="2026-04-28T07:00:00+00:00",
|
||||
ban_count=1,
|
||||
origin="selfblock",
|
||||
country_code="US",
|
||||
)
|
||||
assert item.country_code == "US"
|
||||
|
||||
|
||||
def test_dashboard_ban_item_country_code_empty_string_coerced_to_none() -> None:
|
||||
"""DashboardBanItem coerces empty-string country_code to None."""
|
||||
from app.models.ban import DashboardBanItem
|
||||
|
||||
item = DashboardBanItem(
|
||||
ip="1.2.3.4",
|
||||
jail="sshd",
|
||||
banned_at="2026-04-28T07:00:00+00:00",
|
||||
ban_count=1,
|
||||
origin="selfblock",
|
||||
country_code="",
|
||||
)
|
||||
assert item.country_code is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ActiveBan country validator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_active_ban_country_null() -> None:
|
||||
"""ActiveBan accepts None for country."""
|
||||
from app.models.ban import ActiveBan
|
||||
|
||||
ban = ActiveBan(ip="1.2.3.4", jail="sshd", country=None)
|
||||
assert ban.country is None
|
||||
|
||||
|
||||
def test_active_ban_country_valid() -> None:
|
||||
"""ActiveBan accepts a valid country code."""
|
||||
from app.models.ban import ActiveBan
|
||||
|
||||
ban = ActiveBan(ip="1.2.3.4", jail="sshd", country="DE")
|
||||
assert ban.country == "DE"
|
||||
|
||||
|
||||
def test_active_ban_country_empty_string_coerced_to_none() -> None:
|
||||
"""ActiveBan coerces empty-string country to None."""
|
||||
from app.models.ban import ActiveBan
|
||||
|
||||
ban = ActiveBan(ip="1.2.3.4", jail="sshd", country="")
|
||||
assert ban.country is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Ban country validator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_ban_country_null() -> None:
|
||||
"""Ban accepts None for country."""
|
||||
from app.models.ban import Ban
|
||||
|
||||
ban = Ban(
|
||||
ip="1.2.3.4",
|
||||
jail="sshd",
|
||||
banned_at="2026-04-28T07:00:00+00:00",
|
||||
ban_count=1,
|
||||
origin="selfblock",
|
||||
country=None,
|
||||
)
|
||||
assert ban.country is None
|
||||
|
||||
|
||||
def test_ban_country_valid() -> None:
|
||||
"""Ban accepts a valid country code."""
|
||||
from app.models.ban import Ban
|
||||
|
||||
ban = Ban(
|
||||
ip="1.2.3.4",
|
||||
jail="sshd",
|
||||
banned_at="2026-04-28T07:00:00+00:00",
|
||||
ban_count=1,
|
||||
origin="selfblock",
|
||||
country="FR",
|
||||
)
|
||||
assert ban.country == "FR"
|
||||
|
||||
|
||||
def test_ban_country_empty_string_coerced_to_none() -> None:
|
||||
"""Ban coerces empty-string country to None."""
|
||||
from app.models.ban import Ban
|
||||
|
||||
ban = Ban(
|
||||
ip="1.2.3.4",
|
||||
jail="sshd",
|
||||
banned_at="2026-04-28T07:00:00+00:00",
|
||||
ban_count=1,
|
||||
origin="selfblock",
|
||||
country="",
|
||||
)
|
||||
assert ban.country is None
|
||||
|
||||
@@ -8,7 +8,7 @@ import pytest
|
||||
|
||||
from app.utils.regex_validator import (
|
||||
MAX_REGEX_LENGTH,
|
||||
REGEX_COMPILE_TIMEOUT_SECONDS,
|
||||
ReDoSDetectedError,
|
||||
RegexTimeoutError,
|
||||
validate_regex_pattern,
|
||||
)
|
||||
@@ -116,6 +116,61 @@ class TestRegexTimeoutError:
|
||||
assert exc.timeout_seconds == timeout_seconds
|
||||
|
||||
|
||||
class TestReDoSDetection:
|
||||
"""Tests for ReDoS pattern detection via regexploit."""
|
||||
|
||||
def test_redos_pattern_raises_error(self) -> None:
|
||||
"""Known catastrophic backtracking patterns should raise ReDoSDetectedError."""
|
||||
redos_patterns = [
|
||||
r"(a+)+b",
|
||||
r"([a-zA-Z]+)*d",
|
||||
r"(x+)+y",
|
||||
]
|
||||
for pattern in redos_patterns:
|
||||
with pytest.raises(ReDoSDetectedError, match="ReDoS pattern detected"):
|
||||
validate_regex_pattern(pattern)
|
||||
|
||||
def test_redos_error_message_contains_reason(self) -> None:
|
||||
"""ReDoSDetectedError should include the detection reason."""
|
||||
pattern = r"(a+)+b"
|
||||
from regexploit.ast.sre import SreOpParser
|
||||
from regexploit.redos import find
|
||||
parsed = SreOpParser().parse_sre(pattern, 0)
|
||||
redos_obj = list(find(parsed))[0]
|
||||
exc = ReDoSDetectedError(pattern, redos_obj)
|
||||
assert "ReDoS pattern detected" in str(exc)
|
||||
assert str(redos_obj.starriness) in str(exc) # starriness is included
|
||||
|
||||
def test_redos_error_attributes(self) -> None:
|
||||
"""ReDoSDetectedError should store pattern and starriness."""
|
||||
pattern = r"(x+)+y"
|
||||
from regexploit.ast.sre import SreOpParser
|
||||
from regexploit.redos import find
|
||||
parsed = SreOpParser().parse_sre(pattern, 0)
|
||||
redos_obj = list(find(parsed))[0]
|
||||
exc = ReDoSDetectedError(pattern, redos_obj)
|
||||
assert exc.pattern == pattern
|
||||
assert exc.starriness == redos_obj.starriness
|
||||
assert exc.reason is not None
|
||||
|
||||
def test_non_redos_complex_pattern_passes(self) -> None:
|
||||
"""Complex but safe patterns should pass validation."""
|
||||
safe_patterns = [
|
||||
r"[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}",
|
||||
r"^(?:25[0-5]|2[0-4][0-9]|[01]?[0-9][0-9]?)$",
|
||||
r"(?:foo|bar|baz)",
|
||||
]
|
||||
for pattern in safe_patterns:
|
||||
validate_regex_pattern(pattern)
|
||||
|
||||
def test_redos_detection_before_timeout(self) -> None:
|
||||
"""ReDoS detection should occur before timeout check."""
|
||||
# This pattern is detected as ReDoS by regexploit
|
||||
redos_pattern = r"(a+)+b"
|
||||
with pytest.raises(ReDoSDetectedError):
|
||||
validate_regex_pattern(redos_pattern)
|
||||
|
||||
|
||||
class TestValidateRegexPatternEdgeCases:
|
||||
"""Test edge cases and boundary conditions."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user