From 190ff274fd71662023a804cf98400c71f9f7da4f Mon Sep 17 00:00:00 2001
From: Wenchen Fan <wenchen@databricks.com>
Date: Fri, 3 Jun 2016 00:43:02 -0700
Subject: [PATCH] [SPARK-15494][SQL] encoder code cleanup

## What changes were proposed in this pull request?

Our encoder framework has been evolved a lot, this PR tries to clean up the code to make it more readable and emphasise the concept that encoder should be used as a container of serde expressions.

1. move validation logic to analyzer instead of encoder
2. only have a `resolveAndBind` method in encoder instead of `resolve` and `bind`, as we don't have the encoder life cycle concept anymore.
3. `Dataset` don't need to keep a resolved encoder, as there is no such concept anymore. bound encoder is still needed to do serialization outside of query framework.
4. Using `BoundReference` to represent an unresolved field in deserializer expression is kind of weird, this PR adds a `GetColumnByOrdinal` for this purpose. (serializer expression still use `BoundReference`, we can replace it with `GetColumnByOrdinal` in follow-ups)

## How was this patch tested?

existing test

Author: Wenchen Fan <wenchen@databricks.com>
Author: Cheng Lian <lian@databricks.com>

Closes #13269 from cloud-fan/clean-encoder.
---
 .../linalg/UDTSerializationBenchmark.scala    |   2 +-
 .../scala/org/apache/spark/sql/Encoders.scala |   3 +-
 .../sql/catalyst/JavaTypeInference.scala      |   6 +-
 .../spark/sql/catalyst/ScalaReflection.scala  | 307 +++++++++---------
 .../sql/catalyst/analysis/Analyzer.scala      |  53 ++-
 .../sql/catalyst/analysis/unresolved.scala    |   7 +
 .../catalyst/encoders/ExpressionEncoder.scala | 134 ++------
 .../sql/catalyst/encoders/RowEncoder.scala    |   8 +-
 .../sql/catalyst/plans/logical/object.scala   |  19 +-
 .../encoders/EncoderResolutionSuite.scala     |  42 +--
 .../encoders/ExpressionEncoderSuite.scala     |  11 +-
 .../catalyst/encoders/RowEncoderSuite.scala   |  14 +-
 .../scala/org/apache/spark/sql/Dataset.scala  |  48 +--
 .../spark/sql/KeyValueGroupedDataset.scala    |  26 +-
 .../spark/sql/RelationalGroupedDataset.scala  |   2 +-
 .../aggregate/TypedAggregateExpression.scala  |   6 +-
 .../org/apache/spark/sql/functions.scala      |   2 +-
 .../org/apache/spark/sql/DatasetSuite.scala   |   8 +-
 .../org/apache/spark/sql/QueryTest.scala      |   8 +-
 .../sql/execution/GroupedIteratorSuite.scala  |   6 +-
 .../spark/sql/streaming/StreamTest.scala      |   4 +-
 21 files changed, 324 insertions(+), 392 deletions(-)

diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala
index be7110ad6b..8b439e6b7a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/UDTSerializationBenchmark.scala
@@ -29,7 +29,7 @@ object UDTSerializationBenchmark {
     val iters = 1e2.toInt
     val numRows = 1e3.toInt
 
-    val encoder = ExpressionEncoder[Vector].defaultBinding
+    val encoder = ExpressionEncoder[Vector].resolveAndBind()
 
     val vectors = (1 to numRows).map { i =>
       Vectors.dense(Array.fill(1e5.toInt)(1.0 * i))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
index fa96f8223d..673c587b18 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala
@@ -23,6 +23,7 @@ import scala.reflect.{classTag, ClassTag}
 import scala.reflect.runtime.universe.TypeTag
 
 import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
 import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
 import org.apache.spark.sql.catalyst.expressions.objects.{DecodeUsingSerializer, EncodeUsingSerializer}
 import org.apache.spark.sql.catalyst.expressions.BoundReference
@@ -208,7 +209,7 @@ object Encoders {
           BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)),
       deserializer =
         DecodeUsingSerializer[T](
-          BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo),
+          GetColumnByOrdinal(0, BinaryType), classTag[T], kryo = useKryo),
       clsTag = classTag[T]
     )
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 1fe143494a..b3a233ae39 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -25,7 +25,7 @@ import scala.language.existentials
 
 import com.google.common.reflect.TypeToken
 
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.objects._
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
@@ -177,8 +177,8 @@ object JavaTypeInference {
       .map(p => UnresolvedExtractValue(p, expressions.Literal(part)))
       .getOrElse(UnresolvedAttribute(part))
 
-    /** Returns the current path or `BoundReference`. */
-    def getPath: Expression = path.getOrElse(BoundReference(0, inferDataType(typeToken)._1, true))
+    /** Returns the current path or `GetColumnByOrdinal`. */
+    def getPath: Expression = path.getOrElse(GetColumnByOrdinal(0, inferDataType(typeToken)._1))
 
     typeToken.getRawType match {
       case c if !inferExternalType(c).isInstanceOf[ObjectType] => getPath
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 4750861817..78c145d4fd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql.catalyst
 
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.analysis.{GetColumnByOrdinal, UnresolvedAttribute, UnresolvedExtractValue}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.objects._
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
@@ -156,17 +156,17 @@ object ScalaReflection extends ScalaReflection {
         walkedTypePath: Seq[String]): Expression = {
       val newPath = path
         .map(p => GetStructField(p, ordinal))
-        .getOrElse(BoundReference(ordinal, dataType, false))
+        .getOrElse(GetColumnByOrdinal(ordinal, dataType))
       upCastToExpectedType(newPath, dataType, walkedTypePath)
     }
 
-    /** Returns the current path or `BoundReference`. */
+    /** Returns the current path or `GetColumnByOrdinal`. */
     def getPath: Expression = {
       val dataType = schemaFor(tpe).dataType
       if (path.isDefined) {
         path.get
       } else {
-        upCastToExpectedType(BoundReference(0, dataType, true), dataType, walkedTypePath)
+        upCastToExpectedType(GetColumnByOrdinal(0, dataType), dataType, walkedTypePath)
       }
     }
 
@@ -421,7 +421,7 @@ object ScalaReflection extends ScalaReflection {
   def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = {
     val tpe = localTypeOf[T]
     val clsName = getClassNameFromType(tpe)
-    val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil
+    val walkedTypePath = s"""- root class: "$clsName"""" :: Nil
     serializerFor(inputObject, tpe, walkedTypePath) match {
       case expressions.If(_, _, s: CreateNamedStruct) if definedByConstructorParams(tpe) => s
       case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil)
@@ -449,157 +449,156 @@ object ScalaReflection extends ScalaReflection {
       }
     }
 
-    if (!inputObject.dataType.isInstanceOf[ObjectType]) {
-      inputObject
-    } else {
-      tpe match {
-        case t if t <:< localTypeOf[Option[_]] =>
-          val TypeRef(_, _, Seq(optType)) = t
-          val className = getClassNameFromType(optType)
-          val newPath = s"""- option value class: "$className"""" +: walkedTypePath
-          val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject)
-          serializerFor(unwrapped, optType, newPath)
-
-        // Since List[_] also belongs to localTypeOf[Product], we put this case before
-        // "case t if definedByConstructorParams(t)" to make sure it will match to the
-        // case "localTypeOf[Seq[_]]"
-        case t if t <:< localTypeOf[Seq[_]] =>
-          val TypeRef(_, _, Seq(elementType)) = t
-          toCatalystArray(inputObject, elementType)
-
-        case t if t <:< localTypeOf[Array[_]] =>
-          val TypeRef(_, _, Seq(elementType)) = t
-          toCatalystArray(inputObject, elementType)
-
-        case t if t <:< localTypeOf[Map[_, _]] =>
-          val TypeRef(_, _, Seq(keyType, valueType)) = t
-
-          val keys =
-            Invoke(
-              Invoke(inputObject, "keysIterator",
-                ObjectType(classOf[scala.collection.Iterator[_]])),
-              "toSeq",
-              ObjectType(classOf[scala.collection.Seq[_]]))
-          val convertedKeys = toCatalystArray(keys, keyType)
-
-          val values =
-            Invoke(
-              Invoke(inputObject, "valuesIterator",
-                ObjectType(classOf[scala.collection.Iterator[_]])),
-              "toSeq",
-              ObjectType(classOf[scala.collection.Seq[_]]))
-          val convertedValues = toCatalystArray(values, valueType)
-
-          val Schema(keyDataType, _) = schemaFor(keyType)
-          val Schema(valueDataType, valueNullable) = schemaFor(valueType)
-          NewInstance(
-            classOf[ArrayBasedMapData],
-            convertedKeys :: convertedValues :: Nil,
-            dataType = MapType(keyDataType, valueDataType, valueNullable))
-
-        case t if t <:< localTypeOf[String] =>
-          StaticInvoke(
-            classOf[UTF8String],
-            StringType,
-            "fromString",
-            inputObject :: Nil)
-
-        case t if t <:< localTypeOf[java.sql.Timestamp] =>
-          StaticInvoke(
-            DateTimeUtils.getClass,
-            TimestampType,
-            "fromJavaTimestamp",
-            inputObject :: Nil)
-
-        case t if t <:< localTypeOf[java.sql.Date] =>
-          StaticInvoke(
-            DateTimeUtils.getClass,
-            DateType,
-            "fromJavaDate",
-            inputObject :: Nil)
-
-        case t if t <:< localTypeOf[BigDecimal] =>
-          StaticInvoke(
-            Decimal.getClass,
-            DecimalType.SYSTEM_DEFAULT,
-            "apply",
-            inputObject :: Nil)
-
-        case t if t <:< localTypeOf[java.math.BigDecimal] =>
-          StaticInvoke(
-            Decimal.getClass,
-            DecimalType.SYSTEM_DEFAULT,
-            "apply",
-            inputObject :: Nil)
-
-        case t if t <:< localTypeOf[java.math.BigInteger] =>
-          StaticInvoke(
-            Decimal.getClass,
-            DecimalType.BigIntDecimal,
-            "apply",
-            inputObject :: Nil)
-
-        case t if t <:< localTypeOf[scala.math.BigInt] =>
-          StaticInvoke(
-            Decimal.getClass,
-            DecimalType.BigIntDecimal,
-            "apply",
-            inputObject :: Nil)
-
-        case t if t <:< localTypeOf[java.lang.Integer] =>
-          Invoke(inputObject, "intValue", IntegerType)
-        case t if t <:< localTypeOf[java.lang.Long] =>
-          Invoke(inputObject, "longValue", LongType)
-        case t if t <:< localTypeOf[java.lang.Double] =>
-          Invoke(inputObject, "doubleValue", DoubleType)
-        case t if t <:< localTypeOf[java.lang.Float] =>
-          Invoke(inputObject, "floatValue", FloatType)
-        case t if t <:< localTypeOf[java.lang.Short] =>
-          Invoke(inputObject, "shortValue", ShortType)
-        case t if t <:< localTypeOf[java.lang.Byte] =>
-          Invoke(inputObject, "byteValue", ByteType)
-        case t if t <:< localTypeOf[java.lang.Boolean] =>
-          Invoke(inputObject, "booleanValue", BooleanType)
-
-        case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
-          val udt = getClassFromType(t)
-            .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
-          val obj = NewInstance(
-            udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
-            Nil,
-            dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
-          Invoke(obj, "serialize", udt, inputObject :: Nil)
-
-        case t if UDTRegistration.exists(getClassNameFromType(t)) =>
-          val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
-            .asInstanceOf[UserDefinedType[_]]
-          val obj = NewInstance(
-            udt.getClass,
-            Nil,
-            dataType = ObjectType(udt.getClass))
-          Invoke(obj, "serialize", udt, inputObject :: Nil)
-
-        case t if definedByConstructorParams(t) =>
-          val params = getConstructorParameters(t)
-          val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
-            if (javaKeywords.contains(fieldName)) {
-              throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " +
-                "cannot be used as field name\n" + walkedTypePath.mkString("\n"))
-            }
+    tpe match {
+      case _ if !inputObject.dataType.isInstanceOf[ObjectType] => inputObject
 
-            val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
-            val clsName = getClassNameFromType(fieldType)
-            val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
-            expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil
-          })
-          val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
-          expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
-
-        case other =>
-          throw new UnsupportedOperationException(
-            s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
-      }
+      case t if t <:< localTypeOf[Option[_]] =>
+        val TypeRef(_, _, Seq(optType)) = t
+        val className = getClassNameFromType(optType)
+        val newPath = s"""- option value class: "$className"""" +: walkedTypePath
+        val unwrapped = UnwrapOption(dataTypeFor(optType), inputObject)
+        serializerFor(unwrapped, optType, newPath)
+
+      // Since List[_] also belongs to localTypeOf[Product], we put this case before
+      // "case t if definedByConstructorParams(t)" to make sure it will match to the
+      // case "localTypeOf[Seq[_]]"
+      case t if t <:< localTypeOf[Seq[_]] =>
+        val TypeRef(_, _, Seq(elementType)) = t
+        toCatalystArray(inputObject, elementType)
+
+      case t if t <:< localTypeOf[Array[_]] =>
+        val TypeRef(_, _, Seq(elementType)) = t
+        toCatalystArray(inputObject, elementType)
+
+      case t if t <:< localTypeOf[Map[_, _]] =>
+        val TypeRef(_, _, Seq(keyType, valueType)) = t
+
+        val keys =
+          Invoke(
+            Invoke(inputObject, "keysIterator",
+              ObjectType(classOf[scala.collection.Iterator[_]])),
+            "toSeq",
+            ObjectType(classOf[scala.collection.Seq[_]]))
+        val convertedKeys = toCatalystArray(keys, keyType)
+
+        val values =
+          Invoke(
+            Invoke(inputObject, "valuesIterator",
+              ObjectType(classOf[scala.collection.Iterator[_]])),
+            "toSeq",
+            ObjectType(classOf[scala.collection.Seq[_]]))
+        val convertedValues = toCatalystArray(values, valueType)
+
+        val Schema(keyDataType, _) = schemaFor(keyType)
+        val Schema(valueDataType, valueNullable) = schemaFor(valueType)
+        NewInstance(
+          classOf[ArrayBasedMapData],
+          convertedKeys :: convertedValues :: Nil,
+          dataType = MapType(keyDataType, valueDataType, valueNullable))
+
+      case t if t <:< localTypeOf[String] =>
+        StaticInvoke(
+          classOf[UTF8String],
+          StringType,
+          "fromString",
+          inputObject :: Nil)
+
+      case t if t <:< localTypeOf[java.sql.Timestamp] =>
+        StaticInvoke(
+          DateTimeUtils.getClass,
+          TimestampType,
+          "fromJavaTimestamp",
+          inputObject :: Nil)
+
+      case t if t <:< localTypeOf[java.sql.Date] =>
+        StaticInvoke(
+          DateTimeUtils.getClass,
+          DateType,
+          "fromJavaDate",
+          inputObject :: Nil)
+
+      case t if t <:< localTypeOf[BigDecimal] =>
+        StaticInvoke(
+          Decimal.getClass,
+          DecimalType.SYSTEM_DEFAULT,
+          "apply",
+          inputObject :: Nil)
+
+      case t if t <:< localTypeOf[java.math.BigDecimal] =>
+        StaticInvoke(
+          Decimal.getClass,
+          DecimalType.SYSTEM_DEFAULT,
+          "apply",
+          inputObject :: Nil)
+
+      case t if t <:< localTypeOf[java.math.BigInteger] =>
+        StaticInvoke(
+          Decimal.getClass,
+          DecimalType.BigIntDecimal,
+          "apply",
+          inputObject :: Nil)
+
+      case t if t <:< localTypeOf[scala.math.BigInt] =>
+        StaticInvoke(
+          Decimal.getClass,
+          DecimalType.BigIntDecimal,
+          "apply",
+          inputObject :: Nil)
+
+      case t if t <:< localTypeOf[java.lang.Integer] =>
+        Invoke(inputObject, "intValue", IntegerType)
+      case t if t <:< localTypeOf[java.lang.Long] =>
+        Invoke(inputObject, "longValue", LongType)
+      case t if t <:< localTypeOf[java.lang.Double] =>
+        Invoke(inputObject, "doubleValue", DoubleType)
+      case t if t <:< localTypeOf[java.lang.Float] =>
+        Invoke(inputObject, "floatValue", FloatType)
+      case t if t <:< localTypeOf[java.lang.Short] =>
+        Invoke(inputObject, "shortValue", ShortType)
+      case t if t <:< localTypeOf[java.lang.Byte] =>
+        Invoke(inputObject, "byteValue", ByteType)
+      case t if t <:< localTypeOf[java.lang.Boolean] =>
+        Invoke(inputObject, "booleanValue", BooleanType)
+
+      case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
+        val udt = getClassFromType(t)
+          .getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
+        val obj = NewInstance(
+          udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+          Nil,
+          dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
+        Invoke(obj, "serialize", udt, inputObject :: Nil)
+
+      case t if UDTRegistration.exists(getClassNameFromType(t)) =>
+        val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
+          .asInstanceOf[UserDefinedType[_]]
+        val obj = NewInstance(
+          udt.getClass,
+          Nil,
+          dataType = ObjectType(udt.getClass))
+        Invoke(obj, "serialize", udt, inputObject :: Nil)
+
+      case t if definedByConstructorParams(t) =>
+        val params = getConstructorParameters(t)
+        val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
+          if (javaKeywords.contains(fieldName)) {
+            throw new UnsupportedOperationException(s"`$fieldName` is a reserved keyword and " +
+              "cannot be used as field name\n" + walkedTypePath.mkString("\n"))
+          }
+
+          val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
+          val clsName = getClassNameFromType(fieldType)
+          val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
+          expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil
+        })
+        val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
+        expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
+
+      case other =>
+        throw new UnsupportedOperationException(
+          s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
     }
+
   }
 
   /**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 02966796af..4f6b4830cd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -722,6 +722,7 @@ class Analyzer(
     // Else, throw exception.
     try {
       expr transformUp {
+        case GetColumnByOrdinal(ordinal, _) => plan.output(ordinal)
         case u @ UnresolvedAttribute(nameParts) =>
           withPosition(u) { plan.resolve(nameParts, resolver).getOrElse(u) }
         case UnresolvedExtractValue(child, fieldName) if child.resolved =>
@@ -1924,10 +1925,54 @@ class Analyzer(
           } else {
             inputAttributes
           }
-          val unbound = deserializer transform {
-            case b: BoundReference => inputs(b.ordinal)
-          }
-          resolveExpression(unbound, LocalRelation(inputs), throws = true)
+
+          validateTopLevelTupleFields(deserializer, inputs)
+          val resolved = resolveExpression(
+            deserializer, LocalRelation(inputs), throws = true)
+          validateNestedTupleFields(resolved)
+          resolved
+      }
+    }
+
+    private def fail(schema: StructType, maxOrdinal: Int): Unit = {
+      throw new AnalysisException(s"Try to map ${schema.simpleString} to Tuple${maxOrdinal + 1}, " +
+        "but failed as the number of fields does not line up.")
+    }
+
+    /**
+     * For each top-level Tuple field, we use [[GetColumnByOrdinal]] to get its corresponding column
+     * by position.  However, the actual number of columns may be different from the number of Tuple
+     * fields.  This method is used to check the number of columns and fields, and throw an
+     * exception if they do not match.
+     */
+    private def validateTopLevelTupleFields(
+        deserializer: Expression, inputs: Seq[Attribute]): Unit = {
+      val ordinals = deserializer.collect {
+        case GetColumnByOrdinal(ordinal, _) => ordinal
+      }.distinct.sorted
+
+      if (ordinals.nonEmpty && ordinals != inputs.indices) {
+        fail(inputs.toStructType, ordinals.last)
+      }
+    }
+
+    /**
+     * For each nested Tuple field, we use [[GetStructField]] to get its corresponding struct field
+     * by position.  However, the actual number of struct fields may be different from the number
+     * of nested Tuple fields.  This method is used to check the number of struct fields and nested
+     * Tuple fields, and throw an exception if they do not match.
+     */
+    private def validateNestedTupleFields(deserializer: Expression): Unit = {
+      val structChildToOrdinals = deserializer
+        .collect { case g: GetStructField => g }
+        .groupBy(_.child)
+        .mapValues(_.map(_.ordinal).distinct.sorted)
+
+      structChildToOrdinals.foreach { case (expr, ordinals) =>
+        val schema = expr.dataType.asInstanceOf[StructType]
+        if (ordinals != schema.indices) {
+          fail(schema, ordinals.last)
+        }
       }
     }
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
index e953eda784..b883546135 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala
@@ -366,3 +366,10 @@ case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq
   override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
   override lazy val resolved = false
 }
