Skip to content

.Net Agents - Support streaming for ChatCompletionAgent #6956

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 69 additions & 0 deletions dotnet/samples/Concepts/Agents/ChatCompletion_Streaming.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Text;
using Microsoft.SemanticKernel;
using Microsoft.SemanticKernel.Agents;
using Microsoft.SemanticKernel.ChatCompletion;

namespace Agents;

/// <summary>
/// Demonstrate creation of <see cref="ChatCompletionAgent"/> and
/// eliciting its response to three explicit user messages.
/// </summary>
public class ChatCompletion_Streaming(ITestOutputHelper output) : BaseTest(output)
{
private const string ParrotName = "Parrot";
private const string ParrotInstructions = "Repeat the user message in the voice of a pirate and then end with a parrot sound.";

[Fact]
public async Task UseStreamingChatCompletionAgentAsync()
{
// Define the agent
ChatCompletionAgent agent =
new()
{
Name = ParrotName,
Instructions = ParrotInstructions,
Kernel = this.CreateKernelWithChatCompletion(),
};

ChatHistory chat = [];

// Respond to user input
await InvokeAgentAsync("Fortune favors the bold.");
await InvokeAgentAsync("I came, I saw, I conquered.");
await InvokeAgentAsync("Practice makes perfect.");

// Local function to invoke agent and display the conversation messages.
async Task InvokeAgentAsync(string input)
{
chat.Add(new ChatMessageContent(AuthorRole.User, input));

Console.WriteLine($"# {AuthorRole.User}: '{input}'");

StringBuilder builder = new();
await foreach (StreamingChatMessageContent message in agent.InvokeStreamingAsync(chat))
{
if (string.IsNullOrEmpty(message.Content))
{
continue;
}

if (builder.Length == 0)
{
Console.WriteLine($"# {message.Role} - {message.AuthorName ?? "*"}:");
}

Console.WriteLine($"\t > streamed: '{message.Content}'");
builder.Append(message.Content);
}

if (builder.Length > 0)
{
// Display full response and capture in chat history
Console.WriteLine($"\t > complete: '{builder}'");
chat.Add(new ChatMessageContent(AuthorRole.Assistant, builder.ToString()) { AuthorName = agent.Name });
}
}
}
}
4 changes: 2 additions & 2 deletions dotnet/samples/GettingStartedWithAgents/Step1_Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public async Task UseSingleChatCompletionAgentAsync()
};

/// Create a chat for agent interaction. For more, <see cref="Step3_Chat"/>.
ChatHistory chat = new();
ChatHistory chat = [];

// Respond to user input
await InvokeAgentAsync("Fortune favors the bold.");
Expand All @@ -41,7 +41,7 @@ async Task InvokeAgentAsync(string input)

Console.WriteLine($"# {AuthorRole.User}: '{input}'");

await foreach (var content in agent.InvokeAsync(chat))
await foreach (ChatMessageContent content in agent.InvokeAsync(chat))
{
Console.WriteLine($"# {content.Role} - {content.AuthorName ?? "*"}: '{content.Content}'");
}
Expand Down
6 changes: 3 additions & 3 deletions dotnet/samples/GettingStartedWithAgents/Step2_Plugins.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ public async Task UseChatCompletionWithPluginAgentAsync()
agent.Kernel.Plugins.Add(plugin);

/// Create a chat for agent interaction. For more, <see cref="Step3_Chat"/>.
AgentGroupChat chat = new();
ChatHistory chat = [];

// Respond to user input, invoking functions where appropriate.
await InvokeAgentAsync("Hello");
Expand All @@ -45,10 +45,10 @@ public async Task UseChatCompletionWithPluginAgentAsync()
// Local function to invoke agent and display the conversation messages.
async Task InvokeAgentAsync(string input)
{
chat.AddChatMessage(new ChatMessageContent(AuthorRole.User, input));
chat.Add(new ChatMessageContent(AuthorRole.User, input));
Console.WriteLine($"# {AuthorRole.User}: '{input}'");

await foreach (var content in chat.InvokeAsync(agent))
await foreach (var content in agent.InvokeAsync(chat))
{
Console.WriteLine($"# {content.Role} - {content.AuthorName ?? "*"}: '{content.Content}'");
}
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/Agents/Abstractions/ChatHistoryChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ protected internal sealed override async IAsyncEnumerable<ChatMessageContent> In
throw new KernelException($"Invalid channel binding for agent: {agent.Id} ({agent.GetType().FullName})");
}

