diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 63a87e7f09d85aae021c1d8b18ac69db659aa919..2985c901194683216ac2887588c7a79989c463a7 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -1118,9 +1118,9 @@ abstract class RDD[T: ClassTag]( /** * Aggregates the elements of this RDD in a multi-level tree pattern. + * This method is semantically identical to [[org.apache.spark.rdd.RDD#aggregate]]. * * @param depth suggested depth of the tree (default: 2) - * @see [[org.apache.spark.rdd.RDD#aggregate]] */ def treeAggregate[U: ClassTag](zeroValue: U)( seqOp: (U, T) => U, @@ -1134,7 +1134,7 @@ abstract class RDD[T: ClassTag]( val cleanCombOp = context.clean(combOp) val aggregatePartition = (it: Iterator[T]) => it.aggregate(zeroValue)(cleanSeqOp, cleanCombOp) - var partiallyAggregated = mapPartitions(it => Iterator(aggregatePartition(it))) + var partiallyAggregated: RDD[U] = mapPartitions(it => Iterator(aggregatePartition(it))) var numPartitions = partiallyAggregated.partitions.length val scale = math.max(math.ceil(math.pow(numPartitions, 1.0 / depth)).toInt, 2) // If creating an extra level doesn't help reduce @@ -1146,9 +1146,10 @@ abstract class RDD[T: ClassTag]( val curNumPartitions = numPartitions partiallyAggregated = partiallyAggregated.mapPartitionsWithIndex { (i, iter) => iter.map((i % curNumPartitions, _)) - }.reduceByKey(new HashPartitioner(curNumPartitions), cleanCombOp).values + }.foldByKey(zeroValue, new HashPartitioner(curNumPartitions))(cleanCombOp).values } - partiallyAggregated.reduce(cleanCombOp) + val copiedZeroValue = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) + partiallyAggregated.fold(copiedZeroValue)(cleanCombOp) } } diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 8d06f5468f4f1e50e6ee5ab7a9e7a93109463743..386c0060f9c41ffbd62da5de2b3dde8b32813ce0 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -192,6 +192,23 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { assert(ser.serialize(union.partitions.head).limit() < 2000) } + test("fold") { + val rdd = sc.makeRDD(-1000 until 1000, 10) + def op: (Int, Int) => Int = (c: Int, x: Int) => c + x + val sum = rdd.fold(0)(op) + assert(sum === -1000) + } + + test("fold with op modifying first arg") { + val rdd = sc.makeRDD(-1000 until 1000, 10).map(x => Array(x)) + def op: (Array[Int], Array[Int]) => Array[Int] = { (c: Array[Int], x: Array[Int]) => + c(0) += x(0) + c + } + val sum = rdd.fold(Array(0))(op) + assert(sum(0) === -1000) + } + test("aggregate") { val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))) type StringMap = HashMap[String, Int] @@ -218,7 +235,19 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext { def combOp: (Long, Long) => Long = (c1: Long, c2: Long) => c1 + c2 for (depth <- 1 until 10) { val sum = rdd.treeAggregate(0L)(seqOp, combOp, depth) - assert(sum === -1000L) + assert(sum === -1000) + } + } + + test("treeAggregate with ops modifying first args") { + val rdd = sc.makeRDD(-1000 until 1000, 10).map(x => Array(x)) + def op: (Array[Int], Array[Int]) => Array[Int] = { (c: Array[Int], x: Array[Int]) => + c(0) += x(0) + c + } + for (depth <- 1 until 10) { + val sum = rdd.treeAggregate(Array(0))(op, op, depth) + assert(sum(0) === -1000) } }