Skip to content

Commit

Permalink
feat: Enhance Chat Options
Browse files Browse the repository at this point in the history
Key changes include:
*   **AbstractChatOptions:** Introduced an abstract base class, reducing code duplication.
*   **DefaultChatOptions:**  A concrete implementation of `ChatOptions` built on top of
    `AbstractChatOptions`
*   **Equals and HashCode:** Implemented `equals()` and `hashCode()` methods in all ChatOptions classes and the `DefaultChatOptions` class
*   **Test Updates:**  Comprehensive test updates were made across all affected modules to
    verify the new Builder pattern, copy functionality, and the behavior of the
    `equals()` and `hashCode()` methods.

Signed-off-by: Alexandros Pappas <[email protected]>
  • Loading branch information
apappascs committed Feb 13, 2025
1 parent 171b758 commit 8c0265c
Show file tree
Hide file tree
Showing 24 changed files with 2,021 additions and 754 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,9 +17,11 @@
package org.springframework.ai.anthropic;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;

import com.fasterxml.jackson.annotation.JsonIgnore;
Expand All @@ -29,6 +31,7 @@

import org.springframework.ai.anthropic.api.AnthropicApi;
import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest;
import org.springframework.ai.chat.prompt.AbstractChatOptions;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.util.Assert;
Expand All @@ -42,16 +45,11 @@
* @since 1.0.0
*/
@JsonInclude(Include.NON_NULL)
public class AnthropicChatOptions implements FunctionCallingOptions {
public class AnthropicChatOptions extends AbstractChatOptions implements FunctionCallingOptions {

// @formatter:off
private @JsonProperty("model") String model;
private @JsonProperty("max_tokens") Integer maxTokens;

private @JsonProperty("metadata") ChatCompletionRequest.Metadata metadata;
private @JsonProperty("stop_sequences") List<String> stopSequences;
private @JsonProperty("temperature") Double temperature;
private @JsonProperty("top_p") Double topP;
private @JsonProperty("top_k") Integer topK;

/**
* Tool Function Callbacks to register with the ChatModel. For Prompt
Expand Down Expand Up @@ -81,7 +79,7 @@ public class AnthropicChatOptions implements FunctionCallingOptions {
private Boolean proxyToolCalls;

@JsonIgnore
private Map<String, Object> toolContext;
private Map<String, Object> toolContext = new HashMap<>();

// @formatter:on

Expand All @@ -93,31 +91,22 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions)
return builder().model(fromOptions.getModel())
.maxTokens(fromOptions.getMaxTokens())
.metadata(fromOptions.getMetadata())
.stopSequences(fromOptions.getStopSequences())
.stopSequences(
fromOptions.getStopSequences() != null ? new ArrayList<>(fromOptions.getStopSequences()) : null)
.temperature(fromOptions.getTemperature())
.topP(fromOptions.getTopP())
.topK(fromOptions.getTopK())
.functionCallbacks(fromOptions.getFunctionCallbacks())
.functions(fromOptions.getFunctions())
.proxyToolCalls(fromOptions.getProxyToolCalls())
.toolContext(fromOptions.getToolContext())
.toolContext(fromOptions.getToolContext() != null ? new HashMap<>(fromOptions.getToolContext()) : null)
.build();
}

@Override
public String getModel() {
return this.model;
}

public void setModel(String model) {
this.model = model;
}

@Override
public Integer getMaxTokens() {
return this.maxTokens;
}

public void setMaxTokens(Integer maxTokens) {
this.maxTokens = maxTokens;
}
Expand All @@ -130,38 +119,18 @@ public void setMetadata(ChatCompletionRequest.Metadata metadata) {
this.metadata = metadata;
}

@Override
public List<String> getStopSequences() {
return this.stopSequences;
}

public void setStopSequences(List<String> stopSequences) {
this.stopSequences = stopSequences;
}

@Override
public Double getTemperature() {
return this.temperature;
}

public void setTemperature(Double temperature) {
this.temperature = temperature;
}

@Override
public Double getTopP() {
return this.topP;
}

public void setTopP(Double topP) {
this.topP = topP;
}

@Override
public Integer getTopK() {
return this.topK;
}

public void setTopK(Integer topK) {
this.topK = topK;
}
Expand Down Expand Up @@ -224,6 +193,43 @@ public AnthropicChatOptions copy() {
return fromOptions(this);
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (!(o instanceof AnthropicChatOptions that)) {
return false;
}
return Objects.equals(this.model, that.model) && Objects.equals(this.maxTokens, that.maxTokens)
&& Objects.equals(this.metadata, that.metadata)
&& Objects.equals(this.stopSequences, that.stopSequences)
&& Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP)
&& Objects.equals(this.topK, that.topK)
&& Objects.equals(this.functionCallbacks, that.functionCallbacks)
&& Objects.equals(this.functions, that.functions)
&& Objects.equals(this.proxyToolCalls, that.proxyToolCalls)
&& Objects.equals(this.toolContext, that.toolContext);
}

@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + (this.model != null ? this.model.hashCode() : 0);
result = prime * result + (this.maxTokens != null ? this.maxTokens.hashCode() : 0);
result = prime * result + (this.metadata != null ? this.metadata.hashCode() : 0);
result = prime * result + (this.stopSequences != null ? this.stopSequences.hashCode() : 0);
result = prime * result + (this.temperature != null ? this.temperature.hashCode() : 0);
result = prime * result + (this.topP != null ? this.topP.hashCode() : 0);
result = prime * result + (this.topK != null ? this.topK.hashCode() : 0);
result = prime * result + (this.functionCallbacks != null ? this.functionCallbacks.hashCode() : 0);
result = prime * result + (this.functions != null ? this.functions.hashCode() : 0);
result = prime * result + (this.proxyToolCalls != null ? this.proxyToolCalls.hashCode() : 0);
result = prime * result + (this.toolContext != null ? this.toolContext.hashCode() : 0);
return result;
}

public static class Builder {

private final AnthropicChatOptions options = new AnthropicChatOptions();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* Copyright 2025-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.anthropic;

import java.util.List;
import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;

import org.junit.jupiter.api.Test;

import org.springframework.ai.anthropic.api.AnthropicApi.ChatCompletionRequest.Metadata;

/**
* Tests for {@link AnthropicChatOptions}.
*
* @author Alexandros Pappas
*/
class AnthropicChatOptionsTests {

@Test
void testBuilderWithAllFields() {
AnthropicChatOptions options = AnthropicChatOptions.builder()
.model("test-model")
.maxTokens(100)
.stopSequences(List.of("stop1", "stop2"))
.temperature(0.7)
.topP(0.8)
.topK(50)
.metadata(new Metadata("userId_123"))
.build();

assertThat(options).extracting("model", "maxTokens", "stopSequences", "temperature", "topP", "topK", "metadata")
.containsExactly("test-model", 100, List.of("stop1", "stop2"), 0.7, 0.8, 50, new Metadata("userId_123"));
}

@Test
void testCopy() {
AnthropicChatOptions original = AnthropicChatOptions.builder()
.model("test-model")
.maxTokens(100)
.stopSequences(List.of("stop1", "stop2"))
.temperature(0.7)
.topP(0.8)
.topK(50)
.metadata(new Metadata("userId_123"))
.toolContext(Map.of("key1", "value1"))
.build();

AnthropicChatOptions copied = original.copy();

assertThat(copied).isNotSameAs(original).isEqualTo(original);
// Ensure deep copy
assertThat(copied.getStopSequences()).isNotSameAs(original.getStopSequences());
assertThat(copied.getToolContext()).isNotSameAs(original.getToolContext());
}

@Test
void testSetters() {
AnthropicChatOptions options = new AnthropicChatOptions();
options.setModel("test-model");
options.setMaxTokens(100);
options.setTemperature(0.7);
options.setTopK(50);
options.setTopP(0.8);
options.setStopSequences(List.of("stop1", "stop2"));
options.setMetadata(new Metadata("userId_123"));

assertThat(options.getModel()).isEqualTo("test-model");
assertThat(options.getMaxTokens()).isEqualTo(100);
assertThat(options.getTemperature()).isEqualTo(0.7);
assertThat(options.getTopK()).isEqualTo(50);
assertThat(options.getTopP()).isEqualTo(0.8);
assertThat(options.getStopSequences()).isEqualTo(List.of("stop1", "stop2"));
assertThat(options.getMetadata()).isEqualTo(new Metadata("userId_123"));
}

@Test
void testDefaultValues() {
AnthropicChatOptions options = new AnthropicChatOptions();
assertThat(options.getModel()).isNull();
assertThat(options.getMaxTokens()).isNull();
assertThat(options.getTemperature()).isNull();
assertThat(options.getTopK()).isNull();
assertThat(options.getTopP()).isNull();
assertThat(options.getStopSequences()).isNull();
assertThat(options.getMetadata()).isNull();
}

}
Loading

0 comments on commit 8c0265c

Please sign in to comment.