diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 53384e7373252c5fdf68aff4201b5db8b415caca..b78ae1f3fc150584e8085f509769655c090c85a5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -367,7 +367,7 @@ private[deploy] class Master( drivers.find(_.id == driverId).foreach { driver => driver.worker = Some(worker) driver.state = DriverState.RUNNING - worker.drivers(driverId) = driver + worker.addDriver(driver) } } case None => @@ -547,6 +547,9 @@ private[deploy] class Master( workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker) apps.filter(_.state == ApplicationState.UNKNOWN).foreach(finishApplication) + // Update the state of recovered apps to RUNNING + apps.filter(_.state == ApplicationState.WAITING).foreach(_.state = ApplicationState.RUNNING) + // Reschedule drivers which were not claimed by any workers drivers.filter(_.worker.isEmpty).foreach { d => logWarning(s"Driver ${d.id} was not found after master recovery") diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 539264652d7d54286a13a4a6df53fde6c3fa8a9b..4f432e4cf21c79c3c2986f57a7b9e5e13eb2585a 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -21,12 +21,15 @@ import java.util.Date import java.util.concurrent.ConcurrentLinkedQueue import scala.collection.JavaConverters._ +import scala.collection.mutable.{HashMap, HashSet} import scala.concurrent.duration._ import scala.io.Source import scala.language.postfixOps +import scala.reflect.ClassTag import org.json4s._ import org.json4s.jackson.JsonMethods._ +import org.mockito.Mockito.{mock, when} import org.scalatest.{BeforeAndAfter, Matchers, PrivateMethodTester} import org.scalatest.concurrent.Eventually import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} @@ -34,7 +37,8 @@ import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} import org.apache.spark.deploy._ import org.apache.spark.deploy.DeployMessages._ -import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEnv} +import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEndpointRef, RpcEnv} +import org.apache.spark.serializer class MasterSuite extends SparkFunSuite with Matchers with Eventually with PrivateMethodTester with BeforeAndAfter { @@ -134,6 +138,81 @@ class MasterSuite extends SparkFunSuite CustomRecoveryModeFactory.instantiationAttempts should be > instantiationAttempts } + test("master correctly recover the application") { + val conf = new SparkConf(loadDefaults = false) + conf.set("spark.deploy.recoveryMode", "CUSTOM") + conf.set("spark.deploy.recoveryMode.factory", + classOf[FakeRecoveryModeFactory].getCanonicalName) + conf.set("spark.master.rest.enabled", "false") + + val fakeAppInfo = makeAppInfo(1024) + val fakeWorkerInfo = makeWorkerInfo(8192, 16) + val fakeDriverInfo = new DriverInfo( + startTime = 0, + id = "test_driver", + desc = new DriverDescription( + jarUrl = "", + mem = 1024, + cores = 1, + supervise = false, + command = new Command("", Nil, Map.empty, Nil, Nil, Nil)), + submitDate = new Date()) + + // Build the fake recovery data + FakeRecoveryModeFactory.persistentData.put(s"app_${fakeAppInfo.id}", fakeAppInfo) + FakeRecoveryModeFactory.persistentData.put(s"driver_${fakeDriverInfo.id}", fakeDriverInfo) + FakeRecoveryModeFactory.persistentData.put(s"worker_${fakeWorkerInfo.id}", fakeWorkerInfo) + + var master: Master = null + try { + master = makeMaster(conf) + master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) + // Wait until Master recover from checkpoint data. + eventually(timeout(5 seconds), interval(100 milliseconds)) { + master.idToApp.size should be(1) + } + + master.idToApp.keySet should be(Set(fakeAppInfo.id)) + getDrivers(master) should be(Set(fakeDriverInfo)) + master.workers should be(Set(fakeWorkerInfo)) + + // Notify Master about the executor and driver info to make it correctly recovered. + val fakeExecutors = List( + new ExecutorDescription(fakeAppInfo.id, 0, 8, ExecutorState.RUNNING), + new ExecutorDescription(fakeAppInfo.id, 0, 7, ExecutorState.RUNNING)) + + fakeAppInfo.state should be(ApplicationState.UNKNOWN) + fakeWorkerInfo.coresFree should be(16) + fakeWorkerInfo.coresUsed should be(0) + + master.self.send(MasterChangeAcknowledged(fakeAppInfo.id)) + eventually(timeout(1 second), interval(10 milliseconds)) { + // Application state should be WAITING when "MasterChangeAcknowledged" event executed. + fakeAppInfo.state should be(ApplicationState.WAITING) + } + + master.self.send( + WorkerSchedulerStateResponse(fakeWorkerInfo.id, fakeExecutors, Seq(fakeDriverInfo.id))) + + eventually(timeout(5 seconds), interval(100 milliseconds)) { + getState(master) should be(RecoveryState.ALIVE) + } + + // If driver's resource is also counted, free cores should 0 + fakeWorkerInfo.coresFree should be(0) + fakeWorkerInfo.coresUsed should be(16) + // State of application should be RUNNING + fakeAppInfo.state should be(ApplicationState.RUNNING) + } finally { + if (master != null) { + master.rpcEnv.shutdown() + master.rpcEnv.awaitTermination() + master = null + FakeRecoveryModeFactory.persistentData.clear() + } + } + } + test("master/worker web ui available") { implicit val formats = org.json4s.DefaultFormats val conf = new SparkConf() @@ -394,6 +473,9 @@ class MasterSuite extends SparkFunSuite // ========================================== private val _scheduleExecutorsOnWorkers = PrivateMethod[Array[Int]]('scheduleExecutorsOnWorkers) + private val _drivers = PrivateMethod[HashSet[DriverInfo]]('drivers) + private val _state = PrivateMethod[RecoveryState.Value]('state) + private val workerInfo = makeWorkerInfo(4096, 10) private val workerInfos = Array(workerInfo, workerInfo, workerInfo) @@ -412,12 +494,18 @@ class MasterSuite extends SparkFunSuite val desc = new ApplicationDescription( "test", maxCores, memoryPerExecutorMb, null, "", None, None, coresPerExecutor) val appId = System.currentTimeMillis.toString - new ApplicationInfo(0, appId, desc, new Date, null, Int.MaxValue) + val endpointRef = mock(classOf[RpcEndpointRef]) + val mockAddress = mock(classOf[RpcAddress]) + when(endpointRef.address).thenReturn(mockAddress) + new ApplicationInfo(0, appId, desc, new Date, endpointRef, Int.MaxValue) } private def makeWorkerInfo(memoryMb: Int, cores: Int): WorkerInfo = { val workerId = System.currentTimeMillis.toString - new WorkerInfo(workerId, "host", 100, cores, memoryMb, null, "http://localhost:80") + val endpointRef = mock(classOf[RpcEndpointRef]) + val mockAddress = mock(classOf[RpcAddress]) + when(endpointRef.address).thenReturn(mockAddress) + new WorkerInfo(workerId, "host", 100, cores, memoryMb, endpointRef, "http://localhost:80") } private def scheduleExecutorsOnWorkers( @@ -499,4 +587,40 @@ class MasterSuite extends SparkFunSuite assert(receivedMasterAddress === RpcAddress("localhost2", 10000)) } } + + private def getDrivers(master: Master): HashSet[DriverInfo] = { + master.invokePrivate(_drivers()) + } + + private def getState(master: Master): RecoveryState.Value = { + master.invokePrivate(_state()) + } +} + +private class FakeRecoveryModeFactory(conf: SparkConf, ser: serializer.Serializer) + extends StandaloneRecoveryModeFactory(conf, ser) { + import FakeRecoveryModeFactory.persistentData + + override def createPersistenceEngine(): PersistenceEngine = new PersistenceEngine { + + override def unpersist(name: String): Unit = { + persistentData.remove(name) + } + + override def persist(name: String, obj: Object): Unit = { + persistentData(name) = obj + } + + override def read[T: ClassTag](prefix: String): Seq[T] = { + persistentData.filter(_._1.startsWith(prefix)).map(_._2.asInstanceOf[T]).toSeq + } + } + + override def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent = { + new MonarchyLeaderAgent(master) + } +} + +private object FakeRecoveryModeFactory { + val persistentData = new HashMap[String, Object]() }