Skip to content

Commit

Permalink
core[patch]: RunnablePassthrough transform to autoupgrade to AddableD…
Browse files Browse the repository at this point in the history
  • Loading branch information
eyurtsev authored and Dave Bechberger committed Mar 29, 2024
1 parent e576ddf commit 2ef9f74
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 15 deletions.
17 changes: 5 additions & 12 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
accepts_config,
accepts_context,
accepts_run_manager,
adapt_first_streaming_chunk,
create_model,
gather_with_concurrency,
get_function_first_arg_dict_keys,
Expand Down Expand Up @@ -1207,7 +1208,7 @@ def transform(

for chunk in input:
if not got_first_val:
final = _adapt_first_streaming_chunk(chunk) # type: ignore
final = adapt_first_streaming_chunk(chunk) # type: ignore
got_first_val = True
else:
# Make a best effort to gather, for any type that supports `+`
Expand Down Expand Up @@ -1240,7 +1241,7 @@ async def atransform(

async for chunk in input:
if not got_first_val:
final = _adapt_first_streaming_chunk(chunk) # type: ignore
final = adapt_first_streaming_chunk(chunk) # type: ignore
got_first_val = True
else:
# Make a best effort to gather, for any type that supports `+`
Expand Down Expand Up @@ -3731,7 +3732,7 @@ def _transform(
final: Optional[Input] = None
for ichunk in input:
if final is None:
final = _adapt_first_streaming_chunk(ichunk) # type: ignore
final = adapt_first_streaming_chunk(ichunk) # type: ignore
else:
try:
final = final + ichunk # type: ignore[operator]
Expand Down Expand Up @@ -3815,7 +3816,7 @@ async def _atransform(
final: Optional[Input] = None
async for ichunk in input:
if final is None:
final = _adapt_first_streaming_chunk(ichunk)
final = adapt_first_streaming_chunk(ichunk)
else:
try:
final = final + ichunk # type: ignore[operator]
Expand Down Expand Up @@ -4727,11 +4728,3 @@ def my_func(fields):
yield chunk
"""
return RunnableLambda(func)


def _adapt_first_streaming_chunk(chunk: Any) -> Any:
"""This might transform the first chunk of a stream into an AddableDict."""
if isinstance(chunk, dict) and not isinstance(chunk, AddableDict):
return AddableDict(chunk)
else:
return chunk
5 changes: 3 additions & 2 deletions libs/core/langchain_core/runnables/passthrough.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from langchain_core.runnables.utils import (
AddableDict,
ConfigurableFieldSpec,
adapt_first_streaming_chunk,
create_model,
)
from langchain_core.utils.aiter import atee, py_anext
Expand Down Expand Up @@ -248,7 +249,7 @@ def transform(
for chunk in self._transform_stream_with_config(input, identity, config):
yield chunk
if final is None:
final = chunk
final = adapt_first_streaming_chunk(chunk)
else:
final = final + chunk

Expand Down Expand Up @@ -276,7 +277,7 @@ async def atransform(
):
yield chunk
if final is None:
final = chunk
final = adapt_first_streaming_chunk(chunk)
else:
final = final + chunk

Expand Down
8 changes: 8 additions & 0 deletions libs/core/langchain_core/runnables/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,3 +521,11 @@ def _create_model_cached(
return _create_model_base(
__model_name, __config__=_SchemaConfig, **field_definitions
)


def adapt_first_streaming_chunk(chunk: Any) -> Any:
"""This might transform the first chunk of a stream into an AddableDict."""
if isinstance(chunk, dict) and not isinstance(chunk, AddableDict):
return AddableDict(chunk)
else:
return chunk
21 changes: 20 additions & 1 deletion libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -5324,7 +5324,7 @@ def invoke(
assert list(runnable.transform(chunks)) == [{"foo": "an"}]


async def test_defualt_atransform_with_dicts() -> None:
async def test_default_atransform_with_dicts() -> None:
"""Test that default transform works with dicts."""

class CustomRunnable(RunnableSerializable[Input, Output]):
Expand All @@ -5342,3 +5342,22 @@ async def chunk_iterator() -> AsyncIterator[Dict[str, str]]:
chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]

assert chunks == [{"foo": "an"}]


def test_passthrough_transform_with_dicts() -> None:
"""Test that default transform works with dicts."""
runnable = RunnablePassthrough(lambda x: x)
chunks = [chunk for chunk in runnable.transform(iter([{"foo": "a"}, {"foo": "n"}]))]
assert chunks == [{"foo": "a"}, {"foo": "n"}]


async def test_passthrough_atransform_with_dicts() -> None:
"""Test that default transform works with dicts."""
runnable = RunnablePassthrough(lambda x: x)

async def chunk_iterator() -> AsyncIterator[Dict[str, str]]:
yield {"foo": "a"}
yield {"foo": "n"}

chunks = [chunk async for chunk in runnable.atransform(chunk_iterator())]
assert chunks == [{"foo": "a"}, {"foo": "n"}]

0 comments on commit 2ef9f74

Please sign in to comment.