From 76e3ba4264c4a0bc2c33ae6ac862fc40bc302d83 Mon Sep 17 00:00:00 2001
From: Michael Armbrust <michael@databricks.com>
Date: Thu, 28 Aug 2014 00:15:23 -0700
Subject: [PATCH] [SPARK-3230][SQL] Fix udfs that return structs

We need to convert the case classes into Rows.

Author: Michael Armbrust <michael@databricks.com>

Closes #2133 from marmbrus/structUdfs and squashes the following commits:

189722f [Michael Armbrust] Merge remote-tracking branch 'origin/master' into structUdfs
8e29b1c [Michael Armbrust] Use existing function
d8d0b76 [Michael Armbrust] Fix udfs that return structs
---
 .../apache/spark/sql/catalyst/ScalaReflection.scala  | 12 ++++++++++--
 .../spark/sql/catalyst/expressions/ScalaUdf.scala    |  7 ++++++-
 .../apache/spark/sql/execution/basicOperators.scala  | 11 ++---------
 .../test/scala/org/apache/spark/sql/UDFSuite.scala   | 12 ++++++++++++
 4 files changed, 30 insertions(+), 12 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 6b6b636cd9..88a8fa7c28 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst
 
 import java.sql.Timestamp
 
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference}
 import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
 import org.apache.spark.sql.catalyst.types._
 
@@ -32,6 +31,15 @@ object ScalaReflection {
 
   case class Schema(dataType: DataType, nullable: Boolean)
 
+  /** Converts Scala objects to catalyst rows / types */
+  def convertToCatalyst(a: Any): Any = a match {
+    case o: Option[_] => o.orNull
+    case s: Seq[_] => s.map(convertToCatalyst)
+    case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
+    case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
+    case other => other
+  }
+
   /** Returns a Sequence of attributes for the given case class type. */
   def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match {
     case Schema(s: StructType, _) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
index 589816ccec..1b687a443e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.catalyst.expressions
 
+import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.types.DataType
 import org.apache.spark.util.ClosureCleaner
 
@@ -27,6 +28,8 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
 
   def nullable = true
 
+  override def toString = s"scalaUDF(${children.mkString(",")})"
+
   /** This method has been generated by this script
 
     (1 to 22).map { x =>
@@ -44,7 +47,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
 
   // scalastyle:off
   override def eval(input: Row): Any = {
-    children.size match {
+    val result = children.size match {
       case 0 => function.asInstanceOf[() => Any]()
       case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input))
       case 2 =>
@@ -343,5 +346,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi
           children(21).eval(input))
     }
     // scalastyle:on
+
+    ScalaReflection.convertToCatalyst(result)
   }
 }
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 374af48b82..4abda21ffe 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -204,14 +204,6 @@ case class Sort(
  */
 @DeveloperApi
 object ExistingRdd {
-  def convertToCatalyst(a: Any): Any = a match {
-    case o: Option[_] => o.orNull
-    case s: Seq[_] => s.map(convertToCatalyst)
-    case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) }
-    case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray)
-    case other => other
-  }
-
   def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = {
     data.mapPartitions { iterator =>
       if (iterator.isEmpty) {
@@ -223,7 +215,7 @@ object ExistingRdd {
         bufferedIterator.map { r =>
           var i = 0
           while (i < mutableRow.length) {
-            mutableRow(i) = convertToCatalyst(r.productElement(i))
+            mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i))
             i += 1
           }
 
@@ -245,6 +237,7 @@ object ExistingRdd {
 case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
   override def execute() = rdd
 }
+
 /**
  * :: DeveloperApi ::
  * Computes the set of distinct input rows using a HashSet.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 76aa9b0081..ef9b76b1e2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -22,6 +22,8 @@ import org.apache.spark.sql.test._
 /* Implicits */
 import TestSQLContext._
 
+case class FunctionResult(f1: String, f2: String)
+
 class UDFSuite extends QueryTest {
 
   test("Simple UDF") {
@@ -33,4 +35,14 @@ class UDFSuite extends QueryTest {
     registerFunction("strLenScala", (_: String).length + (_:Int))
     assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5)
   }
+
+
+  test("struct UDF") {
+    registerFunction("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))
+
+    val result=
+      sql("SELECT returnStruct('test', 'test2') as ret")
+        .select("ret.f1".attr).first().getString(0)
+    assert(result == "test")
+  }
 }
-- 
GitLab