Skip to content

Commit 13ecfae

Browse files
committed
Added cancellation token to filter context types
1 parent 775994e commit 13ecfae

File tree

11 files changed

+146
-10
lines changed

11 files changed

+146
-10
lines changed

dotnet/src/Connectors/Connectors.MistralAI/Client/MistralClient.cs

+3-1
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,8 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
177177
Arguments = functionArgs,
178178
RequestSequenceIndex = requestIndex - 1,
179179
FunctionSequenceIndex = toolCallIndex,
180-
FunctionCount = chatChoice.ToolCalls.Count
180+
FunctionCount = chatChoice.ToolCalls.Count,
181+
CancellationToken = cancellationToken
181182
};
182183
s_inflightAutoInvokes.Value++;
183184
try
@@ -409,6 +410,7 @@ internal async IAsyncEnumerable<StreamingChatMessageContent> GetStreamingChatMes
409410
RequestSequenceIndex = requestIndex - 1,
410411
FunctionSequenceIndex = toolCallIndex,
411412
FunctionCount = toolCalls.Count,
413+
CancellationToken = cancellationToken
412414
};
413415
s_inflightAutoInvokes.Value++;
414416
try

dotnet/src/Connectors/Connectors.OpenAI/AzureSdk/ClientCore.cs

+4-2
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,8 @@ internal async Task<IReadOnlyList<ChatMessageContent>> GetChatMessageContentsAsy
511511
Arguments = functionArgs,
512512
RequestSequenceIndex = requestIndex - 1,
513513
FunctionSequenceIndex = toolCallIndex,
514-
FunctionCount = result.ToolCalls.Count
514+
FunctionCount = result.ToolCalls.Count,
515+
CancellationToken = cancellationToken
515516
};
516517

517518
s_inflightAutoInvokes.Value++;
@@ -809,7 +810,8 @@ internal async IAsyncEnumerable<OpenAIStreamingChatMessageContent> GetStreamingC
809810
Arguments = functionArgs,
810811
RequestSequenceIndex = requestIndex - 1,
811812
FunctionSequenceIndex = toolCallIndex,
812-
FunctionCount = toolCalls.Length
813+
FunctionCount = toolCalls.Length,
814+
CancellationToken = cancellationToken
813815
};
814816

815817
s_inflightAutoInvokes.Value++;

dotnet/src/Connectors/Connectors.UnitTests/OpenAI/FunctionCalling/AutoFunctionInvocationFilterTests.cs

+48
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Linq;
66
using System.Net;
77
using System.Net.Http;
8+
using System.Threading;
89
using System.Threading.Tasks;
910
using Microsoft.Extensions.DependencyInjection;
1011
using Microsoft.SemanticKernel;
@@ -569,6 +570,53 @@ public async Task PostFilterCanTerminateOperationOnStreamingAsync()
569570
Assert.Equal(AuthorRole.Tool, lastMessageContent.Role);
570571
}
571572

