Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(allow_hosts): Use getaddrinfo instead of gethostbyname #209

Merged
merged 10 commits into from
Jun 21, 2023
40 changes: 28 additions & 12 deletions pytest_socket.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import ipaddress
import itertools
import socket
import typing
from collections import defaultdict

import pytest

Expand Down Expand Up @@ -192,24 +195,26 @@ def is_ipaddress(address: str) -> bool:
return False


def resolve_hostname(hostname):
def resolve_hostnames(hostname: str) -> typing.Set[str]:
try:
return socket.gethostbyname(hostname)
return {
addr_struct[0] for *_, addr_struct in socket.getaddrinfo(hostname, None)
}
except socket.gaierror:
return None
return set()


def normalize_allowed_hosts(allowed_hosts):
"""Convert all items in `allowed_hosts` to an IP address."""
ip_hosts = []
def normalize_allowed_hosts(
allowed_hosts: typing.List[str],
) -> typing.Dict[str, typing.Set[str]]:
"""Map all items in `allowed_hosts` to IP addresses."""
ip_hosts = defaultdict(set)
for host in allowed_hosts:
host = host.strip()
if is_ipaddress(host):
ip_hosts.append(host)
ip_hosts[host].add(host)
else:
resolved = resolve_hostname(host)
if resolved:
ip_hosts.append(resolved)
ip_hosts[host].update(resolve_hostnames(host))

return ip_hosts

Expand All @@ -222,7 +227,18 @@ def socket_allow_hosts(allowed=None, allow_unix_socket=False):
if not isinstance(allowed, list):
return

allowed_hosts = normalize_allowed_hosts(allowed)
allowed_hosts_by_host = normalize_allowed_hosts(allowed)
allowed_hosts = set(itertools.chain(*allowed_hosts_by_host.values()))
allowed_list = sorted(
[
(
host
if len(normalized) == 1 and next(iter(normalized)) == host
else f"{host} ({','.join(sorted(normalized))})"
)
for host, normalized in allowed_hosts_by_host.items()
]
)

def guarded_connect(inst, *args):
host = host_from_connect_args(args)
Expand All @@ -231,7 +247,7 @@ def guarded_connect(inst, *args):
):
return _true_connect(inst, *args)

raise SocketConnectBlockedError(allowed, host)
raise SocketConnectBlockedError(allowed_list, host)

socket.socket.connect = guarded_connect

Expand Down
32 changes: 29 additions & 3 deletions tests/test_restrict_hosts.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ def {2}():
assert urlopen('http://{0}:{1}/').getcode() == 200
"""

urlopen_hostname_code_template = """
import pytest
from urllib.request import urlopen

{3}
def {2}():
# Skip {{1}} as we expect {{0}} to be the full hostname with or without port
assert urlopen('http://{0}').getcode() == 200
"""


def assert_host_blocked(result, host):
result.stdout.fnmatch_lines(
Expand All @@ -45,6 +55,7 @@ def assert_socket_connect(should_pass, **kwargs):
test_name = inspect.stack()[1][3]

mark = ""
host = kwargs.get("host", httpbin.host)
cli_arg = kwargs.get("cli_arg", None)
code_template = kwargs.get("code_template", connect_code_template)
mark_arg = kwargs.get("mark_arg", None)
Expand All @@ -55,7 +66,7 @@ def assert_socket_connect(should_pass, **kwargs):
elif isinstance(mark_arg, list):
hosts = '","'.join(mark_arg)
mark = f'@pytest.mark.allow_hosts(["{hosts}"])'
code = code_template.format(httpbin.host, httpbin.port, test_name, mark)
code = code_template.format(host, httpbin.port, test_name, mark)
testdir.makepyfile(code)

if cli_arg:
Expand All @@ -67,7 +78,9 @@ def assert_socket_connect(should_pass, **kwargs):
result.assert_outcomes(1, 0, 0)
else:
result.assert_outcomes(0, 0, 1)
assert_host_blocked(result, httpbin.host)
assert_host_blocked(result, host)

return result

return assert_socket_connect

Expand Down Expand Up @@ -106,10 +119,23 @@ def test_single_cli_arg_connect_enabled(assert_connect):
assert_connect(True, cli_arg=localhost)


def test_single_cli_arg_connect_enabled_hostname_resolved(assert_connect):
def test_single_cli_arg_connect_enabled_localhost_resolved(assert_connect):
assert_connect(True, cli_arg="localhost")


def test_single_cli_arg_connect_disabled_hostname_resolved(assert_connect):
result = assert_connect(
False,
cli_arg="localhost",
host="1.2.3.4",
code_template=urlopen_hostname_code_template,
)
result.stdout.fnmatch_lines(
'*A test tried to use socket.socket.connect() with host "1.2.3.4" '
'(allowed: "localhost (127.0.0.1,::1)")*'
)


def test_single_cli_arg_connect_enabled_hostname_unresolvable(assert_connect):
assert_connect(False, cli_arg="unresolvable")

Expand Down