Skip to content

Commit b7dcfc1

Browse files
ThomasVitaletzolov
authored andcommitted
Anthropic - Adopt ToolCallingManager API
- Update AnthropicChatModel to use the new ToolCallingManager API, while ensuring full API backward compatibility. - Introduce Builder to instantiate a new AnthropicChatModel since the number of overloaded constructors is growing too big. - Update documentation about tool calling and Anthropic support for that. Part of the #2207 epic Signed-off-by: Thomas Vitale <[email protected]>
1 parent 8463454 commit b7dcfc1

File tree

10 files changed

+449
-138
lines changed

10 files changed

+449
-138
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java

Lines changed: 238 additions & 68 deletions
Large diffs are not rendered by default.

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java

Lines changed: 122 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,6 +17,8 @@
1717
package org.springframework.ai.anthropic;
1818

1919
import java.util.ArrayList;
20+
import java.util.Arrays;
21+
import java.util.HashMap;
2022
import java.util.HashSet;
2123
import java.util.List;
2224
import java.util.Map;
@@ -30,7 +32,9 @@
3032
import org.springframework.ai.anthropic.api.AnthropicApi;
3133
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest;
3234
import org.springframework.ai.model.function.FunctionCallback;
33-
import org.springframework.ai.model.function.FunctionCallingOptions;
35+
import org.springframework.ai.model.tool.ToolCallingChatOptions;
36+
import org.springframework.ai.tool.ToolCallback;
37+
import org.springframework.lang.Nullable;
3438
import org.springframework.util.Assert;
3539

3640
/**
@@ -42,7 +46,7 @@
4246
* @since 1.0.0
4347
*/
4448
@JsonInclude(Include.NON_NULL)
45-
public class AnthropicChatOptions implements FunctionCallingOptions {
49+
public class AnthropicChatOptions implements ToolCallingChatOptions {
4650

4751
// @formatter:off
4852
private @JsonProperty("model") String model;
@@ -54,34 +58,27 @@ public class AnthropicChatOptions implements FunctionCallingOptions {
5458
private @JsonProperty("top_k") Integer topK;
5559

5660
/**
57-
* Tool Function Callbacks to register with the ChatModel. For Prompt
58-
* Options the functionCallbacks are automatically enabled for the duration of the
59-
* prompt execution. For Default Options the functionCallbacks are registered but
60-
* disabled by default. Use the enableFunctions to set the functions from the registry
61-
* to be used by the ChatModel chat completion requests.
61+
* Collection of {@link ToolCallback}s to be used for tool calling in the chat
62+
* completion requests.
6263
*/
6364
@JsonIgnore
64-
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
65+
private List<FunctionCallback> toolCallbacks = new ArrayList<>();
6566

6667
/**
67-
* List of functions, identified by their names, to configure for function calling in
68-
* the chat completion requests. Functions with those names must exist in the
69-
* functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions
70-
* are automatically enabled for the duration of the prompt execution.
71-
*
72-
* Note that function enabled with the default options are enabled for all chat
73-
* completion requests. This could impact the token count and the billing. If the
74-
* functions is set in a prompt options, then the enabled functions are only active
75-
* for the duration of this prompt execution.
68+
* Collection of tool names to be resolved at runtime and used for tool calling in the
69+
* chat completion requests.
7670
*/
7771
@JsonIgnore
78-
private Set<String> functions = new HashSet<>();
72+
private Set<String> toolNames = new HashSet<>();
7973

74+
/**
75+
* Whether to enable the tool execution lifecycle internally in ChatModel.
76+
*/
8077
@JsonIgnore
81-
private Boolean proxyToolCalls;
78+
private Boolean internalToolExecutionEnabled;
8279

8380
@JsonIgnore
84-
private Map<String, Object> toolContext;
81+
private Map<String, Object> toolContext = new HashMap<>();
8582

8683
// @formatter:on
8784

@@ -97,9 +94,9 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
9794
.temperature(fromOptions.getTemperature())
9895
.topP(fromOptions.getTopP())
9996
.topK(fromOptions.getTopK())
100-
.functionCallbacks(fromOptions.getFunctionCallbacks())
101-
.functions(fromOptions.getFunctions())
102-
.proxyToolCalls(fromOptions.getProxyToolCalls())
97+
.toolCallbacks(fromOptions.getToolCallbacks())
98+
.toolNames(fromOptions.getToolNames())
99+
.internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled())
103100
.toolContext(fromOptions.getToolContext())
104101
.build();
105102
}
@@ -167,25 +164,73 @@ public void setTopK(Integer topK) {
167164
}
168165

169166
@Override
167+
@JsonIgnore
168+
public List<FunctionCallback> getToolCallbacks() {
169+
return this.toolCallbacks;
170+
}
171+
172+
@Override
173+
@JsonIgnore
174+
public void setToolCallbacks(List<FunctionCallback> toolCallbacks) {
175+
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
176+
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
177+
this.toolCallbacks = toolCallbacks;
178+
}
179+
180+
@Override
181+
@JsonIgnore
182+
public Set<String> getToolNames() {
183+
return this.toolNames;
184+
}
185+
186+
@Override
187+
@JsonIgnore
188+
public void setToolNames(Set<String> toolNames) {
189+
Assert.notNull(toolNames, "toolNames cannot be null");
190+
Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
191+
toolNames.forEach(tool -> Assert.hasText(tool, "toolNames cannot contain empty elements"));
192+
this.toolNames = toolNames;
193+
}
194+
195+
@Override
196+
@Nullable
197+
@JsonIgnore
198+
public Boolean isInternalToolExecutionEnabled() {
199+
return internalToolExecutionEnabled;
200+
}
201+
202+
@Override
203+
@JsonIgnore
204+
public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
205+
this.internalToolExecutionEnabled = internalToolExecutionEnabled;
206+
}
207+
208+
@Override
209+
@Deprecated
210+
@JsonIgnore
170211
public List<FunctionCallback> getFunctionCallbacks() {
171-
return this.functionCallbacks;
212+
return this.getToolCallbacks();
172213
}
173214

