From 29246744061ee96afd5f57e113ad69c354e4ba4a Mon Sep 17 00:00:00 2001
From: Li Yichao <lyc@zhihu.com>
Date: Thu, 15 Jun 2017 08:08:26 +0800
Subject: [PATCH] [SPARK-19900][CORE] Remove driver when relaunching.

This is https://github.com/apache/spark/pull/17888 .

Below are some spark ui snapshots.

Master, after worker disconnects:

<img width="1433" alt="master_disconnect" src="https://cloud.githubusercontent.com/assets/2576762/26398687/d0ee228e-40ac-11e7-986d-d3b57b87029f.png">

Master, after worker reconnects, notice the `running drivers` part:

<img width="1412" alt="master_reconnects" src="https://cloud.githubusercontent.com/assets/2576762/26398697/d50735a4-40ac-11e7-80d8-6e9e1cf0b62f.png">

This patch, after worker disconnects:
<img width="1412" alt="patch_disconnect" src="https://cloud.githubusercontent.com/assets/2576762/26398009/c015d3dc-40aa-11e7-8bb4-df11a1f66645.png">

This patch, after worker reconnects:
![image](https://cloud.githubusercontent.com/assets/2576762/26398037/d313769c-40aa-11e7-8613-5f157d193150.png)

cc cloud-fan jiangxb1987

Author: Li Yichao <lyc@zhihu.com>

Closes #18084 from liyichao/SPARK-19900-1.
---
 .../apache/spark/deploy/master/Master.scala   |  16 ++-
 .../spark/deploy/master/MasterSuite.scala     | 109 ++++++++++++++++++
 2 files changed, 122 insertions(+), 3 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index b78ae1f3fc..f10a41286c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -799,9 +799,19 @@ private[deploy] class Master(
   }
 
   private def relaunchDriver(driver: DriverInfo) {
-    driver.worker = None
-    driver.state = DriverState.RELAUNCHING
-    waitingDrivers += driver
+    // We must setup a new driver with a new driver id here, because the original driver may
+    // be still running. Consider this scenario: a worker is network partitioned with master,
+    // the master then relaunches driver driverID1 with a driver id driverID2, then the worker
+    // reconnects to master. From this point on, if driverID2 is equal to driverID1, then master
+    // can not distinguish the statusUpdate of the original driver and the newly relaunched one,
+    // for example, when DriverStateChanged(driverID1, KILLED) arrives at master, master will
+    // remove driverID1, so the newly relaunched driver disappears too. See SPARK-19900 for details.
+    removeDriver(driver.id, DriverState.RELAUNCHING, None)
+    val newDriver = createDriver(driver.desc)
+    persistenceEngine.addDriver(newDriver)
+    drivers.add(newDriver)
+    waitingDrivers += newDriver
+
     schedule()
   }
 
diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
index 4f432e4cf2..6bb0eec040 100644
--- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
@@ -19,8 +19,10 @@ package org.apache.spark.deploy.master
 
 import java.util.Date
 import java.util.concurrent.ConcurrentLinkedQueue
+import java.util.concurrent.atomic.AtomicInteger
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 import scala.collection.mutable.{HashMap, HashSet}
 import scala.concurrent.duration._
 import scala.io.Source
@@ -40,6 +42,49 @@ import org.apache.spark.deploy.DeployMessages._
 import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEndpointRef, RpcEnv}
 import org.apache.spark.serializer
 
+object MockWorker {
+  val counter = new AtomicInteger(10000)
+}
+
+class MockWorker(master: RpcEndpointRef, conf: SparkConf = new SparkConf) extends RpcEndpoint {
+  val seq = MockWorker.counter.incrementAndGet()
+  val id = seq.toString
+  override val rpcEnv: RpcEnv = RpcEnv.create("worker", "localhost", seq,
+    conf, new SecurityManager(conf))
+  var apps = new mutable.HashMap[String, String]()
+  val driverIdToAppId = new mutable.HashMap[String, String]()
+  def newDriver(driverId: String): RpcEndpointRef = {
+    val name = s"driver_${drivers.size}"
+    rpcEnv.setupEndpoint(name, new RpcEndpoint {
+      override val rpcEnv: RpcEnv = MockWorker.this.rpcEnv
+      override def receive: PartialFunction[Any, Unit] = {
+        case RegisteredApplication(appId, _) =>
+          apps(appId) = appId
+          driverIdToAppId(driverId) = appId
+      }
+    })
+  }
+
+  val appDesc = DeployTestUtils.createAppDesc()
+  val drivers = mutable.HashSet[String]()
+  override def receive: PartialFunction[Any, Unit] = {
+    case RegisteredWorker(masterRef, _, _) =>
+      masterRef.send(WorkerLatestState(id, Nil, drivers.toSeq))
+    case LaunchDriver(driverId, desc) =>
+      drivers += driverId
+      master.send(RegisterApplication(appDesc, newDriver(driverId)))
+    case KillDriver(driverId) =>
+      master.send(DriverStateChanged(driverId, DriverState.KILLED, None))
+      drivers -= driverId
+      driverIdToAppId.get(driverId) match {
+        case Some(appId) =>
+          apps.remove(appId)
+          master.send(UnregisterApplication(appId))
+      }
+      driverIdToAppId.remove(driverId)
+  }
+}
+
 class MasterSuite extends SparkFunSuite
   with Matchers with Eventually with PrivateMethodTester with BeforeAndAfter {
 
@@ -588,6 +633,70 @@ class MasterSuite extends SparkFunSuite
     }
   }
 
+  test("SPARK-19900: there should be a corresponding driver for the app after relaunching driver") {
+    val conf = new SparkConf().set("spark.worker.timeout", "1")
+    val master = makeMaster(conf)
+    master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master)
+    eventually(timeout(10.seconds)) {
+      val masterState = master.self.askSync[MasterStateResponse](RequestMasterState)
+      assert(masterState.status === RecoveryState.ALIVE, "Master is not alive")
+    }
+    val worker1 = new MockWorker(master.self)
+    worker1.rpcEnv.setupEndpoint("worker", worker1)
+    val worker1Reg = RegisterWorker(
+      worker1.id,
+      "localhost",
+      9998,
+      worker1.self,
+      10,
+      1024,
+      "http://localhost:8080",
+      RpcAddress("localhost2", 10000))
+    master.self.send(worker1Reg)
+    val driver = DeployTestUtils.createDriverDesc().copy(supervise = true)
+    master.self.askSync[SubmitDriverResponse](RequestSubmitDriver(driver))
+
+    eventually(timeout(10.seconds)) {
+      assert(worker1.apps.nonEmpty)
+    }
+
+    eventually(timeout(10.seconds)) {
+      val masterState = master.self.askSync[MasterStateResponse](RequestMasterState)
+      assert(masterState.workers(0).state == WorkerState.DEAD)
+    }
+
+    val worker2 = new MockWorker(master.self)
+    worker2.rpcEnv.setupEndpoint("worker", worker2)
+    master.self.send(RegisterWorker(
+      worker2.id,
+      "localhost",
+      9999,
+      worker2.self,
+      10,
+      1024,
+      "http://localhost:8081",
+      RpcAddress("localhost", 10001)))
+    eventually(timeout(10.seconds)) {
+      assert(worker2.apps.nonEmpty)
+    }
+
+    master.self.send(worker1Reg)
+    eventually(timeout(10.seconds)) {
+      val masterState = master.self.askSync[MasterStateResponse](RequestMasterState)
+
+      val worker = masterState.workers.filter(w => w.id == worker1.id)
+      assert(worker.length == 1)
+      // make sure the `DriverStateChanged` arrives at Master.
+      assert(worker(0).drivers.isEmpty)
+      assert(worker1.apps.isEmpty)
+      assert(worker1.drivers.isEmpty)
+      assert(worker2.apps.size == 1)
+      assert(worker2.drivers.size == 1)
+      assert(masterState.activeDrivers.length == 1)
+      assert(masterState.activeApps.length == 1)
+    }
+  }
+
   private def getDrivers(master: Master): HashSet[DriverInfo] = {
     master.invokePrivate(_drivers())
   }
-- 
GitLab