From 27e43abd192440de5b10a5cc022fd5705362b276 Mon Sep 17 00:00:00 2001
From: Matei Zaharia <matei@eecs.berkeley.edu>
Date: Tue, 27 Nov 2012 22:27:47 -0800
Subject: [PATCH] Added a zip() operation for RDDs with the same shape (number
 of partitions and number of elements in each partition)

---
 core/src/main/scala/spark/RDD.scala           |  9 ++++
 core/src/main/scala/spark/rdd/ZippedRDD.scala | 54 +++++++++++++++++++
 core/src/test/scala/spark/RDDSuite.scala      | 12 +++++
 3 files changed, 75 insertions(+)
 create mode 100644 core/src/main/scala/spark/rdd/ZippedRDD.scala

diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 338dff4061..f4288a9661 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -42,6 +42,7 @@ import spark.rdd.MapPartitionsWithSplitRDD
 import spark.rdd.PipedRDD
 import spark.rdd.SampledRDD
 import spark.rdd.UnionRDD
+import spark.rdd.ZippedRDD
 import spark.storage.StorageLevel
 
 import SparkContext._
@@ -293,6 +294,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
   def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] =
     new MapPartitionsWithSplitRDD(this, sc.clean(f))
 
+  /**
+   * Zips this RDD with another one, returning key-value pairs with the first element in each RDD,
+   * second element in each RDD, etc. Assumes that the two RDDs have the *same number of
+   * partitions* and the *same number of elements in each partition* (e.g. one was made through
+   * a map on the other).
+   */
+  def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other)
+
   // Actions (launch a job to return a value to the user program)
 
   /**
diff --git a/core/src/main/scala/spark/rdd/ZippedRDD.scala b/core/src/main/scala/spark/rdd/ZippedRDD.scala
new file mode 100644
index 0000000000..80f0150c45
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/ZippedRDD.scala
@@ -0,0 +1,54 @@
+package spark.rdd
+
+import spark.Dependency
+import spark.OneToOneDependency
+import spark.RDD
+import spark.SparkContext
+import spark.Split
+
+private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest](
+    idx: Int, 
+    rdd1: RDD[T],
+    rdd2: RDD[U],
+    split1: Split,
+    split2: Split)
+  extends Split
+  with Serializable {
+  
+  def iterator(): Iterator[(T, U)] = rdd1.iterator(split1).zip(rdd2.iterator(split2))
+
+  def preferredLocations(): Seq[String] =
+    rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
+
+  override val index: Int = idx
+}
+
+class ZippedRDD[T: ClassManifest, U: ClassManifest](
+    sc: SparkContext,
+    @transient rdd1: RDD[T],
+    @transient rdd2: RDD[U])
+  extends RDD[(T, U)](sc)
+  with Serializable {
+
+  @transient
+  val splits_ : Array[Split] = {
+    if (rdd1.splits.size != rdd2.splits.size) {
+      throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
+    }
+    val array = new Array[Split](rdd1.splits.size)
+    for (i <- 0 until rdd1.splits.size) {
+      array(i) = new ZippedSplit(i, rdd1, rdd2, rdd1.splits(i), rdd2.splits(i))
+    }
+    array
+  }
+
+  override def splits = splits_
+
+  @transient
+  override val dependencies = List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))
+  
+  override def compute(s: Split): Iterator[(T, U)] = s.asInstanceOf[ZippedSplit[T, U]].iterator()
+
+  override def preferredLocations(s: Split): Seq[String] =
+    s.asInstanceOf[ZippedSplit[T, U]].preferredLocations()
+}
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index 37a0ff0947..b3c820ed94 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -114,4 +114,16 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
     assert(coalesced4.glom().collect().map(_.toList).toList ===
       (1 to 10).map(x => List(x)).toList)
   }
+
+  test("zipped RDDs") {
+    sc = new SparkContext("local", "test")
+    val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
+    val zipped = nums.zip(nums.map(_ + 1.0))
+    assert(zipped.glom().map(_.toList).collect().toList ===
+      List(List((1, 2.0), (2, 3.0)), List((3, 4.0), (4, 5.0))))
+    
+    intercept[IllegalArgumentException] {
+      nums.zip(sc.parallelize(1 to 4, 1)).collect()
+    }
+  }
 }
-- 
GitLab