Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Reasoning Content support to OpenAiChatModel and related classes (For deepseek-reasoner) #2192

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
"role", choice.message().role() != null ? choice.message().role().name() : "",
"index", choice.index(),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "",
"reasoningContent", StringUtils.hasText(choice.message().reasoningContent()) ? choice.message().reasoningContent() : "");
// @formatter:on
return buildGeneration(choice, metadata, request);
}).toList();
Expand Down Expand Up @@ -346,7 +347,8 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
"role", roleMap.getOrDefault(id, ""),
"index", choice.index(),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "",
"reasoningContent", StringUtils.hasText(choice.message().reasoningContent()) ? choice.message().reasoningContent() : "");

return buildGeneration(choice, metadata, request);
}).toList();
Expand Down Expand Up @@ -543,7 +545,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {

}
return List.of(new ChatCompletionMessage(assistantMessage.getText(),
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput));
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput, null));
}
else if (message.getMessageType() == MessageType.TOOL) {
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
Expand All @@ -553,7 +555,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
return toolMessage.getResponses()
.stream()
.map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(),
tr.id(), null, null, null))
tr.id(), null, null, null, null))
.toList();
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1138,8 +1138,11 @@ public record StreamOptions(
* {@link Role#ASSISTANT} role and null otherwise.
* @param audioOutput Audio response from the model. >>>>>>> bdb66e577 (OpenAI -
* Support audio input modality)
* @param reasoningContent For deepseek-reasoner model only. The reasoning contents of
* the assistant message, before the final answer.
*/
@JsonInclude(Include.NON_NULL)
@JsonIgnoreProperties(ignoreUnknown = true)
public record ChatCompletionMessage(// @formatter:off
@JsonProperty("content") Object rawContent,
@JsonProperty("role") Role role,
Expand All @@ -1148,7 +1151,8 @@ public record ChatCompletionMessage(// @formatter:off
@JsonProperty("tool_calls")
@JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List<ToolCall> toolCalls,
@JsonProperty("refusal") String refusal,
@JsonProperty("audio") AudioOutput audioOutput) { // @formatter:on
@JsonProperty("audio") AudioOutput audioOutput,
@JsonProperty("reasoning_content") String reasoningContent) { // @formatter:on

/**
* Create a chat completion message with the given content and role. All other
Expand All @@ -1157,7 +1161,7 @@ public record ChatCompletionMessage(// @formatter:off
* @param role The role of the author of this message.
*/
public ChatCompletionMessage(Object content, Role role) {
this(content, role, null, null, null, null, null);
this(content, role, null, null, null, null, null, null);

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
*
* @author Christian Tzolov
* @author Thomas Vitale
* @author Alexandros Pappas
* @since 0.8.1
*/
public class OpenAiStreamFunctionCallingHelper {
Expand Down Expand Up @@ -97,6 +98,8 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti
String refusal = (current.refusal() != null ? current.refusal() : previous.refusal());
ChatCompletionMessage.AudioOutput audioOutput = (current.audioOutput() != null ? current.audioOutput()
: previous.audioOutput());
String reasoningContent = (current.reasoningContent() != null ? current.reasoningContent()
: previous.reasoningContent());

List<ToolCall> toolCalls = new ArrayList<>();
ToolCall lastPreviousTooCall = null;
Expand Down Expand Up @@ -126,7 +129,8 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti
toolCalls.add(lastPreviousTooCall);
}
}
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput);
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput,
reasoningContent);
}

private ToolCall merge(ToolCall previous, ToolCall current) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
*
* @author Christian Tzolov
* @author Thomas Vitale
* @author Alexandros Pappas
*/
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
public class OpenAiApiToolFunctionCallIT {
Expand Down Expand Up @@ -129,7 +130,7 @@ public void toolFunctionCall() {

// extend conversation with function response.
messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL,
functionName, toolCall.id(), null, null, null));
functionName, toolCall.id(), null, null, null, null));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
import org.springframework.context.annotation.Bean;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.core.io.Resource;
import org.springframework.http.ResponseEntity;
import org.springframework.util.StringUtils;

