diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala
index 5fac955286ff7a1cdb2647be136ac6ef08728cb1..cce0ea21836fb69da1610515d1c877570e4a8af0 100644
--- a/core/src/main/scala/spark/RDD.scala
+++ b/core/src/main/scala/spark/RDD.scala
@@ -196,6 +196,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
   def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] =
     new MapPartitionsRDD(this, sc.clean(f))
 
+  def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] =
+    new MapPartitionsWithSplitRDD(this, sc.clean(f))
+
   // Actions (launch a job to return a value to the user program)
   
   def foreach(f: T => Unit) {
@@ -417,3 +420,18 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
   override val dependencies = List(new OneToOneDependency(prev))
   override def compute(split: Split) = f(prev.iterator(split))
 }
+
+/**
+ * A variant of the MapPartitionsRDD that passes the split index into the
+ * closure. This can be used to generate or collect partition specific
+ * information such as the number of tuples in a partition.
+ */
+class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest](
+    prev: RDD[T],
+    f: (Int, Iterator[T]) => Iterator[U])
+  extends RDD[U](prev.context) {
+
+  override def splits = prev.splits
+  override val dependencies = List(new OneToOneDependency(prev))
+  override def compute(split: Split) = f(split.index, prev.iterator(split))
+}
diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala
index ba9b36adb7cdc53ba5750482c1d7d3a2ecbfadef..04dbe3a3e4e389cf452d9298c6696a46c2f3d1e1 100644
--- a/core/src/test/scala/spark/RDDSuite.scala
+++ b/core/src/test/scala/spark/RDDSuite.scala
@@ -29,6 +29,11 @@ class RDDSuite extends FunSuite with BeforeAndAfter {
     assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4)))
     val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _)))
     assert(partitionSums.collect().toList === List(3, 7))
+
+    val partitionSumsWithSplit = nums.mapPartitionsWithSplit {
+      case(split, iter) => Iterator((split, iter.reduceLeft(_ + _)))
+    }
+    assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7)))
   }
 
   test("SparkContext.union") {