Skip to content
Snippets Groups Projects
Commit 13b9bf49 authored by Matthew Taylor's avatar Matthew Taylor
Browse files

PartitionPruningRDD is using index from parent

parent e2ebc3a9
No related branches found
No related tags found
No related merge requests found
......@@ -34,10 +34,12 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo
@transient
val partitions: Array[Partition] = rdd.partitions.zipWithIndex
.filter(s => partitionFilterFunc(s._2))
.filter(s => partitionFilterFunc(s._2)).map(_._1).zipWithIndex
.map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition }
override def getParents(partitionId: Int) = List(partitions(partitionId).index)
override def getParents(partitionId: Int) = {
List(partitions(partitionId).asInstanceOf[PartitionPruningRDDPartition].parentSplit.index)
}
}
......
......@@ -19,27 +19,75 @@ package org.apache.spark
import org.scalatest.FunSuite
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.{RDD, PartitionPruningRDD}
import org.apache.spark.rdd.{PartitionPruningRDDPartition, RDD, PartitionPruningRDD}
class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext {
test("Pruned Partitions inherit locality prefs correctly") {
class TestPartition(i: Int) extends Partition {
def index = i
}
val rdd = new RDD[Int](sc, Nil) {
override protected def getPartitions = {
Array[Partition](
new TestPartition(1),
new TestPartition(2),
new TestPartition(3))
new TestPartition(0, 1),
new TestPartition(1, 1),
new TestPartition(2, 1))
}
def compute(split: Partition, context: TaskContext) = {
Iterator()
}
def compute(split: Partition, context: TaskContext) = {Iterator()}
}
val prunedRDD = PartitionPruningRDD.create(rdd, {x => if (x==2) true else false})
val p = prunedRDD.partitions(0)
assert(p.index == 2)
val prunedRDD = PartitionPruningRDD.create(rdd, {
x => if (x == 2) true else false
})
assert(prunedRDD.partitions.length == 1)
val p = prunedRDD.partitions(0)
assert(p.index == 0)
assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2)
}
test("Pruned Partitions can be merged ") {
val rdd = new RDD[Int](sc, Nil) {
override protected def getPartitions = {
Array[Partition](
new TestPartition(0, 4),
new TestPartition(1, 5),
new TestPartition(2, 6))
}
def compute(split: Partition, context: TaskContext) = {
List(split.asInstanceOf[TestPartition].testValue).iterator
}
}
val prunedRDD1 = PartitionPruningRDD.create(rdd, {
x => if (x == 0) true else false
})
val prunedRDD2 = PartitionPruningRDD.create(rdd, {
x => if (x == 2) true else false
})
val merged = prunedRDD1 ++ prunedRDD2
assert(merged.count() == 2)
val take = merged.take(2)
assert(take.apply(0) == 4)
assert(take.apply(1) == 6)
}
}
class TestPartition(i: Int, value: Int) extends Partition with Serializable {
def index = i
def testValue = this.value
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment