diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
index 203179c4ea823efb7d1b48cbf3e612f289729ae9..ae70d559511c9dbac0a013bc62bd2981081642de 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
@@ -20,18 +20,16 @@ package org.apache.spark.rdd
 import org.apache.spark.{Partition, TaskContext}
 
 
-private[spark]
-class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
+private[spark] class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
     prev: RDD[T],
-    f: Iterator[T] => Iterator[U],
+    f: (TaskContext, Int, Iterator[T]) => Iterator[U],  // (TaskContext, partition index, iterator)
     preservesPartitioning: Boolean = false)
   extends RDD[U](prev) {
 
-  override val partitioner =
-    if (preservesPartitioning) firstParent[T].partitioner else None
+  override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
 
   override def getPartitions: Array[Partition] = firstParent[T].partitions
 
   override def compute(split: Partition, context: TaskContext) =
-    f(firstParent[T].iterator(split, context))
+    f(context, split.index, firstParent[T].iterator(split, context))
 }
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
deleted file mode 100644
index aea08ff81bfdb9aa0f5b4158d8b7dd6fb5d90f9f..0000000000000000000000000000000000000000
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import org.apache.spark.{Partition, TaskContext}
-
-
-/**
- * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the
- * TaskContext, the closure can either get access to the interruptible flag or get the index
- * of the partition in the RDD.
- */
-private[spark]
-class MapPartitionsWithContextRDD[U: ClassManifest, T: ClassManifest](
-    prev: RDD[T],
-    f: (TaskContext, Iterator[T]) => Iterator[U],
-    preservesPartitioning: Boolean
-  ) extends RDD[U](prev) {
-
-  override def getPartitions: Array[Partition] = firstParent[T].partitions
-
-  override val partitioner = if (preservesPartitioning) prev.partitioner else None
-
-  override def compute(split: Partition, context: TaskContext) =
-    f(context, firstParent[T].iterator(split, context))
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 7623c44d8856c4a0f414a89d264de085d43f54c7..5b1285307d11ee34cc98b9ca1626f0d3e80b760f 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -408,7 +408,6 @@ abstract class RDD[T: ClassManifest](
   def pipe(command: String, env: Map[String, String]): RDD[String] =
     new PipedRDD(this, command, env)
 
-
   /**
    * Return an RDD created by piping elements to a forked external process.
    * The print behavior can be customized by providing two functions.
@@ -442,7 +441,8 @@ abstract class RDD[T: ClassManifest](
    */
   def mapPartitions[U: ClassManifest](
       f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
-    new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
+    val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter)
+    new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
   }
 
   /**
@@ -451,8 +451,8 @@ abstract class RDD[T: ClassManifest](
    */
   def mapPartitionsWithIndex[U: ClassManifest](
       f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
-    val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter)
-    new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning)
+    val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter)
+    new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
   }
 
   /**
@@ -462,7 +462,8 @@ abstract class RDD[T: ClassManifest](
   def mapPartitionsWithContext[U: ClassManifest](
       f: (TaskContext, Iterator[T]) => Iterator[U],
       preservesPartitioning: Boolean = false): RDD[U] = {
-    new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning)
+    val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter)
+    new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
   }
 
   /**
@@ -483,11 +484,10 @@ abstract class RDD[T: ClassManifest](
   def mapWith[A: ClassManifest, U: ClassManifest]
       (constructA: Int => A, preservesPartitioning: Boolean = false)
       (f: (T, A) => U): RDD[U] = {
-    def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
-      val a = constructA(context.partitionId)
+    mapPartitionsWithIndex((index, iter) => {
+      val a = constructA(index)
       iter.map(t => f(t, a))
-    }
-    new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
+    }, preservesPartitioning)
   }
 
   /**
@@ -498,11 +498,10 @@ abstract class RDD[T: ClassManifest](
   def flatMapWith[A: ClassManifest, U: ClassManifest]
       (constructA: Int => A, preservesPartitioning: Boolean = false)
       (f: (T, A) => Seq[U]): RDD[U] = {
-    def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
-      val a = constructA(context.partitionId)
+    mapPartitionsWithIndex((index, iter) => {
+      val a = constructA(index)
       iter.flatMap(t => f(t, a))
-    }
-    new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
+    }, preservesPartitioning)
   }
 
   /**
@@ -511,11 +510,10 @@ abstract class RDD[T: ClassManifest](
    * partition with the index of that partition.
    */
   def foreachWith[A: ClassManifest](constructA: Int => A)(f: (T, A) => Unit) {
-    def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
-      val a = constructA(context.partitionId)
+    mapPartitionsWithIndex { (index, iter) =>
+      val a = constructA(index)
       iter.map(t => {f(t, a); t})
-    }
-    new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {})
+    }.foreach(_ => {})
   }
 
   /**
@@ -524,11 +522,10 @@ abstract class RDD[T: ClassManifest](
    * partition with the index of that partition.
    */
   def filterWith[A: ClassManifest](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
-    def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
-      val a = constructA(context.partitionId)
+    mapPartitionsWithIndex((index, iter) => {
+      val a = constructA(index)
       iter.filter(t => p(t, a))
-    }
-    new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true)
+    }, preservesPartitioning = true)
   }
 
   /**
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index f26c44d3e76ff84e8088e2a90249eaf4beaf44ea..d2226aa5a566356231d82266c726e0a544f7c10e 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -62,8 +62,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
     testCheckpointing(_.sample(false, 0.5, 0))
     testCheckpointing(_.glom())
     testCheckpointing(_.mapPartitions(_.map(_.toString)))
-    testCheckpointing(r => new MapPartitionsWithContextRDD(r,
-      (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false ))
     testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
     testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
     testCheckpointing(_.pipe(Seq("cat")))