diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index 09a60571238ea8d1dea24faeb98e0f6d1a011170..3935c8772252e3ccbe96b27d9498315c583caacd 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -27,7 +27,9 @@ import org.apache.spark.shuffle.ShuffleHandle
  * Base class for dependencies.
  */
 @DeveloperApi
-abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
+abstract class Dependency[T] extends Serializable {
+  def rdd: RDD[T]
+}
 
 
 /**
@@ -36,20 +38,24 @@ abstract class Dependency[T](val rdd: RDD[T]) extends Serializable
  * partition of the child RDD.  Narrow dependencies allow for pipelined execution.
  */
 @DeveloperApi
-abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
+abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
   /**
    * Get the parent partitions for a child partition.
    * @param partitionId a partition of the child RDD
    * @return the partitions of the parent RDD that the child partition depends upon
    */
   def getParents(partitionId: Int): Seq[Int]
+
+  override def rdd: RDD[T] = _rdd
 }
 
 
 /**
  * :: DeveloperApi ::
- * Represents a dependency on the output of a shuffle stage.
- * @param rdd the parent RDD
+ * Represents a dependency on the output of a shuffle stage. Note that in the case of shuffle,
+ * the RDD is transient since we don't need it on the executor side.
+ *
+ * @param _rdd the parent RDD
  * @param partitioner partitioner used to partition the shuffle output
  * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
  *                   the default serializer, as specified by `spark.serializer` config option, will
@@ -57,20 +63,22 @@ abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) {
  */
 @DeveloperApi
 class ShuffleDependency[K, V, C](
-    @transient rdd: RDD[_ <: Product2[K, V]],
+    @transient _rdd: RDD[_ <: Product2[K, V]],
     val partitioner: Partitioner,
     val serializer: Option[Serializer] = None,
     val keyOrdering: Option[Ordering[K]] = None,
     val aggregator: Option[Aggregator[K, V, C]] = None,
     val mapSideCombine: Boolean = false)
-  extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) {
+  extends Dependency[Product2[K, V]] {
+
+  override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]]
 
-  val shuffleId: Int = rdd.context.newShuffleId()
+  val shuffleId: Int = _rdd.context.newShuffleId()
 
-  val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle(
-    shuffleId, rdd.partitions.size, this)
+  val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle(
+    shuffleId, _rdd.partitions.size, this)
 
-  rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
+  _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
 }
 
 
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 8052499ab7526a0007fdabb88c0c5c87c29e7004..48a09657fde267b66de18ba5c8bc605eb9a4597d 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -997,8 +997,6 @@ class SparkContext(config: SparkConf) extends Logging {
       // TODO: Cache.stop()?
       env.stop()
       SparkEnv.set(null)
-      ShuffleMapTask.clearCache()
-      ResultTask.clearCache()
       listenerBus.stop()
       eventLogger.foreach(_.stop())
       logInfo("Successfully stopped SparkContext")
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 88a918aebf76335caf744619a9c8870b6fec175c..2ee9a8f1a8e0d2b0f6c3895854cd4afb429d816b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -35,12 +35,13 @@ import org.apache.spark.Partitioner._
 import org.apache.spark.SparkContext._
 import org.apache.spark.annotation.{DeveloperApi, Experimental}
 import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.partial.BoundedDouble
 import org.apache.spark.partial.CountEvaluator
 import org.apache.spark.partial.GroupedCountEvaluator
 import org.apache.spark.partial.PartialResult
 import org.apache.spark.storage.StorageLevel
-import org.apache.spark.util.{BoundedPriorityQueue, CallSite, Utils}
+import org.apache.spark.util.{BoundedPriorityQueue, Utils}
 import org.apache.spark.util.collection.OpenHashMap
 import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils}
 
@@ -1195,21 +1196,36 @@ abstract class RDD[T: ClassTag](
   /**
    * Return whether this RDD has been checkpointed or not
    */
-  def isCheckpointed: Boolean = {
-    checkpointData.map(_.isCheckpointed).getOrElse(false)
-  }
+  def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed)
 
   /**
    * Gets the name of the file to which this RDD was checkpointed
    */
