diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/generation/augmentation/ContextualQueryAugmenter.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/generation/augmentation/ContextualQueryAugmenter.java index 42ab33d4d4f..eb6c92a8977 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/rag/generation/augmentation/ContextualQueryAugmenter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/generation/augmentation/ContextualQueryAugmenter.java @@ -18,6 +18,7 @@ import java.util.List; import java.util.Map; +import java.util.function.Function; import java.util.stream.Collectors; import org.slf4j.Logger; @@ -75,18 +76,30 @@ public final class ContextualQueryAugmenter implements QueryAugmenter { private static final boolean DEFAULT_ALLOW_EMPTY_CONTEXT = false; + /** + * Default document formatter that just joins document text with newlines + */ + private static final Function, String> DEFAULT_DOCUMENT_FORMATTER = documents -> + documents.stream() + .map(Document::getText) + .collect(Collectors.joining(System.lineSeparator())); + private final PromptTemplate promptTemplate; private final PromptTemplate emptyContextPromptTemplate; private final boolean allowEmptyContext; + private final Function, String> documentFormatter; + public ContextualQueryAugmenter(@Nullable PromptTemplate promptTemplate, - @Nullable PromptTemplate emptyContextPromptTemplate, @Nullable Boolean allowEmptyContext) { + @Nullable PromptTemplate emptyContextPromptTemplate, @Nullable Boolean allowEmptyContext, + @Nullable Function, String> documentFormatter) { this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE; this.emptyContextPromptTemplate = emptyContextPromptTemplate != null ? emptyContextPromptTemplate : DEFAULT_EMPTY_CONTEXT_PROMPT_TEMPLATE; this.allowEmptyContext = allowEmptyContext != null ? allowEmptyContext : DEFAULT_ALLOW_EMPTY_CONTEXT; + this.documentFormatter = documentFormatter != null ? documentFormatter : DEFAULT_DOCUMENT_FORMATTER; PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "query", "context"); } @@ -102,9 +115,7 @@ public Query augment(Query query, List documents) { } // 1. Collect content from documents. - String documentContext = documents.stream() - .map(Document::getText) - .collect(Collectors.joining(System.lineSeparator())); + String documentContext = this.documentFormatter.apply(documents); // 2. Define prompt parameters. Map promptParameters = Map.of("query", query.text(), "context", documentContext); @@ -134,6 +145,8 @@ public static class Builder { private Boolean allowEmptyContext; + private Function, String> documentFormatter; + public Builder promptTemplate(PromptTemplate promptTemplate) { this.promptTemplate = promptTemplate; return this; @@ -149,9 +162,14 @@ public Builder allowEmptyContext(Boolean allowEmptyContext) { return this; } + public Builder documentFormatter(Function, String> documentFormatter) { + this.documentFormatter = documentFormatter; + return this; + } + public ContextualQueryAugmenter build() { return new ContextualQueryAugmenter(this.promptTemplate, this.emptyContextPromptTemplate, - this.allowEmptyContext); + this.allowEmptyContext, this.documentFormatter); } }