Skip to content

Commit

Permalink
core[patch], langchain[patch], templates: move openai functions parse…
Browse files Browse the repository at this point in the history
  • Loading branch information
baskaryan committed Feb 26, 2024
1 parent 96bff0e commit 767523f
Show file tree
Hide file tree
Showing 15 changed files with 264 additions and 252 deletions.
220 changes: 220 additions & 0 deletions libs/core/langchain_core/output_parsers/openai_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,220 @@
import copy
import json
from typing import Any, Dict, List, Optional, Type, Union

import jsonpatch # type: ignore[import]

from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import (
BaseCumulativeTransformOutputParser,
BaseGenerationOutputParser,
)
from langchain_core.output_parsers.json import parse_partial_json
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel, root_validator


class OutputFunctionsParser(BaseGenerationOutputParser[Any]):
"""Parse an output that is one of sets of values."""

args_only: bool = True
"""Whether to only return the arguments to the function call."""

def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
"This output parser can only be used with a chat generation."
)
message = generation.message
try:
func_call = copy.deepcopy(message.additional_kwargs["function_call"])
except KeyError as exc:
raise OutputParserException(f"Could not parse function call: {exc}")

if self.args_only:
return func_call["arguments"]
return func_call


class JsonOutputFunctionsParser(BaseCumulativeTransformOutputParser[Any]):
"""Parse an output as the Json object."""

strict: bool = False
"""Whether to allow non-JSON-compliant strings.
See: https://docs.python.org/3/library/json.html#encoders-and-decoders
Useful when the parsed output may include unicode characters or new lines.
"""

args_only: bool = True
"""Whether to only return the arguments to the function call."""

@property
def _type(self) -> str:
return "json_functions"

def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch

def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
if len(result) != 1:
raise OutputParserException(
f"Expected exactly one result, but got {len(result)}"
)
generation = result[0]
if not isinstance(generation, ChatGeneration):
raise OutputParserException(
"This output parser can only be used with a chat generation."
)
message = generation.message
try:
function_call = message.additional_kwargs["function_call"]
except KeyError as exc:
if partial:
return None
else:
raise OutputParserException(f"Could not parse function call: {exc}")
try:
if partial:
try:
if self.args_only:
return parse_partial_json(
function_call["arguments"], strict=self.strict
)
else:
return {
**function_call,
"arguments": parse_partial_json(
function_call["arguments"], strict=self.strict
),
}
except json.JSONDecodeError:
return None
else:
if self.args_only:
try:
return json.loads(
function_call["arguments"], strict=self.strict
)
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
else:
try:
return {
**function_call,
"arguments": json.loads(
function_call["arguments"], strict=self.strict
),
}
except (json.JSONDecodeError, TypeError) as exc:
raise OutputParserException(
f"Could not parse function call data: {exc}"
)
except KeyError:
return None

# This method would be called by the default implementation of `parse_result`
# but we're overriding that method so it's not needed.
def parse(self, text: str) -> Any:
raise NotImplementedError()


class JsonKeyOutputFunctionsParser(JsonOutputFunctionsParser):
"""Parse an output as the element of the Json object."""

key_name: str
"""The name of the key to return."""

def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
res = super().parse_result(result, partial=partial)
if partial and res is None:
return None
return res.get(self.key_name) if partial else res[self.key_name]


class PydanticOutputFunctionsParser(OutputFunctionsParser):
"""Parse an output as a pydantic object.
This parser is used to parse the output of a ChatModel that uses
OpenAI function format to invoke functions.
The parser extracts the function call invocation and matches
them to the pydantic schema provided.
An exception will be raised if the function call does not match
the provided schema.
Example:
... code-block:: python
message = AIMessage(
content="This is a test message",
additional_kwargs={
"function_call": {
"name": "cookie",
"arguments": json.dumps({"name": "value", "age": 10}),
}
},
)
chat_generation = ChatGeneration(message=message)
class Cookie(BaseModel):
name: str
age: int
class Dog(BaseModel):
species: str
# Full output
parser = PydanticOutputFunctionsParser(
pydantic_schema={"cookie": Cookie, "dog": Dog}
)
result = parser.parse_result([chat_generation])
"""

pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]]
"""The pydantic schema to parse the output with.
If multiple schemas are provided, then the function name will be used to
determine which schema to use.
"""

@root_validator(pre=True)
def validate_schema(cls, values: Dict) -> Dict:
schema = values["pydantic_schema"]
if "args_only" not in values:
values["args_only"] = isinstance(schema, type) and issubclass(
schema, BaseModel
)
elif values["args_only"] and isinstance(schema, Dict):
raise ValueError(
"If multiple pydantic schemas are provided then args_only should be"
" False."
)
return values

