diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index b3698ffa44d57f192c052e03d5c6551c1b9a2fdf..4c95c989b53676c54e9c941b0892ff08b189a734 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -12,7 +12,7 @@ import spark.storage.StorageLevel
 import com.google.common.base.Optional
 
 
-trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
+trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround[T] {
   def wrapRDD(rdd: RDD[T]): This
 
   implicit val classManifest: ClassManifest[T]
@@ -82,10 +82,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
   }
 
   /**
-   *  Return a new RDD by first applying a function to all elements of this
-   *  RDD, and then flattening the results.
+   * Part of the workaround for SPARK-668; called in PairFlatMapWorkaround.java.
    */
-  def flatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = {
+  private[spark] def doFlatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = {
     import scala.collection.JavaConverters._
     def fn = (x: T) => f.apply(x).asScala
     def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
diff --git a/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java b/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java
new file mode 100644
index 0000000000000000000000000000000000000000..68b6fd6622742148761363045c6a06d0e8afeb74
--- /dev/null
+++ b/core/src/main/scala/spark/api/java/PairFlatMapWorkaround.java
@@ -0,0 +1,20 @@
+package spark.api.java;
+
+import spark.api.java.JavaPairRDD;
+import spark.api.java.JavaRDDLike;
+import spark.api.java.function.PairFlatMapFunction;
+
+import java.io.Serializable;
+
+/**
+ * Workaround for SPARK-668.
+ */
+class PairFlatMapWorkaround<T> implements Serializable {
+    /**
+     *  Return a new RDD by first applying a function to all elements of this
+     *  RDD, and then flattening the results.
+     */
+    public <K, V> JavaPairRDD<K, V> flatMap(PairFlatMapFunction<T, K, V> f) {
+        return ((JavaRDDLike <T, ?>) this).doFlatMap(f);
+    }
+}
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index 01351de4aeb1a4a7a33794cc517047800de20a63..f50ba093e90ae36abf004dc5975d831e7e4d364d 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -355,6 +355,34 @@ public class JavaAPISuite implements Serializable {
     Assert.assertEquals(11, pairs.count());
   }
 
+  @Test
+  public void mapsFromPairsToPairs() {
+      List<Tuple2<Integer, String>> pairs = Arrays.asList(
+              new Tuple2<Integer, String>(1, "a"),
+              new Tuple2<Integer, String>(2, "aa"),
+              new Tuple2<Integer, String>(3, "aaa")
+      );
+      JavaPairRDD<Integer, String> pairRDD = sc.parallelizePairs(pairs);
+
+      // Regression test for SPARK-668:
+      JavaPairRDD<String, Integer> swapped = pairRDD.flatMap(
+          new PairFlatMapFunction<Tuple2<Integer, String>, String, Integer>() {
+          @Override
+          public Iterable<Tuple2<String, Integer>> call(Tuple2<Integer, String> item) throws Exception {
+              return Collections.singletonList(item.swap());
+          }
+      });
+      swapped.collect();
+
+      // There was never a bug here, but it's worth testing:
+      pairRDD.map(new PairFunction<Tuple2<Integer, String>, String, Integer>() {
+          @Override
+          public Tuple2<String, Integer> call(Tuple2<Integer, String> item) throws Exception {
+              return item.swap();
+          }
+      }).collect();
+  }
+
   @Test
   public void mapPartitions() {
     JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);