Skip to content

Commit

Permalink
feat: update all *ChatOptions* Classes and add Unit Tests
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandros Pappas <[email protected]>
  • Loading branch information
apappascs committed Feb 14, 2025
1 parent 4135410 commit 1815681
Show file tree
Hide file tree
Showing 18 changed files with 874 additions and 812 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,10 @@

import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest;
import org.springframework.ai.chat.prompt.AbstractChatOptions;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.lang.Nullable;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.util.Assert;

/**
Expand All @@ -49,11 +47,16 @@
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
public class AnthropicChatOptions extends AbstractChatOptions implements ToolCallingChatOptions {
public class AnthropicChatOptions implements ToolCallingChatOptions {

// @formatter:off

private @JsonProperty("model") String model;
private @JsonProperty("max_tokens") Integer maxTokens;
private @JsonProperty("metadata") ChatCompletionRequest.Metadata metadata;
private @JsonProperty("stop_sequences") List<String> stopSequences;
private @JsonProperty("temperature") Double temperature;
private @JsonProperty("top_p") Double topP;
private @JsonProperty("top_k") Integer topK;

/**
* Collection of {@link ToolCallback}s to be used for tool calling in the chat
Expand Down Expand Up @@ -100,10 +103,20 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
.build();
}

@Override
public String getModel() {
return this.model;
}

public void setModel(String model) {
this.model = model;
}

@Override
public Integer getMaxTokens() {
return this.maxTokens;
}

public void setMaxTokens(Integer maxTokens) {
this.maxTokens = maxTokens;
}
Expand All @@ -116,18 +129,38 @@ public void setMetadata(ChatCompletionRequest.Metadata metadata) {
this.metadata = metadata;
}

@Override
public List<String> getStopSequences() {
return this.stopSequences;
}

public void setStopSequences(List<String> stopSequences) {
this.stopSequences = stopSequences;
}

@Override
public Double getTemperature() {
return this.temperature;
}

public void setTemperature(Double temperature) {
this.temperature = temperature;
}

@Override
public Double getTopP() {
return this.topP;
}

public void setTopP(Double topP) {
this.topP = topP;
}

@Override
public Integer getTopK() {
return this.topK;
}

public void setTopK(Integer topK) {
this.topK = topK;
}
Expand Down Expand Up @@ -240,6 +273,7 @@ public void setToolContext(Map<String, Object> toolContext) {
}

@Override
@SuppressWarnings("unchecked")
public AnthropicChatOptions copy() {
return fromOptions(this);
}
Expand All @@ -264,21 +298,8 @@ public boolean equals(Object o) {

@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + (this.model != null ? this.model.hashCode() : 0);
result = prime * result + (this.maxTokens != null ? this.maxTokens.hashCode() : 0);
result = prime * result + (this.metadata != null ? this.metadata.hashCode() : 0);
result = prime * result + (this.stopSequences != null ? this.stopSequences.hashCode() : 0);
result = prime * result + (this.temperature != null ? this.temperature.hashCode() : 0);
result = prime * result + (this.topP != null ? this.topP.hashCode() : 0);
result = prime * result + (this.topK != null ? this.topK.hashCode() : 0);
result = prime * result + (this.toolCallbacks != null ? this.toolCallbacks.hashCode() : 0);
result = prime * result + (this.toolNames != null ? this.toolNames.hashCode() : 0);
result = prime * result
+ (this.internalToolExecutionEnabled != null ? this.internalToolExecutionEnabled.hashCode() : 0);
result = prime * result + (this.toolContext != null ? this.toolContext.hashCode() : 0);
return result;
return Objects.hash(model, maxTokens, metadata, stopSequences, temperature, topP, topK, toolCallbacks,
toolNames, internalToolExecutionEnabled, toolContext);
}

public static class Builder {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -32,7 +32,6 @@
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;

import org.springframework.ai.chat.prompt.AbstractChatOptions;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.tool.ToolCallback;
Expand All @@ -51,7 +50,34 @@
* @author Alexandros Pappas
*/
@JsonInclude(Include.NON_NULL)
public class AzureOpenAiChatOptions extends AbstractChatOptions implements ToolCallingChatOptions {
public class AzureOpenAiChatOptions implements ToolCallingChatOptions {

/**
* The maximum number of tokens to generate.
*/
@JsonProperty("max_tokens")
private Integer maxTokens;

/**
* The sampling temperature to use that controls the apparent creativity of generated
* completions. Higher values will make output more random while lower values will
* make results more focused and deterministic. It is not recommended to modify
* temperature and top_p for the same completions request as the interaction of these
* two settings is difficult to predict.
*/
@JsonProperty("temperature")
private Double temperature;

/**
* An alternative to sampling with temperature called nucleus sampling. This value
* causes the model to consider the results of tokens with the provided probability
* mass. As an example, a value of 0.15 will cause only the tokens comprising the top
* 15% of probability mass to be considered. It is not recommended to modify
* temperature and top_p for the same completions request as the interaction of these
* two settings is difficult to predict.
*/
@JsonProperty("top_p")
private Double topP;

/**
* A map between GPT token IDs and bias scores that influences the probability of
Expand Down Expand Up @@ -85,6 +111,24 @@ public class AzureOpenAiChatOptions extends AbstractChatOptions implements ToolC
@JsonProperty("stop")
private List<String> stop;

/**
* A value that influences the probability of generated tokens appearing based on
* their existing presence in generated text. Positive values will make tokens less
* likely to appear when they already exist and increase the model's likelihood to
* output new topics.
*/
@JsonProperty("presence_penalty")
private Double presencePenalty;

/**
* A value that influences the probability of generated tokens appearing based on
* their cumulative frequency in generated text. Positive values will make tokens less
* likely to appear as their frequency increases and decrease the likelihood of the
* model repeating the same statements verbatim.
*/
@JsonProperty("frequency_penalty")
private Double frequencyPenalty;

/**
* The deployment name as defined in Azure Open AI Studio when creating a deployment
* backed by an Azure OpenAI base model.
Expand Down Expand Up @@ -279,10 +323,20 @@ public void setStop(List<String> stop) {
this.stop = stop;
}

@Override
public Double getPresencePenalty() {
return this.presencePenalty;
}

public void setPresencePenalty(Double presencePenalty) {
this.presencePenalty = presencePenalty;
}

@Override
public Double getFrequencyPenalty() {
return this.frequencyPenalty;
}

public void setFrequencyPenalty(Double frequencyPenalty) {
this.frequencyPenalty = frequencyPenalty;
}
Expand All @@ -306,6 +360,11 @@ public void setDeploymentName(String deploymentName) {
this.deploymentName = deploymentName;
}

@Override
public Double getTemperature() {
return this.temperature;
}

public void setTemperature(Double temperature) {
this.temperature = temperature;
}
Expand Down Expand Up @@ -422,7 +481,7 @@ public void setStreamOptions(ChatCompletionStreamOptions streamOptions) {
}

@Override
@SuppressWarnings("")
@SuppressWarnings("unchecked")
public AzureOpenAiChatOptions copy() {
return fromOptions(this);
}
Expand Down Expand Up @@ -454,30 +513,10 @@ public boolean equals(Object o) {

@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + (this.logitBias != null ? this.logitBias.hashCode() : 0);
result = prime * result + (this.user != null ? this.user.hashCode() : 0);
result = prime * result + (this.n != null ? this.n.hashCode() : 0);
result = prime * result + (this.stop != null ? this.stop.hashCode() : 0);
result = prime * result + (this.deploymentName != null ? this.deploymentName.hashCode() : 0);
result = prime * result + (this.responseFormat != null ? this.responseFormat.hashCode() : 0);
result = prime * result + (this.toolCallbacks != null ? this.toolCallbacks.hashCode() : 0);
result = prime * result + (this.toolNames != null ? this.toolNames.hashCode() : 0);
result = prime * result
+ (this.internalToolExecutionEnabled != null ? this.internalToolExecutionEnabled.hashCode() : 0);
result = prime * result + (this.seed != null ? this.seed.hashCode() : 0);
result = prime * result + (this.logprobs != null ? this.logprobs.hashCode() : 0);
result = prime * result + (this.topLogProbs != null ? this.topLogProbs.hashCode() : 0);
result = prime * result + (this.enhancements != null ? this.enhancements.hashCode() : 0);
result = prime * result + (this.streamOptions != null ? this.streamOptions.hashCode() : 0);
result = prime * result + (this.toolContext != null ? this.toolContext.hashCode() : 0);
result = prime * result + (this.maxTokens != null ? this.maxTokens.hashCode() : 0);
result = prime * result + (this.frequencyPenalty != null ? this.frequencyPenalty.hashCode() : 0);
result = prime * result + (this.presencePenalty != null ? this.presencePenalty.hashCode() : 0);
result = prime * result + (this.temperature != null ? this.temperature.hashCode() : 0);
result = prime * result + (this.topP != null ? this.topP.hashCode() : 0);
return result;
return Objects.hash(this.logitBias, this.user, this.n, this.stop, this.deploymentName, this.responseFormat,
this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, this.seed, this.logprobs,
this.topLogProbs, this.enhancements, this.streamOptions, this.toolContext, this.maxTokens,
this.frequencyPenalty, this.presencePenalty, this.temperature, this.topP);
}

public static class Builder {
Expand Down
Loading

0 comments on commit 1815681

Please sign in to comment.