Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve type detection for annotations #3556

Merged
Merged
7 changes: 7 additions & 0 deletions hypothesis-python/RELEASE.rst
@@ -0,0 +1,7 @@
RELEASE_TYPE: patch

This patch fixes invalid annotations detected for the tests generated by
:doc:`Ghostwritter <ghostwriter>`. It will now correctly generate `Optional`
types with just one type argument and handle union expressions inside of type
arguments correctly. Additionally, it now supports code with the
`from __future__ import annotations` marker for Python 3.10 and newer.
Zac-HD marked this conversation as resolved.
Show resolved Hide resolved
54 changes: 35 additions & 19 deletions hypothesis-python/src/hypothesis/extra/ghostwriter.py
Expand Up @@ -443,11 +443,11 @@ def _guess_strategy_by_argname(name: str) -> st.SearchStrategy:
return st.nothing()


def _get_params(func: Callable) -> Dict[str, inspect.Parameter]:
def _get_params(func: Callable, eval_str: bool = False) -> Dict[str, inspect.Parameter]:
"""Get non-vararg parameters of `func` as an ordered dict."""
var_param_kinds = (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD)
try:
params = list(get_signature(func).parameters.values())
params = list(get_signature(func, eval_str=eval_str).parameters.values())
except Exception:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
except Exception:
except Exception:
if eval_str:
return _get_params(func)

Evaluating string annotations can fail due to syntax errors, name resolution, or whatever arbitrary code execution we invoked failing; so in those cases I'd rather fall back to unevaluated annotations (above) or switching to annotate=False for that argument. I trust your judgement as to which is preferable.

In either case, a test demonstrating that would be nice - say f(a: "invalid ::: syntax", b: "1/0").

