Skip to content

Commit

Permalink
core[minor]: Add Runnable.batch_as_completed (#17603)
Browse files Browse the repository at this point in the history
This PR adds `batch as completed` method to the standard Runnable
interface. It takes in a list of inputs and yields the corresponding
outputs as the inputs are completed.
  • Loading branch information
nfcampos authored and hinthornw committed Apr 26, 2024
1 parent 167e129 commit 17ded1b
Show file tree
Hide file tree
Showing 2 changed files with 302 additions and 0 deletions.
235 changes: 235 additions & 0 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,76 @@ def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]:
with get_executor_for_config(configs[0]) as executor:
return cast(List[Output], list(executor.map(invoke, inputs, configs)))

@overload
def batch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[False] = False,
**kwargs: Any,
) -> Iterator[Tuple[int, Output]]:
...

@overload
def batch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[True],
**kwargs: Any,
) -> Iterator[Tuple[int, Union[Output, Exception]]]:
...

def batch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> Iterator[Tuple[int, Union[Output, Exception]]]:
"""Run invoke in parallel on a list of inputs,
yielding results as they complete."""

if not inputs:
return

configs = get_config_list(config, len(inputs))

def invoke(
i: int, input: Input, config: RunnableConfig
) -> Tuple[int, Union[Output, Exception]]:
if return_exceptions:
try:
out: Union[Output, Exception] = self.invoke(input, config, **kwargs)
except Exception as e:
out = e
else:
out = self.invoke(input, config, **kwargs)

return (i, out)

if len(inputs) == 1:
yield invoke(0, inputs[0], configs[0])
return

with get_executor_for_config(configs[0]) as executor:
futures = {
executor.submit(invoke, i, input, config)
for i, (input, config) in enumerate(zip(inputs, configs))
}

try:
while futures:
done, futures = wait(futures, return_when=FIRST_COMPLETED)
while done:
yield done.pop().result()
finally:
for future in futures:
future.cancel()

async def abatch(
self,
inputs: List[Input],
Expand Down Expand Up @@ -564,6 +634,64 @@ async def ainvoke(
coros = map(ainvoke, inputs, configs)
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)

@overload
def abatch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[False] = False,
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Output]]:
...

@overload
def abatch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[True],
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]:
...

async def abatch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]:
"""Run ainvoke in parallel on a list of inputs,
yielding results as they complete."""

if not inputs:
return

configs = get_config_list(config, len(inputs))

async def ainvoke(
i: int, input: Input, config: RunnableConfig
) -> Tuple[int, Union[Output, Exception]]:
if return_exceptions:
try:
out: Union[Output, Exception] = await self.ainvoke(
input, config, **kwargs
)
except Exception as e:
out = e
else:
out = await self.ainvoke(input, config, **kwargs)

return (i, out)

coros = map(ainvoke, range(len(inputs)), inputs, configs)

for coro in asyncio.as_completed(coros):
yield await coro

def stream(
self,
input: Input,
Expand Down Expand Up @@ -4149,6 +4277,113 @@ async def abatch(
**{**self.kwargs, **kwargs},
)

@overload
def batch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[False] = False,
**kwargs: Any,
) -> Iterator[Tuple[int, Output]]:
...

@overload
def batch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[True],
**kwargs: Any,
) -> Iterator[Tuple[int, Union[Output, Exception]]]:
...

def batch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> Iterator[Tuple[int, Union[Output, Exception]]]:
if isinstance(config, list):
configs = cast(
List[RunnableConfig],
[self._merge_configs(conf) for conf in config],
)
else:
configs = [self._merge_configs(config) for _ in range(len(inputs))]
# lol mypy
if return_exceptions:
yield from self.bound.batch_as_completed(
inputs,
configs,
return_exceptions=return_exceptions,
**{**self.kwargs, **kwargs},
)
else:
yield from self.bound.batch_as_completed(
inputs,
configs,
return_exceptions=return_exceptions,
**{**self.kwargs, **kwargs},
)

@overload
def abatch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[False] = False,
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Output]]:
...

@overload
def abatch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[True],
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]:
...

async def abatch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]:
if isinstance(config, list):
configs = cast(
List[RunnableConfig],
[self._merge_configs(conf) for conf in config],
)
else:
configs = [self._merge_configs(config) for _ in range(len(inputs))]
if return_exceptions:
async for item in self.bound.abatch_as_completed(
inputs,
configs,
return_exceptions=return_exceptions,
**{**self.kwargs, **kwargs},
):
yield item
else:
async for item in self.bound.abatch_as_completed(
inputs,
configs,
return_exceptions=return_exceptions,
**{**self.kwargs, **kwargs},
):
yield item

def stream(
self,
input: Input,
Expand Down
67 changes: 67 additions & 0 deletions libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,6 +1428,30 @@ async def test_with_config(mocker: MockerFixture) -> None:

spy.reset_mock()

assert sorted(
c
for c in fake.with_config(recursion_limit=5).batch_as_completed(
["hello", "wooorld"],
[dict(tags=["a-tag"]), dict(metadata={"key": "value"})],
)
) == [(0, 5), (1, 7)]

assert len(spy.call_args_list) == 2
for i, call in enumerate(
sorted(spy.call_args_list, key=lambda x: 0 if x.args[0] == "hello" else 1)
):
assert call.args[0] == ("hello" if i == 0 else "wooorld")
if i == 0:
assert call.args[1].get("recursion_limit") == 5
assert call.args[1].get("tags") == ["a-tag"]
assert call.args[1].get("metadata") == {}
else:
assert call.args[1].get("recursion_limit") == 5
assert call.args[1].get("tags") == []
assert call.args[1].get("metadata") == {"key": "value"}

spy.reset_mock()

assert fake.with_config(metadata={"a": "b"}).batch(
["hello", "wooorld"], dict(tags=["a-tag"])
) == [5, 7]
Expand All @@ -1438,6 +1462,15 @@ async def test_with_config(mocker: MockerFixture) -> None:
assert call.args[1].get("metadata") == {"a": "b"}
spy.reset_mock()

assert sorted(
c for c in fake.batch_as_completed(["hello", "wooorld"], dict(tags=["a-tag"]))
) == [(0, 5), (1, 7)]
assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list):
assert call.args[0] == ("hello" if i == 0 else "wooorld")
assert call.args[1].get("tags") == ["a-tag"]
spy.reset_mock()

handler = ConsoleCallbackHandler()
assert (
await fake.with_config(metadata={"a": "b"}).ainvoke(
Expand Down Expand Up @@ -1484,6 +1517,40 @@ async def test_with_config(mocker: MockerFixture) -> None:
),
),
]
spy.reset_mock()

assert sorted(
[
c
async for c in fake.with_config(
recursion_limit=5, tags=["c"]
).abatch_as_completed(["hello", "wooorld"], dict(metadata={"key": "value"}))
]
) == [
(0, 5),
(1, 7),
]
assert len(spy.call_args_list) == 2
first_call = next(call for call in spy.call_args_list if call.args[0] == "hello")
assert first_call == mocker.call(
"hello",
dict(
metadata={"key": "value"},
tags=["c"],
callbacks=None,
recursion_limit=5,
),
)
second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld")
assert second_call == mocker.call(
"wooorld",
dict(
metadata={"key": "value"},
tags=["c"],
callbacks=None,
recursion_limit=5,
),
)


async def test_default_method_implementations(mocker: MockerFixture) -> None:
Expand Down

0 comments on commit 17ded1b

Please sign in to comment.