Skip to content

Commit 3dbfb44

Browse files
.Net: Add cancellation token & add headers to request (#11900)
### Motivation, Context and Description This PR: 1. Propagates the cancellation token from all kernel functions of the `SessionsPythonPlugin` plugin down to the APIs they call. 2. Refactors the way the plugin adds headers to the request. Instead of adding them as default headers of the HTTP client, they are now added as request headers. The result is the same - the headers are sent each call of any kernel function of the plugin. This change may save time later by preventing bug troubleshooting if/when an HTTP client is injected into the plugin and it sets default headers to the client it does not own. 3. Replaces the usage of HTTP client methods for getting content, such as `ReadAsStringAsync`, with the SK wrappers `ReadAsStringWithExceptionMappingAsync` to align the plugin's behavior with that of other SK components. Contributes to: #10070
1 parent 1a187b0 commit 3dbfb44

File tree

5 files changed

+90
-45
lines changed

5 files changed

+90
-45
lines changed

dotnet/samples/Demos/CodeInterpreterPlugin/Program.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@
3535
/// <summary>
3636
/// Acquire a token for the Azure Container Apps service
3737
/// </summary>
38-
async Task<string> TokenProvider()
38+
async Task<string> TokenProvider(CancellationToken cancellationToken)
3939
{
4040
if (cachedToken is null)
4141
{
4242
string resource = "https://acasessions.io/.default";
4343
var credential = new InteractiveBrowserCredential();
4444

4545
// Attempt to get the token
46-
var accessToken = await credential.GetTokenAsync(new Azure.Core.TokenRequestContext([resource])).ConfigureAwait(false);
46+
var accessToken = await credential.GetTokenAsync(new Azure.Core.TokenRequestContext([resource]), cancellationToken).ConfigureAwait(false);
4747
if (logger.IsEnabled(LogLevel.Information))
4848
{
4949
logger.LogInformation("Access token obtained successfully");

dotnet/src/IntegrationTests/Plugins/Core/SessionsPythonPluginTests.cs

+11-10
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,20 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

33
using System;
4-
using System.Threading.Tasks;
5-
using Xunit;
6-
using Microsoft.SemanticKernel.Plugins.Core.CodeInterpreter;
7-
using Microsoft.Extensions.Configuration;
8-
using SemanticKernel.IntegrationTests.TestSettings;
4+
using System.Collections.Generic;
95
using System.Net.Http;
10-
using Azure.Identity;
6+
using System.Threading;
7+
using System.Threading.Tasks;
118
using Azure.Core;
12-
using System.Collections.Generic;
13-
using Microsoft.SemanticKernel;
9+
using Azure.Identity;
10+
using Microsoft.Extensions.Configuration;
1411
using Microsoft.Extensions.DependencyInjection;
12+
using Microsoft.SemanticKernel;
1513
using Microsoft.SemanticKernel.ChatCompletion;
1614
using Microsoft.SemanticKernel.Connectors.AzureOpenAI;
15+
using Microsoft.SemanticKernel.Plugins.Core.CodeInterpreter;
16+
using SemanticKernel.IntegrationTests.TestSettings;
17+
using Xunit;
1718

1819
namespace SemanticKernel.IntegrationTests.Plugins.Core;
1920

@@ -143,13 +144,13 @@ public async Task LlmShouldUploadFileAndAccessItFromCodeInterpreterAsync()
143144
/// <summary>
144145
/// Acquires authentication token for the Azure Container App Session pool.
145146
/// </summary>
146-
private static async Task<string> GetAuthTokenAsync()
147+
private static async Task<string> GetAuthTokenAsync(CancellationToken cancellationToken)
147148
{
148149
string resource = "https://acasessions.io/.default";
149150

150151
var credential = new AzureCliCredential();
151152

152-
AccessToken token = await credential.GetTokenAsync(new Azure.Core.TokenRequestContext([resource])).ConfigureAwait(false);
153+
AccessToken token = await credential.GetTokenAsync(new Azure.Core.TokenRequestContext([resource]), cancellationToken).ConfigureAwait(false);
153154

154155
return token.Token;
155156
}

dotnet/src/Plugins/Plugins.Core/CodeInterpreter/SessionsPythonPlugin.cs

+38-30
Original file line numberDiff line numberDiff line change
@@ -26,21 +26,21 @@ public sealed partial class SessionsPythonPlugin
2626
private const string ApiVersion = "2024-10-02-preview";
2727
private readonly Uri _poolManagementEndpoint;
2828
private readonly SessionsPythonSettings _settings;
29-
private readonly Func<Task<string>>? _authTokenProvider;
29+
private readonly Func<CancellationToken, Task<string>>? _authTokenProvider;
3030
private readonly IHttpClientFactory _httpClientFactory;
3131
private readonly ILogger _logger;
3232

3333
/// <summary>
3434
/// Initializes a new instance of the SessionsPythonTool class.
3535
/// </summary>
36-
/// <param name="settings">The settings for the Python tool plugin. </param>
37-
/// <param name="httpClientFactory">The HTTP client factory. </param>
38-
/// <param name="authTokenProvider"> Optional provider for auth token generation. </param>
39-
/// <param name="loggerFactory">The logger factory. </param>
36+
/// <param name="settings">The settings for the Python tool plugin.</param>
37+
/// <param name="httpClientFactory">The HTTP client factory.</param>
38+
/// <param name="authTokenProvider">Optional provider for auth token generation.</param>
39+
/// <param name="loggerFactory">The logger factory.</param>
4040
public SessionsPythonPlugin(
4141
SessionsPythonSettings settings,
4242
IHttpClientFactory httpClientFactory,
43-
Func<Task<string>>? authTokenProvider = null,
43+
Func<CancellationToken, Task<string>>? authTokenProvider = null,
4444
ILoggerFactory? loggerFactory = null)
4545
{
4646
Verify.NotNull(settings, nameof(settings));
@@ -66,7 +66,8 @@ public SessionsPythonPlugin(
6666
/// Keep everything in a single line; the \n sequences will represent line breaks
6767
/// when the string is processed or displayed.
6868
/// </summary>
69-
/// <param name="code"> The valid Python code to execute. </param>
69+
/// <param name="code"> The valid Python code to execute.</param>
70+
/// <param name="cancellationToken">The cancellation token.</param>
7071
/// <returns> The result of the Python code execution. </returns>
7172
/// <exception cref="ArgumentNullException"></exception>
7273
/// <exception cref="HttpRequestException"></exception>
@@ -79,7 +80,9 @@ Add spaces directly after \n sequences to replicate indentation.
7980
Keep everything in a single line; the \n sequences will represent line breaks
8081
when the string is processed or displayed.
8182
""")]
82-
public async Task<string> ExecuteCodeAsync([Description("The valid Python code to execute.")] string code)
83+
public async Task<string> ExecuteCodeAsync(
84+
[Description("The valid Python code to execute.")] string code,
85+
CancellationToken cancellationToken = default)
8386
{
8487
Verify.NotNullOrWhiteSpace(code, nameof(code));
8588

@@ -91,15 +94,14 @@ public async Task<string> ExecuteCodeAsync([Description("The valid Python code t
9194
this._logger.LogTrace("Executing Python code: {Code}", code);
9295

9396
using var httpClient = this._httpClientFactory.CreateClient();
94-
await this.AddHeadersAsync(httpClient).ConfigureAwait(false);
9597

9698
var requestBody = new SessionsPythonCodeExecutionProperties(this._settings, code);
9799

98100
using var content = new StringContent(JsonSerializer.Serialize(requestBody), Encoding.UTF8, "application/json");
99101

100-
using var response = await this.SendAsync(httpClient, HttpMethod.Post, "executions", content).ConfigureAwait(false);
102+
using var response = await this.SendAsync(httpClient, HttpMethod.Post, "executions", cancellationToken, content).ConfigureAwait(false);
101103

102-
var responseContent = JsonSerializer.Deserialize<JsonElement>(await response.Content.ReadAsStringAsync().ConfigureAwait(false));
104+
var responseContent = JsonSerializer.Deserialize<JsonElement>(await response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false));
103105

104106
var result = responseContent.GetProperty("result");
105107

@@ -120,21 +122,22 @@ public async Task<string> ExecuteCodeAsync([Description("The valid Python code t
120122
/// </summary>
121123
/// <param name="remoteFileName">The name of the remote file, relative to `/mnt/data`.</param>
122124
/// <param name="localFilePath">The path to the file on the local machine.</param>
125+
/// <param name="cancellationToken">The cancellation token.</param>
123126
/// <returns>The metadata of the uploaded file.</returns>
124127
/// <exception cref="ArgumentNullException"></exception>
125128
/// <exception cref="HttpRequestException"></exception>
126129
[KernelFunction, Description("Uploads a file to the `/mnt/data` directory of the current session.")]
127130
public async Task<SessionsRemoteFileMetadata> UploadFileAsync(
128131
[Description("The name of the remote file, relative to `/mnt/data`.")] string remoteFileName,
129-
[Description("The path to the file on the local machine.")] string localFilePath)
132+
[Description("The path to the file on the local machine.")] string localFilePath,
133+
CancellationToken cancellationToken = default)
130134
{
131135
Verify.NotNullOrWhiteSpace(remoteFileName, nameof(remoteFileName));
132136
Verify.NotNullOrWhiteSpace(localFilePath, nameof(localFilePath));
133137

134138
this._logger.LogInformation("Uploading file: {LocalFilePath} to {RemoteFileName}", localFilePath, remoteFileName);
135139

136140
using var httpClient = this._httpClientFactory.CreateClient();
137-
await this.AddHeadersAsync(httpClient).ConfigureAwait(false);
138141

139142
using var fileContent = new ByteArrayContent(File.ReadAllBytes(localFilePath));
140143

@@ -143,9 +146,9 @@ public async Task<SessionsRemoteFileMetadata> UploadFileAsync(
143146
{ fileContent, "file", remoteFileName },
144147
};
145148

146-
using var response = await this.SendAsync(httpClient, HttpMethod.Post, "files", multipartFormDataContent).ConfigureAwait(false);
149+
using var response = await this.SendAsync(httpClient, HttpMethod.Post, "files", cancellationToken, multipartFormDataContent).ConfigureAwait(false);
147150

148-
var stringContent = await response.Content.ReadAsStringAsync().ConfigureAwait(false);
151+
var stringContent = await response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false);
149152

150153
return JsonSerializer.Deserialize<SessionsRemoteFileMetadata>(stringContent)!;
151154
}
@@ -155,22 +158,23 @@ public async Task<SessionsRemoteFileMetadata> UploadFileAsync(
155158
/// </summary>
156159
/// <param name="remoteFileName">The name of the remote file to download, relative to `/mnt/data`.</param>
157160
/// <param name="localFilePath">The path to save the downloaded file to. If not provided won't save it in the disk.</param>
161+
/// <param name="cancellationToken">The cancellation token.</param>
158162
/// <returns>The data of the downloaded file as byte array.</returns>
159163
[KernelFunction, Description("Downloads a file from the `/mnt/data` directory of the current session.")]
160164
public async Task<byte[]> DownloadFileAsync(
161165
[Description("The name of the remote file to download, relative to `/mnt/data`.")] string remoteFileName,
162-
[Description("The path to save the downloaded file to. If not provided won't save it in the disk.")] string? localFilePath = null)
166+
[Description("The path to save the downloaded file to. If not provided won't save it in the disk.")] string? localFilePath = null,
167+
CancellationToken cancellationToken = default)
163168
{
164169
Verify.NotNullOrWhiteSpace(remoteFileName, nameof(remoteFileName));
165170

166171
this._logger.LogTrace("Downloading file: {RemoteFileName} to {LocalFileName}", remoteFileName, localFilePath);
167172

168173
using var httpClient = this._httpClientFactory.CreateClient();
169-
await this.AddHeadersAsync(httpClient).ConfigureAwait(false);
170174

171-
using var response = await this.SendAsync(httpClient, HttpMethod.Get, $"files/{Uri.EscapeDataString(remoteFileName)}/content").ConfigureAwait(false);
175+
using var response = await this.SendAsync(httpClient, HttpMethod.Get, $"files/{Uri.EscapeDataString(remoteFileName)}/content", cancellationToken).ConfigureAwait(false);
172176

173-
var fileContent = await response.Content.ReadAsByteArrayAsync().ConfigureAwait(false);
177+
var fileContent = await response.Content.ReadAsByteArrayAndTranslateExceptionAsync(cancellationToken).ConfigureAwait(false);
174178

175179
if (!string.IsNullOrWhiteSpace(localFilePath))
176180
{
@@ -190,18 +194,18 @@ public async Task<byte[]> DownloadFileAsync(
190194
/// <summary>
191195
/// Lists all entities: files or directories in the `/mnt/data` directory of the current session.
192196
/// </summary>
197+
/// <param name="cancellationToken">The cancellation token.</param>
193198
/// <returns>The list of files in the session.</returns>
194199
[KernelFunction, Description("Lists all entities: files or directories in the `/mnt/data` directory of the current session.")]
195-
public async Task<IReadOnlyList<SessionsRemoteFileMetadata>> ListFilesAsync()
200+
public async Task<IReadOnlyList<SessionsRemoteFileMetadata>> ListFilesAsync(CancellationToken cancellationToken = default)
196201
{
197202
this._logger.LogTrace("Listing files for Session ID: {SessionId}", this._settings.SessionId);
198203

199204
using var httpClient = this._httpClientFactory.CreateClient();
200-
await this.AddHeadersAsync(httpClient).ConfigureAwait(false);
201205

202-
using var response = await this.SendAsync(httpClient, HttpMethod.Get, "files").ConfigureAwait(false);
206+
using var response = await this.SendAsync(httpClient, HttpMethod.Get, "files", cancellationToken).ConfigureAwait(false);
203207

204-
var jsonElementResult = JsonSerializer.Deserialize<JsonElement>(await response.Content.ReadAsStringAsync().ConfigureAwait(false));
208+
var jsonElementResult = JsonSerializer.Deserialize<JsonElement>(await response.Content.ReadAsStringWithExceptionMappingAsync(cancellationToken).ConfigureAwait(false));
205209

206210
var files = jsonElementResult.GetProperty("value");
207211

@@ -241,16 +245,17 @@ private static string SanitizeCodeInput(string code)
241245
}
242246

243247
/// <summary>
244-
/// Add headers to the HTTP client.
248+
/// Add headers to the HTTP request.
245249
/// </summary>
246-
/// <param name="httpClient">The HTTP client to add headers to.</param>
247-
private async Task AddHeadersAsync(HttpClient httpClient)
250+
/// <param name="request">The HTTP request to add headers to.</param>
251+
/// <param name="cancellationToken">The cancellation token.</param>
252+
private async Task AddHeadersAsync(HttpRequestMessage request, CancellationToken cancellationToken)
248253
{
249-
httpClient.DefaultRequestHeaders.Add("User-Agent", $"{HttpHeaderConstant.Values.UserAgent}/{s_assemblyVersion} (Language=dotnet)");
254+
request.Headers.Add("User-Agent", $"{HttpHeaderConstant.Values.UserAgent}/{s_assemblyVersion} (Language=dotnet)");
250255

251256
if (this._authTokenProvider is not null)
252257
{
253-
httpClient.DefaultRequestHeaders.Add("Authorization", $"Bearer {(await this._authTokenProvider().ConfigureAwait(false))}");
258+
request.Headers.Add("Authorization", $"Bearer {(await this._authTokenProvider(cancellationToken).ConfigureAwait(false))}");
254259
}
255260
}
256261

@@ -260,9 +265,10 @@ private async Task AddHeadersAsync(HttpClient httpClient)
260265
/// <param name="httpClient">The HTTP client to use.</param>
261266
/// <param name="method">The HTTP method to use.</param>
262267
/// <param name="path">The path to send the request to.</param>
268+
/// <param name="cancellationToken">The cancellation token.</param>
263269
/// <param name="httpContent">The content to send with the request.</param>
264270
/// <returns>The HTTP response message.</returns>
265-
private async Task<HttpResponseMessage> SendAsync(HttpClient httpClient, HttpMethod method, string path, HttpContent? httpContent = null)
271+
private async Task<HttpResponseMessage> SendAsync(HttpClient httpClient, HttpMethod method, string path, CancellationToken cancellationToken, HttpContent? httpContent = null)
266272
{
267273
// The query string is the same for all operations
268274
var pathWithQueryString = $"{path}?identifier={this._settings.SessionId}&api-version={ApiVersion}";
@@ -281,7 +287,9 @@ private async Task<HttpResponseMessage> SendAsync(HttpClient httpClient, HttpMet
281287
Content = httpContent,
282288
};
283289

284-
return await httpClient.SendWithSuccessCheckAsync(request, CancellationToken.None).ConfigureAwait(false);
290+
await this.AddHeadersAsync(request, cancellationToken).ConfigureAwait(false);
291+
292+
return await httpClient.SendWithSuccessCheckAsync(request, cancellationToken).ConfigureAwait(false);
285293
}
286294

287295
#if NET

dotnet/src/Plugins/Plugins.UnitTests/Core/SessionsPythonPluginTests.cs

+33-1
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,14 @@
33
using System;
44
using System.Collections.Generic;
55
using System.IO;
6+
using System.Linq;
67
using System.Net;
78
using System.Net.Http;
89
using System.Text.Json;
10+
using System.Threading;
911
using System.Threading.Tasks;
1012
using Microsoft.SemanticKernel;
13+
using Microsoft.SemanticKernel.Http;
1114
using Microsoft.SemanticKernel.Plugins.Core.CodeInterpreter;
1215
using Moq;
1316
using Xunit;
@@ -22,6 +25,7 @@ public sealed class SessionsPythonPluginTests : IDisposable
2225
private const string ListFilesTestDataFilePath = "./TestData/sessions_python_plugin_file_list.json";
2326
private const string UpdaloadFileTestDataFilePath = "./TestData/sessions_python_plugin_file_upload.json";
2427
private const string FileTestDataFilePath = "./TestData/sessions_python_plugin_file.txt";
28+
private readonly static string s_assemblyVersion = typeof(Kernel).Assembly.GetName().Version!.ToString();
2529

2630
private readonly SessionsPythonSettings _defaultSettings = new(
2731
sessionId: Guid.NewGuid().ToString(),
@@ -97,7 +101,7 @@ public async Task ItShouldCallTokenProviderWhenProvidedAsync(string methodName)
97101
// Arrange
98102
var tokenProviderCalled = false;
99103

100-
Task<string> tokenProviderAsync()
104+
Task<string> tokenProviderAsync(CancellationToken _)
101105
{
102106
tokenProviderCalled = true;
103107
return Task.FromResult("token");
@@ -332,6 +336,34 @@ public async Task ItShouldRespectAllowedDomainsAsync(string allowedDomain, strin
332336
#pragma warning restore CA1031 // Do not catch general exception types
333337
}
334338

339+
[Fact]
340+
public async Task ItShouldAddHeadersAsync()
341+
{
342+
// Arrange
343+
var responseContent = await File.ReadAllTextAsync(UpdaloadFileTestDataFilePath);
344+
345+
this._messageHandlerStub.ResponseToReturn = new HttpResponseMessage(HttpStatusCode.OK)
346+
{
347+
Content = new StringContent(responseContent),
348+
};
349+
350+
var plugin = new SessionsPythonPlugin(this._defaultSettings, this._httpClientFactory, (_) => Task.FromResult("test-auth-token"));
351+
352+
// Act
353+
var result = await plugin.UploadFileAsync("test-file.txt", FileTestDataFilePath);
354+
355+
// Assert
356+
Assert.NotNull(this._messageHandlerStub.RequestHeaders);
357+
358+
var userAgentHeaderValues = this._messageHandlerStub.RequestHeaders.GetValues("User-Agent").ToArray();
359+
Assert.Equal(2, userAgentHeaderValues.Length);
360+
Assert.Equal($"{HttpHeaderConstant.Values.UserAgent}/{s_assemblyVersion}", userAgentHeaderValues[0]);
361+
Assert.Equal("(Language=dotnet)", userAgentHeaderValues[1]);
362+
363+
var authorizationHeaderValues = this._messageHandlerStub.RequestHeaders.GetValues("Authorization");
364+
Assert.Single(authorizationHeaderValues, value => value == "Bearer test-auth-token");
365+
}
366+
335367
public void Dispose()
336368
{
337369
this._httpClient.Dispose();

dotnet/src/Plugins/Plugins.UnitTests/Plugins.UnitTests.csproj

+6-2
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,17 @@
3838
<ProjectReference Include="..\Plugins.Web\Plugins.Web.csproj" />
3939
<ProjectReference Include="..\Plugins.Memory\Plugins.Memory.csproj" />
4040
</ItemGroup>
41-
41+
42+
<ItemGroup>
43+
<Compile Include="$(RepoRoot)/dotnet/src/InternalUtilities/src/Http/HttpHeaderConstant.cs" Link="%(RecursiveDir)%(Filename)%(Extension)" />
44+
</ItemGroup>
45+
4246
<ItemGroup>
4347
<None Update="TestData\*">
4448
<CopyToOutputDirectory>Always</CopyToOutputDirectory>
4549
</None>
4650
</ItemGroup>
47-
51+
4852
<ItemGroup>
4953
<Folder Include="Web\Tavily\" />
5054
</ItemGroup>

0 commit comments

Comments
 (0)