diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index ccd9d0364ad9582bf319dda96fb57566a5da60ae..09e52ebf3ec008dbe9080b40d07a63755e58aa37 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -35,6 +35,9 @@ import spark.rdd.ShuffledRDD import spark.rdd.SubtractedRDD import spark.rdd.UnionRDD import spark.rdd.ZippedRDD +import spark.rdd.ZippedPartitionsRDD2 +import spark.rdd.ZippedPartitionsRDD3 +import spark.rdd.ZippedPartitionsRDD4 import spark.storage.StorageLevel import SparkContext._ @@ -436,6 +439,31 @@ abstract class RDD[T: ClassManifest]( */ def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other) + /** + * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by + * applying a function to the zipped partitions. Assumes that all the RDDs have the + * *same number of partitions*, but does *not* require them to have the same number + * of elements in each partition. + */ + def zipPartitions[B: ClassManifest, V: ClassManifest]( + f: (Iterator[T], Iterator[B]) => Iterator[V], + rdd2: RDD[B]): RDD[V] = + new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2) + + def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest]( + f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V], + rdd2: RDD[B], + rdd3: RDD[C]): RDD[V] = + new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3) + + def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest]( + f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], + rdd2: RDD[B], + rdd3: RDD[C], + rdd4: RDD[D]): RDD[V] = + new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4) + + // Actions (launch a job to return a value to the user program) /** diff --git a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala new file mode 100644 index 0000000000000000000000000000000000000000..fc3f29ffcda3ecf40bb6acb468e5627eb7506e35 --- /dev/null +++ b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala @@ -0,0 +1,120 @@ +package spark.rdd + +import spark.{OneToOneDependency, RDD, SparkContext, Partition, TaskContext} +import java.io.{ObjectOutputStream, IOException} + +private[spark] class ZippedPartitionsPartition( + idx: Int, + @transient rdds: Seq[RDD[_]]) + extends Partition { + + override val index: Int = idx + var partitionValues = rdds.map(rdd => rdd.partitions(idx)) + def partitions = partitionValues + + @throws(classOf[IOException]) + private def writeObject(oos: ObjectOutputStream) { + // Update the reference to parent split at the time of task serialization + partitionValues = rdds.map(rdd => rdd.partitions(idx)) + oos.defaultWriteObject() + } +} + +abstract class ZippedPartitionsBaseRDD[V: ClassManifest]( + sc: SparkContext, + var rdds: Seq[RDD[_]]) + extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) { + + override def getPartitions: Array[Partition] = { + val sizes = rdds.map(x => x.partitions.size) + if (!sizes.forall(x => x == sizes(0))) { + throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") + } + val array = new Array[Partition](sizes(0)) + for (i <- 0 until sizes(0)) { + array(i) = new ZippedPartitionsPartition(i, rdds) + } + array + } + + override def getPreferredLocations(s: Partition): Seq[String] = { + val splits = s.asInstanceOf[ZippedPartitionsPartition].partitions + val preferredLocations = rdds.zip(splits).map(x => x._1.preferredLocations(x._2)) + preferredLocations.reduce((x, y) => x.intersect(y)) + } + + override def clearDependencies() { + super.clearDependencies() + rdds = null + } +} + +class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest]( + sc: SparkContext, + f: (Iterator[A], Iterator[B]) => Iterator[V], + var rdd1: RDD[A], + var rdd2: RDD[B]) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) { + + override def compute(s: Partition, context: TaskContext): Iterator[V] = { + val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions + f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context)) + } + + override def clearDependencies() { + super.clearDependencies() + rdd1 = null + rdd2 = null + } +} + +class ZippedPartitionsRDD3 + [A: ClassManifest, B: ClassManifest, C: ClassManifest, V: ClassManifest]( + sc: SparkContext, + f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V], + var rdd1: RDD[A], + var rdd2: RDD[B], + var rdd3: RDD[C]) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) { + + override def compute(s: Partition, context: TaskContext): Iterator[V] = { + val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions + f(rdd1.iterator(partitions(0), context), + rdd2.iterator(partitions(1), context), + rdd3.iterator(partitions(2), context)) + } + + override def clearDependencies() { + super.clearDependencies() + rdd1 = null + rdd2 = null + rdd3 = null + } +} + +class ZippedPartitionsRDD4 + [A: ClassManifest, B: ClassManifest, C: ClassManifest, D:ClassManifest, V: ClassManifest]( + sc: SparkContext, + f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], + var rdd1: RDD[A], + var rdd2: RDD[B], + var rdd3: RDD[C], + var rdd4: RDD[D]) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) { + + override def compute(s: Partition, context: TaskContext): Iterator[V] = { + val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions + f(rdd1.iterator(partitions(0), context), + rdd2.iterator(partitions(1), context), + rdd3.iterator(partitions(2), context), + rdd4.iterator(partitions(3), context)) + } + + override def clearDependencies() { + super.clearDependencies() + rdd1 = null + rdd2 = null + rdd3 = null + rdd4 = null + } +} diff --git a/core/src/test/scala/spark/ZippedPartitionsSuite.scala b/core/src/test/scala/spark/ZippedPartitionsSuite.scala new file mode 100644 index 0000000000000000000000000000000000000000..5f60aa75d7f0334a4d99662e04abc9d785c47b4d --- /dev/null +++ b/core/src/test/scala/spark/ZippedPartitionsSuite.scala @@ -0,0 +1,34 @@ +package spark + +import scala.collection.immutable.NumericRange + +import org.scalatest.FunSuite +import org.scalatest.prop.Checkers +import org.scalacheck.Arbitrary._ +import org.scalacheck.Gen +import org.scalacheck.Prop._ + +import SparkContext._ + + +object ZippedPartitionsSuite { + def procZippedData(i: Iterator[Int], s: Iterator[String], d: Iterator[Double]) : Iterator[Int] = { + Iterator(i.toArray.size, s.toArray.size, d.toArray.size) + } +} + +class ZippedPartitionsSuite extends FunSuite with LocalSparkContext { + test("print sizes") { + sc = new SparkContext("local", "test") + val data1 = sc.makeRDD(Array(1, 2, 3, 4), 2) + val data2 = sc.makeRDD(Array("1", "2", "3", "4", "5", "6"), 2) + val data3 = sc.makeRDD(Array(1.0, 2.0), 2) + + val zippedRDD = data1.zipPartitions(ZippedPartitionsSuite.procZippedData, data2, data3) + + val obtainedSizes = zippedRDD.collect() + val expectedSizes = Array(2, 3, 1, 2, 3, 1) + assert(obtainedSizes.size == 6) + assert(obtainedSizes.zip(expectedSizes).forall(x => x._1 == x._2)) + } +}