diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 1167c120225c30c295c56e48e8e6b79905dab325..363667fa863534ee9ad26105595f30d41a364d02 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -588,6 +588,20 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K
     fromRDD(new OrderedRDDFunctions[K, V, (K, V)](rdd).sortByKey(ascending))
   }
 
+  /**
+   * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
+   * `collect` or `save` on the resulting RDD will return or output an ordered list of records
+   * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
+   * order of the keys).
+   */
+  def sortByKey(comp: Comparator[K], ascending: Boolean, numPartitions: Int): JavaPairRDD[K, V] = {
+    class KeyOrdering(val a: K) extends Ordered[K] {
+      override def compare(b: K) = comp.compare(a, b)
+    }
+    implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x)
+    fromRDD(new OrderedRDDFunctions[K, V, (K, V)](rdd).sortByKey(ascending, numPartitions))
+  }
+
   /**
    * Return an RDD with the keys of each tuple.
    */