Skip to content

Commit

Permalink
Allow using TypedDict for more precise typing of **kwds (#13471)
Browse files Browse the repository at this point in the history
Fixes #4441

This uses a different approach than the initial attempt, but I re-used some of the test cases from the older PR. The initial idea was to eagerly expand the signature of the function during semantic analysis, but it didn't work well with fine-grained mode and also mypy in general relies on function definition and its type being consistent (and rewriting `FuncDef` sounds too sketchy). So instead I add a boolean flag to `CallableType` to indicate whether type of `**kwargs` is each item type or the "packed" type.

I also add few helpers and safety net in form of a `NewType()`, but in general I am surprised how few places needed normalizing the signatures (because most relevant code paths go through `check_callable_call()` and/or `is_callable_compatible()`). Currently `Unpack[...]` is hidden behind `--enable-incomplete-features`, so this will be too, but IMO this part is 99% complete (you can see even some more exotic use cases like generic TypedDicts and callback protocols in test cases).
  • Loading branch information
ilevkivskyi authored and jhance committed Sep 9, 2022
1 parent dd2e020 commit 35bc1a2
Show file tree
Hide file tree
Showing 16 changed files with 505 additions and 21 deletions.
16 changes: 12 additions & 4 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,9 +728,10 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
# This is to match the direction the implementation's return
# needs to be compatible in.
if impl_type.variables:
impl = unify_generic_callable(
impl_type,
sig1,
impl: CallableType | None = unify_generic_callable(
# Normalize both before unifying
impl_type.with_unpacked_kwargs(),
sig1.with_unpacked_kwargs(),
ignore_return=False,
return_constraint_direction=SUPERTYPE_OF,
)
Expand Down Expand Up @@ -1165,7 +1166,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: str | None) ->
# builtins.tuple[T] is typing.Tuple[T, ...]
arg_type = self.named_generic_type("builtins.tuple", [arg_type])
elif typ.arg_kinds[i] == nodes.ARG_STAR2:
if not isinstance(arg_type, ParamSpecType):
if not isinstance(arg_type, ParamSpecType) and not typ.unpack_kwargs:
arg_type = self.named_generic_type(
"builtins.dict", [self.str_type(), arg_type]
)
Expand Down Expand Up @@ -1887,6 +1888,13 @@ def check_override(

if fail:
emitted_msg = False

# Normalize signatures, so we get better diagnostics.
if isinstance(override, (CallableType, Overloaded)):
override = override.with_unpacked_kwargs()
if isinstance(original, (CallableType, Overloaded)):
original = original.with_unpacked_kwargs()

if (
isinstance(override, CallableType)
and isinstance(original, CallableType)
Expand Down
4 changes: 4 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,6 +1322,8 @@ def check_callable_call(
See the docstring of check_call for more information.
"""
# Always unpack **kwargs before checking a call.
callee = callee.with_unpacked_kwargs()
if callable_name is None and callee.name:
callable_name = callee.name
ret_type = get_proper_type(callee.ret_type)
Expand Down Expand Up @@ -2057,6 +2059,8 @@ def check_overload_call(
context: Context,
) -> tuple[Type, Type]:
"""Checks a call to an overloaded function."""
# Normalize unpacked kwargs before checking the call.
callee = callee.with_unpacked_kwargs()
arg_types = self.infer_arg_types_in_empty_context(args)
# Step 1: Filter call targets to remove ones where the argument counts don't match
plausible_targets = self.plausible_overload_call_targets(
Expand Down
6 changes: 5 additions & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,9 +735,13 @@ def infer_constraints_from_protocol_members(
return res

def visit_callable_type(self, template: CallableType) -> list[Constraint]:
# Normalize callables before matching against each other.
# Note that non-normalized callables can be created in annotations
# using e.g. callback protocols.
template = template.with_unpacked_kwargs()
if isinstance(self.actual, CallableType):
res: list[Constraint] = []
cactual = self.actual
cactual = self.actual.with_unpacked_kwargs()
param_spec = template.param_spec()
if param_spec is None:
# FIX verify argument counts
Expand Down
18 changes: 17 additions & 1 deletion mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from typing import Tuple

import mypy.typeops
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT
Expand Down Expand Up @@ -141,7 +143,7 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType:

def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
"""Return a simple least upper bound given the declared type."""
# TODO: check infinite recursion for aliases here.
# TODO: check infinite recursion for aliases here?
declaration = get_proper_type(declaration)
s = get_proper_type(s)
t = get_proper_type(t)
Expand Down Expand Up @@ -172,6 +174,9 @@ def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
if isinstance(s, UninhabitedType) and not isinstance(t, UninhabitedType):
s, t = t, s

# Meets/joins require callable type normalization.
s, t = normalize_callables(s, t)

value = t.accept(TypeJoinVisitor(s))
if declaration is None or is_subtype(value, declaration):
return value
Expand Down Expand Up @@ -229,6 +234,9 @@ def join_types(s: Type, t: Type, instance_joiner: InstanceJoiner | None = None)
elif isinstance(t, PlaceholderType):
return AnyType(TypeOfAny.from_error)

# Meets/joins require callable type normalization.
s, t = normalize_callables(s, t)

# Use a visitor to handle non-trivial cases.
return t.accept(TypeJoinVisitor(s, instance_joiner))

Expand Down Expand Up @@ -528,6 +536,14 @@ def is_better(t: Type, s: Type) -> bool:
return False


def normalize_callables(s: ProperType, t: ProperType) -> Tuple[ProperType, ProperType]:
if isinstance(s, (CallableType, Overloaded)):
s = s.with_unpacked_kwargs()
if isinstance(t, (CallableType, Overloaded)):
t = t.with_unpacked_kwargs()
return s, t


def is_similar_callables(t: CallableType, s: CallableType) -> bool:
"""Return True if t and s have identical numbers of
arguments, default arguments and varargs.
Expand Down
4 changes: 4 additions & 0 deletions mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def meet_types(s: Type, t: Type) -> ProperType:
return t
if isinstance(s, UnionType) and not isinstance(t, UnionType):
s, t = t, s

# Meets/joins require callable type normalization.
s, t = join.normalize_callables(s, t)

return t.accept(TypeMeetVisitor(s))


Expand Down
5 changes: 4 additions & 1 deletion mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2391,7 +2391,10 @@ def [T <: int] f(self, x: int, y: T) -> None
name = tp.arg_names[i]
if name:
s += name + ": "
s += format_type_bare(tp.arg_types[i])
type_str = format_type_bare(tp.arg_types[i])
if tp.arg_kinds[i] == ARG_STAR2 and tp.unpack_kwargs:
type_str = f"Unpack[{type_str}]"
s += type_str
if tp.arg_kinds[i].is_optional():
s += " = ..."

Expand Down
26 changes: 26 additions & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@
TypeVarLikeType,
TypeVarType,
UnboundType,
UnpackType,
get_proper_type,
get_proper_types,
invalid_recursive_alias,
Expand Down Expand Up @@ -832,6 +833,8 @@ def analyze_func_def(self, defn: FuncDef) -> None:
self.defer(defn)
return
assert isinstance(result, ProperType)
if isinstance(result, CallableType):
result = self.remove_unpack_kwargs(defn, result)
defn.type = result
self.add_type_alias_deps(analyzer.aliases_used)
self.check_function_signature(defn)
Expand Down Expand Up @@ -874,6 +877,29 @@ def analyze_func_def(self, defn: FuncDef) -> None:
defn.type = defn.type.copy_modified(ret_type=ret_type)
self.wrapped_coro_return_types[defn] = defn.type

def remove_unpack_kwargs(self, defn: FuncDef, typ: CallableType) -> CallableType:
if not typ.arg_kinds or typ.arg_kinds[-1] is not ArgKind.ARG_STAR2:
return typ
last_type = get_proper_type(typ.arg_types[-1])
if not isinstance(last_type, UnpackType):
return typ
last_type = get_proper_type(last_type.type)
if not isinstance(last_type, TypedDictType):
self.fail("Unpack item in ** argument must be a TypedDict", defn)
new_arg_types = typ.arg_types[:-1] + [AnyType(TypeOfAny.from_error)]
return typ.copy_modified(arg_types=new_arg_types)
overlap = set(typ.arg_names) & set(last_type.items)
# It is OK for TypedDict to have a key named 'kwargs'.
overlap.discard(typ.arg_names[-1])
if overlap:
overlapped = ", ".join([f'"{name}"' for name in overlap])
self.fail(f"Overlap between argument names and ** TypedDict items: {overlapped}", defn)
new_arg_types = typ.arg_types[:-1] + [AnyType(TypeOfAny.from_error)]
return typ.copy_modified(arg_types=new_arg_types)
# OK, everything looks right now, mark the callable type as using unpack.
new_arg_types = typ.arg_types[:-1] + [last_type]
return typ.copy_modified(arg_types=new_arg_types, unpack_kwargs=True)

def prepare_method_signature(self, func: FuncDef, info: TypeInfo) -> None:
"""Check basic signature validity and tweak annotation of self/cls argument."""
# Only non-static methods are special.
Expand Down
25 changes: 16 additions & 9 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
Instance,
LiteralType,
NoneType,
NormalizedCallableType,
Overloaded,
Parameters,
ParamSpecType,
Expand Down Expand Up @@ -591,8 +592,10 @@ def visit_unpack_type(self, left: UnpackType) -> bool:
return False

def visit_parameters(self, left: Parameters) -> bool:
right = self.right
if isinstance(right, Parameters) or isinstance(right, CallableType):
if isinstance(self.right, Parameters) or isinstance(self.right, CallableType):
right = self.right
if isinstance(right, CallableType):
right = right.with_unpacked_kwargs()
return are_parameters_compatible(
left,
right,
Expand Down Expand Up @@ -636,7 +639,7 @@ def visit_callable_type(self, left: CallableType) -> bool:
elif isinstance(right, Parameters):
# this doesn't check return types.... but is needed for is_equivalent
return are_parameters_compatible(
left,
left.with_unpacked_kwargs(),
right,
is_compat=self._is_subtype,
ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names,
Expand Down Expand Up @@ -1213,6 +1216,10 @@ def g(x: int) -> int: ...
If the 'some_check' function is also symmetric, the two calls would be equivalent
whether or not we check the args covariantly.
"""
# Normalize both types before comparing them.
left = left.with_unpacked_kwargs()
right = right.with_unpacked_kwargs()

if is_compat_return is None:
is_compat_return = is_compat

Expand Down Expand Up @@ -1277,8 +1284,8 @@ def g(x: int) -> int: ...


def are_parameters_compatible(
left: Parameters | CallableType,
right: Parameters | CallableType,
left: Parameters | NormalizedCallableType,
right: Parameters | NormalizedCallableType,
*,
is_compat: Callable[[Type, Type], bool],
ignore_pos_arg_names: bool = False,
Expand Down Expand Up @@ -1499,11 +1506,11 @@ def new_is_compat(left: Type, right: Type) -> bool:


def unify_generic_callable(
type: CallableType,
target: CallableType,
type: NormalizedCallableType,
target: NormalizedCallableType,
ignore_return: bool,
return_constraint_direction: int | None = None,
) -> CallableType | None:
) -> NormalizedCallableType | None:
"""Try to unify a generic callable type with another callable type.
Return unified CallableType if successful; otherwise, return None.
Expand Down Expand Up @@ -1540,7 +1547,7 @@ def report(*args: Any) -> None:
)
if had_errors:
return None
return applied
return cast(NormalizedCallableType, applied)


def try_restrict_literal_union(t: UnionType, s: Type) -> list[Type] | None:
Expand Down
2 changes: 1 addition & 1 deletion mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Typ
elif fullname in ("typing.Unpack", "typing_extensions.Unpack"):
# We don't want people to try to use this yet.
if not self.options.enable_incomplete_features:
self.fail('"Unpack" is not supported by mypy yet', t)
self.fail('"Unpack" is not supported yet, use --enable-incomplete-features', t)
return AnyType(TypeOfAny.from_error)
return UnpackType(self.anal_type(t.args[0]), line=t.line, column=t.column)
return None
Expand Down

0 comments on commit 35bc1a2

Please sign in to comment.