@@ -93,12 +93,14 @@ namespace hnswlib {
93
93
94
94
std::priority_queue<std::pair<dist_t , labeltype >>
95
95
searchKnn (const void *query_data, size_t k, filter_func_t & isIdAllowed=allowAllIds) const {
96
+ assert (k <= cur_element_count);
96
97
std::priority_queue<std::pair<dist_t , labeltype >> topResults;
97
98
if (cur_element_count == 0 ) return topResults;
99
+ bool is_filter_disabled = std::is_same<filter_func_t , decltype (allowAllIds)>::value;
98
100
for (int i = 0 ; i < k; i++) {
99
101
dist_t dist = fstdistfunc_ (query_data, data_ + size_per_element_ * i, dist_func_param_);
100
102
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
101
- if (isIdAllowed (label)) {
103
+ if (is_filter_disabled || isIdAllowed (label)) {
102
104
topResults.push (std::pair<dist_t , labeltype>(dist, label));
103
105
}
104
106
}
@@ -107,7 +109,7 @@ namespace hnswlib {
107
109
dist_t dist = fstdistfunc_ (query_data, data_ + size_per_element_ * i, dist_func_param_);
108
110
if (dist <= lastdist) {
109
111
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
110
- if (isIdAllowed (label)) {
112
+ if (is_filter_disabled || isIdAllowed (label)) {
111
113
topResults.push (std::pair<dist_t , labeltype>(dist, label));
112
114
}
113
115
if (topResults.size () > k)
0 commit comments