Skip to content

Commit ad3440c

Browse files
committed
Filter function should be sent the label and not the internal ID.
1 parent 765c4ab commit ad3440c

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

examples/searchKnnWithFilter_test.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ bool pickIdsDivisibleBySeven(unsigned int ep_id) {
2121
}
2222

2323
template<typename filter_func_t>
24-
void test(filter_func_t filter_func, size_t div_num) {
24+
void test(filter_func_t filter_func, size_t div_num, size_t label_id_start) {
2525
int d = 4;
2626
idx_t n = 100;
2727
idx_t nq = 10;
@@ -40,15 +40,15 @@ void test(filter_func_t filter_func, size_t div_num) {
4040
for (idx_t i = 0; i < nq * d; ++i) {
4141
query[i] = distrib(rng);
4242
}
43-
4443

4544
hnswlib::L2Space space(d);
4645
hnswlib::AlgorithmInterface<float>* alg_brute = new hnswlib::BruteforceSearch<float,hnswlib::FILTERFUNC>(&space, 2 * n);
4746
hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float,hnswlib::FILTERFUNC>(&space, 2 * n);
4847

4948
for (size_t i = 0; i < n; ++i) {
50-
alg_brute->addPoint(data.data() + d * i, i);
51-
alg_hnsw->addPoint(data.data() + d * i, i);
49+
// `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs
50+
alg_brute->addPoint(data.data() + d * i, label_id_start + i);
51+
alg_hnsw->addPoint(data.data() + d * i, label_id_start + i);
5252
}
5353

5454
// test searchKnnCloserFirst of BruteforceSearch with filtering
@@ -87,8 +87,8 @@ void test(filter_func_t filter_func, size_t div_num) {
8787

8888
int main() {
8989
std::cout << "Testing ..." << std::endl;
90-
test(pickIdsDivisibleByThree, 3);
91-
test(pickIdsDivisibleBySeven, 7);
90+
test(pickIdsDivisibleByThree, 3, 17);
91+
test(pickIdsDivisibleBySeven, 7, 17);
9292
std::cout << "Test ok" << std::endl;
9393

9494
return 0;

hnswlib/hnswalg.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ namespace hnswlib {
247247
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;
248248

249249
dist_t lowerBound;
250-
if ((!has_deletions || !isMarkedDeleted(ep_id)) && isIdAllowed(ep_id)) {
250+
if ((!has_deletions || !isMarkedDeleted(ep_id)) && isIdAllowed(getExternalLabel(ep_id))) {
251251
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
252252
lowerBound = dist;
253253
top_candidates.emplace(dist, ep_id);
@@ -307,7 +307,7 @@ namespace hnswlib {
307307
_MM_HINT_T0);////////////////////////
308308
#endif
309309

310-
if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(candidate_id))
310+
if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(getExternalLabel(candidate_id)))
311311
top_candidates.emplace(dist, candidate_id);
312312

313313
if (top_candidates.size() > ef)

0 commit comments

Comments
 (0)