Skip to content
Snippets Groups Projects
Commit 10bcd217 authored by Josh Rosen's avatar Josh Rosen
Browse files

Remove mapSideCombine field from Aggregator.

Instead, the presence or absense of a ShuffleDependency's aggregator
will control whether map-side combining is performed.
parent 4775c556
No related branches found
No related tags found
No related merge requests found
...@@ -9,15 +9,11 @@ import scala.collection.JavaConversions._ ...@@ -9,15 +9,11 @@ import scala.collection.JavaConversions._
* @param createCombiner function to create the initial value of the aggregation. * @param createCombiner function to create the initial value of the aggregation.
* @param mergeValue function to merge a new value into the aggregation result. * @param mergeValue function to merge a new value into the aggregation result.
* @param mergeCombiners function to merge outputs from multiple mergeValue function. * @param mergeCombiners function to merge outputs from multiple mergeValue function.
* @param mapSideCombine whether to apply combiners on map partitions, also
* known as map-side aggregations. When set to false,
* mergeCombiners function is not used.
*/ */
case class Aggregator[K, V, C] ( case class Aggregator[K, V, C] (
val createCombiner: V => C, val createCombiner: V => C,
val mergeValue: (C, V) => C, val mergeValue: (C, V) => C,
val mergeCombiners: (C, C) => C, val mergeCombiners: (C, C) => C) {
val mapSideCombine: Boolean = true) {
def combineValuesByKey(iter: Iterator[(K, V)]) : Iterator[(K, C)] = { def combineValuesByKey(iter: Iterator[(K, V)]) : Iterator[(K, C)] = {
val combiners = new JHashMap[K, C] val combiners = new JHashMap[K, C]
......
...@@ -22,7 +22,7 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { ...@@ -22,7 +22,7 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
* Represents a dependency on the output of a shuffle stage. * Represents a dependency on the output of a shuffle stage.
* @param shuffleId the shuffle id * @param shuffleId the shuffle id
* @param rdd the parent RDD * @param rdd the parent RDD
* @param aggregator optional aggregator; this allows for map-side combining * @param aggregator optional aggregator; if provided, map-side combining will be performed
* @param partitioner partitioner used to partition the shuffle output * @param partitioner partitioner used to partition the shuffle output
*/ */
class ShuffleDependency[K, V, C]( class ShuffleDependency[K, V, C](
......
...@@ -14,6 +14,13 @@ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split { ...@@ -14,6 +14,13 @@ private[spark] class ShuffledRDDSplit(val idx: Int) extends Split {
/** /**
* The resulting RDD from a shuffle (e.g. repartitioning of data). * The resulting RDD from a shuffle (e.g. repartitioning of data).
* @param parent the parent RDD.
* @param aggregator if provided, this aggregator will be used to perform map-side combining.
* @param part the partitioner used to partition the RDD
* @tparam K the key class.
* @tparam V the value class.
* @tparam C if map side combiners are used, then this is the combiner type; otherwise,
* this is the same as V.
*/ */
class ShuffledRDD[K, V, C]( class ShuffledRDD[K, V, C](
@transient parent: RDD[(K, V)], @transient parent: RDD[(K, V)],
......
...@@ -114,7 +114,7 @@ private[spark] class ShuffleMapTask( ...@@ -114,7 +114,7 @@ private[spark] class ShuffleMapTask(
val partitioner = dep.partitioner val partitioner = dep.partitioner
val bucketIterators = val bucketIterators =
if (dep.aggregator.isDefined && dep.aggregator.get.mapSideCombine) { if (dep.aggregator.isDefined) {
val aggregator = dep.aggregator.get.asInstanceOf[Aggregator[Any, Any, Any]] val aggregator = dep.aggregator.get.asInstanceOf[Aggregator[Any, Any, Any]]
// Apply combiners (map-side aggregation) to the map output. // Apply combiners (map-side aggregation) to the map output.
val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any]) val buckets = Array.tabulate(numOutputSplits)(_ => new JHashMap[Any, Any])
......
...@@ -228,8 +228,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { ...@@ -228,8 +228,7 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
val aggregator = new Aggregator[Int, Int, Int]( val aggregator = new Aggregator[Int, Int, Int](
(v: Int) => v, (v: Int) => v,
_+_, _+_,
_+_, _+_)
false)
// Turn off map-side combine and test the results. // Turn off map-side combine and test the results.
var shuffledRdd : RDD[(Int, Int)] = var shuffledRdd : RDD[(Int, Int)] =
...@@ -237,22 +236,13 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter { ...@@ -237,22 +236,13 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with BeforeAndAfter {
shuffledRdd = shuffledRdd.mapPartitions(aggregator.combineValuesByKey(_)) shuffledRdd = shuffledRdd.mapPartitions(aggregator.combineValuesByKey(_))
assert(shuffledRdd.collect().toSet === Set((1,8), (2, 1))) assert(shuffledRdd.collect().toSet === Set((1,8), (2, 1)))
// Turn map-side combine off and pass a wrong mergeCombine function. Should // Run a wrong mergeCombine function with map-side combine on.
// not see an exception because mergeCombine should not have been called. // We expect to see an exception thrown.
val aggregatorWithException = new Aggregator[Int, Int, Int]( val aggregatorWithException = new Aggregator[Int, Int, Int](
(v: Int) => v, _+_, ShuffleSuite.mergeCombineException, false)
var shuffledRdd1 : RDD[(Int, Int)] =
new ShuffledRDD[Int, Int, Int](pairs, Some(aggregatorWithException), new HashPartitioner(2))
shuffledRdd1 = shuffledRdd1.mapPartitions(aggregatorWithException.combineValuesByKey(_))
assert(shuffledRdd1.collect().toSet === Set((1, 8), (2, 1)))
// Now run the same mergeCombine function with map-side combine on. We
// expect to see an exception thrown.
val aggregatorWithException1 = new Aggregator[Int, Int, Int](
(v: Int) => v, _+_, ShuffleSuite.mergeCombineException) (v: Int) => v, _+_, ShuffleSuite.mergeCombineException)
var shuffledRdd2 : RDD[(Int, Int)] = var shuffledRdd2 : RDD[(Int, Int)] =
new ShuffledRDD[Int, Int, Int](pairs, Some(aggregatorWithException1), new HashPartitioner(2)) new ShuffledRDD[Int, Int, Int](pairs, Some(aggregatorWithException), new HashPartitioner(2))
shuffledRdd2 = shuffledRdd2.mapPartitions(aggregatorWithException1.combineCombinersByKey(_)) shuffledRdd2 = shuffledRdd2.mapPartitions(aggregatorWithException.combineCombinersByKey(_))
evaluating { shuffledRdd2.collect() } should produce [SparkException] evaluating { shuffledRdd2.collect() } should produce [SparkException]
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment