diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
index 490d45d12b8e3c49ac3496f7aef722d8de638f9a..3db59837fbebd4e0c1daeb52e8f95f4ef9a6417f 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala
@@ -371,6 +371,12 @@ private[storage] class BlockInfoManager extends Logging {
     blocksWithReleasedLocks
   }
 
+  /** Returns the number of locks held by the given task.  Used only for testing. */
+  private[storage] def getTaskLockCount(taskAttemptId: TaskAttemptId): Int = {
+    readLocksByTask.get(taskAttemptId).map(_.size()).getOrElse(0) +
+      writeLocksByTask.get(taskAttemptId).map(_.size).getOrElse(0)
+  }
+
   /**
    * Returns the number of blocks tracked.
    */
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 245d94ac4f8b11f42285bb012eb2f2de8d0c900f..991346a40af4e29bbe300fd7db97b7a3fb404ff8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -1187,7 +1187,7 @@ private[spark] class BlockManager(
       blockId: BlockId,
       existingReplicas: Set[BlockManagerId],
       maxReplicas: Int): Unit = {
-    logInfo(s"Pro-actively replicating $blockId")
+    logInfo(s"Using $blockManagerId to pro-actively replicate $blockId")
     blockInfoManager.lockForReading(blockId).foreach { info =>
       val data = doGetLocalBytes(blockId, info)
       val storageLevel = StorageLevel(
@@ -1196,9 +1196,13 @@ private[spark] class BlockManager(
         useOffHeap = info.level.useOffHeap,
         deserialized = info.level.deserialized,
         replication = maxReplicas)
+      // we know we are called as a result of an executor removal, so we refresh peer cache
+      // this way, we won't try to replicate to a missing executor with a stale reference
+      getPeers(forceFetch = true)
       try {
         replicate(blockId, data, storageLevel, info.classTag, existingReplicas)
       } finally {
+        logDebug(s"Releasing lock for $blockId")
         releaseLock(blockId)
       }
     }
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index d907add920c8a15ce695a390f78f854acb22b77e..d5715f8469f717bbb91c584705008637e38634c9 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -493,27 +493,34 @@ class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehav
     assert(blockLocations.size === replicationFactor)
 
     // remove a random blockManager
-    val executorsToRemove = blockLocations.take(replicationFactor - 1)
+    val executorsToRemove = blockLocations.take(replicationFactor - 1).toSet
     logInfo(s"Removing $executorsToRemove")
-    executorsToRemove.foreach{exec =>
-      master.removeExecutor(exec.executorId)
+    initialStores.filter(bm => executorsToRemove.contains(bm.blockManagerId)).foreach { bm =>
+      master.removeExecutor(bm.blockManagerId.executorId)
+      bm.stop()
       // giving enough time for replication to happen and new block be reported to master
-      Thread.sleep(200)
+      eventually(timeout(5 seconds), interval(100 millis)) {
+        val newLocations = master.getLocations(blockId).toSet
+        assert(newLocations.size === replicationFactor)
+      }
     }
 
-    val newLocations = eventually(timeout(5 seconds), interval(10 millis)) {
+    val newLocations = eventually(timeout(5 seconds), interval(100 millis)) {
       val _newLocations = master.getLocations(blockId).toSet
       assert(_newLocations.size === replicationFactor)
       _newLocations
     }
     logInfo(s"New locations : $newLocations")
-    // there should only be one common block manager between initial and new locations
-    assert(newLocations.intersect(blockLocations.toSet).size === 1)
 
-    // check if all the read locks have been released
-    initialStores.filter(bm => newLocations.contains(bm.blockManagerId)).foreach { bm =>
-      val locks = bm.releaseAllLocksForTask(BlockInfo.NON_TASK_WRITER)
-      assert(locks.size === 0, "Read locks unreleased!")
+    // new locations should not contain stopped block managers
+    assert(newLocations.forall(bmId => !executorsToRemove.contains(bmId)),
+      "New locations contain stopped block managers.")
+
+    // Make sure all locks have been released.
+    eventually(timeout(1000 milliseconds), interval(10 milliseconds)) {
+      initialStores.filter(bm => newLocations.contains(bm.blockManagerId)).foreach { bm =>
+        assert(bm.blockInfoManager.getTaskLockCount(BlockInfo.NON_TASK_WRITER) === 0)
+      }
     }
   }
 }