From c402a4a685721d05932bbc578d997f330ff65a49 Mon Sep 17 00:00:00 2001 From: Kan Zhang <kzhang@apache.org> Date: Tue, 3 Jun 2014 22:47:18 -0700 Subject: [PATCH] [SPARK-1817] RDD.zip() should verify partition sizes for each partition RDD.zip() will throw an exception if it finds partition sizes are not the same. Author: Kan Zhang <kzhang@apache.org> Closes #944 from kanzhang/SPARK-1817 and squashes the following commits: c073848 [Kan Zhang] [SPARK-1817] Cosmetic updates 524c670 [Kan Zhang] [SPARK-1817] RDD.zip() should verify partition sizes for each partition --- .../main/scala/org/apache/spark/rdd/RDD.scala | 14 ++- .../org/apache/spark/rdd/ZippedRDD.scala | 87 ------------------- .../org/apache/spark/CheckpointSuite.scala | 26 +++--- .../scala/org/apache/spark/rdd/RDDSuite.scala | 4 + project/MimaExcludes.scala | 2 + 5 files changed, 33 insertions(+), 100 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 585b2f76af..54bdc3e7cb 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -654,7 +654,19 @@ abstract class RDD[T: ClassTag]( * partitions* and the *same number of elements in each partition* (e.g. one was made through * a map on the other). */ - def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other) + def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = { + zipPartitions(other, true) { (thisIter, otherIter) => + new Iterator[(T, U)] { + def hasNext = (thisIter.hasNext, otherIter.hasNext) match { + case (true, true) => true + case (false, false) => false + case _ => throw new SparkException("Can only zip RDDs with " + + "same number of elements in each partition") + } + def next = (thisIter.next, otherIter.next) + } + } + } /** * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala deleted file mode 100644 index b8110ffc42..0000000000 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import java.io.{IOException, ObjectOutputStream} - -import scala.reflect.ClassTag - -import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext} - -private[spark] class ZippedPartition[T: ClassTag, U: ClassTag]( - idx: Int, - @transient rdd1: RDD[T], - @transient rdd2: RDD[U] - ) extends Partition { - - var partition1 = rdd1.partitions(idx) - var partition2 = rdd2.partitions(idx) - override val index: Int = idx - - def partitions = (partition1, partition2) - - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { - // Update the reference to parent partition at the time of task serialization - partition1 = rdd1.partitions(idx) - partition2 = rdd2.partitions(idx) - oos.defaultWriteObject() - } -} - -private[spark] class ZippedRDD[T: ClassTag, U: ClassTag]( - sc: SparkContext, - var rdd1: RDD[T], - var rdd2: RDD[U]) - extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))) { - - override def getPartitions: Array[Partition] = { - if (rdd1.partitions.size != rdd2.partitions.size) { - throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") - } - val array = new Array[Partition](rdd1.partitions.size) - for (i <- 0 until rdd1.partitions.size) { - array(i) = new ZippedPartition(i, rdd1, rdd2) - } - array - } - - override def compute(s: Partition, context: TaskContext): Iterator[(T, U)] = { - val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions - rdd1.iterator(partition1, context).zip(rdd2.iterator(partition2, context)) - } - - override def getPreferredLocations(s: Partition): Seq[String] = { - val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions - val pref1 = rdd1.preferredLocations(partition1) - val pref2 = rdd2.preferredLocations(partition2) - // Check whether there are any hosts that match both RDDs; otherwise return the union - val exactMatchLocations = pref1.intersect(pref2) - if (!exactMatchLocations.isEmpty) { - exactMatchLocations - } else { - (pref1 ++ pref2).distinct - } - } - - override def clearDependencies() { - super.clearDependencies() - rdd1 = null - rdd2 = null - } -} diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index 64933f4b10..f64f3c9036 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -167,26 +167,28 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { }) } - test("ZippedRDD") { - testRDD(rdd => new ZippedRDD(sc, rdd, rdd.map(x => x))) - testRDDPartitions(rdd => new ZippedRDD(sc, rdd, rdd.map(x => x))) + test("ZippedPartitionsRDD") { + testRDD(rdd => rdd.zip(rdd.map(x => x))) + testRDDPartitions(rdd => rdd.zip(rdd.map(x => x))) - // Test that the ZippedPartition updates parent partitions - // after the parent RDD has been checkpointed and parent partitions have been changed. - // Note that this test is very specific to the current implementation of ZippedRDD. + // Test that ZippedPartitionsRDD updates parent partitions after parent RDDs have + // been checkpointed and parent partitions have been changed. + // Note that this test is very specific to the implementation of ZippedPartitionsRDD. val rdd = generateFatRDD() - val zippedRDD = new ZippedRDD(sc, rdd, rdd.map(x => x)) + val zippedRDD = rdd.zip(rdd.map(x => x)).asInstanceOf[ZippedPartitionsRDD2[_, _, _]] zippedRDD.rdd1.checkpoint() zippedRDD.rdd2.checkpoint() val partitionBeforeCheckpoint = - serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartition[_, _]]) + serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartitionsPartition]) zippedRDD.count() val partitionAfterCheckpoint = - serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartition[_, _]]) + serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartitionsPartition]) assert( - partitionAfterCheckpoint.partition1.getClass != partitionBeforeCheckpoint.partition1.getClass && - partitionAfterCheckpoint.partition2.getClass != partitionBeforeCheckpoint.partition2.getClass, - "ZippedRDD.partition1 and ZippedRDD.partition2 not updated after parent RDD is checkpointed" + partitionAfterCheckpoint.partitions(0).getClass != + partitionBeforeCheckpoint.partitions(0).getClass && + partitionAfterCheckpoint.partitions(1).getClass != + partitionBeforeCheckpoint.partitions(1).getClass, + "ZippedPartitionsRDD partition 0 (or 1) not updated after parent RDDs are checkpointed" ) } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index bbd0c14178..286e221e33 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -350,6 +350,10 @@ class RDDSuite extends FunSuite with SharedSparkContext { intercept[IllegalArgumentException] { nums.zip(sc.parallelize(1 to 4, 1)).collect() } + + intercept[SparkException] { + nums.zip(sc.parallelize(1 to 5, 2)).collect() + } } test("partition pruning") { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index fadf6a4d8b..dd7efceb23 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -54,6 +54,8 @@ object MimaExcludes { ProblemFilters.exclude[MissingMethodProblem]( "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1") ) ++ + MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++ + MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++ MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") case v if v.startsWith("1.0") => Seq( -- GitLab