diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index c4af284f73d16cb106dd6625926e398f4778977d..1c7720afe1ca354374943a7fa74a1c03eb76d956 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -307,54 +307,11 @@ object ScalaReflection extends ScalaReflection { } } - val array = Invoke( - MapObjects(mapFunction, getPath, dataType), - "array", - ObjectType(classOf[Array[Any]])) - - val wrappedArray = StaticInvoke( - scala.collection.mutable.WrappedArray.getClass, - ObjectType(classOf[Seq[_]]), - "make", - array :: Nil) - - if (localTypeOf[scala.collection.mutable.WrappedArray[_]] <:< t.erasure) { - wrappedArray - } else { - // Convert to another type using `to` - val cls = mirror.runtimeClass(t.typeSymbol.asClass) - import scala.collection.generic.CanBuildFrom - import scala.reflect.ClassTag - - // Some canBuildFrom methods take an implicit ClassTag parameter - val cbfParams = try { - cls.getDeclaredMethod("canBuildFrom", classOf[ClassTag[_]]) - StaticInvoke( - ClassTag.getClass, - ObjectType(classOf[ClassTag[_]]), - "apply", - StaticInvoke( - cls, - ObjectType(classOf[Class[_]]), - "getClass" - ) :: Nil - ) :: Nil - } catch { - case _: NoSuchMethodException => Nil - } - - Invoke( - wrappedArray, - "to", - ObjectType(cls), - StaticInvoke( - cls, - ObjectType(classOf[CanBuildFrom[_, _, _]]), - "canBuildFrom", - cbfParams - ) :: Nil - ) + val cls = t.dealias.companion.decl(TermName("newBuilder")) match { + case NoSymbol => classOf[Seq[_]] + case _ => mirror.runtimeClass(t.typeSymbol.asClass) } + MapObjects(mapFunction, getPath, dataType, Some(cls)) case t if t <:< localTypeOf[Map[_, _]] => // TODO: add walked type path for map diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 771ac28e5107aa762c346af7fb1ff3d149f7a760..bb584f7d087e8708b148d37e321972004f5055a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions.objects import java.lang.reflect.Modifier +import scala.collection.mutable.Builder import scala.language.existentials import scala.reflect.ClassTag @@ -429,24 +430,34 @@ object MapObjects { * @param function The function applied on the collection elements. * @param inputData An expression that when evaluated returns a collection object. * @param elementType The data type of elements in the collection. + * @param customCollectionCls Class of the resulting collection (returning ObjectType) + * or None (returning ArrayType) */ def apply( function: Expression => Expression, inputData: Expression, - elementType: DataType): MapObjects = { - val loopValue = "MapObjects_loopValue" + curId.getAndIncrement() - val loopIsNull = "MapObjects_loopIsNull" + curId.getAndIncrement() + elementType: DataType, + customCollectionCls: Option[Class[_]] = None): MapObjects = { + val id = curId.getAndIncrement() + val loopValue = s"MapObjects_loopValue$id" + val loopIsNull = s"MapObjects_loopIsNull$id" val loopVar = LambdaVariable(loopValue, loopIsNull, elementType) - MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData) + val builderValue = s"MapObjects_builderValue$id" + MapObjects(loopValue, loopIsNull, elementType, function(loopVar), inputData, + customCollectionCls, builderValue) } } /** * Applies the given expression to every element of a collection of items, returning the result - * as an ArrayType. This is similar to a typical map operation, but where the lambda function - * is expressed using catalyst expressions. + * as an ArrayType or ObjectType. This is similar to a typical map operation, but where the lambda + * function is expressed using catalyst expressions. + * + * The type of the result is determined as follows: + * - ArrayType - when customCollectionCls is None + * - ObjectType(collection) - when customCollectionCls contains a collection class * - * The following collection ObjectTypes are currently supported: + * The following collection ObjectTypes are currently supported on input: * Seq, Array, ArrayData, java.util.List * * @param loopValue the name of the loop variable that used when iterate the collection, and used @@ -458,13 +469,19 @@ object MapObjects { * @param lambdaFunction A function that take the `loopVar` as input, and used as lambda function * to handle collection elements. * @param inputData An expression that when evaluated returns a collection object. + * @param customCollectionCls Class of the resulting collection (returning ObjectType) + * or None (returning ArrayType) + * @param builderValue The name of the builder variable used to construct the resulting collection + * (used only when returning ObjectType) */ case class MapObjects private( loopValue: String, loopIsNull: String, loopVarDataType: DataType, lambdaFunction: Expression, - inputData: Expression) extends Expression with NonSQLExpression { + inputData: Expression, + customCollectionCls: Option[Class[_]], + builderValue: String) extends Expression with NonSQLExpression { override def nullable: Boolean = inputData.nullable @@ -474,7 +491,8 @@ case class MapObjects private( throw new UnsupportedOperationException("Only code-generated evaluation is supported") override def dataType: DataType = - ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable) + customCollectionCls.map(ObjectType.apply).getOrElse( + ArrayType(lambdaFunction.dataType, containsNull = lambdaFunction.nullable)) override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val elementJavaType = ctx.javaType(loopVarDataType) @@ -557,15 +575,33 @@ case class MapObjects private( case _ => s"$loopIsNull = $loopValue == null;" } + val (initCollection, addElement, getResult): (String, String => String, String) = + customCollectionCls match { + case Some(cls) => + // collection + val collObjectName = s"${cls.getName}$$.MODULE$$" + val getBuilderVar = s"$collObjectName.newBuilder()" + + (s"""${classOf[Builder[_, _]].getName} $builderValue = $getBuilderVar; + $builderValue.sizeHint($dataLength);""", + genValue => s"$builderValue.$$plus$$eq($genValue);", + s"(${cls.getName}) $builderValue.result();") + case None => + // array + (s"""$convertedType[] $convertedArray = null; + $convertedArray = $arrayConstructor;""", + genValue => s"$convertedArray[$loopIndex] = $genValue;", + s"new ${classOf[GenericArrayData].getName}($convertedArray);") + } + val code = s""" ${genInputData.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${genInputData.isNull}) { $determineCollectionType - $convertedType[] $convertedArray = null; int $dataLength = $getLength; - $convertedArray = $arrayConstructor; + $initCollection int $loopIndex = 0; while ($loopIndex < $dataLength) { @@ -574,15 +610,15 @@ case class MapObjects private( ${genFunction.code} if (${genFunction.isNull}) { - $convertedArray[$loopIndex] = null; + ${addElement("null")} } else { - $convertedArray[$loopIndex] = $genFunctionValue; + ${addElement(genFunctionValue)} } $loopIndex += 1; } - ${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray); + ${ev.value} = $getResult } """ ev.copy(code = code, isNull = genInputData.isNull) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index 650a35398f3e8e7b03ca643e3479616e9d98eecd..70ad064f93ebc178863d32ca67388bb1ff9d8875 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -312,14 +312,6 @@ class ScalaReflectionSuite extends SparkFunSuite { ArrayType(IntegerType, containsNull = false)) val arrayBufferDeserializer = deserializerFor[ArrayBuffer[Int]] assert(arrayBufferDeserializer.dataType == ObjectType(classOf[ArrayBuffer[_]])) - - // Check whether conversion is skipped when using WrappedArray[_] supertype - // (would otherwise needlessly add overhead) - import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke - val seqDeserializer = deserializerFor[Seq[Int]] - assert(seqDeserializer.asInstanceOf[StaticInvoke].staticObject == - scala.collection.mutable.WrappedArray.getClass) - assert(seqDeserializer.asInstanceOf[StaticInvoke].functionName == "make") } private val dataTypeForComplexData = dataTypeFor[ComplexData]