-
Notifications
You must be signed in to change notification settings - Fork 704
Filter elements with an optional filtering function #402
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
hnswlib/bruteforce.h
Outdated
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); | ||
if(isIdAllowed(label)) { | ||
topResults.push(std::pair<dist_t, labeltype>(dist, label)); | ||
} | ||
} | ||
dist_t lastdist = topResults.top().first; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
now topResults may be empty. Need to add a check here.
if (topResults.size() > 0)
{
lastdist = topResults.top().first;
}
else
{
lastdist = std::numeric_limits<dist_t>::max();
}
hnswlib/bruteforce.h
Outdated
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); | ||
if(isIdAllowed(label)) { | ||
topResults.push(std::pair<dist_t, labeltype>(dist, label)); | ||
} | ||
if (topResults.size() > k) | ||
topResults.pop(); | ||
lastdist = topResults.top().first; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to add a check here as well
if (topResults.size() > 0)
{
lastdist = topResults.top().first;
}
Thank you @kishorenc for implementing this feature! |
Thank you for the review.
|
@@ -247,7 +247,7 @@ namespace hnswlib { | |||
std::priority_queue<std::pair<dist_t, tableint>, std::vector<std::pair<dist_t, tableint>>, CompareByFirst> candidate_set; | |||
|
|||
dist_t lowerBound; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add the following micro optimisation to not call isIdAllowed at all if filtering is disabled:
bool is_filter_disabled = isIdAllowed == allowAllIds;
if ((!has_deletions || !isMarkedDeleted(ep_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(ep_id)))) {
hnswlib/hnswalg.h
Outdated
@@ -307,7 +307,7 @@ namespace hnswlib { | |||
_MM_HINT_T0);//////////////////////// | |||
#endif | |||
|
|||
if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(candidate_id)) | |||
if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(getExternalLabel(candidate_id))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if ((!has_deletions || !isMarkedDeleted(ep_id)) && (is_filter_disabled || isIdAllowed(getExternalLabel(ep_id)))) {
I just discovered a gotcha with the current filtering function interface: there's no way to pass some kind of application context. For e.g. if I have a set of IDs to check against, the static filter function will not be able to access that state. I'm going to think about how to pass a "context" to the filtering function. |
hnswlib/bruteforce.h
Outdated
@@ -92,20 +92,24 @@ namespace hnswlib { | |||
|
|||
|
|||
std::priority_queue<std::pair<dist_t, labeltype >> | |||
searchKnn(const void *query_data, size_t k) const { | |||
searchKnn(const void *query_data, size_t k, filter_func_t isIdAllowed=allowAllIds) const { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would not it better to pass NULL as a default value to show that we do not do filtering by default (see ad3440c#r954851933)?
filter_func_t isIdAllowed=NULL
Yes it is true. I thought that maybe we have some instruments in C++ to dynamically create functions with required parameters (e.g. variable filtering range) and pass pointers to such functions Like in Python |
Maybe the best way is to pass a pointer to an instance of some class instead of pointer to a function. |
92799ec
to
b87f623
Compare
@kishorenc thank you for the updates! Could you please add the
|
hnswlib/hnswalg.h
Outdated
@@ -307,7 +308,8 @@ namespace hnswlib { | |||
_MM_HINT_T0);//////////////////////// | |||
#endif | |||
|
|||
if ((!has_deletions || !isMarkedDeleted(candidate_id)) && isIdAllowed(getExternalLabel(candidate_id))) | |||
is_filter_disabled = std::is_same<filter_func_t, decltype(allowAllIds)>::value; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can remove this line as above we assigned value to the is_filter_disabled
variable.
@kishorenc Thank you! In windows I get the following error: Do we need to remove the |
I've removed it now, please check again. |
Great, thank you! |
@@ -92,23 +92,30 @@ namespace hnswlib { | |||
|
|||
|
|||
std::priority_queue<std::pair<dist_t, labeltype >> | |||
searchKnn(const void *query_data, size_t k) const { | |||
searchKnn(const void *query_data, size_t k, filter_func_t& isIdAllowed=allowAllIds) const { | |||
std::priority_queue<std::pair<dist_t, labeltype >> topResults; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add the same flag as in hnswlib/hnswalg.h
and check of k
assert(k <= cur_element_count);
bool is_filter_disabled = std::is_same<filter_func_t, decltype(allowAllIds)>::value;
hnswlib/bruteforce.h
Outdated
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i + | ||
data_size_)))); | ||
labeltype label = *((labeltype*) (data_ + size_per_element_ * i + data_size_)); | ||
if(isIdAllowed(label)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (is_filter_disabled || isIdAllowed(label)) {
hnswlib/bruteforce.h
Outdated
topResults.push(std::pair<dist_t, labeltype>(dist, *((labeltype *) (data_ + size_per_element_ * i + | ||
data_size_)))); | ||
labeltype label = *((labeltype *) (data_ + size_per_element_ * i + data_size_)); | ||
if(isIdAllowed(label)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if (is_filter_disabled || isIdAllowed(label)) {
da84469
to
1fe7baf
Compare
Thank you! |
This merge is great news! Are there plans to include this in the next release? Are there plans for implementing the Python-interface for the filtering functionality? |
@gtsoukas Thank you! We can add filtering to Python. Just need to figure out what are the most popular use cases for filtering and create a proper API. What kind of filtering are you interested in Python? In my head we can filter elements that are
Any other ideas? |
Thank you @dyashuni! My use case is about a few millions of vectors and it requires filtering varying fractions f of all vectors in sub 5ms per query, where f can be a few hundred vectors but also 50% of all vectors i.e. millions. A query typically must return a few hundred nearest vectors. One example is filtering for all female clothes in an e-commerce store. Here are some preliminary thoughts:
This is probably only fast if f is small, i.e. a few hundred/thousand ids. For large f I guess that the time for assembling the list and potentially transform from python to C++ would have a too big negative effects on performance. For very small f on the other hand, exact ANN e.g. via BLAS might be the better option anyway.
This would work well if f is extremely large, e.g. 99%. For large f on the other hand post-filtering might be OK for real world use cases i.e. query a few more elements than actually needed and delete the disallowed elements from the results. But still the proposed filter would be better than post filtering, since k is guaranteed and efficiency is better. In summary (id) lists of allowed/disallowed elements would probably only work well for very small f and very large f but would not be efficient for a very large range of f.
These are just some preliminary thoughts. I have already started working on a pull request. If I will come up with anything useful, I will mention it here within the next 14 days. Anyway would love to hear your thoughts. |
Fixes #366
Summary of the change
An optional filtering function can be specified as a template parameter that determines if a given ID should be picked. The default is a function that allows all labels via a
return true
-- this will be optimised away by the compiler, so there will be no extra cost for those who don't provide a filtering function.The use of a templated filter function also ensures that the filtering logic can be entirely inlined by the compiler, and is an implementation detail, instead of forcing a
std::set<>
of allowed IDs as discussed in #366.I've added a test that asserts on both the brute force and hnsw implementations. Verified that existing search knn tests pass.