diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/redis.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/redis.adoc index 0165b9d8bdb..ea509063a17 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/redis.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs/redis.adoc @@ -9,6 +9,8 @@ link:https://redis.io/docs/interact/search-and-query/[Redis Search and Query] ex * Store vectors and the associated metadata within hashes or JSON documents * Retrieve vectors * Perform vector searches +* Cache chat responses based on semantic similarity +* Store and query conversation history == Prerequisites @@ -152,6 +154,69 @@ is converted into the proprietary Redis filter format: @country:{UK | NL} @year:[2020 inf] ---- +=== Semantic Cache Usage + +The semantic cache provides vector similarity-based caching for chat responses implemented as an advisor: + +[source,java] +---- +// Create semantic cache +SemanticCache semanticCache = DefaultSemanticCache.builder() + .embeddingModel(embeddingModel) + .jedisClient(jedisClient) + .similarityThreshold(0.95) // Optional: defaults to 0.95 + .build(); + +// Create cache advisor +SemanticCacheAdvisor cacheAdvisor = SemanticCacheAdvisor.builder() + .cache(semanticCache) + .build(); + +// Use with chat client +ChatResponse response = ChatClient.builder(chatModel) + .build() + .prompt("What is the capital of France?") + .advisors(cacheAdvisor) + .call() + .chatResponse(); + +// Manually interact with cache +semanticCache.set("query", chatResponse); +semanticCache.set("query", chatResponse, Duration.ofHours(1)); // With TTL +Optional cached = semanticCache.get("similar query"); +---- + +=== Chat Memory Usage + +RedisChatMemory provides persistent storage for conversation history: + +[source,java] +---- +// Create chat memory +RedisChatMemory chatMemory = RedisChatMemory.builder() + .jedisClient(jedisClient) + .timeToLive(Duration.ofHours(24)) // Optional: message TTL + .indexName("custom-memory-index") // Optional + .keyPrefix("custom-prefix") // Optional + .build(); + +// Add messages +chatMemory.add("conversation-1", new UserMessage("Hello")); +chatMemory.add("conversation-1", new AssistantMessage("Hi there!")); + +// Add multiple messages +chatMemory.add("conversation-1", List.of( + new UserMessage("How are you?"), + new AssistantMessage("I'm doing well!") +)); + +// Retrieve messages +List messages = chatMemory.get("conversation-1", 10); // Last 10 messages + +// Clear conversation +chatMemory.clear("conversation-1"); +---- + == Manual Configuration Instead of using the Spring Boot auto-configuration, you can manually configure the Redis vector store. For this you need to add the `spring-ai-redis-store` to your project: diff --git a/vector-stores/spring-ai-redis-store/pom.xml b/vector-stores/spring-ai-redis-store/pom.xml index 34b078cd95f..31fbd73bdfe 100644 --- a/vector-stores/spring-ai-redis-store/pom.xml +++ b/vector-stores/spring-ai-redis-store/pom.xml @@ -101,6 +101,13 @@ test + + org.springframework.ai + spring-ai-openai + ${project.parent.version} + test + + diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/cache/semantic/SemanticCacheAdvisor.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/cache/semantic/SemanticCacheAdvisor.java new file mode 100644 index 00000000000..3f9efb5972b --- /dev/null +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/cache/semantic/SemanticCacheAdvisor.java @@ -0,0 +1,188 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.cache.semantic; + +import org.springframework.ai.chat.client.advisor.api.*; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.vectorstore.redis.cache.semantic.SemanticCache; +import reactor.core.publisher.Flux; + +import java.util.Optional; + +/** + * An advisor implementation that provides semantic caching capabilities for chat + * responses. This advisor intercepts chat requests and checks for semantically similar + * cached responses before allowing the request to proceed to the model. + * + *

+ * This advisor implements both {@link CallAroundAdvisor} for synchronous operations and + * {@link StreamAroundAdvisor} for reactive streaming operations. + *

+ * + *

+ * Key features: + *