await foreach (var message in historyHandler.InvokeAsync(this._history, cancellationToken).ConfigureAwait(false))
await foreach (ChatMessageContent message in historyHandler.InvokeAsync(this._history, cancellationToken).ConfigureAwait(false))
{
this._history.Add(message);

Expand Down
8 changes: 7 additions & 1 deletion dotnet/src/Agents/Abstractions/ChatHistoryKernelAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.ChatCompletion;

namespace Microsoft.SemanticKernel.Agents;

Expand Down Expand Up @@ -31,6 +32,11 @@ protected internal sealed override Task<AgentChannel> CreateChannelAsync(Cancell

/// <inheritdoc/>
public abstract IAsyncEnumerable<ChatMessageContent> InvokeAsync(
IReadOnlyList<ChatMessageContent> history,
ChatHistory history,
CancellationToken cancellationToken = default);

/// <inheritdoc/>
public abstract IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsync(
ChatHistory history,
CancellationToken cancellationToken = default);
}
15 changes: 13 additions & 2 deletions dotnet/src/Agents/Abstractions/IChatHistoryHandler.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (c) Microsoft. All rights reserved.
using System.Collections.Generic;
using System.Threading;
using Microsoft.SemanticKernel.ChatCompletion;

namespace Microsoft.SemanticKernel.Agents;

Expand All @@ -10,12 +11,22 @@ namespace Microsoft.SemanticKernel.Agents;
public interface IChatHistoryHandler
{
/// <summary>
/// Entry point for calling into an agent from a a <see cref="ChatHistoryChannel"/>.
/// Entry point for calling into an agent from a <see cref="ChatHistoryChannel"/>.
/// </summary>
/// <param name="history">The chat history at the point the channel is created.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>Asynchronous enumeration of messages.</returns>
IAsyncEnumerable<ChatMessageContent> InvokeAsync(
IReadOnlyList<ChatMessageContent> history,
ChatHistory history,
CancellationToken cancellationToken = default);

/// <summary>
/// Entry point for calling into an agent from a <see cref="ChatHistoryChannel"/> for streaming content.
/// </summary>
/// <param name="history">The chat history at the point the channel is created.</param>
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
/// <returns>Asynchronous enumeration of streaming content.</returns>
public abstract IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsync(
ChatHistory history,
CancellationToken cancellationToken = default);
}
72 changes: 63 additions & 9 deletions dotnet/src/Agents/Core/ChatCompletionAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.SemanticKernel.ChatCompletion;

Expand All @@ -23,17 +24,12 @@ public sealed class ChatCompletionAgent : ChatHistoryKernelAgent

/// <inheritdoc/>
public override async IAsyncEnumerable<ChatMessageContent> InvokeAsync(
IReadOnlyList<ChatMessageContent> history,
ChatHistory history,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
var chatCompletionService = this.Kernel.GetRequiredService<IChatCompletionService>();
IChatCompletionService chatCompletionService = this.Kernel.GetRequiredService<IChatCompletionService>();

ChatHistory chat = [];
if (!string.IsNullOrWhiteSpace(this.Instructions))
{
chat.Add(new ChatMessageContent(AuthorRole.System, this.Instructions) { AuthorName = this.Name });
}
chat.AddRange(history);
ChatHistory chat = this.SetupAgentChatHistory(history);

int messageCount = chat.Count;

Expand All @@ -58,7 +54,7 @@ await chatCompletionService.GetChatMessageContentsAsync(

message.AuthorName = this.Name;

yield return message;
history.Add(message);
}

foreach (ChatMessageContent message in messages ?? [])
Expand All @@ -69,4 +65,62 @@ await chatCompletionService.GetChatMessageContentsAsync(
yield return message;
}
}

/// <inheritdoc/>
public override async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsync(
ChatHistory history,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
IChatCompletionService chatCompletionService = this.Kernel.GetRequiredService<IChatCompletionService>();

ChatHistory chat = this.SetupAgentChatHistory(history);

int messageCount = chat.Count;

this.Logger.LogDebug("[{MethodName}] Invoking {ServiceType}.", nameof(InvokeAsync), chatCompletionService.GetType());

IAsyncEnumerable<StreamingChatMessageContent> messages =
chatCompletionService.GetStreamingChatMessageContentsAsync(
chat,
this.ExecutionSettings,
this.Kernel,
cancellationToken);

if (this.Logger.IsEnabled(LogLevel.Information))
{
this.Logger.LogInformation("[{MethodName}] Invoked {ServiceType} with streaming messages.", nameof(InvokeAsync), chatCompletionService.GetType());
}

// Capture mutated messages related function calling / tools
for (int messageIndex = messageCount; messageIndex < chat.Count; messageIndex++)
{
ChatMessageContent message = chat[messageIndex];

message.AuthorName = this.Name;

history.Add(message);
}

await foreach (StreamingChatMessageContent message in messages.ConfigureAwait(false))
{
// TODO: MESSAGE SOURCE - ISSUE #5731
message.AuthorName = this.Name;

yield return message;
}
}

private ChatHistory SetupAgentChatHistory(IReadOnlyList<ChatMessageContent> history)
{
ChatHistory chat = [];

if (!string.IsNullOrWhiteSpace(this.Instructions))
{
chat.Add(new ChatMessageContent(AuthorRole.System, this.Instructions) { AuthorName = this.Name });
}

chat.AddRange(history);

return chat;
}
}
13 changes: 12 additions & 1 deletion dotnet/src/Agents/UnitTests/AgentChatTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ private sealed class TestAgent : ChatHistoryKernelAgent
public int InvokeCount { get; private set; }

