-
Notifications
You must be signed in to change notification settings - Fork 705
Filter elements with an optional filtering function #402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
765c4ab
ad3440c
4f6dcc3
1c833a7
aaee13a
b87f623
f0dedf3
de22860
e4705fd
7f419ea
1fe7baf
c9897b0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
// This is a test file for testing the filtering feature | ||
|
||
#include "../hnswlib/hnswlib.h" | ||
|
||
#include <assert.h> | ||
|
||
#include <vector> | ||
#include <iostream> | ||
|
||
namespace | ||
{ | ||
|
||
using idx_t = hnswlib::labeltype; | ||
|
||
bool pickIdsDivisibleByThree(unsigned int ep_id) { | ||
return ep_id % 3 == 0; | ||
} | ||
|
||
bool pickIdsDivisibleBySeven(unsigned int ep_id) { | ||
return ep_id % 7 == 0; | ||
} | ||
|
||
template<typename filter_func_t> | ||
void test(filter_func_t filter_func, size_t div_num) { | ||
int d = 4; | ||
idx_t n = 100; | ||
idx_t nq = 10; | ||
size_t k = 10; | ||
|
||
std::vector<float> data(n * d); | ||
std::vector<float> query(nq * d); | ||
|
||
std::mt19937 rng; | ||
rng.seed(47); | ||
std::uniform_real_distribution<> distrib; | ||
|
||
for (idx_t i = 0; i < n * d; ++i) { | ||
data[i] = distrib(rng); | ||
} | ||
for (idx_t i = 0; i < nq * d; ++i) { | ||
query[i] = distrib(rng); | ||
} | ||
|
||
|
||
hnswlib::L2Space space(d); | ||
hnswlib::AlgorithmInterface<float>* alg_brute = new hnswlib::BruteforceSearch<float,hnswlib::FILTERFUNC>(&space, 2 * n); | ||
hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float,hnswlib::FILTERFUNC>(&space, 2 * n); | ||
|
||
for (size_t i = 0; i < n; ++i) { | ||
alg_brute->addPoint(data.data() + d * i, i); | ||
alg_hnsw->addPoint(data.data() + d * i, i); | ||
} | ||
|
||
// test searchKnnCloserFirst of BruteforceSearch with filtering | ||
for (size_t j = 0; j < nq; ++j) { | ||
const void* p = query.data() + j * d; | ||
auto gd = alg_brute->searchKnn(p, k, filter_func); | ||
auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func); | ||
assert(gd.size() == res.size()); | ||
size_t t = gd.size(); | ||
while (!gd.empty()) { | ||
assert(gd.top() == res[--t]); | ||
assert((gd.top().second % div_num) == 0); | ||
gd.pop(); | ||
} | ||
} | ||
|
||
// test searchKnnCloserFirst of hnsw with filtering | ||
for (size_t j = 0; j < nq; ++j) { | ||
const void* p = query.data() + j * d; | ||
auto gd = alg_hnsw->searchKnn(p, k, filter_func); | ||
auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func); | ||
assert(gd.size() == res.size()); | ||
size_t t = gd.size(); | ||
while (!gd.empty()) { | ||
assert(gd.top() == res[--t]); | ||
assert((gd.top().second % div_num) == 0); | ||
gd.pop(); | ||
} | ||
} | ||
|
||
delete alg_brute; | ||
delete alg_hnsw; | ||
} | ||
|
||
} // namespace | ||
|
||
int main() { | ||
std::cout << "Testing ..." << std::endl; | ||
test(pickIdsDivisibleByThree, 3); | ||
test(pickIdsDivisibleBySeven, 7); | ||
std::cout << "Test ok" << std::endl; | ||
|
||
return 0; | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,8 +5,8 @@ | |
#include <algorithm> | ||
|
||
namespace hnswlib { | ||
template<typename dist_t> | ||
class BruteforceSearch : public AlgorithmInterface<dist_t> { | ||
template<typename dist_t, typename filter_func_t=FILTERFUNC> | ||
class BruteforceSearch : public AlgorithmInterface<dist_t,filter_func_t> { | ||
public: | ||
BruteforceSearch(SpaceInterface <dist_t> *s) : data_(nullptr), maxelements_(0), | ||
cur_element_count(0), size_per_element_(0), data_size_(0), | ||
|
@@ -92,20 +92,24 @@ namespace hnswlib { | |
|
||
|
||
std::priority_queue<std::pair<dist_t, labeltype >> | ||
searchKnn(const void *query_data, size_t k) const { | ||
searchKnn(const void *query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const { | ||
std::priority_queue<std::pair<dist_t, labeltype >> topResults; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you please add the same flag as in
|
||
if (cur_element_count == 0) return topResults; | ||
for (int i = 0; i < k; i++) { | ||
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); | ||
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i + | ||
data_size_)))); | ||
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); | ||
if(isIdAllowed(label)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
topResults.push(std::pair<dist_t, labeltype>(dist, label)); | ||
} | ||
} | ||
dist_t lastdist = topResults.top().first; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. now topResults may be empty. Need to add a check here.
|
||
for (int i = k; i < cur_element_count; i++) { | ||
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_); | ||
if (dist <= lastdist) { | ||
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i + | ||
data_size_)))); | ||
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); | ||
if(isIdAllowed(label)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
topResults.push(std::pair<dist_t, labeltype>(dist, label)); | ||
} | ||
if (topResults.size() > k) | ||
topResults.pop(); | ||
lastdist = topResults.top().first; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Need to add a check here as well
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,8 +13,8 @@ namespace hnswlib { | |
typedef unsigned int tableint; | ||
typedef unsigned int linklistsizeint; | ||
|
||
template<typename dist_t> | ||
class HierarchicalNSW : public AlgorithmInterface<dist_t> { | ||
template<typename dist_t, typename filter_func_t=FILTERFUNC> | ||
class HierarchicalNSW : public AlgorithmInterface<dist_t,filter_func_t> { | ||
public: | ||
static const tableint max_update_element_locks = 65536; | ||
HierarchicalNSW(SpaceInterface<dist_t> *s) { | ||
|
@@ -238,7 +238,7 @@ namespace hnswlib { | |
|
||
template <bool has_deletions, bool collect_metrics=false> | ||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> | ||
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const { | ||
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, filter_func_t isIdAllowed) const { | ||
VisitedList *vl = visited_list_pool_->getFreeVisitedList(); | ||
vl_type *visited_array = vl->mass; | ||
vl_type visited_array_tag = vl->curV; | ||
|
@@ -247,7 +247,7 @@ namespace hnswlib { | |
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set; | ||
|
||
dist_t lowerBound; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add the following micro optimisation to not call isIdAllowed at all if filtering is disabled:
|
||
if (!has_deletions || !isMarkedDeleted(ep_id)) { | ||
if ((!has_deletions || !isMarkedDeleted(ep_id)) && isIdAllowed(ep_id)) { | ||
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_); | ||
lowerBound = dist; | ||
top_candidates.emplace(dist, ep_id); | ||
|
@@ -307,7 +307,7 @@ namespace hnswlib { | |
_MM_HINT_T0);//////////////////////// | ||
#endif | ||
|
||
if (!has_deletions || !isMarkedDeleted(candidate_id)) | ||
if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(candidate_id)) | ||
top_candidates.emplace(dist, candidate_id); | ||
|
||
if (top_candidates.size() > ef) | ||
|
@@ -1111,7 +1111,7 @@ namespace hnswlib { | |
}; | ||
|
||
std::priority_queue<std::pair<dist_t, labeltype >> | ||
searchKnn(const void *query_data, size_t k) const { | ||
searchKnn(const void *query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const { | ||
std::priority_queue<std::pair<dist_t, labeltype >> result; | ||
if (cur_element_count == 0) return result; | ||
|
||
|
@@ -1148,11 +1148,11 @@ namespace hnswlib { | |
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates; | ||
if (num_deleted_) { | ||
top_candidates=searchBaseLayerST<true,true>( | ||
currObj, query_data, std::max(ef_, k)); | ||
currObj, query_data, std::max(ef_, k), isIdAllowed); | ||
} | ||
else{ | ||
top_candidates=searchBaseLayerST<false,true>( | ||
currObj, query_data, std::max(ef_, k)); | ||
currObj, query_data, std::max(ef_, k), isIdAllowed); | ||
} | ||
|
||
while (top_candidates.size() > k) { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would not it better to pass NULL as a default value to show that we do not do filtering by default (see ad3440c#r954851933)?
filter_func_t isIdAllowed=NULL