diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index faf6dcd6186237b5d58a4612d78dc81bda2b54da..3fd6f5eb472f48ec9ddf4958ef50340e001102f5 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -53,3 +53,16 @@ private[spark] case class ExceptionFailure(
 private[spark] case object TaskResultLost extends TaskEndReason
 
 private[spark] case object TaskKilled extends TaskEndReason
+
+/**
+ * The task failed because the executor that it was running on was lost. This may happen because
+ * the task crashed the JVM.
+ */
+private[spark] case object ExecutorLostFailure extends TaskEndReason
+
+/**
+ * We don't know why the task ended -- for example, because of a ClassNotFound exception when
+ * deserializing the task result.
+ */
+private[spark] case object UnknownReason extends TaskEndReason
+
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 237cbf4c0c942c139f6405407ae24affa03e5c99..821241508ea32805b2e8f69368363392db985755 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -954,8 +954,8 @@ class DAGScheduler(
         // Do nothing here; the TaskScheduler handles these failures and resubmits the task.
 
       case other =>
-        // Unrecognized failure - abort all jobs depending on this stage
-        abortStage(stageIdToStage(task.stageId), task + " failed: " + other)
+        // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler
+        // will abort the job.
     }
   }
 
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index e9f2198a007e526a237f7190e542761bc191c8af..c4d1ad5733b4ce8708244046f10d5fac0fb9b031 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -21,6 +21,12 @@ import scala.collection._
 
 import org.apache.spark.executor.TaskMetrics
 
