From 1815681f4c0f0c81ca71ae35396345ad49d4a8c8 Mon Sep 17 00:00:00 2001 From: Alexandros Pappas Date: Fri, 14 Feb 2025 13:19:53 +0100 Subject: [PATCH] feat: update all *ChatOptions* Classes and add Unit Tests Signed-off-by: Alexandros Pappas --- .../ai/anthropic/AnthropicChatOptions.java | 59 +++-- .../azure/openai/AzureOpenAiChatOptions.java | 95 +++++--- .../ai/minimax/MiniMaxChatOptions.java | 224 ++++++------------ .../ai/mistralai/MistralAiChatOptions.java | 49 +++- .../ai/moonshot/MoonshotChatOptions.java | 191 +++++++-------- .../ai/oci/cohere/OCICohereChatOptions.java | 147 ++++++++---- .../oci/cohere/OCICohereChatOptionsTests.java | 2 +- .../ai/openai/OpenAiChatOptions.java | 60 ++++- .../ai/qianfan/QianFanChatOptions.java | 14 +- .../ai/watsonx/WatsonxAiChatOptions.java | 100 +++++--- .../ai/zhipuai/ZhiPuAiChatOptions.java | 141 ++--------- .../ai/zhipuai/ZhiPuAiChatOptionsTests.java | 131 ++++++++++ .../ai/chat/prompt/AbstractChatOptions.java | 223 ----------------- .../ai/chat/prompt/ChatOptions.java | 6 +- .../ai/chat/prompt/DefaultChatOptions.java | 144 +++++------ .../prompt/DefaultChatOptionsBuilder.java | 80 +++++++ .../chat/prompt/ChatOptionsBuilderTests.java | 16 +- .../chat/prompt/DefaultChatOptionsTests.java | 4 +- 18 files changed, 874 insertions(+), 812 deletions(-) create mode 100644 models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptionsTests.java delete mode 100644 spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/AbstractChatOptions.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index d17face00bd..ad9e31eb01f 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -32,12 +32,10 @@ import org.springframework.ai.anthropic.api.AnthropicApi; import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest; -import org.springframework.ai.chat.prompt.AbstractChatOptions; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.ToolCallback; import org.springframework.lang.Nullable; -import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.util.Assert; /** @@ -49,11 +47,16 @@ * @since 1.0.0 */ @JsonInclude(Include.NON_NULL) -public class AnthropicChatOptions extends AbstractChatOptions implements ToolCallingChatOptions { +public class AnthropicChatOptions implements ToolCallingChatOptions { // @formatter:off - + private @JsonProperty("model") String model; + private @JsonProperty("max_tokens") Integer maxTokens; private @JsonProperty("metadata") ChatCompletionRequest.Metadata metadata; + private @JsonProperty("stop_sequences") List stopSequences; + private @JsonProperty("temperature") Double temperature; + private @JsonProperty("top_p") Double topP; + private @JsonProperty("top_k") Integer topK; /** * Collection of {@link ToolCallback}s to be used for tool calling in the chat @@ -100,10 +103,20 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) .build(); } + @Override + public String getModel() { + return this.model; + } + public void setModel(String model) { this.model = model; } + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } @@ -116,18 +129,38 @@ public void setMetadata(ChatCompletionRequest.Metadata metadata) { this.metadata = metadata; } + @Override + public List getStopSequences() { + return this.stopSequences; + } + public void setStopSequences(List stopSequences) { this.stopSequences = stopSequences; } + @Override + public Double getTemperature() { + return this.temperature; + } + public void setTemperature(Double temperature) { this.temperature = temperature; } + @Override + public Double getTopP() { + return this.topP; + } + public void setTopP(Double topP) { this.topP = topP; } + @Override + public Integer getTopK() { + return this.topK; + } + public void setTopK(Integer topK) { this.topK = topK; } @@ -240,6 +273,7 @@ public void setToolContext(Map toolContext) { } @Override + @SuppressWarnings("unchecked") public AnthropicChatOptions copy() { return fromOptions(this); } @@ -264,21 +298,8 @@ public boolean equals(Object o) { @Override public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + (this.model != null ? this.model.hashCode() : 0); - result = prime * result + (this.maxTokens != null ? this.maxTokens.hashCode() : 0); - result = prime * result + (this.metadata != null ? this.metadata.hashCode() : 0); - result = prime * result + (this.stopSequences != null ? this.stopSequences.hashCode() : 0); - result = prime * result + (this.temperature != null ? this.temperature.hashCode() : 0); - result = prime * result + (this.topP != null ? this.topP.hashCode() : 0); - result = prime * result + (this.topK != null ? this.topK.hashCode() : 0); - result = prime * result + (this.toolCallbacks != null ? this.toolCallbacks.hashCode() : 0); - result = prime * result + (this.toolNames != null ? this.toolNames.hashCode() : 0); - result = prime * result - + (this.internalToolExecutionEnabled != null ? this.internalToolExecutionEnabled.hashCode() : 0); - result = prime * result + (this.toolContext != null ? this.toolContext.hashCode() : 0); - return result; + return Objects.hash(model, maxTokens, metadata, stopSequences, temperature, topP, topK, toolCallbacks, + toolNames, internalToolExecutionEnabled, toolContext); } public static class Builder { diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java index 0c6ab51d0fa..e516034d316 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,7 +32,6 @@ import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.chat.prompt.AbstractChatOptions; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.tool.ToolCallback; @@ -51,7 +50,34 @@ * @author Alexandros Pappas */ @JsonInclude(Include.NON_NULL) -public class AzureOpenAiChatOptions extends AbstractChatOptions implements ToolCallingChatOptions { +public class AzureOpenAiChatOptions implements ToolCallingChatOptions { + + /** + * The maximum number of tokens to generate. + */ + @JsonProperty("max_tokens") + private Integer maxTokens; + + /** + * The sampling temperature to use that controls the apparent creativity of generated + * completions. Higher values will make output more random while lower values will + * make results more focused and deterministic. It is not recommended to modify + * temperature and top_p for the same completions request as the interaction of these + * two settings is difficult to predict. + */ + @JsonProperty("temperature") + private Double temperature; + + /** + * An alternative to sampling with temperature called nucleus sampling. This value + * causes the model to consider the results of tokens with the provided probability + * mass. As an example, a value of 0.15 will cause only the tokens comprising the top + * 15% of probability mass to be considered. It is not recommended to modify + * temperature and top_p for the same completions request as the interaction of these + * two settings is difficult to predict. + */ + @JsonProperty("top_p") + private Double topP; /** * A map between GPT token IDs and bias scores that influences the probability of @@ -85,6 +111,24 @@ public class AzureOpenAiChatOptions extends AbstractChatOptions implements ToolC @JsonProperty("stop") private List stop; + /** + * A value that influences the probability of generated tokens appearing based on + * their existing presence in generated text. Positive values will make tokens less + * likely to appear when they already exist and increase the model's likelihood to + * output new topics. + */ + @JsonProperty("presence_penalty") + private Double presencePenalty; + + /** + * A value that influences the probability of generated tokens appearing based on + * their cumulative frequency in generated text. Positive values will make tokens less + * likely to appear as their frequency increases and decrease the likelihood of the + * model repeating the same statements verbatim. + */ + @JsonProperty("frequency_penalty") + private Double frequencyPenalty; + /** * The deployment name as defined in Azure Open AI Studio when creating a deployment * backed by an Azure OpenAI base model. @@ -279,10 +323,20 @@ public void setStop(List stop) { this.stop = stop; } + @Override + public Double getPresencePenalty() { + return this.presencePenalty; + } + public void setPresencePenalty(Double presencePenalty) { this.presencePenalty = presencePenalty; } + @Override + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + public void setFrequencyPenalty(Double frequencyPenalty) { this.frequencyPenalty = frequencyPenalty; } @@ -306,6 +360,11 @@ public void setDeploymentName(String deploymentName) { this.deploymentName = deploymentName; } + @Override + public Double getTemperature() { + return this.temperature; + } + public void setTemperature(Double temperature) { this.temperature = temperature; } @@ -422,7 +481,7 @@ public void setStreamOptions(ChatCompletionStreamOptions streamOptions) { } @Override - @SuppressWarnings("") + @SuppressWarnings("unchecked") public AzureOpenAiChatOptions copy() { return fromOptions(this); } @@ -454,30 +513,10 @@ public boolean equals(Object o) { @Override public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + (this.logitBias != null ? this.logitBias.hashCode() : 0); - result = prime * result + (this.user != null ? this.user.hashCode() : 0); - result = prime * result + (this.n != null ? this.n.hashCode() : 0); - result = prime * result + (this.stop != null ? this.stop.hashCode() : 0); - result = prime * result + (this.deploymentName != null ? this.deploymentName.hashCode() : 0); - result = prime * result + (this.responseFormat != null ? this.responseFormat.hashCode() : 0); - result = prime * result + (this.toolCallbacks != null ? this.toolCallbacks.hashCode() : 0); - result = prime * result + (this.toolNames != null ? this.toolNames.hashCode() : 0); - result = prime * result - + (this.internalToolExecutionEnabled != null ? this.internalToolExecutionEnabled.hashCode() : 0); - result = prime * result + (this.seed != null ? this.seed.hashCode() : 0); - result = prime * result + (this.logprobs != null ? this.logprobs.hashCode() : 0); - result = prime * result + (this.topLogProbs != null ? this.topLogProbs.hashCode() : 0); - result = prime * result + (this.enhancements != null ? this.enhancements.hashCode() : 0); - result = prime * result + (this.streamOptions != null ? this.streamOptions.hashCode() : 0); - result = prime * result + (this.toolContext != null ? this.toolContext.hashCode() : 0); - result = prime * result + (this.maxTokens != null ? this.maxTokens.hashCode() : 0); - result = prime * result + (this.frequencyPenalty != null ? this.frequencyPenalty.hashCode() : 0); - result = prime * result + (this.presencePenalty != null ? this.presencePenalty.hashCode() : 0); - result = prime * result + (this.temperature != null ? this.temperature.hashCode() : 0); - result = prime * result + (this.topP != null ? this.topP.hashCode() : 0); - return result; + return Objects.hash(this.logitBias, this.user, this.n, this.stop, this.deploymentName, this.responseFormat, + this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.seed, this.logprobs, + this.topLogProbs, this.enhancements, this.streamOptions, this.toolContext, this.maxTokens, + this.frequencyPenalty, this.presencePenalty, this.temperature, this.topP); } public static class Builder { diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java index 8e16213fdd2..491da85b457 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatOptions.java @@ -21,6 +21,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -28,7 +29,6 @@ import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.chat.prompt.AbstractChatOptions; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.minimax.api.MiniMaxApi; import org.springframework.ai.model.function.FunctionCallback; @@ -49,14 +49,33 @@ * @since 1.0.0 M1 */ @JsonInclude(Include.NON_NULL) -public class MiniMaxChatOptions extends AbstractChatOptions implements FunctionCallingOptions { +public class MiniMaxChatOptions implements FunctionCallingOptions { // @formatter:off + /** + * ID of the model to use. + */ + private @JsonProperty("model") String model; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing + * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + */ + private @JsonProperty("frequency_penalty") Double frequencyPenalty; + /** + * The maximum number of tokens to generate in the chat completion. The total length of input + * tokens and generated tokens is limited by the model's context length. + */ + private @JsonProperty("max_tokens") Integer maxTokens; /** * How many chat completion choices to generate for each input message. Note that you will be charged based * on the number of generated tokens across all of the choices. Keep n as 1 to minimize costs. */ private @JsonProperty("n") Integer n; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they + * appear in the text so far, increasing the model's likelihood to talk about new topics. + */ + private @JsonProperty("presence_penalty") Double presencePenalty; /** * An object specifying the format that the model must output. Setting to { "type": * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. @@ -73,7 +92,18 @@ public class MiniMaxChatOptions extends AbstractChatOptions implements FunctionC * Up to 4 sequences where the API will stop generating further tokens. */ private @JsonProperty("stop") List stop; - + /** + * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output + * more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend + * altering this or top_p but not both. + */ + private @JsonProperty("temperature") Double temperature; + /** + * An alternative to sampling with temperature, called nucleus sampling, where the model considers the + * results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + * probability mass are considered. We generally recommend altering this or temperature but not both. + */ + private @JsonProperty("top_p") Double topP; /** * Mask the text information in the output that is easy to involve privacy issues, * including but not limited to email, domain name, link, ID number, home address, etc. @@ -148,14 +178,29 @@ public static MiniMaxChatOptions fromOptions(MiniMaxChatOptions fromOptions) { .build(); } + @Override + public String getModel() { + return this.model; + } + public void setModel(String model) { this.model = model; } + @Override + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + public void setFrequencyPenalty(Double frequencyPenalty) { this.frequencyPenalty = frequencyPenalty; } + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } @@ -168,6 +213,11 @@ public void setN(Integer n) { this.n = n; } + @Override + public Double getPresencePenalty() { + return this.presencePenalty; + } + public void setPresencePenalty(Double presencePenalty) { this.presencePenalty = presencePenalty; } @@ -207,10 +257,20 @@ public void setStop(List stop) { this.stop = stop; } + @Override + public Double getTemperature() { + return this.temperature; + } + public void setTemperature(Double temperature) { this.temperature = temperature; } + @Override + public Double getTopP() { + return this.topP; + } + public void setTopP(Double topP) { this.topP = topP; } @@ -285,24 +345,8 @@ public void setToolContext(Map toolContext) { @Override public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); - result = prime * result + ((this.frequencyPenalty == null) ? 0 : this.frequencyPenalty.hashCode()); - result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); - result = prime * result + ((this.n == null) ? 0 : this.n.hashCode()); - result = prime * result + ((this.presencePenalty == null) ? 0 : this.presencePenalty.hashCode()); - result = prime * result + ((this.responseFormat == null) ? 0 : this.responseFormat.hashCode()); - result = prime * result + ((this.seed == null) ? 0 : this.seed.hashCode()); - result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); - result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); - result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); - result = prime * result + ((this.maskSensitiveInfo == null) ? 0 : this.maskSensitiveInfo.hashCode()); - result = prime * result + ((this.tools == null) ? 0 : this.tools.hashCode()); - result = prime * result + ((this.toolChoice == null) ? 0 : this.toolChoice.hashCode()); - result = prime * result + ((this.proxyToolCalls == null) ? 0 : this.proxyToolCalls.hashCode()); - result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode()); - return result; + return Objects.hash(model, frequencyPenalty, maxTokens, n, presencePenalty, responseFormat, seed, stop, + temperature, topP, maskSensitiveInfo, tools, toolChoice, proxyToolCalls, toolContext); } @Override @@ -310,139 +354,25 @@ public boolean equals(Object obj) { if (this == obj) { return true; } - if (obj == null) { - return false; - } - if (getClass() != obj.getClass()) { + if (obj == null || getClass() != obj.getClass()) { return false; } MiniMaxChatOptions other = (MiniMaxChatOptions) obj; - if (this.model == null) { - if (other.model != null) { - return false; - } - } - else if (!this.model.equals(other.model)) { - return false; - } - if (this.frequencyPenalty == null) { - if (other.frequencyPenalty != null) { - return false; - } - } - else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) { - return false; - } - if (this.maxTokens == null) { - if (other.maxTokens != null) { - return false; - } - } - else if (!this.maxTokens.equals(other.maxTokens)) { - return false; - } - if (this.n == null) { - if (other.n != null) { - return false; - } - } - else if (!this.n.equals(other.n)) { - return false; - } - if (this.presencePenalty == null) { - if (other.presencePenalty != null) { - return false; - } - } - else if (!this.presencePenalty.equals(other.presencePenalty)) { - return false; - } - if (this.responseFormat == null) { - if (other.responseFormat != null) { - return false; - } - } - else if (!this.responseFormat.equals(other.responseFormat)) { - return false; - } - if (this.seed == null) { - if (other.seed != null) { - return false; - } - } - else if (!this.seed.equals(other.seed)) { - return false; - } - if (this.stop == null) { - if (other.stop != null) { - return false; - } - } - else if (!this.stop.equals(other.stop)) { - return false; - } - if (this.temperature == null) { - if (other.temperature != null) { - return false; - } - } - else if (!this.temperature.equals(other.temperature)) { - return false; - } - if (this.topP == null) { - if (other.topP != null) { - return false; - } - } - else if (!this.topP.equals(other.topP)) { - return false; - } - if (this.maskSensitiveInfo == null) { - if (other.maskSensitiveInfo != null) { - return false; - } - } - else if (!this.maskSensitiveInfo.equals(other.maskSensitiveInfo)) { - return false; - } - if (this.tools == null) { - if (other.tools != null) { - return false; - } - } - else if (!this.tools.equals(other.tools)) { - return false; - } - if (this.toolChoice == null) { - if (other.toolChoice != null) { - return false; - } - } - else if (!this.toolChoice.equals(other.toolChoice)) { - return false; - } - if (this.proxyToolCalls == null) { - if (other.proxyToolCalls != null) { - return false; - } - } - else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) { - return false; - } - - if (this.toolContext == null) { - if (other.toolContext != null) { - return false; - } - } - else if (!this.toolContext.equals(other.toolContext)) { - return false; - } - return true; + return Objects.equals(this.model, other.model) && Objects.equals(this.frequencyPenalty, other.frequencyPenalty) + && Objects.equals(this.maxTokens, other.maxTokens) && Objects.equals(this.n, other.n) + && Objects.equals(this.presencePenalty, other.presencePenalty) + && Objects.equals(this.responseFormat, other.responseFormat) && Objects.equals(this.seed, other.seed) + && Objects.equals(this.stop, other.stop) && Objects.equals(this.temperature, other.temperature) + && Objects.equals(this.topP, other.topP) + && Objects.equals(this.maskSensitiveInfo, other.maskSensitiveInfo) + && Objects.equals(this.tools, other.tools) && Objects.equals(this.toolChoice, other.toolChoice) + && Objects.equals(this.proxyToolCalls, other.proxyToolCalls) + && Objects.equals(this.toolContext, other.toolContext); } @Override + @SuppressWarnings("unchecked") public MiniMaxChatOptions copy() { return fromOptions(this); } diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java index 3882a1faf76..987c2b35182 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatOptions.java @@ -29,7 +29,6 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.chat.prompt.AbstractChatOptions; import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ResponseFormat; import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest.ToolChoice; @@ -50,7 +49,33 @@ * @since 0.8.1 */ @JsonInclude(JsonInclude.Include.NON_NULL) -public class MistralAiChatOptions extends AbstractChatOptions implements ToolCallingChatOptions { +public class MistralAiChatOptions implements ToolCallingChatOptions { + + /** + * ID of the model to use + */ + private @JsonProperty("model") String model; + + /** + * What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will + * make the output more random, while lower values like 0.2 will make it more focused + * and deterministic. We generally recommend altering this or top_p but not both. + */ + private @JsonProperty("temperature") Double temperature; + + /** + * Nucleus sampling, where the model considers the results of the tokens with top_p + * probability mass. So 0.1 means only the tokens comprising the top 10% probability + * mass are considered. We generally recommend altering this or temperature but not + * both. + */ + private @JsonProperty("top_p") Double topP; + + /** + * The maximum number of tokens to generate in the completion. The token count of your + * prompt plus max_tokens cannot exceed the model's context length. + */ + private @JsonProperty("max_tokens") Integer maxTokens; /** * Whether to inject a safety prompt before all conversations. @@ -135,10 +160,20 @@ public static MistralAiChatOptions fromOptions(MistralAiChatOptions fromOptions) .build(); } + @Override + public String getModel() { + return this.model; + } + public void setModel(String model) { this.model = model; } + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } @@ -202,10 +237,20 @@ public void setToolChoice(ToolChoice toolChoice) { this.toolChoice = toolChoice; } + @Override + public Double getTemperature() { + return this.temperature; + } + public void setTemperature(Double temperature) { this.temperature = temperature; } + @Override + public Double getTopP() { + return this.topP; + } + public void setTopP(Double topP) { this.topP = topP; } diff --git a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java index 8806e03e9e6..a707a1c08d0 100644 --- a/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java +++ b/models/spring-ai-moonshot/src/main/java/org/springframework/ai/moonshot/MoonshotChatOptions.java @@ -21,13 +21,13 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.chat.prompt.AbstractChatOptions; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallingOptions; import org.springframework.ai.moonshot.api.MoonshotApi; @@ -41,7 +41,33 @@ * @author Alexandros Pappas */ @JsonInclude(JsonInclude.Include.NON_NULL) -public class MoonshotChatOptions extends AbstractChatOptions implements FunctionCallingOptions { +public class MoonshotChatOptions implements FunctionCallingOptions { + + /** + * ID of the model to use + */ + private @JsonProperty("model") String model; + + /** + * The maximum number of tokens to generate in the chat completion. The total length + * of input tokens and generated tokens is limited by the model's context length. + */ + private @JsonProperty("max_tokens") Integer maxTokens; + + /** + * What sampling temperature to use, between 0.0 and 1.0. Higher values like 0.8 will + * make the output more random, while lower values like 0.2 will make it more focused + * and deterministic. We generally recommend altering this or top_p but not both. + */ + private @JsonProperty("temperature") Double temperature; + + /** + * An alternative to sampling with temperature, called nucleus sampling, where the + * model considers the results of the tokens with top_p probability mass. So 0.1 means + * only the tokens comprising the top 10% probability mass are considered. We + * generally recommend altering this or temperature but not both. + */ + private @JsonProperty("top_p") Double topP; /** * How many chat completion choices to generate for each input message. Note that you @@ -50,6 +76,20 @@ public class MoonshotChatOptions extends AbstractChatOptions implements Function */ private @JsonProperty("n") Integer n; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether + * they appear in the text so far, increasing the model's likelihood to talk about new + * topics. + */ + private @JsonProperty("presence_penalty") Double presencePenalty; + + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on their + * existing frequency in the text so far, decreasing the model's likelihood to repeat + * the same line verbatim. + */ + private @JsonProperty("frequency_penalty") Double frequencyPenalty; + /** * Up to 5 sequences where the API will stop generating further tokens. */ @@ -128,14 +168,29 @@ public void setFunctions(Set functionNames) { this.functions = functionNames; } + @Override + public String getModel() { + return this.model; + } + public void setModel(String model) { this.model = model; } + @Override + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + public void setFrequencyPenalty(Double frequencyPenalty) { this.frequencyPenalty = frequencyPenalty; } + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } @@ -148,6 +203,11 @@ public void setN(Integer n) { this.n = n; } + @Override + public Double getPresencePenalty() { + return this.presencePenalty; + } + public void setPresencePenalty(Double presencePenalty) { this.presencePenalty = presencePenalty; } @@ -171,10 +231,20 @@ public void setStop(List stop) { this.stop = stop; } + @Override + public Double getTemperature() { + return this.temperature; + } + public void setTemperature(Double temperature) { this.temperature = temperature; } + @Override + public Double getTopP() { + return this.topP; + } + public void setTopP(Double topP) { this.topP = topP; } @@ -213,6 +283,7 @@ public void setToolContext(Map toolContext) { } @Override + @SuppressWarnings("unchecked") public MoonshotChatOptions copy() { return builder().model(this.model) .maxTokens(this.maxTokens) @@ -234,117 +305,25 @@ public MoonshotChatOptions copy() { @Override public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); - result = prime * result + ((this.frequencyPenalty == null) ? 0 : this.frequencyPenalty.hashCode()); - result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); - result = prime * result + ((this.n == null) ? 0 : this.n.hashCode()); - result = prime * result + ((this.presencePenalty == null) ? 0 : this.presencePenalty.hashCode()); - result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); - result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); - result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); - result = prime * result + ((this.user == null) ? 0 : this.user.hashCode()); - result = prime * result + ((this.proxyToolCalls == null) ? 0 : this.proxyToolCalls.hashCode()); - result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode()); - return result; + return Objects.hash(model, frequencyPenalty, maxTokens, n, presencePenalty, stop, temperature, topP, user, + proxyToolCalls, toolContext); } @Override public boolean equals(Object obj) { - if (this == obj) { + if (this == obj) return true; - } - if (obj == null) { + if (obj == null || getClass() != obj.getClass()) return false; - } - if (getClass() != obj.getClass()) { - return false; - } + MoonshotChatOptions other = (MoonshotChatOptions) obj; - if (this.model == null) { - if (other.model != null) { - return false; - } - } - else if (!this.model.equals(other.model)) { - return false; - } - if (this.frequencyPenalty == null) { - if (other.frequencyPenalty != null) { - return false; - } - } - else if (!this.frequencyPenalty.equals(other.frequencyPenalty)) { - return false; - } - if (this.maxTokens == null) { - if (other.maxTokens != null) { - return false; - } - } - else if (!this.maxTokens.equals(other.maxTokens)) { - return false; - } - if (this.n == null) { - if (other.n != null) { - return false; - } - } - else if (!this.n.equals(other.n)) { - return false; - } - if (this.presencePenalty == null) { - if (other.presencePenalty != null) { - return false; - } - } - else if (!this.presencePenalty.equals(other.presencePenalty)) { - return false; - } - if (this.stop == null) { - if (other.stop != null) { - return false; - } - } - else if (!this.stop.equals(other.stop)) { - return false; - } - if (this.temperature == null) { - if (other.temperature != null) { - return false; - } - } - else if (!this.temperature.equals(other.temperature)) { - return false; - } - if (this.topP == null) { - if (other.topP != null) { - return false; - } - } - else if (!this.topP.equals(other.topP)) { - return false; - } - if (this.user == null) { - return other.user == null; - } - else if (!this.user.equals(other.user)) { - return false; - } - if (this.proxyToolCalls == null) { - return other.proxyToolCalls == null; - } - else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) { - return false; - } - if (this.toolContext == null) { - return other.toolContext == null; - } - else if (!this.toolContext.equals(other.toolContext)) { - return false; - } - return true; + + return Objects.equals(this.model, other.model) && Objects.equals(this.frequencyPenalty, other.frequencyPenalty) + && Objects.equals(this.maxTokens, other.maxTokens) && Objects.equals(this.n, other.n) + && Objects.equals(this.presencePenalty, other.presencePenalty) && Objects.equals(this.stop, other.stop) + && Objects.equals(this.temperature, other.temperature) && Objects.equals(this.topP, other.topP) + && Objects.equals(this.user, other.user) && Objects.equals(this.proxyToolCalls, other.proxyToolCalls) + && Objects.equals(this.toolContext, other.toolContext); } public static class Builder { diff --git a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java index 1ae88b3eac8..1516d88c252 100644 --- a/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java +++ b/models/spring-ai-oci-genai/src/main/java/org/springframework/ai/oci/cohere/OCICohereChatOptions.java @@ -18,12 +18,12 @@ import java.util.ArrayList; import java.util.List; +import java.util.Objects; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; import com.oracle.bmc.generativeaiinference.model.CohereTool; -import org.springframework.ai.chat.prompt.AbstractChatOptions; import org.springframework.ai.chat.prompt.ChatOptions; /** @@ -34,7 +34,16 @@ * @author Alexandros Pappas */ @JsonInclude(JsonInclude.Include.NON_NULL) -public class OCICohereChatOptions extends AbstractChatOptions implements ChatOptions { +public class OCICohereChatOptions implements ChatOptions { + + @JsonProperty("model") + private String model; + + /** + * The maximum number of tokens to generate per request. + */ + @JsonProperty("maxTokens") + private Integer maxTokens; /** * The OCI Compartment to run chat requests in. @@ -54,6 +63,43 @@ public class OCICohereChatOptions extends AbstractChatOptions implements ChatOpt @JsonProperty("preambleOverride") private String preambleOverride; + /** + * The sample temperature, where higher values are more random, and lower values are + * more deterministic. + */ + @JsonProperty("temperature") + private Double temperature; + + /** + * The Top P parameter modifies the probability of tokens sampled. E.g., a value of + * 0.25 means only tokens from the top 25% probability mass will be considered. + */ + @JsonProperty("topP") + private Double topP; + + /** + * The Top K parameter limits the number of potential tokens considered at each step + * of text generation. E.g., a value of 5 means only the top 5 most probable tokens + * will be considered during each step of text generation. + */ + @JsonProperty("topK") + private Integer topK; + + /** + * The frequency penalty assigns a penalty to repeated tokens depending on how many + * times it has already appeared in the prompt or output. Higher values will reduce + * repeated tokens and outputs will be more random. + */ + @JsonProperty("frequencyPenalty") + private Double frequencyPenalty; + + /** + * The presence penalty assigns a penalty to each token when it appears in the output + * to encourage generating outputs with tokens that haven't been used. + */ + @JsonProperty("presencePenalty") + private Double presencePenalty; + /** * A collection of textual sequences that will end completions generation. */ @@ -173,17 +219,58 @@ public void setTools(List tools) { * ChatModel overrides. */ + @Override + public String getModel() { + return this.model; + } + + @Override + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + + @Override + public Double getPresencePenalty() { + return this.presencePenalty; + } + @Override public List getStopSequences() { return this.stop; } + @Override + public Double getTemperature() { + return this.temperature; + } + + @Override + public Integer getTopK() { + return this.topK; + } + + @Override + public Double getTopP() { + return this.topP; + } + @Override @SuppressWarnings("unchecked") - public OCICohereChatOptions copy() { + public ChatOptions copy() { return fromOptions(this); } + @Override + public int hashCode() { + return Objects.hash(model, maxTokens, compartment, servingMode, preambleOverride, temperature, topP, topK, stop, + frequencyPenalty, presencePenalty, documents, tools); + } + @Override public boolean equals(Object o) { if (this == o) @@ -193,51 +280,15 @@ public boolean equals(Object o) { OCICohereChatOptions that = (OCICohereChatOptions) o; - if (model != null ? !model.equals(that.model) : that.model != null) - return false; - if (maxTokens != null ? !maxTokens.equals(that.maxTokens) : that.maxTokens != null) - return false; - if (compartment != null ? !compartment.equals(that.compartment) : that.compartment != null) - return false; - if (servingMode != null ? !servingMode.equals(that.servingMode) : that.servingMode != null) - return false; - if (preambleOverride != null ? !preambleOverride.equals(that.preambleOverride) : that.preambleOverride != null) - return false; - if (temperature != null ? !temperature.equals(that.temperature) : that.temperature != null) - return false; - if (topP != null ? !topP.equals(that.topP) : that.topP != null) - return false; - if (topK != null ? !topK.equals(that.topK) : that.topK != null) - return false; - if (stop != null ? !stop.equals(that.stop) : that.stop != null) - return false; - if (frequencyPenalty != null ? !frequencyPenalty.equals(that.frequencyPenalty) : that.frequencyPenalty != null) - return false; - if (presencePenalty != null ? !presencePenalty.equals(that.presencePenalty) : that.presencePenalty != null) - return false; - if (documents != null ? !documents.equals(that.documents) : that.documents != null) - return false; - return tools != null ? tools.equals(that.tools) : that.tools == null; - } - - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); - result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); - result = prime * result + ((this.compartment == null) ? 0 : this.compartment.hashCode()); - result = prime * result + ((this.servingMode == null) ? 0 : this.servingMode.hashCode()); - result = prime * result + ((this.preambleOverride == null) ? 0 : this.preambleOverride.hashCode()); - result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); - result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); - result = prime * result + ((this.topK == null) ? 0 : this.topK.hashCode()); - result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); - result = prime * result + ((this.frequencyPenalty == null) ? 0 : this.frequencyPenalty.hashCode()); - result = prime * result + ((this.presencePenalty == null) ? 0 : this.presencePenalty.hashCode()); - result = prime * result + ((this.documents == null) ? 0 : this.documents.hashCode()); - result = prime * result + ((this.tools == null) ? 0 : this.tools.hashCode()); - return result; + return Objects.equals(this.model, that.model) && Objects.equals(this.maxTokens, that.maxTokens) + && Objects.equals(this.compartment, that.compartment) + && Objects.equals(this.servingMode, that.servingMode) + && Objects.equals(this.preambleOverride, that.preambleOverride) + && Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP) + && Objects.equals(this.topK, that.topK) && Objects.equals(this.stop, that.stop) + && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) + && Objects.equals(this.presencePenalty, that.presencePenalty) + && Objects.equals(this.documents, that.documents) && Objects.equals(this.tools, that.tools); } public static class Builder { diff --git a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/cohere/OCICohereChatOptionsTests.java b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/cohere/OCICohereChatOptionsTests.java index ea57e37608c..4359d920bb5 100644 --- a/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/cohere/OCICohereChatOptionsTests.java +++ b/models/spring-ai-oci-genai/src/test/java/org/springframework/ai/oci/cohere/OCICohereChatOptionsTests.java @@ -73,7 +73,7 @@ void testCopy() { .tools(List.of(new CohereTool("test-tool", "test-context", Map.of()))) .build(); - OCICohereChatOptions copied = original.copy(); + OCICohereChatOptions copied = (OCICohereChatOptions) original.copy(); assertThat(copied).isNotSameAs(original).isEqualTo(original); // Ensure deep copy diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index 0ee45eb1d7c..b4d4b9b7220 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -30,7 +30,6 @@ import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.chat.prompt.AbstractChatOptions; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.tool.ToolCallingChatOptions; @@ -54,9 +53,18 @@ * @since 0.8.0 */ @JsonInclude(Include.NON_NULL) -public class OpenAiChatOptions extends AbstractChatOptions implements ToolCallingChatOptions { +public class OpenAiChatOptions implements ToolCallingChatOptions { // @formatter:off + /** + * ID of the model to use. + */ + private @JsonProperty("model") String model; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing + * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + */ + private @JsonProperty("frequency_penalty") Double frequencyPenalty; /** * Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object * that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. @@ -75,6 +83,11 @@ public class OpenAiChatOptions extends AbstractChatOptions implements ToolCallin * each with an associated log probability. 'logprobs' must be set to 'true' if this parameter is used. */ private @JsonProperty("top_logprobs") Integer topLogprobs; + /** + * The maximum number of tokens to generate in the chat completion. The total length of input + * tokens and generated tokens is limited by the model's context length. + */ + private @JsonProperty("max_tokens") Integer maxTokens; /** * An upper bound for the number of tokens that can be generated for a completion, * including visible output tokens and reasoning tokens. @@ -105,6 +118,12 @@ public class OpenAiChatOptions extends AbstractChatOptions implements ToolCallin */ private @JsonProperty("audio") AudioParameters outputAudio; + + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they + * appear in the text so far, increasing the model's likelihood to talk about new topics. + */ + private @JsonProperty("presence_penalty") Double presencePenalty; /** * An object specifying the format that the model must output. Setting to { "type": * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. @@ -125,6 +144,18 @@ public class OpenAiChatOptions extends AbstractChatOptions implements ToolCallin * Up to 4 sequences where the API will stop generating further tokens. */ private @JsonProperty("stop") List stop; + /** + * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output + * more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend + * altering this or top_p but not both. + */ + private @JsonProperty("temperature") Double temperature; + /** + * An alternative to sampling with temperature, called nucleus sampling, where the model considers the + * results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + * probability mass are considered. We generally recommend altering this or temperature but not both. + */ + private @JsonProperty("top_p") Double topP; /** * A list of tools the model may call. Currently, only functions are supported as a tool. Use this to * provide a list of functions the model may generate JSON inputs for. @@ -239,10 +270,20 @@ public void setStreamUsage(Boolean enableStreamUsage) { this.streamOptions = (enableStreamUsage) ? StreamOptions.INCLUDE_USAGE : null; } + @Override + public String getModel() { + return this.model; + } + public void setModel(String model) { this.model = model; } + @Override + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + public void setFrequencyPenalty(Double frequencyPenalty) { this.frequencyPenalty = frequencyPenalty; } @@ -271,6 +312,11 @@ public void setTopLogprobs(Integer topLogprobs) { this.topLogprobs = topLogprobs; } + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } @@ -307,6 +353,11 @@ public void setOutputAudio(AudioParameters audio) { this.outputAudio = audio; } + @Override + public Double getPresencePenalty() { + return this.presencePenalty; + } + public void setPresencePenalty(Double presencePenalty) { this.presencePenalty = presencePenalty; } @@ -363,6 +414,11 @@ public void setTemperature(Double temperature) { this.temperature = temperature; } + @Override + public Double getTopP() { + return this.topP; + } + public void setTopP(Double topP) { this.topP = topP; } diff --git a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java index a6b9ce7d31f..22484aa9a3a 100644 --- a/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java +++ b/models/spring-ai-qianfan/src/main/java/org/springframework/ai/qianfan/QianFanChatOptions.java @@ -25,7 +25,6 @@ import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; -import org.springframework.ai.chat.prompt.AbstractChatOptions; import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.qianfan.api.QianFanApi; @@ -208,17 +207,8 @@ public boolean equals(Object o) { @Override public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + (this.model != null ? this.model.hashCode() : 0); - result = prime * result + (this.frequencyPenalty != null ? this.frequencyPenalty.hashCode() : 0); - result = prime * result + (this.maxTokens != null ? this.maxTokens.hashCode() : 0); - result = prime * result + (this.presencePenalty != null ? this.presencePenalty.hashCode() : 0); - result = prime * result + (this.responseFormat != null ? this.responseFormat.hashCode() : 0); - result = prime * result + (this.stop != null ? this.stop.hashCode() : 0); - result = prime * result + (this.temperature != null ? this.temperature.hashCode() : 0); - result = prime * result + (this.topP != null ? this.topP.hashCode() : 0); - return result; + return Objects.hash(this.model, this.frequencyPenalty, this.maxTokens, this.presencePenalty, + this.responseFormat, this.stop, this.temperature, this.topP); } @Override diff --git a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java index bce72b63c5e..440dc718ed5 100644 --- a/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java +++ b/models/spring-ai-watsonx-ai/src/main/java/org/springframework/ai/watsonx/WatsonxAiChatOptions.java @@ -20,6 +20,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.stream.Collectors; import com.fasterxml.jackson.annotation.JsonAnyGetter; @@ -30,7 +31,6 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import org.springframework.ai.chat.prompt.AbstractChatOptions; import org.springframework.ai.chat.prompt.ChatOptions; /** @@ -47,11 +47,34 @@ */ // @formatter:off -public class WatsonxAiChatOptions extends AbstractChatOptions implements ChatOptions { +public class WatsonxAiChatOptions implements ChatOptions { @JsonIgnore private final ObjectMapper mapper = new ObjectMapper(); + /** + * The temperature of the model. Increasing the temperature will + * make the model answer more creatively. (Default: 0.7) + */ + @JsonProperty("temperature") + private Double temperature; + + /** + * Works together with top-k. A higher value (e.g., 0.95) will lead to + * more diverse text, while a lower value (e.g., 0.2) will generate more focused and + * conservative text. (Default: 1.0) + */ + @JsonProperty("top_p") + private Double topP; + + /** + * Reduces the probability of generating nonsense. A higher value (e.g. + * 100) will give more diverse answers, while a lower value (e.g. 10) will be more + * conservative. (Default: 50) + */ + @JsonProperty("top_k") + private Integer topK; + /** * Decoding is the process that a model uses to choose the tokens in the generated output. * Choose one of the following decoding options: @@ -83,6 +106,14 @@ public class WatsonxAiChatOptions extends AbstractChatOptions implements ChatOpt @JsonProperty("min_new_tokens") private Integer minNewTokens; + /** + * Sets when the LLM should stop. + * (e.g., ["\n\n\n"]) then when the LLM generates three consecutive line breaks it will terminate. + * Stop sequences are ignored until after the number of tokens that are specified in the Min tokens parameter are generated. + */ + @JsonProperty("stop_sequences") + private List stopSequences; + /** * Sets how strongly to penalize repetitions. A higher value * (e.g., 1.8) will penalize repetitions more strongly, while a lower value (e.g., @@ -97,6 +128,12 @@ public class WatsonxAiChatOptions extends AbstractChatOptions implements ChatOpt @JsonProperty("random_seed") private Integer randomSeed; + /** + * Model is the identifier of the LLM Model to be used + */ + @JsonProperty("model") + private String model; + /** * Set additional request params (some model have non-predefined options) */ @@ -135,14 +172,29 @@ public static WatsonxAiChatOptions fromOptions(WatsonxAiChatOptions fromOptions) .build(); } + @Override + public Double getTemperature() { + return this.temperature; + } + public void setTemperature(Double temperature) { this.temperature = temperature; } + @Override + public Double getTopP() { + return this.topP; + } + public void setTopP(Double topP) { this.topP = topP; } + @Override + public Integer getTopK() { + return this.topK; + } + public void setTopK(Integer topK) { this.topK = topK; } @@ -182,6 +234,11 @@ public void setMinNewTokens(Integer minNewTokens) { this.minNewTokens = minNewTokens; } + @Override + public List getStopSequences() { + return this.stopSequences; + } + public void setStopSequences(List stopSequences) { this.stopSequences = stopSequences; } @@ -276,35 +333,24 @@ public boolean equals(Object o) { WatsonxAiChatOptions that = (WatsonxAiChatOptions) o; - if (decodingMethod != null ? !decodingMethod.equals(that.decodingMethod) : that.decodingMethod != null) return false; - if (maxNewTokens != null ? !maxNewTokens.equals(that.maxNewTokens) : that.maxNewTokens != null) return false; - if (minNewTokens != null ? !minNewTokens.equals(that.minNewTokens) : that.minNewTokens != null) return false; - if (repetitionPenalty != null ? !repetitionPenalty.equals(that.repetitionPenalty) : that.repetitionPenalty != null) return false; - if (randomSeed != null ? !randomSeed.equals(that.randomSeed) : that.randomSeed != null) return false; - if (temperature != null ? !temperature.equals(that.temperature) : that.temperature != null) return false; - if (topP != null ? !topP.equals(that.topP) : that.topP != null) return false; - if (topK != null ? !topK.equals(that.topK) : that.topK != null) return false; - if (stopSequences != null ? !stopSequences.equals(that.stopSequences) : that.stopSequences != null) return false; - if (model != null ? !model.equals(that.model) : that.model != null) return false; - return additional != null ? additional.equals(that.additional) : that.additional == null; + return Objects.equals(this.decodingMethod, that.decodingMethod) && + Objects.equals(this.maxNewTokens, that.maxNewTokens) && + Objects.equals(this.minNewTokens, that.minNewTokens) && + Objects.equals(this.repetitionPenalty, that.repetitionPenalty) && + Objects.equals(this.randomSeed, that.randomSeed) && + Objects.equals(this.temperature, that.temperature) && + Objects.equals(this.topP, that.topP) && + Objects.equals(this.topK, that.topK) && + Objects.equals(this.stopSequences, that.stopSequences) && + Objects.equals(this.model, that.model) && + Objects.equals(this.additional, that.additional); } @Override public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((this.decodingMethod == null) ? 0 : this.decodingMethod.hashCode()); - result = prime * result + ((this.maxNewTokens == null) ? 0 : this.maxNewTokens.hashCode()); - result = prime * result + ((this.minNewTokens == null) ? 0 : this.minNewTokens.hashCode()); - result = prime * result + ((this.repetitionPenalty == null) ? 0 : this.repetitionPenalty.hashCode()); - result = prime * result + ((this.randomSeed == null) ? 0 : this.randomSeed.hashCode()); - result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); - result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); - result = prime * result + ((this.topK == null) ? 0 : this.topK.hashCode()); - result = prime * result + ((this.stopSequences == null) ? 0 : this.stopSequences.hashCode()); - result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); - result = prime * result + ((this.additional == null) ? 0 : this.additional.hashCode()); - return result; + return Objects.hash(this.decodingMethod, this.maxNewTokens, this.minNewTokens, + this.repetitionPenalty, this.randomSeed, this.temperature, + this.topP, this.topK, this.stopSequences, this.model, this.additional); } public static class Builder { diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java index f57303b1138..205f706b91c 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Set; import com.fasterxml.jackson.annotation.JsonIgnore; @@ -40,6 +41,7 @@ * @author Geng Rong * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @author Alexandros Pappas * @since 1.0.0 M1 */ @JsonInclude(Include.NON_NULL) @@ -130,7 +132,7 @@ public class ZhiPuAiChatOptions implements FunctionCallingOptions { private Boolean proxyToolCalls; @JsonIgnore - private Map toolContext; + private Map toolContext = new HashMap<>(); // @formatter:on public static Builder builder() { @@ -141,7 +143,7 @@ public static ZhiPuAiChatOptions fromOptions(ZhiPuAiChatOptions fromOptions) { return ZhiPuAiChatOptions.builder() .model(fromOptions.getModel()) .maxTokens(fromOptions.getMaxTokens()) - .stop(fromOptions.getStop()) + .stop(fromOptions.getStop() != null ? new ArrayList<>(fromOptions.getStop()) : null) .temperature(fromOptions.getTemperature()) .topP(fromOptions.getTopP()) .tools(fromOptions.getTools()) @@ -149,10 +151,11 @@ public static ZhiPuAiChatOptions fromOptions(ZhiPuAiChatOptions fromOptions) { .user(fromOptions.getUser()) .requestId(fromOptions.getRequestId()) .doSample(fromOptions.getDoSample()) - .functionCallbacks(fromOptions.getFunctionCallbacks()) - .functions(fromOptions.getFunctions()) + .functionCallbacks(fromOptions.getFunctionCallbacks() != null + ? new ArrayList<>(fromOptions.getFunctionCallbacks()) : null) + .functions(fromOptions.getFunctions() != null ? new HashSet<>(fromOptions.getFunctions()) : null) .proxyToolCalls(fromOptions.getProxyToolCalls()) - .toolContext(fromOptions.getToolContext()) + .toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null) .build(); } @@ -309,19 +312,8 @@ public void setToolContext(Map toolContext) { @Override public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); - result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); - result = prime * result + ((this.stop == null) ? 0 : this.stop.hashCode()); - result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); - result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); - result = prime * result + ((this.tools == null) ? 0 : this.tools.hashCode()); - result = prime * result + ((this.toolChoice == null) ? 0 : this.toolChoice.hashCode()); - result = prime * result + ((this.user == null) ? 0 : this.user.hashCode()); - result = prime * result + ((this.proxyToolCalls == null) ? 0 : this.proxyToolCalls.hashCode()); - result = prime * result + ((this.toolContext == null) ? 0 : this.toolContext.hashCode()); - return result; + return Objects.hash(this.model, this.maxTokens, this.stop, this.temperature, this.topP, this.tools, + this.toolChoice, this.user, this.requestId, this.doSample, this.proxyToolCalls, this.toolContext); } @Override @@ -329,113 +321,22 @@ public boolean equals(Object obj) { if (this == obj) { return true; } - if (obj == null) { - return false; - } - if (getClass() != obj.getClass()) { + if (obj == null || getClass() != obj.getClass()) { return false; } ZhiPuAiChatOptions other = (ZhiPuAiChatOptions) obj; - if (this.model == null) { - if (other.model != null) { - return false; - } - } - else if (!this.model.equals(other.model)) { - return false; - } - if (this.maxTokens == null) { - if (other.maxTokens != null) { - return false; - } - } - else if (!this.maxTokens.equals(other.maxTokens)) { - return false; - } - if (this.stop == null) { - if (other.stop != null) { - return false; - } - } - else if (!this.stop.equals(other.stop)) { - return false; - } - if (this.temperature == null) { - if (other.temperature != null) { - return false; - } - } - else if (!this.temperature.equals(other.temperature)) { - return false; - } - if (this.topP == null) { - if (other.topP != null) { - return false; - } - } - else if (!this.topP.equals(other.topP)) { - return false; - } - if (this.tools == null) { - if (other.tools != null) { - return false; - } - } - else if (!this.tools.equals(other.tools)) { - return false; - } - if (this.toolChoice == null) { - if (other.toolChoice != null) { - return false; - } - } - else if (!this.toolChoice.equals(other.toolChoice)) { - return false; - } - if (this.user == null) { - if (other.user != null) { - return false; - } - } - else if (!this.user.equals(other.user)) { - return false; - } - if (this.requestId == null) { - if (other.requestId != null) { - return false; - } - } - else if (!this.requestId.equals(other.requestId)) { - return false; - } - if (this.doSample == null) { - if (other.doSample != null) { - return false; - } - } - else if (!this.doSample.equals(other.doSample)) { - return false; - } - if (this.proxyToolCalls == null) { - if (other.proxyToolCalls != null) { - return false; - } - } - else if (!this.proxyToolCalls.equals(other.proxyToolCalls)) { - return false; - } - if (this.toolContext == null) { - if (other.toolContext != null) { - return false; - } - } - else if (!this.toolContext.equals(other.toolContext)) { - return false; - } - return true; + + return Objects.equals(this.model, other.model) && Objects.equals(this.maxTokens, other.maxTokens) + && Objects.equals(this.stop, other.stop) && Objects.equals(this.temperature, other.temperature) + && Objects.equals(this.topP, other.topP) && Objects.equals(this.tools, other.tools) + && Objects.equals(this.toolChoice, other.toolChoice) && Objects.equals(this.user, other.user) + && Objects.equals(this.requestId, other.requestId) && Objects.equals(this.doSample, other.doSample) + && Objects.equals(this.proxyToolCalls, other.proxyToolCalls) + && Objects.equals(this.toolContext, other.toolContext); } @Override + @SuppressWarnings("unchecked") public ZhiPuAiChatOptions copy() { return fromOptions(this); } diff --git a/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptionsTests.java b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptionsTests.java new file mode 100644 index 00000000000..8bb115a3483 --- /dev/null +++ b/models/spring-ai-zhipuai/src/test/java/org/springframework/ai/zhipuai/ZhiPuAiChatOptionsTests.java @@ -0,0 +1,131 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.zhipuai; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Tests for {@link ZhiPuAiChatOptions}. + * + * @author Alexandros Pappas + */ +class ZhiPuAiChatOptionsTests { + + @Test + void testBuilderWithAllFields() { + ZhiPuAiChatOptions options = ZhiPuAiChatOptions.builder() + .model("test-model") + .maxTokens(100) + .stop(List.of("test")) + .temperature(0.6) + .topP(0.7) + .toolChoice("auto") + .user("test-user") + .requestId("12345") + .doSample(true) + .proxyToolCalls(true) + .toolContext(Map.of("key1", "value1")) + .build(); + + assertThat(options) + .extracting("model", "maxTokens", "stop", "temperature", "topP", "toolChoice", "user", "requestId", + "doSample", "proxyToolCalls", "toolContext") + .containsExactly("test-model", 100, List.of("test"), 0.6, 0.7, "auto", "test-user", "12345", true, true, + Map.of("key1", "value1")); + } + + @Test + void testCopy() { + ZhiPuAiChatOptions original = ZhiPuAiChatOptions.builder() + .model("test-model") + .maxTokens(100) + .stop(List.of("test")) + .temperature(0.6) + .topP(0.7) + .toolChoice("auto") + .user("test-user") + .requestId("12345") + .doSample(true) + .proxyToolCalls(true) + .toolContext(Map.of("key1", "value1")) + .build(); + + ZhiPuAiChatOptions copied = original.copy(); + + assertThat(copied).isNotSameAs(original).isEqualTo(original); + assertThat(copied.getStop()).isNotSameAs(original.getStop()); + assertThat(copied.getFunctionCallbacks()).isNotSameAs(original.getFunctionCallbacks()); + assertThat(copied.getFunctions()).isNotSameAs(original.getFunctions()); + assertThat(copied.getToolContext()).isNotSameAs(original.getToolContext()); + } + + @Test + void testSetters() { + ZhiPuAiChatOptions options = new ZhiPuAiChatOptions(); + options.setModel("test-model"); + options.setMaxTokens(100); + options.setStop(List.of("test")); + options.setTemperature(0.6); + options.setTopP(0.7); + options.setToolChoice("auto"); + options.setUser("test-user"); + options.setRequestId("12345"); + options.setDoSample(true); + options.setProxyToolCalls(true); + options.setToolContext(Map.of("key1", "value1")); + + assertThat(options.getModel()).isEqualTo("test-model"); + assertThat(options.getMaxTokens()).isEqualTo(100); + assertThat(options.getStop()).isEqualTo(List.of("test")); + assertThat(options.getTemperature()).isEqualTo(0.6); + assertThat(options.getTopP()).isEqualTo(0.7); + assertThat(options.getToolChoice()).isEqualTo("auto"); + assertThat(options.getUser()).isEqualTo("test-user"); + assertThat(options.getRequestId()).isEqualTo("12345"); + assertThat(options.getDoSample()).isTrue(); + assertThat(options.getProxyToolCalls()).isTrue(); + assertThat(options.getToolContext()).isEqualTo(Map.of("key1", "value1")); + } + + @Test + void testDefaultValues() { + ZhiPuAiChatOptions options = new ZhiPuAiChatOptions(); + assertThat(options.getModel()).isNull(); + assertThat(options.getMaxTokens()).isNull(); + assertThat(options.getStop()).isNull(); + assertThat(options.getTemperature()).isNull(); + assertThat(options.getTopP()).isNull(); + assertThat(options.getTools()).isNull(); + assertThat(options.getToolChoice()).isNull(); + assertThat(options.getUser()).isNull(); + assertThat(options.getRequestId()).isNull(); + assertThat(options.getDoSample()).isNull(); + assertThat(options.getFunctionCallbacks()).isEqualTo(new ArrayList<>()); + assertThat(options.getFunctions()).isEqualTo(new HashSet<>()); + assertThat(options.getProxyToolCalls()).isNull(); + assertThat(options.getToolContext()).isEqualTo(new HashMap<>()); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/AbstractChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/AbstractChatOptions.java deleted file mode 100644 index 8dd4c82857a..00000000000 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/AbstractChatOptions.java +++ /dev/null @@ -1,223 +0,0 @@ -/* - * Copyright 2025-2025 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * https://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.springframework.ai.chat.prompt; - -import java.util.List; - -import com.fasterxml.jackson.annotation.JsonIgnore; -import com.fasterxml.jackson.annotation.JsonProperty; - -import org.springframework.lang.Nullable; - -/** - * Abstract base class for {@link ChatOptions}, providing common implementation for its - * methods. - * - * @author Alexandros Pappas - */ -public abstract class AbstractChatOptions implements ChatOptions { - - /** - * ID of the model to use. - */ - @JsonProperty("model") - protected String model; - - /** - * Number between -2.0 and 2.0. Positive values penalize new tokens based on their - * existing frequency in the text so far, decreasing the model's likelihood to repeat - * the same line verbatim. - */ - @JsonProperty("frequency_penalty") - protected Double frequencyPenalty; - - /** - * The maximum number of tokens to generate in the chat completion. The total length - * of input tokens and generated tokens is limited by the model's context length. - */ - @JsonProperty("max_tokens") - protected Integer maxTokens; - - /** - * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether - * they appear in the text so far, increasing the model's likelihood to talk about new - * topics. - */ - @JsonProperty("presence_penalty") - protected Double presencePenalty; - - /** - * Sets when the LLM should stop. (e.g., ["\n\n\n"]) then when the LLM generates three - * consecutive line breaks it will terminate. Stop sequences are ignored until after - * the number of tokens that are specified in the Min tokens parameter are generated. - */ - @JsonProperty("stop_sequences") - protected List stopSequences; - - /** - * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make - * the output more random, while lower values like 0.2 will make it more focused and - * deterministic. We generally recommend altering this or top_p but not both. - */ - @JsonProperty("temperature") - protected Double temperature; - - /** - * Reduces the probability of generating nonsense. A higher value (e.g. 100) will give - * more diverse answers, while a lower value (e.g. 10) will be more conservative. - * (Default: 40) - */ - @JsonProperty("top_k") - protected Integer topK; - - /** - * An alternative to sampling with temperature, called nucleus sampling, where the - * model considers the results of the tokens with top_p probability mass. So 0.1 means - * only the tokens comprising the top 10% probability mass are considered. We - * generally recommend altering this or temperature but not both. - */ - @JsonProperty("top_p") - protected Double topP; - - @Override - public String getModel() { - return this.model; - } - - @Override - public Double getFrequencyPenalty() { - return this.frequencyPenalty; - } - - @Override - @Nullable - public Integer getMaxTokens() { - return this.maxTokens; - } - - @Override - public Double getPresencePenalty() { - return this.presencePenalty; - } - - @Override - @JsonIgnore - public List getStopSequences() { - return this.stopSequences; - } - - @Override - public Double getTemperature() { - return this.temperature; - } - - @Override - public Integer getTopK() { - return this.topK; - } - - @Override - @Nullable - public Double getTopP() { - return this.topP; - } - - /** - * Generic Abstract Builder for {@link AbstractChatOptions}. Ensures fluent API in - * subclasses with proper typing. - */ - public abstract static class Builder> - implements ChatOptions.Builder { - - protected T options; - - protected String model; - - protected Double frequencyPenalty; - - protected Integer maxTokens; - - protected Double presencePenalty; - - protected List stopSequences; - - protected Double temperature; - - protected Integer topK; - - protected Double topP; - - protected abstract B self(); - - public Builder(T options) { - this.options = options; - } - - @Override - public B model(String model) { - this.model = model; - return this.self(); - } - - @Override - public B frequencyPenalty(Double frequencyPenalty) { - this.frequencyPenalty = frequencyPenalty; - return this.self(); - } - - @Override - public B maxTokens(Integer maxTokens) { - this.maxTokens = maxTokens; - return this.self(); - } - - @Override - public B presencePenalty(Double presencePenalty) { - this.presencePenalty = presencePenalty; - return this.self(); - } - - @Override - public B stopSequences(List stopSequences) { - this.stopSequences = stopSequences; - return this.self(); - } - - @Override - public B temperature(Double temperature) { - this.temperature = temperature; - return this.self(); - } - - @Override - public B topK(Integer topK) { - this.topK = topK; - return this.self(); - } - - @Override - public B topP(Double topP) { - this.topP = topP; - return this.self(); - } - - @Override - public abstract T build(); - - } - -} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java index cad9af8a3eb..19cb98a3a6b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/ChatOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2025 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,8 +24,6 @@ /** * {@link ModelOptions} representing the common options that are portable across different * chat models. - * - * @author Alexandros Pappas */ public interface ChatOptions extends ModelOptions { @@ -97,7 +95,7 @@ public interface ChatOptions extends ModelOptions { * @return Returns a new {@link ChatOptions.Builder}. */ static ChatOptions.Builder builder() { - return new DefaultChatOptions.Builder(); + return new DefaultChatOptionsBuilder(); } /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java index 2f15f748e2e..02c9356754b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptions.java @@ -17,65 +17,116 @@ package org.springframework.ai.chat.prompt; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Objects; /** * Default implementation for the {@link ChatOptions}. - * - * @author Alexandros Pappas */ -public class DefaultChatOptions extends AbstractChatOptions { +public class DefaultChatOptions implements ChatOptions { + + private String model; + + private Double frequencyPenalty; + + private Integer maxTokens; + + private Double presencePenalty; - public static Builder builder() { - return new Builder(); + private List stopSequences; + + private Double temperature; + + private Integer topK; + + private Double topP; + + @Override + public String getModel() { + return this.model; } public void setModel(String model) { this.model = model; } + @Override + public Double getFrequencyPenalty() { + return this.frequencyPenalty; + } + public void setFrequencyPenalty(Double frequencyPenalty) { this.frequencyPenalty = frequencyPenalty; } + @Override + public Integer getMaxTokens() { + return this.maxTokens; + } + public void setMaxTokens(Integer maxTokens) { this.maxTokens = maxTokens; } + @Override + public Double getPresencePenalty() { + return this.presencePenalty; + } + public void setPresencePenalty(Double presencePenalty) { this.presencePenalty = presencePenalty; } + @Override + public List getStopSequences() { + return this.stopSequences != null ? Collections.unmodifiableList(this.stopSequences) : null; + } + public void setStopSequences(List stopSequences) { this.stopSequences = stopSequences; } + @Override + public Double getTemperature() { + return this.temperature; + } + public void setTemperature(Double temperature) { this.temperature = temperature; } + @Override + public Integer getTopK() { + return this.topK; + } + public void setTopK(Integer topK) { this.topK = topK; } + @Override + public Double getTopP() { + return this.topP; + } + public void setTopP(Double topP) { this.topP = topP; } @Override @SuppressWarnings("unchecked") - public DefaultChatOptions copy() { - return DefaultChatOptions.builder() - .model(this.getModel()) - .frequencyPenalty(this.getFrequencyPenalty()) - .maxTokens(this.getMaxTokens()) - .presencePenalty(this.getPresencePenalty()) - .stopSequences(this.getStopSequences() == null ? null : new ArrayList<>(this.getStopSequences())) - .temperature(this.getTemperature()) - .topK(this.getTopK()) - .topP(this.getTopP()) - .build(); + public T copy() { + DefaultChatOptions copy = new DefaultChatOptions(); + copy.setModel(this.getModel()); + copy.setFrequencyPenalty(this.getFrequencyPenalty()); + copy.setMaxTokens(this.getMaxTokens()); + copy.setPresencePenalty(this.getPresencePenalty()); + copy.setStopSequences(this.getStopSequences() != null ? new ArrayList<>(this.getStopSequences()) : null); + copy.setTemperature(this.getTemperature()); + copy.setTopK(this.getTopK()); + copy.setTopP(this.getTopP()); + return (T) copy; } @Override @@ -87,63 +138,18 @@ public boolean equals(Object o) { DefaultChatOptions that = (DefaultChatOptions) o; - if (!Objects.equals(model, that.model)) - return false; - if (!Objects.equals(frequencyPenalty, that.frequencyPenalty)) - return false; - if (!Objects.equals(maxTokens, that.maxTokens)) - return false; - if (!Objects.equals(presencePenalty, that.presencePenalty)) - return false; - if (!Objects.equals(stopSequences, that.stopSequences)) - return false; - if (!Objects.equals(temperature, that.temperature)) - return false; - if (!Objects.equals(topK, that.topK)) - return false; - return Objects.equals(topP, that.topP); + return Objects.equals(this.model, that.model) && Objects.equals(this.frequencyPenalty, that.frequencyPenalty) + && Objects.equals(this.maxTokens, that.maxTokens) + && Objects.equals(this.presencePenalty, that.presencePenalty) + && Objects.equals(this.stopSequences, that.stopSequences) + && Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topK, that.topK) + && Objects.equals(this.topP, that.topP); } @Override public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((this.model == null) ? 0 : this.model.hashCode()); - result = prime * result + ((this.maxTokens == null) ? 0 : this.maxTokens.hashCode()); - result = prime * result + ((this.frequencyPenalty == null) ? 0 : this.frequencyPenalty.hashCode()); - result = prime * result + ((this.presencePenalty == null) ? 0 : this.presencePenalty.hashCode()); - result = prime * result + ((this.stopSequences == null) ? 0 : this.stopSequences.hashCode()); - result = prime * result + ((this.temperature == null) ? 0 : this.temperature.hashCode()); - result = prime * result + ((this.topP == null) ? 0 : this.topP.hashCode()); - result = prime * result + ((this.topK == null) ? 0 : this.topK.hashCode()); - return result; - } - - public static class Builder extends AbstractChatOptions.Builder { - - public Builder() { - super(new DefaultChatOptions()); - } - - @Override - protected Builder self() { - return this; - } - - @Override - public DefaultChatOptions build() { - DefaultChatOptions optionsToBuild = new DefaultChatOptions(); - optionsToBuild.setModel(this.model); - optionsToBuild.setFrequencyPenalty(this.frequencyPenalty); - optionsToBuild.setMaxTokens(this.maxTokens); - optionsToBuild.setPresencePenalty(this.presencePenalty); - optionsToBuild.setStopSequences(this.stopSequences); - optionsToBuild.setTemperature(this.temperature); - optionsToBuild.setTopK(this.topK); - optionsToBuild.setTopP(this.topP); - return optionsToBuild; - } - + return Objects.hash(model, maxTokens, frequencyPenalty, presencePenalty, stopSequences, temperature, topP, + topK); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java new file mode 100644 index 00000000000..47ba5840109 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/prompt/DefaultChatOptionsBuilder.java @@ -0,0 +1,80 @@ +/* + * Copyright 2024-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.prompt; + +import java.util.List; + +/** + * Implementation of {@link ChatOptions.Builder} to create {@link DefaultChatOptions}. + */ +public class DefaultChatOptionsBuilder implements ChatOptions.Builder { + + protected DefaultChatOptions options; + + public DefaultChatOptionsBuilder() { + this.options = new DefaultChatOptions(); + } + + protected DefaultChatOptionsBuilder(DefaultChatOptions options) { + this.options = options; + } + + public DefaultChatOptionsBuilder model(String model) { + this.options.setModel(model); + return this; + } + + public DefaultChatOptionsBuilder frequencyPenalty(Double frequencyPenalty) { + this.options.setFrequencyPenalty(frequencyPenalty); + return this; + } + + public DefaultChatOptionsBuilder maxTokens(Integer maxTokens) { + this.options.setMaxTokens(maxTokens); + return this; + } + + public DefaultChatOptionsBuilder presencePenalty(Double presencePenalty) { + this.options.setPresencePenalty(presencePenalty); + return this; + } + + public DefaultChatOptionsBuilder stopSequences(List stop) { + this.options.setStopSequences(stop); + return this; + } + + public DefaultChatOptionsBuilder temperature(Double temperature) { + this.options.setTemperature(temperature); + return this; + } + + public DefaultChatOptionsBuilder topK(Integer topK) { + this.options.setTopK(topK); + return this; + } + + public DefaultChatOptionsBuilder topP(Double topP) { + this.options.setTopP(topP); + return this; + } + + public ChatOptions build() { + return this.options.copy(); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java index ca9a2a10db0..247e82b6f00 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/ChatOptionsBuilderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2025 the original author or authors. + * Copyright 2023-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ package org.springframework.ai.chat.prompt; +import java.util.ArrayList; import java.util.List; import java.util.Set; @@ -26,13 +27,13 @@ import org.springframework.ai.model.function.FunctionCallingOptions; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; /** * Unit Tests for {@link ChatOptions} builder. * * @author youngmon * @author Mark Pollack - * @author Alexandros Pappas * @since 1.0.0 */ public class ChatOptionsBuilderTests { @@ -162,4 +163,15 @@ void shouldHaveExpectedDefaultValues() { assertThat(options.getStopSequences()).isNull(); } + @Test + void shouldBeImmutableAfterBuild() { + // Given + List stopSequences = new ArrayList<>(List.of("stop1", "stop2")); + ChatOptions options = this.builder.stopSequences(stopSequences).build(); + + // Then + assertThatThrownBy(() -> options.getStopSequences().add("stop3")) + .isInstanceOf(UnsupportedOperationException.class); + } + } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/DefaultChatOptionsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/DefaultChatOptionsTests.java index ece5853aa95..11a096a535c 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/DefaultChatOptionsTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/prompt/DefaultChatOptionsTests.java @@ -31,7 +31,7 @@ class DefaultChatOptionsTests { @Test void testBuilderWithAllFields() { - DefaultChatOptions options = DefaultChatOptions.builder() + ChatOptions options = ChatOptions.builder() .model("test-model") .frequencyPenalty(0.5) .maxTokens(100) @@ -50,7 +50,7 @@ void testBuilderWithAllFields() { @Test void testCopy() { - DefaultChatOptions original = DefaultChatOptions.builder() + ChatOptions original = ChatOptions.builder() .model("test-model") .frequencyPenalty(0.5) .maxTokens(100)