From dfa61f7b136ae060bbe04e3c0da1148da41018c7 Mon Sep 17 00:00:00 2001
From: Shixiong Zhu <shixiong@databricks.com>
Date: Fri, 20 May 2016 12:38:46 -0700
Subject: [PATCH] [SPARK-15190][SQL] Support using SQLUserDefinedType for case
 classes

## What changes were proposed in this pull request?

Right now inferring the schema for case classes happens before searching the SQLUserDefinedType annotation, so the SQLUserDefinedType annotation for case classes doesn't work.

This PR simply changes the inferring order to resolve it. I also reenabled the java.math.BigDecimal test and added two tests for `List`.

## How was this patch tested?

`encodeDecodeTest(UDTCaseClass(new java.net.URI("http://spark.apache.org/")), "udt with case class")`

Author: Shixiong Zhu <shixiong@databricks.com>

Closes #12965 from zsxwing/SPARK-15190.
---
 .../spark/sql/catalyst/ScalaReflection.scala  | 70 +++++++++----------
 .../encoders/ExpressionEncoderSuite.scala     | 28 +++++++-
 2 files changed, 62 insertions(+), 36 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 58df651da2..36989a20cb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -348,6 +348,23 @@ object ScalaReflection extends ScalaReflection {
           "toScalaMap",
           keyData :: valueData :: Nil)
 
+      case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
+        val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
+        val obj = NewInstance(
+          udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
+          Nil,
+          dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
+        Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
+
+      case t if UDTRegistration.exists(getClassNameFromType(t)) =>
+        val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
+          .asInstanceOf[UserDefinedType[_]]
+        val obj = NewInstance(
+          udt.getClass,
+          Nil,
+          dataType = ObjectType(udt.getClass))
+        Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
+
       case t if definedByConstructorParams(t) =>
         val params = getConstructorParameters(t)
 
@@ -388,23 +405,6 @@ object ScalaReflection extends ScalaReflection {
         } else {
           newInstance
         }
-
-      case t if t.typeSymbol.annotations.exists(_.tpe =:= typeOf[SQLUserDefinedType]) =>
-        val udt = getClassFromType(t).getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance()
-        val obj = NewInstance(
-          udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt(),
-          Nil,
-          dataType = ObjectType(udt.userClass.getAnnotation(classOf[SQLUserDefinedType]).udt()))
-        Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
-
-      case t if UDTRegistration.exists(getClassNameFromType(t)) =>
-        val udt = UDTRegistration.getUDTFor(getClassNameFromType(t)).get.newInstance()
-          .asInstanceOf[UserDefinedType[_]]
-        val obj = NewInstance(
-          udt.getClass,
-          Nil,
-          dataType = ObjectType(udt.getClass))
-        Invoke(obj, "deserialize", ObjectType(udt.userClass), getPath :: Nil)
     }
   }
 
@@ -522,17 +522,6 @@ object ScalaReflection extends ScalaReflection {
           val TypeRef(_, _, Seq(elementType)) = t
           toCatalystArray(inputObject, elementType)
 
-        case t if definedByConstructorParams(t) =>
-          val params = getConstructorParameters(t)
-          val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
-            val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
-            val clsName = getClassNameFromType(fieldType)
-            val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
-            expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil
-          })
-          val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
-          expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
-
         case t if t <:< localTypeOf[Array[_]] =>
           val TypeRef(_, _, Seq(elementType)) = t
           toCatalystArray(inputObject, elementType)
@@ -645,6 +634,17 @@ object ScalaReflection extends ScalaReflection {
             dataType = ObjectType(udt.getClass))
           Invoke(obj, "serialize", udt.sqlType, inputObject :: Nil)
 
+        case t if definedByConstructorParams(t) =>
+          val params = getConstructorParameters(t)
+          val nonNullOutput = CreateNamedStruct(params.flatMap { case (fieldName, fieldType) =>
+            val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType))
+            val clsName = getClassNameFromType(fieldType)
+            val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath
+            expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil
+          })
+          val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType)
+          expressions.If(IsNull(inputObject), nullOutput, nonNullOutput)
+
         case other =>
           throw new UnsupportedOperationException(
             s"No Encoder found for $tpe\n" + walkedTypePath.mkString("\n"))
@@ -743,13 +743,6 @@ object ScalaReflection extends ScalaReflection {
         val Schema(valueDataType, valueNullable) = schemaFor(valueType)
         Schema(MapType(schemaFor(keyType).dataType,
           valueDataType, valueContainsNull = valueNullable), nullable = true)
-      case t if definedByConstructorParams(t) =>
-        val params = getConstructorParameters(t)
-        Schema(StructType(
-          params.map { case (fieldName, fieldType) =>
-            val Schema(dataType, nullable) = schemaFor(fieldType)
-            StructField(fieldName, dataType, nullable)
-          }), nullable = true)
       case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true)
       case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true)
       case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true)
@@ -775,6 +768,13 @@ object ScalaReflection extends ScalaReflection {
       case t if t <:< definitions.ShortTpe => Schema(ShortType, nullable = false)
       case t if t <:< definitions.ByteTpe => Schema(ByteType, nullable = false)
       case t if t <:< definitions.BooleanTpe => Schema(BooleanType, nullable = false)
+      case t if definedByConstructorParams(t) =>
+        val params = getConstructorParameters(t)
+        Schema(StructType(
+          params.map { case (fieldName, fieldType) =>
+            val Schema(dataType, nullable) = schemaFor(fieldType)
+            StructField(fieldName, dataType, nullable)
+          }), nullable = true)
       case other =>
         throw new UnsupportedOperationException(s"Schema for type $other is not supported")
     }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index d4387890b4..3d97113b52 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -31,7 +31,8 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference}
 import org.apache.spark.sql.catalyst.plans.PlanTest
 import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, Project}
 import org.apache.spark.sql.catalyst.util.ArrayData
-import org.apache.spark.sql.types.{ArrayType, Decimal, ObjectType, StructType}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
 
 case class RepeatedStruct(s: Seq[PrimitiveData])
 
@@ -86,6 +87,25 @@ class JavaSerializable(val value: Int) extends Serializable {
   }
 }
 
+/** For testing UDT for a case class */
+@SQLUserDefinedType(udt = classOf[UDTForCaseClass])
+case class UDTCaseClass(uri: java.net.URI)
+
+class UDTForCaseClass extends UserDefinedType[UDTCaseClass] {
+
+  override def sqlType: DataType = StringType
+
+  override def serialize(obj: UDTCaseClass): UTF8String = {
+    UTF8String.fromString(obj.uri.toString)
+  }
+
+  override def userClass: Class[UDTCaseClass] = classOf[UDTCaseClass]
+
+  override def deserialize(datum: Any): UDTCaseClass = datum match {
+    case uri: UTF8String => UDTCaseClass(new java.net.URI(uri.toString))
+  }
+}
+
 class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
   OuterScopes.addOuterScope(this)
 
@@ -147,6 +167,12 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest {
   encodeDecodeTest(Tuple1[Seq[Int]](null), "null seq in tuple")
   encodeDecodeTest(Tuple1[Map[String, String]](null), "null map in tuple")
 
+  encodeDecodeTest(List(1, 2), "list of int")
+  encodeDecodeTest(List("a", null), "list with String and null")
+
+  encodeDecodeTest(
+    UDTCaseClass(new java.net.URI("http://spark.apache.org/")), "udt with case class")
+
   // Kryo encoders
   encodeDecodeTest("hello", "kryo string")(encoderFor(Encoders.kryo[String]))
   encodeDecodeTest(new KryoSerializable(15), "kryo object")(
-- 
GitLab