Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: mustache prompt templates #19980

Merged
merged 9 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
33 changes: 25 additions & 8 deletions libs/core/langchain_core/prompts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Any,
Dict,
List,
Literal,
Optional,
Sequence,
Set,
Expand Down Expand Up @@ -929,6 +930,7 @@ def from_strings(
def from_messages(
cls,
messages: Sequence[MessageLikeRepresentation],
template_format: Literal["f-string", "mustache"] = "f-string",
) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats.

Expand Down Expand Up @@ -964,7 +966,9 @@ def from_messages(
Returns:
a chat prompt template
"""
_messages = [_convert_to_message(message) for message in messages]
_messages = [
_convert_to_message(message, template_format) for message in messages
]

# Automatically infer input variables from messages
input_vars: Set[str] = set()
Expand Down Expand Up @@ -1121,7 +1125,9 @@ def pretty_repr(self, html: bool = False) -> str:


def _create_template_from_message_type(
message_type: str, template: Union[str, list]
message_type: str,
template: Union[str, list],
template_format: Literal["f-string", "mustache"] = "f-string",
) -> BaseMessagePromptTemplate:
"""Create a message prompt template from a message type and template string.

Expand All @@ -1134,12 +1140,16 @@ def _create_template_from_message_type(
"""
if message_type in ("human", "user"):
message: BaseMessagePromptTemplate = HumanMessagePromptTemplate.from_template(
template
template, template_format=template_format
)
elif message_type in ("ai", "assistant"):
message = AIMessagePromptTemplate.from_template(cast(str, template))
message = AIMessagePromptTemplate.from_template(
cast(str, template), template_format=template_format
)
elif message_type == "system":
message = SystemMessagePromptTemplate.from_template(cast(str, template))
message = SystemMessagePromptTemplate.from_template(
cast(str, template), template_format=template_format
)
elif message_type == "placeholder":
if isinstance(template, str):
if template[0] != "{" or template[-1] != "}":
Expand Down Expand Up @@ -1180,6 +1190,7 @@ def _create_template_from_message_type(

def _convert_to_message(
message: MessageLikeRepresentation,
template_format: Literal["f-string", "mustache"] = "f-string",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doc-string out of date

) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
"""Instantiate a message from a variety of message formats.

Expand All @@ -1204,16 +1215,22 @@ def _convert_to_message(
elif isinstance(message, BaseMessage):
_message = message
elif isinstance(message, str):
_message = _create_template_from_message_type("human", message)
_message = _create_template_from_message_type(
"human", message, template_format=template_format
)
elif isinstance(message, tuple):
if len(message) != 2:
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
message_type_str, template = message
if isinstance(message_type_str, str):
_message = _create_template_from_message_type(message_type_str, template)
_message = _create_template_from_message_type(
message_type_str, template, template_format=template_format
)
else:
_message = message_type_str(
prompt=PromptTemplate.from_template(cast(str, template))
prompt=PromptTemplate.from_template(
cast(str, template), template_format=template_format
)
)
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")
Expand Down
17 changes: 14 additions & 3 deletions libs/core/langchain_core/prompts/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
StringPromptTemplate,
check_valid_template,
get_template_variables,
mustache_schema,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.runnables.config import RunnableConfig


class PromptTemplate(StringPromptTemplate):
Expand Down Expand Up @@ -65,12 +67,19 @@ def get_lc_namespace(cls) -> List[str]:
template: str
"""The prompt template."""

template_format: Literal["f-string", "jinja2"] = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A link to mustache spec would be good here

"""The format of the prompt template.
Options are: 'f-string', 'mustache', 'jinja2'."""

validate_template: bool = False
"""Whether or not to try validating the template."""

def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
if self.template_format != "mustache":
return super().get_input_schema(config)

return mustache_schema(self.template)

def __add__(self, other: Any) -> PromptTemplate:
"""Override the + operator to allow for combining prompt templates."""
# Allow for easy combining
Expand Down Expand Up @@ -121,6 +130,8 @@ def format(self, **kwargs: Any) -> str:
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that template and input variables are consistent."""
if values["validate_template"]:
if values["template_format"] == "mustache":
raise ValueError("Mustache templates cannot be validated.")
all_inputs = values["input_variables"] + list(values["partial_variables"])
check_valid_template(
values["template"], values["template_format"], all_inputs
Expand Down
68 changes: 67 additions & 1 deletion libs/core/langchain_core/prompts/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import warnings
from abc import ABC
from string import Formatter
from typing import Any, Callable, Dict, List, Set
from typing import Any, Callable, Dict, List, Set, Tuple, Type

import langchain_core.utils.mustache as mustache
from langchain_core.prompt_values import PromptValue, StringPromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, create_model
from langchain_core.utils import get_colored_text
from langchain_core.utils.formatting import formatter
from langchain_core.utils.interactive_env import is_interactive_env
Expand Down Expand Up @@ -85,8 +87,70 @@ def _get_jinja2_variables_from_template(template: str) -> Set[str]:
return variables


def mustache_formatter(template: str, **kwargs: Any) -> str:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we convert all of these into private functions unless we need them used externally?

"""Format a template using mustache."""
return mustache.render(template, kwargs)


def mustache_template_vars(
template: str,
) -> Set[str]:
"""Get the variables from a mustache template."""
vars: Set[str] = set()
in_section = False
for type, key in mustache.tokenize(template):
if type == "end":
in_section = False
elif in_section:
continue
elif type in ("variable", "section") and key != ".":
vars.add(key.split(".")[0])
if type == "section":
in_section = True
return vars


Defs = Dict[str, "Defs"]


def mustache_schema(
template: str,
) -> Type[BaseModel]:
"""Get the variables from a mustache template."""
fields = set()
prefix: Tuple[str, ...] = ()
for type, key in mustache.tokenize(template):
if key == ".":
continue
if type == "end":
prefix = prefix[: -key.count(".")]
elif type == "section":
prefix = prefix + tuple(key.split("."))
elif type == "variable":
fields.add(prefix + tuple(key.split(".")))
defs: Defs = {} # None means leaf node
while fields:
field = fields.pop()
current = defs
for part in field[:-1]:
current = current.setdefault(part, {})
current[field[-1]] = {}
return _create_model_recursive("PromptInput", defs)


def _create_model_recursive(name: str, defs: Defs) -> Type:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fun will review this fully tomorrow :)

return create_model( # type: ignore[call-overload]
name,
**{
k: (_create_model_recursive(k, v), None) if v else (str, None)
for k, v in defs.items()
},
)


DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
"f-string": formatter.format,
"mustache": mustache_formatter,
"jinja2": jinja2_formatter,
}

Expand Down Expand Up @@ -145,6 +209,8 @@ def get_template_variables(template: str, template_format: str) -> List[str]:
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
elif template_format == "mustache":
input_variables = mustache_template_vars(template)
else:
raise ValueError(f"Unsupported template format: {template_format}")

Expand Down