diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala index 58cc1ef185336e4d0037b9e8fa4041d881db033a..80d0c5a5e929ac2c62c2eef41e70f75b9789d841 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala @@ -40,6 +40,7 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* val startedTasks = new ArrayBuffer[Long] val endedTasks = new mutable.HashMap[Long, TaskEndReason] val finishedManagers = new ArrayBuffer[TaskSetManager] + val taskSetsFailed = new ArrayBuffer[String] val executors = new mutable.HashMap[String, String] ++ liveExecutors @@ -63,7 +64,9 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* def executorLost(execId: String) {} - def taskSetFailed(taskSet: TaskSet, reason: String) {} + def taskSetFailed(taskSet: TaskSet, reason: String) { + taskSetsFailed += taskSet.id + } } def removeExecutor(execId: String): Unit = executors -= execId @@ -270,6 +273,30 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0) } + test("repeated failures lead to task set abortion") { + sc = new SparkContext("local", "test") + val sched = new FakeClusterScheduler(sc, ("exec1", "host1")) + val taskSet = createTaskSet(1) + val clock = new FakeClock + val manager = new ClusterTaskSetManager(sched, taskSet, clock) + + // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted + // after the last failure. + (0 until manager.MAX_TASK_FAILURES).foreach { index => + val offerResult = manager.resourceOffer("exec1", "host1", 1, ANY) + assert(offerResult != None, + "Expect resource offer on iteration %s to return a task".format(index)) + assert(offerResult.get.index === 0) + manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, Some(TaskResultLost)) + if (index < manager.MAX_TASK_FAILURES) { + assert(!sched.taskSetsFailed.contains(taskSet.id)) + } else { + assert(sched.taskSetsFailed.contains(taskSet.id)) + } + } + } + + /** * Utility method to create a TaskSet, potentially setting a particular sequence of preferred * locations for each task (given as varargs) if this sequence is not empty.