Skip to content

Commit

Permalink
feat: JDBC implementation of ChatMemory
Browse files Browse the repository at this point in the history
Signed-off-by: leijendary <[email protected]>

feat: JDBC implementation of ChatMemory

Signed-off-by: leijendary <[email protected]>

feat: JDBC implementation of ChatMemory

Signed-off-by: leijendary <[email protected]>
  • Loading branch information
leijendary committed Jan 25, 2025
1 parent 822576b commit af61678
Show file tree
Hide file tree
Showing 26 changed files with 1,073 additions and 6 deletions.
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ One way to run integration tests on part of the code is to first do a quick comp
```shell
./mvnw clean install -DskipTests -Dmaven.javadoc.skip=true
```
Then run the integration test for a specifi module using the `-pl` option
Then run the integration test for a specific module using the `-pl` option
```shell
./mvnw verify -Pintegration-tests -pl spring-ai-spring-boot-autoconfigure
./mvnw verify -Pintegration-tests -pl spring-ai-spring-boot-autoconfigure
```

### Documentation
Expand Down Expand Up @@ -134,4 +134,4 @@ To build with checkstyles enabled.
Checkstyles are currently disabled, but you can enable them by doing the following:
```shell
./mvnw clean package -DskipTests -Ddisable.checks=false
```
```
1 change: 1 addition & 0 deletions chat-memory/spring-ai-chat-memory-jdbc/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[Chat Memory Documentation](https://docs.spring.io/spring-ai/reference/api/chatclient.html#_chat_memory)
107 changes: 107 additions & 0 deletions chat-memory/spring-ai-chat-memory-jdbc/pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
~ Copyright 2023-2024 the original author or authors.
~
~ Licensed under the Apache License, Version 2.0 (the "License");
~ you may not use this file except in compliance with the License.
~ You may obtain a copy of the License at
~
~ https://www.apache.org/licenses/LICENSE-2.0
~
~ Unless required by applicable law or agreed to in writing, software
~ distributed under the License is distributed on an "AS IS" BASIS,
~ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
~ See the License for the specific language governing permissions and
~ limitations under the License.
-->

<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>
<parent>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai</artifactId>
<version>1.0.0-SNAPSHOT</version>
<relativePath>../../pom.xml</relativePath>
</parent>
<artifactId>spring-ai-chat-memory-jdbc</artifactId>
<packaging>jar</packaging>
<name>Spring AI Chat Memory JDBC</name>
<description>Spring AI Chat Memory implementation with JDBC</description>
<url>https://github.com/spring-projects/spring-ai</url>

<scm>
<url>https://github.com/spring-projects/spring-ai</url>
<connection>git://github.com/spring-projects/spring-ai.git</connection>
<developerConnection>[email protected]:spring-projects/spring-ai.git</developerConnection>
</scm>

<properties>
<maven.compiler.source>17</maven.compiler.source>
<maven.compiler.target>17</maven.compiler.target>
</properties>

<dependencies>
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-core</artifactId>
<version>${project.parent.version}</version>
</dependency>

<dependency>
<groupId>com.zaxxer</groupId>
<artifactId>HikariCP</artifactId>
</dependency>

<dependency>
<groupId>org.springframework</groupId>
<artifactId>spring-jdbc</artifactId>
</dependency>

<dependency>
<groupId>org.postgresql</groupId>
<artifactId>postgresql</artifactId>
<version>${postgresql.version}</version>
<optional>true</optional>
</dependency>

<dependency>
<groupId>org.mariadb.jdbc</groupId>
<artifactId>mariadb-java-client</artifactId>
<version>${mariadb.version}</version>
<optional>true</optional>
</dependency>

<!-- TESTING -->
<dependency>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-test</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>testcontainers</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>postgresql</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>mariadb</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.testcontainers</groupId>
<artifactId>junit-jupiter</artifactId>
<scope>test</scope>
</dependency>
</dependencies>
</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
/*
* Copyright 2024-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.chat.memory.jdbc;

import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.List;

import org.springframework.ai.chat.memory.ChatMemory;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.MessageType;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.jdbc.core.BatchPreparedStatementSetter;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.RowMapper;

/**
* An implementation of {@link ChatMemory} for JDBC. Creating an instance of
* JdbcChatMemory example:
* <code>JdbcChatMemory.create(JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build());</code>
*
* @author Jonathan Leijendekker
* @since 1.0.0
*/
public class JdbcChatMemory implements ChatMemory {

private static final String QUERY_ADD = """
INSERT INTO ai_chat_memory (conversation_id, content, type) VALUES (?, ?, ?)""";

private static final String QUERY_GET = """
SELECT content, type FROM ai_chat_memory WHERE conversation_id = ? ORDER BY "timestamp" DESC LIMIT ?""";

private static final String QUERY_CLEAR = "DELETE FROM ai_chat_memory WHERE conversation_id = ?";

private final JdbcTemplate jdbcTemplate;

public JdbcChatMemory(JdbcChatMemoryConfig config) {
this.jdbcTemplate = config.getJdbcTemplate();
}

public static JdbcChatMemory create(JdbcChatMemoryConfig config) {
return new JdbcChatMemory(config);
}

@Override
public void add(String conversationId, List<Message> messages) {
this.jdbcTemplate.batchUpdate(QUERY_ADD, new AddBatchPreparedStatement(conversationId, messages));
}

@Override
public List<Message> get(String conversationId, int lastN) {
return this.jdbcTemplate.query(QUERY_GET, new MessageRowMapper(), conversationId, lastN);
}

@Override
public void clear(String conversationId) {
this.jdbcTemplate.update(QUERY_CLEAR, conversationId);
}

private record AddBatchPreparedStatement(String conversationId,
List<Message> messages) implements BatchPreparedStatementSetter {
@Override
public void setValues(PreparedStatement ps, int i) throws SQLException {
var message = this.messages.get(i);

ps.setString(1, this.conversationId);
ps.setString(2, message.getText());
ps.setString(3, message.getMessageType().name());
}

@Override
public int getBatchSize() {
return this.messages.size();
}
}

private static class MessageRowMapper implements RowMapper<Message> {

@Override
public Message mapRow(ResultSet rs, int i) throws SQLException {
var content = rs.getString(1);
var type = MessageType.valueOf(rs.getString(2));

return switch (type) {
case USER -> new UserMessage(content);
case ASSISTANT -> new AssistantMessage(content);
default -> null;
};
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright 2024-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.chat.memory.jdbc;

import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.util.Assert;

/**
* Configuration for {@link JdbcChatMemory}.
*
* @author Jonathan Leijendekker
* @since 1.0.0
*/
public final class JdbcChatMemoryConfig {

private final JdbcTemplate jdbcTemplate;

private JdbcChatMemoryConfig(Builder builder) {
this.jdbcTemplate = builder.jdbcTemplate;
}

public static Builder builder() {
return new Builder();
}

JdbcTemplate getJdbcTemplate() {
return this.jdbcTemplate;
}

public static final class Builder {

private JdbcTemplate jdbcTemplate;

private Builder() {
}

public Builder jdbcTemplate(JdbcTemplate jdbcTemplate) {
Assert.notNull(jdbcTemplate, "jdbc template must not be null");

this.jdbcTemplate = jdbcTemplate;
return this;
}

public JdbcChatMemoryConfig build() {
Assert.notNull(this.jdbcTemplate, "jdbc template must not be null");

return new JdbcChatMemoryConfig(this);
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package org.springframework.ai.chat.memory.jdbc.aot.hint;

import javax.sql.DataSource;

import org.springframework.aot.hint.MemberCategory;
import org.springframework.aot.hint.RuntimeHints;
import org.springframework.aot.hint.RuntimeHintsRegistrar;

/**
* A {@link RuntimeHintsRegistrar} for JDBC Chat Memory hints
*
* @author Jonathan Leijendekker
*/
class JdbcChatMemoryRuntimeHints implements RuntimeHintsRegistrar {

@Override
public void registerHints(RuntimeHints hints, ClassLoader classLoader) {
hints.reflection()
.registerType(DataSource.class, (hint) -> hint.withMembers(MemberCategory.INVOKE_DECLARED_METHODS));

hints.resources()
.registerPattern("org/springframework/ai/chat/memory/jdbc/schema-drop-mariadb.sql")
.registerPattern("org/springframework/ai/chat/memory/jdbc/schema-drop-postgresql.sql")
.registerPattern("org/springframework/ai/chat/memory/jdbc/schema-mariadb.sql")
.registerPattern("org/springframework/ai/chat/memory/jdbc/schema-postgresql.sql");
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
org.springframework.aot.hint.RuntimeHintsRegistrar=\
org.springframework.ai.chat.memory.jdbc.aot.hint.JdbcChatMemoryRuntimeHints
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DROP TABLE IF EXISTS ai_chat_memory;
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
DROP TABLE IF EXISTS ai_chat_memory;
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
CREATE TABLE IF NOT EXISTS ai_chat_memory (
conversation_id VARCHAR(36) NOT NULL,
content TEXT NOT NULL,
type VARCHAR(10) NOT NULL,
`timestamp` TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
CONSTRAINT type_check CHECK (type IN ('USER', 'ASSISTANT'))
);

CREATE INDEX IF NOT EXISTS ai_chat_memory_conversation_id_timestamp_idx
ON ai_chat_memory(conversation_id, `timestamp` DESC);
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
CREATE TABLE IF NOT EXISTS ai_chat_memory (
conversation_id VARCHAR(36) NOT NULL,
content TEXT NOT NULL,
type VARCHAR(10) NOT NULL CHECK (type IN ('USER', 'ASSISTANT')),
"timestamp" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
);

CREATE INDEX IF NOT EXISTS ai_chat_memory_conversation_id_timestamp_idx
ON ai_chat_memory(conversation_id, "timestamp" DESC);
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package org.springframework.ai.chat.memory.jdbc;

import org.junit.jupiter.api.Test;

import org.springframework.jdbc.core.JdbcTemplate;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.Mockito.mock;

/**
* @author Jonathan Leijendekker
*/
class JdbcChatMemoryConfigTest {

@Test
void setValues() {
var jdbcTemplate = mock(JdbcTemplate.class);
var config = JdbcChatMemoryConfig.builder().jdbcTemplate(jdbcTemplate).build();

assertThat(config.getJdbcTemplate()).isEqualTo(jdbcTemplate);
}

@Test
void setJdbcTemplateToNull_shouldThrow() {
assertThatThrownBy(() -> JdbcChatMemoryConfig.builder().jdbcTemplate(null));
}

@Test
void buildWithNullJdbcTemplate_shouldThrow() {
assertThatThrownBy(() -> JdbcChatMemoryConfig.builder().build());
}

}
Loading

0 comments on commit af61678

Please sign in to comment.