Skip to content
Snippets Groups Projects
Commit 10bcd217 authored by Josh Rosen's avatar Josh Rosen
Browse files

Remove mapSideCombine field from Aggregator.

Instead, the presence or absense of a ShuffleDependency's aggregator
will control whether map-side combining is performed.
parent 4775c556
No related branches found
No related tags found
No related merge requests found
......@@ -9,15 +9,11 @@ import scala.collection.JavaConversions._
* @param createCombiner function to create the initial value of the aggregation.
* @param mergeValue function to merge a new value into the aggregation result.
* @param mergeCombiners function to merge outputs from multiple mergeValue function.
* @param mapSideCombine whether to apply combiners on map partitions, also
* known as map-side aggregations. When set to false,
* mergeCombiners function is not used.
*/
case class Aggregator[K, V, C] (
val createCombiner: V => C,
val mergeValue: (C, V) => C,
val mergeCombiners: (C, C) => C,
val mapSideCombine: Boolean = true) {
val mergeCombiners: (C, C) => C) {
def combineValuesByKey(iter: Iterator[(K, V)]) : Iterator[(K, C)] = {
val combiners = new JHashMap[K, C]
......
......@@ -22,7 +22,7 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
* Represents a dependency on the output of a shuffle stage.
* @param shuffleId the shuffle id
* @param rdd the parent RDD
* @param aggregator optional aggregator; this allows for map-side combining
* @param aggregator optional aggregator; if provided, map-side combining will be performed
* @param partitioner partitioner used to partition the shuffle output
*/
class ShuffleDependency[K, V, C](
......
......@@ -14,6 +14,13 @@ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
/**
* The resulting RDD from a shuffle (e.g. repartitioning of data).
* @param parent the parent RDD.
* @param aggregator if provided, this aggregator will be used to perform map-side combining.
* @param part the partitioner used to partition the RDD
* @tparam K the key class.
* @tparam V the value class.
* @tparam C if map side combiners are used, then this is the combiner type; otherwise,
* this is the same as V.
*/
class ShuffledRDD[K, V, C](
@transient parent: RDD[(K, V)],
......
......@@ -114,7 +114,7 @@ private[spark] class ShuffleMapTask(
val partitioner = dep.partitioner
val bucketIterators =
if (dep.aggregator.isDefined && dep.aggregator.get.mapSideCombine) {
if (dep.aggregator.isDefined) {
val aggregator = dep.aggregator.get.asInstanceOf[Aggregator[Any, Any, Any]]
// Apply combiners (map-side aggregation) to the map output.
val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any])
......
......@@ -228,8 +228,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
val aggregator = new Aggregator[Int, Int, Int](
(v: Int) => v,
_+_,
_+_,
false)
_+_)
// Turn off map-side combine and test the results.
var shuffledRdd : RDD[(Int, Int)] =
......@@ -237,22 +236,13 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
shuffledRdd = shuffledRdd.mapPartitions(aggregator.combineValuesByKey(_))
assert(shuffledRdd.collect().toSet === Set((1,8), (2, 1)))
// Turn map-side combine off and pass a wrong mergeCombine function. Should
// not see an exception because mergeCombine should not have been called.
// Run a wrong mergeCombine function with map-side combine on.
// We expect to see an exception thrown.
val aggregatorWithException = new Aggregator[Int, Int, Int](
(v: Int) => v, _+_, ShuffleSuite.mergeCombineException, false)
var shuffledRdd1 : RDD[(Int, Int)] =
new ShuffledRDD[Int, Int, Int](pairs, Some(aggregatorWithException), new HashPartitioner(2))
shuffledRdd1 = shuffledRdd1.mapPartitions(aggregatorWithException.combineValuesByKey(_))
assert(shuffledRdd1.collect().toSet === Set((1, 8), (2, 1)))
// Now run the same mergeCombine function with map-side combine on. We
// expect to see an exception thrown.
val aggregatorWithException1 = new Aggregator[Int, Int, Int](
(v: Int) => v, _+_, ShuffleSuite.mergeCombineException)
var shuffledRdd2 : RDD[(Int, Int)] =
new ShuffledRDD[Int, Int, Int](pairs, Some(aggregatorWithException1), new HashPartitioner(2))
shuffledRdd2 = shuffledRdd2.mapPartitions(aggregatorWithException1.combineCombinersByKey(_))
new ShuffledRDD[Int, Int, Int](pairs, Some(aggregatorWithException), new HashPartitioner(2))
shuffledRdd2 = shuffledRdd2.mapPartitions(aggregatorWithException.combineCombinersByKey(_))
evaluating { shuffledRdd2.collect() } should produce [SparkException]
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment