-
Notifications
You must be signed in to change notification settings - Fork 13.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add JSON representation of runnable graph to serialized representation #17745
Changes from all commits
0763f04
811d5ba
b9e3de2
1a167f4
526a307
4922147
33b0395
82f70f3
0434f2e
f157117
a32e61d
dadf63f
b871c58
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,9 @@ | ||
from __future__ import annotations | ||
|
||
import inspect | ||
from dataclasses import dataclass, field | ||
from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional, Type, Union | ||
from uuid import uuid4 | ||
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Type, Union | ||
from uuid import UUID, uuid4 | ||
|
||
from langchain_core.pydantic_v1 import BaseModel | ||
from langchain_core.runnables.graph_draw import draw | ||
|
@@ -11,11 +12,20 @@ | |
from langchain_core.runnables.base import Runnable as RunnableType | ||
|
||
|
||
def is_uuid(value: str) -> bool: | ||
try: | ||
UUID(value) | ||
return True | ||
except ValueError: | ||
return False | ||
|
||
|
||
class Edge(NamedTuple): | ||
"""Edge in a graph.""" | ||
|
||
source: str | ||
target: str | ||
data: Optional[str] = None | ||
|
||
|
||
class Node(NamedTuple): | ||
|
@@ -25,22 +35,108 @@ class Node(NamedTuple): | |
data: Union[Type[BaseModel], RunnableType] | ||
|
||
|
||
def node_data_str(node: Node) -> str: | ||
from langchain_core.runnables.base import Runnable | ||
|
||
if not is_uuid(node.id): | ||
return node.id | ||
elif isinstance(node.data, Runnable): | ||
try: | ||
data = str(node.data) | ||
if ( | ||
data.startswith("<") | ||
or data[0] != data[0].upper() | ||
or len(data.splitlines()) > 1 | ||
): | ||
data = node.data.__class__.__name__ | ||
elif len(data) > 42: | ||
data = data[:42] + "..." | ||
except Exception: | ||
data = node.data.__class__.__name__ | ||
else: | ||
data = node.data.__name__ | ||
return data if not data.startswith("Runnable") else data[8:] | ||
|
||
|
||
def node_data_json(node: Node) -> Dict[str, Union[str, Dict[str, Any]]]: | ||
from langchain_core.load.serializable import to_json_not_implemented | ||
from langchain_core.runnables.base import Runnable, RunnableSerializable | ||
|
||
if isinstance(node.data, RunnableSerializable): | ||
return { | ||
"type": "runnable", | ||
"data": { | ||
"id": node.data.lc_id(), | ||
"name": node.data.get_name(), | ||
}, | ||
} | ||
elif isinstance(node.data, Runnable): | ||
return { | ||
"type": "runnable", | ||
"data": { | ||
"id": to_json_not_implemented(node.data)["id"], | ||
"name": node.data.get_name(), | ||
}, | ||
} | ||
elif inspect.isclass(node.data) and issubclass(node.data, BaseModel): | ||
return { | ||
"type": "schema", | ||
"data": node.data.schema(), | ||
} | ||
else: | ||
return { | ||
"type": "unknown", | ||
"data": node_data_str(node), | ||
} | ||
|
||
|
||
@dataclass | ||
class Graph: | ||
"""Graph of nodes and edges.""" | ||
|
||
nodes: Dict[str, Node] = field(default_factory=dict) | ||
edges: List[Edge] = field(default_factory=list) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are edges in both directions allowed simultaneously? i.e., if there's node 1 and 2 can we get edge 1 -> 2 and edge 2 -> 1 ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In theory possible yea |
||
|
||
def to_json(self) -> Dict[str, List[Dict[str, Any]]]: | ||
"""Convert the graph to a JSON-serializable format.""" | ||
stable_node_ids = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does stable mean here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ah ic, so that two graphs with equivalent objects and edges would have same serialization? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. basically if you generate the same chain twice you should get same json graph |
||
node.id: i if is_uuid(node.id) else node.id | ||
for i, node in enumerate(self.nodes.values()) | ||
} | ||
|
||
return { | ||
"nodes": [ | ||
{"id": stable_node_ids[node.id], **node_data_json(node)} | ||
for node in self.nodes.values() | ||
], | ||
"edges": [ | ||
{ | ||
"source": stable_node_ids[edge.source], | ||
"target": stable_node_ids[edge.target], | ||
"data": edge.data, | ||
} | ||
if edge.data is not None | ||
else { | ||
"source": stable_node_ids[edge.source], | ||
"target": stable_node_ids[edge.target], | ||
} | ||
for edge in self.edges | ||
], | ||
} | ||
|
||
def __bool__(self) -> bool: | ||
return bool(self.nodes) | ||
|
||
def next_id(self) -> str: | ||
return uuid4().hex | ||
|
||
def add_node(self, data: Union[Type[BaseModel], RunnableType]) -> Node: | ||
def add_node( | ||
self, data: Union[Type[BaseModel], RunnableType], id: Optional[str] = None | ||
) -> Node: | ||
"""Add a node to the graph and return it.""" | ||
node = Node(id=self.next_id(), data=data) | ||
if id is not None and id in self.nodes: | ||
raise ValueError(f"Node with id {id} already exists") | ||
node = Node(id=id or self.next_id(), data=data) | ||
self.nodes[node.id] = node | ||
return node | ||
|
||
|
@@ -53,13 +149,13 @@ def remove_node(self, node: Node) -> None: | |
if edge.source != node.id and edge.target != node.id | ||
] | ||
|
||
def add_edge(self, source: Node, target: Node) -> Edge: | ||
def add_edge(self, source: Node, target: Node, data: Optional[str] = None) -> Edge: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: I generally prefer |
||
"""Add an edge to the graph and return it.""" | ||
if source.id not in self.nodes: | ||
raise ValueError(f"Source node {source.id} not in graph") | ||
if target.id not in self.nodes: | ||
raise ValueError(f"Target node {target.id} not in graph") | ||
edge = Edge(source=source.id, target=target.id) | ||
edge = Edge(source=source.id, target=target.id, data=data) | ||
self.edges.append(edge) | ||
return edge | ||
|
||
|
@@ -117,28 +213,8 @@ def trim_last_node(self) -> None: | |
self.remove_node(last_node) | ||
|
||
def draw_ascii(self) -> str: | ||
from langchain_core.runnables.base import Runnable | ||
|
||
def node_data(node: Node) -> str: | ||
if isinstance(node.data, Runnable): | ||
try: | ||
data = str(node.data) | ||
if ( | ||
data.startswith("<") | ||
or data[0] != data[0].upper() | ||
or len(data.splitlines()) > 1 | ||
): | ||
data = node.data.__class__.__name__ | ||
elif len(data) > 42: | ||
data = data[:42] + "..." | ||
except Exception: | ||
data = node.data.__class__.__name__ | ||
else: | ||
data = node.data.__name__ | ||
return data if not data.startswith("Runnable") else data[8:] | ||
|
||
return draw( | ||
{node.id: node_data(node) for node in self.nodes.values()}, | ||
{node.id: node_data_str(node) for node in self.nodes.values()}, | ||
[(edge.source, edge.target) for edge in self.edges], | ||
) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unrelated but why the rename
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does it cause issues when you actually import in functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yea because I need to import the actual class inside a function