Skip to content

Commit

Permalink
core: Assign missing message ids in BaseChatModel (langchain-ai#19863)
Browse files Browse the repository at this point in the history
- This ensures ids are stable across streamed chunks
- Multiple messages in batch call get separate ids
- Also fix ids being dropped when combining message chunks

Thank you for contributing to LangChain!

- [ ] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core,
experimental, etc. is being modified. Use "docs: ..." for purely docs
changes, "templates: ..." for template changes, "infra: ..." for CI
changes.
  - Example: "community: add foobar LLM"


- [ ] **PR message**: ***Delete this entire checklist*** and replace
with
    - **Description:** a description of the change
    - **Issue:** the issue # it fixes, if applicable
    - **Dependencies:** any dependencies required for this change
- **Twitter handle:** if your PR gets announced, and you'd like a
mention, we'll gladly shout you out!


- [ ] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [ ] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, hwchase17.
  • Loading branch information
nfcampos authored and marlenezw committed Apr 2, 2024
1 parent 1f33504 commit 389eff3
Show file tree
Hide file tree
Showing 24 changed files with 8,143 additions and 56,694 deletions.
16 changes: 14 additions & 2 deletions libs/core/langchain_core/language_models/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,8 @@ def stream(
run_manager.on_llm_new_token(
cast(str, chunk.message.content), chunk=chunk
)
if chunk.message.id is None:
chunk.message.id = f"run-{run_manager.run_id}"
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
yield chunk.message
if generation is None:
Expand Down Expand Up @@ -294,6 +296,8 @@ async def astream(
await run_manager.on_llm_new_token(
cast(str, chunk.message.content), chunk=chunk
)
if chunk.message.id is None:
chunk.message.id = f"run-{run_manager.run_id}"
chunk.message.response_metadata = _gen_info_and_msg_metadata(chunk)
yield chunk.message
if generation is None:
Expand Down Expand Up @@ -607,6 +611,8 @@ def _generate_with_cache(
chunks: List[ChatGenerationChunk] = []
for chunk in self._stream(messages, stop=stop, **kwargs):
if run_manager:
if chunk.message.id is None:
chunk.message.id = f"run-{run_manager.run_id}"
run_manager.on_llm_new_token(
cast(str, chunk.message.content), chunk=chunk
)
Expand All @@ -622,7 +628,9 @@ def _generate_with_cache(
result = self._generate(messages, stop=stop, **kwargs)

# Add response metadata to each generation
for generation in result.generations:
for idx, generation in enumerate(result.generations):
if run_manager and generation.message.id is None:
generation.message.id = f"run-{run_manager.run_id}-{idx}"
generation.message.response_metadata = _gen_info_and_msg_metadata(
generation
)
Expand Down Expand Up @@ -684,6 +692,8 @@ async def _agenerate_with_cache(
chunks: List[ChatGenerationChunk] = []
async for chunk in self._astream(messages, stop=stop, **kwargs):
if run_manager:
if chunk.message.id is None:
chunk.message.id = f"run-{run_manager.run_id}"
await run_manager.on_llm_new_token(
cast(str, chunk.message.content), chunk=chunk
)
Expand All @@ -699,7 +709,9 @@ async def _agenerate_with_cache(
result = await self._agenerate(messages, stop=stop, **kwargs)

# Add response metadata to each generation
for generation in result.generations:
for idx, generation in enumerate(result.generations):
if run_manager and generation.message.id is None:
generation.message.id = f"run-{run_manager.run_id}-{idx}"
generation.message.response_metadata = _gen_info_and_msg_metadata(
generation
)
Expand Down
8 changes: 6 additions & 2 deletions libs/core/langchain_core/language_models/fake_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ def _stream(
content_chunks = cast(List[str], re.split(r"(\s)", content))

for token in content_chunks:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token, id=message.id)
)
if run_manager:
run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk
Expand All @@ -240,6 +242,7 @@ def _stream(
for fvalue_chunk in fvalue_chunks:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
id=message.id,
content="",
additional_kwargs={
"function_call": {fkey: fvalue_chunk}
Expand All @@ -255,6 +258,7 @@ def _stream(
else:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
id=message.id,
content="",
additional_kwargs={"function_call": {fkey: fvalue}},
)
Expand All @@ -268,7 +272,7 @@ def _stream(
else:
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content="", additional_kwargs={key: value}
id=message.id, content="", additional_kwargs={key: value}
)
)
if run_manager:
Expand Down
6 changes: 6 additions & 0 deletions libs/core/langchain_core/load/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -971,4 +971,10 @@
"tool",
"ToolMessageChunk",
),
("langchain_core", "prompts", "image", "ImagePromptTemplate"): (
"langchain_core",
"prompts",
"image",
"ImagePromptTemplate",
),
}
1 change: 1 addition & 0 deletions libs/core/langchain_core/messages/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
id=self.id,
)

return super().__add__(other)
2 changes: 2 additions & 0 deletions libs/core/langchain_core/messages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class BaseMessage(Serializable):
name: Optional[str] = None

id: Optional[str] = None
"""An optional unique identifier for the message. This should ideally be
provided by the provider/model which created the message."""

class Config:
extra = Extra.allow
Expand Down
2 changes: 2 additions & 0 deletions libs/core/langchain_core/messages/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
id=self.id,
)
elif isinstance(other, BaseMessageChunk):
return self.__class__(
Expand All @@ -65,6 +66,7 @@ def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
id=self.id,
)
else:
return super().__add__(other)
1 change: 1 addition & 0 deletions libs/core/langchain_core/messages/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
id=self.id,
)

return super().__add__(other)
1 change: 1 addition & 0 deletions libs/core/langchain_core/messages/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def __add__(self, other: Any) -> BaseMessageChunk: # type: ignore
response_metadata=merge_dicts(
self.response_metadata, other.response_metadata
),
id=self.id,
)

return super().__add__(other)
26 changes: 19 additions & 7 deletions libs/core/langchain_core/runnables/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def node_data_str(node: Node) -> str:
return data if not data.startswith("Runnable") else data[8:]


def node_data_json(node: Node) -> Dict[str, Union[str, Dict[str, Any]]]:
def node_data_json(
node: Node, *, with_schemas: bool = False
) -> Dict[str, Union[str, Dict[str, Any]]]:
from langchain_core.load.serializable import to_json_not_implemented
from langchain_core.runnables.base import Runnable, RunnableSerializable

Expand All @@ -137,10 +139,17 @@ def node_data_json(node: Node) -> Dict[str, Union[str, Dict[str, Any]]]:
},
}
elif inspect.isclass(node.data) and issubclass(node.data, BaseModel):
return {
"type": "schema",
"data": node.data.schema(),
}
return (
{
"type": "schema",
"data": node.data.schema(),
}
if with_schemas
else {
"type": "schema",
"data": node_data_str(node),
}
)
else:
return {
"type": "unknown",
Expand All @@ -156,7 +165,7 @@ class Graph:
edges: List[Edge] = field(default_factory=list)
branches: Optional[Dict[str, List[Branch]]] = field(default_factory=dict)

def to_json(self) -> Dict[str, List[Dict[str, Any]]]:
def to_json(self, *, with_schemas: bool = False) -> Dict[str, List[Dict[str, Any]]]:
"""Convert the graph to a JSON-serializable format."""
stable_node_ids = {
node.id: i if is_uuid(node.id) else node.id
Expand All @@ -165,7 +174,10 @@ def to_json(self) -> Dict[str, List[Dict[str, Any]]]:

return {
"nodes": [
{"id": stable_node_ids[node.id], **node_data_json(node)}
{
"id": stable_node_ids[node.id],
**node_data_json(node, with_schemas=with_schemas),
}
for node in self.nodes.values()
],
"edges": [
Expand Down
70 changes: 43 additions & 27 deletions libs/core/tests/unit_tests/fake/test_fake_chat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,30 +8,31 @@
from langchain_core.messages import AIMessage, AIMessageChunk, BaseMessage
from langchain_core.messages.human import HumanMessage
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from tests.unit_tests.stubs import AnyStr


def test_generic_fake_chat_model_invoke() -> None:
# Will alternate between responding with hello and goodbye
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = model.invoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
response = model.invoke("kitty")
assert response == AIMessage(content="goodbye")
assert response == AIMessage(content="goodbye", id=AnyStr())
response = model.invoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())


async def test_generic_fake_chat_model_ainvoke() -> None:
# Will alternate between responding with hello and goodbye
infinite_cycle = cycle([AIMessage(content="hello"), AIMessage(content="goodbye")])
model = GenericFakeChatModel(messages=infinite_cycle)
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())
response = await model.ainvoke("kitty")
assert response == AIMessage(content="goodbye")
assert response == AIMessage(content="goodbye", id=AnyStr())
response = await model.ainvoke("meow")
assert response == AIMessage(content="hello")
assert response == AIMessage(content="hello", id=AnyStr())


async def test_generic_fake_chat_model_stream() -> None:
Expand All @@ -44,27 +45,30 @@ async def test_generic_fake_chat_model_stream() -> None:
model = GenericFakeChatModel(messages=infinite_cycle)
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
assert len({chunk.id for chunk in chunks}) == 1

chunks = [chunk for chunk in model.stream("meow")]
assert chunks == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
assert len({chunk.id for chunk in chunks}) == 1

# Test streaming of additional kwargs.
# Relying on insertion order of the additional kwargs dict
message = AIMessage(content="", additional_kwargs={"foo": 42, "bar": 24})
model = GenericFakeChatModel(messages=cycle([message]))
chunks = [chunk async for chunk in model.astream("meow")]
assert chunks == [
AIMessageChunk(content="", additional_kwargs={"foo": 42}),
AIMessageChunk(content="", additional_kwargs={"bar": 24}),
AIMessageChunk(content="", additional_kwargs={"foo": 42}, id=AnyStr()),
AIMessageChunk(content="", additional_kwargs={"bar": 24}, id=AnyStr()),
]
assert len({chunk.id for chunk in chunks}) == 1

message = AIMessage(
content="",
Expand All @@ -81,24 +85,31 @@ async def test_generic_fake_chat_model_stream() -> None:

assert chunks == [
AIMessageChunk(
content="", additional_kwargs={"function_call": {"name": "move_file"}}
content="",
additional_kwargs={"function_call": {"name": "move_file"}},
id=AnyStr(),
),
AIMessageChunk(
content="",
additional_kwargs={
"function_call": {"arguments": '{\n "source_path": "foo"'}
"function_call": {"arguments": '{\n "source_path": "foo"'},
},
id=AnyStr(),
),
AIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": ","}}
content="",
additional_kwargs={"function_call": {"arguments": ","}},
id=AnyStr(),
),
AIMessageChunk(
content="",
additional_kwargs={
"function_call": {"arguments": '\n "destination_path": "bar"\n}'}
"function_call": {"arguments": '\n "destination_path": "bar"\n}'},
},
id=AnyStr(),
),
]
assert len({chunk.id for chunk in chunks}) == 1

accumulate_chunks = None
for chunk in chunks:
Expand All @@ -116,6 +127,7 @@ async def test_generic_fake_chat_model_stream() -> None:
'destination_path": "bar"\n}',
}
},
id=chunks[0].id,
)


