Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

anthropic[minor]: package move #17974

Merged
merged 10 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 3 additions & 2 deletions libs/partners/anthropic/langchain_anthropic/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from langchain_anthropic.chat_models import ChatAnthropicMessages
from langchain_anthropic.chat_models import ChatAnthropic, ChatAnthropicMessages
from langchain_anthropic.llms import Anthropic, AnthropicLLM

__all__ = ["ChatAnthropicMessages"]
__all__ = ["ChatAnthropicMessages", "ChatAnthropic", "Anthropic", "AnthropicLLM"]
50 changes: 38 additions & 12 deletions libs/partners/anthropic/langchain_anthropic/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Tuple

import anthropic
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
Expand All @@ -14,7 +15,11 @@
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str
from langchain_core.utils import (
build_extra_kwargs,
convert_to_secret_str,
get_pydantic_field_names,
)

_message_type_lookups = {"human": "user", "ai": "assistant"}

Expand Down Expand Up @@ -50,7 +55,7 @@ def _format_messages(messages: List[BaseMessage]) -> Tuple[Optional[str], List[D
return system, formatted_messages


class ChatAnthropicMessages(BaseChatModel):
class ChatAnthropic(BaseChatModel):
"""ChatAnthropicMessages chat model.

Example:
Expand All @@ -61,13 +66,18 @@ class ChatAnthropicMessages(BaseChatModel):
model = ChatAnthropicMessages()
"""

_client: anthropic.Client = Field(default_factory=anthropic.Client)
_async_client: anthropic.AsyncClient = Field(default_factory=anthropic.AsyncClient)
class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True

_client: anthropic.Client = Field(default=None)
_async_client: anthropic.AsyncClient = Field(default=None)

model: str = Field(alias="model_name")
"""Model name to use."""

max_tokens: int = Field(default=256)
max_tokens: int = Field(default=256, alias="max_tokens_to_sample")
"""Denotes the number of tokens to predict per generation."""

temperature: Optional[float] = None
Expand All @@ -88,16 +98,20 @@ class ChatAnthropicMessages(BaseChatModel):

model_kwargs: Dict[str, Any] = Field(default_factory=dict)

class Config:
"""Configuration for this pydantic object."""

allow_population_by_field_name = True

@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "chat-anthropic-messages"

@root_validator(pre=True)
def build_extra(cls, values: Dict) -> Dict:
extra = values.get("model_kwargs", {})
all_required_field_names = get_pydantic_field_names(cls)
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
return values

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
anthropic_api_key = convert_to_secret_str(
Expand Down Expand Up @@ -130,6 +144,7 @@ def _format_params(
"top_p": self.top_p,
"stop_sequences": stop,
"system": system,
**self.model_kwargs,
}
rtn = {k: v for k, v in rtn.items() if v is not None}

Expand All @@ -145,7 +160,10 @@ def _stream(
params = self._format_params(messages=messages, stop=stop, **kwargs)
with self._client.messages.stream(**params) as stream:
for text in stream.text_stream:
yield ChatGenerationChunk(message=AIMessageChunk(content=text))
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
if run_manager:
run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk

async def _astream(
self,
Expand All @@ -157,7 +175,10 @@ async def _astream(
params = self._format_params(messages=messages, stop=stop, **kwargs)
async with self._async_client.messages.stream(**params) as stream:
async for text in stream.text_stream:
yield ChatGenerationChunk(message=AIMessageChunk(content=text))
chunk = ChatGenerationChunk(message=AIMessageChunk(content=text))
if run_manager:
await run_manager.on_llm_new_token(text, chunk=chunk)
yield chunk

def _generate(
self,
Expand Down Expand Up @@ -190,3 +211,8 @@ async def _agenerate(
],
llm_output=data,
)


@deprecated(since="0.1.0", removal="0.2.0", alternative="ChatAnthropic")
class ChatAnthropicMessages(ChatAnthropic):
pass