Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow using TypedDict for more precise typing of **kwds #13471

Merged
merged 9 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 12 additions & 4 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,9 +730,10 @@ def check_overlapping_overloads(self, defn: OverloadedFuncDef) -> None:
# This is to match the direction the implementation's return
# needs to be compatible in.
if impl_type.variables:
impl = unify_generic_callable(
impl_type,
sig1,
impl: CallableType | None = unify_generic_callable(
# Normalize both before unifying
impl_type.with_unpacked_kwargs(),
sig1.with_unpacked_kwargs(),
ignore_return=False,
return_constraint_direction=SUPERTYPE_OF,
)
Expand Down Expand Up @@ -1167,7 +1168,7 @@ def check_func_def(self, defn: FuncItem, typ: CallableType, name: str | None) ->
# builtins.tuple[T] is typing.Tuple[T, ...]
arg_type = self.named_generic_type("builtins.tuple", [arg_type])
elif typ.arg_kinds[i] == nodes.ARG_STAR2:
if not isinstance(arg_type, ParamSpecType):
if not isinstance(arg_type, ParamSpecType) and not typ.unpack_kwargs:
arg_type = self.named_generic_type(
"builtins.dict", [self.str_type(), arg_type]
)
Expand Down Expand Up @@ -1912,6 +1913,13 @@ def check_override(

if fail:
emitted_msg = False

# Normalize signatures, so we get better diagnostics.
if isinstance(override, (CallableType, Overloaded)):
override = override.with_unpacked_kwargs()
if isinstance(original, (CallableType, Overloaded)):
original = original.with_unpacked_kwargs()

if (
isinstance(override, CallableType)
and isinstance(original, CallableType)
Expand Down
4 changes: 4 additions & 0 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1322,6 +1322,8 @@ def check_callable_call(

See the docstring of check_call for more information.
"""
# Always unpack **kwargs before checking a call.
callee = callee.with_unpacked_kwargs()
if callable_name is None and callee.name:
callable_name = callee.name
ret_type = get_proper_type(callee.ret_type)
Expand Down Expand Up @@ -2057,6 +2059,8 @@ def check_overload_call(
context: Context,
) -> tuple[Type, Type]:
"""Checks a call to an overloaded function."""
# Normalize unpacked kwargs before checking the call.
callee = callee.with_unpacked_kwargs()
arg_types = self.infer_arg_types_in_empty_context(args)
# Step 1: Filter call targets to remove ones where the argument counts don't match
plausible_targets = self.plausible_overload_call_targets(
Expand Down
6 changes: 5 additions & 1 deletion mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,9 +763,13 @@ def infer_constraints_from_protocol_members(
return res

def visit_callable_type(self, template: CallableType) -> list[Constraint]:
# Normalize callables before matching against each other.
# Note that non-normalized callables can be created in annotations
# using e.g. callback protocols.
template = template.with_unpacked_kwargs()
if isinstance(self.actual, CallableType):
res: list[Constraint] = []
cactual = self.actual
cactual = self.actual.with_unpacked_kwargs()
param_spec = template.param_spec()
if param_spec is None:
# FIX verify argument counts
Expand Down
18 changes: 17 additions & 1 deletion mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

from typing import Tuple

import mypy.typeops
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT
Expand Down Expand Up @@ -141,7 +143,7 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType:

def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
"""Return a simple least upper bound given the declared type."""
# TODO: check infinite recursion for aliases here.
# TODO: check infinite recursion for aliases here?
declaration = get_proper_type(declaration)
s = get_proper_type(s)
t = get_proper_type(t)
Expand Down Expand Up @@ -172,6 +174,9 @@ def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
if isinstance(s, UninhabitedType) and not isinstance(t, UninhabitedType):
s, t = t, s

# Meets/joins require callable type normalization.
s, t = normalize_callables(s, t)

value = t.accept(TypeJoinVisitor(s))
if declaration is None or is_subtype(value, declaration):
return value
Expand Down Expand Up @@ -229,6 +234,9 @@ def join_types(s: Type, t: Type, instance_joiner: InstanceJoiner | None = None)
elif isinstance(t, PlaceholderType):
return AnyType(TypeOfAny.from_error)

# Meets/joins require callable type normalization.
s, t = normalize_callables(s, t)

# Use a visitor to handle non-trivial cases.
return t.accept(TypeJoinVisitor(s, instance_joiner))

Expand Down Expand Up @@ -528,6 +536,14 @@ def is_better(t: Type, s: Type) -> bool:
return False


def normalize_callables(s: ProperType, t: ProperType) -> Tuple[ProperType, ProperType]:
if isinstance(s, (CallableType, Overloaded)):
s = s.with_unpacked_kwargs()
if isinstance(t, (CallableType, Overloaded)):
t = t.with_unpacked_kwargs()
return s, t


def is_similar_callables(t: CallableType, s: CallableType) -> bool:
"""Return True if t and s have identical numbers of
arguments, default arguments and varargs.
Expand Down
4 changes: 4 additions & 0 deletions mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def meet_types(s: Type, t: Type) -> ProperType:
return t
if isinstance(s, UnionType) and not isinstance(t, UnionType):
s, t = t, s

# Meets/joins require callable type normalization.
s, t = join.normalize_callables(s, t)

return t.accept(TypeMeetVisitor(s))


Expand Down
5 changes: 4 additions & 1 deletion mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -2392,7 +2392,10 @@ def [T <: int] f(self, x: int, y: T) -> None
name = tp.arg_names[i]
if name:
s += name + ": "
s += format_type_bare(tp.arg_types[i])
type_str = format_type_bare(tp.arg_types[i])
if tp.arg_kinds[i] == ARG_STAR2 and tp.unpack_kwargs:
type_str = f"Unpack[{type_str}]"
s += type_str
if tp.arg_kinds[i].is_optional():
s += " = ..."
if (
Expand Down
26 changes: 26 additions & 0 deletions mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,7 @@
TypeVarLikeType,
TypeVarType,
UnboundType,
UnpackType,
get_proper_type,
get_proper_types,
invalid_recursive_alias,
Expand Down Expand Up @@ -830,6 +831,8 @@ def analyze_func_def(self, defn: FuncDef) -> None:
self.defer(defn)
return
assert isinstance(result, ProperType)
if isinstance(result, CallableType):
result = self.remove_unpack_kwargs(defn, result)
defn.type = result
self.add_type_alias_deps(analyzer.aliases_used)
self.check_function_signature(defn)
Expand Down Expand Up @@ -872,6 +875,29 @@ def analyze_func_def(self, defn: FuncDef) -> None:
defn.type = defn.type.copy_modified(ret_type=ret_type)
self.wrapped_coro_return_types[defn] = defn.type

def remove_unpack_kwargs(self, defn: FuncDef, typ: CallableType) -> CallableType:
if not typ.arg_kinds or typ.arg_kinds[-1] is not ArgKind.ARG_STAR2:
return typ
last_type = get_proper_type(typ.arg_types[-1])
if not isinstance(last_type, UnpackType):
return typ
last_type = get_proper_type(last_type.type)
if not isinstance(last_type, TypedDictType):
self.fail("Unpack item in ** argument must be a TypedDict", defn)
new_arg_types = typ.arg_types[:-1] + [AnyType(TypeOfAny.from_error)]
return typ.copy_modified(arg_types=new_arg_types)
overlap = set(typ.arg_names) & set(last_type.items)
# It is OK for TypedDict to have a key named 'kwargs'.
overlap.discard(typ.arg_names[-1])
if overlap:
overlapped = ", ".join([f'"{name}"' for name in overlap])
self.fail(f"Overlap between argument names and ** TypedDict items: {overlapped}", defn)
new_arg_types = typ.arg_types[:-1] + [AnyType(TypeOfAny.from_error)]
return typ.copy_modified(arg_types=new_arg_types)
# OK, everything looks right now, mark the callable type as using unpack.
new_arg_types = typ.arg_types[:-1] + [last_type]
return typ.copy_modified(arg_types=new_arg_types, unpack_kwargs=True)

def prepare_method_signature(self, func: FuncDef, info: TypeInfo) -> None:
"""Check basic signature validity and tweak annotation of self/cls argument."""
# Only non-static methods are special.
Expand Down
25 changes: 16 additions & 9 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
Instance,
LiteralType,
NoneType,
NormalizedCallableType,
Overloaded,
Parameters,
ParamSpecType,
Expand Down Expand Up @@ -626,8 +627,10 @@ def visit_unpack_type(self, left: UnpackType) -> bool:
return False

def visit_parameters(self, left: Parameters) -> bool:
right = self.right
if isinstance(right, Parameters) or isinstance(right, CallableType):
if isinstance(self.right, Parameters) or isinstance(self.right, CallableType):
right = self.right
if isinstance(right, CallableType):
right = right.with_unpacked_kwargs()
return are_parameters_compatible(
left,
right,
Expand Down Expand Up @@ -671,7 +674,7 @@ def visit_callable_type(self, left: CallableType) -> bool:
elif isinstance(right, Parameters):
# this doesn't check return types.... but is needed for is_equivalent
return are_parameters_compatible(
left,
left.with_unpacked_kwargs(),
right,
is_compat=self._is_subtype,
ignore_pos_arg_names=self.subtype_context.ignore_pos_arg_names,
Expand Down Expand Up @@ -1249,6 +1252,10 @@ def g(x: int) -> int: ...
If the 'some_check' function is also symmetric, the two calls would be equivalent
whether or not we check the args covariantly.
"""
# Normalize both types before comparing them.
left = left.with_unpacked_kwargs()
right = right.with_unpacked_kwargs()

if is_compat_return is None:
is_compat_return = is_compat

Expand Down Expand Up @@ -1313,8 +1320,8 @@ def g(x: int) -> int: ...


def are_parameters_compatible(
left: Parameters | CallableType,
right: Parameters | CallableType,
left: Parameters | NormalizedCallableType,
right: Parameters | NormalizedCallableType,
*,
is_compat: Callable[[Type, Type], bool],
ignore_pos_arg_names: bool = False,
Expand Down Expand Up @@ -1535,11 +1542,11 @@ def new_is_compat(left: Type, right: Type) -> bool:


def unify_generic_callable(
type: CallableType,
target: CallableType,
type: NormalizedCallableType,
target: NormalizedCallableType,
ignore_return: bool,
return_constraint_direction: int | None = None,
) -> CallableType | None:
) -> NormalizedCallableType | None:
"""Try to unify a generic callable type with another callable type.

Return unified CallableType if successful; otherwise, return None.
Expand Down Expand Up @@ -1576,7 +1583,7 @@ def report(*args: Any) -> None:
)
if had_errors:
return None
return applied
return cast(NormalizedCallableType, applied)


def try_restrict_literal_union(t: UnionType, s: Type) -> list[Type] | None:
Expand Down
2 changes: 1 addition & 1 deletion mypy/typeanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,7 @@ def try_analyze_special_unbound_type(self, t: UnboundType, fullname: str) -> Typ
elif fullname in ("typing.Unpack", "typing_extensions.Unpack"):
# We don't want people to try to use this yet.
if not self.options.enable_incomplete_features:
self.fail('"Unpack" is not supported by mypy yet', t)
self.fail('"Unpack" is not supported yet, use --enable-incomplete-features', t)
return AnyType(TypeOfAny.from_error)
return UnpackType(self.anal_type(t.args[0]), line=t.line, column=t.column)
return None
Expand Down