Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: generate mermaid syntax and render visual graph #19599

Merged
merged 9 commits into from
Apr 1, 2024
99 changes: 99 additions & 0 deletions libs/core/langchain_core/runnables/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import inspect
from dataclasses import dataclass, field
from enum import Enum
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
List,
NamedTuple,
Expand Down Expand Up @@ -51,6 +53,46 @@ class Node(NamedTuple):
data: Union[Type[BaseModel], RunnableType]


class Branch(NamedTuple):
"""Branch in a graph."""

condition: Callable[..., str]
ends: Optional[dict[str, str]]


class CurveStyle(Enum):
"""Enum for different curve styles supported by Mermaid"""

BASIS = "basis"
BUMP_X = "bumpX"
BUMP_Y = "bumpY"
CARDINAL = "cardinal"
CATMULL_ROM = "catmullRom"
LINEAR = "linear"
MONOTONE_X = "monotoneX"
MONOTONE_Y = "monotoneY"
NATURAL = "natural"
STEP = "step"
STEP_AFTER = "stepAfter"
STEP_BEFORE = "stepBefore"


@dataclass
class NodeColors:
"""Schema for Hexadecimal color codes for different node types"""

start: str = "#ffdfba"
end: str = "#baffc9"
other: str = "#fad7de"


class MermaidDrawMethod(Enum):
"""Enum for different draw methods supported by Mermaid"""

PYPPETEER = "pyppeteer" # Uses Pyppeteer to render the graph
API = "api" # Uses Mermaid.INK API to render the graph


def node_data_str(node: Node) -> str:
from langchain_core.runnables.base import Runnable

Expand Down Expand Up @@ -112,6 +154,7 @@ class Graph:

nodes: Dict[str, Node] = field(default_factory=dict)
edges: List[Edge] = field(default_factory=list)
branches: Optional[Dict[str, List[Branch]]] = field(default_factory=dict)

def to_json(self) -> Dict[str, List[Dict[str, Any]]]:
"""Convert the graph to a JSON-serializable format."""
Expand Down Expand Up @@ -277,3 +320,59 @@ def draw_png(
edges=labels["edges"] if labels is not None else {},
),
).draw(self, output_file_path)

def draw_mermaid(
self,
curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeColors = NodeColors(
start="#ffdfba", end="#baffc9", other="#fad7de"
),
wrap_label_n_words: int = 9,
) -> str:
from langchain_core.runnables.graph_mermaid import draw_mermaid

nodes = {node.id: node_data_str(node) for node in self.nodes.values()}

first_node = self.first_node()
first_label = node_data_str(first_node) if first_node is not None else None

last_node = self.last_node()
last_label = node_data_str(last_node) if last_node is not None else None

return draw_mermaid(
nodes=nodes,
edges=self.edges,
branches=self.branches,
first_node_label=first_label,
last_node_label=last_label,
curve_style=curve_style,
node_colors=node_colors,
wrap_label_n_words=wrap_label_n_words,
)

def draw_mermaid_png(
self,
curve_style: CurveStyle = CurveStyle.LINEAR,
node_colors: NodeColors = NodeColors(
start="#ffdfba", end="#baffc9", other="#fad7de"
),
wrap_label_n_words: int = 9,
output_file_path: str = "graph.png",
draw_method: MermaidDrawMethod = MermaidDrawMethod.API,
background_color: str = "white",
padding: int = 10,
) -> None:
from langchain_core.runnables.graph_mermaid import draw_mermaid_png

mermaid_syntax = self.draw_mermaid(
curve_style=curve_style,
node_colors=node_colors,
wrap_label_n_words=wrap_label_n_words,
)
draw_mermaid_png(
mermaid_syntax=mermaid_syntax,
output_file_path=output_file_path,
draw_method=draw_method,
background_color=background_color,
padding=padding,
)