diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala index b3698ffa44d57f192c052e03d5c6551c1b9a2fdf..4c95c989b53676c54e9c941b0892ff08b189a734 100644 --- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala @@ -12,7 +12,7 @@ import spark.storage.StorageLevel import com.google.common.base.Optional -trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { +trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround[T] { def wrapRDD(rdd: RDD[T]): This implicit val classManifest: ClassManifest[T] @@ -82,10 +82,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** - * Return a new RDD by first applying a function to all elements of this - * RDD, and then flattening the results. + * Part of the workaround for SPARK-668; called in PairFlatMapWorkaround.java. */ - def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = { + private[spark] def doFlatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = { import scala.collection.JavaConverters._ def fn = (x: T) => f.apply(x).asScala def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] diff --git a/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java b/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java new file mode 100644 index 0000000000000000000000000000000000000000..68b6fd6622742148761363045c6a06d0e8afeb74 --- /dev/null +++ b/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java @@ -0,0 +1,20 @@ +package spark.api.java; + +import spark.api.java.JavaPairRDD; +import spark.api.java.JavaRDDLike; +import spark.api.java.function.PairFlatMapFunction; + +import java.io.Serializable; + +/** + * Workaround for SPARK-668. + */ +class PairFlatMapWorkaround<T> implements Serializable { + /** + * Return a new RDD by first applying a function to all elements of this + * RDD, and then flattening the results. + */ + public <K, V> JavaPairRDD<K, V> flatMap(PairFlatMapFunction<T, K, V> f) { + return ((JavaRDDLike <T, ?>) this).doFlatMap(f); + } +} diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java index 01351de4aeb1a4a7a33794cc517047800de20a63..f50ba093e90ae36abf004dc5975d831e7e4d364d 100644 --- a/core/src/test/scala/spark/JavaAPISuite.java +++ b/core/src/test/scala/spark/JavaAPISuite.java @@ -355,6 +355,34 @@ public class JavaAPISuite implements Serializable { Assert.assertEquals(11, pairs.count()); } + @Test + public void mapsFromPairsToPairs() { + List<Tuple2<Integer, String>> pairs = Arrays.asList( + new Tuple2<Integer, String>(1, "a"), + new Tuple2<Integer, String>(2, "aa"), + new Tuple2<Integer, String>(3, "aaa") + ); + JavaPairRDD<Integer, String> pairRDD = sc.parallelizePairs(pairs); + + // Regression test for SPARK-668: + JavaPairRDD<String, Integer> swapped = pairRDD.flatMap( + new PairFlatMapFunction<Tuple2<Integer, String>, String, Integer>() { + @Override + public Iterable<Tuple2<String, Integer>> call(Tuple2<Integer, String> item) throws Exception { + return Collections.singletonList(item.swap()); + } + }); + swapped.collect(); + + // There was never a bug here, but it's worth testing: + pairRDD.map(new PairFunction<Tuple2<Integer, String>, String, Integer>() { + @Override + public Tuple2<String, Integer> call(Tuple2<Integer, String> item) throws Exception { + return item.swap(); + } + }).collect(); + } + @Test public void mapPartitions() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);