diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 437397187356ccb9c067865f6ccafa1bcd82ac28..f10d36862770785b9f20c70b495b81da9e25cdba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -242,6 +242,9 @@ class CodegenContext { private val classFunctions: mutable.Map[String, mutable.Map[String, String]] = mutable.Map(outerClassName -> mutable.Map.empty[String, String]) + // Verbatim extra code to be added to the OuterClass. + private val extraCode: mutable.ListBuffer[String] = mutable.ListBuffer[String]() + // Returns the size of the most recently added class. private def currClassSize(): Int = classSize(classes.head._1) @@ -328,6 +331,22 @@ class CodegenContext { (inlinedFunctions ++ initNestedClasses ++ declareNestedClasses).mkString("\n") } + /** + * Emits any source code added with addExtraCode + */ + def emitExtraCode(): String = { + extraCode.mkString("\n") + } + + /** + * Add extra source code to the outermost generated class. + * @param code verbatim source code to be added. + */ + def addExtraCode(code: String): Unit = { + extraCode.append(code) + classSize(outerClassName) += code.length + } + final val JAVA_BOOLEAN = "boolean" final val JAVA_BYTE = "byte" final val JAVA_SHORT = "short" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index a41a7ca56a0a1ae1a7fa7a8223e22a1cf0158467..268ccfa4edfa052a64825880f1e81406c61c6544 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -197,11 +197,14 @@ trait CodegenSupport extends SparkPlan { * * This should be override by subclass to support codegen. * - * For example, Filter will generate the code like this: + * Note: The operator should not assume the existence of an outer processing loop, + * which it can jump from with "continue;"! * + * For example, filter could generate this: * # code to evaluate the predicate expression, result is isNull1 and value2 - * if (isNull1 || !value2) continue; - * # call consume(), which will call parent.doConsume() + * if (!isNull1 && value2) { + * # call consume(), which will call parent.doConsume() + * } * * Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input). */ @@ -329,6 +332,15 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co def doCodeGen(): (CodegenContext, CodeAndComment) = { val ctx = new CodegenContext val code = child.asInstanceOf[CodegenSupport].produce(ctx, this) + + // main next function. + ctx.addNewFunction("processNext", + s""" + protected void processNext() throws java.io.IOException { + ${code.trim} + } + """, inlineToOuterClass = true) + val source = s""" public Object generate(Object[] references) { return new GeneratedIterator(references); @@ -352,9 +364,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co ${ctx.initPartition()} } - protected void processNext() throws java.io.IOException { - ${code.trim} - } + ${ctx.emitExtraCode()} ${ctx.declareAddedFunctions()} } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index d77405c559c584bc60319e4eefd77f461f87e3d3..abdf9530c6c7b52f966b6a8c6370bb8a8b27f4dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -425,12 +425,14 @@ case class HashAggregateExec( /** * Generate the code for output. + * @return function name for the result code. */ - private def generateResultCode( - ctx: CodegenContext, - keyTerm: String, - bufferTerm: String, - plan: String): String = { + private def generateResultFunction(ctx: CodegenContext): String = { + val funcName = ctx.freshName("doAggregateWithKeysOutput") + val keyTerm = ctx.freshName("keyTerm") + val bufferTerm = ctx.freshName("bufferTerm") + + val body = if (modes.contains(Final) || modes.contains(Complete)) { // generate output using resultExpressions ctx.currentVars = null @@ -462,18 +464,36 @@ case class HashAggregateExec( $evaluateAggResults ${consume(ctx, resultVars)} """ - } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { - // This should be the last operator in a stage, we should output UnsafeRow directly - val joinerTerm = ctx.freshName("unsafeRowJoiner") - ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm, - s"$joinerTerm = $plan.createUnsafeJoiner();") - val resultRow = ctx.freshName("resultRow") + // resultExpressions are Attributes of groupingExpressions and aggregateBufferAttributes. + assert(resultExpressions.forall(_.isInstanceOf[Attribute])) + assert(resultExpressions.length == + groupingExpressions.length + aggregateBufferAttributes.length) + + ctx.currentVars = null + + ctx.INPUT_ROW = keyTerm + val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => + BoundReference(i, e.dataType, e.nullable).genCode(ctx) + } + val evaluateKeyVars = evaluateVariables(keyVars) + + ctx.INPUT_ROW = bufferTerm + val resultBufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) => + BoundReference(i, e.dataType, e.nullable).genCode(ctx) + } + val evaluateResultBufferVars = evaluateVariables(resultBufferVars) + + ctx.currentVars = keyVars ++ resultBufferVars + val inputAttrs = resultExpressions.map(_.toAttribute) + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, inputAttrs).genCode(ctx) + } s""" - UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm); - ${consume(ctx, null, resultRow)} + $evaluateKeyVars + $evaluateResultBufferVars + ${consume(ctx, resultVars)} """ - } else { // generate result based on grouping key ctx.INPUT_ROW = keyTerm @@ -483,6 +503,13 @@ case class HashAggregateExec( } consume(ctx, eval) } + ctx.addNewFunction(funcName, + s""" + private void $funcName(UnsafeRow $keyTerm, UnsafeRow $bufferTerm) + throws java.io.IOException { + $body + } + """) } /** @@ -581,11 +608,6 @@ case class HashAggregateExec( val iterTerm = ctx.freshName("mapIter") ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") - val doAgg = ctx.freshName("doAggregateWithKeys") - val peakMemory = metricTerm(ctx, "peakMemory") - val spillSize = metricTerm(ctx, "spillSize") - val avgHashProbe = metricTerm(ctx, "avgHashProbe") - def generateGenerateCode(): String = { if (isFastHashMapEnabled) { if (isVectorizedHashMapEnabled) { @@ -599,10 +621,14 @@ case class HashAggregateExec( } } else "" } + ctx.addExtraCode(generateGenerateCode()) + val doAgg = ctx.freshName("doAggregateWithKeys") + val peakMemory = metricTerm(ctx, "peakMemory") + val spillSize = metricTerm(ctx, "spillSize") + val avgHashProbe = metricTerm(ctx, "avgHashProbe") val doAggFuncName = ctx.addNewFunction(doAgg, s""" - ${generateGenerateCode} private void $doAgg() throws java.io.IOException { $hashMapTerm = $thisPlan.createHashMap(); ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} @@ -618,7 +644,7 @@ case class HashAggregateExec( // generate code for output val keyTerm = ctx.freshName("aggKey") val bufferTerm = ctx.freshName("aggBuffer") - val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan) + val outputFunc = generateResultFunction(ctx) val numOutput = metricTerm(ctx, "numOutputRows") // The child could change `copyResult` to true, but we had already consumed all the rows, @@ -641,7 +667,7 @@ case class HashAggregateExec( $numOutput.add(1); UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey(); UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue(); - $outputCode + $outputFunc($keyTerm, $bufferTerm); if (shouldStop()) return; } @@ -654,18 +680,23 @@ case class HashAggregateExec( val row = ctx.freshName("fastHashMapRow") ctx.currentVars = null ctx.INPUT_ROW = row - var schema: StructType = groupingKeySchema - bufferSchema.foreach(i => schema = schema.add(i)) - val generateRow = GenerateUnsafeProjection.createCode(ctx, schema.toAttributes.zipWithIndex - .map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) }) + val generateKeyRow = GenerateUnsafeProjection.createCode(ctx, + groupingKeySchema.toAttributes.zipWithIndex + .map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) } + ) + val generateBufferRow = GenerateUnsafeProjection.createCode(ctx, + bufferSchema.toAttributes.zipWithIndex + .map { case (attr, i) => + BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) }) s""" | while ($iterTermForFastHashMap.hasNext()) { | $numOutput.add(1); | org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row = | (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row) | $iterTermForFastHashMap.next(); - | ${generateRow.code} - | ${consume(ctx, Seq.empty, {generateRow.value})} + | ${generateKeyRow.code} + | ${generateBufferRow.code} + | $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value}); | | if (shouldStop()) return; | } @@ -692,7 +723,7 @@ case class HashAggregateExec( $numOutput.add(1); UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); - $outputCode + $outputFunc($keyTerm, $bufferTerm); if (shouldStop()) return; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index e4e9372447f7cde86b2f9d472e04f1afda51a346..18142c44f0295a5af801447174d1498b99602067 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -201,11 +201,14 @@ case class FilterExec(condition: Expression, child: SparkPlan) ev } + // Note: wrap in "do { } while(false);", so the generated checks can jump out with "continue;" s""" - |$generated - |$nullChecks - |$numOutput.add(1); - |${consume(ctx, resultVars)} + |do { + | $generated + | $nullChecks + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + |} while(false); """.stripMargin } @@ -316,9 +319,10 @@ case class SampleExec( """.stripMargin.trim) s""" - | if ($sampler.sample() == 0) continue; - | $numOutput.add(1); - | ${consume(ctx, input)} + | if ($sampler.sample() != 0) { + | $numOutput.add(1); + | ${consume(ctx, input)} + | } """.stripMargin.trim } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala index ab7bb8ab9d87ad7a33c483c1e9dda68712484c6b..b09da9bdacb99b2d7d7ec51a3781cb200bfe27da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoinExec.scala @@ -186,8 +186,7 @@ case class BroadcastHashJoinExec( */ private def getJoinCondition( ctx: CodegenContext, - input: Seq[ExprCode], - anti: Boolean = false): (String, String, Seq[ExprCode]) = { + input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = { val matched = ctx.freshName("matched") val buildVars = genBuildSideVars(ctx, matched) val checkCondition = if (condition.isDefined) { @@ -198,18 +197,12 @@ case class BroadcastHashJoinExec( ctx.currentVars = input ++ buildVars val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).genCode(ctx) - val skipRow = if (!anti) { - s"${ev.isNull} || !${ev.value}" - } else { - s"!${ev.isNull} && ${ev.value}" - } + val skipRow = s"${ev.isNull} || !${ev.value}" s""" |$eval |${ev.code} - |if ($skipRow) continue; + |if (!($skipRow)) """.stripMargin - } else if (anti) { - "continue;" } else { "" } @@ -235,10 +228,12 @@ case class BroadcastHashJoinExec( |${keyEv.code} |// find matches from HashedRelation |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); - |if ($matched == null) continue; - |$checkCondition - |$numOutput.add(1); - |${consume(ctx, resultVars)} + |if ($matched != null) { + | $checkCondition { + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + |} """.stripMargin } else { @@ -250,12 +245,14 @@ case class BroadcastHashJoinExec( |${keyEv.code} |// find matches from HashRelation |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); - |if ($matches == null) continue; - |while ($matches.hasNext()) { - | UnsafeRow $matched = (UnsafeRow) $matches.next(); - | $checkCondition - | $numOutput.add(1); - | ${consume(ctx, resultVars)} + |if ($matches != null) { + | while ($matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); + | $checkCondition { + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } + | } |} """.stripMargin } @@ -328,10 +325,11 @@ case class BroadcastHashJoinExec( | UnsafeRow $matched = $matches != null && $matches.hasNext() ? | (UnsafeRow) $matches.next() : null; | ${checkCondition.trim} - | if (!$conditionPassed) continue; - | $found = true; - | $numOutput.add(1); - | ${consume(ctx, resultVars)} + | if ($conditionPassed) { + | $found = true; + | $numOutput.add(1); + | ${consume(ctx, resultVars)} + | } |} """.stripMargin } @@ -351,10 +349,12 @@ case class BroadcastHashJoinExec( |${keyEv.code} |// find matches from HashedRelation |UnsafeRow $matched = $anyNull ? null: (UnsafeRow)$relationTerm.getValue(${keyEv.value}); - |if ($matched == null) continue; - |$checkCondition - |$numOutput.add(1); - |${consume(ctx, input)} + |if ($matched != null) { + | $checkCondition { + | $numOutput.add(1); + | ${consume(ctx, input)} + | } + |} """.stripMargin } else { val matches = ctx.freshName("matches") @@ -365,16 +365,19 @@ case class BroadcastHashJoinExec( |${keyEv.code} |// find matches from HashRelation |$iteratorCls $matches = $anyNull ? null : ($iteratorCls)$relationTerm.get(${keyEv.value}); - |if ($matches == null) continue; - |boolean $found = false; - |while (!$found && $matches.hasNext()) { - | UnsafeRow $matched = (UnsafeRow) $matches.next(); - | $checkCondition - | $found = true; + |if ($matches != null) { + | boolean $found = false; + | while (!$found && $matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); + | $checkCondition { + | $found = true; + | } + | } + | if ($found) { + | $numOutput.add(1); + | ${consume(ctx, input)} + | } |} - |if (!$found) continue; - |$numOutput.add(1); - |${consume(ctx, input)} """.stripMargin } } @@ -386,11 +389,13 @@ case class BroadcastHashJoinExec( val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) val uniqueKeyCodePath = broadcastRelation.value.keyIsUnique val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) - val (matched, checkCondition, _) = getJoinCondition(ctx, input, uniqueKeyCodePath) + val (matched, checkCondition, _) = getJoinCondition(ctx, input) val numOutput = metricTerm(ctx, "numOutputRows") if (uniqueKeyCodePath) { + val found = ctx.freshName("found") s""" + |boolean $found = false; |// generate join key for stream side |${keyEv.code} |// Check if the key has nulls. @@ -399,17 +404,22 @@ case class BroadcastHashJoinExec( | UnsafeRow $matched = (UnsafeRow)$relationTerm.getValue(${keyEv.value}); | if ($matched != null) { | // Evaluate the condition. - | $checkCondition + | $checkCondition { + | $found = true; + | } | } |} - |$numOutput.add(1); - |${consume(ctx, input)} + |if (!$found) { + | $numOutput.add(1); + | ${consume(ctx, input)} + |} """.stripMargin } else { val matches = ctx.freshName("matches") val iteratorCls = classOf[Iterator[UnsafeRow]].getName val found = ctx.freshName("found") s""" + |boolean $found = false; |// generate join key for stream side |${keyEv.code} |// Check if the key has nulls. @@ -418,17 +428,18 @@ case class BroadcastHashJoinExec( | $iteratorCls $matches = ($iteratorCls)$relationTerm.get(${keyEv.value}); | if ($matches != null) { | // Evaluate the condition. - | boolean $found = false; | while (!$found && $matches.hasNext()) { | UnsafeRow $matched = (UnsafeRow) $matches.next(); - | $checkCondition - | $found = true; + | $checkCondition { + | $found = true; + | } | } - | if ($found) continue; | } |} - |$numOutput.add(1); - |${consume(ctx, input)} + |if (!$found) { + | $numOutput.add(1); + | ${consume(ctx, input)} + |} """.stripMargin } }