diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 36989a20cb5d5e71e0ab8a7b60ecbf01d6fb4ccc..bdd40f340235b17821f5b74e5c23823f951db0bc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** @@ -72,6 +72,7 @@ object ScalaReflection extends ScalaReflection { case t if t <:< definitions.ByteTpe => ByteType case t if t <:< definitions.BooleanTpe => BooleanType case t if t <:< localTypeOf[Array[Byte]] => BinaryType + case t if t <:< localTypeOf[CalendarInterval] => CalendarIntervalType case t if t <:< localTypeOf[Decimal] => DecimalType.SYSTEM_DEFAULT case _ => val className = getClassNameFromType(tpe) @@ -189,7 +190,6 @@ object ScalaReflection extends ScalaReflection { case _ => UpCast(expr, expected, walkedTypePath) } - val className = getClassNameFromType(tpe) tpe match { case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath @@ -239,16 +239,14 @@ object ScalaReflection extends ScalaReflection { DateTimeUtils.getClass, ObjectType(classOf[java.sql.Date]), "toJavaDate", - getPath :: Nil, - propagateNull = true) + getPath :: Nil) case t if t <:< localTypeOf[java.sql.Timestamp] => StaticInvoke( DateTimeUtils.getClass, ObjectType(classOf[java.sql.Timestamp]), "toJavaTimestamp", - getPath :: Nil, - propagateNull = true) + getPath :: Nil) case t if t <:< localTypeOf[java.lang.String] => Invoke(getPath, "toString", ObjectType(classOf[String])) @@ -437,17 +435,17 @@ object ScalaReflection extends ScalaReflection { walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { def toCatalystArray(input: Expression, elementType: `Type`): Expression = { - val externalDataType = dataTypeFor(elementType) - val Schema(catalystType, nullable) = silentSchemaFor(elementType) - if (isNativeType(externalDataType)) { - NewInstance( - classOf[GenericArrayData], - input :: Nil, - dataType = ArrayType(catalystType, nullable)) - } else { - val clsName = getClassNameFromType(elementType) - val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath - MapObjects(serializerFor(_, elementType, newPath), input, externalDataType) + dataTypeFor(elementType) match { + case dt: ObjectType => + val clsName = getClassNameFromType(elementType) + val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath + MapObjects(serializerFor(_, elementType, newPath), input, dt) + + case dt => + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(dt, schemaFor(elementType).nullable)) } } @@ -457,63 +455,10 @@ object ScalaReflection extends ScalaReflection { tpe match { case t if t <:< localTypeOf[Option[_]] => val TypeRef(_, _, Seq(optType)) = t - optType match { - // For primitive types we must manually unbox the value of the object. - case t if t <:< definitions.IntTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject), - "intValue", - IntegerType) - case t if t <:< definitions.LongTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject), - "longValue", - LongType) - case t if t <:< definitions.DoubleTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject), - "doubleValue", - DoubleType) - case t if t <:< definitions.FloatTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject), - "floatValue", - FloatType) - case t if t <:< definitions.ShortTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject), - "shortValue", - ShortType) - case t if t <:< definitions.ByteTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject), - "byteValue", - ByteType) - case t if t <:< definitions.BooleanTpe => - Invoke( - UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject), - "booleanValue", - BooleanType) - - // For non-primitives, we can just extract the object from the Option and then recurse. - case other => - val className = getClassNameFromType(optType) - val newPath = s"""- option value class: "$className"""" +: walkedTypePath - - val optionObjectType: DataType = other match { - // Special handling is required for arrays, as getClassFromType(<Array>) will fail - // since Scala Arrays map to native Java constructs. E.g. "Array[Int]" will map to - // the Java type "[I". - case arr if arr <:< localTypeOf[Array[_]] => arrayClassFor(t) - case cls => ObjectType(getClassFromType(cls)) - } - val unwrapped = UnwrapOption(optionObjectType, inputObject) - - expressions.If( - IsNull(unwrapped), - expressions.Literal.create(null, silentSchemaFor(optType).dataType), - serializerFor(unwrapped, optType, newPath)) - } + val className = getClassNameFromType(optType) + val newPath = s"""- option value class: "$className"""" +: walkedTypePath + val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) + serializerFor(unwrapped, optType, newPath) // Since List[_] also belongs to localTypeOf[Product], we put this case before // "case t if definedByConstructorParams(t)" to make sure it will match to the @@ -704,18 +649,6 @@ object ScalaReflection extends ScalaReflection { /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor[T: TypeTag]: Schema = schemaFor(localTypeOf[T]) - /** - * Returns a catalyst DataType and its nullability for the given Scala Type using reflection. - * - * Unlike `schemaFor`, this method won't throw exception for un-supported type, it will return - * `NullType` silently instead. - */ - def silentSchemaFor(tpe: `Type`): Schema = try { - schemaFor(tpe) - } catch { - case _: UnsupportedOperationException => Schema(NullType, nullable = true) - } - /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ def schemaFor(tpe: `Type`): Schema = ScalaReflectionLock.synchronized { tpe match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 5e17f8920901a4e3142e5283cb914fac3227a853..2f2323fa3a25f52b4e992c5c76dd620ab9ef25fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -289,8 +289,8 @@ case class UnwrapOption( ${inputObject.code} final boolean ${ev.isNull} = ${inputObject.isNull} || ${inputObject.value}.isEmpty(); - $javaType ${ev.value} = - ${ev.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) ${inputObject.value}.get(); + $javaType ${ev.value} = ${ev.isNull} ? + ${ctx.defaultValue(javaType)} : (${ctx.boxedType(javaType)}) ${inputObject.value}.get(); """ ev.copy(code = code) }