-
Notifications
You must be signed in to change notification settings - Fork 13.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
core: mustache prompt templates #19980
Changes from all commits
2e65654
4558174
7a3a1c3
916dd61
4625505
3d6a80e
f060d75
37bb39f
29a3bb4
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,8 +10,10 @@ | |
StringPromptTemplate, | ||
check_valid_template, | ||
get_template_variables, | ||
mustache_schema, | ||
) | ||
from langchain_core.pydantic_v1 import root_validator | ||
from langchain_core.pydantic_v1 import BaseModel, root_validator | ||
from langchain_core.runnables.config import RunnableConfig | ||
|
||
|
||
class PromptTemplate(StringPromptTemplate): | ||
|
@@ -65,12 +67,19 @@ def get_lc_namespace(cls) -> List[str]: | |
template: str | ||
"""The prompt template.""" | ||
|
||
template_format: Literal["f-string", "jinja2"] = "f-string" | ||
"""The format of the prompt template. Options are: 'f-string', 'jinja2'.""" | ||
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. A link to mustache spec would be good here |
||
"""The format of the prompt template. | ||
Options are: 'f-string', 'mustache', 'jinja2'.""" | ||
|
||
validate_template: bool = False | ||
"""Whether or not to try validating the template.""" | ||
|
||
def get_input_schema(self, config: RunnableConfig | None = None) -> type[BaseModel]: | ||
if self.template_format != "mustache": | ||
return super().get_input_schema(config) | ||
|
||
return mustache_schema(self.template) | ||
|
||
def __add__(self, other: Any) -> PromptTemplate: | ||
"""Override the + operator to allow for combining prompt templates.""" | ||
# Allow for easy combining | ||
|
@@ -121,6 +130,8 @@ def format(self, **kwargs: Any) -> str: | |
def template_is_valid(cls, values: Dict) -> Dict: | ||
"""Check that template and input variables are consistent.""" | ||
if values["validate_template"]: | ||
if values["template_format"] == "mustache": | ||
raise ValueError("Mustache templates cannot be validated.") | ||
all_inputs = values["input_variables"] + list(values["partial_variables"]) | ||
check_valid_template( | ||
values["template"], values["template_format"], all_inputs | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,10 +5,12 @@ | |
import warnings | ||
from abc import ABC | ||
from string import Formatter | ||
from typing import Any, Callable, Dict, List, Set | ||
from typing import Any, Callable, Dict, List, Set, Tuple, Type | ||
|
||
import langchain_core.utils.mustache as mustache | ||
from langchain_core.prompt_values import PromptValue, StringPromptValue | ||
from langchain_core.prompts.base import BasePromptTemplate | ||
from langchain_core.pydantic_v1 import BaseModel, create_model | ||
from langchain_core.utils import get_colored_text | ||
from langchain_core.utils.formatting import formatter | ||
from langchain_core.utils.interactive_env import is_interactive_env | ||
|
@@ -85,8 +87,70 @@ def _get_jinja2_variables_from_template(template: str) -> Set[str]: | |
return variables | ||
|
||
|
||
def mustache_formatter(template: str, **kwargs: Any) -> str: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we convert all of these into private functions unless we need them used externally? |
||
"""Format a template using mustache.""" | ||
return mustache.render(template, kwargs) | ||
|
||
|
||
def mustache_template_vars( | ||
template: str, | ||
) -> Set[str]: | ||
"""Get the variables from a mustache template.""" | ||
vars: Set[str] = set() | ||
in_section = False | ||
for type, key in mustache.tokenize(template): | ||
if type == "end": | ||
in_section = False | ||
elif in_section: | ||
continue | ||
elif type in ("variable", "section") and key != ".": | ||
vars.add(key.split(".")[0]) | ||
if type == "section": | ||
in_section = True | ||
return vars | ||
|
||
|
||
Defs = Dict[str, "Defs"] | ||
|
||
|
||
def mustache_schema( | ||
template: str, | ||
) -> Type[BaseModel]: | ||
"""Get the variables from a mustache template.""" | ||
fields = set() | ||
prefix: Tuple[str, ...] = () | ||
for type, key in mustache.tokenize(template): | ||
if key == ".": | ||
continue | ||
if type == "end": | ||
prefix = prefix[: -key.count(".")] | ||
elif type == "section": | ||
prefix = prefix + tuple(key.split(".")) | ||
elif type == "variable": | ||
fields.add(prefix + tuple(key.split("."))) | ||
defs: Defs = {} # None means leaf node | ||
while fields: | ||
field = fields.pop() | ||
current = defs | ||
for part in field[:-1]: | ||
current = current.setdefault(part, {}) | ||
current[field[-1]] = {} | ||
return _create_model_recursive("PromptInput", defs) | ||
|
||
|
||
def _create_model_recursive(name: str, defs: Defs) -> Type: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fun will review this fully tomorrow :) |
||
return create_model( # type: ignore[call-overload] | ||
name, | ||
**{ | ||
k: (_create_model_recursive(k, v), None) if v else (str, None) | ||
for k, v in defs.items() | ||
}, | ||
) | ||
|
||
|
||
DEFAULT_FORMATTER_MAPPING: Dict[str, Callable] = { | ||
"f-string": formatter.format, | ||
"mustache": mustache_formatter, | ||
"jinja2": jinja2_formatter, | ||
} | ||
|
||
|
@@ -145,6 +209,8 @@ def get_template_variables(template: str, template_format: str) -> List[str]: | |
input_variables = { | ||
v for _, v, _, _ in Formatter().parse(template) if v is not None | ||
} | ||
elif template_format == "mustache": | ||
input_variables = mustache_template_vars(template) | ||
else: | ||
raise ValueError(f"Unsupported template format: {template_format}") | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doc-string out of date