diff --git a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
index 86432d0127e3e9a75d24177ebc8558d453e0ecc8..53b0389c3a67373394e8a30a9a12a2401718c45a 100644
--- a/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/spark/BlockStoreShuffleFetcher.scala
@@ -1,20 +1,22 @@
 package spark
 
+import executor.{ShuffleReadMetrics, TaskMetrics}
 import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable.HashMap
 
-import spark.storage.BlockManagerId
+import spark.storage.{DelegateBlockFetchTracker, BlockManagerId}
+import util.{CompletionIterator, TimedIterator}
 
 private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Logging {
-  override def fetch[K, V](shuffleId: Int, reduceId: Int) = {
+  override def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) = {
     logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
     val blockManager = SparkEnv.get.blockManager
-    
+
     val startTime = System.currentTimeMillis
     val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
     logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
       shuffleId, reduceId, System.currentTimeMillis - startTime))
-    
+
     val splitsByAddress = new HashMap[BlockManagerId, ArrayBuffer[(Int, Long)]]
     for (((address, size), index) <- statuses.zipWithIndex) {
       splitsByAddress.getOrElseUpdate(address, ArrayBuffer()) += ((index, size))
@@ -45,6 +47,20 @@ private[spark] class BlockStoreShuffleFetcher extends ShuffleFetcher with Loggin
         }
       }
     }
-    blockManager.getMultiple(blocksByAddress).flatMap(unpackBlock)
+
+    val blockFetcherItr = blockManager.getMultiple(blocksByAddress)
+    val itr = new TimedIterator(blockFetcherItr.flatMap(unpackBlock)) with DelegateBlockFetchTracker
+    itr.setDelegate(blockFetcherItr)
+    CompletionIterator[(K,V), Iterator[(K,V)]](itr, {
+      val shuffleMetrics = new ShuffleReadMetrics
+      shuffleMetrics.shuffleReadMillis = itr.getNetMillis
+      shuffleMetrics.remoteFetchTime = itr.remoteFetchTime
+      shuffleMetrics.remoteFetchWaitTime = itr.remoteFetchWaitTime
+      shuffleMetrics.remoteBytesRead = itr.remoteBytesRead
+      shuffleMetrics.totalBlocksFetched = itr.totalBlocks
+      shuffleMetrics.localBlocksFetched = itr.numLocalBlocks
+      shuffleMetrics.remoteBlocksFetched = itr.numRemoteBlocks
+      metrics.shuffleReadMetrics = Some(shuffleMetrics)
+    })
   }
 }
diff --git a/core/src/main/scala/spark/ShuffleFetcher.scala b/core/src/main/scala/spark/ShuffleFetcher.scala
index d9a94d4021ee325d57aac710f4b0858883f3e63b..442e9f0269dc48284c8e20411e1f72427f627b96 100644
--- a/core/src/main/scala/spark/ShuffleFetcher.scala
+++ b/core/src/main/scala/spark/ShuffleFetcher.scala
@@ -1,11 +1,13 @@
 package spark
 
+import executor.TaskMetrics
+
 private[spark] abstract class ShuffleFetcher {
   /**
    * Fetch the shuffle outputs for a given ShuffleDependency.
    * @return An iterator over the elements of the fetched shuffle outputs.
    */
-  def fetch[K, V](shuffleId: Int, reduceId: Int) : Iterator[(K, V)]
+  def fetch[K, V](shuffleId: Int, reduceId: Int, metrics: TaskMetrics) : Iterator[(K,V)]
 
   /** Stop the fetcher */
   def stop() {}
diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala
index 7503b1a5ea30ec89f85433c18c18b18ed68d23b6..4957a54c1b8af5c199a3f7ffd235c1ff492e0033 100644
--- a/core/src/main/scala/spark/SparkContext.scala
+++ b/core/src/main/scala/spark/SparkContext.scala
@@ -1,19 +1,15 @@
 package spark
 
 import java.io._
-import java.util.concurrent.ConcurrentHashMap
 import java.util.concurrent.atomic.AtomicInteger
-import java.net.{URI, URLClassLoader}
-import java.lang.ref.WeakReference
+import java.net.URI
 
 import scala.collection.Map
 import scala.collection.generic.Growable
-import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.mutable.HashMap
 import scala.collection.JavaConversions._
 
-import akka.actor.Actor
-import akka.actor.Actor._
-import org.apache.hadoop.fs.{FileUtil, Path}
+import org.apache.hadoop.fs.Path
 import org.apache.hadoop.conf.Configuration
 import org.apache.hadoop.mapred.InputFormat
 import org.apache.hadoop.mapred.SequenceFileInputFormat
@@ -33,20 +29,19 @@ import org.apache.hadoop.mapred.TextInputFormat
 import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
 import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat}
 import org.apache.hadoop.mapreduce.{Job => NewHadoopJob}
-import org.apache.mesos.{Scheduler, MesosNativeLibrary}
+import org.apache.mesos.MesosNativeLibrary
 
-import spark.broadcast._
 import spark.deploy.LocalSparkCluster
 import spark.partial.ApproximateEvaluator
 import spark.partial.PartialResult
-import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD}
-import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler}
+import spark.rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD, ParallelCollectionRDD}
+import spark.scheduler._
 import spark.scheduler.local.LocalScheduler
 import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler}
 import spark.scheduler.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
-import storage.BlockManagerUI
-import util.{MetadataCleaner, TimeStampedHashMap}
-import storage.{StorageStatus, StorageUtils, RDDInfo}
+import spark.storage.BlockManagerUI
+import spark.util.{MetadataCleaner, TimeStampedHashMap}
+import spark.storage.{StorageStatus, StorageUtils, RDDInfo}
 
 /**
  * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
@@ -466,6 +461,10 @@ class SparkContext(
     logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key))
   }
 
+  def addSparkListener(listener: SparkListener) {
+    dagScheduler.sparkListeners += listener
+  }
+
   /**
    * Return a map from the slave to the max memory available for caching and the remaining
    * memory available for caching.
@@ -484,6 +483,10 @@ class SparkContext(
     StorageUtils.rddInfoFromStorageStatus(getExecutorStorageStatus, this)
   }
 
+  def getStageInfo: Map[Stage,StageInfo] = {
+    dagScheduler.stageToInfos
+  }
+
   /**
    * Return information about blocks stored in all of the slaves
    */
diff --git a/core/src/main/scala/spark/TaskContext.scala b/core/src/main/scala/spark/TaskContext.scala
index eab85f85a262b146ee9013354f9ba45686fdd2b0..dd0609026ace36a9bd616c0e783ea672ebdcc78d 100644
--- a/core/src/main/scala/spark/TaskContext.scala
+++ b/core/src/main/scala/spark/TaskContext.scala
@@ -1,9 +1,14 @@
 package spark
 
+import executor.TaskMetrics
 import scala.collection.mutable.ArrayBuffer
 
