diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index a392b5341244fce7a8cfcae811dbd5902aebfb74..010ed7f5008ebaffb1105addb0be9efb55230c3d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -219,48 +219,62 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
   }
 
   /**
-   * Runs this query returning the result as an array.
+   * Packing the UnsafeRows into byte array for faster serialization.
+   * The byte arrays are in the following format:
+   * [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1]
+   *
+   * UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also
+   * compressed.
    */
-  def executeCollect(): Array[InternalRow] = {
-    // Packing the UnsafeRows into byte array for faster serialization.
-    // The byte arrays are in the following format:
-    // [size] [bytes of UnsafeRow] [size] [bytes of UnsafeRow] ... [-1]
-    //
-    // UnsafeRow is highly compressible (at least 8 bytes for any column), the byte array is also
-    // compressed.
-    val byteArrayRdd = execute().mapPartitionsInternal { iter =>
+  private def getByteArrayRdd(n: Int = -1): RDD[Array[Byte]] = {
+    execute().mapPartitionsInternal { iter =>
+      var count = 0
       val buffer = new Array[Byte](4 << 10)  // 4K
       val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
       val bos = new ByteArrayOutputStream()
       val out = new DataOutputStream(codec.compressedOutputStream(bos))
-      while (iter.hasNext) {
+      while (iter.hasNext && (n < 0 || count < n)) {
         val row = iter.next().asInstanceOf[UnsafeRow]
         out.writeInt(row.getSizeInBytes)
         row.writeToStream(out, buffer)
+        count += 1
       }
       out.writeInt(-1)
       out.flush()
       out.close()
       Iterator(bos.toByteArray)
     }
+  }
 
-    // Collect the byte arrays back to driver, then decode them as UnsafeRows.
+  /**
+   * Decode the byte arrays back to UnsafeRows and put them into buffer.
+   */
+  private def decodeUnsafeRows(bytes: Array[Byte], buffer: ArrayBuffer[InternalRow]): Unit = {
     val nFields = schema.length
-    val results = ArrayBuffer[InternalRow]()
 
+    val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
+    val bis = new ByteArrayInputStream(bytes)
+    val ins = new DataInputStream(codec.compressedInputStream(bis))
+    var sizeOfNextRow = ins.readInt()
+    while (sizeOfNextRow >= 0) {
+      val bs = new Array[Byte](sizeOfNextRow)
+      ins.readFully(bs)
+      val row = new UnsafeRow(nFields)
+      row.pointTo(bs, sizeOfNextRow)
+      buffer += row
+      sizeOfNextRow = ins.readInt()
+    }
+  }
+
+  /**
+   * Runs this query returning the result as an array.
+   */
+  def executeCollect(): Array[InternalRow] = {
+    val byteArrayRdd = getByteArrayRdd()
+
+    val results = ArrayBuffer[InternalRow]()
     byteArrayRdd.collect().foreach { bytes =>
-      val codec = CompressionCodec.createCodec(SparkEnv.get.conf)
-      val bis = new ByteArrayInputStream(bytes)
-      val ins = new DataInputStream(codec.compressedInputStream(bis))
-      var sizeOfNextRow = ins.readInt()
-      while (sizeOfNextRow >= 0) {
-        val bs = new Array[Byte](sizeOfNextRow)
-        ins.readFully(bs)
-        val row = new UnsafeRow(nFields)
-        row.pointTo(bs, sizeOfNextRow)
-        results += row
-        sizeOfNextRow = ins.readInt()
-      }
+      decodeUnsafeRows(bytes, results)
     }
     results.toArray
   }
@@ -283,7 +297,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
       return new Array[InternalRow](0)
     }
 
-    val childRDD = execute().map(_.copy())
+    val childRDD = getByteArrayRdd(n)
 
     val buf = new ArrayBuffer[InternalRow]
     val totalParts = childRDD.partitions.length
@@ -307,13 +321,21 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
       val left = n - buf.size
       val p = partsScanned.until(math.min(partsScanned + numPartsToTry, totalParts).toInt)
       val sc = sqlContext.sparkContext
-      val res = sc.runJob(childRDD, (it: Iterator[InternalRow]) => it.take(left).toArray, p)
+      val res = sc.runJob(childRDD,
+        (it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else Array.empty, p)
+
+      res.foreach { r =>
+        decodeUnsafeRows(r.asInstanceOf[Array[Byte]], buf)
+      }
 
-      res.foreach(buf ++= _.take(n - buf.size))
       partsScanned += p.size
     }
 
-    buf.toArray
+    if (buf.size > n) {
+      buf.take(n).toArray
+    } else {
+      buf.toArray
+    }
   }
 
   private[this] def isTesting: Boolean = sys.props.contains("spark.testing")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index b6051b07c8093ac94a422b852ac3ef1d9c1c44bb..d293ff66fbc6bb2adfdd1973d472041503217c59 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -465,4 +465,25 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
     collect 4 millions                       3193 / 3895          0.3        3044.7       0.1X
      */
   }
+
+  ignore("collect limit") {
+    val N = 1 << 20
+
+    val benchmark = new Benchmark("collect limit", N)
+    benchmark.addCase("collect limit 1 million") { iter =>
+      sqlContext.range(N * 4).limit(N).collect()
+    }
+    benchmark.addCase("collect limit 2 millions") { iter =>
+      sqlContext.range(N * 4).limit(N * 2).collect()
+    }
+    benchmark.run()
+
+    /**
+    model name      : Westmere E56xx/L56xx/X56xx (Nehalem-C)
+    collect limit:                      Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
+    -------------------------------------------------------------------------------------------
+    collect limit 1 million                   833 / 1284          1.3         794.4       1.0X
+    collect limit 2 millions                 3348 / 4005          0.3        3193.3       0.2X
+     */
+  }
 }