diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index b78ae1f3fc150584e8085f509769655c090c85a5..f10a41286c52fa4047ad922b9abc6b47c51c1cbc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -799,9 +799,19 @@ private[deploy] class Master( } private def relaunchDriver(driver: DriverInfo) { - driver.worker = None - driver.state = DriverState.RELAUNCHING - waitingDrivers += driver + // We must setup a new driver with a new driver id here, because the original driver may + // be still running. Consider this scenario: a worker is network partitioned with master, + // the master then relaunches driver driverID1 with a driver id driverID2, then the worker + // reconnects to master. From this point on, if driverID2 is equal to driverID1, then master + // can not distinguish the statusUpdate of the original driver and the newly relaunched one, + // for example, when DriverStateChanged(driverID1, KILLED) arrives at master, master will + // remove driverID1, so the newly relaunched driver disappears too. See SPARK-19900 for details. + removeDriver(driver.id, DriverState.RELAUNCHING, None) + val newDriver = createDriver(driver.desc) + persistenceEngine.addDriver(newDriver) + drivers.add(newDriver) + waitingDrivers += newDriver + schedule() } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index 4f432e4cf21c79c3c2986f57a7b9e5e13eb2585a..6bb0eec0407871d879b6de865a2d6c3001fc3c7f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -19,8 +19,10 @@ package org.apache.spark.deploy.master import java.util.Date import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.collection.mutable.{HashMap, HashSet} import scala.concurrent.duration._ import scala.io.Source @@ -40,6 +42,49 @@ import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEndpointRef, RpcEnv} import org.apache.spark.serializer +object MockWorker { + val counter = new AtomicInteger(10000) +} + +class MockWorker(master: RpcEndpointRef, conf: SparkConf = new SparkConf) extends RpcEndpoint { + val seq = MockWorker.counter.incrementAndGet() + val id = seq.toString + override val rpcEnv: RpcEnv = RpcEnv.create("worker", "localhost", seq, + conf, new SecurityManager(conf)) + var apps = new mutable.HashMap[String, String]() + val driverIdToAppId = new mutable.HashMap[String, String]() + def newDriver(driverId: String): RpcEndpointRef = { + val name = s"driver_${drivers.size}" + rpcEnv.setupEndpoint(name, new RpcEndpoint { + override val rpcEnv: RpcEnv = MockWorker.this.rpcEnv + override def receive: PartialFunction[Any, Unit] = { + case RegisteredApplication(appId, _) => + apps(appId) = appId + driverIdToAppId(driverId) = appId + } + }) + } + + val appDesc = DeployTestUtils.createAppDesc() + val drivers = mutable.HashSet[String]() + override def receive: PartialFunction[Any, Unit] = { + case RegisteredWorker(masterRef, _, _) => + masterRef.send(WorkerLatestState(id, Nil, drivers.toSeq)) + case LaunchDriver(driverId, desc) => + drivers += driverId + master.send(RegisterApplication(appDesc, newDriver(driverId))) + case KillDriver(driverId) => + master.send(DriverStateChanged(driverId, DriverState.KILLED, None)) + drivers -= driverId + driverIdToAppId.get(driverId) match { + case Some(appId) => + apps.remove(appId) + master.send(UnregisterApplication(appId)) + } + driverIdToAppId.remove(driverId) + } +} + class MasterSuite extends SparkFunSuite with Matchers with Eventually with PrivateMethodTester with BeforeAndAfter { @@ -588,6 +633,70 @@ class MasterSuite extends SparkFunSuite } } + test("SPARK-19900: there should be a corresponding driver for the app after relaunching driver") { + val conf = new SparkConf().set("spark.worker.timeout", "1") + val master = makeMaster(conf) + master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master) + eventually(timeout(10.seconds)) { + val masterState = master.self.askSync[MasterStateResponse](RequestMasterState) + assert(masterState.status === RecoveryState.ALIVE, "Master is not alive") + } + val worker1 = new MockWorker(master.self) + worker1.rpcEnv.setupEndpoint("worker", worker1) + val worker1Reg = RegisterWorker( + worker1.id, + "localhost", + 9998, + worker1.self, + 10, + 1024, + "http://localhost:8080", + RpcAddress("localhost2", 10000)) + master.self.send(worker1Reg) + val driver = DeployTestUtils.createDriverDesc().copy(supervise = true) + master.self.askSync[SubmitDriverResponse](RequestSubmitDriver(driver)) + + eventually(timeout(10.seconds)) { + assert(worker1.apps.nonEmpty) + } + + eventually(timeout(10.seconds)) { + val masterState = master.self.askSync[MasterStateResponse](RequestMasterState) + assert(masterState.workers(0).state == WorkerState.DEAD) + } + + val worker2 = new MockWorker(master.self) + worker2.rpcEnv.setupEndpoint("worker", worker2) + master.self.send(RegisterWorker( + worker2.id, + "localhost", + 9999, + worker2.self, + 10, + 1024, + "http://localhost:8081", + RpcAddress("localhost", 10001))) + eventually(timeout(10.seconds)) { + assert(worker2.apps.nonEmpty) + } + + master.self.send(worker1Reg) + eventually(timeout(10.seconds)) { + val masterState = master.self.askSync[MasterStateResponse](RequestMasterState) + + val worker = masterState.workers.filter(w => w.id == worker1.id) + assert(worker.length == 1) + // make sure the `DriverStateChanged` arrives at Master. + assert(worker(0).drivers.isEmpty) + assert(worker1.apps.isEmpty) + assert(worker1.drivers.isEmpty) + assert(worker2.apps.size == 1) + assert(worker2.drivers.size == 1) + assert(masterState.activeDrivers.length == 1) + assert(masterState.activeApps.length == 1) + } + } + private def getDrivers(master: Master): HashSet[DriverInfo] = { master.invokePrivate(_drivers()) }