Skip to content

Filter elements with an optional filtering function #402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 6, 2022
12 changes: 6 additions & 6 deletions examples/searchKnnWithFilter_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ bool pickIdsDivisibleBySeven(unsigned int ep_id) {
}

template<typename filter_func_t>
void test(filter_func_t filter_func, size_t div_num) {
void test(filter_func_t filter_func, size_t div_num, size_t label_id_start) {
int d = 4;
idx_t n = 100;
idx_t nq = 10;
Expand All @@ -40,15 +40,15 @@ void test(filter_func_t filter_func, size_t div_num) {
for (idx_t i = 0; i < nq * d; ++i) {
query[i] = distrib(rng);
}


hnswlib::L2Space space(d);
hnswlib::AlgorithmInterface<float>* alg_brute = new hnswlib::BruteforceSearch<float,hnswlib::FILTERFUNC>(&space, 2 * n);
hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float,hnswlib::FILTERFUNC>(&space, 2 * n);

for (size_t i = 0; i < n; ++i) {
alg_brute->addPoint(data.data() + d * i, i);
alg_hnsw->addPoint(data.data() + d * i, i);
// `label_id_start` is used to ensure that the returned IDs are labels and not internal IDs
alg_brute->addPoint(data.data() + d * i, label_id_start + i);
alg_hnsw->addPoint(data.data() + d * i, label_id_start + i);
}

// test searchKnnCloserFirst of BruteforceSearch with filtering
Expand Down Expand Up @@ -87,8 +87,8 @@ void test(filter_func_t filter_func, size_t div_num) {

int main() {
std::cout << "Testing ..." << std::endl;
test(pickIdsDivisibleByThree, 3);
test(pickIdsDivisibleBySeven, 7);
test(pickIdsDivisibleByThree, 3, 17);
test(pickIdsDivisibleBySeven, 7, 17);
std::cout << "Test ok" << std::endl;

return 0;
Expand Down
4 changes: 2 additions & 2 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ namespace hnswlib {
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;

dist_t lowerBound;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add the following micro optimisation to not call isIdAllowed at all if filtering is disabled:

bool is_filter_disabled = isIdAllowed == allowAllIds;
if ((!has_deletions || !isMarkedDeleted(ep_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(ep_id)))) {

if ((!has_deletions || !isMarkedDeleted(ep_id)) && isIdAllowed(ep_id)) {
if ((!has_deletions || !isMarkedDeleted(ep_id)) && isIdAllowed(getExternalLabel(ep_id))) {
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
lowerBound = dist;
top_candidates.emplace(dist, ep_id);
Expand Down Expand Up @@ -307,7 +307,7 @@ namespace hnswlib {
_MM_HINT_T0);////////////////////////
#endif

if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(candidate_id))
if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(getExternalLabel(candidate_id)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if ((!has_deletions || !isMarkedDeleted(ep_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(ep_id)))) {

top_candidates.emplace(dist, candidate_id);

if (top_candidates.size() > ef)
Expand Down