forked from langchain-ai/langchain
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add new beta StructuredPrompt (langchain-ai#19080)
Thank you for contributing to LangChain! - [ ] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - [ ] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** a description of the change - **Issue:** the issue # it fixes, if applicable - **Dependencies:** any dependencies required for this change - **Twitter handle:** if your PR gets announced, and you'd like a mention, we'll gladly shout you out! - [ ] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [ ] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, hwchase17.
- Loading branch information
Showing
3 changed files
with
217 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
from typing import ( | ||
Any, | ||
Callable, | ||
Dict, | ||
Iterator, | ||
Mapping, | ||
Optional, | ||
Sequence, | ||
Set, | ||
Type, | ||
Union, | ||
) | ||
|
||
from langchain_core._api.beta_decorator import beta | ||
from langchain_core.language_models.base import BaseLanguageModel | ||
from langchain_core.prompts.chat import ( | ||
BaseChatPromptTemplate, | ||
BaseMessagePromptTemplate, | ||
ChatPromptTemplate, | ||
MessageLikeRepresentation, | ||
MessagesPlaceholder, | ||
_convert_to_message, | ||
) | ||
from langchain_core.pydantic_v1 import BaseModel | ||
from langchain_core.runnables.base import ( | ||
Other, | ||
Runnable, | ||
RunnableSequence, | ||
RunnableSerializable, | ||
) | ||
|
||
|
||
@beta() | ||
class StructuredPrompt(ChatPromptTemplate): | ||
schema_: Union[Dict, Type[BaseModel]] | ||
|
||
@classmethod | ||
def from_messages_and_schema( | ||
cls, | ||
messages: Sequence[MessageLikeRepresentation], | ||
schema: Union[Dict, Type[BaseModel]], | ||
) -> ChatPromptTemplate: | ||
"""Create a chat prompt template from a variety of message formats. | ||
Examples: | ||
Instantiation from a list of message templates: | ||
.. code-block:: python | ||
class OutputSchema(BaseModel): | ||
name: str | ||
value: int | ||
template = ChatPromptTemplate.from_messages( | ||
[ | ||
("human", "Hello, how are you?"), | ||
("ai", "I'm doing well, thanks!"), | ||
("human", "That's good to hear."), | ||
], | ||
OutputSchema, | ||
) | ||
Args: | ||
messages: sequence of message representations. | ||
A message can be represented using the following formats: | ||
(1) BaseMessagePromptTemplate, (2) BaseMessage, (3) 2-tuple of | ||
(message type, template); e.g., ("human", "{user_input}"), | ||
(4) 2-tuple of (message class, template), (4) a string which is | ||
shorthand for ("human", template); e.g., "{user_input}" | ||
schema: a dictionary representation of function call, or a Pydantic model. | ||
Returns: | ||
a structured prompt template | ||
""" | ||
_messages = [_convert_to_message(message) for message in messages] | ||
|
||
# Automatically infer input variables from messages | ||
input_vars: Set[str] = set() | ||
partial_vars: Dict[str, Any] = {} | ||
for _message in _messages: | ||
if isinstance(_message, MessagesPlaceholder) and _message.optional: | ||
partial_vars[_message.variable_name] = [] | ||
elif isinstance( | ||
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate) | ||
): | ||
input_vars.update(_message.input_variables) | ||
|
||
return cls( | ||
input_variables=sorted(input_vars), | ||
messages=_messages, | ||
partial_variables=partial_vars, | ||
schema_=schema, | ||
) | ||
|
||
def __or__( | ||
self, | ||
other: Union[ | ||
Runnable[Any, Other], | ||
Callable[[Any], Other], | ||
Callable[[Iterator[Any]], Iterator[Other]], | ||
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]], | ||
], | ||
) -> RunnableSerializable[Dict, Other]: | ||
if isinstance(other, BaseLanguageModel) or hasattr( | ||
other, "with_structured_output" | ||
): | ||
return RunnableSequence(self, other.with_structured_output(self.schema_)) | ||
else: | ||
raise NotImplementedError( | ||
"Structured prompts need to be piped to a language model." | ||
) | ||
|
||
def pipe( | ||
self, | ||
*others: Union[Runnable[Any, Other], Callable[[Any], Other]], | ||
name: Optional[str] = None, | ||
) -> RunnableSerializable[Dict, Other]: | ||
if ( | ||
others | ||
and isinstance(others[0], BaseLanguageModel) | ||
or hasattr(others[0], "with_structured_output") | ||
): | ||
return RunnableSequence( | ||
self, | ||
others[0].with_structured_output(self.schema_), | ||
*others[1:], | ||
name=name, | ||
) | ||
else: | ||
raise NotImplementedError( | ||
"Structured prompts need to be piped to a language model." | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
from functools import partial | ||
from inspect import isclass | ||
from typing import Any, Dict, Type, Union, cast | ||
|
||
from langchain_core.load.dump import dumps | ||
from langchain_core.load.load import loads | ||
from langchain_core.prompts.structured import StructuredPrompt | ||
from langchain_core.pydantic_v1 import BaseModel | ||
from langchain_core.runnables.base import Runnable, RunnableLambda | ||
from tests.unit_tests.fake.chat_model import FakeListChatModel | ||
|
||
|
||
def _fake_runnable( | ||
schema: Union[Dict, Type[BaseModel]], _: Any | ||
) -> Union[BaseModel, Dict]: | ||
if isclass(schema) and issubclass(schema, BaseModel): | ||
return schema(name="yo", value=42) | ||
else: | ||
params = cast(Dict, schema)["parameters"] | ||
return {k: 1 for k, v in params.items()} | ||
|
||
|
||
class FakeStructuredChatModel(FakeListChatModel): | ||
"""Fake ChatModel for testing purposes.""" | ||
|
||
def with_structured_output(self, schema: Union[Dict, Type[BaseModel]]) -> Runnable: | ||
return RunnableLambda(partial(_fake_runnable, schema)) | ||
|
||
@property | ||
def _llm_type(self) -> str: | ||
return "fake-messages-list-chat-model" | ||
|
||
|
||
def test_structured_prompt_pydantic() -> None: | ||
class OutputSchema(BaseModel): | ||
name: str | ||
value: int | ||
|
||
prompt = StructuredPrompt.from_messages_and_schema( | ||
[ | ||
("human", "I'm very structured, how about you?"), | ||
], | ||
OutputSchema, | ||
) | ||
|
||
model = FakeStructuredChatModel(responses=[]) | ||
|
||
chain = prompt | model | ||
|
||
assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=42) | ||
|
||
|
||
def test_structured_prompt_dict() -> None: | ||
prompt = StructuredPrompt.from_messages_and_schema( | ||
[ | ||
("human", "I'm very structured, how about you?"), | ||
], | ||
{ | ||
"name": "yo", | ||
"description": "a structured output", | ||
"parameters": { | ||
"name": {"type": "string"}, | ||
"value": {"type": "integer"}, | ||
}, | ||
}, | ||
) | ||
|
||
model = FakeStructuredChatModel(responses=[]) | ||
|
||
chain = prompt | model | ||
|
||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 1} | ||
|
||
assert loads(dumps(prompt)) == prompt | ||
|
||
chain = loads(dumps(prompt)) | model | ||
|
||
assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 1} |