-
-class TaskContext(val stageId: Int, val splitId: Int, val attemptId: Long) extends Serializable {
+class TaskContext(
+  val stageId: Int,
+  val splitId: Int,
+  val attemptId: Long,
+  val taskMetrics: TaskMetrics = TaskMetrics.empty()
+) extends Serializable {
 
   @transient val onCompleteCallbacks = new ArrayBuffer[() => Unit]
 
diff --git a/core/src/main/scala/spark/executor/Executor.scala b/core/src/main/scala/spark/executor/Executor.scala
index 5de09030aa1b3b4318c1eb6051eed3e1a3fb23ef..4474ef4593703423215a172d103d7b98a3952b36 100644
--- a/core/src/main/scala/spark/executor/Executor.scala
+++ b/core/src/main/scala/spark/executor/Executor.scala
@@ -85,6 +85,7 @@ private[spark] class Executor extends Logging {
     extends Runnable {
 
     override def run() {
+      val startTime = System.currentTimeMillis()
       SparkEnv.set(env)
       Thread.currentThread.setContextClassLoader(urlClassLoader)
       val ser = SparkEnv.get.closureSerializer.newInstance()
@@ -98,9 +99,18 @@ private[spark] class Executor extends Logging {
         val task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
         logInfo("Its generation is " + task.generation)
         env.mapOutputTracker.updateGeneration(task.generation)
+        val taskStart = System.currentTimeMillis()
         val value = task.run(taskId.toInt)
+        val taskFinish = System.currentTimeMillis()
+        task.metrics.foreach{ m =>
+          m.executorDeserializeTime = (taskStart - startTime).toInt
+          m.executorRunTime = (taskFinish - taskStart).toInt
+        }
+        //TODO I'd also like to track the time it takes to serialize the task results, but that is huge headache, b/c
+        // we need to serialize the task metrics first.  If TaskMetrics had a custom serialized format, we could
+        // just change the relevants bytes in the byte buffer
         val accumUpdates = Accumulators.values
-        val result = new TaskResult(value, accumUpdates)
+        val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null))
         val serializedResult = ser.serialize(result)
         logInfo("Serialized size of result for " + taskId + " is " + serializedResult.limit)
         context.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
diff --git a/core/src/main/scala/spark/executor/TaskMetrics.scala b/core/src/main/scala/spark/executor/TaskMetrics.scala
new file mode 100644
index 0000000000000000000000000000000000000000..800305cd6c2673109a47636064255ac0e91c5750
--- /dev/null
+++ b/core/src/main/scala/spark/executor/TaskMetrics.scala
@@ -0,0 +1,71 @@
+package spark.executor
+
+class TaskMetrics{
+  /**
+   * Time taken on the executor to deserialize this task
+   */
+  var executorDeserializeTime: Int = _
+  /**
+   * Time the executor spends actually running the task (including fetching shuffle data)
+   */
+  var executorRunTime:Int = _
+  /**
+   * The number of bytes this task transmitted back to the driver as the TaskResult
+   */
+  var resultSize: Long = _
+
+  /**
+   * If this task reads from shuffle output, metrics on getting shuffle data will be collected here
+   */
+  var shuffleReadMetrics: Option[ShuffleReadMetrics] = None
+
+  /**
+   * If this task writes to shuffle output, metrics on the written shuffle data will be collected here
+   */
+  var shuffleWriteMetrics: Option[ShuffleWriteMetrics] = None
+
+}
+
+object TaskMetrics {
+  private[spark] def empty() : TaskMetrics = new TaskMetrics
+}
+
+
+class ShuffleReadMetrics {
+  /**
+   * Total number of blocks fetched in a shuffle (remote or local)
+   */
+  var totalBlocksFetched : Int = _
+  /**
+   * Number of remote blocks fetched in a shuffle
+   */
+  var remoteBlocksFetched: Int = _
+  /**
+   * Local blocks fetched in a shuffle
+   */
+  var localBlocksFetched: Int = _
+  /**
+   * Total time to read shuffle data
+   */
+  var shuffleReadMillis: Long = _
+  /**
+   * Total time that is spent blocked waiting for shuffle to fetch remote data
+   */
+  var remoteFetchWaitTime: Long = _
+  /**
+   * The total amount of time for all the shuffle fetches.  This adds up time from overlapping
+   *     shuffles, so can be longer than task time
+   */
+  var remoteFetchTime: Long = _
+  /**
+   * Total number of remote bytes read from a shuffle
+   */
+  var remoteBytesRead: Long = _
+}
+
+class ShuffleWriteMetrics {
+  /**
+   * Number of bytes written for a shuffle
+   */
+  var shuffleBytesWritten: Long = _
+}
\ No newline at end of file
diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
index 5200fb6b656ade45d1a8ca682f9f736ce1d769c5..65b4621b87ed0370a82e8a73769a7b3e28ef33c2 100644
--- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala
@@ -102,7 +102,8 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[(K, _)]], part: Partitioner)
       case ShuffleCoGroupSplitDep(shuffleId) => {
         // Read map outputs of shuffle
         val fetcher = SparkEnv.get.shuffleFetcher
-        for ((k, vs) <- fetcher.fetch[K, Seq[Any]](shuffleId, split.index)) {
+        val fetchItr = fetcher.fetch[K, Seq[Any]](shuffleId, split.index, context.taskMetrics)
+        for ((k, vs) <- fetchItr) {
           getSeq(k)(depNum) ++= vs
         }
       }
diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
index c2f118305f33f260b17af5bf49dbd1d5fd11b538..51f02409b6a75d689159970e59af52f887fc8626 100644
--- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala
@@ -28,6 +28,6 @@ class ShuffledRDD[K, V](
 
   override def compute(split: Partition, context: TaskContext): Iterator[(K, V)] = {
     val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId
-    SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index)
+    SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index, context.taskMetrics)
   }
 }
diff --git a/core/src/main/scala/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
index daf9cc993cf42e9e963e986d73cdad0d6d708059..43ec90cac5a95111534a7e15d9638c438b71dde5 100644
--- a/core/src/main/scala/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/spark/rdd/SubtractedRDD.scala
@@ -89,7 +89,7 @@ private[spark] class SubtractedRDD[T: ClassManifest](
         for (k <- rdd.iterator(itsSplit, context))
           op(k.asInstanceOf[T])
       case ShuffleCoGroupSplitDep(shuffleId) =>
-        for ((k, _) <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index))
+        for ((k, _) <- SparkEnv.get.shuffleFetcher.fetch(shuffleId, partition.index, context.taskMetrics))
           op(k.asInstanceOf[T])
     }
     // the first dep is rdd1; add all keys to the set
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index bf0837c0660cb2090a86a4aa75ec34bcd8b41b32..1bf5054f4d7947cd439d434d2a13eb4a5c11dda1 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -1,20 +1,19 @@
 package spark.scheduler
 
-import java.net.URI
+import cluster.TaskInfo
 import java.util.concurrent.atomic.AtomicInteger
-import java.util.concurrent.Future
 import java.util.concurrent.LinkedBlockingQueue
 import java.util.concurrent.TimeUnit
 
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue, Map}
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
 
 import spark._
+import spark.executor.TaskMetrics
 import spark.partial.ApproximateActionListener
 import spark.partial.ApproximateEvaluator
 import spark.partial.PartialResult
 import spark.storage.BlockManagerMaster
