Skip to content

Commit cbafc67

Browse files
authored
Merge pull request #390 from weaviate/v6-bm25
v6: add BM25 query with searchOperator support
2 parents 41188ea + 6e036e0 commit cbafc67

19 files changed

+13353
-7700
lines changed

src/it/java/io/weaviate/integration/SearchITest.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import io.weaviate.client6.v1.api.collections.data.Reference;
2323
import io.weaviate.client6.v1.api.collections.query.GroupBy;
2424
import io.weaviate.client6.v1.api.collections.query.Metadata;
25+
import io.weaviate.client6.v1.api.collections.query.QueryMetadata;
2526
import io.weaviate.client6.v1.api.collections.query.QueryResponseGroup;
2627
import io.weaviate.client6.v1.api.collections.query.Where;
2728
import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw;
@@ -255,4 +256,29 @@ public void testFetchObjectsWithFilters() throws IOException {
255256
hugeHat.metadata().uuid());
256257

257258
}
259+
260+
@Test
261+
public void testBm25() throws IOException {
262+
var nsWords = ns("Words");
263+
264+
client.collections.create(nsWords,
265+
collection -> collection
266+
.properties(
267+
Property.text("relevant"),
268+
Property.text("irrelevant")));
269+
270+
var words = client.collections.use(nsWords);
271+
272+
/* notWant */ words.data.insert(Map.of("relevant", "elefant", "irrelevant", "dollar bill"));
273+
var want = words.data.insert(Map.of("relevant", "a dime a dollar", "irrelevant", "euro"));
274+
275+
var dollarWorlds = words.query.bm25(
276+
"dollar",
277+
bm25 -> bm25.queryProperties("relevant"));
278+
279+
Assertions.assertThat(dollarWorlds.objects())
280+
.hasSize(1)
281+
.extracting(WeaviateObject::metadata).extracting(QueryMetadata::uuid)
282+
.containsOnly(want.metadata().uuid());
283+
}
258284
}

src/main/java/io/weaviate/client6/v1/api/collections/aggregate/ObjectFilter.java renamed to src/main/java/io/weaviate/client6/v1/api/collections/aggregate/AggregateObjectFilter.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,12 @@
44
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoAggregate.AggregateRequest.Builder;
55

