Skip to content
Snippets Groups Projects
Commit 571af313 authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Merge pull request #433 from rxin/master

Changed PartitionPruningRDD's split to make sure it returns the correct split index.
parents 5ce5efec f9af9cee
No related branches found
No related tags found
No related merge requests found
......@@ -61,17 +61,3 @@ class RangeDependency[T](rdd: RDD[T], inStart: Int, outStart: Int, length: Int)
}
}
}
/**
* Represents a dependency between the PartitionPruningRDD and its parent. In this
* case, the child RDD contains a subset of partitions of the parents'.
*/
class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean)
extends NarrowDependency[T](rdd) {
@transient
val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index))
override def getParents(partitionId: Int) = List(partitions(partitionId).index)
}
package spark.rdd
import spark.{PruneDependency, RDD, SparkEnv, Split, TaskContext}
import spark.{NarrowDependency, RDD, SparkEnv, Split, TaskContext}
class PartitionPruningRDDSplit(idx: Int, val parentSplit: Split) extends Split {
override val index = idx
}
/**
* Represents a dependency between the PartitionPruningRDD and its parent. In this
* case, the child RDD contains a subset of partitions of the parents'.
*/
class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boolean)
extends NarrowDependency[T](rdd) {
@transient
val partitions: Array[Split] = rdd.splits.filter(s => partitionFilterFunc(s.index))
.zipWithIndex.map { case(split, idx) => new PartitionPruningRDDSplit(idx, split) : Split }
override def getParents(partitionId: Int) = List(partitions(partitionId).index)
}
/**
* A RDD used to prune RDD partitions/splits so we can avoid launching tasks on
* all partitions. An example use case: If we know the RDD is partitioned by range,
* and the execution DAG has a filter on the key, we can avoid launching tasks
* on partitions that don't have the range covering the key.
*
* TODO: This currently doesn't give partition IDs properly!
*/
class PartitionPruningRDD[T: ClassManifest](
@transient prev: RDD[T],
@transient partitionFilterFunc: Int => Boolean)
extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) {
override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(split, context)
override def compute(split: Split, context: TaskContext) = firstParent[T].iterator(
split.asInstanceOf[PartitionPruningRDDSplit].parentSplit, context)
override protected def getSplits =
getDependencies.head.asInstanceOf[PruneDependency[T]].partitions
override val partitioner = firstParent[T].partitioner
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment