-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add DocumentStore, MultiVectorRetriever
Signed-off-by: seungy0 <[email protected]>
- Loading branch information
Showing
2 changed files
with
160 additions
and
0 deletions.
There are no files selected for viewing
18 changes: 18 additions & 0 deletions
18
spring-ai-core/src/main/java/org/springframework/ai/document/DocumentStore.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
package org.springframework.ai.document; | ||
|
||
import java.util.List; | ||
import java.util.Map; | ||
|
||
/** | ||
* Interface for a document storage system that can retrieve documents by their IDs. | ||
*/ | ||
public interface DocumentStore extends Map<String, Document> { | ||
|
||
/** | ||
* Retrieves a list of documents by their IDs. | ||
* | ||
* @param ids The list of document IDs to retrieve. | ||
* @return The list of retrieved documents. | ||
*/ | ||
List<Document> get(List<String> ids); | ||
} |
142 changes: 142 additions & 0 deletions
142
...-core/src/main/java/org/springframework/ai/rag/retrieval/search/MultiVectorRetriever.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
package org.springframework.ai.rag.retrieval.search; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
import java.util.Objects; | ||
import java.util.function.Supplier; | ||
|
||
import org.springframework.ai.document.Document; | ||
import org.springframework.ai.document.DocumentStore; | ||
import org.springframework.ai.rag.Query; | ||
import org.springframework.ai.vectorstore.SearchRequest; | ||
import org.springframework.ai.vectorstore.VectorStore; | ||
import org.springframework.ai.vectorstore.filter.Filter; | ||
import org.springframework.lang.Nullable; | ||
import org.springframework.util.Assert; | ||
|
||
/** | ||
* Retrieves from a set of multiple embeddings for the same document. | ||
* <p> | ||
* Example usage: | ||
* <pre>{@code | ||
* MultiVectorRetriever retriever = MultiVectorRetriever.builder() | ||
* .vectorStore(vectorStore) | ||
* .docStore(docStore) | ||
* .similarityThreshold(0.75) | ||
* .topK(5) | ||
* .filterExpression(filterExpression) | ||
* .build(); | ||
* List<Document> documents = retriever.retrieve(new Query("example query")); | ||
* }</pre> | ||
* | ||
* @author Seunggyu Lee | ||
* @since 1.0.0 | ||
*/ | ||
public class MultiVectorRetriever implements DocumentRetriever { | ||
|
||
private final VectorStore vectorStore; | ||
private final DocumentStore docStore; | ||
private final Double similarityThreshold; | ||
private final Integer topK; | ||
private final Supplier<Filter.Expression> filterExpression; | ||
private final String parentIdKey; | ||
|
||
private MultiVectorRetriever(VectorStore vectorStore, DocumentStore docStore, | ||
@Nullable Double similarityThreshold, @Nullable Integer topK, | ||
@Nullable Supplier<Filter.Expression> filterExpression, String parentIdKey) { | ||
Assert.notNull(vectorStore, "vectorStore cannot be null"); | ||
Assert.notNull(docStore, "docStore cannot be null"); | ||
Assert.isTrue(similarityThreshold == null || similarityThreshold >= 0.0, | ||
"similarityThreshold must be >= 0.0"); | ||
Assert.isTrue(topK == null || topK > 0, "topK must be > 0"); | ||
Assert.hasText(parentIdKey, "parentIdKey must not be empty"); | ||
this.vectorStore = vectorStore; | ||
this.docStore = docStore; | ||
this.similarityThreshold = (similarityThreshold != null) | ||
? similarityThreshold | ||
: SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL; | ||
this.topK = (topK != null) ? topK : SearchRequest.DEFAULT_TOP_K; | ||
this.filterExpression = (filterExpression != null) ? filterExpression : () -> null; | ||
this.parentIdKey = parentIdKey; | ||
} | ||
|
||
@Override | ||
public List<Document> retrieve(Query query) { | ||
Assert.notNull(query, "query cannot be null"); | ||
SearchRequest searchRequest = SearchRequest.builder() | ||
.query(query.text()) | ||
.filterExpression(this.filterExpression.get()) | ||
.similarityThreshold(this.similarityThreshold) | ||
.topK(this.topK) | ||
.build(); | ||
|
||
List<Document> subDocs = this.vectorStore.similaritySearch(searchRequest); | ||
if (subDocs == null || subDocs.isEmpty()) { | ||
return subDocs == null ? new ArrayList<>() : subDocs; | ||
} | ||
|
||
List<String> parentIds = new ArrayList<>(); | ||
for (Document chunk : subDocs) { | ||
String pid = (String) chunk.getMetadata().get(this.parentIdKey); | ||
if (pid != null && !parentIds.contains(pid)) { | ||
parentIds.add(pid); | ||
} | ||
} | ||
List<Document> parentDocs = this.docStore.get(parentIds); | ||
parentDocs.removeIf(Objects::isNull); | ||
return parentDocs; | ||
} | ||
|
||
public static Builder builder() { | ||
return new Builder(); | ||
} | ||
|
||
public static final class Builder { | ||
private VectorStore vectorStore; | ||
private DocumentStore docStore; | ||
private Double similarityThreshold; | ||
private Integer topK; | ||
private Supplier<Filter.Expression> filterExpression; | ||
private String parentIdKey = "doc_id"; | ||
|
||
public Builder vectorStore(VectorStore vectorStore) { | ||
this.vectorStore = vectorStore; | ||
return this; | ||
} | ||
|
||
public Builder docStore(DocumentStore docStore) { | ||
this.docStore = docStore; | ||
return this; | ||
} | ||
|
||
public Builder similarityThreshold(Double similarityThreshold) { | ||
this.similarityThreshold = similarityThreshold; | ||
return this; | ||
} | ||
|
||
public Builder topK(Integer topK) { | ||
this.topK = topK; | ||
return this; | ||
} | ||
|
||
public Builder filterExpression(Filter.Expression filterExpression) { | ||
this.filterExpression = () -> filterExpression; | ||
return this; | ||
} | ||
|
||
public Builder filterExpression(Supplier<Filter.Expression> filterExpression) { | ||
this.filterExpression = filterExpression; | ||
return this; | ||
} | ||
|
||
public Builder parentIdKey(String parentIdKey) { | ||
this.parentIdKey = parentIdKey; | ||
return this; | ||
} | ||
|
||
public MultiVectorRetriever build() { | ||
return new MultiVectorRetriever(this.vectorStore, this.docStore, | ||
this.similarityThreshold, this.topK, this.filterExpression, this.parentIdKey); | ||
} | ||
} | ||
} |