Skip to content
Snippets Groups Projects
Commit b9d6783f authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Optimize Python take() to not compute entire first partition

parent 72ff62a3
No related branches found
No related tags found
No related merge requests found
......@@ -63,34 +63,42 @@ private[spark] class PythonRDD[T: ClassManifest](
// Start a thread to feed the process input from our parent's iterator
new Thread("stdin writer for " + pythonExec) {
override def run() {
SparkEnv.set(env)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
val printOut = new PrintWriter(stream)
// Partition index
dataOut.writeInt(split.index)
// sparkFilesDir
PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
// Broadcast variables
dataOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
dataOut.writeLong(broadcast.id)
dataOut.writeInt(broadcast.value.length)
dataOut.write(broadcast.value)
}
dataOut.flush()
// Serialized user code
for (elem <- command) {
printOut.println(elem)
}
printOut.flush()
// Data values
for (elem <- parent.iterator(split, context)) {
PythonRDD.writeAsPickle(elem, dataOut)
try {
SparkEnv.set(env)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
val printOut = new PrintWriter(stream)
// Partition index
dataOut.writeInt(split.index)
// sparkFilesDir
PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
// Broadcast variables
dataOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
dataOut.writeLong(broadcast.id)
dataOut.writeInt(broadcast.value.length)
dataOut.write(broadcast.value)
}
dataOut.flush()
// Serialized user code
for (elem <- command) {
printOut.println(elem)
}
printOut.flush()
// Data values
for (elem <- parent.iterator(split, context)) {
PythonRDD.writeAsPickle(elem, dataOut)
}
dataOut.flush()
printOut.flush()
worker.shutdownOutput()
} catch {
case e: IOException =>
// This can happen for legitimate reasons if the Python code stops returning data before we are done
// passing elements through, e.g., for take(). Just log a message to say it happened.
logInfo("stdin writer to Python finished early")
logDebug("stdin writer to Python finished early", e)
}
dataOut.flush()
printOut.flush()
worker.shutdownOutput()
}
}.start()
......@@ -297,7 +305,7 @@ class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int)
Utils.checkHost(serverHost, "Expected hostname")
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
......
......@@ -386,13 +386,16 @@ class RDD(object):
>>> sc.parallelize([2, 3, 4, 5, 6]).take(10)
[2, 3, 4, 5, 6]
"""
def takeUpToNum(iterator):
taken = 0
while taken < num:
yield next(iterator)
taken += 1
# Take only up to num elements from each partition we try
mapped = self.mapPartitions(takeUpToNum)
items = []
for partition in range(self._jrdd.splits().size()):
iterator = self.ctx._takePartition(self._jrdd.rdd(), partition)
# Each item in the iterator is a string, Python object, batch of
# Python objects. Regardless, it is sufficient to take `num`
# of these objects in order to collect `num` Python objects:
iterator = iterator.take(num)
for partition in range(mapped._jrdd.splits().size()):
iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
items.extend(self._collect_iterator_through_file(iterator))
if len(items) >= num:
break
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment