Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mypyc] Support iterating over a TypedDict #14747

Merged
merged 5 commits into from
Mar 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 7 additions & 2 deletions mypyc/irbuild/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
ProperType,
TupleType,
Type,
TypedDictType,
TypeOfAny,
UninhabitedType,
UnionType,
Expand Down Expand Up @@ -892,8 +893,12 @@ def get_dict_base_type(self, expr: Expression) -> list[Instance]:

dict_types = []
for t in types:
assert isinstance(t, Instance), t
dict_base = next(base for base in t.type.mro if base.fullname == "builtins.dict")
if isinstance(t, TypedDictType):
t = t.fallback
dict_base = next(base for base in t.type.mro if base.fullname == "typing.Mapping")
else:
assert isinstance(t, Instance), t
dict_base = next(base for base in t.type.mro if base.fullname == "builtins.dict")
dict_types.append(map_instance_to_supertype(t, dict_base))
return dict_types

Expand Down
69 changes: 69 additions & 0 deletions mypyc/test-data/irbuild-dict.test
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,12 @@ L0:

[case testDictIterationMethods]
from typing import Dict, Union
from typing_extensions import TypedDict

class Person(TypedDict):
name: str
age: int

def print_dict_methods(d1: Dict[int, int], d2: Dict[int, int]) -> None:
for v in d1.values():
if v in d2:
Expand All @@ -229,6 +235,10 @@ def union_of_dicts(d: Union[Dict[str, int], Dict[str, str]]) -> None:
new = {}
for k, v in d.items():
new[k] = int(v)
def typeddict(d: Person) -> None:
for k, v in d.items():
if k == "name":
name = v
[out]
def print_dict_methods(d1, d2):
d1, d2 :: dict
Expand Down Expand Up @@ -370,6 +380,65 @@ L4:
r19 = CPy_NoErrOccured()
L5:
return 1
def typeddict(d):
d :: dict
r0 :: short_int
r1 :: native_int
r2 :: short_int
r3 :: object
r4 :: tuple[bool, short_int, object, object]
r5 :: short_int
r6 :: bool
r7, r8 :: object
r9, k :: str
v :: object
r10 :: str
r11 :: int32
r12 :: bit
r13 :: object
r14, r15, r16 :: bit
name :: object
r17, r18 :: bit
L0:
r0 = 0
r1 = PyDict_Size(d)
r2 = r1 << 1
r3 = CPyDict_GetItemsIter(d)
L1:
r4 = CPyDict_NextItem(r3, r0)
r5 = r4[1]
r0 = r5
r6 = r4[0]
if r6 goto L2 else goto L9 :: bool
L2:
r7 = r4[2]
r8 = r4[3]
r9 = cast(str, r7)
k = r9
v = r8
r10 = 'name'
r11 = PyUnicode_Compare(k, r10)
r12 = r11 == -1
if r12 goto L3 else goto L5 :: bool
L3:
r13 = PyErr_Occurred()
r14 = r13 != 0
if r14 goto L4 else goto L5 :: bool
L4:
r15 = CPy_KeepPropagating()
L5:
r16 = r11 == 0
if r16 goto L6 else goto L7 :: bool
L6:
name = v
L7:
L8:
r17 = CPyDict_CheckSize(d, r2)
goto L1
L9:
r18 = CPy_NoErrOccured()
L10:
return 1

[case testDictLoadAddress]
def f() -> None:
Expand Down
34 changes: 32 additions & 2 deletions mypyc/test-data/run-dicts.test
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,13 @@ assert get_content_set(od) == ({1, 3}, {2, 4}, {(1, 2), (3, 4)})
[typing fixtures/typing-full.pyi]

[case testDictIterationMethodsRun]
from typing import Dict
from typing import Dict, Union
from typing_extensions import TypedDict

class ExtensionDict(TypedDict):
python: str
c: str

def print_dict_methods(d1: Dict[int, int],
d2: Dict[int, int],
d3: Dict[int, int]) -> None:
Expand All @@ -107,13 +113,27 @@ def print_dict_methods(d1: Dict[int, int],
for v in d3.values():
print(v)

def print_dict_methods_special(d1: Union[Dict[int, int], Dict[str, str]],
d2: ExtensionDict) -> None:
for k in d1.keys():
print(k)
for k, v in d1.items():
print(k)
print(v)
for v2 in d2.values():
print(v2)
for k2, v2 in d2.items():
print(k2)
print(v2)


def clear_during_iter(d: Dict[int, int]) -> None:
for k in d:
d.clear()

class Custom(Dict[int, int]): pass
[file driver.py]
from native import print_dict_methods, Custom, clear_during_iter
from native import print_dict_methods, print_dict_methods_special, Custom, clear_during_iter
from collections import OrderedDict
print_dict_methods({}, {}, {})
print_dict_methods({1: 2}, {3: 4, 5: 6}, {7: 8})
Expand All @@ -124,6 +144,7 @@ print('==')
d = OrderedDict([(1, 2), (3, 4)])
print_dict_methods(d, d, d)
print('==')
print_dict_methods_special({1: 2}, {"python": ".py", "c": ".c"})
d.move_to_end(1)
print_dict_methods(d, d, d)
clear_during_iter({}) # OK
Expand Down Expand Up @@ -185,6 +206,15 @@ else:
2
4
==
1
1
2
.py
.c
python
.py
c
.c
3
1
3
Expand Down
5 changes: 4 additions & 1 deletion test-data/unit/lib-stub/typing_extensions.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import typing
from typing import Any, Mapping, Iterator, NoReturn as NoReturn, Dict, Type
from typing import Any, Mapping, Iterable, Iterator, NoReturn as NoReturn, Dict, Tuple, Type
from typing import TYPE_CHECKING as TYPE_CHECKING
from typing import NewType as NewType, overload as overload

Expand Down Expand Up @@ -50,6 +50,9 @@ class _TypedDict(Mapping[str, object]):
# Mypy expects that 'default' has a type variable type.
def pop(self, k: NoReturn, default: _T = ...) -> object: ...
def update(self: _T, __m: _T) -> None: ...
def items(self) -> Iterable[Tuple[str, object]]: ...
def keys(self) -> Iterable[str]: ...
def values(self) -> Iterable[object]: ...
if sys.version_info < (3, 0):
def has_key(self, k: str) -> bool: ...
def __delitem__(self, k: NoReturn) -> None: ...
Expand Down