+ * + * @author Brian Sam-Bodden + */ +public class SemanticCacheAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { + + /** The underlying semantic cache implementation */ + private final SemanticCache cache; + + /** The order of this advisor in the chain */ + private final int order; + + /** + * Creates a new semantic cache advisor with default order. + * @param cache The semantic cache implementation to use + */ + public SemanticCacheAdvisor(SemanticCache cache) { + this(cache, Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER); + } + + /** + * Creates a new semantic cache advisor with specified order. + * @param cache The semantic cache implementation to use + * @param order The order of this advisor in the chain + */ + public SemanticCacheAdvisor(SemanticCache cache, int order) { + this.cache = cache; + this.order = order; + } + + @Override + public String getName() { + return this.getClass().getSimpleName(); + } + + @Override + public int getOrder() { + return this.order; + } + + /** + * Handles synchronous chat requests by checking the cache before proceeding. If a + * semantically similar response is found in the cache, it is returned immediately. + * Otherwise, the request proceeds through the chain and the response is cached. + * @param request The chat request to process + * @param chain The advisor chain to continue processing if needed + * @return The response, either from cache or from the model + */ + @Override + public AdvisedResponse aroundCall(AdvisedRequest request, CallAroundAdvisorChain chain) { + // Check cache first + Optional cached = cache.get(request.userText()); + + if (cached.isPresent()) { + return new AdvisedResponse(cached.get(), request.adviseContext()); + } + + // Cache miss - call the model + AdvisedResponse response = chain.nextAroundCall(request); + + // Cache the response + if (response.response() != null) { + cache.set(request.userText(), response.response()); + } + + return response; + } + + /** + * Handles streaming chat requests by checking the cache before proceeding. If a + * semantically similar response is found in the cache, it is returned as a single + * item flux. Otherwise, the request proceeds through the chain and the final response + * is cached. + * @param request The chat request to process + * @param chain The advisor chain to continue processing if needed + * @return A Flux of responses, either from cache or from the model + */ + @Override + public Flux aroundStream(AdvisedRequest request, StreamAroundAdvisorChain chain) { + // Check cache first + Optional cached = cache.get(request.userText()); + + if (cached.isPresent()) { + return Flux.just(new AdvisedResponse(cached.get(), request.adviseContext())); + } + + // Cache miss - stream from model + return chain.nextAroundStream(request).collectList().flatMapMany(responses -> { + // Cache the final aggregated response + if (!responses.isEmpty()) { + AdvisedResponse last = responses.get(responses.size() - 1); + if (last.response() != null) { + cache.set(request.userText(), last.response()); + } + } + return Flux.fromIterable(responses); + }); + } + + /** + * Creates a new builder for constructing SemanticCacheAdvisor instances. + * @return A new builder instance + */ + public static Builder builder() { + return new Builder(); + } + + /** + * Builder class for creating SemanticCacheAdvisor instances. Provides a fluent API + * for configuration. + */ + public static class Builder { + + private SemanticCache cache; + + private int order = Advisor.DEFAULT_CHAT_MEMORY_PRECEDENCE_ORDER; + + /** + * Sets the semantic cache implementation. + * @param cache The cache implementation to use + * @return This builder instance + */ + public Builder cache(SemanticCache cache) { + this.cache = cache; + return this; + } + + /** + * Sets the advisor order. + * @param order The order value for this advisor + * @return This builder instance + */ + public Builder order(int order) { + this.order = order; + return this; + } + + /** + * Builds and returns a new SemanticCacheAdvisor instance. + * @return A new SemanticCacheAdvisor configured with this builder's settings + */ + public SemanticCacheAdvisor build() { + return new SemanticCacheAdvisor(cache, order); + } + + } + +} diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/memory/redis/RedisChatMemory.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/memory/redis/RedisChatMemory.java new file mode 100644 index 00000000000..a0fc4e3418e --- /dev/null +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/memory/redis/RedisChatMemory.java @@ -0,0 +1,228 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.memory.redis; + +import com.google.gson.Gson; +import com.google.gson.JsonObject; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.memory.ChatMemory; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.MessageType; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.util.Assert; +import redis.clients.jedis.JedisPooled; +import redis.clients.jedis.Pipeline; +import redis.clients.jedis.json.Path2; +import redis.clients.jedis.search.*; +import redis.clients.jedis.search.schemafields.NumericField; +import redis.clients.jedis.search.schemafields.SchemaField; +import redis.clients.jedis.search.schemafields.TagField; +import redis.clients.jedis.search.schemafields.TextField; + +import java.time.Duration; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; + +/** + * Redis implementation of {@link ChatMemory} using Redis Stack (RedisJSON + RediSearch). + * Stores chat messages as JSON documents and uses RediSearch for querying. + * + * @author Brian Sam-Bodden + */ +public final class RedisChatMemory implements ChatMemory { + + private static final Logger logger = LoggerFactory.getLogger(RedisChatMemory.class); + + private static final Gson gson = new Gson(); + + private static final Path2 ROOT_PATH = Path2.of("$"); + + private final RedisChatMemoryConfig config; + + private final JedisPooled jedis; + + public RedisChatMemory(RedisChatMemoryConfig config) { + Assert.notNull(config, "Config must not be null"); + this.config = config; + this.jedis = config.getJedisClient(); + + if (config.isInitializeSchema()) { + initializeSchema(); + } + } + + public static Builder builder() { + return new Builder(); + } + + @Override + public void add(String conversationId, List messages) { + Assert.notNull(conversationId, "Conversation ID must not be null"); + Assert.notNull(messages, "Messages must not be null"); + + final AtomicLong timestampSequence = new AtomicLong(Instant.now().toEpochMilli()); + try (Pipeline pipeline = jedis.pipelined()) { + for (Message message : messages) { + String key = createKey(conversationId, timestampSequence.getAndIncrement()); + String json = gson.toJson(createMessageDocument(conversationId, message)); + pipeline.jsonSet(key, ROOT_PATH, json); + + if (config.getTimeToLiveSeconds() != -1) { + pipeline.expire(key, config.getTimeToLiveSeconds()); + } + } + pipeline.sync(); + } + } + + @Override + public void add(String conversationId, Message message) { + Assert.notNull(conversationId, "Conversation ID must not be null"); + Assert.notNull(message, "Message must not be null"); + + String key = createKey(conversationId, Instant.now().toEpochMilli()); + String json = gson.toJson(createMessageDocument(conversationId, message)); + + jedis.jsonSet(key, ROOT_PATH, json); + if (config.getTimeToLiveSeconds() != -1) { + jedis.expire(key, config.getTimeToLiveSeconds()); + } + } + + @Override + public List get(String conversationId, int lastN) { + Assert.notNull(conversationId, "Conversation ID must not be null"); + Assert.isTrue(lastN > 0, "LastN must be greater than 0"); + + String queryStr = String.format("@conversation_id:{%s}", RediSearchUtil.escape(conversationId)); + Query query = new Query(queryStr).setSortBy("timestamp", true).limit(0, lastN); + + SearchResult result = jedis.ftSearch(config.getIndexName(), query); + + List messages = new ArrayList<>(); + result.getDocuments().forEach(doc -> { + if (doc.get("$") != null) { + JsonObject json = gson.fromJson(doc.getString("$"), JsonObject.class); + String type = json.get("type").getAsString(); + String content = json.get("content").getAsString(); + + if (MessageType.ASSISTANT.toString().equals(type)) { + messages.add(new AssistantMessage(content)); + } + else if (MessageType.USER.toString().equals(type)) { + messages.add(new UserMessage(content)); + } + } + }); + + return messages; + } + + @Override + public void clear(String conversationId) { + Assert.notNull(conversationId, "Conversation ID must not be null"); + + String queryStr = String.format("@conversation_id:{%s}", RediSearchUtil.escape(conversationId)); + Query query = new Query(queryStr); + SearchResult result = jedis.ftSearch(config.getIndexName(), query); + + try (Pipeline pipeline = jedis.pipelined()) { + result.getDocuments().forEach(doc -> pipeline.del(doc.getId())); + pipeline.sync(); + } + } + + private void initializeSchema() { + try { + if (!jedis.ftList().contains(config.getIndexName())) { + List schemaFields = new ArrayList<>(); + schemaFields.add(new TextField("$.content").as("content")); + schemaFields.add(new TextField("$.type").as("type")); + schemaFields.add(new TagField("$.conversation_id").as("conversation_id")); + schemaFields.add(new NumericField("$.timestamp").as("timestamp")); + + String response = jedis.ftCreate(config.getIndexName(), + FTCreateParams.createParams().on(IndexDataType.JSON).prefix(config.getKeyPrefix()), + schemaFields.toArray(new SchemaField[0])); + + if (!response.equals("OK")) { + throw new IllegalStateException("Failed to create index: " + response); + } + } + } + catch (Exception e) { + logger.error("Failed to initialize Redis schema", e); + throw new IllegalStateException("Could not initialize Redis schema", e); + } + } + + private String createKey(String conversationId, long timestamp) { + return String.format("%s%s:%d", config.getKeyPrefix(), escapeKey(conversationId), timestamp); + } + + private Map createMessageDocument(String conversationId, Message message) { + return Map.of("type", message.getMessageType().toString(), "content", message.getText(), "conversation_id", + conversationId, "timestamp", Instant.now().toEpochMilli()); + } + + private String escapeKey(String key) { + return key.replace(":", "\\:"); + } + + /** + * Builder for RedisChatMemory configuration. + */ + public static class Builder { + + private final RedisChatMemoryConfig.Builder configBuilder = RedisChatMemoryConfig.builder(); + + public Builder jedisClient(JedisPooled jedisClient) { + configBuilder.jedisClient(jedisClient); + return this; + } + + public Builder timeToLive(Duration ttl) { + configBuilder.timeToLive(ttl); + return this; + } + + public Builder indexName(String indexName) { + configBuilder.indexName(indexName); + return this; + } + + public Builder keyPrefix(String keyPrefix) { + configBuilder.keyPrefix(keyPrefix); + return this; + } + + public Builder initializeSchema(boolean initialize) { + configBuilder.initializeSchema(initialize); + return this; + } + + public RedisChatMemory build() { + return new RedisChatMemory(configBuilder.build()); + } + + } + +} diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/memory/redis/RedisChatMemoryConfig.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/memory/redis/RedisChatMemoryConfig.java new file mode 100644 index 00000000000..fe4323d5418 --- /dev/null +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/chat/memory/redis/RedisChatMemoryConfig.java @@ -0,0 +1,158 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.memory.redis; + +import java.time.Duration; + +import redis.clients.jedis.JedisPooled; + +import org.springframework.util.Assert; + +/** + * Configuration class for RedisChatMemory. + * + * @author Brian Sam-Bodden + */ +public class RedisChatMemoryConfig { + + public static final String DEFAULT_INDEX_NAME = "chat-memory-idx"; + + public static final String DEFAULT_KEY_PREFIX = "chat-memory:"; + + private final JedisPooled jedisClient; + + private final String indexName; + + private final String keyPrefix; + + private final Integer timeToLiveSeconds; + + private final boolean initializeSchema; + + private RedisChatMemoryConfig(Builder builder) { + Assert.notNull(builder.jedisClient, "JedisPooled client must not be null"); + Assert.hasText(builder.indexName, "Index name must not be empty"); + Assert.hasText(builder.keyPrefix, "Key prefix must not be empty"); + + this.jedisClient = builder.jedisClient; + this.indexName = builder.indexName; + this.keyPrefix = builder.keyPrefix; + this.timeToLiveSeconds = builder.timeToLiveSeconds; + this.initializeSchema = builder.initializeSchema; + } + + public static Builder builder() { + return new Builder(); + } + + public JedisPooled getJedisClient() { + return jedisClient; + } + + public String getIndexName() { + return indexName; + } + + public String getKeyPrefix() { + return keyPrefix; + } + + public Integer getTimeToLiveSeconds() { + return timeToLiveSeconds; + } + + public boolean isInitializeSchema() { + return initializeSchema; + } + + /** + * Builder for RedisChatMemoryConfig. + */ + public static class Builder { + + private JedisPooled jedisClient; + + private String indexName = DEFAULT_INDEX_NAME; + + private String keyPrefix = DEFAULT_KEY_PREFIX; + + private Integer timeToLiveSeconds = -1; + + private boolean initializeSchema = true; + + /** + * Sets the Redis client. + * @param jedisClient the Redis client to use + * @return the builder instance + */ + public Builder jedisClient(JedisPooled jedisClient) { + this.jedisClient = jedisClient; + return this; + } + + /** + * Sets the index name. + * @param indexName the index name to use + * @return the builder instance + */ + public Builder indexName(String indexName) { + this.indexName = indexName; + return this; + } + + /** + * Sets the key prefix. + * @param keyPrefix the key prefix to use + * @return the builder instance + */ + public Builder keyPrefix(String keyPrefix) { + this.keyPrefix = keyPrefix; + return this; + } + + /** + * Sets the time-to-live duration. + * @param ttl the time-to-live duration + * @return the builder instance + */ + public Builder timeToLive(Duration ttl) { + if (ttl != null) { + this.timeToLiveSeconds = (int) ttl.toSeconds(); + } + return this; + } + + /** + * Sets whether to initialize the schema. + * @param initialize true to initialize schema, false otherwise + * @return the builder instance + */ + public Builder initializeSchema(boolean initialize) { + this.initializeSchema = initialize; + return this; + } + + /** + * Builds a new RedisChatMemoryConfig instance. + * @return the new configuration instance + */ + public RedisChatMemoryConfig build() { + return new RedisChatMemoryConfig(this); + } + + } + +} diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/cache/semantic/DefaultSemanticCache.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/cache/semantic/DefaultSemanticCache.java new file mode 100644 index 00000000000..1309cb6dab5 --- /dev/null +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/cache/semantic/DefaultSemanticCache.java @@ -0,0 +1,354 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vectorstore.redis.cache.semantic; + +import com.google.gson.*; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.vectorstore.SearchRequest; +import org.springframework.ai.vectorstore.VectorStore; +import org.springframework.ai.vectorstore.redis.RedisVectorStore; +import org.springframework.ai.vectorstore.redis.RedisVectorStore.MetadataField; +import redis.clients.jedis.JedisPooled; +import redis.clients.jedis.Pipeline; +import redis.clients.jedis.search.Query; +import redis.clients.jedis.search.SearchResult; + +import java.lang.reflect.Type; +import java.time.Duration; +import java.util.*; + +/** + * Default implementation of SemanticCache using Redis as the backing store. This + * implementation uses vector similarity search to find cached responses for semantically + * similar queries. + * + * @author Brian Sam-Bodden + */ +public class DefaultSemanticCache implements SemanticCache { + + // Default configuration constants + private static final String DEFAULT_INDEX_NAME = "semantic-cache-index"; + + private static final String DEFAULT_PREFIX = "semantic-cache:"; + + private static final Integer DEFAULT_BATCH_SIZE = 100; + + private static final double DEFAULT_SIMILARITY_THRESHOLD = 0.95; + + // Core components + private final VectorStore vectorStore; + + private final EmbeddingModel embeddingModel; + + private final double similarityThreshold; + + private final Gson gson; + + private final String prefix; + + private final String indexName; + + /** + * Private constructor enforcing builder pattern usage. + */ + private DefaultSemanticCache(VectorStore vectorStore, EmbeddingModel embeddingModel, double similarityThreshold, + String indexName, String prefix) { + this.vectorStore = vectorStore; + this.embeddingModel = embeddingModel; + this.similarityThreshold = similarityThreshold; + this.prefix = prefix; + this.indexName = indexName; + this.gson = createGson(); + } + + /** + * Creates a customized Gson instance with type adapters for special types. + */ + private Gson createGson() { + return new GsonBuilder() // + .registerTypeAdapter(Duration.class, new DurationAdapter()) // + .registerTypeAdapter(ChatResponse.class, new ChatResponseAdapter()) // + .create(); + } + + @Override + public VectorStore getStore() { + return this.vectorStore; + } + + @Override + public void set(String query, ChatResponse response) { + // Convert response to JSON for storage + String responseJson = gson.toJson(response); + String responseText = response.getResult().getOutput().getText(); + + // Create metadata map for the document + Map metadata = new HashMap<>(); + metadata.put("response", responseJson); + metadata.put("response_text", responseText); + + // Create document with query as text (for embedding) and response in metadata + Document document = Document.builder().text(query).metadata(metadata).build(); + + // Check for and remove any existing similar documents + List existing = vectorStore.similaritySearch( + SearchRequest.builder().query(query).topK(1).similarityThreshold(similarityThreshold).build()); + + // If similar document exists, delete it first + if (!existing.isEmpty()) { + vectorStore.delete(List.of(existing.get(0).getId())); + } + + // Add new document to vector store + vectorStore.add(List.of(document)); + } + + @Override + public void set(String query, ChatResponse response, Duration ttl) { + // Generate a unique ID for the document + String docId = UUID.randomUUID().toString(); + + // Convert response to JSON + String responseJson = gson.toJson(response); + String responseText = response.getResult().getOutput().getText(); + + // Create metadata + Map metadata = new HashMap<>(); + metadata.put("response", responseJson); + metadata.put("response_text", responseText); + + // Create document with generated ID + Document document = Document.builder().id(docId).text(query).metadata(metadata).build(); + + // Remove any existing similar documents + List existing = vectorStore.similaritySearch( + SearchRequest.builder().query(query).topK(1).similarityThreshold(similarityThreshold).build()); + + // If similar document exists, delete it first + if (!existing.isEmpty()) { + vectorStore.delete(List.of(existing.get(0).getId())); + } + + // Add document to vector store + vectorStore.add(List.of(document)); + + // Get access to Redis client and set TTL + if (vectorStore instanceof RedisVectorStore redisStore) { + String key = prefix + docId; + redisStore.getJedis().expire(key, ttl.getSeconds()); + } + } + + @Override + public Optional get(String query) { + // Search for similar documents + List similar = vectorStore.similaritySearch( + SearchRequest.builder().query(query).topK(1).similarityThreshold(similarityThreshold).build()); + + if (similar.isEmpty()) { + return Optional.empty(); + } + + Document mostSimilar = similar.get(0); + + // Get stored response JSON from metadata + String responseJson = (String) mostSimilar.getMetadata().get("response"); + if (responseJson == null) { + return Optional.empty(); + } + + // Attempt to parse stored response + try { + ChatResponse response = gson.fromJson(responseJson, ChatResponse.class); + return Optional.of(response); + } + catch (JsonParseException e) { + return Optional.empty(); + } + } + + @Override + public void clear() { + Optional nativeClient = vectorStore.getNativeClient(); + if (nativeClient.isPresent()) { + JedisPooled jedis = nativeClient.get(); + + // Delete documents in batches to avoid memory issues + boolean moreRecords = true; + while (moreRecords) { + Query query = new Query("*"); + query.limit(0, DEFAULT_BATCH_SIZE); // Reasonable batch size + query.setNoContent(); + + SearchResult searchResult = jedis.ftSearch(this.indexName, query); + + if (searchResult.getTotalResults() > 0) { + try (Pipeline pipeline = jedis.pipelined()) { + for (redis.clients.jedis.search.Document doc : searchResult.getDocuments()) { + pipeline.jsonDel(doc.getId()); + } + pipeline.syncAndReturnAll(); + } + } + else { + moreRecords = false; + } + } + } + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for creating DefaultSemanticCache instances. + */ + public static class Builder { + + private VectorStore vectorStore; + + private EmbeddingModel embeddingModel; + + private double similarityThreshold = DEFAULT_SIMILARITY_THRESHOLD; + + private String indexName = DEFAULT_INDEX_NAME; + + private String prefix = DEFAULT_PREFIX; + + private JedisPooled jedisClient; + + // Builder methods with validation + public Builder vectorStore(VectorStore vectorStore) { + this.vectorStore = vectorStore; + return this; + } + + public Builder embeddingModel(EmbeddingModel embeddingModel) { + this.embeddingModel = embeddingModel; + return this; + } + + public Builder similarityThreshold(double threshold) { + this.similarityThreshold = threshold; + return this; + } + + public Builder indexName(String indexName) { + this.indexName = indexName; + return this; + } + + public Builder prefix(String prefix) { + this.prefix = prefix; + return this; + } + + public Builder jedisClient(JedisPooled jedisClient) { + this.jedisClient = jedisClient; + return this; + } + + public DefaultSemanticCache build() { + if (vectorStore == null) { + if (jedisClient == null) { + throw new IllegalStateException("Either vectorStore or jedisClient must be provided"); + } + if (embeddingModel == null) { + throw new IllegalStateException("EmbeddingModel must be provided"); + } + vectorStore = RedisVectorStore.builder(jedisClient, embeddingModel) + .indexName(indexName) + .prefix(prefix) + .metadataFields( // + MetadataField.text("response"), // + MetadataField.text("response_text"), // + MetadataField.numeric("ttl")) // + .initializeSchema(true) + .build(); + if (vectorStore instanceof RedisVectorStore redisStore) { + redisStore.afterPropertiesSet(); + } + } + return new DefaultSemanticCache(vectorStore, embeddingModel, similarityThreshold, indexName, prefix); + } + + } + + /** + * Type adapter for serializing/deserializing Duration objects. + */ + private static class DurationAdapter implements JsonSerializer, JsonDeserializer { + + @Override + public JsonElement serialize(Duration duration, Type type, JsonSerializationContext context) { + return new JsonPrimitive(duration.toSeconds()); + } + + @Override + public Duration deserialize(JsonElement json, Type type, JsonDeserializationContext context) + throws JsonParseException { + return Duration.ofSeconds(json.getAsLong()); + } + + } + + /** + * Type adapter for serializing/deserializing ChatResponse objects. + */ + private static class ChatResponseAdapter implements JsonSerializer, JsonDeserializer { + + @Override + public JsonElement serialize(ChatResponse response, Type type, JsonSerializationContext context) { + JsonObject jsonObject = new JsonObject(); + + // Handle generations + JsonArray generations = new JsonArray(); + for (Generation generation : response.getResults()) { + JsonObject generationObj = new JsonObject(); + Message output = (Message) generation.getOutput(); + generationObj.addProperty("text", output.getText()); + generations.add(generationObj); + } + jsonObject.add("generations", generations); + + return jsonObject; + } + + @Override + public ChatResponse deserialize(JsonElement json, Type type, JsonDeserializationContext context) + throws JsonParseException { + JsonObject jsonObject = json.getAsJsonObject(); + + List generations = new ArrayList<>(); + JsonArray generationsArray = jsonObject.getAsJsonArray("generations"); + for (JsonElement element : generationsArray) { + JsonObject generationObj = element.getAsJsonObject(); + String text = generationObj.get("text").getAsString(); + generations.add(new Generation(new AssistantMessage(text))); + } + + return ChatResponse.builder().generations(generations).build(); + } + + } + +} diff --git a/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/cache/semantic/SemanticCache.java b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/cache/semantic/SemanticCache.java new file mode 100644 index 00000000000..d678107a9a7 --- /dev/null +++ b/vector-stores/spring-ai-redis-store/src/main/java/org/springframework/ai/vectorstore/redis/cache/semantic/SemanticCache.java @@ -0,0 +1,91 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.vectorstore.redis.cache.semantic; + +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.vectorstore.VectorStore; + +import java.time.Duration; +import java.util.Optional; + +/** + * Interface defining operations for a semantic cache implementation that stores and + * retrieves chat responses based on semantic similarity of queries. This cache uses + * vector embeddings to determine similarity between queries. + * + *

+ * The semantic cache provides functionality to: + *

    + *
  • Store chat responses with their associated queries
  • + *
  • Retrieve responses for semantically similar queries
  • + *
  • Support time-based expiration of cached entries
  • + *
  • Clear the entire cache
  • + *
+ * + *

+ * Implementations should ensure thread-safety and proper resource management. + * + * @author Brian Sam-Bodden + */ +public interface SemanticCache { + + /** + * Stores a query and its corresponding chat response in the cache. Implementations + * should handle vector embedding of the query and proper storage of both the query + * embedding and response. + * @param query The original query text to be cached + * @param response The chat response associated with the query + */ + void set(String query, ChatResponse response); + + /** + * Stores a query and response in the cache with a specified time-to-live duration. + * After the TTL expires, the entry should be automatically removed from the cache. + * @param query The original query text to be cached + * @param response The chat response associated with the query + * @param ttl The duration after which the cache entry should expire + */ + void set(String query, ChatResponse response, Duration ttl); + + /** + * Retrieves a cached response for a semantically similar query. The implementation + * should: + *

    + *
  • Convert the input query to a vector embedding
  • + *
  • Search for similar query embeddings in the cache
  • + *
  • Return the response associated with the most similar query if it meets the + * similarity threshold
  • + *
+ * @param query The query to find similar responses for + * @return Optional containing the most similar cached response if found and meets + * similarity threshold, empty Optional otherwise + */ + Optional get(String query); + + /** + * Removes all entries from the cache. This operation should be atomic and + * thread-safe. + */ + void clear(); + + /** + * Returns the underlying vector store used by this cache implementation. This allows + * access to lower-level vector operations if needed. + * @return The VectorStore instance used by this cache + */ + VectorStore getStore(); + +} diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/chat/cache/semantic/SemanticCacheAdvisorIT.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/chat/cache/semantic/SemanticCacheAdvisorIT.java new file mode 100644 index 00000000000..138e7eb7856 --- /dev/null +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/chat/cache/semantic/SemanticCacheAdvisorIT.java @@ -0,0 +1,226 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.chat.cache.semantic; + +import com.redis.testcontainers.RedisStackContainer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.cache.semantic.SemanticCacheAdvisorIT.TestApplication; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.openai.OpenAiEmbeddingModel; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.vectorstore.redis.cache.semantic.DefaultSemanticCache; +import org.springframework.ai.vectorstore.redis.cache.semantic.SemanticCache; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.data.redis.RedisAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.data.redis.connection.jedis.JedisConnectionFactory; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import redis.clients.jedis.JedisPooled; + +import java.time.Duration; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Test the Redis-based advisor that provides semantic caching capabilities for chat + * responses + * + * @author Brian Sam-Bodden + */ +@Testcontainers +@SpringBootTest(classes = TestApplication.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") +class SemanticCacheAdvisorIT { + + @Container + static RedisStackContainer redisContainer = new RedisStackContainer( + RedisStackContainer.DEFAULT_IMAGE_NAME.withTag(RedisStackContainer.DEFAULT_TAG)); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withConfiguration(AutoConfigurations.of(RedisAutoConfiguration.class)) + .withUserConfiguration(TestApplication.class) + .withPropertyValues("spring.data.redis.url=" + redisContainer.getRedisURI()); + + @Autowired + OpenAiChatModel openAiChatModel; + + @Autowired + SemanticCache semanticCache; + + @AfterEach + void tearDown() { + semanticCache.clear(); + } + + @Test + void semanticCacheTest() { + this.contextRunner.run(context -> { + String question = "What is the capital of France?"; + String expectedResponse = "Paris is the capital of France."; + + // First, simulate a cached response + semanticCache.set(question, createMockResponse(expectedResponse)); + + // Create advisor + SemanticCacheAdvisor cacheAdvisor = SemanticCacheAdvisor.builder().cache(semanticCache).build(); + + // Test with a semantically similar question + String similarQuestion = "Tell me which city is France's capital?"; + ChatResponse chatResponse = ChatClient.builder(openAiChatModel) + .build() + .prompt(similarQuestion) + .advisors(cacheAdvisor) + .call() + .chatResponse(); + + assertThat(chatResponse).isNotNull(); + String response = chatResponse.getResult().getOutput().getText(); + assertThat(response).containsIgnoringCase("Paris"); + + // Test cache miss with a different question + String differentQuestion = "What is the population of Tokyo?"; + ChatResponse newResponse = ChatClient.builder(openAiChatModel) + .build() + .prompt(differentQuestion) + .advisors(cacheAdvisor) + .call() + .chatResponse(); + + assertThat(newResponse).isNotNull(); + String newResponseText = newResponse.getResult().getOutput().getText(); + assertThat(newResponseText).doesNotContain(expectedResponse); + + // Verify the new response was cached + ChatResponse cachedNewResponse = semanticCache.get(differentQuestion).orElseThrow(); + assertThat(cachedNewResponse.getResult().getOutput().getText()) + .isEqualTo(newResponse.getResult().getOutput().getText()); + }); + } + + @Test + void semanticCacheTTLTest() throws InterruptedException { + this.contextRunner.run(context -> { + String question = "What is the capital of France?"; + String expectedResponse = "Paris is the capital of France."; + + // Set with short TTL + semanticCache.set(question, createMockResponse(expectedResponse), Duration.ofSeconds(2)); + + // Create advisor + SemanticCacheAdvisor cacheAdvisor = SemanticCacheAdvisor.builder().cache(semanticCache).build(); + + // Verify key exists + Optional nativeClient = semanticCache.getStore().getNativeClient(); + assertThat(nativeClient).isPresent(); + JedisPooled jedis = nativeClient.get(); + + Set keys = jedis.keys("semantic-cache:*"); + assertThat(keys).hasSize(1); + String key = keys.iterator().next(); + + // Verify TTL is set + Long ttl = jedis.ttl(key); + assertThat(ttl).isGreaterThan(0); + assertThat(ttl).isLessThanOrEqualTo(2); + + // Test cache hit before expiry + String similarQuestion = "Tell me which city is France's capital?"; + ChatResponse chatResponse = ChatClient.builder(openAiChatModel) + .build() + .prompt(similarQuestion) + .advisors(cacheAdvisor) + .call() + .chatResponse(); + + assertThat(chatResponse).isNotNull(); + assertThat(chatResponse.getResult().getOutput().getText()).containsIgnoringCase("Paris"); + + // Wait for TTL to expire + Thread.sleep(2100); + + // Verify key is gone + assertThat(jedis.exists(key)).isFalse(); + + // Should get a cache miss and new response + ChatResponse newResponse = ChatClient.builder(openAiChatModel) + .build() + .prompt(similarQuestion) + .advisors(cacheAdvisor) + .call() + .chatResponse(); + + assertThat(newResponse).isNotNull(); + assertThat(newResponse.getResult().getOutput().getText()).containsIgnoringCase("Paris"); + // Original cached response should be gone, this should be a fresh response + }); + } + + private ChatResponse createMockResponse(String text) { + return ChatResponse.builder().generations(List.of(new Generation(new AssistantMessage(text)))).build(); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + public static class TestApplication { + + @Bean + public SemanticCache semanticCache(EmbeddingModel embeddingModel, + JedisConnectionFactory jedisConnectionFactory) { + JedisPooled jedisPooled = new JedisPooled(Objects.requireNonNull(jedisConnectionFactory.getPoolConfig()), + jedisConnectionFactory.getHostName(), jedisConnectionFactory.getPort()); + + return DefaultSemanticCache.builder().embeddingModel(embeddingModel).jedisClient(jedisPooled).build(); + } + + @Bean(name = "openAiEmbeddingModel") + public EmbeddingModel embeddingModel() { + return new OpenAiEmbeddingModel(new OpenAiApi(System.getenv("OPENAI_API_KEY"))); + } + + @Bean(name = "openAiChatModel") + public OpenAiChatModel openAiChatModel() { + var openAiApi = new OpenAiApi(System.getenv("OPENAI_API_KEY")); + var openAiChatOptions = OpenAiChatOptions.builder() + .model("gpt-3.5-turbo") + .temperature(0.4) + .maxTokens(200) + .build(); + return new OpenAiChatModel(openAiApi, openAiChatOptions); + } + + } + +} diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/chat/memory/redis/RedisChatMemoryIT.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/chat/memory/redis/RedisChatMemoryIT.java new file mode 100644 index 00000000000..dfc9f0c1af8 --- /dev/null +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/chat/memory/redis/RedisChatMemoryIT.java @@ -0,0 +1,227 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.chat.memory.redis; + +import com.redis.testcontainers.RedisStackContainer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.autoconfigure.jdbc.DataSourceAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import redis.clients.jedis.JedisPooled; + +import java.time.Duration; +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for RedisChatMemory using Redis Stack TestContainer. + * + * @author Brian Sam-Bodden + */ +@Testcontainers +class RedisChatMemoryIT { + + @Container + static RedisStackContainer redisContainer = new RedisStackContainer( + RedisStackContainer.DEFAULT_IMAGE_NAME.withTag(RedisStackContainer.DEFAULT_TAG)); + + private final ApplicationContextRunner contextRunner = new ApplicationContextRunner() + .withUserConfiguration(TestApplication.class); + + private RedisChatMemory chatMemory; + + private JedisPooled jedisClient; + + @BeforeEach + void setUp() { + jedisClient = new JedisPooled(redisContainer.getHost(), redisContainer.getFirstMappedPort()); + chatMemory = RedisChatMemory.builder() + .jedisClient(jedisClient) + .indexName("test-" + RedisChatMemoryConfig.DEFAULT_INDEX_NAME) + .build(); + + chatMemory.clear("test-conversation"); + } + + @AfterEach + void tearDown() { + if (jedisClient != null) { + jedisClient.close(); + } + } + + @Test + void shouldStoreAndRetrieveMessages() { + this.contextRunner.run(context -> { + String conversationId = "test-conversation"; + + // Add messages + chatMemory.add(conversationId, new UserMessage("Hello")); + chatMemory.add(conversationId, new AssistantMessage("Hi there!")); + chatMemory.add(conversationId, new UserMessage("How are you?")); + + // Retrieve messages + List messages = chatMemory.get(conversationId, 10); + + assertThat(messages).hasSize(3); + assertThat(messages.get(0).getText()).isEqualTo("Hello"); + assertThat(messages.get(1).getText()).isEqualTo("Hi there!"); + assertThat(messages.get(2).getText()).isEqualTo("How are you?"); + }); + } + + @Test + void shouldRespectMessageLimit() { + this.contextRunner.run(context -> { + String conversationId = "test-conversation"; + + // Add messages + chatMemory.add(conversationId, new UserMessage("Message 1")); + chatMemory.add(conversationId, new AssistantMessage("Message 2")); + chatMemory.add(conversationId, new UserMessage("Message 3")); + + // Retrieve limited messages + List messages = chatMemory.get(conversationId, 2); + + assertThat(messages).hasSize(2); + }); + } + + @Test + void shouldClearConversation() { + this.contextRunner.run(context -> { + String conversationId = "test-conversation"; + + // Add messages + chatMemory.add(conversationId, new UserMessage("Hello")); + chatMemory.add(conversationId, new AssistantMessage("Hi")); + + // Clear conversation + chatMemory.clear(conversationId); + + // Verify messages are cleared + List messages = chatMemory.get(conversationId, 10); + assertThat(messages).isEmpty(); + }); + } + + @Test + void shouldHandleBatchMessageAddition() { + this.contextRunner.run(context -> { + String conversationId = "test-conversation"; + List messageBatch = List.of(new UserMessage("Message 1"), // + new AssistantMessage("Response 1"), // + new UserMessage("Message 2"), // + new AssistantMessage("Response 2") // + ); + + // Add batch of messages + chatMemory.add(conversationId, messageBatch); + + // Verify all messages were stored + List retrievedMessages = chatMemory.get(conversationId, 10); + assertThat(retrievedMessages).hasSize(4); + }); + } + + @Test + void shouldHandleTimeToLive() throws InterruptedException { + this.contextRunner.run(context -> { + RedisChatMemory shortTtlMemory = RedisChatMemory.builder() + .jedisClient(jedisClient) + .indexName("test-ttl-" + RedisChatMemoryConfig.DEFAULT_INDEX_NAME) + .timeToLive(Duration.ofSeconds(2)) + .keyPrefix("short-lived:") + .build(); + + String conversationId = "test-conversation"; + shortTtlMemory.add(conversationId, new UserMessage("This should expire")); + + // Verify message exists + assertThat(shortTtlMemory.get(conversationId, 1)).hasSize(1); + + // Wait for TTL to expire + Thread.sleep(2000); + + // Verify message is gone + assertThat(shortTtlMemory.get(conversationId, 1)).isEmpty(); + }); + } + + @Test + void shouldMaintainMessageOrder() { + this.contextRunner.run(context -> { + String conversationId = "test-conversation"; + // Add messages with minimal delay to test timestamp ordering + chatMemory.add(conversationId, new UserMessage("First")); + Thread.sleep(10); + chatMemory.add(conversationId, new AssistantMessage("Second")); + Thread.sleep(10); + chatMemory.add(conversationId, new UserMessage("Third")); + + List messages = chatMemory.get(conversationId, 10); + assertThat(messages).hasSize(3); + assertThat(messages.get(0).getText()).isEqualTo("First"); + assertThat(messages.get(1).getText()).isEqualTo("Second"); + assertThat(messages.get(2).getText()).isEqualTo("Third"); + }); + } + + @Test + void shouldHandleMultipleConversations() { + this.contextRunner.run(context -> { + String conv1 = "conversation-1"; + String conv2 = "conversation-2"; + + chatMemory.add(conv1, new UserMessage("Conv1 Message")); + chatMemory.add(conv2, new UserMessage("Conv2 Message")); + + List conv1Messages = chatMemory.get(conv1, 10); + List conv2Messages = chatMemory.get(conv2, 10); + + assertThat(conv1Messages).hasSize(1); + assertThat(conv2Messages).hasSize(1); + assertThat(conv1Messages.get(0).getText()).isEqualTo("Conv1 Message"); + assertThat(conv2Messages.get(0).getText()).isEqualTo("Conv2 Message"); + }); + } + + @SpringBootConfiguration + @EnableAutoConfiguration(exclude = { DataSourceAutoConfiguration.class }) + static class TestApplication { + + @Bean + RedisChatMemory chatMemory() { + return RedisChatMemory.builder() + .jedisClient(new JedisPooled(redisContainer.getHost(), redisContainer.getFirstMappedPort())) + .indexName("test-" + RedisChatMemoryConfig.DEFAULT_INDEX_NAME) + .timeToLive(Duration.ofMinutes(5)) + .build(); + } + + } + +} diff --git a/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/redis/RedisVectorStoreWithChatMemoryAdvisorIT.java b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/redis/RedisVectorStoreWithChatMemoryAdvisorIT.java new file mode 100644 index 00000000000..34f57a7b96f --- /dev/null +++ b/vector-stores/spring-ai-redis-store/src/test/java/org/springframework/ai/vectorstore/redis/RedisVectorStoreWithChatMemoryAdvisorIT.java @@ -0,0 +1,133 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore.redis; + +import com.redis.testcontainers.RedisStackContainer; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mockito; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.VectorStoreChatMemoryAdvisor; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.SystemMessage; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.document.Document; +import org.springframework.ai.embedding.EmbeddingModel; +import org.springframework.ai.vectorstore.redis.RedisVectorStore.MetadataField; +import org.testcontainers.junit.jupiter.Container; +import org.testcontainers.junit.jupiter.Testcontainers; +import redis.clients.jedis.JedisPooled; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Integration tests for RedisVectorStore using Redis Stack TestContainer. + * + * @author Brian Sam-Bodden + */ +@Testcontainers +class RedisVectorStoreWithChatMemoryAdvisorIT { + + @Container + static RedisStackContainer redisContainer = new RedisStackContainer( + RedisStackContainer.DEFAULT_IMAGE_NAME.withTag(RedisStackContainer.DEFAULT_TAG)); + + float[] embed = { 0.003961659F, -0.0073295482F, 0.02663665F }; + + @Test + @DisplayName("Advised chat should have similar messages from vector store") + void advisedChatShouldHaveSimilarMessagesFromVectorStore() throws Exception { + // Mock chat model + ChatModel chatModel = chatModelAlwaysReturnsTheSameReply(); + // Mock embedding model + EmbeddingModel embeddingModel = embeddingModelShouldAlwaysReturnFakedEmbed(); + + // Create Redis store with dimensions matching our fake embeddings + RedisVectorStore store = RedisVectorStore + .builder(new JedisPooled(redisContainer.getHost(), redisContainer.getFirstMappedPort()), embeddingModel) + .metadataFields(MetadataField.tag("conversationId"), MetadataField.tag("messageType")) + .initializeSchema(true) + .build(); + + store.afterPropertiesSet(); + + // Initialize store with test data + store.add(List.of(new Document("Tell me a good joke", Map.of("conversationId", "default")), + new Document("Tell me a bad joke", Map.of("conversationId", "default", "messageType", "USER")))); + + // Run chat with advisor + ChatClient.builder(chatModel) + .build() + .prompt() + .user("joke") + .advisors(VectorStoreChatMemoryAdvisor.builder(store).build()) + .call() + .chatResponse(); + + verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(chatModel); + } + + private static ChatModel chatModelAlwaysReturnsTheSameReply() { + ChatModel chatModel = mock(ChatModel.class); + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(Prompt.class); + ChatResponse chatResponse = new ChatResponse(List.of(new Generation(new AssistantMessage(""" + Why don't scientists trust atoms? + Because they make up everything! + """)))); + given(chatModel.call(argumentCaptor.capture())).willReturn(chatResponse); + return chatModel; + } + + private EmbeddingModel embeddingModelShouldAlwaysReturnFakedEmbed() { + EmbeddingModel embeddingModel = mock(EmbeddingModel.class); + Mockito.doAnswer(invocationOnMock -> List.of(this.embed, this.embed)) + .when(embeddingModel) + .embed(any(), any(), any()); + given(embeddingModel.embed(any(String.class))).willReturn(this.embed); + given(embeddingModel.dimensions()).willReturn(3); // Explicit dimensions matching + // embed array + return embeddingModel; + } + + private static void verifyRequestHasBeenAdvisedWithMessagesFromVectorStore(ChatModel chatModel) { + ArgumentCaptor promptCaptor = ArgumentCaptor.forClass(Prompt.class); + verify(chatModel).call(promptCaptor.capture()); + assertThat(promptCaptor.getValue().getInstructions().get(0)).isInstanceOf(SystemMessage.class); + assertThat(promptCaptor.getValue().getInstructions().get(0).getText()).isEqualTo(""" + + Use the long term conversation memory from the LONG_TERM_MEMORY section to provide accurate answers. + + --------------------- + LONG_TERM_MEMORY: + Tell me a good joke + Tell me a bad joke + --------------------- + """); + } + +}