Skip to content

Commit 1815681

Browse files
committed
feat: update all *ChatOptions* Classes and add Unit Tests
Signed-off-by: Alexandros Pappas <[email protected]>
1 parent 4135410 commit 1815681

File tree

18 files changed

+874
-812
lines changed

18 files changed

+874
-812
lines changed

models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,10 @@
3232

3333
import org.springframework.ai.anthropic.api.AnthropicApi;
3434
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest;
35-
import org.springframework.ai.chat.prompt.AbstractChatOptions;
3635
import org.springframework.ai.model.function.FunctionCallback;
3736
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3837
import org.springframework.ai.tool.ToolCallback;
3938
import org.springframework.lang.Nullable;
40-
import org.springframework.ai.model.tool.ToolCallingChatOptions;
4139
import org.springframework.util.Assert;
4240

4341
/**
@@ -49,11 +47,16 @@
4947
* @since 1.0.0
5048
*/
5149
@JsonInclude(Include.NON_NULL)
52-
public class AnthropicChatOptions extends AbstractChatOptions implements ToolCallingChatOptions {
50+
public class AnthropicChatOptions implements ToolCallingChatOptions {
5351

5452
// @formatter:off
55-
53+
private @JsonProperty("model") String model;
54+
private @JsonProperty("max_tokens") Integer maxTokens;
5655
private @JsonProperty("metadata") ChatCompletionRequest.Metadata metadata;
56+
private @JsonProperty("stop_sequences") List<String> stopSequences;
57+
private @JsonProperty("temperature") Double temperature;
58+
private @JsonProperty("top_p") Double topP;
59+
private @JsonProperty("top_k") Integer topK;
5760

5861
/**
5962
* Collection of {@link ToolCallback}s to be used for tool calling in the chat
@@ -100,10 +103,20 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
100103
.build();
101104
}
102105

106+
@Override
107+
public String getModel() {
108+
return this.model;
109+
}
110+
103111
public void setModel(String model) {
104112
this.model = model;
105113
}
106114

115+
@Override
116+
public Integer getMaxTokens() {
117+
return this.maxTokens;
118+
}
119+
107120
public void setMaxTokens(Integer maxTokens) {
108121
this.maxTokens = maxTokens;
109122
}
@@ -116,18 +129,38 @@ public void setMetadata(ChatCompletionRequest.Metadata metadata) {
116129
this.metadata = metadata;
117130
}
118131

132+
@Override
133+
public List<String> getStopSequences() {
134+
return this.stopSequences;
135+
}
136+
119137
public void setStopSequences(List<String> stopSequences) {
120138
this.stopSequences = stopSequences;
121139
}
122140

141+
@Override
142+
public Double getTemperature() {
143+
return this.temperature;
144+
}
145+
123146
public void setTemperature(Double temperature) {
124147
this.temperature = temperature;
125148
}
126149

150+
@Override
151+
public Double getTopP() {
152+
return this.topP;
153+
}
154+
127155
public void setTopP(Double topP) {
128156
this.topP = topP;
129157
}
130158

159+
@Override
160+
public Integer getTopK() {
161+
return this.topK;
162+
}
163+
131164
public void setTopK(Integer topK) {
132165
this.topK = topK;
133166
}
@@ -240,6 +273,7 @@ public void setToolContext(Map<String, Object> toolContext) {
240273
}
241274

242275
@Override
276+
@SuppressWarnings("unchecked")
243277
public AnthropicChatOptions copy() {
244278
return fromOptions(this);
245279
}
@@ -264,21 +298,8 @@ public boolean equals(Object o) {
264298

265299
@Override
266300
public int hashCode() {
267-
final int prime = 31;
268-
int result = 1;
269-
result = prime * result + (this.model != null ? this.model.hashCode() : 0);
270-
result = prime * result + (this.maxTokens != null ? this.maxTokens.hashCode() : 0);
271-
result = prime * result + (this.metadata != null ? this.metadata.hashCode() : 0);
272-
result = prime * result + (this.stopSequences != null ? this.stopSequences.hashCode() : 0);
273-
result = prime * result + (this.temperature != null ? this.temperature.hashCode() : 0);
274-
result = prime * result + (this.topP != null ? this.topP.hashCode() : 0);
275-
result = prime * result + (this.topK != null ? this.topK.hashCode() : 0);
276-
result = prime * result + (this.toolCallbacks != null ? this.toolCallbacks.hashCode() : 0);
277-
result = prime * result + (this.toolNames != null ? this.toolNames.hashCode() : 0);
278-
result = prime * result
279-
+ (this.internalToolExecutionEnabled != null ? this.internalToolExecutionEnabled.hashCode() : 0);
280-
result = prime * result + (this.toolContext != null ? this.toolContext.hashCode() : 0);
281-
return result;
301+
return Objects.hash(model, maxTokens, metadata, stopSequences, temperature, topP, topK, toolCallbacks,
302+
toolNames, internalToolExecutionEnabled, toolContext);
282303
}
283304

284305
public static class Builder {

models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatOptions.java

Lines changed: 67 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2023-2024 the original author or authors.
2+
* Copyright 2023-2025 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -32,7 +32,6 @@
3232
import com.fasterxml.jackson.annotation.JsonInclude.Include;
3333
import com.fasterxml.jackson.annotation.JsonProperty;
3434

35-
import org.springframework.ai.chat.prompt.AbstractChatOptions;
3635
import org.springframework.ai.model.function.FunctionCallback;
3736
import org.springframework.ai.model.tool.ToolCallingChatOptions;
3837
import org.springframework.ai.tool.ToolCallback;
@@ -51,7 +50,34 @@
5150
* @author Alexandros Pappas
5251
*/
5352
@JsonInclude(Include.NON_NULL)
54-
public class AzureOpenAiChatOptions extends AbstractChatOptions implements ToolCallingChatOptions {
53+
public class AzureOpenAiChatOptions implements ToolCallingChatOptions {
54+
55+
/**
56+
* The maximum number of tokens to generate.
57+
*/
58+
@JsonProperty("max_tokens")
59+
private Integer maxTokens;
60+
61+
/**
62+
* The sampling temperature to use that controls the apparent creativity of generated
63+
* completions. Higher values will make output more random while lower values will
64+
* make results more focused and deterministic. It is not recommended to modify
65+
* temperature and top_p for the same completions request as the interaction of these
66+
* two settings is difficult to predict.
67+
*/
68+
@JsonProperty("temperature")
69+
private Double temperature;
70+
71+
/**
72+
* An alternative to sampling with temperature called nucleus sampling. This value
73+
* causes the model to consider the results of tokens with the provided probability
74+
* mass. As an example, a value of 0.15 will cause only the tokens comprising the top
75+
* 15% of probability mass to be considered. It is not recommended to modify
76+
* temperature and top_p for the same completions request as the interaction of these
77+
* two settings is difficult to predict.
78+
*/
79+
@JsonProperty("top_p")
80+
private Double topP;
5581

5682
/**
5783
* A map between GPT token IDs and bias scores that influences the probability of
@@ -85,6 +111,24 @@ public class AzureOpenAiChatOptions extends AbstractChatOptions implements ToolC
85111
@JsonProperty("stop")
86112
private List<String> stop;
87113

114+
/**
115+
* A value that influences the probability of generated tokens appearing based on
116+
* their existing presence in generated text. Positive values will make tokens less
117+
* likely to appear when they already exist and increase the model's likelihood to
118+
* output new topics.
119+
*/
120+
@JsonProperty("presence_penalty")
121+
private Double presencePenalty;
122+
123+
/**
124+
* A value that influences the probability of generated tokens appearing based on
125+
* their cumulative frequency in generated text. Positive values will make tokens less
126+
* likely to appear as their frequency increases and decrease the likelihood of the
127+
* model repeating the same statements verbatim.
128+
*/
129+
@JsonProperty("frequency_penalty")
130+
private Double frequencyPenalty;
131+
88132
/**
89133
* The deployment name as defined in Azure Open AI Studio when creating a deployment
90134
* backed by an Azure OpenAI base model.
@@ -279,10 +323,20 @@ public void setStop(List<String> stop) {
279323
this.stop = stop;
280324
}
281325

326+
@Override
327+
public Double getPresencePenalty() {
328+
return this.presencePenalty;
329+
}
330+
282331
public void setPresencePenalty(Double presencePenalty) {
283332
this.presencePenalty = presencePenalty;
284333
}
285334

335+
@Override
336+
public Double getFrequencyPenalty() {
337+
return this.frequencyPenalty;
338+
}
339+
286340
public void setFrequencyPenalty(Double frequencyPenalty) {
287341
this.frequencyPenalty = frequencyPenalty;
288342
}
@@ -306,6 +360,11 @@ public void setDeploymentName(String deploymentName) {
306360
this.deploymentName = deploymentName;
307361
}
308362

363+
@Override
364+
public Double getTemperature() {
365+
return this.temperature;
366+
}
367+
309368
public void setTemperature(Double temperature) {
310369
this.temperature = temperature;
311370
}
@@ -422,7 +481,7 @@ public void setStreamOptions(ChatCompletionStreamOptions streamOptions) {
422481
}
423482

424483
@Override
425-
@SuppressWarnings("")
484+
@SuppressWarnings("unchecked")
426485
public AzureOpenAiChatOptions copy() {
427486
return fromOptions(this);
428487
}
@@ -454,30 +513,10 @@ public boolean equals(Object o) {
454513

455514
@Override
456515
public int hashCode() {
457-
final int prime = 31;
458-
int result = 1;
459-
result = prime * result + (this.logitBias != null ? this.logitBias.hashCode() : 0);
460-
result = prime * result + (this.user != null ? this.user.hashCode() : 0);
461-
result = prime * result + (this.n != null ? this.n.hashCode() : 0);
462-
result = prime * result + (this.stop != null ? this.stop.hashCode() : 0);
463-
result = prime * result + (this.deploymentName != null ? this.deploymentName.hashCode() : 0);
464-
result = prime * result + (this.responseFormat != null ? this.responseFormat.hashCode() : 0);
465-
result = prime * result + (this.toolCallbacks != null ? this.toolCallbacks.hashCode() : 0);
466-
result = prime * result + (this.toolNames != null ? this.toolNames.hashCode() : 0);
467-
result = prime * result
468-
+ (this.internalToolExecutionEnabled != null ? this.internalToolExecutionEnabled.hashCode() : 0);
469-
result = prime * result + (this.seed != null ? this.seed.hashCode() : 0);
470-
result = prime * result + (this.logprobs != null ? this.logprobs.hashCode() : 0);
471-
result = prime * result + (this.topLogProbs != null ? this.topLogProbs.hashCode() : 0);
472-
result = prime * result + (this.enhancements != null ? this.enhancements.hashCode() : 0);
473-
result = prime * result + (this.streamOptions != null ? this.streamOptions.hashCode() : 0);
474-
result = prime * result + (this.toolContext != null ? this.toolContext.hashCode() : 0);
475-
result = prime * result + (this.maxTokens != null ? this.maxTokens.hashCode() : 0);
476-
result = prime * result + (this.frequencyPenalty != null ? this.frequencyPenalty.hashCode() : 0);
477-
result = prime * result + (this.presencePenalty != null ? this.presencePenalty.hashCode() : 0);
478-
result = prime * result + (this.temperature != null ? this.temperature.hashCode() : 0);
479-
result = prime * result + (this.topP != null ? this.topP.hashCode() : 0);
480-
return result;
516+
return Objects.hash(this.logitBias, this.user, this.n, this.stop, this.deploymentName, this.responseFormat,
517+
this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.seed, this.logprobs,
518+
this.topLogProbs, this.enhancements, this.streamOptions, this.toolContext, this.maxTokens,
519+
this.frequencyPenalty, this.presencePenalty, this.temperature, this.topP);
481520
}
482521

483522
public static class Builder {

0 commit comments

Comments
 (0)