Skip to content

Commit

Permalink
core[minor]: allow LLMs async streaming to fallback on sync streaming (
Browse files Browse the repository at this point in the history
…langchain-ai#18960)

- **Description:** Handling fallbacks when calling async streaming for a
LLM that doesn't support it.
- **Issue:** langchain-ai#18920 
- **Twitter handle:**@maximeperrin_

---------

Co-authored-by: Maxime Perrin <mperrin@doing.fr>
  • Loading branch information
2 people authored and gkorland committed Mar 30, 2024
1 parent be6d1d3 commit 53b4eb0
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 46 deletions.
128 changes: 83 additions & 45 deletions libs/core/langchain_core/language_models/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pathlib import Path
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Callable,
Dict,
Expand Down Expand Up @@ -113,6 +114,26 @@ def _before_sleep(retry_state: RetryCallState) -> None:
)


def _as_async_iterator(sync_iterator: Callable) -> Callable:
"""Convert a sync iterator into an async iterator."""

async def _as_sync_iterator(*args: Any, **kwargs: Any) -> AsyncGenerator:
iterator = await run_in_executor(None, sync_iterator, *args, **kwargs)
done = object()
while True:
item = await run_in_executor(
None,
next,
iterator,
done, # type: ignore[call-arg, arg-type]
)
if item is done:
break
yield item # type: ignore[misc]

return _as_sync_iterator


def get_prompts(
params: Dict[str, Any], prompts: List[str]
) -> Tuple[Dict[int, List], str, List[int], List[str]]:
Expand Down Expand Up @@ -434,54 +455,71 @@ async def astream(
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> AsyncIterator[str]:
if type(self)._astream == BaseLLM._astream:
if type(self)._astream is not BaseLLM._astream:
# model doesn't implement streaming, so use default implementation
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
else:
prompt = self._convert_input(input).to_string()
config = ensure_config(config)
params = self.dict()
params["stop"] = stop
params = {**params, **kwargs}
options = {"stop": stop}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
_stream_implementation = self._astream
elif type(self)._stream is not BaseLLM._stream:
# Then stream is implemented, so we can create an async iterator from it
# The typing is hard to type correctly with mypy here, so we cast
# and do a type ignore, this code is unit tested and should be fine.
_stream_implementation = cast( # type: ignore
Callable[
[
str,
Optional[List[str]],
CallbackManagerForLLMRun,
Any,
],
AsyncIterator[GenerationChunk],
],
_as_async_iterator(self._stream),
)
(run_manager,) = await callback_manager.on_llm_start(
dumpd(self),
[prompt],
invocation_params=params,
options=options,
name=config.get("run_name"),
batch_size=1,
else:
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
return

prompt = self._convert_input(input).to_string()
config = ensure_config(config)
params = self.dict()
params["stop"] = stop
params = {**params, **kwargs}
options = {"stop": stop}
callback_manager = AsyncCallbackManager.configure(
config.get("callbacks"),
self.callbacks,
self.verbose,
config.get("tags"),
self.tags,
config.get("metadata"),
self.metadata,
)
(run_manager,) = await callback_manager.on_llm_start(
dumpd(self),
[prompt],
invocation_params=params,
options=options,
name=config.get("run_name"),
batch_size=1,
)
generation: Optional[GenerationChunk] = None
try:
async for chunk in _stream_implementation(
prompt, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.text
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
except BaseException as e:
await run_manager.on_llm_error(
e,
response=LLMResult(generations=[[generation]] if generation else []),
)
generation: Optional[GenerationChunk] = None
try:
async for chunk in self._astream(
prompt, stop=stop, run_manager=run_manager, **kwargs
):
yield chunk.text
if generation is None:
generation = chunk
else:
generation += chunk
assert generation is not None
except BaseException as e:
await run_manager.on_llm_error(
e,
response=LLMResult(
generations=[[generation]] if generation else []
),
)
raise e
else:
await run_manager.on_llm_end(LLMResult(generations=[[generation]]))
raise e
else:
await run_manager.on_llm_end(LLMResult(generations=[[generation]]))

# --- Custom methods ---

Expand Down
106 changes: 105 additions & 1 deletion libs/core/tests/unit_tests/language_models/llms/test_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from typing import Any, AsyncIterator, Iterator, List, Optional

import pytest

from langchain_core.outputs.llm_result import LLMResult
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.callbacks import (
BaseFakeCallbackHandler,
Expand Down Expand Up @@ -113,3 +120,100 @@ def eval_response(callback: BaseFakeCallbackHandler, i: int) -> None:
pass

eval_response(cb_sync, i)


async def test_astream_fallback_to_ainvoke() -> None:
"""Test astream uses appropriate implementation."""

class ModelWithGenerate(BaseLLM):
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
generations = [Generation(text="hello")]
return LLMResult(generations=[generations])

@property
def _llm_type(self) -> str:
return "fake-chat-model"

model = ModelWithGenerate()
chunks = [chunk for chunk in model.stream("anything")]
assert chunks == ["hello"]

chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == ["hello"]


async def test_astream_implementation_fallback_to_stream() -> None:
"""Test astream uses appropriate implementation."""

class ModelWithSyncStream(BaseLLM):
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
raise NotImplementedError()

def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Stream the output of the model."""
yield GenerationChunk(text="a")
yield GenerationChunk(text="b")

@property
def _llm_type(self) -> str:
return "fake-chat-model"

model = ModelWithSyncStream()
chunks = [chunk for chunk in model.stream("anything")]
assert chunks == ["a", "b"]
assert type(model)._astream == BaseLLM._astream
astream_chunks = [chunk async for chunk in model.astream("anything")]
assert astream_chunks == ["a", "b"]


async def test_astream_implementation_uses_astream() -> None:
"""Test astream uses appropriate implementation."""

class ModelWithAsyncStream(BaseLLM):
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
raise NotImplementedError()

async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
"""Stream the output of the model."""
yield GenerationChunk(text="a")
yield GenerationChunk(text="b")

@property
def _llm_type(self) -> str:
return "fake-chat-model"

model = ModelWithAsyncStream()
chunks = [chunk async for chunk in model.astream("anything")]
assert chunks == ["a", "b"]

0 comments on commit 53b4eb0

Please sign in to comment.