diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 6b6b636cd96dcb2e0f25e23f36359c08dee9a358..88a8fa7c28e0f01ab6792493b618f09b6ac2a105 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst import java.sql.Timestamp -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.{GenericRow, Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.catalyst.types._ @@ -32,6 +31,15 @@ object ScalaReflection { case class Schema(dataType: DataType, nullable: Boolean) + /** Converts Scala objects to catalyst rows / types */ + def convertToCatalyst(a: Any): Any = a match { + case o: Option[_] => o.orNull + case s: Seq[_] => s.map(convertToCatalyst) + case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) } + case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) + case other => other + } + /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { case Schema(s: StructType, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala index 589816ccec0d52e6aed31015a29319a81d17e079..1b687a443ef8b7a85c7c1e06431235ecd244381f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUdf.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.types.DataType import org.apache.spark.util.ClosureCleaner @@ -27,6 +28,8 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi def nullable = true + override def toString = s"scalaUDF(${children.mkString(",")})" + /** This method has been generated by this script (1 to 22).map { x => @@ -44,7 +47,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi // scalastyle:off override def eval(input: Row): Any = { - children.size match { + val result = children.size match { case 0 => function.asInstanceOf[() => Any]() case 1 => function.asInstanceOf[(Any) => Any](children(0).eval(input)) case 2 => @@ -343,5 +346,7 @@ case class ScalaUdf(function: AnyRef, dataType: DataType, children: Seq[Expressi children(21).eval(input)) } // scalastyle:on + + ScalaReflection.convertToCatalyst(result) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 374af48b820c2a8f7d8e561841fb0dc6d466dde0..4abda21ffec96889325a1432c44646878b3f22a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -204,14 +204,6 @@ case class Sort( */ @DeveloperApi object ExistingRdd { - def convertToCatalyst(a: Any): Any = a match { - case o: Option[_] => o.orNull - case s: Seq[_] => s.map(convertToCatalyst) - case m: Map[_, _] => m.map { case (k, v) => convertToCatalyst(k) -> convertToCatalyst(v) } - case p: Product => new GenericRow(p.productIterator.map(convertToCatalyst).toArray) - case other => other - } - def productToRowRdd[A <: Product](data: RDD[A]): RDD[Row] = { data.mapPartitions { iterator => if (iterator.isEmpty) { @@ -223,7 +215,7 @@ object ExistingRdd { bufferedIterator.map { r => var i = 0 while (i < mutableRow.length) { - mutableRow(i) = convertToCatalyst(r.productElement(i)) + mutableRow(i) = ScalaReflection.convertToCatalyst(r.productElement(i)) i += 1 } @@ -245,6 +237,7 @@ object ExistingRdd { case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode { override def execute() = rdd } + /** * :: DeveloperApi :: * Computes the set of distinct input rows using a HashSet. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 76aa9b0081d7e0cf9b69194a9a059be2b160787f..ef9b76b1e251ef17ceca2eebeb7aedc2380e7fa1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.test._ /* Implicits */ import TestSQLContext._ +case class FunctionResult(f1: String, f2: String) + class UDFSuite extends QueryTest { test("Simple UDF") { @@ -33,4 +35,14 @@ class UDFSuite extends QueryTest { registerFunction("strLenScala", (_: String).length + (_:Int)) assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) } + + + test("struct UDF") { + registerFunction("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + + val result= + sql("SELECT returnStruct('test', 'test2') as ret") + .select("ret.f1".attr).first().getString(0) + assert(result == "test") + } }