Skip to content

Filter elements with an optional filtering function #402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 6, 2022
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ if(CMAKE_PROJECT_NAME STREQUAL PROJECT_NAME)
add_executable(searchKnnCloserFirst_test examples/searchKnnCloserFirst_test.cpp)
target_link_libraries(searchKnnCloserFirst_test hnswlib)

add_executable(searchKnnWithFilter_test examples/searchKnnWithFilter_test.cpp)
target_link_libraries(searchKnnWithFilter_test hnswlib)

add_executable(main main.cpp sift_1b.cpp)
target_link_libraries(main hnswlib)
endif()
95 changes: 95 additions & 0 deletions examples/searchKnnWithFilter_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// This is a test file for testing the filtering feature

#include "../hnswlib/hnswlib.h"

#include <assert.h>

#include <vector>
#include <iostream>

namespace
{

using idx_t = hnswlib::labeltype;

bool pickIdsDivisibleByThree(unsigned int ep_id) {
return ep_id % 3 == 0;
}

bool pickIdsDivisibleBySeven(unsigned int ep_id) {
return ep_id % 7 == 0;
}

template<typename filter_func_t>
void test(filter_func_t filter_func, size_t div_num) {
int d = 4;
idx_t n = 100;
idx_t nq = 10;
size_t k = 10;

std::vector<float> data(n * d);
std::vector<float> query(nq * d);

std::mt19937 rng;
rng.seed(47);
std::uniform_real_distribution<> distrib;

for (idx_t i = 0; i < n * d; ++i) {
data[i] = distrib(rng);
}
for (idx_t i = 0; i < nq * d; ++i) {
query[i] = distrib(rng);
}


hnswlib::L2Space space(d);
hnswlib::AlgorithmInterface<float>* alg_brute = new hnswlib::BruteforceSearch<float,hnswlib::FILTERFUNC>(&space, 2 * n);
hnswlib::AlgorithmInterface<float>* alg_hnsw = new hnswlib::HierarchicalNSW<float,hnswlib::FILTERFUNC>(&space, 2 * n);

for (size_t i = 0; i < n; ++i) {
alg_brute->addPoint(data.data() + d * i, i);
alg_hnsw->addPoint(data.data() + d * i, i);
}

// test searchKnnCloserFirst of BruteforceSearch with filtering
for (size_t j = 0; j < nq; ++j) {
const void* p = query.data() + j * d;
auto gd = alg_brute->searchKnn(p, k, filter_func);
auto res = alg_brute->searchKnnCloserFirst(p, k, filter_func);
assert(gd.size() == res.size());
size_t t = gd.size();
while (!gd.empty()) {
assert(gd.top() == res[--t]);
assert((gd.top().second % div_num) == 0);
gd.pop();
}
}

// test searchKnnCloserFirst of hnsw with filtering
for (size_t j = 0; j < nq; ++j) {
const void* p = query.data() + j * d;
auto gd = alg_hnsw->searchKnn(p, k, filter_func);
auto res = alg_hnsw->searchKnnCloserFirst(p, k, filter_func);
assert(gd.size() == res.size());
size_t t = gd.size();
while (!gd.empty()) {
assert(gd.top() == res[--t]);
assert((gd.top().second % div_num) == 0);
gd.pop();
}
}

delete alg_brute;
delete alg_hnsw;
}

} // namespace

int main() {
std::cout << "Testing ..." << std::endl;
test(pickIdsDivisibleByThree, 3);
test(pickIdsDivisibleBySeven, 7);
std::cout << "Test ok" << std::endl;

return 0;
}
18 changes: 11 additions & 7 deletions hnswlib/bruteforce.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
#include <algorithm>

