Skip to content

Commit 4cd7f07

Browse files
authored
.Net Agents - Fix Aggregator Streaming for Nested Mode (#9669)
### Motivation and Context <!-- Thank you for your contribution to the semantic-kernel repo! Please help reviewers and future users, providing the following information: 1. Why is this change required? 2. What problem does it solve? 3. What scenario does it contribute to? 4. If it fixes an open issue, please link to the issue here. --> Fixes: #8677 Aggretator agent not yielding content for streamed response when Mode=Nested. ### Description <!-- Describe your changes, the overall approach, the underlying design. These notes will help understanding how your code works. Thanks! --> Analyzed and addressed behavior and added integration tests. ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [X] The code builds clean without any errors or warnings - [X] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [X] All unit tests pass, and I have added new tests where possible - [X] I didn't break anyone 😄
1 parent 63d1dc7 commit 4cd7f07

File tree

2 files changed

+192
-6
lines changed

2 files changed

+192
-6
lines changed

dotnet/src/Agents/Abstractions/AggregatorChannel.cs

+8-6
Original file line numberDiff line numberDiff line change
@@ -54,29 +54,31 @@ protected internal override IAsyncEnumerable<ChatMessageContent> GetHistoryAsync
5454
/// <inheritdoc/>
5555
protected internal override async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsync(AggregatorAgent agent, IList<ChatMessageContent> messages, [EnumeratorCancellation] CancellationToken cancellationToken = default)
5656
{
57-
int messageCount = await this._chat.GetChatMessagesAsync(cancellationToken).CountAsync(cancellationToken).ConfigureAwait(false);
57+
int initialCount = await this._chat.GetChatMessagesAsync(cancellationToken).CountAsync(cancellationToken).ConfigureAwait(false);
5858

59-
if (agent.Mode == AggregatorMode.Flat)
59+
await foreach (StreamingChatMessageContent message in this._chat.InvokeStreamingAsync(cancellationToken).ConfigureAwait(false))
6060
{
61-
await foreach (StreamingChatMessageContent message in this._chat.InvokeStreamingAsync(cancellationToken).ConfigureAwait(false))
61+
if (agent.Mode == AggregatorMode.Flat)
6262
{
6363
yield return message;
6464
}
6565
}
6666

6767
ChatMessageContent[] history = await this._chat.GetChatMessagesAsync(cancellationToken).ToArrayAsync(cancellationToken).ConfigureAwait(false);
68-
if (history.Length > messageCount)
68+
if (history.Length > initialCount)
6969
{
7070
if (agent.Mode == AggregatorMode.Flat)
7171
{
72-
for (int index = messageCount; index < messages.Count; ++index)
72+
for (int index = history.Length - 1; index >= initialCount; --index)
7373
{
7474
messages.Add(history[index]);
7575
}
7676
}
7777
else if (agent.Mode == AggregatorMode.Nested)
7878
{
79-
messages.Add(history[history.Length - 1]);
79+
ChatMessageContent finalMessage = history[0]; // Order descending
80+
yield return new StreamingChatMessageContent(finalMessage.Role, finalMessage.Content) { AuthorName = finalMessage.AuthorName };
81+
messages.Add(finalMessage);
8082
}
8183
}
8284
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
// Copyright (c) Microsoft. All rights reserved.
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Threading;
6+
using System.Threading.Tasks;
7+
using Azure.Identity;
8+
using Microsoft.Extensions.Configuration;
9+
using Microsoft.SemanticKernel;
10+
using Microsoft.SemanticKernel.Agents;
11+
using Microsoft.SemanticKernel.Agents.Chat;
12+
using Microsoft.SemanticKernel.ChatCompletion;
13+
using SemanticKernel.IntegrationTests.TestSettings;
14+
using xRetry;
15+
using Xunit;
16+
17+
namespace SemanticKernel.IntegrationTests.Agents;
18+
19+
#pragma warning disable xUnit1004 // Contains test methods used in manual verification. Disable warning for this file only.
20+
21+
public sealed class AggregatorAgentTests()
22+
{
23+
private readonly IKernelBuilder _kernelBuilder = Kernel.CreateBuilder();
24+
private readonly IConfigurationRoot _configuration = new ConfigurationBuilder()
25+
.AddJsonFile(path: "testsettings.json", optional: true, reloadOnChange: true)
26+
.AddJsonFile(path: "testsettings.development.json", optional: true, reloadOnChange: true)
27+
.AddEnvironmentVariables()
28+
.AddUserSecrets<OpenAIAssistantAgentTests>()
29+
.Build();
30+
31+
/// <summary>
32+
/// Integration test for <see cref="AggregatorAgent"/> non-streamed nested response.
33+
/// </summary>
34+
[RetryFact(typeof(HttpOperationException))]
35+
public async Task AggregatorAgentFlatResponseAsync()
36+
{
37+
// Arrange
38+
AggregatorAgent aggregatorAgent = new(() => this.CreateChatProvider())
39+
{
40+
Mode = AggregatorMode.Flat,
41+
};
42+
43+
AgentGroupChat chat = new();
44+
chat.AddChatMessage(new ChatMessageContent(AuthorRole.User, "1"));
45+
46+
// Act
47+
ChatMessageContent[] responses = await chat.InvokeAsync(aggregatorAgent).ToArrayAsync();
48+
49+
// Assert
50+
ChatMessageContent[] innerHistory = await chat.GetChatMessagesAsync(aggregatorAgent).ToArrayAsync();
51+
Assert.Equal(6, innerHistory.Length);
52+
Assert.Equal(5, responses.Length);
53+
Assert.NotNull(responses[4].Content);
54+
AssertResponseContent(responses[4]);
55+
}
56+
57+
/// <summary>
58+
/// Integration test for <see cref="AggregatorAgent"/> non-streamed nested response.
59+
/// </summary>
60+
[RetryFact(typeof(HttpOperationException))]
61+
public async Task AggregatorAgentNestedResponseAsync()
62+
{
63+
// Arrange
64+
AggregatorAgent aggregatorAgent = new(() => this.CreateChatProvider())
65+
{
66+
Mode = AggregatorMode.Nested,
67+
};
68+
69+
AgentGroupChat chat = new();
70+
chat.AddChatMessage(new ChatMessageContent(AuthorRole.User, "1"));
71+
72+
// Act
73+
ChatMessageContent[] responses = await chat.InvokeAsync(aggregatorAgent).ToArrayAsync();
74+
75+
// Assert
76+
ChatMessageContent[] innerHistory = await chat.GetChatMessagesAsync(aggregatorAgent).ToArrayAsync();
77+
Assert.Equal(6, innerHistory.Length);
78+
Assert.Single(responses);
79+
Assert.NotNull(responses[0].Content);
80+
AssertResponseContent(responses[0]);
81+
}
82+
83+
/// <summary>
84+
/// Integration test for <see cref="AggregatorAgent"/> non-streamed response.
85+
/// </summary>
86+
[RetryFact(typeof(HttpOperationException))]
87+
public async Task AggregatorAgentFlatStreamAsync()
88+
{
89+
// Arrange
90+
AggregatorAgent aggregatorAgent = new(() => this.CreateChatProvider())
91+
{
92+
Mode = AggregatorMode.Flat,
93+
};
94+
95+
AgentGroupChat chat = new();
96+
chat.AddChatMessage(new ChatMessageContent(AuthorRole.User, "1"));
97+
98+
// Act
99+
StreamingChatMessageContent[] streamedResponse = await chat.InvokeStreamingAsync(aggregatorAgent).ToArrayAsync();
100+
101+
// Assert
102+
ChatMessageContent[] fullResponses = await chat.GetChatMessagesAsync().ToArrayAsync();
103+
ChatMessageContent[] innerHistory = await chat.GetChatMessagesAsync(aggregatorAgent).ToArrayAsync();
104+
Assert.NotEmpty(streamedResponse);
105+
Assert.Equal(6, innerHistory.Length);
106+
Assert.Equal(6, fullResponses.Length);
107+
Assert.NotNull(fullResponses[0].Content);
108+
AssertResponseContent(fullResponses[0]);
109+
}
110+
111+
/// <summary>
112+
/// Integration test for <see cref="AggregatorAgent"/> non-streamed response.
113+
/// </summary>
114+
[RetryFact(typeof(HttpOperationException))]
115+
public async Task AggregatorAgentNestedStreamAsync()
116+
{
117+
// Arrange
118+
AggregatorAgent aggregatorAgent = new(() => this.CreateChatProvider())
119+
{
120+
Mode = AggregatorMode.Nested,
121+
};
122+
123+
AgentGroupChat chat = new();
124+
chat.AddChatMessage(new ChatMessageContent(AuthorRole.User, "1"));
125+
126+
// Act
127+
StreamingChatMessageContent[] streamedResponse = await chat.InvokeStreamingAsync(aggregatorAgent).ToArrayAsync();
128+
129+
// Assert
130+
ChatMessageContent[] fullResponses = await chat.GetChatMessagesAsync().ToArrayAsync();
131+
ChatMessageContent[] innerHistory = await chat.GetChatMessagesAsync(aggregatorAgent).ToArrayAsync();
132+
Assert.NotEmpty(streamedResponse);
133+
Assert.Equal(6, innerHistory.Length);
134+
Assert.Equal(2, fullResponses.Length);
135+
Assert.NotNull(fullResponses[0].Content);
136+
AssertResponseContent(fullResponses[0]);
137+
}
138+
139+
private static void AssertResponseContent(ChatMessageContent response)
140+
{
141+
// Counting is hard
142+
Assert.True(
143+
response.Content!.Contains("five", StringComparison.OrdinalIgnoreCase) ||
144+
response.Content!.Contains("six", StringComparison.OrdinalIgnoreCase) ||
145+
response.Content!.Contains("seven", StringComparison.OrdinalIgnoreCase) ||
146+
response.Content!.Contains("eight", StringComparison.OrdinalIgnoreCase),
147+
$"Content: {response}");
148+
}
149+
150+
private AgentGroupChat CreateChatProvider()
151+
{
152+
// Arrange
153+
AzureOpenAIConfiguration configuration = this._configuration.GetSection("AzureOpenAI").Get<AzureOpenAIConfiguration>()!;
154+
155+
this._kernelBuilder.AddAzureOpenAIChatCompletion(
156+
configuration.ChatDeploymentName!,
157+
configuration.Endpoint,
158+
new AzureCliCredential());
159+
160+
Kernel kernel = this._kernelBuilder.Build();
161+
162+
ChatCompletionAgent agent =
163+
new()
164+
{
165+
Kernel = kernel,
166+
Instructions = "Your job is to count. Always add one to the previous number and respond using the english word for that number, without explanation.",
167+
};
168+
169+
return new AgentGroupChat(agent)
170+
{
171+
ExecutionSettings = new()
172+
{
173+
TerminationStrategy = new CountTerminationStrategy(5)
174+
}
175+
};
176+
}
177+
178+
private sealed class CountTerminationStrategy(int maximumResponseCount) : TerminationStrategy
179+
{
180+
// Terminate when the assistant has responded N times.
181+
protected override Task<bool> ShouldAgentTerminateAsync(Agent agent, IReadOnlyList<ChatMessageContent> history, CancellationToken cancellationToken)
182+
=> Task.FromResult(history.Count(message => message.Role == AuthorRole.Assistant) >= maximumResponseCount);
183+
}
184+
}

0 commit comments

Comments
 (0)