diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java index 2f9d6ee9291..6a2da274089 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClient.java @@ -16,27 +16,9 @@ package org.springframework.ai.chat.client; -import java.io.IOException; -import java.net.URL; -import java.nio.charset.Charset; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.concurrent.ConcurrentHashMap; -import java.util.function.Consumer; - import io.micrometer.observation.Observation; import io.micrometer.observation.ObservationRegistry; import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; -import org.springframework.ai.tool.ToolCallbacks; -import reactor.core.publisher.Flux; -import reactor.core.scheduler.Schedulers; - import org.springframework.ai.chat.client.advisor.DefaultAroundAdvisorChain; import org.springframework.ai.chat.client.advisor.api.AdvisedRequest; import org.springframework.ai.chat.client.advisor.api.AdvisedResponse; @@ -65,6 +47,7 @@ import org.springframework.ai.model.Media; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackWrapper; +import org.springframework.ai.tool.ToolCallbacks; import org.springframework.core.Ordered; import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.io.Resource; @@ -73,6 +56,22 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.MimeType; import org.springframework.util.StringUtils; +import reactor.core.publisher.Flux; +import reactor.core.scheduler.Schedulers; + +import java.io.IOException; +import java.net.URL; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; /** * The default implementation of {@link ChatClient} as created by the @@ -393,6 +392,8 @@ public static class DefaultCallResponseSpec implements CallResponseSpec { private final DefaultChatClientRequestSpec request; + private final ThreadLocal> memoizedResponse = ThreadLocal.withInitial(Optional::empty); + public DefaultCallResponseSpec(DefaultChatClientRequestSpec request) { Assert.notNull(request, "request cannot be null"); this.request = request; @@ -506,13 +507,16 @@ private static String getContentFromChatResponse(@Nullable ChatResponse chatResp @Override @Nullable public ChatResponse chatResponse() { - return doGetChatResponse(); + final var chatResponse = memoizedResponse.get().orElseGet(this::doGetChatResponse); + memoizedResponse.set(Optional.ofNullable(chatResponse)); + return chatResponse; } @Override @Nullable public String content() { - ChatResponse chatResponse = doGetChatResponse(); + final var chatResponse = memoizedResponse.get().orElseGet(this::doGetChatResponse); + memoizedResponse.set(Optional.ofNullable(chatResponse)); return getContentFromChatResponse(chatResponse); } @@ -522,6 +526,8 @@ public static class DefaultStreamResponseSpec implements StreamResponseSpec { private final DefaultChatClientRequestSpec request; + private final ThreadLocal>> memoizedFlux = ThreadLocal.withInitial(Optional::empty); + public DefaultStreamResponseSpec(DefaultChatClientRequestSpec request) { Assert.notNull(request, "request cannot be null"); this.request = request; @@ -559,12 +565,18 @@ private Flux doGetObservableFluxChatResponse(DefaultChatClientRequ @Override public Flux chatResponse() { - return doGetObservableFluxChatResponse(this.request); + final var chatResponseFlux = memoizedFlux.get() + .orElseGet(() -> doGetObservableFluxChatResponse(this.request)); + memoizedFlux.set(Optional.of(chatResponseFlux)); + return chatResponseFlux; } @Override public Flux content() { - return doGetObservableFluxChatResponse(this.request).map(r -> { + final var chatResponseFlux = memoizedFlux.get() + .orElseGet(() -> doGetObservableFluxChatResponse(this.request)); + memoizedFlux.set(Optional.of(chatResponseFlux)); + return chatResponseFlux.map(r -> { if (r.getResult() == null || r.getResult().getOutput() == null || r.getResult().getOutput().getText() == null) { return "";