From 8fbd5380b7f36842297f624bad3a2513f7eca47b Mon Sep 17 00:00:00 2001
From: Josh Rosen <joshrosen@eecs.berkeley.edu>
Date: Sun, 3 Feb 2013 06:44:49 +0000
Subject: [PATCH] Fetch fewer objects in PySpark's take() method.

---
 core/src/main/scala/spark/api/python/PythonRDD.scala | 11 +++++++++--
 python/pyspark/rdd.py                                |  4 ++++
 2 files changed, 13 insertions(+), 2 deletions(-)

diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index 39758e94f4..ab8351e55e 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -238,6 +238,11 @@ private[spark] object PythonRDD {
   }
 
   def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
+    import scala.collection.JavaConverters._
+    writeIteratorToPickleFile(items.asScala, filename)
+  }
+
+  def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) {
     val file = new DataOutputStream(new FileOutputStream(filename))
     for (item <- items) {
       writeAsPickle(item, file)
@@ -245,8 +250,10 @@ private[spark] object PythonRDD {
     file.close()
   }
 
-  def takePartition[T](rdd: RDD[T], partition: Int): java.util.Iterator[T] =
-    rdd.context.runJob(rdd, ((x: Iterator[T]) => x), Seq(partition), true).head
+  def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = {
+    implicit val cm : ClassManifest[T] = rdd.elementClassManifest
+    rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator
+  }
 }
 
 private object Pickle {
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index fb144bc45d..4cda6cf661 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -372,6 +372,10 @@ class RDD(object):
         items = []
         for partition in range(self._jrdd.splits().size()):
             iterator = self.ctx._takePartition(self._jrdd.rdd(), partition)
+            # Each item in the iterator is a string, Python object, batch of
+            # Python objects.  Regardless, it is sufficient to take `num`
+            # of these objects in order to collect `num` Python objects:
+            iterator = iterator.take(num)
             items.extend(self._collect_iterator_through_file(iterator))
             if len(items) >= num:
                 break
-- 
GitLab