diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 53384e7373252c5fdf68aff4201b5db8b415caca..b78ae1f3fc150584e8085f509769655c090c85a5 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -367,7 +367,7 @@ private[deploy] class Master(
             drivers.find(_.id == driverId).foreach { driver =>
               driver.worker = Some(worker)
               driver.state = DriverState.RUNNING
-              worker.drivers(driverId) = driver
+              worker.addDriver(driver)
             }
           }
         case None =>
@@ -547,6 +547,9 @@ private[deploy] class Master(
     workers.filter(_.state == WorkerState.UNKNOWN).foreach(removeWorker)
     apps.filter(_.state == ApplicationState.UNKNOWN).foreach(finishApplication)
 
+    // Update the state of recovered apps to RUNNING
+    apps.filter(_.state == ApplicationState.WAITING).foreach(_.state = ApplicationState.RUNNING)
+
     // Reschedule drivers which were not claimed by any workers
     drivers.filter(_.worker.isEmpty).foreach { d =>
       logWarning(s"Driver ${d.id} was not found after master recovery")
diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
index 539264652d7d54286a13a4a6df53fde6c3fa8a9b..4f432e4cf21c79c3c2986f57a7b9e5e13eb2585a 100644
--- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala
@@ -21,12 +21,15 @@ import java.util.Date
 import java.util.concurrent.ConcurrentLinkedQueue
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable.{HashMap, HashSet}
 import scala.concurrent.duration._
 import scala.io.Source
 import scala.language.postfixOps
+import scala.reflect.ClassTag
 
 import org.json4s._
 import org.json4s.jackson.JsonMethods._
+import org.mockito.Mockito.{mock, when}
 import org.scalatest.{BeforeAndAfter, Matchers, PrivateMethodTester}
 import org.scalatest.concurrent.Eventually
 import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory}
@@ -34,7 +37,8 @@ import other.supplier.{CustomPersistenceEngine, CustomRecoveryModeFactory}
 import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite}
 import org.apache.spark.deploy._
 import org.apache.spark.deploy.DeployMessages._
-import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEnv}
+import org.apache.spark.rpc.{RpcAddress, RpcEndpoint, RpcEndpointRef, RpcEnv}
+import org.apache.spark.serializer
 
 class MasterSuite extends SparkFunSuite
   with Matchers with Eventually with PrivateMethodTester with BeforeAndAfter {
@@ -134,6 +138,81 @@ class MasterSuite extends SparkFunSuite
     CustomRecoveryModeFactory.instantiationAttempts should be > instantiationAttempts
   }
 
