Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 2b80c90

Browse files
stainless-app[bot]stainless-bot
authored andcommittedNov 4, 2024·
fix: support json safe serialization for basemodel subclasses (#1844)
1 parent 258f265 commit 2b80c90

File tree

7 files changed

+52
-21
lines changed

7 files changed

+52
-21
lines changed
 

‎src/openai/_compat.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
44
from datetime import date, datetime
5-
from typing_extensions import Self
5+
from typing_extensions import Self, Literal
66

77
import pydantic
88
from pydantic.fields import FieldInfo
@@ -137,9 +137,11 @@ def model_dump(
137137
exclude_unset: bool = False,
138138
exclude_defaults: bool = False,
139139
warnings: bool = True,
140+
mode: Literal["json", "python"] = "python",
140141
) -> dict[str, Any]:
141-
if PYDANTIC_V2:
142+
if PYDANTIC_V2 or hasattr(model, "model_dump"):
142143
return model.model_dump(
144+
mode=mode,
143145
exclude=exclude,
144146
exclude_unset=exclude_unset,
145147
exclude_defaults=exclude_defaults,

‎src/openai/_models.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
PropertyInfo,
3939
is_list,
4040
is_given,
41+
json_safe,
4142
lru_cache,
4243
is_mapping,
4344
parse_date,
@@ -304,8 +305,8 @@ def model_dump(
304305
Returns:
305306
A dictionary representation of the model.
306307
"""
307-
if mode != "python":
308-
raise ValueError("mode is only supported in Pydantic v2")
308+
if mode not in {"json", "python"}:
309+
raise ValueError("mode must be either 'json' or 'python'")
309310
if round_trip != False:
310311
raise ValueError("round_trip is only supported in Pydantic v2")
311312
if warnings != True:
@@ -314,7 +315,7 @@ def model_dump(
314315
raise ValueError("context is only supported in Pydantic v2")
315316
if serialize_as_any != False:
316317
raise ValueError("serialize_as_any is only supported in Pydantic v2")
317-
return super().dict( # pyright: ignore[reportDeprecated]
318+
dumped = super().dict( # pyright: ignore[reportDeprecated]
318319
include=include,
319320
exclude=exclude,
320321
by_alias=by_alias,
@@ -323,6 +324,8 @@ def model_dump(
323324
exclude_none=exclude_none,
324325
)
325326

327+
return cast(dict[str, Any], json_safe(dumped)) if mode == "json" else dumped
328+
326329
@override
327330
def model_dump_json(
328331
self,

‎src/openai/_utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
is_list as is_list,
77
is_given as is_given,
88
is_tuple as is_tuple,
9+
json_safe as json_safe,
910
lru_cache as lru_cache,
1011
is_mapping as is_mapping,
1112
is_tuple_t as is_tuple_t,

‎src/openai/_utils/_transform.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def _transform_recursive(
191191
return data
192192

193193
if isinstance(data, pydantic.BaseModel):
194-
return model_dump(data, exclude_unset=True)
194+
return model_dump(data, exclude_unset=True, mode="json")
195195

196196
annotated_type = _get_annotated_type(annotation)
197197
if annotated_type is None:
@@ -329,7 +329,7 @@ async def _async_transform_recursive(
329329
return data
330330

331331
if isinstance(data, pydantic.BaseModel):
332-
return model_dump(data, exclude_unset=True)
332+
return model_dump(data, exclude_unset=True, mode="json")
333333

334334
annotated_type = _get_annotated_type(annotation)
335335
if annotated_type is None:

‎src/openai/_utils/_utils.py

+17
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
overload,
1717
)
1818
from pathlib import Path
19+
from datetime import date, datetime
1920
from typing_extensions import TypeGuard
2021

2122
import sniffio
@@ -395,3 +396,19 @@ def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]:
395396
maxsize=maxsize,
396397
)
397398
return cast(Any, wrapper) # type: ignore[no-any-return]
399+
400+
401+
def json_safe(data: object) -> object:
402+
"""Translates a mapping / sequence recursively in the same fashion
403+
as `pydantic` v2's `model_dump(mode="json")`.
404+
"""
405+
if is_mapping(data):
406+
return {json_safe(key): json_safe(value) for key, value in data.items()}
407+
408+
if is_iterable(data) and not isinstance(data, (str, bytes, bytearray)):
409+
return [json_safe(item) for item in data]
410+
411+
if isinstance(data, (datetime, date)):
412+
return data.isoformat()
413+
414+
return data

‎tests/test_models.py

+7-14
Original file line numberDiff line numberDiff line change
@@ -520,19 +520,15 @@ class Model(BaseModel):
520520
assert m3.to_dict(exclude_none=True) == {}
521521
assert m3.to_dict(exclude_defaults=True) == {}
522522

523-
if PYDANTIC_V2:
524-
525-
class Model2(BaseModel):
526-
created_at: datetime
523+
class Model2(BaseModel):
524+
created_at: datetime
527525

528-
time_str = "2024-03-21T11:39:01.275859"
529-
m4 = Model2.construct(created_at=time_str)
530-
assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)}
531-
assert m4.to_dict(mode="json") == {"created_at": time_str}
532-
else:
533-
with pytest.raises(ValueError, match="mode is only supported in Pydantic v2"):
534-
m.to_dict(mode="json")
526+
time_str = "2024-03-21T11:39:01.275859"
527+
m4 = Model2.construct(created_at=time_str)
528+
assert m4.to_dict(mode="python") == {"created_at": datetime.fromisoformat(time_str)}
529+
assert m4.to_dict(mode="json") == {"created_at": time_str}
535530

531+
if not PYDANTIC_V2:
536532
with pytest.raises(ValueError, match="warnings is only supported in Pydantic v2"):
537533
m.to_dict(warnings=False)
538534

@@ -558,9 +554,6 @@ class Model(BaseModel):
558554
assert m3.model_dump(exclude_none=True) == {}
559555

560556
if not PYDANTIC_V2:
561-
with pytest.raises(ValueError, match="mode is only supported in Pydantic v2"):
562-
m.model_dump(mode="json")
563-
564557
with pytest.raises(ValueError, match="round_trip is only supported in Pydantic v2"):
565558
m.model_dump(round_trip=True)
566559

‎tests/test_transform.py

+15
Original file line numberDiff line numberDiff line change
@@ -177,17 +177,32 @@ class DateDict(TypedDict, total=False):
177177
foo: Annotated[date, PropertyInfo(format="iso8601")]
178178

179179

180+
class DatetimeModel(BaseModel):
181+
foo: datetime
182+
183+
184+
class DateModel(BaseModel):
185+
foo: Optional[date]
186+
187+
180188
@parametrize
181189
@pytest.mark.asyncio
182190
async def test_iso8601_format(use_async: bool) -> None:
183191
dt = datetime.fromisoformat("2023-02-23T14:16:36.337692+00:00")
192+
tz = "Z" if PYDANTIC_V2 else "+00:00"
184193
assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692+00:00"} # type: ignore[comparison-overlap]
194+
assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692" + tz} # type: ignore[comparison-overlap]
185195

186196
dt = dt.replace(tzinfo=None)
187197
assert await transform({"foo": dt}, DatetimeDict, use_async) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap]
198+
assert await transform(DatetimeModel(foo=dt), Any, use_async) == {"foo": "2023-02-23T14:16:36.337692"} # type: ignore[comparison-overlap]
188199

189200
assert await transform({"foo": None}, DateDict, use_async) == {"foo": None} # type: ignore[comparison-overlap]
201+
assert await transform(DateModel(foo=None), Any, use_async) == {"foo": None} # type: ignore
190202
assert await transform({"foo": date.fromisoformat("2023-02-23")}, DateDict, use_async) == {"foo": "2023-02-23"} # type: ignore[comparison-overlap]
203+
assert await transform(DateModel(foo=date.fromisoformat("2023-02-23")), DateDict, use_async) == {
204+
"foo": "2023-02-23"
205+
} # type: ignore[comparison-overlap]
191206

192207

193208
@parametrize

0 commit comments

Comments
 (0)
Please sign in to comment.