From ad4232b4dadc6290d3c4696d3cc007d3f01cb236 Mon Sep 17 00:00:00 2001
From: Charles Reiss <charles@eecs.berkeley.edu>
Date: Sat, 26 Jan 2013 18:07:14 -0800
Subject: [PATCH] Fix deadlock in BlockManager reregistration triggered by
 failed updates.

---
 .../scala/spark/storage/BlockManager.scala    | 35 +++++++++++++++-
 .../spark/storage/BlockManagerSuite.scala     | 40 ++++++++++++++++++-
 2 files changed, 72 insertions(+), 3 deletions(-)

diff --git a/core/src/main/scala/spark/storage/BlockManager.scala b/core/src/main/scala/spark/storage/BlockManager.scala
index 19cdaaa984..19d35b8667 100644
--- a/core/src/main/scala/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/spark/storage/BlockManager.scala
@@ -90,7 +90,10 @@ class BlockManager(
   val slaveActor = master.actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)),
     name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next)
 
-  @volatile private var shuttingDown = false
+  // Pending reregistration action being executed asynchronously or null if none
+  // is pending. Accesses should synchronize on asyncReregisterLock.
+  var asyncReregisterTask: Future[Unit] = null
+  val asyncReregisterLock = new Object
 
   private def heartBeat() {
     if (!master.sendHeartBeat(blockManagerId)) {
@@ -147,6 +150,8 @@ class BlockManager(
   /**
    * Reregister with the master and report all blocks to it. This will be called by the heart beat
    * thread if our heartbeat to the block amnager indicates that we were not registered.
+   *
+   * Note that this method must be called without any BlockInfo locks held.
    */
   def reregister() {
     // TODO: We might need to rate limit reregistering.
@@ -155,6 +160,32 @@ class BlockManager(
     reportAllBlocks()
   }
 
+  /**
+   * Reregister with the master sometime soon.
+   */
+  def asyncReregister() {
+    asyncReregisterLock.synchronized {
+      if (asyncReregisterTask == null) {
+        asyncReregisterTask = Future[Unit] {
+          reregister()
+          asyncReregisterLock.synchronized {
+            asyncReregisterTask = null
+          }
+        }
+      }
+    }
+  }
+
+  /**
+   * For testing. Wait for any pending asynchronous reregistration; otherwise, do nothing.
+   */
+  def waitForAsyncReregister() {
+    val task = asyncReregisterTask
+    if (task != null) {
+      Await.ready(task, Duration.Inf)
+    }
+  }
+
   /**
    * Get storage level of local block. If no info exists for the block, then returns null.
    */
@@ -170,7 +201,7 @@ class BlockManager(
     if (needReregister) {
       logInfo("Got told to reregister updating block " + blockId)
       // Reregistering will report our new block for free.
-      reregister()
+      asyncReregister()
     }
     logDebug("Told master about block " + blockId)
   }
diff --git a/core/src/test/scala/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
index a1aeb12f25..2165744689 100644
--- a/core/src/test/scala/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/spark/storage/BlockManagerSuite.scala
@@ -219,18 +219,56 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
     val a2 = new Array[Byte](400)
 
     store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
-
     assert(master.getLocations("a1").size > 0, "master was not told about a1")
 
     master.notifyADeadHost(store.blockManagerId.ip)
     assert(master.getLocations("a1").size == 0, "a1 was not removed from master")
 
     store.putSingle("a2", a1, StorageLevel.MEMORY_ONLY)
+    store.waitForAsyncReregister()
 
     assert(master.getLocations("a1").size > 0, "a1 was not reregistered with master")
     assert(master.getLocations("a2").size > 0, "master was not told about a2")
   }
 
+  test("reregistration doesn't dead lock") {
+    val heartBeat = PrivateMethod[Unit]('heartBeat)
+    store = new BlockManager(actorSystem, master, serializer, 2000)
+    val a1 = new Array[Byte](400)
+    val a2 = List(new Array[Byte](400))
+
+    // try many times to trigger any deadlocks
+    for (i <- 1 to 100) {
+      master.notifyADeadHost(store.blockManagerId.ip)
+      val t1 = new Thread {
+        override def run = {
+          store.put("a2", a2.iterator, StorageLevel.MEMORY_ONLY, true)
+        }
+      }
+      val t2 = new Thread {
+        override def run = {
+          store.putSingle("a1", a1, StorageLevel.MEMORY_ONLY)
+        }
+      }
+      val t3 = new Thread {
+        override def run = {
+          store invokePrivate heartBeat()
+        }
+      }
+
+      t1.start
+      t2.start
+      t3.start
+      t1.join
+      t2.join
+      t3.join
+ 
+      store.dropFromMemory("a1", null)
+      store.dropFromMemory("a2", null)
+      store.waitForAsyncReregister()
+    }
+  }
+
   test("in-memory LRU storage") {
     store = new BlockManager(actorSystem, master, serializer, 1200)
     val a1 = new Array[Byte](400)
-- 
GitLab