+
+case class GetColumnByOrdinal(ordinal: Int, dataType: DataType) extends LeafExpression
+  with Unevaluable with NonSQLExpression {
+  override def foldable: Boolean = throw new UnresolvedException(this, "foldable")
+  override def nullable: Boolean = throw new UnresolvedException(this, "nullable")
+  override lazy val resolved = false
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 2296946cd7..cc59d06fa3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -17,19 +17,17 @@
 
 package org.apache.spark.sql.catalyst.encoders
 
-import java.util.concurrent.ConcurrentMap
-
 import scala.reflect.ClassTag
 import scala.reflect.runtime.universe.{typeTag, TypeTag}
 
-import org.apache.spark.sql.{AnalysisException, Encoder}
+import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection}
-import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue}
+import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
 import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance}
 import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation}
 import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
 import org.apache.spark.util.Utils
 
@@ -121,15 +119,15 @@ object ExpressionEncoder {
     val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) =>
       if (enc.flat) {
         enc.deserializer.transform {
-          case b: BoundReference => b.copy(ordinal = index)
+          case g: GetColumnByOrdinal => g.copy(ordinal = index)
         }
       } else {
-        val input = BoundReference(index, enc.schema, nullable = true)
+        val input = GetColumnByOrdinal(index, enc.schema)
         val deserialized = enc.deserializer.transformUp {
           case UnresolvedAttribute(nameParts) =>
             assert(nameParts.length == 1)
             UnresolvedExtractValue(input, Literal(nameParts.head))
-          case BoundReference(ordinal, dt, _) => GetStructField(input, ordinal)
+          case GetColumnByOrdinal(ordinal, _) => GetStructField(input, ordinal)
         }
         If(IsNull(input), Literal.create(null, deserialized.dataType), deserialized)
       }
@@ -192,6 +190,26 @@ case class ExpressionEncoder[T](
 
   if (flat) require(serializer.size == 1)
 
+  /**
+   * Returns a new copy of this encoder, where the `deserializer` is resolved and bound to the
+   * given schema.
+   *
+   * Note that, ideally encoder is used as a container of serde expressions, the resolution and
+   * binding stuff should happen inside query framework.  However, in some cases we need to
+   * use encoder as a function to do serialization directly(e.g. Dataset.collect), then we can use
+   * this method to do resolution and binding outside of query framework.
+   */
+  def resolveAndBind(
+      attrs: Seq[Attribute] = schema.toAttributes,
+      analyzer: Analyzer = SimpleAnalyzer): ExpressionEncoder[T] = {
+    val dummyPlan = CatalystSerde.deserialize(LocalRelation(attrs))(this)
+    val analyzedPlan = analyzer.execute(dummyPlan)
+    analyzer.checkAnalysis(analyzedPlan)
+    val resolved = SimplifyCasts(analyzedPlan).asInstanceOf[DeserializeToObject].deserializer
+    val bound = BindReferences.bindReference(resolved, attrs)
+    copy(deserializer = bound)
+  }
+
   @transient
   private lazy val extractProjection = GenerateUnsafeProjection.generate(serializer)
 
@@ -201,16 +219,6 @@ case class ExpressionEncoder[T](
   @transient
   private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil)
 
-  /**
-   * Returns this encoder where it has been bound to its own output (i.e. no remaping of columns
-   * is performed).
-   */
-  def defaultBinding: ExpressionEncoder[T] = {
-    val attrs = schema.toAttributes
-    resolve(attrs, OuterScopes.outerScopes).bind(attrs)
-  }
-
-
   /**
    * Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form
    * of this object.
@@ -236,7 +244,7 @@ case class ExpressionEncoder[T](
 
   /**
    * Returns an object of type `T`, extracting the required values from the provided row.  Note that
-   * you must `resolve` and `bind` an encoder to a specific schema before you can call this
+   * you must `resolveAndBind` an encoder to a specific schema before you can call this
    * function.
    */
   def fromRow(row: InternalRow): T = try {
@@ -259,94 +267,6 @@ case class ExpressionEncoder[T](
     })
   }
 
-  /**
-   * Validates `deserializer` to make sure it can be resolved by given schema, and produce
-   * friendly error messages to explain why it fails to resolve if there is something wrong.
-   */
-  def validate(schema: Seq[Attribute]): Unit = {
-    def fail(st: StructType, maxOrdinal: Int): Unit = {
-      throw new AnalysisException(s"Try to map ${st.simpleString} to Tuple${maxOrdinal + 1}, " +
-        "but failed as the number of fields does not line up.\n" +
-        " - Input schema: " + StructType.fromAttributes(schema).simpleString + "\n" +
-        " - Target schema: " + this.schema.simpleString)
-    }
-
-    // If this is a tuple encoder or tupled encoder, which means its leaf nodes are all
-    // `BoundReference`, make sure their ordinals are all valid.
-    var maxOrdinal = -1
-    deserializer.foreach {
-      case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal
-      case _ =>
-    }
-    if (maxOrdinal >= 0 && maxOrdinal != schema.length - 1) {
-      fail(StructType.fromAttributes(schema), maxOrdinal)
-    }
-
-    // If we have nested tuple, the `fromRowExpression` will contains `GetStructField` instead of
-    // `UnresolvedExtractValue`, so we need to check if their ordinals are all valid.
-    // Note that, `BoundReference` contains the expected type, but here we need the actual type, so
-    // we unbound it by the given `schema` and propagate the actual type to `GetStructField`, after
-    // we resolve the `fromRowExpression`.
-    val resolved = SimpleAnalyzer.resolveExpression(
-      deserializer,
-      LocalRelation(schema),
-      throws = true)
-
-    val unbound = resolved transform {
-      case b: BoundReference => schema(b.ordinal)
-    }
-
-    val exprToMaxOrdinal = scala.collection.mutable.HashMap.empty[Expression, Int]
-    unbound.foreach {
-      case g: GetStructField =>
-        val maxOrdinal = exprToMaxOrdinal.getOrElse(g.child, -1)
-        if (maxOrdinal < g.ordinal) {
-          exprToMaxOrdinal.update(g.child, g.ordinal)
-        }
-      case _ =>
-    }
-    exprToMaxOrdinal.foreach {
-      case (expr, maxOrdinal) =>
-        val schema = expr.dataType.asInstanceOf[StructType]
-        if (maxOrdinal != schema.length - 1) {
-          fail(schema, maxOrdinal)
-        }
-    }
-  }
-
-  /**
-   * Returns a new copy of this encoder, where the `deserializer` is resolved to the given schema.
-   */
-  def resolve(
-      schema: Seq[Attribute],
-      outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = {
-    // Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check
-    // analysis, go through optimizer, etc.
-    val plan = Project(
-      Alias(UnresolvedDeserializer(deserializer, schema), "")() :: Nil,
-      LocalRelation(schema))
-    val analyzedPlan = SimpleAnalyzer.execute(plan)
-    SimpleAnalyzer.checkAnalysis(analyzedPlan)
-    copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head)
-  }
-
-  /**
-   * Returns a copy of this encoder where the `deserializer` has been bound to the
-   * ordinals of the given schema.  Note that you need to first call resolve before bind.
-   */
-  def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = {
-    copy(deserializer = BindReferences.bindReference(deserializer, schema))
-  }
-
-  /**
-   * Returns a new encoder with input columns shifted by `delta` ordinals
-   */
-  def shift(delta: Int): ExpressionEncoder[T] = {
-    copy(deserializer = deserializer transform {
-      case r: BoundReference => r.copy(ordinal = r.ordinal + delta)
-    })
-  }
-
   protected val attrs = serializer.flatMap(_.collect {
     case _: UnresolvedAttribute => ""
     case a: Attribute => s"#${a.exprId}"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 0de9166aa2..3c6ae1c5cc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, GenericArrayData}
 import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.catalyst.analysis.GetColumnByOrdinal
 import org.apache.spark.sql.catalyst.expressions.objects._
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
@@ -210,12 +211,7 @@ object RowEncoder {
         case p: PythonUserDefinedType => p.sqlType
         case other => other
       }
-      val field = BoundReference(i, dt, f.nullable)
-      If(
-        IsNull(field),
-        Literal.create(null, externalDataTypeFor(dt)),
-        deserializerFor(field)
-      )
+      deserializerFor(GetColumnByOrdinal(i, dt))
     }
     CreateExternalRow(fields, schema)
   }
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 98ce5dd2ef..55d8adf040 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -30,26 +30,13 @@ object CatalystSerde {
     DeserializeToObject(deserializer, generateObjAttr[T], child)
   }
 
