diff --git a/core/src/main/scala/spark/MapOutputTracker.scala b/core/src/main/scala/spark/MapOutputTracker.scala index 70eb9f702e689852d7f9bd21c55ac1750bd3e6cb..9f2aa76830ff4d457b85afb3cf146d7660697b3d 100644 --- a/core/src/main/scala/spark/MapOutputTracker.scala +++ b/core/src/main/scala/spark/MapOutputTracker.scala @@ -139,8 +139,8 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea case e: InterruptedException => } } - return mapStatuses.get(shuffleId).map(status => - (status.address, MapOutputTracker.decompressSize(status.compressedSizes(reduceId)))) + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, + mapStatuses.get(shuffleId)) } else { fetching += shuffleId } @@ -156,21 +156,15 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea fetchedStatuses = deserializeStatuses(fetchedBytes) logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) - if (fetchedStatuses.contains(null)) { - throw new FetchFailedException(null, shuffleId, -1, reduceId, - new Exception("Missing an output location for shuffle " + shuffleId)) - } } finally { fetching.synchronized { fetching -= shuffleId fetching.notifyAll() } } - return fetchedStatuses.map(s => - (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId)))) + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, fetchedStatuses) } else { - return statuses.map(s => - (s.address, MapOutputTracker.decompressSize(s.compressedSizes(reduceId)))) + return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) } } @@ -258,6 +252,28 @@ private[spark] class MapOutputTracker(actorSystem: ActorSystem, isMaster: Boolea private[spark] object MapOutputTracker { private val LOG_BASE = 1.1 + // Convert an array of MapStatuses to locations and sizes for a given reduce ID. If + // any of the statuses is null (indicating a missing location due to a failed mapper), + // throw a FetchFailedException. + def convertMapStatuses( + shuffleId: Int, + reduceId: Int, + statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { + if (statuses == null) { + throw new FetchFailedException(null, shuffleId, -1, reduceId, + new Exception("Missing all output locations for shuffle " + shuffleId)) + } + statuses.map { + status => + if (status == null) { + throw new FetchFailedException(null, shuffleId, -1, reduceId, + new Exception("Missing an output location for shuffle " + shuffleId)) + } else { + (status.address, decompressSize(status.compressedSizes(reduceId))) + } + } + } + /** * Compress a size in bytes to 8 bits for efficient reporting of map output sizes. * We do this by encoding the log base 1.1 of the size as an integer, which can support diff --git a/core/src/test/scala/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/spark/MapOutputTrackerSuite.scala index 5b4b19896046d204451ad33ad9fc8b0662c6d082..d3dd3a8fa4930cdc5dcf2cd8b656d5ce03cba272 100644 --- a/core/src/test/scala/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/spark/MapOutputTrackerSuite.scala @@ -1,12 +1,18 @@ package spark import org.scalatest.FunSuite +import org.scalatest.BeforeAndAfter import akka.actor._ import spark.scheduler.MapStatus import spark.storage.BlockManagerId +import spark.util.AkkaUtils -class MapOutputTrackerSuite extends FunSuite { +class MapOutputTrackerSuite extends FunSuite with BeforeAndAfter { + after { + System.clearProperty("spark.master.port") + } + test("compressSize") { assert(MapOutputTracker.compressSize(0L) === 0) assert(MapOutputTracker.compressSize(1L) === 1) @@ -71,6 +77,36 @@ class MapOutputTrackerSuite extends FunSuite { // The remaining reduce task might try to grab the output dispite the shuffle failure; // this should cause it to fail, and the scheduler will ignore the failure due to the // stage already being aborted. - intercept[Exception] { tracker.getServerStatuses(10, 1) } + intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } + } + + test("remote fetch") { + System.clearProperty("spark.master.host") + val (actorSystem, boundPort) = + AkkaUtils.createActorSystem("test", "localhost", 0) + System.setProperty("spark.master.port", boundPort.toString) + val masterTracker = new MapOutputTracker(actorSystem, true) + val slaveTracker = new MapOutputTracker(actorSystem, false) + masterTracker.registerShuffle(10, 1) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + + val compressedSize1000 = MapOutputTracker.compressSize(1000L) + val size1000 = MapOutputTracker.decompressSize(compressedSize1000) + masterTracker.registerMapOutput(10, 0, new MapStatus( + new BlockManagerId("hostA", 1000), Array(compressedSize1000))) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + assert(slaveTracker.getServerStatuses(10, 0).toSeq === + Seq((new BlockManagerId("hostA", 1000), size1000))) + + masterTracker.unregisterMapOutput(10, 0, new BlockManagerId("hostA", 1000)) + masterTracker.incrementGeneration() + slaveTracker.updateGeneration(masterTracker.getGeneration) + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } + + // failure should be cached + intercept[FetchFailedException] { slaveTracker.getServerStatuses(10, 0) } } }