573+
[Fact]
574+
public async Task FilterContextHasCancellationTokenAsync()
575+
{
576+
// Arrange
577+
using var cancellationTokenSource = new CancellationTokenSource();
578+
int firstFunctionInvocations = 0;
579+
int secondFunctionInvocations = 0;
580+
581+
var function1 = KernelFunctionFactory.CreateFromMethod((string parameter) =>
582+
{
583+
cancellationTokenSource.Cancel();
584+
firstFunctionInvocations++;
585+
return parameter;
586+
}, "Function1");
587+
588+
var function2 = KernelFunctionFactory.CreateFromMethod((string parameter) =>
589+
{
590+
secondFunctionInvocations++;
591+
return parameter;
592+
}, "Function2");
593+
594+
var plugin = KernelPluginFactory.CreateFromFunctions("MyPlugin", [function1, function2]);
595+
596+
var kernel = this.GetKernelWithFilter(plugin, async (context, next) =>
597+
{
598+
Assert.Equal(cancellationTokenSource.Token, context.CancellationToken);
599+
600+
await next(context);
601+
602+
context.CancellationToken.ThrowIfCancellationRequested();
603+
});
604+
605+
using var response1 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("filters_multiple_function_calls_test_response.json")) };
606+
using var response2 = new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(OpenAITestHelper.GetTestResponse("chat_completion_test_response.json")) };
607+
608+
this._messageHandlerStub.ResponsesToReturn = [response1, response2];
609+
610+
var arguments = new KernelArguments(new OpenAIPromptExecutionSettings { ToolCallBehavior = ToolCallBehavior.AutoInvokeKernelFunctions });
611+
612+
// Act & Assert
613+
var exception = await Assert.ThrowsAsync<KernelFunctionCanceledException>(()
614+
=> kernel.InvokePromptAsync("Test prompt", arguments, cancellationToken: cancellationTokenSource.Token));
615+
616+
Assert.Equal(1, firstFunctionInvocations);
617+
Assert.Equal(0, secondFunctionInvocations);
618+
}
619+
572620
public void Dispose()
573621
{
574622
this._httpClient.Dispose();

dotnet/src/SemanticKernel.Abstractions/Filters/AutoFunctionInvocation/AutoFunctionInvocationContext.cs

+7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

33
using System.Diagnostics.CodeAnalysis;
4+
using System.Threading;
45
using Microsoft.SemanticKernel.ChatCompletion;
56

67
namespace Microsoft.SemanticKernel;
@@ -35,6 +36,12 @@ public AutoFunctionInvocationContext(
3536
this.ChatHistory = chatHistory;
3637
}
3738

39+
/// <summary>
40+
/// The <see cref="System.Threading.CancellationToken"/> to monitor for cancellation requests.
41+
/// The default is <see cref="CancellationToken.None"/>.
42+
/// </summary>
43+
public CancellationToken CancellationToken { get; init; }
44+
3845
/// <summary>
3946
/// Gets the arguments associated with the operation.
4047
/// </summary>

dotnet/src/SemanticKernel.Abstractions/Filters/Function/FunctionInvocationContext.cs

+7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

33
using System.Diagnostics.CodeAnalysis;
4+
using System.Threading;
45

56
namespace Microsoft.SemanticKernel;
67

@@ -29,6 +30,12 @@ internal FunctionInvocationContext(Kernel kernel, KernelFunction function, Kerne
2930
this.Result = result;
3031
}
3132

33+
/// <summary>
34+
/// The <see cref="System.Threading.CancellationToken"/> to monitor for cancellation requests.
35+
/// The default is <see cref="CancellationToken.None"/>.
36+
/// </summary>
37+
public CancellationToken CancellationToken { get; init; }
38+
3239
/// <summary>
3340
/// Gets the <see cref="Microsoft.SemanticKernel.Kernel"/> containing services, plugins, and other state for use throughout the operation.
3441
/// </summary>

dotnet/src/SemanticKernel.Abstractions/Filters/Prompt/PromptRenderContext.cs

+7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

33
using System.Diagnostics.CodeAnalysis;
4+
using System.Threading;
45

56
namespace Microsoft.SemanticKernel;
67

@@ -29,6 +30,12 @@ internal PromptRenderContext(Kernel kernel, KernelFunction function, KernelArgum
2930
this.Arguments = arguments;
3031
}
3132

33+
/// <summary>
34+
/// The <see cref="System.Threading.CancellationToken"/> to monitor for cancellation requests.
35+
/// The default is <see cref="CancellationToken.None"/>.
36+
/// </summary>
37+
public CancellationToken CancellationToken { get; init; }
38+
3239
/// <summary>
3340
/// Gets the <see cref="Microsoft.SemanticKernel.Kernel"/> containing services, plugins, and other state for use throughout the operation.
3441
/// </summary>

dotnet/src/SemanticKernel.Abstractions/Functions/KernelFunction.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ public async Task<FunctionResult> InvokeAsync(
186186
{
187187
// Invoking the function and updating context with result.
188188
context.Result = functionResult = await this.InvokeCoreAsync(kernel, context.Arguments, cancellationToken).ConfigureAwait(false);
189-
}).ConfigureAwait(false);
189+
}, cancellationToken).ConfigureAwait(false);
190190

191191
// Apply any changes from the function filters context to final result.
192192
functionResult = invocationContext.Result;
@@ -321,7 +321,7 @@ public async IAsyncEnumerable<TResult> InvokeStreamingAsync<TResult>(
321321
context.Result = new FunctionResult(this, enumerable, kernel.Culture);
322322

323323
return Task.CompletedTask;
324-
}).ConfigureAwait(false);
324+
}, cancellationToken).ConfigureAwait(false);
325325

326326
// Apply changes from the function filters to final result.
327327
var enumerable = invocationContext.Result.GetValue<IAsyncEnumerable<TResult>>() ?? AsyncEnumerable.Empty<TResult>();

dotnet/src/SemanticKernel.Abstractions/Kernel.cs

