diff --git a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala index 490d45d12b8e3c49ac3496f7aef722d8de638f9a..3db59837fbebd4e0c1daeb52e8f95f4ef9a6417f 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala @@ -371,6 +371,12 @@ private[storage] class BlockInfoManager extends Logging { blocksWithReleasedLocks } + /** Returns the number of locks held by the given task. Used only for testing. */ + private[storage] def getTaskLockCount(taskAttemptId: TaskAttemptId): Int = { + readLocksByTask.get(taskAttemptId).map(_.size()).getOrElse(0) + + writeLocksByTask.get(taskAttemptId).map(_.size).getOrElse(0) + } + /** * Returns the number of blocks tracked. */ diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 245d94ac4f8b11f42285bb012eb2f2de8d0c900f..991346a40af4e29bbe300fd7db97b7a3fb404ff8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -1187,7 +1187,7 @@ private[spark] class BlockManager( blockId: BlockId, existingReplicas: Set[BlockManagerId], maxReplicas: Int): Unit = { - logInfo(s"Pro-actively replicating $blockId") + logInfo(s"Using $blockManagerId to pro-actively replicate $blockId") blockInfoManager.lockForReading(blockId).foreach { info => val data = doGetLocalBytes(blockId, info) val storageLevel = StorageLevel( @@ -1196,9 +1196,13 @@ private[spark] class BlockManager( useOffHeap = info.level.useOffHeap, deserialized = info.level.deserialized, replication = maxReplicas) + // we know we are called as a result of an executor removal, so we refresh peer cache + // this way, we won't try to replicate to a missing executor with a stale reference + getPeers(forceFetch = true) try { replicate(blockId, data, storageLevel, info.classTag, existingReplicas) } finally { + logDebug(s"Releasing lock for $blockId") releaseLock(blockId) } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index d907add920c8a15ce695a390f78f854acb22b77e..d5715f8469f717bbb91c584705008637e38634c9 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -493,27 +493,34 @@ class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehav assert(blockLocations.size === replicationFactor) // remove a random blockManager - val executorsToRemove = blockLocations.take(replicationFactor - 1) + val executorsToRemove = blockLocations.take(replicationFactor - 1).toSet logInfo(s"Removing $executorsToRemove") - executorsToRemove.foreach{exec => - master.removeExecutor(exec.executorId) + initialStores.filter(bm => executorsToRemove.contains(bm.blockManagerId)).foreach { bm => + master.removeExecutor(bm.blockManagerId.executorId) + bm.stop() // giving enough time for replication to happen and new block be reported to master - Thread.sleep(200) + eventually(timeout(5 seconds), interval(100 millis)) { + val newLocations = master.getLocations(blockId).toSet + assert(newLocations.size === replicationFactor) + } } - val newLocations = eventually(timeout(5 seconds), interval(10 millis)) { + val newLocations = eventually(timeout(5 seconds), interval(100 millis)) { val _newLocations = master.getLocations(blockId).toSet assert(_newLocations.size === replicationFactor) _newLocations } logInfo(s"New locations : $newLocations") - // there should only be one common block manager between initial and new locations - assert(newLocations.intersect(blockLocations.toSet).size === 1) - // check if all the read locks have been released - initialStores.filter(bm => newLocations.contains(bm.blockManagerId)).foreach { bm => - val locks = bm.releaseAllLocksForTask(BlockInfo.NON_TASK_WRITER) - assert(locks.size === 0, "Read locks unreleased!") + // new locations should not contain stopped block managers + assert(newLocations.forall(bmId => !executorsToRemove.contains(bmId)), + "New locations contain stopped block managers.") + + // Make sure all locks have been released. + eventually(timeout(1000 milliseconds), interval(10 milliseconds)) { + initialStores.filter(bm => newLocations.contains(bm.blockManagerId)).foreach { bm => + assert(bm.blockInfoManager.getTaskLockCount(BlockInfo.NON_TASK_WRITER) === 0) + } } } }