Skip to content

core[patch]: Adds mermaid graph format #5978

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 57 additions & 22 deletions langchain-core/src/runnables/graph.ts
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
import { zodToJsonSchema } from "zod-to-json-schema";
import { v4 as uuidv4, validate as isUuid } from "uuid";
import type { RunnableInterface, RunnableIOSchema } from "./types.js";
import type {
RunnableInterface,
RunnableIOSchema,
Node,
Edge,
} from "./types.js";
import { isRunnableInterface } from "./utils.js";

interface Edge {
source: string;
target: string;
data?: string;
}

interface Node {
id: string;

data: RunnableIOSchema | RunnableInterface;
}
import { drawMermaid, drawMermaidPng } from "./graph_mermaid.js";

const MAX_DATA_DISPLAY_NAME_LENGTH = 42;

Expand All @@ -22,17 +16,12 @@ export function nodeDataStr(node: Node): string {
return node.id;
} else if (isRunnableInterface(node.data)) {
try {
let data = node.data.toString();
if (
data.startsWith("<") ||
data[0] !== data[0].toUpperCase() ||
data.split("\n").length > 1
) {
data = node.data.getName();
} else if (data.length > MAX_DATA_DISPLAY_NAME_LENGTH) {
let data = node.data.getName();
data = data.startsWith("Runnable") ? data.slice("Runnable".length) : data;
if (data.length > MAX_DATA_DISPLAY_NAME_LENGTH) {
data = `${data.substring(0, MAX_DATA_DISPLAY_NAME_LENGTH)}...`;
}
return data.startsWith("Runnable") ? data.slice("Runnable".length) : data;
return data;
} catch (error) {
return node.data.getName();
}
Expand Down Expand Up @@ -179,4 +168,50 @@ export class Graph {
}
}
}

drawMermaid(params?: {
withStyles?: boolean;
curveStyle?: string;
nodeColors?: Record<string, string>;
wrapLabelNWords?: number;
}): string {
const {
withStyles,
curveStyle,
nodeColors = { start: "#ffdfba", end: "#baffc9", other: "#fad7de" },
wrapLabelNWords,
} = params ?? {};
const nodes: Record<string, string> = {};
for (const node of Object.values(this.nodes)) {
nodes[node.id] = nodeDataStr(node);
}

const firstNode = this.firstNode();
const firstNodeLabel = firstNode ? nodeDataStr(firstNode) : undefined;

const lastNode = this.lastNode();
const lastNodeLabel = lastNode ? nodeDataStr(lastNode) : undefined;

return drawMermaid(nodes, this.edges, {
firstNodeLabel,
lastNodeLabel,
withStyles,
curveStyle,
nodeColors,
wrapLabelNWords,
});
}

async drawMermaidPng(params?: {
withStyles?: boolean;
curveStyle?: string;
nodeColors?: Record<string, string>;
wrapLabelNWords?: number;
backgroundColor?: string;
}): Promise<Blob> {
const mermaidSyntax = this.drawMermaid(params);
return drawMermaidPng(mermaidSyntax, {
backgroundColor: params?.backgroundColor,
});
}
}
177 changes: 177 additions & 0 deletions langchain-core/src/runnables/graph_mermaid.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import { Edge } from "./types.js";

function _escapeNodeLabel(nodeLabel: string): string {
// Escapes the node label for Mermaid syntax.
return nodeLabel.replace(/[^a-zA-Z-_0-9]/g, "_");
}

// Adjusts Mermaid edge to map conditional nodes to pure nodes.
function _adjustMermaidEdge(edge: Edge, nodes: Record<string, string>) {
const sourceNodeLabel = nodes[edge.source] ?? edge.source;
const targetNodeLabel = nodes[edge.target] ?? edge.target;
return [sourceNodeLabel, targetNodeLabel];
}

function _generateMermaidGraphStyles(
nodeColors: Record<string, string>
): string {
let styles = "";
for (const [className, color] of Object.entries(nodeColors)) {
styles += `\tclassDef ${className}class fill:${color};\n`;
}
return styles;
}

