Skip to content
Snippets Groups Projects
Commit 27e43abd authored by Matei Zaharia's avatar Matei Zaharia
Browse files

Added a zip() operation for RDDs with the same shape (number of

partitions and number of elements in each partition)
parent 59c0a9ad
No related branches found
No related tags found
No related merge requests found
...@@ -42,6 +42,7 @@ import spark.rdd.MapPartitionsWithSplitRDD ...@@ -42,6 +42,7 @@ import spark.rdd.MapPartitionsWithSplitRDD
import spark.rdd.PipedRDD import spark.rdd.PipedRDD
import spark.rdd.SampledRDD import spark.rdd.SampledRDD
import spark.rdd.UnionRDD import spark.rdd.UnionRDD
import spark.rdd.ZippedRDD
import spark.storage.StorageLevel import spark.storage.StorageLevel
import SparkContext._ import SparkContext._
...@@ -293,6 +294,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial ...@@ -293,6 +294,14 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] = def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] =
new MapPartitionsWithSplitRDD(this, sc.clean(f)) new MapPartitionsWithSplitRDD(this, sc.clean(f))
/**
* Zips this RDD with another one, returning key-value pairs with the first element in each RDD,
* second element in each RDD, etc. Assumes that the two RDDs have the *same number of
* partitions* and the *same number of elements in each partition* (e.g. one was made through
* a map on the other).
*/
def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other)
// Actions (launch a job to return a value to the user program) // Actions (launch a job to return a value to the user program)
/** /**
......
package spark.rdd
import spark.Dependency
import spark.OneToOneDependency
import spark.RDD
import spark.SparkContext
import spark.Split
private[spark] class ZippedSplit[T: ClassManifest, U: ClassManifest](
idx: Int,
rdd1: RDD[T],
rdd2: RDD[U],
split1: Split,
split2: Split)
extends Split
with Serializable {
def iterator(): Iterator[(T, U)] = rdd1.iterator(split1).zip(rdd2.iterator(split2))
def preferredLocations(): Seq[String] =
rdd1.preferredLocations(split1).intersect(rdd2.preferredLocations(split2))
override val index: Int = idx
}
class ZippedRDD[T: ClassManifest, U: ClassManifest](
sc: SparkContext,
@transient rdd1: RDD[T],
@transient rdd2: RDD[U])
extends RDD[(T, U)](sc)
with Serializable {
@transient
val splits_ : Array[Split] = {
if (rdd1.splits.size != rdd2.splits.size) {
throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
}
val array = new Array[Split](rdd1.splits.size)
for (i <- 0 until rdd1.splits.size) {
array(i) = new ZippedSplit(i, rdd1, rdd2, rdd1.splits(i), rdd2.splits(i))
}
array
}
override def splits = splits_
@transient
override val dependencies = List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))
override def compute(s: Split): Iterator[(T, U)] = s.asInstanceOf[ZippedSplit[T, U]].iterator()
override def preferredLocations(s: Split): Seq[String] =
s.asInstanceOf[ZippedSplit[T, U]].preferredLocations()
}
...@@ -114,4 +114,16 @@ class RDDSuite extends FunSuite with BeforeAndAfter { ...@@ -114,4 +114,16 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
assert(coalesced4.glom().collect().map(_.toList).toList === assert(coalesced4.glom().collect().map(_.toList).toList ===
(1 to 10).map(x => List(x)).toList) (1 to 10).map(x => List(x)).toList)
} }
test("zipped RDDs") {
sc = new SparkContext("local", "test")
val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
val zipped = nums.zip(nums.map(_ + 1.0))
assert(zipped.glom().map(_.toList).collect().toList ===
List(List((1, 2.0), (2, 3.0)), List((3, 4.0), (4, 5.0))))
intercept[IllegalArgumentException] {
nums.zip(sc.parallelize(1 to 4, 1)).collect()
}
}
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment