Skip to content

Commit

Permalink
Merge changes from langchain-ai#19084
Browse files Browse the repository at this point in the history
  • Loading branch information
billytrend-cohere committed Mar 20, 2024
1 parent c727bae commit bbe6ae4
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 45 deletions.
21 changes: 15 additions & 6 deletions libs/partners/cohere/langchain_cohere/chat_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional

from langchain_community.llms.cohere import BaseCohere
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
Expand All @@ -19,8 +20,6 @@
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult

from langchain_cohere.llms import BaseCohere


def get_role(message: BaseMessage) -> str:
"""Get the role of the message.
Expand Down Expand Up @@ -80,7 +79,7 @@ def get_cohere_chat_request(
"AUTO" if documents is not None or connectors is not None else None
)

return {
req = {
"message": messages[-1].content,
"chat_history": [
{"role": get_role(x), "message": x.content} for x in messages[:-1]
Expand All @@ -91,6 +90,8 @@ def get_cohere_chat_request(
**kwargs,
}

return {k: v for k, v in req.items() if v is not None}


class ChatCohere(BaseChatModel, BaseCohere):
"""`Cohere` chat large language models.
Expand Down Expand Up @@ -142,7 +143,11 @@ def _stream(
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
stream = self.client.chat(**request, stream=True)

if hasattr(self.client, "chat_stream"): # detect and support sdk v5
stream = self.client.chat_stream(**request)
else:
stream = self.client.chat(**request, stream=True)

for data in stream:
if data.event_type == "text-generation":
Expand All @@ -160,7 +165,11 @@ async def _astream(
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
stream = await self.async_client.chat(**request, stream=True)

if hasattr(self.async_client, "chat_stream"): # detect and support sdk v5
stream = self.async_client.chat_stream(**request)
else:
stream = self.async_client.chat(**request, stream=True)

async for data in stream:
if data.event_type == "text-generation":
Expand Down Expand Up @@ -220,7 +229,7 @@ async def _agenerate(
return await agenerate_from_stream(stream_iter)

request = get_cohere_chat_request(messages, **self._default_params, **kwargs)
response = self.client.chat(**request, stream=False)
response = self.client.chat(**request)

message = AIMessage(content=response.text)
generation_info = None
Expand Down
31 changes: 26 additions & 5 deletions libs/partners/cohere/langchain_cohere/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import typing
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

import cohere
from langchain_community.llms.cohere import _create_retry_decorator
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
from langchain_core.utils import get_from_dict_or_env
Expand Down Expand Up @@ -31,12 +32,12 @@ class CohereEmbeddings(BaseModel, Embeddings):
model: str = "embed-english-v2.0"
"""Model name to use."""

truncate: typing.Optional[cohere.EmbedRequestTruncate] = None
truncate: Optional[str] = None
"""Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")"""

cohere_api_key: Optional[str] = None

max_retries: Optional[int] = 3
max_retries: int = 3
"""Maximum number of retries to make when generating."""
request_timeout: Optional[float] = None
"""Timeout in seconds for the Cohere API request."""
Expand Down Expand Up @@ -70,13 +71,33 @@ def validate_environment(cls, values: Dict) -> Dict:

return values

def embed_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the embed call."""
retry_decorator = _create_retry_decorator(self.max_retries)

@retry_decorator
def _embed_with_retry(**kwargs: Any) -> Any:
return self.client.embed(**kwargs)

return _embed_with_retry(**kwargs)

def aembed_with_retry(self, **kwargs: Any) -> Any:
"""Use tenacity to retry the embed call."""
retry_decorator = _create_retry_decorator(self.max_retries)

@retry_decorator
async def _embed_with_retry(**kwargs: Any) -> Any:
return await self.async_client.embed(**kwargs)

return _embed_with_retry(**kwargs)

def embed(
self,
texts: List[str],
*,
input_type: typing.Optional[cohere.EmbedInputType] = None,
) -> List[List[float]]:
embeddings = self.client.embed(
embeddings = self.embed_with_retry(
model=self.model,
texts=texts,
input_type=input_type,
Expand All @@ -91,7 +112,7 @@ async def aembed(
input_type: typing.Optional[cohere.EmbedInputType] = None,
) -> List[List[float]]:
embeddings = (
await self.async_client.embed(
await self.aembed_with_retry(
model=self.model,
texts=texts,
input_type=input_type,
Expand Down
34 changes: 6 additions & 28 deletions libs/partners/cohere/langchain_cohere/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import re
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Dict, List, Optional

import cohere
from langchain_core.callbacks import (
Expand All @@ -12,17 +12,9 @@
from langchain_core.language_models.llms import LLM
from langchain_core.load.serializable import Serializable
from langchain_core.pydantic_v1 import Extra, Field, SecretStr, root_validator
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
)
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env

from .utils import _create_retry_decorator


def enforce_stop_tokens(text: str, stop: List[str]) -> str:
Expand All @@ -33,23 +25,9 @@ def enforce_stop_tokens(text: str, stop: List[str]) -> str:
logger = logging.getLogger(__name__)


def _create_retry_decorator(llm: Cohere) -> Callable[[Any], Any]:
min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(llm.max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=(retry_if_exception_type(cohere.errors.InternalServerError)),
before_sleep=before_sleep_log(logger, logging.WARNING),
)


def completion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm)
retry_decorator = _create_retry_decorator(llm.max_retries)

@retry_decorator
def _completion_with_retry(**kwargs: Any) -> Any:
Expand All @@ -60,7 +38,7 @@ def _completion_with_retry(**kwargs: Any) -> Any:

def acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm)
retry_decorator = _create_retry_decorator(llm.max_retries)

@retry_decorator
async def _completion_with_retry(**kwargs: Any) -> Any:
Expand Down
35 changes: 35 additions & 0 deletions libs/partners/cohere/langchain_cohere/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from __future__ import annotations

import logging
from typing import Any, Callable

import cohere
import logger
from tenacity import (
before_sleep_log,
retry,
retry_if_exception_type,
stop_after_attempt,
wait_exponential,
)


def _create_retry_decorator(max_retries: int) -> Callable[[Any], Any]:
# support v4 and v5
retry_conditions = (
retry_if_exception_type(cohere.error.CohereError)
if hasattr(cohere, "error")
else retry_if_exception_type(Exception)
)

min_seconds = 4
max_seconds = 10
# Wait 2^x * 1 second between each retry starting with
# 4 seconds, then up to 10 seconds, then 10 seconds afterwards
return retry(
reraise=True,
stop=stop_after_attempt(max_retries),
wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds),
retry=retry_conditions,
before_sleep=before_sleep_log(logger, logging.WARNING),
)
10 changes: 5 additions & 5 deletions libs/partners/cohere/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion libs/partners/cohere/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ license = "MIT"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = ">=0.0.12"
cohere = "5.0.0a9"
cohere = "5.0.0a12"

[tool.poetry.group.test]
optional = true
Expand Down

0 comments on commit bbe6ae4

Please sign in to comment.