From 33beba39656fc64984db09a82fc69ca4edcc02d4 Mon Sep 17 00:00:00 2001 From: Josh Rosen <joshrosen@eecs.berkeley.edu> Date: Thu, 3 Jan 2013 14:52:21 -0800 Subject: [PATCH] Change PySpark RDD.take() to not call iterator(). --- core/src/main/scala/spark/api/python/PythonRDD.scala | 4 ++++ python/pyspark/context.py | 1 + python/pyspark/rdd.py | 11 +++++------ 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/spark/api/python/PythonRDD.scala b/core/src/main/scala/spark/api/python/PythonRDD.scala index cf60d14f03..79d824d494 100644 --- a/core/src/main/scala/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/spark/api/python/PythonRDD.scala @@ -10,6 +10,7 @@ import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import spark.broadcast.Broadcast import spark._ import spark.rdd.PipedRDD +import java.util private[spark] class PythonRDD[T: ClassManifest]( @@ -216,6 +217,9 @@ private[spark] object PythonRDD { } file.close() } + + def takePartition[T](rdd: RDD[T], partition: Int): java.util.Iterator[T] = + rdd.context.runJob(rdd, ((x: Iterator[T]) => x), Seq(partition), true).head } private object Pickle { diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6172d69dcf..4439356c1f 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -21,6 +21,7 @@ class SparkContext(object): jvm = gateway.jvm _readRDDFromPickleFile = jvm.PythonRDD.readRDDFromPickleFile _writeIteratorToPickleFile = jvm.PythonRDD.writeIteratorToPickleFile + _takePartition = jvm.PythonRDD.takePartition def __init__(self, master, jobName, sparkHome=None, pyFiles=None, environment=None, batchSize=1024): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index cbffb6cc1f..4ba417b2a2 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -328,18 +328,17 @@ class RDD(object): a lot of partitions are required. In that case, use L{collect} to get the whole RDD instead. - >>> sc.parallelize([2, 3, 4, 5, 6]).take(2) + >>> sc.parallelize([2, 3, 4, 5, 6]).cache().take(2) [2, 3] >>> sc.parallelize([2, 3, 4, 5, 6]).take(10) [2, 3, 4, 5, 6] """ items = [] - splits = self._jrdd.splits() - taskContext = self.ctx.jvm.spark.TaskContext(0, 0, 0) - while len(items) < num and splits: - split = splits.pop(0) - iterator = self._jrdd.iterator(split, taskContext) + for partition in range(self._jrdd.splits().size()): + iterator = self.ctx._takePartition(self._jrdd.rdd(), partition) items.extend(self._collect_iterator_through_file(iterator)) + if len(items) >= num: + break return items[:num] def first(self): -- GitLab