From 2cbe96e64d5f84474b2eb59bed9ce3ab543d8aff Mon Sep 17 00:00:00 2001
From: Takuya UESHIN <ueshin@happy-camper.st>
Date: Fri, 20 May 2016 09:38:34 -0700
Subject: [PATCH] [SPARK-15400][SQL] CreateNamedStruct and
 CreateNamedStructUnsafe should preserve metadata of value expressions if it
 is NamedExpression.

## What changes were proposed in this pull request?

`CreateNamedStruct` and `CreateNamedStructUnsafe` should preserve metadata of value expressions if it is `NamedExpression` like `CreateStruct` or `CreateStructUnsafe` are doing.

## How was this patch tested?

Existing tests.

Author: Takuya UESHIN <ueshin@happy-camper.st>

Closes #13193 from ueshin/issues/SPARK-15400.
---
 .../expressions/complexTypeCreator.scala       | 17 ++++++++++++-----
 .../expressions/ComplexTypeSuite.scala         | 18 ++++++++++++++++++
 2 files changed, 30 insertions(+), 5 deletions(-)

diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
index d986d9dca6..f60d278c54 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala
@@ -252,9 +252,13 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression {
   private lazy val names = nameExprs.map(_.eval(EmptyRow))
 
   override lazy val dataType: StructType = {
-    val fields = names.zip(valExprs).map { case (name, valExpr) =>
-      StructField(name.asInstanceOf[UTF8String].toString,
-        valExpr.dataType, valExpr.nullable, Metadata.empty)
+    val fields = names.zip(valExprs).map {
+      case (name, valExpr: NamedExpression) =>
+        StructField(name.asInstanceOf[UTF8String].toString,
+          valExpr.dataType, valExpr.nullable, valExpr.metadata)
+      case (name, valExpr) =>
+        StructField(name.asInstanceOf[UTF8String].toString,
+          valExpr.dataType, valExpr.nullable, Metadata.empty)
     }
     StructType(fields)
   }
@@ -365,8 +369,11 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression
   private lazy val names = nameExprs.map(_.eval(EmptyRow).toString)
 
   override lazy val dataType: StructType = {
-    val fields = names.zip(valExprs).map { case (name, valExpr) =>
-      StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty)
+    val fields = names.zip(valExprs).map {
+      case (name, valExpr: NamedExpression) =>
+        StructField(name, valExpr.dataType, valExpr.nullable, valExpr.metadata)
+      case (name, valExpr) =>
+        StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty)
     }
     StructType(fields)
   }
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
index 7c009a7360..ec7be4d4b8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala
@@ -228,4 +228,22 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
     checkErrorMessage(structType, IntegerType, "Field name should be String Literal")
     checkErrorMessage(otherType, StringType, "Can't extract value from")
   }
+
+  test("ensure to preserve metadata") {
+    val metadata = new MetadataBuilder()
+      .putString("key", "value")
+      .build()
+
+    def checkMetadata(expr: Expression): Unit = {
+      assert(expr.dataType.asInstanceOf[StructType]("a").metadata === metadata)
+      assert(expr.dataType.asInstanceOf[StructType]("b").metadata === Metadata.empty)
+    }
+
+    val a = AttributeReference("a", IntegerType, metadata = metadata)()
+    val b = AttributeReference("b", IntegerType)()
+    checkMetadata(CreateStruct(Seq(a, b)))
+    checkMetadata(CreateNamedStruct(Seq("a", a, "b", b)))
+    checkMetadata(CreateStructUnsafe(Seq(a, b)))
+    checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b)))
+  }
 }
-- 
GitLab