Skip to content

Filter elements with an optional filtering function #402

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Sep 6, 2022

Conversation

kishorenc
Copy link
Contributor

Fixes #366

Summary of the change

An optional filtering function can be specified as a template parameter that determines if a given ID should be picked. The default is a function that allows all labels via areturn true -- this will be optimised away by the compiler, so there will be no extra cost for those who don't provide a filtering function.

The use of a templated filter function also ensures that the filtering logic can be entirely inlined by the compiler, and is an implementation detail, instead of forcing a std::set<> of allowed IDs as discussed in #366.

I've added a test that asserts on both the brute force and hnsw implementations. Verified that existing search knn tests pass.

labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
if(isIdAllowed(label)) {
topResults.push(std::pair<dist_t, labeltype>(dist, label));
}
}
dist_t lastdist = topResults.top().first;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

now topResults may be empty. Need to add a check here.

if (topResults.size() > 0)
{
    lastdist = topResults.top().first;
}
else
{
    lastdist = std::numeric_limits<dist_t>::max();
}

labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
if(isIdAllowed(label)) {
topResults.push(std::pair<dist_t, labeltype>(dist, label));
}
if (topResults.size() > k)
topResults.pop();
lastdist = topResults.top().first;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add a check here as well

if (topResults.size() > 0)
{
    lastdist = topResults.top().first;
}

@dyashuni
Copy link
Contributor

dyashuni commented Aug 18, 2022

Thank you @kishorenc for implementing this feature!
My major concern is that isIdAllowed works with internal IDs that are hidden from users. Does it make sense to modify it to work with external IDs?

@kishorenc
Copy link
Contributor Author

@dyashuni

Thank you for the review.

  1. Good catch regarding the use of internal ID. I've now fixed the code to send the label to the filtering function and also updated the test to specifically assert for this issue by ensuring that the labels are not the same as the internal ID sequence.

  2. I've also addressed the other issue about topResults empty checks. Added test for this one as well.