+  test("master correctly recover the application") {
+    val conf = new SparkConf(loadDefaults = false)
+    conf.set("spark.deploy.recoveryMode", "CUSTOM")
+    conf.set("spark.deploy.recoveryMode.factory",
+      classOf[FakeRecoveryModeFactory].getCanonicalName)
+    conf.set("spark.master.rest.enabled", "false")
+
+    val fakeAppInfo = makeAppInfo(1024)
+    val fakeWorkerInfo = makeWorkerInfo(8192, 16)
+    val fakeDriverInfo = new DriverInfo(
+      startTime = 0,
+      id = "test_driver",
+      desc = new DriverDescription(
+        jarUrl = "",
+        mem = 1024,
+        cores = 1,
+        supervise = false,
+        command = new Command("", Nil, Map.empty, Nil, Nil, Nil)),
+      submitDate = new Date())
+
+    // Build the fake recovery data
+    FakeRecoveryModeFactory.persistentData.put(s"app_${fakeAppInfo.id}", fakeAppInfo)
+    FakeRecoveryModeFactory.persistentData.put(s"driver_${fakeDriverInfo.id}", fakeDriverInfo)
+    FakeRecoveryModeFactory.persistentData.put(s"worker_${fakeWorkerInfo.id}", fakeWorkerInfo)
+
+    var master: Master = null
+    try {
+      master = makeMaster(conf)
+      master.rpcEnv.setupEndpoint(Master.ENDPOINT_NAME, master)
+      // Wait until Master recover from checkpoint data.
+      eventually(timeout(5 seconds), interval(100 milliseconds)) {
+        master.idToApp.size should be(1)
+      }
+
+      master.idToApp.keySet should be(Set(fakeAppInfo.id))
+      getDrivers(master) should be(Set(fakeDriverInfo))
+      master.workers should be(Set(fakeWorkerInfo))
+
+      // Notify Master about the executor and driver info to make it correctly recovered.
+      val fakeExecutors = List(
+        new ExecutorDescription(fakeAppInfo.id, 0, 8, ExecutorState.RUNNING),
+        new ExecutorDescription(fakeAppInfo.id, 0, 7, ExecutorState.RUNNING))
+
+      fakeAppInfo.state should be(ApplicationState.UNKNOWN)
+      fakeWorkerInfo.coresFree should be(16)
+      fakeWorkerInfo.coresUsed should be(0)
+
+      master.self.send(MasterChangeAcknowledged(fakeAppInfo.id))
+      eventually(timeout(1 second), interval(10 milliseconds)) {
+        // Application state should be WAITING when "MasterChangeAcknowledged" event executed.
+        fakeAppInfo.state should be(ApplicationState.WAITING)
+      }
+
+      master.self.send(
+        WorkerSchedulerStateResponse(fakeWorkerInfo.id, fakeExecutors, Seq(fakeDriverInfo.id)))
+
+      eventually(timeout(5 seconds), interval(100 milliseconds)) {
+        getState(master) should be(RecoveryState.ALIVE)
+      }
+
+      // If driver's resource is also counted, free cores should 0
+      fakeWorkerInfo.coresFree should be(0)
+      fakeWorkerInfo.coresUsed should be(16)
+      // State of application should be RUNNING
+      fakeAppInfo.state should be(ApplicationState.RUNNING)
+    } finally {
+      if (master != null) {
+        master.rpcEnv.shutdown()
+        master.rpcEnv.awaitTermination()
+        master = null
+        FakeRecoveryModeFactory.persistentData.clear()
+      }
+    }
+  }
+
   test("master/worker web ui available") {
     implicit val formats = org.json4s.DefaultFormats
     val conf = new SparkConf()
@@ -394,6 +473,9 @@ class MasterSuite extends SparkFunSuite
   // ==========================================
 
   private val _scheduleExecutorsOnWorkers = PrivateMethod[Array[Int]]('scheduleExecutorsOnWorkers)
+  private val _drivers = PrivateMethod[HashSet[DriverInfo]]('drivers)
+  private val _state = PrivateMethod[RecoveryState.Value]('state)
+
   private val workerInfo = makeWorkerInfo(4096, 10)
   private val workerInfos = Array(workerInfo, workerInfo, workerInfo)
 
@@ -412,12 +494,18 @@ class MasterSuite extends SparkFunSuite
     val desc = new ApplicationDescription(
       "test", maxCores, memoryPerExecutorMb, null, "", None, None, coresPerExecutor)
     val appId = System.currentTimeMillis.toString
-    new ApplicationInfo(0, appId, desc, new Date, null, Int.MaxValue)
+    val endpointRef = mock(classOf[RpcEndpointRef])
+    val mockAddress = mock(classOf[RpcAddress])
+    when(endpointRef.address).thenReturn(mockAddress)
+    new ApplicationInfo(0, appId, desc, new Date, endpointRef, Int.MaxValue)
   }
 
   private def makeWorkerInfo(memoryMb: Int, cores: Int): WorkerInfo = {
     val workerId = System.currentTimeMillis.toString
-    new WorkerInfo(workerId, "host", 100, cores, memoryMb, null, "http://localhost:80")
+    val endpointRef = mock(classOf[RpcEndpointRef])
+    val mockAddress = mock(classOf[RpcAddress])
+    when(endpointRef.address).thenReturn(mockAddress)
+    new WorkerInfo(workerId, "host", 100, cores, memoryMb, endpointRef, "http://localhost:80")
   }
 
   private def scheduleExecutorsOnWorkers(
@@ -499,4 +587,40 @@ class MasterSuite extends SparkFunSuite
       assert(receivedMasterAddress === RpcAddress("localhost2", 10000))
     }
   }
+
+  private def getDrivers(master: Master): HashSet[DriverInfo] = {
+    master.invokePrivate(_drivers())
+  }
+
+  private def getState(master: Master): RecoveryState.Value = {
+    master.invokePrivate(_state())
+  }
+}
+
+private class FakeRecoveryModeFactory(conf: SparkConf, ser: serializer.Serializer)
+    extends StandaloneRecoveryModeFactory(conf, ser) {
+  import FakeRecoveryModeFactory.persistentData
+
+  override def createPersistenceEngine(): PersistenceEngine = new PersistenceEngine {
+
+    override def unpersist(name: String): Unit = {
+      persistentData.remove(name)
+    }
+
+    override def persist(name: String, obj: Object): Unit = {
+      persistentData(name) = obj
+    }
+
+    override def read[T: ClassTag](prefix: String): Seq[T] = {
+      persistentData.filter(_._1.startsWith(prefix)).map(_._2.asInstanceOf[T]).toSeq
+    }
+  }
+
+  override def createLeaderElectionAgent(master: LeaderElectable): LeaderElectionAgent = {
+    new MonarchyLeaderAgent(master)
+  }
+}
+
+private object FakeRecoveryModeFactory {
+  val persistentData = new HashMap[String, Object]()
 }