Skip to content

Commit

Permalink
Make the embedding field name configurable for the ElasticSearchVecto…
Browse files Browse the repository at this point in the history
…rStore

Signed-off-by: jonghoon park <[email protected]>
  • Loading branch information
dev-jonghoonpark committed Feb 12, 2025
1 parent b466159 commit a567120
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 35 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -42,6 +42,7 @@
* @author Josh Long
* @author Christian Tzolov
* @author Soby Chacko
* @author Jonghoon Park
* @since 1.0.0
*/
@AutoConfiguration(after = ElasticsearchRestClientAutoConfiguration.class)
Expand Down Expand Up @@ -72,6 +73,9 @@ ElasticsearchVectorStore vectorStore(ElasticsearchVectorStoreProperties properti
if (properties.getSimilarity() != null) {
elasticsearchVectorStoreOptions.setSimilarity(properties.getSimilarity());
}
if (properties.getEmbeddingFieldName() != null) {
elasticsearchVectorStoreOptions.setEmbeddingFieldName(properties.getEmbeddingFieldName());
}

return ElasticsearchVectorStore.builder(restClient, embeddingModel)
.options(elasticsearchVectorStoreOptions)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -26,6 +26,7 @@
* @author Eddú Meléndez
* @author Wei Jiang
* @author Josh Long
* @author Jonghoon Park
* @since 1.0.0
*/
@ConfigurationProperties(prefix = "spring.ai.vectorstore.elasticsearch")
Expand All @@ -46,6 +47,11 @@ public class ElasticsearchVectorStoreProperties extends CommonVectorStorePropert
*/
private SimilarityFunction similarity;

/**
* The name of the vector field to search against
*/
private String embeddingFieldName = "embedding";

public String getIndexName() {
return this.indexName;
}
Expand All @@ -70,4 +76,12 @@ public void setSimilarity(SimilarityFunction similarity) {
this.similarity = similarity;
}

public String getEmbeddingFieldName() {
return embeddingFieldName;
}

public void setEmbeddingFieldName(String embeddingFieldName) {
this.embeddingFieldName = embeddingFieldName;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch.core.BulkRequest;
import co.elastic.clients.elasticsearch.core.BulkResponse;
import co.elastic.clients.elasticsearch.core.DeleteByQueryResponse;
import co.elastic.clients.elasticsearch.core.SearchResponse;
import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem;
import co.elastic.clients.elasticsearch.core.search.Hit;
Expand All @@ -41,10 +40,8 @@

import org.springframework.ai.document.Document;
import org.springframework.ai.document.DocumentMetadata;
import org.springframework.ai.embedding.BatchingStrategy;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.embedding.EmbeddingOptionsBuilder;
import org.springframework.ai.embedding.TokenCountBatchingStrategy;
import org.springframework.ai.model.EmbeddingUtils;
import org.springframework.ai.observation.conventions.VectorStoreProvider;
import org.springframework.ai.observation.conventions.VectorStoreSimilarityMetric;
Expand Down Expand Up @@ -145,6 +142,7 @@
* @author Christian Tzolov
* @author Thomas Vitale
* @author Ilayaperumal Gopinathan
* @author Jonghoon Park
* @since 1.0.0
*/
public class ElasticsearchVectorStore extends AbstractObservationVectorStore implements InitializingBean {
Expand Down Expand Up @@ -191,11 +189,12 @@ public void doAdd(List<Document> documents) {
List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
this.batchingStrategy);

for (Document document : documents) {
ElasticSearchDocument doc = new ElasticSearchDocument(document.getId(), document.getText(),
document.getMetadata(), embeddings.get(documents.indexOf(document)));
bulkRequestBuilder.operations(
op -> op.index(idx -> idx.index(this.options.getIndexName()).id(document.getId()).document(doc)));
for (int i = 0; i < embeddings.size(); i++) {
Document document = documents.get(i);
float[] embedding = embeddings.get(i);
bulkRequestBuilder.operations(op -> op.index(idx -> idx.index(this.options.getIndexName())
.id(document.getId())
.document(getDocument(document, embedding, this.options.getEmbeddingFieldName()))));
}
BulkResponse bulkRequest = bulkRequest(bulkRequestBuilder.build());
if (bulkRequest.errors()) {
Expand All @@ -208,6 +207,13 @@ public void doAdd(List<Document> documents) {
}
}

private Object getDocument(Document document, float[] embedding, String embeddingFieldName) {
Assert.notNull(document.getText(), "document's text must not be null");

return Map.of("id", document.getId(), "content", document.getText(), "metadata", document.getMetadata(),
embeddingFieldName, embedding);
}

@Override
public Optional<Boolean> doDelete(List<String> idList) {
BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder();
Expand Down Expand Up @@ -264,7 +270,7 @@ public List<Document> doSimilaritySearch(SearchRequest searchRequest) {
.knn(knn -> knn.queryVector(EmbeddingUtils.toList(vectors))
.similarity(finalThreshold)
.k(searchRequest.getTopK())
.field("embedding")
.field(this.options.getEmbeddingFieldName())
.numCandidates((int) (1.5 * searchRequest.getTopK()))
.filter(fl -> fl
.queryString(qs -> qs.query(getElasticsearchQueryString(searchRequest.getFilterExpression())))))
Expand Down Expand Up @@ -322,7 +328,7 @@ private void createIndexMapping() {
try {
this.elasticsearchClient.indices()
.create(cr -> cr.index(this.options.getIndexName())
.mappings(map -> map.properties("embedding",
.mappings(map -> map.properties(this.options.getEmbeddingFieldName(),
p -> p.denseVector(dv -> dv.similarity(this.options.getSimilarity().toString())
.dims(this.options.getDimensions())))));
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -21,6 +21,7 @@
* https://www.elastic.co/guide/en/elasticsearch/reference/current/dense-vector.html
*
* @author Wei Jiang
* @author Jonghoon Park
* @since 1.0.0
*/
public class ElasticsearchVectorStoreOptions {
Expand All @@ -40,6 +41,11 @@ public class ElasticsearchVectorStoreOptions {
*/
private SimilarityFunction similarity = SimilarityFunction.cosine;

/**
* The name of the vector field to search against
*/
private String embeddingFieldName = "embedding";

public String getIndexName() {
return this.indexName;
}
Expand All @@ -64,4 +70,12 @@ public void setSimilarity(SimilarityFunction similarity) {
this.similarity = similarity;
}

public String getEmbeddingFieldName() {
return embeddingFieldName;
}

public void setEmbeddingFieldName(String embeddingFieldName) {
this.embeddingFieldName = embeddingFieldName;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
import org.springframework.ai.openai.OpenAiEmbeddingModel;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.vectorstore.SearchRequest;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.Filter.Expression;
import org.springframework.ai.vectorstore.filter.Filter.ExpressionType;
import org.springframework.ai.vectorstore.filter.Filter.Key;
Expand Down Expand Up @@ -117,10 +116,11 @@ void cleanDatabase() {
});
}

@Test
public void addAndDeleteDocumentsTest() {
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "cosine", "custom_field" })
public void addAndDeleteDocumentsTest(String vectorStoreBeanName) {
getContextRunner().run(context -> {
ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_cosine",
ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName,
ElasticsearchVectorStore.class);
ElasticsearchClient elasticsearchClient = context.getBean(ElasticsearchClient.class);

Expand Down Expand Up @@ -149,10 +149,11 @@ public void addAndDeleteDocumentsTest() {
});
}

@Test
public void deleteDocumentsByFilterExpressionTest() {
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "cosine", "custom_field" })
public void deleteDocumentsByFilterExpressionTest(String vectorStoreBeanName) {
getContextRunner().run(context -> {
ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_cosine",
ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName,
ElasticsearchVectorStore.class);
ElasticsearchClient elasticsearchClient = context.getBean(ElasticsearchClient.class);

Expand Down Expand Up @@ -202,10 +203,11 @@ public void deleteDocumentsByFilterExpressionTest() {
});
}

@Test
public void deleteWithStringFilterExpressionTest() {
@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "cosine", "custom_field" })
public void deleteWithStringFilterExpressionTest(String vectorStoreBeanName) {
getContextRunner().run(context -> {
ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_cosine",
ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName,
ElasticsearchVectorStore.class);
ElasticsearchClient elasticsearchClient = context.getBean(ElasticsearchClient.class);

Expand Down Expand Up @@ -234,12 +236,12 @@ public void deleteWithStringFilterExpressionTest() {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "cosine", "l2_norm", "dot_product" })
public void addAndSearchTest(String similarityFunction) {
@ValueSource(strings = { "cosine", "l2_norm", "dot_product", "custom_field" })
public void addAndSearchTest(String vectorStoreBeanName) {

getContextRunner().run(context -> {

ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction,
ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName,
ElasticsearchVectorStore.class);

vectorStore.add(this.documents);
Expand Down Expand Up @@ -271,11 +273,11 @@ public void addAndSearchTest(String similarityFunction) {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "cosine", "l2_norm", "dot_product" })
public void searchWithFilters(String similarityFunction) {
@ValueSource(strings = { "cosine", "l2_norm", "dot_product", "custom_field" })
public void searchWithFilters(String vectorStoreBeanName) {

getContextRunner().run(context -> {
ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction,
ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName,
ElasticsearchVectorStore.class);

var bgDocument = new Document("1", "The World is Big and Salvation Lurks Around the Corner",
Expand Down Expand Up @@ -385,11 +387,11 @@ public void searchWithFilters(String similarityFunction) {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "cosine", "l2_norm", "dot_product" })
public void documentUpdateTest(String similarityFunction) {
@ValueSource(strings = { "cosine", "l2_norm", "dot_product", "custom_field" })
public void documentUpdateTest(String vectorStoreBeanName) {

getContextRunner().run(context -> {
ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction,
ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName,
ElasticsearchVectorStore.class);

Document document = new Document(UUID.randomUUID().toString(), "Spring AI rocks!!",
Expand Down Expand Up @@ -443,10 +445,10 @@ public void documentUpdateTest(String similarityFunction) {
}

@ParameterizedTest(name = "{0} : {displayName} ")
@ValueSource(strings = { "cosine", "l2_norm", "dot_product" })
public void searchThresholdTest(String similarityFunction) {
@ValueSource(strings = { "cosine", "l2_norm", "dot_product", "custom_field" })
public void searchThresholdTest(String vectorStoreBeanName) {
getContextRunner().run(context -> {
ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + similarityFunction,
ElasticsearchVectorStore vectorStore = context.getBean("vectorStore_" + vectorStoreBeanName,
ElasticsearchVectorStore.class);

vectorStore.add(this.documents);
Expand Down Expand Up @@ -581,6 +583,16 @@ public ElasticsearchVectorStore vectorStoreDotProduct(EmbeddingModel embeddingMo
.build();
}

@Bean("vectorStore_custom_field")
public ElasticsearchVectorStore vectorStoreCustomField(EmbeddingModel embeddingModel, RestClient restClient) {
ElasticsearchVectorStoreOptions options = new ElasticsearchVectorStoreOptions();
options.setEmbeddingFieldName("custom_field");
return ElasticsearchVectorStore.builder(restClient, embeddingModel)
.initializeSchema(true)
.options(options)
.build();
}

@Bean
public EmbeddingModel embeddingModel() {
return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY")));
Expand Down

0 comments on commit a567120

Please sign in to comment.