def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
_result = super().parse_result(result)
if self.args_only:
pydantic_args = self.pydantic_schema.parse_raw(_result) # type: ignore
else:
fn_name = _result["name"]
_args = _result["arguments"]
pydantic_args = self.pydantic_schema[fn_name].parse_raw(_args) # type: ignore # noqa: E501
return pydantic_args


class PydanticAttrOutputFunctionsParser(PydanticOutputFunctionsParser):
"""Parse an output as an attribute of a pydantic object."""

attr_name: str
"""The name of the attribute to return."""

def parse_result(self, result: List[Generation], *, partial: bool = False) -> Any:
result = super().parse_result(result)
return getattr(result, self.attr_name)
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
from typing import Any, Dict

import pytest

from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage
from langchain_core.outputs import ChatGeneration

from langchain.output_parsers.openai_functions import (
from langchain_core.output_parsers.openai_functions import (
JsonOutputFunctionsParser,
PydanticOutputFunctionsParser,
)
from langchain.pydantic_v1 import BaseModel
from langchain_core.outputs import ChatGeneration
from langchain_core.pydantic_v1 import BaseModel


def test_json_output_function_parser() -> None:
Expand Down
6 changes: 3 additions & 3 deletions libs/langchain/langchain/chains/openai_functions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from langchain_core.output_parsers import (
BaseLLMOutputParser,
)
from langchain_core.output_parsers.openai_functions import (
PydanticAttrOutputFunctionsParser,
)
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.function_calling import (
Expand All @@ -27,9 +30,6 @@
create_structured_output_runnable,
get_openai_output_parser,
)
from langchain.output_parsers.openai_functions import (
PydanticAttrOutputFunctionsParser,
)

__all__ = [
"get_openai_output_parser",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.output_parsers.openai_functions import PydanticOutputFunctionsParser
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field

from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions.utils import get_llm_kwargs
from langchain.output_parsers.openai_functions import (
PydanticOutputFunctionsParser,
)


class FactWithEvidence(BaseModel):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Any, List, Optional

from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers.openai_functions import (
JsonKeyOutputFunctionsParser,
PydanticAttrOutputFunctionsParser,
)
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel

Expand All @@ -11,10 +15,6 @@
_resolve_schema_references,
get_llm_kwargs,
)
from langchain.output_parsers.openai_functions import (
JsonKeyOutputFunctionsParser,
PydanticAttrOutputFunctionsParser,
)


def _get_extraction_function(entity_schema: dict) -> dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from langchain_community.utilities.openapi import OpenAPISpec
from langchain_core.callbacks import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain_core.prompts import BasePromptTemplate, ChatPromptTemplate
from langchain_core.utils.input import get_colored_text
from requests import Response

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.sequential import SequentialChain
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.tools import APIOperation

if TYPE_CHECKING:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,16 @@
from langchain_core.language_models import BaseLanguageModel
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.output_parsers import BaseLLMOutputParser
from langchain_core.output_parsers.openai_functions import (
OutputFunctionsParser,
PydanticOutputFunctionsParser,
)
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field

from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions.utils import get_llm_kwargs
from langchain.output_parsers.openai_functions import (
OutputFunctionsParser,
PydanticOutputFunctionsParser,
)


class AnswerWithSources(BaseModel):
Expand Down
8 changes: 4 additions & 4 deletions libs/langchain/langchain/chains/openai_functions/tagging.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import Any, Optional

from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers.openai_functions import (
JsonOutputFunctionsParser,
PydanticOutputFunctionsParser,
)
from langchain_core.prompts import ChatPromptTemplate

from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.chains.openai_functions.utils import _convert_schema, get_llm_kwargs
from langchain.output_parsers.openai_functions import (
JsonOutputFunctionsParser,
PydanticOutputFunctionsParser,
)


def _get_tagging_function(schema: dict) -> dict:
Expand Down
10 changes: 5 additions & 5 deletions libs/langchain/langchain/chains/structured_output/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
BaseOutputParser,
JsonOutputParser,
)
from langchain_core.output_parsers.openai_functions import (
JsonOutputFunctionsParser,
PydanticAttrOutputFunctionsParser,
PydanticOutputFunctionsParser,
)
from langchain_core.prompts import BasePromptTemplate
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables import Runnable
Expand All @@ -19,11 +24,6 @@
PydanticOutputParser,
PydanticToolsParser,
)
from langchain.output_parsers.openai_functions import (
JsonOutputFunctionsParser,
PydanticAttrOutputFunctionsParser,
PydanticOutputFunctionsParser,
)


def create_openai_fn_runnable(
Expand Down

0 comments on commit 767523f

Please sign in to comment.