diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
index a8c827030a1ef99a1ee6241f90886d93c1c50a1b..6a187b40628a25b0039f1b5c715702e180a161c3 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
@@ -32,8 +32,19 @@ import org.apache.spark.annotation.DeveloperApi
  */
 @DeveloperApi
 trait BroadcastFactory {
+
   def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
+
+  /**
+   * Creates a new broadcast variable.
+   *
+   * @param value value to broadcast
+   * @param isLocal whether we are in local mode (single JVM process)
+   * @param id unique id representing this broadcast variable
+   */
   def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T]
+
   def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit
+
   def stop(): Unit
 }
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index d8be649f96e5f990d18b2b13a4531a75bc99f4c2..6173fd3a69fc7545a700c4048a43d3458bcb02f5 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -18,50 +18,116 @@
 package org.apache.spark.broadcast
 
 import java.io._
+import java.nio.ByteBuffer
 
+import scala.collection.JavaConversions.asJavaEnumeration
 import scala.reflect.ClassTag
 import scala.util.Random
 
 import org.apache.spark.{Logging, SparkConf, SparkEnv, SparkException}
 import org.apache.spark.io.CompressionCodec
 import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
+import org.apache.spark.util.ByteBufferInputStream
 
 /**
- *  A [[org.apache.spark.broadcast.Broadcast]] implementation that uses a BitTorrent-like
- *  protocol to do a distributed transfer of the broadcasted data to the executors.
- *  The mechanism is as follows. The driver divides the serializes the broadcasted data,
- *  divides it into smaller chunks, and stores them in the BlockManager of the driver.
- *  These chunks are reported to the BlockManagerMaster so that all the executors can
- *  learn the location of those chunks. The first time the broadcast variable (sent as
- *  part of task) is deserialized at a executor, all the chunks are fetched using
- *  the BlockManager. When all the chunks are fetched (initially from the driver's
- *  BlockManager), they are combined and deserialized to recreate the broadcasted data.
- *  However, the chunks are also stored in the BlockManager and reported to the
- *  BlockManagerMaster. As more executors fetch the chunks, BlockManagerMaster learns
- *  multiple locations for each chunk. Hence, subsequent fetches of each chunk will be
- *  made to other executors who already have those chunks, resulting in a distributed
- *  fetching. This prevents the driver from being the bottleneck in sending out multiple
- *  copies of the broadcast data (one per executor) as done by the
- *  [[org.apache.spark.broadcast.HttpBroadcast]].
+ * A BitTorrent-like implementation of [[org.apache.spark.broadcast.Broadcast]].
+ *
+ * The mechanism is as follows:
+ *
+ * The driver divides the serialized object into small chunks and
+ * stores those chunks in the BlockManager of the driver.
+ *
+ * On each executor, the executor first attempts to fetch the object from its BlockManager. If
+ * it does not exist, it then uses remote fetches to fetch the small chunks from the driver and/or
+ * other executors if available. Once it gets the chunks, it puts the chunks in its own
+ * BlockManager, ready for other executors to fetch from.
+ *
+ * This prevents the driver from being the bottleneck in sending out multiple copies of the
+ * broadcast data (one per executor) as done by the [[org.apache.spark.broadcast.HttpBroadcast]].
+ *
+ * @param obj object to broadcast
+ * @param isLocal whether Spark is running in local mode (single JVM process).
+ * @param id A unique identifier for the broadcast variable.
  */
 private[spark] class TorrentBroadcast[T: ClassTag](
-    @transient var value_ : T, isLocal: Boolean, id: Long)
+    obj : T,
+    @transient private val isLocal: Boolean,
+    id: Long)
   extends Broadcast[T](id) with Logging with Serializable {
 
-  override protected def getValue() = value_
+  /**
+   * Value of the broadcast object. On driver, this is set directly by the constructor.
+   * On executors, this is reconstructed by [[readObject]], which builds this value by reading
+   * blocks from the driver and/or other executors.
+   */
+  @transient private var _value: T = obj
 
   private val broadcastId = BroadcastBlockId(id)
 
-  SparkEnv.get.blockManager.putSingle(
-    broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+  /** Total number of blocks this broadcast variable contains. */
+  private val numBlocks: Int = writeBlocks()
+
+  override protected def getValue() = _value
+
+  /**
+   * Divide the object into multiple blocks and put those blocks in the block manager.
+   *
+   * @return number of blocks this broadcast variable is divided into
+   */
+  private def writeBlocks(): Int = {
+    // For local mode, just put the object in the BlockManager so we can find it later.
+    SparkEnv.get.blockManager.putSingle(
+      broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
+
+    if (!isLocal) {
+      val blocks = TorrentBroadcast.blockifyObject(_value)
+      blocks.zipWithIndex.foreach { case (block, i) =>
+        SparkEnv.get.blockManager.putBytes(
+          BroadcastBlockId(id, "piece" + i),
+          block,
+          StorageLevel.MEMORY_AND_DISK_SER,
+          tellMaster = true)
+      }
+      blocks.length
+    } else {
+      0
+    }
+  }
+
+  /** Fetch torrent blocks from the driver and/or other executors. */
+  private def readBlocks(): Array[ByteBuffer] = {
+    // Fetch chunks of data. Note that all these chunks are stored in the BlockManager and reported
+    // to the driver, so other executors can pull these chunks from this executor as well.
+    val blocks = new Array[ByteBuffer](numBlocks)
+    val bm = SparkEnv.get.blockManager
 
-  @transient private var arrayOfBlocks: Array[TorrentBlock] = null
-  @transient private var totalBlocks = -1
-  @transient private var totalBytes = -1
-  @transient private var hasBlocks = 0
+    for (pid <- Random.shuffle(Seq.range(0, numBlocks))) {
+      val pieceId = BroadcastBlockId(id, "piece" + pid)
 
-  if (!isLocal) {
-    sendBroadcast()
+      // First try getLocalBytes because  there is a chance that previous attempts to fetch the
+      // broadcast blocks have already fetched some of the blocks. In that case, some blocks
+      // would be available locally (on this executor).
+      var blockOpt = bm.getLocalBytes(pieceId)
+      if (!blockOpt.isDefined) {
+        blockOpt = bm.getRemoteBytes(pieceId)
+        blockOpt match {
+          case Some(block) =>
+            // If we found the block from remote executors/driver's BlockManager, put the block
+            // in this executor's BlockManager.
+            SparkEnv.get.blockManager.putBytes(
+              pieceId,
+              block,
+              StorageLevel.MEMORY_AND_DISK_SER,
+              tellMaster = true)
+
+          case None =>
+            throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
+        }
+      }
+      // If we get here, the option is defined.
+      blocks(pid) = blockOpt.get
+    }
+    blocks
   }
 
   /**
@@ -79,26 +145,6 @@ private[spark] class TorrentBroadcast[T: ClassTag](
     TorrentBroadcast.unpersist(id, removeFromDriver = true, blocking)
   }
 
-  private def sendBroadcast() {
-    val tInfo = TorrentBroadcast.blockifyObject(value_)
-    totalBlocks = tInfo.totalBlocks
-    totalBytes = tInfo.totalBytes
-    hasBlocks = tInfo.totalBlocks
-
-    // Store meta-info
-    val metaId = BroadcastBlockId(id, "meta")
-    val metaInfo = TorrentInfo(null, totalBlocks, totalBytes)
-    SparkEnv.get.blockManager.putSingle(
-      metaId, metaInfo, StorageLevel.MEMORY_AND_DISK, tellMaster = true)
-
-    // Store individual pieces
-    for (i <- 0 until totalBlocks) {
-      val pieceId = BroadcastBlockId(id, "piece" + i)
-      SparkEnv.get.blockManager.putSingle(
-        pieceId, tInfo.arrayOfBlocks(i), StorageLevel.MEMORY_AND_DISK, tellMaster = true)
-    }
-  }
-
   /** Used by the JVM when serializing this object. */
   private def writeObject(out: ObjectOutputStream) {
     assertValid()
@@ -109,99 +155,30 @@ private[spark] class TorrentBroadcast[T: ClassTag](
   private def readObject(in: ObjectInputStream) {
     in.defaultReadObject()
     TorrentBroadcast.synchronized {
-      SparkEnv.get.blockManager.getSingle(broadcastId) match {
+      SparkEnv.get.blockManager.getLocal(broadcastId).map(_.data.next()) match {
         case Some(x) =>
-          value_ = x.asInstanceOf[T]
+          _value = x.asInstanceOf[T]
 
         case None =>
-          val start = System.nanoTime
           logInfo("Started reading broadcast variable " + id)
-
-          // Initialize @transient variables that will receive garbage values from the master.
-          resetWorkerVariables()
-
-          if (receiveBroadcast()) {
-            value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks)
-
-            /* Store the merged copy in cache so that the next worker doesn't need to rebuild it.
-             * This creates a trade-off between memory usage and latency. Storing copy doubles
-             * the memory footprint; not storing doubles deserialization cost. Also,
-             * this does not need to be reported to BlockManagerMaster since other executors
-             * does not need to access this block (they only need to fetch the chunks,
-             * which are reported).
-             */
-            SparkEnv.get.blockManager.putSingle(
-              broadcastId, value_, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
-
-            // Remove arrayOfBlocks from memory once value_ is on local cache
-            resetWorkerVariables()
-          } else {
-            logError("Reading broadcast variable " + id + " failed")
-          }
-
-          val time = (System.nanoTime - start) / 1e9
+          val start = System.nanoTime()
+          val blocks = readBlocks()
+          val time = (System.nanoTime() - start) / 1e9
           logInfo("Reading broadcast variable " + id + " took " + time + " s")
-      }
-    }
-  }
-
-  private def resetWorkerVariables() {
-    arrayOfBlocks = null
-    totalBytes = -1
-    totalBlocks = -1
-    hasBlocks = 0
-  }
-
-  private def receiveBroadcast(): Boolean = {
-    // Receive meta-info about the size of broadcast data,
-    // the number of chunks it is divided into, etc.
-    val metaId = BroadcastBlockId(id, "meta")
-    var attemptId = 10
-    while (attemptId > 0 && totalBlocks == -1) {
-      SparkEnv.get.blockManager.getSingle(metaId) match {
-        case Some(x) =>
-          val tInfo = x.asInstanceOf[TorrentInfo]
-          totalBlocks = tInfo.totalBlocks
-          totalBytes = tInfo.totalBytes
-          arrayOfBlocks = new Array[TorrentBlock](totalBlocks)
-          hasBlocks = 0
 
-        case None =>
-          Thread.sleep(500)
-      }
-      attemptId -= 1
-    }
-
-    if (totalBlocks == -1) {
-      return false
-    }
-
-    /*
-     * Fetch actual chunks of data. Note that all these chunks are stored in
-     * the BlockManager and reported to the master, so that other executors
-     * can find out and pull the chunks from this executor.
-     */
-    val recvOrder = new Random().shuffle(Array.iterate(0, totalBlocks)(_ + 1).toList)
-    for (pid <- recvOrder) {
-      val pieceId = BroadcastBlockId(id, "piece" + pid)
-      SparkEnv.get.blockManager.getSingle(pieceId) match {
-        case Some(x) =>
-          arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock]
-          hasBlocks += 1
+          _value = TorrentBroadcast.unBlockifyObject[T](blocks)
+          // Store the merged copy in BlockManager so other tasks on this executor don't
+          // need to re-fetch it.
           SparkEnv.get.blockManager.putSingle(
-            pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, tellMaster = true)
-
-        case None =>
-          throw new SparkException("Failed to get " + pieceId + " of " + broadcastId)
+            broadcastId, _value, StorageLevel.MEMORY_AND_DISK, tellMaster = false)
       }
     }
-
-    hasBlocks == totalBlocks
   }
-
 }
 
-private[broadcast] object TorrentBroadcast extends Logging {
+
+private object TorrentBroadcast extends Logging {
+  /** Size of each block. Default value is 4MB. */
   private lazy val BLOCK_SIZE = conf.getInt("spark.broadcast.blockSize", 4096) * 1024
   private var initialized = false
   private var conf: SparkConf = null
@@ -223,7 +200,9 @@ private[broadcast] object TorrentBroadcast extends Logging {
     initialized = false
   }
 
-  def blockifyObject[T: ClassTag](obj: T): TorrentInfo = {
+  def blockifyObject[T: ClassTag](obj: T): Array[ByteBuffer] = {
+    // TODO: Create a special ByteArrayOutputStream that splits the output directly into chunks
+    // so we don't need to do the extra memory copy.
     val bos = new ByteArrayOutputStream()
     val out: OutputStream = if (compress) compressionCodec.compressedOutputStream(bos) else bos
     val ser = SparkEnv.get.serializer.newInstance()
@@ -231,44 +210,27 @@ private[broadcast] object TorrentBroadcast extends Logging {
     serOut.writeObject[T](obj).close()
     val byteArray = bos.toByteArray
     val bais = new ByteArrayInputStream(byteArray)
+    val numBlocks = math.ceil(byteArray.length.toDouble / BLOCK_SIZE).toInt
+    val blocks = new Array[ByteBuffer](numBlocks)
 
-    var blockNum = byteArray.length / BLOCK_SIZE
-    if (byteArray.length % BLOCK_SIZE != 0) {
-      blockNum += 1
-    }
-
-    val blocks = new Array[TorrentBlock](blockNum)
     var blockId = 0
-
     for (i <- 0 until (byteArray.length, BLOCK_SIZE)) {
       val thisBlockSize = math.min(BLOCK_SIZE, byteArray.length - i)
       val tempByteArray = new Array[Byte](thisBlockSize)
       bais.read(tempByteArray, 0, thisBlockSize)
 
-      blocks(blockId) = new TorrentBlock(blockId, tempByteArray)
+      blocks(blockId) = ByteBuffer.wrap(tempByteArray)
       blockId += 1
     }
     bais.close()
-
-    val info = TorrentInfo(blocks, blockNum, byteArray.length)
-    info.hasBlocks = blockNum
-    info
+    blocks
   }
 
-  def unBlockifyObject[T: ClassTag](
-      arrayOfBlocks: Array[TorrentBlock],
-      totalBytes: Int,
-      totalBlocks: Int): T = {
-    val retByteArray = new Array[Byte](totalBytes)
-    for (i <- 0 until totalBlocks) {
-      System.arraycopy(arrayOfBlocks(i).byteArray, 0, retByteArray,
-        i * BLOCK_SIZE, arrayOfBlocks(i).byteArray.length)
-    }
+  def unBlockifyObject[T: ClassTag](blocks: Array[ByteBuffer]): T = {
+    val is = new SequenceInputStream(
+      asJavaEnumeration(blocks.iterator.map(block => new ByteBufferInputStream(block))))
+    val in: InputStream = if (compress) compressionCodec.compressedInputStream(is) else is
 
-    val in: InputStream = {
-      val arrIn = new ByteArrayInputStream(retByteArray)
-      if (compress) compressionCodec.compressedInputStream(arrIn) else arrIn
-    }
     val ser = SparkEnv.get.serializer.newInstance()
     val serIn = ser.deserializeStream(in)
     val obj = serIn.readObject[T]()
@@ -284,17 +246,3 @@ private[broadcast] object TorrentBroadcast extends Logging {
     SparkEnv.get.blockManager.master.removeBroadcast(id, removeFromDriver, blocking)
   }
 }
-
-private[broadcast] case class TorrentBlock(
-    blockID: Int,
-    byteArray: Array[Byte])
-  extends Serializable
-
-private[broadcast] case class TorrentInfo(
-    @transient arrayOfBlocks: Array[TorrentBlock],
-    totalBlocks: Int,
-    totalBytes: Int)
-  extends Serializable {
-
-  @transient var hasBlocks = 0
-}
diff --git a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
index 17c64455b2429f570e40517ca6809e54257ae83f..978a6ded808297ce1a2325e39a1b782a7c698b97 100644
--- a/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
+++ b/core/src/test/scala/org/apache/spark/broadcast/BroadcastSuite.scala
@@ -17,10 +17,12 @@
 
 package org.apache.spark.broadcast
 
-import org.apache.spark.storage.{BroadcastBlockId, _}
-import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
 import org.scalatest.FunSuite
 
+import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkException}
+import org.apache.spark.storage._
+
+
 class BroadcastSuite extends FunSuite with LocalSparkContext {
 
   private val httpConf = broadcastConf("HttpBroadcastFactory")
@@ -124,12 +126,10 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
   private def testUnpersistHttpBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
     val numSlaves = if (distributed) 2 else 0
 
-    def getBlockIds(id: Long) = Seq[BroadcastBlockId](BroadcastBlockId(id))
-
     // Verify that the broadcast file is created, and blocks are persisted only on the driver
-    def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
-      assert(blockIds.size === 1)
-      val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+    def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) {
+      val blockId = BroadcastBlockId(broadcastId)
+      val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
       assert(statuses.size === 1)
       statuses.head match { case (bm, status) =>
         assert(bm.executorId === "<driver>", "Block should only be on the driver")
@@ -139,14 +139,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
       }
       if (distributed) {
         // this file is only generated in distributed mode
-        assert(HttpBroadcast.getFile(blockIds.head.broadcastId).exists, "Broadcast file not found!")
+        assert(HttpBroadcast.getFile(blockId.broadcastId).exists, "Broadcast file not found!")
       }
     }
 
     // Verify that blocks are persisted in both the executors and the driver
-    def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
-      assert(blockIds.size === 1)
-      val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+    def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) {
+      val blockId = BroadcastBlockId(broadcastId)
+      val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
       assert(statuses.size === numSlaves + 1)
       statuses.foreach { case (_, status) =>
         assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
@@ -157,21 +157,21 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
 
     // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
     // is true. In the latter case, also verify that the broadcast file is deleted on the driver.
-    def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
-      assert(blockIds.size === 1)
-      val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+    def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) {
+      val blockId = BroadcastBlockId(broadcastId)
+      val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
       val expectedNumBlocks = if (removeFromDriver) 0 else 1
       val possiblyNot = if (removeFromDriver) "" else " not"
       assert(statuses.size === expectedNumBlocks,
         "Block should%s be unpersisted on the driver".format(possiblyNot))
       if (distributed && removeFromDriver) {
         // this file is only generated in distributed mode
-        assert(!HttpBroadcast.getFile(blockIds.head.broadcastId).exists,
+        assert(!HttpBroadcast.getFile(blockId.broadcastId).exists,
           "Broadcast file should%s be deleted".format(possiblyNot))
       }
     }
 
-    testUnpersistBroadcast(distributed, numSlaves, httpConf, getBlockIds, afterCreation,
+    testUnpersistBroadcast(distributed, numSlaves, httpConf, afterCreation,
       afterUsingBroadcast, afterUnpersist, removeFromDriver)
   }
 
@@ -185,67 +185,51 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
   private def testUnpersistTorrentBroadcast(distributed: Boolean, removeFromDriver: Boolean) {
     val numSlaves = if (distributed) 2 else 0
 
-    def getBlockIds(id: Long) = {
-      val broadcastBlockId = BroadcastBlockId(id)
-      val metaBlockId = BroadcastBlockId(id, "meta")
-      // Assume broadcast value is small enough to fit into 1 piece
-      val pieceBlockId = BroadcastBlockId(id, "piece0")
-      if (distributed) {
-        // the metadata and piece blocks are generated only in distributed mode
-        Seq[BroadcastBlockId](broadcastBlockId, metaBlockId, pieceBlockId)
-      } else {
-        Seq[BroadcastBlockId](broadcastBlockId)
-      }
+    // Verify that blocks are persisted only on the driver
+    def afterCreation(broadcastId: Long, bmm: BlockManagerMaster) {
+      var blockId = BroadcastBlockId(broadcastId)
+      var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+      assert(statuses.size === 1)
+
+      blockId = BroadcastBlockId(broadcastId, "piece0")
+      statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+      assert(statuses.size === (if (distributed) 1 else 0))
     }
 
-    // Verify that blocks are persisted only on the driver
-    def afterCreation(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
-      blockIds.foreach { blockId =>
-        val statuses = bmm.getBlockStatus(blockIds.head, askSlaves = true)
+    // Verify that blocks are persisted in both the executors and the driver
+    def afterUsingBroadcast(broadcastId: Long, bmm: BlockManagerMaster) {
+      var blockId = BroadcastBlockId(broadcastId)
+      var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+      if (distributed) {
+        assert(statuses.size === numSlaves + 1)
+      } else {
         assert(statuses.size === 1)
-        statuses.head match { case (bm, status) =>
-          assert(bm.executorId === "<driver>", "Block should only be on the driver")
-          assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
-          assert(status.memSize > 0, "Block should be in memory store on the driver")
-          assert(status.diskSize === 0, "Block should not be in disk store on the driver")
-        }
       }
-    }
 
-    // Verify that blocks are persisted in both the executors and the driver
-    def afterUsingBroadcast(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
-      blockIds.foreach { blockId =>
-        val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
-        if (blockId.field == "meta") {
-          // Meta data is only on the driver
-          assert(statuses.size === 1)
-          statuses.head match { case (bm, _) => assert(bm.executorId === "<driver>") }
-        } else {
-          // Other blocks are on both the executors and the driver
-          assert(statuses.size === numSlaves + 1,
-            blockId + " has " + statuses.size + " statuses: " + statuses.mkString(","))
-          statuses.foreach { case (_, status) =>
-            assert(status.storageLevel === StorageLevel.MEMORY_AND_DISK)
-            assert(status.memSize > 0, "Block should be in memory store")
-            assert(status.diskSize === 0, "Block should not be in disk store")
-          }
-        }
+      blockId = BroadcastBlockId(broadcastId, "piece0")
+      statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+      if (distributed) {
+        assert(statuses.size === numSlaves + 1)
+      } else {
+        assert(statuses.size === 0)
       }
     }
 
     // Verify that blocks are unpersisted on all executors, and on all nodes if removeFromDriver
     // is true.
-    def afterUnpersist(blockIds: Seq[BroadcastBlockId], bmm: BlockManagerMaster) {
-      val expectedNumBlocks = if (removeFromDriver) 0 else 1
-      val possiblyNot = if (removeFromDriver) "" else " not"
-      blockIds.foreach { blockId =>
-        val statuses = bmm.getBlockStatus(blockId, askSlaves = true)
-        assert(statuses.size === expectedNumBlocks,
-          "Block should%s be unpersisted on the driver".format(possiblyNot))
-      }
+    def afterUnpersist(broadcastId: Long, bmm: BlockManagerMaster) {
+      var blockId = BroadcastBlockId(broadcastId)
+      var expectedNumBlocks = if (removeFromDriver) 0 else 1
+      var statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+      assert(statuses.size === expectedNumBlocks)
+
+      blockId = BroadcastBlockId(broadcastId, "piece0")
+      expectedNumBlocks = if (removeFromDriver || !distributed) 0 else 1
+      statuses = bmm.getBlockStatus(blockId, askSlaves = true)
+      assert(statuses.size === expectedNumBlocks)
     }
 
-    testUnpersistBroadcast(distributed, numSlaves,  torrentConf, getBlockIds, afterCreation,
+    testUnpersistBroadcast(distributed, numSlaves,  torrentConf, afterCreation,
       afterUsingBroadcast, afterUnpersist, removeFromDriver)
   }
 
@@ -262,10 +246,9 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
       distributed: Boolean,
       numSlaves: Int,  // used only when distributed = true
       broadcastConf: SparkConf,
-      getBlockIds: Long => Seq[BroadcastBlockId],
-      afterCreation: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
-      afterUsingBroadcast: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
-      afterUnpersist: (Seq[BroadcastBlockId], BlockManagerMaster) => Unit,
+      afterCreation: (Long, BlockManagerMaster) => Unit,
+      afterUsingBroadcast: (Long, BlockManagerMaster) => Unit,
+      afterUnpersist: (Long, BlockManagerMaster) => Unit,
       removeFromDriver: Boolean) {
 
     sc = if (distributed) {
@@ -278,15 +261,14 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
 
     // Create broadcast variable
     val broadcast = sc.broadcast(list)
-    val blocks = getBlockIds(broadcast.id)
-    afterCreation(blocks, blockManagerMaster)
+    afterCreation(broadcast.id, blockManagerMaster)
 
     // Use broadcast variable on all executors
     val partitions = 10
     assert(partitions > numSlaves)
     val results = sc.parallelize(1 to partitions, partitions).map(x => (x, broadcast.value.sum))
     assert(results.collect().toSet === (1 to partitions).map(x => (x, list.sum)).toSet)
-    afterUsingBroadcast(blocks, blockManagerMaster)
+    afterUsingBroadcast(broadcast.id, blockManagerMaster)
 
     // Unpersist broadcast
     if (removeFromDriver) {
@@ -294,7 +276,7 @@ class BroadcastSuite extends FunSuite with LocalSparkContext {
     } else {
       broadcast.unpersist(blocking = true)
     }
-    afterUnpersist(blocks, blockManagerMaster)
+    afterUnpersist(broadcast.id, blockManagerMaster)
 
     // If the broadcast is removed from driver, all subsequent uses of the broadcast variable
     // should throw SparkExceptions. Otherwise, the result should be the same as before.