From b9d6783f36d527f5082bf13a4ee6fd108e97795c Mon Sep 17 00:00:00 2001
From: Matei Zaharia <matei@eecs.berkeley.edu>
Date: Sun, 28 Jul 2013 23:28:42 -0400
Subject: [PATCH] Optimize Python take() to not compute entire first partition

---
 .../scala/spark/api/python/PythonRDD.scala    | 64 +++++++++++--------
 python/pyspark/rdd.py                         | 15 +++--
 2 files changed, 45 insertions(+), 34 deletions(-)

diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala
index af10822dbd..2dd79f7100 100644
--- a/core/src/main/scala/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/spark/api/python/PythonRDD.scala
@@ -63,34 +63,42 @@ private[spark] class PythonRDD[T: ClassManifest](
     // Start a thread to feed the process input from our parent's iterator
     new Thread("stdin writer for " + pythonExec) {
       override def run() {
-        SparkEnv.set(env)
-        val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
-        val dataOut = new DataOutputStream(stream)
-        val printOut = new PrintWriter(stream)
-        // Partition index
-        dataOut.writeInt(split.index)
-        // sparkFilesDir
-        PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
-        // Broadcast variables
-        dataOut.writeInt(broadcastVars.length)
-        for (broadcast <- broadcastVars) {
-          dataOut.writeLong(broadcast.id)
-          dataOut.writeInt(broadcast.value.length)
-          dataOut.write(broadcast.value)
-        }
-        dataOut.flush()
-        // Serialized user code
-        for (elem <- command) {
-          printOut.println(elem)
-        }
-        printOut.flush()
-        // Data values
-        for (elem <- parent.iterator(split, context)) {
-          PythonRDD.writeAsPickle(elem, dataOut)
+        try {
+          SparkEnv.set(env)
+          val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
+          val dataOut = new DataOutputStream(stream)
+          val printOut = new PrintWriter(stream)
+          // Partition index
+          dataOut.writeInt(split.index)
+          // sparkFilesDir
+          PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
+          // Broadcast variables
+          dataOut.writeInt(broadcastVars.length)
+          for (broadcast <- broadcastVars) {
+            dataOut.writeLong(broadcast.id)
+            dataOut.writeInt(broadcast.value.length)
+            dataOut.write(broadcast.value)
+          }
+          dataOut.flush()
+          // Serialized user code
+          for (elem <- command) {
+            printOut.println(elem)
+          }
+          printOut.flush()
+          // Data values
+          for (elem <- parent.iterator(split, context)) {
+            PythonRDD.writeAsPickle(elem, dataOut)
+          }
+          dataOut.flush()
+          printOut.flush()
+          worker.shutdownOutput()
+        } catch {
+          case e: IOException =>
+            // This can happen for legitimate reasons if the Python code stops returning data before we are done
+            // passing elements through, e.g., for take(). Just log a message to say it happened.
+            logInfo("stdin writer to Python finished early")
+            logDebug("stdin writer to Python finished early", e)
         }
-        dataOut.flush()
-        printOut.flush()
-        worker.shutdownOutput()
       }
     }.start()
 
@@ -297,7 +305,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
   Utils.checkHost(serverHost, "Expected hostname")
 
   val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
-  
+
   override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
 
   override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index c6a6b24c5a..6efa61aa66 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -386,13 +386,16 @@ class RDD(object):
         >>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
         [2, 3, 4, 5, 6]
         """
+        def takeUpToNum(iterator):
+            taken = 0
+            while taken < num:
+                yield next(iterator)
+                taken += 1
+        # Take only up to num elements from each partition we try
+        mapped = self.mapPartitions(takeUpToNum)
         items = []
-        for partition in range(self._jrdd.splits().size()):
-            iterator = self.ctx._takePartition(self._jrdd.rdd(), partition)
-            # Each item in the iterator is a string, Python object, batch of
-            # Python objects.  Regardless, it is sufficient to take `num`
-            # of these objects in order to collect `num` Python objects:
-            iterator = iterator.take(num)
+        for partition in range(mapped._jrdd.splits().size()):
+            iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
             items.extend(self._collect_iterator_through_file(iterator))
             if len(items) >= num:
                 break
-- 
GitLab