Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core[minor]: Add Runnable.batch_as_completed #17603

Merged
merged 4 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
235 changes: 235 additions & 0 deletions libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,76 @@ def invoke(input: Input, config: RunnableConfig) -> Union[Output, Exception]:
with get_executor_for_config(configs[0]) as executor:
return cast(List[Output], list(executor.map(invoke, inputs, configs)))

@overload
def batch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[False] = False,
**kwargs: Any,
) -> Iterator[Tuple[int, Output]]:
...

@overload
def batch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[True],
**kwargs: Any,
) -> Iterator[Tuple[int, Union[Output, Exception]]]:
...

def batch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> Iterator[Tuple[int, Union[Output, Exception]]]:
"""Run invoke in parallel on a list of inputs,
yielding results as they complete."""

if not inputs:
return

configs = get_config_list(config, len(inputs))

def invoke(
i: int, input: Input, config: RunnableConfig
) -> Tuple[int, Union[Output, Exception]]:
if return_exceptions:
try:
out: Union[Output, Exception] = self.invoke(input, config, **kwargs)
except Exception as e:
out = e
else:
out = self.invoke(input, config, **kwargs)

return (i, out)

if len(inputs) == 1:
yield invoke(0, inputs[0], configs[0])
return

with get_executor_for_config(configs[0]) as executor:
futures = {
executor.submit(invoke, i, input, config)
for i, (input, config) in enumerate(zip(inputs, configs))
}

try:
while futures:
done, futures = wait(futures, return_when=FIRST_COMPLETED)
while done:
yield done.pop().result()
finally:
for future in futures:
future.cancel()

async def abatch(
self,
inputs: List[Input],
Expand Down Expand Up @@ -564,6 +634,64 @@ async def ainvoke(
coros = map(ainvoke, inputs, configs)
return await gather_with_concurrency(configs[0].get("max_concurrency"), *coros)

@overload
def abatch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[False] = False,
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Output]]:
...

@overload
def abatch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[True],
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]:
...

async def abatch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]:
"""Run ainvoke in parallel on a list of inputs,
yielding results as they complete."""

if not inputs:
return

configs = get_config_list(config, len(inputs))

async def ainvoke(
i: int, input: Input, config: RunnableConfig
) -> Tuple[int, Union[Output, Exception]]:
if return_exceptions:
try:
out: Union[Output, Exception] = await self.ainvoke(
input, config, **kwargs
)
except Exception as e:
out = e
else:
out = await self.ainvoke(input, config, **kwargs)

return (i, out)

coros = map(ainvoke, range(len(inputs)), inputs, configs)

for coro in asyncio.as_completed(coros):
yield await coro

def stream(
self,
input: Input,
Expand Down Expand Up @@ -4149,6 +4277,113 @@ async def abatch(
**{**self.kwargs, **kwargs},
)

@overload
def batch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[False] = False,
**kwargs: Any,
) -> Iterator[Tuple[int, Output]]:
...

@overload
def batch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[True],
**kwargs: Any,
) -> Iterator[Tuple[int, Union[Output, Exception]]]:
...

def batch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> Iterator[Tuple[int, Union[Output, Exception]]]:
if isinstance(config, list):
configs = cast(
List[RunnableConfig],
[self._merge_configs(conf) for conf in config],
)
else:
configs = [self._merge_configs(config) for _ in range(len(inputs))]
# lol mypy
if return_exceptions:
yield from self.bound.batch_as_completed(
inputs,
configs,
return_exceptions=return_exceptions,
**{**self.kwargs, **kwargs},
)
else:
yield from self.bound.batch_as_completed(
inputs,
configs,
return_exceptions=return_exceptions,
**{**self.kwargs, **kwargs},
)

@overload
def abatch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[False] = False,
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Output]]:
...

@overload
def abatch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: Literal[True],
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]:
...

