diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
index 58cc1ef185336e4d0037b9e8fa4041d881db033a..80d0c5a5e929ac2c62c2eef41e70f75b9789d841 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
@@ -40,6 +40,7 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /*
   val startedTasks = new ArrayBuffer[Long]
   val endedTasks = new mutable.HashMap[Long, TaskEndReason]
   val finishedManagers = new ArrayBuffer[TaskSetManager]
+  val taskSetsFailed = new ArrayBuffer[String]
 
   val executors = new mutable.HashMap[String, String] ++ liveExecutors
 
@@ -63,7 +64,9 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /*
 
     def executorLost(execId: String) {}
 
-    def taskSetFailed(taskSet: TaskSet, reason: String) {}
+    def taskSetFailed(taskSet: TaskSet, reason: String) {
+      taskSetsFailed += taskSet.id
+    }
   }
 
   def removeExecutor(execId: String): Unit = executors -= execId
@@ -270,6 +273,30 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
     assert(manager.resourceOffer("exec1", "host1", 1, ANY).get.index === 0)
   }
 
+  test("repeated failures lead to task set abortion") {
+    sc = new SparkContext("local", "test")
+    val sched = new FakeClusterScheduler(sc, ("exec1", "host1"))
+    val taskSet = createTaskSet(1)
+    val clock = new FakeClock
+    val manager = new ClusterTaskSetManager(sched, taskSet, clock)
+
+    // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted
+    // after the last failure.
+    (0 until manager.MAX_TASK_FAILURES).foreach { index =>
+      val offerResult = manager.resourceOffer("exec1", "host1", 1, ANY)
+      assert(offerResult != None,
+        "Expect resource offer on iteration %s to return a task".format(index))
+      assert(offerResult.get.index === 0)
+      manager.handleFailedTask(offerResult.get.taskId, TaskState.FINISHED, Some(TaskResultLost))
+      if (index < manager.MAX_TASK_FAILURES) {
+        assert(!sched.taskSetsFailed.contains(taskSet.id))
+      } else {
+        assert(sched.taskSetsFailed.contains(taskSet.id))
+      }
+    }
+  }
+
+
   /**
    * Utility method to create a TaskSet, potentially setting a particular sequence of preferred
    * locations for each task (given as varargs) if this sequence is not empty.