diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 47708cb2e78bd62c9f68604387168ea2f529daea..76d4193e96aea0e87142b002d9a9f06006d9f3d6 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -783,6 +783,17 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
     sortByKey(comp, ascending)
   }
 
+  /**
+   * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
+   * `collect` or `save` on the resulting RDD will return or output an ordered list of records
+   * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
+   * order of the keys).
+   */
+  def sortByKey(ascending: Boolean, numPartitions: Int): JavaPairRDD[K, V] = {
+    val comp = com.google.common.collect.Ordering.natural().asInstanceOf[Comparator[K]]
+    sortByKey(comp, ascending, numPartitions)
+  }
+
   /**
    * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
    * `collect` or `save` on the resulting RDD will return or output an ordered list of records