Skip to content
Snippets Groups Projects
Commit f376c372 authored by 蒋星博's avatar 蒋星博 Committed by Cheng Lian
Browse files

[SPARK-16343][SQL] Improve the PushDownPredicate rule to pushdown predicates...

[SPARK-16343][SQL] Improve the PushDownPredicate rule to pushdown predicates correctly in non-deterministic condition.

## What changes were proposed in this pull request?

Currently our Optimizer may reorder the predicates to run them more efficient, but in non-deterministic condition, change the order between deterministic parts and non-deterministic parts may change the number of input rows. For example:
```SELECT a FROM t WHERE rand() < 0.1 AND a = 1```
And
```SELECT a FROM t WHERE a = 1 AND rand() < 0.1```
may call rand() for different times and therefore the output rows differ.

This PR improved this condition by checking whether the predicate is placed before any non-deterministic predicates.

## How was this patch tested?

Expanded related testcases in FilterPushdownSuite.

Author: 蒋星博 <jiangxingbo@meituan.com>

Closes #14012 from jiangxb1987/ppd.
parent ea06e4ef
No related branches found
No related tags found
No related merge requests found
...@@ -1128,19 +1128,23 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { ...@@ -1128,19 +1128,23 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild))
// Push [[Filter]] operators through [[Window]] operators. Parts of the predicate that can be // Push [[Filter]] operators through [[Window]] operators. Parts of the predicate that can be
// pushed beneath must satisfy the following two conditions: // pushed beneath must satisfy the following conditions:
// 1. All the expressions are part of window partitioning key. The expressions can be compound. // 1. All the expressions are part of window partitioning key. The expressions can be compound.
// 2. Deterministic // 2. Deterministic.
// 3. Placed before any non-deterministic predicates.
case filter @ Filter(condition, w: Window) case filter @ Filter(condition, w: Window)
if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) => if w.partitionSpec.forall(_.isInstanceOf[AttributeReference]) =>
val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references)) val partitionAttrs = AttributeSet(w.partitionSpec.flatMap(_.references))
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond =>
cond.references.subsetOf(partitionAttrs) && cond.deterministic && val (candidates, containingNonDeterministic) =
// This is for ensuring all the partitioning expressions have been converted to alias splitConjunctivePredicates(condition).span(_.deterministic)
// in Analyzer. Thus, we do not need to check if the expressions in conditions are
// the same as the expressions used in partitioning columns. val (pushDown, rest) = candidates.partition { cond =>
partitionAttrs.forall(_.isInstanceOf[Attribute]) cond.references.subsetOf(partitionAttrs)
} }
val stayUp = rest ++ containingNonDeterministic
if (pushDown.nonEmpty) { if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduce(And) val pushDownPredicate = pushDown.reduce(And)
val newWindow = w.copy(child = Filter(pushDownPredicate, w.child)) val newWindow = w.copy(child = Filter(pushDownPredicate, w.child))
...@@ -1159,11 +1163,16 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { ...@@ -1159,11 +1163,16 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
// For each filter, expand the alias and check if the filter can be evaluated using // For each filter, expand the alias and check if the filter can be evaluated using
// attributes produced by the aggregate operator's child operator. // attributes produced by the aggregate operator's child operator.
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => val (candidates, containingNonDeterministic) =
splitConjunctivePredicates(condition).span(_.deterministic)
val (pushDown, rest) = candidates.partition { cond =>
val replaced = replaceAlias(cond, aliasMap) val replaced = replaceAlias(cond, aliasMap)
replaced.references.subsetOf(aggregate.child.outputSet) && replaced.deterministic replaced.references.subsetOf(aggregate.child.outputSet)
} }
val stayUp = rest ++ containingNonDeterministic
if (pushDown.nonEmpty) { if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduce(And) val pushDownPredicate = pushDown.reduce(And)
val replaced = replaceAlias(pushDownPredicate, aliasMap) val replaced = replaceAlias(pushDownPredicate, aliasMap)
...@@ -1177,9 +1186,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { ...@@ -1177,9 +1186,8 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
case filter @ Filter(condition, union: Union) => case filter @ Filter(condition, union: Union) =>
// Union could change the rows, so non-deterministic predicate can't be pushed down // Union could change the rows, so non-deterministic predicate can't be pushed down
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => val (pushDown, stayUp) = splitConjunctivePredicates(condition).span(_.deterministic)
cond.deterministic
}
if (pushDown.nonEmpty) { if (pushDown.nonEmpty) {
val pushDownCond = pushDown.reduceLeft(And) val pushDownCond = pushDown.reduceLeft(And)
val output = union.output val output = union.output
...@@ -1219,9 +1227,15 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { ...@@ -1219,9 +1227,15 @@ object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper {
// come from grandchild. // come from grandchild.
// TODO: non-deterministic predicates could be pushed through some operators that do not change // TODO: non-deterministic predicates could be pushed through some operators that do not change
// the rows. // the rows.
val (pushDown, stayUp) = splitConjunctivePredicates(filter.condition).partition { cond => val (candidates, containingNonDeterministic) =
cond.deterministic && cond.references.subsetOf(grandchild.outputSet) splitConjunctivePredicates(filter.condition).span(_.deterministic)
val (pushDown, rest) = candidates.partition { cond =>
cond.references.subsetOf(grandchild.outputSet)
} }
val stayUp = rest ++ containingNonDeterministic
if (pushDown.nonEmpty) { if (pushDown.nonEmpty) {
val newChild = insertFilter(pushDown.reduceLeft(And)) val newChild = insertFilter(pushDown.reduceLeft(And))
if (stayUp.nonEmpty) { if (stayUp.nonEmpty) {
......
...@@ -531,14 +531,14 @@ class FilterPushdownSuite extends PlanTest { ...@@ -531,14 +531,14 @@ class FilterPushdownSuite extends PlanTest {
val originalQuery = { val originalQuery = {
testRelationWithArrayType testRelationWithArrayType
.generate(Explode('c_arr), true, false, Some("arr")) .generate(Explode('c_arr), true, false, Some("arr"))
.where(('b >= 5) && ('a + Rand(10).as("rnd") > 6)) .where(('b >= 5) && ('a + Rand(10).as("rnd") > 6) && ('c > 6))
} }
val optimized = Optimize.execute(originalQuery.analyze) val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = { val correctAnswer = {
testRelationWithArrayType testRelationWithArrayType
.where('b >= 5) .where('b >= 5)
.generate(Explode('c_arr), true, false, Some("arr")) .generate(Explode('c_arr), true, false, Some("arr"))
.where('a + Rand(10).as("rnd") > 6) .where('a + Rand(10).as("rnd") > 6 && 'c > 6)
.analyze .analyze
} }
...@@ -715,14 +715,14 @@ class FilterPushdownSuite extends PlanTest { ...@@ -715,14 +715,14 @@ class FilterPushdownSuite extends PlanTest {
val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int)
val originalQuery = Union(Seq(testRelation, testRelation2)) val originalQuery = Union(Seq(testRelation, testRelation2))
.where('a === 2L && 'b + Rand(10).as("rnd") === 3) .where('a === 2L && 'b + Rand(10).as("rnd") === 3 && 'c > 5L)
val optimized = Optimize.execute(originalQuery.analyze) val optimized = Optimize.execute(originalQuery.analyze)
val correctAnswer = Union(Seq( val correctAnswer = Union(Seq(
testRelation.where('a === 2L), testRelation.where('a === 2L),
testRelation2.where('d === 2L))) testRelation2.where('d === 2L)))
.where('b + Rand(10).as("rnd") === 3) .where('b + Rand(10).as("rnd") === 3 && 'c > 5L)
.analyze .analyze
comparePlans(optimized, correctAnswer) comparePlans(optimized, correctAnswer)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment