Skip to content
Snippets Groups Projects
Commit d2e1aa97 authored by Takuya UESHIN's avatar Takuya UESHIN Committed by Reynold Xin
Browse files

[SPARK-15308][SQL] RowEncoder should preserve nested column name.

## What changes were proposed in this pull request?

The following code generates wrong schema:

```
val schema = new StructType().add(
  "struct",
  new StructType()
    .add("i", IntegerType, nullable = false)
    .add(
      "s",
      new StructType().add("int", IntegerType, nullable = false),
      nullable = false),
  nullable = false)
val ds = sqlContext.range(10).map(l => Row(l, Row(l)))(RowEncoder(schema))
ds.printSchema()
```

This should print as follows:

```
root
 |-- struct: struct (nullable = false)
 |    |-- i: integer (nullable = false)
 |    |-- s: struct (nullable = false)
 |    |    |-- int: integer (nullable = false)
```

but the result is:

```
root
 |-- struct: struct (nullable = false)
 |    |-- col1: integer (nullable = false)
 |    |-- col2: struct (nullable = false)
 |    |    |-- col1: integer (nullable = false)
```

This PR fixes `RowEncoder` to preserve nested column name.

## How was this patch tested?

Existing tests and I added a test to check if `RowEncoder` preserves nested column name.

Author: Takuya UESHIN <ueshin@happy-camper.st>

Closes #13090 from ueshin/issues/SPARK-15308.
parent 9a9c6f5c
No related branches found
No related tags found
No related merge requests found
...@@ -62,7 +62,7 @@ object RowEncoder { ...@@ -62,7 +62,7 @@ object RowEncoder {
new ExpressionEncoder[Row]( new ExpressionEncoder[Row](
schema, schema,
flat = false, flat = false,
serializer.asInstanceOf[CreateStruct].children, serializer.asInstanceOf[CreateNamedStruct].flatten,
deserializer, deserializer,
ClassTag(cls)) ClassTag(cls))
} }
...@@ -148,28 +148,30 @@ object RowEncoder { ...@@ -148,28 +148,30 @@ object RowEncoder {
dataType = t) dataType = t)
case StructType(fields) => case StructType(fields) =>
val convertedFields = fields.zipWithIndex.map { case (f, i) => val nonNullOutput = CreateNamedStruct(fields.zipWithIndex.flatMap { case (field, index) =>
val fieldValue = serializerFor( val fieldValue = serializerFor(
GetExternalRowField(inputObject, i, f.name, externalDataTypeForInput(f.dataType)), GetExternalRowField(
f.dataType inputObject, index, field.name, externalDataTypeForInput(field.dataType)),
field.dataType
) )
if (f.nullable) { val convertedField = if (field.nullable) {
If( If(
Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: Nil),
Literal.create(null, f.dataType), Literal.create(null, field.dataType),
fieldValue fieldValue
) )
} else { } else {
fieldValue fieldValue
} }
} Literal(field.name) :: convertedField :: Nil
})
if (inputObject.nullable) { if (inputObject.nullable) {
If(IsNull(inputObject), If(IsNull(inputObject),
Literal.create(null, inputType), Literal.create(null, inputType),
CreateStruct(convertedFields)) nonNullOutput)
} else { } else {
CreateStruct(convertedFields) nonNullOutput
} }
} }
......
...@@ -185,6 +185,28 @@ class RowEncoderSuite extends SparkFunSuite { ...@@ -185,6 +185,28 @@ class RowEncoderSuite extends SparkFunSuite {
assert(encoder.serializer.head.nullable == false) assert(encoder.serializer.head.nullable == false)
} }
test("RowEncoder should preserve nested column name") {
val schema = new StructType().add(
"struct",
new StructType()
.add("i", IntegerType, nullable = false)
.add(
"s",
new StructType().add("int", IntegerType, nullable = false),
nullable = false),
nullable = false)
val encoder = RowEncoder(schema)
assert(encoder.serializer.length == 1)
assert(encoder.serializer.head.dataType ==
new StructType()
.add("i", IntegerType, nullable = false)
.add(
"s",
new StructType().add("int", IntegerType, nullable = false),
nullable = false))
assert(encoder.serializer.head.nullable == false)
}
test("RowEncoder should support array as the external type for ArrayType") { test("RowEncoder should support array as the external type for ArrayType") {
val schema = new StructType() val schema = new StructType()
.add("array", ArrayType(IntegerType)) .add("array", ArrayType(IntegerType))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment