Skip to content

Commit

Permalink
Anthropic - Adopt ToolCallingManager API
Browse files Browse the repository at this point in the history
- Update AnthropicChatModel to use the new ToolCallingManager API, while ensuring full API backward compatibility.
- Introduce Builder to instantiate a new AnthropicChatModel since the number of overloaded constructors is growing too big.
- Update documentation about tool calling and Anthropic support for that.

Part of the #2207 epic

Signed-off-by: Thomas Vitale <[email protected]>
  • Loading branch information
ThomasVitale authored and tzolov committed Feb 13, 2025
1 parent 8463454 commit b7dcfc1
Show file tree
Hide file tree
Showing 10 changed files with 449 additions and 138 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,8 @@
package org.springframework.ai.anthropic;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand All @@ -30,7 +32,9 @@
import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;

/**
Expand All @@ -42,7 +46,7 @@
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
public class AnthropicChatOptions implements FunctionCallingOptions {
public class AnthropicChatOptions implements ToolCallingChatOptions {

// @formatter:off
private @JsonProperty("model") String model;
Expand All @@ -54,34 +58,27 @@ public class AnthropicChatOptions implements FunctionCallingOptions {
private @JsonProperty("top_k") Integer topK;

/**
* Tool Function Callbacks to register with the ChatModel. For Prompt
* Options the functionCallbacks are automatically enabled for the duration of the
* prompt execution. For Default Options the functionCallbacks are registered but
* disabled by default. Use the enableFunctions to set the functions from the registry
* to be used by the ChatModel chat completion requests.
* Collection of {@link ToolCallback}s to be used for tool calling in the chat
* completion requests.
*/
@JsonIgnore
private List<FunctionCallback> functionCallbacks = new ArrayList<>();
private List<FunctionCallback> toolCallbacks = new ArrayList<>();

/**
* List of functions, identified by their names, to configure for function calling in
* the chat completion requests. Functions with those names must exist in the
* functionCallbacks registry. The {@link #functionCallbacks} from the PromptOptions
* are automatically enabled for the duration of the prompt execution.
*
* Note that function enabled with the default options are enabled for all chat
* completion requests. This could impact the token count and the billing. If the
* functions is set in a prompt options, then the enabled functions are only active
* for the duration of this prompt execution.
* Collection of tool names to be resolved at runtime and used for tool calling in the
* chat completion requests.
*/
@JsonIgnore
private Set<String> functions = new HashSet<>();
private Set<String> toolNames = new HashSet<>();

/**
* Whether to enable the tool execution lifecycle internally in ChatModel.
*/
@JsonIgnore
private Boolean proxyToolCalls;
private Boolean internalToolExecutionEnabled;

@JsonIgnore
private Map<String, Object> toolContext;
private Map<String, Object> toolContext = new HashMap<>();

// @formatter:on

Expand All @@ -97,9 +94,9 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
.temperature(fromOptions.getTemperature())
.topP(fromOptions.getTopP())
.topK(fromOptions.getTopK())
.functionCallbacks(fromOptions.getFunctionCallbacks())
.functions(fromOptions.getFunctions())
.proxyToolCalls(fromOptions.getProxyToolCalls())
.toolCallbacks(fromOptions.getToolCallbacks())
.toolNames(fromOptions.getToolNames())
.internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled())
.toolContext(fromOptions.getToolContext())
.build();
}
Expand Down Expand Up @@ -167,25 +164,73 @@ public void setTopK(Integer topK) {
}

@Override
@JsonIgnore
public List<FunctionCallback> getToolCallbacks() {
return this.toolCallbacks;
}

@Override
@JsonIgnore
public void setToolCallbacks(List<FunctionCallback> toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
this.toolCallbacks = toolCallbacks;
}

@Override
@JsonIgnore
public Set<String> getToolNames() {
return this.toolNames;
}

@Override
@JsonIgnore
public void setToolNames(Set<String> toolNames) {
Assert.notNull(toolNames, "toolNames cannot be null");
Assert.noNullElements(toolNames, "toolNames cannot contain null elements");
toolNames.forEach(tool -> Assert.hasText(tool, "toolNames cannot contain empty elements"));
this.toolNames = toolNames;
}

@Override
@Nullable
@JsonIgnore
public Boolean isInternalToolExecutionEnabled() {
return internalToolExecutionEnabled;
}

@Override
@JsonIgnore
public void setInternalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
this.internalToolExecutionEnabled = internalToolExecutionEnabled;
}

@Override
@Deprecated
@JsonIgnore
public List<FunctionCallback> getFunctionCallbacks() {
return this.functionCallbacks;
return this.getToolCallbacks();
}

@Override
@Deprecated
@JsonIgnore
public void setFunctionCallbacks(List<FunctionCallback> functionCallbacks) {
Assert.notNull(functionCallbacks, "FunctionCallbacks must not be null");
this.functionCallbacks = functionCallbacks;
this.setToolCallbacks(functionCallbacks);
}

@Override
@Deprecated
@JsonIgnore
public Set<String> getFunctions() {
return this.functions;
return this.getToolNames();
}

