diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 445d520bc2fa5ce144cf8e25b822be0deb3193fe..eb31e901235729d65b3ae11872618cb58ba35206 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -133,7 +133,7 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial val cleanF = sc.clean(f) val reducePartition: Iterator[T] => Option[T] = iter => { if (iter.hasNext) - Some(iter.reduceLeft(f)) + Some(iter.reduceLeft(cleanF)) else None } @@ -144,7 +144,36 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial if (results.size == 0) throw new UnsupportedOperationException("empty collection") else - return results.reduceLeft(f) + return results.reduceLeft(cleanF) + } + + /** + * Aggregate the elements of each partition, and then the results for all the + * partitions, using a given associative function and a neutral "zero value". + * The function op(t1, t2) is allowed to modify t1 and return it as its result + * value to avoid object allocation; however, it should not modify t2. + */ + def fold(zeroValue: T)(op: (T, T) => T): T = { + val cleanOp = sc.clean(op) + val results = sc.runJob(this, (iter: Iterator[T]) => iter.fold(zeroValue)(cleanOp)) + return results.fold(zeroValue)(cleanOp) + } + + /** + * Aggregate the elements of each partition, and then the results for all the + * partitions, using given combine functions and a neutral "zero value". This + * function can return a different result type, U, than the type of this RDD, T. + * Thus, we need one operation for merging a T into an U and one operation for + * merging two U's, as in scala.TraversableOnce. Both of these functions are + * allowed to modify and return their first argument instead of creating a new U + * to avoid memory allocation. + */ + def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = { + val cleanSeqOp = sc.clean(seqOp) + val cleanCombOp = sc.clean(combOp) + val results = sc.runJob(this, + (iter: Iterator[T]) => iter.aggregate(zeroValue)(cleanSeqOp, cleanCombOp)) + return results.fold(zeroValue)(cleanCombOp) } def count(): Long = { diff --git a/core/src/test/scala/spark/RDDSuite.scala b/core/src/test/scala/spark/RDDSuite.scala index 06d438d9e22264bb1b4092f4ee5a62038d3b898a..7199b634b764a79463d450a881523f52492125c1 100644 --- a/core/src/test/scala/spark/RDDSuite.scala +++ b/core/src/test/scala/spark/RDDSuite.scala @@ -1,5 +1,6 @@ package spark +import scala.collection.mutable.HashMap import org.scalatest.FunSuite import SparkContext._ @@ -9,6 +10,7 @@ class RDDSuite extends FunSuite { val nums = sc.makeRDD(Array(1, 2, 3, 4), 2) assert(nums.collect().toList === List(1, 2, 3, 4)) assert(nums.reduce(_ + _) === 10) + assert(nums.fold(0)(_ + _) === 10) assert(nums.map(_.toString).collect().toList === List("1", "2", "3", "4")) assert(nums.filter(_ > 2).collect().toList === List(3, 4)) assert(nums.flatMap(x => 1 to x).collect().toList === List(1, 1, 2, 1, 2, 3, 1, 2, 3, 4)) @@ -18,4 +20,26 @@ class RDDSuite extends FunSuite { assert(partitionSums.collect().toList === List(3, 7)) sc.stop() } + + test("aggregate") { + val sc = new SparkContext("local", "test") + val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))) + type StringMap = HashMap[String, Int] + val emptyMap = new StringMap { + override def default(key: String): Int = 0 + } + val mergeElement: (StringMap, (String, Int)) => StringMap = (map, pair) => { + map(pair._1) += pair._2 + map + } + val mergeMaps: (StringMap, StringMap) => StringMap = (map1, map2) => { + for ((key, value) <- map2) { + map1(key) += value + } + map1 + } + val result = pairs.aggregate(emptyMap)(mergeElement, mergeMaps) + assert(result.toSet === Set(("a", 6), ("b", 2), ("c", 5))) + sc.stop() + } }