diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index cc59d06fa3518bc70f69b0007ea175f3619fe1db..688082dcce538745cc62d7cda85e304d070f8a6e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaRefle
 import org.apache.spark.sql.catalyst.analysis.{Analyzer, GetColumnByOrdinal, SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue}
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
-import org.apache.spark.sql.catalyst.expressions.objects.{Invoke, NewInstance}
+import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, Invoke, NewInstance}
 import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
 import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, DeserializeToObject, LocalRelation}
 import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
@@ -50,8 +50,15 @@ object ExpressionEncoder {
     val cls = mirror.runtimeClass(tpe)
     val flat = !ScalaReflection.definedByConstructorParams(tpe)
 
-    val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false)
-    val serializer = ScalaReflection.serializerFor[T](inputObject)
+    val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = true)
+    val nullSafeInput = if (flat) {
+      inputObject
+    } else {
+      // For input object of non-flat type, we can't encode it to row if it's null, as Spark SQL
+      // doesn't allow top-level row to be null, only its columns can be null.
+      AssertNotNull(inputObject, Seq("top level non-flat input object"))
+    }
+    val serializer = ScalaReflection.serializerFor[T](nullSafeInput)
     val deserializer = ScalaReflection.deserializerFor[T]
 
     val schema = ScalaReflection.schemaFor[T] match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
index 3c6ae1c5ccd8ac9786f765d04872db6bc5565cd4..6cd7b34ceb88cf208ed4e23796a1f6cc25721f04 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala
@@ -57,8 +57,8 @@ import org.apache.spark.unsafe.types.UTF8String
 object RowEncoder {
   def apply(schema: StructType): ExpressionEncoder[Row] = {
     val cls = classOf[Row]
-    val inputObject = BoundReference(0, ObjectType(cls), nullable = false)
-    val serializer = serializerFor(inputObject, schema)
+    val inputObject = BoundReference(0, ObjectType(cls), nullable = true)
+    val serializer = serializerFor(AssertNotNull(inputObject, Seq("top level row object")), schema)
     val deserializer = deserializerFor(schema)
     new ExpressionEncoder[Row](
       schema,
@@ -153,8 +153,7 @@ object RowEncoder {
         val fieldValue = serializerFor(
           GetExternalRowField(
             inputObject, index, field.name, externalDataTypeForInput(field.dataType)),
-          field.dataType
-        )
+          field.dataType)
         val convertedField = if (field.nullable) {
           If(
             Invoke(inputObject, "isNullAt", BooleanType, Literal(index) :: Nil),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index c2e3ab82ff16cfc4fe21c627b17efdb6f243617b..d4c71bffe86bfec275e71dcaadb85c3c18177d5c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -519,7 +519,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType)
     val code = s"""
       $values = new Object[${children.size}];
       $childrenCode
-      final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);
+      final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, $schemaField);
       """
     ev.copy(code = code, isNull = "false")
   }
@@ -675,7 +675,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String])
       ${childGen.code}
 
       if (${childGen.isNull}) {
-        throw new RuntimeException(this.$errMsgField);
+        throw new RuntimeException($errMsgField);
       }
      """
     ev.copy(code = code, isNull = "false", value = childGen.value)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 6f1bc80c1cdda371c75f4626fafb8bcb20a25539..16abde064fc4498bd458f0f50d8f6524f5301f41 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -224,6 +224,14 @@ class RowEncoderSuite extends SparkFunSuite {
     assert(convertedBack.getSeq(2) == Seq(Seq(Seq(0L, null), null), null))
   }
 
+  test("RowEncoder should throw RuntimeException if input row object is null") {
+    val schema = new StructType().add("int", IntegerType)
+    val encoder = RowEncoder(schema)
+    val e = intercept[RuntimeException](encoder.toRow(null))
+    assert(e.getMessage.contains("Null value appeared in non-nullable field"))
+    assert(e.getMessage.contains("top level row object"))
+  }
+
   private def encodeDecodeTest(schema: StructType): Unit = {
     test(s"encode/decode: ${schema.simpleString}") {
       val encoder = RowEncoder(schema).resolveAndBind()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index d1c232974e9cea742847af0bac3e8a58c58e825e..bf2b0a2c7c1b7be7d796f8ff7770278f82a8f351 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -790,6 +790,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
     assert(e.getMessage.contains(
       "`abstract` is a reserved keyword and cannot be used as field name"))
   }
+
+  test("Dataset should support flat input object to be null") {
+    checkDataset(Seq("a", null).toDS(), "a", null)
+  }
+
+  test("Dataset should throw RuntimeException if non-flat input object is null") {
+    val e = intercept[RuntimeException](Seq(ClassData("a", 1), null).toDS())
+    assert(e.getMessage.contains("Null value appeared in non-nullable field"))
+    assert(e.getMessage.contains("top level non-flat input object"))
+  }
 }
 
 case class Generic[T](id: T, value: Double)