diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 5fac955286ff7a1cdb2647be136ac6ef08728cb1..cce0ea21836fb69da1610515d1c877570e4a8af0 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -196,6 +196,9 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial def mapPartitions[U: ClassManifest](f: Iterator[T] => Iterator[U]): RDD[U] = new MapPartitionsRDD(this, sc.clean(f)) + def mapPartitionsWithSplit[U: ClassManifest](f: (Int, Iterator[T]) => Iterator[U]): RDD[U] = + new MapPartitionsWithSplitRDD(this, sc.clean(f)) + // Actions (launch a job to return a value to the user program) def foreach(f: T => Unit) { @@ -417,3 +420,18 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( override val dependencies = List(new OneToOneDependency(prev)) override def compute(split: Split) = f(prev.iterator(split)) } + +/** + * A variant of the MapPartitionsRDD that passes the split index into the + * closure. This can be used to generate or collect partition specific + * information such as the number of tuples in a partition. + */ +class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( + prev: RDD[T], + f: (Int, Iterator[T]) => Iterator[U]) + extends RDD[U](prev.context) { + + override def splits = prev.splits + override val dependencies = List(new OneToOneDependency(prev)) + override def compute(split: Split) = f(split.index, prev.iterator(split)) +} diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index ba9b36adb7cdc53ba5750482c1d7d3a2ecbfadef..04dbe3a3e4e389cf452d9298c6696a46c2f3d1e1 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -29,6 +29,11 @@ class RDDSuite extends FunSuite with BeforeAndAfter { assert(nums.glom().map(_.toList).collect().toList === List(List(1, 2), List(3, 4))) val partitionSums = nums.mapPartitions(iter => Iterator(iter.reduceLeft(_ + _))) assert(partitionSums.collect().toList === List(3, 7)) + + val partitionSumsWithSplit = nums.mapPartitionsWithSplit { + case(split, iter) => Iterator((split, iter.reduceLeft(_ + _))) + } + assert(partitionSumsWithSplit.collect().toList === List((0, 3), (1, 7))) } test("SparkContext.union") {