Skip to content

Commit

Permalink
Add reasoningContent support to OpenAiChatModel and related classes (…
Browse files Browse the repository at this point in the history
…For deepseek-reasoner https://api-docs.deepseek.com/api/create-chat-completion)

- Added reasoningContent field to metadata in OpenAiChatModel
- Updated ChatCompletionMessage to include reasoningContent
- Modified OpenAiStreamFunctionCallingHelper to handle reasoningContent
- Updated tests to verify reasoningContent functionality

Signed-off-by: Alexandros Pappas <[email protected]>
  • Loading branch information
apappascs committed Feb 7, 2025
1 parent 171b758 commit e3947c2
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
"role", choice.message().role() != null ? choice.message().role().name() : "",
"index", choice.index(),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "",
"reasoningContent", StringUtils.hasText(choice.message().reasoningContent()) ? choice.message().reasoningContent() : "");
// @formatter:on
return buildGeneration(choice, metadata, request);
}).toList();
Expand Down Expand Up @@ -346,7 +347,8 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha
"role", roleMap.getOrDefault(id, ""),
"index", choice.index(),
"finishReason", choice.finishReason() != null ? choice.finishReason().name() : "",
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "");
"refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "",
"reasoningContent", StringUtils.hasText(choice.message().reasoningContent()) ? choice.message().reasoningContent() : "");

return buildGeneration(choice, metadata, request);
}).toList();
Expand Down Expand Up @@ -543,7 +545,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) {

}
return List.of(new ChatCompletionMessage(assistantMessage.getText(),
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput));
ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput, null));
}
else if (message.getMessageType() == MessageType.TOOL) {
ToolResponseMessage toolMessage = (ToolResponseMessage) message;
Expand All @@ -553,7 +555,7 @@ else if (message.getMessageType() == MessageType.TOOL) {
return toolMessage.getResponses()
.stream()
.map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(),
tr.id(), null, null, null))
tr.id(), null, null, null, null))
.toList();
}
else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1138,8 +1138,11 @@ public record StreamOptions(
* {@link Role#ASSISTANT} role and null otherwise.
* @param audioOutput Audio response from the model. >>>>>>> bdb66e577 (OpenAI -
* Support audio input modality)
* @param reasoningContent For deepseek-reasoner model only. The reasoning contents of
* the assistant message, before the final answer.
*/
@JsonInclude(Include.NON_NULL)
@JsonIgnoreProperties(ignoreUnknown = true)
public record ChatCompletionMessage(// @formatter:off
@JsonProperty("content") Object rawContent,
@JsonProperty("role") Role role,
Expand All @@ -1148,7 +1151,8 @@ public record ChatCompletionMessage(// @formatter:off
@JsonProperty("tool_calls")
@JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List<ToolCall> toolCalls,
@JsonProperty("refusal") String refusal,
@JsonProperty("audio") AudioOutput audioOutput) { // @formatter:on
@JsonProperty("audio") AudioOutput audioOutput,
@JsonProperty("reasoning_content") String reasoningContent) { // @formatter:on

/**
* Create a chat completion message with the given content and role. All other
Expand All @@ -1157,7 +1161,7 @@ public record ChatCompletionMessage(// @formatter:off
* @param role The role of the author of this message.
*/
public ChatCompletionMessage(Object content, Role role) {
this(content, role, null, null, null, null, null);
this(content, role, null, null, null, null, null, null);

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
*
* @author Christian Tzolov
* @author Thomas Vitale
* @author Alexandros Pappas
* @since 0.8.1
*/
public class OpenAiStreamFunctionCallingHelper {
Expand Down Expand Up @@ -97,6 +98,8 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti
String refusal = (current.refusal() != null ? current.refusal() : previous.refusal());
ChatCompletionMessage.AudioOutput audioOutput = (current.audioOutput() != null ? current.audioOutput()
: previous.audioOutput());
String reasoningContent = (current.reasoningContent() != null ? current.reasoningContent()
: previous.reasoningContent());

List<ToolCall> toolCalls = new ArrayList<>();
ToolCall lastPreviousTooCall = null;
Expand Down Expand Up @@ -126,7 +129,8 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti
toolCalls.add(lastPreviousTooCall);
}
}
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput);
return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput,
reasoningContent);
}

private ToolCall merge(ToolCall previous, ToolCall current) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
*
* @author Christian Tzolov
* @author Thomas Vitale
* @author Alexandros Pappas
*/
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
public class OpenAiApiToolFunctionCallIT {
Expand Down Expand Up @@ -129,7 +130,7 @@ public void toolFunctionCall() {

// extend conversation with function response.
messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL,
functionName, toolCall.id(), null, null, null));
functionName, toolCall.id(), null, null, null, null));
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@
import org.springframework.context.annotation.Bean;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.core.io.Resource;
import org.springframework.http.ResponseEntity;
import org.springframework.util.StringUtils;