/**
* Draws a Mermaid graph using the provided graph data
*/
export function drawMermaid(
nodes: Record<string, string>,
edges: Edge[],
config?: {
firstNodeLabel?: string;
lastNodeLabel?: string;
curveStyle?: string;
withStyles?: boolean;
nodeColors?: Record<string, string>;
wrapLabelNWords?: number;
}
): string {
const {
firstNodeLabel,
lastNodeLabel,
nodeColors,
withStyles = true,
curveStyle = "linear",
wrapLabelNWords = 9,
} = config ?? {};
// Initialize Mermaid graph configuration
let mermaidGraph = withStyles
? `%%{init: {'flowchart': {'curve': '${curveStyle}'}}}%%\ngraph TD;\n`
: "graph TD;\n";
if (withStyles) {
// Node formatting templates
const defaultClassLabel = "default";
const formatDict: Record<string, string> = {
[defaultClassLabel]: "{0}([{1}]):::otherclass",
};
if (firstNodeLabel !== undefined) {
formatDict[firstNodeLabel] = "{0}[{0}]:::startclass";
}
if (lastNodeLabel !== undefined) {
formatDict[lastNodeLabel] = "{0}[{0}]:::endclass";
}

// Add nodes to the graph
for (const node of Object.values(nodes)) {
const nodeLabel = formatDict[node] ?? formatDict[defaultClassLabel];
const escapedNodeLabel = _escapeNodeLabel(node);
const nodeParts = node.split(":");
const nodeSplit = nodeParts[nodeParts.length - 1];
mermaidGraph += `\t${nodeLabel
.replace(/\{0\}/g, escapedNodeLabel)
.replace(/\{1\}/g, nodeSplit)};\n`;
}
}
let subgraph = "";
// Add edges to the graph
for (const edge of edges) {
const sourcePrefix = edge.source.includes(":")
? edge.source.split(":")[0]
: undefined;
const targetPrefix = edge.target.includes(":")
? edge.target.split(":")[0]
: undefined;
// Exit subgraph if source or target is not in the same subgraph
if (
subgraph !== "" &&
(subgraph !== sourcePrefix || subgraph !== targetPrefix)
) {
mermaidGraph += "\tend\n";
subgraph = "";
}
// Enter subgraph if source and target are in the same subgraph
if (
subgraph === "" &&
sourcePrefix !== undefined &&
sourcePrefix === targetPrefix
) {
mermaidGraph = `\tsubgraph ${sourcePrefix}\n`;
subgraph = sourcePrefix;
}
const [source, target] = _adjustMermaidEdge(edge, nodes);
let edgeLabel = "";
// Add BR every wrapLabelNWords words
if (edge.data !== undefined) {
let edgeData = edge.data;
const words = edgeData.split(" ");
// Group words into chunks of wrapLabelNWords size
if (words.length > wrapLabelNWords) {
edgeData = words
.reduce((acc: string[], word: string, i: number) => {
if (i % wrapLabelNWords === 0) acc.push("");
acc[acc.length - 1] += ` ${word}`;
return acc;
}, [])
.join("<br>");
if (edge.conditional) {
edgeLabel = ` -. ${edgeData} .-> `;
} else {
edgeLabel = ` -- ${edgeData} --> `;
}
}
} else {
if (edge.conditional) {
edgeLabel = ` -.-> `;
} else {
edgeLabel = ` --> `;
}
}
mermaidGraph += `\t${_escapeNodeLabel(
source
)}${edgeLabel}${_escapeNodeLabel(target)};\n`;
}
if (subgraph !== "") {
mermaidGraph += "end\n";
}

// Add custom styles for nodes
if (withStyles && nodeColors !== undefined) {
mermaidGraph += _generateMermaidGraphStyles(nodeColors);
}
return mermaidGraph;
}

/**
* Renders Mermaid graph using the Mermaid.INK API.
*/
export async function drawMermaidPng(
mermaidSyntax: string,
config?: {
backgroundColor?: string;
}
) {
let { backgroundColor = "white" } = config ?? {};
// Use btoa for compatibility, assume ASCII
const mermaidSyntaxEncoded = btoa(mermaidSyntax);
// Check if the background color is a hexadecimal color code using regex
if (backgroundColor !== undefined) {
const hexColorPattern = /^#(?:[0-9a-fA-F]{3}){1,2}$/;
if (!hexColorPattern.test(backgroundColor)) {
backgroundColor = `!${backgroundColor}`;
}
}
const imageUrl = `https://mermaid.ink/img/${mermaidSyntaxEncoded}?bgColor=${backgroundColor}`;
const res = await fetch(imageUrl);
if (!res.ok) {
throw new Error(
[
`Failed to render the graph using the Mermaid.INK API.`,
`Status code: ${res.status}`,
`Status text: ${res.statusText}`,
].join("\n")
);
}
const content = await res.blob();
return content;
}
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
16 changes: 16 additions & 0 deletions langchain-core/src/runnables/tests/runnable_graph.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,4 +87,20 @@ test("Test graph sequence", async () => {
{ source: 2, target: 3 },
],
});
expect(graph.drawMermaid())
.toEqual(`%%{init: {'flowchart': {'curve': 'linear'}}}%%
graph TD;
\tPromptTemplateInput[PromptTemplateInput]:::startclass;
\tPromptTemplate([PromptTemplate]):::otherclass;
\tFakeLLM([FakeLLM]):::otherclass;
\tCommaSeparatedListOutputParser([CommaSeparatedListOutputParser]):::otherclass;
\tCommaSeparatedListOutputParserOutput[CommaSeparatedListOutputParserOutput]:::endclass;
\tPromptTemplateInput --> PromptTemplate;
\tPromptTemplate --> FakeLLM;
\tCommaSeparatedListOutputParser --> CommaSeparatedListOutputParserOutput;
\tFakeLLM --> CommaSeparatedListOutputParser;
\tclassDef startclass fill:#ffdfba;
\tclassDef endclass fill:#baffc9;
\tclassDef otherclass fill:#fad7de;
`);
});
12 changes: 12 additions & 0 deletions langchain-core/src/runnables/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,15 @@ export interface RunnableInterface<

getName(suffix?: string): string;
}

export interface Edge {
source: string;
target: string;
data?: string;
conditional?: boolean;
}

export interface Node {
id: string;
data: RunnableIOSchema | RunnableInterface;
}
Loading