Expand All @@ -128,10 +140,11 @@ async def test_generic_fake_chat_model_astream_log() -> None:
]
final = log_patches[-1]
assert final.state["streamed_output"] == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
assert len({chunk.id for chunk in final.state["streamed_output"]}) == 1


async def test_callback_handlers() -> None:
Expand Down Expand Up @@ -178,16 +191,19 @@ async def on_llm_new_token(
# New model
results = list(model.stream("meow", {"callbacks": [MyCustomAsyncHandler(tokens)]}))
assert results == [
AIMessageChunk(content="hello"),
AIMessageChunk(content=" "),
AIMessageChunk(content="goodbye"),
AIMessageChunk(content="hello", id=AnyStr()),
AIMessageChunk(content=" ", id=AnyStr()),
AIMessageChunk(content="goodbye", id=AnyStr()),
]
assert tokens == ["hello", " ", "goodbye"]
assert len({chunk.id for chunk in results}) == 1


def test_chat_model_inputs() -> None:
fake = ParrotFakeChatModel()

assert fake.invoke("hello") == HumanMessage(content="hello")
assert fake.invoke([("ai", "blah")]) == AIMessage(content="blah")
assert fake.invoke([AIMessage(content="blah")]) == AIMessage(content="blah")
assert fake.invoke("hello") == HumanMessage(content="hello", id=AnyStr())
assert fake.invoke([("ai", "blah")]) == AIMessage(content="blah", id=AnyStr())
assert fake.invoke([AIMessage(content="blah")]) == AIMessage(
content="blah", id=AnyStr()
)

0 comments on commit 389eff3

Please sign in to comment.