diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index 229b5113126..87bbf6e9ef9 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,9 +17,11 @@ package org.springframework.ai.anthropic; import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -29,6 +31,7 @@ import org.springframework.ai.anthropic.api.AnthropicApi; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest; +import org.springframework.ai.chat.prompt.AbstractChatOptions; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.util.Assert; @@ -42,16 +45,11 @@ * @since 1.0.0 */ @JsonInclude(Include.NON_NULL) -public class AnthropicChatOptions implements FunctionCallingOptions { +public class AnthropicChatOptions extends AbstractChatOptions implements FunctionCallingOptions { // @formatter:off - private @JsonProperty("model") String model; - private @JsonProperty("max_tokens") Integer maxTokens; + private @JsonProperty("metadata") ChatCompletionRequest.Metadata metadata; - private @JsonProperty("stop_sequences") List stopSequences; - private @JsonProperty("temperature") Double temperature; - private @JsonProperty("top_p") Double topP; - private @JsonProperty("top_k") Integer topK; /** * Tool Function Callbacks to register with the ChatModel. For Prompt @@ -81,7 +79,7 @@ public class AnthropicChatOptions implements FunctionCallingOptions { private Boolean proxyToolCalls; @JsonIgnore - private Map toolContext; + private Map toolContext = new HashMap<>(); // @formatter:on @@ -93,31 +91,22 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) return builder().model(fromOptions.getModel()) .maxTokens(fromOptions.getMaxTokens()) .metadata(fromOptions.getMetadata()) - .stopSequences(fromOptions.getStopSequences()) + .stopSequences( + fromOptions.getStopSequences() != null ? new ArrayList<>(fromOptions.getStopSequences()) : null) .temperature(fromOptions.getTemperature()) .topP(fromOptions.getTopP()) .topK(fromOptions.getTopK()) .functionCallbacks(fromOptions.getFunctionCallbacks()) .functions(fromOptions.getFunctions()) .proxyToolCalls(fromOptions.getProxyToolCalls()) - .toolContext(fromOptions.getToolContext()) + .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .build(); } - @Override - public String getModel() { - return this.model; - } - public void setModel(String model) { this.model = model; } - @Override - public Integer getMaxTokens() { - return this.maxTokens; - } - public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } @@ -130,38 +119,18 @@ public void setMetadata(ChatCompletionRequest.Metadata metadata) { this.metadata = metadata; } - @Override - public List getStopSequences() { - return this.stopSequences; - } - public void setStopSequences(List stopSequences) { this.stopSequences = stopSequences; } - @Override - public Double getTemperature() { - return this.temperature; - } - public void setTemperature(Double temperature) { this.temperature = temperature; } - @Override - public Double getTopP() { - return this.topP; - } - public void setTopP(Double topP) { this.topP = topP; } - @Override - public Integer getTopK() { - return this.topK; - } - public void setTopK(Integer topK) { this.topK = topK; } @@ -224,6 +193,43 @@ public AnthropicChatOptions copy() { return fromOptions(this); } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof AnthropicChatOptions that)) { + return false; + } + return Objects.equals(this.model, that.model) && Objects.equals(this.maxTokens, that.maxTokens) + && Objects.equals(this.metadata, that.metadata) + && Objects.equals(this.stopSequences, that.stopSequences) + && Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP) + && Objects.equals(this.topK, that.topK) + && Objects.equals(this.functionCallbacks, that.functionCallbacks) + && Objects.equals(this.functions, that.functions) + && Objects.equals(this.proxyToolCalls, that.proxyToolCalls) + && Objects.equals(this.toolContext, that.toolContext); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + (this.model != null ? this.model.hashCode() : 0); + result = prime * result + (this.maxTokens != null ? this.maxTokens.hashCode() : 0); + result = prime * result + (this.metadata != null ? this.metadata.hashCode() : 0); + result = prime * result + (this.stopSequences != null ? this.stopSequences.hashCode() : 0); + result = prime * result + (this.temperature != null ? this.temperature.hashCode() : 0); + result = prime * result + (this.topP != null ? this.topP.hashCode() : 0); + result = prime * result + (this.topK != null ? this.topK.hashCode() : 0); + result = prime * result + (this.functionCallbacks != null ? this.functionCallbacks.hashCode() : 0); + result = prime * result + (this.functions != null ? this.functions.hashCode() : 0); + result = prime * result + (this.proxyToolCalls != null ? this.proxyToolCalls.hashCode() : 0); + result = prime * result + (this.toolContext != null ? this.toolContext.hashCode() : 0); + return result; + } + public static class Builder { private final AnthropicChatOptions options = new AnthropicChatOptions(); diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java new file mode 100644 index 00000000000..fbb8409dffc --- /dev/null +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatOptionsTests.java @@ -0,0 +1,104 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.anthropic; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.Metadata; + +/** + * Tests for {@link AnthropicChatOptions}. + * + * @author Alexandros Pappas + */ +class AnthropicChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + AnthropicChatOptions options = AnthropicChatOptions.builder() + .model("test-model") + .maxTokens(100) + .stopSequences(List.of("stop1", "stop2")) + .temperature(0.7) + .topP(0.8) + .topK(50) + .metadata(new Metadata("userId_123")) + .build(); + + assertThat(options).extracting("model", "maxTokens", "stopSequences", "temperature", "topP", "topK", "metadata") + .containsExactly("test-model", 100, List.of("stop1", "stop2"), 0.7, 0.8, 50, new Metadata("userId_123")); + } + + @Test + void testCopy() { + AnthropicChatOptions original = AnthropicChatOptions.builder() + .model("test-model") + .maxTokens(100) + .stopSequences(List.of("stop1", "stop2")) + .temperature(0.7) + .topP(0.8) + .topK(50) + .metadata(new Metadata("userId_123")) + .toolContext(Map.of("key1", "value1")) + .build(); + + AnthropicChatOptions copied = original.copy(); + + assertThat(copied).isNotSameAs(original).isEqualTo(original); + // Ensure deep copy + assertThat(copied.getStopSequences()).isNotSameAs(original.getStopSequences()); + assertThat(copied.getToolContext()).isNotSameAs(original.getToolContext()); + } + + @Test + void testSetters() { + AnthropicChatOptions options = new AnthropicChatOptions(); + options.setModel("test-model"); + options.setMaxTokens(100); + options.setTemperature(0.7); + options.setTopK(50); + options.setTopP(0.8); + options.setStopSequences(List.of("stop1", "stop2")); + options.setMetadata(new Metadata("userId_123")); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getMaxTokens()).isEqualTo(100); + assertThat(options.getTemperature()).isEqualTo(0.7); + assertThat(options.getTopK()).isEqualTo(50); + assertThat(options.getTopP()).isEqualTo(0.8); + assertThat(options.getStopSequences()).isEqualTo(List.of("stop1", "stop2")); + assertThat(options.getMetadata()).isEqualTo(new Metadata("userId_123")); + } + + @Test + void testDefaultValues() { + AnthropicChatOptions options = new AnthropicChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopK()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getStopSequences()).isNull(); + assertThat(options.getMetadata()).isNull(); + } + +} diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index d34ca5dfbd1..d3c8be71353 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -17,9 +17,11 @@ package org.springframework.ai.azure.openai; import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; @@ -29,6 +31,7 @@ import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.chat.prompt.AbstractChatOptions; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.util.Assert; @@ -42,36 +45,10 @@ * @author Thomas Vitale * @author Soby Chacko * @author Ilayaperumal Gopinathan + * @author Alexandros Pappas */ @JsonInclude(Include.NON_NULL) -public class AzureOpenAiChatOptions implements FunctionCallingOptions { - - /** - * The maximum number of tokens to generate. - */ - @JsonProperty("max_tokens") - private Integer maxTokens; - - /** - * The sampling temperature to use that controls the apparent creativity of generated - * completions. Higher values will make output more random while lower values will - * make results more focused and deterministic. It is not recommended to modify - * temperature and top_p for the same completions request as the interaction of these - * two settings is difficult to predict. - */ - @JsonProperty("temperature") - private Double temperature; - - /** - * An alternative to sampling with temperature called nucleus sampling. This value - * causes the model to consider the results of tokens with the provided probability - * mass. As an example, a value of 0.15 will cause only the tokens comprising the top - * 15% of probability mass to be considered. It is not recommended to modify - * temperature and top_p for the same completions request as the interaction of these - * two settings is difficult to predict. - */ - @JsonProperty("top_p") - private Double topP; +public class AzureOpenAiChatOptions extends AbstractChatOptions implements FunctionCallingOptions { /** * A map between GPT token IDs and bias scores that influences the probability of @@ -105,24 +82,6 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions { @JsonProperty("stop") private List stop; - /** - * A value that influences the probability of generated tokens appearing based on - * their existing presence in generated text. Positive values will make tokens less - * likely to appear when they already exist and increase the model's likelihood to - * output new topics. - */ - @JsonProperty("presence_penalty") - private Double presencePenalty; - - /** - * A value that influences the probability of generated tokens appearing based on - * their cumulative frequency in generated text. Positive values will make tokens less - * likely to appear as their frequency increases and decrease the likelihood of the - * model repeating the same statements verbatim. - */ - @JsonProperty("frequency_penalty") - private Double frequencyPenalty; - /** * The deployment name as defined in Azure Open AI Studio when creating a deployment * backed by an Azure OpenAI base model. @@ -199,7 +158,7 @@ public class AzureOpenAiChatOptions implements FunctionCallingOptions { private ChatCompletionStreamOptions streamOptions; @JsonIgnore - private Map toolContext; + private Map toolContext = new HashMap<>(); public static Builder builder() { return new Builder(); @@ -212,18 +171,18 @@ public static AzureOpenAiChatOptions fromOptions(AzureOpenAiChatOptions fromOpti .maxTokens(fromOptions.getMaxTokens()) .N(fromOptions.getN()) .presencePenalty(fromOptions.getPresencePenalty() != null ? fromOptions.getPresencePenalty() : null) - .stop(fromOptions.getStop()) + .stop(fromOptions.getStop() != null ? new ArrayList<>(fromOptions.getStop()) : null) .temperature(fromOptions.getTemperature()) .topP(fromOptions.getTopP()) .user(fromOptions.getUser()) .functionCallbacks(fromOptions.getFunctionCallbacks()) - .functions(fromOptions.getFunctions()) + .functions(fromOptions.getFunctions() != null ? new HashSet<>(fromOptions.getFunctions()) : null) .responseFormat(fromOptions.getResponseFormat()) .seed(fromOptions.getSeed()) .logprobs(fromOptions.isLogprobs()) .topLogprobs(fromOptions.getTopLogProbs()) .enhancements(fromOptions.getEnhancements()) - .toolContext(fromOptions.getToolContext()) + .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .streamOptions(fromOptions.getStreamOptions()) .build(); } @@ -280,20 +239,10 @@ public void setStop(List stop) { this.stop = stop; } - @Override - public Double getPresencePenalty() { - return this.presencePenalty; - } - public void setPresencePenalty(Double presencePenalty) { this.presencePenalty = presencePenalty; } - @Override - public Double getFrequencyPenalty() { - return this.frequencyPenalty; - } - public void setFrequencyPenalty(Double frequencyPenalty) { this.frequencyPenalty = frequencyPenalty; } @@ -317,11 +266,6 @@ public void setDeploymentName(String deploymentName) { this.deploymentName = deploymentName; } - @Override - public Double getTemperature() { - return this.temperature; - } - public void setTemperature(Double temperature) { this.temperature = temperature; } @@ -427,10 +371,62 @@ public void setStreamOptions(ChatCompletionStreamOptions streamOptions) { } @Override + @SuppressWarnings("") public AzureOpenAiChatOptions copy() { return fromOptions(this); } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof AzureOpenAiChatOptions that)) { + return false; + } + return Objects.equals(this.logitBias, that.logitBias) && Objects.equals(this.user, that.user) + && Objects.equals(this.n, that.n) && Objects.equals(this.stop, that.stop) + && Objects.equals(this.deploymentName, that.deploymentName) + && Objects.equals(this.responseFormat, that.responseFormat) + && Objects.equals(this.functionCallbacks, that.functionCallbacks) + && Objects.equals(this.functions, that.functions) + && Objects.equals(this.proxyToolCalls, that.proxyToolCalls) && Objects.equals(this.seed, that.seed) + && Objects.equals(this.logprobs, that.logprobs) && Objects.equals(this.topLogProbs, that.topLogProbs) + && Objects.equals(this.enhancements, that.enhancements) + && Objects.equals(this.streamOptions, that.streamOptions) + && Objects.equals(this.toolContext, that.toolContext) && Objects.equals(this.maxTokens, that.maxTokens) + && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) + && Objects.equals(this.presencePenalty, that.presencePenalty) + && Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + (this.logitBias != null ? this.logitBias.hashCode() : 0); + result = prime * result + (this.user != null ? this.user.hashCode() : 0); + result = prime * result + (this.n != null ? this.n.hashCode() : 0); + result = prime * result + (this.stop != null ? this.stop.hashCode() : 0); + result = prime * result + (this.deploymentName != null ? this.deploymentName.hashCode() : 0); + result = prime * result + (this.responseFormat != null ? this.responseFormat.hashCode() : 0); + result = prime * result + (this.functionCallbacks != null ? this.functionCallbacks.hashCode() : 0); + result = prime * result + (this.functions != null ? this.functions.hashCode() : 0); + result = prime * result + (this.proxyToolCalls != null ? this.proxyToolCalls.hashCode() : 0); + result = prime * result + (this.seed != null ? this.seed.hashCode() : 0); + result = prime * result + (this.logprobs != null ? this.logprobs.hashCode() : 0); + result = prime * result + (this.topLogProbs != null ? this.topLogProbs.hashCode() : 0); + result = prime * result + (this.enhancements != null ? this.enhancements.hashCode() : 0); + result = prime * result + (this.streamOptions != null ? this.streamOptions.hashCode() : 0); + result = prime * result + (this.toolContext != null ? this.toolContext.hashCode() : 0); + result = prime * result + (this.maxTokens != null ? this.maxTokens.hashCode() : 0); + result = prime * result + (this.frequencyPenalty != null ? this.frequencyPenalty.hashCode() : 0); + result = prime * result + (this.presencePenalty != null ? this.presencePenalty.hashCode() : 0); + result = prime * result + (this.temperature != null ? this.temperature.hashCode() : 0); + result = prime * result + (this.topP != null ? this.topP.hashCode() : 0); + return result; + } + public static class Builder { protected AzureOpenAiChatOptions options; diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java new file mode 100644 index 00000000000..b3a8bfd6d74 --- /dev/null +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptionsTests.java @@ -0,0 +1,182 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.azure.openai; + +import java.util.List; +import java.util.Map; + +import com.azure.ai.openai.models.AzureChatEnhancementConfiguration; +import com.azure.ai.openai.models.AzureChatGroundingEnhancementConfiguration; +import com.azure.ai.openai.models.AzureChatOCREnhancementConfiguration; +import com.azure.ai.openai.models.ChatCompletionStreamOptions; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link AzureOpenAiChatOptions}. + * + * @author Alexandros Pappas + */ +class AzureOpenAiChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.TEXT; + ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions(); + streamOptions.setIncludeUsage(true); + + AzureChatEnhancementConfiguration enhancements = new AzureChatEnhancementConfiguration(); + enhancements.setOcr(new AzureChatOCREnhancementConfiguration(true)); + enhancements.setGrounding(new AzureChatGroundingEnhancementConfiguration(true)); + + AzureOpenAiChatOptions options = AzureOpenAiChatOptions.builder() + .deploymentName("test-deployment") + .frequencyPenalty(0.5) + .logitBias(Map.of("token1", 1, "token2", -1)) + .maxTokens(200) + .N(2) + .presencePenalty(0.8) + .stop(List.of("stop1", "stop2")) + .temperature(0.7) + .topP(0.9) + .user("test-user") + .responseFormat(responseFormat) + .seed(12345L) + .logprobs(true) + .topLogprobs(5) + .enhancements(enhancements) + .streamOptions(streamOptions) + .build(); + + assertThat(options) + .extracting("deploymentName", "frequencyPenalty", "logitBias", "maxTokens", "n", "presencePenalty", "stop", + "temperature", "topP", "user", "responseFormat", "seed", "logprobs", "topLogProbs", "enhancements", + "streamOptions") + .containsExactly("test-deployment", 0.5, Map.of("token1", 1, "token2", -1), 200, 2, 0.8, + List.of("stop1", "stop2"), 0.7, 0.9, "test-user", responseFormat, 12345L, true, 5, enhancements, + streamOptions); + } + + @Test + void testCopy() { + AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.TEXT; + ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions(); + streamOptions.setIncludeUsage(true); + + AzureChatEnhancementConfiguration enhancements = new AzureChatEnhancementConfiguration(); + enhancements.setOcr(new AzureChatOCREnhancementConfiguration(true)); + enhancements.setGrounding(new AzureChatGroundingEnhancementConfiguration(true)); + + AzureOpenAiChatOptions originalOptions = AzureOpenAiChatOptions.builder() + .deploymentName("test-deployment") + .frequencyPenalty(0.5) + .logitBias(Map.of("token1", 1, "token2", -1)) + .maxTokens(200) + .N(2) + .presencePenalty(0.8) + .stop(List.of("stop1", "stop2")) + .temperature(0.7) + .topP(0.9) + .user("test-user") + .responseFormat(responseFormat) + .seed(12345L) + .logprobs(true) + .topLogprobs(5) + .enhancements(enhancements) + .streamOptions(streamOptions) + .build(); + + AzureOpenAiChatOptions copiedOptions = originalOptions.copy(); + + assertThat(copiedOptions).isNotSameAs(originalOptions).isEqualTo(originalOptions); + // Ensure deep copy + assertThat(copiedOptions.getStop()).isNotSameAs(originalOptions.getStop()); + assertThat(copiedOptions.getToolContext()).isNotSameAs(originalOptions.getToolContext()); + } + + @Test + void testSetters() { + AzureOpenAiResponseFormat responseFormat = AzureOpenAiResponseFormat.TEXT; + ChatCompletionStreamOptions streamOptions = new ChatCompletionStreamOptions(); + streamOptions.setIncludeUsage(true); + AzureChatEnhancementConfiguration enhancements = new AzureChatEnhancementConfiguration(); + + AzureOpenAiChatOptions options = new AzureOpenAiChatOptions(); + options.setDeploymentName("test-deployment"); + options.setFrequencyPenalty(0.5); + options.setLogitBias(Map.of("token1", 1, "token2", -1)); + options.setMaxTokens(200); + options.setN(2); + options.setPresencePenalty(0.8); + options.setStop(List.of("stop1", "stop2")); + options.setTemperature(0.7); + options.setTopP(0.9); + options.setUser("test-user"); + options.setResponseFormat(responseFormat); + options.setSeed(12345L); + options.setLogprobs(true); + options.setTopLogProbs(5); + options.setEnhancements(enhancements); + options.setStreamOptions(streamOptions); + + assertThat(options.getDeploymentName()).isEqualTo("test-deployment"); + options.setModel("test-model"); + assertThat(options.getDeploymentName()).isEqualTo("test-model"); + + assertThat(options.getFrequencyPenalty()).isEqualTo(0.5); + assertThat(options.getLogitBias()).isEqualTo(Map.of("token1", 1, "token2", -1)); + assertThat(options.getMaxTokens()).isEqualTo(200); + assertThat(options.getN()).isEqualTo(2); + assertThat(options.getPresencePenalty()).isEqualTo(0.8); + assertThat(options.getStop()).isEqualTo(List.of("stop1", "stop2")); + assertThat(options.getTemperature()).isEqualTo(0.7); + assertThat(options.getTopP()).isEqualTo(0.9); + assertThat(options.getUser()).isEqualTo("test-user"); + assertThat(options.getResponseFormat()).isEqualTo(responseFormat); + assertThat(options.getSeed()).isEqualTo(12345L); + assertThat(options.isLogprobs()).isTrue(); + assertThat(options.getTopLogProbs()).isEqualTo(5); + assertThat(options.getEnhancements()).isEqualTo(enhancements); + assertThat(options.getStreamOptions()).isEqualTo(streamOptions); + assertThat(options.getModel()).isEqualTo("test-model"); + } + + @Test + void testDefaultValues() { + AzureOpenAiChatOptions options = new AzureOpenAiChatOptions(); + + assertThat(options.getDeploymentName()).isNull(); + assertThat(options.getFrequencyPenalty()).isNull(); + assertThat(options.getLogitBias()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getN()).isNull(); + assertThat(options.getPresencePenalty()).isNull(); + assertThat(options.getStop()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getUser()).isNull(); + assertThat(options.getResponseFormat()).isNull(); + assertThat(options.getSeed()).isNull(); + assertThat(options.isLogprobs()).isNull(); + assertThat(options.getTopLogProbs()).isNull(); + assertThat(options.getEnhancements()).isNull(); + assertThat(options.getStreamOptions()).isNull(); + assertThat(options.getModel()).isNull(); + } + +} diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java index 244c1fce699..8e16213fdd2 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.ai.minimax; import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -27,6 +28,7 @@ import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.chat.prompt.AbstractChatOptions; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.minimax.api.MiniMaxApi; import org.springframework.ai.model.function.FunctionCallback; @@ -43,36 +45,18 @@ * @author Geng Rong * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @author Alexandros Pappas * @since 1.0.0 M1 */ @JsonInclude(Include.NON_NULL) -public class MiniMaxChatOptions implements FunctionCallingOptions { +public class MiniMaxChatOptions extends AbstractChatOptions implements FunctionCallingOptions { // @formatter:off - /** - * ID of the model to use. - */ - private @JsonProperty("model") String model; - /** - * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing - * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. - */ - private @JsonProperty("frequency_penalty") Double frequencyPenalty; - /** - * The maximum number of tokens to generate in the chat completion. The total length of input - * tokens and generated tokens is limited by the model's context length. - */ - private @JsonProperty("max_tokens") Integer maxTokens; /** * How many chat completion choices to generate for each input message. Note that you will be charged based * on the number of generated tokens across all of the choices. Keep n as 1 to minimize costs. */ private @JsonProperty("n") Integer n; - /** - * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they - * appear in the text so far, increasing the model's likelihood to talk about new topics. - */ - private @JsonProperty("presence_penalty") Double presencePenalty; /** * An object specifying the format that the model must output. Setting to { "type": * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. @@ -89,18 +73,7 @@ public class MiniMaxChatOptions implements FunctionCallingOptions { * Up to 4 sequences where the API will stop generating further tokens. */ private @JsonProperty("stop") List stop; - /** - * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output - * more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend - * altering this or top_p but not both. - */ - private @JsonProperty("temperature") Double temperature; - /** - * An alternative to sampling with temperature, called nucleus sampling, where the model considers the - * results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% - * probability mass are considered. We generally recommend altering this or temperature but not both. - */ - private @JsonProperty("top_p") Double topP; + /** * Mask the text information in the output that is easy to involve privacy issues, * including but not limited to email, domain name, link, ID number, home address, etc. @@ -146,7 +119,7 @@ public class MiniMaxChatOptions implements FunctionCallingOptions { private Boolean proxyToolCalls; @JsonIgnore - private Map toolContext; + private Map toolContext = new HashMap<>(); // @formatter:on @@ -162,7 +135,7 @@ public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) { .presencePenalty(fromOptions.getPresencePenalty()) .responseFormat(fromOptions.getResponseFormat()) .seed(fromOptions.getSeed()) - .stop(fromOptions.getStop()) + .stop(fromOptions.getStop() != null ? new ArrayList<>(fromOptions.getStop()) : null) .temperature(fromOptions.getTemperature()) .topP(fromOptions.getTopP()) .maskSensitiveInfo(fromOptions.getMaskSensitiveInfo()) @@ -171,33 +144,18 @@ public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) { .functionCallbacks(fromOptions.getFunctionCallbacks()) .functions(fromOptions.getFunctions()) .proxyToolCalls(fromOptions.getProxyToolCalls()) - .toolContext(fromOptions.getToolContext()) + .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .build(); } - @Override - public String getModel() { - return this.model; - } - public void setModel(String model) { this.model = model; } - @Override - public Double getFrequencyPenalty() { - return this.frequencyPenalty; - } - public void setFrequencyPenalty(Double frequencyPenalty) { this.frequencyPenalty = frequencyPenalty; } - @Override - public Integer getMaxTokens() { - return this.maxTokens; - } - public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } @@ -210,11 +168,6 @@ public void setN(Integer n) { this.n = n; } - @Override - public Double getPresencePenalty() { - return this.presencePenalty; - } - public void setPresencePenalty(Double presencePenalty) { this.presencePenalty = presencePenalty; } @@ -254,20 +207,10 @@ public void setStop(List stop) { this.stop = stop; } - @Override - public Double getTemperature() { - return this.temperature; - } - public void setTemperature(Double temperature) { this.temperature = temperature; } - @Override - public Double getTopP() { - return this.topP; - } - public void setTopP(Double topP) { this.topP = topP; } diff --git a/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/MiniMaxChatOptionsTests.java b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/MiniMaxChatOptionsTests.java new file mode 100644 index 00000000000..6fbd605fe48 --- /dev/null +++ b/models/spring-ai-minimax/src/test/java/org/springframework/ai/minimax/MiniMaxChatOptionsTests.java @@ -0,0 +1,141 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.minimax; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.minimax.api.MiniMaxApi; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link MiniMaxChatOptions}. + * + * @author Alexandros Pappas + */ +class MiniMaxChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + MiniMaxChatOptions options = MiniMaxChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.5) + .maxTokens(10) + .N(1) + .presencePenalty(0.5) + .responseFormat(new MiniMaxApi.ChatCompletionRequest.ResponseFormat("text")) + .seed(1) + .stop(List.of("test")) + .temperature(0.6) + .topP(0.6) + .maskSensitiveInfo(false) + .toolChoice("test") + .proxyToolCalls(true) + .toolContext(Map.of("key1", "value1")) + .build(); + + assertThat(options) + .extracting("model", "frequencyPenalty", "maxTokens", "N", "presencePenalty", "responseFormat", "seed", + "stopSequences", "temperature", "topP", "maskSensitiveInfo", "toolChoice", "proxyToolCalls", + "toolContext") + .containsExactly("test-model", 0.5, 10, 1, 0.5, new MiniMaxApi.ChatCompletionRequest.ResponseFormat("text"), + 1, List.of("test"), 0.6, 0.6, false, "test", true, Map.of("key1", "value1")); + } + + @Test + void testCopy() { + MiniMaxChatOptions original = MiniMaxChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.5) + .maxTokens(10) + .N(1) + .presencePenalty(0.5) + .responseFormat(new MiniMaxApi.ChatCompletionRequest.ResponseFormat("text")) + .seed(1) + .stop(List.of("test")) + .temperature(0.6) + .topP(0.6) + .maskSensitiveInfo(false) + .toolChoice("test") + .proxyToolCalls(true) + .toolContext(Map.of("key1", "value1")) + .build(); + + MiniMaxChatOptions copied = original.copy(); + + assertThat(copied).isNotSameAs(original).isEqualTo(original); + // Ensure deep copy + assertThat(copied.getStop()).isNotSameAs(original.getStop()); + assertThat(copied.getToolContext()).isNotSameAs(original.getToolContext()); + } + + @Test + void testSetters() { + MiniMaxChatOptions options = new MiniMaxChatOptions(); + options.setModel("test-model"); + options.setFrequencyPenalty(0.5); + options.setMaxTokens(10); + options.setN(1); + options.setPresencePenalty(0.5); + options.setResponseFormat(new MiniMaxApi.ChatCompletionRequest.ResponseFormat("text")); + options.setSeed(1); + options.setStopSequences(List.of("test")); + options.setTemperature(0.6); + options.setTopP(0.6); + options.setMaskSensitiveInfo(false); + options.setToolChoice("test"); + options.setProxyToolCalls(true); + options.setToolContext(Map.of("key1", "value1")); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getFrequencyPenalty()).isEqualTo(0.5); + assertThat(options.getMaxTokens()).isEqualTo(10); + assertThat(options.getN()).isEqualTo(1); + assertThat(options.getPresencePenalty()).isEqualTo(0.5); + assertThat(options.getResponseFormat()).isEqualTo(new MiniMaxApi.ChatCompletionRequest.ResponseFormat("text")); + assertThat(options.getSeed()).isEqualTo(1); + assertThat(options.getStopSequences()).isEqualTo(List.of("test")); + assertThat(options.getTemperature()).isEqualTo(0.6); + assertThat(options.getTopP()).isEqualTo(0.6); + assertThat(options.getMaskSensitiveInfo()).isEqualTo(false); + assertThat(options.getToolChoice()).isEqualTo("test"); + assertThat(options.getProxyToolCalls()).isEqualTo(true); + assertThat(options.getToolContext()).isEqualTo(Map.of("key1", "value1")); + } + + @Test + void testDefaultValues() { + MiniMaxChatOptions options = new MiniMaxChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getFrequencyPenalty()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getN()).isNull(); + assertThat(options.getPresencePenalty()).isNull(); + assertThat(options.getResponseFormat()).isNull(); + assertThat(options.getSeed()).isNull(); + assertThat(options.getStopSequences()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getMaskSensitiveInfo()).isNull(); + assertThat(options.getToolChoice()).isNull(); + assertThat(options.getProxyToolCalls()).isNull(); + assertThat(options.getToolContext()).isEqualTo(new java.util.HashMap<>()); + } + +} diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java index 2524836e001..5e4583c01d3 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java @@ -17,6 +17,7 @@ package org.springframework.ai.mistralai; import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -27,6 +28,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.chat.prompt.AbstractChatOptions; import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ResponseFormat; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice; @@ -45,33 +47,7 @@ * @since 0.8.1 */ @JsonInclude(JsonInclude.Include.NON_NULL) -public class MistralAiChatOptions implements FunctionCallingOptions { - - /** - * ID of the model to use - */ - private @JsonProperty("model") String model; - - /** - * What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will - * make the output more random, while lower values like 0.2 will make it more focused - * and deterministic. We generally recommend altering this or top_p but not both. - */ - private @JsonProperty("temperature") Double temperature; - - /** - * Nucleus sampling, where the model considers the results of the tokens with top_p - * probability mass. So 0.1 means only the tokens comprising the top 10% probability - * mass are considered. We generally recommend altering this or temperature but not - * both. - */ - private @JsonProperty("top_p") Double topP; - - /** - * The maximum number of tokens to generate in the completion. The token count of your - * prompt plus max_tokens cannot exceed the model's context length. - */ - private @JsonProperty("max_tokens") Integer maxTokens; +public class MistralAiChatOptions extends AbstractChatOptions implements FunctionCallingOptions { /** * Whether to inject a safety prompt before all conversations. @@ -139,7 +115,7 @@ public class MistralAiChatOptions implements FunctionCallingOptions { private Boolean proxyToolCalls; @JsonIgnore - private Map toolContext; + private Map toolContext = new HashMap<>(); public static Builder builder() { return new Builder(); @@ -153,30 +129,20 @@ public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions) .temperature(fromOptions.getTemperature()) .topP(fromOptions.getTopP()) .responseFormat(fromOptions.getResponseFormat()) - .stop(fromOptions.getStop()) + .stop(fromOptions.getStop() != null ? new ArrayList<>(fromOptions.getStop()) : null) .tools(fromOptions.getTools()) .toolChoice(fromOptions.getToolChoice()) .functionCallbacks(fromOptions.getFunctionCallbacks()) .functions(fromOptions.getFunctions()) .proxyToolCalls(fromOptions.getProxyToolCalls()) - .toolContext(fromOptions.getToolContext()) + .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .build(); } - @Override - public String getModel() { - return this.model; - } - public void setModel(String model) { this.model = model; } - @Override - public Integer getMaxTokens() { - return this.maxTokens; - } - public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } @@ -240,20 +206,10 @@ public void setToolChoice(ToolChoice toolChoice) { this.toolChoice = toolChoice; } - @Override - public Double getTemperature() { - return this.temperature; - } - public void setTemperature(Double temperature) { this.temperature = temperature; } - @Override - public Double getTopP() { - return this.topP; - } - public void setTopP(Double topP) { this.topP = topP; } @@ -318,13 +274,13 @@ public void setToolContext(Map toolContext) { } @Override + @SuppressWarnings("unchecked") public MistralAiChatOptions copy() { return fromOptions(this); } @Override public int hashCode() { - return Objects.hash(this.model, this.temperature, this.topP, this.maxTokens, this.safePrompt, this.randomSeed, this.responseFormat, this.stop, this.tools, this.toolChoice, this.functionCallbacks, this.functions, this.proxyToolCalls, this.toolContext); diff --git a/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatOptionsTests.java b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatOptionsTests.java new file mode 100644 index 00000000000..2871eb0f576 --- /dev/null +++ b/models/spring-ai-mistral-ai/src/test/java/org/springframework/ai/mistralai/MistralAiChatOptionsTests.java @@ -0,0 +1,127 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mistralai; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ResponseFormat; + +import org.springframework.ai.mistralai.api.MistralAiApi; + +/** + * Tests for {@link MistralAiChatOptions}. + * + * @author Alexandros Pappas + */ +class MistralAiChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + MistralAiChatOptions options = MistralAiChatOptions.builder() + .model("test-model") + .temperature(0.7) + .topP(0.9) + .maxTokens(100) + .safePrompt(true) + .randomSeed(123) + .stop(List.of("stop1", "stop2")) + .responseFormat(new ResponseFormat("json_object")) + .toolChoice(MistralAiApi.ChatCompletionRequest.ToolChoice.AUTO) + .proxyToolCalls(true) + .toolContext(Map.of("key1", "value1")) + .build(); + + assertThat(options) + .extracting("model", "temperature", "topP", "maxTokens", "safePrompt", "randomSeed", "stop", + "responseFormat", "toolChoice", "proxyToolCalls", "toolContext") + .containsExactly("test-model", 0.7, 0.9, 100, true, 123, List.of("stop1", "stop2"), + new ResponseFormat("json_object"), MistralAiApi.ChatCompletionRequest.ToolChoice.AUTO, true, + Map.of("key1", "value1")); + } + + @Test + void testBuilderWithEnum() { + MistralAiChatOptions optionsWithEnum = MistralAiChatOptions.builder() + .model(MistralAiApi.ChatModel.OPEN_MISTRAL_7B) + .build(); + assertThat(optionsWithEnum.getModel()).isEqualTo(MistralAiApi.ChatModel.OPEN_MISTRAL_7B.getValue()); + } + + @Test + void testCopy() { + MistralAiChatOptions options = MistralAiChatOptions.builder() + .model("test-model") + .temperature(0.7) + .topP(0.9) + .maxTokens(100) + .safePrompt(true) + .randomSeed(123) + .stop(List.of("stop1", "stop2")) + .responseFormat(new ResponseFormat("json_object")) + .toolChoice(MistralAiApi.ChatCompletionRequest.ToolChoice.AUTO) + .proxyToolCalls(true) + .toolContext(Map.of("key1", "value1")) + .build(); + + MistralAiChatOptions copiedOptions = options.copy(); + assertThat(copiedOptions).isNotSameAs(options).isEqualTo(options); + // Ensure deep copy + assertThat(copiedOptions.getStop()).isNotSameAs(options.getStop()); + assertThat(copiedOptions.getToolContext()).isNotSameAs(options.getToolContext()); + } + + @Test + void testSetters() { + ResponseFormat responseFormat = new ResponseFormat("json_object"); + MistralAiChatOptions options = new MistralAiChatOptions(); + options.setModel("test-model"); + options.setTemperature(0.7); + options.setTopP(0.9); + options.setMaxTokens(100); + options.setSafePrompt(true); + options.setRandomSeed(123); + options.setResponseFormat(responseFormat); + options.setStopSequences(List.of("stop1", "stop2")); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getTemperature()).isEqualTo(0.7); + assertThat(options.getTopP()).isEqualTo(0.9); + assertThat(options.getMaxTokens()).isEqualTo(100); + assertThat(options.getSafePrompt()).isEqualTo(true); + assertThat(options.getRandomSeed()).isEqualTo(123); + assertThat(options.getStopSequences()).isEqualTo(List.of("stop1", "stop2")); + assertThat(options.getResponseFormat()).isEqualTo(responseFormat); + } + + @Test + void testDefaultValues() { + MistralAiChatOptions options = new MistralAiChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getSafePrompt()).isNull(); + assertThat(options.getRandomSeed()).isNull(); + assertThat(options.getStopSequences()).isNull(); + assertThat(options.getResponseFormat()).isNull(); + } + +} diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java index 3efddaf89c3..8806e03e9e6 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.ai.moonshot; import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -26,6 +27,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.chat.prompt.AbstractChatOptions; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.ai.moonshot.api.MoonshotApi; @@ -39,33 +41,7 @@ * @author Alexandros Pappas */ @JsonInclude(JsonInclude.Include.NON_NULL) -public class MoonshotChatOptions implements FunctionCallingOptions { - - /** - * ID of the model to use - */ - private @JsonProperty("model") String model; - - /** - * The maximum number of tokens to generate in the chat completion. The total length - * of input tokens and generated tokens is limited by the model's context length. - */ - private @JsonProperty("max_tokens") Integer maxTokens; - - /** - * What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will - * make the output more random, while lower values like 0.2 will make it more focused - * and deterministic. We generally recommend altering this or top_p but not both. - */ - private @JsonProperty("temperature") Double temperature; - - /** - * An alternative to sampling with temperature, called nucleus sampling, where the - * model considers the results of the tokens with top_p probability mass. So 0.1 means - * only the tokens comprising the top 10% probability mass are considered. We - * generally recommend altering this or temperature but not both. - */ - private @JsonProperty("top_p") Double topP; +public class MoonshotChatOptions extends AbstractChatOptions implements FunctionCallingOptions { /** * How many chat completion choices to generate for each input message. Note that you @@ -74,20 +50,6 @@ public class MoonshotChatOptions implements FunctionCallingOptions { */ private @JsonProperty("n") Integer n; - /** - * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether - * they appear in the text so far, increasing the model's likelihood to talk about new - * topics. - */ - private @JsonProperty("presence_penalty") Double presencePenalty; - - /** - * Number between -2.0 and 2.0. Positive values penalize new tokens based on their - * existing frequency in the text so far, decreasing the model's likelihood to repeat - * the same line verbatim. - */ - private @JsonProperty("frequency_penalty") Double frequencyPenalty; - /** * Up to 5 sequences where the API will stop generating further tokens. */ @@ -141,7 +103,7 @@ public class MoonshotChatOptions implements FunctionCallingOptions { private Boolean proxyToolCalls; @JsonIgnore - private Map toolContext; + private Map toolContext = new HashMap<>(); public static Builder builder() { return new Builder(); @@ -166,29 +128,14 @@ public void setFunctions(Set functionNames) { this.functions = functionNames; } - @Override - public String getModel() { - return this.model; - } - public void setModel(String model) { this.model = model; } - @Override - public Double getFrequencyPenalty() { - return this.frequencyPenalty; - } - public void setFrequencyPenalty(Double frequencyPenalty) { this.frequencyPenalty = frequencyPenalty; } - @Override - public Integer getMaxTokens() { - return this.maxTokens; - } - public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } @@ -201,11 +148,6 @@ public void setN(Integer n) { this.n = n; } - @Override - public Double getPresencePenalty() { - return this.presencePenalty; - } - public void setPresencePenalty(Double presencePenalty) { this.presencePenalty = presencePenalty; } @@ -229,20 +171,10 @@ public void setStop(List stop) { this.stop = stop; } - @Override - public Double getTemperature() { - return this.temperature; - } - public void setTemperature(Double temperature) { this.temperature = temperature; } - @Override - public Double getTopP() { - return this.topP; - } - public void setTopP(Double topP) { this.topP = topP; } @@ -289,14 +221,14 @@ public MoonshotChatOptions copy() { .N(this.n) .presencePenalty(this.presencePenalty) .frequencyPenalty(this.frequencyPenalty) - .stop(this.stop) + .stop(this.stop != null ? new ArrayList<>(this.stop) : null) .user(this.user) - .tools(this.tools) + .tools(this.tools != null ? new ArrayList<>(this.tools) : null) .toolChoice(this.toolChoice) .functionCallbacks(this.functionCallbacks) .functions(this.functions) .proxyToolCalls(this.proxyToolCalls) - .toolContext(this.toolContext) + .toolContext(this.toolContext != null ? new HashMap<>(this.toolContext) : null) .build(); } diff --git a/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotChatOptionsTests.java b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotChatOptionsTests.java new file mode 100644 index 00000000000..9486b65b3ba --- /dev/null +++ b/models/spring-ai-moonshot/src/test/java/org/springframework/ai/moonshot/MoonshotChatOptionsTests.java @@ -0,0 +1,126 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.moonshot; + +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link MoonshotChatOptions}. + * + * @author Alexandros Pappas + */ +class MoonshotChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + MoonshotChatOptions options = MoonshotChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.5) + .maxTokens(10) + .N(1) + .presencePenalty(0.5) + .stop(List.of("test")) + .temperature(0.6) + .topP(0.6) + .toolChoice("test") + .proxyToolCalls(true) + .toolContext(Map.of("key1", "value1")) + .user("test-user") + .build(); + + assertThat(options) + .extracting("model", "frequencyPenalty", "maxTokens", "N", "presencePenalty", "stop", "temperature", "topP", + "toolChoice", "proxyToolCalls", "toolContext", "user") + .containsExactly("test-model", 0.5, 10, 1, 0.5, List.of("test"), 0.6, 0.6, "test", true, + Map.of("key1", "value1"), "test-user"); + } + + @Test + void testCopy() { + MoonshotChatOptions original = MoonshotChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.5) + .maxTokens(10) + .N(1) + .presencePenalty(0.5) + .stop(List.of("test")) + .temperature(0.6) + .topP(0.6) + .toolChoice("test") + .proxyToolCalls(true) + .toolContext(Map.of("key1", "value1")) + .user("test-user") + .build(); + + MoonshotChatOptions copied = original.copy(); + + assertThat(copied).isNotSameAs(original).isEqualTo(original); + // Ensure deep copy + assertThat(copied.getStop()).isNotSameAs(original.getStop()); + assertThat(copied.getToolContext()).isNotSameAs(original.getToolContext()); + } + + @Test + void testSetters() { + MoonshotChatOptions options = new MoonshotChatOptions(); + options.setModel("test-model"); + options.setFrequencyPenalty(0.5); + options.setMaxTokens(10); + options.setN(1); + options.setPresencePenalty(0.5); + options.setUser("test-user"); + options.setStop(List.of("test")); + options.setTemperature(0.6); + options.setTopP(0.6); + options.setProxyToolCalls(true); + options.setToolContext(Map.of("key1", "value1")); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getFrequencyPenalty()).isEqualTo(0.5); + assertThat(options.getMaxTokens()).isEqualTo(10); + assertThat(options.getN()).isEqualTo(1); + assertThat(options.getPresencePenalty()).isEqualTo(0.5); + assertThat(options.getUser()).isEqualTo("test-user"); + assertThat(options.getStopSequences()).isEqualTo(List.of("test")); + assertThat(options.getTemperature()).isEqualTo(0.6); + assertThat(options.getTopP()).isEqualTo(0.6); + assertThat(options.getProxyToolCalls()).isTrue(); + assertThat(options.getToolContext()).isEqualTo(Map.of("key1", "value1")); + } + + @Test + void testDefaultValues() { + MoonshotChatOptions options = new MoonshotChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getFrequencyPenalty()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getN()).isNull(); + assertThat(options.getPresencePenalty()).isNull(); + assertThat(options.getUser()).isNull(); + assertThat(options.getStopSequences()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getProxyToolCalls()).isNull(); + assertThat(options.getToolContext()).isEqualTo(new java.util.HashMap<>()); + } + +} diff --git a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java index afd17ab456f..1ae88b3eac8 100644 --- a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java +++ b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,12 +16,14 @@ package org.springframework.ai.oci.cohere; +import java.util.ArrayList; import java.util.List; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.oracle.bmc.generativeaiinference.model.CohereTool; +import org.springframework.ai.chat.prompt.AbstractChatOptions; import org.springframework.ai.chat.prompt.ChatOptions; /** @@ -29,18 +31,10 @@ * * @author Anders Swanson * @author Ilayaperumal Gopinathan + * @author Alexandros Pappas */ @JsonInclude(JsonInclude.Include.NON_NULL) -public class OCICohereChatOptions implements ChatOptions { - - @JsonProperty("model") - private String model; - - /** - * The maximum number of tokens to generate per request. - */ - @JsonProperty("maxTokens") - private Integer maxTokens; +public class OCICohereChatOptions extends AbstractChatOptions implements ChatOptions { /** * The OCI Compartment to run chat requests in. @@ -60,43 +54,6 @@ public class OCICohereChatOptions implements ChatOptions { @JsonProperty("preambleOverride") private String preambleOverride; - /** - * The sample temperature, where higher values are more random, and lower values are - * more deterministic. - */ - @JsonProperty("temperature") - private Double temperature; - - /** - * The Top P parameter modifies the probability of tokens sampled. E.g., a value of - * 0.25 means only tokens from the top 25% probability mass will be considered. - */ - @JsonProperty("topP") - private Double topP; - - /** - * The Top K parameter limits the number of potential tokens considered at each step - * of text generation. E.g., a value of 5 means only the top 5 most probable tokens - * will be considered during each step of text generation. - */ - @JsonProperty("topK") - private Integer topK; - - /** - * The frequency penalty assigns a penalty to repeated tokens depending on how many - * times it has already appeared in the prompt or output. Higher values will reduce - * repeated tokens and outputs will be more random. - */ - @JsonProperty("frequencyPenalty") - private Double frequencyPenalty; - - /** - * The presence penalty assigns a penalty to each token when it appears in the output - * to encourage generating outputs with tokens that haven't been used. - */ - @JsonProperty("presencePenalty") - private Double presencePenalty; - /** * A collection of textual sequences that will end completions generation. */ @@ -124,11 +81,11 @@ public static OCICohereChatOptions fromOptions(OCICohereChatOptions fromOptions) .temperature(fromOptions.temperature) .topP(fromOptions.topP) .topK(fromOptions.topK) - .stop(fromOptions.stop) + .stop(fromOptions.stop != null ? new ArrayList<>(fromOptions.stop) : null) .frequencyPenalty(fromOptions.frequencyPenalty) .presencePenalty(fromOptions.presencePenalty) - .documents(fromOptions.documents) - .tools(fromOptions.tools) + .documents(fromOptions.documents != null ? new ArrayList<>(fromOptions.documents) : null) + .tools(fromOptions.tools != null ? new ArrayList<>(fromOptions.tools) : null) .build(); } @@ -216,49 +173,71 @@ public void setTools(List tools) { * ChatModel overrides. */ - @Override - public String getModel() { - return this.model; - } - - @Override - public Double getFrequencyPenalty() { - return this.frequencyPenalty; - } - - @Override - public Integer getMaxTokens() { - return this.maxTokens; - } - - @Override - public Double getPresencePenalty() { - return this.presencePenalty; - } - @Override public List getStopSequences() { return this.stop; } @Override - public Double getTemperature() { - return this.temperature; - } - - @Override - public Integer getTopK() { - return this.topK; + @SuppressWarnings("unchecked") + public OCICohereChatOptions copy() { + return fromOptions(this); } @Override - public Double getTopP() { - return this.topP; + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + OCICohereChatOptions that = (OCICohereChatOptions) o; + + if (model != null ? !model.equals(that.model) : that.model != null) + return false; + if (maxTokens != null ? !maxTokens.equals(that.maxTokens) : that.maxTokens != null) + return false; + if (compartment != null ? !compartment.equals(that.compartment) : that.compartment != null) + return false; + if (servingMode != null ? !servingMode.equals(that.servingMode) : that.servingMode != null) + return false; + if (preambleOverride != null ? !preambleOverride.equals(that.preambleOverride) : that.preambleOverride != null) + return false; + if (temperature != null ? !temperature.equals(that.temperature) : that.temperature != null) + return false; + if (topP != null ? !topP.equals(that.topP) : that.topP != null) + return false; + if (topK != null ? !topK.equals(that.topK) : that.topK != null) + return false; + if (stop != null ? !stop.equals(that.stop) : that.stop != null) + return false; + if (frequencyPenalty != null ? !frequencyPenalty.equals(that.frequencyPenalty) : that.frequencyPenalty != null) + return false; + if (presencePenalty != null ? !presencePenalty.equals(that.presencePenalty) : that.presencePenalty != null) + return false; + if (documents != null ? !documents.equals(that.documents) : that.documents != null) + return false; + return tools != null ? tools.equals(that.tools) : that.tools == null; } @Override - public ChatOptions copy() { - return fromOptions(this); + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); + result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); + result = prime * result + ((this.compartment == null) ? 0 : this.compartment.hashCode()); + result = prime * result + ((this.servingMode == null) ? 0 : this.servingMode.hashCode()); + result = prime * result + ((this.preambleOverride == null) ? 0 : this.preambleOverride.hashCode()); + result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); + result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); + result = prime * result + ((this.topK == null) ? 0 : this.topK.hashCode()); + result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); + result = prime * result + ((this.frequencyPenalty == null) ? 0 : this.frequencyPenalty.hashCode()); + result = prime * result + ((this.presencePenalty == null) ? 0 : this.presencePenalty.hashCode()); + result = prime * result + ((this.documents == null) ? 0 : this.documents.hashCode()); + result = prime * result + ((this.tools == null) ? 0 : this.tools.hashCode()); + return result; } public static class Builder { diff --git a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/cohere/OCICohereChatOptionsTests.java b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/cohere/OCICohereChatOptionsTests.java new file mode 100644 index 00000000000..ea57e37608c --- /dev/null +++ b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/cohere/OCICohereChatOptionsTests.java @@ -0,0 +1,133 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.oci.cohere; + +import com.oracle.bmc.generativeaiinference.model.CohereTool; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link OCICohereChatOptions}. + * + * @author Alexandros Pappas + */ +class OCICohereChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + OCICohereChatOptions options = OCICohereChatOptions.builder() + .model("test-model") + .maxTokens(10) + .compartment("test-compartment") + .servingMode("test-servingMode") + .preambleOverride("test-preambleOverride") + .temperature(0.6) + .topP(0.6) + .topK(50) + .stop(List.of("test")) + .frequencyPenalty(0.5) + .presencePenalty(0.5) + .documents(List.of("doc1", "doc2")) + .build(); + + assertThat(options) + .extracting("model", "maxTokens", "compartment", "servingMode", "preambleOverride", "temperature", "topP", + "topK", "stop", "frequencyPenalty", "presencePenalty", "documents") + .containsExactly("test-model", 10, "test-compartment", "test-servingMode", "test-preambleOverride", 0.6, + 0.6, 50, List.of("test"), 0.5, 0.5, List.of("doc1", "doc2")); + } + + @Test + void testCopy() { + OCICohereChatOptions original = OCICohereChatOptions.builder() + .model("test-model") + .maxTokens(10) + .compartment("test-compartment") + .servingMode("test-servingMode") + .preambleOverride("test-preambleOverride") + .temperature(0.6) + .topP(0.6) + .topK(50) + .stop(List.of("test")) + .frequencyPenalty(0.5) + .presencePenalty(0.5) + .documents(List.of("doc1", "doc2")) + .tools(List.of(new CohereTool("test-tool", "test-context", Map.of()))) + .build(); + + OCICohereChatOptions copied = original.copy(); + + assertThat(copied).isNotSameAs(original).isEqualTo(original); + // Ensure deep copy + assertThat(copied.getStop()).isNotSameAs(original.getStop()); + assertThat(copied.getDocuments()).isNotSameAs(original.getDocuments()); + assertThat(copied.getTools()).isNotSameAs(original.getTools()); + } + + @Test + void testSetters() { + OCICohereChatOptions options = new OCICohereChatOptions(); + options.setModel("test-model"); + options.setMaxTokens(10); + options.setCompartment("test-compartment"); + options.setServingMode("test-servingMode"); + options.setPreambleOverride("test-preambleOverride"); + options.setTemperature(0.6); + options.setTopP(0.6); + options.setTopK(50); + options.setStop(List.of("test")); + options.setFrequencyPenalty(0.5); + options.setPresencePenalty(0.5); + options.setDocuments(List.of("doc1", "doc2")); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getMaxTokens()).isEqualTo(10); + assertThat(options.getCompartment()).isEqualTo("test-compartment"); + assertThat(options.getServingMode()).isEqualTo("test-servingMode"); + assertThat(options.getPreambleOverride()).isEqualTo("test-preambleOverride"); + assertThat(options.getTemperature()).isEqualTo(0.6); + assertThat(options.getTopP()).isEqualTo(0.6); + assertThat(options.getTopK()).isEqualTo(50); + assertThat(options.getStop()).isEqualTo(List.of("test")); + assertThat(options.getFrequencyPenalty()).isEqualTo(0.5); + assertThat(options.getPresencePenalty()).isEqualTo(0.5); + assertThat(options.getDocuments()).isEqualTo(List.of("doc1", "doc2")); + } + + @Test + void testDefaultValues() { + OCICohereChatOptions options = new OCICohereChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getCompartment()).isNull(); + assertThat(options.getServingMode()).isNull(); + assertThat(options.getPreambleOverride()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getTopK()).isNull(); + assertThat(options.getStop()).isNull(); + assertThat(options.getFrequencyPenalty()).isNull(); + assertThat(options.getPresencePenalty()).isNull(); + assertThat(options.getDocuments()).isNull(); + assertThat(options.getTools()).isNull(); + } + +} diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index 6c3e1246ca2..10ef06e1bf2 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -29,6 +29,7 @@ import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.chat.prompt.AbstractChatOptions; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; @@ -46,21 +47,13 @@ * @author Mariusz Bernacki * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @author Alexandros Pappas * @since 0.8.0 */ @JsonInclude(Include.NON_NULL) -public class OpenAiChatOptions implements FunctionCallingOptions { +public class OpenAiChatOptions extends AbstractChatOptions implements FunctionCallingOptions { // @formatter:off - /** - * ID of the model to use. - */ - private @JsonProperty("model") String model; - /** - * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing - * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. - */ - private @JsonProperty("frequency_penalty") Double frequencyPenalty; /** * Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object * that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. @@ -79,11 +72,6 @@ public class OpenAiChatOptions implements FunctionCallingOptions { * each with an associated log probability. 'logprobs' must be set to 'true' if this parameter is used. */ private @JsonProperty("top_logprobs") Integer topLogprobs; - /** - * The maximum number of tokens to generate in the chat completion. The total length of input - * tokens and generated tokens is limited by the model's context length. - */ - private @JsonProperty("max_tokens") Integer maxTokens; /** * An upper bound for the number of tokens that can be generated for a completion, * including visible output tokens and reasoning tokens. @@ -114,12 +102,6 @@ public class OpenAiChatOptions implements FunctionCallingOptions { */ private @JsonProperty("audio") AudioParameters outputAudio; - - /** - * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they - * appear in the text so far, increasing the model's likelihood to talk about new topics. - */ - private @JsonProperty("presence_penalty") Double presencePenalty; /** * An object specifying the format that the model must output. Setting to { "type": * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. @@ -140,18 +122,6 @@ public class OpenAiChatOptions implements FunctionCallingOptions { * Up to 4 sequences where the API will stop generating further tokens. */ private @JsonProperty("stop") List stop; - /** - * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output - * more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend - * altering this or top_p but not both. - */ - private @JsonProperty("temperature") Double temperature; - /** - * An alternative to sampling with temperature, called nucleus sampling, where the model considers the - * results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% - * probability mass are considered. We generally recommend altering this or temperature but not both. - */ - private @JsonProperty("top_p") Double topP; /** * A list of tools the model may call. Currently, only functions are supported as a tool. Use this to * provide a list of functions the model may generate JSON inputs for. @@ -227,7 +197,7 @@ public class OpenAiChatOptions implements FunctionCallingOptions { private Map httpHeaders = new HashMap<>(); @JsonIgnore - private Map toolContext; + private Map toolContext = new HashMap<>(); // @formatter:on @@ -277,20 +247,10 @@ public void setStreamUsage(Boolean enableStreamUsage) { this.streamOptions = (enableStreamUsage) ? StreamOptions.INCLUDE_USAGE : null; } - @Override - public String getModel() { - return this.model; - } - public void setModel(String model) { this.model = model; } - @Override - public Double getFrequencyPenalty() { - return this.frequencyPenalty; - } - public void setFrequencyPenalty(Double frequencyPenalty) { this.frequencyPenalty = frequencyPenalty; } @@ -319,11 +279,6 @@ public void setTopLogprobs(Integer topLogprobs) { this.topLogprobs = topLogprobs; } - @Override - public Integer getMaxTokens() { - return this.maxTokens; - } - public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } @@ -360,11 +315,6 @@ public void setOutputAudio(AudioParameters audio) { this.outputAudio = audio; } - @Override - public Double getPresencePenalty() { - return this.presencePenalty; - } - public void setPresencePenalty(Double presencePenalty) { this.presencePenalty = presencePenalty; } @@ -421,11 +371,6 @@ public void setTemperature(Double temperature) { this.temperature = temperature; } - @Override - public Double getTopP() { - return this.topP; - } - public void setTopP(Double topP) { this.topP = topP; } @@ -539,6 +484,7 @@ public void setReasoningEffort(String reasoningEffort) { } @Override + @SuppressWarnings("unchecked") public OpenAiChatOptions copy() { return OpenAiChatOptions.fromOptions(this); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java new file mode 100644 index 00000000000..5fc63a547ca --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatOptionsTests.java @@ -0,0 +1,262 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import org.junit.jupiter.api.Test; +import static org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters.Voice.ALLOY; + +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.AudioParameters; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.StreamOptions; +import org.springframework.ai.openai.api.ResponseFormat; + +/** + * Tests for {@link OpenAiChatOptions}. + * + * @author Alexandros Pappas + */ +class OpenAiChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + Map logitBias = new HashMap<>(); + logitBias.put("token1", 1); + logitBias.put("token2", -1); + + List outputModalities = List.of("text", "audio"); + AudioParameters outputAudio = new AudioParameters(ALLOY, AudioParameters.AudioResponseFormat.MP3); + ResponseFormat responseFormat = new ResponseFormat(); + StreamOptions streamOptions = StreamOptions.INCLUDE_USAGE; + List stopSequences = List.of("stop1", "stop2"); + List tools = new ArrayList<>(); + Object toolChoice = "auto"; + Map metadata = Map.of("key1", "value1"); + Map toolContext = Map.of("keyA", "valueA"); + + OpenAiChatOptions options = OpenAiChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.5) + .logitBias(logitBias) + .logprobs(true) + .topLogprobs(5) + .maxTokens(100) + .maxCompletionTokens(50) + .N(2) + .outputModalities(outputModalities) + .outputAudio(outputAudio) + .presencePenalty(0.8) + .responseFormat(responseFormat) + .streamUsage(true) + .seed(12345) + .stop(stopSequences) + .temperature(0.7) + .topP(0.9) + .tools(tools) + .toolChoice(toolChoice) + .user("test-user") + .parallelToolCalls(true) + .store(false) + .metadata(metadata) + .reasoningEffort("medium") + .proxyToolCalls(false) + .httpHeaders(Map.of("header1", "value1")) + .toolContext(toolContext) + .build(); + + assertThat(options) + .extracting("model", "frequencyPenalty", "logitBias", "logprobs", "topLogprobs", "maxTokens", + "maxCompletionTokens", "n", "outputModalities", "outputAudio", "presencePenalty", "responseFormat", + "streamOptions", "seed", "stop", "temperature", "topP", "tools", "toolChoice", "user", + "parallelToolCalls", "store", "metadata", "reasoningEffort", "proxyToolCalls", "httpHeaders", + "toolContext") + .containsExactly("test-model", 0.5, logitBias, true, 5, 100, 50, 2, outputModalities, outputAudio, 0.8, + responseFormat, streamOptions, 12345, stopSequences, 0.7, 0.9, tools, toolChoice, "test-user", true, + false, metadata, "medium", false, Map.of("header1", "value1"), toolContext); + + assertThat(options.getStreamUsage()).isTrue(); + assertThat(options.getStreamOptions()).isEqualTo(StreamOptions.INCLUDE_USAGE); + + } + + @Test + void testCopy() { + Map logitBias = new HashMap<>(); + logitBias.put("token1", 1); + + List outputModalities = List.of("text"); + AudioParameters outputAudio = new AudioParameters(ALLOY, AudioParameters.AudioResponseFormat.MP3); + ResponseFormat responseFormat = new ResponseFormat(); + + List stopSequences = List.of("stop1"); + List tools = new ArrayList<>(); + Object toolChoice = "none"; + Map metadata = Map.of("key1", "value1"); + + OpenAiChatOptions originalOptions = OpenAiChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.5) + .logitBias(logitBias) + .logprobs(true) + .topLogprobs(5) + .maxTokens(100) + .maxCompletionTokens(50) + .N(2) + .outputModalities(outputModalities) + .outputAudio(outputAudio) + .presencePenalty(0.8) + .responseFormat(responseFormat) + .streamUsage(false) + .seed(12345) + .stop(stopSequences) + .temperature(0.7) + .topP(0.9) + .tools(tools) + .toolChoice(toolChoice) + .user("test-user") + .parallelToolCalls(false) + .store(true) + .metadata(metadata) + .reasoningEffort("low") + .proxyToolCalls(true) + .httpHeaders(Map.of("header1", "value1")) + .build(); + + OpenAiChatOptions copiedOptions = originalOptions.copy(); + assertThat(copiedOptions).isNotSameAs(originalOptions).isEqualTo(originalOptions); + } + + @Test + void testSetters() { + Map logitBias = new HashMap<>(); + logitBias.put("token1", 1); + + List outputModalities = List.of("audio"); + AudioParameters outputAudio = new AudioParameters(ALLOY, AudioParameters.AudioResponseFormat.MP3); + ResponseFormat responseFormat = new ResponseFormat(); + + StreamOptions streamOptions = StreamOptions.INCLUDE_USAGE; + List stopSequences = List.of("stop1", "stop2"); + List tools = new ArrayList<>(); + Object toolChoice = "auto"; + Map metadata = Map.of("key2", "value2"); + + OpenAiChatOptions options = new OpenAiChatOptions(); + options.setModel("test-model"); + options.setFrequencyPenalty(0.5); + options.setLogitBias(logitBias); + options.setLogprobs(true); + options.setTopLogprobs(5); + options.setMaxTokens(100); + options.setMaxCompletionTokens(50); + options.setN(2); + options.setOutputModalities(outputModalities); + options.setOutputAudio(outputAudio); + options.setPresencePenalty(0.8); + options.setResponseFormat(responseFormat); + options.setStreamOptions(streamOptions); + options.setSeed(12345); + options.setStop(stopSequences); + options.setTemperature(0.7); + options.setTopP(0.9); + options.setTools(tools); + options.setToolChoice(toolChoice); + options.setUser("test-user"); + options.setParallelToolCalls(true); + options.setStore(false); + options.setMetadata(metadata); + options.setReasoningEffort("high"); + options.setProxyToolCalls(false); + options.setHttpHeaders(Map.of("header2", "value2")); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getFrequencyPenalty()).isEqualTo(0.5); + assertThat(options.getLogitBias()).isEqualTo(logitBias); + assertThat(options.getLogprobs()).isTrue(); + assertThat(options.getTopLogprobs()).isEqualTo(5); + assertThat(options.getMaxTokens()).isEqualTo(100); + assertThat(options.getMaxCompletionTokens()).isEqualTo(50); + assertThat(options.getN()).isEqualTo(2); + assertThat(options.getOutputModalities()).isEqualTo(outputModalities); + assertThat(options.getOutputAudio()).isEqualTo(outputAudio); + assertThat(options.getPresencePenalty()).isEqualTo(0.8); + assertThat(options.getResponseFormat()).isEqualTo(responseFormat); + assertThat(options.getStreamOptions()).isEqualTo(streamOptions); + assertThat(options.getSeed()).isEqualTo(12345); + assertThat(options.getStop()).isEqualTo(stopSequences); + assertThat(options.getTemperature()).isEqualTo(0.7); + assertThat(options.getTopP()).isEqualTo(0.9); + assertThat(options.getTools()).isEqualTo(tools); + assertThat(options.getToolChoice()).isEqualTo(toolChoice); + assertThat(options.getUser()).isEqualTo("test-user"); + assertThat(options.getParallelToolCalls()).isTrue(); + assertThat(options.getStore()).isFalse(); + assertThat(options.getMetadata()).isEqualTo(metadata); + assertThat(options.getReasoningEffort()).isEqualTo("high"); + assertThat(options.getProxyToolCalls()).isFalse(); + assertThat(options.getHttpHeaders()).isEqualTo(Map.of("header2", "value2")); + assertThat(options.getStreamUsage()).isTrue(); + options.setStreamUsage(false); + assertThat(options.getStreamUsage()).isFalse(); + assertThat(options.getStreamOptions()).isNull(); + options.setStopSequences(List.of("s1", "s2")); + assertThat(options.getStopSequences()).isEqualTo(List.of("s1", "s2")); + assertThat(options.getStop()).isEqualTo(List.of("s1", "s2")); + } + + @Test + void testDefaultValues() { + OpenAiChatOptions options = new OpenAiChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getFrequencyPenalty()).isNull(); + assertThat(options.getLogitBias()).isNull(); + assertThat(options.getLogprobs()).isNull(); + assertThat(options.getTopLogprobs()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getMaxCompletionTokens()).isNull(); + assertThat(options.getN()).isNull(); + assertThat(options.getOutputModalities()).isNull(); + assertThat(options.getOutputAudio()).isNull(); + assertThat(options.getPresencePenalty()).isNull(); + assertThat(options.getResponseFormat()).isNull(); + assertThat(options.getStreamOptions()).isNull(); + assertThat(options.getSeed()).isNull(); + assertThat(options.getStop()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getTools()).isNull(); + assertThat(options.getToolChoice()).isNull(); + assertThat(options.getUser()).isNull(); + assertThat(options.getParallelToolCalls()).isNull(); + assertThat(options.getStore()).isNull(); + assertThat(options.getMetadata()).isNull(); + assertThat(options.getReasoningEffort()).isNull(); + assertThat(options.getFunctionCallbacks()).isNotNull().isEmpty(); + assertThat(options.getFunctions()).isNotNull().isEmpty(); + assertThat(options.getProxyToolCalls()).isNull(); + assertThat(options.getHttpHeaders()).isNotNull().isEmpty(); + assertThat(options.getToolContext()).isEqualTo(new HashMap<>()); + assertThat(options.getStreamUsage()).isFalse(); + assertThat(options.getStopSequences()).isNull(); + } + +} diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java index 5531e90a88a..a6b9ce7d31f 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,13 +16,16 @@ package org.springframework.ai.qianfan; +import java.util.ArrayList; import java.util.List; +import java.util.Objects; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; +import org.springframework.ai.chat.prompt.AbstractChatOptions; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.qianfan.api.QianFanApi; @@ -33,6 +36,7 @@ * * @author Geng Rong * @author Ilayaperumal Gopinathan + * @author Alexandros Pappas * @since 1.0 * @see ChatOptions */ @@ -64,6 +68,7 @@ public class QianFanChatOptions implements ChatOptions { * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. */ private @JsonProperty("response_format") QianFanApi.ChatCompletionRequest.ResponseFormat responseFormat; + /** * Up to 4 sequences where the API will stop generating further tokens. */ @@ -93,7 +98,7 @@ public static QianFanChatOptions fromOptions(QianFanChatOptions fromOptions) { .maxTokens(fromOptions.getMaxTokens()) .presencePenalty(fromOptions.getPresencePenalty()) .responseFormat(fromOptions.getResponseFormat()) - .stop(fromOptions.getStop()) + .stop(fromOptions.getStop() != null ? new ArrayList<>(fromOptions.getStop()) : null) .temperature(fromOptions.getTemperature()) .topP(fromOptions.getTopP()) .build(); @@ -187,100 +192,37 @@ public Integer getTopK() { } @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); - result = prime * result + ((this.frequencyPenalty == null) ? 0 : this.frequencyPenalty.hashCode()); - result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); - result = prime * result + ((this.presencePenalty == null) ? 0 : this.presencePenalty.hashCode()); - result = prime * result + ((this.responseFormat == null) ? 0 : this.responseFormat.hashCode()); - result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); - result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); - result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); - return result; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) { + public boolean equals(Object o) { + if (this == o) { return true; } - if (obj == null) { + if (!(o instanceof QianFanChatOptions that)) { return false; } - if (getClass() != obj.getClass()) { - return false; - } - QianFanChatOptions other = (QianFanChatOptions) obj; - if (this.model == null) { - if (other.model != null) { - return false; - } - } - else if (!this.model.equals(other.model)) { - return false; - } - if (this.frequencyPenalty == null) { - if (other.frequencyPenalty != null) { - return false; - } - } - else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) { - return false; - } - if (this.maxTokens == null) { - if (other.maxTokens != null) { - return false; - } - } - else if (!this.maxTokens.equals(other.maxTokens)) { - return false; - } - if (this.presencePenalty == null) { - if (other.presencePenalty != null) { - return false; - } - } - else if (!this.presencePenalty.equals(other.presencePenalty)) { - return false; - } - if (this.responseFormat == null) { - if (other.responseFormat != null) { - return false; - } - } - else if (!this.responseFormat.equals(other.responseFormat)) { - return false; - } - if (this.stop == null) { - if (other.stop != null) { - return false; - } - } - else if (!this.stop.equals(other.stop)) { - return false; - } - if (this.temperature == null) { - if (other.temperature != null) { - return false; - } - } - else if (!this.temperature.equals(other.temperature)) { - return false; - } - if (this.topP == null) { - if (other.topP != null) { - return false; - } - } - else if (!this.topP.equals(other.topP)) { - return false; - } - return true; + return Objects.equals(this.model, that.model) && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) + && Objects.equals(this.maxTokens, that.maxTokens) + && Objects.equals(this.presencePenalty, that.presencePenalty) + && Objects.equals(this.responseFormat, that.responseFormat) && Objects.equals(this.stop, that.stop) + && Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + (this.model != null ? this.model.hashCode() : 0); + result = prime * result + (this.frequencyPenalty != null ? this.frequencyPenalty.hashCode() : 0); + result = prime * result + (this.maxTokens != null ? this.maxTokens.hashCode() : 0); + result = prime * result + (this.presencePenalty != null ? this.presencePenalty.hashCode() : 0); + result = prime * result + (this.responseFormat != null ? this.responseFormat.hashCode() : 0); + result = prime * result + (this.stop != null ? this.stop.hashCode() : 0); + result = prime * result + (this.temperature != null ? this.temperature.hashCode() : 0); + result = prime * result + (this.topP != null ? this.topP.hashCode() : 0); + return result; } @Override + @SuppressWarnings("unchecked") public QianFanChatOptions copy() { return fromOptions(this); } diff --git a/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/QianFanChatOptionsTests.java b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/QianFanChatOptionsTests.java new file mode 100644 index 00000000000..0a18ed78590 --- /dev/null +++ b/models/spring-ai-qianfan/src/test/java/org/springframework/ai/qianfan/QianFanChatOptionsTests.java @@ -0,0 +1,127 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.qianfan; + +import java.util.List; + +import com.fasterxml.jackson.databind.ObjectMapper; +import static org.assertj.core.api.Assertions.assertThat; +import org.junit.jupiter.api.Test; + +import org.springframework.ai.qianfan.api.QianFanApi; + +/** + * Tests for {@link QianFanChatOptions}. + */ +class QianFanChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + QianFanChatOptions options = QianFanChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.5) + .maxTokens(10) + .presencePenalty(0.5) + .responseFormat(new QianFanApi.ChatCompletionRequest.ResponseFormat("text")) + .stop(List.of("test")) + .temperature(0.6) + .topP(0.6) + .build(); + + assertThat(options) + .extracting("model", "frequencyPenalty", "maxTokens", "presencePenalty", "responseFormat", "stop", + "temperature", "topP") + .containsExactly("test-model", 0.5, 10, 0.5, new QianFanApi.ChatCompletionRequest.ResponseFormat("text"), + List.of("test"), 0.6, 0.6); + } + + @Test + void testCopy() { + QianFanChatOptions original = QianFanChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.5) + .maxTokens(10) + .presencePenalty(0.5) + .responseFormat(new QianFanApi.ChatCompletionRequest.ResponseFormat("text")) + .stop(List.of("test")) + .temperature(0.6) + .topP(0.6) + .build(); + + QianFanChatOptions copied = original.copy(); + + assertThat(copied).isNotSameAs(original).isEqualTo(original); + // Ensure deep copy + assertThat(copied.getStop()).isNotSameAs(original.getStop()); + } + + @Test + void testSetters() { + QianFanChatOptions options = new QianFanChatOptions(); + options.setModel("test-model"); + options.setFrequencyPenalty(0.5); + options.setMaxTokens(10); + options.setPresencePenalty(0.5); + options.setResponseFormat(new QianFanApi.ChatCompletionRequest.ResponseFormat("text")); + options.setStop(List.of("test")); + options.setTemperature(0.6); + options.setTopP(0.6); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getFrequencyPenalty()).isEqualTo(0.5); + assertThat(options.getMaxTokens()).isEqualTo(10); + assertThat(options.getPresencePenalty()).isEqualTo(0.5); + assertThat(options.getResponseFormat()).isEqualTo(new QianFanApi.ChatCompletionRequest.ResponseFormat("text")); + assertThat(options.getStop()).isEqualTo(List.of("test")); + assertThat(options.getTemperature()).isEqualTo(0.6); + assertThat(options.getTopP()).isEqualTo(0.6); + } + + @Test + void testDefaultValues() { + QianFanChatOptions options = new QianFanChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getFrequencyPenalty()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getPresencePenalty()).isNull(); + assertThat(options.getResponseFormat()).isNull(); + assertThat(options.getStop()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + } + + @Test + void testSerialization() throws Exception { + QianFanChatOptions options = QianFanChatOptions.builder().maxTokens(10).build(); + + ObjectMapper objectMapper = new ObjectMapper(); + String json = objectMapper.writeValueAsString(options); + + assertThat(json).contains("\"max_output_tokens\":10"); + } + + @Test + void testDeserialization() throws Exception { + String json = "{\"max_output_tokens\":10}"; + + ObjectMapper objectMapper = new ObjectMapper(); + QianFanChatOptions options = objectMapper.readValue(json, QianFanChatOptions.class); + + assertThat(options.getMaxTokens()).isEqualTo(10); + } + +} diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java index f6413c759cb..bce72b63c5e 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.ai.watsonx; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -29,6 +30,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; +import org.springframework.ai.chat.prompt.AbstractChatOptions; import org.springframework.ai.chat.prompt.ChatOptions; /** @@ -45,34 +47,11 @@ */ // @formatter:off -public class WatsonxAiChatOptions implements ChatOptions { +public class WatsonxAiChatOptions extends AbstractChatOptions implements ChatOptions { @JsonIgnore private final ObjectMapper mapper = new ObjectMapper(); - /** - * The temperature of the model. Increasing the temperature will - * make the model answer more creatively. (Default: 0.7) - */ - @JsonProperty("temperature") - private Double temperature; - - /** - * Works together with top-k. A higher value (e.g., 0.95) will lead to - * more diverse text, while a lower value (e.g., 0.2) will generate more focused and - * conservative text. (Default: 1.0) - */ - @JsonProperty("top_p") - private Double topP; - - /** - * Reduces the probability of generating nonsense. A higher value (e.g. - * 100) will give more diverse answers, while a lower value (e.g. 10) will be more - * conservative. (Default: 50) - */ - @JsonProperty("top_k") - private Integer topK; - /** * Decoding is the process that a model uses to choose the tokens in the generated output. * Choose one of the following decoding options: @@ -104,14 +83,6 @@ public class WatsonxAiChatOptions implements ChatOptions { @JsonProperty("min_new_tokens") private Integer minNewTokens; - /** - * Sets when the LLM should stop. - * (e.g., ["\n\n\n"]) then when the LLM generates three consecutive line breaks it will terminate. - * Stop sequences are ignored until after the number of tokens that are specified in the Min tokens parameter are generated. - */ - @JsonProperty("stop_sequences") - private List stopSequences; - /** * Sets how strongly to penalize repetitions. A higher value * (e.g., 1.8) will penalize repetitions more strongly, while a lower value (e.g., @@ -126,12 +97,6 @@ public class WatsonxAiChatOptions implements ChatOptions { @JsonProperty("random_seed") private Integer randomSeed; - /** - * Model is the identifier of the LLM Model to be used - */ - @JsonProperty("model") - private String model; - /** * Set additional request params (some model have non-predefined options) */ @@ -162,37 +127,22 @@ public static WatsonxAiChatOptions fromOptions(WatsonxAiChatOptions fromOptions) .decodingMethod(fromOptions.getDecodingMethod()) .maxNewTokens(fromOptions.getMaxNewTokens()) .minNewTokens(fromOptions.getMinNewTokens()) - .stopSequences(fromOptions.getStopSequences()) + .stopSequences(fromOptions.getStopSequences() != null ? new ArrayList<>(fromOptions.getStopSequences()) : null) .repetitionPenalty(fromOptions.getRepetitionPenalty()) .randomSeed(fromOptions.getRandomSeed()) .model(fromOptions.getModel()) - .additionalProperties(fromOptions.getAdditionalProperties()) + .additionalProperties(fromOptions.getAdditionalProperties() != null ? new HashMap<>(fromOptions.getAdditionalProperties()) : null) .build(); } - @Override - public Double getTemperature() { - return this.temperature; - } - public void setTemperature(Double temperature) { this.temperature = temperature; } - @Override - public Double getTopP() { - return this.topP; - } - public void setTopP(Double topP) { this.topP = topP; } - @Override - public Integer getTopK() { - return this.topK; - } - public void setTopK(Integer topK) { this.topK = topK; } @@ -232,11 +182,6 @@ public void setMinNewTokens(Integer minNewTokens) { this.minNewTokens = minNewTokens; } - @Override - public List getStopSequences() { - return this.stopSequences; - } - public void setStopSequences(List stopSequences) { this.stopSequences = stopSequences; } @@ -319,10 +264,49 @@ private String toSnakeCase(String input) { } @Override + @SuppressWarnings("unchecked") public WatsonxAiChatOptions copy() { return fromOptions(this); } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + WatsonxAiChatOptions that = (WatsonxAiChatOptions) o; + + if (decodingMethod != null ? !decodingMethod.equals(that.decodingMethod) : that.decodingMethod != null) return false; + if (maxNewTokens != null ? !maxNewTokens.equals(that.maxNewTokens) : that.maxNewTokens != null) return false; + if (minNewTokens != null ? !minNewTokens.equals(that.minNewTokens) : that.minNewTokens != null) return false; + if (repetitionPenalty != null ? !repetitionPenalty.equals(that.repetitionPenalty) : that.repetitionPenalty != null) return false; + if (randomSeed != null ? !randomSeed.equals(that.randomSeed) : that.randomSeed != null) return false; + if (temperature != null ? !temperature.equals(that.temperature) : that.temperature != null) return false; + if (topP != null ? !topP.equals(that.topP) : that.topP != null) return false; + if (topK != null ? !topK.equals(that.topK) : that.topK != null) return false; + if (stopSequences != null ? !stopSequences.equals(that.stopSequences) : that.stopSequences != null) return false; + if (model != null ? !model.equals(that.model) : that.model != null) return false; + return additional != null ? additional.equals(that.additional) : that.additional == null; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((this.decodingMethod == null) ? 0 : this.decodingMethod.hashCode()); + result = prime * result + ((this.maxNewTokens == null) ? 0 : this.maxNewTokens.hashCode()); + result = prime * result + ((this.minNewTokens == null) ? 0 : this.minNewTokens.hashCode()); + result = prime * result + ((this.repetitionPenalty == null) ? 0 : this.repetitionPenalty.hashCode()); + result = prime * result + ((this.randomSeed == null) ? 0 : this.randomSeed.hashCode()); + result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); + result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); + result = prime * result + ((this.topK == null) ? 0 : this.topK.hashCode()); + result = prime * result + ((this.stopSequences == null) ? 0 : this.stopSequences.hashCode()); + result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); + result = prime * result + ((this.additional == null) ? 0 : this.additional.hashCode()); + return result; + } + public static class Builder { WatsonxAiChatOptions options = new WatsonxAiChatOptions(); diff --git a/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatOptionsTests.java b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatOptionsTests.java new file mode 100644 index 00000000000..f4936942d02 --- /dev/null +++ b/models/spring-ai-watsonx-ai/src/test/java/org/springframework/ai/watsonx/WatsonxAiChatOptionsTests.java @@ -0,0 +1,124 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.watsonx; + +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link WatsonxAiChatOptions}. + * + * @author Alexandros Pappas + */ +class WatsonxAiChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + WatsonxAiChatOptions options = WatsonxAiChatOptions.builder() + .temperature(0.6) + .topP(0.6) + .topK(50) + .decodingMethod("greedy") + .maxNewTokens(100) + .minNewTokens(10) + .stopSequences(List.of("test")) + .repetitionPenalty(1.2) + .randomSeed(42) + .model("test-model") + .additionalProperties(Map.of("key1", "value1")) + .build(); + + assertThat(options) + .extracting("temperature", "topP", "topK", "decodingMethod", "maxNewTokens", "minNewTokens", + "stopSequences", "repetitionPenalty", "randomSeed", "model", "additionalProperties") + .containsExactly(0.6, 0.6, 50, "greedy", 100, 10, List.of("test"), 1.2, 42, "test-model", + Map.of("key1", "value1")); + } + + @Test + void testCopy() { + WatsonxAiChatOptions original = WatsonxAiChatOptions.builder() + .temperature(0.6) + .topP(0.6) + .topK(50) + .decodingMethod("greedy") + .maxNewTokens(100) + .minNewTokens(10) + .stopSequences(List.of("test")) + .repetitionPenalty(1.2) + .randomSeed(42) + .model("test-model") + .additionalProperties(Map.of("key1", "value1")) + .build(); + + WatsonxAiChatOptions copied = original.copy(); + + assertThat(copied).isNotSameAs(original).isEqualTo(original); + assertThat(copied.getStopSequences()).isNotSameAs(original.getStopSequences()); + assertThat(copied.getAdditionalProperties()).isNotSameAs(original.getAdditionalProperties()); + } + + @Test + void testSetters() { + WatsonxAiChatOptions options = new WatsonxAiChatOptions(); + options.setTemperature(0.6); + options.setTopP(0.6); + options.setTopK(50); + options.setDecodingMethod("greedy"); + options.setMaxNewTokens(100); + options.setMinNewTokens(10); + options.setStopSequences(List.of("test")); + options.setRepetitionPenalty(1.2); + options.setRandomSeed(42); + options.setModel("test-model"); + options.addAdditionalProperty("key1", "value1"); + + assertThat(options.getTemperature()).isEqualTo(0.6); + assertThat(options.getTopP()).isEqualTo(0.6); + assertThat(options.getTopK()).isEqualTo(50); + assertThat(options.getDecodingMethod()).isEqualTo("greedy"); + assertThat(options.getMaxNewTokens()).isEqualTo(100); + assertThat(options.getMinNewTokens()).isEqualTo(10); + assertThat(options.getStopSequences()).isEqualTo(List.of("test")); + assertThat(options.getRepetitionPenalty()).isEqualTo(1.2); + assertThat(options.getRandomSeed()).isEqualTo(42); + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getAdditionalProperties()).isEqualTo(Map.of("key1", "value1")); + } + + @Test + void testDefaultValues() { + WatsonxAiChatOptions options = new WatsonxAiChatOptions(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getTopK()).isNull(); + assertThat(options.getDecodingMethod()).isNull(); + assertThat(options.getMaxNewTokens()).isNull(); + assertThat(options.getMinNewTokens()).isNull(); + assertThat(options.getStopSequences()).isNull(); + assertThat(options.getRepetitionPenalty()).isNull(); + assertThat(options.getRandomSeed()).isNull(); + assertThat(options.getModel()).isNull(); + assertThat(options.getAdditionalProperties()).isEqualTo(new HashMap<>()); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/AbstractChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/AbstractChatOptions.java new file mode 100644 index 00000000000..8dd4c82857a --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/AbstractChatOptions.java @@ -0,0 +1,223 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.prompt; + +import java.util.List; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.lang.Nullable; + +/** + * Abstract base class for {@link ChatOptions}, providing common implementation for its + * methods. + * + * @author Alexandros Pappas + */ +public abstract class AbstractChatOptions implements ChatOptions { + + /** + * ID of the model to use. + */ + @JsonProperty("model") + protected String model; + + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on their + * existing frequency in the text so far, decreasing the model's likelihood to repeat + * the same line verbatim. + */ + @JsonProperty("frequency_penalty") + protected Double frequencyPenalty; + + /** + * The maximum number of tokens to generate in the chat completion. The total length + * of input tokens and generated tokens is limited by the model's context length. + */ + @JsonProperty("max_tokens") + protected Integer maxTokens; + + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether + * they appear in the text so far, increasing the model's likelihood to talk about new + * topics. + */ + @JsonProperty("presence_penalty") + protected Double presencePenalty; + + /** + * Sets when the LLM should stop. (e.g., ["\n\n\n"]) then when the LLM generates three + * consecutive line breaks it will terminate. Stop sequences are ignored until after + * the number of tokens that are specified in the Min tokens parameter are generated. + */ + @JsonProperty("stop_sequences") + protected List stopSequences; + + /** + * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make + * the output more random, while lower values like 0.2 will make it more focused and + * deterministic. We generally recommend altering this or top_p but not both. + */ + @JsonProperty("temperature") + protected Double temperature; + + /** + * Reduces the probability of generating nonsense. A higher value (e.g. 100) will give + * more diverse answers, while a lower value (e.g. 10) will be more conservative. + * (Default: 40) + */ + @JsonProperty("top_k") + protected Integer topK; + + /** + * An alternative to sampling with temperature, called nucleus sampling, where the + * model considers the results of the tokens with top_p probability mass. So 0.1 means + * only the tokens comprising the top 10% probability mass are considered. We + * generally recommend altering this or temperature but not both. + */ + @JsonProperty("top_p") + protected Double topP; + + @Override + public String getModel() { + return this.model; + } + + @Override + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + + @Override + @Nullable + public Integer getMaxTokens() { + return this.maxTokens; + } + + @Override + public Double getPresencePenalty() { + return this.presencePenalty; + } + + @Override + @JsonIgnore + public List getStopSequences() { + return this.stopSequences; + } + + @Override + public Double getTemperature() { + return this.temperature; + } + + @Override + public Integer getTopK() { + return this.topK; + } + + @Override + @Nullable + public Double getTopP() { + return this.topP; + } + + /** + * Generic Abstract Builder for {@link AbstractChatOptions}. Ensures fluent API in + * subclasses with proper typing. + */ + public abstract static class Builder> + implements ChatOptions.Builder { + + protected T options; + + protected String model; + + protected Double frequencyPenalty; + + protected Integer maxTokens; + + protected Double presencePenalty; + + protected List stopSequences; + + protected Double temperature; + + protected Integer topK; + + protected Double topP; + + protected abstract B self(); + + public Builder(T options) { + this.options = options; + } + + @Override + public B model(String model) { + this.model = model; + return this.self(); + } + + @Override + public B frequencyPenalty(Double frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this.self(); + } + + @Override + public B maxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + return this.self(); + } + + @Override + public B presencePenalty(Double presencePenalty) { + this.presencePenalty = presencePenalty; + return this.self(); + } + + @Override + public B stopSequences(List stopSequences) { + this.stopSequences = stopSequences; + return this.self(); + } + + @Override + public B temperature(Double temperature) { + this.temperature = temperature; + return this.self(); + } + + @Override + public B topK(Integer topK) { + this.topK = topK; + return this.self(); + } + + @Override + public B topP(Double topP) { + this.topP = topP; + return this.self(); + } + + @Override + public abstract T build(); + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java index 19cb98a3a6b..cad9af8a3eb 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,8 @@ /** * {@link ModelOptions} representing the common options that are portable across different * chat models. + * + * @author Alexandros Pappas */ public interface ChatOptions extends ModelOptions { @@ -95,7 +97,7 @@ public interface ChatOptions extends ModelOptions { * @return Returns a new {@link ChatOptions.Builder}. */ static ChatOptions.Builder builder() { - return new DefaultChatOptionsBuilder(); + return new DefaultChatOptions.Builder(); } /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java index 1af33bf3467..2f15f748e2e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,115 +17,133 @@ package org.springframework.ai.chat.prompt; import java.util.ArrayList; -import java.util.Collections; import java.util.List; +import java.util.Objects; /** * Default implementation for the {@link ChatOptions}. + * + * @author Alexandros Pappas */ -public class DefaultChatOptions implements ChatOptions { - - private String model; - - private Double frequencyPenalty; - - private Integer maxTokens; - - private Double presencePenalty; +public class DefaultChatOptions extends AbstractChatOptions { - private List stopSequences; - - private Double temperature; - - private Integer topK; - - private Double topP; - - @Override - public String getModel() { - return this.model; + public static Builder builder() { + return new Builder(); } public void setModel(String model) { this.model = model; } - @Override - public Double getFrequencyPenalty() { - return this.frequencyPenalty; - } - public void setFrequencyPenalty(Double frequencyPenalty) { this.frequencyPenalty = frequencyPenalty; } - @Override - public Integer getMaxTokens() { - return this.maxTokens; - } - public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } - @Override - public Double getPresencePenalty() { - return this.presencePenalty; - } - public void setPresencePenalty(Double presencePenalty) { this.presencePenalty = presencePenalty; } - @Override - public List getStopSequences() { - return this.stopSequences != null ? Collections.unmodifiableList(this.stopSequences) : null; - } - public void setStopSequences(List stopSequences) { this.stopSequences = stopSequences; } - @Override - public Double getTemperature() { - return this.temperature; - } - public void setTemperature(Double temperature) { this.temperature = temperature; } - @Override - public Integer getTopK() { - return this.topK; - } - public void setTopK(Integer topK) { this.topK = topK; } - @Override - public Double getTopP() { - return this.topP; - } - public void setTopP(Double topP) { this.topP = topP; } @Override @SuppressWarnings("unchecked") - public T copy() { - DefaultChatOptions copy = new DefaultChatOptions(); - copy.setModel(this.getModel()); - copy.setFrequencyPenalty(this.getFrequencyPenalty()); - copy.setMaxTokens(this.getMaxTokens()); - copy.setPresencePenalty(this.getPresencePenalty()); - copy.setStopSequences(this.getStopSequences() != null ? new ArrayList<>(this.getStopSequences()) : null); - copy.setTemperature(this.getTemperature()); - copy.setTopK(this.getTopK()); - copy.setTopP(this.getTopP()); - return (T) copy; + public DefaultChatOptions copy() { + return DefaultChatOptions.builder() + .model(this.getModel()) + .frequencyPenalty(this.getFrequencyPenalty()) + .maxTokens(this.getMaxTokens()) + .presencePenalty(this.getPresencePenalty()) + .stopSequences(this.getStopSequences() == null ? null : new ArrayList<>(this.getStopSequences())) + .temperature(this.getTemperature()) + .topK(this.getTopK()) + .topP(this.getTopP()) + .build(); + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (o == null || getClass() != o.getClass()) + return false; + + DefaultChatOptions that = (DefaultChatOptions) o; + + if (!Objects.equals(model, that.model)) + return false; + if (!Objects.equals(frequencyPenalty, that.frequencyPenalty)) + return false; + if (!Objects.equals(maxTokens, that.maxTokens)) + return false; + if (!Objects.equals(presencePenalty, that.presencePenalty)) + return false; + if (!Objects.equals(stopSequences, that.stopSequences)) + return false; + if (!Objects.equals(temperature, that.temperature)) + return false; + if (!Objects.equals(topK, that.topK)) + return false; + return Objects.equals(topP, that.topP); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); + result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); + result = prime * result + ((this.frequencyPenalty == null) ? 0 : this.frequencyPenalty.hashCode()); + result = prime * result + ((this.presencePenalty == null) ? 0 : this.presencePenalty.hashCode()); + result = prime * result + ((this.stopSequences == null) ? 0 : this.stopSequences.hashCode()); + result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); + result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); + result = prime * result + ((this.topK == null) ? 0 : this.topK.hashCode()); + return result; + } + + public static class Builder extends AbstractChatOptions.Builder { + + public Builder() { + super(new DefaultChatOptions()); + } + + @Override + protected Builder self() { + return this; + } + + @Override + public DefaultChatOptions build() { + DefaultChatOptions optionsToBuild = new DefaultChatOptions(); + optionsToBuild.setModel(this.model); + optionsToBuild.setFrequencyPenalty(this.frequencyPenalty); + optionsToBuild.setMaxTokens(this.maxTokens); + optionsToBuild.setPresencePenalty(this.presencePenalty); + optionsToBuild.setStopSequences(this.stopSequences); + optionsToBuild.setTemperature(this.temperature); + optionsToBuild.setTopK(this.topK); + optionsToBuild.setTopP(this.topP); + return optionsToBuild; + } + } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java deleted file mode 100644 index 47ba5840109..00000000000 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright 2024-2024 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.chat.prompt; - -import java.util.List; - -/** - * Implementation of {@link ChatOptions.Builder} to create {@link DefaultChatOptions}. - */ -public class DefaultChatOptionsBuilder implements ChatOptions.Builder { - - protected DefaultChatOptions options; - - public DefaultChatOptionsBuilder() { - this.options = new DefaultChatOptions(); - } - - protected DefaultChatOptionsBuilder(DefaultChatOptions options) { - this.options = options; - } - - public DefaultChatOptionsBuilder model(String model) { - this.options.setModel(model); - return this; - } - - public DefaultChatOptionsBuilder frequencyPenalty(Double frequencyPenalty) { - this.options.setFrequencyPenalty(frequencyPenalty); - return this; - } - - public DefaultChatOptionsBuilder maxTokens(Integer maxTokens) { - this.options.setMaxTokens(maxTokens); - return this; - } - - public DefaultChatOptionsBuilder presencePenalty(Double presencePenalty) { - this.options.setPresencePenalty(presencePenalty); - return this; - } - - public DefaultChatOptionsBuilder stopSequences(List stop) { - this.options.setStopSequences(stop); - return this; - } - - public DefaultChatOptionsBuilder temperature(Double temperature) { - this.options.setTemperature(temperature); - return this; - } - - public DefaultChatOptionsBuilder topK(Integer topK) { - this.options.setTopK(topK); - return this; - } - - public DefaultChatOptionsBuilder topP(Double topP) { - this.options.setTopP(topP); - return this; - } - - public ChatOptions build() { - return this.options.copy(); - } - -} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java index 247e82b6f00..ca9a2a10db0 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,6 @@ package org.springframework.ai.chat.prompt; -import java.util.ArrayList; import java.util.List; import java.util.Set; @@ -27,13 +26,13 @@ import org.springframework.ai.model.function.FunctionCallingOptions; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; /** * Unit Tests for {@link ChatOptions} builder. * * @author youngmon * @author Mark Pollack + * @author Alexandros Pappas * @since 1.0.0 */ public class ChatOptionsBuilderTests { @@ -163,15 +162,4 @@ void shouldHaveExpectedDefaultValues() { assertThat(options.getStopSequences()).isNull(); } - @Test - void shouldBeImmutableAfterBuild() { - // Given - List stopSequences = new ArrayList<>(List.of("stop1", "stop2")); - ChatOptions options = this.builder.stopSequences(stopSequences).build(); - - // Then - assertThatThrownBy(() -> options.getStopSequences().add("stop3")) - .isInstanceOf(UnsupportedOperationException.class); - } - } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/DefaultChatOptionsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/DefaultChatOptionsTests.java new file mode 100644 index 00000000000..ece5853aa95 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/DefaultChatOptionsTests.java @@ -0,0 +1,106 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.prompt; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link DefaultChatOptions}. + * + * @author Alexandros Pappas + */ +class DefaultChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + DefaultChatOptions options = DefaultChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.5) + .maxTokens(100) + .presencePenalty(0.6) + .stopSequences(List.of("stop1", "stop2")) + .temperature(0.7) + .topK(50) + .topP(0.8) + .build(); + + assertThat(options) + .extracting("model", "frequencyPenalty", "maxTokens", "presencePenalty", "stopSequences", "temperature", + "topK", "topP") + .containsExactly("test-model", 0.5, 100, 0.6, List.of("stop1", "stop2"), 0.7, 50, 0.8); + } + + @Test + void testCopy() { + DefaultChatOptions original = DefaultChatOptions.builder() + .model("test-model") + .frequencyPenalty(0.5) + .maxTokens(100) + .presencePenalty(0.6) + .stopSequences(List.of("stop1", "stop2")) + .temperature(0.7) + .topK(50) + .topP(0.8) + .build(); + + DefaultChatOptions copied = original.copy(); + + assertThat(copied).isNotSameAs(original).isEqualTo(original); + // Ensure deep copy + assertThat(copied.getStopSequences()).isNotSameAs(original.getStopSequences()); + } + + @Test + void testSetters() { + DefaultChatOptions options = new DefaultChatOptions(); + options.setModel("test-model"); + options.setFrequencyPenalty(0.5); + options.setMaxTokens(100); + options.setPresencePenalty(0.6); + options.setStopSequences(List.of("stop1", "stop2")); + options.setTemperature(0.7); + options.setTopK(50); + options.setTopP(0.8); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getFrequencyPenalty()).isEqualTo(0.5); + assertThat(options.getMaxTokens()).isEqualTo(100); + assertThat(options.getPresencePenalty()).isEqualTo(0.6); + assertThat(options.getStopSequences()).isEqualTo(List.of("stop1", "stop2")); + assertThat(options.getTemperature()).isEqualTo(0.7); + assertThat(options.getTopK()).isEqualTo(50); + assertThat(options.getTopP()).isEqualTo(0.8); + } + + @Test + void testDefaultValues() { + DefaultChatOptions options = new DefaultChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getFrequencyPenalty()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getPresencePenalty()).isNull(); + assertThat(options.getStopSequences()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopK()).isNull(); + assertThat(options.getTopP()).isNull(); + } + +}