Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 4e5b368

Browse files
stainless-app[bot]stainless-bot
authored andcommittedFeb 6, 2025·
chore(internal): fix type traversing dictionary params (#2097)
1 parent 9a95db9 commit 4e5b368

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed
 

‎src/openai/_utils/_transform.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
is_annotated_type,
2626
strip_annotated_type,
2727
)
28-
from .._compat import model_dump, is_typeddict
28+
from .._compat import get_origin, model_dump, is_typeddict
2929

3030
_T = TypeVar("_T")
3131

@@ -164,9 +164,14 @@ def _transform_recursive(
164164
inner_type = annotation
165165

166166
stripped_type = strip_annotated_type(inner_type)
167+
origin = get_origin(stripped_type) or stripped_type
167168
if is_typeddict(stripped_type) and is_mapping(data):
168169
return _transform_typeddict(data, stripped_type)
169170

171+
if origin == dict and is_mapping(data):
172+
items_type = get_args(stripped_type)[1]
173+
return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
174+
170175
if (
171176
# List[T]
172177
(is_list_type(stripped_type) and is_list(data))
@@ -307,9 +312,14 @@ async def _async_transform_recursive(
307312
inner_type = annotation
308313

309314
stripped_type = strip_annotated_type(inner_type)
315+
origin = get_origin(stripped_type) or stripped_type
310316
if is_typeddict(stripped_type) and is_mapping(data):
311317
return await _async_transform_typeddict(data, stripped_type)
312318

319+
if origin == dict and is_mapping(data):
320+
items_type = get_args(stripped_type)[1]
321+
return {key: _transform_recursive(value, annotation=items_type) for key, value in data.items()}
322+
313323
if (
314324
# List[T]
315325
(is_list_type(stripped_type) and is_list(data))

‎tests/test_transform.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import io
44
import pathlib
5-
from typing import Any, List, Union, TypeVar, Iterable, Optional, cast
5+
from typing import Any, Dict, List, Union, TypeVar, Iterable, Optional, cast
66
from datetime import date, datetime
77
from typing_extensions import Required, Annotated, TypedDict
88

@@ -388,6 +388,15 @@ def my_iter() -> Iterable[Baz8]:
388388
}
389389

390390

391+
@parametrize
392+
@pytest.mark.asyncio
393+
async def test_dictionary_items(use_async: bool) -> None:
394+
class DictItems(TypedDict):
395+
foo_baz: Annotated[str, PropertyInfo(alias="fooBaz")]
396+
397+
assert await transform({"foo": {"foo_baz": "bar"}}, Dict[str, DictItems], use_async) == {"foo": {"fooBaz": "bar"}}
398+
399+
391400
class TypedDictIterableUnionStr(TypedDict):
392401
foo: Annotated[Union[str, Iterable[Baz8]], PropertyInfo(alias="FOO")]
393402

0 commit comments

Comments
 (0)
Please sign in to comment.