namespace hnswlib {
template<typename dist_t>
class BruteforceSearch : public AlgorithmInterface<dist_t> {
template<typename dist_t, typename filter_func_t=FILTERFUNC>
class BruteforceSearch : public AlgorithmInterface<dist_t,filter_func_t> {
public:
BruteforceSearch(SpaceInterface <dist_t> *s) : data_(nullptr), maxelements_(0),
cur_element_count(0), size_per_element_(0), data_size_(0),
Expand Down Expand Up @@ -92,20 +92,24 @@ namespace hnswlib {


std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k) const {
searchKnn(const void *query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would not it better to pass NULL as a default value to show that we do not do filtering by default (see ad3440c#r954851933)?

filter_func_t isIdAllowed=NULL

std::priority_queue<std::pair<dist_t, labeltype >> topResults;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add the same flag as in hnswlib/hnswalg.h and check of k

assert(k <= cur_element_count);
bool is_filter_disabled = std::is_same<filter_func_t, decltype(allowAllIds)>::value;

if (cur_element_count == 0) return topResults;
for (int i = 0; i < k; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
data_size_))));
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
if(isIdAllowed(label)) {
Copy link
Contributor

@dyashuni dyashuni Sep 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (is_filter_disabled || isIdAllowed(label)) {

topResults.push(std::pair<dist_t, labeltype>(dist, label));
}
}
dist_t lastdist = topResults.top().first;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now topResults may be empty. Need to add a check here.

if (topResults.size() > 0)
{
    lastdist = topResults.top().first;
}
else
{
    lastdist = std::numeric_limits<dist_t>::max();
}

for (int i = k; i < cur_element_count; i++) {
dist_t dist = fstdistfunc_(query_data, data_ + size_per_element_ * i, dist_func_param_);
if (dist <= lastdist) {
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
data_size_))));
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
if(isIdAllowed(label)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (is_filter_disabled || isIdAllowed(label)) {

topResults.push(std::pair<dist_t, labeltype>(dist, label));
}
if (topResults.size() > k)
topResults.pop();
lastdist = topResults.top().first;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add a check here as well

if (topResults.size() > 0)
{
    lastdist = topResults.top().first;
}

Expand Down
16 changes: 8 additions & 8 deletions hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ namespace hnswlib {
typedef unsigned int tableint;
typedef unsigned int linklistsizeint;

template<typename dist_t>
class HierarchicalNSW : public AlgorithmInterface<dist_t> {
template<typename dist_t, typename filter_func_t=FILTERFUNC>
class HierarchicalNSW : public AlgorithmInterface<dist_t,filter_func_t> {
public:
static const tableint max_update_element_locks = 65536;
HierarchicalNSW(SpaceInterface<dist_t> *s) {
Expand Down Expand Up @@ -238,7 +238,7 @@ namespace hnswlib {

template <bool has_deletions, bool collect_metrics=false>
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst>
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef) const {
searchBaseLayerST(tableint ep_id, const void *data_point, size_t ef, filter_func_t isIdAllowed) const {
VisitedList *vl = visited_list_pool_->getFreeVisitedList();
vl_type *visited_array = vl->mass;
vl_type visited_array_tag = vl->curV;
Expand All @@ -247,7 +247,7 @@ namespace hnswlib {
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;

dist_t lowerBound;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add the following micro optimisation to not call isIdAllowed at all if filtering is disabled:

bool is_filter_disabled = isIdAllowed == allowAllIds;
if ((!has_deletions || !isMarkedDeleted(ep_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(ep_id)))) {

if (!has_deletions || !isMarkedDeleted(ep_id)) {
if ((!has_deletions || !isMarkedDeleted(ep_id)) && isIdAllowed(ep_id)) {
dist_t dist = fstdistfunc_(data_point, getDataByInternalId(ep_id), dist_func_param_);
lowerBound = dist;
top_candidates.emplace(dist, ep_id);
Expand Down Expand Up @@ -307,7 +307,7 @@ namespace hnswlib {
_MM_HINT_T0);////////////////////////
#endif

if (!has_deletions || !isMarkedDeleted(candidate_id))
if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(candidate_id))
top_candidates.emplace(dist, candidate_id);

if (top_candidates.size() > ef)
Expand Down Expand Up @@ -1111,7 +1111,7 @@ namespace hnswlib {
};

std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k) const {
searchKnn(const void *query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const {
std::priority_queue<std::pair<dist_t, labeltype >> result;
if (cur_element_count == 0) return result;

Expand Down Expand Up @@ -1148,11 +1148,11 @@ namespace hnswlib {
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> top_candidates;
if (num_deleted_) {
top_candidates=searchBaseLayerST<true,true>(
currObj, query_data, std::max(ef_, k));
currObj, query_data, std::max(ef_, k), isIdAllowed);
}
else{
top_candidates=searchBaseLayerST<false,true>(
currObj, query_data, std::max(ef_, k));
currObj, query_data, std::max(ef_, k), isIdAllowed);
}

while (top_candidates.size() > k) {
Expand Down
20 changes: 14 additions & 6 deletions hnswlib/hnswlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ static bool AVX512Capable() {
namespace hnswlib {
typedef size_t labeltype;

bool allowAllIds(unsigned int ep_id) {
return true;
}

template <typename T>
class pairGreater {
public:
Expand All @@ -137,6 +141,7 @@ namespace hnswlib {
template<typename MTYPE>
using DISTFUNC = MTYPE(*)(const void *, const void *, const void *);

using FILTERFUNC = bool(*)(unsigned int);

template<typename MTYPE>
class SpaceInterface {
Expand All @@ -151,28 +156,31 @@ namespace hnswlib {
virtual ~SpaceInterface() {}
};

template<typename dist_t>
template<typename dist_t, typename filter_func_t=FILTERFUNC>
class AlgorithmInterface {
public:
virtual void addPoint(const void *datapoint, labeltype label)=0;
virtual std::priority_queue<std::pair<dist_t, labeltype >> searchKnn(const void *, size_t) const = 0;

virtual std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void*, size_t, filter_func_t isIdAllowed=allowAllIds) const = 0;

// Return k nearest neighbor in the order of closer fist
virtual std::vector<std::pair<dist_t, labeltype>>
searchKnnCloserFirst(const void* query_data, size_t k) const;
searchKnnCloserFirst(const void* query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const;

virtual void saveIndex(const std::string &location)=0;
virtual ~AlgorithmInterface(){
}
};

template<typename dist_t>
template<typename dist_t, typename filter_func_t>
std::vector<std::pair<dist_t, labeltype>>
AlgorithmInterface<dist_t>::searchKnnCloserFirst(const void* query_data, size_t k) const {
AlgorithmInterface<dist_t, filter_func_t>::searchKnnCloserFirst(const void* query_data, size_t k,
filter_func_t isIdAllowed) const {
std::vector<std::pair<dist_t, labeltype>> result;

// here searchKnn returns the result in the order of further first
auto ret = searchKnn(query_data, k);
auto ret = searchKnn(query_data, k, isIdAllowed);
{
size_t sz = ret.size();
result.resize(sz);
Expand Down