diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 7f760a59bda2fbb73366cfd973b7e4d9a52880e3..195fd4f818b36e24d86ddef15e25845a076d3f4b 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -23,11 +23,14 @@ import java.util.zip.{GZIPInputStream, GZIPOutputStream}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
+import scala.concurrent.{ExecutionContext, Future}
+import scala.concurrent.duration.Duration
 import scala.reflect.ClassTag
 import scala.util.control.NonFatal
 
 import org.apache.spark.broadcast.{Broadcast, BroadcastManager}
 import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
 import org.apache.spark.rpc.{RpcCallContext, RpcEndpoint, RpcEndpointRef, RpcEnv}
 import org.apache.spark.scheduler.MapStatus
 import org.apache.spark.shuffle.MetadataFetchFailedException
@@ -472,15 +475,66 @@ private[spark] class MapOutputTrackerMaster(
     shuffleStatuses.get(shuffleId).map(_.findMissingPartitions())
   }
 
+  /**
+   * Grouped function of Range, this is to avoid traverse of all elements of Range using
+   * IterableLike's grouped function.
+   */
+  def rangeGrouped(range: Range, size: Int): Seq[Range] = {
+    val start = range.start
+    val step = range.step
+    val end = range.end
+    for (i <- start.until(end, size * step)) yield {
+      i.until(i + size * step, step)
+    }
+  }
+
+  /**
+   * To equally divide n elements into m buckets, basically each bucket should have n/m elements,
+   * for the remaining n%m elements, add one more element to the first n%m buckets each.
+   */
+  def equallyDivide(numElements: Int, numBuckets: Int): Seq[Seq[Int]] = {
+    val elementsPerBucket = numElements / numBuckets
+    val remaining = numElements % numBuckets
+    val splitPoint = (elementsPerBucket + 1) * remaining
+    if (elementsPerBucket == 0) {
+      rangeGrouped(0.until(splitPoint), elementsPerBucket + 1)
+    } else {
+      rangeGrouped(0.until(splitPoint), elementsPerBucket + 1) ++
+        rangeGrouped(splitPoint.until(numElements), elementsPerBucket)
+    }
+  }
+
   /**
    * Return statistics about all of the outputs for a given shuffle.
    */
   def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = {
     shuffleStatuses(dep.shuffleId).withMapStatuses { statuses =>
       val totalSizes = new Array[Long](dep.partitioner.numPartitions)
-      for (s <- statuses) {
-        for (i <- 0 until totalSizes.length) {
-          totalSizes(i) += s.getSizeForBlock(i)
+      val parallelAggThreshold = conf.get(
+        SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD)
+      val parallelism = math.min(
+        Runtime.getRuntime.availableProcessors(),
+        statuses.length.toLong * totalSizes.length / parallelAggThreshold + 1).toInt
+      if (parallelism <= 1) {
+        for (s <- statuses) {
+          for (i <- 0 until totalSizes.length) {
+            totalSizes(i) += s.getSizeForBlock(i)
+          }
+        }
+      } else {
+        val threadPool = ThreadUtils.newDaemonFixedThreadPool(parallelism, "map-output-aggregate")
+        try {
+          implicit val executionContext = ExecutionContext.fromExecutor(threadPool)
+          val mapStatusSubmitTasks = equallyDivide(totalSizes.length, parallelism).map {
+            reduceIds => Future {
+              for (s <- statuses; i <- reduceIds) {
+                totalSizes(i) += s.getSizeForBlock(i)
+              }
+            }
+          }
+          ThreadUtils.awaitResult(Future.sequence(mapStatusSubmitTasks), Duration.Inf)
+        } finally {
+          threadPool.shutdown()
         }
       }
       new MapOutputStatistics(dep.shuffleId, totalSizes)
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 7a9072736b9aa3433abda881180fbd73a4f68eb9..8fa25c0281493bcbfa8203a884f5af742860d611 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -499,4 +499,15 @@ package object config {
         "array in the sorter.")
       .intConf
       .createWithDefault(Integer.MAX_VALUE)
+
+  private[spark] val SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD =
+    ConfigBuilder("spark.shuffle.mapOutput.parallelAggregationThreshold")
+      .internal()
+      .doc("Multi-thread is used when the number of mappers * shuffle partitions is greater than " +
+        "or equal to this threshold. Note that the actual parallelism is calculated by number of " +
+        "mappers * shuffle partitions / this threshold + 1, so this threshold should be positive.")
+      .intConf
+      .checkValue(v => v > 0, "The threshold should be positive.")
+      .createWithDefault(10000000)
+
 }
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index ebd826b0ba2f644c16dcdb63a312ee10e9a67326..50b8ea754d8d90747b255d4407cb981b2c214dfa 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -275,4 +275,27 @@ class MapOutputTrackerSuite extends SparkFunSuite {
     }
   }
 
+  test("equally divide map statistics tasks") {
+    val func = newTrackerMaster().equallyDivide _
+    val cases = Seq((0, 5), (4, 5), (15, 5), (16, 5), (17, 5), (18, 5), (19, 5), (20, 5))
+    val expects = Seq(
+      Seq(0, 0, 0, 0, 0),
+      Seq(1, 1, 1, 1, 0),
+      Seq(3, 3, 3, 3, 3),
+      Seq(4, 3, 3, 3, 3),
+      Seq(4, 4, 3, 3, 3),
+      Seq(4, 4, 4, 3, 3),
+      Seq(4, 4, 4, 4, 3),
+      Seq(4, 4, 4, 4, 4))
+    cases.zip(expects).foreach { case ((num, divisor), expect) =>
+      val answer = func(num, divisor).toSeq
+      var wholeSplit = (0 until num)
+      answer.zip(expect).foreach { case (split, expectSplitLength) =>
+        val (currentSplit, rest) = wholeSplit.splitAt(expectSplitLength)
+        assert(currentSplit.toSet == split.toSet)
+        wholeSplit = rest
+      }
+    }
+  }
+
 }