Skip to content

Commit 73ee114

Browse files
cohere[minor]: Add support for tool calling cohere (#5810)
* feat: Add support for tool calling cohere * fix: lint errors + minor fixes * fix: apply format and lint, revert some unexpected changes to json files * Update tsconfig.json * Remove added deps from package json file * nit: Fix typo * address PR comments * make methods private * update yarn lock * nit * Lint, support OpenAI tool schemas * Fix and enable standard tool calling tests * Bump lower-bound dep --------- Co-authored-by: jacoblee93 <[email protected]>
1 parent bac6138 commit 73ee114

16 files changed

+620
-94
lines changed

docs/core_docs/docs/integrations/chat/cohere.mdx

+8
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,14 @@ import StatefulChatExample from "@examples/models/chat/cohere/stateful_conversat
6262
You can see the LangSmith traces from this example [here](https://smith.langchain.com/public/8e67b05a-4e63-414e-ac91-a91acf21b262/r) and [here](https://smith.langchain.com/public/50fabc25-46fe-4727-a59c-7e4eb0de8e70/r)
6363
:::
6464

65+
### Tools
66+
67+
The Cohere API supports tool calling, along with multi-hop-tool calling. The following example demonstrates how to call tools:
68+
69+
import ToolCallingExample from "@examples/models/chat/cohere/tool_calling.ts";
70+
71+
<CodeBlock language="typescript">{ToolCallingExample}</CodeBlock>
72+
6573
### RAG
6674

6775
Cohere also comes out of the box with RAG support.

examples/src/models/chat/cohere/chat_cohere.ts

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ import { ChatPromptTemplate } from "@langchain/core/prompts";
33

44
const model = new ChatCohere({
55
apiKey: process.env.COHERE_API_KEY, // Default
6-
model: "command", // Default
76
});
87
const prompt = ChatPromptTemplate.fromMessages([
98
["ai", "You are a helpful assistant"],

examples/src/models/chat/cohere/chat_stream_cohere.ts

-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import { StringOutputParser } from "@langchain/core/output_parsers";
44

55
const model = new ChatCohere({
66
apiKey: process.env.COHERE_API_KEY, // Default
7-
model: "command", // Default
87
});
98
const prompt = ChatPromptTemplate.fromMessages([
109
["ai", "You are a helpful assistant"],

examples/src/models/chat/cohere/connectors.ts

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ import { HumanMessage } from "@langchain/core/messages";
33

44
const model = new ChatCohere({
55
apiKey: process.env.COHERE_API_KEY, // Default
6-
model: "command", // Default
76
});
87

98
const response = await model.invoke(

examples/src/models/chat/cohere/rag.ts

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ import { HumanMessage } from "@langchain/core/messages";
33

44
const model = new ChatCohere({
55
apiKey: process.env.COHERE_API_KEY, // Default
6-
model: "command", // Default
76
});
87

98
const documents = [

examples/src/models/chat/cohere/stateful_conversation.ts

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ import { HumanMessage } from "@langchain/core/messages";
33

44
const model = new ChatCohere({
55
apiKey: process.env.COHERE_API_KEY, // Default
6-
model: "command", // Default
76
});
87

98
const conversationId = `demo_test_id-${Math.random()}`;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import { ChatCohere } from "@langchain/cohere";
2+
import { HumanMessage } from "@langchain/core/messages";
3+
import { z } from "zod";
4+
import { DynamicStructuredTool } from "@langchain/core/tools";
5+
6+
const model = new ChatCohere({
7+
apiKey: process.env.COHERE_API_KEY, // Default
8+
});
9+
10+
const magicFunctionTool = new DynamicStructuredTool({
11+
name: "magic_function",
12+
description: "Apply a magic function to the input number",
13+
schema: z.object({
14+
num: z.number().describe("The number to apply the magic function for"),
15+
}),
16+
func: async ({ num }) => {
17+
return `The magic function of ${num} is ${num + 5}`;
18+
},
19+
});
20+
21+
const tools = [magicFunctionTool];
22+
const modelWithTools = model.bindTools(tools);
23+
24+
const messages = [new HumanMessage("What is the magic function of number 5?")];
25+
const response = await modelWithTools.invoke(messages);
26+
/*
27+
AIMessage {
28+
content: 'I will use the magic_function tool to answer this question.',
29+
name: undefined,
30+
additional_kwargs: {
31+
response_id: 'd0b189e5-3dbf-493c-93f8-99ed4b01d96d',
32+
generationId: '8982a68f-c64c-48f8-bf12-0b4bea0018b6',
33+
chatHistory: [ [Object], [Object] ],
34+
finishReason: 'COMPLETE',
35+
meta: { apiVersion: [Object], billedUnits: [Object], tokens: [Object] },
36+
toolCalls: [ [Object] ]
37+
},
38+
response_metadata: {
39+
estimatedTokenUsage: { completionTokens: 54, promptTokens: 920, totalTokens: 974 },
40+
response_id: 'd0b189e5-3dbf-493c-93f8-99ed4b01d96d',
41+
generationId: '8982a68f-c64c-48f8-bf12-0b4bea0018b6',
42+
chatHistory: [ [Object], [Object] ],
43+
finishReason: 'COMPLETE',
44+
meta: { apiVersion: [Object], billedUnits: [Object], tokens: [Object] },
45+
toolCalls: [ [Object] ]
46+
},
47+
tool_calls: [
48+
{
49+
name: 'magic_function',
50+
args: [Object],
51+
id: '4ec98550-ba9a-4043-adfe-566230e5'
52+
}
53+
],
54+
invalid_tool_calls: [],
55+
usage_metadata: { input_tokens: 920, output_tokens: 54, total_tokens: 974 }
56+
}
57+
*/

libs/langchain-cohere/.eslintrc.cjs

+1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ module.exports = {
3333
"@typescript-eslint/no-unused-vars": ["warn", { args: "none" }],
3434
"@typescript-eslint/no-floating-promises": "error",
3535
"@typescript-eslint/no-misused-promises": "error",
36+
"arrow-body-style": 0,
3637
camelcase: 0,
3738
"class-methods-use-this": 0,
3839
"import/extensions": [2, "ignorePackages"],

libs/langchain-cohere/package.json

+5-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,11 @@
3535
"author": "LangChain",
3636
"license": "MIT",
3737
"dependencies": {
38-
"@langchain/core": ">=0.2.5 <0.3.0",
39-
"cohere-ai": "^7.10.5"
38+
"@langchain/core": ">=0.2.14 <0.3.0",
39+
"cohere-ai": "^7.10.5",
40+
"uuid": "^10.0.0",
41+
"zod": "^3.23.8",
42+
"zod-to-json-schema": "^3.23.1"
4043
},
4144
"devDependencies": {
4245
"@jest/globals": "^29.5.0",

0 commit comments

Comments
 (0)