diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
index b5b8a5706deb39e0125fffb479451ba46086c158..a637d6f15b7e5ed546160e073303efdece29ab36 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionwiseSampledRDD.scala
@@ -39,6 +39,7 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
  *
  * @param prev RDD to be sampled
  * @param sampler a random sampler
+ * @param preservesPartitioning whether the sampler preserves the partitioner of the parent RDD
  * @param seed random seed
  * @tparam T input RDD item type
  * @tparam U sampled RDD item type
@@ -46,9 +47,12 @@ class PartitionwiseSampledRDDPartition(val prev: Partition, val seed: Long)
 private[spark] class PartitionwiseSampledRDD[T: ClassTag, U: ClassTag](
     prev: RDD[T],
     sampler: RandomSampler[T, U],
+    @transient preservesPartitioning: Boolean,
     @transient seed: Long = Utils.random.nextLong)
   extends RDD[U](prev) {
 
+  @transient override val partitioner = if (preservesPartitioning) prev.partitioner else None
+
   override def getPartitions: Array[Partition] = {
     val random = new Random(seed)
     firstParent[T].partitions.map(x => new PartitionwiseSampledRDDPartition(x, random.nextLong()))
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index a1f28272488918c880a60d505b9fa4441adff23c..c1bafab3e7491b38df6022433eadaa1cd29f225d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -356,9 +356,9 @@ abstract class RDD[T: ClassTag](
       seed: Long = Utils.random.nextLong): RDD[T] = {
     require(fraction >= 0.0, "Invalid fraction value: " + fraction)
     if (withReplacement) {
-      new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), seed)
+      new PartitionwiseSampledRDD[T, T](this, new PoissonSampler[T](fraction), true, seed)
     } else {
-      new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), seed)
+      new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](fraction), true, seed)
     }
   }
 
@@ -374,7 +374,7 @@ abstract class RDD[T: ClassTag](
     val sum = weights.sum
     val normalizedCumWeights = weights.map(_ / sum).scanLeft(0.0d)(_ + _)
     normalizedCumWeights.sliding(2).map { x =>
-      new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), seed)
+      new PartitionwiseSampledRDD[T, T](this, new BernoulliSampler[T](x(0), x(1)), true, seed)
     }.toArray
   }
 
