diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index fdb7f2020c4..f23038772d9 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -258,7 +258,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons "role", choice.message().role() != null ? choice.message().role().name() : "", "index", choice.index(), "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "", - "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : ""); + "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "", + "reasoningContent", StringUtils.hasText(choice.message().reasoningContent()) ? choice.message().reasoningContent() : ""); // @formatter:on return buildGeneration(choice, metadata, request); }).toList(); @@ -346,7 +347,8 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha "role", roleMap.getOrDefault(id, ""), "index", choice.index(), "finishReason", choice.finishReason() != null ? choice.finishReason().name() : "", - "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : ""); + "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "", + "reasoningContent", StringUtils.hasText(choice.message().reasoningContent()) ? choice.message().reasoningContent() : ""); return buildGeneration(choice, metadata, request); }).toList(); @@ -543,7 +545,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) { } return List.of(new ChatCompletionMessage(assistantMessage.getText(), - ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput)); + ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput, null)); } else if (message.getMessageType() == MessageType.TOOL) { ToolResponseMessage toolMessage = (ToolResponseMessage) message; @@ -553,7 +555,7 @@ else if (message.getMessageType() == MessageType.TOOL) { return toolMessage.getResponses() .stream() .map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(), - tr.id(), null, null, null)) + tr.id(), null, null, null, null)) .toList(); } else { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 7827f034d55..702fca27193 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -1138,8 +1138,11 @@ public record StreamOptions( * {@link Role#ASSISTANT} role and null otherwise. * @param audioOutput Audio response from the model. >>>>>>> bdb66e577 (OpenAI - * Support audio input modality) + * @param reasoningContent For deepseek-reasoner model only. The reasoning contents of + * the assistant message, before the final answer. */ @JsonInclude(Include.NON_NULL) + @JsonIgnoreProperties(ignoreUnknown = true) public record ChatCompletionMessage(// @formatter:off @JsonProperty("content") Object rawContent, @JsonProperty("role") Role role, @@ -1148,7 +1151,8 @@ public record ChatCompletionMessage(// @formatter:off @JsonProperty("tool_calls") @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List toolCalls, @JsonProperty("refusal") String refusal, - @JsonProperty("audio") AudioOutput audioOutput) { // @formatter:on + @JsonProperty("audio") AudioOutput audioOutput, + @JsonProperty("reasoning_content") String reasoningContent) { // @formatter:on /** * Create a chat completion message with the given content and role. All other @@ -1157,7 +1161,7 @@ public record ChatCompletionMessage(// @formatter:off * @param role The role of the author of this message. */ public ChatCompletionMessage(Object content, Role role) { - this(content, role, null, null, null, null, null); + this(content, role, null, null, null, null, null, null); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java index a65f9fbcc0d..c40c2f3fed5 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java @@ -39,6 +39,7 @@ * * @author Christian Tzolov * @author Thomas Vitale + * @author Alexandros Pappas * @since 0.8.1 */ public class OpenAiStreamFunctionCallingHelper { @@ -97,6 +98,8 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti String refusal = (current.refusal() != null ? current.refusal() : previous.refusal()); ChatCompletionMessage.AudioOutput audioOutput = (current.audioOutput() != null ? current.audioOutput() : previous.audioOutput()); + String reasoningContent = (current.reasoningContent() != null ? current.reasoningContent() + : previous.reasoningContent()); List toolCalls = new ArrayList<>(); ToolCall lastPreviousTooCall = null; @@ -126,7 +129,8 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti toolCalls.add(lastPreviousTooCall); } } - return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput); + return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput, + reasoningContent); } private ToolCall merge(ToolCall previous, ToolCall current) { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java index 33f50b1f4d7..dcdf73e0d68 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java @@ -44,6 +44,7 @@ * * @author Christian Tzolov * @author Thomas Vitale + * @author Alexandros Pappas */ @EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") public class OpenAiApiToolFunctionCallIT { @@ -129,7 +130,7 @@ public void toolFunctionCall() { // extend conversation with function response. messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL, - functionName, toolCall.id(), null, null, null)); + functionName, toolCall.id(), null, null, null, null)); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/DeepSeekWithOpenAiChatModelIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/DeepSeekWithOpenAiChatModelIT.java index 461b8cdafda..4c4abca1501 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/DeepSeekWithOpenAiChatModelIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/proxy/DeepSeekWithOpenAiChatModelIT.java @@ -56,6 +56,8 @@ import org.springframework.context.annotation.Bean; import org.springframework.core.convert.support.DefaultConversionService; import org.springframework.core.io.Resource; +import org.springframework.http.ResponseEntity; +import org.springframework.util.StringUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -70,7 +72,6 @@ */ @SpringBootTest(classes = DeepSeekWithOpenAiChatModelIT.Config.class) @EnabledIfEnvironmentVariable(named = "DEEPSEEK_API_KEY", matches = ".+") -@Disabled("Requires DeepSeek credits") class DeepSeekWithOpenAiChatModelIT { private static final Logger logger = LoggerFactory.getLogger(DeepSeekWithOpenAiChatModelIT.class); @@ -82,6 +83,9 @@ class DeepSeekWithOpenAiChatModelIT { @Value("classpath:/prompts/system-message.st") private Resource systemResource; + @Autowired + private OpenAiApi openAiApi; + @Autowired private OpenAiChatModel chatModel; @@ -128,9 +132,9 @@ void streamingWithTokenUsage() { var streamingTokenUsage = this.chatModel.stream(prompt).blockLast().getMetadata().getUsage(); var referenceTokenUsage = this.chatModel.call(prompt).getMetadata().getUsage(); - assertThat(streamingTokenUsage.getPromptTokens()).isGreaterThan(0); - assertThat(streamingTokenUsage.getCompletionTokens()).isGreaterThan(0); - assertThat(streamingTokenUsage.getTotalTokens()).isGreaterThan(0); + assertThat(streamingTokenUsage.getPromptTokens()).isPositive(); + assertThat(streamingTokenUsage.getCompletionTokens()).isPositive(); + assertThat(streamingTokenUsage.getTotalTokens()).isPositive(); assertThat(streamingTokenUsage.getPromptTokens()).isEqualTo(referenceTokenUsage.getPromptTokens()); assertThat(streamingTokenUsage.getCompletionTokens()).isEqualTo(referenceTokenUsage.getCompletionTokens()); @@ -325,6 +329,56 @@ record ActorsFilmsRecord(String actor, List movies) { } + @Test + void chatCompletionEntityWithReasoning() { + OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage( + "Explain the theory of relativity", OpenAiApi.ChatCompletionMessage.Role.USER); + OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(List.of(chatCompletionMessage), + "deepseek-reasoner", 0.8, false); + ResponseEntity response = this.openAiApi.chatCompletionEntity(request); + + assertThat(response).isNotNull(); + assertThat(response.getBody().choices().get(0).message().reasoningContent()).isNotBlank(); + } + + @Test + void chatCompletionStreamWithReasoning() { + OpenAiApi.ChatCompletionMessage chatCompletionMessage = new OpenAiApi.ChatCompletionMessage( + "Explain the theory of relativity", OpenAiApi.ChatCompletionMessage.Role.USER); + OpenAiApi.ChatCompletionRequest request = new OpenAiApi.ChatCompletionRequest(List.of(chatCompletionMessage), + "deepseek-reasoner", 0.8, true); + Flux response = this.openAiApi.chatCompletionStream(request); + + assertThat(response).isNotNull(); + List chunks = response.collectList().block(); + assertThat(chunks).isNotNull(); + assertThat(chunks.stream().anyMatch(chunk -> !chunk.choices().get(0).delta().reasoningContent().isBlank())) + .isTrue(); + } + + @Test + void chatModelCallWithReasoning() { + OpenAiChatModel deepReasoner = new OpenAiChatModel(this.openAiApi, + OpenAiChatOptions.builder().model("deepseek-reasoner").build()); + ChatResponse chatResponse = deepReasoner.call(new Prompt("Explain the theory of relativity")); + assertThat(chatResponse.getResults()).isNotEmpty(); + assertThat(chatResponse.getResults().get(0).getOutput().getMetadata().get("reasoningContent").toString()) + .isNotBlank(); + } + + @Test + void chatModelStreamWithReasoning() { + OpenAiChatModel deepReasoner = new OpenAiChatModel(this.openAiApi, + OpenAiChatOptions.builder().model("deepseek-reasoner").build()); + Flux flux = deepReasoner.stream(new Prompt("Explain the theory of relativity")); + List responses = flux.collectList().block(); + assertThat(responses).isNotEmpty(); + assertThat(responses.stream() + .flatMap(response -> response.getResults().stream()) + .map(result -> result.getOutput().getMetadata().get("reasoningContent").toString()) + .anyMatch(StringUtils::hasText)).isTrue(); + } + @SpringBootConfiguration static class Config {