Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[patch]: RunnablePassthrough transform to autoupgrade to AddableDict #19051

Merged
merged 2 commits into from
Mar 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
17 changes: 5 additions & 12 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
accepts_config,
accepts_context,
accepts_run_manager,
adapt_first_streaming_chunk,
create_model,
gather_with_concurrency,
get_function_first_arg_dict_keys,
Expand Down Expand Up @@ -1178,7 +1179,7 @@ def transform(

for chunk in input:
if not got_first_val:
final = _adapt_first_streaming_chunk(chunk) # type: ignore
final = adapt_first_streaming_chunk(chunk) # type: ignore
got_first_val = True
else:
# Make a best effort to gather, for any type that supports `+`
Expand Down Expand Up @@ -1211,7 +1212,7 @@ async def atransform(

async for chunk in input:
if not got_first_val:
final = _adapt_first_streaming_chunk(chunk) # type: ignore
final = adapt_first_streaming_chunk(chunk) # type: ignore
got_first_val = True
else:
# Make a best effort to gather, for any type that supports `+`
Expand Down Expand Up @@ -3702,7 +3703,7 @@ def _transform(
final: Optional[Input] = None
for ichunk in input:
if final is None:
final = _adapt_first_streaming_chunk(ichunk) # type: ignore
final = adapt_first_streaming_chunk(ichunk) # type: ignore
else:
try:
final = final + ichunk # type: ignore[operator]
Expand Down Expand Up @@ -3786,7 +3787,7 @@ async def _atransform(
final: Optional[Input] = None
async for ichunk in input:
if final is None:
final = _adapt_first_streaming_chunk(ichunk)
final = adapt_first_streaming_chunk(ichunk)
else:
try:
final = final + ichunk # type: ignore[operator]
Expand Down Expand Up @@ -4698,11 +4699,3 @@ def my_func(fields):
yield chunk
"""
return RunnableLambda(func)


def _adapt_first_streaming_chunk(chunk: Any) -> Any:
"""This might transform the first chunk of a stream into an AddableDict."""
if isinstance(chunk, dict) and not isinstance(chunk, AddableDict):
return AddableDict(chunk)
else:
return chunk
5 changes: 3 additions & 2 deletions libs/core/langchain_core/runnables/passthrough.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from langchain_core.runnables.utils import (
AddableDict,
ConfigurableFieldSpec,
adapt_first_streaming_chunk,
create_model,
)
from langchain_core.utils.aiter import atee, py_anext
Expand Down Expand Up @@ -248,7 +249,7 @@ def transform(
for chunk in self._transform_stream_with_config(input, identity, config):
yield chunk
if final is None:
final = chunk
final = adapt_first_streaming_chunk(chunk)
else:
final = final + chunk

Expand Down Expand Up @@ -276,7 +277,7 @@ async def atransform(
):
yield chunk
if final is None:
final = chunk
final = adapt_first_streaming_chunk(chunk)
else:
final = final + chunk

Expand Down
8 changes: 8 additions & 0 deletions libs/core/langchain_core/runnables/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,11 @@ def _create_model_cached(
return _create_model_base(
__model_name, __config__=_SchemaConfig, **field_definitions
)


def adapt_first_streaming_chunk(chunk: Any) -> Any:
"""This might transform the first chunk of a stream into an AddableDict."""
if isinstance(chunk, dict) and not isinstance(chunk, AddableDict):
return AddableDict(chunk)
else:
return chunk
21 changes: 20 additions & 1 deletion libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5324,7 +5324,7 @@ def invoke(
assert list(runnable.transform(chunks)) == [{"foo": "an"}]


async def test_defualt_atransform_with_dicts() -> None:
async def test_default_atransform_with_dicts() -> None:
"""Test that default transform works with dicts."""

class CustomRunnable(RunnableSerializable[Input, Output]):
Expand All @@ -5342,3 +5342,22 @@ async def chunk_iterator() -> AsyncIterator[Dict[str, str]]:
chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]

assert chunks == [{"foo": "an"}]


def test_passthrough_transform_with_dicts() -> None:
"""Test that default transform works with dicts."""
runnable = RunnablePassthrough(lambda x: x)
chunks = [chunk for chunk in runnable.transform(iter([{"foo": "a"}, {"foo": "n"}]))]
assert chunks == [{"foo": "a"}, {"foo": "n"}]


async def test_passthrough_atransform_with_dicts() -> None:
"""Test that default transform works with dicts."""
runnable = RunnablePassthrough(lambda x: x)

async def chunk_iterator() -> AsyncIterator[Dict[str, str]]:
yield {"foo": "a"}
yield {"foo": "n"}

chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]
assert chunks == [{"foo": "a"}, {"foo": "n"}]