Skip to content

Commit

Permalink
chore(allow_hosts): Use getaddrinfo instead of gethostbyname (#209)
Browse files Browse the repository at this point in the history
Co-authored-by: Mike Fiedler <miketheman@gmail.com>
  • Loading branch information
hasier and miketheman committed Jun 21, 2023
1 parent 0264acd commit 13ca78c
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 15 deletions.
40 changes: 28 additions & 12 deletions pytest_socket.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import ipaddress
import itertools
import socket
import typing
from collections import defaultdict

import pytest

Expand Down Expand Up @@ -192,24 +195,26 @@ def is_ipaddress(address: str) -> bool:
return False


def resolve_hostname(hostname):
def resolve_hostnames(hostname: str) -> typing.Set[str]:
try:
return socket.gethostbyname(hostname)
return {
addr_struct[0] for *_, addr_struct in socket.getaddrinfo(hostname, None)
}
except socket.gaierror:
return None
return set()


def normalize_allowed_hosts(allowed_hosts):
"""Convert all items in `allowed_hosts` to an IP address."""
ip_hosts = []
def normalize_allowed_hosts(
allowed_hosts: typing.List[str],
) -> typing.Dict[str, typing.Set[str]]:
"""Map all items in `allowed_hosts` to IP addresses."""
ip_hosts = defaultdict(set)
for host in allowed_hosts:
host = host.strip()
if is_ipaddress(host):
ip_hosts.append(host)
ip_hosts[host].add(host)
else:
resolved = resolve_hostname(host)
if resolved:
ip_hosts.append(resolved)
ip_hosts[host].update(resolve_hostnames(host))

return ip_hosts

Expand All @@ -222,7 +227,18 @@ def socket_allow_hosts(allowed=None, allow_unix_socket=False):
if not isinstance(allowed, list):
return

allowed_hosts = normalize_allowed_hosts(allowed)
allowed_hosts_by_host = normalize_allowed_hosts(allowed)
allowed_hosts = set(itertools.chain(*allowed_hosts_by_host.values()))
allowed_list = sorted(
[
(
host
if len(normalized) == 1 and next(iter(normalized)) == host
else f"{host} ({','.join(sorted(normalized))})"
)
for host, normalized in allowed_hosts_by_host.items()
]
)

def guarded_connect(inst, *args):
host = host_from_connect_args(args)
Expand All @@ -231,7 +247,7 @@ def guarded_connect(inst, *args):
):
return _true_connect(inst, *args)

raise SocketConnectBlockedError(allowed, host)
raise SocketConnectBlockedError(allowed_list, host)

socket.socket.connect = guarded_connect

Expand Down
32 changes: 29 additions & 3 deletions tests/test_restrict_hosts.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ def {2}():
assert urlopen('http://{0}:{1}/').getcode() == 200
"""

urlopen_hostname_code_template = """
import pytest
from urllib.request import urlopen
{3}
def {2}():
# Skip {{1}} as we expect {{0}} to be the full hostname with or without port
assert urlopen('http://{0}').getcode() == 200
"""


def assert_host_blocked(result, host):
result.stdout.fnmatch_lines(
Expand All @@ -45,6 +55,7 @@ def assert_socket_connect(should_pass, **kwargs):
test_name = inspect.stack()[1][3]

mark = ""
host = kwargs.get("host", httpbin.host)
cli_arg = kwargs.get("cli_arg", None)
code_template = kwargs.get("code_template", connect_code_template)
mark_arg = kwargs.get("mark_arg", None)
Expand All @@ -55,7 +66,7 @@ def assert_socket_connect(should_pass, **kwargs):
elif isinstance(mark_arg, list):
hosts = '","'.join(mark_arg)
mark = f'@pytest.mark.allow_hosts(["{hosts}"])'
code = code_template.format(httpbin.host, httpbin.port, test_name, mark)
code = code_template.format(host, httpbin.port, test_name, mark)
testdir.makepyfile(code)

if cli_arg:
Expand All @@ -67,7 +78,9 @@ def assert_socket_connect(should_pass, **kwargs):
result.assert_outcomes(1, 0, 0)
else:
result.assert_outcomes(0, 0, 1)
assert_host_blocked(result, httpbin.host)
assert_host_blocked(result, host)

return result

return assert_socket_connect

Expand Down Expand Up @@ -106,10 +119,23 @@ def test_single_cli_arg_connect_enabled(assert_connect):
assert_connect(True, cli_arg=localhost)


def test_single_cli_arg_connect_enabled_hostname_resolved(assert_connect):
def test_single_cli_arg_connect_enabled_localhost_resolved(assert_connect):
assert_connect(True, cli_arg="localhost")


def test_single_cli_arg_connect_disabled_hostname_resolved(assert_connect):
result = assert_connect(
False,
cli_arg="localhost",
host="1.2.3.4",
code_template=urlopen_hostname_code_template,
)
result.stdout.fnmatch_lines(
'*A test tried to use socket.socket.connect() with host "1.2.3.4" '
'(allowed: "localhost (127.0.0.1,::1)")*'
)


def test_single_cli_arg_connect_enabled_hostname_unresolvable(assert_connect):
assert_connect(False, cli_arg="unresolvable")

Expand Down

0 comments on commit 13ca78c

Please sign in to comment.