Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit dc7f6cb

Browse files
authoredNov 29, 2024··
feat(client): make ChatCompletionStreamState public (#1898)
1 parent 6974a98 commit dc7f6cb

File tree

3 files changed

+123
-5
lines changed

3 files changed

+123
-5
lines changed
 

‎src/openai/lib/streaming/chat/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ._completions import (
2222
ChatCompletionStream as ChatCompletionStream,
2323
AsyncChatCompletionStream as AsyncChatCompletionStream,
24+
ChatCompletionStreamState as ChatCompletionStreamState,
2425
ChatCompletionStreamManager as ChatCompletionStreamManager,
2526
AsyncChatCompletionStreamManager as AsyncChatCompletionStreamManager,
2627
)

‎src/openai/lib/streaming/chat/_completions.py

+29-4
Original file line numberDiff line numberDiff line change
@@ -287,11 +287,31 @@ async def __aexit__(
287287

288288

289289
class ChatCompletionStreamState(Generic[ResponseFormatT]):
290+
"""Helper class for manually accumulating `ChatCompletionChunk`s into a final `ChatCompletion` object.
291+
292+
This is useful in cases where you can't always use the `.stream()` method, e.g.
293+
294+
```py
295+
from openai.lib.streaming.chat import ChatCompletionStreamState
296+
297+
state = ChatCompletionStreamState()
298+
299+
stream = client.chat.completions.create(..., stream=True)
300+
for chunk in response:
301+
state.handle_chunk(chunk)
302+
303+
# can also access the accumulated `ChatCompletion` mid-stream
304+
state.current_completion_snapshot
305+
306+
print(state.get_final_completion())
307+
```
308+
"""
309+
290310
def __init__(
291311
self,
292312
*,
293-
input_tools: Iterable[ChatCompletionToolParam] | NotGiven,
294-
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven,
313+
input_tools: Iterable[ChatCompletionToolParam] | NotGiven = NOT_GIVEN,
314+
response_format: type[ResponseFormatT] | ResponseFormatParam | NotGiven = NOT_GIVEN,
295315
) -> None:
296316
self.__current_completion_snapshot: ParsedChatCompletionSnapshot | None = None
297317
self.__choice_event_states: list[ChoiceEventState] = []
@@ -301,6 +321,11 @@ def __init__(
301321
self._rich_response_format: type | NotGiven = response_format if inspect.isclass(response_format) else NOT_GIVEN
302322

303323
def get_final_completion(self) -> ParsedChatCompletion[ResponseFormatT]:
324+
"""Parse the final completion object.
325+
326+
Note this does not provide any guarantees that the stream has actually finished, you must
327+
only call this method when the stream is finished.
328+
"""
304329
return parse_chat_completion(
305330
chat_completion=self.current_completion_snapshot,
306331
response_format=self._rich_response_format,
@@ -312,8 +337,8 @@ def current_completion_snapshot(self) -> ParsedChatCompletionSnapshot:
312337
assert self.__current_completion_snapshot is not None
313338
return self.__current_completion_snapshot
314339

315-
def handle_chunk(self, chunk: ChatCompletionChunk) -> list[ChatCompletionStreamEvent[ResponseFormatT]]:
316-
"""Accumulate a new chunk into the snapshot and returns a list of events to yield."""
340+
def handle_chunk(self, chunk: ChatCompletionChunk) -> Iterable[ChatCompletionStreamEvent[ResponseFormatT]]:
341+
"""Accumulate a new chunk into the snapshot and returns an iterable of events to yield."""
317342
self.__current_completion_snapshot = self._accumulate_chunk(chunk)
318343

319344
return self._build_events(

‎tests/lib/chat/test_completions_streaming.py

+93-1
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@
1313

1414
import openai
1515
from openai import OpenAI, AsyncOpenAI
16-
from openai._utils import assert_signatures_in_sync
16+
from openai._utils import consume_sync_iterator, assert_signatures_in_sync
1717
from openai._compat import model_copy
18+
from openai.types.chat import ChatCompletionChunk
1819
from openai.lib.streaming.chat import (
1920
ContentDoneEvent,
2021
ChatCompletionStream,
2122
ChatCompletionStreamEvent,
23+
ChatCompletionStreamState,
2224
ChatCompletionStreamManager,
2325
ParsedChatCompletionSnapshot,
2426
)
@@ -997,6 +999,55 @@ def test_allows_non_strict_tools_but_no_parsing(
997999
)
9981000

9991001

1002+
@pytest.mark.respx(base_url=base_url)
1003+
def test_chat_completion_state_helper(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None:
1004+
state = ChatCompletionStreamState()
1005+
1006+
def streamer(client: OpenAI) -> Iterator[ChatCompletionChunk]:
1007+
stream = client.chat.completions.create(
1008+
model="gpt-4o-2024-08-06",
1009+
messages=[
1010+
{
1011+
"role": "user",
1012+
"content": "What's the weather like in SF?",
1013+
},
1014+
],
1015+
stream=True,
1016+
)
1017+
for chunk in stream:
1018+
state.handle_chunk(chunk)
1019+
yield chunk
1020+
1021+
_make_raw_stream_snapshot_request(
1022+
streamer,
1023+
content_snapshot=snapshot(external("e2aad469b71d*.bin")),
1024+
mock_client=client,
1025+
respx_mock=respx_mock,
1026+
)
1027+
1028+
assert print_obj(state.get_final_completion().choices, monkeypatch) == snapshot(
1029+
"""\
1030+
[
1031+
ParsedChoice[NoneType](
1032+
finish_reason='stop',
1033+
index=0,
1034+
logprobs=None,
1035+
message=ParsedChatCompletionMessage[NoneType](
1036+
audio=None,
1037+
content="I'm unable to provide real-time weather updates. To get the current weather in San Francisco, I
1038+
recommend checking a reliable weather website or a weather app.",
1039+
function_call=None,
1040+
parsed=None,
1041+
refusal=None,
1042+
role='assistant',
1043+
tool_calls=[]
1044+
)
1045+
)
1046+
]
1047+
"""
1048+
)
1049+
1050+
10001051
@pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
10011052
def test_stream_method_in_sync(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
10021053
checking_client: OpenAI | AsyncOpenAI = client if sync else async_client
@@ -1075,3 +1126,44 @@ def _on_response(response: httpx.Response) -> None:
10751126
client.close()
10761127

10771128
return listener
1129+
1130+
1131+
def _make_raw_stream_snapshot_request(
1132+
func: Callable[[OpenAI], Iterator[ChatCompletionChunk]],
1133+
*,
1134+
content_snapshot: Any,
1135+
respx_mock: MockRouter,
1136+
mock_client: OpenAI,
1137+
) -> None:
1138+
live = os.environ.get("OPENAI_LIVE") == "1"
1139+
if live:
1140+
1141+
def _on_response(response: httpx.Response) -> None:
1142+
# update the content snapshot
1143+
assert outsource(response.read()) == content_snapshot
1144+
1145+
respx_mock.stop()
1146+
1147+
client = OpenAI(
1148+
http_client=httpx.Client(
1149+
event_hooks={
1150+
"response": [_on_response],
1151+
}
1152+
)
1153+
)
1154+
else:
1155+
respx_mock.post("/chat/completions").mock(
1156+
return_value=httpx.Response(
1157+
200,
1158+
content=content_snapshot._old_value._load_value(),
1159+
headers={"content-type": "text/event-stream"},
1160+
)
1161+
)
1162+
1163+
client = mock_client
1164+
1165+
stream = func(client)
1166+
consume_sync_iterator(stream)
1167+
1168+
if live:
1169+
client.close()

0 commit comments

Comments
 (0)
Please sign in to comment.