import static org.assertj.core.api.Assertions.assertThat;

Expand All @@ -82,6 +84,9 @@ class DeepSeekWithOpenAiChatModelIT {
@Value("classpath:/prompts/system-message.st")
private Resource systemResource;

@Autowired
private OpenAiApi openAiApi;

@Autowired
private OpenAiChatModel chatModel;

Expand Down Expand Up @@ -128,9 +133,9 @@ void streamingWithTokenUsage() {
var streamingTokenUsage = this.chatModel.stream(prompt).blockLast().getMetadata().getUsage();
var referenceTokenUsage = this.chatModel.call(prompt).getMetadata().getUsage();

assertThat(streamingTokenUsage.getPromptTokens()).isGreaterThan(0);
assertThat(streamingTokenUsage.getCompletionTokens()).isGreaterThan(0);
assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0);
assertThat(streamingTokenUsage.getPromptTokens()).isPositive();
assertThat(streamingTokenUsage.getCompletionTokens()).isPositive();
assertThat(streamingTokenUsage.getTotalTokens()).isPositive();

assertThat(streamingTokenUsage.getPromptTokens()).isEqualTo(referenceTokenUsage.getPromptTokens());
assertThat(streamingTokenUsage.getCompletionTokens()).isEqualTo(referenceTokenUsage.getCompletionTokens());
Expand Down Expand Up @@ -325,6 +330,56 @@ record ActorsFilmsRecord(String actor, List<String> movies) {

}

@Test
void chatCompletionEntityWithReasoning() {
OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage(
"Explain the theory of relativity", OpenAiApi.ChatCompletionMessage.Role.USER);
OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(List.of(chatCompletionMessage),
"deepseek-reasoner", 0.8, false);
ResponseEntity<OpenAiApi.ChatCompletion> response = this.openAiApi.chatCompletionEntity(request);

assertThat(response).isNotNull();
assertThat(response.getBody().choices().get(0).message().reasoningContent()).isNotBlank();
}

@Test
void chatCompletionStreamWithReasoning() {
OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage(
"Explain the theory of relativity", OpenAiApi.ChatCompletionMessage.Role.USER);
OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(List.of(chatCompletionMessage),
"deepseek-reasoner", 0.8, true);
Flux<OpenAiApi.ChatCompletionChunk> response = this.openAiApi.chatCompletionStream(request);

assertThat(response).isNotNull();
List<OpenAiApi.ChatCompletionChunk> chunks = response.collectList().block();
assertThat(chunks).isNotNull();
assertThat(chunks.stream().anyMatch(chunk -> !chunk.choices().get(0).delta().reasoningContent().isBlank()))
.isTrue();
}

@Test
void chatModelCallWithReasoning() {
OpenAiChatModel deepReasoner = new OpenAiChatModel(this.openAiApi,
OpenAiChatOptions.builder().model("deepseek-reasoner").build());
ChatResponse chatResponse = deepReasoner.call(new Prompt("Explain the theory of relativity"));
assertThat(chatResponse.getResults()).isNotEmpty();
assertThat(chatResponse.getResults().get(0).getOutput().getMetadata().get("reasoningContent").toString())
.isNotBlank();
}

@Test
void chatModelStreamWithReasoning() {
OpenAiChatModel deepReasoner = new OpenAiChatModel(this.openAiApi,
OpenAiChatOptions.builder().model("deepseek-reasoner").build());
Flux<ChatResponse> flux = deepReasoner.stream(new Prompt("Explain the theory of relativity"));
List<ChatResponse> responses = flux.collectList().block();
assertThat(responses).isNotEmpty();
assertThat(responses.stream()
.flatMap(response -> response.getResults().stream())
.map(result -> result.getOutput().getMetadata().get("reasoningContent").toString())
.anyMatch(StringUtils::hasText)).isTrue();
}

@SpringBootConfiguration
static class Config {

Expand Down