-  def getCheckpointFile: Option[String] = {
-    checkpointData.flatMap(_.getCheckpointFile)
-  }
+  def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile)
 
   // =======================================================================
   // Other internal methods and fields
   // =======================================================================
 
+  /**
+   * Broadcasted copy of this RDD, used to dispatch tasks to executors. Note that we broadcast
+   * the serialized copy of the RDD and for each task we will deserialize it, which means each
+   * task gets a different copy of the RDD. This provides stronger isolation between tasks that
+   * might modify state of objects referenced in their closures. This is necessary in Hadoop
+   * where the JobConf/Configuration object is not thread-safe.
+   */
+  @transient private[spark] lazy val broadcasted: Broadcast[Array[Byte]] = {
+    val ser = SparkEnv.get.closureSerializer.newInstance()
+    val bytes = ser.serialize(this).array()
+    val size = Utils.bytesToString(bytes.length)
+    if (bytes.length > (1L << 20)) {
+      logWarning(s"Broadcasting RDD $id ($size), which contains large objects")
+    } else {
+      logDebug(s"Broadcasting RDD $id ($size)")
+    }
+    sc.broadcast(bytes)
+  }
+
   private var storageLevel: StorageLevel = StorageLevel.NONE
 
   /** User code that created this RDD (e.g. `textFile`, `parallelize`). */
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
index c3b2a33fb54d095f32d988070115ae735ebf5c66..f67e5f1857979fa8244838c4dcf26e9a86343451 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
@@ -106,7 +106,6 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
       cpRDD = Some(newRDD)
       rdd.markCheckpointed(newRDD)   // Update the RDD's dependencies and partitions
       cpState = Checkpointed
-      RDDCheckpointData.clearTaskCaches()
     }
     logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id)
   }
@@ -131,9 +130,5 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
   }
 }
 
-private[spark] object RDDCheckpointData {
-  def clearTaskCaches() {
-    ShuffleMapTask.clearCache()
-    ResultTask.clearCache()
-  }
-}
+// Used for synchronization
+private[spark] object RDDCheckpointData
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index ede3c7d9f01ae3b924bf5945ec4f9e38d312e031..88cb5feaaff2a85cf36f335ab337f4ae2470da4b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -376,9 +376,6 @@ class DAGScheduler(
               stageIdToStage -= stageId
               stageIdToJobIds -= stageId
 
-              ShuffleMapTask.removeStage(stageId)
-              ResultTask.removeStage(stageId)
-
               logDebug("After removal of stage %d, remaining stages = %d"
                 .format(stageId, stageIdToStage.size))
             }
@@ -723,7 +720,6 @@ class DAGScheduler(
     }
   }
 
-
   /** Called when stage's parents are available and we can now do its task. */
   private def submitMissingTasks(stage: Stage, jobId: Int) {
     logDebug("submitMissingTasks(" + stage + ")")
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
index bbf9f7388b074582a56d1875c764fe375eeb7270..62beb0d02a9c3602e56bf96c0965fd5744fca635 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala
@@ -17,134 +17,68 @@
 
 package org.apache.spark.scheduler
 
-import scala.language.existentials
+import java.nio.ByteBuffer
 
 import java.io._
-import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-
-import scala.collection.mutable.HashMap
 
 import org.apache.spark._
-import org.apache.spark.rdd.{RDD, RDDCheckpointData}
-
-private[spark] object ResultTask {
-
-  // A simple map between the stage id to the serialized byte array of a task.
-  // Served as a cache for task serialization because serialization can be
-  // expensive on the master node if it needs to launch thousands of tasks.
-  private val serializedInfoCache = new HashMap[Int, Array[Byte]]
-
-  def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] =
-  {
-    synchronized {
-      val old = serializedInfoCache.get(stageId).orNull
-      if (old != null) {
-        old
-      } else {
-        val out = new ByteArrayOutputStream
-        val ser = SparkEnv.get.closureSerializer.newInstance()
-        val objOut = ser.serializeStream(new GZIPOutputStream(out))
-        objOut.writeObject(rdd)
-        objOut.writeObject(func)
-        objOut.close()
-        val bytes = out.toByteArray
-        serializedInfoCache.put(stageId, bytes)
-        bytes
-      }
-    }
-  }
-
-  def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) =
-  {
-    val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
-    val ser = SparkEnv.get.closureSerializer.newInstance()
-    val objIn = ser.deserializeStream(in)
-    val rdd = objIn.readObject().asInstanceOf[RDD[_]]
-    val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _]
-    (rdd, func)
-  }
-
-  def removeStage(stageId: Int) {
-    serializedInfoCache.remove(stageId)
-  }
-
-  def clearCache() {
-    synchronized {
-      serializedInfoCache.clear()
-    }
-  }
-}
-
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
 
 /**
  * A task that sends back the output to the driver application.
  *
- * See [[org.apache.spark.scheduler.Task]] for more information.
+ * See [[Task]] for more information.
  *
  * @param stageId id of the stage this task belongs to
- * @param rdd input to func
+ * @param rddBinary broadcast version of of the serialized RDD
  * @param func a function to apply on a partition of the RDD
- * @param _partitionId index of the number in the RDD
+ * @param partition partition of the RDD this task is associated with
  * @param locs preferred task execution locations for locality scheduling
  * @param outputId index of the task in this job (a job can launch tasks on only a subset of the
  *                 input RDD's partitions).
  */
 private[spark] class ResultTask[T, U](
     stageId: Int,
-    var rdd: RDD[T],
-    var func: (TaskContext, Iterator[T]) => U,
-    _partitionId: Int,
+    val rddBinary: Broadcast[Array[Byte]],
+    val func: (TaskContext, Iterator[T]) => U,
+    val partition: Partition,
     @transient locs: Seq[TaskLocation],
-    var outputId: Int)
-  extends Task[U](stageId, _partitionId) with Externalizable {
-
-  def this() = this(0, null, null, 0, null, 0)
-
-  var split = if (rdd == null) null else rdd.partitions(partitionId)
+    val outputId: Int)
+  extends Task[U](stageId, partition.index) with Serializable {
+
+  // TODO: Should we also broadcast func? For that we would need a place to
+  // keep a reference to it (perhaps in DAGScheduler's job object).
+
+  def this(
+      stageId: Int,
+      rdd: RDD[T],
+      func: (TaskContext, Iterator[T]) => U,
+      partitionId: Int,
+      locs: Seq[TaskLocation],
+      outputId: Int) = {
+    this(stageId, rdd.broadcasted, func, rdd.partitions(partitionId), locs, outputId)
+  }
 
-  @transient private val preferredLocs: Seq[TaskLocation] = {
+  @transient private[this] val preferredLocs: Seq[TaskLocation] = {
     if (locs == null) Nil else locs.toSet.toSeq
   }
 
   override def runTask(context: TaskContext): U = {
+    // Deserialize the RDD using the broadcast variable.
+    val ser = SparkEnv.get.closureSerializer.newInstance()
+    val rdd = ser.deserialize[RDD[T]](ByteBuffer.wrap(rddBinary.value),
+      Thread.currentThread.getContextClassLoader)
     metrics = Some(context.taskMetrics)
     try {
-      func(context, rdd.iterator(split, context))
+      func(context, rdd.iterator(partition, context))
     } finally {
       context.executeOnCompleteCallbacks()
     }
   }
 
+  // This is only callable on the driver side.
   override def preferredLocations: Seq[TaskLocation] = preferredLocs
 
   override def toString = "ResultTask(" + stageId + ", " + partitionId + ")"
-
-  override def writeExternal(out: ObjectOutput) {
-    RDDCheckpointData.synchronized {
-      split = rdd.partitions(partitionId)
-      out.writeInt(stageId)
-      val bytes = ResultTask.serializeInfo(
-        stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _])
-      out.writeInt(bytes.length)
-      out.write(bytes)
-      out.writeInt(partitionId)
-      out.writeInt(outputId)
-      out.writeLong(epoch)
-      out.writeObject(split)
-    }
-  }
-
-  override def readExternal(in: ObjectInput) {
-    val stageId = in.readInt()
-    val numBytes = in.readInt()
-    val bytes = new Array[Byte](numBytes)
-    in.readFully(bytes)
-    val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes)
-    rdd = rdd_.asInstanceOf[RDD[T]]
-    func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U]
-    partitionId = in.readInt()
-    outputId = in.readInt()
-    epoch = in.readLong()
-    split = in.readObject().asInstanceOf[Partition]
-  }
 }
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index fdaf1de83f051ce07294c58b1ee329ec25685405..033c6e52861e0623a76d50307e76db39c3d087aa 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -17,71 +17,13 @@
 
 package org.apache.spark.scheduler
 
-import scala.language.existentials
-
-import java.io._
-import java.util.zip.{GZIPInputStream, GZIPOutputStream}
-
-import scala.collection.mutable.HashMap
+import java.nio.ByteBuffer
 
 import org.apache.spark._
-import org.apache.spark.rdd.{RDD, RDDCheckpointData}
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
 import org.apache.spark.shuffle.ShuffleWriter
 
-private[spark] object ShuffleMapTask {
-
-  // A simple map between the stage id to the serialized byte array of a task.
-  // Served as a cache for task serialization because serialization can be
-  // expensive on the master node if it needs to launch thousands of tasks.
-  private val serializedInfoCache = new HashMap[Int, Array[Byte]]
-
-  def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_, _, _]): Array[Byte] = {
-    synchronized {
-      val old = serializedInfoCache.get(stageId).orNull
-      if (old != null) {
-        return old
-      } else {
-        val out = new ByteArrayOutputStream
-        val ser = SparkEnv.get.closureSerializer.newInstance()
-        val objOut = ser.serializeStream(new GZIPOutputStream(out))
-        objOut.writeObject(rdd)
-        objOut.writeObject(dep)
-        objOut.close()
-        val bytes = out.toByteArray
-        serializedInfoCache.put(stageId, bytes)
-        bytes
-      }
-    }
-  }
-
-  def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_, _, _]) = {
-    val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
-    val ser = SparkEnv.get.closureSerializer.newInstance()
-    val objIn = ser.deserializeStream(in)
-    val rdd = objIn.readObject().asInstanceOf[RDD[_]]
-    val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_, _, _]]
-    (rdd, dep)
-  }
-
-  // Since both the JarSet and FileSet have the same format this is used for both.
-  def deserializeFileSet(bytes: Array[Byte]): HashMap[String, Long] = {
-    val in = new GZIPInputStream(new ByteArrayInputStream(bytes))
-    val objIn = new ObjectInputStream(in)
-    val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap
-    HashMap(set.toSeq: _*)
-  }
-
-  def removeStage(stageId: Int) {
-    serializedInfoCache.remove(stageId)
-  }
-
-  def clearCache() {
-    synchronized {
-      serializedInfoCache.clear()
-    }
-  }
-}
-
 /**
  * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner
  * specified in the ShuffleDependency).
@@ -89,62 +31,47 @@ private[spark] object ShuffleMapTask {
  * See [[org.apache.spark.scheduler.Task]] for more information.
  *
  * @param stageId id of the stage this task belongs to
- * @param rdd the final RDD in this stage
+ * @param rddBinary broadcast version of of the serialized RDD
  * @param dep the ShuffleDependency
- * @param _partitionId index of the number in the RDD
+ * @param partition partition of the RDD this task is associated with
  * @param locs preferred task execution locations for locality scheduling
  */
 private[spark] class ShuffleMapTask(
     stageId: Int,
-    var rdd: RDD[_],
+    var rddBinary: Broadcast[Array[Byte]],
     var dep: ShuffleDependency[_, _, _],
-    _partitionId: Int,
+    partition: Partition,
     @transient private var locs: Seq[TaskLocation])
-  extends Task[MapStatus](stageId, _partitionId)
-  with Externalizable
-  with Logging {
-
-  protected def this() = this(0, null, null, 0, null)
+  extends Task[MapStatus](stageId, partition.index) with Logging {
+
+  // TODO: Should we also broadcast the ShuffleDependency? For that we would need a place to
+  // keep a reference to it (perhaps in Stage).
+
+  def this(
+      stageId: Int,
+      rdd: RDD[_],
+      dep: ShuffleDependency[_, _, _],
+      partitionId: Int,
+      locs: Seq[TaskLocation]) = {
+    this(stageId, rdd.broadcasted, dep, rdd.partitions(partitionId), locs)
+  }
 
   @transient private val preferredLocs: Seq[TaskLocation] = {
     if (locs == null) Nil else locs.toSet.toSeq
   }
 
-  var split = if (rdd == null) null else rdd.partitions(partitionId)
-
-  override def writeExternal(out: ObjectOutput) {
-    RDDCheckpointData.synchronized {
-      split = rdd.partitions(partitionId)
-      out.writeInt(stageId)
-      val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep)
-      out.writeInt(bytes.length)
-      out.write(bytes)
-      out.writeInt(partitionId)
-      out.writeLong(epoch)
-      out.writeObject(split)
-    }
-  }
-
-  override def readExternal(in: ObjectInput) {
-    val stageId = in.readInt()
-    val numBytes = in.readInt()
-    val bytes = new Array[Byte](numBytes)
-    in.readFully(bytes)
-    val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes)
-    rdd = rdd_
-    dep = dep_
-    partitionId = in.readInt()
-    epoch = in.readLong()
-    split = in.readObject().asInstanceOf[Partition]
-  }
-
   override def runTask(context: TaskContext): MapStatus = {
+    // Deserialize the RDD using the broadcast variable.
+    val ser = SparkEnv.get.closureSerializer.newInstance()
+    val rdd = ser.deserialize[RDD[_]](ByteBuffer.wrap(rddBinary.value),
+      Thread.currentThread.getContextClassLoader)
+
     metrics = Some(context.taskMetrics)
     var writer: ShuffleWriter[Any, Any] = null
     try {
       val manager = SparkEnv.get.shuffleManager
       writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context)
-      writer.write(rdd.iterator(split, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
+      writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]])
       return writer.stop(success = true).get
     } catch {
       case e: Exception =>
diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
index 13b415cccb647f4ca8bf34b5850895bcb4916c07..871f831531bee9639d3b87a84f981618b6b2eb2a 100644
--- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala
@@ -52,9 +52,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
     }
   }
 
-
   test("cleanup RDD") {
-    val rdd = newRDD.persist()
+    val rdd = newRDD().persist()
     val collected = rdd.collect().toList
     val tester = new CleanerTester(sc, rddIds = Seq(rdd.id))
 
@@ -67,7 +66,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
   }
 
   test("cleanup shuffle") {
-    val (rdd, shuffleDeps) = newRDDWithShuffleDependencies
+    val (rdd, shuffleDeps) = newRDDWithShuffleDependencies()
     val collected = rdd.collect().toList
     val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId))
 
@@ -80,7 +79,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
   }
 
   test("cleanup broadcast") {
-    val broadcast = newBroadcast
+    val broadcast = newBroadcast()
     val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
 
     // Explicit cleanup
@@ -89,7 +88,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
   }
 
   test("automatically cleanup RDD") {
-    var rdd = newRDD.persist()
+    var rdd = newRDD().persist()
     rdd.count()
 
     // Test that GC does not cause RDD cleanup due to a strong reference
@@ -107,7 +106,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
   }
 
   test("automatically cleanup shuffle") {
-    var rdd = newShuffleRDD
+    var rdd = newShuffleRDD()
     rdd.count()
 
     // Test that GC does not cause shuffle cleanup due to a strong reference
@@ -125,7 +124,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
   }
 
   test("automatically cleanup broadcast") {
-    var broadcast = newBroadcast
+    var broadcast = newBroadcast()
 
     // Test that GC does not cause broadcast cleanup due to a strong reference
     val preGCTester =  new CleanerTester(sc, broadcastIds = Seq(broadcast.id))
@@ -141,11 +140,23 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
     postGCTester.assertCleanup()
   }
 
+  test("automatically cleanup broadcast data for task dispatching") {
+    var rdd = newRDDWithShuffleDependencies()._1
+    rdd.count()  // This triggers an action that broadcasts the RDDs.
+
+    // Test that GC causes broadcast task data cleanup after dereferencing the RDD.
+    val postGCTester = new CleanerTester(sc,
+      broadcastIds = Seq(rdd.broadcasted.id, rdd.firstParent.broadcasted.id))
+    rdd = null
+    runGC()
+    postGCTester.assertCleanup()
+  }
+
   test("automatically cleanup RDD + shuffle + broadcast") {
     val numRdds = 100
     val numBroadcasts = 4 // Broadcasts are more costly
-    val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
-    val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer
+    val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer
+    val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer
     val rddIds = sc.persistentRdds.keys.toSeq
     val shuffleIds = 0 until sc.newShuffleId
     val broadcastIds = 0L until numBroadcasts
@@ -175,8 +186,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
 
     val numRdds = 10
     val numBroadcasts = 4 // Broadcasts are more costly
-    val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer
-    val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer
+    val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer
+    val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer
     val rddIds = sc.persistentRdds.keys.toSeq
     val shuffleIds = 0 until sc.newShuffleId
     val broadcastIds = 0L until numBroadcasts
@@ -197,17 +208,18 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
 
   //------ Helper functions ------
 
-  def newRDD = sc.makeRDD(1 to 10)
-  def newPairRDD = newRDD.map(_ -> 1)
-  def newShuffleRDD = newPairRDD.reduceByKey(_ + _)
-  def newBroadcast = sc.broadcast(1 to 100)
-  def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
+  private def newRDD() = sc.makeRDD(1 to 10)
+  private def newPairRDD() = newRDD().map(_ -> 1)
+  private def newShuffleRDD() = newPairRDD().reduceByKey(_ + _)
+  private def newBroadcast() = sc.broadcast(1 to 100)
+
+  private def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = {
     def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = {
       rdd.dependencies ++ rdd.dependencies.flatMap { dep =>
         getAllDependencies(dep.rdd)
       }
     }
-    val rdd = newShuffleRDD
+    val rdd = newShuffleRDD()
 
     // Get all the shuffle dependencies
     val shuffleDeps = getAllDependencies(rdd)
@@ -216,34 +228,34 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo
     (rdd, shuffleDeps)
   }
 
-  def randomRdd = {
+  private def randomRdd() = {
     val rdd: RDD[_] = Random.nextInt(3) match {
-      case 0 => newRDD
-      case 1 => newShuffleRDD
-      case 2 => newPairRDD.join(newPairRDD)
+      case 0 => newRDD()
+      case 1 => newShuffleRDD()
+      case 2 => newPairRDD.join(newPairRDD())
     }
     if (Random.nextBoolean()) rdd.persist()
     rdd.count()
     rdd
   }
 
-  def randomBroadcast = {
+  private def randomBroadcast() = {
     sc.broadcast(Random.nextInt(Int.MaxValue))
   }
 
   /** Run GC and make sure it actually has run */
-  def runGC() {
+  private def runGC() {
     val weakRef = new WeakReference(new Object())
     val startTime = System.currentTimeMillis
     System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC.
     // Wait until a weak reference object has been GCed
-    while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
+    while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) {
       System.gc()
       Thread.sleep(200)
     }
   }
 
-  def cleaner = sc.cleaner.get
+  private def cleaner = sc.cleaner.get
 }