import static org.assertj.core.api.Assertions.assertThat;

Expand All @@ -70,7 +72,6 @@
*/
@SpringBootTest(classes = DeepSeekWithOpenAiChatModelIT.Config.class)
@EnabledIfEnvironmentVariable(named = "DEEPSEEK_API_KEY", matches = ".+")
@Disabled("Requires DeepSeek credits")
class DeepSeekWithOpenAiChatModelIT {

private static final Logger logger = LoggerFactory.getLogger(DeepSeekWithOpenAiChatModelIT.class);
Expand All @@ -82,6 +83,9 @@ class DeepSeekWithOpenAiChatModelIT {
@Value("classpath:/prompts/system-message.st")
private Resource systemResource;

@Autowired
private OpenAiApi openAiApi;

@Autowired
private OpenAiChatModel chatModel;

Expand Down Expand Up @@ -128,9 +132,9 @@ void streamingWithTokenUsage() {
var streamingTokenUsage = this.chatModel.stream(prompt).blockLast().getMetadata().getUsage();
var referenceTokenUsage = this.chatModel.call(prompt).getMetadata().getUsage();

assertThat(streamingTokenUsage.getPromptTokens()).isGreaterThan(0);
assertThat(streamingTokenUsage.getCompletionTokens()).isGreaterThan(0);
assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0);
assertThat(streamingTokenUsage.getPromptTokens()).isPositive();
assertThat(streamingTokenUsage.getCompletionTokens()).isPositive();
assertThat(streamingTokenUsage.getTotalTokens()).isPositive();

assertThat(streamingTokenUsage.getPromptTokens()).isEqualTo(referenceTokenUsage.getPromptTokens());
assertThat(streamingTokenUsage.getCompletionTokens()).isEqualTo(referenceTokenUsage.getCompletionTokens());
Expand Down Expand Up @@ -325,6 +329,56 @@ record ActorsFilmsRecord(String actor, List<String> movies) {

}

@Test
void chatCompletionEntityWithReasoning() {
OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage(
"Explain the theory of relativity", OpenAiApi.ChatCompletionMessage.Role.USER);
OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(List.of(chatCompletionMessage),
"deepseek-reasoner", 0.8, false);
ResponseEntity<OpenAiApi.ChatCompletion> response = this.openAiApi.chatCompletionEntity(request);

assertThat(response).isNotNull();
assertThat(response.getBody().choices().get(0).message().reasoningContent()).isNotBlank();
}

@Test
void chatCompletionStreamWithReasoning() {
OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage(
"Explain the theory of relativity", OpenAiApi.ChatCompletionMessage.Role.USER);
OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(List.of(chatCompletionMessage),
"deepseek-reasoner", 0.8, true);
Flux<OpenAiApi.ChatCompletionChunk> response = this.openAiApi.chatCompletionStream(request);

assertThat(response).isNotNull();
List<OpenAiApi.ChatCompletionChunk> chunks = response.collectList().block();
assertThat(chunks).isNotNull();
assertThat(chunks.stream().anyMatch(chunk -> !chunk.choices().get(0).delta().reasoningContent().isBlank()))
.isTrue();
}

@Test
void chatModelCallWithReasoning() {
OpenAiChatModel deepReasoner = new OpenAiChatModel(this.openAiApi,
OpenAiChatOptions.builder().model("deepseek-reasoner").build());
ChatResponse chatResponse = deepReasoner.call(new Prompt("Explain the theory of relativity"));
assertThat(chatResponse.getResults()).isNotEmpty();
assertThat(chatResponse.getResults().get(0).getOutput().getMetadata().get("reasoningContent").toString())
.isNotBlank();
}

@Test
void chatModelStreamWithReasoning() {
OpenAiChatModel deepReasoner = new OpenAiChatModel(this.openAiApi,
OpenAiChatOptions.builder().model("deepseek-reasoner").build());
Flux<ChatResponse> flux = deepReasoner.stream(new Prompt("Explain the theory of relativity"));
List<ChatResponse> responses = flux.collectList().block();
assertThat(responses).isNotEmpty();
assertThat(responses.stream()
.flatMap(response -> response.getResults().stream())
.map(result -> result.getOutput().getMetadata().get("reasoningContent").toString())
.anyMatch(StringUtils::hasText)).isTrue();
}

@SpringBootConfiguration
static class Config {

Expand Down

0 comments on commit e3947c2

Please sign in to comment.