+/**
+ * Stores information about a stage to pass from the scheduler to SparkListeners.
+ *
+ * taskInfos stores the metrics for all tasks that have completed, including redundant, speculated
+ * tasks.
+ */
 class StageInfo(
     stage: Stage,
     val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
index 35e9544718eb2ea9e5842d58d0d572977151421d..bdec08e968a4528a87ffb8d1b7641f4d5af8bdb8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala
@@ -57,7 +57,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
                  * between when the task ended and when we tried to fetch the result, or if the
                  * block manager had to flush the result. */
                 scheduler.handleFailedTask(
-                  taskSetManager, tid, TaskState.FINISHED, Some(TaskResultLost))
+                  taskSetManager, tid, TaskState.FINISHED, TaskResultLost)
                 return
               }
               val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]](
@@ -80,13 +80,13 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul
 
   def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState,
     serializedData: ByteBuffer) {
-    var reason: Option[TaskEndReason] = None
+    var reason : TaskEndReason = UnknownReason
     getTaskResultExecutor.execute(new Runnable {
       override def run() {
         try {
           if (serializedData != null && serializedData.limit() > 0) {
-            reason = Some(serializer.get().deserialize[TaskEndReason](
-              serializedData, getClass.getClassLoader))
+            reason = serializer.get().deserialize[TaskEndReason](
+              serializedData, getClass.getClassLoader)
           }
         } catch {
           case cnd: ClassNotFoundException =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index 83ba5840155fb6a0a87cfec4b456e461b02fe937..5b525155e9f6229f81b2bddab892e9ad65ba1653 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -67,7 +67,6 @@ private[spark] class TaskSchedulerImpl(
 
   val taskIdToTaskSetId = new HashMap[Long, String]
   val taskIdToExecutorId = new HashMap[Long, String]
-  val taskSetTaskIds = new HashMap[String, HashSet[Long]]
 
   @volatile private var hasReceivedTask = false
   @volatile private var hasLaunchedTask = false
@@ -142,7 +141,6 @@ private[spark] class TaskSchedulerImpl(
       val manager = new TaskSetManager(this, taskSet, maxTaskFailures)
       activeTaskSets(taskSet.id) = manager
       schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
-      taskSetTaskIds(taskSet.id) = new HashSet[Long]()
 
       if (!isLocal && !hasReceivedTask) {
         starvationTimer.scheduleAtFixedRate(new TimerTask() {
@@ -171,31 +169,25 @@ private[spark] class TaskSchedulerImpl(
       //    the stage.
       // 2. The task set manager has been created but no tasks has been scheduled. In this case,
       //    simply abort the stage.
-      val taskIds = taskSetTaskIds(tsm.taskSet.id)
-      if (taskIds.size > 0) {
-        taskIds.foreach { tid =>
-          val execId = taskIdToExecutorId(tid)
-          backend.killTask(tid, execId)
-        }
+      tsm.runningTasksSet.foreach { tid =>
+        val execId = taskIdToExecutorId(tid)
+        backend.killTask(tid, execId)
       }
+      tsm.abort("Stage %s cancelled".format(stageId))
       logInfo("Stage %d was cancelled".format(stageId))
-      tsm.removeAllRunningTasks()
-      taskSetFinished(tsm)
     }
   }
 
+  /**
+   * Called to indicate that all task attempts (including speculated tasks) associated with the
+   * given TaskSetManager have completed, so state associated with the TaskSetManager should be
+   * cleaned up.
+   */
   def taskSetFinished(manager: TaskSetManager): Unit = synchronized {
-    // Check to see if the given task set has been removed. This is possible in the case of
-    // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has
-    // more than one running tasks).
-    if (activeTaskSets.contains(manager.taskSet.id)) {
-      activeTaskSets -= manager.taskSet.id
-      manager.parent.removeSchedulable(manager)
-      logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name))
-      taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id)
-      taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id)
-      taskSetTaskIds.remove(manager.taskSet.id)
-    }
+    activeTaskSets -= manager.taskSet.id
+    manager.parent.removeSchedulable(manager)
+    logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s"
+      .format(manager.taskSet.id, manager.parent.name))
   }
 
   /**
@@ -237,7 +229,6 @@ private[spark] class TaskSchedulerImpl(
             tasks(i) += task
             val tid = task.taskId
             taskIdToTaskSetId(tid) = taskSet.taskSet.id
-            taskSetTaskIds(taskSet.taskSet.id) += tid
             taskIdToExecutorId(tid) = execId
             activeExecutorIds += execId
             executorsByHost(host) += execId
@@ -270,9 +261,6 @@ private[spark] class TaskSchedulerImpl(
           case Some(taskSetId) =>
             if (TaskState.isFinished(state)) {
               taskIdToTaskSetId.remove(tid)
-              if (taskSetTaskIds.contains(taskSetId)) {
-                taskSetTaskIds(taskSetId) -= tid
-              }
               taskIdToExecutorId.remove(tid)
             }
             activeTaskSets.get(taskSetId).foreach { taskSet =>
@@ -285,7 +273,9 @@ private[spark] class TaskSchedulerImpl(
               }
             }
           case None =>
-            logInfo("Ignoring update with state %s from TID %s because its task set is gone"
+            logError(
+              ("Ignoring update with state %s for TID %s because its task set is gone (this is " +
+               "likely the result of receiving duplicate task finished status updates)")
               .format(state, tid))
         }
       } catch {
@@ -314,9 +304,9 @@ private[spark] class TaskSchedulerImpl(
     taskSetManager: TaskSetManager,
     tid: Long,
     taskState: TaskState,
-    reason: Option[TaskEndReason]) = synchronized {
+    reason: TaskEndReason) = synchronized {
     taskSetManager.handleFailedTask(tid, taskState, reason)
-    if (taskState != TaskState.KILLED) {
+    if (!taskSetManager.isZombie && taskState != TaskState.KILLED) {
       // Need to revive offers again now that the task set manager state has been updated to
       // reflect failed tasks that need to be re-run.
       backend.reviveOffers()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 777f31dc5e05340a1decd162d4f537f309132649..3f0ee7a6d48cb5dc4dd02c5a0b1af67d906a56a4 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -26,9 +26,10 @@ import scala.collection.mutable.HashSet
 import scala.math.max
 import scala.math.min
 
-import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv,
-  Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
+import org.apache.spark.{ExceptionFailure, ExecutorLostFailure, FetchFailed, Logging, Resubmitted,
+  SparkEnv, Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState}
 import org.apache.spark.TaskState.TaskState
+import org.apache.spark.executor.TaskMetrics
 import org.apache.spark.util.{Clock, SystemClock}
 
 
@@ -82,8 +83,16 @@ private[spark] class TaskSetManager(
   var name = "TaskSet_"+taskSet.stageId.toString
   var parent: Pool = null
 
-  var runningTasks = 0
-  private val runningTasksSet = new HashSet[Long]
+  val runningTasksSet = new HashSet[Long]
+  override def runningTasks = runningTasksSet.size
+
+  // True once no more tasks should be launched for this task set manager. TaskSetManagers enter
+  // the zombie state once at least one attempt of each task has completed successfully, or if the
+  // task set is aborted (for example, because it was killed).  TaskSetManagers remain in the zombie
+  // state until all tasks have finished running; we keep TaskSetManagers that are in the zombie
+  // state in order to continue to track and account for the running tasks.
+  // TODO: We should kill any running task attempts when the task set manager becomes a zombie.
+  var isZombie = false
 
   // Set of pending tasks for each executor. These collections are actually
   // treated as stacks, in which new tasks are added to the end of the
@@ -345,7 +354,7 @@ private[spark] class TaskSetManager(
       maxLocality: TaskLocality.TaskLocality)
     : Option[TaskDescription] =
   {
-    if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) {
+    if (!isZombie && availableCpus >= CPUS_PER_TASK) {
       val curTime = clock.getTime()
 
       var allowedLocality = getAllowedLocalityLevel(curTime)
@@ -380,8 +389,7 @@ private[spark] class TaskSetManager(
           logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
             taskSet.id, index, serializedTask.limit, timeTaken))
           val taskName = "task %s:%d".format(taskSet.id, index)
-          if (taskAttempts(index).size == 1)
-            taskStarted(task,info)
+          sched.dagScheduler.taskStarted(task, info)
           return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
         }
         case _ =>
@@ -390,6 +398,12 @@ private[spark] class TaskSetManager(
     None
   }
 
+  private def maybeFinishTaskSet() {
+    if (isZombie && runningTasks == 0) {
+      sched.taskSetFinished(this)
+    }
+  }
+
   /**
    * Get the level we can launch tasks according to delay scheduling, based on current wait time.
    */
@@ -418,10 +432,6 @@ private[spark] class TaskSetManager(
     index
   }
 
-  private def taskStarted(task: Task[_], info: TaskInfo) {
-    sched.dagScheduler.taskStarted(task, info)
-  }
-
   def handleTaskGettingResult(tid: Long) = {
     val info = taskInfos(tid)
     info.markGettingResult()
@@ -436,123 +446,116 @@ private[spark] class TaskSetManager(
     val index = info.index
     info.markSuccessful()
     removeRunningTask(tid)
+    sched.dagScheduler.taskEnded(
+      tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
     if (!successful(index)) {
       tasksSuccessful += 1
       logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format(
         tid, info.duration, info.host, tasksSuccessful, numTasks))
-      sched.dagScheduler.taskEnded(
-        tasks(index), Success, result.value, result.accumUpdates, info, result.metrics)
-
       // Mark successful and stop if all the tasks have succeeded.
       successful(index) = true
       if (tasksSuccessful == numTasks) {
-        sched.taskSetFinished(this)
+        isZombie = true
       }
     } else {
       logInfo("Ignorning task-finished event for TID " + tid + " because task " +
         index + " has already completed successfully")
     }
+    maybeFinishTaskSet()
   }
 
   /**
    * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the
    * DAG Scheduler.
    */
-  def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) {
+  def handleFailedTask(tid: Long, state: TaskState, reason: TaskEndReason) {
     val info = taskInfos(tid)
     if (info.failed) {
       return
     }
     removeRunningTask(tid)
-    val index = info.index
     info.markFailed()
-    var failureReason = "unknown"
-    if (!successful(index)) {
+    val index = info.index
+    copiesRunning(index) -= 1
+    if (!isZombie) {
       logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index))
-      copiesRunning(index) -= 1
-      // Check if the problem is a map output fetch failure. In that case, this
-      // task will never succeed on any node, so tell the scheduler about it.
-      reason.foreach {
-        case fetchFailed: FetchFailed =>
-          logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
-          sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null)
+    }
+    var taskMetrics : TaskMetrics = null
+    var failureReason = "unknown"
+    reason match {
+      case fetchFailed: FetchFailed =>
+        logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress)
+        if (!successful(index)) {
           successful(index) = true
           tasksSuccessful += 1
-          sched.taskSetFinished(this)
-          removeAllRunningTasks()
-          return
-
-        case TaskKilled =>
-          logWarning("Task %d was killed.".format(tid))
-          sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null)
+        }
+        isZombie = true
+
+      case TaskKilled =>
+        logWarning("Task %d was killed.".format(tid))
+
+      case ef: ExceptionFailure =>
+        taskMetrics = ef.metrics.getOrElse(null)
+        if (ef.className == classOf[NotSerializableException].getName()) {
+          // If the task result wasn't serializable, there's no point in trying to re-execute it.
+          logError("Task %s:%s had a not serializable result: %s; not retrying".format(
+            taskSet.id, index, ef.description))
+          abort("Task %s:%s had a not serializable result: %s".format(
+            taskSet.id, index, ef.description))
           return
-
-        case ef: ExceptionFailure =>
-          sched.dagScheduler.taskEnded(
-            tasks(index), ef, null, null, info, ef.metrics.getOrElse(null))
-          if (ef.className == classOf[NotSerializableException].getName()) {
-            // If the task result wasn't rerializable, there's no point in trying to re-execute it.
-            logError("Task %s:%s had a not serializable result: %s; not retrying".format(
-              taskSet.id, index, ef.description))
-            abort("Task %s:%s had a not serializable result: %s".format(
-              taskSet.id, index, ef.description))
-            return
-          }
-          val key = ef.description
-          failureReason = "Exception failure: %s".format(ef.description)
-          val now = clock.getTime()
-          val (printFull, dupCount) = {
-            if (recentExceptions.contains(key)) {
-              val (dupCount, printTime) = recentExceptions(key)
-              if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
-                recentExceptions(key) = (0, now)
-                (true, 0)
-              } else {
-                recentExceptions(key) = (dupCount + 1, printTime)
-                (false, dupCount + 1)
-              }
-            } else {
+        }
+        val key = ef.description
+        failureReason = "Exception failure: %s".format(ef.description)
+        val now = clock.getTime()
+        val (printFull, dupCount) = {
+          if (recentExceptions.contains(key)) {
+            val (dupCount, printTime) = recentExceptions(key)
+            if (now - printTime > EXCEPTION_PRINT_INTERVAL) {
               recentExceptions(key) = (0, now)
               (true, 0)
+            } else {
+              recentExceptions(key) = (dupCount + 1, printTime)
+              (false, dupCount + 1)
             }
-          }
-          if (printFull) {
-            val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
-            logWarning("Loss was due to %s\n%s\n%s".format(
-              ef.className, ef.description, locs.mkString("\n")))
           } else {
-            logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+            recentExceptions(key) = (0, now)
+            (true, 0)
           }
+        }
+        if (printFull) {
+          val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString))
+          logWarning("Loss was due to %s\n%s\n%s".format(
+            ef.className, ef.description, locs.mkString("\n")))
+        } else {
+          logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount))
+        }
 
-        case TaskResultLost =>
-          failureReason = "Lost result for TID %s on host %s".format(tid, info.host)
-          logWarning(failureReason)
-          sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null)
+      case TaskResultLost =>
+        failureReason = "Lost result for TID %s on host %s".format(tid, info.host)
+        logWarning(failureReason)
 
-        case _ => {}
-      }
-      // On non-fetch failures, re-enqueue the task as pending for a max number of retries
-      addPendingTask(index)
-      if (state != TaskState.KILLED) {
-        numFailures(index) += 1
-        if (numFailures(index) >= maxTaskFailures) {
-          logError("Task %s:%d failed %d times; aborting job".format(
-            taskSet.id, index, maxTaskFailures))
-          abort("Task %s:%d failed %d times (most recent failure: %s)".format(
-            taskSet.id, index, maxTaskFailures, failureReason))
-        }
+      case _ => {}
+    }
+    sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics)
+    addPendingTask(index)
+    if (!isZombie && state != TaskState.KILLED) {
+      numFailures(index) += 1
+      if (numFailures(index) >= maxTaskFailures) {
+        logError("Task %s:%d failed %d times; aborting job".format(
+          taskSet.id, index, maxTaskFailures))
+        abort("Task %s:%d failed %d times (most recent failure: %s)".format(
+          taskSet.id, index, maxTaskFailures, failureReason))
+        return
       }
-    } else {
-      logInfo("Ignoring task-lost event for TID " + tid +
-        " because task " + index + " is already finished")
     }
+    maybeFinishTaskSet()
   }
 
   def abort(message: String) {
     // TODO: Kill running tasks if we were not terminated due to a Mesos error
     sched.dagScheduler.taskSetFailed(taskSet, message)
-    removeAllRunningTasks()
-    sched.taskSetFinished(this)
+    isZombie = true
+    maybeFinishTaskSet()
   }
 
   /** If the given task ID is not in the set of running tasks, adds it.
@@ -563,7 +566,6 @@ private[spark] class TaskSetManager(
     if (runningTasksSet.add(tid) && parent != null) {
       parent.increaseRunningTasks(1)
     }
-    runningTasks = runningTasksSet.size
   }
 
   /** If the given task ID is in the set of running tasks, removes it. */
@@ -571,16 +573,6 @@ private[spark] class TaskSetManager(
     if (runningTasksSet.remove(tid) && parent != null) {
       parent.decreaseRunningTasks(1)
     }
-    runningTasks = runningTasksSet.size
-  }
-
-  private[scheduler] def removeAllRunningTasks() {
-    val numRunningTasks = runningTasksSet.size
-    runningTasksSet.clear()
-    if (parent != null) {
-      parent.decreaseRunningTasks(numRunningTasks)
-    }
-    runningTasks = 0
   }
 
   override def getSchedulableByName(name: String): Schedulable = {
@@ -629,7 +621,7 @@ private[spark] class TaskSetManager(
     }
     // Also re-enqueue any tasks that were running on the node
     for ((tid, info) <- taskInfos if info.running && info.executorId == execId) {
-      handleFailedTask(tid, TaskState.FAILED, None)
+      handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure)
     }
   }
 