-import spark.storage.BlockManagerId
-import util.{MetadataCleaner, TimeStampedHashMap}
+import spark.util.{MetadataCleaner, TimeStampedHashMap}
 
 /**
  * A Scheduler subclass that implements stage-oriented scheduling. It computes a DAG of stages for
@@ -40,8 +39,10 @@ class DAGScheduler(
       task: Task[_],
       reason: TaskEndReason,
       result: Any,
-      accumUpdates: Map[Long, Any]) {
-    eventQueue.put(CompletionEvent(task, reason, result, accumUpdates))
+      accumUpdates: Map[Long, Any],
+      taskInfo: TaskInfo,
+      taskMetrics: TaskMetrics) {
+    eventQueue.put(CompletionEvent(task, reason, result, accumUpdates, taskInfo, taskMetrics))
   }
 
   // Called by TaskScheduler when an executor fails.
@@ -73,6 +74,10 @@ class DAGScheduler(
 
   val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
 
+  private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
+
+  private[spark] val sparkListeners = ArrayBuffer[SparkListener]()
+
   var cacheLocs = new HashMap[Int, Array[List[String]]]
 
   // For tracking failed nodes, we use the MapOutputTracker's generation number, which is
@@ -148,6 +153,7 @@ class DAGScheduler(
     val id = nextStageId.getAndIncrement()
     val stage = new Stage(id, rdd, shuffleDep, getParentStages(rdd, priority), priority)
     idToStage(id) = stage
+    stageToInfos(stage) = StageInfo(stage)
     stage
   }
 
@@ -472,6 +478,8 @@ class DAGScheduler(
         case _ => "Unkown"
       }
       logInfo("%s (%s) finished in %s s".format(stage, stage.origin, serviceTime))
+      val stageComp = StageCompleted(stageToInfos(stage))
+      sparkListeners.foreach{_.onStageCompleted(stageComp)}
       running -= stage
     }
     event.reason match {
@@ -481,6 +489,7 @@ class DAGScheduler(
           Accumulators.add(event.accumUpdates) // TODO: do this only if task wasn't resubmitted
         }
         pendingTasks(stage) -= task
+        stageToInfos(stage).taskInfos += event.taskInfo -> event.taskMetrics
         task match {
           case rt: ResultTask[_, _] =>
             resultStageToJob.get(stage) match {
@@ -501,7 +510,6 @@ class DAGScheduler(
             }
 
           case smt: ShuffleMapTask =>
-            val stage = idToStage(smt.stageId)
             val status = event.result.asInstanceOf[MapStatus]
             val execId = status.location.executorId
             logDebug("ShuffleMapTask finished on " + execId)
diff --git a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
index b34fa78c072c0ccdb2c304aec72fd3464169ae72..ed0b9bf178a89c80f5709efd09d288c18fc7bbac 100644
--- a/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/spark/scheduler/DAGSchedulerEvent.scala
@@ -1,8 +1,10 @@
 package spark.scheduler
 
+import spark.scheduler.cluster.TaskInfo
 import scala.collection.mutable.Map
 
 import spark._
+import spark.executor.TaskMetrics
 
 /**
  * Types of events that can be handled by the DAGScheduler. The DAGScheduler uses an event queue
@@ -25,7 +27,9 @@ private[spark] case class CompletionEvent(
     task: Task[_],
     reason: TaskEndReason,
     result: Any,
-    accumUpdates: Map[Long, Any])
+    accumUpdates: Map[Long, Any],
+    taskInfo: TaskInfo,
+    taskMetrics: TaskMetrics)
   extends DAGSchedulerEvent
 
 private[spark] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
diff --git a/core/src/main/scala/spark/scheduler/ResultTask.scala b/core/src/main/scala/spark/scheduler/ResultTask.scala
index 1721f78f483cf9f8b274e81ef6e2c01327a8b277..beb21a76fe5c8247b3373e05cb09355141d49ed8 100644
--- a/core/src/main/scala/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/spark/scheduler/ResultTask.scala
@@ -72,6 +72,7 @@ private[spark] class ResultTask[T, U](
 
   override def run(attemptId: Long): U = {
     val context = new TaskContext(stageId, partition, attemptId)
+    metrics = Some(context.taskMetrics)
     try {
       func(context, rdd.iterator(split, context))
     } finally {
diff --git a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
index 59ee3c0a095acc3eab7bedc5c4189538ac3f91cc..36d087a4d009c8e6a4fdf85bfc5d6e6e6e3926bf 100644
--- a/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/spark/scheduler/ShuffleMapTask.scala
@@ -13,6 +13,7 @@ import com.ning.compress.lzf.LZFInputStream
 import com.ning.compress.lzf.LZFOutputStream
 
 import spark._
+import executor.ShuffleWriteMetrics
 import spark.storage._
 import util.{TimeStampedHashMap, MetadataCleaner}
 
@@ -119,6 +120,7 @@ private[spark] class ShuffleMapTask(
     val numOutputSplits = dep.partitioner.numPartitions
 
     val taskContext = new TaskContext(stageId, partition, attemptId)
+    metrics = Some(taskContext.taskMetrics)
     try {
       // Partition the map output.
       val buckets = Array.fill(numOutputSplits)(new ArrayBuffer[(Any, Any)])
@@ -130,14 +132,20 @@ private[spark] class ShuffleMapTask(
 
       val compressedSizes = new Array[Byte](numOutputSplits)
 
+      var totalBytes = 0l
+
       val blockManager = SparkEnv.get.blockManager
       for (i <- 0 until numOutputSplits) {
         val blockId = "shuffle_" + dep.shuffleId + "_" + partition + "_" + i
         // Get a Scala iterator from Java map
         val iter: Iterator[(Any, Any)] = buckets(i).iterator
         val size = blockManager.put(blockId, iter, StorageLevel.DISK_ONLY, false)
+        totalBytes += size
         compressedSizes(i) = MapOutputTracker.compressSize(size)
       }
+      val shuffleMetrics = new ShuffleWriteMetrics
+      shuffleMetrics.shuffleBytesWritten = totalBytes
+      metrics.get.shuffleWriteMetrics = Some(shuffleMetrics)
 
       return new MapStatus(blockManager.blockManagerId, compressedSizes)
     } finally {
diff --git a/core/src/main/scala/spark/scheduler/SparkListener.scala b/core/src/main/scala/spark/scheduler/SparkListener.scala
new file mode 100644
index 0000000000000000000000000000000000000000..21185227ab0c6ff822c73e4e91c135434a92d726
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/SparkListener.scala
@@ -0,0 +1,146 @@
+package spark.scheduler
+
+import spark.scheduler.cluster.TaskInfo
+import spark.util.Distribution
+import spark.{Utils, Logging}
+import spark.executor.TaskMetrics
+
+trait SparkListener {
+  /**
+   * called when a stage is completed, with information on the completed stage
+   */
+  def onStageCompleted(stageCompleted: StageCompleted)
+}
+
+sealed trait SparkListenerEvents
+
+case class StageCompleted(val stageInfo: StageInfo) extends SparkListenerEvents
+
+
+/**
+ * Simple SparkListener that logs a few summary statistics when each stage completes
+ */
+class StatsReportListener extends SparkListener with Logging {
+  def onStageCompleted(stageCompleted: StageCompleted) {
+    import spark.scheduler.StatsReportListener._
+    implicit val sc = stageCompleted
+    this.logInfo("Finished stage: " + stageCompleted.stageInfo)
+    showMillisDistribution("task runtime:", (info, _) => Some(info.duration))
+
+    //shuffle write
+    showBytesDistribution("shuffle bytes written:",(_,metric) => metric.shuffleWriteMetrics.map{_.shuffleBytesWritten})
+
+    //fetch & io
+    showMillisDistribution("fetch wait time:",(_, metric) => metric.shuffleReadMetrics.map{_.remoteFetchWaitTime})
+    showBytesDistribution("remote bytes read:", (_, metric) => metric.shuffleReadMetrics.map{_.remoteBytesRead})
+    showBytesDistribution("task result size:", (_, metric) => Some(metric.resultSize))
+
+    //runtime breakdown
+    val runtimePcts = stageCompleted.stageInfo.taskInfos.map{
+      case (info, metrics) => RuntimePercentage(info.duration, metrics)
+    }
+    showDistribution("executor (non-fetch) time pct: ", Distribution(runtimePcts.map{_.executorPct * 100}), "%2.0f %%")
+    showDistribution("fetch wait time pct: ", Distribution(runtimePcts.flatMap{_.fetchPct.map{_ * 100}}), "%2.0f %%")
+    showDistribution("other time pct: ", Distribution(runtimePcts.map{_.other * 100}), "%2.0f %%")
+  }
+
+}
+
+object StatsReportListener extends Logging {
+
+  //for profiling, the extremes are more interesting
+  val percentiles = Array[Int](0,5,10,25,50,75,90,95,100)
+  val probabilities = percentiles.map{_ / 100.0}
+  val percentilesHeader = "\t" + percentiles.mkString("%\t") + "%"
+
+  def extractDoubleDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Double]): Option[Distribution] = {
+    Distribution(stage.stageInfo.taskInfos.flatMap{
+      case ((info,metric)) => getMetric(info, metric)})
+  }
+
+  //is there some way to setup the types that I can get rid of this completely?
+  def extractLongDistribution(stage:StageCompleted, getMetric: (TaskInfo,TaskMetrics) => Option[Long]): Option[Distribution] = {
+    extractDoubleDistribution(stage, (info, metric) => getMetric(info,metric).map{_.toDouble})
+  }
+
+  def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) {
+    val stats = d.statCounter
+    logInfo(heading + stats)
+    val quantiles = d.getQuantiles(probabilities).map{formatNumber}
+    logInfo(percentilesHeader)
+    logInfo("\t" + quantiles.mkString("\t"))
+  }
+
+  def showDistribution(heading: String, dOpt: Option[Distribution], formatNumber: Double => String) {
+    dOpt.foreach { d => showDistribution(heading, d, formatNumber)}
+  }
+
+  def showDistribution(heading: String, dOpt: Option[Distribution], format:String) {
+    def f(d:Double) = format.format(d)
+    showDistribution(heading, dOpt, f _)
+  }
+
+  def showDistribution(heading:String, format: String, getMetric: (TaskInfo,TaskMetrics) => Option[Double])
+    (implicit stage: StageCompleted) {
+    showDistribution(heading, extractDoubleDistribution(stage, getMetric), format)
+  }
+
+  def showBytesDistribution(heading:String, getMetric: (TaskInfo,TaskMetrics) => Option[Long])
+    (implicit stage: StageCompleted) {
+    showBytesDistribution(heading, extractLongDistribution(stage, getMetric))
+  }
+
+  def showBytesDistribution(heading: String, dOpt: Option[Distribution]) {
+    dOpt.foreach{dist => showBytesDistribution(heading, dist)}
+  }
+
+  def showBytesDistribution(heading: String, dist: Distribution) {
+    showDistribution(heading, dist, (d => Utils.memoryBytesToString(d.toLong)): Double => String)
+  }
+
+  def showMillisDistribution(heading: String, dOpt: Option[Distribution]) {
+    showDistribution(heading, dOpt, (d => StatsReportListener.millisToString(d.toLong)): Double => String)
+  }
+
+  def showMillisDistribution(heading: String, getMetric: (TaskInfo, TaskMetrics) => Option[Long])
+    (implicit stage: StageCompleted) {
+    showMillisDistribution(heading, extractLongDistribution(stage, getMetric))
+  }
+
+
+
+  val seconds = 1000L
+  val minutes = seconds * 60
+  val hours = minutes * 60
+
+  /**
+   * reformat a time interval in milliseconds to a prettier format for output
+   */
+  def millisToString(ms: Long) = {
+    val (size, units) =
+      if (ms > hours) {
+        (ms.toDouble / hours, "hours")
+      } else if (ms > minutes) {
+        (ms.toDouble / minutes, "min")
+      } else if (ms > seconds) {
+        (ms.toDouble / seconds, "s")
+      } else {
+        (ms.toDouble, "ms")
+      }
+    "%.1f %s".format(size, units)
+  }
+}
+
+
+
+case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double)
+object RuntimePercentage {
+  def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = {
+    val denom = totalTime.toDouble
+    val fetchTime = metrics.shuffleReadMetrics.map{_.remoteFetchWaitTime}
+    val fetch = fetchTime.map{_ / denom}
+    val exec = (metrics.executorRunTime - fetchTime.getOrElse(0l)) / denom
+    val other = 1.0 - (exec + fetch.getOrElse(0d))
+    RuntimePercentage(exec, fetch, other)
+  }
+}
diff --git a/core/src/main/scala/spark/scheduler/StageInfo.scala b/core/src/main/scala/spark/scheduler/StageInfo.scala
new file mode 100644
index 0000000000000000000000000000000000000000..8d83ff10c420d97bd23697ceee64774e8dd93bb3
--- /dev/null
+++ b/core/src/main/scala/spark/scheduler/StageInfo.scala
@@ -0,0 +1,12 @@
+package spark.scheduler
+
+import spark.scheduler.cluster.TaskInfo
+import scala.collection._
+import spark.executor.TaskMetrics
+
+case class StageInfo(
+    val stage: Stage,
+    val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]()
+) {
+  override def toString = stage.rdd.toString
+}
\ No newline at end of file
diff --git a/core/src/main/scala/spark/scheduler/Task.scala b/core/src/main/scala/spark/scheduler/Task.scala
index ef987fdeb696bbe96f11868ad6b8930cf2331f41..a6462c6968b67492399304c2e4d7887986e349ac 100644
--- a/core/src/main/scala/spark/scheduler/Task.scala
+++ b/core/src/main/scala/spark/scheduler/Task.scala
@@ -1,12 +1,12 @@
 package spark.scheduler
 
