diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 6d822261b050a290727bbd717a344966fcbe2f1f..0b3dd351e38e8d398362bbaed2f86b394410ac86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -75,7 +75,7 @@ trait ScalaReflection { * * @see SPARK-5281 */ - private def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe + def localTypeOf[T: TypeTag]: `Type` = typeTag[T].in(mirror).tpe /** * Returns the Spark SQL DataType for a given scala type. Where this is not an exact mapping diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala new file mode 100644 index 0000000000000000000000000000000000000000..6d307ab13a9fc19bf24af50eb99c97234fea15d7 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoder.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.reflect.ClassTag +import scala.reflect.runtime.universe.{typeTag, TypeTag} + +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.catalyst.expressions.{Literal, CreateNamedStruct, BoundReference} +import org.apache.spark.sql.catalyst.ScalaReflection + +object FlatEncoder { + import ScalaReflection.schemaFor + import ScalaReflection.dataTypeFor + + def apply[T : TypeTag]: ExpressionEncoder[T] = { + // We convert the not-serializable TypeTag into StructType and ClassTag. + val tpe = typeTag[T].tpe + val mirror = typeTag[T].mirror + val cls = mirror.runtimeClass(tpe) + assert(!schemaFor(tpe).dataType.isInstanceOf[StructType]) + + val input = BoundReference(0, dataTypeFor(tpe), nullable = true) + val toRowExpression = CreateNamedStruct( + Literal("value") :: ProductEncoder.extractorFor(input, tpe) :: Nil) + val fromRowExpression = ProductEncoder.constructorFor(tpe) + + new ExpressionEncoder[T]( + toRowExpression.dataType, + flat = true, + toRowExpression.flatten, + fromRowExpression, + ClassTag[T](cls)) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala new file mode 100644 index 0000000000000000000000000000000000000000..414adb21168ed5424e6acb101a5ebb3dc32587e2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoder.scala @@ -0,0 +1,452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import org.apache.spark.util.Utils +import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.ScalaReflectionLock +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, ArrayBasedMapData, GenericArrayData} + +import scala.reflect.ClassTag + +object ProductEncoder { + import ScalaReflection.universe._ + import ScalaReflection.localTypeOf + import ScalaReflection.dataTypeFor + import ScalaReflection.Schema + import ScalaReflection.schemaFor + import ScalaReflection.arrayClassFor + + def apply[T <: Product : TypeTag]: ExpressionEncoder[T] = { + // We convert the not-serializable TypeTag into StructType and ClassTag. + val tpe = typeTag[T].tpe + val mirror = typeTag[T].mirror + val cls = mirror.runtimeClass(tpe) + + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val toRowExpression = extractorFor(inputObject, tpe).asInstanceOf[CreateNamedStruct] + val fromRowExpression = constructorFor(tpe) + + new ExpressionEncoder[T]( + toRowExpression.dataType, + flat = false, + toRowExpression.flatten, + fromRowExpression, + ClassTag[T](cls)) + } + + // The Predef.Map is scala.collection.immutable.Map. + // Since the map values can be mutable, we explicitly import scala.collection.Map at here. + import scala.collection.Map + + def extractorFor( + inputObject: Expression, + tpe: `Type`): Expression = ScalaReflectionLock.synchronized { + if (!inputObject.dataType.isInstanceOf[ObjectType]) { + inputObject + } else { + tpe match { + case t if t <:< localTypeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + optType match { + // For primitive types we must manually unbox the value of the object. + case t if t <:< definitions.IntTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Integer]), inputObject), + "intValue", + IntegerType) + case t if t <:< definitions.LongTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Long]), inputObject), + "longValue", + LongType) + case t if t <:< definitions.DoubleTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Double]), inputObject), + "doubleValue", + DoubleType) + case t if t <:< definitions.FloatTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Float]), inputObject), + "floatValue", + FloatType) + case t if t <:< definitions.ShortTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Short]), inputObject), + "shortValue", + ShortType) + case t if t <:< definitions.ByteTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Byte]), inputObject), + "byteValue", + ByteType) + case t if t <:< definitions.BooleanTpe => + Invoke( + UnwrapOption(ObjectType(classOf[java.lang.Boolean]), inputObject), + "booleanValue", + BooleanType) + + // For non-primitives, we can just extract the object from the Option and then recurse. + case other => + val className: String = optType.erasure.typeSymbol.asClass.fullName + val classObj = Utils.classForName(className) + val optionObjectType = ObjectType(classObj) + + val unwrapped = UnwrapOption(optionObjectType, inputObject) + expressions.If( + IsNull(unwrapped), + expressions.Literal.create(null, schemaFor(optType).dataType), + extractorFor(unwrapped, optType)) + } + + case t if t <:< localTypeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val constructorSymbol = t.member(nme.CONSTRUCTOR) + val params = if (constructorSymbol.isMethod) { + constructorSymbol.asMethod.paramss + } else { + // Find the primary constructor, and use its parameter ordering. + val primaryConstructorSymbol: Option[Symbol] = + constructorSymbol.asTerm.alternatives.find(s => + s.isMethod && s.asMethod.isPrimaryConstructor) + + if (primaryConstructorSymbol.isEmpty) { + sys.error("Internal SQL error: Product object did not have a primary constructor.") + } else { + primaryConstructorSymbol.get.asMethod.paramss + } + } + + CreateNamedStruct(params.head.flatMap { p => + val fieldName = p.name.toString + val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) + expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil + }) + + case t if t <:< localTypeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = t + toCatalystArray(inputObject, elementType) + + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + toCatalystArray(inputObject, elementType) + + case t if t <:< localTypeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + + val keys = + Invoke( + Invoke(inputObject, "keysIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedKeys = toCatalystArray(keys, keyType) + + val values = + Invoke( + Invoke(inputObject, "valuesIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedValues = toCatalystArray(values, valueType) + + val Schema(keyDataType, _) = schemaFor(keyType) + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + NewInstance( + classOf[ArrayBasedMapData], + convertedKeys :: convertedValues :: Nil, + dataType = MapType(keyDataType, valueDataType, valueNullable)) + + case t if t <:< localTypeOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + DateType, + "fromJavaDate", + inputObject :: Nil) + + case t if t <:< localTypeOf[BigDecimal] => + StaticInvoke( + Decimal, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.math.BigDecimal] => + StaticInvoke( + Decimal, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case t if t <:< localTypeOf[java.lang.Long] => + Invoke(inputObject, "longValue", LongType) + case t if t <:< localTypeOf[java.lang.Double] => + Invoke(inputObject, "doubleValue", DoubleType) + case t if t <:< localTypeOf[java.lang.Float] => + Invoke(inputObject, "floatValue", FloatType) + case t if t <:< localTypeOf[java.lang.Short] => + Invoke(inputObject, "shortValue", ShortType) + case t if t <:< localTypeOf[java.lang.Byte] => + Invoke(inputObject, "byteValue", ByteType) + case t if t <:< localTypeOf[java.lang.Boolean] => + Invoke(inputObject, "booleanValue", BooleanType) + + case other => + throw new UnsupportedOperationException(s"Extractor for type $other is not supported") + } + } + } + + private def toCatalystArray(input: Expression, elementType: `Type`): Expression = { + val externalDataType = dataTypeFor(elementType) + val Schema(catalystType, nullable) = schemaFor(elementType) + if (RowEncoder.isNativeType(catalystType)) { + NewInstance( + classOf[GenericArrayData], + input :: Nil, + dataType = ArrayType(catalystType, nullable)) + } else { + MapObjects(extractorFor(_, elementType), input, externalDataType) + } + } + + def constructorFor( + tpe: `Type`, + path: Option[Expression] = None): Expression = ScalaReflectionLock.synchronized { + + /** Returns the current path with a sub-field extracted. */ + def addToPath(part: String): Expression = path + .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) + .getOrElse(UnresolvedAttribute(part)) + + /** Returns the current path with a field at ordinal extracted. */ + def addToPathOrdinal(ordinal: Int, dataType: DataType): Expression = path + .map(p => GetInternalRowField(p, ordinal, dataType)) + .getOrElse(BoundReference(ordinal, dataType, false)) + + /** Returns the current path or `BoundReference`. */ + def getPath: Expression = path.getOrElse(BoundReference(0, schemaFor(tpe).dataType, true)) + + tpe match { + case t if !dataTypeFor(t).isInstanceOf[ObjectType] => getPath + + case t if t <:< localTypeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + WrapOption(null, constructorFor(optType, path)) + + case t if t <:< localTypeOf[java.lang.Integer] => + val boxedType = classOf[java.lang.Integer] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Long] => + val boxedType = classOf[java.lang.Long] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Double] => + val boxedType = classOf[java.lang.Double] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Float] => + val boxedType = classOf[java.lang.Float] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Short] => + val boxedType = classOf[java.lang.Short] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Byte] => + val boxedType = classOf[java.lang.Byte] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.lang.Boolean] => + val boxedType = classOf[java.lang.Boolean] + val objectType = ObjectType(boxedType) + NewInstance(boxedType, getPath :: Nil, propagateNull = true, objectType) + + case t if t <:< localTypeOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils, + ObjectType(classOf[java.sql.Date]), + "toJavaDate", + getPath :: Nil, + propagateNull = true) + + case t if t <:< localTypeOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils, + ObjectType(classOf[java.sql.Timestamp]), + "toJavaTimestamp", + getPath :: Nil, + propagateNull = true) + + case t if t <:< localTypeOf[java.lang.String] => + Invoke(getPath, "toString", ObjectType(classOf[String])) + + case t if t <:< localTypeOf[java.math.BigDecimal] => + Invoke(getPath, "toJavaBigDecimal", ObjectType(classOf[java.math.BigDecimal])) + + case t if t <:< localTypeOf[BigDecimal] => + Invoke(getPath, "toBigDecimal", ObjectType(classOf[BigDecimal])) + + case t if t <:< localTypeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val primitiveMethod = elementType match { + case t if t <:< definitions.IntTpe => Some("toIntArray") + case t if t <:< definitions.LongTpe => Some("toLongArray") + case t if t <:< definitions.DoubleTpe => Some("toDoubleArray") + case t if t <:< definitions.FloatTpe => Some("toFloatArray") + case t if t <:< definitions.ShortTpe => Some("toShortArray") + case t if t <:< definitions.ByteTpe => Some("toByteArray") + case t if t <:< definitions.BooleanTpe => Some("toBooleanArray") + case _ => None + } + + primitiveMethod.map { method => + Invoke(getPath, method, arrayClassFor(elementType)) + }.getOrElse { + Invoke( + MapObjects( + p => constructorFor(elementType, Some(p)), + getPath, + schemaFor(elementType).dataType), + "array", + arrayClassFor(elementType)) + } + + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val arrayData = + Invoke( + MapObjects( + p => constructorFor(elementType, Some(p)), + getPath, + schemaFor(elementType).dataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + scala.collection.mutable.WrappedArray, + ObjectType(classOf[Seq[_]]), + "make", + arrayData :: Nil) + + case t if t <:< localTypeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + + val keyData = + Invoke( + MapObjects( + p => constructorFor(keyType, Some(p)), + Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), + schemaFor(keyType).dataType), + "array", + ObjectType(classOf[Array[Any]])) + + val valueData = + Invoke( + MapObjects( + p => constructorFor(valueType, Some(p)), + Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), + schemaFor(valueType).dataType), + "array", + ObjectType(classOf[Array[Any]])) + + StaticInvoke( + ArrayBasedMapData, + ObjectType(classOf[Map[_, _]]), + "toScalaMap", + keyData :: valueData :: Nil) + + case t if t <:< localTypeOf[Product] => + val formalTypeArgs = t.typeSymbol.asClass.typeParams + val TypeRef(_, _, actualTypeArgs) = t + val constructorSymbol = t.member(nme.CONSTRUCTOR) + val params = if (constructorSymbol.isMethod) { + constructorSymbol.asMethod.paramss + } else { + // Find the primary constructor, and use its parameter ordering. + val primaryConstructorSymbol: Option[Symbol] = + constructorSymbol.asTerm.alternatives.find(s => + s.isMethod && s.asMethod.isPrimaryConstructor) + + if (primaryConstructorSymbol.isEmpty) { + sys.error("Internal SQL error: Product object did not have a primary constructor.") + } else { + primaryConstructorSymbol.get.asMethod.paramss + } + } + + val className: String = t.erasure.typeSymbol.asClass.fullName + val cls = Utils.classForName(className) + + val arguments = params.head.zipWithIndex.map { case (p, i) => + val fieldName = p.name.toString + val fieldType = p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs) + val dataType = schemaFor(fieldType).dataType + + // For tuples, we based grab the inner fields by ordinal instead of name. + if (className startsWith "scala.Tuple") { + constructorFor(fieldType, Some(addToPathOrdinal(i, dataType))) + } else { + constructorFor(fieldType, Some(addToPath(fieldName))) + } + } + + val newInstance = NewInstance(cls, arguments, propagateNull = false, ObjectType(cls)) + + if (path.nonEmpty) { + expressions.If( + IsNull(getPath), + expressions.Literal.create(null, ObjectType(cls)), + newInstance + ) + } else { + newInstance + } + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 0b42130a013b2575959d6169cafdeb3c06cd1523..e0be896bb3548feac07e41bb00a6992d8a89fbe0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -119,9 +119,17 @@ object RowEncoder { CreateStruct(convertedFields) } - private def externalDataTypeFor(dt: DataType): DataType = dt match { + /** + * Returns true if the value of this data type is same between internal and external. + */ + def isNativeType(dt: DataType): Boolean = dt match { case BooleanType | ByteType | ShortType | IntegerType | LongType | - FloatType | DoubleType | BinaryType => dt + FloatType | DoubleType | BinaryType => true + case _ => false + } + + private def externalDataTypeFor(dt: DataType): DataType = dt match { + case _ if isNativeType(dt) => dt case TimestampType => ObjectType(classOf[java.sql.Timestamp]) case DateType => ObjectType(classOf[java.sql.Date]) case _: DecimalType => ObjectType(classOf[java.math.BigDecimal]) @@ -137,13 +145,13 @@ object RowEncoder { If( IsNull(field), Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(BoundReference(i, f.dataType, f.nullable), f.dataType) + constructorFor(BoundReference(i, f.dataType, f.nullable)) ) } CreateExternalRow(fields) } - private def constructorFor(input: Expression, dataType: DataType): Expression = dataType match { + private def constructorFor(input: Expression): Expression = input.dataType match { case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType => input @@ -170,7 +178,7 @@ object RowEncoder { case ArrayType(et, nullable) => val arrayData = Invoke( - MapObjects(constructorFor(_, et), input, et), + MapObjects(constructorFor, input, et), "array", ObjectType(classOf[Array[_]])) StaticInvoke( @@ -181,10 +189,10 @@ object RowEncoder { case MapType(kt, vt, valueNullable) => val keyArrayType = ArrayType(kt, false) - val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType), keyArrayType) + val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType)) val valueArrayType = ArrayType(vt, valueNullable) - val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType), valueArrayType) + val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType)) StaticInvoke( ArrayBasedMapData, @@ -197,42 +205,8 @@ object RowEncoder { If( Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(getField(input, i, f.dataType), f.dataType)) + constructorFor(GetInternalRowField(input, i, f.dataType))) } CreateExternalRow(convertedFields) } - - private def getField( - row: Expression, - ordinal: Int, - dataType: DataType): Expression = dataType match { - case BooleanType => - Invoke(row, "getBoolean", dataType, Literal(ordinal) :: Nil) - case ByteType => - Invoke(row, "getByte", dataType, Literal(ordinal) :: Nil) - case ShortType => - Invoke(row, "getShort", dataType, Literal(ordinal) :: Nil) - case IntegerType | DateType => - Invoke(row, "getInt", dataType, Literal(ordinal) :: Nil) - case LongType | TimestampType => - Invoke(row, "getLong", dataType, Literal(ordinal) :: Nil) - case FloatType => - Invoke(row, "getFloat", dataType, Literal(ordinal) :: Nil) - case DoubleType => - Invoke(row, "getDouble", dataType, Literal(ordinal) :: Nil) - case t: DecimalType => - Invoke(row, "getDecimal", dataType, Seq(ordinal, t.precision, t.scale).map(Literal(_))) - case StringType => - Invoke(row, "getUTF8String", dataType, Literal(ordinal) :: Nil) - case BinaryType => - Invoke(row, "getBinary", dataType, Literal(ordinal) :: Nil) - case CalendarIntervalType => - Invoke(row, "getInterval", dataType, Literal(ordinal) :: Nil) - case t: StructType => - Invoke(row, "getStruct", dataType, Literal(ordinal) :: Literal(t.size) :: Nil) - case _: ArrayType => - Invoke(row, "getArray", dataType, Literal(ordinal) :: Nil) - case _: MapType => - Invoke(row, "getMap", dataType, Literal(ordinal) :: Nil) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index f5fff90e5a5424672192c46dbefa0a7940928715..deff8a5378b925f38e7c1c99a760adc875cad526 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -110,7 +110,7 @@ object DateTimeUtils { } def stringToTime(s: String): java.util.Date = { - var indexOfGMT = s.indexOf("GMT"); + val indexOfGMT = s.indexOf("GMT") if (indexOfGMT != -1) { // ISO8601 with a weird time zone specifier (2000-01-01T00:00GMT+01:00) val s0 = s.substring(0, indexOfGMT) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala index e9bf7b33e35be54135aad9e7a89c900287d8a3bc..96588bb5dc1bc033cae775bdba4416514806898e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/GenericArrayData.scala @@ -23,7 +23,7 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} class GenericArrayData(val array: Array[Any]) extends ArrayData { - def this(seq: scala.collection.GenIterable[Any]) = this(seq.toArray) + def this(seq: Seq[Any]) = this(seq.toArray) // TODO: This is boxing. We should specialize. def this(primitiveArray: Array[Int]) = this(primitiveArray.toSeq) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index b0dacf7f555e0be62c28d7ee09707976f1e16f66..9fe64b4cf10e4ce1552010c5d14573e48dbe4b71 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -17,232 +17,27 @@ package org.apache.spark.sql.catalyst.encoders -import scala.collection.mutable.ArrayBuffer -import scala.reflect.runtime.universe._ +import java.util.Arrays import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.expressions.AttributeReference import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.types.{StructField, ArrayType} - -case class RepeatedStruct(s: Seq[PrimitiveData]) - -case class NestedArray(a: Array[Array[Int]]) - -case class BoxedData( - intField: java.lang.Integer, - longField: java.lang.Long, - doubleField: java.lang.Double, - floatField: java.lang.Float, - shortField: java.lang.Short, - byteField: java.lang.Byte, - booleanField: java.lang.Boolean) - -case class RepeatedData( - arrayField: Seq[Int], - arrayFieldContainsNull: Seq[java.lang.Integer], - mapField: scala.collection.Map[Int, Long], - mapFieldNull: scala.collection.Map[Int, java.lang.Long], - structField: PrimitiveData) - -case class SpecificCollection(l: List[Int]) - -class ExpressionEncoderSuite extends SparkFunSuite { - - encodeDecodeTest(1) - encodeDecodeTest(1L) - encodeDecodeTest(1.toDouble) - encodeDecodeTest(1.toFloat) - encodeDecodeTest(true) - encodeDecodeTest(false) - encodeDecodeTest(1.toShort) - encodeDecodeTest(1.toByte) - encodeDecodeTest("hello") - - encodeDecodeTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) - - // TODO: Support creating specific subclasses of Seq. - ignore("Specific collection types") { encodeDecodeTest(SpecificCollection(1 :: Nil)) } - - encodeDecodeTest( - OptionalData( - Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), - Some(PrimitiveData(1, 1, 1, 1, 1, 1, true)))) - - encodeDecodeTest(OptionalData(None, None, None, None, None, None, None, None)) - - encodeDecodeTest( - BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) - - encodeDecodeTest( - BoxedData(null, null, null, null, null, null, null)) - - encodeDecodeTest( - RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) - - encodeDecodeTest( - RepeatedData( - Seq(1, 2), - Seq(new Integer(1), null, new Integer(2)), - Map(1 -> 2L), - Map(1 -> null), - PrimitiveData(1, 1, 1, 1, 1, 1, true))) - - encodeDecodeTest(("nullable Seq[Integer]", Seq[Integer](1, null))) - - encodeDecodeTest(("Seq[(String, String)]", - Seq(("a", "b")))) - encodeDecodeTest(("Seq[(Int, Int)]", - Seq((1, 2)))) - encodeDecodeTest(("Seq[(Long, Long)]", - Seq((1L, 2L)))) - encodeDecodeTest(("Seq[(Float, Float)]", - Seq((1.toFloat, 2.toFloat)))) - encodeDecodeTest(("Seq[(Double, Double)]", - Seq((1.toDouble, 2.toDouble)))) - encodeDecodeTest(("Seq[(Short, Short)]", - Seq((1.toShort, 2.toShort)))) - encodeDecodeTest(("Seq[(Byte, Byte)]", - Seq((1.toByte, 2.toByte)))) - encodeDecodeTest(("Seq[(Boolean, Boolean)]", - Seq((true, false)))) - - // TODO: Decoding/encoding of complex maps. - ignore("complex maps") { - encodeDecodeTest(("Map[Int, (String, String)]", - Map(1 ->("a", "b")))) - } - - encodeDecodeTest(("ArrayBuffer[(String, String)]", - ArrayBuffer(("a", "b")))) - encodeDecodeTest(("ArrayBuffer[(Int, Int)]", - ArrayBuffer((1, 2)))) - encodeDecodeTest(("ArrayBuffer[(Long, Long)]", - ArrayBuffer((1L, 2L)))) - encodeDecodeTest(("ArrayBuffer[(Float, Float)]", - ArrayBuffer((1.toFloat, 2.toFloat)))) - encodeDecodeTest(("ArrayBuffer[(Double, Double)]", - ArrayBuffer((1.toDouble, 2.toDouble)))) - encodeDecodeTest(("ArrayBuffer[(Short, Short)]", - ArrayBuffer((1.toShort, 2.toShort)))) - encodeDecodeTest(("ArrayBuffer[(Byte, Byte)]", - ArrayBuffer((1.toByte, 2.toByte)))) - encodeDecodeTest(("ArrayBuffer[(Boolean, Boolean)]", - ArrayBuffer((true, false)))) - - encodeDecodeTest(("Seq[Seq[(Int, Int)]]", - Seq(Seq((1, 2))))) - - encodeDecodeTestCustom(("Array[Array[(Int, Int)]]", - Array(Array((1, 2))))) - { (l, r) => l._2(0)(0) == r._2(0)(0) } - - encodeDecodeTestCustom(("Array[Array[(Int, Int)]]", - Array(Array(Array((1, 2)))))) - { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Array[(Int, Int)]]]", - Array(Array(Array(Array((1, 2))))))) - { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Array[Array[(Int, Int)]]]]", - Array(Array(Array(Array(Array((1, 2)))))))) - { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) } - - - encodeDecodeTestCustom(("Array[Array[Integer]]", - Array(Array[Integer](1)))) - { (l, r) => l._2(0)(0) == r._2(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Int]]", - Array(Array(1)))) - { (l, r) => l._2(0)(0) == r._2(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Int]]", - Array(Array(Array(1))))) - { (l, r) => l._2(0)(0)(0) == r._2(0)(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Array[Int]]]", - Array(Array(Array(Array(1)))))) - { (l, r) => l._2(0)(0)(0)(0) == r._2(0)(0)(0)(0) } - - encodeDecodeTestCustom(("Array[Array[Array[Array[Int]]]]", - Array(Array(Array(Array(Array(1))))))) - { (l, r) => l._2(0)(0)(0)(0)(0) == r._2(0)(0)(0)(0)(0) } - - encodeDecodeTest(("Array[Byte] null", - null: Array[Byte])) - encodeDecodeTestCustom(("Array[Byte]", - Array[Byte](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Int] null", - null: Array[Int])) - encodeDecodeTestCustom(("Array[Int]", - Array[Int](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Long] null", - null: Array[Long])) - encodeDecodeTestCustom(("Array[Long]", - Array[Long](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Double] null", - null: Array[Double])) - encodeDecodeTestCustom(("Array[Double]", - Array[Double](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Float] null", - null: Array[Float])) - encodeDecodeTestCustom(("Array[Float]", - Array[Float](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Boolean] null", - null: Array[Boolean])) - encodeDecodeTestCustom(("Array[Boolean]", - Array[Boolean](true, false))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTest(("Array[Short] null", - null: Array[Short])) - encodeDecodeTestCustom(("Array[Short]", - Array[Short](1, 2, 3))) - { (l, r) => java.util.Arrays.equals(l._2, r._2) } - - encodeDecodeTestCustom(("java.sql.Timestamp", - new java.sql.Timestamp(1))) - { (l, r) => l._2.toString == r._2.toString } - - encodeDecodeTestCustom(("java.sql.Date", new java.sql.Date(1))) - { (l, r) => l._2.toString == r._2.toString } - - /** Simplified encodeDecodeTestCustom, where the comparison function can be `Object.equals`. */ - protected def encodeDecodeTest[T : TypeTag](inputData: T) = - encodeDecodeTestCustom[T](inputData)((l, r) => l == r) - - /** - * Constructs a test that round-trips `t` through an encoder, checking the results to ensure it - * matches the original. - */ - protected def encodeDecodeTestCustom[T : TypeTag]( - inputData: T)( - c: (T, T) => Boolean) = { - test(s"encode/decode: $inputData - ${inputData.getClass.getName}") { - val encoder = try ExpressionEncoder[T]() catch { - case e: Exception => - fail(s"Exception thrown generating encoder", e) - } - val convertedData = encoder.toRow(inputData) +import org.apache.spark.sql.types.ArrayType + +abstract class ExpressionEncoderSuite extends SparkFunSuite { + protected def encodeDecodeTest[T]( + input: T, + encoder: ExpressionEncoder[T], + testName: String): Unit = { + test(s"encode/decode for $testName: $input") { + val row = encoder.toRow(input) val schema = encoder.schema.toAttributes val boundEncoder = encoder.resolve(schema).bind(schema) - val convertedBack = try boundEncoder.fromRow(convertedData) catch { + val convertedBack = try boundEncoder.fromRow(row) catch { case e: Exception => fail( s"""Exception thrown while decoding - |Converted: $convertedData + |Converted: $row |Schema: ${schema.mkString(",")} |${encoder.schema.treeString} | @@ -252,18 +47,27 @@ class ExpressionEncoderSuite extends SparkFunSuite { """.stripMargin, e) } - if (!c(inputData, convertedBack)) { + val isCorrect = (input, convertedBack) match { + case (b1: Array[Byte], b2: Array[Byte]) => Arrays.equals(b1, b2) + case (b1: Array[Int], b2: Array[Int]) => Arrays.equals(b1, b2) + case (b1: Array[Array[_]], b2: Array[Array[_]]) => + Arrays.deepEquals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]]) + case (b1: Array[_], b2: Array[_]) => + Arrays.equals(b1.asInstanceOf[Array[AnyRef]], b2.asInstanceOf[Array[AnyRef]]) + case _ => input == convertedBack + } + + if (!isCorrect) { val types = convertedBack match { case c: Product => c.productIterator.filter(_ != null).map(_.getClass.getName).mkString(",") case other => other.getClass.getName } - val encodedData = try { - convertedData.toSeq(encoder.schema).zip(encoder.schema).map { - case (a: ArrayData, StructField(_, at: ArrayType, _, _)) => - a.toArray[Any](at.elementType).toSeq + row.toSeq(encoder.schema).zip(schema).map { + case (a: ArrayData, AttributeReference(_, ArrayType(et, _), _, _)) => + a.toArray[Any](et).toSeq case (other, _) => other }.mkString("[", ",", "]") @@ -274,7 +78,7 @@ class ExpressionEncoderSuite extends SparkFunSuite { fail( s"""Encoded/Decoded data does not match input data | - |in: $inputData + |in: $input |out: $convertedBack |types: $types | @@ -282,11 +86,10 @@ class ExpressionEncoderSuite extends SparkFunSuite { |Schema: ${schema.mkString(",")} |${encoder.schema.treeString} | - |Extract Expressions: - |$boundEncoder + |fromRow Expressions: + |${boundEncoder.fromRowExpression.treeString} """.stripMargin) - } } - + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..55821c43706847746fa8601d1caeec9e9d2579de --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/FlatEncoderSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import java.sql.{Date, Timestamp} + +class FlatEncoderSuite extends ExpressionEncoderSuite { + encodeDecodeTest(false, FlatEncoder[Boolean], "primitive boolean") + encodeDecodeTest(-3.toByte, FlatEncoder[Byte], "primitive byte") + encodeDecodeTest(-3.toShort, FlatEncoder[Short], "primitive short") + encodeDecodeTest(-3, FlatEncoder[Int], "primitive int") + encodeDecodeTest(-3L, FlatEncoder[Long], "primitive long") + encodeDecodeTest(-3.7f, FlatEncoder[Float], "primitive float") + encodeDecodeTest(-3.7, FlatEncoder[Double], "primitive double") + + encodeDecodeTest(new java.lang.Boolean(false), FlatEncoder[java.lang.Boolean], "boxed boolean") + encodeDecodeTest(new java.lang.Byte(-3.toByte), FlatEncoder[java.lang.Byte], "boxed byte") + encodeDecodeTest(new java.lang.Short(-3.toShort), FlatEncoder[java.lang.Short], "boxed short") + encodeDecodeTest(new java.lang.Integer(-3), FlatEncoder[java.lang.Integer], "boxed int") + encodeDecodeTest(new java.lang.Long(-3L), FlatEncoder[java.lang.Long], "boxed long") + encodeDecodeTest(new java.lang.Float(-3.7f), FlatEncoder[java.lang.Float], "boxed float") + encodeDecodeTest(new java.lang.Double(-3.7), FlatEncoder[java.lang.Double], "boxed double") + + encodeDecodeTest(BigDecimal("32131413.211321313"), FlatEncoder[BigDecimal], "scala decimal") + type JDecimal = java.math.BigDecimal + // encodeDecodeTest(new JDecimal("231341.23123"), FlatEncoder[JDecimal], "java decimal") + + encodeDecodeTest("hello", FlatEncoder[String], "string") + encodeDecodeTest(Date.valueOf("2012-12-23"), FlatEncoder[Date], "date") + encodeDecodeTest(Timestamp.valueOf("2016-01-29 10:00:00"), FlatEncoder[Timestamp], "timestamp") + encodeDecodeTest(Array[Byte](13, 21, -23), FlatEncoder[Array[Byte]], "binary") + + encodeDecodeTest(Seq(31, -123, 4), FlatEncoder[Seq[Int]], "seq of int") + encodeDecodeTest(Seq("abc", "xyz"), FlatEncoder[Seq[String]], "seq of string") + encodeDecodeTest(Seq("abc", null, "xyz"), FlatEncoder[Seq[String]], "seq of string with null") + encodeDecodeTest(Seq.empty[Int], FlatEncoder[Seq[Int]], "empty seq of int") + encodeDecodeTest(Seq.empty[String], FlatEncoder[Seq[String]], "empty seq of string") + + encodeDecodeTest(Seq(Seq(31, -123), null, Seq(4, 67)), + FlatEncoder[Seq[Seq[Int]]], "seq of seq of int") + encodeDecodeTest(Seq(Seq("abc", "xyz"), Seq[String](null), null, Seq("1", null, "2")), + FlatEncoder[Seq[Seq[String]]], "seq of seq of string") + + encodeDecodeTest(Array(31, -123, 4), FlatEncoder[Array[Int]], "array of int") + encodeDecodeTest(Array("abc", "xyz"), FlatEncoder[Array[String]], "array of string") + encodeDecodeTest(Array("a", null, "x"), FlatEncoder[Array[String]], "array of string with null") + encodeDecodeTest(Array.empty[Int], FlatEncoder[Array[Int]], "empty array of int") + encodeDecodeTest(Array.empty[String], FlatEncoder[Array[String]], "empty array of string") + + encodeDecodeTest(Array(Array(31, -123), null, Array(4, 67)), + FlatEncoder[Array[Array[Int]]], "array of array of int") + encodeDecodeTest(Array(Array("abc", "xyz"), Array[String](null), null, Array("1", null, "2")), + FlatEncoder[Array[Array[String]]], "array of array of string") + + encodeDecodeTest(Map(1 -> "a", 2 -> "b"), FlatEncoder[Map[Int, String]], "map") + encodeDecodeTest(Map(1 -> "a", 2 -> null), FlatEncoder[Map[Int, String]], "map with null") + encodeDecodeTest(Map(1 -> Map("a" -> 1), 2 -> Map("b" -> 2)), + FlatEncoder[Map[Int, Map[String, Int]]], "map of map") +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..fda978e7055ea6ecd35f88b469a4f80f935119d2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.encoders + +import scala.collection.mutable.ArrayBuffer +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} + +case class RepeatedStruct(s: Seq[PrimitiveData]) + +case class NestedArray(a: Array[Array[Int]]) { + override def equals(other: Any): Boolean = other match { + case NestedArray(otherArray) => + java.util.Arrays.deepEquals( + a.asInstanceOf[Array[AnyRef]], + otherArray.asInstanceOf[Array[AnyRef]]) + case _ => false + } +} + +case class BoxedData( + intField: java.lang.Integer, + longField: java.lang.Long, + doubleField: java.lang.Double, + floatField: java.lang.Float, + shortField: java.lang.Short, + byteField: java.lang.Byte, + booleanField: java.lang.Boolean) + +case class RepeatedData( + arrayField: Seq[Int], + arrayFieldContainsNull: Seq[java.lang.Integer], + mapField: scala.collection.Map[Int, Long], + mapFieldNull: scala.collection.Map[Int, java.lang.Long], + structField: PrimitiveData) + +case class SpecificCollection(l: List[Int]) + +class ProductEncoderSuite extends ExpressionEncoderSuite { + + productTest(PrimitiveData(1, 1, 1, 1, 1, 1, true)) + + productTest( + OptionalData(Some(2), Some(2), Some(2), Some(2), Some(2), Some(2), Some(true), + Some(PrimitiveData(1, 1, 1, 1, 1, 1, true)))) + + productTest(OptionalData(None, None, None, None, None, None, None, None)) + + productTest(BoxedData(1, 1L, 1.0, 1.0f, 1.toShort, 1.toByte, true)) + + productTest(BoxedData(null, null, null, null, null, null, null)) + + productTest(RepeatedStruct(PrimitiveData(1, 1, 1, 1, 1, 1, true) :: Nil)) + + productTest((1, "test", PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + productTest( + RepeatedData( + Seq(1, 2), + Seq(new Integer(1), null, new Integer(2)), + Map(1 -> 2L), + Map(1 -> null), + PrimitiveData(1, 1, 1, 1, 1, 1, true))) + + productTest(NestedArray(Array(Array(1, -2, 3), null, Array(4, 5, -6)))) + + productTest(("Seq[(String, String)]", + Seq(("a", "b")))) + productTest(("Seq[(Int, Int)]", + Seq((1, 2)))) + productTest(("Seq[(Long, Long)]", + Seq((1L, 2L)))) + productTest(("Seq[(Float, Float)]", + Seq((1.toFloat, 2.toFloat)))) + productTest(("Seq[(Double, Double)]", + Seq((1.toDouble, 2.toDouble)))) + productTest(("Seq[(Short, Short)]", + Seq((1.toShort, 2.toShort)))) + productTest(("Seq[(Byte, Byte)]", + Seq((1.toByte, 2.toByte)))) + productTest(("Seq[(Boolean, Boolean)]", + Seq((true, false)))) + + productTest(("ArrayBuffer[(String, String)]", + ArrayBuffer(("a", "b")))) + productTest(("ArrayBuffer[(Int, Int)]", + ArrayBuffer((1, 2)))) + productTest(("ArrayBuffer[(Long, Long)]", + ArrayBuffer((1L, 2L)))) + productTest(("ArrayBuffer[(Float, Float)]", + ArrayBuffer((1.toFloat, 2.toFloat)))) + productTest(("ArrayBuffer[(Double, Double)]", + ArrayBuffer((1.toDouble, 2.toDouble)))) + productTest(("ArrayBuffer[(Short, Short)]", + ArrayBuffer((1.toShort, 2.toShort)))) + productTest(("ArrayBuffer[(Byte, Byte)]", + ArrayBuffer((1.toByte, 2.toByte)))) + productTest(("ArrayBuffer[(Boolean, Boolean)]", + ArrayBuffer((true, false)))) + + productTest(("Seq[Seq[(Int, Int)]]", + Seq(Seq((1, 2))))) + + private def productTest[T <: Product : TypeTag](input: T): Unit = { + encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 9c16940707de97c030a4ff7fd82f8f8e004e787f..ebcf4c8bfe7e694deca3d6749cda2db3f7e256f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.api.java.function._ -import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor} +import org.apache.spark.sql.catalyst.encoders.{FlatEncoder, ExpressionEncoder, encoderFor} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -56,9 +56,6 @@ class GroupedDataset[K, T] private[sql]( private val resolvedKEncoder = unresolvedKEncoder.resolve(groupingAttributes) private val resolvedTEncoder = unresolvedTEncoder.resolve(dataAttributes) - /** Encoders for built in aggregations. */ - private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) - private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext @@ -211,7 +208,7 @@ class GroupedDataset[K, T] private[sql]( * Returns a [[Dataset]] that contains a tuple with each key and the number of items present * for that key. */ - def count(): Dataset[(K, Long)] = agg(functions.count("*").as[Long]) + def count(): Dataset[(K, Long)] = agg(functions.count("*").as(FlatEncoder[Long])) /** * Applies the given function to each cogrouped data. For each unique group, the function will diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 6da46a5f7ef9ad5d93327642128020c66260f8f5..8471eea1b7d9c5a45111f9d841e45391616c8912 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -37,17 +37,21 @@ import org.apache.spark.unsafe.types.UTF8String abstract class SQLImplicits { protected def _sqlContext: SQLContext - implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder[T]() + implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ProductEncoder[T] - implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder[Int](flat = true) - implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true) - implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder[Double](flat = true) - implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder[Float](flat = true) - implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder[Byte](flat = true) - implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder[Short](flat = true) - implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder[Boolean](flat = true) - implicit def newStringEncoder: Encoder[String] = ExpressionEncoder[String](flat = true) + implicit def newIntEncoder: Encoder[Int] = FlatEncoder[Int] + implicit def newLongEncoder: Encoder[Long] = FlatEncoder[Long] + implicit def newDoubleEncoder: Encoder[Double] = FlatEncoder[Double] + implicit def newFloatEncoder: Encoder[Float] = FlatEncoder[Float] + implicit def newByteEncoder: Encoder[Byte] = FlatEncoder[Byte] + implicit def newShortEncoder: Encoder[Short] = FlatEncoder[Short] + implicit def newBooleanEncoder: Encoder[Boolean] = FlatEncoder[Boolean] + implicit def newStringEncoder: Encoder[String] = FlatEncoder[String] + /** + * Creates a [[Dataset]] from an RDD. + * @since 1.6.0 + */ implicit def rddToDatasetHolder[T : Encoder](rdd: RDD[T]): DatasetHolder[T] = { DatasetHolder(_sqlContext.createDataset(rdd)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 53cc6e0cda11053e16e9c66628fbf367617598cb..95158de710acf063df64cdc984f32c76d643d569 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -26,7 +26,7 @@ import scala.util.Try import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.{SqlParser, ScalaReflection} import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.catalyst.encoders.FlatEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint @@ -267,7 +267,7 @@ object functions extends LegacyFunctions { * @since 1.3.0 */ def count(columnName: String): TypedColumn[Any, Long] = - count(Column(columnName)).as(ExpressionEncoder[Long](flat = true)) + count(Column(columnName)).as(FlatEncoder[Long]) /** * Aggregate function: returns the number of distinct items in a group.