if (
isinstance(func, (types.BuiltinFunctionType, types.BuiltinMethodType))
Expand Down Expand Up @@ -831,7 +831,7 @@ def _annotate_args(
) -> Iterable[str]:
arg_parameters: DefaultDict[str, Set[Any]] = defaultdict(set)
for func in funcs:
for key, param in _get_params(func).items():
for key, param in _get_params(func, eval_str=True).items():
arg_parameters[key].add(param.annotation)

for argname in argnames:
Expand Down Expand Up @@ -884,6 +884,15 @@ def _join_generics(
if origin_type_data is None:
return None

# because typing.Optional is converted to a Union, it also contains None
# since typing.Optional only accepts one type variable, we need to remove it
if origin_type_data is not None and origin_type_data[0] == "typing.Optional":
annotations = (
annotation
for annotation in annotations
if annotation is None or annotation.type_name != "None"
)

origin_type, imports = origin_type_data
joined = _join_argument_annotations(annotations)
if joined is None or not joined[0]:
Expand Down Expand Up @@ -920,35 +929,41 @@ def _parameter_to_annotation(parameter: Any) -> Optional[_AnnotationData]:
return None
return _parameter_to_annotation(forwarded_value)

# the arguments of Callable are in a list
if isinstance(parameter, list):
joined = _join_argument_annotations(
_parameter_to_annotation(param) for param in parameter
)
if joined is None:
return None
arg_type_names, new_imports = joined
return _AnnotationData("[{}]".format(", ".join(arg_type_names)), new_imports)

if isinstance(parameter, type):
if parameter.__module__ == "builtins":
return _AnnotationData(
"None" if parameter.__name__ == "NoneType" else parameter.__name__,
set(),
)
return _AnnotationData(
f"{parameter.__module__}.{parameter.__name__}", {parameter.__module__}
)

# the arguments of Callable are in a list
if isinstance(parameter, list):
joined = _join_argument_annotations(map(_parameter_to_annotation, parameter))
if joined is None:
return None
arg_type_names, new_imports = joined
return _AnnotationData("[{}]".format(", ".join(arg_type_names)), new_imports)
type_name = f"{parameter.__module__}.{parameter.__name__}"

# the types.UnionType does not support type arguments and needs to be translated
if type_name == "types.UnionType":
return _AnnotationData("typing.Union", {"typing"})
else:
if hasattr(parameter, "__module__") and hasattr(parameter, "__name__"):
type_name = f"{parameter.__module__}.{parameter.__name__}"
else:
type_name = str(parameter)

origin_type = get_origin(parameter)

# if not generic or no generic arguments
if origin_type is None or origin_type == parameter:
type_name = str(parameter)
if type_name.startswith("typing."):
return _AnnotationData(type_name, {"typing"})
return _AnnotationData(type_name, set())
return _AnnotationData(type_name, set(type_name.rsplit(".", maxsplit=1)[:-1]))

arg_types = get_args(parameter)
type_name = str(parameter)

# typing types get translated to classes that don't support generics
origin_annotation: Optional[_AnnotationData]
Expand All @@ -963,7 +978,8 @@ def _parameter_to_annotation(parameter: Any) -> Optional[_AnnotationData]:

if arg_types:
return _join_generics(
origin_annotation, map(_parameter_to_annotation, arg_types)
origin_annotation,
(_parameter_to_annotation(arg_type) for arg_type in arg_types),
)
return origin_annotation

Expand Down
24 changes: 22 additions & 2 deletions hypothesis-python/src/hypothesis/internal/reflection.py
Expand Up @@ -127,7 +127,9 @@ def check_signature(sig: inspect.Signature) -> None:
)


def get_signature(target: Any, *, follow_wrapped: bool = True) -> inspect.Signature:
def get_signature(
target: Any, *, follow_wrapped: bool = True, eval_str: bool = False
) -> inspect.Signature:
# Special case for use of `@unittest.mock.patch` decorator, mimicking the
# behaviour of getfullargspec instead of reporting unusable arguments.
patches = getattr(target, "patchings", None)
Expand Down Expand Up @@ -164,11 +166,29 @@ def get_signature(target: Any, *, follow_wrapped: bool = True) -> inspect.Signat
return sig.replace(
parameters=[v for k, v in sig.parameters.items() if k != "self"]
)
sig = inspect.signature(target, follow_wrapped=follow_wrapped)
sig = _inspect_signature(target, follow_wrapped=follow_wrapped, eval_str=eval_str)
Zac-HD marked this conversation as resolved.
Show resolved Hide resolved
check_signature(sig)
return sig


# eval_str is only supported by Python 3.10 and newer
if sys.version_info[:2] >= (3, 10):

def _inspect_signature(
target: Any, *, follow_wrapped: bool = True, eval_str: bool = False
) -> inspect.Signature:
return inspect.signature(
target, follow_wrapped=follow_wrapped, eval_str=eval_str
)

else:

def _inspect_signature(
target: Any, *, follow_wrapped: bool = True, eval_str: bool = False
) -> inspect.Signature: # pragma: no cover
return inspect.signature(target, follow_wrapped=follow_wrapped)


def arg_is_required(param):
return param.default is inspect.Parameter.empty and param.kind in (
inspect.Parameter.POSITIONAL_OR_KEYWORD,
Expand Down
9 changes: 9 additions & 0 deletions hypothesis-python/tests/ghostwriter/example_code/__init__.py
@@ -0,0 +1,9 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.
@@ -0,0 +1,30 @@
# This file is part of Hypothesis, which may be found at
# https://github.com/HypothesisWorks/hypothesis/
#
# Copyright the Hypothesis Authors.
# Individual contributors are listed in AUTHORS.rst and the git log.
#
# This Source Code Form is subject to the terms of the Mozilla Public License,
# v. 2.0. If a copy of the MPL was not distributed with this file, You can
# obtain one at https://mozilla.org/MPL/2.0/.

from __future__ import annotations

import collections.abc


class CustomClass:
def __init__(self, number: int) -> None:
self.number = number


def add_custom_classes(c1: CustomClass, c2: CustomClass | None = None) -> CustomClass:
if c2 is None:
return CustomClass(c1.number)
return CustomClass(c1.number + c2.number)


def merge_dicts(
map1: collections.abc.Mapping[str, int], map2: collections.abc.Mapping[str, int]
) -> collections.abc.Mapping[str, int]:
return {**map1, **map2}
@@ -0,0 +1,15 @@
# This test code was written by the `hypothesis.extra.ghostwriter` module
# and is provided under the Creative Commons Zero public domain dedication.

import example_code.future_annotations
import typing
from example_code.future_annotations import CustomClass
from hypothesis import given, strategies as st


@given(c1=st.builds(CustomClass), c2=st.one_of(st.none(), st.builds(CustomClass)))
def test_fuzz_add_custom_classes(
c1: example_code.future_annotations.CustomClass,
c2: typing.Union[example_code.future_annotations.CustomClass, None],
) -> None:
example_code.future_annotations.add_custom_classes(c1=c1, c2=c2)
Expand Up @@ -4,6 +4,7 @@
import datetime
import hypothesis
import hypothesis.strategies
import hypothesis.strategies._internal.strategies
import random
import typing
from hypothesis import given, settings, strategies as st
Expand All @@ -29,7 +30,9 @@ def test_fuzz_event(value: str) -> None:
database_key=st.one_of(st.none(), st.binary()),
)
def test_fuzz_find(
specifier: hypothesis.strategies.SearchStrategy,
specifier: hypothesis.strategies.SearchStrategy[
hypothesis.strategies._internal.strategies.Ex
],
condition: typing.Callable[[typing.Any], bool],
settings: typing.Union[hypothesis.settings, None],
random: typing.Union[random.Random, None],
Expand Down
23 changes: 23 additions & 0 deletions hypothesis-python/tests/ghostwriter/recorded/merge_dicts.txt
@@ -0,0 +1,23 @@
# This test code was written by the `hypothesis.extra.ghostwriter` module
# and is provided under the Creative Commons Zero public domain dedication.

import collections.abc
import example_code.future_annotations
from collections import ChainMap
from hypothesis import given, strategies as st


@given(
map1=st.one_of(
st.dictionaries(keys=st.text(), values=st.integers()),
st.dictionaries(keys=st.text(), values=st.integers()).map(ChainMap),
),
map2=st.one_of(
st.dictionaries(keys=st.text(), values=st.integers()),
st.dictionaries(keys=st.text(), values=st.integers()).map(ChainMap),
),
)
def test_fuzz_merge_dicts(
map1: collections.abc.Mapping[str, int], map2: collections.abc.Mapping[str, int]
) -> None:
example_code.future_annotations.merge_dicts(map1=map1, map2=map2)
@@ -0,0 +1,11 @@
# This test code was written by the `hypothesis.extra.ghostwriter` module
# and is provided under the Creative Commons Zero public domain dedication.

import test_expected_output
import typing
from hypothesis import given, strategies as st


@given(a=st.floats(), b=st.one_of(st.none(), st.floats()))
def test_fuzz_optional_parameter(a: float, b: typing.Optional[float]) -> None:
test_expected_output.optional_parameter(a=a, b=b)
@@ -0,0 +1,11 @@
# This test code was written by the `hypothesis.extra.ghostwriter` module
# and is provided under the Creative Commons Zero public domain dedication.

import test_expected_output
import typing
from hypothesis import given, strategies as st


@given(a=st.floats(), b=st.one_of(st.none(), st.floats()))
def test_fuzz_optional_parameter(a: float, b: typing.Union[float, None]) -> None:
test_expected_output.optional_parameter(a=a, b=b)
@@ -0,0 +1,13 @@
# This test code was written by the `hypothesis.extra.ghostwriter` module
# and is provided under the Creative Commons Zero public domain dedication.

import test_expected_output
import typing
from hypothesis import given, strategies as st


@given(a=st.floats(), b=st.one_of(st.none(), st.floats(), st.integers()))
def test_fuzz_optional_union_parameter(
a: float, b: typing.Union[float, int, None]
) -> None:
test_expected_output.optional_union_parameter(a=a, b=b)
@@ -0,0 +1,13 @@
# This test code was written by the `hypothesis.extra.ghostwriter` module
# and is provided under the Creative Commons Zero public domain dedication.

import test_expected_output
import typing
from hypothesis import given, strategies as st


@given(items=st.one_of(st.binary(), st.lists(st.one_of(st.floats(), st.integers()))))
def test_fuzz_union_sequence_parameter(
items: typing.Sequence[typing.Union[float, int]]
) -> None:
test_expected_output.union_sequence_parameter(items=items)
40 changes: 39 additions & 1 deletion hypothesis-python/tests/ghostwriter/test_expected_output.py
Expand Up @@ -21,10 +21,11 @@
import pathlib
import re
import sys
from typing import Sequence
from typing import Optional, Sequence, Union

import numpy
import pytest
from example_code.future_annotations import add_custom_classes, merge_dicts

import hypothesis
from hypothesis.extra import ghostwriter
Expand Down Expand Up @@ -82,6 +83,25 @@ def divide(a: int, b: int) -> float:
return a / b


def optional_parameter(a: float, b: Optional[float]) -> float:
return optional_union_parameter(a, b)


def optional_union_parameter(a: float, b: Optional[Union[float, int]]) -> float:
return a if b is None else a + b


if sys.version_info[:2] >= (3, 10):

def union_sequence_parameter(items: Sequence[float | int]) -> float:
return sum(items)

else:

def union_sequence_parameter(items: Sequence[Union[float, int]]) -> float:
return sum(items)


# Note: for some of the `expected` outputs, we replace away some small
# parts which vary between minor versions of Python.
@pytest.mark.parametrize(
Expand All @@ -94,6 +114,24 @@ def divide(a: int, b: int) -> float:
("fuzz_staticmethod", ghostwriter.fuzz(A_Class.a_staticmethod)),
("fuzz_ufunc", ghostwriter.fuzz(numpy.add)),
("magic_gufunc", ghostwriter.magic(numpy.matmul)),
pytest.param(
("optional_parameter", ghostwriter.magic(optional_parameter)),
marks=pytest.mark.skipif("sys.version_info[:2] < (3, 9)"),
),
pytest.param(
("optional_parameter_pre_py_3_9", ghostwriter.magic(optional_parameter)),
marks=pytest.mark.skipif("sys.version_info[:2] >= (3, 9)"),
),
("optional_union_parameter", ghostwriter.magic(optional_union_parameter)),
("union_sequence_parameter", ghostwriter.magic(union_sequence_parameter)),
pytest.param(
("add_custom_classes", ghostwriter.magic(add_custom_classes)),
marks=pytest.mark.skipif("sys.version_info[:2] < (3, 10)"),
),
pytest.param(
("merge_dicts", ghostwriter.magic(merge_dicts)),
marks=pytest.mark.skipif("sys.version_info[:2] < (3, 10)"),
),
("magic_base64_roundtrip", ghostwriter.magic(base64.b64encode)),
(
"magic_base64_roundtrip_with_annotations",
Expand Down