Skip to content

Commit 70d88f2

Browse files
author
sicheng
committed
Fix tests and benches
1 parent d590ff7 commit 70d88f2

File tree

14 files changed

+123
-140
lines changed

14 files changed

+123
-140
lines changed

rust/worker/benches/filter.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@ use chroma_benchmark::benchmark::{bench_run, tokio_multi_thread};
44
use chroma_log::test::{upsert_generator, LoadFromGenerator};
55
use chroma_segment::test::TestDistributedSegment;
66
use chroma_system::Operator;
7+
use chroma_types::operator::Filter;
78
use chroma_types::{
89
BooleanOperator, Chunk, CompositeExpression, MetadataComparison, MetadataExpression,
910
MetadataValue, PrimitiveOperator, Where,
1011
};
1112
use criterion::Criterion;
1213
use criterion::{criterion_group, criterion_main};
13-
use worker::execution::operators::filter::{FilterInput, FilterOperator};
14+
use worker::execution::operators::filter::FilterInput;
1415

1516
fn baseline_where_clauses() -> Vec<(&'static str, Option<Where>)> {
1617
use BooleanOperator::*;
@@ -89,12 +90,12 @@ fn bench_filter(criterion: &mut Criterion) {
8990
};
9091

9192
for (op, where_clause) in baseline_where_clauses() {
92-
let filter_operator = FilterOperator {
93+
let filter_operator = Filter {
9394
query_ids: None,
9495
where_clause: where_clause.clone(),
9596
};
9697

97-
let routine = |(op, input): (FilterOperator, FilterInput)| async move {
98+
let routine = |(op, input): (Filter, FilterInput)| async move {
9899
op.run(&input)
99100
.await
100101
.expect("FilterOperator should not fail");

rust/worker/benches/get.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ async fn bench_routine(input: (System, GetOrchestrator, Vec<String>)) {
111111
.expect("Orchestrator should not fail");
112112
assert_eq!(
113113
output
114-
.0
114+
.result
115115
.records
116116
.into_iter()
117117
.map(|record| record.id)

rust/worker/benches/limit.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ use chroma_benchmark::benchmark::{bench_run, tokio_multi_thread};
22
use chroma_log::test::{upsert_generator, LoadFromGenerator};
33
use chroma_segment::test::TestDistributedSegment;
44
use chroma_system::Operator;
5+
use chroma_types::operator::Limit;
56
use chroma_types::{Chunk, SignedRoaringBitmap};
67
use criterion::Criterion;
78
use criterion::{criterion_group, criterion_main};
8-
use worker::execution::operators::limit::{LimitInput, LimitOperator};
9+
use worker::execution::operators::limit::LimitInput;
910

1011
const FETCH: usize = 100;
1112

@@ -30,12 +31,12 @@ fn bench_limit(criterion: &mut Criterion) {
3031
};
3132

3233
for offset in [0, record_count / 2, record_count - FETCH] {
33-
let limit_operator = LimitOperator {
34+
let limit_operator = Limit {
3435
skip: offset as u32,
3536
fetch: Some(FETCH as u32),
3637
};
3738

38-
let routine = |(op, input): (LimitOperator, LimitInput)| async move {
39+
let routine = |(op, input): (Limit, LimitInput)| async move {
3940
op.run(&input).await.expect("LimitOperator should not fail");
4041
};
4142

rust/worker/benches/load.rs

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@ use chroma_benchmark::datasets::sift::Sift1MData;
22
use chroma_log::{in_memory_log::InMemoryLog, test::modulo_metadata, Log};
33
use chroma_segment::test::TestDistributedSegment;
44
use chroma_types::{
5+
operator::{Filter, Limit, Projection},
56
Chunk, CollectionUuid, LogRecord, MetadataComparison, MetadataExpression, MetadataSetValue,
67
Operation, OperationRecord, SetOperator, Where,
78
};
89
use indicatif::ProgressIterator;
9-
use worker::execution::operators::{
10-
fetch_log::FetchLogOperator, filter::FilterOperator, limit::LimitOperator,
11-
projection::ProjectionOperator,
12-
};
10+
use worker::execution::operators::fetch_log::FetchLogOperator;
1311

1412
const DATA_CHUNK_SIZE: usize = 10000;
1513

@@ -62,15 +60,15 @@ pub fn empty_fetch_log(collection_uuid: CollectionUuid) -> FetchLogOperator {
6260
}
6361
}
6462

65-
pub fn trivial_filter() -> FilterOperator {
66-
FilterOperator {
63+
pub fn trivial_filter() -> Filter {
64+
Filter {
6765
query_ids: None,
6866
where_clause: None,
6967
}
7068
}
7169

72-
pub fn always_false_filter_for_modulo_metadata() -> FilterOperator {
73-
FilterOperator {
70+
pub fn always_false_filter_for_modulo_metadata() -> Filter {
71+
Filter {
7472
query_ids: None,
7573
where_clause: Some(Where::disjunction(vec![
7674
Where::Metadata(MetadataExpression {
@@ -91,8 +89,8 @@ pub fn always_false_filter_for_modulo_metadata() -> FilterOperator {
9189
}
9290
}
9391

94-
pub fn always_true_filter_for_modulo_metadata() -> FilterOperator {
95-
FilterOperator {
92+
pub fn always_true_filter_for_modulo_metadata() -> Filter {
93+
Filter {
9694
query_ids: None,
9795
where_clause: Some(Where::conjunction(vec![
9896
Where::Metadata(MetadataExpression {
@@ -113,30 +111,30 @@ pub fn always_true_filter_for_modulo_metadata() -> FilterOperator {
113111
}
114112
}
115113

116-
pub fn trivial_limit() -> LimitOperator {
117-
LimitOperator {
114+
pub fn trivial_limit() -> Limit {
115+
Limit {
118116
skip: 0,
119117
fetch: Some(100),
120118
}
121119
}
122120

123-
pub fn offset_limit() -> LimitOperator {
124-
LimitOperator {
121+
pub fn offset_limit() -> Limit {
122+
Limit {
125123
skip: 100,
126124
fetch: Some(100),
127125
}
128126
}
129127

130-
pub fn trivial_projection() -> ProjectionOperator {
131-
ProjectionOperator {
128+
pub fn trivial_projection() -> Projection {
129+
Projection {
132130
document: false,
133131
embedding: false,
134132
metadata: false,
135133
}
136134
}
137135

138-
pub fn all_projection() -> ProjectionOperator {
139-
ProjectionOperator {
136+
pub fn all_projection() -> Projection {
137+
Projection {
140138
document: true,
141139
embedding: true,
142140
metadata: true,

rust/worker/benches/query.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use chroma_benchmark::{
88
use chroma_config::{registry::Registry, Configurable};
99
use chroma_segment::test::TestDistributedSegment;
1010
use chroma_system::{ComponentHandle, Dispatcher, Orchestrator, System};
11+
use chroma_types::operator::{Knn, KnnProjection};
1112
use criterion::{criterion_group, criterion_main, Criterion};
1213
use futures::{stream, StreamExt, TryStreamExt};
1314
use load::{
@@ -17,12 +18,9 @@ use load::{
1718
use rand::{seq::SliceRandom, thread_rng};
1819
use worker::{
1920
config::RootConfig,
20-
execution::{
21-
operators::{knn::KnnOperator, knn_projection::KnnProjectionOperator},
22-
orchestration::{
23-
knn::KnnOrchestrator,
24-
knn_filter::{KnnFilterOrchestrator, KnnFilterOutput},
25-
},
21+
execution::orchestration::{
22+
knn::KnnOrchestrator,
23+
knn_filter::{KnnFilterOrchestrator, KnnFilterOutput},
2624
},
2725
};
2826

@@ -91,11 +89,11 @@ fn knn(
9189
dispatcher_handle.clone(),
9290
1000,
9391
knn_filter_output.clone(),
94-
KnnOperator {
92+
Knn {
9593
embedding: query,
9694
fetch: Sift1MData::k() as u32,
9795
},
98-
KnnProjectionOperator {
96+
KnnProjection {
9997
projection: all_projection(),
10098
distance: true,
10199
},

rust/worker/benches/regex.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use chroma_benchmark::datasets::wikipedia::WikipediaDataset;
77
use chroma_log::test::{int_as_id, random_embedding};
88
use chroma_segment::test::TestDistributedSegment;
99
use chroma_system::Operator;
10+
use chroma_types::operator::Filter;
1011
use chroma_types::{
1112
Chunk, DocumentExpression, DocumentOperator, LogRecord, Operation, OperationRecord,
1213
ScalarEncoding, SignedRoaringBitmap, Where,
@@ -18,7 +19,7 @@ use indicatif::ProgressIterator;
1819
use regex::Regex;
1920
use roaring::RoaringBitmap;
2021
use tokio::time::Instant;
21-
use worker::execution::operators::filter::{FilterInput, FilterOperator};
22+
use worker::execution::operators::filter::FilterInput;
2223

2324
const LOG_CHUNK_SIZE: usize = 10000;
2425
const DOCUMENT_SIZE: usize = 100000;
@@ -110,7 +111,7 @@ fn bench_regex(criterion: &mut Criterion) {
110111
};
111112

112113
for pattern in REGEX_PATTERNS {
113-
let filter_operator = FilterOperator {
114+
let filter_operator = Filter {
114115
query_ids: None,
115116
where_clause: Some(Where::Document(DocumentExpression {
116117
operator: DocumentOperator::Regex,
@@ -119,7 +120,7 @@ fn bench_regex(criterion: &mut Criterion) {
119120
};
120121

121122
let routine = |(op, input, expected): (
122-
FilterOperator,
123+
Filter,
123124
FilterInput,
124125
HashMap<String, RoaringBitmap>,
125126
)| async move {

rust/worker/benches/spann.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,14 @@ use chroma_index::{
2020
};
2121
use chroma_storage::{local::LocalStorage, Storage};
2222
use chroma_system::Operator;
23-
use chroma_types::{CollectionUuid, InternalSpannConfiguration};
23+
use chroma_types::{operator::KnnMerge, CollectionUuid, InternalSpannConfiguration};
2424
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
2525
use futures::StreamExt;
2626
use rand::seq::SliceRandom;
2727
use roaring::RoaringBitmap;
2828
use worker::execution::operators::{
29+
knn_merge::KnnMergeInput,
2930
spann_bf_pl::{SpannBfPlInput, SpannBfPlOperator},
30-
spann_knn_merge::{SpannKnnMergeInput, SpannKnnMergeOperator},
3131
};
3232

3333
fn get_records(runtime: &tokio::runtime::Runtime) -> Vec<(u32, Vec<f32>)> {
@@ -199,10 +199,10 @@ fn calculate_recall<'a>(
199199
merge_list.push(bf_output.records);
200200
}
201201
// Now merge.
202-
let knn_input = SpannKnnMergeInput {
203-
records: merge_list,
202+
let knn_input = KnnMergeInput {
203+
batch_distances: merge_list,
204204
};
205-
let knn_operator = SpannKnnMergeOperator { k: k as u32 };
205+
let knn_operator = KnnMerge { fetch: k as u32 };
206206
let knn_output = knn_operator
207207
.run(&knn_input)
208208
.await
@@ -233,7 +233,7 @@ fn calculate_recall<'a>(
233233
.expect("Error running operator");
234234
let mut recall = 0;
235235
for bf_record in bf_output.records.iter() {
236-
for spann_record in knn_output.merged_records.iter() {
236+
for spann_record in knn_output.distances.iter() {
237237
if bf_record.offset_id == spann_record.offset_id {
238238
recall += 1;
239239
}

0 commit comments

Comments
 (0)