@Override
public void setFunctions(Set<String> functions) {
Assert.notNull(functions, "Function must not be null");
this.functions = functions;
@Deprecated
@JsonIgnore
public void setFunctions(Set<String> functionNames) {
this.setToolNames(functionNames);
}

@Override
Expand All @@ -201,20 +246,26 @@ public Double getPresencePenalty() {
}

@Override
@Deprecated
@JsonIgnore
public Boolean getProxyToolCalls() {
return this.proxyToolCalls;
return this.internalToolExecutionEnabled != null ? !this.internalToolExecutionEnabled : null;
}

@Deprecated
@JsonIgnore
public void setProxyToolCalls(Boolean proxyToolCalls) {
this.proxyToolCalls = proxyToolCalls;
this.internalToolExecutionEnabled = proxyToolCalls != null ? !proxyToolCalls : null;
}

@Override
@JsonIgnore
public Map<String, Object> getToolContext() {
return this.toolContext;
}

@Override
@JsonIgnore
public void setToolContext(Map<String, Object> toolContext) {
this.toolContext = toolContext;
}
Expand Down Expand Up @@ -268,25 +319,54 @@ public Builder topK(Integer topK) {
return this;
}

public Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
this.options.functionCallbacks = functionCallbacks;
public Builder toolCallbacks(List<FunctionCallback> toolCallbacks) {
this.options.setToolCallbacks(toolCallbacks);
return this;
}

public Builder functions(Set<String> functionNames) {
Assert.notNull(functionNames, "Function names must not be null");
this.options.functions = functionNames;
public Builder toolCallbacks(FunctionCallback... toolCallbacks) {
Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
this.options.toolCallbacks.addAll(Arrays.asList(toolCallbacks));
return this;
}

public Builder function(String functionName) {
Assert.hasText(functionName, "Function name must not be empty");
this.options.functions.add(functionName);
public Builder toolNames(Set<String> toolNames) {
Assert.notNull(toolNames, "toolNames cannot be null");
this.options.setToolNames(toolNames);
return this;
}

public Builder toolNames(String... toolNames) {
Assert.notNull(toolNames, "toolNames cannot be null");
this.options.toolNames.addAll(Set.of(toolNames));
return this;
}

public Builder internalToolExecutionEnabled(@Nullable Boolean internalToolExecutionEnabled) {
this.options.setInternalToolExecutionEnabled(internalToolExecutionEnabled);
return this;
}

@Deprecated
public Builder functionCallbacks(List<FunctionCallback> functionCallbacks) {
return toolCallbacks(functionCallbacks);
}

@Deprecated
public Builder functions(Set<String> functionNames) {
return toolNames(functionNames);
}

@Deprecated
public Builder function(String functionName) {
return toolNames(functionName);
}

@Deprecated
public Builder proxyToolCalls(Boolean proxyToolCalls) {
this.options.proxyToolCalls = proxyToolCalls;
if (proxyToolCalls != null) {
this.options.setInternalToolExecutionEnabled(!proxyToolCalls);
}
return this;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ public enum ChatModel implements ChatModelDescription {
CLAUDE_3_OPUS("claude-3-opus-latest"),

/**
* The CLAUDE_3_SONNET
* The CLAUDE_3_SONNET (Deprecated. To be removed on July 21, 2025)
*/
CLAUDE_3_SONNET("claude-3-sonnet-20240229"),

Expand All @@ -254,12 +254,12 @@ public enum ChatModel implements ChatModelDescription {

// Legacy models
/**
* The CLAUDE_2_1
* The CLAUDE_2_1 (Deprecated. To be removed on July 21, 2025)
*/
CLAUDE_2_1("claude-2.1"),

/**
* The CLAUDE_2_0
* The CLAUDE_2_0 (Deprecated. To be removed on July 21, 2025)
*/
CLAUDE_2("claude-2.0");

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -26,6 +26,7 @@
/**
* @author Christian Tzolov
* @author Alexandros Pappas
* @author Thomas Vitale
*/
public class ChatCompletionRequestTests {

Expand All @@ -35,16 +36,20 @@ public void createRequestWithChatOptions() {
var client = new AnthropicChatModel(new AnthropicApi("TEST"),
AnthropicChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6).build());

var request = client.createRequest(new Prompt("Test message content"), false);
var prompt = client.buildRequestPrompt(new Prompt("Test message content"));

var request = client.createRequest(prompt, false);

assertThat(request.messages()).hasSize(1);
assertThat(request.stream()).isFalse();

assertThat(request.model()).isEqualTo("DEFAULT_MODEL");
assertThat(request.temperature()).isEqualTo(66.6);

request = client.createRequest(new Prompt("Test message content",
AnthropicChatOptions.builder().model("PROMPT_MODEL").temperature(99.9).build()), true);
prompt = client.buildRequestPrompt(new Prompt("Test message content",
AnthropicChatOptions.builder().model("PROMPT_MODEL").temperature(99.9).build()));

request = client.createRequest(prompt, true);

assertThat(request.messages()).hasSize(1);
assertThat(request.stream()).isTrue();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.model.ModelResponse;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

/**
Expand Down Expand Up @@ -111,6 +112,21 @@ public boolean hasToolCalls() {
return generations.stream().anyMatch(generation -> generation.getOutput().hasToolCalls());
}

/**
* Whether the model has finished with any of the given finish reasons.
*/
public boolean hasFinishReasons(Set<String> finishReasons) {
Assert.notNull(finishReasons, "finishReasons cannot be null");
if (CollectionUtils.isEmpty(generations)) {
return false;
}
return generations.stream().anyMatch(generation -> {
var finishReason = (generation.getMetadata().getFinishReason() != null)
? generation.getMetadata().getFinishReason() : "";
return finishReasons.stream().map(String::toLowerCase).toList().contains(finishReason.toLowerCase());
});
}

@Override
public String toString() {
return "ChatResponse [metadata=" + this.chatResponseMetadata + ", generations=" + this.generations + "]";
Expand Down
Loading

0 comments on commit b7dcfc1

Please sign in to comment.