From efd0036ec88bdc385f5a9ea568d2e2bbfcda2912 Mon Sep 17 00:00:00 2001
From: GuoChenzhao <chenzhao.guo@intel.com>
Date: Fri, 24 Nov 2017 15:09:43 +0100
Subject: [PATCH] [SPARK-22537][CORE] Aggregation of map output statistics on
 driver faces single point bottleneck

## What changes were proposed in this pull request?

In adaptive execution, the map output statistics of all mappers will be aggregated after previous stage is successfully executed. Driver takes the aggregation job while it will get slow when the number of `mapper * shuffle partitions` is large, since it only uses single thread to compute. This PR uses multi-thread to deal with this single point bottleneck.

## How was this patch tested?

Test cases are in `MapOutputTrackerSuite.scala`

Author: GuoChenzhao <chenzhao.guo@intel.com>
Author: gczsjdy <gczsjdy1994@gmail.com>

Closes #19763 from gczsjdy/single_point_mapstatistics.
---
 .../org/apache/spark/MapOutputTracker.scala   | 60 ++++++++++++++++++-
 .../spark/internal/config/package.scala       | 11 ++++
 .../apache/spark/MapOutputTrackerSuite.scala  | 23 +++++++
 3 files changed, 91 insertions(+), 3 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 7f760a59bd..195fd4f818 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -23,11 +23,14 @@ import java.util.zip.{GZIPInputStream, GZIPOutputStream}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
+import scala.concurrent.{ExecutionContext, Future}
+import scala.concurrent.duration.Duration
 import scala.reflect.ClassTag
 import scala.util.control.NonFatal
 
 import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
 import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
 import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
 import org.apache.spark.scheduler.MapStatus
 import org.apache.spark.shuffle.MetadataFetchFailedException
@@ -472,15 +475,66 @@ private[spark] class MapOutputTrackerMaster(
     shuffleStatuses.get(shuffleId).map(_.findMissingPartitions())
   }
 
+  /**
+   * Grouped function of Range, this is to avoid traverse of all elements of Range using
+   * IterableLike's grouped function.
+   */
+  def rangeGrouped(range: Range, size: Int): Seq[Range] = {
+    val start = range.start
+    val step = range.step
+    val end = range.end
+    for (i <- start.until(end, size * step)) yield {
+      i.until(i + size * step, step)
+    }
+  }
+
+  /**
+   * To equally divide n elements into m buckets, basically each bucket should have n/m elements,
+   * for the remaining n%m elements, add one more element to the first n%m buckets each.
+   */
+  def equallyDivide(numElements: Int, numBuckets: Int): Seq[Seq[Int]] = {
+    val elementsPerBucket = numElements / numBuckets
+    val remaining = numElements % numBuckets
+    val splitPoint = (elementsPerBucket + 1) * remaining
+    if (elementsPerBucket == 0) {
+      rangeGrouped(0.until(splitPoint), elementsPerBucket + 1)
+    } else {
+      rangeGrouped(0.until(splitPoint), elementsPerBucket + 1) ++
+        rangeGrouped(splitPoint.until(numElements), elementsPerBucket)
+    }
+  }
+
   /**
    * Return statistics about all of the outputs for a given shuffle.
    */
   def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = {
     shuffleStatuses(dep.shuffleId).withMapStatuses { statuses =>
       val totalSizes = new Array[Long](dep.partitioner.numPartitions)
-      for (s <- statuses) {
-        for (i <- 0 until totalSizes.length) {
-          totalSizes(i) += s.getSizeForBlock(i)
+      val parallelAggThreshold = conf.get(
+        SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD)
+      val parallelism = math.min(
+        Runtime.getRuntime.availableProcessors(),
+        statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1).toInt
+      if (parallelism <= 1) {
+        for (s <- statuses) {
+          for (i <- 0 until totalSizes.length) {
+            totalSizes(i) += s.getSizeForBlock(i)
+          }
+        }
+      } else {
+        val threadPool = ThreadUtils.newDaemonFixedThreadPool(parallelism, "map-output-aggregate")
+        try {
+          implicit val executionContext = ExecutionContext.fromExecutor(threadPool)
+          val mapStatusSubmitTasks = equallyDivide(totalSizes.length, parallelism).map {
+            reduceIds => Future {
+              for (s <- statuses; i <- reduceIds) {
+                totalSizes(i) += s.getSizeForBlock(i)
+              }
+            }
+          }
+          ThreadUtils.awaitResult(Future.sequence(mapStatusSubmitTasks), Duration.Inf)
+        } finally {
+          threadPool.shutdown()
         }
       }
       new MapOutputStatistics(dep.shuffleId, totalSizes)
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 7a9072736b..8fa25c0281 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -499,4 +499,15 @@ package object config {
         "array in the sorter.")
       .intConf
       .createWithDefault(Integer.MAX_VALUE)
+
+  private[spark] val SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD =
+    ConfigBuilder("spark.shuffle.mapOutput.parallelAggregationThreshold")
+      .internal()
+      .doc("Multi-thread is used when the number of mappers * shuffle partitions is greater than " +
+        "or equal to this threshold. Note that the actual parallelism is calculated by number of " +
+        "mappers * shuffle partitions / this threshold + 1, so this threshold should be positive.")
+      .intConf
+      .checkValue(v => v > 0, "The threshold should be positive.")
+      .createWithDefault(10000000)
+
 }
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index ebd826b0ba..50b8ea754d 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -275,4 +275,27 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     }
   }
 
+  test("equally divide map statistics tasks") {
+    val func = newTrackerMaster().equallyDivide _
+    val cases = Seq((0, 5), (4, 5), (15, 5), (16, 5), (17, 5), (18, 5), (19, 5), (20, 5))
+    val expects = Seq(
+      Seq(0, 0, 0, 0, 0),
+      Seq(1, 1, 1, 1, 0),
+      Seq(3, 3, 3, 3, 3),
+      Seq(4, 3, 3, 3, 3),
+      Seq(4, 4, 3, 3, 3),
+      Seq(4, 4, 4, 3, 3),
+      Seq(4, 4, 4, 4, 3),
+      Seq(4, 4, 4, 4, 4))
+    cases.zip(expects).foreach { case ((num, divisor), expect) =>
+      val answer = func(num, divisor).toSeq
+      var wholeSplit = (0 until num)
+      answer.zip(expect).foreach { case (split, expectSplitLength) =>
+        val (currentSplit, rest) = wholeSplit.splitAt(expectSplitLength)
+        assert(currentSplit.toSet == split.toSet)
+        wholeSplit = rest
+      }
+    }
+  }
+
 }
-- 
GitLab