-  def deserialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): DeserializeToObject = {
-    val deserializer = UnresolvedDeserializer(encoder.deserializer)
-    DeserializeToObject(deserializer, generateObjAttrForRow(encoder), child)
-  }
-
   def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = {
     SerializeFromObject(encoderFor[T].namedExpressions, child)
   }
 
-  def serialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): SerializeFromObject = {
-    SerializeFromObject(encoder.namedExpressions, child)
-  }
-
   def generateObjAttr[T : Encoder]: Attribute = {
     AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)()
   }
-
-  def generateObjAttrForRow(encoder: ExpressionEncoder[Row]): Attribute = {
-    AttributeReference("obj", encoder.deserializer.dataType, nullable = false)()
-  }
 }
 
 /**
@@ -128,16 +115,16 @@ object MapPartitionsInR {
       schema: StructType,
       encoder: ExpressionEncoder[Row],
       child: LogicalPlan): LogicalPlan = {
-    val deserialized = CatalystSerde.deserialize(child, encoder)
+    val deserialized = CatalystSerde.deserialize(child)(encoder)
     val mapped = MapPartitionsInR(
       func,
       packageNames,
       broadcastVars,
       encoder.schema,
       schema,
-      CatalystSerde.generateObjAttrForRow(RowEncoder(schema)),
+      CatalystSerde.generateObjAttr(RowEncoder(schema)),
       deserialized)
-    CatalystSerde.serialize(mapped, RowEncoder(schema))
+    CatalystSerde.serialize(mapped)(RowEncoder(schema))
   }
 }
 
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
index 3ad0dae767..7251202c7b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/EncoderResolutionSuite.scala
@@ -41,17 +41,17 @@ class EncoderResolutionSuite extends PlanTest {
 
     // int type can be up cast to long type
     val attrs1 = Seq('a.string, 'b.int)
-    encoder.resolve(attrs1, null).bind(attrs1).fromRow(InternalRow(str, 1))
+    encoder.resolveAndBind(attrs1).fromRow(InternalRow(str, 1))
 
     // int type can be up cast to string type
     val attrs2 = Seq('a.int, 'b.long)
-    encoder.resolve(attrs2, null).bind(attrs2).fromRow(InternalRow(1, 2L))
+    encoder.resolveAndBind(attrs2).fromRow(InternalRow(1, 2L))
   }
 
   test("real type doesn't match encoder schema but they are compatible: nested product") {
     val encoder = ExpressionEncoder[ComplexClass]
     val attrs = Seq('a.int, 'b.struct('a.int, 'b.long))
-    encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L)))
+    encoder.resolveAndBind(attrs).fromRow(InternalRow(1, InternalRow(2, 3L)))
   }
 
   test("real type doesn't match encoder schema but they are compatible: tupled encoder") {
@@ -59,7 +59,7 @@ class EncoderResolutionSuite extends PlanTest {
       ExpressionEncoder[StringLongClass],
       ExpressionEncoder[Long])
     val attrs = Seq('a.struct('a.string, 'b.byte), 'b.int)
-    encoder.resolve(attrs, null).bind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
+    encoder.resolveAndBind(attrs).fromRow(InternalRow(InternalRow(str, 1.toByte), 2))
   }
 
   test("nullability of array type element should not fail analysis") {
@@ -67,7 +67,7 @@ class EncoderResolutionSuite extends PlanTest {
     val attrs = 'a.array(IntegerType) :: Nil
 
     // It should pass analysis
-    val bound = encoder.resolve(attrs, null).bind(attrs)
+    val bound = encoder.resolveAndBind(attrs)
 
     // If no null values appear, it should works fine
     bound.fromRow(InternalRow(new GenericArrayData(Array(1, 2))))
@@ -84,20 +84,16 @@ class EncoderResolutionSuite extends PlanTest {
 
     {
       val attrs = Seq('a.string, 'b.long, 'c.int)
-      assert(intercept[AnalysisException](encoder.validate(attrs)).message ==
+      assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
         "Try to map struct<a:string,b:bigint,c:int> to Tuple2, " +
-          "but failed as the number of fields does not line up.\n" +
-          " - Input schema: struct<a:string,b:bigint,c:int>\n" +
-          " - Target schema: struct<_1:string,_2:bigint>")
+          "but failed as the number of fields does not line up.")
     }
 
     {
       val attrs = Seq('a.string)
-      assert(intercept[AnalysisException](encoder.validate(attrs)).message ==
+      assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
         "Try to map struct<a:string> to Tuple2, " +
-          "but failed as the number of fields does not line up.\n" +
-          " - Input schema: struct<a:string>\n" +
-          " - Target schema: struct<_1:string,_2:bigint>")
+          "but failed as the number of fields does not line up.")
     }
   }
 
@@ -106,26 +102,22 @@ class EncoderResolutionSuite extends PlanTest {
 
     {
       val attrs = Seq('a.string, 'b.struct('x.long, 'y.string, 'z.int))
-      assert(intercept[AnalysisException](encoder.validate(attrs)).message ==
+      assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
         "Try to map struct<x:bigint,y:string,z:int> to Tuple2, " +
-          "but failed as the number of fields does not line up.\n" +
-          " - Input schema: struct<a:string,b:struct<x:bigint,y:string,z:int>>\n" +
-          " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>")
+          "but failed as the number of fields does not line up.")
     }
 
     {
       val attrs = Seq('a.string, 'b.struct('x.long))
-      assert(intercept[AnalysisException](encoder.validate(attrs)).message ==
+      assert(intercept[AnalysisException](encoder.resolveAndBind(attrs)).message ==
         "Try to map struct<x:bigint> to Tuple2, " +
-          "but failed as the number of fields does not line up.\n" +
-          " - Input schema: struct<a:string,b:struct<x:bigint>>\n" +
-          " - Target schema: struct<_1:string,_2:struct<_1:bigint,_2:string>>")
+          "but failed as the number of fields does not line up.")
     }
   }
 
   test("throw exception if real type is not compatible with encoder schema") {
     val msg1 = intercept[AnalysisException] {
-      ExpressionEncoder[StringIntClass].resolve(Seq('a.string, 'b.long), null)
+      ExpressionEncoder[StringIntClass].resolveAndBind(Seq('a.string, 'b.long))
     }.message
     assert(msg1 ==
       s"""
@@ -138,7 +130,7 @@ class EncoderResolutionSuite extends PlanTest {
 
     val msg2 = intercept[AnalysisException] {
       val structType = new StructType().add("a", StringType).add("b", DecimalType.SYSTEM_DEFAULT)
-      ExpressionEncoder[ComplexClass].resolve(Seq('a.long, 'b.struct(structType)), null)
+      ExpressionEncoder[ComplexClass].resolveAndBind(Seq('a.long, 'b.struct(structType)))
     }.message
     assert(msg2 ==
       s"""
@@ -171,7 +163,7 @@ class EncoderResolutionSuite extends PlanTest {
     val to = ExpressionEncoder[U]
     val catalystType = from.schema.head.dataType.simpleString
     test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should success") {
-      to.resolve(from.schema.toAttributes, null)
+      to.resolveAndBind(from.schema.toAttributes)
     }
   }
 
@@ -180,7 +172,7 @@ class EncoderResolutionSuite extends PlanTest {
     val to = ExpressionEncoder[U]
     val catalystType = from.schema.head.dataType.simpleString
     test(s"cast from $catalystType to ${implicitly[TypeTag[U]].tpe} should fail") {
-      intercept[AnalysisException](to.resolve(from.schema.toAttributes, null))
+      intercept[AnalysisException](to.resolveAndBind(from.schema.toAttributes))
     }
   }
 }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index 232dcc9ee5..a1f9259f13 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -27,6 +27,7 @@ import scala.reflect.runtime.universe.TypeTag
 import org.apache.spark.sql.Encoders
 import org.apache.spark.sql.catalyst.{OptionalData, PrimitiveData}
 import org.apache.spark.sql.catalyst.analysis.AnalysisTest
+import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
@@ -334,7 +335,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
       val encoder = implicitly[ExpressionEncoder[T]]
       val row = encoder.toRow(input)
       val schema = encoder.schema.toAttributes
-      val boundEncoder = encoder.defaultBinding
+      val boundEncoder = encoder.resolveAndBind()
       val convertedBack = try boundEncoder.fromRow(row) catch {
         case e: Exception =>
           fail(
@@ -350,12 +351,8 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
       }
 
       // Test the correct resolution of serialization / deserialization.
-      val attr = AttributeReference("obj", ObjectType(encoder.clsTag.runtimeClass))()
-      val inputPlan = LocalRelation(attr)
-      val plan =
-        Project(Alias(encoder.deserializer, "obj")() :: Nil,
-          Project(encoder.namedExpressions,
-            inputPlan))
+      val attr = AttributeReference("obj", encoder.deserializer.dataType)()
+      val plan = LocalRelation(attr).serialize[T].deserialize[T]
       assertAnalysisSuccess(plan)
 
       val isCorrect = (input, convertedBack) match {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 39fcc7225b..6f1bc80c1c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -135,7 +135,7 @@ class RowEncoderSuite extends SparkFunSuite {
           .add("string", StringType)
           .add("double", DoubleType))
 
-    val encoder = RowEncoder(schema)
+    val encoder = RowEncoder(schema).resolveAndBind()
 
     val input: Row = Row((100, "test", 0.123))
     val row = encoder.toRow(input)
@@ -152,7 +152,7 @@ class RowEncoderSuite extends SparkFunSuite {
       .add("scala_decimal", DecimalType.SYSTEM_DEFAULT)
       .add("catalyst_decimal", DecimalType.SYSTEM_DEFAULT)
 
-    val encoder = RowEncoder(schema)
+    val encoder = RowEncoder(schema).resolveAndBind()
 
     val javaDecimal = new java.math.BigDecimal("1234.5678")
     val scalaDecimal = BigDecimal("1234.5678")
@@ -169,7 +169,7 @@ class RowEncoderSuite extends SparkFunSuite {
 
   test("RowEncoder should preserve decimal precision and scale") {
     val schema = new StructType().add("decimal", DecimalType(10, 5), false)
-    val encoder = RowEncoder(schema)
+    val encoder = RowEncoder(schema).resolveAndBind()
     val decimal = Decimal("67123.45")
     val input = Row(decimal)
     val row = encoder.toRow(input)
@@ -179,7 +179,7 @@ class RowEncoderSuite extends SparkFunSuite {
 
   test("RowEncoder should preserve schema nullability") {
     val schema = new StructType().add("int", IntegerType, nullable = false)
-    val encoder = RowEncoder(schema)
+    val encoder = RowEncoder(schema).resolveAndBind()
     assert(encoder.serializer.length == 1)
     assert(encoder.serializer.head.dataType == IntegerType)
     assert(encoder.serializer.head.nullable == false)
@@ -195,7 +195,7 @@ class RowEncoderSuite extends SparkFunSuite {
           new StructType().add("int", IntegerType, nullable = false),
           nullable = false),
       nullable = false)
-    val encoder = RowEncoder(schema)
+    val encoder = RowEncoder(schema).resolveAndBind()
     assert(encoder.serializer.length == 1)
     assert(encoder.serializer.head.dataType ==
       new StructType()
@@ -212,7 +212,7 @@ class RowEncoderSuite extends SparkFunSuite {
       .add("array", ArrayType(IntegerType))
       .add("nestedArray", ArrayType(ArrayType(StringType)))
       .add("deepNestedArray", ArrayType(ArrayType(ArrayType(LongType))))
-    val encoder = RowEncoder(schema)
+    val encoder = RowEncoder(schema).resolveAndBind()
     val input = Row(
       Array(1, 2, null),
       Array(Array("abc", null), null),
@@ -226,7 +226,7 @@ class RowEncoderSuite extends SparkFunSuite {
 
   private def encodeDecodeTest(schema: StructType): Unit = {
     test(s"encode/decode: ${schema.simpleString}") {
-      val encoder = RowEncoder(schema)
+      val encoder = RowEncoder(schema).resolveAndBind()
       val inputGenerator = RandomDataGenerator.forType(schema, nullable = false).get
 
       var input: Row = null
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 369b772d32..96c871d034 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -192,24 +192,24 @@ class Dataset[T] private[sql](
   }
 
   /**
-   * An unresolved version of the internal encoder for the type of this [[Dataset]]. This one is
-   * marked implicit so that we can use it when constructing new [[Dataset]] objects that have the
-   * same object type (that will be possibly resolved to a different schema).
+   * Currently [[ExpressionEncoder]] is the only implementation of [[Encoder]], here we turn the
+   * passed in encoder to [[ExpressionEncoder]] explicitly, and mark it implicit so that we can use
+   * it when constructing new [[Dataset]] objects that have the same object type (that will be
+   * possibly resolved to a different schema).
    */
