Skip to content

Commit fed08e1

Browse files
authored
fix(cubesql): Fix SortPushDown pushing sort through joins (#9464)
LogicalPlan::Join and CrossJoin do not preserve the ordering semantically When planned as HashJoin it will output batches in same order as they are coming from right stream But both Join and CrossJoin will have same partitioning as right input (even when repartition_joins disabled), and these partitions can be collected in arbitrary order by CoalescePartitions See https://github.com/apache/datafusion/blob/7.0.0/datafusion/src/physical_plan/hash_join.rs#L282-L284 See https://github.com/apache/datafusion/blob/7.0.0/datafusion/src/physical_plan/cross_join.rs#L141-L143 Side note: Substrait says that for both Join and Cross Product > Orderedness is empty post operation See https://substrait.io/relations/logical_relations/#join-operation
1 parent f90997c commit fed08e1

8 files changed

+93
-177
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
---
2+
source: cubesql/src/compile/engine/df/optimizers/sort_push_down.rs
3+
expression: optimize(&plan)
4+
---
5+
Projection: #j1.c1, #j2.c2
6+
Sort: #j1.c1 ASC NULLS LAST
7+
CrossJoin:
8+
Projection: #j1.key, #j1.c1
9+
TableScan: j1 projection=None
10+
Projection: #j2.key, #j2.c2
11+
TableScan: j2 projection=None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
---
2+
source: cubesql/src/compile/engine/df/optimizers/sort_push_down.rs
3+
expression: optimize(&plan)
4+
---
5+
Projection: #j1.c1, #j2.c2
6+
Sort: #j2.c2 ASC NULLS LAST
7+
CrossJoin:
8+
Projection: #j1.key, #j1.c1
9+
TableScan: j1 projection=None
10+
Projection: #j2.key, #j2.c2
11+
TableScan: j2 projection=None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
---
2+
source: cubesql/src/compile/engine/df/optimizers/sort_push_down.rs
3+
expression: optimize(&plan)
4+
---
5+
Projection: #j1.c1, #j2.c2
6+
Sort: #j1.c1 ASC NULLS LAST
7+
Inner Join: #j1.key = #j2.key
8+
Projection: #j1.key, #j1.c1
9+
TableScan: j1 projection=None
10+
Projection: #j2.key, #j2.c2
11+
TableScan: j2 projection=None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
---
2+
source: cubesql/src/compile/engine/df/optimizers/sort_push_down.rs
3+
expression: optimize(&plan)
4+
---
5+
Projection: #j1.c1, #j2.c2
6+
Sort: #j2.c2 ASC NULLS LAST
7+
Inner Join: #j1.key = #j2.key
8+
Projection: #j1.key, #j1.c1
9+
TableScan: j1 projection=None
10+
Projection: #j2.key, #j2.c2
11+
TableScan: j2 projection=None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
---
2+
source: cubesql/src/compile/engine/df/optimizers/sort_push_down.rs
3+
expression: optimize(&plan)
4+
---
5+
Projection: #t3.n3, #t3.n4, #t3.n2, alias=t4
6+
Projection: #t2.n1 AS n3, #t2.c2 AS n4, #t2.n2, alias=t3
7+
Projection: #t1.c1 AS n1, #t1.c2, #t1.c3 AS n2, alias=t2
8+
Sort: #t1.c2 ASC NULLS LAST, #t1.c3 DESC NULLS FIRST
9+
Projection: #t1.c1, #t1.c2, #t1.c3
10+
TableScan: t1 projection=None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
---
2+
source: cubesql/src/compile/engine/df/optimizers/sort_push_down.rs
3+
expression: optimize(&plan)
4+
---
5+
Projection: #t1.c1 AS n1, #t1.c2, #t1.c3 AS n2, alias=t2
6+
Sort: #t1.c2 ASC NULLS LAST, #t1.c3 DESC NULLS FIRST
7+
Projection: #t1.c1, #t1.c2, #t1.c3
8+
TableScan: t1 projection=None
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
---
2+
source: cubesql/src/compile/engine/df/optimizers/sort_push_down.rs
3+
expression: optimize(&plan)
4+
---
5+
Projection: #t3.n3, #t3.n4, #t3.n2, alias=t4
6+
Projection: #t2.n1 AS n3, #t2.c2 AS n4, #t2.n2, alias=t3
7+
Projection: #t1.c1 AS n1, #t1.c2, #t1.c3 AS n2, alias=t2
8+
Sort: #t1.c2 ASC NULLS LAST, #t1.c3 DESC NULLS FIRST
9+
Projection: #t1.c1, #t1.c2, #t1.c3
10+
TableScan: t1 projection=None

rust/cubesql/cubesql/src/compile/engine/df/optimizers/sort_push_down.rs

+21-177
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,13 @@ use std::{collections::HashMap, sync::Arc};
33
use datafusion::{
44
error::{DataFusionError, Result},
55
logical_plan::{
6-
plan::{
7-
Aggregate, CrossJoin, Distinct, Join, Limit, Projection, Sort, Subquery, Union, Window,
8-
},
6+
plan::{Aggregate, Distinct, Limit, Projection, Sort, Subquery, Union, Window},
97
Column, DFSchema, Expr, Filter, LogicalPlan,
108
},
119
optimizer::optimizer::{OptimizerConfig, OptimizerRule},
1210
};
1311

14-
use super::utils::{get_schema_columns, is_column_expr, plan_has_projections, rewrite};
12+
use super::utils::{is_column_expr, plan_has_projections, rewrite};
1513

1614
/// Sort Push Down optimizer rule pushes ORDER BY clauses consisting of specific,
1715
/// mostly simple, expressions down the plan, all the way to the Projection
@@ -167,97 +165,6 @@ fn sort_push_down(
167165
optimizer_config,
168166
)
169167
}
170-
LogicalPlan::Join(Join {
171-
left,
172-
right,
173-
on,
174-
join_type,
175-
join_constraint,
176-
schema,
177-
null_equals_null,
178-
}) => {
179-
// DataFusion preserves the sorting of the joined plans, prioritizing left side.
180-
// Taking this into account, we can push Sort down the left plan if Sort references
181-
// columns just from the left side.
182-
// TODO: check if this is still the case with multiple target partitions
183-
if let Some(some_sort_expr) = &sort_expr {
184-
let left_columns = get_schema_columns(left.schema());
185-
if some_sort_expr.iter().all(|expr| {
186-
if let Expr::Sort { expr, .. } = expr {
187-
if let Expr::Column(column) = expr.as_ref() {
188-
return left_columns.contains(column);
189-
}
190-
}
191-
false
192-
}) {
193-
return Ok(LogicalPlan::Join(Join {
194-
left: Arc::new(sort_push_down(
195-
optimizer,
196-
left,
197-
sort_expr,
198-
optimizer_config,
199-
)?),
200-
right: Arc::new(sort_push_down(optimizer, right, None, optimizer_config)?),
201-
on: on.clone(),
202-
join_type: *join_type,
203-
join_constraint: *join_constraint,
204-
schema: schema.clone(),
205-
null_equals_null: *null_equals_null,
206-
}));
207-
}
208-
}
209-
210-
issue_sort(
211-
sort_expr,
212-
LogicalPlan::Join(Join {
213-
left: Arc::new(sort_push_down(optimizer, left, None, optimizer_config)?),
214-
right: Arc::new(sort_push_down(optimizer, right, None, optimizer_config)?),
215-
on: on.clone(),
216-
join_type: *join_type,
217-
join_constraint: *join_constraint,
218-
schema: schema.clone(),
219-
null_equals_null: *null_equals_null,
220-
}),
221-
)
222-
}
223-
LogicalPlan::CrossJoin(CrossJoin {
224-
left,
225-
right,
226-
schema,
227-
}) => {
228-
// See `LogicalPlan::Join` notes above.
229-
if let Some(some_sort_expr) = &sort_expr {
230-
let left_columns = get_schema_columns(left.schema());
231-
if some_sort_expr.iter().all(|expr| {
232-
if let Expr::Sort { expr, .. } = expr {
233-
if let Expr::Column(column) = expr.as_ref() {
234-
return left_columns.contains(column);
235-
}
236-
}
237-
false
238-
}) {
239-
return Ok(LogicalPlan::CrossJoin(CrossJoin {
240-
left: Arc::new(sort_push_down(
241-
optimizer,
242-
left,
243-
sort_expr,
244-
optimizer_config,
245-
)?),
246-
right: Arc::new(sort_push_down(optimizer, right, None, optimizer_config)?),
247-
schema: schema.clone(),
248-
}));
249-
}
250-
}
251-
252-
issue_sort(
253-
sort_expr,
254-
LogicalPlan::CrossJoin(CrossJoin {
255-
left: Arc::new(sort_push_down(optimizer, left, None, optimizer_config)?),
256-
right: Arc::new(sort_push_down(optimizer, right, None, optimizer_config)?),
257-
schema: schema.clone(),
258-
}),
259-
)
260-
}
261168
LogicalPlan::Union(Union {
262169
inputs,
263170
schema,
@@ -384,15 +291,10 @@ mod tests {
384291
};
385292
use datafusion::logical_plan::{col, JoinType, LogicalPlanBuilder};
386293

387-
fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
294+
fn optimize(plan: &LogicalPlan) -> LogicalPlan {
388295
let rule = SortPushDown::new();
389296
rule.optimize(plan, &OptimizerConfig::new())
390-
}
391-
392-
fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) {
393-
let optimized_plan = optimize(&plan).expect("failed to optimize plan");
394-
let formatted_plan = format!("{:?}", optimized_plan);
395-
assert_eq!(formatted_plan, expected);
297+
.expect("failed to optimize plan")
396298
}
397299

398300
fn sort(expr: Expr, asc: bool, nulls_first: bool) -> Expr {
@@ -417,14 +319,7 @@ mod tests {
417319
])?
418320
.build()?;
419321

420-
let expected = "\
421-
Projection: #t1.c1 AS n1, #t1.c2, #t1.c3 AS n2, alias=t2\
422-
\n Sort: #t1.c2 ASC NULLS LAST, #t1.c3 DESC NULLS FIRST\
423-
\n Projection: #t1.c1, #t1.c2, #t1.c3\
424-
\n TableScan: t1 projection=None\
425-
";
426-
427-
assert_optimized_plan_eq(plan, expected);
322+
insta::assert_debug_snapshot!(optimize(&plan));
428323
Ok(())
429324
}
430325

