"""Tests for DNS-validated socket factory that prevents DNS-rebinding attacks.""" from __future__ import annotations import socket from unittest.mock import patch import pytest from app.services.dns_validated_connector import create_dns_validated_socket_factory class TestDnsValidatedSocketFactory: """Test DNS validation in socket factory.""" def test_socket_factory_allows_public_ipv4(self) -> None: """Test that public IPv4 addresses are allowed.""" factory = create_dns_validated_socket_factory() # Create a mock address_info tuple for a public IPv4 address address_info = (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("8.8.8.8", 80)) # Should not raise sock = factory(address_info) assert sock is not None assert sock.family == socket.AF_INET assert sock.type == socket.SOCK_STREAM sock.close() def test_socket_factory_allows_public_ipv6(self) -> None: """Test that public IPv6 addresses are allowed.""" factory = create_dns_validated_socket_factory() # Public IPv6 address (Google DNS) address_info = (socket.AF_INET6, socket.SOCK_STREAM, 6, "", ("2606:4700:4700::1111", 80, 0, 0)) sock = factory(address_info) assert sock is not None assert sock.family == socket.AF_INET6 sock.close() def test_socket_factory_blocks_private_ip_192_168(self) -> None: """Test that 192.168.x.x private IPs are blocked.""" factory = create_dns_validated_socket_factory() address_info = (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("192.168.1.1", 80)) with pytest.raises(OSError) as exc_info: factory(address_info) assert "rebinding" in str(exc_info.value).lower() or "private" in str(exc_info.value).lower() def test_socket_factory_blocks_private_ip_10(self) -> None: """Test that 10.x.x.x private IPs are blocked.""" factory = create_dns_validated_socket_factory() address_info = (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("10.0.0.1", 80)) with pytest.raises(OSError) as exc_info: factory(address_info) error_msg = str(exc_info.value).lower() assert "private" in error_msg or "rebinding" in error_msg def test_socket_factory_blocks_private_ip_172_16(self) -> None: """Test that 172.16.x.x private IPs are blocked.""" factory = create_dns_validated_socket_factory() address_info = (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("172.16.0.1", 80)) with pytest.raises(OSError): factory(address_info) def test_socket_factory_blocks_loopback_ipv4(self) -> None: """Test that IPv4 loopback is blocked.""" factory = create_dns_validated_socket_factory() address_info = (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("127.0.0.1", 80)) with pytest.raises(OSError): factory(address_info) def test_socket_factory_blocks_loopback_ipv6(self) -> None: """Test that IPv6 loopback is blocked.""" factory = create_dns_validated_socket_factory() address_info = (socket.AF_INET6, socket.SOCK_STREAM, 6, "", ("::1", 80, 0, 0)) with pytest.raises(OSError): factory(address_info) def test_socket_factory_blocks_link_local_ipv6(self) -> None: """Test that IPv6 link-local addresses are blocked.""" factory = create_dns_validated_socket_factory() address_info = (socket.AF_INET6, socket.SOCK_STREAM, 6, "", ("fe80::1", 80, 0, 0)) with pytest.raises(OSError): factory(address_info) def test_socket_factory_blocks_ipv6_ula(self) -> None: """Test that IPv6 ULA (Unique Local Address) is blocked.""" factory = create_dns_validated_socket_factory() # fc00::/7 is ULA address_info = (socket.AF_INET6, socket.SOCK_STREAM, 6, "", ("fc00::1", 80, 0, 0)) with pytest.raises(OSError): factory(address_info) def test_socket_factory_blocks_ipv4_multicast(self) -> None: """Test that IPv4 multicast addresses are blocked.""" factory = create_dns_validated_socket_factory() # 224.0.0.0/4 is multicast address_info = (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("224.0.0.1", 80)) with pytest.raises(OSError): factory(address_info) def test_socket_factory_blocks_reserved_ips(self) -> None: """Test that reserved IPs are blocked.""" factory = create_dns_validated_socket_factory() # 240.0.0.0/4 is reserved address_info = (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("240.0.0.1", 80)) with pytest.raises(OSError): factory(address_info) def test_socket_factory_blocks_broadcast(self) -> None: """Test that broadcast addresses are blocked.""" factory = create_dns_validated_socket_factory() # 255.255.255.255 is broadcast address_info = (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("255.255.255.255", 80)) with pytest.raises(OSError): factory(address_info) def test_socket_factory_allows_multiple_public_ips(self) -> None: """Test that multiple public IPs can be created.""" factory = create_dns_validated_socket_factory() public_ips = [ (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("8.8.8.8", 80)), (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("1.1.1.1", 443)), (socket.AF_INET, socket.SOCK_STREAM, 6, "", ("208.67.222.222", 53)), ] socks = [] for address_info in public_ips: sock = factory(address_info) assert sock is not None socks.append(sock) # Clean up for sock in socks: sock.close()