Skip to content

v6: add BM25 query #390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions src/it/java/io/weaviate/integration/SearchITest.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.weaviate.client6.v1.api.collections.data.Reference;
import io.weaviate.client6.v1.api.collections.query.GroupBy;
import io.weaviate.client6.v1.api.collections.query.Metadata;
import io.weaviate.client6.v1.api.collections.query.QueryMetadata;
import io.weaviate.client6.v1.api.collections.query.QueryResponseGroup;
import io.weaviate.client6.v1.api.collections.query.Where;
import io.weaviate.client6.v1.api.collections.vectorindex.Hnsw;
Expand Down Expand Up @@ -255,4 +256,29 @@ public void testFetchObjectsWithFilters() throws IOException {
hugeHat.metadata().uuid());

}

@Test
public void testBm25() throws IOException {
var nsWords = ns("Words");

client.collections.create(nsWords,
collection -> collection
.properties(
Property.text("relevant"),
Property.text("irrelevant")));

var words = client.collections.use(nsWords);

/* notWant */ words.data.insert(Map.of("relevant", "elefant", "irrelevant", "dollar bill"));
var want = words.data.insert(Map.of("relevant", "a dime a dollar", "irrelevant", "euro"));

var dollarWorlds = words.query.bm25(
"dollar",
bm25 -> bm25.queryProperties("relevant"));

Assertions.assertThat(dollarWorlds.objects())
.hasSize(1)
.extracting(WeaviateObject::metadata).extracting(QueryMetadata::uuid)
.containsOnly(want.metadata().uuid());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoAggregate.AggregateRequest.Builder;

// TODO: move Near-, Hybrid, BM25 under client.collection.operators? With them implementing query.SearchOperator and aggregate.ObjectFilter
public interface ObjectFilter {
public interface AggregateObjectFilter {
void appendTo(WeaviateProtoAggregate.AggregateRequest.Builder req);

static ObjectFilter NONE = new ObjectFilter() {
static AggregateObjectFilter NONE = new AggregateObjectFilter() {
@Override
public void appendTo(Builder req) {
return;
}
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoAggregate;

public record Aggregation(
ObjectFilter filter,
AggregateObjectFilter filter,
Integer objectLimit,
boolean includeTotalCount,
List<PropertyAggregation> returnMetrics) {

public static Aggregation of(Function<Builder, ObjectBuilder<Aggregation>> fn) {
return of(ObjectFilter.NONE, fn);
return of(AggregateObjectFilter.NONE, fn);
}

public static Aggregation of(ObjectFilter objectFilter, Function<Builder, ObjectBuilder<Aggregation>> fn) {
public static Aggregation of(AggregateObjectFilter objectFilter, Function<Builder, ObjectBuilder<Aggregation>> fn) {
return fn.apply(new Builder(objectFilter)).build();
}

Expand All @@ -31,9 +31,9 @@ public Aggregation(Builder builder) {
}

public static class Builder implements ObjectBuilder<Aggregation> {
private final ObjectFilter objectFilter;
private final AggregateObjectFilter objectFilter;

public Builder(ObjectFilter objectFilter) {
public Builder(AggregateObjectFilter objectFilter) {
this.objectFilter = objectFilter;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ abstract class AbstractQueryClient<PropertiesT, SingleT, ResponseT, GroupedRespo

protected abstract SingleT byId(ById byId);

protected abstract ResponseT performRequest(SearchOperator operator);
protected abstract ResponseT performRequest(QueryOperator operator);

protected abstract GroupedResponseT performRequest(SearchOperator operator, GroupBy groupBy);
protected abstract GroupedResponseT performRequest(QueryOperator operator, GroupBy groupBy);

// Fetch by ID --------------------------------------------------------------

Expand Down Expand Up @@ -56,6 +56,32 @@ public GroupedResponseT fetchObjects(FetchObjects query, GroupBy groupBy) {
return performRequest(query, groupBy);
}

// BM25 queries -------------------------------------------------------------

public ResponseT bm25(String query) {
return bm25(Bm25.of(query));
}

public ResponseT bm25(String query, Function<Bm25.Builder, ObjectBuilder<Bm25>> fn) {
return bm25(Bm25.of(query, fn));
}

public ResponseT bm25(Bm25 query) {
return performRequest(query);
}

public GroupedResponseT bm25(String query, GroupBy groupBy) {
return bm25(Bm25.of(query), groupBy);
}

public GroupedResponseT bm25(String query, Function<Bm25.Builder, ObjectBuilder<Bm25>> fn, GroupBy groupBy) {
return bm25(Bm25.of(query, fn), groupBy);
}

public GroupedResponseT bm25(Bm25 query, GroupBy groupBy) {
return performRequest(query, groupBy);
}

// NearVector queries -------------------------------------------------------

public ResponseT nearVector(Float[] vector) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
package io.weaviate.client6.v1.api.collections.query;

import java.util.Arrays;
import java.util.List;
import java.util.function.Function;

import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;

public record Bm25(String query, List<String> queryProperties, BaseQueryOptions common)
implements QueryOperator {

public static final Bm25 of(String query) {
return of(query, ObjectBuilder.identity());
}

public static final Bm25 of(String query, Function<Builder, ObjectBuilder<Bm25>> fn) {
return fn.apply(new Builder(query)).build();
}

public Bm25(Builder builder) {
this(builder.query, builder.queryProperties, builder.baseOptions());
}

public static class Builder extends BaseQueryOptions.Builder<Builder, Bm25> {
// Required query parameters.
private final String query;

// Optional query parameters.
List<String> queryProperties;
SearchOperator searchOperator;

public Builder(String query) {
this.query = query;
}

public Builder queryProperties(String... properties) {
return queryProperties(Arrays.asList(properties));
}

public Builder queryProperties(List<String> properties) {
this.queryProperties = properties;
return this;
}

public Builder searchOperator(SearchOperator searchOperator) {
this.searchOperator = searchOperator;
return this;
}

@Override
public final Bm25 build() {
return new Bm25(this);
}
}

@Override
public final void appendTo(WeaviateProtoSearchGet.SearchRequest.Builder req) {
common.appendTo(req);
req.setBm25Search(WeaviateProtoBaseSearch.BM25.newBuilder()
.setQuery(query)
.addAllProperties(queryProperties));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public record ById(
String uuid,
List<String> returnProperties,
List<QueryReference> returnReferences,
List<Metadata> returnMetadata) implements SearchOperator {
List<Metadata> returnMetadata) implements QueryOperator {

private static final String ID_PROPERTY = "_id";

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;

public record FetchObjects(BaseQueryOptions common) implements SearchOperator {
public record FetchObjects(BaseQueryOptions common) implements QueryOperator {

public static FetchObjects of(Function<Builder, ObjectBuilder<FetchObjects>> fn) {
return fn.apply(new Builder()).build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import java.util.function.Function;

import io.weaviate.client6.v1.api.collections.aggregate.ObjectFilter;
import io.weaviate.client6.v1.api.collections.aggregate.AggregateObjectFilter;
import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoAggregate;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;

public record NearImage(String image, Float distance, Float certainty, BaseQueryOptions common)
implements SearchOperator, ObjectFilter {
implements QueryOperator, AggregateObjectFilter {

public static NearImage of(String image) {
return of(image, ObjectBuilder.identity());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
import java.util.List;
import java.util.function.Function;

import io.weaviate.client6.v1.api.collections.aggregate.ObjectFilter;
import io.weaviate.client6.v1.api.collections.aggregate.AggregateObjectFilter;
import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoAggregate;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;

public record NearText(List<String> concepts, Float distance, Float certainty, Move moveTo, Move moveAway,
BaseQueryOptions common) implements SearchOperator, ObjectFilter {
BaseQueryOptions common) implements QueryOperator, AggregateObjectFilter {

public static NearText of(String... concepts) {
return of(Arrays.asList(concepts), ObjectBuilder.identity());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import java.util.function.Function;

import io.weaviate.client6.v1.api.collections.aggregate.ObjectFilter;
import io.weaviate.client6.v1.api.collections.aggregate.AggregateObjectFilter;
import io.weaviate.client6.v1.internal.ObjectBuilder;
import io.weaviate.client6.v1.internal.grpc.GRPC;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoAggregate;
Expand All @@ -11,7 +11,7 @@
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;

public record NearVector(Float[] vector, Float distance, Float certainty, BaseQueryOptions common)
implements SearchOperator, ObjectFilter {
implements QueryOperator, AggregateObjectFilter {

public static final NearVector of(Float[] vector) {
return of(vector, ObjectBuilder.identity());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package io.weaviate.client6.v1.api.collections.query;

import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;

interface QueryOperator {
void appendTo(WeaviateProtoSearchGet.SearchRequest.Builder req);
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import io.weaviate.client6.v1.internal.orm.CollectionDescriptor;
import io.weaviate.client6.v1.internal.orm.PropertiesBuilder;

public record QueryRequest(SearchOperator operator, GroupBy groupBy) {
public record QueryRequest(QueryOperator operator, GroupBy groupBy) {

static <T> Rpc<QueryRequest, WeaviateProtoSearchGet.SearchRequest, QueryResponse<T>, WeaviateProtoSearchGet.SearchReply> rpc(
CollectionDescriptor<T> collection) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,29 @@
package io.weaviate.client6.v1.api.collections.query;

import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoSearchGet;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch;
import io.weaviate.client6.v1.internal.grpc.protocol.WeaviateProtoBaseSearch.SearchOperatorOptions.Operator;

interface SearchOperator {
void appendTo(WeaviateProtoSearchGet.SearchRequest.Builder req);
public class SearchOperator {
private final String operator;
private final Integer minimumOrTokensMatch;

public static final SearchOperator or(int minimumOrTokensMatch) {
return new SearchOperator("Or", minimumOrTokensMatch);
}

public static final SearchOperator and() {
return new SearchOperator("And", 0);
}

private SearchOperator(String operator, Integer minimumOrTokensMatch) {
this.operator = operator;
this.minimumOrTokensMatch = minimumOrTokensMatch;
}

void appendTo(WeaviateProtoBaseSearch.SearchOperatorOptions.Builder options) {
options.setOperator(operator == "And" ? Operator.OPERATOR_AND : Operator.OPERATOR_OR);
if (minimumOrTokensMatch != null) {
options.setMinimumOrTokensMatch(minimumOrTokensMatch);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ protected Optional<WeaviateObject<T, Object, QueryMetadata>> byId(ById byId) {
}

@Override
protected final QueryResponse<T> performRequest(SearchOperator operator) {
protected final QueryResponse<T> performRequest(QueryOperator operator) {
var request = new QueryRequest(operator, null);
return this.transport.performRequest(request, QueryRequest.rpc(collection));
}

@Override
protected final QueryResponseGrouped<T> performRequest(SearchOperator operator, GroupBy groupBy) {
protected final QueryResponseGrouped<T> performRequest(QueryOperator operator, GroupBy groupBy) {
var request = new QueryRequest(operator, groupBy);
return this.transport.performRequest(request, QueryRequest.grouped(collection));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,13 @@ protected CompletableFuture<Optional<WeaviateObject<T, Object, QueryMetadata>>>
}

@Override
protected final CompletableFuture<QueryResponse<T>> performRequest(SearchOperator operator) {
protected final CompletableFuture<QueryResponse<T>> performRequest(QueryOperator operator) {
var request = new QueryRequest(operator, null);
return this.transport.performRequestAsync(request, QueryRequest.rpc(collection));
}

@Override
protected final CompletableFuture<QueryResponseGrouped<T>> performRequest(SearchOperator operator, GroupBy groupBy) {
protected final CompletableFuture<QueryResponseGrouped<T>> performRequest(QueryOperator operator, GroupBy groupBy) {
var request = new QueryRequest(operator, groupBy);
return this.transport.performRequestAsync(request, QueryRequest.grouped(collection));
}
Expand Down
Loading