-  private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(encoder)
-  unresolvedTEncoder.validate(logicalPlan.output)
-
-  /** The encoder for this [[Dataset]] that has been resolved to its output schema. */
-  private[sql] val resolvedTEncoder: ExpressionEncoder[T] =
-    unresolvedTEncoder.resolve(logicalPlan.output, OuterScopes.outerScopes)
+  private[sql] implicit val exprEnc: ExpressionEncoder[T] = encoderFor(encoder)
 
   /**
-   * The encoder where the expressions used to construct an object from an input row have been
-   * bound to the ordinals of this [[Dataset]]'s output schema.
+   * Encoder is used mostly as a container of serde expressions in [[Dataset]].  We build logical
+   * plans by these serde expressions and execute it within the query framework.  However, for
+   * performance reasons we may want to use encoder as a function to deserialize internal rows to
+   * custom objects, e.g. collect.  Here we resolve and bind the encoder so that we can call its
+   * `fromRow` method later.
    */
-  private[sql] val boundTEncoder = resolvedTEncoder.bind(logicalPlan.output)
+  private val boundEnc =
+    exprEnc.resolveAndBind(logicalPlan.output, sparkSession.sessionState.analyzer)
 
-  private implicit def classTag = unresolvedTEncoder.clsTag
+  private implicit def classTag = exprEnc.clsTag
 
   // sqlContext must be val because a stable identifier is expected when you import implicits
   @transient lazy val sqlContext: SQLContext = sparkSession.sqlContext