174215
@Override
216+
@Deprecated
217+
@JsonIgnore
175218
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
176-
Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null");
177-
this.functionCallbacks = functionCallbacks;
219+
this.setToolCallbacks(functionCallbacks);
178220
}
179221

180222
@Override
223+
@Deprecated
224+
@JsonIgnore
181225
public Set<String> getFunctions() {
182-
return this.functions;
226+
return this.getToolNames();
183227
}
184228

185229
@Override
186-
public void setFunctions(Set<String> functions) {
187-
Assert.notNull(functions, "Function must not be null");
188-
this.functions = functions;
230+
@Deprecated
231+
@JsonIgnore
232+
public void setFunctions(Set<String> functionNames) {
233+
this.setToolNames(functionNames);
189234
}
190235

191236
@Override
@@ -201,20 +246,26 @@ public Double getPresencePenalty() {
201246
}
202247

203248
@Override
249+
@Deprecated
250+
@JsonIgnore
204251
public Boolean getProxyToolCalls() {
205-
return this.proxyToolCalls;
252+
return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null;
206253
}
207254

255+
@Deprecated
256+
@JsonIgnore
208257
public void setProxyToolCalls(Boolean proxyToolCalls) {
209-
this.proxyToolCalls = proxyToolCalls;
258+
this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null;
210259
}
211260

212261
@Override
262+
@JsonIgnore
213263
public Map<String, Object> getToolContext() {
214264
return this.toolContext;
215265
}
216266

217267
@Override
268+
@JsonIgnore
218269
public void setToolContext(Map<String, Object> toolContext) {
219270
this.toolContext = toolContext;
220271
}
@@ -268,25 +319,54 @@ public Builder topK(Integer topK) {
268319
return this;
269320
}
270321

271-
public Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
272-
this.options.functionCallbacks = functionCallbacks;
322+
public Builder toolCallbacks(List<FunctionCallback> toolCallbacks) {
323+
this.options.setToolCallbacks(toolCallbacks);
273324
return this;
274325
}
275326

276-
public Builder functions(Set<String> functionNames) {
277-
Assert.notNull(functionNames, "Function names must not be null");
278-
this.options.functions = functionNames;
327+
public Builder toolCallbacks(FunctionCallback... toolCallbacks) {
328+
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
329+
this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks));
279330
return this;
280331
}
281332

