Skip to content

Commit

Permalink
Add new beta StructuredPrompt (#19080)
Browse files Browse the repository at this point in the history
Thank you for contributing to LangChain!

- [ ] **PR title**: "package: description"
- Where "package" is whichever of langchain, community, core,
experimental, etc. is being modified. Use "docs: ..." for purely docs
changes, "templates: ..." for template changes, "infra: ..." for CI
changes.
  - Example: "community: add foobar LLM"


- [ ] **PR message**: ***Delete this entire checklist*** and replace
with
    - **Description:** a description of the change
    - **Issue:** the issue # it fixes, if applicable
    - **Dependencies:** any dependencies required for this change
- **Twitter handle:** if your PR gets announced, and you'd like a
mention, we'll gladly shout you out!


- [ ] **Add tests and docs**: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.


- [ ] **Lint and test**: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified. See contribution
guidelines for more: https://python.langchain.com/docs/contributing/

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, hwchase17.
  • Loading branch information
nfcampos authored and hinthornw committed Apr 26, 2024
1 parent f751ae1 commit bc3da71
Show file tree
Hide file tree
Showing 3 changed files with 217 additions and 0 deletions.
6 changes: 6 additions & 0 deletions libs/core/langchain_core/load/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,12 @@
"image",
"ImagePromptTemplate",
),
("langchain", "prompts", "chat", "StructuredPrompt"): (
"langchain_core",
"prompts",
"structured",
"StructuredPrompt",
),
}

# Needed for backwards compatibility for a few versions where we serialized
Expand Down
133 changes: 133 additions & 0 deletions libs/core/langchain_core/prompts/structured.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from typing import (
Any,
Callable,
Dict,
Iterator,
Mapping,
Optional,
Sequence,
Set,
Type,
Union,
)

from langchain_core._api.beta_decorator import beta
from langchain_core.language_models.base import BaseLanguageModel
from langchain_core.prompts.chat import (
BaseChatPromptTemplate,
BaseMessagePromptTemplate,
ChatPromptTemplate,
MessageLikeRepresentation,
MessagesPlaceholder,
_convert_to_message,
)
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import (
Other,
Runnable,
RunnableSequence,
RunnableSerializable,
)


@beta()
class StructuredPrompt(ChatPromptTemplate):
schema_: Union[Dict, Type[BaseModel]]

@classmethod
def from_messages_and_schema(
cls,
messages: Sequence[MessageLikeRepresentation],
schema: Union[Dict, Type[BaseModel]],
) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats.
Examples:
Instantiation from a list of message templates:
.. code-block:: python
class OutputSchema(BaseModel):
name: str
value: int
template = ChatPromptTemplate.from_messages(
[
("human", "Hello, how are you?"),
("ai", "I'm doing well, thanks!"),
("human", "That's good to hear."),
],
OutputSchema,
)
Args:
messages: sequence of message representations.
A message can be represented using the following formats:
(1) BaseMessagePromptTemplate, (2) BaseMessage, (3) 2-tuple of
(message type, template); e.g., ("human", "{user_input}"),
(4) 2-tuple of (message class, template), (4) a string which is
shorthand for ("human", template); e.g., "{user_input}"
schema: a dictionary representation of function call, or a Pydantic model.
Returns:
a structured prompt template
"""
_messages = [_convert_to_message(message) for message in messages]

# Automatically infer input variables from messages
input_vars: Set[str] = set()
partial_vars: Dict[str, Any] = {}
for _message in _messages:
if isinstance(_message, MessagesPlaceholder) and _message.optional:
partial_vars[_message.variable_name] = []
elif isinstance(
_message, (BaseChatPromptTemplate, BaseMessagePromptTemplate)
):
input_vars.update(_message.input_variables)

return cls(
input_variables=sorted(input_vars),
messages=_messages,
partial_variables=partial_vars,
schema_=schema,
)

def __or__(
self,
other: Union[
Runnable[Any, Other],
Callable[[Any], Other],
Callable[[Iterator[Any]], Iterator[Other]],
Mapping[str, Union[Runnable[Any, Other], Callable[[Any], Other], Any]],
],
) -> RunnableSerializable[Dict, Other]:
if isinstance(other, BaseLanguageModel) or hasattr(
other, "with_structured_output"
):
return RunnableSequence(self, other.with_structured_output(self.schema_))
else:
raise NotImplementedError(
"Structured prompts need to be piped to a language model."
)

def pipe(
self,
*others: Union[Runnable[Any, Other], Callable[[Any], Other]],
name: Optional[str] = None,
) -> RunnableSerializable[Dict, Other]:
if (
others
and isinstance(others[0], BaseLanguageModel)
or hasattr(others[0], "with_structured_output")
):
return RunnableSequence(
self,
others[0].with_structured_output(self.schema_),
*others[1:],
name=name,
)
else:
raise NotImplementedError(
"Structured prompts need to be piped to a language model."
)
78 changes: 78 additions & 0 deletions libs/core/tests/unit_tests/prompts/test_structured.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
from functools import partial
from inspect import isclass
from typing import Any, Dict, Type, Union, cast

from langchain_core.load.dump import dumps
from langchain_core.load.load import loads
from langchain_core.prompts.structured import StructuredPrompt
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.base import Runnable, RunnableLambda
from tests.unit_tests.fake.chat_model import FakeListChatModel


def _fake_runnable(
schema: Union[Dict, Type[BaseModel]], _: Any
) -> Union[BaseModel, Dict]:
if isclass(schema) and issubclass(schema, BaseModel):
return schema(name="yo", value=42)
else:
params = cast(Dict, schema)["parameters"]
return {k: 1 for k, v in params.items()}


class FakeStructuredChatModel(FakeListChatModel):
"""Fake ChatModel for testing purposes."""

def with_structured_output(self, schema: Union[Dict, Type[BaseModel]]) -> Runnable:
return RunnableLambda(partial(_fake_runnable, schema))

@property
def _llm_type(self) -> str:
return "fake-messages-list-chat-model"


def test_structured_prompt_pydantic() -> None:
class OutputSchema(BaseModel):
name: str
value: int

prompt = StructuredPrompt.from_messages_and_schema(
[
("human", "I'm very structured, how about you?"),
],
OutputSchema,
)

model = FakeStructuredChatModel(responses=[])

chain = prompt | model

assert chain.invoke({"hello": "there"}) == OutputSchema(name="yo", value=42)


def test_structured_prompt_dict() -> None:
prompt = StructuredPrompt.from_messages_and_schema(
[
("human", "I'm very structured, how about you?"),
],
{
"name": "yo",
"description": "a structured output",
"parameters": {
"name": {"type": "string"},
"value": {"type": "integer"},
},
},
)

model = FakeStructuredChatModel(responses=[])

chain = prompt | model

assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 1}

assert loads(dumps(prompt)) == prompt

chain = loads(dumps(prompt)) | model

assert chain.invoke({"hello": "there"}) == {"name": 1, "value": 1}

0 comments on commit bc3da71

Please sign in to comment.