@@ -3,15 +3,13 @@ use std::{collections::HashMap, sync::Arc};
3
3
use datafusion:: {
4
4
error:: { DataFusionError , Result } ,
5
5
logical_plan:: {
6
- plan:: {
7
- Aggregate , CrossJoin , Distinct , Join , Limit , Projection , Sort , Subquery , Union , Window ,
8
- } ,
6
+ plan:: { Aggregate , Distinct , Limit , Projection , Sort , Subquery , Union , Window } ,
9
7
Column , DFSchema , Expr , Filter , LogicalPlan ,
10
8
} ,
11
9
optimizer:: optimizer:: { OptimizerConfig , OptimizerRule } ,
12
10
} ;
13
11
14
- use super :: utils:: { get_schema_columns , is_column_expr, plan_has_projections, rewrite} ;
12
+ use super :: utils:: { is_column_expr, plan_has_projections, rewrite} ;
15
13
16
14
/// Sort Push Down optimizer rule pushes ORDER BY clauses consisting of specific,
17
15
/// mostly simple, expressions down the plan, all the way to the Projection
@@ -167,97 +165,6 @@ fn sort_push_down(
167
165
optimizer_config,
168
166
)
169
167
}
170
- LogicalPlan :: Join ( Join {
171
- left,
172
- right,
173
- on,
174
- join_type,
175
- join_constraint,
176
- schema,
177
- null_equals_null,
178
- } ) => {
179
- // DataFusion preserves the sorting of the joined plans, prioritizing left side.
180
- // Taking this into account, we can push Sort down the left plan if Sort references
181
- // columns just from the left side.
182
- // TODO: check if this is still the case with multiple target partitions
183
- if let Some ( some_sort_expr) = & sort_expr {
184
- let left_columns = get_schema_columns ( left. schema ( ) ) ;
185
- if some_sort_expr. iter ( ) . all ( |expr| {
186
- if let Expr :: Sort { expr, .. } = expr {
187
- if let Expr :: Column ( column) = expr. as_ref ( ) {
188
- return left_columns. contains ( column) ;
189
- }
190
- }
191
- false
192
- } ) {
193
- return Ok ( LogicalPlan :: Join ( Join {
194
- left : Arc :: new ( sort_push_down (
195
- optimizer,
196
- left,
197
- sort_expr,
198
- optimizer_config,
199
- ) ?) ,
200
- right : Arc :: new ( sort_push_down ( optimizer, right, None , optimizer_config) ?) ,
201
- on : on. clone ( ) ,
202
- join_type : * join_type,
203
- join_constraint : * join_constraint,
204
- schema : schema. clone ( ) ,
205
- null_equals_null : * null_equals_null,
206
- } ) ) ;
207
- }
208
- }
209
-
210
- issue_sort (
211
- sort_expr,
212
- LogicalPlan :: Join ( Join {
213
- left : Arc :: new ( sort_push_down ( optimizer, left, None , optimizer_config) ?) ,
214
- right : Arc :: new ( sort_push_down ( optimizer, right, None , optimizer_config) ?) ,
215
- on : on. clone ( ) ,
216
- join_type : * join_type,
217
- join_constraint : * join_constraint,
218
- schema : schema. clone ( ) ,
219
- null_equals_null : * null_equals_null,
220
- } ) ,
221
- )
222
- }
223
- LogicalPlan :: CrossJoin ( CrossJoin {
224
- left,
225
- right,
226
- schema,
227
- } ) => {
228
- // See `LogicalPlan::Join` notes above.
229
- if let Some ( some_sort_expr) = & sort_expr {
230
- let left_columns = get_schema_columns ( left. schema ( ) ) ;
231
- if some_sort_expr. iter ( ) . all ( |expr| {
232
- if let Expr :: Sort { expr, .. } = expr {
233
- if let Expr :: Column ( column) = expr. as_ref ( ) {
234
- return left_columns. contains ( column) ;
235
- }
236
- }
237
- false
238
- } ) {
239
- return Ok ( LogicalPlan :: CrossJoin ( CrossJoin {
240
- left : Arc :: new ( sort_push_down (
241
- optimizer,
242
- left,
243
- sort_expr,
244
- optimizer_config,
245
- ) ?) ,
246
- right : Arc :: new ( sort_push_down ( optimizer, right, None , optimizer_config) ?) ,
247
- schema : schema. clone ( ) ,
248
- } ) ) ;
249
- }
250
- }
251
-
252
- issue_sort (
253
- sort_expr,
254
- LogicalPlan :: CrossJoin ( CrossJoin {
255
- left : Arc :: new ( sort_push_down ( optimizer, left, None , optimizer_config) ?) ,
256
- right : Arc :: new ( sort_push_down ( optimizer, right, None , optimizer_config) ?) ,
257
- schema : schema. clone ( ) ,
258
- } ) ,
259
- )
260
- }
261
168
LogicalPlan :: Union ( Union {
262
169
inputs,
263
170
schema,
@@ -384,15 +291,10 @@ mod tests {
384
291
} ;
385
292
use datafusion:: logical_plan:: { col, JoinType , LogicalPlanBuilder } ;
386
293
387
- fn optimize ( plan : & LogicalPlan ) -> Result < LogicalPlan > {
294
+ fn optimize ( plan : & LogicalPlan ) -> LogicalPlan {
388
295
let rule = SortPushDown :: new ( ) ;
389
296
rule. optimize ( plan, & OptimizerConfig :: new ( ) )
390
- }
391
-
392
- fn assert_optimized_plan_eq ( plan : LogicalPlan , expected : & str ) {
393
- let optimized_plan = optimize ( & plan) . expect ( "failed to optimize plan" ) ;
394
- let formatted_plan = format ! ( "{:?}" , optimized_plan) ;
395
- assert_eq ! ( formatted_plan, expected) ;
297
+ . expect ( "failed to optimize plan" )
396
298
}
397
299
398
300
fn sort ( expr : Expr , asc : bool , nulls_first : bool ) -> Expr {
@@ -417,14 +319,7 @@ mod tests {
417
319
] ) ?
418
320
. build ( ) ?;
419
321
420
- let expected = "\
421
- Projection: #t1.c1 AS n1, #t1.c2, #t1.c3 AS n2, alias=t2\
422
- \n Sort: #t1.c2 ASC NULLS LAST, #t1.c3 DESC NULLS FIRST\
423
- \n Projection: #t1.c1, #t1.c2, #t1.c3\
424
- \n TableScan: t1 projection=None\
425
- ";
426
-
427
- assert_optimized_plan_eq ( plan, expected) ;
322
+ insta:: assert_debug_snapshot!( optimize( & plan) ) ;
428
323
Ok ( ( ) )
429
324
}
430
325
@@ -450,16 +345,7 @@ mod tests {
450
345
] ) ?
451
346
. build ( ) ?;
452
347
453
- let expected = "\
454
- Projection: #t3.n3, #t3.n4, #t3.n2, alias=t4\
455
- \n Projection: #t2.n1 AS n3, #t2.c2 AS n4, #t2.n2, alias=t3\
456
- \n Projection: #t1.c1 AS n1, #t1.c2, #t1.c3 AS n2, alias=t2\
457
- \n Sort: #t1.c2 ASC NULLS LAST, #t1.c3 DESC NULLS FIRST\
458
- \n Projection: #t1.c1, #t1.c2, #t1.c3\
459
- \n TableScan: t1 projection=None\
460
- ";
461
-
462
- assert_optimized_plan_eq ( plan, expected) ;
348
+ insta:: assert_debug_snapshot!( optimize( & plan) ) ;
463
349
Ok ( ( ) )
464
350
}
465
351
@@ -487,21 +373,12 @@ mod tests {
487
373
] ) ?
488
374
. build ( ) ?;
489
375
490
- let expected = "\
491
- Projection: #t3.n3, #t3.n4, #t3.n2, alias=t4\
492
- \n Projection: #t2.n1 AS n3, #t2.c2 AS n4, #t2.n2, alias=t3\
493
- \n Projection: #t1.c1 AS n1, #t1.c2, #t1.c3 AS n2, alias=t2\
494
- \n Sort: #t1.c2 ASC NULLS LAST, #t1.c3 DESC NULLS FIRST\
495
- \n Projection: #t1.c1, #t1.c2, #t1.c3\
496
- \n TableScan: t1 projection=None\
497
- ";
498
-
499
- assert_optimized_plan_eq ( plan, expected) ;
376
+ insta:: assert_debug_snapshot!( optimize( & plan) ) ;
500
377
Ok ( ( ) )
501
378
}
502
379
503
380
#[ test]
504
- fn test_sort_down_join ( ) -> Result < ( ) > {
381
+ fn test_sort_down_join_sort_left ( ) -> Result < ( ) > {
505
382
let plan = LogicalPlanBuilder :: from (
506
383
LogicalPlanBuilder :: from ( make_sample_table ( "j1" , vec ! [ "key" , "c1" ] , vec ! [ ] ) ?)
507
384
. project ( vec ! [ col( "key" ) , col( "c1" ) ] ) ?
@@ -521,18 +398,12 @@ mod tests {
521
398
. sort ( vec ! [ sort( col( "j1.c1" ) , true , false ) ] ) ?
522
399
. build ( ) ?;
523
400
524
- let expected = "\
525
- Projection: #j1.c1, #j2.c2\
526
- \n Inner Join: #j1.key = #j2.key\
527
- \n Sort: #j1.c1 ASC NULLS LAST\
528
- \n Projection: #j1.key, #j1.c1\
529
- \n TableScan: j1 projection=None\
530
- \n Projection: #j2.key, #j2.c2\
531
- \n TableScan: j2 projection=None\
532
- ";
533
-
534
- assert_optimized_plan_eq ( plan, expected) ;
401
+ insta:: assert_debug_snapshot!( optimize( & plan) ) ;
402
+ Ok ( ( ) )
403
+ }
535
404
405
+ #[ test]
406
+ fn test_sort_down_join_sort_right ( ) -> Result < ( ) > {
536
407
let plan = LogicalPlanBuilder :: from (
537
408
LogicalPlanBuilder :: from ( make_sample_table ( "j1" , vec ! [ "key" , "c1" ] , vec ! [ ] ) ?)
538
409
. project ( vec ! [ col( "key" ) , col( "c1" ) ] ) ?
@@ -552,23 +423,12 @@ mod tests {
552
423
. sort ( vec ! [ sort( col( "j2.c2" ) , true , false ) ] ) ?
553
424
. build ( ) ?;
554
425
555
- let expected = "\
556
- Projection: #j1.c1, #j2.c2\
557
- \n Sort: #j2.c2 ASC NULLS LAST\
558
- \n Inner Join: #j1.key = #j2.key\
559
- \n Projection: #j1.key, #j1.c1\
560
- \n TableScan: j1 projection=None\
561
- \n Projection: #j2.key, #j2.c2\
562
- \n TableScan: j2 projection=None\
563
- ";
564
-
565
- assert_optimized_plan_eq ( plan, expected) ;
566
-
426
+ insta:: assert_debug_snapshot!( optimize( & plan) ) ;
567
427
Ok ( ( ) )
568
428
}
569
429
570
430
#[ test]
571
- fn test_sort_down_cross_join ( ) -> Result < ( ) > {
431
+ fn test_sort_down_cross_join_sort_left ( ) -> Result < ( ) > {
572
432
let plan = LogicalPlanBuilder :: from (
573
433
LogicalPlanBuilder :: from ( make_sample_table ( "j1" , vec ! [ "key" , "c1" ] , vec ! [ ] ) ?)
574
434
. project ( vec ! [ col( "key" ) , col( "c1" ) ] ) ?
@@ -583,18 +443,12 @@ mod tests {
583
443
. sort ( vec ! [ sort( col( "j1.c1" ) , true , false ) ] ) ?
584
444
. build ( ) ?;
585
445
586
- let expected = "\
587
- Projection: #j1.c1, #j2.c2\
588
- \n CrossJoin:\
589
- \n Sort: #j1.c1 ASC NULLS LAST\
590
- \n Projection: #j1.key, #j1.c1\
591
- \n TableScan: j1 projection=None\
592
- \n Projection: #j2.key, #j2.c2\
593
- \n TableScan: j2 projection=None\
594
- ";
595
-
596
- assert_optimized_plan_eq ( plan, expected) ;
446
+ insta:: assert_debug_snapshot!( optimize( & plan) ) ;
447
+ Ok ( ( ) )
448
+ }
597
449
450
+ #[ test]
451
+ fn test_sort_down_cross_join_sort_right ( ) -> Result < ( ) > {
598
452
let plan = LogicalPlanBuilder :: from (
599
453
LogicalPlanBuilder :: from ( make_sample_table ( "j1" , vec ! [ "key" , "c1" ] , vec ! [ ] ) ?)
600
454
. project ( vec ! [ col( "key" ) , col( "c1" ) ] ) ?
@@ -609,17 +463,7 @@ mod tests {
609
463
. sort ( vec ! [ sort( col( "j2.c2" ) , true , false ) ] ) ?
610
464
. build ( ) ?;
611
465
612
- let expected = "\
613
- Projection: #j1.c1, #j2.c2\
614
- \n Sort: #j2.c2 ASC NULLS LAST\
615
- \n CrossJoin:\
616
- \n Projection: #j1.key, #j1.c1\
617
- \n TableScan: j1 projection=None\
618
- \n Projection: #j2.key, #j2.c2\
619
- \n TableScan: j2 projection=None\
620
- ";
621
-
622
- assert_optimized_plan_eq ( plan, expected) ;
466
+ insta:: assert_debug_snapshot!( optimize( & plan) ) ;
623
467
624
468
Ok ( ( ) )
625
469
}
0 commit comments