-import scala.collection.mutable.HashMap
-import spark.serializer.{SerializerInstance, Serializer}
+import spark.serializer.SerializerInstance
 import java.io.{DataInputStream, DataOutputStream}
 import java.nio.ByteBuffer
 import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream
 import spark.util.ByteBufferInputStream
 import scala.collection.mutable.HashMap
+import spark.executor.TaskMetrics
 
 /**
  * A task to execute on a worker node.
@@ -16,6 +16,9 @@ private[spark] abstract class Task[T](val stageId: Int) extends Serializable {
   def preferredLocations: Seq[String] = Nil
 
   var generation: Long = -1   // Map output tracker generation. Will be set by TaskScheduler.
+
+  var metrics: Option[TaskMetrics] = None
+
 }
 
 /**
diff --git a/core/src/main/scala/spark/scheduler/TaskResult.scala b/core/src/main/scala/spark/scheduler/TaskResult.scala
index 9a54d0e8541eb46381b2387e2d578222f60c0d6a..6de0aa7adf8ac782895b53d568286e2f2c4a1cdf 100644
--- a/core/src/main/scala/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/spark/scheduler/TaskResult.scala
@@ -3,13 +3,14 @@ package spark.scheduler
 import java.io._
 
 import scala.collection.mutable.Map
+import spark.executor.TaskMetrics
 
 // Task result. Also contains updates to accumulator variables.
 // TODO: Use of distributed cache to return result is a hack to get around
 // what seems to be a bug with messages over 60KB in libprocess; fix it
 private[spark]
-class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any]) extends Externalizable {
-  def this() = this(null.asInstanceOf[T], null)
+class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics) extends Externalizable {
+  def this() = this(null.asInstanceOf[T], null, null)
 
   override def writeExternal(out: ObjectOutput) {
     out.writeObject(value)
@@ -18,6 +19,7 @@ class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any]) extends Exte
       out.writeLong(key)
       out.writeObject(value)
     }
+    out.writeObject(metrics)
   }
 
   override def readExternal(in: ObjectInput) {
@@ -31,5 +33,6 @@ class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any]) extends Exte
         accumUpdates(in.readLong()) = in.readObject()
       }
     }
+    metrics = in.readObject().asInstanceOf[TaskMetrics]
   }
 }
diff --git a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
index 9fcef86e46a29673133f2884defdde81b8caf9b9..771518dddfacaaf2916a1f6cd834983725cbc533 100644
--- a/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
+++ b/core/src/main/scala/spark/scheduler/TaskSchedulerListener.scala
@@ -1,15 +1,18 @@
 package spark.scheduler
 
+import spark.scheduler.cluster.TaskInfo
 import scala.collection.mutable.Map
 
 import spark.TaskEndReason
+import spark.executor.TaskMetrics
 
 /**
  * Interface for getting events back from the TaskScheduler.
  */
 private[spark] trait TaskSchedulerListener {
   // A task has finished or failed.
-  def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any]): Unit
+  def taskEnded(task: Task[_], reason: TaskEndReason, result: Any, accumUpdates: Map[Long, Any],
+                taskInfo: TaskInfo, taskMetrics: TaskMetrics): Unit
 
   // A node was lost from the cluster.
   def executorLost(execId: String): Unit
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
index 0f975ce1eb4192caa024b279c16328b2e312d806..dfe3c5a85bc25f47a85a0da31a4649b27ef2c3be 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskInfo.scala
@@ -9,7 +9,8 @@ class TaskInfo(
     val index: Int,
     val launchTime: Long,
     val executorId: String,
-    val host: String) {
+    val host: String,
+    val preferred: Boolean) {
   var finishTime: Long = 0
   var failed = false
 
diff --git a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
index 3dabdd76b1aa55404ef660199d05e5e699775038..c9f2c488048ca2628387165ac498d7346da45627 100644
--- a/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/spark/scheduler/cluster/TaskSetManager.scala
@@ -208,7 +208,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
             taskSet.id, index, taskId, execId, host, prefStr))
           // Do various bookkeeping
           copiesRunning(index) += 1
-          val info = new TaskInfo(taskId, index, time, execId, host)
+          val info = new TaskInfo(taskId, index, time, execId, host, preferred)
           taskInfos(taskId) = info
           taskAttempts(index) = info :: taskAttempts(index)
           if (preferred) {
@@ -259,7 +259,8 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
         tid, info.duration, tasksFinished, numTasks))
       // Deserialize task result and pass it to the scheduler
       val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
-      sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates)
+      result.metrics.resultSize = serializedData.limit()
+      sched.listener.taskEnded(tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
       // Mark finished and stop if we've finished all the tasks
       finished(index) = true
       if (tasksFinished == numTasks) {
@@ -290,7 +291,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
         reason match {
           case fetchFailed: FetchFailed =>
             logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
-            sched.listener.taskEnded(tasks(index), fetchFailed, null, null)
+            sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
             finished(index) = true
             tasksFinished += 1
             sched.taskSetFinished(this)
@@ -378,7 +379,7 @@ private[spark] class TaskSetManager(sched: ClusterScheduler, val taskSet: TaskSe
           addPendingTask(index)
           // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
           // stage finishes when a total of tasks.size tasks finish.
-          sched.listener.taskEnded(tasks(index), Resubmitted, null, null)
+          sched.listener.taskEnded(tasks(index), Resubmitted, null, null, info, null)
         }
       }
     }
diff --git a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
index 482d1cc85343a97852193b693ee9f7e054944b21..a76253ea14d3e108f251501c948383d143b666f6 100644
--- a/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/local/LocalScheduler.scala
@@ -1,14 +1,13 @@
 package spark.scheduler.local
 
 import java.io.File
-import java.net.URLClassLoader
-import java.util.concurrent.Executors
 import java.util.concurrent.atomic.AtomicInteger
 import scala.collection.mutable.HashMap
 
 import spark._
-import executor.ExecutorURLClassLoader
+import spark.executor.ExecutorURLClassLoader
 import spark.scheduler._
+import spark.scheduler.cluster.TaskInfo
 
 /**
  * A simple TaskScheduler implementation that runs tasks locally in a thread pool. Optionally
@@ -54,6 +53,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
 
     def runTask(task: Task[_], idInJob: Int, attemptId: Int) {
       logInfo("Running " + task)
+      val info = new TaskInfo(attemptId, idInJob, System.currentTimeMillis(), "local", "local", true)
       // Set the Spark execution environment for the worker thread
       SparkEnv.set(env)
       try {
@@ -81,10 +81,11 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
         val accumUpdates = ser.deserialize[collection.mutable.Map[Long, Any]](
           ser.serialize(Accumulators.values))
         logInfo("Finished " + task)
+        info.markSuccessful()
 
         // If the threadpool has not already been shutdown, notify DAGScheduler
         if (!Thread.currentThread().isInterrupted)
-          listener.taskEnded(task, Success, resultToReturn, accumUpdates)
+          listener.taskEnded(task, Success, resultToReturn, accumUpdates, info, null)
       } catch {
         case t: Throwable => {
           logError("Exception in task " + idInJob, t)
@@ -95,7 +96,7 @@ private[spark] class LocalScheduler(threads: Int, maxFailures: Int, sc: SparkCon
             } else {
               // TODO: Do something nicer here to return all the way to the user
               if (!Thread.currentThread().isInterrupted)
-                listener.taskEnded(task, new ExceptionFailure(t), null, null)
+                listener.taskEnded(task, new ExceptionFailure(t), null, null, info, null)
             }
           }
         }
diff --git a/core/src/main/scala/spark/storage/BlockFetchTracker.scala b/core/src/main/scala/spark/storage/BlockFetchTracker.scala
new file mode 100644
index 0000000000000000000000000000000000000000..ababb04305cde82f18665b6b65298af2857faf0c
--- /dev/null
+++ b/core/src/main/scala/spark/storage/BlockFetchTracker.scala
@@ -0,0 +1,10 @@
+package spark.storage
+
+private[spark] trait BlockFetchTracker {
+    def totalBlocks : Int
+    def numLocalBlocks: Int
+    def numRemoteBlocks: Int
+    def remoteFetchTime : Long
+    def remoteFetchWaitTime: Long
+    def remoteBytesRead : Long
+}
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index 2462721fb844307f41677477abe0d26512f57db2..4964060b1c46c4fa24e90795d9d88b3368592086 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -446,152 +446,8 @@ class BlockManager(
    * so that we can control the maxMegabytesInFlight for the fetch.
    */
   def getMultiple(blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])])
