diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 338dff40618cc8469f693e092439f88e7913e8b2..f4288a9661d6d3b0b2b45ae2c6bda52459fa58db 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -42,6 +42,7 @@ import spark.rdd.MapPartitionsWithSplitRDD import spark.rdd.PipedRDD import spark.rdd.SampledRDD import spark.rdd.UnionRDD +import spark.rdd.ZippedRDD import spark.storage.StorageLevel import SparkContext._ @@ -293,6 +294,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] = new MapPartitionsWithSplitRDD(this, sc.clean(f)) + /** + * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, + * second element in each RDD, etc. Assumes that the two RDDs have the *same number of + * partitions* and the *same number of elements in each partition* (e.g. one was made through + * a map on the other). + */ + def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other) + // Actions (launch a job to return a value to the user program) /** diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala new file mode 100644 index 0000000000000000000000000000000000000000..80f0150c45c2dd831f0d101434a1607c056b4b9a --- /dev/null +++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala @@ -0,0 +1,54 @@ +package spark.rdd + +import spark.Dependency +import spark.OneToOneDependency +import spark.RDD +import spark.SparkContext +import spark.Split + +private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest]( + idx: Int, + rdd1: RDD[T], + rdd2: RDD[U], + split1: Split, + split2: Split) + extends Split + with Serializable { + + def iterator(): Iterator[(T, U)] = rdd1.iterator(split1).zip(rdd2.iterator(split2)) + + def preferredLocations(): Seq[String] = + rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2)) + + override val index: Int = idx +} + +class ZippedRDD[T: ClassManifest, U: ClassManifest]( + sc: SparkContext, + @transient rdd1: RDD[T], + @transient rdd2: RDD[U]) + extends RDD[(T, U)](sc) + with Serializable { + + @transient + val splits_ : Array[Split] = { + if (rdd1.splits.size != rdd2.splits.size) { + throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") + } + val array = new Array[Split](rdd1.splits.size) + for (i <- 0 until rdd1.splits.size) { + array(i) = new ZippedSplit(i, rdd1, rdd2, rdd1.splits(i), rdd2.splits(i)) + } + array + } + + override def splits = splits_ + + @transient + override val dependencies = List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2)) + + override def compute(s: Split): Iterator[(T, U)] = s.asInstanceOf[ZippedSplit[T, U]].iterator() + + override def preferredLocations(s: Split): Seq[String] = + s.asInstanceOf[ZippedSplit[T, U]].preferredLocations() +} diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 37a0ff09477b2c0dfc5909c3d7a0b41975be0eff..b3c820ed94af8b4fa38252745ad51ccb8ac00277 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -114,4 +114,16 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(coalesced4.glom().collect().map(_.toList).toList === (1 to 10).map(x => List(x)).toList) } + + test("zipped RDDs") { + sc = new SparkContext("local", "test") + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val zipped = nums.zip(nums.map(_ + 1.0)) + assert(zipped.glom().map(_.toList).collect().toList === + List(List((1, 2.0), (2, 3.0)), List((3, 4.0), (4, 5.0)))) + + intercept[IllegalArgumentException] { + nums.zip(sc.parallelize(1 to 4, 1)).collect() + } + } }