From eedc542a0276a5248c81446ee84f56d691e5f488 Mon Sep 17 00:00:00 2001 From: Reynold Xin <rxin@cs.berkeley.edu> Date: Wed, 23 Jan 2013 22:14:23 -0800 Subject: [PATCH] Removed pruneSplits method in RDD and renamed SplitsPruningRDD to PartitionPruningRDD. --- core/src/main/scala/spark/RDD.scala | 10 -------- .../scala/spark/rdd/PartitionPruningRDD.scala | 24 +++++++++++++++++++ .../scala/spark/rdd/SplitsPruningRDD.scala | 24 ------------------- core/src/test/scala/spark/RDDSuite.scala | 6 ++--- 4 files changed, 27 insertions(+), 37 deletions(-) create mode 100644 core/src/main/scala/spark/rdd/PartitionPruningRDD.scala delete mode 100644 core/src/main/scala/spark/rdd/SplitsPruningRDD.scala diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 3d93ff33bb..e0d2eabb1d 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -40,7 +40,6 @@ import spark.rdd.MapPartitionsRDD import spark.rdd.MapPartitionsWithSplitRDD import spark.rdd.PipedRDD import spark.rdd.SampledRDD -import spark.rdd.SplitsPruningRDD import spark.rdd.UnionRDD import spark.rdd.ZippedRDD import spark.storage.StorageLevel @@ -544,15 +543,6 @@ abstract class RDD[T: ClassManifest]( map(x => (f(x), x)) } - /** - * Prune splits (partitions) so Spark can avoid launching tasks on - * all splits. An example use case: If we know the RDD is partitioned by range, - * and the execution DAG has a filter on the key, we can avoid launching tasks - * on splits that don't have the range covering the key. - */ - def pruneSplits(splitsFilterFunc: Int => Boolean): RDD[T] = - new SplitsPruningRDD(this, splitsFilterFunc) - /** A private method for tests, to look at the contents of each partition */ private[spark] def collectPartitions(): Array[Array[T]] = { sc.runJob(this, (iter: Iterator[T]) => iter.toArray) diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala new file mode 100644 index 0000000000..3048949ef2 --- /dev/null +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -0,0 +1,24 @@ +package spark.rdd + +import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} + +/** + * A RDD used to prune RDD partitions/splits so we can avoid launching tasks on + * all partitions. An example use case: If we know the RDD is partitioned by range, + * and the execution DAG has a filter on the key, we can avoid launching tasks + * on partitions that don't have the range covering the key. + */ +class PartitionPruningRDD[T: ClassManifest]( + @transient prev: RDD[T], + @transient partitionFilterFunc: Int => Boolean) + extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { + + @transient + val partitions_ : Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].splits + + override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) + + override protected def getSplits = partitions_ + + override val partitioner = firstParent[T].partitioner +} diff --git a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala b/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala deleted file mode 100644 index 9b1a210ba3..0000000000 --- a/core/src/main/scala/spark/rdd/SplitsPruningRDD.scala +++ /dev/null @@ -1,24 +0,0 @@ -package spark.rdd - -import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext} - -/** - * A RDD used to prune RDD splits so we can avoid launching tasks on - * all splits. An example use case: If we know the RDD is partitioned by range, - * and the execution DAG has a filter on the key, we can avoid launching tasks - * on splits that don't have the range covering the key. - */ -class SplitsPruningRDD[T: ClassManifest]( - @transient prev: RDD[T], - @transient splitsFilterFunc: Int => Boolean) - extends RDD[T](prev.context, List(new PruneDependency(prev, splitsFilterFunc))) { - - @transient - val _splits: Array[Split] = dependencies_.head.asInstanceOf[PruneDependency[T]].splits - - override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context) - - override protected def getSplits = _splits - - override val partitioner = firstParent[T].partitioner -} diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 5a3a12dfff..73846131a9 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -3,7 +3,7 @@ package spark import scala.collection.mutable.HashMap import org.scalatest.{BeforeAndAfter, FunSuite} import spark.SparkContext._ -import spark.rdd.CoalescedRDD +import spark.rdd.{CoalescedRDD, PartitionPruningRDD} class RDDSuite extends FunSuite with BeforeAndAfter { @@ -169,11 +169,11 @@ class RDDSuite extends FunSuite with BeforeAndAfter { } } - test("split pruning") { + test("partition pruning") { sc = new SparkContext("local", "test") val data = sc.parallelize(1 to 10, 10) // Note that split number starts from 0, so > 8 means only 10th partition left. - val prunedRdd = data.pruneSplits(splitNum => splitNum > 8) + val prunedRdd = new PartitionPruningRDD(data, splitNum => splitNum > 8) assert(prunedRdd.splits.size === 1) val prunedData = prunedRdd.collect assert(prunedData.size === 1) -- GitLab