From 4c51098f320f164eb66f92ff0f26b0b595a58f38 Mon Sep 17 00:00:00 2001
From: Sandy Ryza <sandy@cloudera.com>
Date: Thu, 7 Aug 2014 18:09:03 -0700
Subject: [PATCH] SPARK-2565. Update ShuffleReadMetrics as blocks are fetched

Author: Sandy Ryza <sandy@cloudera.com>

Closes #1507 from sryza/sandy-spark-2565 and squashes the following commits:

74dad41 [Sandy Ryza] SPARK-2565. Update ShuffleReadMetrics as blocks are fetched
---
 .../org/apache/spark/executor/Executor.scala  |  1 +
 .../apache/spark/executor/TaskMetrics.scala   | 55 ++++++++++++++-----
 .../hash/BlockStoreShuffleFetcher.scala       | 13 ++---
 .../shuffle/hash/HashShuffleReader.scala      |  4 +-
 .../spark/storage/BlockFetcherIterator.scala  | 40 +++++---------
 .../apache/spark/storage/BlockManager.scala   | 11 ++--
 .../org/apache/spark/util/JsonProtocol.scala  |  5 +-
 .../storage/BlockFetcherIteratorSuite.scala   | 13 +++--
 .../ui/jobs/JobProgressListenerSuite.scala    |  4 +-
 .../apache/spark/util/JsonProtocolSuite.scala |  2 +-
 10 files changed, 84 insertions(+), 64 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index c2b9c660dd..eac1f2326a 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -374,6 +374,7 @@ private[spark] class Executor(
           for (taskRunner <- runningTasks.values()) {
             if (!taskRunner.attemptedTask.isEmpty) {
               Option(taskRunner.task).flatMap(_.metrics).foreach { metrics =>
+                metrics.updateShuffleReadMetrics
                 tasksMetrics += ((taskRunner.taskId, metrics))
               }
             }
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 11a6e10243..99a88c1345 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.executor
 
+import scala.collection.mutable.ArrayBuffer
+
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.storage.{BlockId, BlockStatus}
 
@@ -81,12 +83,27 @@ class TaskMetrics extends Serializable {
   var inputMetrics: Option[InputMetrics] = None
 
   /**
-   * If this task reads from shuffle output, metrics on getting shuffle data will be collected here
+   * If this task reads from shuffle output, metrics on getting shuffle data will be collected here.
+   * This includes read metrics aggregated over all the task's shuffle dependencies.
    */
   private var _shuffleReadMetrics: Option[ShuffleReadMetrics] = None
 
   def shuffleReadMetrics = _shuffleReadMetrics
 
+  /**
+   * This should only be used when recreating TaskMetrics, not when updating read metrics in
+   * executors.
+   */
+  private[spark] def setShuffleReadMetrics(shuffleReadMetrics: Option[ShuffleReadMetrics]) {
+    _shuffleReadMetrics = shuffleReadMetrics
+  }
+
+  /**
+   * ShuffleReadMetrics per dependency for collecting independently while task is in progress.
+   */
+  @transient private lazy val depsShuffleReadMetrics: ArrayBuffer[ShuffleReadMetrics] =
+    new ArrayBuffer[ShuffleReadMetrics]()
+
   /**
    * If this task writes to shuffle output, metrics on the written shuffle data will be collected
    * here
@@ -98,19 +115,31 @@ class TaskMetrics extends Serializable {
    */
   var updatedBlocks: Option[Seq[(BlockId, BlockStatus)]] = None
 
-  /** Adds the given ShuffleReadMetrics to any existing shuffle metrics for this task. */
-  def updateShuffleReadMetrics(newMetrics: ShuffleReadMetrics) = synchronized {
-    _shuffleReadMetrics match {
-      case Some(existingMetrics) =>
-        existingMetrics.shuffleFinishTime = math.max(
-          existingMetrics.shuffleFinishTime, newMetrics.shuffleFinishTime)
-        existingMetrics.fetchWaitTime += newMetrics.fetchWaitTime
-        existingMetrics.localBlocksFetched += newMetrics.localBlocksFetched
-        existingMetrics.remoteBlocksFetched += newMetrics.remoteBlocksFetched
-        existingMetrics.remoteBytesRead += newMetrics.remoteBytesRead
-      case None =>
-        _shuffleReadMetrics = Some(newMetrics)
+  /**
+   * A task may have multiple shuffle readers for multiple dependencies. To avoid synchronization
+   * issues from readers in different threads, in-progress tasks use a ShuffleReadMetrics for each
+   * dependency, and merge these metrics before reporting them to the driver. This method returns
+   * a ShuffleReadMetrics for a dependency and registers it for merging later.
+   */
+  private [spark] def createShuffleReadMetricsForDependency(): ShuffleReadMetrics = synchronized {
+    val readMetrics = new ShuffleReadMetrics()
+    depsShuffleReadMetrics += readMetrics
+    readMetrics
+  }
+
+  /**
+   * Aggregates shuffle read metrics for all registered dependencies into shuffleReadMetrics.
+   */
+  private[spark] def updateShuffleReadMetrics() = synchronized {
+    val merged = new ShuffleReadMetrics()
+    for (depMetrics <- depsShuffleReadMetrics) {
+      merged.fetchWaitTime += depMetrics.fetchWaitTime
+      merged.localBlocksFetched += depMetrics.localBlocksFetched
+      merged.remoteBlocksFetched += depMetrics.remoteBlocksFetched
+      merged.remoteBytesRead += depMetrics.remoteBytesRead
+      merged.shuffleFinishTime = math.max(merged.shuffleFinishTime, depMetrics.shuffleFinishTime)
     }
+    _shuffleReadMetrics = Some(merged)
   }
 }
 
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index 9978882898..12b475658e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -32,7 +32,8 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
       shuffleId: Int,
       reduceId: Int,
       context: TaskContext,
-      serializer: Serializer)
+      serializer: Serializer,
+      shuffleMetrics: ShuffleReadMetrics)
     : Iterator[T] =
   {
     logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
@@ -73,17 +74,11 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
       }
     }
 
-    val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer)
+    val blockFetcherItr = blockManager.getMultiple(blocksByAddress, serializer, shuffleMetrics)
     val itr = blockFetcherItr.flatMap(unpackBlock)
 
     val completionIter = CompletionIterator[T, Iterator[T]](itr, {
-      val shuffleMetrics = new ShuffleReadMetrics
-      shuffleMetrics.shuffleFinishTime = System.currentTimeMillis
-      shuffleMetrics.fetchWaitTime = blockFetcherItr.fetchWaitTime
-      shuffleMetrics.remoteBytesRead = blockFetcherItr.remoteBytesRead
-      shuffleMetrics.localBlocksFetched = blockFetcherItr.numLocalBlocks
-      shuffleMetrics.remoteBlocksFetched = blockFetcherItr.numRemoteBlocks
-      context.taskMetrics.updateShuffleReadMetrics(shuffleMetrics)
+      context.taskMetrics.updateShuffleReadMetrics()
     })
 
     new InterruptibleIterator[T](context, completionIter)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index 88a5f1e5dd..7bed97a63f 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -36,8 +36,10 @@ private[spark] class HashShuffleReader[K, C](
 
   /** Read the combined key-values for this reduce task */
   override def read(): Iterator[Product2[K, C]] = {
+    val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
     val ser = Serializer.getSerializer(dep.serializer)
-    val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
+    val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser,
+      readMetrics)
 
     val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
       if (dep.mapSideCombine) {
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
index 938af6f5b9..5f44f5f319 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala
@@ -27,6 +27,7 @@ import scala.util.{Failure, Success}
 import io.netty.buffer.ByteBuf
 
 import org.apache.spark.{Logging, SparkException}
+import org.apache.spark.executor.ShuffleReadMetrics
 import org.apache.spark.network.BufferMessage
 import org.apache.spark.network.ConnectionManagerId
 import org.apache.spark.network.netty.ShuffleCopier
@@ -47,10 +48,6 @@ import org.apache.spark.util.Utils
 private[storage]
 trait BlockFetcherIterator extends Iterator[(BlockId, Option[Iterator[Any]])] with Logging {
   def initialize()
-  def numLocalBlocks: Int
-  def numRemoteBlocks: Int
-  def fetchWaitTime: Long
-  def remoteBytesRead: Long
 }
 
 
@@ -72,14 +69,12 @@ object BlockFetcherIterator {
   class BasicBlockFetcherIterator(
       private val blockManager: BlockManager,
       val blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
-      serializer: Serializer)
+      serializer: Serializer,
+      readMetrics: ShuffleReadMetrics)
     extends BlockFetcherIterator {
 
     import blockManager._
 
-    private var _remoteBytesRead = 0L
-    private var _fetchWaitTime = 0L
-
     if (blocksByAddress == null) {
       throw new IllegalArgumentException("BlocksByAddress is null")
     }
@@ -89,13 +84,9 @@ object BlockFetcherIterator {
 
     protected var startTime = System.currentTimeMillis
 
-    // This represents the number of local blocks, also counting zero-sized blocks
-    private var numLocal = 0
     // BlockIds for local blocks that need to be fetched. Excludes zero-sized blocks
     protected val localBlocksToFetch = new ArrayBuffer[BlockId]()
 
-    // This represents the number of remote blocks, also counting zero-sized blocks
-    private var numRemote = 0
     // BlockIds for remote blocks that need to be fetched. Excludes zero-sized blocks
     protected val remoteBlocksToFetch = new HashSet[BlockId]()
 
@@ -132,7 +123,10 @@ object BlockFetcherIterator {
             val networkSize = blockMessage.getData.limit()
             results.put(new FetchResult(blockId, sizeMap(blockId),
               () => dataDeserialize(blockId, blockMessage.getData, serializer)))
-            _remoteBytesRead += networkSize
+            // TODO: NettyBlockFetcherIterator has some race conditions where multiple threads can
+            // be incrementing bytes read at the same time (SPARK-2625).
+            readMetrics.remoteBytesRead += networkSize
+            readMetrics.remoteBlocksFetched += 1
             logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
           }
         }
@@ -155,14 +149,14 @@ object BlockFetcherIterator {
       // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
       // at most maxBytesInFlight in order to limit the amount of data in flight.
       val remoteRequests = new ArrayBuffer[FetchRequest]
+      var totalBlocks = 0
       for ((address, blockInfos) <- blocksByAddress) {
+        totalBlocks += blockInfos.size
         if (address == blockManagerId) {
-          numLocal = blockInfos.size
           // Filter out zero-sized blocks
           localBlocksToFetch ++= blockInfos.filter(_._2 != 0).map(_._1)
           _numBlocksToFetch += localBlocksToFetch.size
         } else {
-          numRemote += blockInfos.size
           val iterator = blockInfos.iterator
           var curRequestSize = 0L
           var curBlocks = new ArrayBuffer[(BlockId, Long)]
@@ -192,7 +186,7 @@ object BlockFetcherIterator {
         }
       }
       logInfo("Getting " + _numBlocksToFetch + " non-empty blocks out of " +
-        (numLocal + numRemote) + " blocks")
+        totalBlocks + " blocks")
       remoteRequests
     }
 
@@ -205,6 +199,7 @@ object BlockFetcherIterator {
           // getLocalFromDisk never return None but throws BlockException
           val iter = getLocalFromDisk(id, serializer).get
           // Pass 0 as size since it's not in flight
+          readMetrics.localBlocksFetched += 1
           results.put(new FetchResult(id, 0, () => iter))
           logDebug("Got local block " + id)
         } catch {
@@ -238,12 +233,6 @@ object BlockFetcherIterator {
       logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
     }
 
-    override def numLocalBlocks: Int = numLocal
-    override def numRemoteBlocks: Int = numRemote
-    override def fetchWaitTime: Long = _fetchWaitTime
-    override def remoteBytesRead: Long = _remoteBytesRead
-
-
     // Implementing the Iterator methods with an iterator that reads fetched blocks off the queue
     // as they arrive.
     @volatile protected var resultsGotten = 0
@@ -255,7 +244,7 @@ object BlockFetcherIterator {
       val startFetchWait = System.currentTimeMillis()
       val result = results.take()
       val stopFetchWait = System.currentTimeMillis()
-      _fetchWaitTime += (stopFetchWait - startFetchWait)
+      readMetrics.fetchWaitTime += (stopFetchWait - startFetchWait)
       if (! result.failed) bytesInFlight -= result.size
       while (!fetchRequests.isEmpty &&
         (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
@@ -269,8 +258,9 @@ object BlockFetcherIterator {
   class NettyBlockFetcherIterator(
       blockManager: BlockManager,
       blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
-      serializer: Serializer)
-    extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer) {
+      serializer: Serializer,
+      readMetrics: ShuffleReadMetrics)
+    extends BasicBlockFetcherIterator(blockManager, blocksByAddress, serializer, readMetrics) {
 
     import blockManager._
 
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 8d21b02b74..e8bbd298c6 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props}
 import sun.nio.ch.DirectBuffer
 
 import org.apache.spark._
-import org.apache.spark.executor.{DataReadMethod, InputMetrics, ShuffleWriteMetrics}
+import org.apache.spark.executor._
 import org.apache.spark.io.CompressionCodec
 import org.apache.spark.network._
 import org.apache.spark.serializer.Serializer
@@ -539,12 +539,15 @@ private[spark] class BlockManager(
    */
   def getMultiple(
       blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
-      serializer: Serializer): BlockFetcherIterator = {
+      serializer: Serializer,
+      readMetrics: ShuffleReadMetrics): BlockFetcherIterator = {
     val iter =
       if (conf.getBoolean("spark.shuffle.use.netty", false)) {
-        new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer)
+        new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer,
+          readMetrics)
       } else {
-        new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer)
+        new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer,
+          readMetrics)
       }
     iter.initialize()
     iter
diff --git a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
index b112b35936..6f8eb1ee12 100644
--- a/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
+++ b/core/src/main/scala/org/apache/spark/util/JsonProtocol.scala
@@ -560,9 +560,8 @@ private[spark] object JsonProtocol {
     metrics.resultSerializationTime = (json \ "Result Serialization Time").extract[Long]
     metrics.memoryBytesSpilled = (json \ "Memory Bytes Spilled").extract[Long]
     metrics.diskBytesSpilled = (json \ "Disk Bytes Spilled").extract[Long]
-    Utils.jsonOption(json \ "Shuffle Read Metrics").map { shuffleReadMetrics =>
-      metrics.updateShuffleReadMetrics(shuffleReadMetricsFromJson(shuffleReadMetrics))
-    }
+    metrics.setShuffleReadMetrics(
+      Utils.jsonOption(json \ "Shuffle Read Metrics").map(shuffleReadMetricsFromJson))
     metrics.shuffleWriteMetrics =
       Utils.jsonOption(json \ "Shuffle Write Metrics").map(shuffleWriteMetricsFromJson)
     metrics.inputMetrics =
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
index 1538995a6b..bcbfe8baf3 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockFetcherIteratorSuite.scala
@@ -33,6 +33,7 @@ import org.mockito.invocation.InvocationOnMock
 
 import org.apache.spark.storage.BlockFetcherIterator._
 import org.apache.spark.network.{ConnectionManager, Message}
+import org.apache.spark.executor.ShuffleReadMetrics
 
 class BlockFetcherIteratorSuite extends FunSuite with Matchers {
 
@@ -70,8 +71,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
       (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
     )
 
-    val iterator = new BasicBlockFetcherIterator(blockManager,
-      blocksByAddress, null)
+    val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null,
+      new ShuffleReadMetrics())
 
     iterator.initialize()
 
@@ -121,8 +122,8 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
       (bmId, blIds.map(blId => (blId, 1.asInstanceOf[Long])).toSeq)
     )
 
-    val iterator = new BasicBlockFetcherIterator(blockManager,
-      blocksByAddress, null)
+    val iterator = new BasicBlockFetcherIterator(blockManager, blocksByAddress, null,
+      new ShuffleReadMetrics())
 
     iterator.initialize()
 
@@ -165,7 +166,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
     )
 
     val iterator = new BasicBlockFetcherIterator(blockManager,
-      blocksByAddress, null)
+      blocksByAddress, null, new ShuffleReadMetrics())
 
     iterator.initialize()
     iterator.foreach{
@@ -219,7 +220,7 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
     )
 
     val iterator = new BasicBlockFetcherIterator(blockManager,
-      blocksByAddress, null)
+      blocksByAddress, null, new ShuffleReadMetrics())
     iterator.initialize()
     iterator.foreach{
       case (_, r) => {
diff --git a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
index cb82525152..f5ba31c309 100644
--- a/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/jobs/JobProgressListenerSuite.scala
@@ -65,7 +65,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
 
     // finish this task, should get updated shuffleRead
     shuffleReadMetrics.remoteBytesRead = 1000
-    taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics)
+    taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics))
     var taskInfo = new TaskInfo(1234L, 0, 1, 0L, "exe-1", "host1", TaskLocality.NODE_LOCAL, false)
     taskInfo.finishTime = 1
     var task = new ShuffleMapTask(0)
@@ -142,7 +142,7 @@ class JobProgressListenerSuite extends FunSuite with LocalSparkContext with Matc
       val taskMetrics = new TaskMetrics()
       val shuffleReadMetrics = new ShuffleReadMetrics()
       val shuffleWriteMetrics = new ShuffleWriteMetrics()
-      taskMetrics.updateShuffleReadMetrics(shuffleReadMetrics)
+      taskMetrics.setShuffleReadMetrics(Some(shuffleReadMetrics))
       taskMetrics.shuffleWriteMetrics = Some(shuffleWriteMetrics)
       shuffleReadMetrics.remoteBytesRead = base + 1
       shuffleReadMetrics.remoteBlocksFetched = base + 2
diff --git a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
index 2002a817d9..97ffb07662 100644
--- a/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/JsonProtocolSuite.scala
@@ -539,7 +539,7 @@ class JsonProtocolSuite extends FunSuite {
       sr.localBlocksFetched = e
       sr.fetchWaitTime = a + d
       sr.remoteBlocksFetched = f
-      t.updateShuffleReadMetrics(sr)
+      t.setShuffleReadMetrics(Some(sr))
     }
     sw.shuffleBytesWritten = a + b + c
     sw.shuffleWriteTime = b + c + d
-- 
GitLab