Skip to content

Commit

Permalink
Models updates
Browse files Browse the repository at this point in the history
  - Remove AbstractToolCallSupport from the models which use ToolCallingManager
  - Remove deprecated constructors and their usage
  - Remove FunctionCallbackResolver and FunctionCallbacks usage in the models

Signed-off-by: Ilayaperumal Gopinathan <[email protected]>
  • Loading branch information
ilayaperumalg committed Feb 19, 2025
1 parent 0fd136e commit 8df9049
Show file tree
Hide file tree
Showing 25 changed files with 66 additions and 1,040 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,6 @@
import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.model.tool.LegacyToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.lang.Nullable;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

Expand All @@ -57,7 +50,6 @@
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.metadata.UsageUtils;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
Expand All @@ -70,10 +62,12 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.Media;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.http.ResponseEntity;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
Expand All @@ -91,7 +85,7 @@
* @author Alexandros Pappas
* @since 1.0.0
*/
public class AnthropicChatModel extends AbstractToolCallSupport implements ChatModel {
public class AnthropicChatModel implements ChatModel {

public static final String DEFAULT_MODEL_NAME = AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getValue();

Expand Down Expand Up @@ -132,111 +126,9 @@ public class AnthropicChatModel extends AbstractToolCallSupport implements ChatM
*/
private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION;

/**
* Construct a new {@link AnthropicChatModel} instance.
* @param anthropicApi the lower-level API for the Anthropic service.
* @deprecated Use {@link AnthropicChatModel.Builder}.
*/
@Deprecated
public AnthropicChatModel(AnthropicApi anthropicApi) {
this(anthropicApi,
AnthropicChatOptions.builder()
.model(DEFAULT_MODEL_NAME)
.maxTokens(DEFAULT_MAX_TOKENS)
.temperature(DEFAULT_TEMPERATURE)
.build());
}

/**
* Construct a new {@link AnthropicChatModel} instance.
* @param anthropicApi the lower-level API for the Anthropic service.
* @param defaultOptions the default options used for the chat completion requests.
* @deprecated Use {@link AnthropicChatModel.Builder}.
*/
@Deprecated
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions) {
this(anthropicApi, defaultOptions, RetryUtils.DEFAULT_RETRY_TEMPLATE);
}

/**
* Construct a new {@link AnthropicChatModel} instance.
* @param anthropicApi the lower-level API for the Anthropic service.
* @param defaultOptions the default options used for the chat completion requests.
* @param retryTemplate the retry template used to retry the Anthropic API calls.
* @deprecated Use {@link AnthropicChatModel.Builder}.
*/
@Deprecated
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
RetryTemplate retryTemplate) {
this(anthropicApi, defaultOptions, retryTemplate, null);
}

/**
* Construct a new {@link AnthropicChatModel} instance.
* @param anthropicApi the lower-level API for the Anthropic service.
* @param defaultOptions the default options used for the chat completion requests.
* @param retryTemplate the retry template used to retry the Anthropic API calls.
* @param functionCallbackResolver the function callback resolver used to resolve the
* function by its name.
* @deprecated Use {@link AnthropicChatModel.Builder}.
*/
@Deprecated
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
RetryTemplate retryTemplate, FunctionCallbackResolver functionCallbackResolver) {
this(anthropicApi, defaultOptions, retryTemplate, functionCallbackResolver, List.of());
}

/**
* Construct a new {@link AnthropicChatModel} instance.
* @param anthropicApi the lower-level API for the Anthropic service.
* @param defaultOptions the default options used for the chat completion requests.
* @param retryTemplate the retry template used to retry the Anthropic API calls.
* @param functionCallbackResolver the function callback resolver used to resolve the
* function by its name.
* @param toolFunctionCallbacks the tool function callbacks used to handle the tool
* calls.
* @deprecated Use {@link AnthropicChatModel.Builder}.
*/
@Deprecated
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
RetryTemplate retryTemplate, FunctionCallbackResolver functionCallbackResolver,
List<FunctionCallback> toolFunctionCallbacks) {
this(anthropicApi, defaultOptions, retryTemplate, functionCallbackResolver, toolFunctionCallbacks,
ObservationRegistry.NOOP);
}