@@ -641,8 +633,9 @@ private[spark] class TaskSetManager(
    * we don't scan the whole task set. It might also help to make this sorted by launch time.
    */
   override def checkSpeculatableTasks(): Boolean = {
-    // Can't speculate if we only have one task, or if all tasks have finished.
-    if (numTasks == 1 || tasksSuccessful == numTasks) {
+    // Can't speculate if we only have one task, and no need to speculate if the task set is a
+    // zombie.
+    if (isZombie || numTasks == 1) {
       return false
     }
     var foundTasks = false
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
index 235d31709af2b69d3965a8fa262cc1febee6e01e..98ea4cb5612ecabe68c6d332686a8fb1a41460e0 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
@@ -36,22 +36,24 @@ class FakeTaskSetManager(
   parent = null
   weight = 1
   minShare = 2
-  runningTasks = 0
   priority = initPriority
   stageId = initStageId
   name = "TaskSet_"+stageId
   override val numTasks = initNumTasks
   tasksSuccessful = 0
 
+  var numRunningTasks = 0
+  override def runningTasks = numRunningTasks
+
   def increaseRunningTasks(taskNum: Int) {
-    runningTasks += taskNum
+    numRunningTasks += taskNum
     if (parent != null) {
       parent.increaseRunningTasks(taskNum)
     }
   }
 
   def decreaseRunningTasks(taskNum: Int) {
-    runningTasks -= taskNum
+    numRunningTasks -= taskNum
     if (parent != null) {
       parent.decreaseRunningTasks(taskNum)
     }
@@ -77,7 +79,7 @@ class FakeTaskSetManager(
       maxLocality: TaskLocality.TaskLocality)
     : Option[TaskDescription] =
   {
-    if (tasksSuccessful + runningTasks < numTasks) {
+    if (tasksSuccessful + numRunningTasks < numTasks) {
       increaseRunningTasks(1)
       Some(new TaskDescription(0, execId, "task 0:0", 0, null))
     } else {
@@ -98,7 +100,7 @@ class FakeTaskSetManager(
   }
 
   def abort() {
-    decreaseRunningTasks(runningTasks)
+    decreaseRunningTasks(numRunningTasks)
     parent.removeSchedulable(this)
   }
 }
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 1a16e438c43d8e543c701829ae2d8ed2ec17d360..368c5154ea3b909968f637e13b731ba9f3f949d6 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -168,6 +168,39 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
     assert(listener.endedTasks.contains(TASK_INDEX))
   }
 
+  test("onTaskEnd() should be called for all started tasks, even after job has been killed") {
+    val WAIT_TIMEOUT_MILLIS = 10000
+    val listener = new SaveTaskEvents
+    sc.addSparkListener(listener)
+
+    val numTasks = 10
+    val f = sc.parallelize(1 to 10000, numTasks).map { i => Thread.sleep(10); i }.countAsync()
+    // Wait until one task has started (because we want to make sure that any tasks that are started
+    // have corresponding end events sent to the listener).
+    var finishTime = System.currentTimeMillis + WAIT_TIMEOUT_MILLIS
+    listener.synchronized {
+      var remainingWait = finishTime - System.currentTimeMillis
+      while (listener.startedTasks.isEmpty && remainingWait > 0) {
+        listener.wait(remainingWait)
+        remainingWait = finishTime - System.currentTimeMillis
+      }
+      assert(!listener.startedTasks.isEmpty)
+    }
+
+    f.cancel()
+
+    // Ensure that onTaskEnd is called for all started tasks.
+    finishTime = System.currentTimeMillis + WAIT_TIMEOUT_MILLIS
+    listener.synchronized {
+      var remainingWait = finishTime - System.currentTimeMillis
+      while (listener.endedTasks.size < listener.startedTasks.size && remainingWait > 0) {
+        listener.wait(finishTime - System.currentTimeMillis)
+        remainingWait = finishTime - System.currentTimeMillis
+      }
+      assert(listener.endedTasks.size === listener.startedTasks.size)
+    }
+  }
+
   def checkNonZeroAvg(m: Traversable[Long], msg: String) {
     assert(m.sum / m.size.toDouble > 0.0, msg)
   }
@@ -184,12 +217,14 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
     val startedGettingResultTasks = new HashSet[Int]()
     val endedTasks = new HashSet[Int]()
 
-    override def onTaskStart(taskStart: SparkListenerTaskStart) {
+    override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
       startedTasks += taskStart.taskInfo.index
+      notify()
     }
 
-    override def onTaskEnd(taskEnd: SparkListenerTaskEnd) {
-        endedTasks += taskEnd.taskInfo.index
+    override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
+      endedTasks += taskEnd.taskInfo.index
+      notify()
     }
 
     override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index ecac2f79a25e26b7f62cabebeca4cd28695e7f3b..de321c45b547ca2510060b996bcc67436e91f388 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -269,7 +269,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
     assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
 
     // Tell it the task has finished but the result was lost.
-    manager.handleFailedTask(0, TaskState.FINISHED, Some(TaskResultLost))
+    manager.handleFailedTask(0, TaskState.FINISHED, TaskResultLost)
     assert(sched.endedTasks(0) === TaskResultLost)
 
     // Re-offer the host -- now we should get task 0 again.
@@ -290,7 +290,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
       assert(offerResult.isDefined,
         "Expect resource offer on iteration %s to return a task".format(index))
       assert(offerResult.get.index === 0)
-      manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, Some(TaskResultLost))
+      manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost)
       if (index < MAX_TASK_FAILURES) {
         assert(!sched.taskSetsFailed.contains(taskSet.id))
       } else {