diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 74ac97091fd0b46601a4e4047a9266cbb3ec113c..e1c49e35abecd3624ef6971c743fd601c1e18a38 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -1236,6 +1236,23 @@ abstract class RDD[T: ClassTag](
   /** The [[org.apache.spark.SparkContext]] that this RDD was created on. */
   def context = sc
 
+  /**
+   * Private API for changing an RDD's ClassTag.
+   * Used for internal Java <-> Scala API compatibility.
+   */
+  private[spark] def retag(cls: Class[T]): RDD[T] = {
+    val classTag: ClassTag[T] = ClassTag.apply(cls)
+    this.retag(classTag)
+  }
+
+  /**
+   * Private API for changing an RDD's ClassTag.
+   * Used for internal Java <-> Scala API compatibility.
+   */
+  private[spark] def retag(implicit classTag: ClassTag[T]): RDD[T] = {
+    this.mapPartitions(identity, preservesPartitioning = true)(classTag)
+  }
+
   // Avoid handling doCheckpoint multiple times to prevent excessive recursion
   @transient private var doCheckpointCalled = false
 
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index e8bd65f8e45070d2ae31e1f317c8b47f315bc0e1..fab64a54e24797b75ec709f9fc04140abefaa7f6 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -1245,4 +1245,21 @@ public class JavaAPISuite implements Serializable {
     Assert.assertTrue(worExactCounts.get(0) == 2);
     Assert.assertTrue(worExactCounts.get(1) == 4);
   }
+
+  private static class SomeCustomClass implements Serializable {
+    public SomeCustomClass() {
+      // Intentionally left blank
+    }
+  }
+
+  @Test
+  public void collectUnderlyingScalaRDD() {
+    List<SomeCustomClass> data = new ArrayList<SomeCustomClass>();
+    for (int i = 0; i < 100; i++) {
+      data.add(new SomeCustomClass());
+    }
+    JavaRDD<SomeCustomClass> rdd = sc.parallelize(data);
+    SomeCustomClass[] collected = (SomeCustomClass[]) rdd.rdd().retag(SomeCustomClass.class).collect();
+    Assert.assertEquals(data.size(), collected.length);
+  }
 }
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index ae6e52587584f7df2f8c92a1d8de0c04b669f12d..b31e3a09e5b9c706dd19ac6672181051204a25fc 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.rdd
 
 import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.JavaConverters._
 import scala.reflect.ClassTag
 
 import org.scalatest.FunSuite
@@ -26,6 +27,7 @@ import org.apache.spark._
 import org.apache.spark.SparkContext._
 import org.apache.spark.util.Utils
 
+import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
 import org.apache.spark.rdd.RDDSuiteUtils._
 
 class RDDSuite extends FunSuite with SharedSparkContext {
@@ -718,6 +720,12 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     assert(ids.length === n)
   }
 
+  test("retag with implicit ClassTag") {
+    val jsc: JavaSparkContext = new JavaSparkContext(sc)
+    val jrdd: JavaRDD[String] = jsc.parallelize(Seq("A", "B", "C").asJava)
+    jrdd.rdd.retag.collect()
+  }
+
   test("getNarrowAncestors") {
     val rdd1 = sc.parallelize(1 to 100, 4)
     val rdd2 = rdd1.filter(_ % 2 == 0).map(_ + 1)