Skip to content

Commit 9af4da8

Browse files
author
sicheng
committed
Fix test and benches
1 parent d2b1d4c commit 9af4da8

File tree

2 files changed

+91
-103
lines changed

2 files changed

+91
-103
lines changed

rust/worker/benches/get.rs

Lines changed: 59 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -5,110 +5,52 @@ use chroma_benchmark::benchmark::{bench_run, tokio_multi_thread};
55
use chroma_config::{registry::Registry, Configurable};
66
use chroma_segment::test::TestDistributedSegment;
77
use chroma_system::{ComponentHandle, Dispatcher, Orchestrator, System};
8+
use chroma_types::operator::{Filter, Limit, Projection};
89
use criterion::{criterion_group, criterion_main, Criterion};
910
use load::{
1011
all_projection, always_false_filter_for_modulo_metadata,
1112
always_true_filter_for_modulo_metadata, empty_fetch_log, offset_limit, sift1m_segments,
1213
trivial_filter, trivial_limit, trivial_projection,
1314
};
14-
use worker::{config::RootConfig, execution::orchestration::get::GetOrchestrator};
15-
16-
fn trivial_get(
17-
test_segments: &TestDistributedSegment,
18-
dispatcher_handle: ComponentHandle<Dispatcher>,
19-
) -> GetOrchestrator {
20-
let blockfile_provider = test_segments.blockfile_provider.clone();
21-
let collection_uuid = test_segments.collection.collection_id;
22-
GetOrchestrator::new(
23-
blockfile_provider,
24-
dispatcher_handle,
25-
1000,
26-
test_segments.into(),
27-
empty_fetch_log(collection_uuid),
28-
trivial_filter(),
29-
trivial_limit(),
30-
trivial_projection(),
31-
)
32-
}
33-
34-
fn get_false_filter(
35-
test_segments: &TestDistributedSegment,
36-
dispatcher_handle: ComponentHandle<Dispatcher>,
37-
) -> GetOrchestrator {
38-
let blockfile_provider = test_segments.blockfile_provider.clone();
39-
let collection_uuid = test_segments.collection.collection_id;
40-
GetOrchestrator::new(
41-
blockfile_provider,
42-
dispatcher_handle,
43-
1000,
44-
test_segments.into(),
45-
empty_fetch_log(collection_uuid),
46-
always_false_filter_for_modulo_metadata(),
47-
trivial_limit(),
48-
trivial_projection(),
49-
)
50-
}
51-
52-
fn get_true_filter(
53-
test_segments: &TestDistributedSegment,
54-
dispatcher_handle: ComponentHandle<Dispatcher>,
55-
) -> GetOrchestrator {
56-
let blockfile_provider = test_segments.blockfile_provider.clone();
57-
let collection_uuid = test_segments.collection.collection_id;
58-
GetOrchestrator::new(
59-
blockfile_provider,
60-
dispatcher_handle,
61-
1000,
62-
test_segments.into(),
63-
empty_fetch_log(collection_uuid),
64-
always_true_filter_for_modulo_metadata(),
65-
trivial_limit(),
66-
trivial_projection(),
67-
)
68-
}
15+
use worker::{
16+
config::RootConfig,
17+
execution::orchestration::{filter::FilterOrchestrator, get::GetOrchestrator},
18+
};
6919

