Skip to content

Commit 7b4c5b8

Browse files
authored
Fix Vector Support and add Samples for DBaas and Dse (#27)
1 parent 99ade34 commit 7b4c5b8

File tree

7 files changed

+420
-2
lines changed

7 files changed

+420
-2
lines changed

Diff for: pom.xml

+22
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,13 @@
6767
<role>developer</role>
6868
</roles>
6969
</contributor>
70+
<contributor>
71+
<name>Cedrick Lunven</name>
72+
<url>https://github.com/clun</url>
73+
<roles>
74+
<role>developer</role>
75+
</roles>
76+
</contributor>
7077
</contributors>
7178

7279
<scm>
@@ -106,6 +113,7 @@
106113
<mockito.version>3.12.4</mockito.version>
107114
<slf4j.version>1.7.36</slf4j.version>
108115
<testcontainers.version>1.18.3</testcontainers.version>
116+
<astra-sdk.version>0.6.11</astra-sdk.version>
109117
<!-- Versions for plugins -->
110118
<maven-checkstyle-plugin.version>3.3.0</maven-checkstyle-plugin.version>
111119
<maven-clean-plugin.version>3.3.1</maven-clean-plugin.version>
@@ -233,6 +241,20 @@
233241
<version>${testcontainers.version}</version>
234242
<scope>test</scope>
235243
</dependency>
244+
<!-- Astra Test instances for integration tests -->
245+
<dependency>
246+
<groupId>com.datastax.astra</groupId>
247+
<artifactId>astra-sdk-devops</artifactId>
248+
<version>${astra-sdk.version}</version>
249+
<scope>test</scope>
250+
</dependency>
251+
<!-- handy to build the Cql Queries -->
252+
<dependency>
253+
<groupId>com.datastax.oss</groupId>
254+
<artifactId>java-driver-query-builder</artifactId>
255+
<version>${datastax.java.driver.version}</version>
256+
<scope>test</scope>
257+
</dependency>
236258
<!-- Logging for tests -->
237259
<dependency>
238260
<groupId>org.slf4j</groupId>

Diff for: src/main/java/com/ing/data/cassandra/jdbc/CassandraResultSet.java

+3
Original file line numberDiff line numberDiff line change
@@ -1593,6 +1593,9 @@ public String getColumnTypeName(final int column) {
15931593
} else {
15941594
dataType = driverResultSet.getColumnDefinitions().get(column - 1).getType();
15951595
}
1596+
if (dataType.toString().contains(DataTypeEnum.VECTOR.cqlType)) {
1597+
return DataTypeEnum.VECTOR.cqlType;
1598+
}
15961599
return dataType.toString();
15971600
}
15981601

Diff for: src/main/java/com/ing/data/cassandra/jdbc/types/DataTypeEnum.java

+14-2
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import com.datastax.oss.driver.api.core.data.UdtValue;
2222
import com.datastax.oss.driver.api.core.type.DataTypes;
2323
import com.datastax.oss.driver.api.core.type.UserDefinedType;
24+
import com.datastax.oss.driver.api.core.type.VectorType;
2425
import com.datastax.oss.protocol.internal.ProtocolConstants.DataType;
2526

2627
import javax.annotation.Nonnull;
@@ -169,7 +170,7 @@ public enum DataTypeEnum {
169170
* {@code vector} CQL type (type {@value DataType#LIST} in CQL native protocol) mapped to {@link CqlVector} Java
170171
* type.
171172
*/
172-
VECTOR(DataType.LIST, CqlVector.class, "vector");
173+
VECTOR(DataType.LIST, CqlVector.class, "Vector");
173174

174175
private static final Map<String, DataTypeEnum> CQL_DATATYPE_TO_DATATYPE;
175176

@@ -184,6 +185,8 @@ public enum DataTypeEnum {
184185

185186
final int protocolId;
186187

188+
static final String VECTOR_CLASSNAME = "org.apache.cassandra.db.marshal.VectorType";
189+
187190
static {
188191
CQL_DATATYPE_TO_DATATYPE = new HashMap<>();
189192
for (final DataTypeEnum dataType : DataTypeEnum.values()) {
@@ -217,6 +220,9 @@ public static DataTypeEnum fromCqlTypeName(final String cqlTypeName) {
217220
if (cqlTypeName.startsWith(UDT.cqlType)) {
218221
return UDT;
219222
}
223+
if (cqlTypeName.contains(VECTOR_CLASSNAME)) {
224+
return VECTOR;
225+
}
220226
// Manage collection types (e.g. "list<varchar>")
221227
final int collectionTypeCharPos = cqlTypeName.indexOf("<");
222228
String cqlDataType = cqlTypeName;
@@ -236,6 +242,9 @@ public static DataTypeEnum fromDataType(final com.datastax.oss.driver.api.core.t
236242
if (dataType instanceof UserDefinedType) {
237243
return UDT;
238244
}
245+
if (dataType instanceof VectorType) {
246+
return VECTOR;
247+
}
239248
return fromCqlTypeName(dataType.asCql(false, false));
240249
}
241250

@@ -320,12 +329,15 @@ public String toString() {
320329

321330
/**
322331
* Gets the CQL name from a given {@link com.datastax.oss.driver.api.core.type.DataType} instance.
332+
* For vectors, dataType.asCql returns looks like 'org.apache.cassandra.db.marshal.VectorType(n)' where n is
333+
* the dimension of the vector. In this specific case, return a common name not including the dimension.
323334
*
324335
* @param dataType The data type.
325336
* @return The CQL name of the type.
326337
*/
327338
public static String cqlName(@Nonnull final com.datastax.oss.driver.api.core.type.DataType dataType) {
328-
return dataType.asCql(false, false);
339+
final String rawCql = dataType.asCql(false, false);
340+
return rawCql.contains(VECTOR_CLASSNAME) ? VECTOR.cqlType : rawCql;
329341
}
330342
}
331343

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
/*
2+
* Licensed under the Apache License, Version 2.0 (the "License");
3+
* you may not use this file except in compliance with the License.
4+
* You may obtain a copy of the License at
5+
*
6+
* http://www.apache.org/licenses/LICENSE-2.0
7+
*
8+
* Unless required by applicable law or agreed to in writing, software
9+
* distributed under the License is distributed on an "AS IS" BASIS,
10+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11+
* See the License for the specific language governing permissions and
12+
* limitations under the License.
13+
*/
14+
package com.ing.data.cassandra.jdbc;
15+
16+
import com.datastax.oss.driver.api.core.type.DataTypes;
17+
import com.datastax.oss.driver.api.querybuilder.SchemaBuilder;
18+
import com.dtsx.astra.sdk.db.AstraDbClient;
19+
import com.dtsx.astra.sdk.db.domain.DatabaseStatusType;
20+
import com.dtsx.astra.sdk.utils.TestUtils;
21+
import org.junit.jupiter.api.AfterAll;
22+
import org.junit.jupiter.api.Assertions;
23+
import org.junit.jupiter.api.BeforeAll;
24+
import org.junit.jupiter.api.Order;
25+
import org.junit.jupiter.api.Test;
26+
import org.junit.jupiter.api.TestMethodOrder;
27+
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
28+
import org.slf4j.Logger;
29+
import org.slf4j.LoggerFactory;
30+
31+
import java.sql.DriverManager;
32+
import java.sql.ResultSet;
33+
import java.sql.SQLException;
34+
35+
/**
36+
* Test JDBC Driver against DbAAS Astra.
37+
* To run this test define environment variable ASTRA_DB_APPLICATION_TOKEN
38+
* but not having any token does not block the build.
39+
*/
40+
@TestMethodOrder(org.junit.jupiter.api.MethodOrderer.OrderAnnotation.class)
41+
class DbaasAstraIntegrationTest {
42+
43+
private static final Logger log = LoggerFactory.getLogger(DbaasAstraIntegrationTest.class);
44+
private static final String DATABASE_NAME = "test_cassandra_jdbc";
45+
private static final String KEYSPACE_NAME = "test";
46+
static CassandraConnection sqlConnection = null;
47+
48+
@BeforeAll
49+
static void setupAstra() throws Exception {
50+
if (System.getenv("ASTRA_DB_APPLICATION_TOKEN") != null) {
51+
log.debug("ASTRA_DB_APPLICATION_TOKEN is provided, Astra Test is executed");
52+
53+
54+
/*
55+
* Devops API Client (create database, resume, delete)
56+
*/
57+
AstraDbClient astraDbClient = new AstraDbClient(TestUtils.getAstraToken());
58+
log.debug("Connected the dbaas API");
59+
60+
/*
61+
* Set up a Database in Astra : create if not exist, resume if needed
62+
* Vector Database is Cassandra DB with vector support enabled.
63+
* It can take up to 1 min to create the database if not exists
64+
*/
65+
String dbId = TestUtils.setupVectorDatabase(DATABASE_NAME, KEYSPACE_NAME);
66+
Assertions.assertTrue(astraDbClient.findById(dbId).isPresent());
67+
Assertions.assertEquals(DatabaseStatusType.ACTIVE, astraDbClient.findById(dbId).get().getStatus());
68+
log.debug("Database ready");
69+
70+
/*
71+
* Download cloud secure bundle to connect to the database.
72+
* - Saved in /tmp
73+
* - Single region = we can use default region
74+
*/
75+
astraDbClient
76+
.database(dbId)
77+
.downloadDefaultSecureConnectBundle("/tmp/" + DATABASE_NAME + "_scb.zip");
78+
log.debug("Connection bundle downloaded.");
79+
80+
/*
81+
* Building jdbcUrl and sqlConnection.
82+
* Note: Astra can be access with only a token (username='token')
83+
*/
84+
sqlConnection = (CassandraConnection) DriverManager.getConnection(
85+
"jdbc:cassandra://dbaas/" + KEYSPACE_NAME +
86+
"?user=" + "token" +
87+
"&password=" + TestUtils.getAstraToken() + // env var ASTRA_DB_APPLICATION_TOKEN
88+
"&consistency=" + "LOCAL_QUORUM" +
89+
"&secureconnectbundle=/tmp/" + DATABASE_NAME + "_scb.zip");
90+
} else {
91+
log.debug("ASTRA_DB_APPLICATION_TOKEN is not defined, skipping ASTRA test");
92+
}
93+
}
94+
95+
@Test
96+
@Order(1)
97+
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
98+
void givenConnection_whenCreateTable_shouldTableExist() throws SQLException {
99+
// Given
100+
Assertions.assertNotNull(sqlConnection);
101+
// When
102+
sqlConnection.createStatement().execute(SchemaBuilder
103+
.createTable("simple_table")
104+
.ifNotExists()
105+
.withPartitionKey("email", DataTypes.TEXT)
106+
.withColumn("firstname", DataTypes.TEXT)
107+
.withColumn("lastname", DataTypes.TEXT)
108+
.build().getQuery());
109+
// Then
110+
Assertions.assertTrue(tableExist("simple_table"));
111+
}
112+
113+
@Test
114+
@Order(2)
115+
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
116+
void givenTable_whenInsert_shouldRetrieveData() throws Exception {
117+
// Given
118+
Assertions.assertTrue(tableExist("simple_table"));
119+
// When
120+
String insertSimpleCQL = "INSERT INTO simple_table (email, firstname, lastname) VALUES(?,?,?)";
121+
final CassandraPreparedStatement prepStatement = sqlConnection.prepareStatement(insertSimpleCQL);
122+
prepStatement.setString(1, "[email protected]");
123+
prepStatement.setString(2, "pierre");
124+
prepStatement.setString(2, "feuille");
125+
prepStatement.execute();
126+
// Then (warning on Cassandra expected)
127+
Assertions.assertEquals(1, countRecords("simple_table"));
128+
}
129+
130+
@Test
131+
@Order(3)
132+
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
133+
void givenConnection_whenCreateTableVector_shouldTableExist() throws Exception {
134+
// When
135+
sqlConnection.createStatement().execute("" +
136+
"CREATE TABLE IF NOT EXISTS pet_supply_vectors (" +
137+
" product_id TEXT PRIMARY KEY," +
138+
" product_name TEXT," +
139+
" product_vector vector<float, 14>)");
140+
// Then
141+
Assertions.assertTrue(tableExist("pet_supply_vectors"));
142+
sqlConnection.createStatement().execute("" +
143+
"CREATE CUSTOM INDEX IF NOT EXISTS idx_vector " +
144+
"ON pet_supply_vectors(product_vector) " +
145+
"USING 'StorageAttachedIndex'");
146+
// When
147+
sqlConnection.createStatement().execute("" +
148+
"INSERT INTO pet_supply_vectors (product_id, product_name, product_vector) " +
149+
"VALUES ('pf1843','HealthyFresh - Chicken raw dog food',[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])");
150+
sqlConnection.createStatement().execute("" +
151+
"INSERT INTO pet_supply_vectors (product_id, product_name, product_vector) " +
152+
"VALUES ('pf1844','HealthyFresh - Beef raw dog food',[1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0])");
153+
sqlConnection.createStatement().execute("" +
154+
"INSERT INTO pet_supply_vectors (product_id, product_name, product_vector) " +
155+
"VALUES ('pt0021','Dog Tennis Ball Toy',[0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0])");
156+
sqlConnection.createStatement().execute("" +
157+
"INSERT INTO pet_supply_vectors (product_id, product_name, product_vector) " +
158+
"VALUES ('pt0041','Dog Ring Chew Toy',[0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 0])");
159+
sqlConnection.createStatement().execute("" +
160+
"INSERT INTO pet_supply_vectors (product_id, product_name, product_vector) " +
161+
"VALUES ('pf7043','PupperSausage Bacon dog Treats',[0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1])");
162+
sqlConnection.createStatement().execute("" +
163+
"INSERT INTO pet_supply_vectors (product_id, product_name, product_vector) " +
164+
"VALUES ('pf7044','PupperSausage Beef dog Treats',[0, 0, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0])");
165+
// Then (warning on Cassandra expected)
166+
Assertions.assertEquals(6, countRecords("pet_supply_vectors"));
167+
}
168+
169+
@Test
170+
@Order(4)
171+
@EnabledIfEnvironmentVariable(named = "ASTRA_DB_APPLICATION_TOKEN", matches = "Astra.*")
172+
void givenVectorTable_whenSimilaritySearch_shouldReturnResults() throws Exception {
173+
// Given
174+
Assertions.assertTrue(tableExist("pet_supply_vectors"));
175+
Assertions.assertEquals(6, countRecords("pet_supply_vectors"));
176+
// When
177+
final CassandraPreparedStatement prepStatement = sqlConnection.prepareStatement("" +
178+
"SELECT\n" +
179+
" product_id, product_vector,\n" +
180+
" similarity_dot_product(product_vector,[1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]) as similarity\n" +
181+
"FROM pet_supply_vectors\n" +
182+
"ORDER BY product_vector\n" +
183+
"ANN OF [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0]\n" +
184+
"LIMIT 2;");
185+
java.sql.ResultSet rs = prepStatement.executeQuery();
186+
// A result has been found
187+
Assertions.assertTrue(rs.next());
188+
// Parsing Results
189+
Assertions.assertNotNull(rs.getObject("product_vector"));
190+
Assertions.assertEquals(3.0d, rs.getDouble("similarity"));
191+
}
192+
193+
private boolean tableExist(String tableName) throws SQLException {
194+
String existTableCql = "select table_name,keyspace_name from system_schema.tables where keyspace_name=? and table_name=?";
195+
final CassandraPreparedStatement prepStatement = sqlConnection.prepareStatement(existTableCql);
196+
prepStatement.setString(1, KEYSPACE_NAME);
197+
prepStatement.setString(2, tableName);
198+
return prepStatement.executeQuery().next();
199+
}
200+
201+
private int countRecords(String tablename) throws SQLException {
202+
String countRecordsCql = "select count(*) from " + tablename;
203+
final CassandraPreparedStatement prepStatement = sqlConnection.prepareStatement(countRecordsCql);
204+
final ResultSet resultSet = prepStatement.executeQuery();
205+
resultSet.next();
206+
return resultSet.getInt(1);
207+
}
208+
209+
@AfterAll
210+
static void closeSql() throws SQLException {
211+
if (sqlConnection != null) {
212+
sqlConnection.close();
213+
}
214+
}
215+
216+
}
217+
218+

0 commit comments

Comments
 (0)