From 8e4f894e986ccd943df9ddf55fc853eb0558886f Mon Sep 17 00:00:00 2001 From: Davies Liu <davies@databricks.com> Date: Wed, 20 Jan 2016 10:02:40 -0800 Subject: [PATCH] [SPARK-12881] [SQL] subexpress elimination in mutable projection Author: Davies Liu <davies@databricks.com> Closes #10814 from davies/mutable_subexpr. --- .../expressions/EquivalentExpressions.scala | 5 ++- .../sql/catalyst/expressions/Expression.scala | 4 +- .../expressions/codegen/CodeGenerator.scala | 8 ++-- .../codegen/GenerateMutableProjection.scala | 43 ++++++++++++++----- .../codegen/GenerateUnsafeProjection.scala | 6 +-- .../SubexpressionEliminationSuite.scala | 13 ++++++ .../spark/sql/execution/SparkPlan.scala | 6 ++- .../apache/spark/sql/execution/Window.scala | 8 +++- .../aggregate/SortBasedAggregate.scala | 3 +- .../aggregate/TungstenAggregate.scala | 3 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 8 ++++ 11 files changed, 80 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index f7162e420d..affd1bdb32 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.mutable +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback + /** * This class is used to compute equality of (sub)expression trees. Expressions can be added * to this class and they subsequently query for expression equality. Expression trees are @@ -67,7 +69,8 @@ class EquivalentExpressions { */ def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = { val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf - if (!skip && !addExpr(root)) { + // the children of CodegenFallback will not be used to generate code (call eval() instead) + if (!skip && !addExpr(root) && !root.isInstanceOf[CodegenFallback]) { root.children.foreach(addExprTree(_, ignoreLeaf)) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 25cf210c4b..db17ba7c84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -100,8 +100,8 @@ abstract class Expression extends TreeNode[Expression] { ExprCode(code, subExprState.isNull, subExprState.value) }.getOrElse { val isNull = ctx.freshName("isNull") - val primitive = ctx.freshName("primitive") - val ve = ExprCode("", isNull, primitive) + val value = ctx.freshName("value") + val ve = ExprCode("", isNull, value) ve.code = genCode(ctx, ve) // Add `this` in the comment. ve.copy(s"/* ${this.toCommentSafeString} */\n" + ve.code.trim) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 683029ff14..2747c315ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -125,7 +125,7 @@ class CodegenContext { val subExprEliminationExprs = mutable.HashMap.empty[Expression, SubExprEliminationState] // The collection of sub-exression result resetting methods that need to be called on each row. - val subExprResetVariables = mutable.ArrayBuffer.empty[String] + val subexprFunctions = mutable.ArrayBuffer.empty[String] def declareAddedFunctions(): String = { addedFunctions.map { case (funcName, funcCode) => funcCode }.mkString("\n") @@ -424,9 +424,9 @@ class CodegenContext { val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) commonExprs.foreach(e => { val expr = e.head - val isNull = freshName("isNull") - val value = freshName("value") val fnName = freshName("evalExpr") + val isNull = s"${fnName}IsNull" + val value = s"${fnName}Value" // Generate the code for this expression tree and wrap it in a function. val code = expr.gen(this) @@ -461,7 +461,7 @@ class CodegenContext { addMutableState(javaType(expr.dataType), value, s"$value = ${defaultValue(expr.dataType)};") - subExprResetVariables += s"$fnName($INPUT_ROW);" + subexprFunctions += s"$fnName($INPUT_ROW);" val state = SubExprEliminationState(isNull, value) e.foreach(subExprEliminationExprs.put(_, state)) }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 59ef0f5836..d9fe76133c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -38,12 +38,29 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu protected def bind(in: Seq[Expression], inputSchema: Seq[Attribute]): Seq[Expression] = in.map(BindReferences.bindReference(_, inputSchema)) + def generate( + expressions: Seq[Expression], + inputSchema: Seq[Attribute], + useSubexprElimination: Boolean): (() => MutableProjection) = { + create(canonicalize(bind(expressions, inputSchema)), useSubexprElimination) + } + protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { + create(expressions, false) + } + + private def create( + expressions: Seq[Expression], + useSubexprElimination: Boolean): (() => MutableProjection) = { val ctx = newCodeGenContext() - val projectionCodes = expressions.zipWithIndex.map { - case (NoOp, _) => "" - case (e, i) => - val evaluationCode = e.gen(ctx) + val (validExpr, index) = expressions.zipWithIndex.filter { + case (NoOp, _) => false + case _ => true + }.unzip + val exprVals = ctx.generateExpressions(validExpr, useSubexprElimination) + val projectionCodes = exprVals.zip(index).map { + case (ev, i) => + val e = expressions(i) if (e.nullable) { val isNull = s"isNull_$i" val value = s"value_$i" @@ -51,22 +68,25 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu ctx.addMutableState(ctx.javaType(e.dataType), value, s"this.$value = ${ctx.defaultValue(e.dataType)};") s""" - ${evaluationCode.code} - this.$isNull = ${evaluationCode.isNull}; - this.$value = ${evaluationCode.value}; + ${ev.code} + this.$isNull = ${ev.isNull}; + this.$value = ${ev.value}; """ } else { val value = s"value_$i" ctx.addMutableState(ctx.javaType(e.dataType), value, s"this.$value = ${ctx.defaultValue(e.dataType)};") s""" - ${evaluationCode.code} - this.$value = ${evaluationCode.value}; + ${ev.code} + this.$value = ${ev.value}; """ } } - val updates = expressions.zipWithIndex.map { - case (NoOp, _) => "" + + // Evaluate all the the subexpressions. + val evalSubexpr = ctx.subexprFunctions.mkString("\n") + + val updates = validExpr.zip(index).map { case (e, i) => if (e.nullable) { if (e.dataType.isInstanceOf[DecimalType]) { @@ -128,6 +148,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; + $evalSubexpr $allProjections // copy all the results into MutableRow $allUpdates diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 61e7469ee4..72bf39a039 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -294,13 +294,13 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val holderClass = classOf[BufferHolder].getName ctx.addMutableState(holderClass, bufferHolder, s"this.$bufferHolder = new $holderClass();") - // Reset the subexpression values for each row. - val subexprReset = ctx.subExprResetVariables.mkString("\n") + // Evaluate all the subexpression. + val evalSubexpr = ctx.subexprFunctions.mkString("\n") val code = s""" $bufferHolder.reset(); - $subexprReset + $evalSubexpr ${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)} $result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize()); diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala index a61297b2c0..43a3eb9dec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/SubexpressionEliminationSuite.scala @@ -154,4 +154,17 @@ class SubexpressionEliminationSuite extends SparkFunSuite { equivalence.addExpr(sum) assert(equivalence.getAllEquivalentExprs.isEmpty) } + + test("Children of CodegenFallback") { + val one = Literal(1) + val two = Add(one, one) + val explode = Explode(two) + val add = Add(two, explode) + + var equivalence = new EquivalentExpressions + equivalence.addExprTree(add, true) + // the `two` inside `explode` should not be added + assert(equivalence.getAllEquivalentExprs.filter(_.size > 1).size == 0) + assert(equivalence.getAllEquivalentExprs.filter(_.size == 1).size == 3) // add, two, explode + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 75101ea0fc..b19b772409 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -196,10 +196,12 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ private[this] def isTesting: Boolean = sys.props.contains("spark.testing") protected def newMutableProjection( - expressions: Seq[Expression], inputSchema: Seq[Attribute]): () => MutableProjection = { + expressions: Seq[Expression], + inputSchema: Seq[Attribute], + useSubexprElimination: Boolean = false): () => MutableProjection = { log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema") try { - GenerateMutableProjection.generate(expressions, inputSchema) + GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination) } catch { case e: Exception => if (isTesting) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 168b5ab031..26a7340f1a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -194,7 +194,11 @@ case class Window( val functions = functionSeq.toArray // Construct an aggregate processor if we need one. - def processor = AggregateProcessor(functions, ordinal, child.output, newMutableProjection) + def processor = AggregateProcessor( + functions, + ordinal, + child.output, + (expressions, schema) => newMutableProjection(expressions, schema)) // Create the factory val factory = key match { @@ -206,7 +210,7 @@ case class Window( ordinal, functions, child.output, - newMutableProjection, + (expressions, schema) => newMutableProjection(expressions, schema), offset) // Growing Frame. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala index 1d56592c40..06a3991459 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala @@ -87,7 +87,8 @@ case class SortBasedAggregate( aggregateAttributes, initialInputBufferOffset, resultExpressions, - newMutableProjection, + (expressions, inputSchema) => + newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), numInputRows, numOutputRows) if (!hasInput && groupingExpressions.isEmpty) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index a9cf04388d..8dcbab4c8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -94,7 +94,8 @@ case class TungstenAggregate( aggregateAttributes, initialInputBufferOffset, resultExpressions, - newMutableProjection, + (expressions, inputSchema) => + newMutableProjection(expressions, inputSchema, subexpressionEliminationEnabled), child.output, iter, testFallbackStartsAt, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index d7f182352b..b159346bed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.parser.ParserConf import org.apache.spark.sql.execution.{aggregate, SparkQl} import org.apache.spark.sql.execution.joins.{CartesianProduct, SortMergeJoin} +import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext} import org.apache.spark.sql.test.SQLTestData._ @@ -1968,6 +1969,13 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { verifyCallCount( df.selectExpr("testUdf(a + 1) + testUdf(1 + b)", "testUdf(a + 1)"), Row(4, 2), 2) + val testUdf = functions.udf((x: Int) => { + countAcc.++=(1) + x + }) + verifyCallCount( + df.groupBy().agg(sum(testUdf($"b") + testUdf($"b") + testUdf($"b"))), Row(3.0), 1) + // Would be nice if semantic equals for `+` understood commutative verifyCallCount( df.selectExpr("testUdf(a + 1) + testUdf(1 + a)", "testUdf(a + 1)"), Row(4, 2), 2) -- GitLab