From c75eb14fe52f6789430983471974e5ddf73aacbf Mon Sep 17 00:00:00 2001
From: Kay Ousterhout <kayo@yahoo-inc.com>
Date: Sun, 12 May 2013 15:30:02 -0700
Subject: [PATCH] Send Task results through the block manager when larger than
 Akka frame size.

This change requires adding an extra failure mode: tasks can complete
successfully, but the result gets lost or flushed from the block manager
before it's been fetched.
---
 .../scala/org/apache/spark/SparkContext.scala |   2 +-
 .../org/apache/spark/TaskEndReason.scala      |   8 +-
 .../org/apache/spark/executor/Executor.scala  |  26 ++-
 .../apache/spark/scheduler/DAGScheduler.scala |   5 +-
 .../apache/spark/scheduler/TaskResult.scala   |  14 +-
 .../scheduler/cluster/ClusterScheduler.scala  |  56 ++++--
 .../cluster/ClusterTaskSetManager.scala       | 169 ++++++++----------
 .../apache/spark/scheduler/cluster/Pool.scala |   6 +-
 .../spark/scheduler/cluster/Schedulable.scala |   4 +-
 .../cluster/TaskResultResolver.scala          | 125 +++++++++++++
 .../scheduler/cluster/TaskSetManager.scala    |   2 -
 .../scheduler/local/LocalScheduler.scala      |   5 +-
 .../scheduler/local/LocalTaskSetManager.scala |  22 ++-
 .../apache/spark/storage/BlockManager.scala   |  27 ++-
 .../org/apache/spark/DistributedSuite.scala   |  13 --
 .../scheduler/TaskResultResolverSuite.scala   | 106 +++++++++++
 .../cluster/ClusterSchedulerSuite.scala       |  15 +-
 .../cluster/ClusterTaskSetManagerSuite.scala  |  12 +-
 18 files changed, 452 insertions(+), 165 deletions(-)
 create mode 100644 core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultResolver.scala
 create mode 100644 core/src/test/scala/org/apache/spark/scheduler/TaskResultResolverSuite.scala

diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 72540c712a..d9be6f71f2 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -145,7 +145,7 @@ class SparkContext(
   }
 
   // Create and start the scheduler
-  private var taskScheduler: TaskScheduler = {
+  private[spark] var taskScheduler: TaskScheduler = {
     // Regular expression used for local[N] master format
     val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
     // Regular expression for local[N, maxRetries], used in tests with failing tasks
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index 03bf268863..8466c2a004 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -46,6 +46,10 @@ private[spark] case class ExceptionFailure(
     metrics: Option[TaskMetrics])
   extends TaskEndReason
 
-private[spark] case class OtherFailure(message: String) extends TaskEndReason
+/**
+ * The task finished successfully, but the result was lost from the executor's block manager before
+ * it was fetched.
+ */
+private[spark] case object TaskResultLost extends TaskEndReason
 
-private[spark] case class TaskResultTooBigFailure() extends TaskEndReason
+private[spark] case class OtherFailure(message: String) extends TaskEndReason
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index ceae3b8289..acdb8d0343 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.executor
 
-import java.io.{File}
+import java.io.File
 import java.lang.management.ManagementFactory
 import java.nio.ByteBuffer
 import java.util.concurrent._
@@ -27,11 +27,11 @@ import scala.collection.mutable.HashMap
 
 import org.apache.spark.scheduler._
 import org.apache.spark._
+import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
 
-
 /**
- * The Mesos executor for Spark.
+ * Spark executor used with Mesos and the standalone scheduler.
  */
 private[spark] class Executor(
     executorId: String,
@@ -167,12 +167,20 @@ private[spark] class Executor(
         // we need to serialize the task metrics first.  If TaskMetrics had a custom serialized format, we could
         // just change the relevants bytes in the byte buffer
         val accumUpdates = Accumulators.values
-        val result = new TaskResult(value, accumUpdates, task.metrics.getOrElse(null))
-        val serializedResult = ser.serialize(result)
-        logInfo("Serialized size of result for " + taskId + " is " + serializedResult.limit)
-        if (serializedResult.limit >= (akkaFrameSize - 1024)) {
-          context.statusUpdate(taskId, TaskState.FAILED, ser.serialize(TaskResultTooBigFailure()))
-          return
+        val directResult = new DirectTaskResult(value, accumUpdates, task.metrics.getOrElse(null))
+        val serializedDirectResult = ser.serialize(directResult)
+        logInfo("Serialized size of result for " + taskId + " is " + serializedDirectResult.limit)
+        val serializedResult = {
+          if (serializedDirectResult.limit >= akkaFrameSize - 1024) {
+            logInfo("Storing result for " + taskId + " in local BlockManager")
+            val blockId = "taskresult_" + taskId
+            env.blockManager.putBytes(
+              blockId, serializedDirectResult, StorageLevel.MEMORY_AND_DISK_SER)
+            ser.serialize(new IndirectTaskResult[Any](blockId))
+          } else {
+            logInfo("Sending result for " + taskId + " directly to driver")
+            serializedDirectResult
+          }
         }
         context.statusUpdate(taskId, TaskState.FINISHED, serializedResult)
         logInfo("Finished task ID " + taskId)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 3e3f04f087..db998e499a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -553,7 +553,7 @@ class DAGScheduler(
         SparkEnv.get.closureSerializer.newInstance().serialize(tasks.head)
       } catch {
         case e: NotSerializableException =>
-          abortStage(stage, e.toString)
+          abortStage(stage, "Task not serializable: " + e.toString)
           running -= stage
           return
       }
@@ -705,6 +705,9 @@ class DAGScheduler(
       case ExceptionFailure(className, description, stackTrace, metrics) =>
         // Do nothing here, left up to the TaskScheduler to decide how to handle user failures
 
+      case TaskResultLost =>
+        // Do nothing here; the TaskScheduler handles these failures and resubmits the task.
+
       case other =>
         // Unrecognized failure - abort all jobs depending on this stage
         abortStage(stageIdToStage(task.stageId), task + " failed: " + other)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
index 5c7e5bb977..25a61b3115 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResult.scala
@@ -29,9 +29,17 @@ import org.apache.spark.util.Utils
 // TODO: Use of distributed cache to return result is a hack to get around
 // what seems to be a bug with messages over 60KB in libprocess; fix it
 private[spark]
-class TaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics)
-  extends Externalizable
-{
+sealed abstract class TaskResult[T]
+
+/** A reference to a DirectTaskResult that has been stored in the worker's BlockManager. */
+private[spark]
+case class IndirectTaskResult[T](val blockId: String) extends TaskResult[T] with Serializable
+
+/** A TaskResult that contains the task's return value and accumulator updates. */
+private[spark]
+class DirectTaskResult[T](var value: T, var accumUpdates: Map[Long, Any], var metrics: TaskMetrics)
+  extends TaskResult[T] with Externalizable {
+
   def this() = this(null.asInstanceOf[T], null, null)
 
   override def writeExternal(out: ObjectOutput) {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
index 919acce828..db7c6001f1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
@@ -18,6 +18,9 @@
 package org.apache.spark.scheduler.cluster
 
 import java.lang.{Boolean => JBoolean}
+import java.nio.ByteBuffer
+import java.util.concurrent.atomic.AtomicLong
+import java.util.{TimerTask, Timer}
 
 import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable.HashMap
@@ -27,9 +30,7 @@ import org.apache.spark._
 import org.apache.spark.TaskState.TaskState
 import org.apache.spark.scheduler._
 import org.apache.spark.scheduler.cluster.SchedulingMode.SchedulingMode
-import java.nio.ByteBuffer
-import java.util.concurrent.atomic.AtomicLong
-import java.util.{TimerTask, Timer}
+
 
 /**
  * The main TaskScheduler implementation, for running tasks on a cluster. Clients should first call
@@ -55,7 +56,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
   // Threshold above which we warn user initial TaskSet may be starved
   val STARVATION_TIMEOUT = System.getProperty("spark.starvation.timeout", "15000").toLong
 
-  val activeTaskSets = new HashMap[String, TaskSetManager]
+  // ClusterTaskSetManagers are not thread safe, so any access to one should be synchronized
+  // on this class.
+  val activeTaskSets = new HashMap[String, ClusterTaskSetManager]
 
   val taskIdToTaskSetId = new HashMap[Long, String]
   val taskIdToExecutorId = new HashMap[Long, String]
@@ -65,7 +68,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
   @volatile private var hasLaunchedTask = false
   private val starvationTimer = new Timer(true)
 
-  // Incrementing Mesos task IDs
+  // Incrementing task IDs
   val nextTaskId = new AtomicLong(0)
 
   // Which executor IDs we have executors on
@@ -96,6 +99,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
   val schedulingMode: SchedulingMode = SchedulingMode.withName(
     System.getProperty("spark.scheduler.mode", "FIFO"))
 
+  // This is a var so that we can reset it for testing purposes.
+  private[spark] var taskResultResolver = new TaskResultResolver(sc.env, this)
+
   override def setListener(listener: TaskSchedulerListener) {
     this.listener = listener
   }
@@ -234,7 +240,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
   }
 
   def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
-    var taskSetToUpdate: Option[TaskSetManager] = None
     var failedExecutor: Option[String] = None
     var taskFailed = false
     synchronized {
@@ -249,9 +254,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
         }
         taskIdToTaskSetId.get(tid) match {
           case Some(taskSetId) =>
-            if (activeTaskSets.contains(taskSetId)) {
-              taskSetToUpdate = Some(activeTaskSets(taskSetId))
-            }
             if (TaskState.isFinished(state)) {
               taskIdToTaskSetId.remove(tid)
               if (taskSetTaskIds.contains(taskSetId)) {
@@ -262,6 +264,15 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
             if (state == TaskState.FAILED) {
               taskFailed = true
             }
+            activeTaskSets.get(taskSetId).foreach { taskSet =>
+              if (state == TaskState.FINISHED) {
+                taskSet.removeRunningTask(tid)
+                taskResultResolver.enqueueSuccessfulTask(taskSet, tid, serializedData)
+              } else if (Set(TaskState.FAILED, TaskState.KILLED, TaskState.LOST).contains(state)) {
+                taskSet.removeRunningTask(tid)
+                taskResultResolver.enqueueFailedTask(taskSet, tid, state, serializedData)
+              }
+            }
           case None =>
             logInfo("Ignoring update from TID " + tid + " because its task set is gone")
         }
@@ -269,10 +280,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
         case e: Exception => logError("Exception in statusUpdate", e)
       }
     }
-    // Update the task set and DAGScheduler without holding a lock on this, since that can deadlock
-    if (taskSetToUpdate != None) {
-      taskSetToUpdate.get.statusUpdate(tid, state, serializedData)
-    }
+    // Update the DAGScheduler without holding a lock on this, since that can deadlock
     if (failedExecutor != None) {
       listener.executorLost(failedExecutor.get)
       backend.reviveOffers()
@@ -283,6 +291,25 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
     }
   }
 
+  def handleSuccessfulTask(
+    taskSetManager: ClusterTaskSetManager,
+    tid: Long,
+    taskResult: DirectTaskResult[_]) = synchronized {
+    taskSetManager.handleSuccessfulTask(tid, taskResult)
+  }
+
+  def handleFailedTask(
+    taskSetManager: ClusterTaskSetManager,
+    tid: Long,
+    taskState: TaskState,
+    reason: Option[TaskEndReason]) = synchronized {
+    taskSetManager.handleFailedTask(tid, taskState, reason)
+    if (taskState == TaskState.FINISHED) {
+      // The task finished successfully but the result was lost, so we should revive offers.
+      backend.reviveOffers()
+    }
+  }
+
   def error(message: String) {
     synchronized {
       if (activeTaskSets.size > 0) {
@@ -311,6 +338,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
     if (jarServer != null) {
       jarServer.stop()
     }
+    if (taskResultResolver != null) {
+      taskResultResolver.stop()
+    }
 
     // sleeping for an arbitrary 5 seconds : to ensure that messages are sent out.
     // TODO: Do something better !
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
index 0ac3d7bcfd..25e6f0a3ac 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -25,15 +25,13 @@ import scala.collection.mutable.HashMap
 import scala.collection.mutable.HashSet
 import scala.math.max
 import scala.math.min
+import scala.Some
 
-import org.apache.spark.{FetchFailed, Logging, Resubmitted, SparkEnv, Success, TaskEndReason, TaskState}
-import org.apache.spark.{ExceptionFailure, SparkException, TaskResultTooBigFailure}
+import org.apache.spark._
 import org.apache.spark.TaskState.TaskState
 import org.apache.spark.scheduler._
-import scala.Some
 import org.apache.spark.FetchFailed
 import org.apache.spark.ExceptionFailure
-import org.apache.spark.TaskResultTooBigFailure
 import org.apache.spark.util.{SystemClock, Clock}
 
 
@@ -71,18 +69,20 @@ private[spark] class ClusterTaskSetManager(
   val tasks = taskSet.tasks
   val numTasks = tasks.length
   val copiesRunning = new Array[Int](numTasks)
-  val finished = new Array[Boolean](numTasks)
+  val successful = new Array[Boolean](numTasks)
   val numFailures = new Array[Int](numTasks)
   val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
-  var tasksFinished = 0
+  var tasksSuccessful = 0
 
   var weight = 1
   var minShare = 0
-  var runningTasks = 0
   var priority = taskSet.priority
   var stageId = taskSet.stageId
   var name = "TaskSet_"+taskSet.stageId.toString
-  var parent: Schedulable = null
+  var parent: Pool = null
+
+  var runningTasks = 0
+  private val runningTasksSet = new HashSet[Long]
 
   // Set of pending tasks for each executor. These collections are actually
   // treated as stacks, in which new tasks are added to the end of the
@@ -223,7 +223,7 @@ private[spark] class ClusterTaskSetManager(
     while (!list.isEmpty) {
       val index = list.last
       list.trimEnd(1)
-      if (copiesRunning(index) == 0 && !finished(index)) {
+      if (copiesRunning(index) == 0 && !successful(index)) {
         return Some(index)
       }
     }
@@ -243,7 +243,7 @@ private[spark] class ClusterTaskSetManager(
   private def findSpeculativeTask(execId: String, host: String, locality: TaskLocality.Value)
     : Option[(Int, TaskLocality.Value)] =
   {
-    speculatableTasks.retain(index => !finished(index)) // Remove finished tasks from set
+    speculatableTasks.retain(index => !successful(index)) // Remove finished tasks from set
 
     if (!speculatableTasks.isEmpty) {
       // Check for process-local or preference-less tasks; note that tasks can be process-local
@@ -344,7 +344,7 @@ private[spark] class ClusterTaskSetManager(
       maxLocality: TaskLocality.TaskLocality)
     : Option[TaskDescription] =
   {
-    if (tasksFinished < numTasks && availableCpus >= CPUS_PER_TASK) {
+    if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) {
       val curTime = clock.getTime()
 
       var allowedLocality = getAllowedLocalityLevel(curTime)
@@ -375,7 +375,7 @@ private[spark] class ClusterTaskSetManager(
           val serializedTask = Task.serializeWithDependencies(
             task, sched.sc.addedFiles, sched.sc.addedJars, ser)
           val timeTaken = clock.getTime() - startTime
-          increaseRunningTasks(1)
+          addRunningTask(taskId)
           logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
             taskSet.id, index, serializedTask.limit, timeTaken))
           val taskName = "task %s:%d".format(taskSet.id, index)
@@ -417,94 +417,63 @@ private[spark] class ClusterTaskSetManager(
     index
   }
 
-  /** Called by cluster scheduler when one of our tasks changes state */
-  override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
-    SparkEnv.set(env)
-    state match {
-      case TaskState.FINISHED =>
-        taskFinished(tid, state, serializedData)
-      case TaskState.LOST =>
-        taskLost(tid, state, serializedData)
-      case TaskState.FAILED =>
-        taskLost(tid, state, serializedData)
-      case TaskState.KILLED =>
-        taskLost(tid, state, serializedData)
-      case _ =>
-    }
-  }
-
-  def taskStarted(task: Task[_], info: TaskInfo) {
+  private def taskStarted(task: Task[_], info: TaskInfo) {
     sched.listener.taskStarted(task, info)
   }
 
-  def taskFinished(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+  /**
+   * Marks the task as successful and notifies the listener that a task has ended.
+   */
+  def handleSuccessfulTask(tid: Long, result: DirectTaskResult[_]) = {
     val info = taskInfos(tid)
-    if (info.failed) {
-      // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
-      // or even from Mesos itself when acks get delayed.
-      return
-    }
     val index = info.index
     info.markSuccessful()
-    decreaseRunningTasks(1)
-    if (!finished(index)) {
-      tasksFinished += 1
+    removeRunningTask(tid)
+    if (!successful(index)) {
       logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
-        tid, info.duration, info.host, tasksFinished, numTasks))
-      // Deserialize task result and pass it to the scheduler
-      try {
-        val result = ser.deserialize[TaskResult[_]](serializedData)
-        result.metrics.resultSize = serializedData.limit()
-        sched.listener.taskEnded(
-          tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
-      } catch {
-        case cnf: ClassNotFoundException =>
-          val loader = Thread.currentThread().getContextClassLoader
-          throw new SparkException("ClassNotFound with classloader: " + loader, cnf)
-        case ex => throw ex
-      }
-      // Mark finished and stop if we've finished all the tasks
-      finished(index) = true
-      if (tasksFinished == numTasks) {
+        tid, info.duration, info.host, tasksSuccessful, numTasks))
+      sched.listener.taskEnded(
+        tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
+
+      // Mark successful and stop if all the tasks have succeeded.
+      tasksSuccessful += 1
+      successful(index) = true
+      if (tasksSuccessful == numTasks) {
         sched.taskSetFinished(this)
       }
     } else {
-      logInfo("Ignoring task-finished event for TID " + tid +
-        " because task " + index + " is already finished")
+      logInfo("Ignorning task-finished event for TID " + tid + " because task " +
+        index + " has already completed successfully")
     }
   }
 
-  def taskLost(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+  /**
+   * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the listener.
+   */
+  def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
     val info = taskInfos(tid)
     if (info.failed) {
-      // We might get two task-lost messages for the same task in coarse-grained Mesos mode,
-      // or even from Mesos itself when acks get delayed.
       return
     }
+    removeRunningTask(tid)
     val index = info.index
     info.markFailed()
-    decreaseRunningTasks(1)
-    if (!finished(index)) {
+    // Count failed attempts only on FAILED and LOST state (not on KILLED)
+    var countFailedTaskAttempt = (state == TaskState.FAILED || state == TaskState.LOST)
+    if (!successful(index)) {
       logInfo("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
       copiesRunning(index) -= 1
       // Check if the problem is a map output fetch failure. In that case, this
       // task will never succeed on any node, so tell the scheduler about it.
-      if (serializedData != null && serializedData.limit() > 0) {
-        val reason = ser.deserialize[TaskEndReason](serializedData, getClass.getClassLoader)
-        reason match {
+      reason.foreach {
+        _ match {
           case fetchFailed: FetchFailed =>
             logInfo("Loss was due to fetch failure from " + fetchFailed.bmAddress)
             sched.listener.taskEnded(tasks(index), fetchFailed, null, null, info, null)
-            finished(index) = true
-            tasksFinished += 1
+            successful(index) = true
+            tasksSuccessful += 1
             sched.taskSetFinished(this)
-            decreaseRunningTasks(runningTasks)
-            return
-
-          case taskResultTooBig: TaskResultTooBigFailure =>
-            logInfo("Loss was due to task %s result exceeding Akka frame size; aborting job".format(
-              tid))
-            abort("Task %s result exceeded Akka frame size".format(tid))
+            removeAllRunningTasks()
             return
 
           case ef: ExceptionFailure =>
@@ -534,13 +503,17 @@ private[spark] class ClusterTaskSetManager(
               logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
             }
 
+          case TaskResultLost =>
+            logInfo("Lost result for TID %s on host %s".format(tid, info.host))
+            countFailedTaskAttempt = true
+            sched.listener.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
+
           case _ => {}
         }
       }
       // On non-fetch failures, re-enqueue the task as pending for a max number of retries
       addPendingTask(index)
-      // Count failed attempts only on FAILED and LOST state (not on KILLED)
-      if (state == TaskState.FAILED || state == TaskState.LOST) {
+      if (countFailedTaskAttempt) {
         numFailures(index) += 1
         if (numFailures(index) > MAX_TASK_FAILURES) {
           logError("Task %s:%d failed more than %d times; aborting job".format(
@@ -564,22 +537,36 @@ private[spark] class ClusterTaskSetManager(
     causeOfFailure = message
     // TODO: Kill running tasks if we were not terminated due to a Mesos error
     sched.listener.taskSetFailed(taskSet, message)
-    decreaseRunningTasks(runningTasks)
+    removeAllRunningTasks()
     sched.taskSetFinished(this)
   }
 
-  override def increaseRunningTasks(taskNum: Int) {
-    runningTasks += taskNum
-    if (parent != null) {
-      parent.increaseRunningTasks(taskNum)
+  /** If the given task ID is not in the set of running tasks, adds it.
+   *
+   * Used to keep track of the number of running tasks, for enforcing scheduling policies.
+   */
+  def addRunningTask(tid: Long) {
+    if (runningTasksSet.add(tid) && parent != null) {
+      parent.increaseRunningTasks(1)
+    }
+    runningTasks = runningTasksSet.size
+  }
+
+  /** If the given task ID is in the set of running tasks, removes it. */
+  def removeRunningTask(tid: Long) {
+    if (runningTasksSet.remove(tid) && parent != null) {
+      parent.decreaseRunningTasks(1)
     }
+    runningTasks = runningTasksSet.size
   }
 
-  override def decreaseRunningTasks(taskNum: Int) {
-    runningTasks -= taskNum
+  private def removeAllRunningTasks() {
+    val numRunningTasks = runningTasksSet.size
+    runningTasksSet.clear()
     if (parent != null) {
-      parent.decreaseRunningTasks(taskNum)
+      parent.decreaseRunningTasks(numRunningTasks)
     }
+    runningTasks = 0
   }
 
   override def getSchedulableByName(name: String): Schedulable = {
@@ -615,10 +602,10 @@ private[spark] class ClusterTaskSetManager(
     if (tasks(0).isInstanceOf[ShuffleMapTask]) {
       for ((tid, info) <- taskInfos if info.executorId == execId) {
         val index = taskInfos(tid).index
-        if (finished(index)) {
-          finished(index) = false
+        if (successful(index)) {
+          successful(index) = false
           copiesRunning(index) -= 1
-          tasksFinished -= 1
+          tasksSuccessful -= 1
           addPendingTask(index)
           // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our
           // stage finishes when a total of tasks.size tasks finish.
@@ -628,7 +615,7 @@ private[spark] class ClusterTaskSetManager(
     }
     // Also re-enqueue any tasks that were running on the node
     for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
-      taskLost(tid, TaskState.KILLED, null)
+      handleFailedTask(tid, TaskState.KILLED, None)
     }
   }
 
@@ -641,13 +628,13 @@ private[spark] class ClusterTaskSetManager(
    */
   override def checkSpeculatableTasks(): Boolean = {
     // Can't speculate if we only have one task, or if all tasks have finished.
-    if (numTasks == 1 || tasksFinished == numTasks) {
+    if (numTasks == 1 || tasksSuccessful == numTasks) {
       return false
     }
     var foundTasks = false
     val minFinishedForSpeculation = (SPECULATION_QUANTILE * numTasks).floor.toInt
     logDebug("Checking for speculative tasks: minFinished = " + minFinishedForSpeculation)
-    if (tasksFinished >= minFinishedForSpeculation) {
+    if (tasksSuccessful >= minFinishedForSpeculation) {
       val time = clock.getTime()
       val durations = taskInfos.values.filter(_.successful).map(_.duration).toArray
       Arrays.sort(durations)
@@ -658,7 +645,7 @@ private[spark] class ClusterTaskSetManager(
       logDebug("Task length threshold for speculation: " + threshold)
       for ((tid, info) <- taskInfos) {
         val index = info.index
-        if (!finished(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
+        if (!successful(index) && copiesRunning(index) == 1 && info.timeRunning(time) > threshold &&
           !speculatableTasks.contains(index)) {
           logInfo(
             "Marking task %s:%d (on %s) as speculatable because it ran more than %.0f ms".format(
@@ -672,7 +659,7 @@ private[spark] class ClusterTaskSetManager(
   }
 
   override def hasPendingTasks(): Boolean = {
-    numTasks > 0 && tasksFinished < numTasks
+    numTasks > 0 && tasksSuccessful < numTasks
   }
 
   private def getLocalityWait(level: TaskLocality.TaskLocality): Long = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/Pool.scala
index 35b32600da..199a0521ff 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/Pool.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/Pool.scala
@@ -45,7 +45,7 @@ private[spark] class Pool(
   var priority = 0
   var stageId = 0
   var name = poolName
-  var parent:Schedulable = null
+  var parent: Pool = null
 
   var taskSetSchedulingAlgorithm: SchedulingAlgorithm = {
     schedulingMode match {
@@ -101,14 +101,14 @@ private[spark] class Pool(
     return sortedTaskSetQueue
   }
 
-  override def increaseRunningTasks(taskNum: Int) {
+  def increaseRunningTasks(taskNum: Int) {
     runningTasks += taskNum
     if (parent != null) {
       parent.increaseRunningTasks(taskNum)
     }
   }
 
-  override def decreaseRunningTasks(taskNum: Int) {
+  def decreaseRunningTasks(taskNum: Int) {
     runningTasks -= taskNum
     if (parent != null) {
       parent.decreaseRunningTasks(taskNum)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/Schedulable.scala
index f4726450ec..171549fbd9 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/Schedulable.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/Schedulable.scala
@@ -25,7 +25,7 @@ import scala.collection.mutable.ArrayBuffer
  * there are two type of Schedulable entities(Pools and TaskSetManagers)
  */
 private[spark] trait Schedulable {
-  var parent: Schedulable
+  var parent: Pool
   // child queues
   def schedulableQueue: ArrayBuffer[Schedulable]
   def schedulingMode: SchedulingMode
@@ -36,8 +36,6 @@ private[spark] trait Schedulable {
   def stageId: Int
   def name: String
 
-  def increaseRunningTasks(taskNum: Int): Unit
-  def decreaseRunningTasks(taskNum: Int): Unit
   def addSchedulable(schedulable: Schedulable): Unit
   def removeSchedulable(schedulable: Schedulable): Unit
   def getSchedulableByName(name: String): Schedulable
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultResolver.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultResolver.scala
new file mode 100644
index 0000000000..812a9cf695
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultResolver.scala
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import java.nio.ByteBuffer
+import java.util.concurrent.{LinkedBlockingDeque, ThreadFactory, ThreadPoolExecutor, TimeUnit}
+
+import org.apache.spark._
+import org.apache.spark.TaskState.TaskState
+import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, TaskResult}
+import org.apache.spark.serializer.SerializerInstance
+
+/**
+ * Runs a thread pool that deserializes and remotely fetches (if neceessary) task results.
+ */
+private[spark] class TaskResultResolver(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
+  extends Logging {
+  private val MIN_THREADS = 20
+  private val MAX_THREADS = 60
+  private val KEEP_ALIVE_SECONDS = 60
+  private val getTaskResultExecutor = new ThreadPoolExecutor(
+    MIN_THREADS,
+    MAX_THREADS,
+    KEEP_ALIVE_SECONDS,
+    TimeUnit.SECONDS,
+    new LinkedBlockingDeque[Runnable],
+    new ResultResolverThreadFactory)
+
+  class ResultResolverThreadFactory extends ThreadFactory {
+    private var counter = 0
+    private var PREFIX = "Result resolver thread"
+
+    override def newThread(r: Runnable): Thread = {
+      val thread = new Thread(r, "%s-%s".format(PREFIX, counter))
+      counter += 1
+      thread.setDaemon(true)
+      return thread
+    }
+  }
+
+  protected val serializer = new ThreadLocal[SerializerInstance] {
+    override def initialValue(): SerializerInstance = {
+      return sparkEnv.closureSerializer.newInstance()
+    }
+  }
+
+  def enqueueSuccessfulTask(
+    taskSetManager: ClusterTaskSetManager, tid: Long, serializedData: ByteBuffer) {
+    getTaskResultExecutor.execute(new Runnable {
+      override def run() {
+        try {
+          val result = serializer.get().deserialize[TaskResult[_]](serializedData) match {
+            case directResult: DirectTaskResult[_] => directResult
+            case IndirectTaskResult(blockId) =>
+              logDebug("Fetching indirect task result for TID %s".format(tid))
+              val serializedTaskResult = sparkEnv.blockManager.getRemoteBytes(blockId)
+              if (!serializedTaskResult.isDefined) {
+                /* We won't be able to get the task result if the machine that ran the task failed
+                 * between when the task ended and when we tried to fetch the result, or if the
+                 * block manager had to flush the result. */
+                scheduler.handleFailedTask(
+                  taskSetManager, tid, TaskState.FINISHED, Some(TaskResultLost))
+                return
+              }
+              val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
+                serializedTaskResult.get)
+              sparkEnv.blockManager.master.removeBlock(blockId)
+              deserializedResult
+          }
+          result.metrics.resultSize = serializedData.limit()
+          scheduler.handleSuccessfulTask(taskSetManager, tid, result)
+        } catch {
+          case cnf: ClassNotFoundException =>
+            val loader = Thread.currentThread.getContextClassLoader
+            taskSetManager.abort("ClassNotFound with classloader: " + loader)
+          case ex =>
+            taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex))
+        }
+      }
+    })
+  }
+
+  def enqueueFailedTask(taskSetManager: ClusterTaskSetManager, tid: Long, taskState: TaskState,
+    serializedData: ByteBuffer) {
+    var reason: Option[TaskEndReason] = None
+    getTaskResultExecutor.execute(new Runnable {
+      override def run() {
+        try {
+          if (serializedData != null && serializedData.limit() > 0) {
+            reason = Some(serializer.get().deserialize[TaskEndReason](
+              serializedData, getClass.getClassLoader))
+          }
+        } catch {
+          case cnd: ClassNotFoundException =>
+            // Log an error but keep going here -- the task failed, so not catastropic if we can't
+            // deserialize the reason.
+            val loader = Thread.currentThread.getContextClassLoader
+            logError(
+              "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
+          case ex => {}
+        }
+        scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
+      }
+    })
+  }
+
+  def stop() {
+    getTaskResultExecutor.shutdownNow()
+  }
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskSetManager.scala
index 648a3ef922..a0f3758a24 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskSetManager.scala
@@ -45,7 +45,5 @@ private[spark] trait TaskSetManager extends Schedulable {
       maxLocality: TaskLocality.TaskLocality)
     : Option[TaskDescription]
 
-  def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer)
-
   def error(message: String)
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
index 8cb4d1396f..bcf9e1baf2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
@@ -92,7 +92,7 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
   var rootPool: Pool = null
   val schedulingMode: SchedulingMode = SchedulingMode.withName(
     System.getProperty("spark.scheduler.mode", "FIFO"))
-  val activeTaskSets = new HashMap[String, TaskSetManager]
+  val activeTaskSets = new HashMap[String, LocalTaskSetManager]
   val taskIdToTaskSetId = new HashMap[Long, String]
   val taskSetTaskIds = new HashMap[String, HashSet[Long]]
 
@@ -211,7 +211,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
       deserializedTask.metrics.get.executorRunTime = serviceTime.toInt
       deserializedTask.metrics.get.jvmGCTime = getTotalGCTime - startGCTime
       deserializedTask.metrics.get.executorDeserializeTime = deserTime.toInt
-      val taskResult = new TaskResult(result, accumUpdates, deserializedTask.metrics.getOrElse(null))
+      val taskResult = new DirectTaskResult(
+        result, accumUpdates, deserializedTask.metrics.getOrElse(null))
       val serializedResult = ser.serialize(taskResult)
       localActor ! LocalStatusUpdate(taskId, TaskState.FINISHED, serializedResult)
     } catch {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
index e52cb998bd..de0fd5a528 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalTaskSetManager.scala
@@ -21,16 +21,17 @@ import java.nio.ByteBuffer
 import scala.collection.mutable.ArrayBuffer
 import scala.collection.mutable.HashMap
 
-import org.apache.spark.{ExceptionFailure, Logging, SparkEnv, Success, TaskState}
+import org.apache.spark.{ExceptionFailure, Logging, SparkEnv, SparkException, Success, TaskState}
 import org.apache.spark.TaskState.TaskState
-import org.apache.spark.scheduler.{Task, TaskResult, TaskSet}
-import org.apache.spark.scheduler.cluster.{Schedulable, TaskDescription, TaskInfo, TaskLocality, TaskSetManager}
+import org.apache.spark.scheduler.{DirectTaskResult, IndirectTaskResult, Task, TaskResult, TaskSet}
+import org.apache.spark.scheduler.cluster.{Pool, Schedulable, TaskDescription, TaskInfo}
+import org.apache.spark.scheduler.cluster.{TaskLocality, TaskSetManager}
 
 
 private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: TaskSet)
   extends TaskSetManager with Logging {
 
-  var parent: Schedulable = null
+  var parent: Pool = null
   var weight: Int = 1
   var minShare: Int = 0
   var runningTasks: Int = 0
@@ -49,14 +50,14 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
   val numFailures = new Array[Int](numTasks)
   val MAX_TASK_FAILURES = sched.maxFailures
 
-  override def increaseRunningTasks(taskNum: Int): Unit = {
+  def increaseRunningTasks(taskNum: Int): Unit = {
     runningTasks += taskNum
     if (parent != null) {
      parent.increaseRunningTasks(taskNum)
     }
   }
 
-  override def decreaseRunningTasks(taskNum: Int): Unit = {
+  def decreaseRunningTasks(taskNum: Int): Unit = {
     runningTasks -= taskNum
     if (parent != null) {
       parent.decreaseRunningTasks(taskNum)
@@ -132,7 +133,7 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
     return None
   }
 
-  override def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
+  def statusUpdate(tid: Long, state: TaskState, serializedData: ByteBuffer) {
     SparkEnv.set(env)
     state match {
       case TaskState.FINISHED =>
@@ -152,7 +153,12 @@ private[spark] class LocalTaskSetManager(sched: LocalScheduler, val taskSet: Tas
     val index = info.index
     val task = taskSet.tasks(index)
     info.markSuccessful()
-    val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader)
+    val result = ser.deserialize[TaskResult[_]](serializedData, getClass.getClassLoader) match {
+      case directResult: DirectTaskResult[_] => directResult
+      case IndirectTaskResult(blockId) => {
+        throw new SparkException("Expect only DirectTaskResults when using LocalScheduler")
+      }
+    }
     result.metrics.resultSize = serializedData.limit()
     sched.listener.taskEnded(task, Success, result.value, result.accumUpdates, info, result.metrics)
     numFinished += 1
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 60fdc5f2ee..495a72db69 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -484,7 +484,7 @@ private[spark] class BlockManager(
     for (loc <- locations) {
       logDebug("Getting remote block " + blockId + " from " + loc)
       val data = BlockManagerWorker.syncGetBlock(
-          GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
+        GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
       if (data != null) {
         return Some(dataDeserialize(blockId, data))
       }
@@ -494,6 +494,31 @@ private[spark] class BlockManager(
     return None
   }
 
+  /**
+   * Get block from remote block managers as serialized bytes.
+   */
+   def getRemoteBytes(blockId: String): Option[ByteBuffer] = {
+     // TODO: As with getLocalBytes, this is very similar to getRemote and perhaps should be
+     // refactored.
+     if (blockId == null) {
+       throw new IllegalArgumentException("Block Id is null")
+     }
+     logDebug("Getting remote block " + blockId + " as bytes")
+     
+     val locations = master.getLocations(blockId)
+     for (loc <- locations) {
+       logDebug("Getting remote block " + blockId + " from " + loc)
+       val data = BlockManagerWorker.syncGetBlock(
+         GetBlock(blockId), ConnectionManagerId(loc.host, loc.port))
+       if (data != null) {
+         return Some(data)
+       }
+       logDebug("The value of block " + blockId + " is null")
+     }
+     logDebug("Block " + blockId + " not found")
+     return None
+   }
+
   /**
    * Get a block from the block manager (either local or remote).
    */
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 7a856d4081..cd2bf9a8ff 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -319,19 +319,6 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
       }
     }
   }
-
-  test("job should fail if TaskResult exceeds Akka frame size") {
-    // We must use local-cluster mode since results are returned differently
-    // when running under LocalScheduler:
-    sc = new SparkContext("local-cluster[1,1,512]", "test")
-    val akkaFrameSize =
-      sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
-    val rdd = sc.parallelize(Seq(1)).map{x => new Array[Byte](akkaFrameSize)}
-    val exception = intercept[SparkException] {
-      rdd.reduce((x, y) => x)
-    }
-    exception.getMessage should endWith("result exceeded Akka frame size")
-  }
 }
 
 object DistributedSuite {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultResolverSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultResolverSuite.scala
new file mode 100644
index 0000000000..ff058c13ab
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultResolverSuite.scala
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import java.nio.ByteBuffer
+
+import org.scalatest.BeforeAndAfter
+import org.scalatest.FunSuite
+
+import org.apache.spark.{LocalSparkContext, SparkContext, SparkEnv}
+import org.apache.spark.scheduler.cluster.{ClusterScheduler, ClusterTaskSetManager, TaskResultResolver}
+
+/**
+ * Removes the TaskResult from the BlockManager before delegating to a normal TaskResultResolver.
+ *
+ * Used to test the case where a BlockManager evicts the task result (or dies) before the
+ * TaskResult is retrieved.
+ */
+class ResultDeletingTaskResultResolver(sparkEnv: SparkEnv, scheduler: ClusterScheduler)
+  extends TaskResultResolver(sparkEnv, scheduler) {
+  var removedResult = false
+
+  override def enqueueSuccessfulTask(
+    taskSetManager: ClusterTaskSetManager, tid: Long, serializedData: ByteBuffer) {
+    if (!removedResult) {
+      // Only remove the result once, since we'd like to test the case where the task eventually
+      // succeeds.
+      serializer.get().deserialize[TaskResult[_]](serializedData) match {
+        case IndirectTaskResult(blockId) =>
+          sparkEnv.blockManager.master.removeBlock(blockId)
+        case directResult: DirectTaskResult[_] =>
+          taskSetManager.abort("Expect only indirect results") 
+      }
+      serializedData.rewind()
+      removedResult = true
+    }
+    super.enqueueSuccessfulTask(taskSetManager, tid, serializedData)
+  } 
+}
+
+/**
+ * Tests related to handling task results (both direct and indirect).
+ */
+class TaskResultResolverSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
+
+  before {
+    // Set the Akka frame size to be as small as possible (it must be an integer, so 1 is as small
+    // as we can make it) so the tests don't take too long.
+    System.setProperty("spark.akka.frameSize", "1")
+    // Use local-cluster mode because results are returned differently when running with the
+    // LocalScheduler.
+    sc = new SparkContext("local-cluster[1,1,512]", "test")
+  }
+
+  test("handling results smaller than Akka frame size") {
+    val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x)
+    assert(result === 2)
+  }
+
+  test("handling results larger than Akka frame size") { 
+    val akkaFrameSize =
+      sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
+    val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)
+    assert(result === 1.to(akkaFrameSize).toArray)
+
+    val RESULT_BLOCK_ID = "taskresult_0"
+    assert(sc.env.blockManager.master.getLocations(RESULT_BLOCK_ID).size === 0,
+      "Expect result to be removed from the block manager.")
+  }
+
+  test("task retried if result missing from block manager") {
+    // If this test hangs, it's probably because no resource offers were made after the task
+    // failed.
+    val scheduler: ClusterScheduler = sc.taskScheduler match {
+      case clusterScheduler: ClusterScheduler =>
+        clusterScheduler
+      case _ =>
+        assert(false, "Expect local cluster to use ClusterScheduler")
+        throw new ClassCastException
+    }
+    scheduler.taskResultResolver = new ResultDeletingTaskResultResolver(sc.env, scheduler)
+    val akkaFrameSize =
+      sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
+    val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)
+    assert(result === 1.to(akkaFrameSize).toArray)
+
+    // Make sure two tasks were run (one failed one, and a second retried one).
+    assert(scheduler.nextTaskId.get() === 2)
+  }
+}
+
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala
index 1b50ce06b3..95d3553d91 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterSchedulerSuite.scala
@@ -43,16 +43,16 @@ class FakeTaskSetManager(
   stageId = initStageId
   name = "TaskSet_"+stageId
   override val numTasks = initNumTasks
-  tasksFinished = 0
+  tasksSuccessful = 0
 
-  override def increaseRunningTasks(taskNum: Int) {
+  def increaseRunningTasks(taskNum: Int) {
     runningTasks += taskNum
     if (parent != null) {
       parent.increaseRunningTasks(taskNum)
     }
   }
 
-  override def decreaseRunningTasks(taskNum: Int) {
+  def decreaseRunningTasks(taskNum: Int) {
     runningTasks -= taskNum
     if (parent != null) {
       parent.decreaseRunningTasks(taskNum)
@@ -79,7 +79,7 @@ class FakeTaskSetManager(
       maxLocality: TaskLocality.TaskLocality)
     : Option[TaskDescription] =
   {
-    if (tasksFinished + runningTasks < numTasks) {
+    if (tasksSuccessful + runningTasks < numTasks) {
       increaseRunningTasks(1)
       return Some(new TaskDescription(0, execId, "task 0:0", 0, null))
     }
@@ -92,8 +92,8 @@ class FakeTaskSetManager(
 
   def taskFinished() {
     decreaseRunningTasks(1)
-    tasksFinished +=1
-    if (tasksFinished == numTasks) {
+    tasksSuccessful +=1
+    if (tasksSuccessful == numTasks) {
       parent.removeSchedulable(this)
     }
   }
@@ -114,7 +114,8 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
     val taskSetQueue = rootPool.getSortedTaskSetQueue()
     /* Just for Test*/
     for (manager <- taskSetQueue) {
-       logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks))
+       logInfo("parentName:%s, parent running tasks:%d, name:%s,runningTasks:%d".format(
+         manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks))
     }
     for (taskSet <- taskSetQueue) {
       taskSet.resourceOffer("execId_1", "hostname_1", 1, TaskLocality.ANY) match {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
index ff70a2cdf0..ef99651b80 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
@@ -101,7 +101,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
     assert(manager.resourceOffer("exec1", "host1", 2, PROCESS_LOCAL) === None)
 
     // Tell it the task has finished
-    manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0))
+    manager.handleSuccessfulTask(0, createTaskResult(0))
     assert(sched.endedTasks(0) === Success)
     assert(sched.finishedManagers.contains(manager))
   }
@@ -125,14 +125,14 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
     assert(manager.resourceOffer("exec1", "host1", 1, PROCESS_LOCAL) === None)
 
     // Finish the first two tasks
-    manager.statusUpdate(0, TaskState.FINISHED, createTaskResult(0))
-    manager.statusUpdate(1, TaskState.FINISHED, createTaskResult(1))
+    manager.handleSuccessfulTask(0, createTaskResult(0))
+    manager.handleSuccessfulTask(1, createTaskResult(1))
     assert(sched.endedTasks(0) === Success)
     assert(sched.endedTasks(1) === Success)
     assert(!sched.finishedManagers.contains(manager))
 
     // Finish the last task
-    manager.statusUpdate(2, TaskState.FINISHED, createTaskResult(2))
+    manager.handleSuccessfulTask(2, createTaskResult(2))
     assert(sched.endedTasks(2) === Success)
     assert(sched.finishedManagers.contains(manager))
   }
@@ -267,7 +267,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
     new TaskSet(tasks, 0, 0, 0, null)
   }
 
-  def createTaskResult(id: Int): ByteBuffer = {
-    ByteBuffer.wrap(Utils.serialize(new TaskResult[Int](id, mutable.Map.empty, new TaskMetrics)))
+  def createTaskResult(id: Int): DirectTaskResult[Int] = {
+    new DirectTaskResult[Int](id, mutable.Map.empty, new TaskMetrics)
   }
 }
-- 
GitLab