@@ -247,7 +247,7 @@ namespace hnswlib {
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set;

dist_t lowerBound;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add the following micro optimisation to not call isIdAllowed at all if filtering is disabled:

bool is_filter_disabled = isIdAllowed == allowAllIds;
if ((!has_deletions || !isMarkedDeleted(ep_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(ep_id)))) {

@@ -307,7 +307,7 @@ namespace hnswlib {
_MM_HINT_T0);////////////////////////
#endif

if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(candidate_id))
if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(getExternalLabel(candidate_id)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if ((!has_deletions || !isMarkedDeleted(ep_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(ep_id)))) {

@kishorenc
Copy link
Contributor Author

@dyashuni

I just discovered a gotcha with the current filtering function interface: there's no way to pass some kind of application context. For e.g. if I have a set of IDs to check against, the static filter function will not be able to access that state. I'm going to think about how to pass a "context" to the filtering function.

@@ -92,20 +92,24 @@ namespace hnswlib {


std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k) const {
searchKnn(const void *query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would not it better to pass NULL as a default value to show that we do not do filtering by default (see ad3440c#r954851933)?

filter_func_t isIdAllowed=NULL

@dyashuni
Copy link
Contributor

@dyashuni

I just discovered a gotcha with the current filtering function interface: there's no way to pass some kind of application context. For e.g. if I have a set of IDs to check against, the static filter function will not be able to access that state. I'm going to think about how to pass a "context" to the filtering function.

Yes it is true. I thought that maybe we have some instruments in C++ to dynamically create functions with required parameters (e.g. variable filtering range) and pass pointers to such functions

Like in Python
https://stackoverflow.com/questions/803616/passing-functions-with-arguments-to-another-function-in-python
when we can pass functions with parameters using partial module or lambda functions

@dyashuni
Copy link
Contributor

dyashuni commented Aug 25, 2022

Maybe the best way is to pass a pointer to an instance of some class instead of pointer to a function.

@kishorenc
Copy link
Contributor Author

@dyashuni

  1. I've refactored the code to use a functor because function objects tend to get inlined very well by compilers (see this).
  2. Added an example in the test to show how a custom filtering function (with state) can be used.
  3. I've also added the micro optimization as a compile time check

@dyashuni
Copy link
Contributor

@kishorenc thank you for the updates!

Could you please add the searchKnnWithFilter_test test to the CI?
We have just added cpp tests to the CI. You need to rebase and pick up the latest changes from the develop branch.
Then modify the .github\workflows\build.yml file, add ./searchKnnWithFilter_test to the step Test

- name: Test
  run: |
    cd build
    ./searchKnnCloserFirst_test
    ./searchKnnWithFilter_test  <----- add this line
    ./test_updates
    ...

@@ -307,7 +308,8 @@ namespace hnswlib {
_MM_HINT_T0);////////////////////////
#endif

if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(getExternalLabel(candidate_id)))
is_filter_disabled = std::is_same<filter_func_t, decltype(allowAllIds)>::value;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove this line as above we assigned value to the is_filter_disabled variable.

@dyashuni
Copy link
Contributor

@kishorenc Thank you!

In windows I get the following error:
searchKnnWithFilter_test.cpp(151,20): error C3615: constexpr function 'CustomFilterFunctor::operator ()' cannot result in a constant expression

Do we need to remove the constexpr keyword ?

@kishorenc
Copy link
Contributor Author

I've removed it now, please check again.

@dyashuni
Copy link
Contributor

Great, thank you!

@@ -92,23 +92,30 @@ namespace hnswlib {


std::priority_queue<std::pair<dist_t, labeltype >>
searchKnn(const void *query_data, size_t k) const {
searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed=allowAllIds) const {
std::priority_queue<std::pair<dist_t, labeltype >> topResults;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add the same flag as in hnswlib/hnswalg.h and check of k

assert(k <= cur_element_count);
bool is_filter_disabled = std::is_same<filter_func_t, decltype(allowAllIds)>::value;

topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
data_size_))));
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_));
if(isIdAllowed(label)) {
Copy link
Contributor

@dyashuni dyashuni Sep 5, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (is_filter_disabled || isIdAllowed(label)) {

topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i +
data_size_))));
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_));
if(isIdAllowed(label)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (is_filter_disabled || isIdAllowed(label)) {

@yurymalkov yurymalkov mentioned this pull request Sep 5, 2022
@dyashuni
Copy link
Contributor

dyashuni commented Sep 6, 2022

Thank you!

@dyashuni dyashuni merged commit 5c14e05 into nmslib:develop Sep 6, 2022
@gtsoukas
Copy link
Contributor

This merge is great news! Are there plans to include this in the next release? Are there plans for implementing the Python-interface for the filtering functionality?

@dyashuni
Copy link
Contributor

@gtsoukas Thank you! We can add filtering to Python. Just need to figure out what are the most popular use cases for filtering and create a proper API. What kind of filtering are you interested in Python? In my head we can filter elements that are

  1. in the list of allowed elements - include elements that are in the list
  2. not in the list of disallowed elements - include elements that are not in the list

Any other ideas?

@gtsoukas
Copy link
Contributor

gtsoukas commented Sep 17, 2022

Thank you @dyashuni! My use case is about a few millions of vectors and it requires filtering varying fractions f of all vectors in sub 5ms per query, where f can be a few hundred vectors but also 50% of all vectors i.e. millions. A query typically must return a few hundred nearest vectors. One example is filtering for all female clothes in an e-commerce store.

Here are some preliminary thoughts:

  1. in the list of allowed elements - include elements that are in the list

This is probably only fast if f is small, i.e. a few hundred/thousand ids. For large f I guess that the time for assembling the list and potentially transform from python to C++ would have a too big negative effects on performance. For very small f on the other hand, exact ANN e.g. via BLAS might be the better option anyway.

  1. not in the list of disallowed elements - include elements that are not in the list

This would work well if f is extremely large, e.g. 99%. For large f on the other hand post-filtering might be OK for real world use cases i.e. query a few more elements than actually needed and delete the disallowed elements from the results. But still the proposed filter would be better than post filtering, since k is guaranteed and efficiency is better.

In summary (id) lists of allowed/disallowed elements would probably only work well for very small f and very large f but would not be efficient for a very large range of f.

  1. Another option would be to have an optional filter function in the python API which takes an id as parameter and returns a boolean. My guess is, that this would create a too large performance penalty due to the large amount of calls from C++ to python. What do you think? (I don't have any experience with implementing C++ python bindings)

  2. Probably the best option would be to allow tagging the nodes/vectors. This could be done via an optional parameter tags or categories to the add_items() function. Then, instead of passing a list of node ids to the knn_query() function API users could add a very small list of allowed_tag_ids or disallowed_tag_ids. This is probably the best option for efficiently handling the entire range of f. The largest downside is, that it is not only a change to the python API but rather requires some additional functionality to the core library. The tagging would require maintaining a vector of hash-maps. I would argue, that for many real world use cases this functionality is vital and add a unique selling point to hnswlib over the competing libraries. What do you think?

These are just some preliminary thoughts. I have already started working on a pull request. If I will come up with anything useful, I will mention it here within the next 14 days. Anyway would love to hear your thoughts.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants