diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala
index 473520758557dc835d848c4a624247347ae439d5..866d630a6d27b8b41b2965c97146d8087d9f450b 100644
--- a/core/src/main/scala/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/spark/MapOutputTracker.scala
@@ -38,9 +38,10 @@ private[spark] class MapOutputTrackerActor(tracker: MapOutputTracker) extends Ac
   }
 }
 
-private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolean) extends Logging {
+private[spark] class MapOutputTracker extends Logging {
 
-  val timeout = 10.seconds
+  // Set to the MapOutputTrackerActor living on the driver
+  var trackerActor: ActorRef = _
 
   var mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
 
@@ -53,24 +54,13 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isDriver: Boolea
   var cacheGeneration = generation
   val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]
 
-  val actorName: String = "MapOutputTracker"
-  var trackerActor: ActorRef = if (isDriver) {
-    val actor = actorSystem.actorOf(Props(new MapOutputTrackerActor(this)), name = actorName)
-    logInfo("Registered MapOutputTrackerActor actor")
-    actor
-  } else {
-    val ip = System.getProperty("spark.driver.host", "localhost")
-    val port = System.getProperty("spark.driver.port", "7077").toInt
-    val url = "akka://spark@%s:%s/user/%s".format(ip, port, actorName)
-    actorSystem.actorFor(url)
-  }
-
   val metadataCleaner = new MetadataCleaner("MapOutputTracker", this.cleanup)
 
   // Send a message to the trackerActor and get its result within a default timeout, or
   // throw a SparkException if this fails.
   def askTracker(message: Any): Any = {
     try {
+      val timeout = 10.seconds
       val future = trackerActor.ask(message)(timeout)
       return Await.result(future, timeout)
     } catch {
diff --git a/core/src/main/scala/spark/SparkEnv.scala b/core/src/main/scala/spark/SparkEnv.scala
index d2193ae72b3d09180421a72096a9bd5f9553b3c2..7157fd26883d3a3f7b29fb71fc272886a92ecfd5 100644
--- a/core/src/main/scala/spark/SparkEnv.scala
+++ b/core/src/main/scala/spark/SparkEnv.scala
@@ -1,7 +1,6 @@
 package spark
 
-import akka.actor.ActorSystem
-import akka.actor.ActorSystemImpl
+import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
 import akka.remote.RemoteActorRefProvider
 
 import serializer.Serializer
@@ -83,11 +82,23 @@ object SparkEnv extends Logging {
     }
 
     val serializer = instantiateClass[Serializer]("spark.serializer", "spark.JavaSerializer")
+    
+    def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
+      if (isDriver) {
+        logInfo("Registering " + name)
+        actorSystem.actorOf(Props(newActor), name = name)
+      } else {
+        val driverIp: String = System.getProperty("spark.driver.host", "localhost")
+        val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
+        val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, name)
+        logInfo("Connecting to " + name + ": " + url)
+        actorSystem.actorFor(url)
+      }
+    }
 
-    val driverIp: String = System.getProperty("spark.driver.host", "localhost")
-    val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
-    val blockManagerMaster = new BlockManagerMaster(
-      actorSystem, isDriver, isLocal, driverIp, driverPort)
+    val blockManagerMaster = new BlockManagerMaster(registerOrLookup(
+      "BlockManagerMaster",
+      new spark.storage.BlockManagerMasterActor(isLocal)))
     val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, serializer)
 
     val connectionManager = blockManager.connectionManager
@@ -99,7 +110,12 @@ object SparkEnv extends Logging {
 
     val cacheManager = new CacheManager(blockManager)
 
-    val mapOutputTracker = new MapOutputTracker(actorSystem, isDriver)
+    // Have to assign trackerActor after initialization as MapOutputTrackerActor
+    // requires the MapOutputTracker itself
+    val mapOutputTracker = new MapOutputTracker()
+    mapOutputTracker.trackerActor = registerOrLookup(
+      "MapOutputTracker",
+      new MapOutputTrackerActor(mapOutputTracker))
 
     val shuffleFetcher = instantiateClass[ShuffleFetcher](
       "spark.shuffle.fetcher", "spark.BlockStoreShuffleFetcher")
@@ -137,4 +153,5 @@ object SparkEnv extends Logging {
       httpFileServer,
       sparkFilesDir)
   }
+  
 }
diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
index 1bf5054f4d7947cd439d434d2a13eb4a5c11dda1..c54dce51d783969e09f6b924db09aefc1352e6a4 100644
--- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala
@@ -385,29 +385,34 @@ class DAGScheduler(
    * We run the operation in a separate thread just in case it takes a bunch of time, so that we
    * don't block the DAGScheduler event loop or other concurrent jobs.
    */
-  private def runLocally(job: ActiveJob) {
+  protected def runLocally(job: ActiveJob) {
     logInfo("Computing the requested partition locally")
     new Thread("Local computation of job " + job.runId) {
       override def run() {
-        try {
-          SparkEnv.set(env)
-          val rdd = job.finalStage.rdd
-          val split = rdd.partitions(job.partitions(0))
-          val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
-          try {
-            val result = job.func(taskContext, rdd.iterator(split, taskContext))
-            job.listener.taskSucceeded(0, result)
-          } finally {
-            taskContext.executeOnCompleteCallbacks()
-          }
-        } catch {
-          case e: Exception =>
-            job.listener.jobFailed(e)
-        }
+        runLocallyWithinThread(job)
       }
     }.start()
   }
 
+  // Broken out for easier testing in DAGSchedulerSuite.
+  protected def runLocallyWithinThread(job: ActiveJob) {
+    try {
+      SparkEnv.set(env)
+      val rdd = job.finalStage.rdd
+      val split = rdd.partitions(job.partitions(0))
+      val taskContext = new TaskContext(job.finalStage.id, job.partitions(0), 0)
+      try {
+        val result = job.func(taskContext, rdd.iterator(split, taskContext))
+        job.listener.taskSucceeded(0, result)
+      } finally {
+        taskContext.executeOnCompleteCallbacks()
+      }
+    } catch {
+      case e: Exception =>
+        job.listener.jobFailed(e)
+    }
+  }
+
   /** Submits stage, but first recursively submits any missing parents. */
   private def submitStage(stage: Stage) {
     logDebug("submitStage(" + stage + ")")
diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index 5849045a55f505b50b2933a068bba7ba72aeab94..3118d3d412b2f207f502c0cbbbd8479837217fb3 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -103,7 +103,7 @@ class BlockManager(
 
   val host = System.getProperty("spark.hostname", Utils.localHostName())
 
-  val slaveActor = master.actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
+  val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
     name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
 
   // Pending reregistration action being executed asynchronously or null if none
@@ -853,7 +853,7 @@ class BlockManager(
       heartBeatTask.cancel()
     }
     connectionManager.stop()
-    master.actorSystem.stop(slaveActor)
+    actorSystem.stop(slaveActor)
     blockInfo.clear()
     memoryStore.clear()
     diskStore.clear()
diff --git a/core/src/main/scala/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
index d93cfc48d03fc563dd3bf91ed8c8a3c45ea24b14..036fdc3480119307f9c01094bfd47fd5e75c06e2 100644
--- a/core/src/main/scala/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/spark/storage/BlockManagerMaster.scala
@@ -15,13 +15,7 @@ import akka.util.duration._
 
 import spark.{Logging, SparkException, Utils}
 
-private[spark] class BlockManagerMaster(
-    val actorSystem: ActorSystem,
-    isDriver: Boolean,
-    isLocal: Boolean,
-    driverIp: String,
-    driverPort: Int)
-  extends Logging {
+private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging {
 
   val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt
   val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt
@@ -29,18 +23,6 @@ private[spark] class BlockManagerMaster(
   val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster"
 
   val timeout = 10.seconds
-  var driverActor: ActorRef = {
-    if (isDriver) {
-      val driverActor = actorSystem.actorOf(Props(new BlockManagerMasterActor(isLocal)),
-        name = DRIVER_AKKA_ACTOR_NAME)
-      logInfo("Registered BlockManagerMaster Actor")
-      driverActor
-    } else {
-      val url = "akka://spark@%s:%s/user/%s".format(driverIp, driverPort, DRIVER_AKKA_ACTOR_NAME)
-      logInfo("Connecting to BlockManagerMaster: " + url)
-      actorSystem.actorFor(url)
-    }
-  }
 
   /** Remove a dead executor from the driver actor. This is only called on the driver side. */
   def removeExecutor(execId: String) {
@@ -59,7 +41,7 @@ private[spark] class BlockManagerMaster(
 
   /** Register the BlockManager's id with the driver. */
   def registerBlockManager(
-    blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
+      blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) {
     logInfo("Trying to register BlockManager")
     tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor))
     logInfo("Registered BlockManager")
diff --git a/core/src/main/scala/spark/storage/ThreadingTest.scala b/core/src/main/scala/spark/storage/ThreadingTest.scala
index a70d1c8e78e109523552f1b9cdb1a7b47d2200fd..5c406e68cb2ac5b061dcaebd101aa8c5bcb6083b 100644
--- a/core/src/main/scala/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/spark/storage/ThreadingTest.scala
@@ -75,9 +75,8 @@ private[spark] object ThreadingTest {
     System.setProperty("spark.kryoserializer.buffer.mb", "1")
     val actorSystem = ActorSystem("test")
     val serializer = new KryoSerializer
-    val driverIp: String = System.getProperty("spark.driver.host", "localhost")
-    val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
-    val blockManagerMaster = new BlockManagerMaster(actorSystem, true, true, driverIp, driverPort)
+    val blockManagerMaster = new BlockManagerMaster(
+      actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))
     val blockManager = new BlockManager(
       "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024)
     val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
index dd19442dcb4ad656cb73e1ccbe1e10c523617142..3abc584b6a177e0159ddecea2dc7e85ac87eb2cc 100644
--- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala
@@ -31,13 +31,15 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
 
   test("master start and stop") {
     val actorSystem = ActorSystem("test")
-    val tracker = new MapOutputTracker(actorSystem, true)
+    val tracker = new MapOutputTracker()
+    tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
     tracker.stop()
   }
 
   test("master register and fetch") {
     val actorSystem = ActorSystem("test")
-    val tracker = new MapOutputTracker(actorSystem, true)
+    val tracker = new MapOutputTracker()
+    tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
     tracker.registerShuffle(10, 2)
     val compressedSize1000 = MapOutputTracker.compressSize(1000L)
     val compressedSize10000 = MapOutputTracker.compressSize(10000L)
@@ -55,7 +57,8 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
 
   test("master register and unregister and fetch") {
     val actorSystem = ActorSystem("test")
-    val tracker = new MapOutputTracker(actorSystem, true)
+    val tracker = new MapOutputTracker()
+    tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerActor(tracker)))
     tracker.registerShuffle(10, 2)
     val compressedSize1000 = MapOutputTracker.compressSize(1000L)
     val compressedSize10000 = MapOutputTracker.compressSize(10000L)
@@ -77,35 +80,36 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
   }
 
   test("remote fetch") {
-    try {
-      System.clearProperty("spark.driver.host")  // In case some previous test had set it
-      val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", "localhost", 0)
-      System.setProperty("spark.driver.port", boundPort.toString)
-      val masterTracker = new MapOutputTracker(actorSystem, true)
-      val slaveTracker = new MapOutputTracker(actorSystem, false)
-      masterTracker.registerShuffle(10, 1)
-      masterTracker.incrementGeneration()
-      slaveTracker.updateGeneration(masterTracker.getGeneration)
-      intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+    val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", "localhost", 0)
+    val masterTracker = new MapOutputTracker()
+    masterTracker.trackerActor = actorSystem.actorOf(
+        Props(new MapOutputTrackerActor(masterTracker)), "MapOutputTracker")
+    
+    val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", "localhost", 0)
+    val slaveTracker = new MapOutputTracker()
+    slaveTracker.trackerActor = slaveSystem.actorFor(
+        "akka://spark@localhost:" + boundPort + "/user/MapOutputTracker")
+    
+    masterTracker.registerShuffle(10, 1)
+    masterTracker.incrementGeneration()
+    slaveTracker.updateGeneration(masterTracker.getGeneration)
+    intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
 
-      val compressedSize1000 = MapOutputTracker.compressSize(1000L)
-      val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
-      masterTracker.registerMapOutput(10, 0, new MapStatus(
-        BlockManagerId("a", "hostA", 1000), Array(compressedSize1000)))
-      masterTracker.incrementGeneration()
-      slaveTracker.updateGeneration(masterTracker.getGeneration)
-      assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
-             Seq((BlockManagerId("a", "hostA", 1000), size1000)))
+    val compressedSize1000 = MapOutputTracker.compressSize(1000L)
+    val size1000 = MapOutputTracker.decompressSize(compressedSize1000)
+    masterTracker.registerMapOutput(10, 0, new MapStatus(
+      BlockManagerId("a", "hostA", 1000), Array(compressedSize1000)))
+    masterTracker.incrementGeneration()
+    slaveTracker.updateGeneration(masterTracker.getGeneration)
+    assert(slaveTracker.getServerStatuses(10, 0).toSeq ===
+           Seq((BlockManagerId("a", "hostA", 1000), size1000)))
 
-      masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
-      masterTracker.incrementGeneration()
-      slaveTracker.updateGeneration(masterTracker.getGeneration)
-      intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
+    masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
+    masterTracker.incrementGeneration()
+    slaveTracker.updateGeneration(masterTracker.getGeneration)
+    intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
 
-      // failure should be cached
-      intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
-    } finally {
-      System.clearProperty("spark.driver.port")
-    }
+    // failure should be cached
+    intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) }
   }
 }
diff --git a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
index b3e6ab4c0f45e4357ebc1ba4fcc204115b3d2450..710df929f6afe8802357a5437176df5d485c1103 100644
--- a/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/spark/scheduler/DAGSchedulerSuite.scala
@@ -4,16 +4,6 @@ import scala.collection.mutable.{Map, HashMap}
 
 import org.scalatest.FunSuite
 import org.scalatest.BeforeAndAfter
-import org.scalatest.concurrent.TimeLimitedTests
-import org.scalatest.mock.EasyMockSugar
-import org.scalatest.time.{Span, Seconds}
-
-import org.easymock.EasyMock._
-import org.easymock.Capture
-import org.easymock.EasyMock
-import org.easymock.{IAnswer, IArgumentMatcher}
-
-import akka.actor.ActorSystem
 
 import spark.storage.BlockManager
 import spark.storage.BlockManagerId
@@ -42,27 +32,26 @@ import spark.{FetchFailed, Success}
  * DAGScheduler#submitWaitingStages (via test utility functions like runEvent or respondToTaskSet)
  * and capturing the resulting TaskSets from the mock TaskScheduler.
  */
-class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar with TimeLimitedTests {
+class DAGSchedulerSuite extends FunSuite with BeforeAndAfter {
+  
+  val sc: SparkContext = new SparkContext("local", "DAGSchedulerSuite")
 
-  // impose a time limit on this test in case we don't let the job finish, in which case
-  // JobWaiter#getResult will hang.
-  override val timeLimit = Span(5, Seconds)
+  /** Set of TaskSets the DAGScheduler has requested executed. */
+  val taskSets = scala.collection.mutable.Buffer[TaskSet]()
+  val taskScheduler = new TaskScheduler() {
+    override def start() = {}
+    override def stop() = {}
+    override def submitTasks(taskSet: TaskSet) = {
+      // normally done by TaskSetManager
+      taskSet.tasks.foreach(_.generation = mapOutputTracker.getGeneration)
+      taskSets += taskSet 
+    }
+    override def setListener(listener: TaskSchedulerListener) = {}
+    override def defaultParallelism() = 2
+  }
 
-  val sc: SparkContext = new SparkContext("local", "DAGSchedulerSuite")
-  var scheduler: DAGScheduler = null
-  val taskScheduler = mock[TaskScheduler]
-  val blockManagerMaster = mock[BlockManagerMaster]
   var mapOutputTracker: MapOutputTracker = null
-  var schedulerThread: Thread = null
-  var schedulerException: Throwable = null
-
-  /**
-   * Set of EasyMock argument matchers that match a TaskSet for a given RDD.
-   * We cache these so we do not create duplicate matchers for the same RDD.
-   * This allows us to easily setup a sequence of expectations for task sets for
-   * that RDD.
-   */
-  val taskSetMatchers = new HashMap[MyRDD, IArgumentMatcher]
+  var scheduler: DAGScheduler = null
 
   /**
    * Set of cache locations to return from our mock BlockManagerMaster.
@@ -70,68 +59,51 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
    * list of cache locations silently.
    */
   val cacheLocations = new HashMap[(Int, Int), Seq[BlockManagerId]]
-
-  /**
-   * JobWaiter for the last JobSubmitted event we pushed. To keep tests (most of which
-   * will only submit one job) from needing to explicitly track it.
-   */
-  var lastJobWaiter: JobWaiter[Int] = null
-
-  /**
-   * Array into which we are accumulating the results from the last job asynchronously.
-   */
-  var lastJobResult: Array[Int] = null
-
-  /**
-   * Tell EasyMockSugar what mock objects we want to be configured by expecting {...}
-   * and whenExecuting {...} */
-  implicit val mocks = MockObjects(taskScheduler, blockManagerMaster)
-
-  /**
-   * Utility function to reset mocks and set expectations on them. EasyMock wants mock objects
-   * to be reset after each time their expectations are set, and we tend to check mock object
-   * calls over a single call to DAGScheduler.
-   *
-   * We also set a default expectation here that blockManagerMaster.getLocations can be called
-   * and will return values from cacheLocations.
-   */
-  def resetExpecting(f: => Unit) {
-    reset(taskScheduler)
-    reset(blockManagerMaster)
-    expecting {
-      expectGetLocations()
-      f
+  // stub out BlockManagerMaster.getLocations to use our cacheLocations
+  val blockManagerMaster = new BlockManagerMaster(null) {
+      override def getLocations(blockIds: Array[String]): Seq[Seq[BlockManagerId]] = {
+        blockIds.map { name =>
+          val pieces = name.split("_")
+          if (pieces(0) == "rdd") {
+            val key = pieces(1).toInt -> pieces(2).toInt
+            cacheLocations.getOrElse(key, Seq())
+          } else {
+            Seq()
+          }
+        }.toSeq
+      }
+      override def removeExecutor(execId: String) {
+        // don't need to propagate to the driver, which we don't have
+      }
     }
+
+  /** The list of results that DAGScheduler has collected. */
+  val results = new HashMap[Int, Any]()
+  var failure: Exception = _
+  val listener = new JobListener() {
+    override def taskSucceeded(index: Int, result: Any) = results.put(index, result)
+    override def jobFailed(exception: Exception) = { failure = exception }
   }
 
   before {
-    taskSetMatchers.clear()
+    taskSets.clear()
     cacheLocations.clear()
-    val actorSystem = ActorSystem("test")
-    mapOutputTracker = new MapOutputTracker(actorSystem, true)
-    resetExpecting {
-      taskScheduler.setListener(anyObject())
-    }
-    whenExecuting {
-      scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null)
+    results.clear()
+    mapOutputTracker = new MapOutputTracker()
+    scheduler = new DAGScheduler(taskScheduler, mapOutputTracker, blockManagerMaster, null) {
+      override def runLocally(job: ActiveJob) {
+        // don't bother with the thread while unit testing
+        runLocallyWithinThread(job)
+      }
     }
   }
 
   after {
-    assert(scheduler.processEvent(StopDAGScheduler))
-    resetExpecting {
-      taskScheduler.stop()
-    }
-    whenExecuting {
-      scheduler.stop()
-    }
+    scheduler.stop()
     sc.stop()
     System.clearProperty("spark.master.port")
   }
 
-  def makeBlockManagerId(host: String): BlockManagerId =
-    BlockManagerId("exec-" + host, host, 12345)
-
   /**
    * Type of RDD we use for testing. Note that we should never call the real RDD compute methods.
    * This is a pair RDD type so it can always be used in ShuffleDependencies.
@@ -143,7 +115,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
    * preferredLocations (if any) that are passed to them. They are deliberately not executable
    * so we can test that DAGScheduler does not try to execute RDDs locally.
    */
-  def makeRdd(
+  private def makeRdd(
         numPartitions: Int,
         dependencies: List[Dependency[_]],
         locations: Seq[Seq[String]] = Nil
@@ -164,55 +136,6 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
     }
   }
 
-  /**
-   * EasyMock matcher method. For use as an argument matcher for a TaskSet whose first task
-   * is from a particular RDD.
-   */
-  def taskSetForRdd(rdd: MyRDD): TaskSet = {
-    val matcher = taskSetMatchers.getOrElseUpdate(rdd,
-      new IArgumentMatcher {
-        override def matches(actual: Any): Boolean = {
-          val taskSet = actual.asInstanceOf[TaskSet]
-          taskSet.tasks(0) match {
-            case rt: ResultTask[_, _] => rt.rdd.id == rdd.id
-            case smt: ShuffleMapTask => smt.rdd.id == rdd.id
-            case _ => false
-          }
-        }
-        override def appendTo(buf: StringBuffer) {
-          buf.append("taskSetForRdd(" + rdd + ")")
-        }
-      })
-    EasyMock.reportMatcher(matcher)
-    return null
-  }
-
-  /**
-   * Setup an EasyMock expectation to repsond to blockManagerMaster.getLocations() called from
-   * cacheLocations.
-   */
-  def expectGetLocations(): Unit = {
-    EasyMock.expect(blockManagerMaster.getLocations(anyObject().asInstanceOf[Array[String]])).
-        andAnswer(new IAnswer[Seq[Seq[BlockManagerId]]] {
-      override def answer(): Seq[Seq[BlockManagerId]] = {
-        val blocks = getCurrentArguments()(0).asInstanceOf[Array[String]]
-        return blocks.map { name =>
-          val pieces = name.split("_")
-          if (pieces(0) == "rdd") {
-            val key = pieces(1).toInt -> pieces(2).toInt
-            if (cacheLocations.contains(key)) {
-              cacheLocations(key)
-            } else {
-              Seq[BlockManagerId]()
-            }
-          } else {
-            Seq[BlockManagerId]()
-          }
-        }.toSeq
-      }
-    }).anyTimes()
-  }
-
   /**
    * Process the supplied event as if it were the top of the DAGScheduler event queue, expecting
    * the scheduler not to exit.
@@ -220,48 +143,21 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
    * After processing the event, submit waiting stages as is done on most iterations of the
    * DAGScheduler event loop.
    */
-  def runEvent(event: DAGSchedulerEvent) {
+  private def runEvent(event: DAGSchedulerEvent) {
     assert(!scheduler.processEvent(event))
     scheduler.submitWaitingStages()
   }
 
   /**
-   * Expect a TaskSet for the specified RDD to be submitted to the TaskScheduler. Should be
-   * called from a resetExpecting { ... } block.
-   *
-   * Returns a easymock Capture that will contain the task set after the stage is submitted.
-   * Most tests should use interceptStage() instead of this directly.
-   */
-  def expectStage(rdd: MyRDD): Capture[TaskSet] = {
-    val taskSetCapture = new Capture[TaskSet]
-    taskScheduler.submitTasks(and(capture(taskSetCapture), taskSetForRdd(rdd)))
-    return taskSetCapture
-  }
-
-  /**
-   * Expect the supplied code snippet to submit a stage for the specified RDD.
-   * Return the resulting TaskSet. First marks all the tasks are belonging to the
-   * current MapOutputTracker generation.
+   * When we submit dummy Jobs, this is the compute function we supply. Except in a local test
+   * below, we do not expect this function to ever be executed; instead, we will return results
+   * directly through CompletionEvents.
    */
-  def interceptStage(rdd: MyRDD)(f: => Unit): TaskSet = {
-    var capture: Capture[TaskSet] = null
-    resetExpecting {
-      capture = expectStage(rdd)
-    }
-    whenExecuting {
-      f
-    }
-    val taskSet = capture.getValue
-    for (task <- taskSet.tasks) {
-      task.generation = mapOutputTracker.getGeneration
-    }
-    return taskSet
-  }
+  private val jobComputeFunc = (context: TaskContext, it: Iterator[(_)]) =>
+     it.next.asInstanceOf[Tuple2[_, _]]._1
 
-  /**
-   * Send the given CompletionEvent messages for the tasks in the TaskSet.
-   */
-  def respondToTaskSet(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
+  /** Send the given CompletionEvent messages for the tasks in the TaskSet. */
+  private def complete(taskSet: TaskSet, results: Seq[(TaskEndReason, Any)]) {
     assert(taskSet.tasks.size >= results.size)
     for ((result, i) <- results.zipWithIndex) {
       if (i < taskSet.tasks.size) {
@@ -269,88 +165,38 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
       }
     }
   }
-
-  /**
-   * Assert that the supplied TaskSet has exactly the given preferredLocations.
-   */
-  def expectTaskSetLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) {
-    assert(locations.size === taskSet.tasks.size)
-    for ((expectLocs, taskLocs) <-
-            taskSet.tasks.map(_.preferredLocations).zip(locations)) {
-      assert(expectLocs === taskLocs)
-    }
-  }
-
-  /**
-   * When we submit dummy Jobs, this is the compute function we supply. Except in a local test
-   * below, we do not expect this function to ever be executed; instead, we will return results
-   * directly through CompletionEvents.
-   */
-  def jobComputeFunc(context: TaskContext, it: Iterator[(Int, Int)]): Int =
-     it.next._1.asInstanceOf[Int]
-
-
-  /**
-   * Start a job to compute the given RDD. Returns the JobWaiter that will
-   * collect the result of the job via callbacks from DAGScheduler.
-   */
-  def submitRdd(rdd: MyRDD, allowLocal: Boolean = false): (JobWaiter[Int], Array[Int]) = {
-    val resultArray = new Array[Int](rdd.partitions.size)
-    val (toSubmit, waiter) = scheduler.prepareJob[(Int, Int), Int](
-        rdd,
-        jobComputeFunc,
-        (0 to (rdd.partitions.size - 1)),
-        "test-site",
-        allowLocal,
-        (i: Int, value: Int) => resultArray(i) = value
-    )
-    lastJobWaiter = waiter
-    lastJobResult = resultArray
-    runEvent(toSubmit)
-    return (waiter, resultArray)
-  }
-
-  /**
-   * Assert that a job we started has failed.
-   */
-  def expectJobException(waiter: JobWaiter[Int] = lastJobWaiter) {
-    waiter.awaitResult() match {
-      case JobSucceeded => fail()
-      case JobFailed(_) => return
-    }
+     
+  /** Sends the rdd to the scheduler for scheduling. */
+  private def submit(
+      rdd: RDD[_],
+      partitions: Array[Int],
+      func: (TaskContext, Iterator[_]) => _ = jobComputeFunc,
+      allowLocal: Boolean = false,
+      listener: JobListener = listener) {
+    runEvent(JobSubmitted(rdd, func, partitions, allowLocal, null, listener))
   }
-
-  /**
-   * Assert that a job we started has succeeded and has the given result.
-   */
-  def expectJobResult(expected: Array[Int], waiter: JobWaiter[Int] = lastJobWaiter,
-                      result: Array[Int] = lastJobResult) {
-    waiter.awaitResult match {
-      case JobSucceeded =>
-        assert(expected === result)
-      case JobFailed(_) =>
-        fail()
-    }
+  
+  /** Sends TaskSetFailed to the scheduler. */
+  private def failed(taskSet: TaskSet, message: String) {
+    runEvent(TaskSetFailed(taskSet, message))
   }
 
-  def makeMapStatus(host: String, reduces: Int): MapStatus =
-    new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
-
   test("zero split job") {
     val rdd = makeRdd(0, Nil)
     var numResults = 0
-    def accumulateResult(partition: Int, value: Int) {
-      numResults += 1
+    val fakeListener = new JobListener() {
+      override def taskSucceeded(partition: Int, value: Any) = numResults += 1
+      override def jobFailed(exception: Exception) = throw exception
     }
-    scheduler.runJob(rdd, jobComputeFunc, Seq(), "test-site", false, accumulateResult)
+    submit(rdd, Array(), listener = fakeListener)
     assert(numResults === 0)
   }
 
   test("run trivial job") {
     val rdd = makeRdd(1, Nil)
-    val taskSet = interceptStage(rdd) { submitRdd(rdd) }
-    respondToTaskSet(taskSet, List( (Success, 42) ))
-    expectJobResult(Array(42))
+    submit(rdd, Array(0))
+    complete(taskSets(0), List((Success, 42)))
+    assert(results === Map(0 -> 42))
   }
 
   test("local job") {
@@ -361,16 +207,16 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
       override def getPreferredLocations(split: Partition) = Nil
       override def toString = "DAGSchedulerSuite Local RDD"
     }
-    submitRdd(rdd, true)
-    expectJobResult(Array(42))
+    runEvent(JobSubmitted(rdd, jobComputeFunc, Array(0), true, null, listener))
+    assert(results === Map(0 -> 42))
   }
-
+  
   test("run trivial job w/ dependency") {
     val baseRdd = makeRdd(1, Nil)
     val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
-    val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
-    respondToTaskSet(taskSet, List( (Success, 42) ))
-    expectJobResult(Array(42))
+    submit(finalRdd, Array(0)) 
+    complete(taskSets(0), Seq((Success, 42)))
+    assert(results === Map(0 -> 42))
   }
 
   test("cache location preferences w/ dependency") {
@@ -378,17 +224,17 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
     val finalRdd = makeRdd(1, List(new OneToOneDependency(baseRdd)))
     cacheLocations(baseRdd.id -> 0) =
       Seq(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))
-    val taskSet = interceptStage(finalRdd) { submitRdd(finalRdd) }
-    expectTaskSetLocations(taskSet, List(Seq("hostA", "hostB")))
-    respondToTaskSet(taskSet, List( (Success, 42) ))
-    expectJobResult(Array(42))
+    submit(finalRdd, Array(0))
+    val taskSet = taskSets(0)
+    assertLocations(taskSet, Seq(Seq("hostA", "hostB")))
+    complete(taskSet, Seq((Success, 42)))
+    assert(results === Map(0 -> 42))
   }
 
   test("trivial job failure") {
-    val rdd = makeRdd(1, Nil)
-    val taskSet = interceptStage(rdd) { submitRdd(rdd) }
-    runEvent(TaskSetFailed(taskSet, "test failure"))
-    expectJobException()
+    submit(makeRdd(1, Nil), Array(0))
+    failed(taskSets(0), "some failure")
+    assert(failure.getMessage === "Job failed: some failure")
   }
 
   test("run trivial shuffle") {
@@ -396,52 +242,39 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
     val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
     val shuffleId = shuffleDep.shuffleId
     val reduceRdd = makeRdd(1, List(shuffleDep))
-
-    val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
-    val secondStage = interceptStage(reduceRdd) {
-      respondToTaskSet(firstStage, List(
+    submit(reduceRdd, Array(0))
+    complete(taskSets(0), Seq(
         (Success, makeMapStatus("hostA", 1)),
-        (Success, makeMapStatus("hostB", 1))
-      ))
-    }
+        (Success, makeMapStatus("hostB", 1))))
     assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
            Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
-    respondToTaskSet(secondStage, List( (Success, 42) ))
-    expectJobResult(Array(42))
+    complete(taskSets(1), Seq((Success, 42)))
+    assert(results === Map(0 -> 42))
   }
-
+  
   test("run trivial shuffle with fetch failure") {
     val shuffleMapRdd = makeRdd(2, Nil)
     val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
     val shuffleId = shuffleDep.shuffleId
     val reduceRdd = makeRdd(2, List(shuffleDep))
-
-    val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
-    val secondStage = interceptStage(reduceRdd) {
-      respondToTaskSet(firstStage, List(
+    submit(reduceRdd, Array(0, 1))
+    complete(taskSets(0), Seq(
         (Success, makeMapStatus("hostA", 1)),
-        (Success, makeMapStatus("hostB", 1))
-      ))
-    }
-    resetExpecting {
-      blockManagerMaster.removeExecutor("exec-hostA")
-    }
-    whenExecuting {
-      respondToTaskSet(secondStage, List(
+        (Success, makeMapStatus("hostB", 1))))
+    // the 2nd ResultTask failed
+    complete(taskSets(1), Seq(
         (Success, 42),
-        (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)
-      ))
-    }
-    val thirdStage = interceptStage(shuffleMapRdd) {
-      scheduler.resubmitFailedStages()
-    }
-    val fourthStage = interceptStage(reduceRdd) {
-      respondToTaskSet(thirdStage, List( (Success, makeMapStatus("hostA", 1)) ))
-    }
-    assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
-                   Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
-    respondToTaskSet(fourthStage, List( (Success, 43) ))
-    expectJobResult(Array(42, 43))
+        (FetchFailed(makeBlockManagerId("hostA"), shuffleId, 0, 0), null)))
+    // this will get called
+    // blockManagerMaster.removeExecutor("exec-hostA")
+    // ask the scheduler to try it again
+    scheduler.resubmitFailedStages()
+    // have the 2nd attempt pass
+    complete(taskSets(2), Seq((Success, makeMapStatus("hostA", 1))))
+    // we can see both result blocks now
+    assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.ip) === Array("hostA", "hostB"))
+    complete(taskSets(3), Seq((Success, 43)))
+    assert(results === Map(0 -> 42, 1 -> 43))
   }
 
   test("ignore late map task completions") {
@@ -449,33 +282,27 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
     val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
     val shuffleId = shuffleDep.shuffleId
     val reduceRdd = makeRdd(2, List(shuffleDep))
-
-    val taskSet = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
+    submit(reduceRdd, Array(0, 1))
+    // pretend we were told hostA went away
     val oldGeneration = mapOutputTracker.getGeneration
-    resetExpecting {
-      blockManagerMaster.removeExecutor("exec-hostA")
-    }
-    whenExecuting {
-      runEvent(ExecutorLost("exec-hostA"))
-    }
+    runEvent(ExecutorLost("exec-hostA"))
     val newGeneration = mapOutputTracker.getGeneration
     assert(newGeneration > oldGeneration)
     val noAccum = Map[Long, Any]()
-    // We rely on the event queue being ordered and increasing the generation number by 1
+    val taskSet = taskSets(0)
     // should be ignored for being too old
     runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
     // should work because it's a non-failed host
     runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostB", 1), noAccum, null, null))
     // should be ignored for being too old
     runEvent(CompletionEvent(taskSet.tasks(0), Success, makeMapStatus("hostA", 1), noAccum, null, null))
+    // should work because it's a new generation
     taskSet.tasks(1).generation = newGeneration
-    val secondStage = interceptStage(reduceRdd) {
-      runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null))
-    }
+    runEvent(CompletionEvent(taskSet.tasks(1), Success, makeMapStatus("hostA", 1), noAccum, null, null))
     assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
            Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
-    respondToTaskSet(secondStage, List( (Success, 42), (Success, 43) ))
-    expectJobResult(Array(42, 43))
+    complete(taskSets(1), Seq((Success, 42), (Success, 43)))
+    assert(results === Map(0 -> 42, 1 -> 43))
   }
 
   test("run trivial shuffle with out-of-band failure and retry") {
@@ -483,76 +310,49 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
     val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
     val shuffleId = shuffleDep.shuffleId
     val reduceRdd = makeRdd(1, List(shuffleDep))
-
-    val firstStage = interceptStage(shuffleMapRdd) { submitRdd(reduceRdd) }
-    resetExpecting {
-      blockManagerMaster.removeExecutor("exec-hostA")
-    }
-    whenExecuting {
-      runEvent(ExecutorLost("exec-hostA"))
-    }
+    submit(reduceRdd, Array(0))
+    // blockManagerMaster.removeExecutor("exec-hostA")
+    // pretend we were told hostA went away
+    runEvent(ExecutorLost("exec-hostA"))
     // DAGScheduler will immediately resubmit the stage after it appears to have no pending tasks
     // rather than marking it is as failed and waiting.
-    val secondStage = interceptStage(shuffleMapRdd) {
-      respondToTaskSet(firstStage, List(
+    complete(taskSets(0), Seq(
         (Success, makeMapStatus("hostA", 1)),
-        (Success, makeMapStatus("hostB", 1))
-      ))
-    }
-    val thirdStage = interceptStage(reduceRdd) {
-      respondToTaskSet(secondStage, List(
-        (Success, makeMapStatus("hostC", 1))
-      ))
-    }
-    assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
-           Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
-    respondToTaskSet(thirdStage, List( (Success, 42) ))
-    expectJobResult(Array(42))
-  }
-
-  test("recursive shuffle failures") {
+       (Success, makeMapStatus("hostB", 1))))
+   // have hostC complete the resubmitted task
+   complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
+   assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+          Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+   complete(taskSets(2), Seq((Success, 42)))
+   assert(results === Map(0 -> 42))
+ }
+
+ test("recursive shuffle failures") {
     val shuffleOneRdd = makeRdd(2, Nil)
     val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
     val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
     val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
     val finalRdd = makeRdd(1, List(shuffleDepTwo))
-
-    val firstStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
-    val secondStage = interceptStage(shuffleTwoRdd) {
-      respondToTaskSet(firstStage, List(
+    submit(finalRdd, Array(0))
+    // have the first stage complete normally
+    complete(taskSets(0), Seq(
         (Success, makeMapStatus("hostA", 2)),
-        (Success, makeMapStatus("hostB", 2))
-      ))
-    }
-    val thirdStage = interceptStage(finalRdd) {
-      respondToTaskSet(secondStage, List(
+        (Success, makeMapStatus("hostB", 2))))
+    // have the second stage complete normally
+    complete(taskSets(1), Seq(
         (Success, makeMapStatus("hostA", 1)),
-        (Success, makeMapStatus("hostC", 1))
-      ))
-    }
-    resetExpecting {
-      blockManagerMaster.removeExecutor("exec-hostA")
-    }
-    whenExecuting {
-      respondToTaskSet(thirdStage, List(
-        (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
-      ))
-    }
-    val recomputeOne = interceptStage(shuffleOneRdd) {
-      scheduler.resubmitFailedStages()
-    }
-    val recomputeTwo = interceptStage(shuffleTwoRdd) {
-      respondToTaskSet(recomputeOne, List(
-        (Success, makeMapStatus("hostA", 2))
-      ))
-    }
-    val finalStage = interceptStage(finalRdd) {
-      respondToTaskSet(recomputeTwo, List(
-        (Success, makeMapStatus("hostA", 1))
-      ))
-    }
-    respondToTaskSet(finalStage, List( (Success, 42) ))
-    expectJobResult(Array(42))
+        (Success, makeMapStatus("hostC", 1))))
+    // fail the third stage because hostA went down
+    complete(taskSets(2), Seq(
+        (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
+    // TODO assert this:
+    // blockManagerMaster.removeExecutor("exec-hostA")
+    // have DAGScheduler try again
+    scheduler.resubmitFailedStages()
+    complete(taskSets(3), Seq((Success, makeMapStatus("hostA", 2))))
+    complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1))))
+    complete(taskSets(5), Seq((Success, 42)))
+    assert(results === Map(0 -> 42))
   }
 
   test("cached post-shuffle") {
@@ -561,103 +361,44 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with EasyMockSugar
     val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
     val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
     val finalRdd = makeRdd(1, List(shuffleDepTwo))
-
-    val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
+    submit(finalRdd, Array(0))
     cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
     cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
-    val secondShuffleStage = interceptStage(shuffleTwoRdd) {
-      respondToTaskSet(firstShuffleStage, List(
+    // complete stage 2
+    complete(taskSets(0), Seq(
         (Success, makeMapStatus("hostA", 2)),
-        (Success, makeMapStatus("hostB", 2))
-      ))
-    }
-    val reduceStage = interceptStage(finalRdd) {
-      respondToTaskSet(secondShuffleStage, List(
+        (Success, makeMapStatus("hostB", 2))))
+    // complete stage 1
+    complete(taskSets(1), Seq(
         (Success, makeMapStatus("hostA", 1)),
-        (Success, makeMapStatus("hostB", 1))
-      ))
-    }
-    resetExpecting {
-      blockManagerMaster.removeExecutor("exec-hostA")
-    }
-    whenExecuting {
-      respondToTaskSet(reduceStage, List(
-        (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
-      ))
-    }
+        (Success, makeMapStatus("hostB", 1))))
+    // pretend stage 0 failed because hostA went down
+    complete(taskSets(2), Seq(
+        (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)))
+    // TODO assert this:
+    // blockManagerMaster.removeExecutor("exec-hostA")
     // DAGScheduler should notice the cached copy of the second shuffle and try to get it rerun.
-    val recomputeTwo = interceptStage(shuffleTwoRdd) {
-      scheduler.resubmitFailedStages()
-    }
-    expectTaskSetLocations(recomputeTwo, Seq(Seq("hostD")))
-    val finalRetry = interceptStage(finalRdd) {
-      respondToTaskSet(recomputeTwo, List(
-        (Success, makeMapStatus("hostD", 1))
-      ))
-    }
-    respondToTaskSet(finalRetry, List( (Success, 42) ))
-    expectJobResult(Array(42))
+    scheduler.resubmitFailedStages()
+    assertLocations(taskSets(3), Seq(Seq("hostD")))
+    // allow hostD to recover
+    complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1))))
+    complete(taskSets(4), Seq((Success, 42)))
+    assert(results === Map(0 -> 42))
   }
 
-  test("cached post-shuffle but fails") {
-    val shuffleOneRdd = makeRdd(2, Nil)
-    val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
-    val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
-    val shuffleDepTwo = new ShuffleDependency(shuffleTwoRdd, null)
-    val finalRdd = makeRdd(1, List(shuffleDepTwo))
-
-    val firstShuffleStage = interceptStage(shuffleOneRdd) { submitRdd(finalRdd) }
-    cacheLocations(shuffleTwoRdd.id -> 0) = Seq(makeBlockManagerId("hostD"))
-    cacheLocations(shuffleTwoRdd.id -> 1) = Seq(makeBlockManagerId("hostC"))
-    val secondShuffleStage = interceptStage(shuffleTwoRdd) {
-      respondToTaskSet(firstShuffleStage, List(
-        (Success, makeMapStatus("hostA", 2)),
-        (Success, makeMapStatus("hostB", 2))
-      ))
-    }
-    val reduceStage = interceptStage(finalRdd) {
-      respondToTaskSet(secondShuffleStage, List(
-        (Success, makeMapStatus("hostA", 1)),
-        (Success, makeMapStatus("hostB", 1))
-      ))
-    }
-    resetExpecting {
-      blockManagerMaster.removeExecutor("exec-hostA")
-    }
-    whenExecuting {
-      respondToTaskSet(reduceStage, List(
-        (FetchFailed(makeBlockManagerId("hostA"), shuffleDepTwo.shuffleId, 0, 0), null)
-      ))
-    }
-    val recomputeTwoCached = interceptStage(shuffleTwoRdd) {
-      scheduler.resubmitFailedStages()
-    }
-    expectTaskSetLocations(recomputeTwoCached, Seq(Seq("hostD")))
-    intercept[FetchFailedException]{
-      mapOutputTracker.getServerStatuses(shuffleDepOne.shuffleId, 0)
+  /** Assert that the supplied TaskSet has exactly the given preferredLocations. */
+  private def assertLocations(taskSet: TaskSet, locations: Seq[Seq[String]]) {
+    assert(locations.size === taskSet.tasks.size)
+    for ((expectLocs, taskLocs) <-
+            taskSet.tasks.map(_.preferredLocations).zip(locations)) {
+      assert(expectLocs === taskLocs)
     }
+  }
 
-    // Simulate the shuffle input data failing to be cached.
-    cacheLocations.remove(shuffleTwoRdd.id -> 0)
-    respondToTaskSet(recomputeTwoCached, List(
-      (FetchFailed(null, shuffleDepOne.shuffleId, 0, 0), null)
-    ))
+  private def makeMapStatus(host: String, reduces: Int): MapStatus =
+   new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
 
-    // After the fetch failure, DAGScheduler should recheck the cache and decide to resubmit
-    // everything.
-    val recomputeOne = interceptStage(shuffleOneRdd) {
-      scheduler.resubmitFailedStages()
-    }
-    // We use hostA here to make sure DAGScheduler doesn't think it's still dead.
-    val recomputeTwoUncached = interceptStage(shuffleTwoRdd) {
-      respondToTaskSet(recomputeOne, List( (Success, makeMapStatus("hostA", 1)) ))
-    }
-    expectTaskSetLocations(recomputeTwoUncached, Seq(Seq[String]()))
-    val finalRetry = interceptStage(finalRdd) {
-      respondToTaskSet(recomputeTwoUncached, List( (Success, makeMapStatus("hostA", 1)) ))
+  private def makeBlockManagerId(host: String): BlockManagerId =
+    BlockManagerId("exec-" + host, host, 12345)
 
-    }
-    respondToTaskSet(finalRetry, List( (Success, 42) ))
-    expectJobResult(Array(42))
-  }
 }
diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
index 61e793b31f4d4426e188c10af78ec5c5386316bb..b8c0f6fb763a9d58251988d12708c7f7776f7454 100644
--- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
@@ -32,7 +32,8 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
 
   before {
     actorSystem = ActorSystem("test")
-    master = new BlockManagerMaster(actorSystem, true, true, "localhost", 7077)
+    master = new BlockManagerMaster(
+      actorSystem.actorOf(Props(new spark.storage.BlockManagerMasterActor(true))))
 
     // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
     oldArch = System.setProperty("os.arch", "amd64")