Skip to content

Commit d2b1d4c

Browse files
author
sicheng
committed
Simplify get orchestrator
1 parent eb439cb commit d2b1d4c

File tree

4 files changed

+89
-149
lines changed

4 files changed

+89
-149
lines changed

rust/types/src/execution/operator.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,12 +435,12 @@ impl TryFrom<ProjectionRecord> for chroma_proto::ProjectionRecord {
435435
}
436436
}
437437

438-
#[derive(Clone, Debug, Eq, PartialEq)]
438+
#[derive(Clone, Debug, Default, Eq, PartialEq)]
439439
pub struct ProjectionOutput {
440440
pub records: Vec<ProjectionRecord>,
441441
}
442442

443-
#[derive(Clone, Debug, Eq, PartialEq)]
443+
#[derive(Clone, Debug, Default, Eq, PartialEq)]
444444
pub struct GetResult {
445445
pub pulled_log_bytes: u64,
446446
pub result: ProjectionOutput,

rust/worker/src/execution/orchestration/get.rs

Lines changed: 57 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,21 @@ use chroma_system::{
55
wrap, ChannelError, ComponentContext, ComponentHandle, Dispatcher, Handler, Orchestrator,
66
PanicError, TaskError, TaskMessage, TaskResult,
77
};
8-
use chroma_types::{
9-
operator::{Filter, GetResult, Limit, Projection, ProjectionOutput},
10-
CollectionAndSegments,
11-
};
8+
use chroma_types::operator::{GetResult, Limit, Projection, ProjectionOutput};
129
use thiserror::Error;
1310
use tokio::sync::oneshot::{error::RecvError, Sender};
1411
use tracing::Span;
1512

1613
use crate::execution::operators::{
17-
fetch_log::{FetchLogError, FetchLogOperator, FetchLogOutput},
18-
filter::{FilterError, FilterInput, FilterOutput},
14+
fetch_log::FetchLogError,
15+
filter::FilterError,
1916
limit::{LimitError, LimitInput, LimitOutput},
2017
prefetch_record::{PrefetchRecordError, PrefetchRecordOperator, PrefetchRecordOutput},
2118
projection::{ProjectionError, ProjectionInput},
2219
};
2320

21+
use super::filter::FilterOrchestratorOutput;
22+
2423
#[derive(Error, Debug)]
2524
pub enum GetError {
2625
#[error("Error sending message through channel: {0}")]
@@ -69,51 +68,47 @@ where
6968
}
7069
}
7170

72-
/// The `GetOrchestrator` chains a sequence of operators in sequence to evaluate
73-
/// a `<collection>.get(...)` query from the user
71+
/// The `GetOrchestrator` chains a sequence of operators in sequence to get data for user.
72+
/// When used together with `FilterOrchestrator`, they evaluate a `<collection>.get(...)` query
7473
///
7574
/// # Pipeline
7675
/// ```text
77-
/// ┌────────────┐
78-
/// │ │
79-
/// │ on_start │
80-
/// │ │
81-
/// └──────┬─────┘
82-
/// │
83-
/// ▼
84-
/// ┌────────────────────┐
85-
/// │ │
86-
/// │ FetchLogOperator │
87-
/// │ │
88-
/// └─────────┬──────────┘
89-
/// │
90-
/// ▼
91-
/// ┌───────────────────┐
92-
/// │ │
93-
/// │ FilterOperator │
94-
/// │ │
95-
/// └─────────┬─────────┘
9676
/// │
97-
/// ▼
98-
/// ┌─────────────────┐
99-
/// │ │
100-
/// │ LimitOperator │
101-
/// │ │
102-
/// └────────┬────────┘
10377
/// │
10478
/// ▼
105-
/// ──────────────────────┐
106-
///
107-
/// ProjectionOperator
108-
///
109-
/// ──────────┬───────────┘
79+
/// ┌───────────────────────┐
80+
///
81+
/// │ FilterOrchestrator
82+
///
83+
/// └───────────┬───────────┘
11084
/// │
85+
/// ┌────────── │ ────────────┐
86+
/// │ │ Get │
87+
/// │ │ Orchestrator │
88+
/// │ ▼ │
89+
/// │ ┌─────────────────┐ │
90+
/// │ │ │ │
91+
/// │ │ LimitOperator │ │
92+
/// │ │ │ │
93+
/// │ └────────┬────────┘ │
94+
/// │ │ │
95+
/// │ ▼ │
96+
/// │ ┌──────────────────────┐ │
97+
/// │ │ │ │
98+
/// │ │ ProjectionOperator │ │
99+
/// │ │ │ │
100+
/// │ └──────────┬───────────┘ │
101+
/// │ │ │
102+
/// │ ▼ │
103+
/// │ ┌──────────────────┐ │
104+
/// │ │ │ │
105+
/// │ │ result_channel │ │
106+
/// │ │ │ │
107+
/// │ └──────────────────┘ │
108+
/// │ │ │
109+
/// └────────── │ ────────────┘
111110
/// ▼
112-
/// ┌──────────────────┐
113-
/// │ │
114-
/// │ result_channel │
115-
/// │ │
116-
/// └──────────────────┘
111+
///
117112
/// ```
118113
#[derive(Debug)]
119114
pub struct GetOrchestrator {
@@ -122,17 +117,10 @@ pub struct GetOrchestrator {
122117
dispatcher: ComponentHandle<Dispatcher>,
123118
queue: usize,
124119

125-
// Collection segments
126-
collection_and_segments: CollectionAndSegments,
127-
128-
// Fetch logs
129-
fetch_log: FetchLogOperator,
130-
131-
// Fetched logs
132-
fetched_logs: Option<FetchLogOutput>,
120+
// Output from FilterOrchestrator
121+
filter_output: FilterOrchestratorOutput,
133122

134123
// Pipelined operators
135-
filter: Filter,
136124
limit: Limit,
137125
projection: Projection,
138126

@@ -146,20 +134,15 @@ impl GetOrchestrator {
146134
blockfile_provider: BlockfileProvider,
147135
dispatcher: ComponentHandle<Dispatcher>,
148136
queue: usize,
149-
collection_and_segments: CollectionAndSegments,
150-
fetch_log: FetchLogOperator,
151-
filter: Filter,
137+
filter_output: FilterOrchestratorOutput,
152138
limit: Limit,
153139
projection: Projection,
154140
) -> Self {
155141
Self {
156142
blockfile_provider,
157143
dispatcher,
158144
queue,
159-
collection_and_segments,
160-
fetch_log,
161-
fetched_logs: None,
162-
filter,
145+
filter_output,
163146
limit,
164147
projection,
165148
result_channel: None,
@@ -181,7 +164,17 @@ impl Orchestrator for GetOrchestrator {
181164
ctx: &ComponentContext<Self>,
182165
) -> Vec<(TaskMessage, Option<Span>)> {
183166
vec![(
184-
wrap(Box::new(self.fetch_log.clone()), (), ctx.receiver()),
167+
wrap(
168+
Box::new(self.limit.clone()),
169+
LimitInput {
170+
logs: self.filter_output.logs.clone(),
171+
blockfile_provider: self.blockfile_provider.clone(),
172+
record_segment: self.filter_output.record_segment.clone(),
173+
log_offset_ids: self.filter_output.filter_output.log_offset_ids.clone(),
174+
compact_offset_ids: self.filter_output.filter_output.compact_offset_ids.clone(),
175+
},
176+
ctx.receiver(),
177+
),
185178
Some(Span::current()),
186179
)]
187180
}
@@ -201,68 +194,6 @@ impl Orchestrator for GetOrchestrator {
201194
}
202195
}
203196

204-
#[async_trait]
205-
impl Handler<TaskResult<FetchLogOutput, FetchLogError>> for GetOrchestrator {
206-
type Result = ();
207-
208-
async fn handle(
209-
&mut self,
210-
message: TaskResult<FetchLogOutput, FetchLogError>,
211-
ctx: &ComponentContext<Self>,
212-
) {
213-
let output = match self.ok_or_terminate(message.into_inner(), ctx).await {
214-
Some(output) => output,
215-
None => return,
216-
};
217-
218-
self.fetched_logs = Some(output.clone());
219-
220-
let task = wrap(
221-
Box::new(self.filter.clone()),
222-
FilterInput {
223-
logs: output,
224-
blockfile_provider: self.blockfile_provider.clone(),
225-
metadata_segment: self.collection_and_segments.metadata_segment.clone(),
226-
record_segment: self.collection_and_segments.record_segment.clone(),
227-
},
228-
ctx.receiver(),
229-
);
230-
self.send(task, ctx, Some(Span::current())).await;
231-
}
232-
}
233-
234-
#[async_trait]
235-
impl Handler<TaskResult<FilterOutput, FilterError>> for GetOrchestrator {
236-
type Result = ();
237-
238-
async fn handle(
239-
&mut self,
240-
message: TaskResult<FilterOutput, FilterError>,
241-
ctx: &ComponentContext<Self>,
242-
) {
243-
let output = match self.ok_or_terminate(message.into_inner(), ctx).await {
244-
Some(output) => output,
245-
None => return,
246-
};
247-
let task = wrap(
248-
Box::new(self.limit.clone()),
249-
LimitInput {
250-
logs: self
251-
.fetched_logs
252-
.as_ref()
253-
.expect("FetchLogOperator should have finished already")
254-
.clone(),
255-
blockfile_provider: self.blockfile_provider.clone(),
256-
record_segment: self.collection_and_segments.record_segment.clone(),
257-
log_offset_ids: output.log_offset_ids,
258-
compact_offset_ids: output.compact_offset_ids,
259-
},
260-
ctx.receiver(),
261-
);
262-
self.send(task, ctx, Some(Span::current())).await;
263-
}
264-
}
265-
266197
#[async_trait]
267198
impl Handler<TaskResult<LimitOutput, LimitError>> for GetOrchestrator {
268199
type Result = ();
@@ -278,13 +209,9 @@ impl Handler<TaskResult<LimitOutput, LimitError>> for GetOrchestrator {
278209
};
279210

280211
let input = ProjectionInput {
281-
logs: self
282-
.fetched_logs
283-
.as_ref()
284-
.expect("FetchLogOperator should have finished already")
285-
.clone(),
212+
logs: self.filter_output.logs.clone(),
286213
blockfile_provider: self.blockfile_provider.clone(),
287-
record_segment: self.collection_and_segments.record_segment.clone(),
214+
record_segment: self.filter_output.record_segment.clone(),
288215
offset_ids: output.offset_ids.iter().collect(),
289216
};
290217

@@ -332,9 +259,8 @@ impl Handler<TaskResult<ProjectionOutput, ProjectionError>> for GetOrchestrator
332259
};
333260

334261
let pulled_log_bytes = self
335-
.fetched_logs
336-
.as_ref()
337-
.expect("FetchLogOperator should have finished already")
262+
.filter_output
263+
.logs
338264
.iter()
339265
.map(|(l, _)| l.size_bytes())
340266
.sum();

rust/worker/src/execution/orchestration/knn_hnsw.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ pub struct KnnHnswOrchestrator {
163163
dispatcher: ComponentHandle<Dispatcher>,
164164
queue: usize,
165165

166-
// Output from KnnFilterOrchestrator
166+
// Output from FilterOrchestrator
167167
filter_output: FilterOrchestratorOutput,
168168

169169
// Knn operator shared between log and segments

rust/worker/src/server.rs

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -211,30 +211,44 @@ impl WorkerServer {
211211
.projection
212212
.ok_or(Status::invalid_argument("Invalid Projection Operator"))?;
213213

214-
let get_orchestrator = GetOrchestrator::new(
214+
// If dimension is not set and segment is uninitialized, we assume
215+
// this is a get on empty collection, so we return early here
216+
if collection_and_segments.collection.dimension.is_none()
217+
&& collection_and_segments.vector_segment.file_path.is_empty()
218+
{
219+
return Ok(Response::new(GetResult::default().try_into()?));
220+
}
221+
222+
let filter_orchestrator = FilterOrchestrator::new(
215223
self.blockfile_provider.clone(),
216224
self.clone_dispatcher()?,
225+
self.hnsw_index_provider.clone(),
217226
// TODO: Make this configurable
218227
1000,
219-
collection_and_segments,
228+
collection_and_segments.clone(),
220229
fetch_log,
221230
filter.try_into()?,
231+
);
232+
233+
let matching_records = match filter_orchestrator.run(self.system.clone()).await {
234+
Ok(output) => output,
235+
Err(e) => {
236+
return Err(Status::new(e.code().into(), e.to_string()));
237+
}
238+
};
239+
240+
let get_orchestrator = GetOrchestrator::new(
241+
self.blockfile_provider.clone(),
242+
self.clone_dispatcher()?,
243+
// TODO: Make this configurable
244+
1000,
245+
matching_records,
222246
limit.into(),
223247
projection.into(),
224248
);
225249

226250
match get_orchestrator.run(self.system.clone()).await {
227-
Ok(GetResult {
228-
pulled_log_bytes,
229-
result,
230-
}) => Ok(Response::new(chroma_proto::GetResult {
231-
records: result
232-
.records
233-
.into_iter()
234-
.map(TryInto::try_into)
235-
.collect::<Result<_, _>>()?,
236-
pulled_log_bytes,
237-
})),
251+
Ok(result) => Ok(Response::new(result.try_into()?)),
238252
Err(err) => Err(Status::new(err.code().into(), err.to_string())),
239253
}
240254
}
@@ -289,7 +303,7 @@ impl WorkerServer {
289303
}
290304

291305
let vector_segment_type = collection_and_segments.vector_segment.r#type;
292-
let knn_filter_orchestrator = FilterOrchestrator::new(
306+
let filter_orchestrator = FilterOrchestrator::new(
293307
self.blockfile_provider.clone(),
294308
dispatcher.clone(),
295309
self.hnsw_index_provider.clone(),
@@ -300,7 +314,7 @@ impl WorkerServer {
300314
filter.try_into()?,
301315
);
302316

303-
let matching_records = match knn_filter_orchestrator.run(system.clone()).await {
317+
let matching_records = match filter_orchestrator.run(system.clone()).await {
304318
Ok(output) => output,
305319
Err(e) => {
306320
return Err(Status::new(e.code().into(), e.to_string()));

0 commit comments

Comments
 (0)