-      : Iterator[(String, Option[Iterator[Any]])] = {
-
-    if (blocksByAddress == null) {
-      throw new IllegalArgumentException("BlocksByAddress is null")
-    }
-    val totalBlocks = blocksByAddress.map(_._2.size).sum
-    logDebug("Getting " + totalBlocks + " blocks")
-    var startTime = System.currentTimeMillis
-    val localBlockIds = new ArrayBuffer[String]()
-    val remoteBlockIds = new HashSet[String]()
-
-    // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
-    // the block (since we want all deserializaton to happen in the calling thread); can also
-    // represent a fetch failure if size == -1.
-    class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
-      def failed: Boolean = size == -1
-    }
-
-    // A queue to hold our results.
-    val results = new LinkedBlockingQueue[FetchResult]
-
-    // A request to fetch one or more blocks, complete with their sizes
-    class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
-      val size = blocks.map(_._2).sum
-    }
-
-    // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
-    // the number of bytes in flight is limited to maxBytesInFlight
-    val fetchRequests = new Queue[FetchRequest]
-
-    // Current bytes in flight from our requests
-    var bytesInFlight = 0L
-
-    def sendRequest(req: FetchRequest) {
-      logDebug("Sending request for %d blocks (%s) from %s".format(
-        req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip))
-      val cmId = new ConnectionManagerId(req.address.ip, req.address.port)
-      val blockMessageArray = new BlockMessageArray(req.blocks.map {
-        case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
-      })
-      bytesInFlight += req.size
-      val sizeMap = req.blocks.toMap  // so we can look up the size of each blockID
-      val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
-      future.onSuccess {
-        case Some(message) => {
-          val bufferMessage = message.asInstanceOf[BufferMessage]
-          val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
-          for (blockMessage <- blockMessageArray) {
-            if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
-              throw new SparkException(
-                "Unexpected message " + blockMessage.getType + " received from " + cmId)
-            }
-            val blockId = blockMessage.getId
-            results.put(new FetchResult(
-              blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData)))
-            logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
-          }
-        }
-        case None => {
-          logError("Could not get block(s) from " + cmId)
-          for ((blockId, size) <- req.blocks) {
-            results.put(new FetchResult(blockId, -1, null))
-          }
-        }
-      }
-    }
-
-    // Partition local and remote blocks. Remote blocks are further split into FetchRequests of size
-    // at most maxBytesInFlight in order to limit the amount of data in flight.
-    val remoteRequests = new ArrayBuffer[FetchRequest]
-    for ((address, blockInfos) <- blocksByAddress) {
-      if (address == blockManagerId) {
-        localBlockIds ++= blockInfos.map(_._1)
-      } else {
-        remoteBlockIds ++= blockInfos.map(_._1)
-        // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
-        // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
-        // nodes, rather than blocking on reading output from one node.
-        val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
-        logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
-        val iterator = blockInfos.iterator
-        var curRequestSize = 0L
-        var curBlocks = new ArrayBuffer[(String, Long)]
-        while (iterator.hasNext) {
-          val (blockId, size) = iterator.next()
-          curBlocks += ((blockId, size))
-          curRequestSize += size
-          if (curRequestSize >= minRequestSize) {
-            // Add this FetchRequest
-            remoteRequests += new FetchRequest(address, curBlocks)
-            curRequestSize = 0
-            curBlocks = new ArrayBuffer[(String, Long)]
-          }
-        }
-        // Add in the final request
-        if (!curBlocks.isEmpty) {
-          remoteRequests += new FetchRequest(address, curBlocks)
-        }
-      }
-    }
-    // Add the remote requests into our queue in a random order
-    fetchRequests ++= Utils.randomize(remoteRequests)
-
-    // Send out initial requests for blocks, up to our maxBytesInFlight
-    while (!fetchRequests.isEmpty &&
-        (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
-      sendRequest(fetchRequests.dequeue())
-    }
-
-    val numGets = remoteBlockIds.size - fetchRequests.size
-    logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
-
-    // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
-    // these all at once because they will just memory-map some files, so they won't consume
-    // any memory that might exceed our maxBytesInFlight
-    startTime = System.currentTimeMillis
-    for (id <- localBlockIds) {
-      getLocal(id) match {
-        case Some(iter) => {
-          results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
-          logDebug("Got local block " + id)
-        }
-        case None => {
-          throw new BlockException(id, "Could not get block " + id + " from local machine")
-        }
-      }
-    }
-    logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
-
-    // Return an iterator that will read fetched blocks off the queue as they arrive.
-    return new Iterator[(String, Option[Iterator[Any]])] {
-      var resultsGotten = 0
-
-      def hasNext: Boolean = resultsGotten < totalBlocks
-
-      def next(): (String, Option[Iterator[Any]]) = {
-        resultsGotten += 1
-        val result = results.take()
-        bytesInFlight -= result.size
-        while (!fetchRequests.isEmpty &&
-            (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
-          sendRequest(fetchRequests.dequeue())
-        }
-        (result.blockId, if (result.failed) None else Some(result.deserialize()))
-      }
-    }
+      : BlockFetcherIterator = {
+    return new BlockFetcherIterator(this, blocksByAddress)
   }
 
   def put(blockId: String, values: Iterator[Any], level: StorageLevel, tellMaster: Boolean)
@@ -986,3 +842,176 @@ object BlockManager extends Logging {
     }
   }
 }
