diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 4310f745f37d492d447d4f4054196b76a20ffa80..09e52ebf3ec008dbe9080b40d07a63755e58aa37 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -439,6 +439,12 @@ abstract class RDD[T: ClassManifest]( */ def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other) + /** + * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by + * applying a function to the zipped partitions. Assumes that all the RDDs have the + * *same number of partitions*, but does *not* require them to have the same number + * of elements in each partition. + */ def zipPartitions[B: ClassManifest, V: ClassManifest]( f: (Iterator[T], Iterator[B]) => Iterator[V], rdd2: RDD[B]): RDD[V] = diff --git a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala index b3113c1969182d55b6655e0e0b78c765ba7fef3e..fc3f29ffcda3ecf40bb6acb468e5627eb7506e35 100644 --- a/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/ZippedPartitionsRDD.scala @@ -3,7 +3,7 @@ package spark.rdd import spark.{OneToOneDependency, RDD, SparkContext, Partition, TaskContext} import java.io.{ObjectOutputStream, IOException} -private[spark] class ZippedPartitions( +private[spark] class ZippedPartitionsPartition( idx: Int, @transient rdds: Seq[RDD[_]]) extends Partition { @@ -32,13 +32,13 @@ abstract class ZippedPartitionsBaseRDD[V: ClassManifest]( } val array = new Array[Partition](sizes(0)) for (i <- 0 until sizes(0)) { - array(i) = new ZippedPartitions(i, rdds) + array(i) = new ZippedPartitionsPartition(i, rdds) } array } override def getPreferredLocations(s: Partition): Seq[String] = { - val splits = s.asInstanceOf[ZippedPartitions].partitions + val splits = s.asInstanceOf[ZippedPartitionsPartition].partitions val preferredLocations = rdds.zip(splits).map(x => x._1.preferredLocations(x._2)) preferredLocations.reduce((x, y) => x.intersect(y)) } @@ -57,7 +57,7 @@ class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest] extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { - val partitions = s.asInstanceOf[ZippedPartitions].partitions + val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context)) } @@ -78,7 +78,7 @@ class ZippedPartitionsRDD3 extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { - val partitions = s.asInstanceOf[ZippedPartitions].partitions + val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context), rdd3.iterator(partitions(2), context)) @@ -103,7 +103,7 @@ class ZippedPartitionsRDD4 extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { - val partitions = s.asInstanceOf[ZippedPartitions].partitions + val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context), rdd3.iterator(partitions(2), context),