diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala index be7110ad6bbf0773fc45928a341ff3f436ae0efa..8b439e6b7a0174ba64fe5b4ab6fa187d0e4827be 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala @@ -29,7 +29,7 @@ object UDTSerializationBenchmark { val iters = 1e2.toInt val numRows = 1e3.toInt - val encoder = ExpressionEncoder[Vector].defaultBinding + val encoder = ExpressionEncoder[Vector].resolveAndBind() val vectors = (1 to numRows).map { i => Vectors.dense(Array.fill(1e5.toInt)(1.0 * i)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala index fa96f8223d179f67d7fe097a34e2c1371dba0cf6..673c587b18325506acedf4822f78e94fef54bfbc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -23,6 +23,7 @@ import scala.reflect.{classTag, ClassTag} import scala.reflect.runtime.universe.TypeTag import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.objects.{DecodeUsingSerializer, EncodeUsingSerializer} import org.apache.spark.sql.catalyst.expressions.BoundReference @@ -208,7 +209,7 @@ object Encoders { BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), deserializer = DecodeUsingSerializer[T]( - BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), + GetColumnByOrdinal(0, BinaryType), classTag[T], kryo = useKryo), clsTag = classTag[T] ) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 1fe143494abad0ed19f035933789017a5b8b8818..b3a233ae390ab4137d4ce0bad4fa928454397775 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -25,7 +25,7 @@ import scala.language.existentials import com.google.common.reflect.TypeToken -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} @@ -177,8 +177,8 @@ object JavaTypeInference { .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) .getOrElse(UnresolvedAttribute(part)) - /** Returns the current path or `BoundReference`. */ - def getPath: Expression = path.getOrElse(BoundReference(0, inferDataType(typeToken)._1, true)) + /** Returns the current path or `GetColumnByOrdinal`. */ + def getPath: Expression = path.getOrElse(GetColumnByOrdinal(0, inferDataType(typeToken)._1)) typeToken.getRawType match { case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 47508618178ea58acbbb6fb267fa17f833089716..78c145d4fd93668f0647f29dd58fb43e1030e56e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} @@ -156,17 +156,17 @@ object ScalaReflection extends ScalaReflection { walkedTypePath: Seq[String]): Expression = { val newPath = path .map(p => GetStructField(p, ordinal)) - .getOrElse(BoundReference(ordinal, dataType, false)) + .getOrElse(GetColumnByOrdinal(ordinal, dataType)) upCastToExpectedType(newPath, dataType, walkedTypePath) } - /** Returns the current path or `BoundReference`. */ + /** Returns the current path or `GetColumnByOrdinal`. */ def getPath: Expression = { val dataType = schemaFor(tpe).dataType if (path.isDefined) { path.get } else { - upCastToExpectedType(BoundReference(0, dataType, true), dataType, walkedTypePath) + upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType, walkedTypePath) } } @@ -421,7 +421,7 @@ object ScalaReflection extends ScalaReflection { def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) - val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil + val walkedTypePath = s"""- root class: "$clsName"""" :: Nil serializerFor(inputObject, tpe, walkedTypePath) match { case expressions.If(_, _, s: CreateNamedStruct) if definedByConstructorParams(tpe) => s case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) @@ -449,157 +449,156 @@ object ScalaReflection extends ScalaReflection { } } - if (!inputObject.dataType.isInstanceOf[ObjectType]) { - inputObject - } else { - tpe match { - case t if t <:< localTypeOf[Option[_]] => - val TypeRef(_, _, Seq(optType)) = t - val className = getClassNameFromType(optType) - val newPath = s"""- option value class: "$className"""" +: walkedTypePath - val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) - serializerFor(unwrapped, optType, newPath) - - // Since List[_] also belongs to localTypeOf[Product], we put this case before - // "case t if definedByConstructorParams(t)" to make sure it will match to the - // case "localTypeOf[Seq[_]]" - case t if t <:< localTypeOf[Seq[_]] => - val TypeRef(_, _, Seq(elementType)) = t - toCatalystArray(inputObject, elementType) - - case t if t <:< localTypeOf[Array[_]] => - val TypeRef(_, _, Seq(elementType)) = t - toCatalystArray(inputObject, elementType) - - case t if t <:< localTypeOf[Map[_, _]] => - val TypeRef(_, _, Seq(keyType, valueType)) = t - - val keys = - Invoke( - Invoke(inputObject, "keysIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedKeys = toCatalystArray(keys, keyType) - - val values = - Invoke( - Invoke(inputObject, "valuesIterator", - ObjectType(classOf[scala.collection.Iterator[_]])), - "toSeq", - ObjectType(classOf[scala.collection.Seq[_]])) - val convertedValues = toCatalystArray(values, valueType) - - val Schema(keyDataType, _) = schemaFor(keyType) - val Schema(valueDataType, valueNullable) = schemaFor(valueType) - NewInstance( - classOf[ArrayBasedMapData], - convertedKeys :: convertedValues :: Nil, - dataType = MapType(keyDataType, valueDataType, valueNullable)) - - case t if t <:< localTypeOf[String] => - StaticInvoke( - classOf[UTF8String], - StringType, - "fromString", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.sql.Timestamp] => - StaticInvoke( - DateTimeUtils.getClass, - TimestampType, - "fromJavaTimestamp", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.sql.Date] => - StaticInvoke( - DateTimeUtils.getClass, - DateType, - "fromJavaDate", - inputObject :: Nil) - - case t if t <:< localTypeOf[BigDecimal] => - StaticInvoke( - Decimal.getClass, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.math.BigDecimal] => - StaticInvoke( - Decimal.getClass, - DecimalType.SYSTEM_DEFAULT, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.math.BigInteger] => - StaticInvoke( - Decimal.getClass, - DecimalType.BigIntDecimal, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[scala.math.BigInt] => - StaticInvoke( - Decimal.getClass, - DecimalType.BigIntDecimal, - "apply", - inputObject :: Nil) - - case t if t <:< localTypeOf[java.lang.Integer] => - Invoke(inputObject, "intValue", IntegerType) - case t if t <:< localTypeOf[java.lang.Long] => - Invoke(inputObject, "longValue", LongType) - case t if t <:< localTypeOf[java.lang.Double] => - Invoke(inputObject, "doubleValue", DoubleType) - case t if t <:< localTypeOf[java.lang.Float] => - Invoke(inputObject, "floatValue", FloatType) - case t if t <:< localTypeOf[java.lang.Short] => - Invoke(inputObject, "shortValue", ShortType) - case t if t <:< localTypeOf[java.lang.Byte] => - Invoke(inputObject, "byteValue", ByteType) - case t if t <:< localTypeOf[java.lang.Boolean] => - Invoke(inputObject, "booleanValue", BooleanType) - - case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => - val udt = getClassFromType(t) - .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() - val obj = NewInstance( - udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), - Nil, - dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) - Invoke(obj, "serialize", udt, inputObject :: Nil) - - case t if UDTRegistration.exists(getClassNameFromType(t)) => - val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() - .asInstanceOf[UserDefinedType[_]] - val obj = NewInstance( - udt.getClass, - Nil, - dataType = ObjectType(udt.getClass)) - Invoke(obj, "serialize", udt, inputObject :: Nil) - - case t if definedByConstructorParams(t) => - val params = getConstructorParameters(t) - val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => - if (javaKeywords.contains(fieldName)) { - throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " + - "cannot be used as field name\n" + walkedTypePath.mkString("\n")) - } + tpe match { + case _ if !inputObject.dataType.isInstanceOf[ObjectType] => inputObject - val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) - val clsName = getClassNameFromType(fieldType) - val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil - }) - val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) - expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) - - case other => - throw new UnsupportedOperationException( - s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) - } + case t if t <:< localTypeOf[Option[_]] => + val TypeRef(_, _, Seq(optType)) = t + val className = getClassNameFromType(optType) + val newPath = s"""- option value class: "$className"""" +: walkedTypePath + val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject) + serializerFor(unwrapped, optType, newPath) + + // Since List[_] also belongs to localTypeOf[Product], we put this case before + // "case t if definedByConstructorParams(t)" to make sure it will match to the + // case "localTypeOf[Seq[_]]" + case t if t <:< localTypeOf[Seq[_]] => + val TypeRef(_, _, Seq(elementType)) = t + toCatalystArray(inputObject, elementType) + + case t if t <:< localTypeOf[Array[_]] => + val TypeRef(_, _, Seq(elementType)) = t + toCatalystArray(inputObject, elementType) + + case t if t <:< localTypeOf[Map[_, _]] => + val TypeRef(_, _, Seq(keyType, valueType)) = t + + val keys = + Invoke( + Invoke(inputObject, "keysIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedKeys = toCatalystArray(keys, keyType) + + val values = + Invoke( + Invoke(inputObject, "valuesIterator", + ObjectType(classOf[scala.collection.Iterator[_]])), + "toSeq", + ObjectType(classOf[scala.collection.Seq[_]])) + val convertedValues = toCatalystArray(values, valueType) + + val Schema(keyDataType, _) = schemaFor(keyType) + val Schema(valueDataType, valueNullable) = schemaFor(valueType) + NewInstance( + classOf[ArrayBasedMapData], + convertedKeys :: convertedValues :: Nil, + dataType = MapType(keyDataType, valueDataType, valueNullable)) + + case t if t <:< localTypeOf[String] => + StaticInvoke( + classOf[UTF8String], + StringType, + "fromString", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.sql.Timestamp] => + StaticInvoke( + DateTimeUtils.getClass, + TimestampType, + "fromJavaTimestamp", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.sql.Date] => + StaticInvoke( + DateTimeUtils.getClass, + DateType, + "fromJavaDate", + inputObject :: Nil) + + case t if t <:< localTypeOf[BigDecimal] => + StaticInvoke( + Decimal.getClass, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.math.BigDecimal] => + StaticInvoke( + Decimal.getClass, + DecimalType.SYSTEM_DEFAULT, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.math.BigInteger] => + StaticInvoke( + Decimal.getClass, + DecimalType.BigIntDecimal, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[scala.math.BigInt] => + StaticInvoke( + Decimal.getClass, + DecimalType.BigIntDecimal, + "apply", + inputObject :: Nil) + + case t if t <:< localTypeOf[java.lang.Integer] => + Invoke(inputObject, "intValue", IntegerType) + case t if t <:< localTypeOf[java.lang.Long] => + Invoke(inputObject, "longValue", LongType) + case t if t <:< localTypeOf[java.lang.Double] => + Invoke(inputObject, "doubleValue", DoubleType) + case t if t <:< localTypeOf[java.lang.Float] => + Invoke(inputObject, "floatValue", FloatType) + case t if t <:< localTypeOf[java.lang.Short] => + Invoke(inputObject, "shortValue", ShortType) + case t if t <:< localTypeOf[java.lang.Byte] => + Invoke(inputObject, "byteValue", ByteType) + case t if t <:< localTypeOf[java.lang.Boolean] => + Invoke(inputObject, "booleanValue", BooleanType) + + case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) => + val udt = getClassFromType(t) + .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance() + val obj = NewInstance( + udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(), + Nil, + dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt())) + Invoke(obj, "serialize", udt, inputObject :: Nil) + + case t if UDTRegistration.exists(getClassNameFromType(t)) => + val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance() + .asInstanceOf[UserDefinedType[_]] + val obj = NewInstance( + udt.getClass, + Nil, + dataType = ObjectType(udt.getClass)) + Invoke(obj, "serialize", udt, inputObject :: Nil) + + case t if definedByConstructorParams(t) => + val params = getConstructorParameters(t) + val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) => + if (javaKeywords.contains(fieldName)) { + throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " + + "cannot be used as field name\n" + walkedTypePath.mkString("\n")) + } + + val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) + val clsName = getClassNameFromType(fieldType) + val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath + expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil + }) + val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) + expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) + + case other => + throw new UnsupportedOperationException( + s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n")) } + } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 02966796afdd78ed7012f8232fe346d181b2f1cd..4f6b4830cd6f6eee642294547820a76dad1bfcc4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -722,6 +722,7 @@ class Analyzer( // Else, throw exception. try { expr transformUp { + case GetColumnByOrdinal(ordinal, _) => plan.output(ordinal) case u @ UnresolvedAttribute(nameParts) => withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) } case UnresolvedExtractValue(child, fieldName) if child.resolved => @@ -1924,10 +1925,54 @@ class Analyzer( } else { inputAttributes } - val unbound = deserializer transform { - case b: BoundReference => inputs(b.ordinal) - } - resolveExpression(unbound, LocalRelation(inputs), throws = true) + + validateTopLevelTupleFields(deserializer, inputs) + val resolved = resolveExpression( + deserializer, LocalRelation(inputs), throws = true) + validateNestedTupleFields(resolved) + resolved + } + } + + private def fail(schema: StructType, maxOrdinal: Int): Unit = { + throw new AnalysisException(s"Try to map ${schema.simpleString} to Tuple${maxOrdinal + 1}, " + + "but failed as the number of fields does not line up.") + } + + /** + * For each top-level Tuple field, we use [[GetColumnByOrdinal]] to get its corresponding column + * by position. However, the actual number of columns may be different from the number of Tuple + * fields. This method is used to check the number of columns and fields, and throw an + * exception if they do not match. + */ + private def validateTopLevelTupleFields( + deserializer: Expression, inputs: Seq[Attribute]): Unit = { + val ordinals = deserializer.collect { + case GetColumnByOrdinal(ordinal, _) => ordinal + }.distinct.sorted + + if (ordinals.nonEmpty && ordinals != inputs.indices) { + fail(inputs.toStructType, ordinals.last) + } + } + + /** + * For each nested Tuple field, we use [[GetStructField]] to get its corresponding struct field + * by position. However, the actual number of struct fields may be different from the number + * of nested Tuple fields. This method is used to check the number of struct fields and nested + * Tuple fields, and throw an exception if they do not match. + */ + private def validateNestedTupleFields(deserializer: Expression): Unit = { + val structChildToOrdinals = deserializer + .collect { case g: GetStructField => g } + .groupBy(_.child) + .mapValues(_.map(_.ordinal).distinct.sorted) + + structChildToOrdinals.foreach { case (expr, ordinals) => + val schema = expr.dataType.asInstanceOf[StructType] + if (ordinals != schema.indices) { + fail(schema, ordinals.last) + } } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index e953eda7843c96974d7763665a5ef997acc84df0..b883546135f07a5dcffae4f5ec35125b4463192f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -366,3 +366,10 @@ case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq override def nullable: Boolean = throw new UnresolvedException(this, "nullable") override lazy val resolved = false } + +case class GetColumnByOrdinal(ordinal: Int, dataType: DataType) extends LeafExpression + with Unevaluable with NonSQLExpression { + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 2296946cd7c5e4a1764000d302689591a886d5b5..cc59d06fa3518bc70f69b0007ea175f3619fe1db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -17,19 +17,17 @@ package org.apache.spark.sql.catalyst.encoders -import java.util.concurrent.ConcurrentMap - import scala.reflect.ClassTag import scala.reflect.runtime.universe.{typeTag, TypeTag} -import org.apache.spark.sql.{AnalysisException, Encoder} +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} +import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation} import org.apache.spark.sql.types.{ObjectType, StructField, StructType} import org.apache.spark.util.Utils @@ -121,15 +119,15 @@ object ExpressionEncoder { val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => if (enc.flat) { enc.deserializer.transform { - case b: BoundReference => b.copy(ordinal = index) + case g: GetColumnByOrdinal => g.copy(ordinal = index) } } else { - val input = BoundReference(index, enc.schema, nullable = true) + val input = GetColumnByOrdinal(index, enc.schema) val deserialized = enc.deserializer.transformUp { case UnresolvedAttribute(nameParts) => assert(nameParts.length == 1) UnresolvedExtractValue(input, Literal(nameParts.head)) - case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal) + case GetColumnByOrdinal(ordinal, _) => GetStructField(input, ordinal) } If(IsNull(input), Literal.create(null, deserialized.dataType), deserialized) } @@ -192,6 +190,26 @@ case class ExpressionEncoder[T]( if (flat) require(serializer.size == 1) + /** + * Returns a new copy of this encoder, where the `deserializer` is resolved and bound to the + * given schema. + * + * Note that, ideally encoder is used as a container of serde expressions, the resolution and + * binding stuff should happen inside query framework. However, in some cases we need to + * use encoder as a function to do serialization directly(e.g. Dataset.collect), then we can use + * this method to do resolution and binding outside of query framework. + */ + def resolveAndBind( + attrs: Seq[Attribute] = schema.toAttributes, + analyzer: Analyzer = SimpleAnalyzer): ExpressionEncoder[T] = { + val dummyPlan = CatalystSerde.deserialize(LocalRelation(attrs))(this) + val analyzedPlan = analyzer.execute(dummyPlan) + analyzer.checkAnalysis(analyzedPlan) + val resolved = SimplifyCasts(analyzedPlan).asInstanceOf[DeserializeToObject].deserializer + val bound = BindReferences.bindReference(resolved, attrs) + copy(deserializer = bound) + } + @transient private lazy val extractProjection = GenerateUnsafeProjection.generate(serializer) @@ -201,16 +219,6 @@ case class ExpressionEncoder[T]( @transient private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil) - /** - * Returns this encoder where it has been bound to its own output (i.e. no remaping of columns - * is performed). - */ - def defaultBinding: ExpressionEncoder[T] = { - val attrs = schema.toAttributes - resolve(attrs, OuterScopes.outerScopes).bind(attrs) - } - - /** * Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form * of this object. @@ -236,7 +244,7 @@ case class ExpressionEncoder[T]( /** * Returns an object of type `T`, extracting the required values from the provided row. Note that - * you must `resolve` and `bind` an encoder to a specific schema before you can call this + * you must `resolveAndBind` an encoder to a specific schema before you can call this * function. */ def fromRow(row: InternalRow): T = try { @@ -259,94 +267,6 @@ case class ExpressionEncoder[T]( }) } - /** - * Validates `deserializer` to make sure it can be resolved by given schema, and produce - * friendly error messages to explain why it fails to resolve if there is something wrong. - */ - def validate(schema: Seq[Attribute]): Unit = { - def fail(st: StructType, maxOrdinal: Int): Unit = { - throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: " + StructType.fromAttributes(schema).simpleString + "\n" + - " - Target schema: " + this.schema.simpleString) - } - - // If this is a tuple encoder or tupled encoder, which means its leaf nodes are all - // `BoundReference`, make sure their ordinals are all valid. - var maxOrdinal = -1 - deserializer.foreach { - case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal - case _ => - } - if (maxOrdinal >= 0 && maxOrdinal != schema.length - 1) { - fail(StructType.fromAttributes(schema), maxOrdinal) - } - - // If we have nested tuple, the `fromRowExpression` will contains `GetStructField` instead of - // `UnresolvedExtractValue`, so we need to check if their ordinals are all valid. - // Note that, `BoundReference` contains the expected type, but here we need the actual type, so - // we unbound it by the given `schema` and propagate the actual type to `GetStructField`, after - // we resolve the `fromRowExpression`. - val resolved = SimpleAnalyzer.resolveExpression( - deserializer, - LocalRelation(schema), - throws = true) - - val unbound = resolved transform { - case b: BoundReference => schema(b.ordinal) - } - - val exprToMaxOrdinal = scala.collection.mutable.HashMap.empty[Expression, Int] - unbound.foreach { - case g: GetStructField => - val maxOrdinal = exprToMaxOrdinal.getOrElse(g.child, -1) - if (maxOrdinal < g.ordinal) { - exprToMaxOrdinal.update(g.child, g.ordinal) - } - case _ => - } - exprToMaxOrdinal.foreach { - case (expr, maxOrdinal) => - val schema = expr.dataType.asInstanceOf[StructType] - if (maxOrdinal != schema.length - 1) { - fail(schema, maxOrdinal) - } - } - } - - /** - * Returns a new copy of this encoder, where the `deserializer` is resolved to the given schema. - */ - def resolve( - schema: Seq[Attribute], - outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { - // Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check - // analysis, go through optimizer, etc. - val plan = Project( - Alias(UnresolvedDeserializer(deserializer, schema), "")() :: Nil, - LocalRelation(schema)) - val analyzedPlan = SimpleAnalyzer.execute(plan) - SimpleAnalyzer.checkAnalysis(analyzedPlan) - copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head) - } - - /** - * Returns a copy of this encoder where the `deserializer` has been bound to the - * ordinals of the given schema. Note that you need to first call resolve before bind. - */ - def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = { - copy(deserializer = BindReferences.bindReference(deserializer, schema)) - } - - /** - * Returns a new encoder with input columns shifted by `delta` ordinals - */ - def shift(delta: Int): ExpressionEncoder[T] = { - copy(deserializer = deserializer transform { - case r: BoundReference => r.copy(ordinal = r.ordinal + delta) - }) - } - protected val attrs = serializer.flatMap(_.collect { case _: UnresolvedAttribute => "" case a: Attribute => s"#${a.exprId}" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 0de9166aa29aaec87ed8cb09ea2842839c9b712a..3c6ae1c5ccd8ac9786f765d04872db6bc5565cd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData} import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal import org.apache.spark.sql.catalyst.expressions.objects._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -210,12 +211,7 @@ object RowEncoder { case p: PythonUserDefinedType => p.sqlType case other => other } - val field = BoundReference(i, dt, f.nullable) - If( - IsNull(field), - Literal.create(null, externalDataTypeFor(dt)), - deserializerFor(field) - ) + deserializerFor(GetColumnByOrdinal(i, dt)) } CreateExternalRow(fields, schema) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 98ce5dd2efd91ba1ea79d07ac8480619a9933bac..55d8adf0408fd83bb23513860e39733de84d96f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -30,26 +30,13 @@ object CatalystSerde { DeserializeToObject(deserializer, generateObjAttr[T], child) } - def deserialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): DeserializeToObject = { - val deserializer = UnresolvedDeserializer(encoder.deserializer) - DeserializeToObject(deserializer, generateObjAttrForRow(encoder), child) - } - def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = { SerializeFromObject(encoderFor[T].namedExpressions, child) } - def serialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): SerializeFromObject = { - SerializeFromObject(encoder.namedExpressions, child) - } - def generateObjAttr[T : Encoder]: Attribute = { AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)() } - - def generateObjAttrForRow(encoder: ExpressionEncoder[Row]): Attribute = { - AttributeReference("obj", encoder.deserializer.dataType, nullable = false)() - } } /** @@ -128,16 +115,16 @@ object MapPartitionsInR { schema: StructType, encoder: ExpressionEncoder[Row], child: LogicalPlan): LogicalPlan = { - val deserialized = CatalystSerde.deserialize(child, encoder) + val deserialized = CatalystSerde.deserialize(child)(encoder) val mapped = MapPartitionsInR( func, packageNames, broadcastVars, encoder.schema, schema, - CatalystSerde.generateObjAttrForRow(RowEncoder(schema)), + CatalystSerde.generateObjAttr(RowEncoder(schema)), deserialized) - CatalystSerde.serialize(mapped, RowEncoder(schema)) + CatalystSerde.serialize(mapped)(RowEncoder(schema)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala index 3ad0dae767be3a16422a160c4336bf8423f692eb..7251202c7bd5828f3aa6094424791b05faf07721 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala @@ -41,17 +41,17 @@ class EncoderResolutionSuite extends PlanTest { // int type can be up cast to long type val attrs1 = Seq('a.string, 'b.int) - encoder.resolve(attrs1, null).bind(attrs1).fromRow(InternalRow(str, 1)) + encoder.resolveAndBind(attrs1).fromRow(InternalRow(str, 1)) // int type can be up cast to string type val attrs2 = Seq('a.int, 'b.long) - encoder.resolve(attrs2, null).bind(attrs2).fromRow(InternalRow(1, 2L)) + encoder.resolveAndBind(attrs2).fromRow(InternalRow(1, 2L)) } test("real type doesn't match encoder schema but they are compatible: nested product") { val encoder = ExpressionEncoder[ComplexClass] val attrs = Seq('a.int, 'b.struct('a.int, 'b.long)) - encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L))) + encoder.resolveAndBind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L))) } test("real type doesn't match encoder schema but they are compatible: tupled encoder") { @@ -59,7 +59,7 @@ class EncoderResolutionSuite extends PlanTest { ExpressionEncoder[StringLongClass], ExpressionEncoder[Long]) val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int) - encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2)) + encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2)) } test("nullability of array type element should not fail analysis") { @@ -67,7 +67,7 @@ class EncoderResolutionSuite extends PlanTest { val attrs = 'a.array(IntegerType) :: Nil // It should pass analysis - val bound = encoder.resolve(attrs, null).bind(attrs) + val bound = encoder.resolveAndBind(attrs) // If no null values appear, it should works fine bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2)))) @@ -84,20 +84,16 @@ class EncoderResolutionSuite extends PlanTest { { val attrs = Seq('a.string, 'b.long, 'c.int) - assert(intercept[AnalysisException](encoder.validate(attrs)).message == + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "Try to map struct<a:string,b:bigint,c:int> to Tuple2, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct<a:string,b:bigint,c:int>\n" + - " - Target schema: struct<_1:string,_2:bigint>") + "but failed as the number of fields does not line up.") } { val attrs = Seq('a.string) - assert(intercept[AnalysisException](encoder.validate(attrs)).message == + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "Try to map struct<a:string> to Tuple2, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct<a:string>\n" + - " - Target schema: struct<_1:string,_2:bigint>") + "but failed as the number of fields does not line up.") } } @@ -106,26 +102,22 @@ class EncoderResolutionSuite extends PlanTest { { val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int)) - assert(intercept[AnalysisException](encoder.validate(attrs)).message == + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "Try to map struct<x:bigint,y:string,z:int> to Tuple2, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct<a:string,b:struct<x:bigint,y:string,z:int>>\n" + - " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>") + "but failed as the number of fields does not line up.") } { val attrs = Seq('a.string, 'b.struct('x.long)) - assert(intercept[AnalysisException](encoder.validate(attrs)).message == + assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message == "Try to map struct<x:bigint> to Tuple2, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct<a:string,b:struct<x:bigint>>\n" + - " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>") + "but failed as the number of fields does not line up.") } } test("throw exception if real type is not compatible with encoder schema") { val msg1 = intercept[AnalysisException] { - ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null) + ExpressionEncoder[StringIntClass].resolveAndBind(Seq('a.string, 'b.long)) }.message assert(msg1 == s""" @@ -138,7 +130,7 @@ class EncoderResolutionSuite extends PlanTest { val msg2 = intercept[AnalysisException] { val structType = new StructType().add("a", StringType).add("b", DecimalType.SYSTEM_DEFAULT) - ExpressionEncoder[ComplexClass].resolve(Seq('a.long, 'b.struct(structType)), null) + ExpressionEncoder[ComplexClass].resolveAndBind(Seq('a.long, 'b.struct(structType))) }.message assert(msg2 == s""" @@ -171,7 +163,7 @@ class EncoderResolutionSuite extends PlanTest { val to = ExpressionEncoder[U] val catalystType = from.schema.head.dataType.simpleString test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should success") { - to.resolve(from.schema.toAttributes, null) + to.resolveAndBind(from.schema.toAttributes) } } @@ -180,7 +172,7 @@ class EncoderResolutionSuite extends PlanTest { val to = ExpressionEncoder[U] val catalystType = from.schema.head.dataType.simpleString test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should fail") { - intercept[AnalysisException](to.resolve(from.schema.toAttributes, null)) + intercept[AnalysisException](to.resolveAndBind(from.schema.toAttributes)) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index 232dcc9ee51ca1bb77247c0ba7705cc810a19931..a1f9259f139edab305b8c157f7bbcf4c761b0393 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -27,6 +27,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.Encoders import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData} import org.apache.spark.sql.catalyst.analysis.AnalysisTest +import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project} @@ -334,7 +335,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { val encoder = implicitly[ExpressionEncoder[T]] val row = encoder.toRow(input) val schema = encoder.schema.toAttributes - val boundEncoder = encoder.defaultBinding + val boundEncoder = encoder.resolveAndBind() val convertedBack = try boundEncoder.fromRow(row) catch { case e: Exception => fail( @@ -350,12 +351,8 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { } // Test the correct resolution of serialization / deserialization. - val attr = AttributeReference("obj", ObjectType(encoder.clsTag.runtimeClass))() - val inputPlan = LocalRelation(attr) - val plan = - Project(Alias(encoder.deserializer, "obj")() :: Nil, - Project(encoder.namedExpressions, - inputPlan)) + val attr = AttributeReference("obj", encoder.deserializer.dataType)() + val plan = LocalRelation(attr).serialize[T].deserialize[T] assertAnalysisSuccess(plan) val isCorrect = (input, convertedBack) match { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 39fcc7225b37e9b02a48c687ad4bc49cbda467c6..6f1bc80c1cdda371c75f4626fafb8bcb20a25539 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -135,7 +135,7 @@ class RowEncoderSuite extends SparkFunSuite { .add("string", StringType) .add("double", DoubleType)) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val input: Row = Row((100, "test", 0.123)) val row = encoder.toRow(input) @@ -152,7 +152,7 @@ class RowEncoderSuite extends SparkFunSuite { .add("scala_decimal", DecimalType.SYSTEM_DEFAULT) .add("catalyst_decimal", DecimalType.SYSTEM_DEFAULT) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val javaDecimal = new java.math.BigDecimal("1234.5678") val scalaDecimal = BigDecimal("1234.5678") @@ -169,7 +169,7 @@ class RowEncoderSuite extends SparkFunSuite { test("RowEncoder should preserve decimal precision and scale") { val schema = new StructType().add("decimal", DecimalType(10, 5), false) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val decimal = Decimal("67123.45") val input = Row(decimal) val row = encoder.toRow(input) @@ -179,7 +179,7 @@ class RowEncoderSuite extends SparkFunSuite { test("RowEncoder should preserve schema nullability") { val schema = new StructType().add("int", IntegerType, nullable = false) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() assert(encoder.serializer.length == 1) assert(encoder.serializer.head.dataType == IntegerType) assert(encoder.serializer.head.nullable == false) @@ -195,7 +195,7 @@ class RowEncoderSuite extends SparkFunSuite { new StructType().add("int", IntegerType, nullable = false), nullable = false), nullable = false) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() assert(encoder.serializer.length == 1) assert(encoder.serializer.head.dataType == new StructType() @@ -212,7 +212,7 @@ class RowEncoderSuite extends SparkFunSuite { .add("array", ArrayType(IntegerType)) .add("nestedArray", ArrayType(ArrayType(StringType))) .add("deepNestedArray", ArrayType(ArrayType(ArrayType(LongType)))) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val input = Row( Array(1, 2, null), Array(Array("abc", null), null), @@ -226,7 +226,7 @@ class RowEncoderSuite extends SparkFunSuite { private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val inputGenerator = RandomDataGenerator.forType(schema, nullable = false).get var input: Row = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 369b772d322c01f7b1af4142228f9a5fa29bb1fe..96c871d0343551a71c0b13f48de62703ced71cdf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -192,24 +192,24 @@ class Dataset[T] private[sql]( } /** - * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is - * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the - * same object type (that will be possibly resolved to a different schema). + * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the + * passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use + * it when constructing new [[Dataset]] objects that have the same object type (that will be + * possibly resolved to a different schema). */ - private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(encoder) - unresolvedTEncoder.validate(logicalPlan.output) - - /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ - private[sql] val resolvedTEncoder: ExpressionEncoder[T] = - unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes) + private[sql] implicit val exprEnc: ExpressionEncoder[T] = encoderFor(encoder) /** - * The encoder where the expressions used to construct an object from an input row have been - * bound to the ordinals of this [[Dataset]]'s output schema. + * Encoder is used mostly as a container of serde expressions in [[Dataset]]. We build logical + * plans by these serde expressions and execute it within the query framework. However, for + * performance reasons we may want to use encoder as a function to deserialize internal rows to + * custom objects, e.g. collect. Here we resolve and bind the encoder so that we can call its + * `fromRow` method later. */ - private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output) + private val boundEnc = + exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer) - private implicit def classTag = unresolvedTEncoder.clsTag + private implicit def classTag = exprEnc.clsTag // sqlContext must be val because a stable identifier is expected when you import implicits @transient lazy val sqlContext: SQLContext = sparkSession.sqlContext @@ -761,7 +761,7 @@ class Dataset[T] private[sql]( // Note that we do this before joining them, to enable the join operator to return null for one // side, in cases like outer-join. val left = { - val combined = if (this.unresolvedTEncoder.flat) { + val combined = if (this.exprEnc.flat) { assert(joined.left.output.length == 1) Alias(joined.left.output.head, "_1")() } else { @@ -771,7 +771,7 @@ class Dataset[T] private[sql]( } val right = { - val combined = if (other.unresolvedTEncoder.flat) { + val combined = if (other.exprEnc.flat) { assert(joined.right.output.length == 1) Alias(joined.right.output.head, "_2")() } else { @@ -784,14 +784,14 @@ class Dataset[T] private[sql]( // combine the outputs of each join side. val conditionExpr = joined.condition.get transformUp { case a: Attribute if joined.left.outputSet.contains(a) => - if (this.unresolvedTEncoder.flat) { + if (this.exprEnc.flat) { left.output.head } else { val index = joined.left.output.indexWhere(_.exprId == a.exprId) GetStructField(left.output.head, index) } case a: Attribute if joined.right.outputSet.contains(a) => - if (other.unresolvedTEncoder.flat) { + if (other.exprEnc.flat) { right.output.head } else { val index = joined.right.output.indexWhere(_.exprId == a.exprId) @@ -800,7 +800,7 @@ class Dataset[T] private[sql]( } implicit val tuple2Encoder: Encoder[(T, U)] = - ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) + ExpressionEncoder.tuple(this.exprEnc, other.exprEnc) withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr))) } @@ -1024,7 +1024,7 @@ class Dataset[T] private[sql]( sparkSession, Project( c1.withInputType( - unresolvedTEncoder.deserializer, + exprEnc.deserializer, logicalPlan.output).named :: Nil, logicalPlan), implicitly[Encoder[U1]]) @@ -1038,7 +1038,7 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(unresolvedTEncoder.deserializer, logicalPlan.output).named) + columns.map(_.withInputType(exprEnc.deserializer, logicalPlan.output).named) val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders)) } @@ -2153,14 +2153,14 @@ class Dataset[T] private[sql]( */ def collectAsList(): java.util.List[T] = withCallback("collectAsList", toDF()) { _ => withNewExecutionId { - val values = queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow) + val values = queryExecution.executedPlan.executeCollect().map(boundEnc.fromRow) java.util.Arrays.asList(values : _*) } } private def collect(needCallback: Boolean): Array[T] = { def execute(): Array[T] = withNewExecutionId { - queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow) + queryExecution.executedPlan.executeCollect().map(boundEnc.fromRow) } if (needCallback) { @@ -2184,7 +2184,7 @@ class Dataset[T] private[sql]( */ def toLocalIterator(): java.util.Iterator[T] = withCallback("toLocalIterator", toDF()) { _ => withNewExecutionId { - queryExecution.executedPlan.executeToIterator().map(boundTEncoder.fromRow).asJava + queryExecution.executedPlan.executeToIterator().map(boundEnc.fromRow).asJava } } @@ -2322,7 +2322,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ lazy val rdd: RDD[T] = { - val objectType = unresolvedTEncoder.deserializer.dataType + val objectType = exprEnc.deserializer.dataType val deserialized = CatalystSerde.deserialize[T](logicalPlan) sparkSession.sessionState.executePlan(deserialized).toRdd.mapPartitions { rows => rows.map(_.get(0, objectType).asInstanceOf[T]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index 53f4ea647c85322cc22ab6b0a189be8963362a14..a6867a67eeade4b7a5e575d584dbfe67e0aa0307 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -42,17 +42,9 @@ class KeyValueGroupedDataset[K, V] private[sql]( private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) extends Serializable { - // Similar to [[Dataset]], we use unresolved encoders for later composition and resolved encoders - // when constructing new logical plans that will operate on the output of the current - // queryexecution. - - private implicit val unresolvedKEncoder = encoderFor(kEncoder) - private implicit val unresolvedVEncoder = encoderFor(vEncoder) - - private val resolvedKEncoder = - unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes) - private val resolvedVEncoder = - unresolvedVEncoder.resolve(dataAttributes, OuterScopes.outerScopes) + // Similar to [[Dataset]], we turn the passed in encoder to `ExpressionEncoder` explicitly. + private implicit val kExprEnc = encoderFor(kEncoder) + private implicit val vExprEnc = encoderFor(vEncoder) private def logicalPlan = queryExecution.analyzed private def sparkSession = queryExecution.sparkSession @@ -67,7 +59,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( def keyAs[L : Encoder]: KeyValueGroupedDataset[L, V] = new KeyValueGroupedDataset( encoderFor[L], - unresolvedVEncoder, + vExprEnc, queryExecution, dataAttributes, groupingAttributes) @@ -187,7 +179,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = { val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f))) - implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedVEncoder) + implicit val resultEncoder = ExpressionEncoder.tuple(kExprEnc, vExprEnc) flatMapGroups(func) } @@ -209,8 +201,8 @@ class KeyValueGroupedDataset[K, V] private[sql]( protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(unresolvedVEncoder.deserializer, dataAttributes).named) - val keyColumn = if (resolvedKEncoder.flat) { + columns.map(_.withInputType(vExprEnc.deserializer, dataAttributes).named) + val keyColumn = if (kExprEnc.flat) { assert(groupingAttributes.length == 1) groupingAttributes.head } else { @@ -222,7 +214,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( new Dataset( sparkSession, execution, - ExpressionEncoder.tuple(unresolvedKEncoder +: encoders)) + ExpressionEncoder.tuple(kExprEnc +: encoders)) } /** @@ -287,7 +279,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( def cogroup[U, R : Encoder]( other: KeyValueGroupedDataset[K, U])( f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { - implicit val uEncoder = other.unresolvedVEncoder + implicit val uEncoder = other.vExprEnc Dataset[R]( sparkSession, CoGroup( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 58850a7d4b49f6b9d22a68ebb1e2eb064d614cc6..49b6eab8db5b0d2080ad702d8796004b7615b6af 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -215,7 +215,7 @@ class RelationalGroupedDataset protected[sql]( def agg(expr: Column, exprs: Column*): DataFrame = { toDF((expr +: exprs).map { case typed: TypedColumn[_, _] => - typed.withInputType(df.unresolvedTEncoder.deserializer, df.logicalPlan.output).expr + typed.withInputType(df.exprEnc.deserializer, df.logicalPlan.output).expr case c => c.expr }) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 8f94184764c0fc0b939090d0fefdf58f13a6ee27..ecb56e2a2848aa3970b1eaeb7de97f13733470e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -33,9 +33,9 @@ object TypedAggregateExpression { aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = { val bufferEncoder = encoderFor[BUF] val bufferSerializer = bufferEncoder.namedExpressions - val bufferDeserializer = bufferEncoder.deserializer.transform { - case b: BoundReference => bufferSerializer(b.ordinal).toAttribute - } + val bufferDeserializer = UnresolvedDeserializer( + bufferEncoder.deserializer, + bufferSerializer.map(_.toAttribute)) val outputEncoder = encoderFor[OUT] val outputType = if (outputEncoder.flat) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index d89e98645b3670a0d8a60a18806cf8463f70b2bb..4dbd1665e4bbc45f09aae6154ac58d6208621a8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -924,7 +924,7 @@ object functions { * @since 1.5.0 */ def broadcast[T](df: Dataset[T]): Dataset[T] = { - Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.unresolvedTEncoder) + Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.exprEnc) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index df8f4b0610af284164c1586e3824b315fd161982..d1c232974e9cea742847af0bac3e8a58c58e825e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -566,18 +566,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext { }.message assert(message == "Try to map struct<a:string,b:int> to Tuple3, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct<a:string,b:int>\n" + - " - Target schema: struct<_1:string,_2:int,_3:bigint>") + "but failed as the number of fields does not line up.") val message2 = intercept[AnalysisException] { ds.as[Tuple1[String]] }.message assert(message2 == "Try to map struct<a:string,b:int> to Tuple1, " + - "but failed as the number of fields does not line up.\n" + - " - Input schema: struct<a:string,b:int>\n" + - " - Target schema: struct<_1:string>") + "but failed as the number of fields does not line up.") } test("SPARK-13440: Resolving option fields") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index a1a9b66c1feec6dd963f1dcfee52001be1b5cc98..9c044f4e8fd6e6770d182e5ea4ffeebde5f65ba4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -81,7 +81,7 @@ abstract class QueryTest extends PlanTest { expectedAnswer: T*): Unit = { checkAnswer( ds.toDF(), - spark.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq) + spark.createDataset(expectedAnswer)(ds.exprEnc).toDF().collect().toSeq) checkDecoding(ds, expectedAnswer: _*) } @@ -94,8 +94,8 @@ abstract class QueryTest extends PlanTest { fail( s""" |Exception collecting dataset as objects - |${ds.resolvedTEncoder} - |${ds.resolvedTEncoder.deserializer.treeString} + |${ds.exprEnc} + |${ds.exprEnc.deserializer.treeString} |${ds.queryExecution} """.stripMargin, e) } @@ -114,7 +114,7 @@ abstract class QueryTest extends PlanTest { fail( s"""Decoded objects do not match expected objects: |$comparison - |${ds.resolvedTEncoder.deserializer.treeString} + |${ds.exprEnc.deserializer.treeString} """.stripMargin) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala index 6f10e4b80577ab1cf9011d855efb74ddb8e6be48..80340b5552c6d01f4f3b06dee3924b9857cd8290 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala @@ -27,7 +27,7 @@ class GroupedIteratorSuite extends SparkFunSuite { test("basic") { val schema = new StructType().add("i", IntegerType).add("s", StringType) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) val grouped = GroupedIterator(input.iterator.map(encoder.toRow), Seq('i.int.at(0)), schema.toAttributes) @@ -45,7 +45,7 @@ class GroupedIteratorSuite extends SparkFunSuite { test("group by 2 columns") { val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val input = Seq( Row(1, 2L, "a"), @@ -72,7 +72,7 @@ class GroupedIteratorSuite extends SparkFunSuite { test("do nothing to the value iterator") { val schema = new StructType().add("i", IntegerType).add("s", StringType) - val encoder = RowEncoder(schema) + val encoder = RowEncoder(schema).resolveAndBind() val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c")) val grouped = GroupedIterator(input.iterator.map(encoder.toRow), Seq('i.int.at(0)), schema.toAttributes) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index dd8672aa641d291d90a0e880d0d09b8dc0aaab36..194c3e7307255c14cab1bb38a38355cb8793604b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -110,7 +110,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { object CheckAnswer { def apply[A : Encoder](data: A*): CheckAnswerRows = { val encoder = encoderFor[A] - val toExternalRow = RowEncoder(encoder.schema) + val toExternalRow = RowEncoder(encoder.schema).resolveAndBind() CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), false) } @@ -124,7 +124,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts { object CheckLastBatch { def apply[A : Encoder](data: A*): CheckAnswerRows = { val encoder = encoderFor[A] - val toExternalRow = RowEncoder(encoder.schema) + val toExternalRow = RowEncoder(encoder.schema).resolveAndBind() CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), true) }