-
Notifications
You must be signed in to change notification settings - Fork 13.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
core[patch], langchain[patch], templates: move openai functions parse…
…rs to core (#18060) ![Screenshot 2024-02-23 at 7 48 03 PM](https://github.com/langchain-ai/langchain/assets/22008038/e5540c4d-0020-4ece-869f-ae19db2a1f3f)
- Loading branch information
Showing
15 changed files
with
264 additions
and
252 deletions.
There are no files selected for viewing
220 changes: 220 additions & 0 deletions
220
libs/core/langchain_core/output_parsers/openai_functions.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
import copy | ||
import json | ||
from typing import Any, Dict, List, Optional, Type, Union | ||
|
||
import jsonpatch # type: ignore[import] | ||
|
||
from langchain_core.exceptions import OutputParserException | ||
from langchain_core.output_parsers import ( | ||
BaseCumulativeTransformOutputParser, | ||
BaseGenerationOutputParser, | ||
) | ||
from langchain_core.output_parsers.json import parse_partial_json | ||
from langchain_core.outputs import ChatGeneration, Generation | ||
from langchain_core.pydantic_v1 import BaseModel, root_validator | ||
|
||
|
||
class OutputFunctionsParser(BaseGenerationOutputParser[Any]): | ||
"""Parse an output that is one of sets of values.""" | ||
|
||
args_only: bool = True | ||
"""Whether to only return the arguments to the function call.""" | ||
|
||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: | ||
generation = result[0] | ||
if not isinstance(generation, ChatGeneration): | ||
raise OutputParserException( | ||
"This output parser can only be used with a chat generation." | ||
) | ||
message = generation.message | ||
try: | ||
func_call = copy.deepcopy(message.additional_kwargs["function_call"]) | ||
except KeyError as exc: | ||
raise OutputParserException(f"Could not parse function call: {exc}") | ||
|
||
if self.args_only: | ||
return func_call["arguments"] | ||
return func_call | ||
|
||
|
||
class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]): | ||
"""Parse an output as the Json object.""" | ||
|
||
strict: bool = False | ||
"""Whether to allow non-JSON-compliant strings. | ||
See: https://docs.python.org/3/library/json.html#encoders-and-decoders | ||
Useful when the parsed output may include unicode characters or new lines. | ||
""" | ||
|
||
args_only: bool = True | ||
"""Whether to only return the arguments to the function call.""" | ||
|
||
@property | ||
def _type(self) -> str: | ||
return "json_functions" | ||
|
||
def _diff(self, prev: Optional[Any], next: Any) -> Any: | ||
return jsonpatch.make_patch(prev, next).patch | ||
|
||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: | ||
if len(result) != 1: | ||
raise OutputParserException( | ||
f"Expected exactly one result, but got {len(result)}" | ||
) | ||
generation = result[0] | ||
if not isinstance(generation, ChatGeneration): | ||
raise OutputParserException( | ||
"This output parser can only be used with a chat generation." | ||
) | ||
message = generation.message | ||
try: | ||
function_call = message.additional_kwargs["function_call"] | ||
except KeyError as exc: | ||
if partial: | ||
return None | ||
else: | ||
raise OutputParserException(f"Could not parse function call: {exc}") | ||
try: | ||
if partial: | ||
try: | ||
if self.args_only: | ||
return parse_partial_json( | ||
function_call["arguments"], strict=self.strict | ||
) | ||
else: | ||
return { | ||
**function_call, | ||
"arguments": parse_partial_json( | ||
function_call["arguments"], strict=self.strict | ||
), | ||
} | ||
except json.JSONDecodeError: | ||
return None | ||
else: | ||
if self.args_only: | ||
try: | ||
return json.loads( | ||
function_call["arguments"], strict=self.strict | ||
) | ||
except (json.JSONDecodeError, TypeError) as exc: | ||
raise OutputParserException( | ||
f"Could not parse function call data: {exc}" | ||
) | ||
else: | ||
try: | ||
return { | ||
**function_call, | ||
"arguments": json.loads( | ||
function_call["arguments"], strict=self.strict | ||
), | ||
} | ||
except (json.JSONDecodeError, TypeError) as exc: | ||
raise OutputParserException( | ||
f"Could not parse function call data: {exc}" | ||
) | ||
except KeyError: | ||
return None | ||
|
||
# This method would be called by the default implementation of `parse_result` | ||
# but we're overriding that method so it's not needed. | ||
def parse(self, text: str) -> Any: | ||
raise NotImplementedError() | ||
|
||
|
||
class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser): | ||
"""Parse an output as the element of the Json object.""" | ||
|
||
key_name: str | ||
"""The name of the key to return.""" | ||
|
||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: | ||
res = super().parse_result(result, partial=partial) | ||
if partial and res is None: | ||
return None | ||
return res.get(self.key_name) if partial else res[self.key_name] | ||
|
||
|
||
class PydanticOutputFunctionsParser(OutputFunctionsParser): | ||
"""Parse an output as a pydantic object. | ||
This parser is used to parse the output of a ChatModel that uses | ||
OpenAI function format to invoke functions. | ||
The parser extracts the function call invocation and matches | ||
them to the pydantic schema provided. | ||
An exception will be raised if the function call does not match | ||
the provided schema. | ||
Example: | ||
... code-block:: python | ||
message = AIMessage( | ||
content="This is a test message", | ||
additional_kwargs={ | ||
"function_call": { | ||
"name": "cookie", | ||
"arguments": json.dumps({"name": "value", "age": 10}), | ||
} | ||
}, | ||
) | ||
chat_generation = ChatGeneration(message=message) | ||
class Cookie(BaseModel): | ||
name: str | ||
age: int | ||
class Dog(BaseModel): | ||
species: str | ||
# Full output | ||
parser = PydanticOutputFunctionsParser( | ||
pydantic_schema={"cookie": Cookie, "dog": Dog} | ||
) | ||
result = parser.parse_result([chat_generation]) | ||
""" | ||
|
||
pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]] | ||
"""The pydantic schema to parse the output with. | ||
If multiple schemas are provided, then the function name will be used to | ||
determine which schema to use. | ||
""" | ||
|
||
@root_validator(pre=True) | ||
def validate_schema(cls, values: Dict) -> Dict: | ||
schema = values["pydantic_schema"] | ||
if "args_only" not in values: | ||
values["args_only"] = isinstance(schema, type) and issubclass( | ||
schema, BaseModel | ||
) | ||
elif values["args_only"] and isinstance(schema, Dict): | ||
raise ValueError( | ||
"If multiple pydantic schemas are provided then args_only should be" | ||
" False." | ||
) | ||
return values | ||
|
||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: | ||
_result = super().parse_result(result) | ||
if self.args_only: | ||
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore | ||
else: | ||
fn_name = _result["name"] | ||
_args = _result["arguments"] | ||
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore # noqa: E501 | ||
return pydantic_args | ||
|
||
|
||
class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser): | ||
"""Parse an output as an attribute of a pydantic object.""" | ||
|
||
attr_name: str | ||
"""The name of the attribute to return.""" | ||
|
||
def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any: | ||
result = super().parse_result(result) | ||
return getattr(result, self.attr_name) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.