Skip to content

Commit

Permalink
feat: Enhance Anthropic integration with Thinking
Browse files Browse the repository at this point in the history
- The `thinking` option is added to `AnthropicChatOptions` and `ChatCompletionRequest`.
- The `AnthropicApi` and `AnthropicChatModel` now handle `THINKING` and `REDACTED_THINKING` content blocks in responses.  New tests verify parsing of these blocks.
- Updated method signatures on ChatCompletionRequestBuilder, deprecating old builders with `with*` prefix in favor of those without.

Signed-off-by: Alexandros Pappas <[email protected]>
  • Loading branch information
apappascs committed Feb 27, 2025
1 parent 6cb15e4 commit 9168278
Show file tree
Hide file tree
Showing 7 changed files with 416 additions and 65 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -36,6 +37,7 @@
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.lang.Nullable;

import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

Expand Down Expand Up @@ -379,46 +381,51 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage
return new ChatResponse(List.of());
}

List<Generation> generations = chatCompletion.content()
.stream()
.filter(content -> content.type() != ContentBlock.Type.TOOL_USE)
.map(content -> new Generation(new AssistantMessage(content.text(), Map.of()),
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()))
.toList();

List<Generation> allGenerations = new ArrayList<>(generations);
List<Generation> generations = new ArrayList<>();
List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>();
for (ContentBlock content : chatCompletion.content()) {
switch (content.type()) {
case TEXT, TEXT_DELTA:
generations.add(new Generation(new AssistantMessage(content.text(), Map.of()),
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()));
break;
case THINKING, THINKING_DELTA:
System.out.println(
"THINKINGTHINKINGTHINKINGTHINKINGTHINKINGTHINKINGTHINKINGcontent type: " + content.type());
Map<String, Object> thinkingProperties = new HashMap<>();
thinkingProperties.put("signature", content.signature());
generations.add(new Generation(new AssistantMessage(content.thinking(), thinkingProperties),
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()));
break;
case REDACTED_THINKING:
Map<String, Object> redactedProperties = new HashMap<>();
redactedProperties.put("data", content.data());
generations.add(new Generation(new AssistantMessage(null, redactedProperties),
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()));
break;
case TOOL_USE:
var functionCallId = content.id();
var functionName = content.name();
var functionArguments = JsonParser.toJson(content.input());
toolCalls.add(
new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments));
break;
}
}

if (chatCompletion.stopReason() != null && generations.isEmpty()) {
Generation generation = new Generation(new AssistantMessage(null, Map.of()),
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build());
allGenerations.add(generation);
generations.add(generation);
}

List<ContentBlock> toolToUseList = chatCompletion.content()
.stream()
.filter(c -> c.type() == ContentBlock.Type.TOOL_USE)
.toList();

if (!CollectionUtils.isEmpty(toolToUseList)) {
List<AssistantMessage.ToolCall> toolCalls = new ArrayList<>();

for (ContentBlock toolToUse : toolToUseList) {

var functionCallId = toolToUse.id();
var functionName = toolToUse.name();
var functionArguments = JsonParser.toJson(toolToUse.input());

toolCalls
.add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments));
}

if (!CollectionUtils.isEmpty(toolCalls)) {
AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls);
Generation toolCallGeneration = new Generation(assistantMessage,
ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build());
allGenerations.add(toolCallGeneration);
generations.add(toolCallGeneration);
}

return new ChatResponse(allGenerations, this.from(chatCompletion, usage));
return new ChatResponse(generations, this.from(chatCompletion, usage));
}

private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) {
Expand Down Expand Up @@ -575,7 +582,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
List<ToolDefinition> toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions);
if (!CollectionUtils.isEmpty(toolDefinitions)) {
request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class);
request = ChatCompletionRequest.from(request).withTools(getFunctionTools(toolDefinitions)).build();
request = ChatCompletionRequest.from(request).tools(getFunctionTools(toolDefinitions)).build();
}

return request;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ public class AnthropicChatOptions implements ToolCallingChatOptions {
private @JsonProperty("temperature") Double temperature;
private @JsonProperty("top_p") Double topP;
private @JsonProperty("top_k") Integer topK;
private @JsonProperty("thinking") ChatCompletionRequest.ThinkingConfig thinking;

/**
* Collection of {@link ToolCallback}s to be used for tool calling in the chat
Expand Down Expand Up @@ -94,6 +95,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
.temperature(fromOptions.getTemperature())
.topP(fromOptions.getTopP())
.topK(fromOptions.getTopK())
.thinking(fromOptions.getThinking())
.toolCallbacks(fromOptions.getToolCallbacks())
.toolNames(fromOptions.getToolNames())
.internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled())
Expand Down Expand Up @@ -163,6 +165,14 @@ public void setTopK(Integer topK) {
this.topK = topK;
}

public ChatCompletionRequest.ThinkingConfig getThinking() {
return this.thinking;
}

public void setThinking(ChatCompletionRequest.ThinkingConfig thinking) {
this.thinking = thinking;
}

@Override
@JsonIgnore
public List<FunctionCallback> getToolCallbacks() {
Expand Down Expand Up @@ -319,6 +329,16 @@ public Builder topK(Integer topK) {
return this;
}

public Builder thinking(ChatCompletionRequest.ThinkingConfig thinking) {
this.options.thinking = thinking;
return this;
}

public Builder thinking(AnthropicApi.ThinkingType type, Integer budgetTokens) {
this.options.thinking = new ChatCompletionRequest.ThinkingConfig(type, budgetTokens);
return this;
}

public Builder toolCallbacks(List<FunctionCallback> toolCallbacks) {
this.options.setToolCallbacks(toolCallbacks);
return this;
Expand Down
Loading

0 comments on commit 9168278

Please sign in to comment.