Skip to content

Commit 4ea1d48

Browse files
.Net: Represent all embeddings as ReadOnlyMemory<float> (#2419)
### Motivation and Context Across AI-related libraries, we'd like a consistent representation of an embedding. It needs to support efficient handling of the data, with the ability to get a span over the data so that the data can be manipulated efficiently, which rules out interfaces like `IEnumerable<float>` and `IList<float>`. It should ideally encourage immutability but not enforce it, such that developers have escape hatches that don't require them to continually allocate and copy around large objects. It should be usable in both synchronous and asynchronous contexts. And it should be a ubiquitous core type that's available everywhere. Based on that, `ReadOnlyMemory<float>` is the winning candidate. ### Description This PR overhauls SK to use `ReadOnlyMemory<float>` for embeddings. In addition to the advantages listed above, I think it also cleans things up nicely. Note that in the System.Text.Json v8.0.0 nuget package, ReadOnlyMemory is implicitly supported. Once SK moves from having a v6 to a v8 dependency, we can delete the ReadOnlyMemoryConverter and all mentions of it. ### Contribution Checklist <!-- Before submitting this PR, please make sure: --> - [x] The code builds clean without any errors or warnings - [x] The PR follows the [SK Contribution Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md) and the [pre-submission formatting script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts) raises no violations - [x] All unit tests pass, and I have added new tests where possible - [ ] I didn't break anyone 😄 Co-authored-by: Shawn Callegari <[email protected]>
1 parent 3ec5953 commit 4ea1d48

File tree

82 files changed

+661
-761
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+661
-761
lines changed

dotnet/samples/KernelSyntaxExamples/Example25_ReadOnlyMemoryStore.cs

+10-10
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
using System.Collections.Generic;
55
using System.Linq;
66
using System.Runtime.CompilerServices;
7+
using System.Runtime.InteropServices;
78
using System.Text.Json;
89
using System.Threading;
910
using System.Threading.Tasks;
10-
using Microsoft.SemanticKernel.AI.Embeddings;
1111
using Microsoft.SemanticKernel.AI.Embeddings.VectorOperations;
1212
using Microsoft.SemanticKernel.Memory;
1313
using Microsoft.SemanticKernel.Memory.Collections;
@@ -30,20 +30,20 @@ public static async Task RunAsync()
3030
{
3131
var store = new ReadOnlyMemoryStore(s_jsonVectorEntries);
3232

33-
var embedding = new Embedding<float>(new float[] { 22, 4, 6 });
33+
var embedding = new ReadOnlyMemory<float>(new float[] { 22, 4, 6 });
3434

3535
Console.WriteLine("Reading data from custom read-only memory store");
3636
var memoryRecord = await store.GetAsync("collection", "key3");
3737
if (memoryRecord != null)
3838
{
39-
Console.WriteLine("ID = {0}, Embedding = {1}", memoryRecord.Metadata.Id, string.Join(", ", memoryRecord.Embedding.Vector));
39+
Console.WriteLine("ID = {0}, Embedding = {1}", memoryRecord.Metadata.Id, string.Join(", ", MemoryMarshal.ToEnumerable(memoryRecord.Embedding)));
4040
}
4141

42-
Console.WriteLine("Getting most similar vector to {0}", string.Join(", ", embedding.Vector));
42+
Console.WriteLine("Getting most similar vector to {0}", string.Join(", ", MemoryMarshal.ToEnumerable(embedding)));
4343
var result = await store.GetNearestMatchAsync("collection", embedding, 0.0);
4444
if (result.HasValue)
4545
{
46-
Console.WriteLine("Embedding = {0}, Similarity = {1}", string.Join(", ", result.Value.Item1.Embedding.Vector), result.Value.Item2);
46+
Console.WriteLine("Embedding = {0}, Similarity = {1}", string.Join(", ", MemoryMarshal.ToEnumerable(result.Value.Item1.Embedding)), result.Value.Item2);
4747
}
4848
}
4949

@@ -105,7 +105,7 @@ public IAsyncEnumerable<string> GetCollectionsAsync(CancellationToken cancellati
105105
throw new System.NotImplementedException();
106106
}
107107

108-
public async Task<(MemoryRecord, double)?> GetNearestMatchAsync(string collectionName, Embedding<float> embedding, double minRelevanceScore = 0,
108+
public async Task<(MemoryRecord, double)?> GetNearestMatchAsync(string collectionName, ReadOnlyMemory<float> embedding, double minRelevanceScore = 0,
109109
bool withEmbedding = false, CancellationToken cancellationToken = default)
110110
{
111111
// Note: with this simple implementation, the MemoryRecord will always contain the embedding.
@@ -123,7 +123,7 @@ public IAsyncEnumerable<string> GetCollectionsAsync(CancellationToken cancellati
123123
return default;
124124
}
125125

126-
public async IAsyncEnumerable<(MemoryRecord, double)> GetNearestMatchesAsync(string collectionName, Embedding<float> embedding, int limit,
126+
public async IAsyncEnumerable<(MemoryRecord, double)> GetNearestMatchesAsync(string collectionName, ReadOnlyMemory<float> embedding, int limit,
127127
double minRelevanceScore = 0, bool withEmbeddings = false, [EnumeratorCancellation] CancellationToken cancellationToken = default)
128128
{
129129
// Note: with this simple implementation, the MemoryRecord will always contain the embedding.
@@ -132,16 +132,16 @@ public IAsyncEnumerable<string> GetCollectionsAsync(CancellationToken cancellati
132132
yield break;
133133
}
134134

135-
if (embedding.Count != this._vectorSize)
135+
if (embedding.Length != this._vectorSize)
136136
{
137-
throw new Exception($"Embedding vector size {embedding.Count} does not match expected size of {this._vectorSize}");
137+
throw new Exception($"Embedding vector size {embedding.Length} does not match expected size of {this._vectorSize}");
138138
}
139139

140140
TopNCollection<MemoryRecord> embeddings = new(limit);
141141

142142
foreach (var item in this._memoryRecords)
143143
{
144-
double similarity = embedding.AsReadOnlySpan().CosineSimilarity(item.Embedding.AsReadOnlySpan());
144+
double similarity = embedding.Span.CosineSimilarity(item.Embedding.Span);
145145
if (similarity >= minRelevanceScore)
146146
{
147147
embeddings.Add(new(item, similarity));

dotnet/src/Connectors/Connectors.AI.HuggingFace/TextEmbedding/HuggingFaceTextEmbeddingGeneration.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ public HuggingFaceTextEmbeddingGeneration(string model, HttpClient httpClient, s
8181
}
8282

8383
/// <inheritdoc/>
84-
public async Task<IList<Embedding<float>>> GenerateEmbeddingsAsync(IList<string> data, CancellationToken cancellationToken = default)
84+
public async Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(IList<string> data, CancellationToken cancellationToken = default)
8585
{
8686
return await this.ExecuteEmbeddingRequestAsync(data, cancellationToken).ConfigureAwait(false);
8787
}
@@ -95,7 +95,7 @@ public async Task<IList<Embedding<float>>> GenerateEmbeddingsAsync(IList<string>
9595
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
9696
/// <returns>List of generated embeddings.</returns>
9797
/// <exception cref="AIException">Exception when backend didn't respond with generated embeddings.</exception>
98-
private async Task<IList<Embedding<float>>> ExecuteEmbeddingRequestAsync(IList<string> data, CancellationToken cancellationToken)
98+
private async Task<IList<ReadOnlyMemory<float>>> ExecuteEmbeddingRequestAsync(IList<string> data, CancellationToken cancellationToken)
9999
{
100100
try
101101
{
@@ -113,7 +113,7 @@ private async Task<IList<Embedding<float>>> ExecuteEmbeddingRequestAsync(IList<s
113113

114114
var embeddingResponse = JsonSerializer.Deserialize<TextEmbeddingResponse>(body);
115115

116-
return embeddingResponse?.Embeddings?.Select(l => new Embedding<float>(l.Embedding!, transferOwnership: true)).ToList()!;
116+
return embeddingResponse?.Embeddings?.Select(l => l.Embedding).ToList()!;
117117
}
118118
catch (Exception e) when (e is not AIException && !e.IsCriticalException())
119119
{

dotnet/src/Connectors/Connectors.AI.HuggingFace/TextEmbedding/TextEmbeddingResponse.cs

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
using System;
34
using System.Collections.Generic;
45
using System.Text.Json.Serialization;
6+
using Microsoft.SemanticKernel.Text;
57

68
namespace Microsoft.SemanticKernel.Connectors.AI.HuggingFace.TextEmbedding;
79

@@ -16,7 +18,8 @@ public sealed class TextEmbeddingResponse
1618
public sealed class EmbeddingVector
1719
{
1820
[JsonPropertyName("embedding")]
19-
public IList<float>? Embedding { get; set; }
21+
[JsonConverter(typeof(ReadOnlyMemoryConverter))]
22+
public ReadOnlyMemory<float> Embedding { get; set; }
2023
}
2124

2225
/// <summary>

dotnet/src/Connectors/Connectors.AI.OpenAI/AzureSdk/ClientBase.cs

+3-6
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
using Microsoft.Extensions.Logging.Abstractions;
1313
using Microsoft.SemanticKernel.AI;
1414
using Microsoft.SemanticKernel.AI.ChatCompletion;
15-
using Microsoft.SemanticKernel.AI.Embeddings;
1615
using Microsoft.SemanticKernel.AI.TextCompletion;
1716
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.ChatCompletion;
1817
using Microsoft.SemanticKernel.Diagnostics;
@@ -115,11 +114,11 @@ private protected async IAsyncEnumerable<TextStreamingResult> InternalGetTextStr
115114
/// <param name="data">List of strings to generate embeddings for</param>
116115
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
117116
/// <returns>List of embeddings</returns>
118-
private protected async Task<IList<Embedding<float>>> InternalGetEmbeddingsAsync(
117+
private protected async Task<IList<ReadOnlyMemory<float>>> InternalGetEmbeddingsAsync(
119118
IList<string> data,
120119
CancellationToken cancellationToken = default)
121120
{
122-
var result = new List<Embedding<float>>();
121+
var result = new List<ReadOnlyMemory<float>>(data.Count);
123122
foreach (string text in data)
124123
{
125124
var options = new EmbeddingsOptions(text);
@@ -137,9 +136,7 @@ private protected async Task<IList<Embedding<float>>> InternalGetEmbeddingsAsync
137136
throw new OpenAIInvalidResponseException<Embeddings>(response.Value, "Text embedding not found");
138137
}
139138

140-
EmbeddingItem x = response.Value.Data[0];
141-
142-
result.Add(new Embedding<float>(x.Embedding, transferOwnership: true));
139+
result.Add(response.Value.Data[0].Embedding.ToArray());
143140
}
144141

145142
return result;

dotnet/src/Connectors/Connectors.AI.OpenAI/CustomClient/OpenAIClientBase.cs

+2-3
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
using Microsoft.Extensions.Logging;
1313
using Microsoft.Extensions.Logging.Abstractions;
1414
using Microsoft.SemanticKernel.AI;
15-
using Microsoft.SemanticKernel.AI.Embeddings;
1615
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.ImageGeneration;
1716
using Microsoft.SemanticKernel.Connectors.AI.OpenAI.TextEmbedding;
1817
using Microsoft.SemanticKernel.Diagnostics;
@@ -48,7 +47,7 @@ private protected virtual void AddRequestHeaders(HttpRequestMessage request)
4847
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
4948
/// <returns>List of text embeddings</returns>
5049
/// <exception cref="AIException">AIException thrown during the request.</exception>
51-
private protected async Task<IList<Embedding<float>>> ExecuteTextEmbeddingRequestAsync(
50+
private protected async Task<IList<ReadOnlyMemory<float>>> ExecuteTextEmbeddingRequestAsync(
5251
string url,
5352
string requestBody,
5453
CancellationToken cancellationToken = default)
@@ -61,7 +60,7 @@ private protected async Task<IList<Embedding<float>>> ExecuteTextEmbeddingReques
6160
"Embeddings not found");
6261
}
6362

64-
return result.Embeddings.Select(e => new Embedding<float>(e.Values, transferOwnership: true)).ToList();
63+
return result.Embeddings.Select(e => e.Values).ToList();
6564
}
6665

6766
/// <summary>

dotnet/src/Connectors/Connectors.AI.OpenAI/TextEmbedding/AzureTextEmbeddingGeneration.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
using System;
34
using System.Collections.Generic;
45
using System.Net.Http;
56
using System.Threading;
@@ -56,7 +57,7 @@ public AzureTextEmbeddingGeneration(
5657
/// <param name="data">List of strings to generate embeddings for</param>
5758
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
5859
/// <returns>List of embeddings</returns>
59-
public Task<IList<Embedding<float>>> GenerateEmbeddingsAsync(
60+
public Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(
6061
IList<string> data,
6162
CancellationToken cancellationToken = default)
6263
{

dotnet/src/Connectors/Connectors.AI.OpenAI/TextEmbedding/OpenAITextEmbeddingGeneration.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
using System;
34
using System.Collections.Generic;
45
using System.Net.Http;
56
using System.Threading;
@@ -39,7 +40,7 @@ public OpenAITextEmbeddingGeneration(
3940
/// <param name="data">List of strings to generate embeddings for</param>
4041
/// <param name="cancellationToken">The <see cref="CancellationToken"/> to monitor for cancellation requests. The default is <see cref="CancellationToken.None"/>.</param>
4142
/// <returns>List of embeddings</returns>
42-
public Task<IList<Embedding<float>>> GenerateEmbeddingsAsync(
43+
public Task<IList<ReadOnlyMemory<float>>> GenerateEmbeddingsAsync(
4344
IList<string> data,
4445
CancellationToken cancellationToken = default)
4546
{

dotnet/src/Connectors/Connectors.AI.OpenAI/TextEmbedding/TextEmbeddingResponse.cs

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
using System;
34
using System.Collections.Generic;
45
using System.Text.Json.Serialization;
6+
using Microsoft.SemanticKernel.Text;
57

68
namespace Microsoft.SemanticKernel.Connectors.AI.OpenAI.TextEmbedding;
79

@@ -19,7 +21,8 @@ public sealed class EmbeddingResponseIndex
1921
/// The embedding vector
2022
/// </summary>
2123
[JsonPropertyName("embedding")]
22-
public IList<float> Values { get; set; } = new List<float>();
24+
[JsonConverter(typeof(ReadOnlyMemoryConverter))]
25+
public ReadOnlyMemory<float> Values { get; set; }
2326

2427
/// <summary>
2528
/// Index of the embedding vector

dotnet/src/Connectors/Connectors.Memory.AzureCognitiveSearch/AzureCognitiveSearchMemoryRecord.cs

+6-7
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

33
using System;
4-
using System.Collections.Generic;
5-
using System.Linq;
64
using System.Text;
75
using System.Text.Json.Serialization;
8-
using Microsoft.SemanticKernel.AI.Embeddings;
96
using Microsoft.SemanticKernel.Memory;
7+
using Microsoft.SemanticKernel.Text;
108

119
namespace Microsoft.SemanticKernel.Connectors.Memory.AzureCognitiveSearch;
1210

@@ -41,7 +39,8 @@ public class AzureCognitiveSearchMemoryRecord
4139
/// Content embedding
4240
/// </summary>
4341
[JsonPropertyName(EmbeddingField)]
44-
public List<float> Embedding { get; set; } = Array.Empty<float>().ToList();
42+
[JsonConverter(typeof(ReadOnlyMemoryConverter))]
43+
public ReadOnlyMemory<float> Embedding { get; set; }
4544

4645
/// <summary>
4746
/// Optional description of the content, e.g. a title. This can be useful when
@@ -87,13 +86,13 @@ public AzureCognitiveSearchMemoryRecord(
8786
string text,
8887
string externalSourceName,
8988
bool isReference,
90-
Embedding<float> embedding,
89+
ReadOnlyMemory<float> embedding,
9190
string? description = null,
9291
string? additionalMetadata = null)
9392
{
9493
this.Id = EncodeId(id);
9594
this.IsReference = isReference;
96-
this.Embedding = embedding.Vector.ToList();
95+
this.Embedding = embedding;
9796
this.Text = text;
9897
this.ExternalSourceName = externalSourceName;
9998
this.Description = description;
@@ -128,7 +127,7 @@ public MemoryRecord ToMemoryRecord(bool withEmbeddings = true)
128127
{
129128
return new MemoryRecord(
130129
metadata: this.ToMemoryRecordMetadata(),
131-
embedding: new Embedding<float>(withEmbeddings ? this.Embedding : Array.Empty<float>()),
130+
embedding: withEmbeddings ? this.Embedding : default,
132131
key: this.Id);
133132
}
134133

dotnet/src/Connectors/Connectors.Memory.AzureCognitiveSearch/AzureCognitiveSearchMemoryStore.cs

+5-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Collections.Generic;
66
using System.Linq;
77
using System.Runtime.CompilerServices;
8+
using System.Runtime.InteropServices;
89
using System.Text.RegularExpressions;
910
using System.Threading;
1011
using System.Threading.Tasks;
@@ -14,7 +15,6 @@
1415
using Azure.Search.Documents.Indexes;
1516
using Azure.Search.Documents.Indexes.Models;
1617
using Azure.Search.Documents.Models;
17-
using Microsoft.SemanticKernel.AI.Embeddings;
1818
using Microsoft.SemanticKernel.Diagnostics;
1919
using Microsoft.SemanticKernel.Memory;
2020

@@ -134,7 +134,7 @@ public async IAsyncEnumerable<MemoryRecord> GetBatchAsync(
134134
/// <inheritdoc />
135135
public async Task<(MemoryRecord, double)?> GetNearestMatchAsync(
136136
string collectionName,
137-
Embedding<float> embedding,
137+
ReadOnlyMemory<float> embedding,
138138
double minRelevanceScore = 0,
139139
bool withEmbedding = false,
140140
CancellationToken cancellationToken = default)
@@ -147,7 +147,7 @@ public async IAsyncEnumerable<MemoryRecord> GetBatchAsync(
147147
/// <inheritdoc />
148148
public async IAsyncEnumerable<(MemoryRecord, double)> GetNearestMatchesAsync(
149149
string collectionName,
150-
Embedding<float> embedding,
150+
ReadOnlyMemory<float> embedding,
151151
int limit,
152152
double minRelevanceScore = 0,
153153
bool withEmbeddings = false,
@@ -161,7 +161,7 @@ public async IAsyncEnumerable<MemoryRecord> GetBatchAsync(
161161
{
162162
KNearestNeighborsCount = limit,
163163
Fields = AzureCognitiveSearchMemoryRecord.EmbeddingField,
164-
Value = embedding.Vector.ToList()
164+
Value = MemoryMarshal.TryGetArray(embedding, out var array) && array.Count == embedding.Length ? array.Array! : embedding.ToArray(),
165165
};
166166

167167
SearchOptions options = new() { Vector = vectorQuery };
@@ -305,7 +305,7 @@ private async Task<List<string>> UpsertBatchAsync(
305305

306306
if (records.Count < 1) { return keys; }
307307

308-
var embeddingSize = records[0].Embedding.Count;
308+
var embeddingSize = records[0].Embedding.Length;
309309

310310
var client = this.GetSearchClient(indexName);
311311

dotnet/src/Connectors/Connectors.Memory.Chroma/ChromaClient.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ public async IAsyncEnumerable<string> ListCollectionsAsync([EnumeratorCancellati
107107
}
108108

109109
/// <inheritdoc />
110-
public async Task UpsertEmbeddingsAsync(string collectionId, string[] ids, float[][] embeddings, object[]? metadatas = null, CancellationToken cancellationToken = default)
110+
public async Task UpsertEmbeddingsAsync(string collectionId, string[] ids, ReadOnlyMemory<float>[] embeddings, object[]? metadatas = null, CancellationToken cancellationToken = default)
111111
{
112112
this._logger.LogDebug("Upserting embeddings to collection with id: {0}", collectionId);
113113

@@ -141,7 +141,7 @@ public async Task DeleteEmbeddingsAsync(string collectionId, string[] ids, Cance
141141
}
142142

143143
/// <inheritdoc />
144-
public async Task<ChromaQueryResultModel> QueryEmbeddingsAsync(string collectionId, float[][] queryEmbeddings, int nResults, string[]? include = null, CancellationToken cancellationToken = default)
144+
public async Task<ChromaQueryResultModel> QueryEmbeddingsAsync(string collectionId, ReadOnlyMemory<float>[] queryEmbeddings, int nResults, string[]? include = null, CancellationToken cancellationToken = default)
145145
{
146146
this._logger.LogDebug("Query embeddings in collection with id: {0}", collectionId);
147147

0 commit comments

Comments
 (0)