diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 57b28b9972366d47dee404539e46a5412cf42822..d1df99300c5b1c1055be38fb6607d4f87eed2fd6 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -269,6 +269,26 @@ private object SpecialLengths {
 private[spark] object PythonRDD {
   val UTF8 = Charset.forName("UTF-8")
 
+  /**
+   * Adapter for calling SparkContext#runJob from Python.
+   *
+   * This method will return an iterator of an array that contains all elements in the RDD
+   * (effectively a collect()), but allows you to run on a certain subset of partitions,
+   * or to enable local execution.
+   */
+  def runJob(
+      sc: SparkContext,
+      rdd: JavaRDD[Array[Byte]],
+      partitions: JArrayList[Int],
+      allowLocal: Boolean): Iterator[Array[Byte]] = {
+    type ByteArray = Array[Byte]
+    type UnrolledPartition = Array[ByteArray]
+    val allPartitions: Array[UnrolledPartition] =
+      sc.runJob(rdd, (x: Iterator[ByteArray]) => x.toArray, partitions, allowLocal)
+    val flattenedPartition: UnrolledPartition = Array.concat(allPartitions: _*)
+    flattenedPartition.iterator
+  }
+
   def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
   JavaRDD[Array[Byte]] = {
     val file = new DataInputStream(new FileInputStream(filename))
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 56746cb7aab3d97e85f63ab3951c3989dd2ecfe1..9ae9305d4f02ea9a0461bf4be416f9d09ac78a8e 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -537,6 +537,32 @@ class SparkContext(object):
         """
         self._jsc.sc().cancelAllJobs()
 
+    def runJob(self, rdd, partitionFunc, partitions = None, allowLocal = False):
+        """
+        Executes the given partitionFunc on the specified set of partitions,
+        returning the result as an array of elements.
+
+        If 'partitions' is not specified, this will run over all partitions.
+
+        >>> myRDD = sc.parallelize(range(6), 3)
+        >>> sc.runJob(myRDD, lambda part: [x * x for x in part])
+        [0, 1, 4, 9, 16, 25]
+
+        >>> myRDD = sc.parallelize(range(6), 3)
+        >>> sc.runJob(myRDD, lambda part: [x * x for x in part], [0, 2], True)
+        [0, 1, 16, 25]
+        """
+        if partitions == None:
+            partitions = range(rdd._jrdd.splits().size())
+        javaPartitions = ListConverter().convert(partitions, self._gateway._gateway_client)
+
+        # Implementation note: This is implemented as a mapPartitions followed
+        # by runJob() in order to avoid having to pass a Python lambda into
+        # SparkContext#runJob.
+        mappedRDD = rdd.mapPartitions(partitionFunc)
+        it = self._jvm.PythonRDD.runJob(self._jsc.sc(), mappedRDD._jrdd, javaPartitions, allowLocal)
+        return list(mappedRDD._collect_iterator_through_file(it))
+
 def _test():
     import atexit
     import doctest
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 07578b8d937fc7433059a9bd1f4ccdf295611424..f3b1f1a665e5a48e92f86653d18f1043ec3ed894 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -841,34 +841,51 @@ class RDD(object):
         """
         Take the first num elements of the RDD.
 
-        This currently scans the partitions *one by one*, so it will be slow if
-        a lot of partitions are required. In that case, use L{collect} to get
-        the whole RDD instead.
+        It works by first scanning one partition, and use the results from
+        that partition to estimate the number of additional partitions needed
+        to satisfy the limit.
+
+        Translated from the Scala implementation in RDD#take().
 
         >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2)
         [2, 3]
         >>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
         [2, 3, 4, 5, 6]
+        >>> sc.parallelize(range(100), 100).filter(lambda x: x > 90).take(3)
+        [91, 92, 93]
         """
-        def takeUpToNum(iterator):
-            taken = 0
-            while taken < num:
-                yield next(iterator)
-                taken += 1
-        # Take only up to num elements from each partition we try
-        mapped = self.mapPartitions(takeUpToNum)
         items = []
-        # TODO(shivaram): Similar to the scala implementation, update the take 
-        # method to scan multiple splits based on an estimate of how many elements 
-        # we have per-split.
-        with _JavaStackTrace(self.context) as st:
-            for partition in range(mapped._jrdd.splits().size()):
-                partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
-                partitionsToTake[0] = partition
-                iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
-                items.extend(mapped._collect_iterator_through_file(iterator))
-                if len(items) >= num:
-                    break
+        totalParts = self._jrdd.splits().size()
+        partsScanned = 0
+
+        while len(items) < num and partsScanned < totalParts:
+            # The number of partitions to try in this iteration.
+            # It is ok for this number to be greater than totalParts because
+            # we actually cap it at totalParts in runJob.
+            numPartsToTry = 1
+            if partsScanned > 0:
+                # If we didn't find any rows after the first iteration, just
+                # try all partitions next. Otherwise, interpolate the number
+                # of partitions we need to try, but overestimate it by 50%.
+                if len(items) == 0:
+                    numPartsToTry = totalParts - 1
+                else:
+                    numPartsToTry = int(1.5 * num * partsScanned / len(items))
+
+            left = num - len(items)
+
+            def takeUpToNumLeft(iterator):
+                taken = 0
+                while taken < left:
+                    yield next(iterator)
+                    taken += 1
+
+            p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts))
+            res = self.context.runJob(self, takeUpToNumLeft, p, True)
+
+            items += res
+            partsScanned += numPartsToTry
+
         return items[:num]
 
     def first(self):