diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 1082cbae3e5193a36ed42d43a60d57c5f1f40e57..1893627ee2ce3360d0b5dfd78f00c9af2e38e54d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -753,24 +753,42 @@ abstract class RDD[T: ClassManifest](
   }
 
   /**
-   * Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
-   * it will be slow if a lot of partitions are required. In that case, use collect() to get the
-   * whole RDD instead.
+   * Take the first num elements of the RDD. It works by first scanning one partition, and use the
+   * results from that partition to estimate the number of additional partitions needed to satisfy
+   * the limit.
    */
   def take(num: Int): Array[T] = {
     if (num == 0) {
       return new Array[T](0)
     }
+
     val buf = new ArrayBuffer[T]
-    var p = 0
-    while (buf.size < num && p < partitions.size) {
+    val totalParts = this.partitions.length
+    var partsScanned = 0
+    while (buf.size < num && partsScanned < totalParts) {
+      // The number of partitions to try in this iteration. It is ok for this number to be
+      // greater than totalParts because we actually cap it at totalParts in runJob.
+      var numPartsToTry = 1
+      if (partsScanned > 0) {
+        // If we didn't find any rows after the first iteration, just try all partitions next.
+        // Otherwise, interpolate the number of partitions we need to try, but overestimate it
+        // by 50%.
+        if (buf.size == 0) {
+          numPartsToTry = totalParts - 1
+        } else {
+          numPartsToTry = (1.5 * num * partsScanned / buf.size).toInt
+        }
+      }
+      numPartsToTry = math.max(0, numPartsToTry)  // guard against negative num of partitions
+
       val left = num - buf.size
-      val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, Array(p), true)
-      buf ++= res(0)
-      if (buf.size == num)
-        return buf.toArray
-      p += 1
+      val p = partsScanned until math.min(partsScanned + numPartsToTry, totalParts)
+      val res = sc.runJob(this, (it: Iterator[T]) => it.take(left).toArray, p, allowLocal = true)
+
+      res.foreach(buf ++= _.take(num - buf.size))
+      partsScanned += numPartsToTry
     }
+
     return buf.toArray
   }
 
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 016db8d57e4e7a4ce7c0701fec174d64e622f53d..6d1bc5e296e06beb137673088229da3750c0579c 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -320,6 +320,44 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     for (i <- 0 until sample.size) assert(sample(i) === checkSample(i))
   }
 
+  test("take") {
+    var nums = sc.makeRDD(Range(1, 1000), 1)
+    assert(nums.take(0).size === 0)
+    assert(nums.take(1) === Array(1))
+    assert(nums.take(3) === Array(1, 2, 3))
+    assert(nums.take(500) === (1 to 500).toArray)
+    assert(nums.take(501) === (1 to 501).toArray)
+    assert(nums.take(999) === (1 to 999).toArray)
+    assert(nums.take(1000) === (1 to 999).toArray)
+
+    nums = sc.makeRDD(Range(1, 1000), 2)
+    assert(nums.take(0).size === 0)
+    assert(nums.take(1) === Array(1))
+    assert(nums.take(3) === Array(1, 2, 3))
+    assert(nums.take(500) === (1 to 500).toArray)
+    assert(nums.take(501) === (1 to 501).toArray)
+    assert(nums.take(999) === (1 to 999).toArray)
+    assert(nums.take(1000) === (1 to 999).toArray)
+
+    nums = sc.makeRDD(Range(1, 1000), 100)
+    assert(nums.take(0).size === 0)
+    assert(nums.take(1) === Array(1))
+    assert(nums.take(3) === Array(1, 2, 3))
+    assert(nums.take(500) === (1 to 500).toArray)
+    assert(nums.take(501) === (1 to 501).toArray)
+    assert(nums.take(999) === (1 to 999).toArray)
+    assert(nums.take(1000) === (1 to 999).toArray)
+
+    nums = sc.makeRDD(Range(1, 1000), 1000)
+    assert(nums.take(0).size === 0)
+    assert(nums.take(1) === Array(1))
+    assert(nums.take(3) === Array(1, 2, 3))
+    assert(nums.take(500) === (1 to 500).toArray)
+    assert(nums.take(501) === (1 to 501).toArray)
+    assert(nums.take(999) === (1 to 999).toArray)
+    assert(nums.take(1000) === (1 to 999).toArray)
+  }
+
   test("top with predefined ordering") {
     val nums = Array.range(1, 100000)
     val ints = sc.makeRDD(scala.util.Random.shuffle(nums), 2)