Skip to content
Snippets Groups Projects
Commit c402a4a6 authored by Kan Zhang's avatar Kan Zhang Committed by Reynold Xin
Browse files

[SPARK-1817] RDD.zip() should verify partition sizes for each partition

RDD.zip() will throw an exception if it finds partition sizes are not the same.

Author: Kan Zhang <kzhang@apache.org>

Closes #944 from kanzhang/SPARK-1817 and squashes the following commits:

c073848 [Kan Zhang] [SPARK-1817] Cosmetic updates
524c670 [Kan Zhang] [SPARK-1817] RDD.zip() should verify partition sizes for each partition
parent 4ca06256
No related branches found
No related tags found
No related merge requests found
......@@ -654,7 +654,19 @@ abstract class RDD[T: ClassTag](
* partitions* and the *same number of elements in each partition* (e.g. one was made through
* a map on the other).
*/
def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other)
def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = {
zipPartitions(other, true) { (thisIter, otherIter) =>
new Iterator[(T, U)] {
def hasNext = (thisIter.hasNext, otherIter.hasNext) match {
case (true, true) => true
case (false, false) => false
case _ => throw new SparkException("Can only zip RDDs with " +
"same number of elements in each partition")
}
def next = (thisIter.next, otherIter.next)
}
}
}
/**
* Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
......
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.rdd
import java.io.{IOException, ObjectOutputStream}
import scala.reflect.ClassTag
import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}
private[spark] class ZippedPartition[T: ClassTag, U: ClassTag](
idx: Int,
@transient rdd1: RDD[T],
@transient rdd2: RDD[U]
) extends Partition {
var partition1 = rdd1.partitions(idx)
var partition2 = rdd2.partitions(idx)
override val index: Int = idx
def partitions = (partition1, partition2)
@throws(classOf[IOException])
private def writeObject(oos: ObjectOutputStream) {
// Update the reference to parent partition at the time of task serialization
partition1 = rdd1.partitions(idx)
partition2 = rdd2.partitions(idx)
oos.defaultWriteObject()
}
}
private[spark] class ZippedRDD[T: ClassTag, U: ClassTag](
sc: SparkContext,
var rdd1: RDD[T],
var rdd2: RDD[U])
extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))) {
override def getPartitions: Array[Partition] = {
if (rdd1.partitions.size != rdd2.partitions.size) {
throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
}
val array = new Array[Partition](rdd1.partitions.size)
for (i <- 0 until rdd1.partitions.size) {
array(i) = new ZippedPartition(i, rdd1, rdd2)
}
array
}
override def compute(s: Partition, context: TaskContext): Iterator[(T, U)] = {
val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions
rdd1.iterator(partition1, context).zip(rdd2.iterator(partition2, context))
}
override def getPreferredLocations(s: Partition): Seq[String] = {
val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions
val pref1 = rdd1.preferredLocations(partition1)
val pref2 = rdd2.preferredLocations(partition2)
// Check whether there are any hosts that match both RDDs; otherwise return the union
val exactMatchLocations = pref1.intersect(pref2)
if (!exactMatchLocations.isEmpty) {
exactMatchLocations
} else {
(pref1 ++ pref2).distinct
}
}
override def clearDependencies() {
super.clearDependencies()
rdd1 = null
rdd2 = null
}
}
......@@ -167,26 +167,28 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
})
}
test("ZippedRDD") {
testRDD(rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)))
testRDDPartitions(rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)))
test("ZippedPartitionsRDD") {
testRDD(rdd => rdd.zip(rdd.map(x => x)))
testRDDPartitions(rdd => rdd.zip(rdd.map(x => x)))
// Test that the ZippedPartition updates parent partitions
// after the parent RDD has been checkpointed and parent partitions have been changed.
// Note that this test is very specific to the current implementation of ZippedRDD.
// Test that ZippedPartitionsRDD updates parent partitions after parent RDDs have
// been checkpointed and parent partitions have been changed.
// Note that this test is very specific to the implementation of ZippedPartitionsRDD.
val rdd = generateFatRDD()
val zippedRDD = new ZippedRDD(sc, rdd, rdd.map(x => x))
val zippedRDD = rdd.zip(rdd.map(x => x)).asInstanceOf[ZippedPartitionsRDD2[_, _, _]]
zippedRDD.rdd1.checkpoint()
zippedRDD.rdd2.checkpoint()
val partitionBeforeCheckpoint =
serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartition[_, _]])
serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartitionsPartition])
zippedRDD.count()
val partitionAfterCheckpoint =
serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartition[_, _]])
serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartitionsPartition])
assert(
partitionAfterCheckpoint.partition1.getClass != partitionBeforeCheckpoint.partition1.getClass &&
partitionAfterCheckpoint.partition2.getClass != partitionBeforeCheckpoint.partition2.getClass,
"ZippedRDD.partition1 and ZippedRDD.partition2 not updated after parent RDD is checkpointed"
partitionAfterCheckpoint.partitions(0).getClass !=
partitionBeforeCheckpoint.partitions(0).getClass &&
partitionAfterCheckpoint.partitions(1).getClass !=
partitionBeforeCheckpoint.partitions(1).getClass,
"ZippedPartitionsRDD partition 0 (or 1) not updated after parent RDDs are checkpointed"
)
}
......
......@@ -350,6 +350,10 @@ class RDDSuite extends FunSuite with SharedSparkContext {
intercept[IllegalArgumentException] {
nums.zip(sc.parallelize(1 to 4, 1)).collect()
}
intercept[SparkException] {
nums.zip(sc.parallelize(1 to 5, 2)).collect()
}
}
test("partition pruning") {
......
......@@ -54,6 +54,8 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1")
) ++
MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++
MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++
MimaBuild.excludeSparkClass("util.SerializableHyperLogLog")
case v if v.startsWith("1.0") =>
Seq(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment