Skip to content

Commit

Permalink
core: mustache prompt templates (#19980)
Browse files Browse the repository at this point in the history
Co-authored-by: Erick Friis <erick@langchain.dev>
  • Loading branch information
2 people authored and hinthornw committed Apr 26, 2024
1 parent 5133b56 commit 71e7711
Show file tree
Hide file tree
Showing 6 changed files with 904 additions and 12 deletions.
33 changes: 25 additions & 8 deletions libs/core/langchain_core/prompts/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Any,
Dict,
List,
Literal,
Optional,
Sequence,
Set,
Expand Down Expand Up @@ -929,6 +930,7 @@ def from_strings(
def from_messages(
cls,
messages: Sequence[MessageLikeRepresentation],
template_format: Literal["f-string", "mustache"] = "f-string",
) -> ChatPromptTemplate:
"""Create a chat prompt template from a variety of message formats.
Expand Down Expand Up @@ -964,7 +966,9 @@ def from_messages(
Returns:
a chat prompt template
"""
_messages = [_convert_to_message(message) for message in messages]
_messages = [
_convert_to_message(message, template_format) for message in messages
]

# Automatically infer input variables from messages
input_vars: Set[str] = set()
Expand Down Expand Up @@ -1121,7 +1125,9 @@ def pretty_repr(self, html: bool = False) -> str:


def _create_template_from_message_type(
message_type: str, template: Union[str, list]
message_type: str,
template: Union[str, list],
template_format: Literal["f-string", "mustache"] = "f-string",
) -> BaseMessagePromptTemplate:
"""Create a message prompt template from a message type and template string.
Expand All @@ -1134,12 +1140,16 @@ def _create_template_from_message_type(
"""
if message_type in ("human", "user"):
message: BaseMessagePromptTemplate = HumanMessagePromptTemplate.from_template(
template
template, template_format=template_format
)
elif message_type in ("ai", "assistant"):
message = AIMessagePromptTemplate.from_template(cast(str, template))
message = AIMessagePromptTemplate.from_template(
cast(str, template), template_format=template_format
)
elif message_type == "system":
message = SystemMessagePromptTemplate.from_template(cast(str, template))
message = SystemMessagePromptTemplate.from_template(
cast(str, template), template_format=template_format
)
elif message_type == "placeholder":
if isinstance(template, str):
if template[0] != "{" or template[-1] != "}":
Expand Down Expand Up @@ -1180,6 +1190,7 @@ def _create_template_from_message_type(

def _convert_to_message(
message: MessageLikeRepresentation,
template_format: Literal["f-string", "mustache"] = "f-string",
) -> Union[BaseMessage, BaseMessagePromptTemplate, BaseChatPromptTemplate]:
"""Instantiate a message from a variety of message formats.
Expand All @@ -1204,16 +1215,22 @@ def _convert_to_message(
elif isinstance(message, BaseMessage):
_message = message
elif isinstance(message, str):
_message = _create_template_from_message_type("human", message)
_message = _create_template_from_message_type(
"human", message, template_format=template_format
)
elif isinstance(message, tuple):
if len(message) != 2:
raise ValueError(f"Expected 2-tuple of (role, template), got {message}")
message_type_str, template = message
if isinstance(message_type_str, str):
_message = _create_template_from_message_type(message_type_str, template)
_message = _create_template_from_message_type(
message_type_str, template, template_format=template_format
)
else:
_message = message_type_str(
prompt=PromptTemplate.from_template(cast(str, template))
prompt=PromptTemplate.from_template(
cast(str, template), template_format=template_format
)
)
else:
raise NotImplementedError(f"Unsupported message type: {type(message)}")
Expand Down
17 changes: 14 additions & 3 deletions libs/core/langchain_core/prompts/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
StringPromptTemplate,
check_valid_template,
get_template_variables,
mustache_schema,
)
from langchain_core.pydantic_v1 import root_validator
from langchain_core.pydantic_v1 import BaseModel, root_validator
from langchain_core.runnables.config import RunnableConfig


class PromptTemplate(StringPromptTemplate):
Expand Down Expand Up @@ -65,12 +67,19 @@ def get_lc_namespace(cls) -> List[str]:
template: str
"""The prompt template."""

template_format: Literal["f-string", "jinja2"] = "f-string"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string"
"""The format of the prompt template.
Options are: 'f-string', 'mustache', 'jinja2'."""

validate_template: bool = False
"""Whether or not to try validating the template."""

def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]:
if self.template_format != "mustache":
return super().get_input_schema(config)

return mustache_schema(self.template)

def __add__(self, other: Any) -> PromptTemplate:
"""Override the + operator to allow for combining prompt templates."""
# Allow for easy combining
Expand Down Expand Up @@ -121,6 +130,8 @@ def format(self, **kwargs: Any) -> str:
def template_is_valid(cls, values: Dict) -> Dict:
"""Check that template and input variables are consistent."""
if values["validate_template"]:
if values["template_format"] == "mustache":
raise ValueError("Mustache templates cannot be validated.")
all_inputs = values["input_variables"] + list(values["partial_variables"])
check_valid_template(
values["template"], values["template_format"], all_inputs
Expand Down
68 changes: 67 additions & 1 deletion libs/core/langchain_core/prompts/string.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@
import warnings
from abc import ABC
from string import Formatter
from typing import Any, Callable, Dict, List, Set
from typing import Any, Callable, Dict, List, Set, Tuple, Type

import langchain_core.utils.mustache as mustache
from langchain_core.prompt_values import PromptValue, StringPromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, create_model
from langchain_core.utils import get_colored_text
from langchain_core.utils.formatting import formatter
from langchain_core.utils.interactive_env import is_interactive_env
Expand Down Expand Up @@ -85,8 +87,70 @@ def _get_jinja2_variables_from_template(template: str) -> Set[str]:
return variables


def mustache_formatter(template: str, **kwargs: Any) -> str:
"""Format a template using mustache."""
return mustache.render(template, kwargs)


def mustache_template_vars(
template: str,
) -> Set[str]:
"""Get the variables from a mustache template."""
vars: Set[str] = set()
in_section = False
for type, key in mustache.tokenize(template):
if type == "end":
in_section = False
elif in_section:
continue
elif type in ("variable", "section") and key != ".":
vars.add(key.split(".")[0])
if type == "section":
in_section = True
return vars


Defs = Dict[str, "Defs"]


def mustache_schema(
template: str,
) -> Type[BaseModel]:
"""Get the variables from a mustache template."""
fields = set()
prefix: Tuple[str, ...] = ()
for type, key in mustache.tokenize(template):
if key == ".":
continue
if type == "end":
prefix = prefix[: -key.count(".")]
elif type == "section":
prefix = prefix + tuple(key.split("."))
elif type == "variable":
fields.add(prefix + tuple(key.split(".")))
defs: Defs = {} # None means leaf node
while fields:
field = fields.pop()
current = defs
for part in field[:-1]:
current = current.setdefault(part, {})
current[field[-1]] = {}
return _create_model_recursive("PromptInput", defs)


def _create_model_recursive(name: str, defs: Defs) -> Type:
return create_model( # type: ignore[call-overload]
name,
**{
k: (_create_model_recursive(k, v), None) if v else (str, None)
for k, v in defs.items()
},
)


DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = {
"f-string": formatter.format,
"mustache": mustache_formatter,
"jinja2": jinja2_formatter,
}

Expand Down Expand Up @@ -145,6 +209,8 @@ def get_template_variables(template: str, template_format: str) -> List[str]:
input_variables = {
v for _, v, _, _ in Formatter().parse(template) if v is not None
}
elif template_format == "mustache":
input_variables = mustache_template_vars(template)
else:
raise ValueError(f"Unsupported template format: {template_format}")

Expand Down

0 comments on commit 71e7711

Please sign in to comment.