diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index a3b722a47d688400ce7edeac6de73d3b47e07bfd..743782a6453e9641113592b21a932b293715273a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -104,16 +104,48 @@ abstract class Expression extends TreeNode[Expression] { }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val ve = doGenCode(ctx, ExprCode("", isNull, value)) - if (ve.code.nonEmpty) { + val eval = doGenCode(ctx, ExprCode("", isNull, value)) + reduceCodeSize(ctx, eval) + if (eval.code.nonEmpty) { // Add `this` in the comment. - ve.copy(code = s"${ctx.registerComment(this.toString)}\n" + ve.code.trim) + eval.copy(code = s"${ctx.registerComment(this.toString)}\n" + eval.code.trim) } else { - ve + eval } } } + private def reduceCodeSize(ctx: CodegenContext, eval: ExprCode): Unit = { + // TODO: support whole stage codegen too + if (eval.code.trim.length > 1024 && ctx.INPUT_ROW != null && ctx.currentVars == null) { + val setIsNull = if (eval.isNull != "false" && eval.isNull != "true") { + val globalIsNull = ctx.freshName("globalIsNull") + ctx.addMutableState(ctx.JAVA_BOOLEAN, globalIsNull) + val localIsNull = eval.isNull + eval.isNull = globalIsNull + s"$globalIsNull = $localIsNull;" + } else { + "" + } + + val javaType = ctx.javaType(dataType) + val newValue = ctx.freshName("value") + + val funcName = ctx.freshName(nodeName) + val funcFullName = ctx.addNewFunction(funcName, + s""" + |private $javaType $funcName(InternalRow ${ctx.INPUT_ROW}) { + | ${eval.code.trim} + | $setIsNull + | return ${eval.value}; + |} + """.stripMargin) + + eval.value = newValue + eval.code = s"$javaType $newValue = $funcFullName(${ctx.INPUT_ROW});" + } + } + /** * Returns Java source code that can be compiled to evaluate this expression. * The default behavior is to call the eval method of the expression. Concrete expression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 78617194e47d5c755830dfb9f140347135150cfd..9df8a8d6f6609e70cd34d284c8f52376ae7a0aeb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -930,36 +930,6 @@ class CodegenContext { } } - /** - * Wrap the generated code of expression, which was created from a row object in INPUT_ROW, - * by a function. ev.isNull and ev.value are passed by global variables - * - * @param ev the code to evaluate expressions. - * @param dataType the data type of ev.value. - * @param baseFuncName the split function name base. - */ - def createAndAddFunction( - ev: ExprCode, - dataType: DataType, - baseFuncName: String): (String, String, String) = { - val globalIsNull = freshName("isNull") - addMutableState(JAVA_BOOLEAN, globalIsNull, s"$globalIsNull = false;") - val globalValue = freshName("value") - addMutableState(javaType(dataType), globalValue, - s"$globalValue = ${defaultValue(dataType)};") - val funcName = freshName(baseFuncName) - val funcBody = - s""" - |private void $funcName(InternalRow ${INPUT_ROW}) { - | ${ev.code.trim} - | $globalIsNull = ${ev.isNull}; - | $globalValue = ${ev.value}; - |} - """.stripMargin - val fullFuncName = addNewFunction(funcName, funcBody) - (fullFuncName, globalIsNull, globalValue) - } - /** * Perform a function which generates a sequence of ExprCodes with a given mapping between * expressions and common expressions, instead of using the mapping in current context. @@ -1065,7 +1035,8 @@ class CodegenContext { * elimination will be performed. Subexpression elimination assumes that the code for each * expression will be combined in the `expressions` order. */ - def generateExpressions(expressions: Seq[Expression], + def generateExpressions( + expressions: Seq[Expression], doSubexpressionElimination: Boolean = false): Seq[ExprCode] = { if (doSubexpressionElimination) subexpressionElimination(expressions) expressions.map(e => e.genCode(this)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index c41a10c7b0f875e57d59c3d6deca11248269f723..6195be3a258c4344b962e7642564949892b622ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -64,52 +64,22 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi val trueEval = trueValue.genCode(ctx) val falseEval = falseValue.genCode(ctx) - // place generated code of condition, true value and false value in separate methods if - // their code combined is large - val combinedLength = condEval.code.length + trueEval.code.length + falseEval.code.length - val generatedCode = if (combinedLength > 1024 && - // Split these expressions only if they are created from a row object - (ctx.INPUT_ROW != null && ctx.currentVars == null)) { - - val (condFuncName, condGlobalIsNull, condGlobalValue) = - ctx.createAndAddFunction(condEval, predicate.dataType, "evalIfCondExpr") - val (trueFuncName, trueGlobalIsNull, trueGlobalValue) = - ctx.createAndAddFunction(trueEval, trueValue.dataType, "evalIfTrueExpr") - val (falseFuncName, falseGlobalIsNull, falseGlobalValue) = - ctx.createAndAddFunction(falseEval, falseValue.dataType, "evalIfFalseExpr") + val code = s""" - $condFuncName(${ctx.INPUT_ROW}); - boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!$condGlobalIsNull && $condGlobalValue) { - $trueFuncName(${ctx.INPUT_ROW}); - ${ev.isNull} = $trueGlobalIsNull; - ${ev.value} = $trueGlobalValue; - } else { - $falseFuncName(${ctx.INPUT_ROW}); - ${ev.isNull} = $falseGlobalIsNull; - ${ev.value} = $falseGlobalValue; - } - """ - } - else { - s""" - ${condEval.code} - boolean ${ev.isNull} = false; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${condEval.isNull} && ${condEval.value}) { - ${trueEval.code} - ${ev.isNull} = ${trueEval.isNull}; - ${ev.value} = ${trueEval.value}; - } else { - ${falseEval.code} - ${ev.isNull} = ${falseEval.isNull}; - ${ev.value} = ${falseEval.value}; - } - """ - } - - ev.copy(code = generatedCode) + |${condEval.code} + |boolean ${ev.isNull} = false; + |${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; + |if (!${condEval.isNull} && ${condEval.value}) { + | ${trueEval.code} + | ${ev.isNull} = ${trueEval.isNull}; + | ${ev.value} = ${trueEval.value}; + |} else { + | ${falseEval.code} + | ${ev.isNull} = ${falseEval.isNull}; + | ${ev.value} = ${falseEval.value}; + |} + """.stripMargin + ev.copy(code = code) } override def toString: String = s"if ($predicate) $trueValue else $falseValue" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index e518e73cba5493b5b3ecccfc7378256007e84a9e..8df870468c2ad946e8f8777bc500db2f508fc35b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -140,7 +140,9 @@ case class Alias(child: Expression, name: String)( /** Just a simple passthrough for code generation. */ override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx) - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("") + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + throw new IllegalStateException("Alias.doGenCode should not be called.") + } override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index c0084af320689baabb37911fcc78f0cc86c66616..eb7475354b104efdbf077e08bbd42a4e086e4041 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -378,46 +378,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with val eval2 = right.genCode(ctx) // The result should be `false`, if any of them is `false` whenever the other is null or not. - - // place generated code of eval1 and eval2 in separate methods if their code combined is large - val combinedLength = eval1.code.length + eval2.code.length - if (combinedLength > 1024 && - // Split these expressions only if they are created from a row object - (ctx.INPUT_ROW != null && ctx.currentVars == null)) { - - val (eval1FuncName, eval1GlobalIsNull, eval1GlobalValue) = - ctx.createAndAddFunction(eval1, BooleanType, "eval1Expr") - val (eval2FuncName, eval2GlobalIsNull, eval2GlobalValue) = - ctx.createAndAddFunction(eval2, BooleanType, "eval2Expr") - if (!left.nullable && !right.nullable) { - val generatedCode = s""" - $eval1FuncName(${ctx.INPUT_ROW}); - boolean ${ev.value} = false; - if (${eval1GlobalValue}) { - $eval2FuncName(${ctx.INPUT_ROW}); - ${ev.value} = ${eval2GlobalValue}; - } - """ - ev.copy(code = generatedCode, isNull = "false") - } else { - val generatedCode = s""" - $eval1FuncName(${ctx.INPUT_ROW}); - boolean ${ev.isNull} = false; - boolean ${ev.value} = false; - if (!${eval1GlobalIsNull} && !${eval1GlobalValue}) { - } else { - $eval2FuncName(${ctx.INPUT_ROW}); - if (!${eval2GlobalIsNull} && !${eval2GlobalValue}) { - } else if (!${eval1GlobalIsNull} && !${eval2GlobalIsNull}) { - ${ev.value} = true; - } else { - ${ev.isNull} = true; - } - } - """ - ev.copy(code = generatedCode) - } - } else if (!left.nullable && !right.nullable) { + if (!left.nullable && !right.nullable) { ev.copy(code = s""" ${eval1.code} boolean ${ev.value} = false; @@ -480,46 +441,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P val eval2 = right.genCode(ctx) // The result should be `true`, if any of them is `true` whenever the other is null or not. - - // place generated code of eval1 and eval2 in separate methods if their code combined is large - val combinedLength = eval1.code.length + eval2.code.length - if (combinedLength > 1024 && - // Split these expressions only if they are created from a row object - (ctx.INPUT_ROW != null && ctx.currentVars == null)) { - - val (eval1FuncName, eval1GlobalIsNull, eval1GlobalValue) = - ctx.createAndAddFunction(eval1, BooleanType, "eval1Expr") - val (eval2FuncName, eval2GlobalIsNull, eval2GlobalValue) = - ctx.createAndAddFunction(eval2, BooleanType, "eval2Expr") - if (!left.nullable && !right.nullable) { - val generatedCode = s""" - $eval1FuncName(${ctx.INPUT_ROW}); - boolean ${ev.value} = true; - if (!${eval1GlobalValue}) { - $eval2FuncName(${ctx.INPUT_ROW}); - ${ev.value} = ${eval2GlobalValue}; - } - """ - ev.copy(code = generatedCode, isNull = "false") - } else { - val generatedCode = s""" - $eval1FuncName(${ctx.INPUT_ROW}); - boolean ${ev.isNull} = false; - boolean ${ev.value} = true; - if (!${eval1GlobalIsNull} && ${eval1GlobalValue}) { - } else { - $eval2FuncName(${ctx.INPUT_ROW}); - if (!${eval2GlobalIsNull} && ${eval2GlobalValue}) { - } else if (!${eval1GlobalIsNull} && !${eval2GlobalIsNull}) { - ${ev.value} = false; - } else { - ${ev.isNull} = true; - } - } - """ - ev.copy(code = generatedCode) - } - } else if (!left.nullable && !right.nullable) { + if (!left.nullable && !right.nullable) { ev.isNull = "false" ev.copy(code = s""" ${eval1.code} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 8f6289f00571cc335c014c607bf621ab85cbcdad..6e33087b4c6c84b7941553bab409ef6a65a409cd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -97,7 +97,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { assert(actual(0) == cases) } - test("SPARK-18091: split large if expressions into blocks due to JVM code size limit") { + test("SPARK-22543: split large if expressions into blocks due to JVM code size limit") { var strExpr: Expression = Literal("abc") for (_ <- 1 to 150) { strExpr = Decode(Encode(strExpr, "utf-8"), "utf-8") @@ -342,7 +342,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { projection(row) } - test("SPARK-21720: split large predications into blocks due to JVM code size limit") { + test("SPARK-22543: split large predicates into blocks due to JVM code size limit") { val length = 600 val input = new GenericInternalRow(length) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 19c793e45a57d0ee819461581aa0f2144c2b65b0..dc8aecf185a96e3b2d3929d85da16f9b27b24a2c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -179,6 +179,8 @@ case class HashAggregateExec( private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") ctx.addMutableState(ctx.JAVA_BOOLEAN, initAgg, s"$initAgg = false;") + // The generated function doesn't have input row in the code context. + ctx.INPUT_ROW = null // generate variables for aggregation buffer val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])