diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index cc59d06fa3518bc70f69b0007ea175f3619fe1db..688082dcce538745cc62d7cda85e304d070f8a6e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaRefle import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} -import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance} +import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation} import org.apache.spark.sql.types.{ObjectType, StructField, StructType} @@ -50,8 +50,15 @@ object ExpressionEncoder { val cls = mirror.runtimeClass(tpe) val flat = !ScalaReflection.definedByConstructorParams(tpe) - val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false) - val serializer = ScalaReflection.serializerFor[T](inputObject) + val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true) + val nullSafeInput = if (flat) { + inputObject + } else { + // For input object of non-flat type, we can't encode it to row if it's null, as Spark SQL + // doesn't allow top-level row to be null, only its columns can be null. + AssertNotNull(inputObject, Seq("top level non-flat input object")) + } + val serializer = ScalaReflection.serializerFor[T](nullSafeInput) val deserializer = ScalaReflection.deserializerFor[T] val schema = ScalaReflection.schemaFor[T] match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 3c6ae1c5ccd8ac9786f765d04872db6bc5565cd4..6cd7b34ceb88cf208ed4e23796a1f6cc25721f04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -57,8 +57,8 @@ import org.apache.spark.unsafe.types.UTF8String object RowEncoder { def apply(schema: StructType): ExpressionEncoder[Row] = { val cls = classOf[Row] - val inputObject = BoundReference(0, ObjectType(cls), nullable = false) - val serializer = serializerFor(inputObject, schema) + val inputObject = BoundReference(0, ObjectType(cls), nullable = true) + val serializer = serializerFor(AssertNotNull(inputObject, Seq("top level row object")), schema) val deserializer = deserializerFor(schema) new ExpressionEncoder[Row]( schema, @@ -153,8 +153,7 @@ object RowEncoder { val fieldValue = serializerFor( GetExternalRowField( inputObject, index, field.name, externalDataTypeForInput(field.dataType)), - field.dataType - ) + field.dataType) val convertedField = if (field.nullable) { If( Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: Nil), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index c2e3ab82ff16cfc4fe21c627b17efdb6f243617b..d4c71bffe86bfec275e71dcaadb85c3c18177d5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -519,7 +519,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) val code = s""" $values = new Object[${children.size}]; $childrenCode - final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField); + final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField); """ ev.copy(code = code, isNull = "false") } @@ -675,7 +675,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) ${childGen.code} if (${childGen.isNull}) { - throw new RuntimeException(this.$errMsgField); + throw new RuntimeException($errMsgField); } """ ev.copy(code = code, isNull = "false", value = childGen.value) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 6f1bc80c1cdda371c75f4626fafb8bcb20a25539..16abde064fc4498bd458f0f50d8f6524f5301f41 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -224,6 +224,14 @@ class RowEncoderSuite extends SparkFunSuite { assert(convertedBack.getSeq(2) == Seq(Seq(Seq(0L, null), null), null)) } + test("RowEncoder should throw RuntimeException if input row object is null") { + val schema = new StructType().add("int", IntegerType) + val encoder = RowEncoder(schema) + val e = intercept[RuntimeException](encoder.toRow(null)) + assert(e.getMessage.contains("Null value appeared in non-nullable field")) + assert(e.getMessage.contains("top level row object")) + } + private def encodeDecodeTest(schema: StructType): Unit = { test(s"encode/decode: ${schema.simpleString}") { val encoder = RowEncoder(schema).resolveAndBind() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index d1c232974e9cea742847af0bac3e8a58c58e825e..bf2b0a2c7c1b7be7d796f8ff7770278f82a8f351 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -790,6 +790,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { assert(e.getMessage.contains( "`abstract` is a reserved keyword and cannot be used as field name")) } + + test("Dataset should support flat input object to be null") { + checkDataset(Seq("a", null).toDS(), "a", null) + } + + test("Dataset should throw RuntimeException if non-flat input object is null") { + val e = intercept[RuntimeException](Seq(ClassData("a", 1), null).toDS()) + assert(e.getMessage.contains("Null value appeared in non-nullable field")) + assert(e.getMessage.contains("top level non-flat input object")) + } } case class Generic[T](id: T, value: Double)