Skip to content

Commit

Permalink
Fix Moonshot Chat model toolcalling token usage
Browse files Browse the repository at this point in the history
 - Accumulate the token usage when toolcalling is invoked
   - Fix both call() and stream() methods
     - Add `usage` field to the Chat completion choice as the usage is returned via Choice
 - Add Mootshot chatmodel ITs for functioncalling tests

Move the tests into MoonshotChatModelFunctionCallingIT
  • Loading branch information
ilayaperumalg committed Jan 21, 2025
1 parent 1c41c6a commit f5761de
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.EmptyUsage;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.ai.chat.metadata.UsageUtils;
import org.springframework.ai.chat.model.AbstractToolCallSupport;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.chat.model.ChatResponse;
Expand Down Expand Up @@ -75,6 +77,7 @@
*
* @author Geng Rong
* @author Alexandros Pappas
* @author Ilayaperumal Gopinathan
*/
public class MoonshotChatModel extends AbstractToolCallSupport implements ChatModel, StreamingChatModel {

Expand Down Expand Up @@ -180,6 +183,10 @@ private static Generation buildGeneration(Choice choice, Map<String, Object> met

@Override
public ChatResponse call(Prompt prompt) {
return this.internalCall(prompt, null);
}

public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatResponse) {
ChatCompletionRequest request = createRequest(prompt, false);

ChatModelObservationContext observationContext = ChatModelObservationContext.builder()
Expand Down Expand Up @@ -218,8 +225,11 @@ public ChatResponse call(Prompt prompt) {
// @formatter:on
return buildGeneration(choice, metadata);
}).toList();

ChatResponse chatResponse = new ChatResponse(generations, from(completionEntity.getBody()));
MoonshotApi.Usage usage = completionEntity.getBody().usage();
Usage currentUsage = (usage != null) ? MoonshotUsage.from(usage) : new EmptyUsage();
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);
ChatResponse chatResponse = new ChatResponse(generations,
from(completionEntity.getBody(), cumulativeUsage));

observationContext.setResponse(chatResponse);

Expand All @@ -232,7 +242,7 @@ && isToolCall(response, Set.of(MoonshotApi.ChatCompletionFinishReason.TOOL_CALLS
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the call method with the tool call message
// conversation that contains the call responses.
return this.call(new Prompt(toolCallConversation, prompt.getOptions()));
return this.internalCall(new Prompt(toolCallConversation, prompt.getOptions()), response);
}
return response;
}
Expand All @@ -244,6 +254,10 @@ public ChatOptions getDefaultOptions() {

@Override
public Flux<ChatResponse> stream(Prompt prompt) {
return this.internalStream(prompt, null);
}

public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousChatResponse) {
return Flux.deferContextual(contextView -> {
ChatCompletionRequest request = createRequest(prompt, true);

Expand Down Expand Up @@ -287,8 +301,11 @@ public Flux<ChatResponse> stream(Prompt prompt) {
// @formatter:on
return buildGeneration(choice, metadata);
}).toList();
MoonshotApi.Usage usage = chatCompletion2.usage();
Usage currentUsage = (usage != null) ? MoonshotUsage.from(usage) : new EmptyUsage();
Usage cumulativeUsage = UsageUtils.getCumulativeUsage(currentUsage, previousChatResponse);

return new ChatResponse(generations, from(chatCompletion2));
return new ChatResponse(generations, from(chatCompletion2, cumulativeUsage));
}
catch (Exception e) {
logger.error("Error processing chat completion", e);
Expand All @@ -303,7 +320,7 @@ public Flux<ChatResponse> stream(Prompt prompt) {
var toolCallConversation = handleToolCalls(prompt, response);
// Recursively call the stream method with the tool call message
// conversation that contains the call responses.
return this.stream(new Prompt(toolCallConversation, prompt.getOptions()));
return this.internalStream(new Prompt(toolCallConversation, prompt.getOptions()), response);
}
return Flux.just(response);
})
Expand All @@ -325,6 +342,16 @@ private ChatResponseMetadata from(ChatCompletion result) {
.build();
}

private ChatResponseMetadata from(ChatCompletion result, Usage usage) {
Assert.notNull(result, "Moonshot ChatCompletionResult must not be null");
return ChatResponseMetadata.builder()
.id(result.id() != null ? result.id() : "")
.usage(usage)
.model(result.model() != null ? result.model() : "")
.keyValue("created", result.created() != null ? result.created() : 0L)
.build();
}

/**
* Convert the ChatCompletionChunk into a ChatCompletion. The Usage is set to null.
* @param chunk the ChatCompletionChunk to convert
Expand All @@ -336,10 +363,11 @@ private ChatCompletion chunkToChatCompletion(ChatCompletionChunk chunk) {
if (delta == null) {
delta = new ChatCompletionMessage("", ChatCompletionMessage.Role.ASSISTANT);
}
return new ChatCompletion.Choice(cc.index(), delta, cc.finishReason());
return new ChatCompletion.Choice(cc.index(), delta, cc.finishReason(), cc.usage());
}).toList();

return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, null);
// Get the usage from the latest choice
MoonshotApi.Usage usage = choices.get(choices.size() - 1).usage();
return new ChatCompletion(chunk.id(), "chat.completion", chunk.created(), chunk.model(), choices, usage);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,8 @@ public record Choice(
// @formatter:off
@JsonProperty("index") Integer index,
@JsonProperty("message") ChatCompletionMessage message,
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason) {
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason,
@JsonProperty("usage") Usage usage) {
// @formatter:on
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) {
: previous.finishReason());
Integer index = (current.index() != null ? current.index() : previous.index());

MoonshotApi.Usage usage = current.usage() != null ? current.usage() : previous.usage();

ChatCompletionMessage message = merge(previous.delta(), current.delta());
return new ChunkChoice(index, message, finishReason, null);
return new ChunkChoice(index, message, finishReason, usage);
}

private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public void beforeEach() {
public void moonshotChatTransientError() {

var choice = new ChatCompletion.Choice(0, new ChatCompletionMessage("Response", Role.ASSISTANT),
ChatCompletionFinishReason.STOP);
ChatCompletionFinishReason.STOP, null);
ChatCompletion expectedChatCompletion = new ChatCompletion("id", "chat.completion", 789L, "model",
List.of(choice), new MoonshotApi.Usage(10, 10, 10));

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -17,6 +17,7 @@
package org.springframework.ai.moonshot.chat;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -53,6 +54,33 @@ class MoonshotChatModelFunctionCallingIT {
@Autowired
ChatModel chatModel;

private static final MoonshotApi.FunctionTool FUNCTION_TOOL = new MoonshotApi.FunctionTool(
MoonshotApi.FunctionTool.Type.FUNCTION, new MoonshotApi.FunctionTool.Function(
"Get the weather in location. Return temperature in 30°F or 30°C format.", "getCurrentWeather", """
{
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"lat": {
"type": "number",
"description": "The city latitude"
},
"lon": {
"type": "number",
"description": "The city longitude"
},
"unit": {
"type": "string",
"enum": ["C", "F"]
}
},
"required": ["location", "lat", "lon", "unit"]
}
"""));

@Test
void functionCallTest() {

Expand Down Expand Up @@ -89,6 +117,7 @@ void streamFunctionCallTest() {
.functionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location")
.inputType(MockWeatherService.Request.class)
.build()))
.build();

Expand All @@ -108,4 +137,47 @@ void streamFunctionCallTest() {
assertThat(content).contains("30", "10", "15");
}

@Test
public void toolFunctionCallWithUsage() {
var promptOptions = MoonshotChatOptions.builder()
.model(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
.tools(Arrays.asList(FUNCTION_TOOL))
.functionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location. Return temperature in 36°F or 36°C format.")
.inputType(MockWeatherService.Request.class)
.build()))
.build();
Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.",
promptOptions);

ChatResponse chatResponse = this.chatModel.call(prompt);
assertThat(chatResponse).isNotNull();
assertThat(chatResponse.getResult().getOutput());
assertThat(chatResponse.getResult().getOutput().getText()).contains("San Francisco");
assertThat(chatResponse.getResult().getOutput().getText()).contains("30.0");
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280);
}

@Test
public void testStreamFunctionCallUsage() {
var promptOptions = MoonshotChatOptions.builder()
.model(MoonshotApi.ChatModel.MOONSHOT_V1_8K.getValue())
.tools(Arrays.asList(FUNCTION_TOOL))
.functionCallbacks(List.of(FunctionCallback.builder()
.function("getCurrentWeather", new MockWeatherService())
.description("Get the weather in location. Return temperature in 36°F or 36°C format.")
.inputType(MockWeatherService.Request.class)
.build()))
.build();
Prompt prompt = new Prompt("What's the weather like in San Francisco? Return the temperature in Celsius.",
promptOptions);

ChatResponse chatResponse = this.chatModel.stream(prompt).blockLast();
assertThat(chatResponse).isNotNull();
assertThat(chatResponse.getMetadata()).isNotNull();
assertThat(chatResponse.getMetadata().getUsage()).isNotNull();
assertThat(chatResponse.getMetadata().getUsage().getTotalTokens()).isLessThan(450).isGreaterThan(280);
}

}

0 comments on commit f5761de

Please sign in to comment.