diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index d986d9dca6ed47b49cc256c2e89649aa78331c0a..f60d278c54873605f502a1f0092847dbb414c4a6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -252,9 +252,13 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { private lazy val names = nameExprs.map(_.eval(EmptyRow)) override lazy val dataType: StructType = { - val fields = names.zip(valExprs).map { case (name, valExpr) => - StructField(name.asInstanceOf[UTF8String].toString, - valExpr.dataType, valExpr.nullable, Metadata.empty) + val fields = names.zip(valExprs).map { + case (name, valExpr: NamedExpression) => + StructField(name.asInstanceOf[UTF8String].toString, + valExpr.dataType, valExpr.nullable, valExpr.metadata) + case (name, valExpr) => + StructField(name.asInstanceOf[UTF8String].toString, + valExpr.dataType, valExpr.nullable, Metadata.empty) } StructType(fields) } @@ -365,8 +369,11 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression private lazy val names = nameExprs.map(_.eval(EmptyRow).toString) override lazy val dataType: StructType = { - val fields = names.zip(valExprs).map { case (name, valExpr) => - StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) + val fields = names.zip(valExprs).map { + case (name, valExpr: NamedExpression) => + StructField(name, valExpr.dataType, valExpr.nullable, valExpr.metadata) + case (name, valExpr) => + StructField(name, valExpr.dataType, valExpr.nullable, Metadata.empty) } StructType(fields) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 7c009a7360b6f3c5d47fb2cbd034f2feeece88bd..ec7be4d4b849d959c34b32479d3a53656716e041 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -228,4 +228,22 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { checkErrorMessage(structType, IntegerType, "Field name should be String Literal") checkErrorMessage(otherType, StringType, "Can't extract value from") } + + test("ensure to preserve metadata") { + val metadata = new MetadataBuilder() + .putString("key", "value") + .build() + + def checkMetadata(expr: Expression): Unit = { + assert(expr.dataType.asInstanceOf[StructType]("a").metadata === metadata) + assert(expr.dataType.asInstanceOf[StructType]("b").metadata === Metadata.empty) + } + + val a = AttributeReference("a", IntegerType, metadata = metadata)() + val b = AttributeReference("b", IntegerType)() + checkMetadata(CreateStruct(Seq(a, b))) + checkMetadata(CreateNamedStruct(Seq("a", a, "b", b))) + checkMetadata(CreateStructUnsafe(Seq(a, b))) + checkMetadata(CreateNamedStructUnsafe(Seq("a", a, "b", b))) + } }