70-
fn get_true_filter_limit(
71-
test_segments: &TestDistributedSegment,
72-
dispatcher_handle: ComponentHandle<Dispatcher>,
73-
) -> GetOrchestrator {
74-
let blockfile_provider = test_segments.blockfile_provider.clone();
75-
let collection_uuid = test_segments.collection.collection_id;
76-
GetOrchestrator::new(
77-
blockfile_provider,
78-
dispatcher_handle,
20+
async fn bench_routine(
21+
(system, dispatcher, test_segments, filter, limit, projection, expected_ids): (
22+
System,
23+
ComponentHandle<Dispatcher>,
24+
&TestDistributedSegment,
25+
Filter,
26+
Limit,
27+
Projection,
28+
Vec<String>,
29+
),
30+
) {
31+
let matching_records = FilterOrchestrator::new(
32+
test_segments.blockfile_provider.clone(),
33+
dispatcher.clone(),
34+
test_segments.hnsw_provider.clone(),
7935
1000,
8036
test_segments.into(),
81-
empty_fetch_log(collection_uuid),
82-
always_true_filter_for_modulo_metadata(),
83-
offset_limit(),
84-
trivial_projection(),
37+
empty_fetch_log(test_segments.collection.collection_id),
38+
filter,
8539
)
86-
}
87-
88-
fn get_true_filter_limit_projection(
89-
test_segments: &TestDistributedSegment,
90-
dispatcher_handle: ComponentHandle<Dispatcher>,
91-
) -> GetOrchestrator {
92-
let blockfile_provider = test_segments.blockfile_provider.clone();
93-
let collection_uuid = test_segments.collection.collection_id;
94-
GetOrchestrator::new(
95-
blockfile_provider,
96-
dispatcher_handle,
40+
.run(system.clone())
41+
.await
42+
.expect("Filter orchestrator should not fail");
43+
let output = GetOrchestrator::new(
44+
test_segments.blockfile_provider.clone(),
45+
dispatcher,
9746
1000,
98-
test_segments.into(),
99-
empty_fetch_log(collection_uuid),
100-
always_true_filter_for_modulo_metadata(),
101-
offset_limit(),
102-
all_projection(),
47+
matching_records,
48+
limit,
49+
projection,
10350
)
104-
}
105-
106-
async fn bench_routine(input: (System, GetOrchestrator, Vec<String>)) {
107-
let (system, orchestrator, expected_ids) = input;
108-
let output = orchestrator
109-
.run(system)
110-
.await
111-
.expect("Orchestrator should not fail");
51+
.run(system)
52+
.await
53+
.expect("Get orchestrator should not fail");
11254
assert_eq!(
11355
output
11456
.result
@@ -138,35 +80,55 @@ fn bench_get(criterion: &mut Criterion) {
13880
let trivial_get_setup = || {
13981
(
14082
system.clone(),
141-
trivial_get(&test_segments, dispatcher_handle.clone()),
83+
dispatcher_handle.clone(),
84+
&test_segments,
85+
trivial_filter(),
86+
trivial_limit(),
87+
trivial_projection(),
14288
(0..100).map(|id| id.to_string()).collect(),
14389
)
14490
};
14591
let get_false_filter_setup = || {
14692
(
14793
system.clone(),
148-
get_false_filter(&test_segments, dispatcher_handle.clone()),
94+
dispatcher_handle.clone(),
95+
&test_segments,
96+
always_false_filter_for_modulo_metadata(),
97+
trivial_limit(),
98+
trivial_projection(),
14999
Vec::new(),
150100
)
151101
};
152102
let get_true_filter_setup = || {
153103
(
154104
system.clone(),
155-
get_true_filter(&test_segments, dispatcher_handle.clone()),
105+
dispatcher_handle.clone(),
106+
&test_segments,
107+
always_true_filter_for_modulo_metadata(),
108+
trivial_limit(),
109+
trivial_projection(),
156110
(0..100).map(|id| id.to_string()).collect(),
157111
)
158112
};
159113
let get_true_filter_limit_setup = || {
160114
(
161115
system.clone(),
162-
get_true_filter_limit(&test_segments, dispatcher_handle.clone()),
116+
dispatcher_handle.clone(),
117+
&test_segments,
118+
always_true_filter_for_modulo_metadata(),
119+
offset_limit(),
120+
trivial_projection(),
163121
(100..200).map(|id| id.to_string()).collect(),
164122
)
165123
};
166124
let get_true_filter_limit_projection_setup = || {
167125
(
168126
system.clone(),
169-
get_true_filter_limit_projection(&test_segments, dispatcher_handle.clone()),
127+
dispatcher_handle.clone(),
128+
&test_segments,
129+
always_true_filter_for_modulo_metadata(),
130+
offset_limit(),
131+
all_projection(),
170132
(100..200).map(|id| id.to_string()).collect(),
171133
)
172134
};

rust/worker/src/execution/orchestration/compact.rs

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,7 +1058,10 @@ mod tests {
10581058

10591059
use crate::{
10601060
config::RootConfig,
1061-
execution::{operators::fetch_log::FetchLogOperator, orchestration::get::GetOrchestrator},
1061+
execution::{
1062+
operators::fetch_log::FetchLogOperator,
1063+
orchestration::{filter::FilterOrchestrator, get::GetOrchestrator},
1064+
},
10621065
};
10631066

10641067
use super::CompactOrchestrator;
@@ -1156,6 +1159,19 @@ mod tests {
11561159
}),
11571160
])),
11581161
};
1162+
let filter_orchestrator = FilterOrchestrator::new(
1163+
test_segments.blockfile_provider.clone(),
1164+
dispatcher_handle.clone(),
1165+
test_segments.hnsw_provider.clone(),
1166+
1000,
1167+
old_cas.clone(),
1168+
fetch_log.clone(),
1169+
filter.clone(),
1170+
);
1171+
let matching_records = filter_orchestrator
1172+
.run(system.clone())
1173+
.await
1174+
.expect("Filter orchestrator should not fail");
11591175
let limit = Limit {
11601176
skip: 0,
11611177
fetch: None,
@@ -1169,9 +1185,7 @@ mod tests {
11691185
test_segments.blockfile_provider.clone(),
11701186
dispatcher_handle.clone(),
11711187
1000,
1172-
old_cas.clone(),
1173-
fetch_log.clone(),
1174-
filter.clone(),
1188+
matching_records,
11751189
limit.clone(),
11761190
project.clone(),
11771191
);
@@ -1231,13 +1245,25 @@ mod tests {
12311245
old_cas.vector_segment.file_path
12321246
);
12331247

1234-
let get_orchestrator = GetOrchestrator::new(
1248+
let filter_orchestrator = FilterOrchestrator::new(
12351249
test_segments.blockfile_provider.clone(),
1236-
dispatcher_handle,
1250+
dispatcher_handle.clone(),
1251+
test_segments.hnsw_provider.clone(),
12371252
1000,
12381253
new_cas,
12391254
fetch_log,
12401255
filter,
1256+
);
1257+
let matching_records = filter_orchestrator
1258+
.run(system.clone())
1259+
.await
1260+
.expect("Filter orchestrator should not fail");
1261+
1262+
let get_orchestrator = GetOrchestrator::new(
1263+
test_segments.blockfile_provider.clone(),
1264+
dispatcher_handle,
1265+
1000,
1266+
matching_records,
12411267
limit,
12421268
project,
12431269
);

0 commit comments

Comments
 (0)