@@ -586,6 +586,9 @@ abstract class RDD[T: ClassTag](
 
   /**
    * Return a new RDD by applying a function to each partition of this RDD.
+   *
+   * `preservesPartitioning` indicates whether the input function preserves the partitioner, which
+   * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
    */
   def mapPartitions[U: ClassTag](
       f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
@@ -596,6 +599,9 @@ abstract class RDD[T: ClassTag](
   /**
    * Return a new RDD by applying a function to each partition of this RDD, while tracking the index
    * of the original partition.
+   *
+   * `preservesPartitioning` indicates whether the input function preserves the partitioner, which
+   * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
    */
   def mapPartitionsWithIndex[U: ClassTag](
       f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
@@ -607,6 +613,9 @@ abstract class RDD[T: ClassTag](
    * :: DeveloperApi ::
    * Return a new RDD by applying a function to each partition of this RDD. This is a variant of
    * mapPartitions that also passes the TaskContext into the closure.
+   *
+   * `preservesPartitioning` indicates whether the input function preserves the partitioner, which
+   * should be `false` unless this is a pair RDD and the input function doesn't modify the keys.
    */
   @DeveloperApi
   def mapPartitionsWithContext[U: ClassTag](
@@ -689,7 +698,7 @@ abstract class RDD[T: ClassTag](
    * a map on the other).
    */
   def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = {
-    zipPartitions(other, true) { (thisIter, otherIter) =>
+    zipPartitions(other, preservesPartitioning = false) { (thisIter, otherIter) =>
       new Iterator[(T, U)] {
         def hasNext = (thisIter.hasNext, otherIter.hasNext) match {
           case (true, true) => true
diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
index 5dd8de319a654443cf0e18b837a03aa188185855..a0483886f8db3fcd8309af7677db5789d51d52f0 100644
--- a/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/PartitionwiseSampledRDDSuite.scala
@@ -43,7 +43,7 @@ class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext {
   test("seed distribution") {
     val rdd = sc.makeRDD(Array(1L, 2L, 3L, 4L), 2)
     val sampler = new MockSampler
-    val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, 0L)
+    val sample = new PartitionwiseSampledRDD[Long, Long](rdd, sampler, false, 0L)
     assert(sample.distinct().count == 2, "Seeds must be different.")
   }
 
@@ -52,7 +52,7 @@ class PartitionwiseSampledRDDSuite extends FunSuite with SharedSparkContext {
     // We want to make sure there are no concurrency issues.
     val rdd = sc.parallelize(0 until 111, 10)
     for (sampler <- Seq(new BernoulliSampler[Int](0.5), new PoissonSampler[Int](0.5))) {
-      val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler)
+      val sampled = new PartitionwiseSampledRDD[Int, Int](rdd, sampler, true)
       sampled.zip(sampled).count()
     }
   }
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 2924de112934c3a5ad07604d356965fafcaa196c..6654ec2d7c656fba1872ece36d00372907ee366f 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -523,6 +523,15 @@ class RDDSuite extends FunSuite with SharedSparkContext {
     assert(sortedTopK === nums.sorted(ord).take(5))
   }
 
+  test("sample preserves partitioner") {
+    val partitioner = new HashPartitioner(2)
+    val rdd = sc.parallelize(Seq((0, 1), (2, 3))).partitionBy(partitioner)
+    for (withReplacement <- Seq(true, false)) {
+      val sampled = rdd.sample(withReplacement, 1.0)
+      assert(sampled.partitioner === rdd.partitioner)
+    }
+  }
+
   test("takeSample") {
     val n = 1000000
     val data = sc.parallelize(1 to n, 2)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
index 079743742d86d6fc18882b2c0868f4fe7f6fe458..1af40de2c7fcf9ad14c6f73188e9456110b0b31c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
@@ -103,11 +103,11 @@ class BinaryClassificationMetrics(scoreAndLabels: RDD[(Double, Double)]) extends
       mergeValue = (c: BinaryLabelCounter, label: Double) => c += label,
       mergeCombiners = (c1: BinaryLabelCounter, c2: BinaryLabelCounter) => c1 += c2
     ).sortByKey(ascending = false)
-    val agg = counts.values.mapPartitions({ iter =>
+    val agg = counts.values.mapPartitions { iter =>
       val agg = new BinaryLabelCounter()
       iter.foreach(agg += _)
       Iterator(agg)
-    }, preservesPartitioning = true).collect()
+    }.collect()
     val partitionwiseCumulativeCounts =
       agg.scanLeft(new BinaryLabelCounter())(
         (agg: BinaryLabelCounter, c: BinaryLabelCounter) => agg.clone() += c)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
index f4c403bc7861ce345c35d2cece596ce33096a01e..8c2b044ea73f215f2ab205f0f3f56fb7f245dbaf 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/distributed/RowMatrix.scala
@@ -377,9 +377,9 @@ class RowMatrix(
       s"Only support dense matrix at this time but found ${B.getClass.getName}.")
 
     val Bb = rows.context.broadcast(B.toBreeze.asInstanceOf[BDM[Double]].toDenseVector.toArray)
-    val AB = rows.mapPartitions({ iter =>
+    val AB = rows.mapPartitions { iter =>
       val Bi = Bb.value
-      iter.map(row => {
+      iter.map { row =>
         val v = BDV.zeros[Double](k)
         var i = 0
         while (i < k) {
@@ -387,8 +387,8 @@ class RowMatrix(
           i += 1
         }
         Vectors.fromBreeze(v)
-      })
-    }, preservesPartitioning = true)
+      }
+    }
 
     new RowMatrix(AB, nRows, B.numCols)
   }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 15e8855db6ca756635c74e01ac73cfc1ecd7f9f7..5356790cb53393f3381b9ec41fcd450490048d13 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -430,7 +430,7 @@ class ALS private (
       val inLinkBlock = makeInLinkBlock(numProductBlocks, ratings, productPartitioner)
       val outLinkBlock = makeOutLinkBlock(numProductBlocks, ratings, productPartitioner)
       Iterator.single((blockId, (inLinkBlock, outLinkBlock)))
-    }, true)
+    }, preservesPartitioning = true)
     val inLinks = links.mapValues(_._1)
     val outLinks = links.mapValues(_._2)
     inLinks.persist(StorageLevel.MEMORY_AND_DISK)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index aaf92a1a8869aa8fe51489d64d8e33e29f1ebf77..30de24ad89f9835453a76fc2fa195f0f4bf59fce 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -264,8 +264,8 @@ object MLUtils {
     (1 to numFolds).map { fold =>
       val sampler = new BernoulliSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF,
         complement = false)
-      val validation = new PartitionwiseSampledRDD(rdd, sampler, seed)
-      val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), seed)
+      val validation = new PartitionwiseSampledRDD(rdd, sampler, true, seed)
+      val training = new PartitionwiseSampledRDD(rdd, sampler.cloneComplement(), true, seed)
       (training, validation)
     }.toArray
   }