diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index cf60d14f0350ff35bd3d0ebe031a64f8bdbc20aa..79d824d494a8d09b2106fda2825d0c83757b4e59 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -10,6 +10,7 @@ import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
 import spark.broadcast.Broadcast
 import spark._
 import spark.rdd.PipedRDD
+import java.util
 
 
 private[spark] class PythonRDD[T: ClassManifest](
@@ -216,6 +217,9 @@ private[spark] object PythonRDD {
     }
     file.close()
   }
+
+  def takePartition[T](rdd: RDD[T], partition: Int): java.util.Iterator[T] =
+    rdd.context.runJob(rdd, ((x: Iterator[T]) => x), Seq(partition), true).head
 }
 
 private object Pickle {
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 6172d69dcff97fba5546494732f0199f2c12e021..4439356c1f10faa0a81c05e4213f41db93fa7e44 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -21,6 +21,7 @@ class SparkContext(object):
     jvm = gateway.jvm
     _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile
     _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile
+    _takePartition = jvm.PythonRDD.takePartition
 
     def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
         environment=None, batchSize=1024):
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index cbffb6cc1f2995a14ee53e04281a807cab32c335..4ba417b2a2bf8b1c0909c26e5259d44c197b807a 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -328,18 +328,17 @@ class RDD(object):
         a lot of partitions are required. In that case, use L{collect} to get
         the whole RDD instead.
 
-        >>> sc.parallelize([2, 3, 4, 5, 6]).take(2)
+        >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2)
         [2, 3]
         >>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
         [2, 3, 4, 5, 6]
         """
         items = []
-        splits = self._jrdd.splits()
-        taskContext = self.ctx.jvm.spark.TaskContext(0, 0, 0)
-        while len(items) < num and splits:
-            split = splits.pop(0)
-            iterator = self._jrdd.iterator(split, taskContext)
+        for partition in range(self._jrdd.splits().size()):
+            iterator = self.ctx._takePartition(self._jrdd.rdd(), partition)
             items.extend(self._collect_iterator_through_file(iterator))
+            if len(items) >= num:
+                break
         return items[:num]
 
     def first(self):