diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
index 993d085c75089e560c68fbee1d824858d25ed378..31d27bb4f0571f1a2c34f07c248fce70d62a99df 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala
@@ -430,7 +430,7 @@ class SchemaRDD(
    * @group schema
    */
   private def applySchema(rdd: RDD[Row]): SchemaRDD = {
-    new SchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(logicalPlan.output, rdd)))
+    new SchemaRDD(sqlContext, SparkLogicalPlan(ExistingRdd(queryExecution.analyzed.output, rdd)))
   }
 
   // =======================================================================
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index 68dae58728a2a21801be34bbb0cf11a841e88835..c8ea01c4e1b6a438399bd71c71793e877c47f239 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -33,6 +33,12 @@ class DslQuerySuite extends QueryTest {
       testData.collect().toSeq)
   }
 
+  test("repartition") {
+    checkAnswer(
+      testData.select('key).repartition(10).select('key),
+      testData.select('key).collect().toSeq)
+  }
+
   test("agg") {
     checkAnswer(
       testData2.groupBy('a)('a, Sum('b)),