66
// TODO: move Near-, Hybrid, BM25 under client.collection.operators? With them implementing query.SearchOperator and aggregate.ObjectFilter
7-
public interface ObjectFilter {
7+
public interface AggregateObjectFilter {
88
void appendTo(WeaviateProtoAggregate.AggregateRequest.Builder req);
99

10-
static ObjectFilter NONE = new ObjectFilter() {
10+
static AggregateObjectFilter NONE = new AggregateObjectFilter() {
1111
@Override
1212
public void appendTo(Builder req) {
13-
return;
1413
}
1514
};
1615
}

src/main/java/io/weaviate/client6/v1/api/collections/aggregate/Aggregation.java

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,16 @@
99
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoAggregate;
1010

1111
public record Aggregation(
12-
ObjectFilter filter,
12+
AggregateObjectFilter filter,
1313
Integer objectLimit,
1414
boolean includeTotalCount,
1515
List<PropertyAggregation> returnMetrics) {
1616

1717
public static Aggregation of(Function<Builder, ObjectBuilder<Aggregation>> fn) {
18-
return of(ObjectFilter.NONE, fn);
18+
return of(AggregateObjectFilter.NONE, fn);
1919
}
2020

21-
public static Aggregation of(ObjectFilter objectFilter, Function<Builder, ObjectBuilder<Aggregation>> fn) {
21+
public static Aggregation of(AggregateObjectFilter objectFilter, Function<Builder, ObjectBuilder<Aggregation>> fn) {
2222
return fn.apply(new Builder(objectFilter)).build();
2323
}
2424

@@ -31,9 +31,9 @@ public Aggregation(Builder builder) {
3131
}
3232

3333
public static class Builder implements ObjectBuilder<Aggregation> {
34-
private final ObjectFilter objectFilter;
34+
private final AggregateObjectFilter objectFilter;
3535

36-
public Builder(ObjectFilter objectFilter) {
36+
public Builder(AggregateObjectFilter objectFilter) {
3737
this.objectFilter = objectFilter;
3838
}
3939

src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@ abstract class AbstractQueryClient<PropertiesT, SingleT, ResponseT, GroupedRespo
1919

2020
protected abstract SingleT byId(ById byId);
2121

22-
protected abstract ResponseT performRequest(SearchOperator operator);
22+
protected abstract ResponseT performRequest(QueryOperator operator);
2323

24-
protected abstract GroupedResponseT performRequest(SearchOperator operator, GroupBy groupBy);
24+
protected abstract GroupedResponseT performRequest(QueryOperator operator, GroupBy groupBy);
2525

2626
// Fetch by ID --------------------------------------------------------------
2727

@@ -56,6 +56,32 @@ public GroupedResponseT fetchObjects(FetchObjects query, GroupBy groupBy) {
5656
return performRequest(query, groupBy);
5757
}
5858

59+
// BM25 queries -------------------------------------------------------------
60+
61+
public ResponseT bm25(String query) {
62+
return bm25(Bm25.of(query));
63+
}
64+
65+
public ResponseT bm25(String query, Function<Bm25.Builder, ObjectBuilder<Bm25>> fn) {
66+
return bm25(Bm25.of(query, fn));
67+
}
68+
69+
public ResponseT bm25(Bm25 query) {
70+
return performRequest(query);
71+
}
72+
73+
public GroupedResponseT bm25(String query, GroupBy groupBy) {
74+
return bm25(Bm25.of(query), groupBy);
75+
}
76+
77+
public GroupedResponseT bm25(String query, Function<Bm25.Builder, ObjectBuilder<Bm25>> fn, GroupBy groupBy) {
78+
return bm25(Bm25.of(query, fn), groupBy);
79+
}
80+
81+
public GroupedResponseT bm25(Bm25 query, GroupBy groupBy) {
82+
return performRequest(query, groupBy);
83+
}
84+
5985
// NearVector queries -------------------------------------------------------
6086

6187
public ResponseT nearVector(Float[] vector) {
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package io.weaviate.client6.v1.api.collections.query;
2+
3+
import java.util.Arrays;
4+
import java.util.List;
5+
import java.util.function.Function;
6+
7+
import io.weaviate.client6.v1.internal.ObjectBuilder;
8+
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch;
9+
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;
10+
11+
public record Bm25(String query, List<String> queryProperties, BaseQueryOptions common)
12+
implements QueryOperator {
13+
14+
public static final Bm25 of(String query) {
15+
return of(query, ObjectBuilder.identity());
16+
}
17+
18+
public static final Bm25 of(String query, Function<Builder, ObjectBuilder<Bm25>> fn) {
19+
return fn.apply(new Builder(query)).build();
20+
}
21+
22+
public Bm25(Builder builder) {
23+
this(builder.query, builder.queryProperties, builder.baseOptions());
24+
}
25+
26+
public static class Builder extends BaseQueryOptions.Builder<Builder, Bm25> {
27+
// Required query parameters.
28+
private final String query;
29+
30+
// Optional query parameters.
31+
List<String> queryProperties;
32+
SearchOperator searchOperator;
33+
34+
public Builder(String query) {
35+
this.query = query;
36+
}
37+
38+
public Builder queryProperties(String... properties) {
39+
return queryProperties(Arrays.asList(properties));
40+
}
41+
42+
public Builder queryProperties(List<String> properties) {
43+
this.queryProperties = properties;
44+
return this;
45+
}
46+
47+
public Builder searchOperator(SearchOperator searchOperator) {
48+
this.searchOperator = searchOperator;
49+
return this;
50+
}
51+
52+
@Override
53+
public final Bm25 build() {
54+
return new Bm25(this);
55+
}
56+
}
57+
58+
@Override
59+
public final void appendTo(WeaviateProtoSearchGet.SearchRequest.Builder req) {
60+
common.appendTo(req);
61+
req.setBm25Search(WeaviateProtoBaseSearch.BM25.newBuilder()
62+
.setQuery(query)
63+
.addAllProperties(queryProperties));
64+
}
65+
}

src/main/java/io/weaviate/client6/v1/api/collections/query/ById.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public record ById(
1313
String uuid,
1414
List<String> returnProperties,
1515
List<QueryReference> returnReferences,
16-
List<Metadata> returnMetadata) implements SearchOperator {
16+
List<Metadata> returnMetadata) implements QueryOperator {
1717

1818
private static final String ID_PROPERTY = "_id";
1919

src/main/java/io/weaviate/client6/v1/api/collections/query/FetchObjects.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import io.weaviate.client6.v1.internal.ObjectBuilder;
66
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;
77

8-
public record FetchObjects(BaseQueryOptions common) implements SearchOperator {
8+
public record FetchObjects(BaseQueryOptions common) implements QueryOperator {
99

1010
public static FetchObjects of(Function<Builder, ObjectBuilder<FetchObjects>> fn) {
1111
return fn.apply(new Builder()).build();

src/main/java/io/weaviate/client6/v1/api/collections/query/NearImage.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
import java.util.function.Function;
44

5-
import io.weaviate.client6.v1.api.collections.aggregate.ObjectFilter;
5+
import io.weaviate.client6.v1.api.collections.aggregate.AggregateObjectFilter;
66
import io.weaviate.client6.v1.internal.ObjectBuilder;
77
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoAggregate;
88
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch;
99
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;
1010

1111
public record NearImage(String image, Float distance, Float certainty, BaseQueryOptions common)
12-
implements SearchOperator, ObjectFilter {
12+
implements QueryOperator, AggregateObjectFilter {
1313

1414
public static NearImage of(String image) {
1515
return of(image, ObjectBuilder.identity());

src/main/java/io/weaviate/client6/v1/api/collections/query/NearText.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55
import java.util.List;
66
import java.util.function.Function;
77

8-
import io.weaviate.client6.v1.api.collections.aggregate.ObjectFilter;
8+
import io.weaviate.client6.v1.api.collections.aggregate.AggregateObjectFilter;
99
import io.weaviate.client6.v1.internal.ObjectBuilder;
1010
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoAggregate;
1111
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch;
1212
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;
1313

1414
public record NearText(List<String> concepts, Float distance, Float certainty, Move moveTo, Move moveAway,
15-
BaseQueryOptions common) implements SearchOperator, ObjectFilter {
15+
BaseQueryOptions common) implements QueryOperator, AggregateObjectFilter {
1616

1717
public static NearText of(String... concepts) {
1818
return of(Arrays.asList(concepts), ObjectBuilder.identity());

src/main/java/io/weaviate/client6/v1/api/collections/query/NearVector.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import java.util.function.Function;
44

5-
import io.weaviate.client6.v1.api.collections.aggregate.ObjectFilter;
5+
import io.weaviate.client6.v1.api.collections.aggregate.AggregateObjectFilter;
66
import io.weaviate.client6.v1.internal.ObjectBuilder;
77
import io.weaviate.client6.v1.internal.grpc.GRPC;
88
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoAggregate;
@@ -11,7 +11,7 @@
1111
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;
1212

1313
public record NearVector(Float[] vector, Float distance, Float certainty, BaseQueryOptions common)
14-
implements SearchOperator, ObjectFilter {
14+
implements QueryOperator, AggregateObjectFilter {
1515

1616
public static final NearVector of(Float[] vector) {
1717
return of(vector, ObjectBuilder.identity());
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
package io.weaviate.client6.v1.api.collections.query;
2+
3+
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;
4+
5+
interface QueryOperator {
6+
void appendTo(WeaviateProtoSearchGet.SearchRequest.Builder req);
7+
}

src/main/java/io/weaviate/client6/v1/api/collections/query/QueryRequest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import io.weaviate.client6.v1.internal.orm.CollectionDescriptor;
2222
import io.weaviate.client6.v1.internal.orm.PropertiesBuilder;
2323

24-
public record QueryRequest(SearchOperator operator, GroupBy groupBy) {
24+
public record QueryRequest(QueryOperator operator, GroupBy groupBy) {
2525

2626
static <T> Rpc<QueryRequest, WeaviateProtoSearchGet.SearchRequest, QueryResponse<T>, WeaviateProtoSearchGet.SearchReply> rpc(
2727
CollectionDescriptor<T> collection) {
Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,29 @@
11
package io.weaviate.client6.v1.api.collections.query;
22

3-
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;
3+
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch;
4+
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch.SearchOperatorOptions.Operator;
45

5-
interface SearchOperator {
6-
void appendTo(WeaviateProtoSearchGet.SearchRequest.Builder req);
6+
public class SearchOperator {
7+
private final String operator;
8+
private final Integer minimumOrTokensMatch;
9+
10+
public static final SearchOperator or(int minimumOrTokensMatch) {
11+
return new SearchOperator("Or", minimumOrTokensMatch);
12+
}
13+
14+
public static final SearchOperator and() {
15+
return new SearchOperator("And", 0);
16+
}
17+
18+
private SearchOperator(String operator, Integer minimumOrTokensMatch) {
19+
this.operator = operator;
20+
this.minimumOrTokensMatch = minimumOrTokensMatch;
21+
}
22+
23+
void appendTo(WeaviateProtoBaseSearch.SearchOperatorOptions.Builder options) {
24+
options.setOperator(operator == "And" ? Operator.OPERATOR_AND : Operator.OPERATOR_OR);
25+
if (minimumOrTokensMatch != null) {
26+
options.setMinimumOrTokensMatch(minimumOrTokensMatch);
27+
}
28+
}
729
}

src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClient.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ protected Optional<WeaviateObject<T, Object, QueryMetadata>> byId(ById byId) {
2323
}
2424

2525
@Override
26-
protected final QueryResponse<T> performRequest(SearchOperator operator) {
26+
protected final QueryResponse<T> performRequest(QueryOperator operator) {
2727
var request = new QueryRequest(operator, null);
2828
return this.transport.performRequest(request, QueryRequest.rpc(collection));
2929
}
3030

3131
@Override
32-
protected final QueryResponseGrouped<T> performRequest(SearchOperator operator, GroupBy groupBy) {
32+
protected final QueryResponseGrouped<T> performRequest(QueryOperator operator, GroupBy groupBy) {
3333
var request = new QueryRequest(operator, groupBy);
3434
return this.transport.performRequest(request, QueryRequest.grouped(collection));
3535
}

src/main/java/io/weaviate/client6/v1/api/collections/query/WeaviateQueryClientAsync.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@ protected CompletableFuture<Optional<WeaviateObject<T, Object, QueryMetadata>>>
2424
}
2525

2626
@Override
27-
protected final CompletableFuture<QueryResponse<T>> performRequest(SearchOperator operator) {
27+
protected final CompletableFuture<QueryResponse<T>> performRequest(QueryOperator operator) {
2828
var request = new QueryRequest(operator, null);
2929
return this.transport.performRequestAsync(request, QueryRequest.rpc(collection));
3030
}
3131

3232
@Override
33-
protected final CompletableFuture<QueryResponseGrouped<T>> performRequest(SearchOperator operator, GroupBy groupBy) {
33+
protected final CompletableFuture<QueryResponseGrouped<T>> performRequest(QueryOperator operator, GroupBy groupBy) {
3434
var request = new QueryRequest(operator, groupBy);
3535
return this.transport.performRequestAsync(request, QueryRequest.grouped(collection));
3636
}

0 commit comments

Comments
 (0)