diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 0a901a251d886bea79c23f561437349c28b84c56..2ad11bc604daf804bf5c4219b58081fd55876618 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -365,60 +365,59 @@ abstract class RDD[T: ClassManifest]( new MapPartitionsWithIndexRDD(this, sc.clean(f), preservesPartitioning) /** - * Maps f over this RDD where f takes an additional parameter of type A. This - * additional parameter is produced by a factory method T => A which is called - * on each invocation of f. This factory method is produced by the factoryBuilder, - * an instance of which is constructed in each partition from the partition index - * and a seed value of type B. - */ - def mapWith[A: ClassManifest, B: ClassManifest, U: ClassManifest]( - factoryBuilder: (Int, B) => (T => A), - factorySeed: B, - preservesPartitioning: Boolean = false) + * Maps f over this RDD where, f takes an additional parameter of type A. This + * additional parameter is produced by constructorOfA, which is called in each + * partition with the index of that partition. + */ + def mapWith[A: ClassManifest, U: ClassManifest](constructorOfA: Int => A, preservesPartitioning: Boolean = false) (f:(A, T) => U): RDD[U] = { def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { - val factory = factoryBuilder(index, factorySeed) - iter.map(t => f(factory(t), t)) + val a = constructorOfA(index) + iter.map(t => f(a, t)) } new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) } - /** - * FlatMaps f over this RDD where f takes an additional parameter of type A. This - * additional parameter is produced by a factory method T => A which is called - * on each invocation of f. This factory method is produced by the factoryBuilder, - * an instance of which is constructed in each partition from the partition index - * and a seed value of type B. + /** + * FlatMaps f over this RDD, where f takes an additional parameter of type A. This + * additional parameter is produced by constructorOfA, which is called in each + * partition with the index of that partition. */ - def flatMapWith[A: ClassManifest, B: ClassManifest, U: ClassManifest]( - factoryBuilder: (Int, B) => (T => A), - factorySeed: B, - preservesPartitioning: Boolean = false) + def flatMapWith[A: ClassManifest, U: ClassManifest](constructorOfA: Int => A, preservesPartitioning: Boolean = false) (f:(A, T) => Seq[U]): RDD[U] = { def iterF(index: Int, iter: Iterator[T]): Iterator[U] = { - val factory = factoryBuilder(index, factorySeed) - iter.flatMap(t => f(factory(t), t)) + val a = constructorOfA(index) + iter.flatMap(t => f(a, t)) } new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) } + /** + * Applies f to each element of this RDD, where f takes an additional parameter of type A. + * This additional parameter is produced by constructorOfA, which is called in each + * partition with the index of that partition. + */ + def foreachWith[A: ClassManifest](constructorOfA: Int => A) + (f:(A, T) => Unit) { + def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { + val a = constructorOfA(index) + iter.map(t => {f(a, t); t}) + } + (new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true)).foreach(_ => {}) + } + /** * Filters this RDD with p, where p takes an additional parameter of type A. This - * additional parameter is produced by a factory method T => A which is called - * on each invocation of p. This factory method is produced by the factoryBuilder, - * an instance of which is constructed in each partition from the partition index - * and a seed value of type B. - */ - def filterWith[A: ClassManifest, B: ClassManifest]( - factoryBuilder: (Int, B) => (T => A), - factorySeed: B, - preservesPartitioning: Boolean = false) + * additional parameter is produced by constructorOfA, which is called in each + * partition with the index of that partition. + */ + def filterWith[A: ClassManifest](constructorOfA: Int => A) (p:(A, T) => Boolean): RDD[T] = { def iterF(index: Int, iter: Iterator[T]): Iterator[T] = { - val factory = factoryBuilder(index, factorySeed) - iter.filter(t => p(factory(t), t)) + val a = constructorOfA(index) + iter.filter(t => p(a, t)) } - new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), preservesPartitioning) + new MapPartitionsWithIndexRDD(this, sc.clean(iterF _), true) } /** @@ -439,6 +438,14 @@ abstract class RDD[T: ClassManifest]( sc.runJob(this, (iter: Iterator[T]) => iter.foreach(cleanF)) } + /** + * Applies a function f to each partition of this RDD. + */ + def foreachPartition(f: Iterator[T] => Unit) { + val cleanF = sc.clean(f) + sc.runJob(this, (iter: Iterator[T]) => f(iter)) + } + /** * Return an array that contains all of the elements in this RDD. */ diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 2a182e0d6c3c99b6a4f0eeb67f3602a078af7b84..d260191dd706fe1963faa0288db06ab1133060df 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -180,21 +180,18 @@ class RDDSuite extends FunSuite with LocalSparkContext { } test("mapWith") { + import java.util.Random sc = new SparkContext("local", "test") val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) val randoms = ones.mapWith( - (index: Int, seed: Int) => { - val prng = new java.util.Random(index + seed) - (_ => prng.nextDouble)}, - 42) - {(random: Double, t: Int) => random * t}. - collect() + (index: Int) => new Random(index + 42)) + {(prng: Random, t: Int) => prng.nextDouble * t}.collect() val prn42_3 = { - val prng42 = new java.util.Random(42) + val prng42 = new Random(42) prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() } val prn43_3 = { - val prng43 = new java.util.Random(43) + val prng43 = new Random(43) prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() } assert(randoms(2) === prn42_3) @@ -202,21 +199,21 @@ class RDDSuite extends FunSuite with LocalSparkContext { } test("flatMapWith") { + import java.util.Random sc = new SparkContext("local", "test") val ones = sc.makeRDD(Array(1, 1, 1, 1, 1, 1), 2) val randoms = ones.flatMapWith( - (index: Int, seed: Int) => { - val prng = new java.util.Random(index + seed) - (_ => prng.nextDouble)}, - 42) - {(random: Double, t: Int) => Seq(random * t, random * t * 10)}. + (index: Int) => new Random(index + 42)) + {(prng: Random, t: Int) => { + val random = prng.nextDouble() + Seq(random * t, random * t * 10)}}. collect() val prn42_3 = { - val prng42 = new java.util.Random(42) + val prng42 = new Random(42) prng42.nextDouble(); prng42.nextDouble(); prng42.nextDouble() } val prn43_3 = { - val prng43 = new java.util.Random(43) + val prng43 = new Random(43) prng43.nextDouble(); prng43.nextDouble(); prng43.nextDouble() } assert(randoms(5) === prn42_3 * 10) @@ -228,11 +225,8 @@ class RDDSuite extends FunSuite with LocalSparkContext { sc = new SparkContext("local", "test") val ints = sc.makeRDD(Array(1, 2, 3, 4, 5, 6), 2) val sample = ints.filterWith( - (index: Int, seed: Int) => { - val prng = new Random(index + seed) - (_ => prng.nextInt(3))}, - 42) - {(random: Int, t: Int) => random == 0}. + (index: Int) => new Random(index + 42)) + {(prng: Random, t: Int) => prng.nextInt(3) == 0}. collect() val checkSample = { val prng42 = new Random(42)