Skip to content

Commit

Permalink
Add JSON representation of runnable graph to serialized representation (
Browse files Browse the repository at this point in the history
langchain-ai#17745)

Sent to LangSmith

Thank you for contributing to LangChain!

Checklist:

- [ ] PR title: Please title your PR "package: description", where
"package" is whichever of langchain, community, core, experimental, etc.
is being modified. Use "docs: ..." for purely docs changes, "templates:
..." for template changes, "infra: ..." for CI changes.
  - Example: "community: add foobar LLM"
- [ ] PR message: **Delete this entire template message** and replace it
with the following bulleted list
    - **Description:** a description of the change
    - **Issue:** the issue # it fixes, if applicable
    - **Dependencies:** any dependencies required for this change
- **Twitter handle:** if your PR gets announced, and you'd like a
mention, we'll gladly shout you out!
- [ ] Pass lint and test: Run `make format`, `make lint` and `make test`
from the root of the package(s) you've modified to check that you're
passing lint and testing. See contribution guidelines for more
information on how to write/run tests, lint, etc:
https://python.langchain.com/docs/contributing/
- [ ] Add tests and docs: If you're adding a new integration, please
include
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

Additional guidelines:
- Make sure optional dependencies are imported within a function.
- Please do not add dependencies to pyproject.toml files (even optional
ones) unless they are required for unit tests.
- Most PRs should not touch more than one package.
- Changes should be backwards compatible.
- If you are adding something to community, do not re-import it in
langchain.

If no one reviews your PR within a few days, please @-mention one of
baskaryan, efriis, eyurtsev, hwchase17.
  • Loading branch information
nfcampos authored and al1p-R committed Feb 27, 2024
1 parent 32f1670 commit 1444dfe
Show file tree
Hide file tree
Showing 7 changed files with 51,499 additions and 1,213 deletions.
15 changes: 14 additions & 1 deletion libs/core/langchain_core/load/serializable.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
from abc import ABC
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union, cast
from typing import (
Any,
Dict,
List,
Literal,
Optional,
TypedDict,
Union,
cast,
)

from typing_extensions import NotRequired

from langchain_core.pydantic_v1 import BaseModel, PrivateAttr

Expand All @@ -9,6 +20,8 @@ class BaseSerialized(TypedDict):

lc: int
id: List[str]
name: NotRequired[str]
graph: NotRequired[Dict[str, Any]]


class SerializedConstructor(BaseSerialized):
Expand Down
16 changes: 15 additions & 1 deletion libs/core/langchain_core/runnables/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@

from langchain_core._api import beta_decorator
from langchain_core.load.dump import dumpd
from langchain_core.load.serializable import Serializable
from langchain_core.load.serializable import (
Serializable,
SerializedConstructor,
SerializedNotImplemented,
)
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables.config import (
RunnableConfig,
Expand Down Expand Up @@ -1630,6 +1634,16 @@ class RunnableSerializable(Serializable, Runnable[Input, Output]):
name: Optional[str] = None
"""The name of the runnable. Used for debugging and tracing."""

def to_json(self) -> Union[SerializedConstructor, SerializedNotImplemented]:
"""Serialize the runnable to JSON."""
dumped = super().to_json()
try:
dumped["name"] = self.get_name()
dumped["graph"] = self.get_graph().to_json()
except Exception:
pass
return dumped

def configurable_fields(
self, **kwargs: AnyConfigurableField
) -> RunnableSerializable[Input, Output]:
Expand Down
130 changes: 103 additions & 27 deletions libs/core/langchain_core/runnables/graph.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import inspect
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Dict, List, NamedTuple, Optional, Type, Union
from uuid import uuid4
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Type, Union
from uuid import UUID, uuid4

from langchain_core.pydantic_v1 import BaseModel
from langchain_core.runnables.graph_draw import draw
Expand All @@ -11,11 +12,20 @@
from langchain_core.runnables.base import Runnable as RunnableType


def is_uuid(value: str) -> bool:
try:
UUID(value)
return True
except ValueError:
return False


class Edge(NamedTuple):
"""Edge in a graph."""

source: str
target: str
data: Optional[str] = None


class Node(NamedTuple):
Expand All @@ -25,22 +35,108 @@ class Node(NamedTuple):
data: Union[Type[BaseModel], RunnableType]


def node_data_str(node: Node) -> str:
from langchain_core.runnables.base import Runnable

if not is_uuid(node.id):
return node.id
elif isinstance(node.data, Runnable):
try:
data = str(node.data)
if (
data.startswith("<")
or data[0] != data[0].upper()
or len(data.splitlines()) > 1
):
data = node.data.__class__.__name__
elif len(data) > 42:
data = data[:42] + "..."
except Exception:
data = node.data.__class__.__name__
else:
data = node.data.__name__
return data if not data.startswith("Runnable") else data[8:]


def node_data_json(node: Node) -> Dict[str, Union[str, Dict[str, Any]]]:
from langchain_core.load.serializable import to_json_not_implemented
from langchain_core.runnables.base import Runnable, RunnableSerializable

if isinstance(node.data, RunnableSerializable):
return {
"type": "runnable",
"data": {
"id": node.data.lc_id(),
"name": node.data.get_name(),
},
}
elif isinstance(node.data, Runnable):
return {
"type": "runnable",
"data": {
"id": to_json_not_implemented(node.data)["id"],
"name": node.data.get_name(),
},
}
elif inspect.isclass(node.data) and issubclass(node.data, BaseModel):
return {
"type": "schema",
"data": node.data.schema(),
}
else:
return {
"type": "unknown",
"data": node_data_str(node),
}


@dataclass
class Graph:
"""Graph of nodes and edges."""

nodes: Dict[str, Node] = field(default_factory=dict)
edges: List[Edge] = field(default_factory=list)

def to_json(self) -> Dict[str, List[Dict[str, Any]]]:
"""Convert the graph to a JSON-serializable format."""
stable_node_ids = {
node.id: i if is_uuid(node.id) else node.id
for i, node in enumerate(self.nodes.values())
}

return {
"nodes": [
{"id": stable_node_ids[node.id], **node_data_json(node)}
for node in self.nodes.values()
],
"edges": [
{
"source": stable_node_ids[edge.source],
"target": stable_node_ids[edge.target],
"data": edge.data,
}
if edge.data is not None
else {
"source": stable_node_ids[edge.source],
"target": stable_node_ids[edge.target],
}
for edge in self.edges
],
}

def __bool__(self) -> bool:
return bool(self.nodes)

def next_id(self) -> str:
return uuid4().hex

def add_node(self, data: Union[Type[BaseModel], RunnableType]) -> Node:
def add_node(
self, data: Union[Type[BaseModel], RunnableType], id: Optional[str] = None
) -> Node:
"""Add a node to the graph and return it."""
node = Node(id=self.next_id(), data=data)
if id is not None and id in self.nodes:
raise ValueError(f"Node with id {id} already exists")
node = Node(id=id or self.next_id(), data=data)
self.nodes[node.id] = node
return node

Expand All @@ -53,13 +149,13 @@ def remove_node(self, node: Node) -> None:
if edge.source != node.id and edge.target != node.id
]

def add_edge(self, source: Node, target: Node) -> Edge:
def add_edge(self, source: Node, target: Node, data: Optional[str] = None) -> Edge:
"""Add an edge to the graph and return it."""
if source.id not in self.nodes:
raise ValueError(f"Source node {source.id} not in graph")
if target.id not in self.nodes:
raise ValueError(f"Target node {target.id} not in graph")
edge = Edge(source=source.id, target=target.id)
edge = Edge(source=source.id, target=target.id, data=data)
self.edges.append(edge)
return edge

Expand Down Expand Up @@ -117,28 +213,8 @@ def trim_last_node(self) -> None:
self.remove_node(last_node)

def draw_ascii(self) -> str:
from langchain_core.runnables.base import Runnable

def node_data(node: Node) -> str:
if isinstance(node.data, Runnable):
try:
data = str(node.data)
if (
data.startswith("<")
or data[0] != data[0].upper()
or len(data.splitlines()) > 1
):
data = node.data.__class__.__name__
elif len(data) > 42:
data = data[:42] + "..."
except Exception:
data = node.data.__class__.__name__
else:
data = node.data.__name__
return data if not data.startswith("Runnable") else data[8:]

return draw(
{node.id: node_data(node) for node in self.nodes.values()},
{node.id: node_data_str(node) for node in self.nodes.values()},
[(edge.source, edge.target) for edge in self.edges],
)

Expand Down

0 comments on commit 1444dfe

Please sign in to comment.