diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala index faf6dcd6186237b5d58a4612d78dc81bda2b54da..3fd6f5eb472f48ec9ddf4958ef50340e001102f5 100644 --- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala +++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala @@ -53,3 +53,16 @@ private[spark] case class ExceptionFailure( private[spark] case object TaskResultLost extends TaskEndReason private[spark] case object TaskKilled extends TaskEndReason + +/** + * The task failed because the executor that it was running on was lost. This may happen because + * the task crashed the JVM. + */ +private[spark] case object ExecutorLostFailure extends TaskEndReason + +/** + * We don't know why the task ended -- for example, because of a ClassNotFound exception when + * deserializing the task result. + */ +private[spark] case object UnknownReason extends TaskEndReason + diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 237cbf4c0c942c139f6405407ae24affa03e5c99..821241508ea32805b2e8f69368363392db985755 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -954,8 +954,8 @@ class DAGScheduler( // Do nothing here; the TaskScheduler handles these failures and resubmits the task. case other => - // Unrecognized failure - abort all jobs depending on this stage - abortStage(stageIdToStage(task.stageId), task + " failed: " + other) + // Unrecognized failure - also do nothing. If the task fails repeatedly, the TaskScheduler + // will abort the job. } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index e9f2198a007e526a237f7190e542761bc191c8af..c4d1ad5733b4ce8708244046f10d5fac0fb9b031 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -21,6 +21,12 @@ import scala.collection._ import org.apache.spark.executor.TaskMetrics +/** + * Stores information about a stage to pass from the scheduler to SparkListeners. + * + * taskInfos stores the metrics for all tasks that have completed, including redundant, speculated + * tasks. + */ class StageInfo( stage: Stage, val taskInfos: mutable.Buffer[(TaskInfo, TaskMetrics)] = mutable.Buffer[(TaskInfo, TaskMetrics)]() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 35e9544718eb2ea9e5842d58d0d572977151421d..bdec08e968a4528a87ffb8d1b7641f4d5af8bdb8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -57,7 +57,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul * between when the task ended and when we tried to fetch the result, or if the * block manager had to flush the result. */ scheduler.handleFailedTask( - taskSetManager, tid, TaskState.FINISHED, Some(TaskResultLost)) + taskSetManager, tid, TaskState.FINISHED, TaskResultLost) return } val deserializedResult = serializer.get().deserialize[DirectTaskResult[_]]( @@ -80,13 +80,13 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul def enqueueFailedTask(taskSetManager: TaskSetManager, tid: Long, taskState: TaskState, serializedData: ByteBuffer) { - var reason: Option[TaskEndReason] = None + var reason : TaskEndReason = UnknownReason getTaskResultExecutor.execute(new Runnable { override def run() { try { if (serializedData != null && serializedData.limit() > 0) { - reason = Some(serializer.get().deserialize[TaskEndReason]( - serializedData, getClass.getClassLoader)) + reason = serializer.get().deserialize[TaskEndReason]( + serializedData, getClass.getClassLoader) } } catch { case cnd: ClassNotFoundException => diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 83ba5840155fb6a0a87cfec4b456e461b02fe937..5b525155e9f6229f81b2bddab892e9ad65ba1653 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -67,7 +67,6 @@ private[spark] class TaskSchedulerImpl( val taskIdToTaskSetId = new HashMap[Long, String] val taskIdToExecutorId = new HashMap[Long, String] - val taskSetTaskIds = new HashMap[String, HashSet[Long]] @volatile private var hasReceivedTask = false @volatile private var hasLaunchedTask = false @@ -142,7 +141,6 @@ private[spark] class TaskSchedulerImpl( val manager = new TaskSetManager(this, taskSet, maxTaskFailures) activeTaskSets(taskSet.id) = manager schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties) - taskSetTaskIds(taskSet.id) = new HashSet[Long]() if (!isLocal && !hasReceivedTask) { starvationTimer.scheduleAtFixedRate(new TimerTask() { @@ -171,31 +169,25 @@ private[spark] class TaskSchedulerImpl( // the stage. // 2. The task set manager has been created but no tasks has been scheduled. In this case, // simply abort the stage. - val taskIds = taskSetTaskIds(tsm.taskSet.id) - if (taskIds.size > 0) { - taskIds.foreach { tid => - val execId = taskIdToExecutorId(tid) - backend.killTask(tid, execId) - } + tsm.runningTasksSet.foreach { tid => + val execId = taskIdToExecutorId(tid) + backend.killTask(tid, execId) } + tsm.abort("Stage %s cancelled".format(stageId)) logInfo("Stage %d was cancelled".format(stageId)) - tsm.removeAllRunningTasks() - taskSetFinished(tsm) } } + /** + * Called to indicate that all task attempts (including speculated tasks) associated with the + * given TaskSetManager have completed, so state associated with the TaskSetManager should be + * cleaned up. + */ def taskSetFinished(manager: TaskSetManager): Unit = synchronized { - // Check to see if the given task set has been removed. This is possible in the case of - // multiple unrecoverable task failures (e.g. if the entire task set is killed when it has - // more than one running tasks). - if (activeTaskSets.contains(manager.taskSet.id)) { - activeTaskSets -= manager.taskSet.id - manager.parent.removeSchedulable(manager) - logInfo("Remove TaskSet %s from pool %s".format(manager.taskSet.id, manager.parent.name)) - taskIdToTaskSetId --= taskSetTaskIds(manager.taskSet.id) - taskIdToExecutorId --= taskSetTaskIds(manager.taskSet.id) - taskSetTaskIds.remove(manager.taskSet.id) - } + activeTaskSets -= manager.taskSet.id + manager.parent.removeSchedulable(manager) + logInfo("Removed TaskSet %s, whose tasks have all completed, from pool %s" + .format(manager.taskSet.id, manager.parent.name)) } /** @@ -237,7 +229,6 @@ private[spark] class TaskSchedulerImpl( tasks(i) += task val tid = task.taskId taskIdToTaskSetId(tid) = taskSet.taskSet.id - taskSetTaskIds(taskSet.taskSet.id) += tid taskIdToExecutorId(tid) = execId activeExecutorIds += execId executorsByHost(host) += execId @@ -270,9 +261,6 @@ private[spark] class TaskSchedulerImpl( case Some(taskSetId) => if (TaskState.isFinished(state)) { taskIdToTaskSetId.remove(tid) - if (taskSetTaskIds.contains(taskSetId)) { - taskSetTaskIds(taskSetId) -= tid - } taskIdToExecutorId.remove(tid) } activeTaskSets.get(taskSetId).foreach { taskSet => @@ -285,7 +273,9 @@ private[spark] class TaskSchedulerImpl( } } case None => - logInfo("Ignoring update with state %s from TID %s because its task set is gone" + logError( + ("Ignoring update with state %s for TID %s because its task set is gone (this is " + + "likely the result of receiving duplicate task finished status updates)") .format(state, tid)) } } catch { @@ -314,9 +304,9 @@ private[spark] class TaskSchedulerImpl( taskSetManager: TaskSetManager, tid: Long, taskState: TaskState, - reason: Option[TaskEndReason]) = synchronized { + reason: TaskEndReason) = synchronized { taskSetManager.handleFailedTask(tid, taskState, reason) - if (taskState != TaskState.KILLED) { + if (!taskSetManager.isZombie && taskState != TaskState.KILLED) { // Need to revive offers again now that the task set manager state has been updated to // reflect failed tasks that need to be re-run. backend.reviveOffers() diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 777f31dc5e05340a1decd162d4f537f309132649..3f0ee7a6d48cb5dc4dd02c5a0b1af67d906a56a4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -26,9 +26,10 @@ import scala.collection.mutable.HashSet import scala.math.max import scala.math.min -import org.apache.spark.{ExceptionFailure, FetchFailed, Logging, Resubmitted, SparkEnv, - Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState} +import org.apache.spark.{ExceptionFailure, ExecutorLostFailure, FetchFailed, Logging, Resubmitted, + SparkEnv, Success, TaskEndReason, TaskKilled, TaskResultLost, TaskState} import org.apache.spark.TaskState.TaskState +import org.apache.spark.executor.TaskMetrics import org.apache.spark.util.{Clock, SystemClock} @@ -82,8 +83,16 @@ private[spark] class TaskSetManager( var name = "TaskSet_"+taskSet.stageId.toString var parent: Pool = null - var runningTasks = 0 - private val runningTasksSet = new HashSet[Long] + val runningTasksSet = new HashSet[Long] + override def runningTasks = runningTasksSet.size + + // True once no more tasks should be launched for this task set manager. TaskSetManagers enter + // the zombie state once at least one attempt of each task has completed successfully, or if the + // task set is aborted (for example, because it was killed). TaskSetManagers remain in the zombie + // state until all tasks have finished running; we keep TaskSetManagers that are in the zombie + // state in order to continue to track and account for the running tasks. + // TODO: We should kill any running task attempts when the task set manager becomes a zombie. + var isZombie = false // Set of pending tasks for each executor. These collections are actually // treated as stacks, in which new tasks are added to the end of the @@ -345,7 +354,7 @@ private[spark] class TaskSetManager( maxLocality: TaskLocality.TaskLocality) : Option[TaskDescription] = { - if (tasksSuccessful < numTasks && availableCpus >= CPUS_PER_TASK) { + if (!isZombie && availableCpus >= CPUS_PER_TASK) { val curTime = clock.getTime() var allowedLocality = getAllowedLocalityLevel(curTime) @@ -380,8 +389,7 @@ private[spark] class TaskSetManager( logInfo("Serialized task %s:%d as %d bytes in %d ms".format( taskSet.id, index, serializedTask.limit, timeTaken)) val taskName = "task %s:%d".format(taskSet.id, index) - if (taskAttempts(index).size == 1) - taskStarted(task,info) + sched.dagScheduler.taskStarted(task, info) return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask)) } case _ => @@ -390,6 +398,12 @@ private[spark] class TaskSetManager( None } + private def maybeFinishTaskSet() { + if (isZombie && runningTasks == 0) { + sched.taskSetFinished(this) + } + } + /** * Get the level we can launch tasks according to delay scheduling, based on current wait time. */ @@ -418,10 +432,6 @@ private[spark] class TaskSetManager( index } - private def taskStarted(task: Task[_], info: TaskInfo) { - sched.dagScheduler.taskStarted(task, info) - } - def handleTaskGettingResult(tid: Long) = { val info = taskInfos(tid) info.markGettingResult() @@ -436,123 +446,116 @@ private[spark] class TaskSetManager( val index = info.index info.markSuccessful() removeRunningTask(tid) + sched.dagScheduler.taskEnded( + tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) if (!successful(index)) { tasksSuccessful += 1 logInfo("Finished TID %s in %d ms on %s (progress: %d/%d)".format( tid, info.duration, info.host, tasksSuccessful, numTasks)) - sched.dagScheduler.taskEnded( - tasks(index), Success, result.value, result.accumUpdates, info, result.metrics) - // Mark successful and stop if all the tasks have succeeded. successful(index) = true if (tasksSuccessful == numTasks) { - sched.taskSetFinished(this) + isZombie = true } } else { logInfo("Ignorning task-finished event for TID " + tid + " because task " + index + " has already completed successfully") } + maybeFinishTaskSet() } /** * Marks the task as failed, re-adds it to the list of pending tasks, and notifies the * DAG Scheduler. */ - def handleFailedTask(tid: Long, state: TaskState, reason: Option[TaskEndReason]) { + def handleFailedTask(tid: Long, state: TaskState, reason: TaskEndReason) { val info = taskInfos(tid) if (info.failed) { return } removeRunningTask(tid) - val index = info.index info.markFailed() - var failureReason = "unknown" - if (!successful(index)) { + val index = info.index + copiesRunning(index) -= 1 + if (!isZombie) { logWarning("Lost TID %s (task %s:%d)".format(tid, taskSet.id, index)) - copiesRunning(index) -= 1 - // Check if the problem is a map output fetch failure. In that case, this - // task will never succeed on any node, so tell the scheduler about it. - reason.foreach { - case fetchFailed: FetchFailed => - logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress) - sched.dagScheduler.taskEnded(tasks(index), fetchFailed, null, null, info, null) + } + var taskMetrics : TaskMetrics = null + var failureReason = "unknown" + reason match { + case fetchFailed: FetchFailed => + logWarning("Loss was due to fetch failure from " + fetchFailed.bmAddress) + if (!successful(index)) { successful(index) = true tasksSuccessful += 1 - sched.taskSetFinished(this) - removeAllRunningTasks() - return - - case TaskKilled => - logWarning("Task %d was killed.".format(tid)) - sched.dagScheduler.taskEnded(tasks(index), reason.get, null, null, info, null) + } + isZombie = true + + case TaskKilled => + logWarning("Task %d was killed.".format(tid)) + + case ef: ExceptionFailure => + taskMetrics = ef.metrics.getOrElse(null) + if (ef.className == classOf[NotSerializableException].getName()) { + // If the task result wasn't serializable, there's no point in trying to re-execute it. + logError("Task %s:%s had a not serializable result: %s; not retrying".format( + taskSet.id, index, ef.description)) + abort("Task %s:%s had a not serializable result: %s".format( + taskSet.id, index, ef.description)) return - - case ef: ExceptionFailure => - sched.dagScheduler.taskEnded( - tasks(index), ef, null, null, info, ef.metrics.getOrElse(null)) - if (ef.className == classOf[NotSerializableException].getName()) { - // If the task result wasn't rerializable, there's no point in trying to re-execute it. - logError("Task %s:%s had a not serializable result: %s; not retrying".format( - taskSet.id, index, ef.description)) - abort("Task %s:%s had a not serializable result: %s".format( - taskSet.id, index, ef.description)) - return - } - val key = ef.description - failureReason = "Exception failure: %s".format(ef.description) - val now = clock.getTime() - val (printFull, dupCount) = { - if (recentExceptions.contains(key)) { - val (dupCount, printTime) = recentExceptions(key) - if (now - printTime > EXCEPTION_PRINT_INTERVAL) { - recentExceptions(key) = (0, now) - (true, 0) - } else { - recentExceptions(key) = (dupCount + 1, printTime) - (false, dupCount + 1) - } - } else { + } + val key = ef.description + failureReason = "Exception failure: %s".format(ef.description) + val now = clock.getTime() + val (printFull, dupCount) = { + if (recentExceptions.contains(key)) { + val (dupCount, printTime) = recentExceptions(key) + if (now - printTime > EXCEPTION_PRINT_INTERVAL) { recentExceptions(key) = (0, now) (true, 0) + } else { + recentExceptions(key) = (dupCount + 1, printTime) + (false, dupCount + 1) } - } - if (printFull) { - val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) - logWarning("Loss was due to %s\n%s\n%s".format( - ef.className, ef.description, locs.mkString("\n"))) } else { - logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) + recentExceptions(key) = (0, now) + (true, 0) } + } + if (printFull) { + val locs = ef.stackTrace.map(loc => "\tat %s".format(loc.toString)) + logWarning("Loss was due to %s\n%s\n%s".format( + ef.className, ef.description, locs.mkString("\n"))) + } else { + logInfo("Loss was due to %s [duplicate %d]".format(ef.description, dupCount)) + } - case TaskResultLost => - failureReason = "Lost result for TID %s on host %s".format(tid, info.host) - logWarning(failureReason) - sched.dagScheduler.taskEnded(tasks(index), TaskResultLost, null, null, info, null) + case TaskResultLost => + failureReason = "Lost result for TID %s on host %s".format(tid, info.host) + logWarning(failureReason) - case _ => {} - } - // On non-fetch failures, re-enqueue the task as pending for a max number of retries - addPendingTask(index) - if (state != TaskState.KILLED) { - numFailures(index) += 1 - if (numFailures(index) >= maxTaskFailures) { - logError("Task %s:%d failed %d times; aborting job".format( - taskSet.id, index, maxTaskFailures)) - abort("Task %s:%d failed %d times (most recent failure: %s)".format( - taskSet.id, index, maxTaskFailures, failureReason)) - } + case _ => {} + } + sched.dagScheduler.taskEnded(tasks(index), reason, null, null, info, taskMetrics) + addPendingTask(index) + if (!isZombie && state != TaskState.KILLED) { + numFailures(index) += 1 + if (numFailures(index) >= maxTaskFailures) { + logError("Task %s:%d failed %d times; aborting job".format( + taskSet.id, index, maxTaskFailures)) + abort("Task %s:%d failed %d times (most recent failure: %s)".format( + taskSet.id, index, maxTaskFailures, failureReason)) + return } - } else { - logInfo("Ignoring task-lost event for TID " + tid + - " because task " + index + " is already finished") } + maybeFinishTaskSet() } def abort(message: String) { // TODO: Kill running tasks if we were not terminated due to a Mesos error sched.dagScheduler.taskSetFailed(taskSet, message) - removeAllRunningTasks() - sched.taskSetFinished(this) + isZombie = true + maybeFinishTaskSet() } /** If the given task ID is not in the set of running tasks, adds it. @@ -563,7 +566,6 @@ private[spark] class TaskSetManager( if (runningTasksSet.add(tid) && parent != null) { parent.increaseRunningTasks(1) } - runningTasks = runningTasksSet.size } /** If the given task ID is in the set of running tasks, removes it. */ @@ -571,16 +573,6 @@ private[spark] class TaskSetManager( if (runningTasksSet.remove(tid) && parent != null) { parent.decreaseRunningTasks(1) } - runningTasks = runningTasksSet.size - } - - private[scheduler] def removeAllRunningTasks() { - val numRunningTasks = runningTasksSet.size - runningTasksSet.clear() - if (parent != null) { - parent.decreaseRunningTasks(numRunningTasks) - } - runningTasks = 0 } override def getSchedulableByName(name: String): Schedulable = { @@ -629,7 +621,7 @@ private[spark] class TaskSetManager( } // Also re-enqueue any tasks that were running on the node for ((tid, info) <- taskInfos if info.running && info.executorId == execId) { - handleFailedTask(tid, TaskState.FAILED, None) + handleFailedTask(tid, TaskState.FAILED, ExecutorLostFailure) } } @@ -641,8 +633,9 @@ private[spark] class TaskSetManager( * we don't scan the whole task set. It might also help to make this sorted by launch time. */ override def checkSpeculatableTasks(): Boolean = { - // Can't speculate if we only have one task, or if all tasks have finished. - if (numTasks == 1 || tasksSuccessful == numTasks) { + // Can't speculate if we only have one task, and no need to speculate if the task set is a + // zombie. + if (isZombie || numTasks == 1) { return false } var foundTasks = false diff --git a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala index 235d31709af2b69d3965a8fa262cc1febee6e01e..98ea4cb5612ecabe68c6d332686a8fb1a41460e0 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala @@ -36,22 +36,24 @@ class FakeTaskSetManager( parent = null weight = 1 minShare = 2 - runningTasks = 0 priority = initPriority stageId = initStageId name = "TaskSet_"+stageId override val numTasks = initNumTasks tasksSuccessful = 0 + var numRunningTasks = 0 + override def runningTasks = numRunningTasks + def increaseRunningTasks(taskNum: Int) { - runningTasks += taskNum + numRunningTasks += taskNum if (parent != null) { parent.increaseRunningTasks(taskNum) } } def decreaseRunningTasks(taskNum: Int) { - runningTasks -= taskNum + numRunningTasks -= taskNum if (parent != null) { parent.decreaseRunningTasks(taskNum) } @@ -77,7 +79,7 @@ class FakeTaskSetManager( maxLocality: TaskLocality.TaskLocality) : Option[TaskDescription] = { - if (tasksSuccessful + runningTasks < numTasks) { + if (tasksSuccessful + numRunningTasks < numTasks) { increaseRunningTasks(1) Some(new TaskDescription(0, execId, "task 0:0", 0, null)) } else { @@ -98,7 +100,7 @@ class FakeTaskSetManager( } def abort() { - decreaseRunningTasks(runningTasks) + decreaseRunningTasks(numRunningTasks) parent.removeSchedulable(this) } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 1a16e438c43d8e543c701829ae2d8ed2ec17d360..368c5154ea3b909968f637e13b731ba9f3f949d6 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -168,6 +168,39 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc assert(listener.endedTasks.contains(TASK_INDEX)) } + test("onTaskEnd() should be called for all started tasks, even after job has been killed") { + val WAIT_TIMEOUT_MILLIS = 10000 + val listener = new SaveTaskEvents + sc.addSparkListener(listener) + + val numTasks = 10 + val f = sc.parallelize(1 to 10000, numTasks).map { i => Thread.sleep(10); i }.countAsync() + // Wait until one task has started (because we want to make sure that any tasks that are started + // have corresponding end events sent to the listener). + var finishTime = System.currentTimeMillis + WAIT_TIMEOUT_MILLIS + listener.synchronized { + var remainingWait = finishTime - System.currentTimeMillis + while (listener.startedTasks.isEmpty && remainingWait > 0) { + listener.wait(remainingWait) + remainingWait = finishTime - System.currentTimeMillis + } + assert(!listener.startedTasks.isEmpty) + } + + f.cancel() + + // Ensure that onTaskEnd is called for all started tasks. + finishTime = System.currentTimeMillis + WAIT_TIMEOUT_MILLIS + listener.synchronized { + var remainingWait = finishTime - System.currentTimeMillis + while (listener.endedTasks.size < listener.startedTasks.size && remainingWait > 0) { + listener.wait(finishTime - System.currentTimeMillis) + remainingWait = finishTime - System.currentTimeMillis + } + assert(listener.endedTasks.size === listener.startedTasks.size) + } + } + def checkNonZeroAvg(m: Traversable[Long], msg: String) { assert(m.sum / m.size.toDouble > 0.0, msg) } @@ -184,12 +217,14 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc val startedGettingResultTasks = new HashSet[Int]() val endedTasks = new HashSet[Int]() - override def onTaskStart(taskStart: SparkListenerTaskStart) { + override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { startedTasks += taskStart.taskInfo.index + notify() } - override def onTaskEnd(taskEnd: SparkListenerTaskEnd) { - endedTasks += taskEnd.taskInfo.index + override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized { + endedTasks += taskEnd.taskInfo.index + notify() } override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index ecac2f79a25e26b7f62cabebeca4cd28695e7f3b..de321c45b547ca2510060b996bcc67436e91f388 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -269,7 +269,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) // Tell it the task has finished but the result was lost. - manager.handleFailedTask(0, TaskState.FINISHED, Some(TaskResultLost)) + manager.handleFailedTask(0, TaskState.FINISHED, TaskResultLost) assert(sched.endedTasks(0) === TaskResultLost) // Re-offer the host -- now we should get task 0 again. @@ -290,7 +290,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { assert(offerResult.isDefined, "Expect resource offer on iteration %s to return a task".format(index)) assert(offerResult.get.index === 0) - manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, Some(TaskResultLost)) + manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, TaskResultLost) if (index < MAX_TASK_FAILURES) { assert(!sched.taskSetsFailed.contains(taskSet.id)) } else {