diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index c1b319a27ff..ad9e31eb01f 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -22,6 +22,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -90,14 +91,15 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) return builder().model(fromOptions.getModel()) .maxTokens(fromOptions.getMaxTokens()) .metadata(fromOptions.getMetadata()) - .stopSequences(fromOptions.getStopSequences()) + .stopSequences( + fromOptions.getStopSequences() != null ? new ArrayList<>(fromOptions.getStopSequences()) : null) .temperature(fromOptions.getTemperature()) .topP(fromOptions.getTopP()) .topK(fromOptions.getTopK()) .toolCallbacks(fromOptions.getToolCallbacks()) .toolNames(fromOptions.getToolNames()) .internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled()) - .toolContext(fromOptions.getToolContext()) + .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .build(); } @@ -271,10 +273,35 @@ public void setToolContext(Map toolContext) { } @Override + @SuppressWarnings("unchecked") public AnthropicChatOptions copy() { return fromOptions(this); } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof AnthropicChatOptions that)) { + return false; + } + return Objects.equals(this.model, that.model) && Objects.equals(this.maxTokens, that.maxTokens) + && Objects.equals(this.metadata, that.metadata) + && Objects.equals(this.stopSequences, that.stopSequences) + && Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP) + && Objects.equals(this.topK, that.topK) && Objects.equals(this.toolCallbacks, that.toolCallbacks) + && Objects.equals(this.toolNames, that.toolNames) + && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) + && Objects.equals(this.toolContext, that.toolContext); + } + + @Override + public int hashCode() { + return Objects.hash(model, maxTokens, metadata, stopSequences, temperature, topP, topK, toolCallbacks, + toolNames, internalToolExecutionEnabled, toolContext); + } + public static class Builder { private final AnthropicChatOptions options = new AnthropicChatOptions(); diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java new file mode 100644 index 00000000000..fbb8409dffc --- /dev/null +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java @@ -0,0 +1,104 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.anthropic; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.Metadata; + +/** + * Tests for {@link AnthropicChatOptions}. + * + * @author Alexandros Pappas + */ +class AnthropicChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + AnthropicChatOptions options = AnthropicChatOptions.builder() + .model("test-model") + .maxTokens(100) + .stopSequences(List.of("stop1", "stop2")) + .temperature(0.7) + .topP(0.8) + .topK(50) + .metadata(new Metadata("userId_123")) + .build(); + + assertThat(options).extracting("model", "maxTokens", "stopSequences", "temperature", "topP", "topK", "metadata") + .containsExactly("test-model", 100, List.of("stop1", "stop2"), 0.7, 0.8, 50, new Metadata("userId_123")); + } + + @Test + void testCopy() { + AnthropicChatOptions original = AnthropicChatOptions.builder() + .model("test-model") + .maxTokens(100) + .stopSequences(List.of("stop1", "stop2")) + .temperature(0.7) + .topP(0.8) + .topK(50) + .metadata(new Metadata("userId_123")) + .toolContext(Map.of("key1", "value1")) + .build(); + + AnthropicChatOptions copied = original.copy(); + + assertThat(copied).isNotSameAs(original).isEqualTo(original); + // Ensure deep copy + assertThat(copied.getStopSequences()).isNotSameAs(original.getStopSequences()); + assertThat(copied.getToolContext()).isNotSameAs(original.getToolContext()); + } + + @Test + void testSetters() { + AnthropicChatOptions options = new AnthropicChatOptions(); + options.setModel("test-model"); + options.setMaxTokens(100); + options.setTemperature(0.7); + options.setTopK(50); + options.setTopP(0.8); + options.setStopSequences(List.of("stop1", "stop2")); + options.setMetadata(new Metadata("userId_123")); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getMaxTokens()).isEqualTo(100); + assertThat(options.getTemperature()).isEqualTo(0.7); + assertThat(options.getTopK()).isEqualTo(50); + assertThat(options.getTopP()).isEqualTo(0.8); + assertThat(options.getStopSequences()).isEqualTo(List.of("stop1", "stop2")); + assertThat(options.getMetadata()).isEqualTo(new Metadata("userId_123")); + } + + @Test + void testDefaultValues() { + AnthropicChatOptions options = new AnthropicChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopK()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getStopSequences()).isNull(); + assertThat(options.getMetadata()).isNull(); + } + +} diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index 887c0ad6e74..e516034d316 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; @@ -46,6 +47,7 @@ * @author Thomas Vitale * @author Soby Chacko * @author Ilayaperumal Gopinathan + * @author Alexandros Pappas */ @JsonInclude(Include.NON_NULL) public class AzureOpenAiChatOptions implements ToolCallingChatOptions { @@ -250,18 +252,18 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti .maxTokens(fromOptions.getMaxTokens()) .N(fromOptions.getN()) .presencePenalty(fromOptions.getPresencePenalty() != null ? fromOptions.getPresencePenalty() : null) - .stop(fromOptions.getStop()) + .stop(fromOptions.getStop() != null ? new ArrayList<>(fromOptions.getStop()) : null) .temperature(fromOptions.getTemperature()) .topP(fromOptions.getTopP()) .user(fromOptions.getUser()) .functionCallbacks(fromOptions.getFunctionCallbacks()) - .functions(fromOptions.getFunctions()) + .functions(fromOptions.getFunctions() != null ? new HashSet<>(fromOptions.getFunctions()) : null) .responseFormat(fromOptions.getResponseFormat()) .seed(fromOptions.getSeed()) .logprobs(fromOptions.isLogprobs()) .topLogprobs(fromOptions.getTopLogProbs()) .enhancements(fromOptions.getEnhancements()) - .toolContext(fromOptions.getToolContext()) + .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled()) .streamOptions(fromOptions.getStreamOptions()) .toolCallbacks(fromOptions.getToolCallbacks()) @@ -479,10 +481,44 @@ public void setStreamOptions(ChatCompletionStreamOptions streamOptions) { } @Override + @SuppressWarnings("unchecked") public AzureOpenAiChatOptions copy() { return fromOptions(this); } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof AzureOpenAiChatOptions that)) { + return false; + } + return Objects.equals(this.logitBias, that.logitBias) && Objects.equals(this.user, that.user) + && Objects.equals(this.n, that.n) && Objects.equals(this.stop, that.stop) + && Objects.equals(this.deploymentName, that.deploymentName) + && Objects.equals(this.responseFormat, that.responseFormat) + + && Objects.equals(this.toolCallbacks, that.toolCallbacks) + && Objects.equals(this.toolNames, that.toolNames) + && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) + && Objects.equals(this.logprobs, that.logprobs) && Objects.equals(this.topLogProbs, that.topLogProbs) + && Objects.equals(this.enhancements, that.enhancements) + && Objects.equals(this.streamOptions, that.streamOptions) + && Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.maxTokens, that.maxTokens) + && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) + && Objects.equals(this.presencePenalty, that.presencePenalty) + && Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP); + } + + @Override + public int hashCode() { + return Objects.hash(this.logitBias, this.user, this.n, this.stop, this.deploymentName, this.responseFormat, + this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.seed, this.logprobs, + this.topLogProbs, this.enhancements, this.streamOptions, this.toolContext, this.maxTokens, + this.frequencyPenalty, this.presencePenalty, this.temperature, this.topP); + } + public static class Builder { protected AzureOpenAiChatOptions options; diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java new file mode 100644 index 00000000000..b3a8bfd6d74 --- /dev/null +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java @@ -0,0 +1,182 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.azure.openai; + +import java.util.List; +import java.util.Map; + +import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; +import com.azure.ai.openai.models.AzureChatGroundingEnhancementConfiguration; +import com.azure.ai.openai.models.AzureChatOCREnhancementConfiguration; +import com.azure.ai.openai.models.ChatCompletionStreamOptions; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link AzureOpenAiChatOptions}. + * + * @author Alexandros Pappas + */ +class AzureOpenAiChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.TEXT; + ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions(); + streamOptions.setIncludeUsage(true); + + AzureChatEnhancementConfiguration enhancements = new AzureChatEnhancementConfiguration(); + enhancements.setOcr(new AzureChatOCREnhancementConfiguration(true)); + enhancements.setGrounding(new AzureChatGroundingEnhancementConfiguration(true)); + + AzureOpenAiChatOptions options = AzureOpenAiChatOptions.builder() + .deploymentName("test-deployment") + .frequencyPenalty(0.5) + .logitBias(Map.of("token1", 1, "token2", -1)) + .maxTokens(200) + .N(2) + .presencePenalty(0.8) + .stop(List.of("stop1", "stop2")) + .temperature(0.7) + .topP(0.9) + .user("test-user") + .responseFormat(responseFormat) + .seed(12345L) + .logprobs(true) + .topLogprobs(5) + .enhancements(enhancements) + .streamOptions(streamOptions) + .build(); + + assertThat(options) + .extracting("deploymentName", "frequencyPenalty", "logitBias", "maxTokens", "n", "presencePenalty", "stop", + "temperature", "topP", "user", "responseFormat", "seed", "logprobs", "topLogProbs", "enhancements", + "streamOptions") + .containsExactly("test-deployment", 0.5, Map.of("token1", 1, "token2", -1), 200, 2, 0.8, + List.of("stop1", "stop2"), 0.7, 0.9, "test-user", responseFormat, 12345L, true, 5, enhancements, + streamOptions); + } + + @Test + void testCopy() { + AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.TEXT; + ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions(); + streamOptions.setIncludeUsage(true); + + AzureChatEnhancementConfiguration enhancements = new AzureChatEnhancementConfiguration(); + enhancements.setOcr(new AzureChatOCREnhancementConfiguration(true)); + enhancements.setGrounding(new AzureChatGroundingEnhancementConfiguration(true)); + + AzureOpenAiChatOptions originalOptions = AzureOpenAiChatOptions.builder() + .deploymentName("test-deployment") + .frequencyPenalty(0.5) + .logitBias(Map.of("token1", 1, "token2", -1)) + .maxTokens(200) + .N(2) + .presencePenalty(0.8) + .stop(List.of("stop1", "stop2")) + .temperature(0.7) + .topP(0.9) + .user("test-user") + .responseFormat(responseFormat) + .seed(12345L) + .logprobs(true) + .topLogprobs(5) + .enhancements(enhancements) + .streamOptions(streamOptions) + .build(); + + AzureOpenAiChatOptions copiedOptions = originalOptions.copy(); + + assertThat(copiedOptions).isNotSameAs(originalOptions).isEqualTo(originalOptions); + // Ensure deep copy + assertThat(copiedOptions.getStop()).isNotSameAs(originalOptions.getStop()); + assertThat(copiedOptions.getToolContext()).isNotSameAs(originalOptions.getToolContext()); + } + + @Test + void testSetters() { + AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.TEXT; + ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions(); + streamOptions.setIncludeUsage(true); + AzureChatEnhancementConfiguration enhancements = new AzureChatEnhancementConfiguration(); + + AzureOpenAiChatOptions options = new AzureOpenAiChatOptions(); + options.setDeploymentName("test-deployment"); + options.setFrequencyPenalty(0.5); + options.setLogitBias(Map.of("token1", 1, "token2", -1)); + options.setMaxTokens(200); + options.setN(2); + options.setPresencePenalty(0.8); + options.setStop(List.of("stop1", "stop2")); + options.setTemperature(0.7); + options.setTopP(0.9); + options.setUser("test-user"); + options.setResponseFormat(responseFormat); + options.setSeed(12345L); + options.setLogprobs(true); + options.setTopLogProbs(5); + options.setEnhancements(enhancements); + options.setStreamOptions(streamOptions); + + assertThat(options.getDeploymentName()).isEqualTo("test-deployment"); + options.setModel("test-model"); + assertThat(options.getDeploymentName()).isEqualTo("test-model"); + + assertThat(options.getFrequencyPenalty()).isEqualTo(0.5); + assertThat(options.getLogitBias()).isEqualTo(Map.of("token1", 1, "token2", -1)); + assertThat(options.getMaxTokens()).isEqualTo(200); + assertThat(options.getN()).isEqualTo(2); + assertThat(options.getPresencePenalty()).isEqualTo(0.8); + assertThat(options.getStop()).isEqualTo(List.of("stop1", "stop2")); + assertThat(options.getTemperature()).isEqualTo(0.7); + assertThat(options.getTopP()).isEqualTo(0.9); + assertThat(options.getUser()).isEqualTo("test-user"); + assertThat(options.getResponseFormat()).isEqualTo(responseFormat); + assertThat(options.getSeed()).isEqualTo(12345L); + assertThat(options.isLogprobs()).isTrue(); + assertThat(options.getTopLogProbs()).isEqualTo(5); + assertThat(options.getEnhancements()).isEqualTo(enhancements); + assertThat(options.getStreamOptions()).isEqualTo(streamOptions); + assertThat(options.getModel()).isEqualTo("test-model"); + } + + @Test + void testDefaultValues() { + AzureOpenAiChatOptions options = new AzureOpenAiChatOptions(); + + assertThat(options.getDeploymentName()).isNull(); + assertThat(options.getFrequencyPenalty()).isNull(); + assertThat(options.getLogitBias()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getN()).isNull(); + assertThat(options.getPresencePenalty()).isNull(); + assertThat(options.getStop()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getUser()).isNull(); + assertThat(options.getResponseFormat()).isNull(); + assertThat(options.getSeed()).isNull(); + assertThat(options.isLogprobs()).isNull(); + assertThat(options.getTopLogProbs()).isNull(); + assertThat(options.getEnhancements()).isNull(); + assertThat(options.getStreamOptions()).isNull(); + assertThat(options.getModel()).isNull(); + } + +} diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java index 244c1fce699..491da85b457 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,9 +17,11 @@ package org.springframework.ai.minimax; import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -43,6 +45,7 @@ * @author Geng Rong * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @author Alexandros Pappas * @since 1.0.0 M1 */ @JsonInclude(Include.NON_NULL) @@ -146,7 +149,7 @@ public class MiniMaxChatOptions implements FunctionCallingOptions { private Boolean proxyToolCalls; @JsonIgnore - private Map toolContext; + private Map toolContext = new HashMap<>(); // @formatter:on @@ -162,7 +165,7 @@ public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) { .presencePenalty(fromOptions.getPresencePenalty()) .responseFormat(fromOptions.getResponseFormat()) .seed(fromOptions.getSeed()) - .stop(fromOptions.getStop()) + .stop(fromOptions.getStop() != null ? new ArrayList<>(fromOptions.getStop()) : null) .temperature(fromOptions.getTemperature()) .topP(fromOptions.getTopP()) .maskSensitiveInfo(fromOptions.getMaskSensitiveInfo()) @@ -171,7 +174,7 @@ public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) { .functionCallbacks(fromOptions.getFunctionCallbacks()) .functions(fromOptions.getFunctions()) .proxyToolCalls(fromOptions.getProxyToolCalls()) - .toolContext(fromOptions.getToolContext()) + .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .build(); } @@ -342,24 +345,8 @@ public void setToolContext(Map toolContext) { @Override public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); - result = prime * result + ((this.frequencyPenalty == null) ? 0 : this.frequencyPenalty.hashCode()); - result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); - result = prime * result + ((this.n == null) ? 0 : this.n.hashCode()); - result = prime * result + ((this.presencePenalty == null) ? 0 : this.presencePenalty.hashCode()); - result = prime * result + ((this.responseFormat == null) ? 0 : this.responseFormat.hashCode()); - result = prime * result + ((this.seed == null) ? 0 : this.seed.hashCode()); - result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); - result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); - result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); - result = prime * result + ((this.maskSensitiveInfo == null) ? 0 : this.maskSensitiveInfo.hashCode()); - result = prime * result + ((this.tools == null) ? 0 : this.tools.hashCode()); - result = prime * result + ((this.toolChoice == null) ? 0 : this.toolChoice.hashCode()); - result = prime * result + ((this.proxyToolCalls == null) ? 0 : this.proxyToolCalls.hashCode()); - result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode()); - return result; + return Objects.hash(model, frequencyPenalty, maxTokens, n, presencePenalty, responseFormat, seed, stop, + temperature, topP, maskSensitiveInfo, tools, toolChoice, proxyToolCalls, toolContext); } @Override @@ -367,139 +354,25 @@ public boolean equals(Object obj) { if (this == obj) { return true; } - if (obj == null) { - return false; - } - if (getClass() != obj.getClass()) { + if (obj == null || getClass() != obj.getClass()) { return false; } MiniMaxChatOptions other = (MiniMaxChatOptions) obj; - if (this.model == null) { - if (other.model != null) { - return false; - } - } - else if (!this.model.equals(other.model)) { - return false; - } - if (this.frequencyPenalty == null) { - if (other.frequencyPenalty != null) { - return false; - } - } - else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) { - return false; - } - if (this.maxTokens == null) { - if (other.maxTokens != null) { - return false; - } - } - else if (!this.maxTokens.equals(other.maxTokens)) { - return false; - } - if (this.n == null) { - if (other.n != null) { - return false; - } - } - else if (!this.n.equals(other.n)) { - return false; - } - if (this.presencePenalty == null) { - if (other.presencePenalty != null) { - return false; - } - } - else if (!this.presencePenalty.equals(other.presencePenalty)) { - return false; - } - if (this.responseFormat == null) { - if (other.responseFormat != null) { - return false; - } - } - else if (!this.responseFormat.equals(other.responseFormat)) { - return false; - } - if (this.seed == null) { - if (other.seed != null) { - return false; - } - } - else if (!this.seed.equals(other.seed)) { - return false; - } - if (this.stop == null) { - if (other.stop != null) { - return false; - } - } - else if (!this.stop.equals(other.stop)) { - return false; - } - if (this.temperature == null) { - if (other.temperature != null) { - return false; - } - } - else if (!this.temperature.equals(other.temperature)) { - return false; - } - if (this.topP == null) { - if (other.topP != null) { - return false; - } - } - else if (!this.topP.equals(other.topP)) { - return false; - } - if (this.maskSensitiveInfo == null) { - if (other.maskSensitiveInfo != null) { - return false; - } - } - else if (!this.maskSensitiveInfo.equals(other.maskSensitiveInfo)) { - return false; - } - if (this.tools == null) { - if (other.tools != null) { - return false; - } - } - else if (!this.tools.equals(other.tools)) { - return false; - } - if (this.toolChoice == null) { - if (other.toolChoice != null) { - return false; - } - } - else if (!this.toolChoice.equals(other.toolChoice)) { - return false; - } - if (this.proxyToolCalls == null) { - if (other.proxyToolCalls != null) { - return false; - } - } - else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) { - return false; - } - - if (this.toolContext == null) { - if (other.toolContext != null) { - return false; - } - } - else if (!this.toolContext.equals(other.toolContext)) { - return false; - } - return true; + return Objects.equals(this.model, other.model) && Objects.equals(this.frequencyPenalty, other.frequencyPenalty) + && Objects.equals(this.maxTokens, other.maxTokens) && Objects.equals(this.n, other.n) + && Objects.equals(this.presencePenalty, other.presencePenalty) + && Objects.equals(this.responseFormat, other.responseFormat) && Objects.equals(this.seed, other.seed) + && Objects.equals(this.stop, other.stop) && Objects.equals(this.temperature, other.temperature) + && Objects.equals(this.topP, other.topP) + && Objects.equals(this.maskSensitiveInfo, other.maskSensitiveInfo) + && Objects.equals(this.tools, other.tools) && Objects.equals(this.toolChoice, other.toolChoice) + && Objects.equals(this.proxyToolCalls, other.proxyToolCalls) + && Objects.equals(this.toolContext, other.toolContext); } @Override + @SuppressWarnings("unchecked") public MiniMaxChatOptions copy() { return fromOptions(this); } diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/MiniMaxChatOptionsTests.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/MiniMaxChatOptionsTests.java new file mode 100644 index 00000000000..6fbd605fe48 --- /dev/null +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/MiniMaxChatOptionsTests.java @@ -0,0 +1,141 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.minimax; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.minimax.api.MiniMaxApi; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link MiniMaxChatOptions}. + * + * @author Alexandros Pappas + */ +class MiniMaxChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + MiniMaxChatOptions options = MiniMaxChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.5) + .maxTokens(10) + .N(1) + .presencePenalty(0.5) + .responseFormat(new MiniMaxApi.ChatCompletionRequest.ResponseFormat("text")) + .seed(1) + .stop(List.of("test")) + .temperature(0.6) + .topP(0.6) + .maskSensitiveInfo(false) + .toolChoice("test") + .proxyToolCalls(true) + .toolContext(Map.of("key1", "value1")) + .build(); + + assertThat(options) + .extracting("model", "frequencyPenalty", "maxTokens", "N", "presencePenalty", "responseFormat", "seed", + "stopSequences", "temperature", "topP", "maskSensitiveInfo", "toolChoice", "proxyToolCalls", + "toolContext") + .containsExactly("test-model", 0.5, 10, 1, 0.5, new MiniMaxApi.ChatCompletionRequest.ResponseFormat("text"), + 1, List.of("test"), 0.6, 0.6, false, "test", true, Map.of("key1", "value1")); + } + + @Test + void testCopy() { + MiniMaxChatOptions original = MiniMaxChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.5) + .maxTokens(10) + .N(1) + .presencePenalty(0.5) + .responseFormat(new MiniMaxApi.ChatCompletionRequest.ResponseFormat("text")) + .seed(1) + .stop(List.of("test")) + .temperature(0.6) + .topP(0.6) + .maskSensitiveInfo(false) + .toolChoice("test") + .proxyToolCalls(true) + .toolContext(Map.of("key1", "value1")) + .build(); + + MiniMaxChatOptions copied = original.copy(); + + assertThat(copied).isNotSameAs(original).isEqualTo(original); + // Ensure deep copy + assertThat(copied.getStop()).isNotSameAs(original.getStop()); + assertThat(copied.getToolContext()).isNotSameAs(original.getToolContext()); + } + + @Test + void testSetters() { + MiniMaxChatOptions options = new MiniMaxChatOptions(); + options.setModel("test-model"); + options.setFrequencyPenalty(0.5); + options.setMaxTokens(10); + options.setN(1); + options.setPresencePenalty(0.5); + options.setResponseFormat(new MiniMaxApi.ChatCompletionRequest.ResponseFormat("text")); + options.setSeed(1); + options.setStopSequences(List.of("test")); + options.setTemperature(0.6); + options.setTopP(0.6); + options.setMaskSensitiveInfo(false); + options.setToolChoice("test"); + options.setProxyToolCalls(true); + options.setToolContext(Map.of("key1", "value1")); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getFrequencyPenalty()).isEqualTo(0.5); + assertThat(options.getMaxTokens()).isEqualTo(10); + assertThat(options.getN()).isEqualTo(1); + assertThat(options.getPresencePenalty()).isEqualTo(0.5); + assertThat(options.getResponseFormat()).isEqualTo(new MiniMaxApi.ChatCompletionRequest.ResponseFormat("text")); + assertThat(options.getSeed()).isEqualTo(1); + assertThat(options.getStopSequences()).isEqualTo(List.of("test")); + assertThat(options.getTemperature()).isEqualTo(0.6); + assertThat(options.getTopP()).isEqualTo(0.6); + assertThat(options.getMaskSensitiveInfo()).isEqualTo(false); + assertThat(options.getToolChoice()).isEqualTo("test"); + assertThat(options.getProxyToolCalls()).isEqualTo(true); + assertThat(options.getToolContext()).isEqualTo(Map.of("key1", "value1")); + } + + @Test + void testDefaultValues() { + MiniMaxChatOptions options = new MiniMaxChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getFrequencyPenalty()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getN()).isNull(); + assertThat(options.getPresencePenalty()).isNull(); + assertThat(options.getResponseFormat()).isNull(); + assertThat(options.getSeed()).isNull(); + assertThat(options.getStopSequences()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getMaskSensitiveInfo()).isNull(); + assertThat(options.getToolChoice()).isNull(); + assertThat(options.getProxyToolCalls()).isNull(); + assertThat(options.getToolContext()).isEqualTo(new java.util.HashMap<>()); + } + +} diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java index a59e8a71e58..987c2b35182 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java @@ -150,13 +150,13 @@ public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions) .temperature(fromOptions.getTemperature()) .topP(fromOptions.getTopP()) .responseFormat(fromOptions.getResponseFormat()) - .stop(fromOptions.getStop()) + .stop(fromOptions.getStop() != null ? new ArrayList<>(fromOptions.getStop()) : null) .tools(fromOptions.getTools()) .toolChoice(fromOptions.getToolChoice()) .toolCallbacks(fromOptions.getToolCallbacks()) .toolNames(fromOptions.getToolNames()) .internalToolExecutionEnabled(fromOptions.isInternalToolExecutionEnabled()) - .toolContext(fromOptions.getToolContext()) + .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .build(); } @@ -369,6 +369,7 @@ public void setToolContext(Map toolContext) { } @Override + @SuppressWarnings("unchecked") public MistralAiChatOptions copy() { return fromOptions(this); } diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatOptionsTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatOptionsTests.java new file mode 100644 index 00000000000..2871eb0f576 --- /dev/null +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatOptionsTests.java @@ -0,0 +1,127 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mistralai; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ResponseFormat; + +import org.springframework.ai.mistralai.api.MistralAiApi; + +/** + * Tests for {@link MistralAiChatOptions}. + * + * @author Alexandros Pappas + */ +class MistralAiChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + MistralAiChatOptions options = MistralAiChatOptions.builder() + .model("test-model") + .temperature(0.7) + .topP(0.9) + .maxTokens(100) + .safePrompt(true) + .randomSeed(123) + .stop(List.of("stop1", "stop2")) + .responseFormat(new ResponseFormat("json_object")) + .toolChoice(MistralAiApi.ChatCompletionRequest.ToolChoice.AUTO) + .proxyToolCalls(true) + .toolContext(Map.of("key1", "value1")) + .build(); + + assertThat(options) + .extracting("model", "temperature", "topP", "maxTokens", "safePrompt", "randomSeed", "stop", + "responseFormat", "toolChoice", "proxyToolCalls", "toolContext") + .containsExactly("test-model", 0.7, 0.9, 100, true, 123, List.of("stop1", "stop2"), + new ResponseFormat("json_object"), MistralAiApi.ChatCompletionRequest.ToolChoice.AUTO, true, + Map.of("key1", "value1")); + } + + @Test + void testBuilderWithEnum() { + MistralAiChatOptions optionsWithEnum = MistralAiChatOptions.builder() + .model(MistralAiApi.ChatModel.OPEN_MISTRAL_7B) + .build(); + assertThat(optionsWithEnum.getModel()).isEqualTo(MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue()); + } + + @Test + void testCopy() { + MistralAiChatOptions options = MistralAiChatOptions.builder() + .model("test-model") + .temperature(0.7) + .topP(0.9) + .maxTokens(100) + .safePrompt(true) + .randomSeed(123) + .stop(List.of("stop1", "stop2")) + .responseFormat(new ResponseFormat("json_object")) + .toolChoice(MistralAiApi.ChatCompletionRequest.ToolChoice.AUTO) + .proxyToolCalls(true) + .toolContext(Map.of("key1", "value1")) + .build(); + + MistralAiChatOptions copiedOptions = options.copy(); + assertThat(copiedOptions).isNotSameAs(options).isEqualTo(options); + // Ensure deep copy + assertThat(copiedOptions.getStop()).isNotSameAs(options.getStop()); + assertThat(copiedOptions.getToolContext()).isNotSameAs(options.getToolContext()); + } + + @Test + void testSetters() { + ResponseFormat responseFormat = new ResponseFormat("json_object"); + MistralAiChatOptions options = new MistralAiChatOptions(); + options.setModel("test-model"); + options.setTemperature(0.7); + options.setTopP(0.9); + options.setMaxTokens(100); + options.setSafePrompt(true); + options.setRandomSeed(123); + options.setResponseFormat(responseFormat); + options.setStopSequences(List.of("stop1", "stop2")); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getTemperature()).isEqualTo(0.7); + assertThat(options.getTopP()).isEqualTo(0.9); + assertThat(options.getMaxTokens()).isEqualTo(100); + assertThat(options.getSafePrompt()).isEqualTo(true); + assertThat(options.getRandomSeed()).isEqualTo(123); + assertThat(options.getStopSequences()).isEqualTo(List.of("stop1", "stop2")); + assertThat(options.getResponseFormat()).isEqualTo(responseFormat); + } + + @Test + void testDefaultValues() { + MistralAiChatOptions options = new MistralAiChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getSafePrompt()).isNull(); + assertThat(options.getRandomSeed()).isNull(); + assertThat(options.getStopSequences()).isNull(); + assertThat(options.getResponseFormat()).isNull(); + } + +} diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java index 3efddaf89c3..a707a1c08d0 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,9 +17,11 @@ package org.springframework.ai.moonshot; import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -141,7 +143,7 @@ public class MoonshotChatOptions implements FunctionCallingOptions { private Boolean proxyToolCalls; @JsonIgnore - private Map toolContext; + private Map toolContext = new HashMap<>(); public static Builder builder() { return new Builder(); @@ -281,6 +283,7 @@ public void setToolContext(Map toolContext) { } @Override + @SuppressWarnings("unchecked") public MoonshotChatOptions copy() { return builder().model(this.model) .maxTokens(this.maxTokens) @@ -289,130 +292,38 @@ public MoonshotChatOptions copy() { .N(this.n) .presencePenalty(this.presencePenalty) .frequencyPenalty(this.frequencyPenalty) - .stop(this.stop) + .stop(this.stop != null ? new ArrayList<>(this.stop) : null) .user(this.user) - .tools(this.tools) + .tools(this.tools != null ? new ArrayList<>(this.tools) : null) .toolChoice(this.toolChoice) .functionCallbacks(this.functionCallbacks) .functions(this.functions) .proxyToolCalls(this.proxyToolCalls) - .toolContext(this.toolContext) + .toolContext(this.toolContext != null ? new HashMap<>(this.toolContext) : null) .build(); } @Override public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); - result = prime * result + ((this.frequencyPenalty == null) ? 0 : this.frequencyPenalty.hashCode()); - result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); - result = prime * result + ((this.n == null) ? 0 : this.n.hashCode()); - result = prime * result + ((this.presencePenalty == null) ? 0 : this.presencePenalty.hashCode()); - result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); - result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); - result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); - result = prime * result + ((this.user == null) ? 0 : this.user.hashCode()); - result = prime * result + ((this.proxyToolCalls == null) ? 0 : this.proxyToolCalls.hashCode()); - result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode()); - return result; + return Objects.hash(model, frequencyPenalty, maxTokens, n, presencePenalty, stop, temperature, topP, user, + proxyToolCalls, toolContext); } @Override public boolean equals(Object obj) { - if (this == obj) { + if (this == obj) return true; - } - if (obj == null) { - return false; - } - if (getClass() != obj.getClass()) { + if (obj == null || getClass() != obj.getClass()) return false; - } + MoonshotChatOptions other = (MoonshotChatOptions) obj; - if (this.model == null) { - if (other.model != null) { - return false; - } - } - else if (!this.model.equals(other.model)) { - return false; - } - if (this.frequencyPenalty == null) { - if (other.frequencyPenalty != null) { - return false; - } - } - else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) { - return false; - } - if (this.maxTokens == null) { - if (other.maxTokens != null) { - return false; - } - } - else if (!this.maxTokens.equals(other.maxTokens)) { - return false; - } - if (this.n == null) { - if (other.n != null) { - return false; - } - } - else if (!this.n.equals(other.n)) { - return false; - } - if (this.presencePenalty == null) { - if (other.presencePenalty != null) { - return false; - } - } - else if (!this.presencePenalty.equals(other.presencePenalty)) { - return false; - } - if (this.stop == null) { - if (other.stop != null) { - return false; - } - } - else if (!this.stop.equals(other.stop)) { - return false; - } - if (this.temperature == null) { - if (other.temperature != null) { - return false; - } - } - else if (!this.temperature.equals(other.temperature)) { - return false; - } - if (this.topP == null) { - if (other.topP != null) { - return false; - } - } - else if (!this.topP.equals(other.topP)) { - return false; - } - if (this.user == null) { - return other.user == null; - } - else if (!this.user.equals(other.user)) { - return false; - } - if (this.proxyToolCalls == null) { - return other.proxyToolCalls == null; - } - else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) { - return false; - } - if (this.toolContext == null) { - return other.toolContext == null; - } - else if (!this.toolContext.equals(other.toolContext)) { - return false; - } - return true; + + return Objects.equals(this.model, other.model) && Objects.equals(this.frequencyPenalty, other.frequencyPenalty) + && Objects.equals(this.maxTokens, other.maxTokens) && Objects.equals(this.n, other.n) + && Objects.equals(this.presencePenalty, other.presencePenalty) && Objects.equals(this.stop, other.stop) + && Objects.equals(this.temperature, other.temperature) && Objects.equals(this.topP, other.topP) + && Objects.equals(this.user, other.user) && Objects.equals(this.proxyToolCalls, other.proxyToolCalls) + && Objects.equals(this.toolContext, other.toolContext); } public static class Builder { diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotChatOptionsTests.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotChatOptionsTests.java new file mode 100644 index 00000000000..9486b65b3ba --- /dev/null +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotChatOptionsTests.java @@ -0,0 +1,126 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.moonshot; + +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link MoonshotChatOptions}. + * + * @author Alexandros Pappas + */ +class MoonshotChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + MoonshotChatOptions options = MoonshotChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.5) + .maxTokens(10) + .N(1) + .presencePenalty(0.5) + .stop(List.of("test")) + .temperature(0.6) + .topP(0.6) + .toolChoice("test") + .proxyToolCalls(true) + .toolContext(Map.of("key1", "value1")) + .user("test-user") + .build(); + + assertThat(options) + .extracting("model", "frequencyPenalty", "maxTokens", "N", "presencePenalty", "stop", "temperature", "topP", + "toolChoice", "proxyToolCalls", "toolContext", "user") + .containsExactly("test-model", 0.5, 10, 1, 0.5, List.of("test"), 0.6, 0.6, "test", true, + Map.of("key1", "value1"), "test-user"); + } + + @Test + void testCopy() { + MoonshotChatOptions original = MoonshotChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.5) + .maxTokens(10) + .N(1) + .presencePenalty(0.5) + .stop(List.of("test")) + .temperature(0.6) + .topP(0.6) + .toolChoice("test") + .proxyToolCalls(true) + .toolContext(Map.of("key1", "value1")) + .user("test-user") + .build(); + + MoonshotChatOptions copied = original.copy(); + + assertThat(copied).isNotSameAs(original).isEqualTo(original); + // Ensure deep copy + assertThat(copied.getStop()).isNotSameAs(original.getStop()); + assertThat(copied.getToolContext()).isNotSameAs(original.getToolContext()); + } + + @Test + void testSetters() { + MoonshotChatOptions options = new MoonshotChatOptions(); + options.setModel("test-model"); + options.setFrequencyPenalty(0.5); + options.setMaxTokens(10); + options.setN(1); + options.setPresencePenalty(0.5); + options.setUser("test-user"); + options.setStop(List.of("test")); + options.setTemperature(0.6); + options.setTopP(0.6); + options.setProxyToolCalls(true); + options.setToolContext(Map.of("key1", "value1")); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getFrequencyPenalty()).isEqualTo(0.5); + assertThat(options.getMaxTokens()).isEqualTo(10); + assertThat(options.getN()).isEqualTo(1); + assertThat(options.getPresencePenalty()).isEqualTo(0.5); + assertThat(options.getUser()).isEqualTo("test-user"); + assertThat(options.getStopSequences()).isEqualTo(List.of("test")); + assertThat(options.getTemperature()).isEqualTo(0.6); + assertThat(options.getTopP()).isEqualTo(0.6); + assertThat(options.getProxyToolCalls()).isTrue(); + assertThat(options.getToolContext()).isEqualTo(Map.of("key1", "value1")); + } + + @Test + void testDefaultValues() { + MoonshotChatOptions options = new MoonshotChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getFrequencyPenalty()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getN()).isNull(); + assertThat(options.getPresencePenalty()).isNull(); + assertThat(options.getUser()).isNull(); + assertThat(options.getStopSequences()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getProxyToolCalls()).isNull(); + assertThat(options.getToolContext()).isEqualTo(new java.util.HashMap<>()); + } + +} diff --git a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java index afd17ab456f..1516d88c252 100644 --- a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java +++ b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,9 @@ package org.springframework.ai.oci.cohere; +import java.util.ArrayList; import java.util.List; +import java.util.Objects; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; @@ -29,6 +31,7 @@ * * @author Anders Swanson * @author Ilayaperumal Gopinathan + * @author Alexandros Pappas */ @JsonInclude(JsonInclude.Include.NON_NULL) public class OCICohereChatOptions implements ChatOptions { @@ -124,11 +127,11 @@ public static OCICohereChatOptions fromOptions(OCICohereChatOptions fromOptions) .temperature(fromOptions.temperature) .topP(fromOptions.topP) .topK(fromOptions.topK) - .stop(fromOptions.stop) + .stop(fromOptions.stop != null ? new ArrayList<>(fromOptions.stop) : null) .frequencyPenalty(fromOptions.frequencyPenalty) .presencePenalty(fromOptions.presencePenalty) - .documents(fromOptions.documents) - .tools(fromOptions.tools) + .documents(fromOptions.documents != null ? new ArrayList<>(fromOptions.documents) : null) + .tools(fromOptions.tools != null ? new ArrayList<>(fromOptions.tools) : null) .build(); } @@ -257,10 +260,37 @@ public Double getTopP() { } @Override + @SuppressWarnings("unchecked") public ChatOptions copy() { return fromOptions(this); } + @Override + public int hashCode() { + return Objects.hash(model, maxTokens, compartment, servingMode, preambleOverride, temperature, topP, topK, stop, + frequencyPenalty, presencePenalty, documents, tools); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + OCICohereChatOptions that = (OCICohereChatOptions) o; + + return Objects.equals(this.model, that.model) && Objects.equals(this.maxTokens, that.maxTokens) + && Objects.equals(this.compartment, that.compartment) + && Objects.equals(this.servingMode, that.servingMode) + && Objects.equals(this.preambleOverride, that.preambleOverride) + && Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP) + && Objects.equals(this.topK, that.topK) && Objects.equals(this.stop, that.stop) + && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) + && Objects.equals(this.presencePenalty, that.presencePenalty) + && Objects.equals(this.documents, that.documents) && Objects.equals(this.tools, that.tools); + } + public static class Builder { protected OCICohereChatOptions chatOptions; diff --git a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/cohere/OCICohereChatOptionsTests.java b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/cohere/OCICohereChatOptionsTests.java new file mode 100644 index 00000000000..4359d920bb5 --- /dev/null +++ b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/cohere/OCICohereChatOptionsTests.java @@ -0,0 +1,133 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.oci.cohere; + +import com.oracle.bmc.generativeaiinference.model.CohereTool; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link OCICohereChatOptions}. + * + * @author Alexandros Pappas + */ +class OCICohereChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + OCICohereChatOptions options = OCICohereChatOptions.builder() + .model("test-model") + .maxTokens(10) + .compartment("test-compartment") + .servingMode("test-servingMode") + .preambleOverride("test-preambleOverride") + .temperature(0.6) + .topP(0.6) + .topK(50) + .stop(List.of("test")) + .frequencyPenalty(0.5) + .presencePenalty(0.5) + .documents(List.of("doc1", "doc2")) + .build(); + + assertThat(options) + .extracting("model", "maxTokens", "compartment", "servingMode", "preambleOverride", "temperature", "topP", + "topK", "stop", "frequencyPenalty", "presencePenalty", "documents") + .containsExactly("test-model", 10, "test-compartment", "test-servingMode", "test-preambleOverride", 0.6, + 0.6, 50, List.of("test"), 0.5, 0.5, List.of("doc1", "doc2")); + } + + @Test + void testCopy() { + OCICohereChatOptions original = OCICohereChatOptions.builder() + .model("test-model") + .maxTokens(10) + .compartment("test-compartment") + .servingMode("test-servingMode") + .preambleOverride("test-preambleOverride") + .temperature(0.6) + .topP(0.6) + .topK(50) + .stop(List.of("test")) + .frequencyPenalty(0.5) + .presencePenalty(0.5) + .documents(List.of("doc1", "doc2")) + .tools(List.of(new CohereTool("test-tool", "test-context", Map.of()))) + .build(); + + OCICohereChatOptions copied = (OCICohereChatOptions) original.copy(); + + assertThat(copied).isNotSameAs(original).isEqualTo(original); + // Ensure deep copy + assertThat(copied.getStop()).isNotSameAs(original.getStop()); + assertThat(copied.getDocuments()).isNotSameAs(original.getDocuments()); + assertThat(copied.getTools()).isNotSameAs(original.getTools()); + } + + @Test + void testSetters() { + OCICohereChatOptions options = new OCICohereChatOptions(); + options.setModel("test-model"); + options.setMaxTokens(10); + options.setCompartment("test-compartment"); + options.setServingMode("test-servingMode"); + options.setPreambleOverride("test-preambleOverride"); + options.setTemperature(0.6); + options.setTopP(0.6); + options.setTopK(50); + options.setStop(List.of("test")); + options.setFrequencyPenalty(0.5); + options.setPresencePenalty(0.5); + options.setDocuments(List.of("doc1", "doc2")); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getMaxTokens()).isEqualTo(10); + assertThat(options.getCompartment()).isEqualTo("test-compartment"); + assertThat(options.getServingMode()).isEqualTo("test-servingMode"); + assertThat(options.getPreambleOverride()).isEqualTo("test-preambleOverride"); + assertThat(options.getTemperature()).isEqualTo(0.6); + assertThat(options.getTopP()).isEqualTo(0.6); + assertThat(options.getTopK()).isEqualTo(50); + assertThat(options.getStop()).isEqualTo(List.of("test")); + assertThat(options.getFrequencyPenalty()).isEqualTo(0.5); + assertThat(options.getPresencePenalty()).isEqualTo(0.5); + assertThat(options.getDocuments()).isEqualTo(List.of("doc1", "doc2")); + } + + @Test + void testDefaultValues() { + OCICohereChatOptions options = new OCICohereChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getCompartment()).isNull(); + assertThat(options.getServingMode()).isNull(); + assertThat(options.getPreambleOverride()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getTopK()).isNull(); + assertThat(options.getStop()).isNull(); + assertThat(options.getFrequencyPenalty()).isNull(); + assertThat(options.getPresencePenalty()).isNull(); + assertThat(options.getDocuments()).isNull(); + assertThat(options.getTools()).isNull(); + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index c1193d4b261..b4d4b9b7220 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -49,6 +49,7 @@ * @author Mariusz Bernacki * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @author Alexandros Pappas * @since 0.8.0 */ @JsonInclude(Include.NON_NULL) @@ -588,6 +589,7 @@ public void setReasoningEffort(String reasoningEffort) { } @Override + @SuppressWarnings("unchecked") public OpenAiChatOptions copy() { return OpenAiChatOptions.fromOptions(this); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java new file mode 100644 index 00000000000..5fc63a547ca --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java @@ -0,0 +1,262 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import org.junit.jupiter.api.Test; +import static org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters.Voice.ALLOY; + +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.StreamOptions; +import org.springframework.ai.openai.api.ResponseFormat; + +/** + * Tests for {@link OpenAiChatOptions}. + * + * @author Alexandros Pappas + */ +class OpenAiChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + Map logitBias = new HashMap<>(); + logitBias.put("token1", 1); + logitBias.put("token2", -1); + + List outputModalities = List.of("text", "audio"); + AudioParameters outputAudio = new AudioParameters(ALLOY, AudioParameters.AudioResponseFormat.MP3); + ResponseFormat responseFormat = new ResponseFormat(); + StreamOptions streamOptions = StreamOptions.INCLUDE_USAGE; + List stopSequences = List.of("stop1", "stop2"); + List tools = new ArrayList<>(); + Object toolChoice = "auto"; + Map metadata = Map.of("key1", "value1"); + Map toolContext = Map.of("keyA", "valueA"); + + OpenAiChatOptions options = OpenAiChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.5) + .logitBias(logitBias) + .logprobs(true) + .topLogprobs(5) + .maxTokens(100) + .maxCompletionTokens(50) + .N(2) + .outputModalities(outputModalities) + .outputAudio(outputAudio) + .presencePenalty(0.8) + .responseFormat(responseFormat) + .streamUsage(true) + .seed(12345) + .stop(stopSequences) + .temperature(0.7) + .topP(0.9) + .tools(tools) + .toolChoice(toolChoice) + .user("test-user") + .parallelToolCalls(true) + .store(false) + .metadata(metadata) + .reasoningEffort("medium") + .proxyToolCalls(false) + .httpHeaders(Map.of("header1", "value1")) + .toolContext(toolContext) + .build(); + + assertThat(options) + .extracting("model", "frequencyPenalty", "logitBias", "logprobs", "topLogprobs", "maxTokens", + "maxCompletionTokens", "n", "outputModalities", "outputAudio", "presencePenalty", "responseFormat", + "streamOptions", "seed", "stop", "temperature", "topP", "tools", "toolChoice", "user", + "parallelToolCalls", "store", "metadata", "reasoningEffort", "proxyToolCalls", "httpHeaders", + "toolContext") + .containsExactly("test-model", 0.5, logitBias, true, 5, 100, 50, 2, outputModalities, outputAudio, 0.8, + responseFormat, streamOptions, 12345, stopSequences, 0.7, 0.9, tools, toolChoice, "test-user", true, + false, metadata, "medium", false, Map.of("header1", "value1"), toolContext); + + assertThat(options.getStreamUsage()).isTrue(); + assertThat(options.getStreamOptions()).isEqualTo(StreamOptions.INCLUDE_USAGE); + + } + + @Test + void testCopy() { + Map logitBias = new HashMap<>(); + logitBias.put("token1", 1); + + List outputModalities = List.of("text"); + AudioParameters outputAudio = new AudioParameters(ALLOY, AudioParameters.AudioResponseFormat.MP3); + ResponseFormat responseFormat = new ResponseFormat(); + + List stopSequences = List.of("stop1"); + List tools = new ArrayList<>(); + Object toolChoice = "none"; + Map metadata = Map.of("key1", "value1"); + + OpenAiChatOptions originalOptions = OpenAiChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.5) + .logitBias(logitBias) + .logprobs(true) + .topLogprobs(5) + .maxTokens(100) + .maxCompletionTokens(50) + .N(2) + .outputModalities(outputModalities) + .outputAudio(outputAudio) + .presencePenalty(0.8) + .responseFormat(responseFormat) + .streamUsage(false) + .seed(12345) + .stop(stopSequences) + .temperature(0.7) + .topP(0.9) + .tools(tools) + .toolChoice(toolChoice) + .user("test-user") + .parallelToolCalls(false) + .store(true) + .metadata(metadata) + .reasoningEffort("low") + .proxyToolCalls(true) + .httpHeaders(Map.of("header1", "value1")) + .build(); + + OpenAiChatOptions copiedOptions = originalOptions.copy(); + assertThat(copiedOptions).isNotSameAs(originalOptions).isEqualTo(originalOptions); + } + + @Test + void testSetters() { + Map logitBias = new HashMap<>(); + logitBias.put("token1", 1); + + List outputModalities = List.of("audio"); + AudioParameters outputAudio = new AudioParameters(ALLOY, AudioParameters.AudioResponseFormat.MP3); + ResponseFormat responseFormat = new ResponseFormat(); + + StreamOptions streamOptions = StreamOptions.INCLUDE_USAGE; + List stopSequences = List.of("stop1", "stop2"); + List tools = new ArrayList<>(); + Object toolChoice = "auto"; + Map metadata = Map.of("key2", "value2"); + + OpenAiChatOptions options = new OpenAiChatOptions(); + options.setModel("test-model"); + options.setFrequencyPenalty(0.5); + options.setLogitBias(logitBias); + options.setLogprobs(true); + options.setTopLogprobs(5); + options.setMaxTokens(100); + options.setMaxCompletionTokens(50); + options.setN(2); + options.setOutputModalities(outputModalities); + options.setOutputAudio(outputAudio); + options.setPresencePenalty(0.8); + options.setResponseFormat(responseFormat); + options.setStreamOptions(streamOptions); + options.setSeed(12345); + options.setStop(stopSequences); + options.setTemperature(0.7); + options.setTopP(0.9); + options.setTools(tools); + options.setToolChoice(toolChoice); + options.setUser("test-user"); + options.setParallelToolCalls(true); + options.setStore(false); + options.setMetadata(metadata); + options.setReasoningEffort("high"); + options.setProxyToolCalls(false); + options.setHttpHeaders(Map.of("header2", "value2")); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getFrequencyPenalty()).isEqualTo(0.5); + assertThat(options.getLogitBias()).isEqualTo(logitBias); + assertThat(options.getLogprobs()).isTrue(); + assertThat(options.getTopLogprobs()).isEqualTo(5); + assertThat(options.getMaxTokens()).isEqualTo(100); + assertThat(options.getMaxCompletionTokens()).isEqualTo(50); + assertThat(options.getN()).isEqualTo(2); + assertThat(options.getOutputModalities()).isEqualTo(outputModalities); + assertThat(options.getOutputAudio()).isEqualTo(outputAudio); + assertThat(options.getPresencePenalty()).isEqualTo(0.8); + assertThat(options.getResponseFormat()).isEqualTo(responseFormat); + assertThat(options.getStreamOptions()).isEqualTo(streamOptions); + assertThat(options.getSeed()).isEqualTo(12345); + assertThat(options.getStop()).isEqualTo(stopSequences); + assertThat(options.getTemperature()).isEqualTo(0.7); + assertThat(options.getTopP()).isEqualTo(0.9); + assertThat(options.getTools()).isEqualTo(tools); + assertThat(options.getToolChoice()).isEqualTo(toolChoice); + assertThat(options.getUser()).isEqualTo("test-user"); + assertThat(options.getParallelToolCalls()).isTrue(); + assertThat(options.getStore()).isFalse(); + assertThat(options.getMetadata()).isEqualTo(metadata); + assertThat(options.getReasoningEffort()).isEqualTo("high"); + assertThat(options.getProxyToolCalls()).isFalse(); + assertThat(options.getHttpHeaders()).isEqualTo(Map.of("header2", "value2")); + assertThat(options.getStreamUsage()).isTrue(); + options.setStreamUsage(false); + assertThat(options.getStreamUsage()).isFalse(); + assertThat(options.getStreamOptions()).isNull(); + options.setStopSequences(List.of("s1", "s2")); + assertThat(options.getStopSequences()).isEqualTo(List.of("s1", "s2")); + assertThat(options.getStop()).isEqualTo(List.of("s1", "s2")); + } + + @Test + void testDefaultValues() { + OpenAiChatOptions options = new OpenAiChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getFrequencyPenalty()).isNull(); + assertThat(options.getLogitBias()).isNull(); + assertThat(options.getLogprobs()).isNull(); + assertThat(options.getTopLogprobs()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getMaxCompletionTokens()).isNull(); + assertThat(options.getN()).isNull(); + assertThat(options.getOutputModalities()).isNull(); + assertThat(options.getOutputAudio()).isNull(); + assertThat(options.getPresencePenalty()).isNull(); + assertThat(options.getResponseFormat()).isNull(); + assertThat(options.getStreamOptions()).isNull(); + assertThat(options.getSeed()).isNull(); + assertThat(options.getStop()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getTools()).isNull(); + assertThat(options.getToolChoice()).isNull(); + assertThat(options.getUser()).isNull(); + assertThat(options.getParallelToolCalls()).isNull(); + assertThat(options.getStore()).isNull(); + assertThat(options.getMetadata()).isNull(); + assertThat(options.getReasoningEffort()).isNull(); + assertThat(options.getFunctionCallbacks()).isNotNull().isEmpty(); + assertThat(options.getFunctions()).isNotNull().isEmpty(); + assertThat(options.getProxyToolCalls()).isNull(); + assertThat(options.getHttpHeaders()).isNotNull().isEmpty(); + assertThat(options.getToolContext()).isEqualTo(new HashMap<>()); + assertThat(options.getStreamUsage()).isFalse(); + assertThat(options.getStopSequences()).isNull(); + } + +} diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java index 5531e90a88a..22484aa9a3a 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,9 @@ package org.springframework.ai.qianfan; +import java.util.ArrayList; import java.util.List; +import java.util.Objects; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; @@ -33,6 +35,7 @@ * * @author Geng Rong * @author Ilayaperumal Gopinathan + * @author Alexandros Pappas * @since 1.0 * @see ChatOptions */ @@ -64,6 +67,7 @@ public class QianFanChatOptions implements ChatOptions { * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. */ private @JsonProperty("response_format") QianFanApi.ChatCompletionRequest.ResponseFormat responseFormat; + /** * Up to 4 sequences where the API will stop generating further tokens. */ @@ -93,7 +97,7 @@ public static QianFanChatOptions fromOptions(QianFanChatOptions fromOptions) { .maxTokens(fromOptions.getMaxTokens()) .presencePenalty(fromOptions.getPresencePenalty()) .responseFormat(fromOptions.getResponseFormat()) - .stop(fromOptions.getStop()) + .stop(fromOptions.getStop() != null ? new ArrayList<>(fromOptions.getStop()) : null) .temperature(fromOptions.getTemperature()) .topP(fromOptions.getTopP()) .build(); @@ -187,100 +191,28 @@ public Integer getTopK() { } @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); - result = prime * result + ((this.frequencyPenalty == null) ? 0 : this.frequencyPenalty.hashCode()); - result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); - result = prime * result + ((this.presencePenalty == null) ? 0 : this.presencePenalty.hashCode()); - result = prime * result + ((this.responseFormat == null) ? 0 : this.responseFormat.hashCode()); - result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); - result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); - result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { + public boolean equals(Object o) { + if (this == o) { return true; } - if (obj == null) { - return false; - } - if (getClass() != obj.getClass()) { - return false; - } - QianFanChatOptions other = (QianFanChatOptions) obj; - if (this.model == null) { - if (other.model != null) { - return false; - } - } - else if (!this.model.equals(other.model)) { + if (!(o instanceof QianFanChatOptions that)) { return false; } - if (this.frequencyPenalty == null) { - if (other.frequencyPenalty != null) { - return false; - } - } - else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) { - return false; - } - if (this.maxTokens == null) { - if (other.maxTokens != null) { - return false; - } - } - else if (!this.maxTokens.equals(other.maxTokens)) { - return false; - } - if (this.presencePenalty == null) { - if (other.presencePenalty != null) { - return false; - } - } - else if (!this.presencePenalty.equals(other.presencePenalty)) { - return false; - } - if (this.responseFormat == null) { - if (other.responseFormat != null) { - return false; - } - } - else if (!this.responseFormat.equals(other.responseFormat)) { - return false; - } - if (this.stop == null) { - if (other.stop != null) { - return false; - } - } - else if (!this.stop.equals(other.stop)) { - return false; - } - if (this.temperature == null) { - if (other.temperature != null) { - return false; - } - } - else if (!this.temperature.equals(other.temperature)) { - return false; - } - if (this.topP == null) { - if (other.topP != null) { - return false; - } - } - else if (!this.topP.equals(other.topP)) { - return false; - } - return true; + return Objects.equals(this.model, that.model) && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) + && Objects.equals(this.maxTokens, that.maxTokens) + && Objects.equals(this.presencePenalty, that.presencePenalty) + && Objects.equals(this.responseFormat, that.responseFormat) && Objects.equals(this.stop, that.stop) + && Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP); + } + + @Override + public int hashCode() { + return Objects.hash(this.model, this.frequencyPenalty, this.maxTokens, this.presencePenalty, + this.responseFormat, this.stop, this.temperature, this.topP); } @Override + @SuppressWarnings("unchecked") public QianFanChatOptions copy() { return fromOptions(this); } diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/QianFanChatOptionsTests.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/QianFanChatOptionsTests.java new file mode 100644 index 00000000000..0a18ed78590 --- /dev/null +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/QianFanChatOptionsTests.java @@ -0,0 +1,127 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.qianfan; + +import java.util.List; + +import com.fasterxml.jackson.databind.ObjectMapper; +import static org.assertj.core.api.Assertions.assertThat; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.qianfan.api.QianFanApi; + +/** + * Tests for {@link QianFanChatOptions}. + */ +class QianFanChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + QianFanChatOptions options = QianFanChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.5) + .maxTokens(10) + .presencePenalty(0.5) + .responseFormat(new QianFanApi.ChatCompletionRequest.ResponseFormat("text")) + .stop(List.of("test")) + .temperature(0.6) + .topP(0.6) + .build(); + + assertThat(options) + .extracting("model", "frequencyPenalty", "maxTokens", "presencePenalty", "responseFormat", "stop", + "temperature", "topP") + .containsExactly("test-model", 0.5, 10, 0.5, new QianFanApi.ChatCompletionRequest.ResponseFormat("text"), + List.of("test"), 0.6, 0.6); + } + + @Test + void testCopy() { + QianFanChatOptions original = QianFanChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.5) + .maxTokens(10) + .presencePenalty(0.5) + .responseFormat(new QianFanApi.ChatCompletionRequest.ResponseFormat("text")) + .stop(List.of("test")) + .temperature(0.6) + .topP(0.6) + .build(); + + QianFanChatOptions copied = original.copy(); + + assertThat(copied).isNotSameAs(original).isEqualTo(original); + // Ensure deep copy + assertThat(copied.getStop()).isNotSameAs(original.getStop()); + } + + @Test + void testSetters() { + QianFanChatOptions options = new QianFanChatOptions(); + options.setModel("test-model"); + options.setFrequencyPenalty(0.5); + options.setMaxTokens(10); + options.setPresencePenalty(0.5); + options.setResponseFormat(new QianFanApi.ChatCompletionRequest.ResponseFormat("text")); + options.setStop(List.of("test")); + options.setTemperature(0.6); + options.setTopP(0.6); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getFrequencyPenalty()).isEqualTo(0.5); + assertThat(options.getMaxTokens()).isEqualTo(10); + assertThat(options.getPresencePenalty()).isEqualTo(0.5); + assertThat(options.getResponseFormat()).isEqualTo(new QianFanApi.ChatCompletionRequest.ResponseFormat("text")); + assertThat(options.getStop()).isEqualTo(List.of("test")); + assertThat(options.getTemperature()).isEqualTo(0.6); + assertThat(options.getTopP()).isEqualTo(0.6); + } + + @Test + void testDefaultValues() { + QianFanChatOptions options = new QianFanChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getFrequencyPenalty()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getPresencePenalty()).isNull(); + assertThat(options.getResponseFormat()).isNull(); + assertThat(options.getStop()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + } + + @Test + void testSerialization() throws Exception { + QianFanChatOptions options = QianFanChatOptions.builder().maxTokens(10).build(); + + ObjectMapper objectMapper = new ObjectMapper(); + String json = objectMapper.writeValueAsString(options); + + assertThat(json).contains("\"max_output_tokens\":10"); + } + + @Test + void testDeserialization() throws Exception { + String json = "{\"max_output_tokens\":10}"; + + ObjectMapper objectMapper = new ObjectMapper(); + QianFanChatOptions options = objectMapper.readValue(json, QianFanChatOptions.class); + + assertThat(options.getMaxTokens()).isEqualTo(10); + } + +} diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java index f6413c759cb..440dc718ed5 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,9 +16,11 @@ package org.springframework.ai.watsonx; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.stream.Collectors; import com.fasterxml.jackson.annotation.JsonAnyGetter; @@ -162,11 +164,11 @@ public static WatsonxAiChatOptions fromOptions(WatsonxAiChatOptions fromOptions) .decodingMethod(fromOptions.getDecodingMethod()) .maxNewTokens(fromOptions.getMaxNewTokens()) .minNewTokens(fromOptions.getMinNewTokens()) - .stopSequences(fromOptions.getStopSequences()) + .stopSequences(fromOptions.getStopSequences() != null ? new ArrayList<>(fromOptions.getStopSequences()) : null) .repetitionPenalty(fromOptions.getRepetitionPenalty()) .randomSeed(fromOptions.getRandomSeed()) .model(fromOptions.getModel()) - .additionalProperties(fromOptions.getAdditionalProperties()) + .additionalProperties(fromOptions.getAdditionalProperties() != null ? new HashMap<>(fromOptions.getAdditionalProperties()) : null) .build(); } @@ -319,10 +321,38 @@ private String toSnakeCase(String input) { } @Override + @SuppressWarnings("unchecked") public WatsonxAiChatOptions copy() { return fromOptions(this); } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + WatsonxAiChatOptions that = (WatsonxAiChatOptions) o; + + return Objects.equals(this.decodingMethod, that.decodingMethod) && + Objects.equals(this.maxNewTokens, that.maxNewTokens) && + Objects.equals(this.minNewTokens, that.minNewTokens) && + Objects.equals(this.repetitionPenalty, that.repetitionPenalty) && + Objects.equals(this.randomSeed, that.randomSeed) && + Objects.equals(this.temperature, that.temperature) && + Objects.equals(this.topP, that.topP) && + Objects.equals(this.topK, that.topK) && + Objects.equals(this.stopSequences, that.stopSequences) && + Objects.equals(this.model, that.model) && + Objects.equals(this.additional, that.additional); + } + + @Override + public int hashCode() { + return Objects.hash(this.decodingMethod, this.maxNewTokens, this.minNewTokens, + this.repetitionPenalty, this.randomSeed, this.temperature, + this.topP, this.topK, this.stopSequences, this.model, this.additional); + } + public static class Builder { WatsonxAiChatOptions options = new WatsonxAiChatOptions(); diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatOptionsTests.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatOptionsTests.java new file mode 100644 index 00000000000..f4936942d02 --- /dev/null +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatOptionsTests.java @@ -0,0 +1,124 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.watsonx; + +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link WatsonxAiChatOptions}. + * + * @author Alexandros Pappas + */ +class WatsonxAiChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + WatsonxAiChatOptions options = WatsonxAiChatOptions.builder() + .temperature(0.6) + .topP(0.6) + .topK(50) + .decodingMethod("greedy") + .maxNewTokens(100) + .minNewTokens(10) + .stopSequences(List.of("test")) + .repetitionPenalty(1.2) + .randomSeed(42) + .model("test-model") + .additionalProperties(Map.of("key1", "value1")) + .build(); + + assertThat(options) + .extracting("temperature", "topP", "topK", "decodingMethod", "maxNewTokens", "minNewTokens", + "stopSequences", "repetitionPenalty", "randomSeed", "model", "additionalProperties") + .containsExactly(0.6, 0.6, 50, "greedy", 100, 10, List.of("test"), 1.2, 42, "test-model", + Map.of("key1", "value1")); + } + + @Test + void testCopy() { + WatsonxAiChatOptions original = WatsonxAiChatOptions.builder() + .temperature(0.6) + .topP(0.6) + .topK(50) + .decodingMethod("greedy") + .maxNewTokens(100) + .minNewTokens(10) + .stopSequences(List.of("test")) + .repetitionPenalty(1.2) + .randomSeed(42) + .model("test-model") + .additionalProperties(Map.of("key1", "value1")) + .build(); + + WatsonxAiChatOptions copied = original.copy(); + + assertThat(copied).isNotSameAs(original).isEqualTo(original); + assertThat(copied.getStopSequences()).isNotSameAs(original.getStopSequences()); + assertThat(copied.getAdditionalProperties()).isNotSameAs(original.getAdditionalProperties()); + } + + @Test + void testSetters() { + WatsonxAiChatOptions options = new WatsonxAiChatOptions(); + options.setTemperature(0.6); + options.setTopP(0.6); + options.setTopK(50); + options.setDecodingMethod("greedy"); + options.setMaxNewTokens(100); + options.setMinNewTokens(10); + options.setStopSequences(List.of("test")); + options.setRepetitionPenalty(1.2); + options.setRandomSeed(42); + options.setModel("test-model"); + options.addAdditionalProperty("key1", "value1"); + + assertThat(options.getTemperature()).isEqualTo(0.6); + assertThat(options.getTopP()).isEqualTo(0.6); + assertThat(options.getTopK()).isEqualTo(50); + assertThat(options.getDecodingMethod()).isEqualTo("greedy"); + assertThat(options.getMaxNewTokens()).isEqualTo(100); + assertThat(options.getMinNewTokens()).isEqualTo(10); + assertThat(options.getStopSequences()).isEqualTo(List.of("test")); + assertThat(options.getRepetitionPenalty()).isEqualTo(1.2); + assertThat(options.getRandomSeed()).isEqualTo(42); + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getAdditionalProperties()).isEqualTo(Map.of("key1", "value1")); + } + + @Test + void testDefaultValues() { + WatsonxAiChatOptions options = new WatsonxAiChatOptions(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getTopK()).isNull(); + assertThat(options.getDecodingMethod()).isNull(); + assertThat(options.getMaxNewTokens()).isNull(); + assertThat(options.getMinNewTokens()).isNull(); + assertThat(options.getStopSequences()).isNull(); + assertThat(options.getRepetitionPenalty()).isNull(); + assertThat(options.getRandomSeed()).isNull(); + assertThat(options.getModel()).isNull(); + assertThat(options.getAdditionalProperties()).isEqualTo(new HashMap<>()); + } + +} diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java index f57303b1138..205f706b91c 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -40,6 +41,7 @@ * @author Geng Rong * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @author Alexandros Pappas * @since 1.0.0 M1 */ @JsonInclude(Include.NON_NULL) @@ -130,7 +132,7 @@ public class ZhiPuAiChatOptions implements FunctionCallingOptions { private Boolean proxyToolCalls; @JsonIgnore - private Map toolContext; + private Map toolContext = new HashMap<>(); // @formatter:on public static Builder builder() { @@ -141,7 +143,7 @@ public static ZhiPuAiChatOptions fromOptions(ZhiPuAiChatOptions fromOptions) { return ZhiPuAiChatOptions.builder() .model(fromOptions.getModel()) .maxTokens(fromOptions.getMaxTokens()) - .stop(fromOptions.getStop()) + .stop(fromOptions.getStop() != null ? new ArrayList<>(fromOptions.getStop()) : null) .temperature(fromOptions.getTemperature()) .topP(fromOptions.getTopP()) .tools(fromOptions.getTools()) @@ -149,10 +151,11 @@ public static ZhiPuAiChatOptions fromOptions(ZhiPuAiChatOptions fromOptions) { .user(fromOptions.getUser()) .requestId(fromOptions.getRequestId()) .doSample(fromOptions.getDoSample()) - .functionCallbacks(fromOptions.getFunctionCallbacks()) - .functions(fromOptions.getFunctions()) + .functionCallbacks(fromOptions.getFunctionCallbacks() != null + ? new ArrayList<>(fromOptions.getFunctionCallbacks()) : null) + .functions(fromOptions.getFunctions() != null ? new HashSet<>(fromOptions.getFunctions()) : null) .proxyToolCalls(fromOptions.getProxyToolCalls()) - .toolContext(fromOptions.getToolContext()) + .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .build(); } @@ -309,19 +312,8 @@ public void setToolContext(Map toolContext) { @Override public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); - result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); - result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); - result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); - result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); - result = prime * result + ((this.tools == null) ? 0 : this.tools.hashCode()); - result = prime * result + ((this.toolChoice == null) ? 0 : this.toolChoice.hashCode()); - result = prime * result + ((this.user == null) ? 0 : this.user.hashCode()); - result = prime * result + ((this.proxyToolCalls == null) ? 0 : this.proxyToolCalls.hashCode()); - result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode()); - return result; + return Objects.hash(this.model, this.maxTokens, this.stop, this.temperature, this.topP, this.tools, + this.toolChoice, this.user, this.requestId, this.doSample, this.proxyToolCalls, this.toolContext); } @Override @@ -329,113 +321,22 @@ public boolean equals(Object obj) { if (this == obj) { return true; } - if (obj == null) { - return false; - } - if (getClass() != obj.getClass()) { + if (obj == null || getClass() != obj.getClass()) { return false; } ZhiPuAiChatOptions other = (ZhiPuAiChatOptions) obj; - if (this.model == null) { - if (other.model != null) { - return false; - } - } - else if (!this.model.equals(other.model)) { - return false; - } - if (this.maxTokens == null) { - if (other.maxTokens != null) { - return false; - } - } - else if (!this.maxTokens.equals(other.maxTokens)) { - return false; - } - if (this.stop == null) { - if (other.stop != null) { - return false; - } - } - else if (!this.stop.equals(other.stop)) { - return false; - } - if (this.temperature == null) { - if (other.temperature != null) { - return false; - } - } - else if (!this.temperature.equals(other.temperature)) { - return false; - } - if (this.topP == null) { - if (other.topP != null) { - return false; - } - } - else if (!this.topP.equals(other.topP)) { - return false; - } - if (this.tools == null) { - if (other.tools != null) { - return false; - } - } - else if (!this.tools.equals(other.tools)) { - return false; - } - if (this.toolChoice == null) { - if (other.toolChoice != null) { - return false; - } - } - else if (!this.toolChoice.equals(other.toolChoice)) { - return false; - } - if (this.user == null) { - if (other.user != null) { - return false; - } - } - else if (!this.user.equals(other.user)) { - return false; - } - if (this.requestId == null) { - if (other.requestId != null) { - return false; - } - } - else if (!this.requestId.equals(other.requestId)) { - return false; - } - if (this.doSample == null) { - if (other.doSample != null) { - return false; - } - } - else if (!this.doSample.equals(other.doSample)) { - return false; - } - if (this.proxyToolCalls == null) { - if (other.proxyToolCalls != null) { - return false; - } - } - else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) { - return false; - } - if (this.toolContext == null) { - if (other.toolContext != null) { - return false; - } - } - else if (!this.toolContext.equals(other.toolContext)) { - return false; - } - return true; + + return Objects.equals(this.model, other.model) && Objects.equals(this.maxTokens, other.maxTokens) + && Objects.equals(this.stop, other.stop) && Objects.equals(this.temperature, other.temperature) + && Objects.equals(this.topP, other.topP) && Objects.equals(this.tools, other.tools) + && Objects.equals(this.toolChoice, other.toolChoice) && Objects.equals(this.user, other.user) + && Objects.equals(this.requestId, other.requestId) && Objects.equals(this.doSample, other.doSample) + && Objects.equals(this.proxyToolCalls, other.proxyToolCalls) + && Objects.equals(this.toolContext, other.toolContext); } @Override + @SuppressWarnings("unchecked") public ZhiPuAiChatOptions copy() { return fromOptions(this); } diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptionsTests.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptionsTests.java new file mode 100644 index 00000000000..8bb115a3483 --- /dev/null +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptionsTests.java @@ -0,0 +1,131 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.zhipuai; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link ZhiPuAiChatOptions}. + * + * @author Alexandros Pappas + */ +class ZhiPuAiChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + ZhiPuAiChatOptions options = ZhiPuAiChatOptions.builder() + .model("test-model") + .maxTokens(100) + .stop(List.of("test")) + .temperature(0.6) + .topP(0.7) + .toolChoice("auto") + .user("test-user") + .requestId("12345") + .doSample(true) + .proxyToolCalls(true) + .toolContext(Map.of("key1", "value1")) + .build(); + + assertThat(options) + .extracting("model", "maxTokens", "stop", "temperature", "topP", "toolChoice", "user", "requestId", + "doSample", "proxyToolCalls", "toolContext") + .containsExactly("test-model", 100, List.of("test"), 0.6, 0.7, "auto", "test-user", "12345", true, true, + Map.of("key1", "value1")); + } + + @Test + void testCopy() { + ZhiPuAiChatOptions original = ZhiPuAiChatOptions.builder() + .model("test-model") + .maxTokens(100) + .stop(List.of("test")) + .temperature(0.6) + .topP(0.7) + .toolChoice("auto") + .user("test-user") + .requestId("12345") + .doSample(true) + .proxyToolCalls(true) + .toolContext(Map.of("key1", "value1")) + .build(); + + ZhiPuAiChatOptions copied = original.copy(); + + assertThat(copied).isNotSameAs(original).isEqualTo(original); + assertThat(copied.getStop()).isNotSameAs(original.getStop()); + assertThat(copied.getFunctionCallbacks()).isNotSameAs(original.getFunctionCallbacks()); + assertThat(copied.getFunctions()).isNotSameAs(original.getFunctions()); + assertThat(copied.getToolContext()).isNotSameAs(original.getToolContext()); + } + + @Test + void testSetters() { + ZhiPuAiChatOptions options = new ZhiPuAiChatOptions(); + options.setModel("test-model"); + options.setMaxTokens(100); + options.setStop(List.of("test")); + options.setTemperature(0.6); + options.setTopP(0.7); + options.setToolChoice("auto"); + options.setUser("test-user"); + options.setRequestId("12345"); + options.setDoSample(true); + options.setProxyToolCalls(true); + options.setToolContext(Map.of("key1", "value1")); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getMaxTokens()).isEqualTo(100); + assertThat(options.getStop()).isEqualTo(List.of("test")); + assertThat(options.getTemperature()).isEqualTo(0.6); + assertThat(options.getTopP()).isEqualTo(0.7); + assertThat(options.getToolChoice()).isEqualTo("auto"); + assertThat(options.getUser()).isEqualTo("test-user"); + assertThat(options.getRequestId()).isEqualTo("12345"); + assertThat(options.getDoSample()).isTrue(); + assertThat(options.getProxyToolCalls()).isTrue(); + assertThat(options.getToolContext()).isEqualTo(Map.of("key1", "value1")); + } + + @Test + void testDefaultValues() { + ZhiPuAiChatOptions options = new ZhiPuAiChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getStop()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getTools()).isNull(); + assertThat(options.getToolChoice()).isNull(); + assertThat(options.getUser()).isNull(); + assertThat(options.getRequestId()).isNull(); + assertThat(options.getDoSample()).isNull(); + assertThat(options.getFunctionCallbacks()).isEqualTo(new ArrayList<>()); + assertThat(options.getFunctions()).isEqualTo(new HashSet<>()); + assertThat(options.getProxyToolCalls()).isNull(); + assertThat(options.getToolContext()).isEqualTo(new HashMap<>()); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java index 1af33bf3467..02c9356754b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,6 +19,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Objects; /** * Default implementation for the {@link ChatOptions}. @@ -128,4 +129,27 @@ public T copy() { return (T) copy; } + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + DefaultChatOptions that = (DefaultChatOptions) o; + + return Objects.equals(this.model, that.model) && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) + && Objects.equals(this.maxTokens, that.maxTokens) + && Objects.equals(this.presencePenalty, that.presencePenalty) + && Objects.equals(this.stopSequences, that.stopSequences) + && Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topK, that.topK) + && Objects.equals(this.topP, that.topP); + } + + @Override + public int hashCode() { + return Objects.hash(model, maxTokens, frequencyPenalty, presencePenalty, stopSequences, temperature, topP, + topK); + } + } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/DefaultChatOptionsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/DefaultChatOptionsTests.java new file mode 100644 index 00000000000..11a096a535c --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/DefaultChatOptionsTests.java @@ -0,0 +1,106 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.prompt; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link DefaultChatOptions}. + * + * @author Alexandros Pappas + */ +class DefaultChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + ChatOptions options = ChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.5) + .maxTokens(100) + .presencePenalty(0.6) + .stopSequences(List.of("stop1", "stop2")) + .temperature(0.7) + .topK(50) + .topP(0.8) + .build(); + + assertThat(options) + .extracting("model", "frequencyPenalty", "maxTokens", "presencePenalty", "stopSequences", "temperature", + "topK", "topP") + .containsExactly("test-model", 0.5, 100, 0.6, List.of("stop1", "stop2"), 0.7, 50, 0.8); + } + + @Test + void testCopy() { + ChatOptions original = ChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.5) + .maxTokens(100) + .presencePenalty(0.6) + .stopSequences(List.of("stop1", "stop2")) + .temperature(0.7) + .topK(50) + .topP(0.8) + .build(); + + DefaultChatOptions copied = original.copy(); + + assertThat(copied).isNotSameAs(original).isEqualTo(original); + // Ensure deep copy + assertThat(copied.getStopSequences()).isNotSameAs(original.getStopSequences()); + } + + @Test + void testSetters() { + DefaultChatOptions options = new DefaultChatOptions(); + options.setModel("test-model"); + options.setFrequencyPenalty(0.5); + options.setMaxTokens(100); + options.setPresencePenalty(0.6); + options.setStopSequences(List.of("stop1", "stop2")); + options.setTemperature(0.7); + options.setTopK(50); + options.setTopP(0.8); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getFrequencyPenalty()).isEqualTo(0.5); + assertThat(options.getMaxTokens()).isEqualTo(100); + assertThat(options.getPresencePenalty()).isEqualTo(0.6); + assertThat(options.getStopSequences()).isEqualTo(List.of("stop1", "stop2")); + assertThat(options.getTemperature()).isEqualTo(0.7); + assertThat(options.getTopK()).isEqualTo(50); + assertThat(options.getTopP()).isEqualTo(0.8); + } + + @Test + void testDefaultValues() { + DefaultChatOptions options = new DefaultChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getFrequencyPenalty()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getPresencePenalty()).isNull(); + assertThat(options.getStopSequences()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopK()).isNull(); + assertThat(options.getTopP()).isNull(); + } + +}