Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

partners: AI21 Labs Contextual Answers support #18270

Merged
merged 26 commits into from
Mar 5, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
9976ca2
feat: Contextual answers in langchain
asafgardin Feb 28, 2024
ae411b6
docs: Updated readme
asafgardin Feb 28, 2024
f0e6aa0
Merge branch 'langchain-ai:master' into master
Josephasafg Feb 28, 2024
3eebce9
Merge branch 'master' into Josephasafg/master
efriis Feb 28, 2024
ccb7561
deps
efriis Feb 28, 2024
c054713
docs: Added ca examples to llm docs
asafgardin Feb 29, 2024
931ff88
fix: readme
asafgardin Feb 29, 2024
9b1799d
feat: Added chunking to embedding
asafgardin Feb 29, 2024
0b0ede6
fix: Used default chunk size
asafgardin Feb 29, 2024
57eb9dd
fix: default chunk
asafgardin Feb 29, 2024
4669ef9
refactor: use itertools
asafgardin Mar 3, 2024
a94f423
fix: Moved function outside of class and added a test
asafgardin Mar 3, 2024
7a03e63
fix: test name
asafgardin Mar 3, 2024
a6f1d91
fix: lint
asafgardin Mar 3, 2024
383f522
fix: rename
asafgardin Mar 3, 2024
c5d1d43
Merge pull request #1 from Josephasafg/add_chunking_in_embeddings
Josephasafg Mar 3, 2024
5f4d507
revert: embeddings
asafgardin Mar 4, 2024
16db69a
Merge pull request #2 from Josephasafg/revert_embeddings
Josephasafg Mar 4, 2024
570eea0
Merge branch 'langchain-ai:master' into master
Josephasafg Mar 4, 2024
7546b9b
Merge branch 'master' into master
Josephasafg Mar 5, 2024
67eee9b
fix: notebook
asafgardin Mar 5, 2024
85dc73b
Merge pull request #4 from Josephasafg/notebook_format
Josephasafg Mar 5, 2024
2a439b7
Merge branch 'langchain-ai:master' into master
Josephasafg Mar 5, 2024
dcc74d1
Merge branch 'master' into master
efriis Mar 5, 2024
65ff240
Merge branch 'master' into Josephasafg/master
efriis Mar 5, 2024
6f788a3
cr
efriis Mar 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
32 changes: 32 additions & 0 deletions libs/partners/ai21/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,35 @@ from langchain_ai21 import AI21Embeddings
embeddings = AI21Embeddings()
embeddings.embed_documents(["Hello! This is document 1", "And this is document 2!"])
```

## Task Specific Models

### Contextual Answers

You can use AI21's contextual answers model to receives text or document, serving as a context,
and a question and returns an answer based entirely on this context.

This means that if the answer to your question is not in the document,
the model will indicate it (instead of providing a false answer)

### Query

```python
from langchain_ai21 import AI21ContextualAnswers

tsm = AI21ContextualAnswers()

response = tsm.invoke(input={"context": "Your context", "question": "Your question"})
```
You can also use it with chains and output parsers and vector DBs:
```python
from langchain_ai21 import AI21ContextualAnswers
from langchain_core.output_parsers import StrOutputParser

tsm = AI21ContextualAnswers()
chain = tsm | StrOutputParser()

