From 8fef5b9c5f595b621b5b0218d659f6a5392b3250 Mon Sep 17 00:00:00 2001
From: Imran Rashid <imran@quantifind.com>
Date: Sun, 3 Mar 2013 16:34:04 -0800
Subject: [PATCH] refactoring of TaskMetrics

---
 .../spark/BlockStoreShuffleFetcher.scala      | 18 ++--
 .../scala/spark/executor/TaskMetrics.scala    | 95 +++++++++++++------
 .../main/scala/spark/rdd/SubtractedRDD.scala  |  2 +-
 .../spark/scheduler/ShuffleMapTask.scala      |  5 +-
 .../scala/spark/scheduler/SparkListener.scala | 44 +++++----
 .../scheduler/cluster/TaskSetManager.scala    |  5 +-
 6 files changed, 110 insertions(+), 59 deletions(-)

diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
index 9f5ebe3fd1..45fc8cbf7e 100644
--- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
@@ -1,6 +1,6 @@
 package spark
 
-import executor.TaskMetrics
+import executor.{ShuffleReadMetrics, TaskMetrics}
 import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable.HashMap
 
@@ -52,13 +52,15 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
     val itr = new TimedIterator(blockFetcherItr.flatMap(unpackBlock)) with DelegateBlockFetchTracker
     itr.setDelegate(blockFetcherItr)
     CleanupIterator[(K,V), Iterator[(K,V)]](itr, {
-      metrics.shuffleReadMillis = Some(itr.getNetMillis)
-      metrics.remoteFetchTime = Some(itr.remoteFetchTime)
-      metrics.remoteFetchWaitTime = Some(itr.remoteFetchWaitTime)
-      metrics.remoteBytesRead = Some(itr.remoteBytesRead)
-      metrics.totalBlocksFetched = Some(itr.totalBlocks)
-      metrics.localBlocksFetched = Some(itr.numLocalBlocks)
-      metrics.remoteBlocksFetched = Some(itr.numRemoteBlocks)
+      val shuffleMetrics = new ShuffleReadMetrics
+      shuffleMetrics.shuffleReadMillis = itr.getNetMillis
+      shuffleMetrics.remoteFetchTime = itr.remoteFetchTime
+      shuffleMetrics.remoteFetchWaitTime = itr.remoteFetchWaitTime
+      shuffleMetrics.remoteBytesRead = itr.remoteBytesRead
+      shuffleMetrics.totalBlocksFetched = itr.totalBlocks
+      shuffleMetrics.localBlocksFetched = itr.numLocalBlocks
+      shuffleMetrics.remoteBlocksFetched = itr.numRemoteBlocks
+      metrics.shuffleReadMetrics = Some(shuffleMetrics)
     })
   }
 }
diff --git a/core/src/main/scala/spark/executor/TaskMetrics.scala b/core/src/main/scala/spark/executor/TaskMetrics.scala
index c66abdf2ca..800305cd6c 100644
--- a/core/src/main/scala/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/spark/executor/TaskMetrics.scala
@@ -1,34 +1,71 @@
 package spark.executor
 
-/**
- *
- * @param totalBlocksFetched total number of blocks fetched in a shuffle (remote or local)
- * @param remoteBlocksFetched number of remote blocks fetched in a shuffle
- * @param localBlocksFetched local blocks fetched in a shuffle
- * @param shuffleReadMillis total time to read shuffle data
- * @param remoteFetchWaitTime total time that is spent blocked waiting for shuffle to fetch remote data
- * @param remoteFetchTime the total amount of time for all the shuffle fetches.  This adds up time from overlapping
- *                        shuffles, so can be longer than task time
- * @param remoteBytesRead total number of remote bytes read from a shuffle
- * @param shuffleBytesWritten number of bytes written for a shuffle
- * @param executorDeserializeTime time taken on the executor to deserialize this task
- * @param executorRunTime time the executor spends actually running the task (including fetching shuffle data)
- * @param resultSize the number of bytes this task transmitted back to the driver as the TaskResult
- */
-case class TaskMetrics(
-  var totalBlocksFetched : Option[Int],
-  var remoteBlocksFetched: Option[Int],
-  var localBlocksFetched: Option[Int],
-  var shuffleReadMillis: Option[Long],
-  var remoteFetchWaitTime: Option[Long],
-  var remoteFetchTime: Option[Long],
-  var remoteBytesRead: Option[Long],
-  var shuffleBytesWritten: Option[Long],
-  var executorDeserializeTime: Int,
-  var executorRunTime:Int,
-  var resultSize: Long
-)
+class TaskMetrics{
+  /**
+   * Time taken on the executor to deserialize this task
+   */
+  var executorDeserializeTime: Int = _
+  /**
+   * Time the executor spends actually running the task (including fetching shuffle data)
+   */
+  var executorRunTime:Int = _
+  /**
+   * The number of bytes this task transmitted back to the driver as the TaskResult
+   */
+  var resultSize: Long = _
+
+  /**
+   * If this task reads from shuffle output, metrics on getting shuffle data will be collected here
+   */
+  var shuffleReadMetrics: Option[ShuffleReadMetrics] = None
+
+  /**
+   * If this task writes to shuffle output, metrics on the written shuffle data will be collected here
+   */
+  var shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None
+
+}
 
 object TaskMetrics {
-  private[spark] def empty() : TaskMetrics = TaskMetrics(None,None,None,None,None,None,None,None, -1, -1, -1)
+  private[spark] def empty() : TaskMetrics = new TaskMetrics
+}
+
+
+class ShuffleReadMetrics {
+  /**
+   * Total number of blocks fetched in a shuffle (remote or local)
+   */
+  var totalBlocksFetched : Int = _
+  /**
+   * Number of remote blocks fetched in a shuffle
+   */
+  var remoteBlocksFetched: Int = _
+  /**
+   * Local blocks fetched in a shuffle
+   */
+  var localBlocksFetched: Int = _
+  /**
+   * Total time to read shuffle data
+   */
+  var shuffleReadMillis: Long = _
+  /**
+   * Total time that is spent blocked waiting for shuffle to fetch remote data
+   */
+  var remoteFetchWaitTime: Long = _
+  /**
+   * The total amount of time for all the shuffle fetches.  This adds up time from overlapping
+   *     shuffles, so can be longer than task time
+   */
+  var remoteFetchTime: Long = _
+  /**
+   * Total number of remote bytes read from a shuffle
+   */
+  var remoteBytesRead: Long = _
+}
+
+class ShuffleWriteMetrics {
+  /**
+   * Number of bytes written for a shuffle
+   */
+  var shuffleBytesWritten: Long = _
 }
\ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
index daf9cc993c..43ec90cac5 100644
--- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -89,7 +89,7 @@ private[spark] class SubtractedRDD[T: ClassManifest](
         for (k <- rdd.iterator(itsSplit, context))
           op(k.asInstanceOf[T])
       case ShuffleCoGroupSplitDep(shuffleId) =>
-        for ((k, _) <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index))
+        for ((k, _) <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics))
           op(k.asInstanceOf[T])
     }
     // the first dep is rdd1; add all keys to the set
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 0b567d1312..36d087a4d0 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -13,6 +13,7 @@ import com.ning.compress.lzf.LZFInputStream
 import com.ning.compress.lzf.LZFOutputStream
 
 import spark._
+import executor.ShuffleWriteMetrics
 import spark.storage._
 import util.{TimeStampedHashMap, MetadataCleaner}
 
@@ -142,7 +143,9 @@ private[spark] class ShuffleMapTask(
         totalBytes += size
         compressedSizes(i) = MapOutputTracker.compressSize(size)
       }
-      metrics.get.shuffleBytesWritten = Some(totalBytes)
+      val shuffleMetrics = new ShuffleWriteMetrics
+      shuffleMetrics.shuffleBytesWritten = totalBytes
+      metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
 
       return new MapStatus(blockManager.blockManagerId, compressedSizes)
     } finally {
diff --git a/core/src/main/scala/spark/scheduler/SparkListener.scala b/core/src/main/scala/spark/scheduler/SparkListener.scala
index 0915b3eb5b..21185227ab 100644
--- a/core/src/main/scala/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/spark/scheduler/SparkListener.scala
@@ -26,11 +26,13 @@ class StatsReportListener extends SparkListener with Logging {
     implicit val sc = stageCompleted
     this.logInfo("Finished stage: " + stageCompleted.stageInfo)
     showMillisDistribution("task runtime:", (info, _) => Some(info.duration))
-    showBytesDistribution("shuffle bytes written:",(_,metric) => metric.shuffleBytesWritten)
 
-    //fetch & some io info
-    showMillisDistribution("fetch wait time:",(_, metric) => metric.remoteFetchWaitTime)
-    showBytesDistribution("remote bytes read:", (_, metric) => metric.remoteBytesRead)
+    //shuffle write
+    showBytesDistribution("shuffle bytes written:",(_,metric) => metric.shuffleWriteMetrics.map{_.shuffleBytesWritten})
+
+    //fetch & io
+    showMillisDistribution("fetch wait time:",(_, metric) => metric.shuffleReadMetrics.map{_.remoteFetchWaitTime})
+    showBytesDistribution("remote bytes read:", (_, metric) => metric.shuffleReadMetrics.map{_.remoteBytesRead})
     showBytesDistribution("task result size:", (_, metric) => Some(metric.resultSize))
 
     //runtime breakdown
@@ -61,6 +63,18 @@ object StatsReportListener extends Logging {
     extractDoubleDistribution(stage, (info, metric) => getMetric(info,metric).map{_.toDouble})
   }
 
+  def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) {
+    val stats = d.statCounter
+    logInfo(heading + stats)
+    val quantiles = d.getQuantiles(probabilities).map{formatNumber}
+    logInfo(percentilesHeader)
+    logInfo("\t" + quantiles.mkString("\t"))
+  }
+
+  def showDistribution(heading: String, dOpt: Option[Distribution], formatNumber: Double => String) {
+    dOpt.foreach { d => showDistribution(heading, d, formatNumber)}
+  }
+
   def showDistribution(heading: String, dOpt: Option[Distribution], format:String) {
     def f(d:Double) = format.format(d)
     showDistribution(heading, dOpt, f _)
@@ -77,11 +91,15 @@ object StatsReportListener extends Logging {
   }
 
   def showBytesDistribution(heading: String, dOpt: Option[Distribution]) {
-    showDistribution(heading, dOpt, d => Utils.memoryBytesToString(d.toLong))
+    dOpt.foreach{dist => showBytesDistribution(heading, dist)}
+  }
+
+  def showBytesDistribution(heading: String, dist: Distribution) {
+    showDistribution(heading, dist, (d => Utils.memoryBytesToString(d.toLong)): Double => String)
   }
 
   def showMillisDistribution(heading: String, dOpt: Option[Distribution]) {
-    showDistribution(heading, dOpt, d => StatsReportListener.millisToString(d.toLong))
+    showDistribution(heading, dOpt, (d => StatsReportListener.millisToString(d.toLong)): Double => String)
   }
 
   def showMillisDistribution(heading: String, getMetric: (TaskInfo, TaskMetrics) => Option[Long])
@@ -89,15 +107,6 @@ object StatsReportListener extends Logging {
     showMillisDistribution(heading, extractLongDistribution(stage, getMetric))
   }
 
-  def showDistribution(heading: String, dOpt: Option[Distribution], formatNumber: Double => String) {
-    dOpt.foreach { d =>
-      val stats = d.statCounter
-      logInfo(heading + stats)
-      val quantiles = d.getQuantiles(probabilities).map{formatNumber}
-      logInfo(percentilesHeader)
-      logInfo("\t" + quantiles.mkString("\t"))
-    }
-  }
 
 
   val seconds = 1000L
@@ -128,8 +137,9 @@ case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], othe
 object RuntimePercentage {
   def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = {
     val denom = totalTime.toDouble
-    val fetch = metrics.remoteFetchWaitTime.map{_ / denom}
-    val exec = (metrics.executorRunTime - metrics.remoteFetchWaitTime.getOrElse(0l)) / denom
+    val fetchTime = metrics.shuffleReadMetrics.map{_.remoteFetchWaitTime}
+    val fetch = fetchTime.map{_ / denom}
+    val exec = (metrics.executorRunTime - fetchTime.getOrElse(0l)) / denom
     val other = 1.0 - (exec + fetch.getOrElse(0d))
     RuntimePercentage(exec, fetch, other)
   }
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
index 236f81bb9f..c9f2c48804 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -259,9 +259,8 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
         tid, info.duration, tasksFinished, numTasks))
       // Deserialize task result and pass it to the scheduler
       val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
-      //lame way to get size into final metrics
-      val metricsWithSize = result.metrics.copy(resultSize = serializedData.limit())
-      sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, metricsWithSize)
+      result.metrics.resultSize = serializedData.limit()
+      sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
       // Mark finished and stop if we've finished all the tasks
       finished(index) = true
       if (tasksFinished == numTasks) {
-- 
GitLab