diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 4ef66562224550c620ce029cbb63b3f1c0fe943e..3e10b9eee4e2401d297b3870d6e13af3f0d34c67 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -34,6 +34,156 @@ import org.apache.spark.shuffle.MetadataFetchFailedException
 import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId}
 import org.apache.spark.util._
 
+/**
+ * Helper class used by the [[MapOutputTrackerMaster]] to perform bookkeeping for a single
+ * ShuffleMapStage.
+ *
+ * This class maintains a mapping from mapIds to `MapStatus`. It also maintains a cache of
+ * serialized map statuses in order to speed up tasks' requests for map output statuses.
+ *
+ * All public methods of this class are thread-safe.
+ */
+private class ShuffleStatus(numPartitions: Int) {
+
+  // All accesses to the following state must be guarded with `this.synchronized`.
+
+  /**
+   * MapStatus for each partition. The index of the array is the map partition id.
+   * Each value in the array is the MapStatus for a partition, or null if the partition
+   * is not available. Even though in theory a task may run multiple times (due to speculation,
+   * stage retries, etc.), in practice the likelihood of a map output being available at multiple
+   * locations is so small that we choose to ignore that case and store only a single location
+   * for each output.
+   */
+  private[this] val mapStatuses = new Array[MapStatus](numPartitions)
+
+  /**
+   * The cached result of serializing the map statuses array. This cache is lazily populated when
+   * [[serializedMapStatus]] is called. The cache is invalidated when map outputs are removed.
+   */
+  private[this] var cachedSerializedMapStatus: Array[Byte] = _
+
+  /**
+   * Broadcast variable holding serialized map output statuses array. When [[serializedMapStatus]]
+   * serializes the map statuses array it may detect that the result is too large to send in a
+   * single RPC, in which case it places the serialized array into a broadcast variable and then
+   * sends a serialized broadcast variable instead. This variable holds a reference to that
+   * broadcast variable in order to keep it from being garbage collected and to allow for it to be
+   * explicitly destroyed later on when the ShuffleMapStage is garbage-collected.
+   */
+  private[this] var cachedSerializedBroadcast: Broadcast[Array[Byte]] = _
+
+  /**
+   * Counter tracking the number of partitions that have output. This is a performance optimization
+   * to avoid having to count the number of non-null entries in the `mapStatuses` array and should
+   * be equivalent to`mapStatuses.count(_ ne null)`.
+   */
+  private[this] var _numAvailableOutputs: Int = 0
+
+  /**
+   * Register a map output. If there is already a registered location for the map output then it
+   * will be replaced by the new location.
+   */
+  def addMapOutput(mapId: Int, status: MapStatus): Unit = synchronized {
+    if (mapStatuses(mapId) == null) {
+      _numAvailableOutputs += 1
+      invalidateSerializedMapOutputStatusCache()
+    }
+    mapStatuses(mapId) = status
+  }
+
+  /**
+   * Remove the map output which was served by the specified block manager.
+   * This is a no-op if there is no registered map output or if the registered output is from a
+   * different block manager.
+   */
+  def removeMapOutput(mapId: Int, bmAddress: BlockManagerId): Unit = synchronized {
+    if (mapStatuses(mapId) != null && mapStatuses(mapId).location == bmAddress) {
+      _numAvailableOutputs -= 1
+      mapStatuses(mapId) = null
+      invalidateSerializedMapOutputStatusCache()
+    }
+  }
+
+  /**
+   * Removes all map outputs associated with the specified executor. Note that this will also
+   * remove outputs which are served by an external shuffle server (if one exists), as they are
+   * still registered with that execId.
+   */
+  def removeOutputsOnExecutor(execId: String): Unit = synchronized {
+    for (mapId <- 0 until mapStatuses.length) {
+      if (mapStatuses(mapId) != null && mapStatuses(mapId).location.executorId == execId) {
+        _numAvailableOutputs -= 1
+        mapStatuses(mapId) = null
+        invalidateSerializedMapOutputStatusCache()
+      }
+    }
+  }
+
+  /**
+   * Number of partitions that have shuffle outputs.
+   */
+  def numAvailableOutputs: Int = synchronized {
+    _numAvailableOutputs
+  }
+
+  /**
+   * Returns the sequence of partition ids that are missing (i.e. needs to be computed).
+   */
+  def findMissingPartitions(): Seq[Int] = synchronized {
+    val missing = (0 until numPartitions).filter(id => mapStatuses(id) == null)
+    assert(missing.size == numPartitions - _numAvailableOutputs,
+      s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}")
+    missing
+  }
+
+  /**
+   * Serializes the mapStatuses array into an efficient compressed format. See the comments on
+   * `MapOutputTracker.serializeMapStatuses()` for more details on the serialization format.
+   *
+   * This method is designed to be called multiple times and implements caching in order to speed
+   * up subsequent requests. If the cache is empty and multiple threads concurrently attempt to
+   * serialize the map statuses then serialization will only be performed in a single thread and all
+   * other threads will block until the cache is populated.
+   */
+  def serializedMapStatus(
+      broadcastManager: BroadcastManager,
+      isLocal: Boolean,
+      minBroadcastSize: Int): Array[Byte] = synchronized {
+    if (cachedSerializedMapStatus eq null) {
+      val serResult = MapOutputTracker.serializeMapStatuses(
+          mapStatuses, broadcastManager, isLocal, minBroadcastSize)
+      cachedSerializedMapStatus = serResult._1
+      cachedSerializedBroadcast = serResult._2
+    }
+    cachedSerializedMapStatus
+  }
+
+  // Used in testing.
+  def hasCachedSerializedBroadcast: Boolean = synchronized {
+    cachedSerializedBroadcast != null
+  }
+
+  /**
+   * Helper function which provides thread-safe access to the mapStatuses array.
+   * The function should NOT mutate the array.
+   */
+  def withMapStatuses[T](f: Array[MapStatus] => T): T = synchronized {
+    f(mapStatuses)
+  }
+
+  /**
+   * Clears the cached serialized map output statuses.
+   */
+  def invalidateSerializedMapOutputStatusCache(): Unit = synchronized {
+    if (cachedSerializedBroadcast != null) {
+      cachedSerializedBroadcast.destroy()
+      cachedSerializedBroadcast = null
+    }
+    cachedSerializedMapStatus = null
+  }
+}
+
 private[spark] sealed trait MapOutputTrackerMessage
 private[spark] case class GetMapOutputStatuses(shuffleId: Int)
   extends MapOutputTrackerMessage
@@ -62,37 +212,26 @@ private[spark] class MapOutputTrackerMasterEndpoint(
 }
 
 /**
- * Class that keeps track of the location of the map output of
- * a stage. This is abstract because different versions of MapOutputTracker
- * (driver and executor) use different HashMap to store its metadata.
- */
+ * Class that keeps track of the location of the map output of a stage. This is abstract because the
+ * driver and executor have different versions of the MapOutputTracker. In principle the driver-
+ * and executor-side classes don't need to share a common base class; the current shared base class
+ * is maintained primarily for backwards-compatibility in order to avoid having to update existing
+ * test code.
+*/
 private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging {
-
   /** Set to the MapOutputTrackerMasterEndpoint living on the driver. */
   var trackerEndpoint: RpcEndpointRef = _
 
   /**
-   * This HashMap has different behavior for the driver and the executors.
-   *
-   * On the driver, it serves as the source of map outputs recorded from ShuffleMapTasks.
-   * On the executors, it simply serves as a cache, in which a miss triggers a fetch from the
-   * driver's corresponding HashMap.
-   *
-   * Note: because mapStatuses is accessed concurrently, subclasses should make sure it's a
-   * thread-safe map.
-   */
-  protected val mapStatuses: Map[Int, Array[MapStatus]]
-
-  /**
-   * Incremented every time a fetch fails so that client nodes know to clear
-   * their cache of map output locations if this happens.
+   * The driver-side counter is incremented every time that a map output is lost. This value is sent
+   * to executors as part of tasks, where executors compare the new epoch number to the highest
+   * epoch number that they received in the past. If the new epoch number is higher then executors
+   * will clear their local caches of map output statuses and will re-fetch (possibly updated)
+   * statuses from the driver.
    */
   protected var epoch: Long = 0
   protected val epochLock = new AnyRef
 
-  /** Remembers which map output locations are currently being fetched on an executor. */
-  private val fetching = new HashSet[Int]
-
   /**
    * Send a message to the trackerEndpoint and get its result within a default timeout, or
    * throw a SparkException if this fails.
@@ -116,14 +255,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
     }
   }
 
-  /**
-   * Called from executors to get the server URIs and output sizes for each shuffle block that
-   * needs to be read from a given reduce task.
-   *
-   * @return A sequence of 2-item tuples, where the first item in the tuple is a BlockManagerId,
-   *         and the second item is a sequence of (shuffle block id, shuffle block size) tuples
-   *         describing the shuffle blocks that are stored at that block manager.
-   */
+  // For testing
   def getMapSizesByExecutorId(shuffleId: Int, reduceId: Int)
       : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
     getMapSizesByExecutorId(shuffleId, reduceId, reduceId + 1)
@@ -139,135 +271,31 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
    *         describing the shuffle blocks that are stored at that block manager.
    */
   def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
-      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
-    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
-    val statuses = getStatuses(shuffleId)
-    // Synchronize on the returned array because, on the driver, it gets mutated in place
-    statuses.synchronized {
-      return MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
-    }
-  }
+      : Seq[(BlockManagerId, Seq[(BlockId, Long)])]
 
   /**
-   * Return statistics about all of the outputs for a given shuffle.
+   * Deletes map output status information for the specified shuffle stage.
    */
-  def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = {
-    val statuses = getStatuses(dep.shuffleId)
-    // Synchronize on the returned array because, on the driver, it gets mutated in place
-    statuses.synchronized {
-      val totalSizes = new Array[Long](dep.partitioner.numPartitions)
-      for (s <- statuses) {
-        for (i <- 0 until totalSizes.length) {
-          totalSizes(i) += s.getSizeForBlock(i)
-        }
-      }
-      new MapOutputStatistics(dep.shuffleId, totalSizes)
-    }
-  }
-
-  /**
-   * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize
-   * on this array when reading it, because on the driver, we may be changing it in place.
-   *
-   * (It would be nice to remove this restriction in the future.)
-   */
-  private def getStatuses(shuffleId: Int): Array[MapStatus] = {
-    val statuses = mapStatuses.get(shuffleId).orNull
-    if (statuses == null) {
-      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
-      val startTime = System.currentTimeMillis
-      var fetchedStatuses: Array[MapStatus] = null
-      fetching.synchronized {
-        // Someone else is fetching it; wait for them to be done
-        while (fetching.contains(shuffleId)) {
-          try {
-            fetching.wait()
-          } catch {
-            case e: InterruptedException =>
-          }
-        }
-
-        // Either while we waited the fetch happened successfully, or
-        // someone fetched it in between the get and the fetching.synchronized.
-        fetchedStatuses = mapStatuses.get(shuffleId).orNull
-        if (fetchedStatuses == null) {
-          // We have to do the fetch, get others to wait for us.
-          fetching += shuffleId
-        }
-      }
+  def unregisterShuffle(shuffleId: Int): Unit
 
-      if (fetchedStatuses == null) {
-        // We won the race to fetch the statuses; do so
-        logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
-        // This try-finally prevents hangs due to timeouts:
-        try {
-          val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
-          fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
-          logInfo("Got the output locations")
-          mapStatuses.put(shuffleId, fetchedStatuses)
-        } finally {
-          fetching.synchronized {
-            fetching -= shuffleId
-            fetching.notifyAll()
-          }
-        }
-      }
-      logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
-        s"${System.currentTimeMillis - startTime} ms")
-
-      if (fetchedStatuses != null) {
-        return fetchedStatuses
-      } else {
-        logError("Missing all output locations for shuffle " + shuffleId)
-        throw new MetadataFetchFailedException(
-          shuffleId, -1, "Missing all output locations for shuffle " + shuffleId)
-      }
-    } else {
-      return statuses
-    }
-  }
-
-  /** Called to get current epoch number. */
-  def getEpoch: Long = {
-    epochLock.synchronized {
-      return epoch
-    }
-  }
-
-  /**
-   * Called from executors to update the epoch number, potentially clearing old outputs
-   * because of a fetch failure. Each executor task calls this with the latest epoch
-   * number on the driver at the time it was created.
-   */
-  def updateEpoch(newEpoch: Long) {
-    epochLock.synchronized {
-      if (newEpoch > epoch) {
-        logInfo("Updating epoch to " + newEpoch + " and clearing cache")
-        epoch = newEpoch
-        mapStatuses.clear()
-      }
-    }
-  }
-
-  /** Unregister shuffle data. */
-  def unregisterShuffle(shuffleId: Int) {
-    mapStatuses.remove(shuffleId)
-  }
-
-  /** Stop the tracker. */
-  def stop() { }
+  def stop() {}
 }
 
 /**
- * MapOutputTracker for the driver.
+ * Driver-side class that keeps track of the location of the map output of a stage.
+ *
+ * The DAGScheduler uses this class to (de)register map output statuses and to look up statistics
+ * for performing locality-aware reduce task scheduling.
+ *
+ * ShuffleMapStage uses this class for tracking available / missing outputs in order to determine
+ * which tasks need to be run.
  */
-private[spark] class MapOutputTrackerMaster(conf: SparkConf,
-    broadcastManager: BroadcastManager, isLocal: Boolean)
+private[spark] class MapOutputTrackerMaster(
+    conf: SparkConf,
+    broadcastManager: BroadcastManager,
+    isLocal: Boolean)
   extends MapOutputTracker(conf) {
 
-  /** Cache a serialized version of the output statuses for each shuffle to send them out faster */
-  private var cacheEpoch = epoch
-
   // The size at which we use Broadcast to send the map output statuses to the executors
   private val minSizeForBroadcast =
     conf.getSizeAsBytes("spark.shuffle.mapOutput.minSizeForBroadcast", "512k").toInt
@@ -287,22 +315,12 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf,
   // can be read locally, but may lead to more delay in scheduling if those locations are busy.
   private val REDUCER_PREF_LOCS_FRACTION = 0.2
 
-  // HashMaps for storing mapStatuses and cached serialized statuses in the driver.
+  // HashMap for storing shuffleStatuses in the driver.
   // Statuses are dropped only by explicit de-registering.
-  protected val mapStatuses = new ConcurrentHashMap[Int, Array[MapStatus]]().asScala
-  private val cachedSerializedStatuses = new ConcurrentHashMap[Int, Array[Byte]]().asScala
+  private val shuffleStatuses = new ConcurrentHashMap[Int, ShuffleStatus]().asScala
 
   private val maxRpcMessageSize = RpcUtils.maxMessageSizeBytes(conf)
 
-  // Kept in sync with cachedSerializedStatuses explicitly
-  // This is required so that the Broadcast variable remains in scope until we remove
-  // the shuffleId explicitly or implicitly.
-  private val cachedSerializedBroadcast = new HashMap[Int, Broadcast[Array[Byte]]]()
-
-  // This is to prevent multiple serializations of the same shuffle - which happens when
-  // there is a request storm when shuffle start.
-  private val shuffleIdLocks = new ConcurrentHashMap[Int, AnyRef]()
-
   // requests for map output statuses
   private val mapOutputRequests = new LinkedBlockingQueue[GetMapOutputMessage]
 
@@ -348,8 +366,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf,
             val hostPort = context.senderAddress.hostPort
             logDebug("Handling request to send map output locations for shuffle " + shuffleId +
               " to " + hostPort)
-            val mapOutputStatuses = getSerializedMapOutputStatuses(shuffleId)
-            context.reply(mapOutputStatuses)
+            val shuffleStatus = shuffleStatuses.get(shuffleId).head
+            context.reply(
+              shuffleStatus.serializedMapStatus(broadcastManager, isLocal, minSizeForBroadcast))
           } catch {
             case NonFatal(e) => logError(e.getMessage, e)
           }
@@ -363,59 +382,77 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf,
   /** A poison endpoint that indicates MessageLoop should exit its message loop. */
   private val PoisonPill = new GetMapOutputMessage(-99, null)
 
