diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index e92cb9dedcd..980ba5e1c84 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.Base64; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -379,46 +380,49 @@ private ChatResponse toChatResponse(ChatCompletionResponse chatCompletion, Usage return new ChatResponse(List.of()); } - List generations = chatCompletion.content() - .stream() - .filter(content -> content.type() != ContentBlock.Type.TOOL_USE) - .map(content -> new Generation(new AssistantMessage(content.text(), Map.of()), - ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())) - .toList(); - - List allGenerations = new ArrayList<>(generations); + List generations = new ArrayList<>(); + List toolCalls = new ArrayList<>(); + for (ContentBlock content : chatCompletion.content()) { + switch (content.type()) { + case TEXT: + generations.add(new Generation(new AssistantMessage(content.text(), Map.of()), + ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())); + break; + case THINKING: + Map thinkingProperties = new HashMap<>(); + thinkingProperties.put("signature", content.signature()); + generations.add(new Generation(new AssistantMessage(content.thinking(), thinkingProperties), + ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())); + break; + case REDACTED_THINKING: + Map redactedProperties = new HashMap<>(); + redactedProperties.put("data", content.data()); + generations.add(new Generation(new AssistantMessage(null, redactedProperties), + ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build())); + break; + case TOOL_USE: + var functionCallId = content.id(); + var functionName = content.name(); + var functionArguments = JsonParser.toJson(content.input()); + toolCalls.add( + new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments)); + break; + } + } if (chatCompletion.stopReason() != null && generations.isEmpty()) { Generation generation = new Generation(new AssistantMessage(null, Map.of()), ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()); - allGenerations.add(generation); + generations.add(generation); } - List toolToUseList = chatCompletion.content() - .stream() - .filter(c -> c.type() == ContentBlock.Type.TOOL_USE) - .toList(); - - if (!CollectionUtils.isEmpty(toolToUseList)) { - List toolCalls = new ArrayList<>(); - - for (ContentBlock toolToUse : toolToUseList) { - - var functionCallId = toolToUse.id(); - var functionName = toolToUse.name(); - var functionArguments = JsonParser.toJson(toolToUse.input()); - - toolCalls - .add(new AssistantMessage.ToolCall(functionCallId, "function", functionName, functionArguments)); - } - + if (!CollectionUtils.isEmpty(toolCalls)) { AssistantMessage assistantMessage = new AssistantMessage("", Map.of(), toolCalls); Generation toolCallGeneration = new Generation(assistantMessage, ChatGenerationMetadata.builder().finishReason(chatCompletion.stopReason()).build()); - allGenerations.add(toolCallGeneration); + generations.add(toolCallGeneration); } - - return new ChatResponse(allGenerations, this.from(chatCompletion, usage)); + return new ChatResponse(generations, this.from(chatCompletion, usage)); } private ChatResponseMetadata from(AnthropicApi.ChatCompletionResponse result) { @@ -575,7 +579,7 @@ else if (message.getMessageType() == MessageType.TOOL) { List toolDefinitions = this.toolCallingManager.resolveToolDefinitions(requestOptions); if (!CollectionUtils.isEmpty(toolDefinitions)) { request = ModelOptionsUtils.merge(request, this.defaultOptions, ChatCompletionRequest.class); - request = ChatCompletionRequest.from(request).withTools(getFunctionTools(toolDefinitions)).build(); + request = ChatCompletionRequest.from(request).tools(getFunctionTools(toolDefinitions)).build(); } return request; diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index c1b319a27ff..4df44a1c6e5 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -56,6 +56,7 @@ public class AnthropicChatOptions implements ToolCallingChatOptions { private @JsonProperty("temperature") Double temperature; private @JsonProperty("top_p") Double topP; private @JsonProperty("top_k") Integer topK; + private @JsonProperty("thinking") ChatCompletionRequest.ThinkingConfig thinking; /** * Collection of {@link ToolCallback}s to be used for tool calling in the chat @@ -94,6 +95,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) .temperature(fromOptions.getTemperature()) .topP(fromOptions.getTopP()) .topK(fromOptions.getTopK()) + .thinking(fromOptions.getThinking()) .toolCallbacks(fromOptions.getToolCallbacks()) .toolNames(fromOptions.getToolNames()) .internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled()) @@ -163,6 +165,14 @@ public void setTopK(Integer topK) { this.topK = topK; } + public ChatCompletionRequest.ThinkingConfig getThinking() { + return this.thinking; + } + + public void setThinking(ChatCompletionRequest.ThinkingConfig thinking) { + this.thinking = thinking; + } + @Override @JsonIgnore public List getToolCallbacks() { @@ -319,6 +329,16 @@ public Builder topK(Integer topK) { return this; } + public Builder thinking(ChatCompletionRequest.ThinkingConfig thinking) { + this.options.thinking = thinking; + return this; + } + + public Builder thinking(AnthropicApi.ThinkingType type, Integer budgetTokens) { + this.options.thinking = new ChatCompletionRequest.ThinkingConfig(type, budgetTokens); + return this; + } + public Builder toolCallbacks(List toolCallbacks) { this.options.setToolCallbacks(toolCallbacks); return this; diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java index 891a8b77230..d22cd2ea5bb 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java @@ -54,6 +54,7 @@ * @author Mariusz Bernacki * @author Thomas Vitale * @author Jihoon Kim + * @author Alexandros Pappas * @since 1.0.0 */ public class AnthropicApi { @@ -231,9 +232,9 @@ public enum ChatModel implements ChatModelDescription { * The claude-3-7-sonnet-latest model. */ CLAUDE_3_7_SONNET("claude-3-7-sonnet-latest"), - + /** - * The claude-3-5-sonnet-20241022 model. + * The claude-3-5-sonnet-latest model. */ CLAUDE_3_5_SONNET("claude-3-5-sonnet-latest"), @@ -316,6 +317,25 @@ public enum Role { } + /** + * The thinking type. + */ + public enum ThinkingType { + + /** + * Enabled thinking type. + */ + @JsonProperty("enabled") + ENABLED, + + /** + * Disabled thinking type. + */ + @JsonProperty("disabled") + DISABLED + + } + /** * The event type of the streamed chunk. */ @@ -339,6 +359,22 @@ public enum EventType { @JsonProperty("message_stop") MESSAGE_STOP, + /** + * When using extended thinking with streaming enabled, you’ll receive thinking + * content via thinking_delta events. These deltas correspond to the thinking + * field of the thinking content blocks. + */ + @JsonProperty("thinking_delta") + THINKING_DELTA, + + /** + * For thinking content, a special signature_delta event is sent just before the + * content_block_stop event. This signature is used to verify the integrity of the + * thinking block. + */ + @JsonProperty("signature_delta") + SIGNATURE_DELTA, + /** * Content block start event. */ @@ -381,11 +417,8 @@ public enum EventType { @JsonSubTypes({ @JsonSubTypes.Type(value = ContentBlockStartEvent.class, name = "content_block_start"), @JsonSubTypes.Type(value = ContentBlockDeltaEvent.class, name = "content_block_delta"), @JsonSubTypes.Type(value = ContentBlockStopEvent.class, name = "content_block_stop"), - @JsonSubTypes.Type(value = PingEvent.class, name = "ping"), - @JsonSubTypes.Type(value = ErrorEvent.class, name = "error"), - @JsonSubTypes.Type(value = MessageStartEvent.class, name = "message_start"), @JsonSubTypes.Type(value = MessageDeltaEvent.class, name = "message_delta"), @JsonSubTypes.Type(value = MessageStopEvent.class, name = "message_stop") }) @@ -437,6 +470,8 @@ public interface StreamEvent { * return tool_use content blocks that represent the model's use of those tools. You * can then run those tools using the tool input generated by the model and then * optionally return results back to the model using tool_result content blocks. + * @param thinking Configuration for the model's thinking mode. When enabled, the + * model can perform more in-depth reasoning before responding to a query. */ @JsonInclude(Include.NON_NULL) public record ChatCompletionRequest( @@ -451,17 +486,18 @@ public record ChatCompletionRequest( @JsonProperty("temperature") Double temperature, @JsonProperty("top_p") Double topP, @JsonProperty("top_k") Integer topK, - @JsonProperty("tools") List tools) { + @JsonProperty("tools") List tools, + @JsonProperty("thinking") ThinkingConfig thinking) { // @formatter:on public ChatCompletionRequest(String model, List messages, String system, Integer maxTokens, Double temperature, Boolean stream) { - this(model, messages, system, maxTokens, null, null, stream, temperature, null, null, null); + this(model, messages, system, maxTokens, null, null, stream, temperature, null, null, null, null); } public ChatCompletionRequest(String model, List messages, String system, Integer maxTokens, List stopSequences, Double temperature, Boolean stream) { - this(model, messages, system, maxTokens, null, stopSequences, stream, temperature, null, null, null); + this(model, messages, system, maxTokens, null, stopSequences, stream, temperature, null, null, null, null); } public static ChatCompletionRequestBuilder builder() { @@ -485,6 +521,18 @@ public record Metadata(@JsonProperty("user_id") String userId) { } + /** + * Configuration for the model's thinking mode. + * + * @param type The type of thinking mode. Currently, "enabled" is supported. + * @param budgetTokens The token budget available for the thinking process. Must + * be ≥1024 and less than max_tokens. + */ + @JsonInclude(Include.NON_NULL) + public record ThinkingConfig(@JsonProperty("type") ThinkingType type, + @JsonProperty("budget_tokens") Integer budgetTokens) { + } + } public static final class ChatCompletionRequestBuilder { @@ -511,6 +559,8 @@ public static final class ChatCompletionRequestBuilder { private List tools; + private ChatCompletionRequest.ThinkingConfig thinking; + private ChatCompletionRequestBuilder() { } @@ -526,71 +576,209 @@ private ChatCompletionRequestBuilder(ChatCompletionRequest request) { this.topP = request.topP; this.topK = request.topK; this.tools = request.tools; + this.thinking = request.thinking; } + /** + * @deprecated use {@link #model(ChatModel)} instead. + */ + @Deprecated(forRemoval = true, since = "1.0.0-M6") public ChatCompletionRequestBuilder withModel(ChatModel model) { this.model = model.getValue(); return this; } + public ChatCompletionRequestBuilder model(ChatModel model) { + this.model = model.getValue(); + return this; + } + + /** + * @deprecated use {@link #model(String)} instead. + */ + @Deprecated(forRemoval = true, since = "1.0.0-M6") public ChatCompletionRequestBuilder withModel(String model) { this.model = model; return this; } + public ChatCompletionRequestBuilder model(String model) { + this.model = model; + return this; + } + + /** + * @deprecated use {@link #messages(List)} instead. + */ + @Deprecated(forRemoval = true, since = "1.0.0-M6") public ChatCompletionRequestBuilder withMessages(List messages) { this.messages = messages; return this; } + public ChatCompletionRequestBuilder messages(List messages) { + this.messages = messages; + return this; + } + + /** + * @deprecated use {@link #system(String)} instead. + */ + @Deprecated(forRemoval = true, since = "1.0.0-M6") public ChatCompletionRequestBuilder withSystem(String system) { this.system = system; return this; } + public ChatCompletionRequestBuilder system(String system) { + this.system = system; + return this; + } + + /** + * @deprecated use {@link #maxTokens(Integer)} instead. + */ + @Deprecated(forRemoval = true, since = "1.0.0-M6") public ChatCompletionRequestBuilder withMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; return this; } + public ChatCompletionRequestBuilder maxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + /** + * @deprecated use {@link #metadata(ChatCompletionRequest.Metadata)} instead. + */ + @Deprecated(forRemoval = true, since = "1.0.0-M6") public ChatCompletionRequestBuilder withMetadata(ChatCompletionRequest.Metadata metadata) { this.metadata = metadata; return this; } + public ChatCompletionRequestBuilder metadata(ChatCompletionRequest.Metadata metadata) { + this.metadata = metadata; + return this; + } + + /** + * @deprecated use {@link #stopSequences(List)} instead. + */ + @Deprecated(forRemoval = true, since = "1.0.0-M6") public ChatCompletionRequestBuilder withStopSequences(List stopSequences) { this.stopSequences = stopSequences; return this; } + public ChatCompletionRequestBuilder stopSequences(List stopSequences) { + this.stopSequences = stopSequences; + return this; + } + + /** + * @deprecated use {@link #stream(Boolean)} instead. + */ + @Deprecated(forRemoval = true, since = "1.0.0-M6") public ChatCompletionRequestBuilder withStream(Boolean stream) { this.stream = stream; return this; } + public ChatCompletionRequestBuilder stream(Boolean stream) { + this.stream = stream; + return this; + } + + /** + * @deprecated use {@link #temperature(Double)} instead. + */ + @Deprecated(forRemoval = true, since = "1.0.0-M6") public ChatCompletionRequestBuilder withTemperature(Double temperature) { this.temperature = temperature; return this; } + public ChatCompletionRequestBuilder temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + /** + * @deprecated use {@link #topP(Double)} instead. + */ + @Deprecated(forRemoval = true, since = "1.0.0-M6") public ChatCompletionRequestBuilder withTopP(Double topP) { this.topP = topP; return this; } + public ChatCompletionRequestBuilder topP(Double topP) { + this.topP = topP; + return this; + } + + /** + * @deprecated use {@link #topK(Integer)} instead. + */ + @Deprecated(forRemoval = true, since = "1.0.0-M6") public ChatCompletionRequestBuilder withTopK(Integer topK) { this.topK = topK; return this; } + public ChatCompletionRequestBuilder topK(Integer topK) { + this.topK = topK; + return this; + } + + /** + * @deprecated use {@link #tools(List)} instead. + */ + @Deprecated(forRemoval = true, since = "1.0.0-M6") public ChatCompletionRequestBuilder withTools(List tools) { this.tools = tools; return this; } + public ChatCompletionRequestBuilder tools(List tools) { + this.tools = tools; + return this; + } + + /** + * @deprecated use {@link #thinking(ChatCompletionRequest.ThinkingConfig)} + * instead. + */ + @Deprecated(forRemoval = true, since = "1.0.0-M6") + public ChatCompletionRequestBuilder withThinking(ChatCompletionRequest.ThinkingConfig thinking) { + this.thinking = thinking; + return this; + } + + public ChatCompletionRequestBuilder thinking(ChatCompletionRequest.ThinkingConfig thinking) { + this.thinking = thinking; + return this; + } + + /** + * @deprecated use {@link #thinking(ThinkingType, Integer)} instead. + */ + @Deprecated(forRemoval = true, since = "1.0.0-M6") + public ChatCompletionRequestBuilder withThinking(ThinkingType type, Integer budgetTokens) { + this.thinking = new ChatCompletionRequest.ThinkingConfig(type, budgetTokens); + return this; + } + + public ChatCompletionRequestBuilder thinking(ThinkingType type, Integer budgetTokens) { + this.thinking = new ChatCompletionRequest.ThinkingConfig(type, budgetTokens); + return this; + } + public ChatCompletionRequest build() { return new ChatCompletionRequest(this.model, this.messages, this.system, this.maxTokens, this.metadata, - this.stopSequences, this.stream, this.temperature, this.topP, this.topK, this.tools); + this.stopSequences, this.stream, this.temperature, this.topP, this.topK, this.tools, this.thinking); } } @@ -658,7 +846,14 @@ public record ContentBlock( // tool_result response only @JsonProperty("tool_use_id") String toolUseId, - @JsonProperty("content") String content + @JsonProperty("content") String content, + + // Thinking only + @JsonProperty("signature") String signature, + @JsonProperty("thinking") String thinking, + + // Redacted Thinking only + @JsonProperty("data") String data ) { // @formatter:on @@ -677,7 +872,7 @@ public ContentBlock(String mediaType, String data) { * @param source The source of the content. */ public ContentBlock(Type type, Source source) { - this(type, source, null, null, null, null, null, null, null); + this(type, source, null, null, null, null, null, null, null, null, null, null); } /** @@ -685,7 +880,7 @@ public ContentBlock(Type type, Source source) { * @param source The source of the content. */ public ContentBlock(Source source) { - this(Type.IMAGE, source, null, null, null, null, null, null, null); + this(Type.IMAGE, source, null, null, null, null, null, null, null, null, null, null); } /** @@ -693,7 +888,7 @@ public ContentBlock(Source source) { * @param text The text of the content. */ public ContentBlock(String text) { - this(Type.TEXT, null, text, null, null, null, null, null, null); + this(Type.TEXT, null, text, null, null, null, null, null, null, null, null, null); } // Tool result @@ -704,7 +899,7 @@ public ContentBlock(String text) { * @param content The content of the tool result. */ public ContentBlock(Type type, String toolUseId, String content) { - this(type, null, null, null, null, null, null, toolUseId, content); + this(type, null, null, null, null, null, null, toolUseId, content, null, null, null); } /** @@ -715,7 +910,7 @@ public ContentBlock(Type type, String toolUseId, String content) { * @param index The index of the content block. */ public ContentBlock(Type type, Source source, String text, Integer index) { - this(type, source, text, index, null, null, null, null, null); + this(type, source, text, index, null, null, null, null, null, null, null, null); } // Tool use input JSON delta streaming @@ -727,7 +922,7 @@ public ContentBlock(Type type, Source source, String text, Integer index) { * @param input The input of the tool use. */ public ContentBlock(Type type, String id, String name, Map input) { - this(type, null, null, null, id, name, input, null, null); + this(type, null, null, null, id, name, input, null, null, null, null, null); } /** @@ -775,7 +970,19 @@ public enum Type { * Document message. */ @JsonProperty("document") - DOCUMENT("document"); + DOCUMENT("document"), + + /** + * Thinking message. + */ + @JsonProperty("thinking") + THINKING("thinking"), + + /** + * Redacted Thinking message. + */ + @JsonProperty("redacted_thinking") + REDACTED_THINKING("redacted_thinking"); public final String value; @@ -1042,7 +1249,10 @@ public record ContentBlockDeltaEvent( @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "type", visible = true) @JsonSubTypes({ @JsonSubTypes.Type(value = ContentBlockDeltaText.class, name = "text_delta"), - @JsonSubTypes.Type(value = ContentBlockDeltaJson.class, name = "input_json_delta") }) + @JsonSubTypes.Type(value = ContentBlockDeltaJson.class, name = "input_json_delta"), + @JsonSubTypes.Type(value = ContentBlockDeltaThinking.class, name = "thinking_delta"), + @JsonSubTypes.Type(value = ContentBlockDeltaSignature.class, name = "signature_delta") + }) public interface ContentBlockDeltaBody { String type(); } @@ -1068,6 +1278,26 @@ public record ContentBlockDeltaJson( @JsonProperty("type") String type, @JsonProperty("partial_json") String partialJson) implements ContentBlockDeltaBody { } + + /** + * Thinking content block delta. + * @param type The content block type. + * @param thinking The thinking content. + */ + public record ContentBlockDeltaThinking( + @JsonProperty("type") String type, + @JsonProperty("thinking") String thinking) implements ContentBlockDeltaBody { + } + + /** + * Signature content block delta. + * @param type The content block type. + * @param signature The signature content. + */ + public record ContentBlockDeltaSignature( + @JsonProperty("type") String type, + @JsonProperty("signature") String signature) implements ContentBlockDeltaBody { + } } // @formatter:on diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java index f141583486b..90dcfdfa171 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java @@ -23,6 +23,7 @@ import java.util.Map; import java.util.stream.Collectors; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; @@ -82,7 +83,6 @@ class AnthropicChatModelIT { private static void validateChatResponseMetadata(ChatResponse response, String model) { assertThat(response.getMetadata().getId()).isNotEmpty(); - assertThat(response.getMetadata().getModel()).containsIgnoringCase(model); assertThat(response.getMetadata().getUsage().getPromptTokens()).isPositive(); assertThat(response.getMetadata().getUsage().getCompletionTokens()).isPositive(); assertThat(response.getMetadata().getUsage().getTotalTokens()).isPositive(); @@ -118,7 +118,7 @@ void testMessageHistory() { SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(this.systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage), - AnthropicChatOptions.builder().model("claude-3-sonnet-20240229").build()); + AnthropicChatOptions.builder().model("claude-3-5-sonnet-latest").build()); ChatResponse response = this.chatModel.call(prompt); assertThat(response.getResult().getOutput().getText()).containsAnyOf("Blackbeard", "Bartholomew"); @@ -143,9 +143,6 @@ void streamingWithTokenUsage() { assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0); assertThat(streamingTokenUsage.getPromptTokens()).isEqualTo(referenceTokenUsage.getPromptTokens()); - // assertThat(streamingTokenUsage.getCompletionTokens()).isEqualTo(referenceTokenUsage.getCompletionTokens()); - // assertThat(streamingTokenUsage.getTotalTokens()).isEqualTo(referenceTokenUsage.getTotalTokens()); - } @Test @@ -357,7 +354,7 @@ void streamFunctionCallUsageTest() { @Test void validateCallResponseMetadata() { - String model = AnthropicApi.ChatModel.CLAUDE_2_1.getName(); + String model = AnthropicApi.ChatModel.CLAUDE_3_7_SONNET.getName(); // @formatter:off ChatResponse response = ChatClient.create(this.chatModel).prompt() .options(AnthropicChatOptions.builder().model(model).build()) @@ -384,13 +381,76 @@ void validateStreamCallResponseMetadata() { logger.info(response.toString()); // Note, brittle test. - validateChatResponseMetadata(response, "claude-3-5-sonnet-20241022"); + validateChatResponseMetadata(response, "claude-3-5-sonnet-latest"); } record ActorsFilmsRecord(String actor, List movies) { } + @Test + void thinkingTest() { + UserMessage userMessage = new UserMessage( + "Are there an infinite number of prime numbers such that n mod 4 == 3?"); + + var promptOptions = AnthropicChatOptions.builder() + .model(AnthropicApi.ChatModel.CLAUDE_3_7_SONNET.getName()) + .temperature(1.0) // temperature should be set to 1 when thinking is enabled + .maxTokens(8192) + .thinking(AnthropicApi.ThinkingType.ENABLED, 2048) // Must be ≥1024 && < + // max_tokens + .build(); + + ChatResponse response = this.chatModel.call(new Prompt(List.of(userMessage), promptOptions)); + + logger.info("Response: {}", response); + + for (Generation generation : response.getResults()) { + AssistantMessage message = generation.getOutput(); + if (message.getText() != null) { // text + assertThat(message.getText()).isNotBlank(); + } + else if (message.getMetadata().containsKey("signature")) { // thinking + assertThat(message.getMetadata().get("signature")).isNotNull(); + assertThat(message.getMetadata().get("thinking")).isNotNull(); + } + else if (message.getMetadata().containsKey("data")) { // redacted thinking + assertThat(message.getMetadata().get("data")).isNotNull(); + } + } + } + + @Test + void testToolUseContentBlock() { + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco, Tokyo and Paris? Return the result in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = AnthropicChatOptions.builder() + .model(AnthropicApi.ChatModel.CLAUDE_3_OPUS.getName()) + .functionCallbacks(List.of(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description( + "Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.") + .inputType(MockWeatherService.Request.class) + .build())) + .build(); + + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + for (Generation generation : response.getResults()) { + AssistantMessage message = generation.getOutput(); + if (!message.getToolCalls().isEmpty()) { + assertThat(message.getToolCalls()).isNotEmpty(); + AssistantMessage.ToolCall toolCall = message.getToolCalls().get(0); + assertThat(toolCall.id()).isNotBlank(); + assertThat(toolCall.name()).isNotBlank(); + assertThat(toolCall.arguments()).isNotBlank(); + } + } + } + @SpringBootConfiguration public static class Config { diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java index 752e9247fae..7eb86034c75 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/AnthropicApiIT.java @@ -35,6 +35,7 @@ /** * @author Christian Tzolov * @author Jihoon Kim + * @author Alexandros Pappas */ @EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".+") public class AnthropicApiIT { @@ -55,6 +56,39 @@ void chatCompletionEntity() { assertThat(response.getBody()).isNotNull(); } + @Test + void chatCompletionWithThinking() { + AnthropicMessage chatCompletionMessage = new AnthropicMessage(List.of(new ContentBlock("Tell me a Joke?")), + Role.USER); + + ChatCompletionRequest request = ChatCompletionRequest.builder() + .model(AnthropicApi.ChatModel.CLAUDE_3_7_SONNET.getValue()) + .messages(List.of(chatCompletionMessage)) + .maxTokens(8192) + .temperature(1.0) // temperature should be set to 1 when thinking is enabled + .thinking(new ChatCompletionRequest.ThinkingConfig(AnthropicApi.ThinkingType.ENABLED, 2048)) + .build(); + + ResponseEntity response = this.anthropicApi.chatCompletionEntity(request); + + assertThat(response).isNotNull(); + assertThat(response.getBody()).isNotNull(); + + List content = response.getBody().content(); + for (ContentBlock block : content) { + if (block.type() == ContentBlock.Type.THINKING) { + assertThat(block.thinking()).isNotBlank(); + assertThat(block.signature()).isNotBlank(); + } + if (block.type() == ContentBlock.Type.REDACTED_THINKING) { + assertThat(block.data()).isNotBlank(); + } + if (block.type() == ContentBlock.Type.TEXT) { + assertThat(block.text()).isNotBlank(); + } + } + } + @Test void chatCompletionStream() { diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiToolIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiToolIT.java index 079a01b3caf..6ac322220fd 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiToolIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/api/tool/AnthropicApiToolIT.java @@ -102,11 +102,11 @@ void toolCalls() { private ResponseEntity doCall(List messageConversation) { ChatCompletionRequest chatCompletionRequest = ChatCompletionRequest.builder() - .withModel(AnthropicApi.ChatModel.CLAUDE_3_OPUS) - .withMessages(messageConversation) - .withMaxTokens(1500) - .withTemperature(0.8) - .withTools(this.tools) + .model(AnthropicApi.ChatModel.CLAUDE_3_OPUS) + .messages(messageConversation) + .maxTokens(1500) + .temperature(0.8) + .tools(this.tools) .build(); ResponseEntity response = this.anthropicApi.chatCompletionEntity(chatCompletionRequest); diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java index 0c98773d02e..27c631ffd1e 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientIT.java @@ -284,8 +284,8 @@ void streamFunctionCallTest() { } @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", - "claude-3-5-sonnet-20241022" }) + @ValueSource(strings = { "claude-3-opus-latest", "claude-3-5-sonnet-latest", "claude-3-5-haiku-latest", + "claude-3-7-sonnet-latest" }) void multiModalityEmbeddedImage(String modelName) throws IOException { // @formatter:off @@ -303,8 +303,8 @@ void multiModalityEmbeddedImage(String modelName) throws IOException { @Disabled("Currently Anthropic API does not support external image URLs") @ParameterizedTest(name = "{0} : {displayName} ") - @ValueSource(strings = { "claude-3-opus-20240229", "claude-3-sonnet-20240229", "claude-3-haiku-20240307", - "claude-3-5-sonnet-20241022" }) + @ValueSource(strings = { "claude-3-opus-latest", "claude-3-5-sonnet-latest", "claude-3-haiku-latest", + "claude-3-7-sonnet-latest" }) void multiModalityImageUrl(String modelName) throws IOException { // TODO: add url method that wrapps the checked exception.