Skip to content

Commit 7189655

Browse files
author
sicheng
committed
Fix KnnMerge
1 parent c643f08 commit 7189655

File tree

2 files changed

+11
-10
lines changed

2 files changed

+11
-10
lines changed

rust/types/src/execution/operator.rs

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -283,17 +283,12 @@ pub struct KnnOutput {
283283
}
284284

285285
/// The `KnnMerge` operator selects the records nearest to target from the batch vectors of records
286-
/// which are all sorted by distance in ascending order
286+
/// which are all sorted by distance in ascending order. If the same record occurs multiple times
287+
/// only one copy will remain in the final result.
287288
///
288289
/// # Parameters
289290
/// - `fetch`: The total number of records to fetch
290291
///
291-
/// # Inputs
292-
/// - `batch_distances`: The batch vector of records, each sorted by distance in ascending order
293-
///
294-
/// # Outputs
295-
/// - `distances`: The nearest records in either vectors, sorted by distance in ascending order
296-
///
297292
/// # Usage
298293
/// It can be used to merge the query results from different operators
299294
#[derive(Clone, Debug)]
@@ -312,10 +307,16 @@ impl KnnMerge {
312307
.filter_map(|(idx, itr)| itr.next().map(|rec| Reverse((rec, idx))))
313308
.collect::<BinaryHeap<_>>();
314309

315-
let mut distances = Vec::new();
310+
let mut distances = Vec::<RecordDistance>::with_capacity(self.fetch as usize);
316311
while distances.len() < self.fetch as usize {
317312
if let Some(Reverse((rec, idx))) = heap_dist.pop() {
318-
distances.push(rec);
313+
if distances.last().is_none()
314+
|| distances
315+
.last()
316+
.is_some_and(|last_rec| last_rec.offset_id != rec.offset_id)
317+
{
318+
distances.push(rec);
319+
}
319320
if let Some(next_rec) = batch_iters
320321
.get_mut(idx)
321322
.expect("Enumerated index should be valid")

rust/worker/src/execution/operators/knn_merge.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ mod tests {
100100
.iter()
101101
.map(|record| record.offset_id)
102102
.collect::<Vec<_>>(),
103-
vec![1, 2, 3, 4, 7, 7, 10, 10, 12, 13]
103+
vec![1, 2, 3, 4, 7, 10, 12, 13, 16, 17]
104104
);
105105
}
106106
}

0 commit comments

Comments
 (0)