From c402a4a685721d05932bbc578d997f330ff65a49 Mon Sep 17 00:00:00 2001
From: Kan Zhang <kzhang@apache.org>
Date: Tue, 3 Jun 2014 22:47:18 -0700
Subject: [PATCH] [SPARK-1817] RDD.zip() should verify partition sizes for each
 partition

RDD.zip() will throw an exception if it finds partition sizes are not the same.

Author: Kan Zhang <kzhang@apache.org>

Closes #944 from kanzhang/SPARK-1817 and squashes the following commits:

c073848 [Kan Zhang] [SPARK-1817] Cosmetic updates
524c670 [Kan Zhang] [SPARK-1817] RDD.zip() should verify partition sizes for each partition
---
 .../main/scala/org/apache/spark/rdd/RDD.scala | 14 ++-
 .../org/apache/spark/rdd/ZippedRDD.scala      | 87 -------------------
 .../org/apache/spark/CheckpointSuite.scala    | 26 +++---
 .../scala/org/apache/spark/rdd/RDDSuite.scala |  4 +
 project/MimaExcludes.scala                    |  2 +
 5 files changed, 33 insertions(+), 100 deletions(-)
 delete mode 100644 core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala

diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 585b2f76af..54bdc3e7cb 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -654,7 +654,19 @@ abstract class RDD[T: ClassTag](
    * partitions* and the *same number of elements in each partition* (e.g. one was made through
    * a map on the other).
    */
-  def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other)
+  def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = {
+    zipPartitions(other, true) { (thisIter, otherIter) =>
+      new Iterator[(T, U)] {
+        def hasNext = (thisIter.hasNext, otherIter.hasNext) match {
+          case (true, true) => true
+          case (false, false) => false
+          case _ => throw new SparkException("Can only zip RDDs with " +
+            "same number of elements in each partition")
+        }
+        def next = (thisIter.next, otherIter.next)
+      }
+    }
+  }
 
   /**
    * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala
deleted file mode 100644
index b8110ffc42..0000000000
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala
+++ /dev/null
@@ -1,87 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import java.io.{IOException, ObjectOutputStream}
-
-import scala.reflect.ClassTag
-
-import org.apache.spark.{OneToOneDependency, Partition, SparkContext, TaskContext}
-
-private[spark] class ZippedPartition[T: ClassTag, U: ClassTag](
-    idx: Int,
-    @transient rdd1: RDD[T],
-    @transient rdd2: RDD[U]
-  ) extends Partition {
-
-  var partition1 = rdd1.partitions(idx)
-  var partition2 = rdd2.partitions(idx)
-  override val index: Int = idx
-
-  def partitions = (partition1, partition2)
-
-  @throws(classOf[IOException])
-  private def writeObject(oos: ObjectOutputStream) {
-    // Update the reference to parent partition at the time of task serialization
-    partition1 = rdd1.partitions(idx)
-    partition2 = rdd2.partitions(idx)
-    oos.defaultWriteObject()
-  }
-}
-
-private[spark] class ZippedRDD[T: ClassTag, U: ClassTag](
-    sc: SparkContext,
-    var rdd1: RDD[T],
-    var rdd2: RDD[U])
-  extends RDD[(T, U)](sc, List(new OneToOneDependency(rdd1), new OneToOneDependency(rdd2))) {
-
-  override def getPartitions: Array[Partition] = {
-    if (rdd1.partitions.size != rdd2.partitions.size) {
-      throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
-    }
-    val array = new Array[Partition](rdd1.partitions.size)
-    for (i <- 0 until rdd1.partitions.size) {
-      array(i) = new ZippedPartition(i, rdd1, rdd2)
-    }
-    array
-  }
-
-  override def compute(s: Partition, context: TaskContext): Iterator[(T, U)] = {
-    val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions
-    rdd1.iterator(partition1, context).zip(rdd2.iterator(partition2, context))
-  }
-
-  override def getPreferredLocations(s: Partition): Seq[String] = {
-    val (partition1, partition2) = s.asInstanceOf[ZippedPartition[T, U]].partitions
-    val pref1 = rdd1.preferredLocations(partition1)
-    val pref2 = rdd2.preferredLocations(partition2)
-    // Check whether there are any hosts that match both RDDs; otherwise return the union
-    val exactMatchLocations = pref1.intersect(pref2)
-    if (!exactMatchLocations.isEmpty) {
-      exactMatchLocations
-    } else {
-      (pref1 ++ pref2).distinct
-    }
-  }
-
-  override def clearDependencies() {
-    super.clearDependencies()
-    rdd1 = null
-    rdd2 = null
-  }
-}
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index 64933f4b10..f64f3c9036 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -167,26 +167,28 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
     })
   }
 
-  test("ZippedRDD") {
-    testRDD(rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)))
-    testRDDPartitions(rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)))
+  test("ZippedPartitionsRDD") {
+    testRDD(rdd => rdd.zip(rdd.map(x => x)))
+    testRDDPartitions(rdd => rdd.zip(rdd.map(x => x)))
 
-    // Test that the ZippedPartition updates parent partitions
-    // after the parent RDD has been checkpointed and parent partitions have been changed.
-    // Note that this test is very specific to the current implementation of ZippedRDD.
+    // Test that ZippedPartitionsRDD updates parent partitions after parent RDDs have
+    // been checkpointed and parent partitions have been changed.
+    // Note that this test is very specific to the implementation of ZippedPartitionsRDD.
     val rdd = generateFatRDD()
-    val zippedRDD = new ZippedRDD(sc, rdd, rdd.map(x => x))
+    val zippedRDD = rdd.zip(rdd.map(x => x)).asInstanceOf[ZippedPartitionsRDD2[_, _, _]]
     zippedRDD.rdd1.checkpoint()
     zippedRDD.rdd2.checkpoint()
     val partitionBeforeCheckpoint =
-      serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartition[_, _]])
+      serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartitionsPartition])
     zippedRDD.count()
     val partitionAfterCheckpoint =
-      serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartition[_, _]])
+      serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartitionsPartition])
     assert(
-      partitionAfterCheckpoint.partition1.getClass != partitionBeforeCheckpoint.partition1.getClass &&
-        partitionAfterCheckpoint.partition2.getClass != partitionBeforeCheckpoint.partition2.getClass,
-      "ZippedRDD.partition1 and ZippedRDD.partition2 not updated after parent RDD is checkpointed"
+      partitionAfterCheckpoint.partitions(0).getClass !=
+        partitionBeforeCheckpoint.partitions(0).getClass &&
+      partitionAfterCheckpoint.partitions(1).getClass !=
+        partitionBeforeCheckpoint.partitions(1).getClass,
+      "ZippedPartitionsRDD partition 0 (or 1) not updated after parent RDDs are checkpointed"
     )
   }
 
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index bbd0c14178..286e221e33 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -350,6 +350,10 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     intercept[IllegalArgumentException] {
       nums.zip(sc.parallelize(1 to 4, 1)).collect()
     }
+
+    intercept[SparkException] {
+      nums.zip(sc.parallelize(1 to 5, 2)).collect()
+    }
   }
 
   test("partition pruning") {
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index fadf6a4d8b..dd7efceb23 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -54,6 +54,8 @@ object MimaExcludes {
             ProblemFilters.exclude[MissingMethodProblem](
               "org.apache.spark.api.java.JavaDoubleRDD.countApproxDistinct$default$1")
           ) ++
+          MimaBuild.excludeSparkClass("rdd.ZippedRDD") ++
+          MimaBuild.excludeSparkClass("rdd.ZippedPartition") ++
           MimaBuild.excludeSparkClass("util.SerializableHyperLogLog")
         case v if v.startsWith("1.0") =>
           Seq(
-- 
GitLab