From 46f9c6b858cf9737b7d46b22b75bfc847244331b Mon Sep 17 00:00:00 2001
From: Kay Ousterhout <kayousterhout@gmail.com>
Date: Wed, 13 Nov 2013 15:46:41 -0800
Subject: [PATCH] Fixed naming issues and added back ability to specify max
 task failures.

---
 .../scala/org/apache/spark/SparkContext.scala | 17 +++-
 .../spark/scheduler/ClusterScheduler.scala    | 19 ++--
 .../spark/scheduler/SchedulerBackend.scala    |  2 +-
 .../spark/scheduler/TaskScheduler.scala       | 56 +++++++++++
 .../spark/scheduler/TaskSetManager.scala      |  2 +-
 .../CoarseGrainedSchedulerBackend.scala       |  2 +-
 .../cluster/mesos/MesosSchedulerBackend.scala |  2 +-
 .../scala/org/apache/spark/FailureSuite.scala | 20 ++--
 .../scheduler/ClusterSchedulerSuite.scala     | 48 ++++-----
 .../spark/scheduler/DAGSchedulerSuite.scala   | 97 +++++++++----------
 .../scheduler/TaskResultGetterSuite.scala     | 13 +--
 .../spark/scheduler/TaskSetManagerSuite.scala | 20 ++--
 .../cluster/YarnClusterScheduler.scala        |  6 +-
 13 files changed, 177 insertions(+), 127 deletions(-)
 create mode 100644 core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala

diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 10db2fa7e7..06bea0c535 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -156,6 +156,8 @@ class SparkContext(
   private[spark] var taskScheduler: TaskScheduler = {
     // Regular expression used for local[N] master format
     val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
+    // Regular expression for local[N, maxRetries], used in tests with failing tasks
+    val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
     // Regular expression for simulating a Spark cluster of [N, cores, memory] locally
     val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
     // Regular expression for connecting to Spark deploy clusters
@@ -165,19 +167,28 @@ class SparkContext(
     // Regular expression for connection to Simr cluster
     val SIMR_REGEX = """simr://(.*)""".r
 
+    // When running locally, don't try to re-execute tasks on failure.
+    val MAX_LOCAL_TASK_FAILURES = 0
+
     master match {
       case "local" =>
-        val scheduler = new ClusterScheduler(this, isLocal = true)
+        val scheduler = new ClusterScheduler(this, MAX_LOCAL_TASK_FAILURES, isLocal = true)
         val backend = new LocalBackend(scheduler, 1) 
         scheduler.initialize(backend)
         scheduler
 
       case LOCAL_N_REGEX(threads) =>
-        val scheduler = new ClusterScheduler(this, isLocal = true)
+        val scheduler = new ClusterScheduler(this, MAX_LOCAL_TASK_FAILURES, isLocal = true)
         val backend = new LocalBackend(scheduler, threads.toInt) 
         scheduler.initialize(backend)
         scheduler
 
+      case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
+        val scheduler = new ClusterScheduler(this, maxFailures.toInt, isLocal = true)
+        val backend = new LocalBackend(scheduler, threads.toInt)
+        scheduler.initialize(backend)
+        scheduler
+
       case SPARK_REGEX(sparkUrl) =>
         val scheduler = new ClusterScheduler(this)
         val masterUrls = sparkUrl.split(",").map("spark://" + _)
@@ -200,7 +211,7 @@ class SparkContext(
               memoryPerSlaveInt, SparkContext.executorMemoryRequested))
         }
 
-        val scheduler = new ClusterScheduler(this, isLocal = true)
+        val scheduler = new ClusterScheduler(this)
         val localCluster = new LocalSparkCluster(
           numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
         val masterUrls = localCluster.start()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala
index c5d7ca0481..37d554715d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ClusterScheduler.scala
@@ -46,8 +46,10 @@ import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
  * acquire a lock on us, so we need to make sure that we don't try to lock the backend while
  * we are holding a lock on ourselves.
  */
-private[spark] class ClusterScheduler(val sc: SparkContext, isLocal: Boolean = false)
-  extends TaskScheduler with Logging {
+private[spark] class ClusterScheduler(
+  val sc: SparkContext,
+  val maxTaskFailures : Int = System.getProperty("spark.task.maxFailures", "4").toInt,
+  isLocal: Boolean = false) extends TaskScheduler with Logging {
 
   // How often to check for speculative tasks
   val SPECULATION_INTERVAL = System.getProperty("spark.speculation.interval", "100").toLong
@@ -59,15 +61,6 @@ private[spark] class ClusterScheduler(val sc: SparkContext, isLocal: Boolean = f
   // on this class.
   val activeTaskSets = new HashMap[String, TaskSetManager]
 
-  val MAX_TASK_FAILURES = {
-    if (isLocal) {
-      // No sense in retrying if all tasks run locally!
-      0
-    } else {
-      System.getProperty("spark.task.maxFailures", "4").toInt
-    }
-  }
-
   val taskIdToTaskSetId = new HashMap[Long, String]
   val taskIdToExecutorId = new HashMap[Long, String]
   val taskSetTaskIds = new HashMap[String, HashSet[Long]]
@@ -142,7 +135,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext, isLocal: Boolean = f
     val tasks = taskSet.tasks
     logInfo("Adding task set " + taskSet.id + " with " + tasks.length + " tasks")
     this.synchronized {
-      val manager = new TaskSetManager(this, taskSet, MAX_TASK_FAILURES)
+      val manager = new TaskSetManager(this, taskSet, maxTaskFailures)
       activeTaskSets(taskSet.id) = manager
       schedulableBuilder.addTaskSetManager(manager, manager.taskSet.properties)
       taskSetTaskIds(taskSet.id) = new HashSet[Long]()
@@ -345,7 +338,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext, isLocal: Boolean = f
         // No task sets are active but we still got an error. Just exit since this
         // must mean the error is during registration.
         // It might be good to do something smarter here in the future.
-        logError("Exiting due to error from task scheduler: " + message)
+        logError("Exiting due to error from cluster scheduler: " + message)
         System.exit(1)
       }
     }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
index 1f0839a0e1..89aa098664 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulerBackend.scala
@@ -21,7 +21,7 @@ import org.apache.spark.SparkContext
 
 /**
  * A backend interface for scheduling systems that allows plugging in different ones under
- * TaskScheduler. We assume a Mesos-like model where the application gets resource offers as
+ * ClusterScheduler. We assume a Mesos-like model where the application gets resource offers as
  * machines become available and can launch tasks on them.
  */
 private[spark] trait SchedulerBackend {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
new file mode 100644
index 0000000000..17b6d97e90
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler
+
+import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
+
+/**
+ * Low-level task scheduler interface, currently implemented exclusively by the ClusterScheduler.
+ * This interface allows plugging in different task schedulers. Each TaskScheduler schedulers tasks
+ * for a single SparkContext. These schedulers get sets of tasks submitted to them from the
+ * DAGScheduler for each stage, and are responsible for sending the tasks to the cluster, running
+ * them, retrying if there are failures, and mitigating stragglers. They return events to the
+ * DAGScheduler.
+ */
+private[spark] trait TaskScheduler {
+
+  def rootPool: Pool
+
+  def schedulingMode: SchedulingMode
+
+  def start(): Unit
+
+  // Invoked after system has successfully initialized (typically in spark context).
+  // Yarn uses this to bootstrap allocation of resources based on preferred locations, wait for slave registerations, etc.
+  def postStartHook() { }
+
+  // Disconnect from the cluster.
+  def stop(): Unit
+
+  // Submit a sequence of tasks to run.
+  def submitTasks(taskSet: TaskSet): Unit
+
+  // Cancel a stage.
+  def cancelTasks(stageId: Int)
+
+  // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called.
+  def setDAGScheduler(dagScheduler: DAGScheduler): Unit
+
+  // Get the default level of parallelism to use in the cluster, as a hint for sizing jobs.
+  def defaultParallelism(): Int
+}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 8757d7fd2a..bc35e53220 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -32,7 +32,7 @@ import org.apache.spark.util.{SystemClock, Clock}
 
 
 /**
- * Schedules the tasks within a single TaskSet in the TaskScheduler. This class keeps track of
+ * Schedules the tasks within a single TaskSet in the ClusterScheduler. This class keeps track of
  * each task, retries tasks if they fail (up to a limited number of times), and
  * handles locality-aware scheduling for this TaskSet via delay scheduling. The main interfaces
  * to it are resourceOffer, which asks the TaskSet whether it wants to run a task on one node,
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 3bb715e7d0..3af02b42b2 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -29,7 +29,7 @@ import akka.util.Duration
 import akka.util.duration._
 
 import org.apache.spark.{SparkException, Logging, TaskState}
-import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, ClusterScheduler,
+import org.apache.spark.scheduler.{ClusterScheduler, SchedulerBackend, SlaveLost, TaskDescription,
   WorkerOffer}
 import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
 import org.apache.spark.util.Utils
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index 3acad1bb46..773b980c53 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -209,7 +209,7 @@ private[spark] class MesosSchedulerBackend(
             getResource(offer.getResourcesList, "cpus").toInt)
         }
 
-        // Call into the TaskScheduler
+        // Call into the ClusterScheduler
         val taskLists = scheduler.resourceOffers(offerableWorkers)
 
         // Build a list of Mesos tasks for each slave
diff --git a/core/src/test/scala/org/apache/spark/FailureSuite.scala b/core/src/test/scala/org/apache/spark/FailureSuite.scala
index 2f7d6dff38..af448fcb37 100644
--- a/core/src/test/scala/org/apache/spark/FailureSuite.scala
+++ b/core/src/test/scala/org/apache/spark/FailureSuite.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark
 
-import org.scalatest.{BeforeAndAfterAll, FunSuite}
+import org.scalatest.FunSuite
 
 import SparkContext._
 import org.apache.spark.util.NonSerializable
@@ -37,20 +37,12 @@ object FailureSuiteState {
   }
 }
 
-class FailureSuite extends FunSuite with LocalSparkContext with BeforeAndAfterAll {
-
-  override def beforeAll {
-    System.setProperty("spark.task.maxFailures", "1")
-  }
-
-  override def afterAll {
-    System.clearProperty("spark.task.maxFailures")
-  }
+class FailureSuite extends FunSuite with LocalSparkContext {
 
   // Run a 3-task map job in which task 1 deterministically fails once, and check
   // whether the job completes successfully and we ran 4 tasks in total.
   test("failure in a single-stage job") {
-    sc = new SparkContext("local[1]", "test")
+    sc = new SparkContext("local[1,1]", "test")
     val results = sc.makeRDD(1 to 3, 3).map { x =>
       FailureSuiteState.synchronized {
         FailureSuiteState.tasksRun += 1
@@ -70,7 +62,7 @@ class FailureSuite extends FunSuite with LocalSparkContext with BeforeAndAfterAl
 
   // Run a map-reduce job in which a reduce task deterministically fails once.
   test("failure in a two-stage job") {
-    sc = new SparkContext("local[1]", "test")
+    sc = new SparkContext("local[1,1]", "test")
     val results = sc.makeRDD(1 to 3).map(x => (x, x)).groupByKey(3).map {
       case (k, v) =>
         FailureSuiteState.synchronized {
@@ -90,7 +82,7 @@ class FailureSuite extends FunSuite with LocalSparkContext with BeforeAndAfterAl
   }
 
   test("failure because task results are not serializable") {
-    sc = new SparkContext("local[1]", "test")
+    sc = new SparkContext("local[1,1]", "test")
     val results = sc.makeRDD(1 to 3).map(x => new NonSerializable)
 
     val thrown = intercept[SparkException] {
@@ -103,7 +95,7 @@ class FailureSuite extends FunSuite with LocalSparkContext with BeforeAndAfterAl
   }
 
   test("failure because task closure is not serializable") {
-    sc = new SparkContext("local[1]", "test")
+    sc = new SparkContext("local[1,1]", "test")
     val a = new NonSerializable
 
     // Non-serializable closure in the final result stage
diff --git a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
index 96adcf7198..35a06c4875 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/ClusterSchedulerSuite.scala
@@ -29,9 +29,9 @@ class FakeTaskSetManager(
     initPriority: Int,
     initStageId: Int,
     initNumTasks: Int,
-    taskScheduler: ClusterScheduler,
+    clusterScheduler: ClusterScheduler,
     taskSet: TaskSet)
-  extends TaskSetManager(taskScheduler, taskSet, 1) {
+  extends TaskSetManager(clusterScheduler, taskSet, 0) {
 
   parent = null
   weight = 1
@@ -130,8 +130,8 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
   }
 
   test("FIFO Scheduler Test") {
-    sc = new SparkContext("local", "TaskSchedulerSuite")
-    val taskScheduler = new ClusterScheduler(sc)
+    sc = new SparkContext("local", "ClusterSchedulerSuite")
+    val clusterScheduler = new ClusterScheduler(sc)
     var tasks = ArrayBuffer[Task[_]]()
     val task = new FakeTask(0)
     tasks += task
@@ -141,9 +141,9 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
     val schedulableBuilder = new FIFOSchedulableBuilder(rootPool)
     schedulableBuilder.buildPools()
 
-    val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, taskScheduler, taskSet)
-    val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, taskScheduler, taskSet)
-    val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, taskScheduler, taskSet)
+    val taskSetManager0 = createDummyTaskSetManager(0, 0, 2, clusterScheduler, taskSet)
+    val taskSetManager1 = createDummyTaskSetManager(0, 1, 2, clusterScheduler, taskSet)
+    val taskSetManager2 = createDummyTaskSetManager(0, 2, 2, clusterScheduler, taskSet)
     schedulableBuilder.addTaskSetManager(taskSetManager0, null)
     schedulableBuilder.addTaskSetManager(taskSetManager1, null)
     schedulableBuilder.addTaskSetManager(taskSetManager2, null)
@@ -157,8 +157,8 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
   }
 
   test("Fair Scheduler Test") {
-    sc = new SparkContext("local", "TaskSchedulerSuite")
-    val taskScheduler = new ClusterScheduler(sc)
+    sc = new SparkContext("local", "ClusterSchedulerSuite")
+    val clusterScheduler = new ClusterScheduler(sc)
     var tasks = ArrayBuffer[Task[_]]()
     val task = new FakeTask(0)
     tasks += task
@@ -186,15 +186,15 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
     val properties2 = new Properties()
     properties2.setProperty("spark.scheduler.pool","2")
 
-    val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, taskScheduler, taskSet)
-    val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, taskScheduler, taskSet)
-    val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, taskScheduler, taskSet)
+    val taskSetManager10 = createDummyTaskSetManager(1, 0, 1, clusterScheduler, taskSet)
+    val taskSetManager11 = createDummyTaskSetManager(1, 1, 1, clusterScheduler, taskSet)
+    val taskSetManager12 = createDummyTaskSetManager(1, 2, 2, clusterScheduler, taskSet)
     schedulableBuilder.addTaskSetManager(taskSetManager10, properties1)
     schedulableBuilder.addTaskSetManager(taskSetManager11, properties1)
     schedulableBuilder.addTaskSetManager(taskSetManager12, properties1)
 
-    val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, taskScheduler, taskSet)
-    val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, taskScheduler, taskSet)
+    val taskSetManager23 = createDummyTaskSetManager(2, 3, 2, clusterScheduler, taskSet)
+    val taskSetManager24 = createDummyTaskSetManager(2, 4, 2, clusterScheduler, taskSet)
     schedulableBuilder.addTaskSetManager(taskSetManager23, properties2)
     schedulableBuilder.addTaskSetManager(taskSetManager24, properties2)
 
@@ -214,8 +214,8 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
   }
 
   test("Nested Pool Test") {
-    sc = new SparkContext("local", "TaskSchedulerSuite")
-    val taskScheduler = new ClusterScheduler(sc)
+    sc = new SparkContext("local", "ClusterSchedulerSuite")
+    val clusterScheduler = new ClusterScheduler(sc)
     var tasks = ArrayBuffer[Task[_]]()
     val task = new FakeTask(0)
     tasks += task
@@ -237,23 +237,23 @@ class ClusterSchedulerSuite extends FunSuite with LocalSparkContext with Logging
     pool1.addSchedulable(pool10)
     pool1.addSchedulable(pool11)
 
-    val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, taskScheduler, taskSet)
-    val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, taskScheduler, taskSet)
+    val taskSetManager000 = createDummyTaskSetManager(0, 0, 5, clusterScheduler, taskSet)
+    val taskSetManager001 = createDummyTaskSetManager(0, 1, 5, clusterScheduler, taskSet)
     pool00.addSchedulable(taskSetManager000)
     pool00.addSchedulable(taskSetManager001)
 
-    val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, taskScheduler, taskSet)
-    val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, taskScheduler, taskSet)
+    val taskSetManager010 = createDummyTaskSetManager(1, 2, 5, clusterScheduler, taskSet)
+    val taskSetManager011 = createDummyTaskSetManager(1, 3, 5, clusterScheduler, taskSet)
     pool01.addSchedulable(taskSetManager010)
     pool01.addSchedulable(taskSetManager011)
 
-    val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, taskScheduler, taskSet)
-    val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, taskScheduler, taskSet)
+    val taskSetManager100 = createDummyTaskSetManager(2, 4, 5, clusterScheduler, taskSet)
+    val taskSetManager101 = createDummyTaskSetManager(2, 5, 5, clusterScheduler, taskSet)
     pool10.addSchedulable(taskSetManager100)
     pool10.addSchedulable(taskSetManager101)
 
-    val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, taskScheduler, taskSet)
-    val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, taskScheduler, taskSet)
+    val taskSetManager110 = createDummyTaskSetManager(3, 6, 5, clusterScheduler, taskSet)
+    val taskSetManager111 = createDummyTaskSetManager(3, 7, 5, clusterScheduler, taskSet)
     pool11.addSchedulable(taskSetManager110)
     pool11.addSchedulable(taskSetManager111)
 
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index 24689a7093..00f2fdd657 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -33,25 +33,6 @@ import org.apache.spark.rdd.RDD
 import org.apache.spark.scheduler.SchedulingMode.SchedulingMode
 import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster}
 
-/**
- * TaskScheduler that records the task sets that the DAGScheduler requested executed.
- */
-class TaskSetRecordingTaskScheduler(sc: SparkContext,
-  mapOutputTrackerMaster: MapOutputTrackerMaster) extends ClusterScheduler(sc) {
-  /** Set of TaskSets the DAGScheduler has requested executed. */
-  val taskSets = scala.collection.mutable.Buffer[TaskSet]()
-  override def start() = {}
-  override def stop() = {}
-  override def submitTasks(taskSet: TaskSet) = {
-    // normally done by TaskSetManager
-    taskSet.tasks.foreach(_.epoch = mapOutputTrackerMaster.getEpoch)
-    taskSets += taskSet
-  }
-  override def cancelTasks(stageId: Int) {}
-  override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
-  override def defaultParallelism() = 2
-}
-
 /**
  * Tests for DAGScheduler. These tests directly call the event processing functions in DAGScheduler
  * rather than spawning an event loop thread as happens in the real code. They use EasyMock
@@ -65,7 +46,24 @@ class TaskSetRecordingTaskScheduler(sc: SparkContext,
  * and capturing the resulting TaskSets from the mock TaskScheduler.
  */
 class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
-  var taskScheduler: TaskSetRecordingTaskScheduler = null
+
+  /** Set of TaskSets the DAGScheduler has requested executed. */
+  val taskSets = scala.collection.mutable.Buffer[TaskSet]()
+  val taskScheduler = new TaskScheduler() {
+    override def rootPool: Pool = null
+    override def schedulingMode: SchedulingMode = SchedulingMode.NONE
+    override def start() = {}
+    override def stop() = {}
+    override def submitTasks(taskSet: TaskSet) = {
+      // normally done by TaskSetManager
+      taskSet.tasks.foreach(_.epoch = mapOutputTracker.getEpoch)
+      taskSets += taskSet
+    }
+    override def cancelTasks(stageId: Int) {}
+    override def setDAGScheduler(dagScheduler: DAGScheduler) = {}
+    override def defaultParallelism() = 2
+  }
+
   var mapOutputTracker: MapOutputTrackerMaster = null
   var scheduler: DAGScheduler = null
 
@@ -98,11 +96,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
 
   before {
     sc = new SparkContext("local", "DAGSchedulerSuite")
-    mapOutputTracker = new MapOutputTrackerMaster()
-    taskScheduler = new TaskSetRecordingTaskScheduler(sc, mapOutputTracker)
-    taskScheduler.taskSets.clear()
+    taskSets.clear()
     cacheLocations.clear()
     results.clear()
+    mapOutputTracker = new MapOutputTrackerMaster()
     scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) {
       override def runLocally(job: ActiveJob) {
         // don't bother with the thread while unit testing
@@ -207,7 +204,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
   test("run trivial job") {
     val rdd = makeRdd(1, Nil)
     submit(rdd, Array(0))
-    complete(taskScheduler.taskSets(0), List((Success, 42)))
+    complete(taskSets(0), List((Success, 42)))
     assert(results === Map(0 -> 42))
   }
 
@@ -228,7 +225,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
     val baseRdd = makeRdd(1, Nil)
     val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
     submit(finalRdd, Array(0))
-    complete(taskScheduler.taskSets(0), Seq((Success, 42)))
+    complete(taskSets(0), Seq((Success, 42)))
     assert(results === Map(0 -> 42))
   }
 
@@ -238,7 +235,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
     cacheLocations(baseRdd.id -> 0) =
       Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
     submit(finalRdd, Array(0))
-    val taskSet = taskScheduler.taskSets(0)
+    val taskSet = taskSets(0)
     assertLocations(taskSet, Seq(Seq("hostA", "hostB")))
     complete(taskSet, Seq((Success, 42)))
     assert(results === Map(0 -> 42))
@@ -246,7 +243,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
 
   test("trivial job failure") {
     submit(makeRdd(1, Nil), Array(0))
-    failed(taskScheduler.taskSets(0), "some failure")
+    failed(taskSets(0), "some failure")
     assert(failure.getMessage === "Job aborted: some failure")
   }
 
@@ -256,12 +253,12 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
     val shuffleId = shuffleDep.shuffleId
     val reduceRdd = makeRdd(1, List(shuffleDep))
     submit(reduceRdd, Array(0))
-    complete(taskScheduler.taskSets(0), Seq(
+    complete(taskSets(0), Seq(
         (Success, makeMapStatus("hostA", 1)),
         (Success, makeMapStatus("hostB", 1))))
     assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
            Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
-    complete(taskScheduler.taskSets(1), Seq((Success, 42)))
+    complete(taskSets(1), Seq((Success, 42)))
     assert(results === Map(0 -> 42))
   }
 
@@ -271,11 +268,11 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
     val shuffleId = shuffleDep.shuffleId
     val reduceRdd = makeRdd(2, List(shuffleDep))
     submit(reduceRdd, Array(0, 1))
-    complete(taskScheduler.taskSets(0), Seq(
+    complete(taskSets(0), Seq(
         (Success, makeMapStatus("hostA", 1)),
         (Success, makeMapStatus("hostB", 1))))
     // the 2nd ResultTask failed
-    complete(taskScheduler.taskSets(1), Seq(
+    complete(taskSets(1), Seq(
         (Success, 42),
         (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)))
     // this will get called
@@ -283,10 +280,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
     // ask the scheduler to try it again
     scheduler.resubmitFailedStages()
     // have the 2nd attempt pass
-    complete(taskScheduler.taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))
+    complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))
     // we can see both result blocks now
     assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB"))
-    complete(taskScheduler.taskSets(3), Seq((Success, 43)))
+    complete(taskSets(3), Seq((Success, 43)))
     assert(results === Map(0 -> 42, 1 -> 43))
   }
 
@@ -302,7 +299,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
     val newEpoch = mapOutputTracker.getEpoch
     assert(newEpoch > oldEpoch)
     val noAccum = Map[Long, Any]()
-    val taskSet = taskScheduler.taskSets(0)
+    val taskSet = taskSets(0)
     // should be ignored for being too old
     runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
     // should work because it's a non-failed host
@@ -314,7 +311,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
     runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null))
     assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
            Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
-    complete(taskScheduler.taskSets(1), Seq((Success, 42), (Success, 43)))
+    complete(taskSets(1), Seq((Success, 42), (Success, 43)))
     assert(results === Map(0 -> 42, 1 -> 43))
   }
 
@@ -329,14 +326,14 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
     runEvent(ExecutorLost("exec-hostA"))
     // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks
     // rather than marking it is as failed and waiting.
-    complete(taskScheduler.taskSets(0), Seq(
+    complete(taskSets(0), Seq(
         (Success, makeMapStatus("hostA", 1)),
        (Success, makeMapStatus("hostB", 1))))
    // have hostC complete the resubmitted task
-   complete(taskScheduler.taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
+   complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
    assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
           Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
-   complete(taskScheduler.taskSets(2), Seq((Success, 42)))
+   complete(taskSets(2), Seq((Success, 42)))
    assert(results === Map(0 -> 42))
  }
 
@@ -348,23 +345,23 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
     val finalRdd = makeRdd(1, List(shuffleDepTwo))
     submit(finalRdd, Array(0))
     // have the first stage complete normally
-    complete(taskScheduler.taskSets(0), Seq(
+    complete(taskSets(0), Seq(
         (Success, makeMapStatus("hostA", 2)),
         (Success, makeMapStatus("hostB", 2))))
     // have the second stage complete normally
-    complete(taskScheduler.taskSets(1), Seq(
+    complete(taskSets(1), Seq(
         (Success, makeMapStatus("hostA", 1)),
         (Success, makeMapStatus("hostC", 1))))
     // fail the third stage because hostA went down
-    complete(taskScheduler.taskSets(2), Seq(
+    complete(taskSets(2), Seq(
         (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
     // TODO assert this:
     // blockManagerMaster.removeExecutor("exec-hostA")
     // have DAGScheduler try again
     scheduler.resubmitFailedStages()
-    complete(taskScheduler.taskSets(3), Seq((Success, makeMapStatus("hostA", 2))))
-    complete(taskScheduler.taskSets(4), Seq((Success, makeMapStatus("hostA", 1))))
-    complete(taskScheduler.taskSets(5), Seq((Success, 42)))
+    complete(taskSets(3), Seq((Success, makeMapStatus("hostA", 2))))
+    complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1))))
+    complete(taskSets(5), Seq((Success, 42)))
     assert(results === Map(0 -> 42))
   }
 
@@ -378,24 +375,24 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
     cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
     cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
     // complete stage 2
-    complete(taskScheduler.taskSets(0), Seq(
+    complete(taskSets(0), Seq(
         (Success, makeMapStatus("hostA", 2)),
         (Success, makeMapStatus("hostB", 2))))
     // complete stage 1
-    complete(taskScheduler.taskSets(1), Seq(
+    complete(taskSets(1), Seq(
         (Success, makeMapStatus("hostA", 1)),
         (Success, makeMapStatus("hostB", 1))))
     // pretend stage 0 failed because hostA went down
-    complete(taskScheduler.taskSets(2), Seq(
+    complete(taskSets(2), Seq(
         (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
     // TODO assert this:
     // blockManagerMaster.removeExecutor("exec-hostA")
     // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
     scheduler.resubmitFailedStages()
-    assertLocations(taskScheduler.taskSets(3), Seq(Seq("hostD")))
+    assertLocations(taskSets(3), Seq(Seq("hostD")))
     // allow hostD to recover
-    complete(taskScheduler.taskSets(3), Seq((Success, makeMapStatus("hostD", 1))))
-    complete(taskScheduler.taskSets(4), Seq((Success, 42)))
+    complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1))))
+    complete(taskSets(4), Seq((Success, 42)))
     assert(results === Map(0 -> 42))
   }
 
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
index 2ac2d7a36a..b0d1902c67 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
@@ -64,20 +64,18 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA
     System.setProperty("spark.akka.frameSize", "1")
   }
 
-  before {
-    sc = new SparkContext("local", "test")
-  }
-
   override def afterAll {
     System.clearProperty("spark.akka.frameSize")
   }
 
   test("handling results smaller than Akka frame size") {
+    sc = new SparkContext("local", "test")
     val result = sc.parallelize(Seq(1), 1).map(x => 2 * x).reduce((x, y) => x)
     assert(result === 2)
   }
 
-  test("handling results larger than Akka frame size") { 
+  test("handling results larger than Akka frame size") {
+    sc = new SparkContext("local", "test")
     val akkaFrameSize =
       sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
     val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)
@@ -89,13 +87,16 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA
   }
 
   test("task retried if result missing from block manager") {
+    // Set the maximum number of task failures to > 0, so that the task set isn't aborted
+    // after the result is missing.
+    sc = new SparkContext("local[1,1]", "test")
     // If this test hangs, it's probably because no resource offers were made after the task
     // failed.
     val scheduler: ClusterScheduler = sc.taskScheduler match {
       case clusterScheduler: ClusterScheduler =>
         clusterScheduler
       case _ =>
-        assert(false, "Expect local cluster to use TaskScheduler")
+        assert(false, "Expect local cluster to use ClusterScheduler")
         throw new ClassCastException
     }
     scheduler.taskResultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
index 592bb11364..4bbb51532d 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
@@ -27,7 +27,7 @@ import org.apache.spark.executor.TaskMetrics
 import java.nio.ByteBuffer
 import org.apache.spark.util.{Utils, FakeClock}
 
-class FakeDAGScheduler(taskScheduler: FakeTaskScheduler) extends DAGScheduler(taskScheduler) {
+class FakeDAGScheduler(taskScheduler: FakeClusterScheduler) extends DAGScheduler(taskScheduler) {
   override def taskStarted(task: Task[_], taskInfo: TaskInfo) {
     taskScheduler.startedTasks += taskInfo.index
   }
@@ -52,12 +52,12 @@ class FakeDAGScheduler(taskScheduler: FakeTaskScheduler) extends DAGScheduler(ta
 }
 
 /**
- * A mock TaskScheduler implementation that just remembers information about tasks started and
+ * A mock ClusterScheduler implementation that just remembers information about tasks started and
  * feedback received from the TaskSetManagers. Note that it's important to initialize this with
  * a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost
  * to work, and these are required for locality in TaskSetManager.
  */
-class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */)
+class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */)
   extends ClusterScheduler(sc)
 {
   val startedTasks = new ArrayBuffer[Long]
@@ -86,7 +86,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
 
   test("TaskSet with no preferences") {
     sc = new SparkContext("local", "test")
-    val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
+    val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
     val taskSet = createTaskSet(1)
     val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
 
@@ -112,7 +112,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
 
   test("multiple offers with no preferences") {
     sc = new SparkContext("local", "test")
-    val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
+    val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
     val taskSet = createTaskSet(3)
     val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES)
 
@@ -143,7 +143,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
 
   test("basic delay scheduling") {
     sc = new SparkContext("local", "test")
-    val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
+    val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
     val taskSet = createTaskSet(4,
       Seq(TaskLocation("host1", "exec1")),
       Seq(TaskLocation("host2", "exec2")),
@@ -187,7 +187,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
 
   test("delay scheduling with fallback") {
     sc = new SparkContext("local", "test")
-    val sched = new FakeTaskScheduler(sc,
+    val sched = new FakeClusterScheduler(sc,
       ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3"))
     val taskSet = createTaskSet(5,
       Seq(TaskLocation("host1")),
@@ -227,7 +227,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
 
   test("delay scheduling with failed hosts") {
     sc = new SparkContext("local", "test")
-    val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
+    val sched = new FakeClusterScheduler(sc, ("exec1", "host1"), ("exec2", "host2"))
     val taskSet = createTaskSet(3,
       Seq(TaskLocation("host1")),
       Seq(TaskLocation("host2")),
@@ -259,7 +259,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
 
   test("task result lost") {
     sc = new SparkContext("local", "test")
-    val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
+    val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
     val taskSet = createTaskSet(1)
     val clock = new FakeClock
     val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
@@ -276,7 +276,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
 
   test("repeated failures lead to task set abortion") {
     sc = new SparkContext("local", "test")
-    val sched = new FakeTaskScheduler(sc, ("exec1", "host1"))
+    val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
     val taskSet = createTaskSet(1)
     val clock = new FakeClock
     val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
index e873400680..4e988b8017 100644
--- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClusterScheduler.scala
@@ -21,16 +21,16 @@ import org.apache.hadoop.conf.Configuration
 
 import org.apache.spark._
 import org.apache.spark.deploy.yarn.{ApplicationMaster, YarnAllocationHandler}
-import org.apache.spark.scheduler.TaskScheduler
+import org.apache.spark.scheduler.ClusterScheduler
 import org.apache.spark.util.Utils
 
 /**
  *
- * This is a simple extension to TaskScheduler - to ensure that appropriate initialization of
+ * This is a simple extension to ClusterScheduler - to ensure that appropriate initialization of
  * ApplicationMaster, etc. is done
  */
 private[spark] class YarnClusterScheduler(sc: SparkContext, conf: Configuration)
-  extends TaskScheduler(sc) {
+  extends ClusterScheduler(sc) {
 
   logInfo("Created YarnClusterScheduler")
 
-- 
GitLab