diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index bad115d22f1ae5ace96096eaa5b75c8fbbd94a80..438cbabdbb8a8e92b7ef752c01012b91dda74b35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -66,9 +66,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { ReorderJoin, OuterJoinElimination, PushPredicateThroughJoin, - PushPredicateThroughProject, - PushPredicateThroughGenerate, - PushPredicateThroughAggregate, + PushDownPredicate, LimitPushDown, ColumnPruning, InferFiltersFromConstraints, @@ -917,12 +915,13 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { } /** - * Pushes [[Filter]] operators through [[Project]] operators, in-lining any [[Alias Aliases]] - * that were defined in the projection. + * Pushes [[Filter]] operators through many operators iff: + * 1) the operator is deterministic + * 2) the predicate is deterministic and the operator will not change any of rows. * * This heuristic is valid assuming the expression evaluation cost is minimal. */ -object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelper { +object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // SPARK-13473: We can't push the predicate down when the underlying projection output non- // deterministic field(s). Non-deterministic expressions are essentially stateful. This @@ -939,41 +938,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe }) project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) - } - -} - -/** - * Push [[Filter]] operators through [[Generate]] operators. Parts of the predicate that reference - * attributes generated in [[Generate]] will remain above, and the rest should be pushed beneath. - */ -object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper { - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case filter @ Filter(condition, g: Generate) => - // Predicates that reference attributes produced by the `Generate` operator cannot - // be pushed below the operator. - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => - cond.references.subsetOf(g.child.outputSet) && cond.deterministic - } - if (pushDown.nonEmpty) { - val pushDownPredicate = pushDown.reduce(And) - val newGenerate = Generate(g.generator, join = g.join, outer = g.outer, - g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child)) - if (stayUp.isEmpty) newGenerate else Filter(stayUp.reduce(And), newGenerate) - } else { - filter - } - } -} -/** - * Push [[Filter]] operators through [[Aggregate]] operators, iff the filters reference only - * non-aggregate attributes (typically literals or grouping expressions). - */ -object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHelper { - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { case filter @ Filter(condition, aggregate: Aggregate) => // Find all the aliased expressions in the aggregate list that don't include any actual // AggregateExpression, and create a map from the alias to the expression @@ -999,6 +964,72 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel } else { filter } + + case filter @ Filter(condition, child) + if child.isInstanceOf[Union] || child.isInstanceOf[Intersect] => + // Union/Intersect could change the rows, so non-deterministic predicate can't be pushed down + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => + cond.deterministic + } + if (pushDown.nonEmpty) { + val pushDownCond = pushDown.reduceLeft(And) + val output = child.output + val newGrandChildren = child.children.map { grandchild => + val newCond = pushDownCond transform { + case e if output.exists(_.semanticEquals(e)) => + grandchild.output(output.indexWhere(_.semanticEquals(e))) + } + assert(newCond.references.subsetOf(grandchild.outputSet)) + Filter(newCond, grandchild) + } + val newChild = child.withNewChildren(newGrandChildren) + if (stayUp.nonEmpty) { + Filter(stayUp.reduceLeft(And), newChild) + } else { + newChild + } + } else { + filter + } + + case filter @ Filter(condition, e @ Except(left, _)) => + pushDownPredicate(filter, e.left) { predicate => + e.copy(left = Filter(predicate, left)) + } + + // two filters should be combine together by other rules + case filter @ Filter(_, f: Filter) => filter + // should not push predicates through sample, or will generate different results. + case filter @ Filter(_, s: Sample) => filter + // TODO: push predicates through expand + case filter @ Filter(_, e: Expand) => filter + + case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) => + pushDownPredicate(filter, u.child) { predicate => + u.withNewChildren(Seq(Filter(predicate, u.child))) + } + } + + private def pushDownPredicate( + filter: Filter, + grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = { + // Only push down the predicates that is deterministic and all the referenced attributes + // come from grandchild. + // TODO: non-deterministic predicates could be pushed through some operators that do not change + // the rows. + val (pushDown, stayUp) = splitConjunctivePredicates(filter.condition).partition { cond => + cond.deterministic && cond.references.subsetOf(grandchild.outputSet) + } + if (pushDown.nonEmpty) { + val newChild = insertFilter(pushDown.reduceLeft(And)) + if (stayUp.nonEmpty) { + Filter(stayUp.reduceLeft(And), newChild) + } else { + newChild + } + } else { + filter + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 6f35d87ebbd95d947ee41132fdbd44651e04ecc6..00656191354f24f663eee59e57ca3e905c64473d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -69,6 +69,9 @@ object PhysicalOperation extends PredicateHelper { val substitutedCondition = substitute(aliases)(condition) (fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases) + case BroadcastHint(child) => + collectProjectsAndFilters(child) + case other => (None, Nil, other, Map.empty) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 2248e03b2fc5892a68ead9f15c0ed4bc0d273f2f..52b574c0e63c93feefe430c4011aae8363e0728f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -34,7 +34,7 @@ class ColumnPruningSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Column pruning", FixedPoint(100), - PushPredicateThroughProject, + PushDownPredicate, ColumnPruning, CollapseProject) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index b84ae7c5bb6ad5d899b25512b642e31166f47627..df7529d83f7c829955f16d733844fb1723dc487f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -33,14 +33,12 @@ class FilterPushdownSuite extends PlanTest { val batches = Batch("Subqueries", Once, EliminateSubqueryAliases) :: - Batch("Filter Pushdown", Once, + Batch("Filter Pushdown", FixedPoint(10), SamplePushDown, CombineFilters, - PushPredicateThroughProject, + PushDownPredicate, BooleanSimplification, PushPredicateThroughJoin, - PushPredicateThroughGenerate, - PushPredicateThroughAggregate, CollapseProject) :: Nil } @@ -620,8 +618,8 @@ class FilterPushdownSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a, 'b) .where('a === 3) + .select('a, 'b) .groupBy('a)('a, count('b) as 'c) .where('c === 2L) .analyze @@ -638,8 +636,8 @@ class FilterPushdownSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a, 'b) .where('a + 1 < 3) + .select('a, 'b) .groupBy('a)(('a + 1) as 'aa, count('b) as 'c) .where('c === 2L || 'aa > 4) .analyze @@ -656,8 +654,8 @@ class FilterPushdownSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a, 'b) .where("s" === "s") + .select('a, 'b) .groupBy('a)('a, count('b) as 'c, "s" as 'd) .where('c === 2L) .analyze @@ -681,4 +679,68 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("broadcast hint") { + val originalQuery = BroadcastHint(testRelation) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = BroadcastHint(testRelation.where('a === 2L)) + .where('b + Rand(10).as("rnd") === 3) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("union") { + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + + val originalQuery = Union(Seq(testRelation, testRelation2)) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = Union(Seq( + testRelation.where('a === 2L), + testRelation2.where('d === 2L))) + .where('b + Rand(10).as("rnd") === 3) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("intersect") { + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + + val originalQuery = Intersect(testRelation, testRelation2) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = Intersect( + testRelation.where('a === 2L), + testRelation2.where('d === 2L)) + .where('b + Rand(10).as("rnd") === 3) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("except") { + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + + val originalQuery = Except(testRelation, testRelation2) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = Except( + testRelation.where('a === 2L), + testRelation2) + .where('b + Rand(10).as("rnd") === 3) + .analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index e2f8146beee7bc0fe8409fec4301f9a77b12fc44..c1ebf8b09e08d49ca3a45f59d87dc2f9829e09f1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -36,12 +36,10 @@ class JoinOptimizationSuite extends PlanTest { EliminateSubqueryAliases) :: Batch("Filter Pushdown", FixedPoint(100), CombineFilters, - PushPredicateThroughProject, + PushDownPredicate, BooleanSimplification, ReorderJoin, PushPredicateThroughJoin, - PushPredicateThroughGenerate, - PushPredicateThroughAggregate, ColumnPruning, CollapseProject) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index 14fb72a8a343907c4d919f56af9e042d111c3155..d8cfec5391497605caa3dd272efd1fcd6981b5ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -34,7 +34,7 @@ class PruneFiltersSuite extends PlanTest { Batch("Filter Pushdown and Pruning", Once, CombineFilters, PruneFilters, - PushPredicateThroughProject, + PushDownPredicate, PushPredicateThroughJoin) :: Nil }