|
13 | 13 |
|
14 | 14 | import openai
|
15 | 15 | from openai import OpenAI, AsyncOpenAI
|
16 |
| -from openai._utils import assert_signatures_in_sync |
| 16 | +from openai._utils import consume_sync_iterator, assert_signatures_in_sync |
17 | 17 | from openai._compat import model_copy
|
| 18 | +from openai.types.chat import ChatCompletionChunk |
18 | 19 | from openai.lib.streaming.chat import (
|
19 | 20 | ContentDoneEvent,
|
20 | 21 | ChatCompletionStream,
|
21 | 22 | ChatCompletionStreamEvent,
|
| 23 | + ChatCompletionStreamState, |
22 | 24 | ChatCompletionStreamManager,
|
23 | 25 | ParsedChatCompletionSnapshot,
|
24 | 26 | )
|
@@ -997,6 +999,55 @@ def test_allows_non_strict_tools_but_no_parsing(
|
997 | 999 | )
|
998 | 1000 |
|
999 | 1001 |
|
| 1002 | +@pytest.mark.respx(base_url=base_url) |
| 1003 | +def test_chat_completion_state_helper(client: OpenAI, respx_mock: MockRouter, monkeypatch: pytest.MonkeyPatch) -> None: |
| 1004 | + state = ChatCompletionStreamState() |
| 1005 | + |
| 1006 | + def streamer(client: OpenAI) -> Iterator[ChatCompletionChunk]: |
| 1007 | + stream = client.chat.completions.create( |
| 1008 | + model="gpt-4o-2024-08-06", |
| 1009 | + messages=[ |
| 1010 | + { |
| 1011 | + "role": "user", |
| 1012 | + "content": "What's the weather like in SF?", |
| 1013 | + }, |
| 1014 | + ], |
| 1015 | + stream=True, |
| 1016 | + ) |
| 1017 | + for chunk in stream: |
| 1018 | + state.handle_chunk(chunk) |
| 1019 | + yield chunk |
| 1020 | + |
| 1021 | + _make_raw_stream_snapshot_request( |
| 1022 | + streamer, |
| 1023 | + content_snapshot=snapshot(external("e2aad469b71d*.bin")), |
| 1024 | + mock_client=client, |
| 1025 | + respx_mock=respx_mock, |
| 1026 | + ) |
| 1027 | + |
| 1028 | + assert print_obj(state.get_final_completion().choices, monkeypatch) == snapshot( |
| 1029 | + """\ |
| 1030 | +[ |
| 1031 | + ParsedChoice[NoneType]( |
| 1032 | + finish_reason='stop', |
| 1033 | + index=0, |
| 1034 | + logprobs=None, |
| 1035 | + message=ParsedChatCompletionMessage[NoneType]( |
| 1036 | + audio=None, |
| 1037 | + content="I'm unable to provide real-time weather updates. To get the current weather in San Francisco, I |
| 1038 | +recommend checking a reliable weather website or a weather app.", |
| 1039 | + function_call=None, |
| 1040 | + parsed=None, |
| 1041 | + refusal=None, |
| 1042 | + role='assistant', |
| 1043 | + tool_calls=[] |
| 1044 | + ) |
| 1045 | + ) |
| 1046 | +] |
| 1047 | +""" |
| 1048 | + ) |
| 1049 | + |
| 1050 | + |
1000 | 1051 | @pytest.mark.parametrize("sync", [True, False], ids=["sync", "async"])
|
1001 | 1052 | def test_stream_method_in_sync(sync: bool, client: OpenAI, async_client: AsyncOpenAI) -> None:
|
1002 | 1053 | checking_client: OpenAI | AsyncOpenAI = client if sync else async_client
|
@@ -1075,3 +1126,44 @@ def _on_response(response: httpx.Response) -> None:
|
1075 | 1126 | client.close()
|
1076 | 1127 |
|
1077 | 1128 | return listener
|
| 1129 | + |
| 1130 | + |
| 1131 | +def _make_raw_stream_snapshot_request( |
| 1132 | + func: Callable[[OpenAI], Iterator[ChatCompletionChunk]], |
| 1133 | + *, |
| 1134 | + content_snapshot: Any, |
| 1135 | + respx_mock: MockRouter, |
| 1136 | + mock_client: OpenAI, |
| 1137 | +) -> None: |
| 1138 | + live = os.environ.get("OPENAI_LIVE") == "1" |
| 1139 | + if live: |
| 1140 | + |
| 1141 | + def _on_response(response: httpx.Response) -> None: |
| 1142 | + # update the content snapshot |
| 1143 | + assert outsource(response.read()) == content_snapshot |
| 1144 | + |
| 1145 | + respx_mock.stop() |
| 1146 | + |
| 1147 | + client = OpenAI( |
| 1148 | + http_client=httpx.Client( |
| 1149 | + event_hooks={ |
| 1150 | + "response": [_on_response], |
| 1151 | + } |
| 1152 | + ) |
| 1153 | + ) |
| 1154 | + else: |
| 1155 | + respx_mock.post("/chat/completions").mock( |
| 1156 | + return_value=httpx.Response( |
| 1157 | + 200, |
| 1158 | + content=content_snapshot._old_value._load_value(), |
| 1159 | + headers={"content-type": "text/event-stream"}, |
| 1160 | + ) |
| 1161 | + ) |
| 1162 | + |
| 1163 | + client = mock_client |
| 1164 | + |
| 1165 | + stream = func(client) |
| 1166 | + consume_sync_iterator(stream) |
| 1167 | + |
| 1168 | + if live: |
| 1169 | + client.close() |
0 commit comments