282-
public Builder function(String functionName) {
283-
Assert.hasText(functionName, "Function name must not be empty");
284-
this.options.functions.add(functionName);
333+
public Builder toolNames(Set<String> toolNames) {
334+
Assert.notNull(toolNames, "toolNames cannot be null");
335+
this.options.setToolNames(toolNames);
336+
return this;
337+
}
338+
339+
public Builder toolNames(String... toolNames) {
340+
Assert.notNull(toolNames, "toolNames cannot be null");
341+
this.options.toolNames.addAll(Set.of(toolNames));
285342
return this;
286343
}
287344

345+
public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
346+
this.options.setInternalToolExecutionEnabled(internalToolExecutionEnabled);
347+
return this;
348+
}
349+
350+
@Deprecated
351+
public Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
352+
return toolCallbacks(functionCallbacks);
353+
}
354+
355+
@Deprecated
356+
public Builder functions(Set<String> functionNames) {
357+
return toolNames(functionNames);
358+
}
359+
360+
@Deprecated
361+
public Builder function(String functionName) {
362+
return toolNames(functionName);
363+
}
364+
365+
@Deprecated
288366
public Builder proxyToolCalls(Boolean proxyToolCalls) {
289-
this.options.proxyToolCalls = proxyToolCalls;
367+
if (proxyToolCalls != null) {
368+
this.options.setInternalToolExecutionEnabled(!proxyToolCalls);
369+
}
290370
return this;
291371
}
292372

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ public enum ChatModel implements ChatModelDescription {
238238
CLAUDE_3_OPUS("claude-3-opus-latest"),
239239

240240
/**
241-
* The CLAUDE_3_SONNET
241+
* The CLAUDE_3_SONNET (Deprecated. To be removed on July 21, 2025)
242242
*/
243243
CLAUDE_3_SONNET("claude-3-sonnet-20240229"),
244244

@@ -254,12 +254,12 @@ public enum ChatModel implements ChatModelDescription {
254254

255255
// Legacy models
256256
/**
257-
* The CLAUDE_2_1
257+
* The CLAUDE_2_1 (Deprecated. To be removed on July 21, 2025)
258258
*/
259259
CLAUDE_2_1("claude-2.1"),
260260

261261
/**
262-
* The CLAUDE_2_0
262+
* The CLAUDE_2_0 (Deprecated. To be removed on July 21, 2025)
263263
*/
264264
CLAUDE_2("claude-2.0");
265265

models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/ChatCompletionRequestTests.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -26,6 +26,7 @@
2626
/**
2727
* @author Christian Tzolov
2828
* @author Alexandros Pappas
29+
* @author Thomas Vitale
2930
*/
3031
public class ChatCompletionRequestTests {
3132

@@ -35,16 +36,20 @@ public void createRequestWithChatOptions() {
3536
var client = new AnthropicChatModel(new AnthropicApi("TEST"),
3637
AnthropicChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6).build());
3738

38-
var request = client.createRequest(new Prompt("Test message content"), false);
39+
var prompt = client.buildRequestPrompt(new Prompt("Test message content"));
40+
41+
var request = client.createRequest(prompt, false);
3942

4043
assertThat(request.messages()).hasSize(1);
4144
assertThat(request.stream()).isFalse();
4245

4346
assertThat(request.model()).isEqualTo("DEFAULT_MODEL");
4447
assertThat(request.temperature()).isEqualTo(66.6);
4548

46-
request = client.createRequest(new Prompt("Test message content",
47-
AnthropicChatOptions.builder().model("PROMPT_MODEL").temperature(99.9).build()), true);
49+
prompt = client.buildRequestPrompt(new Prompt("Test message content",
50+
AnthropicChatOptions.builder().model("PROMPT_MODEL").temperature(99.9).build()));
51+
52+
request = client.createRequest(prompt, true);
4853

4954
assertThat(request.messages()).hasSize(1);
5055
assertThat(request.stream()).isTrue();

spring-ai-core/src/main/java/org/springframework/ai/chat/model/ChatResponse.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
2525
import org.springframework.ai.model.ModelResponse;
26+
import org.springframework.util.Assert;
2627
import org.springframework.util.CollectionUtils;
2728

2829
/**
@@ -111,6 +112,21 @@ public boolean hasToolCalls() {
111112
return generations.stream().anyMatch(generation -> generation.getOutput().hasToolCalls());
112113
}
113114

115+
/**
116+
* Whether the model has finished with any of the given finish reasons.
117+
*/
118+
public boolean hasFinishReasons(Set<String> finishReasons) {
119+
Assert.notNull(finishReasons, "finishReasons cannot be null");
120+
if (CollectionUtils.isEmpty(generations)) {
121+
return false;
122+
}
123+
return generations.stream().anyMatch(generation -> {
124+
var finishReason = (generation.getMetadata().getFinishReason() != null)
125+
? generation.getMetadata().getFinishReason() : "";
126+
return finishReasons.stream().map(String::toLowerCase).toList().contains(finishReason.toLowerCase());
127+
});
128+
}
129+
114130
@Override
115131
public String toString() {
116132
return "ChatResponse [metadata=" + this.chatResponseMetadata + ", generations=" + this.generations + "]";

0 commit comments

Comments
 (0)