/**
* Construct a new {@link AnthropicChatModel} instance.
* @param anthropicApi the lower-level API for the Anthropic service.
* @param defaultOptions the default options used for the chat completion requests.
* @param retryTemplate the retry template used to retry the Anthropic API calls.
* @param functionCallbackResolver the function callback resolver used to resolve the
* function by its name.
* @param toolFunctionCallbacks the tool function callbacks used to handle the tool
* calls.
* @deprecated Use {@link AnthropicChatModel.Builder}.
*/
@Deprecated
public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
RetryTemplate retryTemplate, @Nullable FunctionCallbackResolver functionCallbackResolver,
@Nullable List<FunctionCallback> toolFunctionCallbacks, ObservationRegistry observationRegistry) {
this(anthropicApi, defaultOptions,
LegacyToolCallingManager.builder()
.functionCallbackResolver(functionCallbackResolver)
.functionCallbacks(toolFunctionCallbacks)
.build(),
retryTemplate, observationRegistry);
logger.warn("This constructor is deprecated and will be removed in the next milestone. "
+ "Please use the MistralAiChatModel.Builder or the new constructor accepting ToolCallingManager instead.");
}

public AnthropicChatModel(AnthropicApi anthropicApi, AnthropicChatOptions defaultOptions,
ToolCallingManager toolCallingManager, RetryTemplate retryTemplate,
ObservationRegistry observationRegistry) {
// We do not pass the 'defaultOptions' to the AbstractToolSupport,
// because it modifies them. We are using ToolCallingManager instead,
// so we just pass empty options here.
super(null, AnthropicChatOptions.builder().build(), List.of());

Assert.notNull(anthropicApi, "anthropicApi cannot be null");
Assert.notNull(defaultOptions, "defaultOptions cannot be null");
Expand Down Expand Up @@ -470,10 +362,6 @@ Prompt buildRequestPrompt(Prompt prompt) {
runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,
AnthropicChatOptions.class);
}
else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class,
AnthropicChatOptions.class);
}
else {
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
AnthropicChatOptions.class);
Expand Down Expand Up @@ -621,10 +509,6 @@ public static final class Builder {

private RetryTemplate retryTemplate = RetryUtils.DEFAULT_RETRY_TEMPLATE;

private FunctionCallbackResolver functionCallbackResolver;

private List<FunctionCallback> toolCallbacks;

private ToolCallingManager toolCallingManager;

private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;
Expand Down Expand Up @@ -652,41 +536,16 @@ public Builder toolCallingManager(ToolCallingManager toolCallingManager) {
return this;
}

@Deprecated
public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) {
this.functionCallbackResolver = functionCallbackResolver;
return this;
}

@Deprecated
public Builder toolCallbacks(List<FunctionCallback> toolCallbacks) {
this.toolCallbacks = toolCallbacks;
return this;
}

public Builder observationRegistry(ObservationRegistry observationRegistry) {
this.observationRegistry = observationRegistry;
return this;
}

