diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index eb2201d142ffb097cf4628ecb934f7c8c5170294..651e9c7b2ab6173861e3d8b555ca0b69fe1382ad 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.{UNROLL_MEMORY_CHECK_PERIOD, UNROLL_MEMORY_GROWTH_FACTOR} import org.apache.spark.memory.{MemoryManager, MemoryMode} import org.apache.spark.serializer.{SerializationStream, SerializerManager} -import org.apache.spark.storage.{BlockId, BlockInfoManager, StorageLevel, StreamBlockId} +import org.apache.spark.storage._ import org.apache.spark.unsafe.Platform import org.apache.spark.util.{SizeEstimator, Utils} import org.apache.spark.util.collection.SizeTrackingVector @@ -544,20 +544,38 @@ private[spark] class MemoryStore( } if (freedMemory >= space) { - logInfo(s"${selectedBlocks.size} blocks selected for dropping " + - s"(${Utils.bytesToString(freedMemory)} bytes)") - for (blockId <- selectedBlocks) { - val entry = entries.synchronized { entries.get(blockId) } - // This should never be null as only one task should be dropping - // blocks and removing entries. However the check is still here for - // future safety. - if (entry != null) { - dropBlock(blockId, entry) + var lastSuccessfulBlock = -1 + try { + logInfo(s"${selectedBlocks.size} blocks selected for dropping " + + s"(${Utils.bytesToString(freedMemory)} bytes)") + (0 until selectedBlocks.size).foreach { idx => + val blockId = selectedBlocks(idx) + val entry = entries.synchronized { + entries.get(blockId) + } + // This should never be null as only one task should be dropping + // blocks and removing entries. However the check is still here for + // future safety. + if (entry != null) { + dropBlock(blockId, entry) + afterDropAction(blockId) + } + lastSuccessfulBlock = idx + } + logInfo(s"After dropping ${selectedBlocks.size} blocks, " + + s"free memory is ${Utils.bytesToString(maxMemory - blocksMemoryUsed)}") + freedMemory + } finally { + // like BlockManager.doPut, we use a finally rather than a catch to avoid having to deal + // with InterruptedException + if (lastSuccessfulBlock != selectedBlocks.size - 1) { + // the blocks we didn't process successfully are still locked, so we have to unlock them + (lastSuccessfulBlock + 1 until selectedBlocks.size).foreach { idx => + val blockId = selectedBlocks(idx) + blockInfoManager.unlock(blockId) + } } } - logInfo(s"After dropping ${selectedBlocks.size} blocks, " + - s"free memory is ${Utils.bytesToString(maxMemory - blocksMemoryUsed)}") - freedMemory } else { blockId.foreach { id => logInfo(s"Will not store $id") @@ -570,6 +588,9 @@ private[spark] class MemoryStore( } } + // hook for testing, so we can simulate a race + protected def afterDropAction(blockId: BlockId): Unit = {} + def contains(blockId: BlockId): Boolean = { entries.synchronized { entries.containsKey(blockId) } } diff --git a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala index 9929ea033a99f3996cb8bb057d0c15b71184cc27..7274072e5049a26ade0f2bc7c1ac935aea5ea281 100644 --- a/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/MemoryStoreSuite.scala @@ -407,4 +407,123 @@ class MemoryStoreSuite }) assert(memoryStore.getSize(blockId) === 10000) } + + test("SPARK-22083: Release all locks in evictBlocksToFreeSpace") { + // Setup a memory store with many blocks cached, and then one request which leads to multiple + // blocks getting evicted. We'll make the eviction throw an exception, and make sure that + // all locks are released. + val ct = implicitly[ClassTag[Array[Byte]]] + val numInitialBlocks = 10 + val memStoreSize = 100 + val bytesPerSmallBlock = memStoreSize / numInitialBlocks + def testFailureOnNthDrop(numValidBlocks: Int, readLockAfterDrop: Boolean): Unit = { + val tc = TaskContext.empty() + val memManager = new StaticMemoryManager(conf, Long.MaxValue, memStoreSize, numCores = 1) + val blockInfoManager = new BlockInfoManager + blockInfoManager.registerTask(tc.taskAttemptId) + var droppedSoFar = 0 + val blockEvictionHandler = new BlockEvictionHandler { + var memoryStore: MemoryStore = _ + + override private[storage] def dropFromMemory[T: ClassTag]( + blockId: BlockId, + data: () => Either[Array[T], ChunkedByteBuffer]): StorageLevel = { + if (droppedSoFar < numValidBlocks) { + droppedSoFar += 1 + memoryStore.remove(blockId) + if (readLockAfterDrop) { + // for testing purposes, we act like another thread gets the read lock on the new + // block + StorageLevel.DISK_ONLY + } else { + StorageLevel.NONE + } + } else { + throw new RuntimeException(s"Mock error dropping block $droppedSoFar") + } + } + } + val memoryStore = new MemoryStore(conf, blockInfoManager, serializerManager, memManager, + blockEvictionHandler) { + override def afterDropAction(blockId: BlockId): Unit = { + if (readLockAfterDrop) { + // pretend that we get a read lock on the block (now on disk) in another thread + TaskContext.setTaskContext(tc) + blockInfoManager.lockForReading(blockId) + TaskContext.unset() + } + } + } + + blockEvictionHandler.memoryStore = memoryStore + memManager.setMemoryStore(memoryStore) + + // Put in some small blocks to fill up the memory store + val initialBlocks = (1 to numInitialBlocks).map { id => + val blockId = BlockId(s"rdd_1_$id") + val blockInfo = new BlockInfo(StorageLevel.MEMORY_ONLY, ct, tellMaster = false) + val initialWriteLock = blockInfoManager.lockNewBlockForWriting(blockId, blockInfo) + assert(initialWriteLock) + val success = memoryStore.putBytes(blockId, bytesPerSmallBlock, MemoryMode.ON_HEAP, () => { + new ChunkedByteBuffer(ByteBuffer.allocate(bytesPerSmallBlock)) + }) + assert(success) + blockInfoManager.unlock(blockId, None) + } + assert(blockInfoManager.size === numInitialBlocks) + + + // Add one big block, which will require evicting everything in the memorystore. However our + // mock BlockEvictionHandler will throw an exception -- make sure all locks are cleared. + val largeBlockId = BlockId(s"rdd_2_1") + val largeBlockInfo = new BlockInfo(StorageLevel.MEMORY_ONLY, ct, tellMaster = false) + val initialWriteLock = blockInfoManager.lockNewBlockForWriting(largeBlockId, largeBlockInfo) + assert(initialWriteLock) + if (numValidBlocks < numInitialBlocks) { + val exc = intercept[RuntimeException] { + memoryStore.putBytes(largeBlockId, memStoreSize, MemoryMode.ON_HEAP, () => { + new ChunkedByteBuffer(ByteBuffer.allocate(memStoreSize)) + }) + } + assert(exc.getMessage().startsWith("Mock error dropping block"), exc) + // BlockManager.doPut takes care of releasing the lock for the newly written block -- not + // testing that here, so do it manually + blockInfoManager.removeBlock(largeBlockId) + } else { + memoryStore.putBytes(largeBlockId, memStoreSize, MemoryMode.ON_HEAP, () => { + new ChunkedByteBuffer(ByteBuffer.allocate(memStoreSize)) + }) + // BlockManager.doPut takes care of releasing the lock for the newly written block -- not + // testing that here, so do it manually + blockInfoManager.unlock(largeBlockId) + } + + val largeBlockInMemory = if (numValidBlocks == numInitialBlocks) 1 else 0 + val expBlocks = numInitialBlocks + + (if (readLockAfterDrop) 0 else -numValidBlocks) + + largeBlockInMemory + assert(blockInfoManager.size === expBlocks) + + val blocksStillInMemory = blockInfoManager.entries.filter { case (id, info) => + assert(info.writerTask === BlockInfo.NO_WRITER, id) + // in this test, all the blocks in memory have no reader, but everything dropped to disk + // had another thread read the block. We shouldn't lose the other thread's reader lock. + if (memoryStore.contains(id)) { + assert(info.readerCount === 0, id) + true + } else { + assert(info.readerCount === 1, id) + false + } + } + assert(blocksStillInMemory.size === + (numInitialBlocks - numValidBlocks + largeBlockInMemory)) + } + + Seq(0, 3, numInitialBlocks).foreach { failAfterDropping => + Seq(true, false).foreach { readLockAfterDropping => + testFailureOnNthDrop(failAfterDropping, readLockAfterDropping) + } + } + } }