+12-4
Original file line numberDiff line numberDiff line change
@@ -314,9 +314,13 @@ internal async Task<FunctionInvocationContext> OnFunctionInvocationAsync(
314314
KernelFunction function,
315315
KernelArguments arguments,
316316
FunctionResult functionResult,
317-
Func<FunctionInvocationContext, Task> functionCallback)
317+
Func<FunctionInvocationContext, Task> functionCallback,
318+
CancellationToken cancellationToken)
318319
{
319-
FunctionInvocationContext context = new(this, function, arguments, functionResult);
320+
FunctionInvocationContext context = new(this, function, arguments, functionResult)
321+
{
322+
CancellationToken = cancellationToken
323+
};
320324

321325
await InvokeFilterOrFunctionAsync(this._functionInvocationFilters, functionCallback, context).ConfigureAwait(false);
322326

@@ -351,9 +355,13 @@ await functionFilters[index].OnFunctionInvocationAsync(context,
351355
internal async Task<PromptRenderContext> OnPromptRenderAsync(
352356
KernelFunction function,
353357
KernelArguments arguments,
354-
Func<PromptRenderContext, Task> renderCallback)
358+
Func<PromptRenderContext, Task> renderCallback,
359+
CancellationToken cancellationToken)
355360
{
356-
PromptRenderContext context = new(this, function, arguments);
361+
PromptRenderContext context = new(this, function, arguments)
362+
{
363+
CancellationToken = cancellationToken
364+
};
357365

358366
await InvokeFilterOrPromptRenderAsync(this._promptRenderFilters, renderCallback, context).ConfigureAwait(false);
359367

dotnet/src/SemanticKernel.Core/Functions/KernelFunctionFromPrompt.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ private async Task<PromptRenderingResult> RenderPromptAsync(Kernel kernel, Kerne
335335
}
336336

337337
context.RenderedPrompt = renderedPrompt;
338-
}).ConfigureAwait(false);
338+
}, cancellationToken).ConfigureAwait(false);
339339

340340
if (!string.IsNullOrWhiteSpace(renderingContext.RenderedPrompt) &&
341341
!string.Equals(renderingContext.RenderedPrompt, renderedPrompt, StringComparison.OrdinalIgnoreCase))

dotnet/src/SemanticKernel.UnitTests/Filters/FunctionInvocationFilterTests.cs

+30
Original file line numberDiff line numberDiff line change
@@ -1022,4 +1022,34 @@ public async Task InsertFilterInMiddleOfPipelineTriggersFiltersInCorrectOrderAsy
10221022
Assert.Equal("FunctionFilter3-Invoked", executionOrder[4]);
10231023
Assert.Equal("FunctionFilter1-Invoked", executionOrder[5]);
10241024
}
1025+
1026+
[Fact]
1027+
public async Task FilterContextHasCancellationTokenAsync()
1028+
{
1029+
// Arrange
1030+
using var cancellationTokenSource = new CancellationTokenSource();
1031+
var function = KernelFunctionFactory.CreateFromMethod(() =>
1032+
{
1033+
cancellationTokenSource.Cancel();
1034+
return "Result";
1035+
});
1036+
1037+
var kernel = this.GetKernelWithFilters(onFunctionInvocation: async (context, next) =>
1038+
{
1039+
Assert.Equal(cancellationTokenSource.Token, context.CancellationToken);
1040+
Assert.False(context.CancellationToken.IsCancellationRequested);
1041+
1042+
await next(context);
1043+
1044+
Assert.True(context.CancellationToken.IsCancellationRequested);
1045+
context.CancellationToken.ThrowIfCancellationRequested();
1046+
});
1047+
1048+
// Act & Assert
1049+
var exception = await Assert.ThrowsAsync<KernelFunctionCanceledException>(()
1050+
=> kernel.InvokeAsync(function, cancellationToken: cancellationTokenSource.Token));
1051+
1052+
Assert.NotNull(exception.FunctionResult);
1053+
Assert.Equal("Result", exception.FunctionResult.ToString());
1054+
}
10251055
}

dotnet/src/SemanticKernel.UnitTests/Filters/PromptRenderFilterTests.cs

+25
Original file line numberDiff line numberDiff line change
@@ -264,4 +264,29 @@ public async Task PromptFilterCanOverrideFunctionResultAsync()
264264

265265
Assert.Equal("Result from prompt filter", result.ToString());
266266
}
267+
268+
[Fact]
269+
public async Task FilterContextHasCancellationTokenAsync()
270+
{
271+
// Arrange
272+
using var cancellationTokenSource = new CancellationTokenSource();
273+
var mockTextGeneration = this.GetMockTextGeneration();
274+
var function = KernelFunctionFactory.CreateFromPrompt("Prompt");
275+
276+
var kernel = this.GetKernelWithFilters(onPromptRender: async (context, next) =>
277+
{
278+
Assert.Equal(cancellationTokenSource.Token, context.CancellationToken);
279+
Assert.True(context.CancellationToken.IsCancellationRequested);
280+
281+
context.CancellationToken.ThrowIfCancellationRequested();
282+
283+
await next(context);
284+
});
285+
286+
// Act & Assert
287+
cancellationTokenSource.Cancel();
288+
289+
await Assert.ThrowsAsync<KernelFunctionCanceledException>(()
290+
=> kernel.InvokeAsync(function, cancellationToken: cancellationTokenSource.Token));
291+
}
267292
}

0 commit comments

Comments
 (0)