public AnthropicChatModel build() {
if (toolCallingManager != null) {
Assert.isNull(functionCallbackResolver,
"functionCallbackResolver cannot be set when toolCallingManager is set");
Assert.isNull(toolCallbacks, "toolCallbacks cannot be set when toolCallingManager is set");

return new AnthropicChatModel(anthropicApi, defaultOptions, toolCallingManager, retryTemplate,
observationRegistry);
}
if (functionCallbackResolver != null) {
Assert.isNull(toolCallingManager,
"toolCallingManager cannot be set when functionCallbackResolver is set");
List<FunctionCallback> toolCallbacks = this.toolCallbacks != null ? this.toolCallbacks : List.of();

return new AnthropicChatModel(anthropicApi, defaultOptions, retryTemplate, functionCallbackResolver,
toolCallbacks, observationRegistry);
}

return new AnthropicChatModel(anthropicApi, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER, retryTemplate,
observationRegistry);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
import org.springframework.ai.chat.observation.DefaultChatModelObservationConvention;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.function.DefaultFunctionCallbackResolver;
import org.springframework.ai.model.tool.DefaultToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.observation.conventions.AiOperationType;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.beans.factory.annotation.Autowired;
Expand Down Expand Up @@ -171,8 +173,7 @@ public AnthropicApi anthropicApi() {
public AnthropicChatModel anthropicChatModel(AnthropicApi anthropicApi,
TestObservationRegistry observationRegistry) {
return new AnthropicChatModel(anthropicApi, AnthropicChatOptions.builder().build(),
RetryTemplate.defaultInstance(), new DefaultFunctionCallbackResolver(), List.of(),
observationRegistry);
ToolCallingManager.builder().build(), RetryTemplate.defaultInstance(), observationRegistry);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ public class ChatCompletionRequestTests {
@Test
public void createRequestWithChatOptions() {

var client = new AnthropicChatModel(new AnthropicApi("TEST"),
AnthropicChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6).build());
var client = AnthropicChatModel.builder()
.anthropicApi(new AnthropicApi("TEST"))
.defaultOptions(AnthropicChatOptions.builder().model("DEFAULT_MODEL").temperature(66.6).build())
.build();

var prompt = client.buildRequestPrompt(new Prompt("Test message content"));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@
import org.springframework.ai.chat.metadata.PromptMetadata.PromptFilterMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.metadata.UsageUtils;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
Expand All @@ -85,16 +84,12 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.Media;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackResolver;
import org.springframework.ai.model.function.FunctionCallingOptions;
import org.springframework.ai.model.tool.LegacyToolCallingManager;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.model.tool.ToolCallingManager;
import org.springframework.ai.model.tool.ToolExecutionResult;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

Expand All @@ -119,7 +114,7 @@
* @see com.azure.ai.openai.OpenAIClient
* @since 1.0.0
*/
public class AzureOpenAiChatModel extends AbstractToolCallSupport implements ChatModel {
public class AzureOpenAiChatModel implements ChatModel {

private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiChatModel.class);

Expand Down Expand Up @@ -163,10 +158,6 @@ public class AzureOpenAiChatModel extends AbstractToolCallSupport implements Cha

public AzureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatOptions defaultOptions,
ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry) {
// We do not pass the 'defaultOptions' to the AbstractToolSupport,
// because it modifies them. We are using ToolCallingManager instead,
// so we just pass empty options here.
super(null, AzureOpenAiChatOptions.builder().build(), List.of());
Assert.notNull(openAIClientBuilder, "com.azure.ai.openai.OpenAIClient must not be null");
Assert.notNull(defaultOptions, "defaultOptions cannot be null");
Assert.notNull(toolCallingManager, "toolCallingManager cannot be null");
Expand Down Expand Up @@ -488,10 +479,6 @@ ChatCompletionsOptions toAzureChatCompletionsOptions(Prompt prompt) {

if (prompt.getOptions() != null) {
AzureOpenAiChatOptions updatedRuntimeOptions;
if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions,
FunctionCallingOptions.class, AzureOpenAiChatOptions.class);
}
if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
updatedRuntimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions,
ToolCallingChatOptions.class, AzureOpenAiChatOptions.class);
Expand Down Expand Up @@ -622,10 +609,6 @@ Prompt buildRequestPrompt(Prompt prompt) {
runtimeOptions = ModelOptionsUtils.copyToTarget(toolCallingChatOptions, ToolCallingChatOptions.class,
AzureOpenAiChatOptions.class);
}
else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOptions) {
runtimeOptions = ModelOptionsUtils.copyToTarget(functionCallingOptions, FunctionCallingOptions.class,
AzureOpenAiChatOptions.class);
}
else {
runtimeOptions = ModelOptionsUtils.copyToTarget(prompt.getOptions(), ChatOptions.class,
AzureOpenAiChatOptions.class);
Expand Down Expand Up @@ -932,10 +915,6 @@ public static class Builder {

private ToolCallingManager toolCallingManager;

private FunctionCallbackResolver functionCallbackResolver;

private List<FunctionCallback> toolFunctionCallbacks;

private ObservationRegistry observationRegistry = ObservationRegistry.NOOP;

private Builder() {
Expand All @@ -956,48 +935,16 @@ public Builder toolCallingManager(ToolCallingManager toolCallingManager) {
return this;
}

@Deprecated
public Builder functionCallbackResolver(FunctionCallbackResolver functionCallbackResolver) {
this.functionCallbackResolver = functionCallbackResolver;
return this;
}

@Deprecated
public Builder toolFunctionCallbacks(List<FunctionCallback> toolFunctionCallbacks) {
this.toolFunctionCallbacks = toolFunctionCallbacks;
return this;
}

public Builder observationRegistry(ObservationRegistry observationRegistry) {
this.observationRegistry = observationRegistry;
return this;
}

public AzureOpenAiChatModel build() {
if (toolCallingManager != null) {
Assert.isNull(functionCallbackResolver,
"functionCallbackResolver cannot be set when toolCallingManager is set");
Assert.isNull(toolFunctionCallbacks,
"toolFunctionCallbacks cannot be set when toolCallingManager is set");

return new AzureOpenAiChatModel(openAIClientBuilder, defaultOptions, toolCallingManager,
observationRegistry);
}

if (functionCallbackResolver != null) {
Assert.isNull(toolCallingManager,
"toolCallingManager cannot be set when functionCallbackResolver is set");
List<FunctionCallback> toolCallbacks = this.toolFunctionCallbacks != null ? this.toolFunctionCallbacks
: List.of();

return new Builder().openAIClientBuilder(openAIClientBuilder)
.defaultOptions(defaultOptions)
.functionCallbackResolver(functionCallbackResolver)
.toolFunctionCallbacks(toolCallbacks)
.observationRegistry(observationRegistry)
.build();
}

return new AzureOpenAiChatModel(openAIClientBuilder, defaultOptions, DEFAULT_TOOL_CALLING_MANAGER,
observationRegistry);
}
Expand Down
Loading

0 comments on commit 8df9049

Please sign in to comment.