Skip to content

Commit 39e24aa

Browse files
committed
feat: add BM25 query
Bm25 does not implement ObjectFilter because Aggregation API does not support bm25.
1 parent 8121c4d commit 39e24aa

File tree

3 files changed

+111
-0
lines changed

3 files changed

+111
-0
lines changed

src/it/java/io/weaviate/integration/SearchITest.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import io.weaviate.client6.v1.api.collections.data.Reference;
2323
import io.weaviate.client6.v1.api.collections.query.GroupBy;
2424
import io.weaviate.client6.v1.api.collections.query.Metadata;
25+
import io.weaviate.client6.v1.api.collections.query.QueryMetadata;
2526
import io.weaviate.client6.v1.api.collections.query.QueryResponseGroup;
2627
import io.weaviate.client6.v1.api.collections.query.Where;
2728
import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw;
@@ -255,4 +256,29 @@ public void testFetchObjectsWithFilters() throws IOException {
255256
hugeHat.metadata().uuid());
256257

257258
}
259+
260+
@Test
261+
public void testBm25() throws IOException {
262+
var nsWords = ns("Words");
263+
264+
client.collections.create(nsWords,
265+
collection -> collection
266+
.properties(
267+
Property.text("relevant"),
268+
Property.text("irrelevant")));
269+
270+
var words = client.collections.use(nsWords);
271+
272+
/* notWant */ words.data.insert(Map.of("relevant", "elefant", "irrelevant", "dollar bill"));
273+
var want = words.data.insert(Map.of("relevant", "a dime a dollar", "irrelevant", "euro"));
274+
275+
var dollarWorlds = words.query.bm25(
276+
"dollar",
277+
bm25 -> bm25.queryProperties("relevant"));
278+
279+
Assertions.assertThat(dollarWorlds.objects())
280+
.hasSize(1)
281+
.extracting(WeaviateObject::metadata).extracting(QueryMetadata::uuid)
282+
.containsOnly(want.metadata().uuid());
283+
}
258284
}

src/main/java/io/weaviate/client6/v1/api/collections/query/AbstractQueryClient.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,32 @@ public GroupedResponseT fetchObjects(FetchObjects query, GroupBy groupBy) {
5656
return performRequest(query, groupBy);
5757
}
5858

59+
// BM25 queries -------------------------------------------------------------
60+
61+
public ResponseT bm25(String query) {
62+
return bm25(Bm25.of(query));
63+
}
64+
65+
public ResponseT bm25(String query, Function<Bm25.Builder, ObjectBuilder<Bm25>> fn) {
66+
return bm25(Bm25.of(query, fn));
67+
}
68+
69+
public ResponseT bm25(Bm25 query) {
70+
return performRequest(query);
71+
}
72+
73+
public GroupedResponseT bm25(String query, GroupBy groupBy) {
74+
return bm25(Bm25.of(query), groupBy);
75+
}
76+
77+
public GroupedResponseT bm25(String query, Function<Bm25.Builder, ObjectBuilder<Bm25>> fn, GroupBy groupBy) {
78+
return bm25(Bm25.of(query, fn), groupBy);
79+
}
80+
81+
public GroupedResponseT bm25(Bm25 query, GroupBy groupBy) {
82+
return performRequest(query, groupBy);
83+
}
84+
5985
// NearVector queries -------------------------------------------------------
6086

6187
public ResponseT nearVector(Float[] vector) {
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
package io.weaviate.client6.v1.api.collections.query;
2+
3+
import java.util.Arrays;
4+
import java.util.List;
5+
import java.util.function.Function;
6+
7+
import io.weaviate.client6.v1.internal.ObjectBuilder;
8+
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch;
9+
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;
10+
11+
public record Bm25(String query, List<String> queryProperties, BaseQueryOptions common)
12+
implements SearchOperator {
13+
14+
public static final Bm25 of(String query) {
15+
return of(query, ObjectBuilder.identity());
16+
}
17+
18+
public static final Bm25 of(String query, Function<Builder, ObjectBuilder<Bm25>> fn) {
19+
return fn.apply(new Builder(query)).build();
20+
}
21+
22+
public Bm25(Builder builder) {
23+
this(builder.query, builder.queryProperties, builder.baseOptions());
24+
}
25+
26+
public static class Builder extends BaseQueryOptions.Builder<Builder, Bm25> {
27+
// Required query parameters.
28+
private final String query;
29+
30+
// Optional query parameters.
31+
List<String> queryProperties;
32+
33+
public Builder(String query) {
34+
this.query = query;
35+
}
36+
37+
public Builder queryProperties(String... properties) {
38+
return queryProperties(Arrays.asList(properties));
39+
}
40+
41+
public Builder queryProperties(List<String> properties) {
42+
this.queryProperties = properties;
43+
return this;
44+
}
45+
46+
@Override
47+
public final Bm25 build() {
48+
return new Bm25(this);
49+
}
50+
}
51+
52+
@Override
53+
public final void appendTo(WeaviateProtoSearchGet.SearchRequest.Builder req) {
54+
common.appendTo(req);
55+
req.setBm25Search(WeaviateProtoBaseSearch.BM25.newBuilder()
56+
.setQuery(query)
57+
.addAllProperties(queryProperties));
58+
}
59+
}

0 commit comments

Comments
 (0)