diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
index 165cd412fcfb85f23089719205e53d49c4af12bc..2738a0089409a7c1bc39deef85ed4592f8df0d3f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
@@ -34,10 +34,12 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo
 
   @transient
   val partitions: Array[Partition] = rdd.partitions.zipWithIndex
-    .filter(s => partitionFilterFunc(s._2))
+    .filter(s => partitionFilterFunc(s._2)).map(_._1).zipWithIndex
     .map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition }
 
-  override def getParents(partitionId: Int) = List(partitions(partitionId).index)
+  override def getParents(partitionId: Int) = {
+    List(partitions(partitionId).asInstanceOf[PartitionPruningRDDPartition].parentSplit.index)
+  }
 }
 
 
diff --git a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala
index 21f16ef2c6ececa8db7126f8897152ff5b523b28..28e71e835fa2cf2bc92e8a8e4e4e600e392c8ea5 100644
--- a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala
@@ -19,27 +19,75 @@ package org.apache.spark
 
 import org.scalatest.FunSuite
 import org.apache.spark.SparkContext._
-import org.apache.spark.rdd.{RDD, PartitionPruningRDD}
+import org.apache.spark.rdd.{PartitionPruningRDDPartition, RDD, PartitionPruningRDD}
 
 
 class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext {
 
+
   test("Pruned Partitions inherit locality prefs correctly") {
-    class TestPartition(i: Int) extends Partition {
-      def index = i
-    }
+
     val rdd = new RDD[Int](sc, Nil) {
       override protected def getPartitions = {
         Array[Partition](
-            new TestPartition(1),
-            new TestPartition(2), 
-            new TestPartition(3))
+          new TestPartition(0, 1),
+          new TestPartition(1, 1),
+          new TestPartition(2, 1))
+      }
+
+      def compute(split: Partition, context: TaskContext) = {
+        Iterator()
       }
-      def compute(split: Partition, context: TaskContext) = {Iterator()}
     }
-    val prunedRDD = PartitionPruningRDD.create(rdd, {x => if (x==2) true else false})
-    val p = prunedRDD.partitions(0)
-    assert(p.index == 2)
+    val prunedRDD = PartitionPruningRDD.create(rdd, {
+      x => if (x == 2) true else false
+    })
     assert(prunedRDD.partitions.length == 1)
+    val p = prunedRDD.partitions(0)
+    assert(p.index == 0)
+    assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2)
   }
+
+
+  test("Pruned Partitions can be merged ") {
+
+    val rdd = new RDD[Int](sc, Nil) {
+      override protected def getPartitions = {
+        Array[Partition](
+          new TestPartition(0, 4),
+          new TestPartition(1, 5),
+          new TestPartition(2, 6))
+      }
+
+      def compute(split: Partition, context: TaskContext) = {
+        List(split.asInstanceOf[TestPartition].testValue).iterator
+      }
+    }
+    val prunedRDD1 = PartitionPruningRDD.create(rdd, {
+      x => if (x == 0) true else false
+    })
+
+    val prunedRDD2 = PartitionPruningRDD.create(rdd, {
+      x => if (x == 2) true else false
+    })
+
+    val merged = prunedRDD1 ++ prunedRDD2
+
+    assert(merged.count() == 2)
+    val take = merged.take(2)
+
+    assert(take.apply(0) == 4)
+
+    assert(take.apply(1) == 6)
+
+
+  }
+
 }
+
+class TestPartition(i: Int, value: Int) extends Partition with Serializable {
+  def index = i
+
+  def testValue = this.value
+
+}
\ No newline at end of file