Skip to content

fix(cubesql): Fix SortPushDown pushing sort through joins #9464

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
source: cubesql/src/compile/engine/df/optimizers/sort_push_down.rs
expression: optimize(&plan)
---
Projection: #j1.c1, #j2.c2
Sort: #j1.c1 ASC NULLS LAST
CrossJoin:
Projection: #j1.key, #j1.c1
TableScan: j1 projection=None
Projection: #j2.key, #j2.c2
TableScan: j2 projection=None
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
source: cubesql/src/compile/engine/df/optimizers/sort_push_down.rs
expression: optimize(&plan)
---
Projection: #j1.c1, #j2.c2
Sort: #j2.c2 ASC NULLS LAST
CrossJoin:
Projection: #j1.key, #j1.c1
TableScan: j1 projection=None
Projection: #j2.key, #j2.c2
TableScan: j2 projection=None
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
source: cubesql/src/compile/engine/df/optimizers/sort_push_down.rs
expression: optimize(&plan)
---
Projection: #j1.c1, #j2.c2
Sort: #j1.c1 ASC NULLS LAST
Inner Join: #j1.key = #j2.key
Projection: #j1.key, #j1.c1
TableScan: j1 projection=None
Projection: #j2.key, #j2.c2
TableScan: j2 projection=None
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
source: cubesql/src/compile/engine/df/optimizers/sort_push_down.rs
expression: optimize(&plan)
---
Projection: #j1.c1, #j2.c2
Sort: #j2.c2 ASC NULLS LAST
Inner Join: #j1.key = #j2.key
Projection: #j1.key, #j1.c1
TableScan: j1 projection=None
Projection: #j2.key, #j2.c2
TableScan: j2 projection=None
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
source: cubesql/src/compile/engine/df/optimizers/sort_push_down.rs
expression: optimize(&plan)
---
Projection: #t3.n3, #t3.n4, #t3.n2, alias=t4
Projection: #t2.n1 AS n3, #t2.c2 AS n4, #t2.n2, alias=t3
Projection: #t1.c1 AS n1, #t1.c2, #t1.c3 AS n2, alias=t2
Sort: #t1.c2 ASC NULLS LAST, #t1.c3 DESC NULLS FIRST
Projection: #t1.c1, #t1.c2, #t1.c3
TableScan: t1 projection=None
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
---
source: cubesql/src/compile/engine/df/optimizers/sort_push_down.rs
expression: optimize(&plan)
---
Projection: #t1.c1 AS n1, #t1.c2, #t1.c3 AS n2, alias=t2
Sort: #t1.c2 ASC NULLS LAST, #t1.c3 DESC NULLS FIRST
Projection: #t1.c1, #t1.c2, #t1.c3
TableScan: t1 projection=None
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
source: cubesql/src/compile/engine/df/optimizers/sort_push_down.rs
expression: optimize(&plan)
---
Projection: #t3.n3, #t3.n4, #t3.n2, alias=t4
Projection: #t2.n1 AS n3, #t2.c2 AS n4, #t2.n2, alias=t3
Projection: #t1.c1 AS n1, #t1.c2, #t1.c3 AS n2, alias=t2
Sort: #t1.c2 ASC NULLS LAST, #t1.c3 DESC NULLS FIRST
Projection: #t1.c1, #t1.c2, #t1.c3
TableScan: t1 projection=None
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@ use std::{collections::HashMap, sync::Arc};
use datafusion::{
error::{DataFusionError, Result},
logical_plan::{
plan::{
Aggregate, CrossJoin, Distinct, Join, Limit, Projection, Sort, Subquery, Union, Window,
},
plan::{Aggregate, Distinct, Limit, Projection, Sort, Subquery, Union, Window},
Column, DFSchema, Expr, Filter, LogicalPlan,
},
optimizer::optimizer::{OptimizerConfig, OptimizerRule},
};

use super::utils::{get_schema_columns, is_column_expr, plan_has_projections, rewrite};
use super::utils::{is_column_expr, plan_has_projections, rewrite};

/// Sort Push Down optimizer rule pushes ORDER BY clauses consisting of specific,
/// mostly simple, expressions down the plan, all the way to the Projection
Expand Down Expand Up @@ -167,97 +165,6 @@ fn sort_push_down(
optimizer_config,
)
}
LogicalPlan::Join(Join {
left,
right,
on,
join_type,
join_constraint,
schema,
null_equals_null,
}) => {
// DataFusion preserves the sorting of the joined plans, prioritizing left side.
// Taking this into account, we can push Sort down the left plan if Sort references
// columns just from the left side.
// TODO: check if this is still the case with multiple target partitions
if let Some(some_sort_expr) = &sort_expr {
let left_columns = get_schema_columns(left.schema());
if some_sort_expr.iter().all(|expr| {
if let Expr::Sort { expr, .. } = expr {
if let Expr::Column(column) = expr.as_ref() {
return left_columns.contains(column);
}
}
false
}) {
return Ok(LogicalPlan::Join(Join {
left: Arc::new(sort_push_down(
optimizer,
left,
sort_expr,
optimizer_config,
)?),
right: Arc::new(sort_push_down(optimizer, right, None, optimizer_config)?),
on: on.clone(),
join_type: *join_type,
join_constraint: *join_constraint,
schema: schema.clone(),
null_equals_null: *null_equals_null,
}));
}
}

issue_sort(
sort_expr,
LogicalPlan::Join(Join {
left: Arc::new(sort_push_down(optimizer, left, None, optimizer_config)?),
right: Arc::new(sort_push_down(optimizer, right, None, optimizer_config)?),
on: on.clone(),
join_type: *join_type,
join_constraint: *join_constraint,
schema: schema.clone(),
null_equals_null: *null_equals_null,
}),
)
}
LogicalPlan::CrossJoin(CrossJoin {
left,
right,
schema,
}) => {
// See `LogicalPlan::Join` notes above.
if let Some(some_sort_expr) = &sort_expr {
let left_columns = get_schema_columns(left.schema());
if some_sort_expr.iter().all(|expr| {
if let Expr::Sort { expr, .. } = expr {
if let Expr::Column(column) = expr.as_ref() {
return left_columns.contains(column);
}
}
false
}) {
return Ok(LogicalPlan::CrossJoin(CrossJoin {
left: Arc::new(sort_push_down(
optimizer,
left,
sort_expr,
optimizer_config,
)?),
right: Arc::new(sort_push_down(optimizer, right, None, optimizer_config)?),
schema: schema.clone(),
}));
}
}

issue_sort(
sort_expr,
LogicalPlan::CrossJoin(CrossJoin {
left: Arc::new(sort_push_down(optimizer, left, None, optimizer_config)?),
right: Arc::new(sort_push_down(optimizer, right, None, optimizer_config)?),
schema: schema.clone(),
}),
)
}
LogicalPlan::Union(Union {
inputs,
schema,
Expand Down Expand Up @@ -384,15 +291,10 @@ mod tests {
};
use datafusion::logical_plan::{col, JoinType, LogicalPlanBuilder};

fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
fn optimize(plan: &LogicalPlan) -> LogicalPlan {
let rule = SortPushDown::new();
rule.optimize(plan, &OptimizerConfig::new())
}

fn assert_optimized_plan_eq(plan: LogicalPlan, expected: &str) {
let optimized_plan = optimize(&plan).expect("failed to optimize plan");
let formatted_plan = format!("{:?}", optimized_plan);
assert_eq!(formatted_plan, expected);
.expect("failed to optimize plan")
}

fn sort(expr: Expr, asc: bool, nulls_first: bool) -> Expr {
Expand All @@ -417,14 +319,7 @@ mod tests {
])?
.build()?;

let expected = "\
Projection: #t1.c1 AS n1, #t1.c2, #t1.c3 AS n2, alias=t2\
\n Sort: #t1.c2 ASC NULLS LAST, #t1.c3 DESC NULLS FIRST\
\n Projection: #t1.c1, #t1.c2, #t1.c3\
\n TableScan: t1 projection=None\
";

assert_optimized_plan_eq(plan, expected);
insta::assert_debug_snapshot!(optimize(&plan));
Ok(())
}

Expand All @@ -450,16 +345,7 @@ mod tests {
])?
.build()?;

let expected = "\
Projection: #t3.n3, #t3.n4, #t3.n2, alias=t4\
\n Projection: #t2.n1 AS n3, #t2.c2 AS n4, #t2.n2, alias=t3\
\n Projection: #t1.c1 AS n1, #t1.c2, #t1.c3 AS n2, alias=t2\
\n Sort: #t1.c2 ASC NULLS LAST, #t1.c3 DESC NULLS FIRST\
\n Projection: #t1.c1, #t1.c2, #t1.c3\
\n TableScan: t1 projection=None\
";

assert_optimized_plan_eq(plan, expected);
insta::assert_debug_snapshot!(optimize(&plan));
Ok(())
}

Expand Down Expand Up @@ -487,21 +373,12 @@ mod tests {
])?
.build()?;

let expected = "\
Projection: #t3.n3, #t3.n4, #t3.n2, alias=t4\
\n Projection: #t2.n1 AS n3, #t2.c2 AS n4, #t2.n2, alias=t3\
\n Projection: #t1.c1 AS n1, #t1.c2, #t1.c3 AS n2, alias=t2\
\n Sort: #t1.c2 ASC NULLS LAST, #t1.c3 DESC NULLS FIRST\
\n Projection: #t1.c1, #t1.c2, #t1.c3\
\n TableScan: t1 projection=None\
";

assert_optimized_plan_eq(plan, expected);
insta::assert_debug_snapshot!(optimize(&plan));
Ok(())
}

#[test]
fn test_sort_down_join() -> Result<()> {
fn test_sort_down_join_sort_left() -> Result<()> {
let plan = LogicalPlanBuilder::from(
LogicalPlanBuilder::from(make_sample_table("j1", vec!["key", "c1"], vec![])?)
.project(vec![col("key"), col("c1")])?
Expand All @@ -521,18 +398,12 @@ mod tests {
.sort(vec![sort(col("j1.c1"), true, false)])?
.build()?;

let expected = "\
Projection: #j1.c1, #j2.c2\
\n Inner Join: #j1.key = #j2.key\
\n Sort: #j1.c1 ASC NULLS LAST\
\n Projection: #j1.key, #j1.c1\
\n TableScan: j1 projection=None\
\n Projection: #j2.key, #j2.c2\
\n TableScan: j2 projection=None\
";

assert_optimized_plan_eq(plan, expected);
insta::assert_debug_snapshot!(optimize(&plan));
Ok(())
}

#[test]
fn test_sort_down_join_sort_right() -> Result<()> {
let plan = LogicalPlanBuilder::from(
LogicalPlanBuilder::from(make_sample_table("j1", vec!["key", "c1"], vec![])?)
.project(vec![col("key"), col("c1")])?
Expand All @@ -552,23 +423,12 @@ mod tests {
.sort(vec![sort(col("j2.c2"), true, false)])?
.build()?;

let expected = "\
Projection: #j1.c1, #j2.c2\
\n Sort: #j2.c2 ASC NULLS LAST\
\n Inner Join: #j1.key = #j2.key\
\n Projection: #j1.key, #j1.c1\
\n TableScan: j1 projection=None\
\n Projection: #j2.key, #j2.c2\
\n TableScan: j2 projection=None\
";

assert_optimized_plan_eq(plan, expected);

insta::assert_debug_snapshot!(optimize(&plan));
Ok(())
}

#[test]
fn test_sort_down_cross_join() -> Result<()> {
fn test_sort_down_cross_join_sort_left() -> Result<()> {
let plan = LogicalPlanBuilder::from(
LogicalPlanBuilder::from(make_sample_table("j1", vec!["key", "c1"], vec![])?)
.project(vec![col("key"), col("c1")])?
Expand All @@ -583,18 +443,12 @@ mod tests {
.sort(vec![sort(col("j1.c1"), true, false)])?
.build()?;

let expected = "\
Projection: #j1.c1, #j2.c2\
\n CrossJoin:\
\n Sort: #j1.c1 ASC NULLS LAST\
\n Projection: #j1.key, #j1.c1\
\n TableScan: j1 projection=None\
\n Projection: #j2.key, #j2.c2\
\n TableScan: j2 projection=None\
";

assert_optimized_plan_eq(plan, expected);
insta::assert_debug_snapshot!(optimize(&plan));
Ok(())
}

#[test]
fn test_sort_down_cross_join_sort_right() -> Result<()> {
let plan = LogicalPlanBuilder::from(
LogicalPlanBuilder::from(make_sample_table("j1", vec!["key", "c1"], vec![])?)
.project(vec![col("key"), col("c1")])?
Expand All @@ -609,17 +463,7 @@ mod tests {
.sort(vec![sort(col("j2.c2"), true, false)])?
.build()?;

let expected = "\
Projection: #j1.c1, #j2.c2\
\n Sort: #j2.c2 ASC NULLS LAST\
\n CrossJoin:\
\n Projection: #j1.key, #j1.c1\
\n TableScan: j1 projection=None\
\n Projection: #j2.key, #j2.c2\
\n TableScan: j2 projection=None\
";

assert_optimized_plan_eq(plan, expected);
insta::assert_debug_snapshot!(optimize(&plan));

Ok(())
}
Expand Down
Loading