Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: allow LLMs async streaming to fallback on sync streaming #18960

Merged
merged 8 commits into from
Mar 15, 2024
128 changes: 83 additions & 45 deletions libs/core/langchain_core/language_models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from pathlib import Path
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Dict,
Expand Down Expand Up @@ -119,6 +120,26 @@ def _before_sleep(retry_state: RetryCallState) -> None:
)


def _as_async_iterator(sync_iterator: Callable) -> Callable:
"""Convert a sync iterator into an async iterator."""

async def _as_sync_iterator(*args: Any, **kwargs: Any) -> AsyncGenerator:
iterator = await run_in_executor(None, sync_iterator, *args, **kwargs)
done = object()
while True:
item = await run_in_executor(
None,
next,
iterator,
done, # type: ignore[call-arg, arg-type]
)
if item is done:
break
yield item # type: ignore[misc]

return _as_sync_iterator


def get_prompts(
params: Dict[str, Any], prompts: List[str]
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
Expand Down Expand Up @@ -461,54 +482,71 @@ async def astream(
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
if type(self)._astream == BaseLLM._astream:
if type(self)._astream is not BaseLLM._astream:
# model doesn't implement streaming, so use default implementation
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
else:
prompt = self._convert_input(input).to_string()
config = ensure_config(config)
params = self.dict()
params["stop"] = stop
params = {**params, **kwargs}
options = {"stop": stop}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
_stream_implementation = self._astream
elif type(self)._stream is not BaseLLM._stream:
# Then stream is implemented, so we can create an async iterator from it
# The typing is hard to type correctly with mypy here, so we cast
# and do a type ignore, this code is unit tested and should be fine.
_stream_implementation = cast( # type: ignore
Callable[
[
str,
Optional[List[str]],
CallbackManagerForLLMRun,
Any,
],
AsyncIterator[GenerationChunk],
],
_as_async_iterator(self._stream),
)
(run_manager,) = await callback_manager.on_llm_start(
dumpd(self),
[prompt],
invocation_params=params,
options=options,
name=config.get("run_name"),
batch_size=1,
else:
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
return

prompt = self._convert_input(input).to_string()
config = ensure_config(config)
params = self.dict()
params["stop"] = stop
params = {**params, **kwargs}
options = {"stop": stop}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
)
(run_manager,) = await callback_manager.on_llm_start(
dumpd(self),
[prompt],
invocation_params=params,
options=options,
name=config.get("run_name"),
batch_size=1,
)
generation: Optional[GenerationChunk] = None
try:
async for chunk in _stream_implementation(
prompt, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.text
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
except BaseException as e:
await run_manager.on_llm_error(
e,
response=LLMResult(generations=[[generation]] if generation else []),
)
generation: Optional[GenerationChunk] = None
try:
async for chunk in self._astream(
prompt, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.text
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
except BaseException as e:
await run_manager.on_llm_error(
e,
response=LLMResult(
generations=[[generation]] if generation else []
),
)
raise e
else:
await run_manager.on_llm_end(LLMResult(generations=[[generation]]))
raise e
else:
await run_manager.on_llm_end(LLMResult(generations=[[generation]]))

# --- Custom methods ---

Expand Down
106 changes: 105 additions & 1 deletion libs/core/tests/unit_tests/language_models/llms/test_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from typing import Any, AsyncIterator, Iterator, List, Optional

import pytest

from langchain_core.outputs.llm_result import LLMResult
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.callbacks import (
BaseFakeCallbackHandler,
Expand Down Expand Up @@ -113,3 +120,100 @@ def eval_response(callback: BaseFakeCallbackHandler, i: int) -> None:
pass

eval_response(cb_sync, i)


async def test_astream_fallback_to_ainvoke() -> None:
"""Test astream uses appropriate implementation."""

class ModelWithGenerate(BaseLLM):
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
generations = [Generation(text="hello")]
return LLMResult(generations=[generations])

@property
def _llm_type(self) -> str:
return "fake-chat-model"

model = ModelWithGenerate()
chunks = [chunk for chunk in model.stream("anything")]
assert chunks == ["hello"]

chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == ["hello"]


async def test_astream_implementation_fallback_to_stream() -> None:
"""Test astream uses appropriate implementation."""

class ModelWithSyncStream(BaseLLM):
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
raise NotImplementedError()

def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Stream the output of the model."""
yield GenerationChunk(text="a")
yield GenerationChunk(text="b")

@property
def _llm_type(self) -> str:
return "fake-chat-model"

model = ModelWithSyncStream()
chunks = [chunk for chunk in model.stream("anything")]
assert chunks == ["a", "b"]
assert type(model)._astream == BaseLLM._astream
astream_chunks = [chunk async for chunk in model.astream("anything")]
assert astream_chunks == ["a", "b"]


async def test_astream_implementation_uses_astream() -> None:
"""Test astream uses appropriate implementation."""

class ModelWithAsyncStream(BaseLLM):
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
raise NotImplementedError()

async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
"""Stream the output of the model."""
yield GenerationChunk(text="a")
yield GenerationChunk(text="b")

@property
def _llm_type(self) -> str:
return "fake-chat-model"

model = ModelWithAsyncStream()
chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == ["a", "b"]