diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index c15fabab805a78f43047b80bca103eb879e5ba64..57f4945de980455620f794491fcb4cd3a6981ea3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -264,12 +264,16 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) */ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Rule[SparkPlan] { + private def supportCodegen(e: Expression): Boolean = e match { + case e: LeafExpression => true + // CodegenFallback requires the input to be an InternalRow + case e: CodegenFallback => false + case _ => true + } + private def supportCodegen(plan: SparkPlan): Boolean = plan match { case plan: CodegenSupport if plan.supportCodegen => - // Non-leaf with CodegenFallback does not work with whole stage codegen - val willFallback = plan.expressions.exists( - _.find(e => e.isInstanceOf[CodegenFallback] && !e.isInstanceOf[LeafExpression]).isDefined - ) + val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined) // the generated code will be huge if there are too many columns val haveManyColumns = plan.output.length > 200 !willFallback && !haveManyColumns diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 8dcbab4c8cfbc6ca5c9204eb64ca2a246951d2d8..23e54f344d2526b5e55b5248324dfa6580924d79 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} +import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.StructType @@ -35,7 +36,7 @@ case class TungstenAggregate( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryNode { + extends UnaryNode with CodegenSupport { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -113,6 +114,86 @@ case class TungstenAggregate( } } + override def supportCodegen: Boolean = { + groupingExpressions.isEmpty && + // ImperativeAggregate is not supported right now + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) && + // final aggregation only have one row, do not need to codegen + !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete) + } + + // The variables used as aggregation buffer + private var bufVars: Seq[ExprCode] = _ + + private val modes = aggregateExpressions.map(_.mode).distinct + + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + val initAgg = ctx.freshName("initAgg") + ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + + // generate variables for aggregation buffer + val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + val initExpr = functions.flatMap(f => f.initialValues) + bufVars = initExpr.map { e => + val isNull = ctx.freshName("bufIsNull") + val value = ctx.freshName("bufValue") + // The initial expression should not access any column + val ev = e.gen(ctx) + val initVars = s""" + | boolean $isNull = ${ev.isNull}; + | ${ctx.javaType(e.dataType)} $value = ${ev.value}; + """.stripMargin + ExprCode(ev.code + initVars, isNull, value) + } + + val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this) + val source = + s""" + | if (!$initAgg) { + | $initAgg = true; + | + | // initialize aggregation buffer + | ${bufVars.map(_.code).mkString("\n")} + | + | $childSource + | + | // output the result + | ${consume(ctx, bufVars)} + | } + """.stripMargin + + (rdd, source) + } + + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + // only have DeclarativeAggregate + val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + // the mode could be only Partial or PartialMerge + val updateExpr = if (modes.contains(Partial)) { + functions.flatMap(_.updateExpressions) + } else { + functions.flatMap(_.mergeExpressions) + } + + val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output + val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr)) + ctx.currentVars = bufVars ++ input + // TODO: support subexpression elimination + val codes = boundExpr.zipWithIndex.map { case (e, i) => + val ev = e.gen(ctx) + s""" + | ${ev.code} + | ${bufVars(i).isNull} = ${ev.isNull}; + | ${bufVars(i).value} = ${ev.value}; + """.stripMargin + } + + s""" + | // do aggregate and update aggregation buffer + | ${codes.mkString("")} + """.stripMargin + } + override def simpleString: String = { val allAggregateExpressions = aggregateExpressions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 788b04fcf8c2e9cb7be22c0c87124f6f739bdd00..c4aad398bfa5466381a8263c8d9c4f5161ef1ab7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -46,10 +46,10 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { /* Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------- - Without whole stage codegen 6725.52 31.18 1.00 X - With whole stage codegen 2233.05 93.91 3.01 X + Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + Without whole stage codegen 7775.53 26.97 1.00 X + With whole stage codegen 342.15 612.94 22.73 X */ benchmark.run() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index c54fc6ba2de3d8b44b02d4ea25bdbb6c23b3c4bd..300788c88ab2f6059efa0c161e363310c79a4da8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.execution.aggregate.TungstenAggregate +import org.apache.spark.sql.functions.{avg, col, max} import org.apache.spark.sql.test.SharedSQLContext class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { @@ -35,4 +38,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { sortAnswers = false ) } + + test("Aggregate should be included in WholeStageCodegen") { + val df = sqlContext.range(10).groupBy().agg(max(col("id")), avg(col("id"))) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) + assert(df.collect() === Array(Row(9, 4.5))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 4339f7260dcb9d732ac9bf8718ce39c1e018aceb..51285431a47ed03fa8354f0674593381a8d95080 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -71,7 +71,9 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { expectedNumOfJobs: Int, expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { val previousExecutionIds = sqlContext.listener.executionIdToData.keySet - df.collect() + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + df.collect() + } sparkContext.listenerBus.waitUntilEmpty(10000) val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1)