@@ -761,7 +761,7 @@ class Dataset[T] private[sql](
     // Note that we do this before joining them, to enable the join operator to return null for one
     // side, in cases like outer-join.
     val left = {
-      val combined = if (this.unresolvedTEncoder.flat) {
+      val combined = if (this.exprEnc.flat) {
         assert(joined.left.output.length == 1)
         Alias(joined.left.output.head, "_1")()
       } else {
@@ -771,7 +771,7 @@ class Dataset[T] private[sql](
     }
 
     val right = {
-      val combined = if (other.unresolvedTEncoder.flat) {
+      val combined = if (other.exprEnc.flat) {
         assert(joined.right.output.length == 1)
         Alias(joined.right.output.head, "_2")()
       } else {
@@ -784,14 +784,14 @@ class Dataset[T] private[sql](
     // combine the outputs of each join side.
     val conditionExpr = joined.condition.get transformUp {
       case a: Attribute if joined.left.outputSet.contains(a) =>
-        if (this.unresolvedTEncoder.flat) {
+        if (this.exprEnc.flat) {
           left.output.head
         } else {
           val index = joined.left.output.indexWhere(_.exprId == a.exprId)
           GetStructField(left.output.head, index)
         }
       case a: Attribute if joined.right.outputSet.contains(a) =>
-        if (other.unresolvedTEncoder.flat) {
+        if (other.exprEnc.flat) {
           right.output.head
         } else {
           val index = joined.right.output.indexWhere(_.exprId == a.exprId)
@@ -800,7 +800,7 @@ class Dataset[T] private[sql](
     }
 
     implicit val tuple2Encoder: Encoder[(T, U)] =
-      ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder)
+      ExpressionEncoder.tuple(this.exprEnc, other.exprEnc)
 
     withTypedPlan(Join(left, right, joined.joinType, Some(conditionExpr)))
   }
@@ -1024,7 +1024,7 @@ class Dataset[T] private[sql](
       sparkSession,
       Project(
         c1.withInputType(
-          unresolvedTEncoder.deserializer,
+          exprEnc.deserializer,
           logicalPlan.output).named :: Nil,
         logicalPlan),
       implicitly[Encoder[U1]])
@@ -1038,7 +1038,7 @@ class Dataset[T] private[sql](
   protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
     val encoders = columns.map(_.encoder)
     val namedColumns =
-      columns.map(_.withInputType(unresolvedTEncoder.deserializer, logicalPlan.output).named)
+      columns.map(_.withInputType(exprEnc.deserializer, logicalPlan.output).named)
     val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan))
     new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders))
   }
@@ -2153,14 +2153,14 @@ class Dataset[T] private[sql](
    */
   def collectAsList(): java.util.List[T] = withCallback("collectAsList", toDF()) { _ =>
     withNewExecutionId {
-      val values = queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow)
+      val values = queryExecution.executedPlan.executeCollect().map(boundEnc.fromRow)
       java.util.Arrays.asList(values : _*)
     }
   }
 
   private def collect(needCallback: Boolean): Array[T] = {
     def execute(): Array[T] = withNewExecutionId {
-      queryExecution.executedPlan.executeCollect().map(boundTEncoder.fromRow)
+      queryExecution.executedPlan.executeCollect().map(boundEnc.fromRow)
     }
 
     if (needCallback) {
@@ -2184,7 +2184,7 @@ class Dataset[T] private[sql](
    */
   def toLocalIterator(): java.util.Iterator[T] = withCallback("toLocalIterator", toDF()) { _ =>
     withNewExecutionId {
-      queryExecution.executedPlan.executeToIterator().map(boundTEncoder.fromRow).asJava
+      queryExecution.executedPlan.executeToIterator().map(boundEnc.fromRow).asJava
     }
   }
 
@@ -2322,7 +2322,7 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   lazy val rdd: RDD[T] = {
-    val objectType = unresolvedTEncoder.deserializer.dataType
+    val objectType = exprEnc.deserializer.dataType
     val deserialized = CatalystSerde.deserialize[T](logicalPlan)
     sparkSession.sessionState.executePlan(deserialized).toRdd.mapPartitions { rows =>
       rows.map(_.get(0, objectType).asInstanceOf[T])
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index 53f4ea647c..a6867a67ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -42,17 +42,9 @@ class KeyValueGroupedDataset[K, V] private[sql](
     private val dataAttributes: Seq[Attribute],
     private val groupingAttributes: Seq[Attribute]) extends Serializable {
 
-  // Similar to [[Dataset]], we use unresolved encoders for later composition and resolved encoders
-  // when constructing new logical plans that will operate on the output of the current
-  // queryexecution.
-
-  private implicit val unresolvedKEncoder = encoderFor(kEncoder)
-  private implicit val unresolvedVEncoder = encoderFor(vEncoder)
-
-  private val resolvedKEncoder =
-    unresolvedKEncoder.resolve(groupingAttributes, OuterScopes.outerScopes)
-  private val resolvedVEncoder =
-    unresolvedVEncoder.resolve(dataAttributes, OuterScopes.outerScopes)
+  // Similar to [[Dataset]], we turn the passed in encoder to `ExpressionEncoder` explicitly.
+  private implicit val kExprEnc = encoderFor(kEncoder)
+  private implicit val vExprEnc = encoderFor(vEncoder)
 
   private def logicalPlan = queryExecution.analyzed
   private def sparkSession = queryExecution.sparkSession
@@ -67,7 +59,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
   def keyAs[L : Encoder]: KeyValueGroupedDataset[L, V] =
     new KeyValueGroupedDataset(
       encoderFor[L],
-      unresolvedVEncoder,
+      vExprEnc,
       queryExecution,
       dataAttributes,
       groupingAttributes)
@@ -187,7 +179,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
   def reduceGroups(f: (V, V) => V): Dataset[(K, V)] = {
     val func = (key: K, it: Iterator[V]) => Iterator((key, it.reduce(f)))
 
-    implicit val resultEncoder = ExpressionEncoder.tuple(unresolvedKEncoder, unresolvedVEncoder)
+    implicit val resultEncoder = ExpressionEncoder.tuple(kExprEnc, vExprEnc)
     flatMapGroups(func)
   }
 
@@ -209,8 +201,8 @@ class KeyValueGroupedDataset[K, V] private[sql](
   protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
     val encoders = columns.map(_.encoder)
     val namedColumns =
-      columns.map(_.withInputType(unresolvedVEncoder.deserializer, dataAttributes).named)
-    val keyColumn = if (resolvedKEncoder.flat) {
+      columns.map(_.withInputType(vExprEnc.deserializer, dataAttributes).named)
+    val keyColumn = if (kExprEnc.flat) {
       assert(groupingAttributes.length == 1)
       groupingAttributes.head
     } else {
@@ -222,7 +214,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
     new Dataset(
       sparkSession,
       execution,
-      ExpressionEncoder.tuple(unresolvedKEncoder +: encoders))
+      ExpressionEncoder.tuple(kExprEnc +: encoders))
   }
 
   /**
@@ -287,7 +279,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
   def cogroup[U, R : Encoder](
       other: KeyValueGroupedDataset[K, U])(
       f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
-    implicit val uEncoder = other.unresolvedVEncoder
+    implicit val uEncoder = other.vExprEnc
     Dataset[R](
       sparkSession,
       CoGroup(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 58850a7d4b..49b6eab8db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -215,7 +215,7 @@ class RelationalGroupedDataset protected[sql](
   def agg(expr: Column, exprs: Column*): DataFrame = {
     toDF((expr +: exprs).map {
       case typed: TypedColumn[_, _] =>
-        typed.withInputType(df.unresolvedTEncoder.deserializer, df.logicalPlan.output).expr
+        typed.withInputType(df.exprEnc.deserializer, df.logicalPlan.output).expr
       case c => c.expr
     })
   }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 8f94184764..ecb56e2a28 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -33,9 +33,9 @@ object TypedAggregateExpression {
       aggregator: Aggregator[_, BUF, OUT]): TypedAggregateExpression = {
     val bufferEncoder = encoderFor[BUF]
     val bufferSerializer = bufferEncoder.namedExpressions
-    val bufferDeserializer = bufferEncoder.deserializer.transform {
-      case b: BoundReference => bufferSerializer(b.ordinal).toAttribute
-    }
+    val bufferDeserializer = UnresolvedDeserializer(
+      bufferEncoder.deserializer,
+      bufferSerializer.map(_.toAttribute))
 
     val outputEncoder = encoderFor[OUT]
     val outputType = if (outputEncoder.flat) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index d89e98645b..4dbd1665e4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -924,7 +924,7 @@ object functions {
    * @since 1.5.0
    */
   def broadcast[T](df: Dataset[T]): Dataset[T] = {
-    Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.unresolvedTEncoder)
+    Dataset[T](df.sparkSession, BroadcastHint(df.logicalPlan))(df.exprEnc)
   }
 
   /**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index df8f4b0610..d1c232974e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -566,18 +566,14 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
     }.message
     assert(message ==
       "Try to map struct<a:string,b:int> to Tuple3, " +
-        "but failed as the number of fields does not line up.\n" +
-        " - Input schema: struct<a:string,b:int>\n" +
-        " - Target schema: struct<_1:string,_2:int,_3:bigint>")
+        "but failed as the number of fields does not line up.")
 
     val message2 = intercept[AnalysisException] {
       ds.as[Tuple1[String]]
     }.message
     assert(message2 ==
       "Try to map struct<a:string,b:int> to Tuple1, " +
-        "but failed as the number of fields does not line up.\n" +
-        " - Input schema: struct<a:string,b:int>\n" +
-        " - Target schema: struct<_1:string>")
+        "but failed as the number of fields does not line up.")
   }
 
   test("SPARK-13440: Resolving option fields") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index a1a9b66c1f..9c044f4e8f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -81,7 +81,7 @@ abstract class QueryTest extends PlanTest {
       expectedAnswer: T*): Unit = {
     checkAnswer(
       ds.toDF(),
-      spark.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq)
+      spark.createDataset(expectedAnswer)(ds.exprEnc).toDF().collect().toSeq)
 
     checkDecoding(ds, expectedAnswer: _*)
   }
@@ -94,8 +94,8 @@ abstract class QueryTest extends PlanTest {
         fail(
           s"""
              |Exception collecting dataset as objects
-             |${ds.resolvedTEncoder}
-             |${ds.resolvedTEncoder.deserializer.treeString}
+             |${ds.exprEnc}
+             |${ds.exprEnc.deserializer.treeString}
              |${ds.queryExecution}
            """.stripMargin, e)
     }
@@ -114,7 +114,7 @@ abstract class QueryTest extends PlanTest {
       fail(
         s"""Decoded objects do not match expected objects:
             |$comparison
-            |${ds.resolvedTEncoder.deserializer.treeString}
+            |${ds.exprEnc.deserializer.treeString}
          """.stripMargin)
     }
   }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala
index 6f10e4b805..80340b5552 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/GroupedIteratorSuite.scala
@@ -27,7 +27,7 @@ class GroupedIteratorSuite extends SparkFunSuite {
 
   test("basic") {
     val schema = new StructType().add("i", IntegerType).add("s", StringType)
-    val encoder = RowEncoder(schema)
+    val encoder = RowEncoder(schema).resolveAndBind()
     val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
     val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
       Seq('i.int.at(0)), schema.toAttributes)
@@ -45,7 +45,7 @@ class GroupedIteratorSuite extends SparkFunSuite {
 
   test("group by 2 columns") {
     val schema = new StructType().add("i", IntegerType).add("l", LongType).add("s", StringType)
-    val encoder = RowEncoder(schema)
+    val encoder = RowEncoder(schema).resolveAndBind()
 
     val input = Seq(
       Row(1, 2L, "a"),
@@ -72,7 +72,7 @@ class GroupedIteratorSuite extends SparkFunSuite {
 
   test("do nothing to the value iterator") {
     val schema = new StructType().add("i", IntegerType).add("s", StringType)
-    val encoder = RowEncoder(schema)
+    val encoder = RowEncoder(schema).resolveAndBind()
     val input = Seq(Row(1, "a"), Row(1, "b"), Row(2, "c"))
     val grouped = GroupedIterator(input.iterator.map(encoder.toRow),
       Seq('i.int.at(0)), schema.toAttributes)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index dd8672aa64..194c3e7307 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -110,7 +110,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
   object CheckAnswer {
     def apply[A : Encoder](data: A*): CheckAnswerRows = {
       val encoder = encoderFor[A]
-      val toExternalRow = RowEncoder(encoder.schema)
+      val toExternalRow = RowEncoder(encoder.schema).resolveAndBind()
       CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), false)
     }
 
@@ -124,7 +124,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with Timeouts {
   object CheckLastBatch {
     def apply[A : Encoder](data: A*): CheckAnswerRows = {
       val encoder = encoderFor[A]
-      val toExternalRow = RowEncoder(encoder.schema)
+      val toExternalRow = RowEncoder(encoder.schema).resolveAndBind()
       CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), true)
     }
 
-- 
GitLab