@@ -450,16 +345,7 @@ mod tests {
450345
])?
451346
.build()?;
452347

453-
let expected = "\
454-
Projection: #t3.n3, #t3.n4, #t3.n2, alias=t4\
455-
\n Projection: #t2.n1 AS n3, #t2.c2 AS n4, #t2.n2, alias=t3\
456-
\n Projection: #t1.c1 AS n1, #t1.c2, #t1.c3 AS n2, alias=t2\
457-
\n Sort: #t1.c2 ASC NULLS LAST, #t1.c3 DESC NULLS FIRST\
458-
\n Projection: #t1.c1, #t1.c2, #t1.c3\
459-
\n TableScan: t1 projection=None\
460-
";
461-
462-
assert_optimized_plan_eq(plan, expected);
348+
insta::assert_debug_snapshot!(optimize(&plan));
463349
Ok(())
464350
}
465351

@@ -487,21 +373,12 @@ mod tests {
487373
])?
488374
.build()?;
489375

490-
let expected = "\
491-
Projection: #t3.n3, #t3.n4, #t3.n2, alias=t4\
492-
\n Projection: #t2.n1 AS n3, #t2.c2 AS n4, #t2.n2, alias=t3\
493-
\n Projection: #t1.c1 AS n1, #t1.c2, #t1.c3 AS n2, alias=t2\
494-
\n Sort: #t1.c2 ASC NULLS LAST, #t1.c3 DESC NULLS FIRST\
495-
\n Projection: #t1.c1, #t1.c2, #t1.c3\
496-
\n TableScan: t1 projection=None\
497-
";
498-
499-
assert_optimized_plan_eq(plan, expected);
376+
insta::assert_debug_snapshot!(optimize(&plan));
500377
Ok(())
501378
}
502379

503380
#[test]
504-
fn test_sort_down_join() -> Result<()> {
381+
fn test_sort_down_join_sort_left() -> Result<()> {
505382
let plan = LogicalPlanBuilder::from(
506383
LogicalPlanBuilder::from(make_sample_table("j1", vec!["key", "c1"], vec![])?)
507384
.project(vec![col("key"), col("c1")])?
@@ -521,18 +398,12 @@ mod tests {
521398
.sort(vec![sort(col("j1.c1"), true, false)])?
522399
.build()?;
523400

524-
let expected = "\
525-
Projection: #j1.c1, #j2.c2\
526-
\n Inner Join: #j1.key = #j2.key\
527-
\n Sort: #j1.c1 ASC NULLS LAST\
528-
\n Projection: #j1.key, #j1.c1\
529-
\n TableScan: j1 projection=None\
530-
\n Projection: #j2.key, #j2.c2\
531-
\n TableScan: j2 projection=None\
532-
";
533-
534-
assert_optimized_plan_eq(plan, expected);
401+
insta::assert_debug_snapshot!(optimize(&plan));
402+
Ok(())
403+
}
535404

405+
#[test]
406+
fn test_sort_down_join_sort_right() -> Result<()> {
536407
let plan = LogicalPlanBuilder::from(
537408
LogicalPlanBuilder::from(make_sample_table("j1", vec!["key", "c1"], vec![])?)
538409
.project(vec![col("key"), col("c1")])?
@@ -552,23 +423,12 @@ mod tests {
552423
.sort(vec![sort(col("j2.c2"), true, false)])?
553424
.build()?;
554425

555-
let expected = "\
556-
Projection: #j1.c1, #j2.c2\
557-
\n Sort: #j2.c2 ASC NULLS LAST\
558-
\n Inner Join: #j1.key = #j2.key\
559-
\n Projection: #j1.key, #j1.c1\
560-
\n TableScan: j1 projection=None\
561-
\n Projection: #j2.key, #j2.c2\
562-
\n TableScan: j2 projection=None\
563-
";
564-
565-
assert_optimized_plan_eq(plan, expected);
566-
426+
insta::assert_debug_snapshot!(optimize(&plan));
567427
Ok(())
568428
}
569429

570430
#[test]
571-
fn test_sort_down_cross_join() -> Result<()> {
431+
fn test_sort_down_cross_join_sort_left() -> Result<()> {
572432
let plan = LogicalPlanBuilder::from(
573433
LogicalPlanBuilder::from(make_sample_table("j1", vec!["key", "c1"], vec![])?)
574434
.project(vec![col("key"), col("c1")])?
@@ -583,18 +443,12 @@ mod tests {
583443
.sort(vec![sort(col("j1.c1"), true, false)])?
584444
.build()?;
585445

586-
let expected = "\
587-
Projection: #j1.c1, #j2.c2\
588-
\n CrossJoin:\
589-
\n Sort: #j1.c1 ASC NULLS LAST\
590-
\n Projection: #j1.key, #j1.c1\
591-
\n TableScan: j1 projection=None\
592-
\n Projection: #j2.key, #j2.c2\
593-
\n TableScan: j2 projection=None\
594-
";
595-
596-
assert_optimized_plan_eq(plan, expected);
446+
insta::assert_debug_snapshot!(optimize(&plan));
447+
Ok(())
448+
}
597449

450+
#[test]
451+
fn test_sort_down_cross_join_sort_right() -> Result<()> {
598452
let plan = LogicalPlanBuilder::from(
599453
LogicalPlanBuilder::from(make_sample_table("j1", vec!["key", "c1"], vec![])?)
600454
.project(vec![col("key"), col("c1")])?
@@ -609,17 +463,7 @@ mod tests {
609463
.sort(vec![sort(col("j2.c2"), true, false)])?
610464
.build()?;
611465

612-
let expected = "\
613-
Projection: #j1.c1, #j2.c2\
614-
\n Sort: #j2.c2 ASC NULLS LAST\
615-
\n CrossJoin:\
616-
\n Projection: #j1.key, #j1.c1\
617-
\n TableScan: j1 projection=None\
618-
\n Projection: #j2.key, #j2.c2\
619-
\n TableScan: j2 projection=None\
620-
";
621-
622-
assert_optimized_plan_eq(plan, expected);
466+
insta::assert_debug_snapshot!(optimize(&plan));
623467

624468
Ok(())
625469
}

0 commit comments

Comments
 (0)