response = chain.invoke(
{"context": "Your context", "question": "Your question"},
)
```
2 changes: 2 additions & 0 deletions libs/partners/ai21/langchain_ai21/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from langchain_ai21.chat_models import ChatAI21
from langchain_ai21.contextual_answers import AI21ContextualAnswers
from langchain_ai21.embeddings import AI21Embeddings
from langchain_ai21.llms import AI21LLM

__all__ = [
"AI21LLM",
"ChatAI21",
"AI21Embeddings",
"AI21ContextualAnswers",
]
108 changes: 108 additions & 0 deletions libs/partners/ai21/langchain_ai21/contextual_answers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from typing import (
Any,
List,
Optional,
Tuple,
Type,
TypedDict,
Union,
)

from langchain_core.documents import Document
from langchain_core.runnables import RunnableConfig, RunnableSerializable, ensure_config

from langchain_ai21.ai21_base import AI21Base

ANSWER_NOT_IN_CONTEXT_RESPONSE = "Answer not in context"

ContextType = Union[str, List[Union[Document, str]]]


class ContextualAnswerInput(TypedDict):
context: ContextType
question: str


class AI21ContextualAnswers(RunnableSerializable[ContextualAnswerInput, str], AI21Base):
class Config:
"""Configuration for this pydantic object."""

arbitrary_types_allowed = True

@property
def InputType(self) -> Type[ContextualAnswerInput]:
"""Get the input type for this runnable."""
return ContextualAnswerInput

@property
def OutputType(self) -> Type[str]:
"""Get the input type for this runnable."""
return str

def invoke(
self,
input: ContextualAnswerInput,
config: Optional[RunnableConfig] = None,
response_if_no_answer_found: str = ANSWER_NOT_IN_CONTEXT_RESPONSE,
**kwargs: Any,
) -> str:
config = ensure_config(config)
return self._call_with_config(
func=lambda inner_input: self._call_contextual_answers(
inner_input, response_if_no_answer_found
),
input=input,
config=config,
run_type="llm",
)

def _call_contextual_answers(
self,
input: ContextualAnswerInput,
response_if_no_answer_found: str,
) -> str:
context, question = self._convert_input(input)
response = self.client.answer.create(context=context, question=question)

if response.answer is None:
return response_if_no_answer_found

return response.answer

def _convert_input(self, input: ContextualAnswerInput) -> Tuple[str, str]:
context, question = self._extract_context_and_question(input)

context = self._parse_context(context)

return context, question

def _extract_context_and_question(
self,
input: ContextualAnswerInput,
) -> Tuple[ContextType, str]:
context = input.get("context")
question = input.get("question")

if not context or not question:
raise ValueError(
f"Input must contain a 'context' and 'question' fields. Got {input}"
)

if not isinstance(context, list) and not isinstance(context, str):
raise ValueError(
f"Expected input to be a list of strings or Documents."
f" Received {type(input)}"
)

return context, question

def _parse_context(self, context: ContextType) -> str:
if isinstance(context, str):
return context

docs = [
item.page_content if isinstance(item, Document) else item
for item in context
]

return "\n".join(docs)
12 changes: 6 additions & 6 deletions libs/partners/ai21/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion libs/partners/ai21/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ readme = "README.md"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1.22"
ai21 = "^2.0.0"
ai21 = "2.0.5"

[tool.poetry.group.test]
optional = true
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import Runnable

from langchain_ai21.contextual_answers import (
ANSWER_NOT_IN_CONTEXT_RESPONSE,
AI21ContextualAnswers,
)

context = """
Albert Einstein German: 14 March 1879 – 18 April 1955)
was a German-born theoretical physicist who is widely held
to be one of the greatest and most influential scientists
"""


_GOOD_QUESTION = "When did Albert Einstein born?"
_BAD_QUESTION = "What color is Yoda's light saber?"
_EXPECTED_PARTIAL_RESPONSE = "March 14, 1879"


def test_invoke__when_good_question() -> None:
llm = AI21ContextualAnswers()

response = llm.invoke(
{"context": context, "question": _GOOD_QUESTION},
config={"metadata": {"name": "I AM A TEST"}},
)

assert response != ANSWER_NOT_IN_CONTEXT_RESPONSE


def test_invoke__when_bad_question__should_return_answer_not_in_context() -> None:
llm = AI21ContextualAnswers()

response = llm.invoke(input={"context": context, "question": _BAD_QUESTION})

assert response == ANSWER_NOT_IN_CONTEXT_RESPONSE


def test_invoke__when_response_if_no_answer_passed__should_use_it() -> None:
response_if_no_answer_found = "This should be the response"
llm = AI21ContextualAnswers()

response = llm.invoke(
input={"context": context, "question": _BAD_QUESTION},
response_if_no_answer_found=response_if_no_answer_found,
)

assert response == response_if_no_answer_found


def test_invoke_when_used_in_a_simple_chain_with_no_vectorstore() -> None:
tsm = AI21ContextualAnswers()

chain: Runnable = tsm | StrOutputParser()

response = chain.invoke(
{"context": context, "question": _GOOD_QUESTION},
)

assert response != ANSWER_NOT_IN_CONTEXT_RESPONSE
22 changes: 18 additions & 4 deletions libs/partners/ai21/tests/unit_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import os
from contextlib import contextmanager
from typing import Generator
from unittest.mock import Mock

import pytest
from ai21 import AI21Client
from ai21 import AI21Client, AI21EnvConfig
from ai21.models import (
AnswerResponse,
ChatOutput,
ChatResponse,
Completion,
Expand Down Expand Up @@ -84,8 +84,22 @@ def temporarily_unset_api_key() -> Generator:
"""
Unset and set environment key for testing purpose for when an API KEY is not set
"""
api_key = os.environ.pop("API_KEY", None)
api_key = AI21EnvConfig.api_key
AI21EnvConfig.api_key = None
yield

if api_key is not None:
os.environ["API_KEY"] = api_key
AI21EnvConfig.api_key = api_key


@pytest.fixture
def mock_client_with_contextual_answers(mocker: MockerFixture) -> Mock:
mock_client = mocker.MagicMock(spec=AI21Client)
mock_client.answer = mocker.MagicMock()
mock_client.answer.create.return_value = AnswerResponse(
id="some_id",
answer="some answer",
answer_in_context=False,
)

return mock_client