async def abatch_as_completed(
self,
inputs: List[Input],
config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None,
*,
return_exceptions: bool = False,
**kwargs: Optional[Any],
) -> AsyncIterator[Tuple[int, Union[Output, Exception]]]:
if isinstance(config, list):
configs = cast(
List[RunnableConfig],
[self._merge_configs(conf) for conf in config],
)
else:
configs = [self._merge_configs(config) for _ in range(len(inputs))]
if return_exceptions:
async for item in self.bound.abatch_as_completed(
inputs,
configs,
return_exceptions=return_exceptions,
**{**self.kwargs, **kwargs},
):
yield item
else:
async for item in self.bound.abatch_as_completed(
inputs,
configs,
return_exceptions=return_exceptions,
**{**self.kwargs, **kwargs},
):
yield item

def stream(
self,
input: Input,
Expand Down
67 changes: 67 additions & 0 deletions libs/core/tests/unit_tests/runnables/test_runnable.py
Original file line number Diff line number Diff line change
Expand Up @@ -1428,6 +1428,30 @@ async def test_with_config(mocker: MockerFixture) -> None:

spy.reset_mock()

assert sorted(
c
for c in fake.with_config(recursion_limit=5).batch_as_completed(
["hello", "wooorld"],
[dict(tags=["a-tag"]), dict(metadata={"key": "value"})],
)
) == [(0, 5), (1, 7)]

assert len(spy.call_args_list) == 2
for i, call in enumerate(
sorted(spy.call_args_list, key=lambda x: 0 if x.args[0] == "hello" else 1)
):
assert call.args[0] == ("hello" if i == 0 else "wooorld")
if i == 0:
assert call.args[1].get("recursion_limit") == 5
assert call.args[1].get("tags") == ["a-tag"]
assert call.args[1].get("metadata") == {}
else:
assert call.args[1].get("recursion_limit") == 5
assert call.args[1].get("tags") == []
assert call.args[1].get("metadata") == {"key": "value"}

spy.reset_mock()

assert fake.with_config(metadata={"a": "b"}).batch(
["hello", "wooorld"], dict(tags=["a-tag"])
) == [5, 7]
Expand All @@ -1438,6 +1462,15 @@ async def test_with_config(mocker: MockerFixture) -> None:
assert call.args[1].get("metadata") == {"a": "b"}
spy.reset_mock()

assert sorted(
c for c in fake.batch_as_completed(["hello", "wooorld"], dict(tags=["a-tag"]))
) == [(0, 5), (1, 7)]
assert len(spy.call_args_list) == 2
for i, call in enumerate(spy.call_args_list):
assert call.args[0] == ("hello" if i == 0 else "wooorld")
assert call.args[1].get("tags") == ["a-tag"]
spy.reset_mock()

handler = ConsoleCallbackHandler()
assert (
await fake.with_config(metadata={"a": "b"}).ainvoke(
Expand Down Expand Up @@ -1484,6 +1517,40 @@ async def test_with_config(mocker: MockerFixture) -> None:
),
),
]
spy.reset_mock()

assert sorted(
[
c
async for c in fake.with_config(
recursion_limit=5, tags=["c"]
).abatch_as_completed(["hello", "wooorld"], dict(metadata={"key": "value"}))
]
) == [
(0, 5),
(1, 7),
]
assert len(spy.call_args_list) == 2
first_call = next(call for call in spy.call_args_list if call.args[0] == "hello")
assert first_call == mocker.call(
"hello",
dict(
metadata={"key": "value"},
tags=["c"],
callbacks=None,
recursion_limit=5,
),
)
second_call = next(call for call in spy.call_args_list if call.args[0] == "wooorld")
assert second_call == mocker.call(
"wooorld",
dict(
metadata={"key": "value"},
tags=["c"],
callbacks=None,
recursion_limit=5,
),
)


async def test_default_method_implementations(mocker: MockerFixture) -> None:
Expand Down