From 6e730edcde7ca6cbb5727dff7a42f7284b368528 Mon Sep 17 00:00:00 2001 From: Prashant Sharma <prashant.s@imaginea.com> Date: Fri, 7 Mar 2014 18:48:07 -0800 Subject: [PATCH] Spark 1165 rdd.intersection in python and java Author: Prashant Sharma <prashant.s@imaginea.com> Author: Prashant Sharma <scrapcodes@gmail.com> Closes #80 from ScrapCodes/SPARK-1165/RDD.intersection and squashes the following commits: 9b015e9 [Prashant Sharma] Added a note, shuffle is required for intersection. 1fea813 [Prashant Sharma] correct the lines wrapping d0c71f3 [Prashant Sharma] SPARK-1165 RDD.intersection in java d6effee [Prashant Sharma] SPARK-1165 Implemented RDD.intersection in python. --- .../apache/spark/api/java/JavaDoubleRDD.scala | 8 +++++ .../apache/spark/api/java/JavaPairRDD.scala | 10 ++++++ .../org/apache/spark/api/java/JavaRDD.scala | 9 ++++++ .../java/org/apache/spark/JavaAPISuite.java | 31 +++++++++++++++++++ python/pyspark/rdd.py | 17 ++++++++++ 5 files changed, 75 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index d1787061bc..f816bb43a5 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -140,6 +140,14 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[JDouble, Ja */ def union(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.union(other.srdd)) + /** + * Return the intersection of this RDD and another one. The output will not contain any duplicate + * elements, even if the input RDDs did. + * + * Note that this method performs a shuffle internally. + */ + def intersection(other: JavaDoubleRDD): JavaDoubleRDD = fromRDD(srdd.intersection(other.srdd)) + // Double RDD functions /** Add up the elements in this RDD. */ diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 857626fe84..0ff428c120 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -126,6 +126,16 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) def union(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.union(other.rdd)) + /** + * Return the intersection of this RDD and another one. The output will not contain any duplicate + * elements, even if the input RDDs did. + * + * Note that this method performs a shuffle internally. + */ + def intersection(other: JavaPairRDD[K, V]): JavaPairRDD[K, V] = + new JavaPairRDD[K, V](rdd.intersection(other.rdd)) + + // first() has to be overridden here so that the generated method has the signature // 'public scala.Tuple2 first()'; if the trait's definition is used, // then the method has the signature 'public java.lang.Object first()', diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index e973c46edd..91bf404631 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -106,6 +106,15 @@ class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) */ def union(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.union(other.rdd)) + + /** + * Return the intersection of this RDD and another one. The output will not contain any duplicate + * elements, even if the input RDDs did. + * + * Note that this method performs a shuffle internally. + */ + def intersection(other: JavaRDD[T]): JavaRDD[T] = wrapRDD(rdd.intersection(other.rdd)) + /** * Return an RDD with the elements from `this` that are not in `other`. * diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java index c7d0e2d577..40e853c39c 100644 --- a/core/src/test/java/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java @@ -110,6 +110,37 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(4, pUnion.count()); } + @SuppressWarnings("unchecked") + @Test + public void intersection() { + List<Integer> ints1 = Arrays.asList(1, 10, 2, 3, 4, 5); + List<Integer> ints2 = Arrays.asList(1, 6, 2, 3, 7, 8); + JavaRDD<Integer> s1 = sc.parallelize(ints1); + JavaRDD<Integer> s2 = sc.parallelize(ints2); + + JavaRDD<Integer> intersections = s1.intersection(s2); + Assert.assertEquals(3, intersections.count()); + + ArrayList<Integer> list = new ArrayList<Integer>(); + JavaRDD<Integer> empty = sc.parallelize(list); + JavaRDD<Integer> emptyIntersection = empty.intersection(s2); + Assert.assertEquals(0, emptyIntersection.count()); + + List<Double> doubles = Arrays.asList(1.0, 2.0); + JavaDoubleRDD d1 = sc.parallelizeDoubles(doubles); + JavaDoubleRDD d2 = sc.parallelizeDoubles(doubles); + JavaDoubleRDD dIntersection = d1.intersection(d2); + Assert.assertEquals(2, dIntersection.count()); + + List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>(); + pairs.add(new Tuple2<Integer, Integer>(1, 2)); + pairs.add(new Tuple2<Integer, Integer>(3, 4)); + JavaPairRDD<Integer, Integer> p1 = sc.parallelizePairs(pairs); + JavaPairRDD<Integer, Integer> p2 = sc.parallelizePairs(pairs); + JavaPairRDD<Integer, Integer> pIntersection = p1.intersection(p2); + Assert.assertEquals(2, pIntersection.count()); + } + @Test public void sortByKey() { List<Tuple2<Integer, Integer>> pairs = new ArrayList<Tuple2<Integer, Integer>>(); diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 097a0a236b..e72f57d9d1 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -326,6 +326,23 @@ class RDD(object): return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx, self.ctx.serializer) + def intersection(self, other): + """ + Return the intersection of this RDD and another one. The output will not + contain any duplicate elements, even if the input RDDs did. + + Note that this method performs a shuffle internally. + + >>> rdd1 = sc.parallelize([1, 10, 2, 3, 4, 5]) + >>> rdd2 = sc.parallelize([1, 6, 2, 3, 7, 8]) + >>> rdd1.intersection(rdd2).collect() + [1, 2, 3] + """ + return self.map(lambda v: (v, None)) \ + .cogroup(other.map(lambda v: (v, None))) \ + .filter(lambda x: (len(x[1][0]) != 0) and (len(x[1][1]) != 0)) \ + .keys() + def _reserialize(self): if self._jrdd_deserializer == self.ctx.serializer: return self -- GitLab