Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add JSON representation of runnable graph to serialized representation #17745

Merged
merged 13 commits into from
Feb 20, 2024
15 changes: 14 additions & 1 deletion libs/core/langchain_core/load/serializable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
from abc import ABC
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
from typing import (
Any,
Dict,
List,
Literal,
Optional,
TypedDict,
Union,
cast,
)

from typing_extensions import NotRequired

from langchain_core.pydantic_v1 import BaseModel, PrivateAttr

Expand All @@ -9,6 +20,8 @@ class BaseSerialized(TypedDict):

lc: int
id: List[str]
name: NotRequired[str]
graph: NotRequired[Dict[str, Any]]


class SerializedConstructor(BaseSerialized):
Expand Down
16 changes: 15 additions & 1 deletion libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@

from langchain_core._api import beta_decorator
from langchain_core.load.dump import dumpd
from langchain_core.load.serializable import Serializable
from langchain_core.load.serializable import (
Serializable,
SerializedConstructor,
SerializedNotImplemented,
)
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables.config import (
RunnableConfig,
Expand Down Expand Up @@ -1630,6 +1634,16 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
name: Optional[str] = None
"""The name of the runnable. Used for debugging and tracing."""

def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
"""Serialize the runnable to JSON."""
dumped = super().to_json()
try:
dumped["name"] = self.get_name()
dumped["graph"] = self.get_graph().to_json()
except Exception:
pass
return dumped

def configurable_fields(
self, **kwargs: AnyConfigurableField
) -> RunnableSerializable[Input, Output]:
Expand Down
130 changes: 103 additions & 27 deletions libs/core/langchain_core/runnables/graph.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import inspect
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional, Type, Union
from uuid import uuid4
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Type, Union
from uuid import UUID, uuid4

from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.graph_draw import draw
Expand All @@ -11,11 +12,20 @@
from langchain_core.runnables.base import Runnable as RunnableType
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unrelated but why the rename

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it cause issues when you actually import in functions

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea because I need to import the actual class inside a function



def is_uuid(value: str) -> bool:
try:
UUID(value)
return True
except ValueError:
return False


class Edge(NamedTuple):
"""Edge in a graph."""

source: str
target: str
data: Optional[str] = None


class Node(NamedTuple):
Expand All @@ -25,22 +35,108 @@ class Node(NamedTuple):
data: Union[Type[BaseModel], RunnableType]


def node_data_str(node: Node) -> str:
from langchain_core.runnables.base import Runnable

if not is_uuid(node.id):
return node.id
elif isinstance(node.data, Runnable):
try:
data = str(node.data)
if (
data.startswith("<")
or data[0] != data[0].upper()
or len(data.splitlines()) > 1
):
data = node.data.__class__.__name__
elif len(data) > 42:
data = data[:42] + "..."
except Exception:
data = node.data.__class__.__name__
else:
data = node.data.__name__
return data if not data.startswith("Runnable") else data[8:]


def node_data_json(node: Node) -> Dict[str, Union[str, Dict[str, Any]]]:
from langchain_core.load.serializable import to_json_not_implemented
from langchain_core.runnables.base import Runnable, RunnableSerializable

if isinstance(node.data, RunnableSerializable):
return {
"type": "runnable",
"data": {
"id": node.data.lc_id(),
"name": node.data.get_name(),
},
}
elif isinstance(node.data, Runnable):
return {
"type": "runnable",
"data": {
"id": to_json_not_implemented(node.data)["id"],
"name": node.data.get_name(),
},
}
elif inspect.isclass(node.data) and issubclass(node.data, BaseModel):
return {
"type": "schema",
"data": node.data.schema(),
}
else:
return {
"type": "unknown",
"data": node_data_str(node),
}


@dataclass
class Graph:
"""Graph of nodes and edges."""

nodes: Dict[str, Node] = field(default_factory=dict)
edges: List[Edge] = field(default_factory=list)
Copy link
Collaborator

@eyurtsev eyurtsev Feb 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are edges in both directions allowed simultaneously?

i.e., if there's node 1 and 2

can we get edge 1 -> 2 and edge 2 -> 1 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In theory possible yea


def to_json(self) -> Dict[str, List[Dict[str, Any]]]:
"""Convert the graph to a JSON-serializable format."""
stable_node_ids = {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does stable mean here

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah ic, so that two graphs with equivalent objects and edges would have same serialization?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically if you generate the same chain twice you should get same json graph

node.id: i if is_uuid(node.id) else node.id
for i, node in enumerate(self.nodes.values())
}

return {
"nodes": [
{"id": stable_node_ids[node.id], **node_data_json(node)}
for node in self.nodes.values()
],
"edges": [
{
"source": stable_node_ids[edge.source],
"target": stable_node_ids[edge.target],
"data": edge.data,
}
if edge.data is not None
else {
"source": stable_node_ids[edge.source],
"target": stable_node_ids[edge.target],
}
for edge in self.edges
],
}

def __bool__(self) -> bool:
return bool(self.nodes)

def next_id(self) -> str:
return uuid4().hex

def add_node(self, data: Union[Type[BaseModel], RunnableType]) -> Node:
def add_node(
self, data: Union[Type[BaseModel], RunnableType], id: Optional[str] = None
) -> Node:
"""Add a node to the graph and return it."""
node = Node(id=self.next_id(), data=data)
if id is not None and id in self.nodes:
raise ValueError(f"Node with id {id} already exists")
node = Node(id=id or self.next_id(), data=data)
self.nodes[node.id] = node
return node

Expand All @@ -53,13 +149,13 @@ def remove_node(self, node: Node) -> None:
if edge.source != node.id and edge.target != node.id
]

def add_edge(self, source: Node, target: Node) -> Edge:
def add_edge(self, source: Node, target: Node, data: Optional[str] = None) -> Edge:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I generally prefer * prior to all arguments with None defaults.

"""Add an edge to the graph and return it."""
if source.id not in self.nodes:
raise ValueError(f"Source node {source.id} not in graph")
if target.id not in self.nodes:
raise ValueError(f"Target node {target.id} not in graph")
edge = Edge(source=source.id, target=target.id)
edge = Edge(source=source.id, target=target.id, data=data)
self.edges.append(edge)
return edge

Expand Down Expand Up @@ -117,28 +213,8 @@ def trim_last_node(self) -> None:
self.remove_node(last_node)

def draw_ascii(self) -> str:
from langchain_core.runnables.base import Runnable

def node_data(node: Node) -> str:
if isinstance(node.data, Runnable):
try:
data = str(node.data)
if (
data.startswith("<")
or data[0] != data[0].upper()
or len(data.splitlines()) > 1
):
data = node.data.__class__.__name__
elif len(data) > 42:
data = data[:42] + "..."
except Exception:
data = node.data.__class__.__name__
else:
data = node.data.__name__
return data if not data.startswith("Runnable") else data[8:]

return draw(
{node.id: node_data(node) for node in self.nodes.values()},
{node.id: node_data_str(node) for node in self.nodes.values()},
[(edge.source, edge.target) for edge in self.edges],
)

Expand Down