From 27e43abd192440de5b10a5cc022fd5705362b276 Mon Sep 17 00:00:00 2001 From: Matei Zaharia <matei@eecs.berkeley.edu> Date: Tue, 27 Nov 2012 22:27:47 -0800 Subject: [PATCH] Added a zip() operation for RDDs with the same shape (number of partitions and number of elements in each partition) --- core/src/main/scala/spark/RDD.scala | 9 ++++ core/src/main/scala/spark/rdd/ZippedRDD.scala | 54 +++++++++++++++++++ core/src/test/scala/spark/RDDSuite.scala | 12 +++++ 3 files changed, 75 insertions(+) create mode 100644 core/src/main/scala/spark/rdd/ZippedRDD.scala diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 338dff4061..f4288a9661 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -42,6 +42,7 @@ import spark.rdd.MapPartitionsWithSplitRDD import spark.rdd.PipedRDD import spark.rdd.SampledRDD import spark.rdd.UnionRDD +import spark.rdd.ZippedRDD import spark.storage.StorageLevel import SparkContext._ @@ -293,6 +294,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] = new MapPartitionsWithSplitRDD(this, sc.clean(f)) + /** + * Zips this RDD with another one, returning key-value pairs with the first element in each RDD, + * second element in each RDD, etc. Assumes that the two RDDs have the *same number of + * partitions* and the *same number of elements in each partition* (e.g. one was made through + * a map on the other). + */ + def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other) + // Actions (launch a job to return a value to the user program) /** diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala new file mode 100644 index 0000000000..80f0150c45 --- /dev/null +++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala @@ -0,0 +1,54 @@ +package spark.rdd + +import spark.Dependency +import spark.OneToOneDependency +import spark.RDD +import spark.SparkContext +import spark.Split + +private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest]( + idx: Int, + rdd1: RDD[T], + rdd2: RDD[U], + split1: Split, + split2: Split) + extends Split + with Serializable { + + def iterator(): Iterator[(T, U)] = rdd1.iterator(split1).zip(rdd2.iterator(split2)) + + def preferredLocations(): Seq[String] = + rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2)) + + override val index: Int = idx +} + +class ZippedRDD[T: ClassManifest, U: ClassManifest]( + sc: SparkContext, + @transient rdd1: RDD[T], + @transient rdd2: RDD[U]) + extends RDD[(T, U)](sc) + with Serializable { + + @transient + val splits_ : Array[Split] = { + if (rdd1.splits.size != rdd2.splits.size) { + throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") + } + val array = new Array[Split](rdd1.splits.size) + for (i <- 0 until rdd1.splits.size) { + array(i) = new ZippedSplit(i, rdd1, rdd2, rdd1.splits(i), rdd2.splits(i)) + } + array + } + + override def splits = splits_ + + @transient + override val dependencies = List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2)) + + override def compute(s: Split): Iterator[(T, U)] = s.asInstanceOf[ZippedSplit[T, U]].iterator() + + override def preferredLocations(s: Split): Seq[String] = + s.asInstanceOf[ZippedSplit[T, U]].preferredLocations() +} diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 37a0ff0947..b3c820ed94 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -114,4 +114,16 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(coalesced4.glom().collect().map(_.toList).toList === (1 to 10).map(x => List(x)).toList) } + + test("zipped RDDs") { + sc = new SparkContext("local", "test") + val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) + val zipped = nums.zip(nums.map(_ + 1.0)) + assert(zipped.glom().map(_.toList).collect().toList === + List(List((1, 2.0), (2, 3.0)), List((3, 4.0), (4, 5.0)))) + + intercept[IllegalArgumentException] { + nums.zip(sc.parallelize(1 to 4, 1)).collect() + } + } } -- GitLab