diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index fc3ecf5451426de6a36ecb2ccd0b13dc396fcfa5..71f8ea09f0770e924a3d9847e5492d1e5475ccc6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -116,154 +116,6 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro throw new UnsupportedOperationException(s"Not supported DataType: $fieldType") } - /** - * Generates the code to create an [[UnsafeRow]] object based on the input expressions. - * @param ctx context for code generation - * @param ev specifies the name of the variable for the output [[UnsafeRow]] object - * @param expressions input expressions - * @return generated code to put the expression output into an [[UnsafeRow]] - */ - def createCode(ctx: CodeGenContext, ev: GeneratedExpressionCode, expressions: Seq[Expression]) - : String = { - - val ret = ev.primitive - ctx.addMutableState("UnsafeRow", ret, s"$ret = new UnsafeRow();") - val buffer = ctx.freshName("buffer") - ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") - val cursor = ctx.freshName("cursor") - val numBytes = ctx.freshName("numBytes") - - val exprs = expressions.map { e => e.dataType match { - case st: StructType => createCodeForStruct(ctx, e.gen(ctx), st) - case _ => e.gen(ctx) - }} - val allExprs = exprs.map(_.code).mkString("\n") - - val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) - val additionalSize = expressions.zipWithIndex.map { - case (e, i) => genAdditionalSize(e.dataType, exprs(i)) - }.mkString("") - - val writers = expressions.zipWithIndex.map { case (e, i) => - val update = genFieldWriter(ctx, e.dataType, exprs(i), ret, i, cursor) - s"""if (${exprs(i).isNull}) { - $ret.setNullAt($i); - } else { - $update; - }""" - }.mkString("\n ") - - s""" - $allExprs - int $numBytes = $fixedSize $additionalSize; - if ($numBytes > $buffer.length) { - $buffer = new byte[$numBytes]; - } - - $ret.pointTo( - $buffer, - $PlatformDependent.BYTE_ARRAY_OFFSET, - ${expressions.size}, - $numBytes); - int $cursor = $fixedSize; - - $writers - boolean ${ev.isNull} = false; - """ - } - - /** - * Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow. - * - * This function also handles nested structs by recursively generating the code to do conversion. - * - * @param ctx code generation context - * @param input the input struct, identified by a [[GeneratedExpressionCode]] - * @param schema schema of the struct field - */ - // TODO: refactor createCode and this function to reduce code duplication. - private def createCodeForStruct( - ctx: CodeGenContext, - input: GeneratedExpressionCode, - schema: StructType): GeneratedExpressionCode = { - - val isNull = input.isNull - val primitive = ctx.freshName("structConvert") - ctx.addMutableState("UnsafeRow", primitive, s"$primitive = new UnsafeRow();") - val buffer = ctx.freshName("buffer") - ctx.addMutableState("byte[]", buffer, s"$buffer = new byte[64];") - val cursor = ctx.freshName("cursor") - - val exprs: Seq[GeneratedExpressionCode] = schema.map(_.dataType).zipWithIndex.map { - case (dt, i) => dt match { - case st: StructType => - val nestedStructEv = GeneratedExpressionCode( - code = "", - isNull = s"${input.primitive}.isNullAt($i)", - primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}" - ) - createCodeForStruct(ctx, nestedStructEv, st) - case _ => - GeneratedExpressionCode( - code = "", - isNull = s"${input.primitive}.isNullAt($i)", - primitive = s"${ctx.getValue(input.primitive, dt, i.toString)}" - ) - } - } - val allExprs = exprs.map(_.code).mkString("\n") - - val fixedSize = 8 * exprs.length + UnsafeRow.calculateBitSetWidthInBytes(exprs.length) - val additionalSize = schema.toSeq.map(_.dataType).zip(exprs).map { case (dt, ev) => - genAdditionalSize(dt, ev) - }.mkString("") - - val writers = schema.toSeq.map(_.dataType).zip(exprs).zipWithIndex.map { case ((dt, ev), i) => - val update = genFieldWriter(ctx, dt, ev, primitive, i, cursor) - s""" - if (${exprs(i).isNull}) { - $primitive.setNullAt($i); - } else { - $update; - } - """ - }.mkString("\n ") - - // Note that we add a shortcut here for performance: if the input is already an UnsafeRow, - // just copy the bytes directly into our buffer space without running any conversion. - // We also had to use a hack to introduce a "tmp" variable, to avoid the Java compiler from - // complaining that a GenericMutableRow (generated by expressions) cannot be cast to UnsafeRow. - val tmp = ctx.freshName("tmp") - val numBytes = ctx.freshName("numBytes") - val code = s""" - |${input.code} - |if (!${input.isNull}) { - | Object $tmp = (Object) ${input.primitive}; - | if ($tmp instanceof UnsafeRow) { - | $primitive = (UnsafeRow) $tmp; - | } else { - | $allExprs - | - | int $numBytes = $fixedSize $additionalSize; - | if ($numBytes > $buffer.length) { - | $buffer = new byte[$numBytes]; - | } - | - | $primitive.pointTo( - | $buffer, - | $PlatformDependent.BYTE_ARRAY_OFFSET, - | ${exprs.size}, - | $numBytes); - | int $cursor = $fixedSize; - | - | $writers - | } - |} - """.stripMargin - - GeneratedExpressionCode(code, isNull, primitive) - } - /** * Generates the Java code to convert a struct (backed by InternalRow) to UnsafeRow. * @@ -271,7 +123,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro * @param inputs could be the codes for expressions or input struct fields. * @param inputTypes types of the inputs */ - private def createCodeForStruct2( + private def createCodeForStruct( ctx: CodeGenContext, inputs: Seq[GeneratedExpressionCode], inputTypes: Seq[DataType]): GeneratedExpressionCode = { @@ -537,7 +389,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val fieldIsNull = s"$tmp.isNullAt($i)" GeneratedExpressionCode("", fieldIsNull, getFieldCode) } - val converter = createCodeForStruct2(ctx, fieldEvals, fieldTypes) + val converter = createCodeForStruct(ctx, fieldEvals, fieldTypes) val code = s""" ${input.code} UnsafeRow $output = null; @@ -561,6 +413,12 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro case _ => input } + def createCode(ctx: CodeGenContext, expressions: Seq[Expression]): GeneratedExpressionCode = { + val exprEvals = expressions.map(e => e.gen(ctx)) + val exprTypes = expressions.map(_.dataType) + createCodeForStruct(ctx, exprEvals, exprTypes) + } + protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) @@ -570,8 +428,7 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro protected def create(expressions: Seq[Expression]): UnsafeProjection = { val ctx = newCodeGenContext() - val exprEvals = expressions.map(e => e.gen(ctx)) - val eval = createCodeForStruct2(ctx, exprEvals, expressions.map(_.dataType)) + val eval = createCode(ctx, expressions) val code = s""" public Object generate($exprType[] exprs) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index a145dfb4bbf082b89f9d54c9cf0e4ca7f1ec8436..4a071e663e0d1f6d264b819d328b6cdbc9076341 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -211,7 +211,10 @@ case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - GenerateUnsafeProjection.createCode(ctx, ev, children) + val eval = GenerateUnsafeProjection.createCode(ctx, children) + ev.isNull = eval.isNull + ev.primitive = eval.primitive + eval.code } override def prettyName: String = "struct_unsafe" @@ -246,7 +249,10 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression override def eval(input: InternalRow): Any = throw new UnsupportedOperationException override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - GenerateUnsafeProjection.createCode(ctx, ev, valExprs) + val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) + ev.isNull = eval.isNull + ev.primitive = eval.primitive + eval.code } override def prettyName: String = "named_struct_unsafe"