+
+class BlockFetcherIterator(
+    private val blockManager: BlockManager,
+    val blocksByAddress: Seq[(BlockManagerId, Seq[(String, Long)])]
+) extends Iterator[(String, Option[Iterator[Any]])] with Logging with BlockFetchTracker {
+
+  import blockManager._
+
+  private var _remoteBytesRead = 0l
+  private var _remoteFetchTime = 0l
+  private var _remoteFetchWaitTime = 0l
+
+  if (blocksByAddress == null) {
+    throw new IllegalArgumentException("BlocksByAddress is null")
+  }
+  val totalBlocks = blocksByAddress.map(_._2.size).sum
+  logDebug("Getting " + totalBlocks + " blocks")
+  var startTime = System.currentTimeMillis
+  val localBlockIds = new ArrayBuffer[String]()
+  val remoteBlockIds = new HashSet[String]()
+
+  // A result of a fetch. Includes the block ID, size in bytes, and a function to deserialize
+  // the block (since we want all deserializaton to happen in the calling thread); can also
+  // represent a fetch failure if size == -1.
+  class FetchResult(val blockId: String, val size: Long, val deserialize: () => Iterator[Any]) {
+    def failed: Boolean = size == -1
+  }
+
+  // A queue to hold our results.
+  val results = new LinkedBlockingQueue[FetchResult]
+
+  // A request to fetch one or more blocks, complete with their sizes
+  class FetchRequest(val address: BlockManagerId, val blocks: Seq[(String, Long)]) {
+    val size = blocks.map(_._2).sum
+  }
+
+  // Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
+  // the number of bytes in flight is limited to maxBytesInFlight
+  val fetchRequests = new Queue[FetchRequest]
+
+  // Current bytes in flight from our requests
+  var bytesInFlight = 0L
+
+  def sendRequest(req: FetchRequest) {
+    logDebug("Sending request for %d blocks (%s) from %s".format(
+      req.blocks.size, Utils.memoryBytesToString(req.size), req.address.ip))
+    val cmId = new ConnectionManagerId(req.address.ip, req.address.port)
+    val blockMessageArray = new BlockMessageArray(req.blocks.map {
+      case (blockId, size) => BlockMessage.fromGetBlock(GetBlock(blockId))
+    })
+    bytesInFlight += req.size
+    val sizeMap = req.blocks.toMap  // so we can look up the size of each blockID
+    val fetchStart = System.currentTimeMillis()
+    val future = connectionManager.sendMessageReliably(cmId, blockMessageArray.toBufferMessage)
+    future.onSuccess {
+      case Some(message) => {
+        val fetchDone = System.currentTimeMillis()
+        _remoteFetchTime += fetchDone - fetchStart
+        val bufferMessage = message.asInstanceOf[BufferMessage]
+        val blockMessageArray = BlockMessageArray.fromBufferMessage(bufferMessage)
+        for (blockMessage <- blockMessageArray) {
+          if (blockMessage.getType != BlockMessage.TYPE_GOT_BLOCK) {
+            throw new SparkException(
+              "Unexpected message " + blockMessage.getType + " received from " + cmId)
+          }
+          val blockId = blockMessage.getId
+          results.put(new FetchResult(
+            blockId, sizeMap(blockId), () => dataDeserialize(blockId, blockMessage.getData)))
+          _remoteBytesRead += req.size
+          logDebug("Got remote block " + blockId + " after " + Utils.getUsedTimeMs(startTime))
+        }
+      }
+      case None => {
+        logError("Could not get block(s) from " + cmId)
+        for ((blockId, size) <- req.blocks) {
+          results.put(new FetchResult(blockId, -1, null))
+        }
+      }
+    }
+  }
+
+  // Split local and remote blocks. Remote blocks are further split into FetchRequests of size
+  // at most maxBytesInFlight in order to limit the amount of data in flight.
+  val remoteRequests = new ArrayBuffer[FetchRequest]
+  for ((address, blockInfos) <- blocksByAddress) {
+    if (address == blockManagerId) {
+      localBlockIds ++= blockInfos.map(_._1)
+    } else {
+      remoteBlockIds ++= blockInfos.map(_._1)
+      // Make our requests at least maxBytesInFlight / 5 in length; the reason to keep them
+      // smaller than maxBytesInFlight is to allow multiple, parallel fetches from up to 5
+      // nodes, rather than blocking on reading output from one node.
+      val minRequestSize = math.max(maxBytesInFlight / 5, 1L)
+      logInfo("maxBytesInFlight: " + maxBytesInFlight + ", minRequest: " + minRequestSize)
+      val iterator = blockInfos.iterator
+      var curRequestSize = 0L
+      var curBlocks = new ArrayBuffer[(String, Long)]
+      while (iterator.hasNext) {
+        val (blockId, size) = iterator.next()
+        curBlocks += ((blockId, size))
+        curRequestSize += size
+        if (curRequestSize >= minRequestSize) {
+          // Add this FetchRequest
+          remoteRequests += new FetchRequest(address, curBlocks)
+          curRequestSize = 0
+          curBlocks = new ArrayBuffer[(String, Long)]
+        }
+      }
+      // Add in the final request
+      if (!curBlocks.isEmpty) {
+        remoteRequests += new FetchRequest(address, curBlocks)
+      }
+    }
+  }
+  // Add the remote requests into our queue in a random order
+  fetchRequests ++= Utils.randomize(remoteRequests)
+
+  // Send out initial requests for blocks, up to our maxBytesInFlight
+  while (!fetchRequests.isEmpty &&
+    (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+    sendRequest(fetchRequests.dequeue())
+  }
+
+  val numGets = remoteBlockIds.size - fetchRequests.size
+  logInfo("Started " + numGets + " remote gets in " + Utils.getUsedTimeMs(startTime))
+
+  // Get the local blocks while remote blocks are being fetched. Note that it's okay to do
+  // these all at once because they will just memory-map some files, so they won't consume
+  // any memory that might exceed our maxBytesInFlight
+  startTime = System.currentTimeMillis
+  for (id <- localBlockIds) {
+    getLocal(id) match {
+      case Some(iter) => {
+        results.put(new FetchResult(id, 0, () => iter)) // Pass 0 as size since it's not in flight
+        logDebug("Got local block " + id)
+      }
+      case None => {
+        throw new BlockException(id, "Could not get block " + id + " from local machine")
+      }
+    }
+  }
+  logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+
+  //an iterator that will read fetched blocks off the queue as they arrive.
+  var resultsGotten = 0
+
+  def hasNext: Boolean = resultsGotten < totalBlocks
+
+  def next(): (String, Option[Iterator[Any]]) = {
+    resultsGotten += 1
+    val startFetchWait = System.currentTimeMillis()
+    val result = results.take()
+    val stopFetchWait = System.currentTimeMillis()
+    _remoteFetchWaitTime += (stopFetchWait - startFetchWait)
+    bytesInFlight -= result.size
+    while (!fetchRequests.isEmpty &&
+      (bytesInFlight == 0 || bytesInFlight + fetchRequests.front.size <= maxBytesInFlight)) {
+      sendRequest(fetchRequests.dequeue())
+    }
+    (result.blockId, if (result.failed) None else Some(result.deserialize()))
+  }
+
+
+  //methods to profile the block fetching
+  def numLocalBlocks = localBlockIds.size
+  def numRemoteBlocks = remoteBlockIds.size
+
+  def remoteFetchTime = _remoteFetchTime
+  def remoteFetchWaitTime = _remoteFetchWaitTime
+
+  def remoteBytesRead = _remoteBytesRead
+
+}
diff --git a/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala b/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala
new file mode 100644
index 0000000000000000000000000000000000000000..5c491877bad1929cbda15eb414b252f99e00e382
--- /dev/null
+++ b/core/src/main/scala/spark/storage/DelegateBlockFetchTracker.scala
@@ -0,0 +1,12 @@
+package spark.storage
+
+private[spark] trait DelegateBlockFetchTracker extends BlockFetchTracker {
+  var delegate : BlockFetchTracker = _
+  def setDelegate(d: BlockFetchTracker) {delegate = d}
+  def totalBlocks = delegate.totalBlocks
+  def numLocalBlocks = delegate.numLocalBlocks
+  def numRemoteBlocks = delegate.numRemoteBlocks
+  def remoteFetchTime = delegate.remoteFetchTime
+  def remoteFetchWaitTime = delegate.remoteFetchWaitTime
+  def remoteBytesRead = delegate.remoteBytesRead
+}
diff --git a/core/src/main/scala/spark/util/CompletionIterator.scala b/core/src/main/scala/spark/util/CompletionIterator.scala
new file mode 100644
index 0000000000000000000000000000000000000000..81391837805967141fcadf678e414a0dea7c7db6
--- /dev/null
+++ b/core/src/main/scala/spark/util/CompletionIterator.scala
@@ -0,0 +1,25 @@
+package spark.util
+
+/**
+ * Wrapper around an iterator which calls a completion method after it successfully iterates through all the elements
+ */
+abstract class CompletionIterator[+A, +I <: Iterator[A]](sub: I) extends Iterator[A]{
+  def next = sub.next
+  def hasNext = {
+    val r = sub.hasNext
+    if (!r) {
+      completion
+    }
+    r
+  }
+
+  def completion()
+}
+
+object CompletionIterator {
+  def apply[A, I <: Iterator[A]](sub: I, completionFunction: => Unit) : CompletionIterator[A,I] = {
+    new CompletionIterator[A,I](sub) {
+      def completion() = completionFunction
+    }
+  }
+}
\ No newline at end of file
diff --git a/core/src/main/scala/spark/util/Distribution.scala b/core/src/main/scala/spark/util/Distribution.scala
new file mode 100644
index 0000000000000000000000000000000000000000..24738b43078740537b86dc41339d1ae159ed07bf
--- /dev/null
+++ b/core/src/main/scala/spark/util/Distribution.scala
@@ -0,0 +1,65 @@
+package spark.util
+
+import java.io.PrintStream
+
+/**
+ * Util for getting some stats from a small sample of numeric values, with some handy summary functions.
+ *
+ * Entirely in memory, not intended as a good way to compute stats over large data sets.
+ *
+ * Assumes you are giving it a non-empty set of data
+ */
+class Distribution(val data: Array[Double], val startIdx: Int, val endIdx: Int) {
+  require(startIdx < endIdx)
+  def this(data: Traversable[Double]) = this(data.toArray, 0, data.size)
+  java.util.Arrays.sort(data, startIdx, endIdx)
+  val length = endIdx - startIdx
+
+  val defaultProbabilities = Array(0,0.25,0.5,0.75,1.0)
+
+  /**
+   * Get the value of the distribution at the given probabilities.  Probabilities should be
+   * given from 0 to 1
+   * @param probabilities
+   */
+  def getQuantiles(probabilities: Traversable[Double] = defaultProbabilities) = {
+    probabilities.toIndexedSeq.map{p:Double => data(closestIndex(p))}
+  }
+
+  private def closestIndex(p: Double) = {
+    math.min((p * length).toInt + startIdx, endIdx - 1)
+  }
+
+  def showQuantiles(out: PrintStream = System.out) = {
+    out.println("min\t25%\t50%\t75%\tmax")
+    getQuantiles(defaultProbabilities).foreach{q => out.print(q + "\t")}
+    out.println
+  }
+
+  def statCounter = StatCounter(data.slice(startIdx, endIdx))
+
+  /**
+   * print a summary of this distribution to the given PrintStream.
+   * @param out
+   */
+  def summary(out: PrintStream = System.out) {
+    out.println(statCounter)
+    showQuantiles(out)
+  }
+}
+
+object Distribution {
+
+  def apply(data: Traversable[Double]): Option[Distribution] = {
+    if (data.size > 0)
+      Some(new Distribution(data))
+    else
+      None
+  }
+
+  def showQuantiles(out: PrintStream = System.out, quantiles: Traversable[Double]) {
+    out.println("min\t25%\t50%\t75%\tmax")
+    quantiles.foreach{q => out.print(q + "\t")}
+    out.println
+  }
+}
\ No newline at end of file
diff --git a/core/src/main/scala/spark/util/TimedIterator.scala b/core/src/main/scala/spark/util/TimedIterator.scala
new file mode 100644
index 0000000000000000000000000000000000000000..539b01f4ce47d3ff7237ca619d220aded7b04ee1
--- /dev/null
+++ b/core/src/main/scala/spark/util/TimedIterator.scala
@@ -0,0 +1,32 @@
+package spark.util
+
+/**
+ * A utility for tracking the total time an iterator takes to iterate through its elements.
+ *
+ * In general, this should only be used if you expect it to take a considerable amount of time
+ * (eg. milliseconds) to get each element -- otherwise, the timing won't be very accurate,
+ * and you are probably just adding more overhead
+ */
+class TimedIterator[+A](val sub: Iterator[A]) extends Iterator[A] {
+  private var netMillis = 0l
+  private var nElems = 0
+  def hasNext = {
+    val start = System.currentTimeMillis()
+    val r = sub.hasNext
+    val end = System.currentTimeMillis()
+    netMillis += (end - start)
+    r
+  }
+  def next = {
+    val start = System.currentTimeMillis()
+    val r = sub.next
+    val end = System.currentTimeMillis()
+    netMillis += (end - start)
+    nElems += 1
+    r
+  }
+
+  def getNetMillis = netMillis
+  def getAverageTimePerItem = netMillis / nElems.toDouble
+
+}
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index 9ffe7c5f992b6dfe82e77d596bcc81bea164be16..26e3ab72c0c2b233c31b9f2ecf5dd5ea292af678 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -423,7 +423,7 @@ public class JavaAPISuite implements Serializable {
   @Test
   public void iterator() {
     JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
-    TaskContext context = new TaskContext(0, 0, 0);
+    TaskContext context = new TaskContext(0, 0, 0, null);
     Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0), context).next().intValue());
   }
 
diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
index 8de490eb86f34c1393baea0e8422effae92129af..b3e6ab4c0f45e4357ebc1ba4fcc204115b3d2450 100644
--- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
@@ -265,7 +265,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
     assert(taskSet.tasks.size >= results.size)
     for ((result, i) <- results.zipWithIndex) {
       if (i < taskSet.tasks.size) {
-        runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any]()))
+        runEvent(CompletionEvent(taskSet.tasks(i), result._1, result._2, Map[Long, Any](), null, null))
       }
     }
   }
@@ -463,14 +463,14 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
     val noAccum = Map[Long, Any]()
     // We rely on the event queue being ordered and increasing the generation number by 1
     // should be ignored for being too old
-    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
+    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
     // should work because it's a non-failed host
-    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum))
+    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum, null, null))
     // should be ignored for being too old
-    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum))
+    runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
     taskSet.tasks(1).generation = newGeneration
     val secondStage = interceptStage(reduceRdd) {
-      runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum))
+      runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null))
     }
     assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
            Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
diff --git a/core/src/test/scala/spark/util/DistributionSuite.scala b/core/src/test/scala/spark/util/DistributionSuite.scala
new file mode 100644
index 0000000000000000000000000000000000000000..cc6249b1dda8de4d183e4786f6981bb09d609904
--- /dev/null
+++ b/core/src/test/scala/spark/util/DistributionSuite.scala
@@ -0,0 +1,25 @@
+package spark.util
+
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+
+/**
+ *
+ */
+
+class DistributionSuite extends FunSuite with ShouldMatchers {
+  test("summary") {
+    val d = new Distribution((1 to 100).toArray.map{_.toDouble})
+    val stats = d.statCounter
+    stats.count should be (100)
+    stats.mean should be (50.5)
+    stats.sum should be (50 * 101)
+
+    val quantiles = d.getQuantiles()
+    quantiles(0) should be (1)
+    quantiles(1) should be (26)
+    quantiles(2) should be (51)
+    quantiles(3) should be (76)
+    quantiles(4) should be (100)
+  }
+}