Skip to content

Commit

Permalink
openai[patch]: Partially Revert Update openai chat model to new base …
Browse files Browse the repository at this point in the history
…class interface (#19871)

Partially Reverts #19729

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
  • Loading branch information
nfcampos and efriis committed Apr 1, 2024
1 parent be92cf5 commit aa5797d
Showing 1 changed file with 24 additions and 11 deletions.
35 changes: 24 additions & 11 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
Expand Down Expand Up @@ -474,6 +478,8 @@ def _stream(
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
yield chunk

def _generate(
Expand All @@ -483,12 +489,13 @@ def _generate(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop)
params = {
**params,
**({"stream": self.streaming} if self.streaming else {}),
**kwargs,
}
params = {**params, **kwargs}
response = self.client.create(messages=message_dicts, **params)
return self._create_chat_result(response)

Expand Down Expand Up @@ -569,6 +576,10 @@ async def _astream(
chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None
)
if run_manager:
await run_manager.on_llm_new_token(
token=chunk.text, chunk=chunk, logprobs=logprobs
)
yield chunk

async def _agenerate(
Expand All @@ -578,12 +589,14 @@ async def _agenerate(
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)

message_dicts, params = self._create_message_dicts(messages, stop)
params = {
**params,
**({"stream": self.streaming} if self.streaming else {}),
**kwargs,
}
params = {**params, **kwargs}
response = await self.async_client.create(messages=message_dicts, **params)
return self._create_chat_result(response)

Expand Down

0 comments on commit aa5797d

Please sign in to comment.