diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 7f760a59bda2fbb73366cfd973b7e4d9a52880e3..195fd4f818b36e24d86ddef15e25845a076d3f4b 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -23,11 +23,14 @@ import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.concurrent.{ExecutionContext, Future} +import scala.concurrent.duration.Duration import scala.reflect.ClassTag import scala.util.control.NonFatal import org.apache.spark.broadcast.{Broadcast, BroadcastManager} import org.apache.spark.internal.Logging +import org.apache.spark.internal.config._ import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException @@ -472,15 +475,66 @@ private[spark] class MapOutputTrackerMaster( shuffleStatuses.get(shuffleId).map(_.findMissingPartitions()) } + /** + * Grouped function of Range, this is to avoid traverse of all elements of Range using + * IterableLike's grouped function. + */ + def rangeGrouped(range: Range, size: Int): Seq[Range] = { + val start = range.start + val step = range.step + val end = range.end + for (i <- start.until(end, size * step)) yield { + i.until(i + size * step, step) + } + } + + /** + * To equally divide n elements into m buckets, basically each bucket should have n/m elements, + * for the remaining n%m elements, add one more element to the first n%m buckets each. + */ + def equallyDivide(numElements: Int, numBuckets: Int): Seq[Seq[Int]] = { + val elementsPerBucket = numElements / numBuckets + val remaining = numElements % numBuckets + val splitPoint = (elementsPerBucket + 1) * remaining + if (elementsPerBucket == 0) { + rangeGrouped(0.until(splitPoint), elementsPerBucket + 1) + } else { + rangeGrouped(0.until(splitPoint), elementsPerBucket + 1) ++ + rangeGrouped(splitPoint.until(numElements), elementsPerBucket) + } + } + /** * Return statistics about all of the outputs for a given shuffle. */ def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = { shuffleStatuses(dep.shuffleId).withMapStatuses { statuses => val totalSizes = new Array[Long](dep.partitioner.numPartitions) - for (s <- statuses) { - for (i <- 0 until totalSizes.length) { - totalSizes(i) += s.getSizeForBlock(i) + val parallelAggThreshold = conf.get( + SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD) + val parallelism = math.min( + Runtime.getRuntime.availableProcessors(), + statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1).toInt + if (parallelism <= 1) { + for (s <- statuses) { + for (i <- 0 until totalSizes.length) { + totalSizes(i) += s.getSizeForBlock(i) + } + } + } else { + val threadPool = ThreadUtils.newDaemonFixedThreadPool(parallelism, "map-output-aggregate") + try { + implicit val executionContext = ExecutionContext.fromExecutor(threadPool) + val mapStatusSubmitTasks = equallyDivide(totalSizes.length, parallelism).map { + reduceIds => Future { + for (s <- statuses; i <- reduceIds) { + totalSizes(i) += s.getSizeForBlock(i) + } + } + } + ThreadUtils.awaitResult(Future.sequence(mapStatusSubmitTasks), Duration.Inf) + } finally { + threadPool.shutdown() } } new MapOutputStatistics(dep.shuffleId, totalSizes) diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 7a9072736b9aa3433abda881180fbd73a4f68eb9..8fa25c0281493bcbfa8203a884f5af742860d611 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -499,4 +499,15 @@ package object config { "array in the sorter.") .intConf .createWithDefault(Integer.MAX_VALUE) + + private[spark] val SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD = + ConfigBuilder("spark.shuffle.mapOutput.parallelAggregationThreshold") + .internal() + .doc("Multi-thread is used when the number of mappers * shuffle partitions is greater than " + + "or equal to this threshold. Note that the actual parallelism is calculated by number of " + + "mappers * shuffle partitions / this threshold + 1, so this threshold should be positive.") + .intConf + .checkValue(v => v > 0, "The threshold should be positive.") + .createWithDefault(10000000) + } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index ebd826b0ba2f644c16dcdb63a312ee10e9a67326..50b8ea754d8d90747b255d4407cb981b2c214dfa 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -275,4 +275,27 @@ class MapOutputTrackerSuite extends SparkFunSuite { } } + test("equally divide map statistics tasks") { + val func = newTrackerMaster().equallyDivide _ + val cases = Seq((0, 5), (4, 5), (15, 5), (16, 5), (17, 5), (18, 5), (19, 5), (20, 5)) + val expects = Seq( + Seq(0, 0, 0, 0, 0), + Seq(1, 1, 1, 1, 0), + Seq(3, 3, 3, 3, 3), + Seq(4, 3, 3, 3, 3), + Seq(4, 4, 3, 3, 3), + Seq(4, 4, 4, 3, 3), + Seq(4, 4, 4, 4, 3), + Seq(4, 4, 4, 4, 4)) + cases.zip(expects).foreach { case ((num, divisor), expect) => + val answer = func(num, divisor).toSeq + var wholeSplit = (0 until num) + answer.zip(expect).foreach { case (split, expectSplitLength) => + val (currentSplit, rest) = wholeSplit.splitAt(expectSplitLength) + assert(currentSplit.toSet == split.toSet) + wholeSplit = rest + } + } + } + }