Skip to content

Commit

Permalink
Make imprecise constraints handling more robust (#16502)
Browse files Browse the repository at this point in the history
Fixes #16485

My initial implementation of imprecise constraints fallback was really
fragile and ad-hoc, and I now see several edge case scenarios where we
may end up using imprecise constraints for a `ParamSpec` while some
precise ones are available. So I re-organized it: now we just infer
everything as normally, and filter out imprecise (if needed) at the very
end, when we have the full picture. I also fix an accidental omission in
`expand_type()`.
  • Loading branch information
ilevkivskyi committed Nov 22, 2023
1 parent a3e488d commit 3e6b552
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 33 deletions.
76 changes: 43 additions & 33 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,25 +226,22 @@ def infer_constraints_for_callable(
actual_type = mapper.expand_actual_type(
actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i]
)
if (
param_spec
and callee.arg_kinds[i] in (ARG_STAR, ARG_STAR2)
and not incomplete_star_mapping
):
if param_spec and callee.arg_kinds[i] in (ARG_STAR, ARG_STAR2):
# If actual arguments are mapped to ParamSpec type, we can't infer individual
# constraints, instead store them and infer single constraint at the end.
# It is impossible to map actual kind to formal kind, so use some heuristic.
# This inference is used as a fallback, so relying on heuristic should be OK.
param_spec_arg_types.append(
mapper.expand_actual_type(
actual_arg_type, arg_kinds[actual], None, arg_kinds[actual]
if not incomplete_star_mapping:
param_spec_arg_types.append(
mapper.expand_actual_type(
actual_arg_type, arg_kinds[actual], None, arg_kinds[actual]
)
)
)
actual_kind = arg_kinds[actual]
param_spec_arg_kinds.append(
ARG_POS if actual_kind not in (ARG_STAR, ARG_STAR2) else actual_kind
)
param_spec_arg_names.append(arg_names[actual] if arg_names else None)
actual_kind = arg_kinds[actual]
param_spec_arg_kinds.append(
ARG_POS if actual_kind not in (ARG_STAR, ARG_STAR2) else actual_kind
)
param_spec_arg_names.append(arg_names[actual] if arg_names else None)
else:
c = infer_constraints(callee.arg_types[i], actual_type, SUPERTYPE_OF)
constraints.extend(c)
Expand All @@ -267,6 +264,9 @@ def infer_constraints_for_callable(
),
)
)
if any(isinstance(v, ParamSpecType) for v in callee.variables):
# As a perf optimization filter imprecise constraints only when we can have them.
constraints = filter_imprecise_kinds(constraints)
return constraints


Expand Down Expand Up @@ -1094,29 +1094,18 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
)

param_spec_target: Type | None = None
skip_imprecise = (
any(c.type_var == param_spec.id for c in res) and cactual.imprecise_arg_kinds
)
if not cactual_ps:
max_prefix_len = len([k for k in cactual.arg_kinds if k in (ARG_POS, ARG_OPT)])
prefix_len = min(prefix_len, max_prefix_len)
# This logic matches top-level callable constraint exception, if we managed
# to get other constraints for ParamSpec, don't infer one with imprecise kinds
if not skip_imprecise:
param_spec_target = Parameters(
arg_types=cactual.arg_types[prefix_len:],
arg_kinds=cactual.arg_kinds[prefix_len:],
arg_names=cactual.arg_names[prefix_len:],
variables=cactual.variables
if not type_state.infer_polymorphic
else [],
imprecise_arg_kinds=cactual.imprecise_arg_kinds,
)
param_spec_target = Parameters(
arg_types=cactual.arg_types[prefix_len:],
arg_kinds=cactual.arg_kinds[prefix_len:],
arg_names=cactual.arg_names[prefix_len:],
variables=cactual.variables if not type_state.infer_polymorphic else [],
imprecise_arg_kinds=cactual.imprecise_arg_kinds,
)
else:
if (
len(param_spec.prefix.arg_types) <= len(cactual_ps.prefix.arg_types)
and not skip_imprecise
):
if len(param_spec.prefix.arg_types) <= len(cactual_ps.prefix.arg_types):
param_spec_target = cactual_ps.copy_modified(
prefix=Parameters(
arg_types=cactual_ps.prefix.arg_types[prefix_len:],
Expand Down Expand Up @@ -1611,3 +1600,24 @@ def infer_callable_arguments_constraints(
infer_directed_arg_constraints(left_by_name.typ, right_by_name.typ, direction)
)
return res


def filter_imprecise_kinds(cs: list[Constraint]) -> list[Constraint]:
"""For each ParamSpec remove all imprecise constraints, if at least one precise available."""
have_precise = set()
for c in cs:
if not isinstance(c.origin_type_var, ParamSpecType):
continue
if (
isinstance(c.target, ParamSpecType)
or isinstance(c.target, Parameters)
and not c.target.imprecise_arg_kinds
):
have_precise.add(c.type_var)
new_cs = []
for c in cs:
if not isinstance(c.origin_type_var, ParamSpecType) or c.type_var not in have_precise:
new_cs.append(c)
if not isinstance(c.target, Parameters) or not c.target.imprecise_arg_kinds:
new_cs.append(c)
return new_cs
1 change: 1 addition & 0 deletions mypy/expandtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ def visit_param_spec(self, t: ParamSpecType) -> Type:
t.prefix.arg_kinds + repl.arg_kinds,
t.prefix.arg_names + repl.arg_names,
variables=[*t.prefix.variables, *repl.variables],
imprecise_arg_kinds=repl.imprecise_arg_kinds,
)
else:
# We could encode Any as trivial parameters etc., but it would be too verbose.
Expand Down
23 changes: 23 additions & 0 deletions test-data/unit/check-parameter-specification.test
Original file line number Diff line number Diff line change
Expand Up @@ -2163,3 +2163,26 @@ def func2(arg: T) -> List[Union[T, str]]:
reveal_type(func2) # N: Revealed type is "def [T] (arg: T`-1) -> Union[T`-1, builtins.str]"
reveal_type(func2(42)) # N: Revealed type is "Union[builtins.int, builtins.str]"
[builtins fixtures/paramspec.pyi]

[case testParamSpecPreciseKindsUsedIfPossible]
from typing import Callable, Generic
from typing_extensions import ParamSpec

P = ParamSpec('P')

class Case(Generic[P]):
def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None:
pass

def _test(a: int, b: int = 0) -> None: ...

def parametrize(
func: Callable[P, None], *cases: Case[P], **named_cases: Case[P]
) -> Callable[[], None]:
...

parametrize(_test, Case(1, 2), Case(3, 4))
parametrize(_test, Case(1, b=2), Case(3, b=4))
parametrize(_test, Case(1, 2), Case(3))
parametrize(_test, Case(1, 2), Case(3, b=4))
[builtins fixtures/paramspec.pyi]

0 comments on commit 3e6b552

Please sign in to comment.