Skip to content

Commit 1fe7baf

Browse files
committed
Add check for is_filter_disabled.
1 parent 7f419ea commit 1fe7baf

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

hnswlib/bruteforce.h

+4-2
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,14 @@ namespace hnswlib {
9393

9494
std::priority_queue<std::pair<dist_t, labeltype >>
9595
searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed=allowAllIds) const {
96+
assert(k <= cur_element_count);
9697
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
9798
if (cur_element_count == 0) return topResults;
99+
bool is_filter_disabled = std::is_same<filter_func_t, decltype(allowAllIds)>::value;
98100
for (int i = 0; i < k; i++) {
99101
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
100102
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
101-
if(isIdAllowed(label)) {
103+
if(is_filter_disabled || isIdAllowed(label)) {
102104
topResults.push(std::pair<dist_t, labeltype>(dist, label));
103105
}
104106
}
@@ -107,7 +109,7 @@ namespace hnswlib {
107109
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
108110
if (dist <= lastdist) {
109111
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
110-
if(isIdAllowed(label)) {
112+
if(is_filter_disabled || isIdAllowed(label)) {
111113
topResults.push(std::pair<dist_t, labeltype>(dist, label));
112114
}
113115
if (topResults.size() > k)

0 commit comments

Comments
 (0)