Skip to content

Commit 3a8899b

Browse files
authored
.Net: Bugfix for AddMessageFromStreaming (#9619)
### Motivation and Context - Fixes #9458 - Fixes #6153
1 parent ffac88a commit 3a8899b

File tree

4 files changed

+109
-3
lines changed

4 files changed

+109
-3
lines changed

dotnet/src/Connectors/Connectors.OpenAI.UnitTests/Extensions/ChatHistoryExtensionsTests.cs

+32
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
using System;
4+
using System.ClientModel.Primitives;
35
using System.Collections.Generic;
46
using System.Linq;
57
using System.Threading.Tasks;
68
using Microsoft.SemanticKernel;
79
using Microsoft.SemanticKernel.ChatCompletion;
810
using Microsoft.SemanticKernel.Connectors.OpenAI;
11+
using OpenAI.Chat;
912
using Xunit;
1013

1114
namespace SemanticKernel.Connectors.OpenAI.UnitTests.Extensions;
@@ -43,4 +46,33 @@ public async Task ItCanAddMessageFromStreamingChatContentsAsync()
4346
Assert.Equal(AuthorRole.User, chatHistory[0].Role);
4447
Assert.Equal(metadata["message"], chatHistory[0].Metadata!["message"]);
4548
}
49+
50+
[Theory]
51+
[InlineData(true)]
52+
[InlineData(false)]
53+
public async Task ItKeepsOrNotToolCallsCorrectlyForStreamingChatContentsAsync(bool includeToolcalls)
54+
{
55+
var chatHistoryStreamingContents = new List<OpenAIStreamingChatMessageContent>
56+
{
57+
new(AuthorRole.User, "Hello ", [ModelReaderWriter.Read<StreamingChatToolCallUpdate>(BinaryData.FromString("{\"index\":0,\"id\":\"call_123\",\"type\":\"function\",\"function\":{\"name\":\"FakePlugin_CreateSpecialPoem\",\"arguments\":\"\"}}"))!]),
58+
new(null, "! ", [ModelReaderWriter.Read<StreamingChatToolCallUpdate>(BinaryData.FromString("{\"index\":0,\"function\":{\"arguments\":\"{}\"}}"))!]),
59+
}.ToAsyncEnumerable();
60+
var chatHistory = new ChatHistory();
61+
await foreach (var chatMessageChunk in chatHistory.AddStreamingMessageAsync(chatHistoryStreamingContents, includeToolcalls))
62+
{
63+
}
64+
65+
Assert.Single(chatHistory);
66+
var lastMessage = chatHistory.Last();
67+
Assert.IsType<OpenAIChatMessageContent>(lastMessage);
68+
var openAIChatMessageContent = (OpenAIChatMessageContent)lastMessage;
69+
if (includeToolcalls)
70+
{
71+
Assert.NotEmpty(openAIChatMessageContent.ToolCalls);
72+
}
73+
else
74+
{
75+
Assert.Empty(openAIChatMessageContent.ToolCalls);
76+
}
77+
}
4678
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
<?xml version="1.0" encoding="utf-8"?>
2+
<!-- https://learn.microsoft.com/en-us/dotnet/fundamentals/package-validation/diagnostic-ids -->
3+
<Suppressions xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xmlns:xsd="http://www.w3.org/2001/XMLSchema">
4+
<Suppression>
5+
<DiagnosticId>CP0002</DiagnosticId>
6+
<Target>M:Microsoft.SemanticKernel.OpenAIChatHistoryExtensions.AddStreamingMessageAsync(Microsoft.SemanticKernel.ChatCompletion.ChatHistory,System.Collections.Generic.IAsyncEnumerable{Microsoft.SemanticKernel.Connectors.OpenAI.OpenAIStreamingChatMessageContent})</Target>
7+
<Left>lib/net8.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll</Left>
8+
<Right>lib/net8.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll</Right>
9+
<IsBaselineSuppression>true</IsBaselineSuppression>
10+
</Suppression>
11+
<Suppression>
12+
<DiagnosticId>CP0002</DiagnosticId>
13+
<Target>M:Microsoft.SemanticKernel.OpenAIChatHistoryExtensions.AddStreamingMessageAsync(Microsoft.SemanticKernel.ChatCompletion.ChatHistory,System.Collections.Generic.IAsyncEnumerable{Microsoft.SemanticKernel.Connectors.OpenAI.OpenAIStreamingChatMessageContent})</Target>
14+
<Left>lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll</Left>
15+
<Right>lib/netstandard2.0/Microsoft.SemanticKernel.Connectors.OpenAI.dll</Right>
16+
<IsBaselineSuppression>true</IsBaselineSuppression>
17+
</Suppression>
18+
</Suppressions>

dotnet/src/Connectors/Connectors.OpenAI/Extensions/ChatHistoryExtensions.cs

+17-3
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,18 @@ public static class OpenAIChatHistoryExtensions
1919
/// </summary>
2020
/// <param name="chatHistory">Target chat history</param>
2121
/// <param name="streamingMessageContents"><see cref="IAsyncEnumerator{T}"/> list of streaming message contents</param>
22+
/// <param name="includeToolCalls">The tool call information from the processed message will be ignored (<c>false</c>) by default.</param>
23+
/// <remarks>
24+
/// Setting <c>removeToolCalls</c> to <c>false</c> should be only for manual tool calling scenarios, otherwise
25+
/// may result in the error below. See <a href="https://github.com/microsoft/semantic-kernel/issues/9458">Issue 9458</a>
26+
/// <code>An assistant message with 'tool_calls' must be followed by tool messages</code>
27+
/// </remarks>
2228
/// <returns>Returns the original streaming results with some message processing</returns>
2329
[Experimental("SKEXP0010")]
24-
public static async IAsyncEnumerable<StreamingChatMessageContent> AddStreamingMessageAsync(this ChatHistory chatHistory, IAsyncEnumerable<OpenAIStreamingChatMessageContent> streamingMessageContents)
30+
public static async IAsyncEnumerable<StreamingChatMessageContent> AddStreamingMessageAsync(
31+
this ChatHistory chatHistory,
32+
IAsyncEnumerable<OpenAIStreamingChatMessageContent> streamingMessageContents,
33+
bool includeToolCalls = false)
2534
{
2635
List<StreamingChatMessageContent> messageContents = [];
2736

@@ -43,7 +52,10 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> AddStreamingMe
4352
(contentBuilder ??= new()).Append(contentUpdate);
4453
}
4554

46-
OpenAIFunctionToolCall.TrackStreamingToolingUpdate(chatMessage.ToolCallUpdates, ref toolCallIdsByIndex, ref functionNamesByIndex, ref functionArgumentBuildersByIndex);
55+
if (includeToolCalls)
56+
{
57+
OpenAIFunctionToolCall.TrackStreamingToolingUpdate(chatMessage.ToolCallUpdates, ref toolCallIdsByIndex, ref functionNamesByIndex, ref functionArgumentBuildersByIndex);
58+
}
4759

4860
// Is always expected to have at least one chunk with the role provided from a streaming message
4961
streamedRole ??= chatMessage.Role;
@@ -62,7 +74,9 @@ public static async IAsyncEnumerable<StreamingChatMessageContent> AddStreamingMe
6274
role,
6375
contentBuilder?.ToString() ?? string.Empty,
6476
messageContents[0].ModelId!,
65-
OpenAIFunctionToolCall.ConvertToolCallUpdatesToFunctionToolCalls(ref toolCallIdsByIndex, ref functionNamesByIndex, ref functionArgumentBuildersByIndex),
77+
includeToolCalls
78+
? OpenAIFunctionToolCall.ConvertToolCallUpdatesToFunctionToolCalls(ref toolCallIdsByIndex, ref functionNamesByIndex, ref functionArgumentBuildersByIndex)
79+
: [],
6680
metadata)
6781
{ AuthorName = streamedName });
6882
}

dotnet/src/IntegrationTests/Connectors/OpenAI/OpenAIChatCompletion_StreamingTests.cs

+42
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
using System;
34
using System.Collections.Generic;
5+
using System.Linq;
46
using System.Text;
57
using System.Threading.Tasks;
68
using Microsoft.Extensions.Configuration;
@@ -147,6 +149,46 @@ public async Task TextGenerationShouldReturnMetadataAsync()
147149
Assert.Equal("Stop", finishReason);
148150
}
149151

152+
[Fact]
153+
public async Task RepeatedChatHistoryAddStreamingMessageWorksAsExpectedAsync()
154+
{
155+
// Arrange
156+
var kernel = this.CreateAndInitializeKernel();
157+
var chatCompletion = kernel.Services.GetRequiredService<IChatCompletionService>();
158+
159+
kernel.ImportPluginFromFunctions("TestFunctions",
160+
[
161+
kernel.CreateFunctionFromMethod((string input) => Task.FromResult(input), "Test", "Test executed.")
162+
]);
163+
164+
// Prepare Chat
165+
var chatService = kernel.GetRequiredService<IChatCompletionService>();
166+
167+
OpenAIPromptExecutionSettings settings = new()
168+
{
169+
FunctionChoiceBehavior = FunctionChoiceBehavior.Auto()
170+
};
171+
172+
ChatHistory chatHistory = new("You are to test the system");
173+
174+
for (int i = 0; i < 2; i++)
175+
{
176+
chatHistory.AddUserMessage("Please test the system");
177+
178+
var results = chatHistory.AddStreamingMessageAsync(chatService
179+
.GetStreamingChatMessageContentsAsync(chatHistory, settings, kernel)
180+
.Cast<OpenAIStreamingChatMessageContent>()
181+
);
182+
183+
await foreach (var result in results)
184+
{
185+
Console.Write(result.ToString());
186+
}
187+
188+
Console.WriteLine($"Call #{i} OK");
189+
}
190+
}
191+
150192
#region internals
151193

152194
private Kernel CreateAndInitializeKernel()

0 commit comments

Comments
 (0)