-  // Exposed for testing
-  private[spark] def getNumCachedSerializedBroadcast = cachedSerializedBroadcast.size
+  // Used only in unit tests.
+  private[spark] def getNumCachedSerializedBroadcast: Int = {
+    shuffleStatuses.valuesIterator.count(_.hasCachedSerializedBroadcast)
+  }
 
   def registerShuffle(shuffleId: Int, numMaps: Int) {
-    if (mapStatuses.put(shuffleId, new Array[MapStatus](numMaps)).isDefined) {
+    if (shuffleStatuses.put(shuffleId, new ShuffleStatus(numMaps)).isDefined) {
       throw new IllegalArgumentException("Shuffle ID " + shuffleId + " registered twice")
     }
-    // add in advance
-    shuffleIdLocks.putIfAbsent(shuffleId, new Object())
   }
 
   def registerMapOutput(shuffleId: Int, mapId: Int, status: MapStatus) {
-    val array = mapStatuses(shuffleId)
-    array.synchronized {
-      array(mapId) = status
-    }
-  }
-
-  /** Register multiple map output information for the given shuffle */
-  def registerMapOutputs(shuffleId: Int, statuses: Array[MapStatus], changeEpoch: Boolean = false) {
-    mapStatuses.put(shuffleId, statuses.clone())
-    if (changeEpoch) {
-      incrementEpoch()
-    }
+    shuffleStatuses(shuffleId).addMapOutput(mapId, status)
   }
 
   /** Unregister map output information of the given shuffle, mapper and block manager */
   def unregisterMapOutput(shuffleId: Int, mapId: Int, bmAddress: BlockManagerId) {
-    val arrayOpt = mapStatuses.get(shuffleId)
-    if (arrayOpt.isDefined && arrayOpt.get != null) {
-      val array = arrayOpt.get
-      array.synchronized {
-        if (array(mapId) != null && array(mapId).location == bmAddress) {
-          array(mapId) = null
-        }
-      }
-      incrementEpoch()
-    } else {
-      throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
+    shuffleStatuses.get(shuffleId) match {
+      case Some(shuffleStatus) =>
+        shuffleStatus.removeMapOutput(mapId, bmAddress)
+        incrementEpoch()
+      case None =>
+        throw new SparkException("unregisterMapOutput called for nonexistent shuffle ID")
     }
   }
 
   /** Unregister shuffle data */
-  override def unregisterShuffle(shuffleId: Int) {
-    mapStatuses.remove(shuffleId)
-    cachedSerializedStatuses.remove(shuffleId)
-    cachedSerializedBroadcast.remove(shuffleId).foreach(v => removeBroadcast(v))
-    shuffleIdLocks.remove(shuffleId)
+  def unregisterShuffle(shuffleId: Int) {
+    shuffleStatuses.remove(shuffleId).foreach { shuffleStatus =>
+      shuffleStatus.invalidateSerializedMapOutputStatusCache()
+    }
+  }
+
+  /**
+   * Removes all shuffle outputs associated with this executor. Note that this will also remove
+   * outputs which are served by an external shuffle server (if one exists), as they are still
+   * registered with this execId.
+   */
+  def removeOutputsOnExecutor(execId: String): Unit = {
+    shuffleStatuses.valuesIterator.foreach { _.removeOutputsOnExecutor(execId) }
+    incrementEpoch()
   }
 
   /** Check if the given shuffle is being tracked */
-  def containsShuffle(shuffleId: Int): Boolean = {
-    cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId)
+  def containsShuffle(shuffleId: Int): Boolean = shuffleStatuses.contains(shuffleId)
+
+  def getNumAvailableOutputs(shuffleId: Int): Int = {
+    shuffleStatuses.get(shuffleId).map(_.numAvailableOutputs).getOrElse(0)
+  }
+
+  /**
+   * Returns the sequence of partition ids that are missing (i.e. needs to be computed), or None
+   * if the MapOutputTrackerMaster doesn't know about this shuffle.
+   */
+  def findMissingPartitions(shuffleId: Int): Option[Seq[Int]] = {
+    shuffleStatuses.get(shuffleId).map(_.findMissingPartitions())
+  }
+
+  /**
+   * Return statistics about all of the outputs for a given shuffle.
+   */
+  def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics = {
+    shuffleStatuses(dep.shuffleId).withMapStatuses { statuses =>
+      val totalSizes = new Array[Long](dep.partitioner.numPartitions)
+      for (s <- statuses) {
+        for (i <- 0 until totalSizes.length) {
+          totalSizes(i) += s.getSizeForBlock(i)
+        }
+      }
+      new MapOutputStatistics(dep.shuffleId, totalSizes)
+    }
   }
 
   /**
@@ -459,9 +496,9 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf,
       fractionThreshold: Double)
     : Option[Array[BlockManagerId]] = {
 
-    val statuses = mapStatuses.get(shuffleId).orNull
-    if (statuses != null) {
-      statuses.synchronized {
+    val shuffleStatus = shuffleStatuses.get(shuffleId).orNull
+    if (shuffleStatus != null) {
+      shuffleStatus.withMapStatuses { statuses =>
         if (statuses.nonEmpty) {
           // HashMap to add up sizes of all blocks at the same location
           val locs = new HashMap[BlockManagerId, Long]
@@ -502,77 +539,24 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf,
     }
   }
 
-  private def removeBroadcast(bcast: Broadcast[_]): Unit = {
-    if (null != bcast) {
-      broadcastManager.unbroadcast(bcast.id,
-        removeFromDriver = true, blocking = false)
+  /** Called to get current epoch number. */
+  def getEpoch: Long = {
+    epochLock.synchronized {
+      return epoch
     }
   }
 
-  private def clearCachedBroadcast(): Unit = {
-    for (cached <- cachedSerializedBroadcast) removeBroadcast(cached._2)
-    cachedSerializedBroadcast.clear()
-  }
-
-  def getSerializedMapOutputStatuses(shuffleId: Int): Array[Byte] = {
-    var statuses: Array[MapStatus] = null
-    var retBytes: Array[Byte] = null
-    var epochGotten: Long = -1
-
-    // Check to see if we have a cached version, returns true if it does
-    // and has side effect of setting retBytes.  If not returns false
-    // with side effect of setting statuses
-    def checkCachedStatuses(): Boolean = {
-      epochLock.synchronized {
-        if (epoch > cacheEpoch) {
-          cachedSerializedStatuses.clear()
-          clearCachedBroadcast()
-          cacheEpoch = epoch
-        }
-        cachedSerializedStatuses.get(shuffleId) match {
-          case Some(bytes) =>
-            retBytes = bytes
-            true
-          case None =>
-            logDebug("cached status not found for : " + shuffleId)
-            statuses = mapStatuses.getOrElse(shuffleId, Array.empty[MapStatus])
-            epochGotten = epoch
-            false
-        }
-      }
-    }
-
-    if (checkCachedStatuses()) return retBytes
-    var shuffleIdLock = shuffleIdLocks.get(shuffleId)
-    if (null == shuffleIdLock) {
-      val newLock = new Object()
-      // in general, this condition should be false - but good to be paranoid
-      val prevLock = shuffleIdLocks.putIfAbsent(shuffleId, newLock)
-      shuffleIdLock = if (null != prevLock) prevLock else newLock
-    }
-    // synchronize so we only serialize/broadcast it once since multiple threads call
-    // in parallel
-    shuffleIdLock.synchronized {
-      // double check to make sure someone else didn't serialize and cache the same
-      // mapstatus while we were waiting on the synchronize
-      if (checkCachedStatuses()) return retBytes
-
-      // If we got here, we failed to find the serialized locations in the cache, so we pulled
-      // out a snapshot of the locations as "statuses"; let's serialize and return that
-      val (bytes, bcast) = MapOutputTracker.serializeMapStatuses(statuses, broadcastManager,
-        isLocal, minSizeForBroadcast)
-      logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
-      // Add them into the table only if the epoch hasn't changed while we were working
-      epochLock.synchronized {
-        if (epoch == epochGotten) {
-          cachedSerializedStatuses(shuffleId) = bytes
-          if (null != bcast) cachedSerializedBroadcast(shuffleId) = bcast
-        } else {
-          logInfo("Epoch changed, not caching!")
-          removeBroadcast(bcast)
+  // This method is only called in local-mode.
+  def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
+      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
+    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
+    shuffleStatuses.get(shuffleId) match {
+      case Some (shuffleStatus) =>
+        shuffleStatus.withMapStatuses { statuses =>
+          MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
         }
-      }
-      bytes
+      case None =>
+        Seq.empty
     }
   }
 
@@ -580,21 +564,121 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf,
     mapOutputRequests.offer(PoisonPill)
     threadpool.shutdown()
     sendTracker(StopMapOutputTracker)
-    mapStatuses.clear()
     trackerEndpoint = null
-    cachedSerializedStatuses.clear()
-    clearCachedBroadcast()
-    shuffleIdLocks.clear()
+    shuffleStatuses.clear()
   }
 }
 
 /**
- * MapOutputTracker for the executors, which fetches map output information from the driver's
- * MapOutputTrackerMaster.
+ * Executor-side client for fetching map output info from the driver's MapOutputTrackerMaster.
+ * Note that this is not used in local-mode; instead, local-mode Executors access the
+ * MapOutputTrackerMaster directly (which is possible because the master and worker share a comon
+ * superclass).
  */
 private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) {
-  protected val mapStatuses: Map[Int, Array[MapStatus]] =
+
+  val mapStatuses: Map[Int, Array[MapStatus]] =
     new ConcurrentHashMap[Int, Array[MapStatus]]().asScala
+
+  /** Remembers which map output locations are currently being fetched on an executor. */
+  private val fetching = new HashSet[Int]
+
+  override def getMapSizesByExecutorId(shuffleId: Int, startPartition: Int, endPartition: Int)
+      : Seq[(BlockManagerId, Seq[(BlockId, Long)])] = {
+    logDebug(s"Fetching outputs for shuffle $shuffleId, partitions $startPartition-$endPartition")
+    val statuses = getStatuses(shuffleId)
+    try {
+      MapOutputTracker.convertMapStatuses(shuffleId, startPartition, endPartition, statuses)
+    } catch {
+      case e: MetadataFetchFailedException =>
+        // We experienced a fetch failure so our mapStatuses cache is outdated; clear it:
+        mapStatuses.clear()
+        throw e
+    }
+  }
+
+  /**
+   * Get or fetch the array of MapStatuses for a given shuffle ID. NOTE: clients MUST synchronize
+   * on this array when reading it, because on the driver, we may be changing it in place.
+   *
+   * (It would be nice to remove this restriction in the future.)
+   */
+  private def getStatuses(shuffleId: Int): Array[MapStatus] = {
+    val statuses = mapStatuses.get(shuffleId).orNull
+    if (statuses == null) {
+      logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
+      val startTime = System.currentTimeMillis
+      var fetchedStatuses: Array[MapStatus] = null
+      fetching.synchronized {
+        // Someone else is fetching it; wait for them to be done
+        while (fetching.contains(shuffleId)) {
+          try {
+            fetching.wait()
+          } catch {
+            case e: InterruptedException =>
+          }
+        }
+
+        // Either while we waited the fetch happened successfully, or
+        // someone fetched it in between the get and the fetching.synchronized.
+        fetchedStatuses = mapStatuses.get(shuffleId).orNull
+        if (fetchedStatuses == null) {
+          // We have to do the fetch, get others to wait for us.
+          fetching += shuffleId
+        }
+      }
+
+      if (fetchedStatuses == null) {
+        // We won the race to fetch the statuses; do so
+        logInfo("Doing the fetch; tracker endpoint = " + trackerEndpoint)
+        // This try-finally prevents hangs due to timeouts:
+        try {
+          val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId))
+          fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes)
+          logInfo("Got the output locations")
+          mapStatuses.put(shuffleId, fetchedStatuses)
+        } finally {
+          fetching.synchronized {
+            fetching -= shuffleId
+            fetching.notifyAll()
+          }
+        }
+      }
+      logDebug(s"Fetching map output statuses for shuffle $shuffleId took " +
+        s"${System.currentTimeMillis - startTime} ms")
+
+      if (fetchedStatuses != null) {
+        fetchedStatuses
+      } else {
+        logError("Missing all output locations for shuffle " + shuffleId)
+        throw new MetadataFetchFailedException(
+          shuffleId, -1, "Missing all output locations for shuffle " + shuffleId)
+      }
+    } else {
+      statuses
+    }
+  }
+
+
+  /** Unregister shuffle data. */
+  def unregisterShuffle(shuffleId: Int): Unit = {
+    mapStatuses.remove(shuffleId)
+  }
+
+  /**
+   * Called from executors to update the epoch number, potentially clearing old outputs
+   * because of a fetch failure. Each executor task calls this with the latest epoch
+   * number on the driver at the time it was created.
+   */
+  def updateEpoch(newEpoch: Long): Unit = {
+    epochLock.synchronized {
+      if (newEpoch > epoch) {
+        logInfo("Updating epoch to " + newEpoch + " and clearing cache")
+        epoch = newEpoch
+        mapStatuses.clear()
+      }
+    }
+  }
 }
 
 private[spark] object MapOutputTracker extends Logging {
@@ -683,7 +767,7 @@ private[spark] object MapOutputTracker extends Logging {
    *         and the second item is a sequence of (shuffle block ID, shuffle block size) tuples
    *         describing the shuffle blocks that are stored at that block manager.
    */
-  private def convertMapStatuses(
+  def convertMapStatuses(
       shuffleId: Int,
       startPartition: Int,
       endPartition: Int,
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 5b396687dd11a94a44cb39f3df64d8a68ceb4015..19e7eb086f413b6a2c474f3fbbfcc3e8ea3e7852 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -322,8 +322,14 @@ private[spark] class Executor(
           throw new TaskKilledException(killReason.get)
         }
 
-        logDebug("Task " + taskId + "'s epoch is " + task.epoch)
-        env.mapOutputTracker.updateEpoch(task.epoch)
+        // The purpose of updating the epoch here is to invalidate executor map output status cache
+        // in case FetchFailures have occurred. In local mode `env.mapOutputTracker` will be
+        // MapOutputTrackerMaster and its cache invalidation is not based on epoch numbers so
+        // we don't need to make any special calls here.
+        if (!isLocal) {
+          logDebug("Task " + taskId + "'s epoch is " + task.epoch)
+          env.mapOutputTracker.asInstanceOf[MapOutputTrackerWorker].updateEpoch(task.epoch)
+        }
 
         // Run the actual task and measure its runtime.
         taskStart = System.currentTimeMillis()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index ab2255f8a665462461b88ad52d92623e9b0082ca..932e6c138e1c418b1376883dbd7659dd235add88 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -328,25 +328,14 @@ class DAGScheduler(
     val numTasks = rdd.partitions.length
     val parents = getOrCreateParentStages(rdd, jobId)
     val id = nextStageId.getAndIncrement()
-    val stage = new ShuffleMapStage(id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep)
+    val stage = new ShuffleMapStage(
+      id, rdd, numTasks, parents, jobId, rdd.creationSite, shuffleDep, mapOutputTracker)
 
     stageIdToStage(id) = stage
     shuffleIdToMapStage(shuffleDep.shuffleId) = stage
     updateJobIdStageIdMaps(jobId, stage)
 
-    if (mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
-      // A previously run stage generated partitions for this shuffle, so for each output
-      // that's still available, copy information about that output location to the new stage
-      // (so we don't unnecessarily re-compute that data).
-      val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
-      val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
-      (0 until locs.length).foreach { i =>
-        if (locs(i) ne null) {
-          // locs(i) will be null if missing
-          stage.addOutputLoc(i, locs(i))
-        }
-      }
-    } else {
+    if (!mapOutputTracker.containsShuffle(shuffleDep.shuffleId)) {
       // Kind of ugly: need to register RDDs with the cache and map output tracker here
       // since we can't do it in the RDD constructor because # of partitions is unknown
       logInfo("Registering RDD " + rdd.id + " (" + rdd.getCreationSite + ")")
@@ -1217,7 +1206,8 @@ class DAGScheduler(
               // The epoch of the task is acceptable (i.e., the task was launched after the most
               // recent failure we're aware of for the executor), so mark the task's output as
               // available.
-              shuffleStage.addOutputLoc(smt.partitionId, status)
+              mapOutputTracker.registerMapOutput(
+                shuffleStage.shuffleDep.shuffleId, smt.partitionId, status)
               // Remove the task's partition from pending partitions. This may have already been
               // done above, but will not have been done yet in cases where the task attempt was
               // from an earlier attempt of the stage (i.e., not the attempt that's currently
@@ -1234,16 +1224,14 @@ class DAGScheduler(
               logInfo("waiting: " + waitingStages)
               logInfo("failed: " + failedStages)
 
-              // We supply true to increment the epoch number here in case this is a
-              // recomputation of the map outputs. In that case, some nodes may have cached
-              // locations with holes (from when we detected the error) and will need the
-              // epoch incremented to refetch them.
-              // TODO: Only increment the epoch number if this is not the first time
-              //       we registered these map outputs.
-              mapOutputTracker.registerMapOutputs(
-                shuffleStage.shuffleDep.shuffleId,
-                shuffleStage.outputLocInMapOutputTrackerFormat(),
-                changeEpoch = true)
+              // This call to increment the epoch may not be strictly necessary, but it is retained
+              // for now in order to minimize the changes in behavior from an earlier version of the
+              // code. This existing behavior of always incrementing the epoch following any
+              // successful shuffle map stage completion may have benefits by causing unneeded
+              // cached map outputs to be cleaned up earlier on executors. In the future we can
+              // consider removing this call, but this will require some extra investigation.
+              // See https://github.com/apache/spark/pull/17955/files#r117385673 for more details.
+              mapOutputTracker.incrementEpoch()
 
               clearCacheLocs()
 
@@ -1343,7 +1331,6 @@ class DAGScheduler(
           }
           // Mark the map whose fetch failed as broken in the map stage
           if (mapId != -1) {
-            mapStage.removeOutputLoc(mapId, bmAddress)
             mapOutputTracker.unregisterMapOutput(shuffleId, mapId, bmAddress)
           }
 
@@ -1393,17 +1380,7 @@ class DAGScheduler(
 
       if (filesLost || !env.blockManager.externalShuffleServiceEnabled) {
         logInfo("Shuffle files lost for executor: %s (epoch %d)".format(execId, currentEpoch))
-        // TODO: This will be really slow if we keep accumulating shuffle map stages
-        for ((shuffleId, stage) <- shuffleIdToMapStage) {
-          stage.removeOutputsOnExecutor(execId)
-          mapOutputTracker.registerMapOutputs(
-            shuffleId,
-            stage.outputLocInMapOutputTrackerFormat(),
-            changeEpoch = true)
-        }
-        if (shuffleIdToMapStage.isEmpty) {
-          mapOutputTracker.incrementEpoch()
-        }
+        mapOutputTracker.removeOutputsOnExecutor(execId)
         clearCacheLocs()
       }
     } else {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
index db4d9efa2270c80182c6482d1d0cb01536f945cb..05f650fbf5df995950c52f9f0fa62c6796066aa8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapStage.scala
@@ -19,9 +19,8 @@ package org.apache.spark.scheduler
 
 import scala.collection.mutable.HashSet
 
-import org.apache.spark.ShuffleDependency
+import org.apache.spark.{MapOutputTrackerMaster, ShuffleDependency, SparkEnv}
 import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.BlockManagerId
 import org.apache.spark.util.CallSite
 
 /**
@@ -42,13 +41,12 @@ private[spark] class ShuffleMapStage(
     parents: List[Stage],
     firstJobId: Int,
     callSite: CallSite,
-    val shuffleDep: ShuffleDependency[_, _, _])
+    val shuffleDep: ShuffleDependency[_, _, _],
+    mapOutputTrackerMaster: MapOutputTrackerMaster)
   extends Stage(id, rdd, numTasks, parents, firstJobId, callSite) {
 
   private[this] var _mapStageJobs: List[ActiveJob] = Nil
 
-  private[this] var _numAvailableOutputs: Int = 0
-
   /**
    * Partitions that either haven't yet been computed, or that were computed on an executor
    * that has since been lost, so should be re-computed.  This variable is used by the
@@ -60,13 +58,6 @@ private[spark] class ShuffleMapStage(
    */
   val pendingPartitions = new HashSet[Int]
 
-  /**
-   * List of [[MapStatus]] for each partition. The index of the array is the map partition id,
-   * and each value in the array is the list of possible [[MapStatus]] for a partition
-   * (a single task might run multiple times).
-   */
-  private[this] val outputLocs = Array.fill[List[MapStatus]](numPartitions)(Nil)
-
   override def toString: String = "ShuffleMapStage " + id
 
   /**
@@ -88,69 +79,18 @@ private[spark] class ShuffleMapStage(
   /**
    * Number of partitions that have shuffle outputs.
    * When this reaches [[numPartitions]], this map stage is ready.
-   * This should be kept consistent as `outputLocs.filter(!_.isEmpty).size`.
    */
-  def numAvailableOutputs: Int = _numAvailableOutputs
+  def numAvailableOutputs: Int = mapOutputTrackerMaster.getNumAvailableOutputs(shuffleDep.shuffleId)
 
   /**
    * Returns true if the map stage is ready, i.e. all partitions have shuffle outputs.
-   * This should be the same as `outputLocs.contains(Nil)`.
    */
-  def isAvailable: Boolean = _numAvailableOutputs == numPartitions
+  def isAvailable: Boolean = numAvailableOutputs == numPartitions
 
   /** Returns the sequence of partition ids that are missing (i.e. needs to be computed). */
   override def findMissingPartitions(): Seq[Int] = {
-    val missing = (0 until numPartitions).filter(id => outputLocs(id).isEmpty)
-    assert(missing.size == numPartitions - _numAvailableOutputs,
-      s"${missing.size} missing, expected ${numPartitions - _numAvailableOutputs}")
-    missing
-  }
-
-  def addOutputLoc(partition: Int, status: MapStatus): Unit = {
-    val prevList = outputLocs(partition)
-    outputLocs(partition) = status :: prevList
-    if (prevList == Nil) {
-      _numAvailableOutputs += 1
-    }
-  }
-
-  def removeOutputLoc(partition: Int, bmAddress: BlockManagerId): Unit = {
-    val prevList = outputLocs(partition)
-    val newList = prevList.filterNot(_.location == bmAddress)
-    outputLocs(partition) = newList
-    if (prevList != Nil && newList == Nil) {
-      _numAvailableOutputs -= 1
-    }
-  }
-
-  /**
-   * Returns an array of [[MapStatus]] (index by partition id). For each partition, the returned
-   * value contains only one (i.e. the first) [[MapStatus]]. If there is no entry for the partition,
-   * that position is filled with null.
-   */
-  def outputLocInMapOutputTrackerFormat(): Array[MapStatus] = {
-    outputLocs.map(_.headOption.orNull)
-  }
-
-  /**
-   * Removes all shuffle outputs associated with this executor. Note that this will also remove
-   * outputs which are served by an external shuffle server (if one exists), as they are still
-   * registered with this execId.
-   */
-  def removeOutputsOnExecutor(execId: String): Unit = {
-    var becameUnavailable = false
-    for (partition <- 0 until numPartitions) {
-      val prevList = outputLocs(partition)
-      val newList = prevList.filterNot(_.location.executorId == execId)
-      outputLocs(partition) = newList
-      if (prevList != Nil && newList == Nil) {
-        becameUnavailable = true
-        _numAvailableOutputs -= 1
-      }
-    }
-    if (becameUnavailable) {
-      logInfo("%s is now unavailable on executor %s (%d/%d, %s)".format(
-        this, execId, _numAvailableOutputs, numPartitions, isAvailable))
-    }
+    mapOutputTrackerMaster
+      .findMissingPartitions(shuffleDep.shuffleId)
+      .getOrElse(0 until numPartitions)
   }
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index f3033e28b47d0ac3505a2f0c21d56149ffc60aae..629cfc7c7a8ceea949429e8071e52f351112dd44 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -129,7 +129,7 @@ private[spark] class TaskSchedulerImpl private[scheduler](
 
   var backend: SchedulerBackend = null
 
-  val mapOutputTracker = SparkEnv.get.mapOutputTracker
+  val mapOutputTracker = SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
 
   private var schedulableBuilder: SchedulableBuilder = null
   // default scheduler is FIFO
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index 4fe5c5e4fee4a80da9988a6768f4429fe9140414..bc3d23e3fbb294efd71b88968e13cc1986c62736 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -139,21 +139,21 @@ class MapOutputTrackerSuite extends SparkFunSuite {
       slaveRpcEnv.setupEndpointRef(rpcEnv.address, MapOutputTracker.ENDPOINT_NAME)
 
     masterTracker.registerShuffle(10, 1)
-    masterTracker.incrementEpoch()
     slaveTracker.updateEpoch(masterTracker.getEpoch)
+    // This is expected to fail because no outputs have been registered for the shuffle.
     intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) }
 
     val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L))
     masterTracker.registerMapOutput(10, 0, MapStatus(
       BlockManagerId("a", "hostA", 1000), Array(1000L)))
-    masterTracker.incrementEpoch()
     slaveTracker.updateEpoch(masterTracker.getEpoch)
     assert(slaveTracker.getMapSizesByExecutorId(10, 0) ===
       Seq((BlockManagerId("a", "hostA", 1000), ArrayBuffer((ShuffleBlockId(10, 0, 0), size1000)))))
     assert(0 == masterTracker.getNumCachedSerializedBroadcast)
 
+    val masterTrackerEpochBeforeLossOfMapOutput = masterTracker.getEpoch
     masterTracker.unregisterMapOutput(10, 0, BlockManagerId("a", "hostA", 1000))
-    masterTracker.incrementEpoch()
+    assert(masterTracker.getEpoch > masterTrackerEpochBeforeLossOfMapOutput)
     slaveTracker.updateEpoch(masterTracker.getEpoch)
     intercept[FetchFailedException] { slaveTracker.getMapSizesByExecutorId(10, 0) }
 
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 622f7985ba444ca3250674fb073934b520f83c27..3931d53b4ae0aedc33299c5c8d8ca657a028708d 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -359,6 +359,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
     val shuffleMapRdd = new MyRDD(sc, 1, Nil)
     val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1))
     val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep)
+    mapTrackerMaster.registerShuffle(0, 1)
 
     // first attempt -- its successful
     val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
@@ -393,7 +394,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
 
     // register one of the map outputs -- doesn't matter which one
     mapOutput1.foreach { case mapStatus =>
-      mapTrackerMaster.registerMapOutputs(0, Array(mapStatus))
+      mapTrackerMaster.registerMapOutput(0, 0, mapStatus)
     }
 
     val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,
diff --git a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
index 2b18ebee79a2b4874be7c31f1cb8726b3256938a..571c6bbb4585d811d2be09173b4ef4829303b848 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/BlacklistTrackerSuite.scala
@@ -86,7 +86,8 @@ class BlacklistTrackerSuite extends SparkFunSuite with BeforeAndAfterEach with M
     sc = new SparkContext(conf)
     val scheduler = mock[TaskSchedulerImpl]
     when(scheduler.sc).thenReturn(sc)
-    when(scheduler.mapOutputTracker).thenReturn(SparkEnv.get.mapOutputTracker)
+    when(scheduler.mapOutputTracker).thenReturn(
+      SparkEnv.get.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster])
     scheduler
   }