From e258e5040fa1905a04efcb7b3ca4a6d33e18fa61 Mon Sep 17 00:00:00 2001
From: Egor Pakhomov <pahomov.egor@gmail.com>
Date: Sun, 6 Apr 2014 16:41:23 -0700
Subject: [PATCH] [SPARK-1259] Make RDD locally iterable

Author: Egor Pakhomov <pahomov.egor@gmail.com>

Closes #156 from epahomov/SPARK-1259 and squashes the following commits:

8ec8f24 [Egor Pakhomov] Make to local iterator shorter
34aa300 [Egor Pakhomov] Fix toLocalIterator docs
08363ef [Egor Pakhomov] SPARK-1259 from toLocallyIterable to toLocalIterator
6a994eb [Egor Pakhomov] SPARK-1259 Make RDD locally iterable
8be3dcf [Egor Pakhomov] SPARK-1259 Make RDD locally iterable
33ecb17 [Egor Pakhomov] SPARK-1259 Make RDD locally iterable
---
 .../org/apache/spark/api/java/JavaRDDLike.scala    | 14 +++++++++++++-
 core/src/main/scala/org/apache/spark/rdd/RDD.scala | 12 ++++++++++++
 .../test/java/org/apache/spark/JavaAPISuite.java   |  9 +++++++++
 .../test/scala/org/apache/spark/rdd/RDDSuite.scala |  1 +
 4 files changed, 35 insertions(+), 1 deletion(-)

diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index e03b8e78d5..6e8ec8e0c7 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -17,7 +17,8 @@
 
 package org.apache.spark.api.java
 
-import java.util.{Comparator, List => JList}
+import java.util.{Comparator, Iterator => JIterator, List => JList}
+import java.lang.{Iterable => JIterable}
 
 import scala.collection.JavaConversions._
 import scala.reflect.ClassTag
@@ -280,6 +281,17 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
     new java.util.ArrayList(arr)
   }
 
+  /**
+   * Return an iterator that contains all of the elements in this RDD.
+   *
+   * The iterator will consume as much memory as the largest partition in this RDD.
+   */
+  def toLocalIterator(): JIterator[T] = {
+     import scala.collection.JavaConversions._
+     rdd.toLocalIterator
+  }
+
+
   /**
    * Return an array that contains all of the elements in this RDD.
    * @deprecated As of Spark 1.0.0, toArray() is deprecated, use {@link #collect()} instead
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 08c42c5ee8..c43823bd76 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -661,6 +661,18 @@ abstract class RDD[T: ClassTag](
     Array.concat(results: _*)
   }
 
+  /**
+   * Return an iterator that contains all of the elements in this RDD.
+   *
+   * The iterator will consume as much memory as the largest partition in this RDD.
+   */
+  def toLocalIterator: Iterator[T] = {
+    def collectPartition(p: Int): Array[T] = {
+      sc.runJob(this, (iter: Iterator[T]) => iter.toArray, Seq(p), allowLocal = false).head
+    }
+    (0 until partitions.length).iterator.flatMap(i => collectPartition(i))
+  }
+
   /**
    * Return an array that contains all of the elements in this RDD.
    */
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 2372f2d992..762405be2a 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -22,6 +22,7 @@ import java.util.*;
 
 import scala.Tuple2;
 
+import com.google.common.collect.Lists;
 import com.google.common.base.Optional;
 import com.google.common.base.Charsets;
 import com.google.common.io.Files;
@@ -179,6 +180,14 @@ public class JavaAPISuite implements Serializable {
     Assert.assertEquals(2, foreachCalls);
   }
 
+    @Test
+    public void toLocalIterator() {
+        List<Integer> correct = Arrays.asList(1, 2, 3, 4);
+        JavaRDD<Integer> rdd = sc.parallelize(correct);
+        List<Integer> result = Lists.newArrayList(rdd.toLocalIterator());
+        Assert.assertTrue(correct.equals(result));
+    }
+
   @SuppressWarnings("unchecked")
   @Test
   public void lookup() {
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index d6b5fdc798..25973348a7 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -33,6 +33,7 @@ class RDDSuite extends FunSuite with SharedSparkContext {
   test("basic operations") {
     val nums = sc.makeRDD(Array(1, 2, 3, 4), 2)
     assert(nums.collect().toList === List(1, 2, 3, 4))
+    assert(nums.toLocalIterator.toList === List(1, 2, 3, 4))
     val dups = sc.makeRDD(Array(1, 1, 2, 2, 3, 3, 4, 4), 2)
     assert(dups.distinct().count() === 4)
     assert(dups.distinct.count === 4)  // Can distinct and count be called without parentheses?
-- 
GitLab