public override async IAsyncEnumerable<ChatMessageContent> InvokeAsync(
IReadOnlyList<ChatMessageContent> history,
ChatHistory history,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
await Task.Delay(0, cancellationToken);
Expand All @@ -144,5 +144,16 @@ public override async IAsyncEnumerable<ChatMessageContent> InvokeAsync(

yield return new ChatMessageContent(AuthorRole.Assistant, "sup");
}

public override async IAsyncEnumerable<StreamingChatMessageContent> InvokeStreamingAsync(
ChatHistory history,
[EnumeratorCancellation] CancellationToken cancellationToken = default)
{
await Task.Delay(0, cancellationToken);

this.InvokeCount++;

yield return new StreamingChatMessageContent(AuthorRole.Assistant, "sup");
}
}
}
2 changes: 1 addition & 1 deletion dotnet/src/Agents/UnitTests/AggregatorAgentTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ private static Mock<ChatHistoryKernelAgent> CreateMockAgent()
Mock<ChatHistoryKernelAgent> agent = new();

ChatMessageContent[] messages = [new ChatMessageContent(AuthorRole.Assistant, "test agent")];
agent.Setup(a => a.InvokeAsync(It.IsAny<IReadOnlyList<ChatMessageContent>>(), It.IsAny<CancellationToken>())).Returns(() => messages.ToAsyncEnumerable());
agent.Setup(a => a.InvokeAsync(It.IsAny<ChatHistory>(), It.IsAny<CancellationToken>())).Returns(() => messages.ToAsyncEnumerable());

return agent;
}
Expand Down
2 changes: 1 addition & 1 deletion dotnet/src/Agents/UnitTests/Core/AgentGroupChatTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ private static Mock<ChatHistoryKernelAgent> CreateMockAgent()
Mock<ChatHistoryKernelAgent> agent = new();

ChatMessageContent[] messages = [new ChatMessageContent(AuthorRole.Assistant, "test")];
agent.Setup(a => a.InvokeAsync(It.IsAny<IReadOnlyList<ChatMessageContent>>(), It.IsAny<CancellationToken>())).Returns(() => messages.ToAsyncEnumerable());
agent.Setup(a => a.InvokeAsync(It.IsAny<ChatHistory>(), It.IsAny<CancellationToken>())).Returns(() => messages.ToAsyncEnumerable());

return agent;
}
Expand Down
42 changes: 42 additions & 0 deletions dotnet/src/Agents/UnitTests/Core/ChatCompletionAgentTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,48 @@ public async Task VerifyChatCompletionAgentInvocationAsync()
Times.Once);
}

/// <summary>
/// Verify the streaming invocation and response of <see cref="ChatCompletionAgent"/>.
/// </summary>
[Fact]
public async Task VerifyChatCompletionAgentStreamingAsync()
{
StreamingChatMessageContent[] returnContent =
[
new(AuthorRole.Assistant, "wh"),
new(AuthorRole.Assistant, "at?"),
];

var mockService = new Mock<IChatCompletionService>();
mockService.Setup(
s => s.GetStreamingChatMessageContentsAsync(
It.IsAny<ChatHistory>(),
It.IsAny<PromptExecutionSettings>(),
It.IsAny<Kernel>(),
It.IsAny<CancellationToken>())).Returns(returnContent.ToAsyncEnumerable());

var agent =
new ChatCompletionAgent()
{
Instructions = "test instructions",
Kernel = CreateKernel(mockService.Object),
ExecutionSettings = new(),
};

var result = await agent.InvokeStreamingAsync([]).ToArrayAsync();

Assert.Equal(2, result.Length);

mockService.Verify(
x =>
x.GetStreamingChatMessageContentsAsync(
It.IsAny<ChatHistory>(),
It.IsAny<PromptExecutionSettings>(),
It.IsAny<Kernel>(),
It.IsAny<CancellationToken>()),
Times.Once);
}

private static Kernel CreateKernel(IChatCompletionService chatCompletionService)
{
var builder = Kernel.CreateBuilder();
Expand Down
Loading