diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index 3935c8772252e3ccbe96b27d9498315c583caacd..09a60571238ea8d1dea24faeb98e0f6d1a011170 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -27,9 +27,7 @@ import org.apache.spark.shuffle.ShuffleHandle * Base class for dependencies. */ @DeveloperApi -abstract class Dependency[T] extends Serializable { - def rdd: RDD[T] -} +abstract class Dependency[T](val rdd: RDD[T]) extends Serializable /** @@ -38,24 +36,20 @@ abstract class Dependency[T] extends Serializable { * partition of the child RDD. Narrow dependencies allow for pipelined execution. */ @DeveloperApi -abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { +abstract class NarrowDependency[T](rdd: RDD[T]) extends Dependency(rdd) { /** * Get the parent partitions for a child partition. * @param partitionId a partition of the child RDD * @return the partitions of the parent RDD that the child partition depends upon */ def getParents(partitionId: Int): Seq[Int] - - override def rdd: RDD[T] = _rdd } /** * :: DeveloperApi :: - * Represents a dependency on the output of a shuffle stage. Note that in the case of shuffle, - * the RDD is transient since we don't need it on the executor side. - * - * @param _rdd the parent RDD + * Represents a dependency on the output of a shuffle stage. + * @param rdd the parent RDD * @param partitioner partitioner used to partition the shuffle output * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None, * the default serializer, as specified by `spark.serializer` config option, will @@ -63,22 +57,20 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] { */ @DeveloperApi class ShuffleDependency[K, V, C]( - @transient _rdd: RDD[_ <: Product2[K, V]], + @transient rdd: RDD[_ <: Product2[K, V]], val partitioner: Partitioner, val serializer: Option[Serializer] = None, val keyOrdering: Option[Ordering[K]] = None, val aggregator: Option[Aggregator[K, V, C]] = None, val mapSideCombine: Boolean = false) - extends Dependency[Product2[K, V]] { - - override def rdd = _rdd.asInstanceOf[RDD[Product2[K, V]]] + extends Dependency(rdd.asInstanceOf[RDD[Product2[K, V]]]) { - val shuffleId: Int = _rdd.context.newShuffleId() + val shuffleId: Int = rdd.context.newShuffleId() - val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle( - shuffleId, _rdd.partitions.size, this) + val shuffleHandle: ShuffleHandle = rdd.context.env.shuffleManager.registerShuffle( + shuffleId, rdd.partitions.size, this) - _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) + rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 48a09657fde267b66de18ba5c8bc605eb9a4597d..8052499ab7526a0007fdabb88c0c5c87c29e7004 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -997,6 +997,8 @@ class SparkContext(config: SparkConf) extends Logging { // TODO: Cache.stop()? env.stop() SparkEnv.set(null) + ShuffleMapTask.clearCache() + ResultTask.clearCache() listenerBus.stop() eventLogger.foreach(_.stop()) logInfo("Successfully stopped SparkContext") diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 2ee9a8f1a8e0d2b0f6c3895854cd4afb429d816b..88a918aebf76335caf744619a9c8870b6fec175c 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -35,13 +35,12 @@ import org.apache.spark.Partitioner._ import org.apache.spark.SparkContext._ import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.api.java.JavaRDD -import org.apache.spark.broadcast.Broadcast import org.apache.spark.partial.BoundedDouble import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{BoundedPriorityQueue, Utils} +import org.apache.spark.util.{BoundedPriorityQueue, CallSite, Utils} import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, SamplingUtils} @@ -1196,36 +1195,21 @@ abstract class RDD[T: ClassTag]( /** * Return whether this RDD has been checkpointed or not */ - def isCheckpointed: Boolean = checkpointData.exists(_.isCheckpointed) + def isCheckpointed: Boolean = { + checkpointData.map(_.isCheckpointed).getOrElse(false) + } /** * Gets the name of the file to which this RDD was checkpointed */ - def getCheckpointFile: Option[String] = checkpointData.flatMap(_.getCheckpointFile) + def getCheckpointFile: Option[String] = { + checkpointData.flatMap(_.getCheckpointFile) + } // ======================================================================= // Other internal methods and fields // ======================================================================= - /** - * Broadcasted copy of this RDD, used to dispatch tasks to executors. Note that we broadcast - * the serialized copy of the RDD and for each task we will deserialize it, which means each - * task gets a different copy of the RDD. This provides stronger isolation between tasks that - * might modify state of objects referenced in their closures. This is necessary in Hadoop - * where the JobConf/Configuration object is not thread-safe. - */ - @transient private[spark] lazy val broadcasted: Broadcast[Array[Byte]] = { - val ser = SparkEnv.get.closureSerializer.newInstance() - val bytes = ser.serialize(this).array() - val size = Utils.bytesToString(bytes.length) - if (bytes.length > (1L << 20)) { - logWarning(s"Broadcasting RDD $id ($size), which contains large objects") - } else { - logDebug(s"Broadcasting RDD $id ($size)") - } - sc.broadcast(bytes) - } - private var storageLevel: StorageLevel = StorageLevel.NONE /** User code that created this RDD (e.g. `textFile`, `parallelize`). */ diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index f67e5f1857979fa8244838c4dcf26e9a86343451..c3b2a33fb54d095f32d988070115ae735ebf5c66 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -106,6 +106,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) cpRDD = Some(newRDD) rdd.markCheckpointed(newRDD) // Update the RDD's dependencies and partitions cpState = Checkpointed + RDDCheckpointData.clearTaskCaches() } logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id) } @@ -130,5 +131,9 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T]) } } -// Used for synchronization -private[spark] object RDDCheckpointData +private[spark] object RDDCheckpointData { + def clearTaskCaches() { + ShuffleMapTask.clearCache() + ResultTask.clearCache() + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 88cb5feaaff2a85cf36f335ab337f4ae2470da4b..ede3c7d9f01ae3b924bf5945ec4f9e38d312e031 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -376,6 +376,9 @@ class DAGScheduler( stageIdToStage -= stageId stageIdToJobIds -= stageId + ShuffleMapTask.removeStage(stageId) + ResultTask.removeStage(stageId) + logDebug("After removal of stage %d, remaining stages = %d" .format(stageId, stageIdToStage.size)) } @@ -720,6 +723,7 @@ class DAGScheduler( } } + /** Called when stage's parents are available and we can now do its task. */ private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 62beb0d02a9c3602e56bf96c0965fd5744fca635..bbf9f7388b074582a56d1875c764fe375eeb7270 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -17,68 +17,134 @@ package org.apache.spark.scheduler -import java.nio.ByteBuffer +import scala.language.existentials import java.io._ +import java.util.zip.{GZIPInputStream, GZIPOutputStream} + +import scala.collection.mutable.HashMap import org.apache.spark._ -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{RDD, RDDCheckpointData} + +private[spark] object ResultTask { + + // A simple map between the stage id to the serialized byte array of a task. + // Served as a cache for task serialization because serialization can be + // expensive on the master node if it needs to launch thousands of tasks. + private val serializedInfoCache = new HashMap[Int, Array[Byte]] + + def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = + { + synchronized { + val old = serializedInfoCache.get(stageId).orNull + if (old != null) { + old + } else { + val out = new ByteArrayOutputStream + val ser = SparkEnv.get.closureSerializer.newInstance() + val objOut = ser.serializeStream(new GZIPOutputStream(out)) + objOut.writeObject(rdd) + objOut.writeObject(func) + objOut.close() + val bytes = out.toByteArray + serializedInfoCache.put(stageId, bytes) + bytes + } + } + } + + def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], (TaskContext, Iterator[_]) => _) = + { + val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) + val ser = SparkEnv.get.closureSerializer.newInstance() + val objIn = ser.deserializeStream(in) + val rdd = objIn.readObject().asInstanceOf[RDD[_]] + val func = objIn.readObject().asInstanceOf[(TaskContext, Iterator[_]) => _] + (rdd, func) + } + + def removeStage(stageId: Int) { + serializedInfoCache.remove(stageId) + } + + def clearCache() { + synchronized { + serializedInfoCache.clear() + } + } +} + /** * A task that sends back the output to the driver application. * - * See [[Task]] for more information. + * See [[org.apache.spark.scheduler.Task]] for more information. * * @param stageId id of the stage this task belongs to - * @param rddBinary broadcast version of of the serialized RDD + * @param rdd input to func * @param func a function to apply on a partition of the RDD - * @param partition partition of the RDD this task is associated with + * @param _partitionId index of the number in the RDD * @param locs preferred task execution locations for locality scheduling * @param outputId index of the task in this job (a job can launch tasks on only a subset of the * input RDD's partitions). */ private[spark] class ResultTask[T, U]( stageId: Int, - val rddBinary: Broadcast[Array[Byte]], - val func: (TaskContext, Iterator[T]) => U, - val partition: Partition, + var rdd: RDD[T], + var func: (TaskContext, Iterator[T]) => U, + _partitionId: Int, @transient locs: Seq[TaskLocation], - val outputId: Int) - extends Task[U](stageId, partition.index) with Serializable { - - // TODO: Should we also broadcast func? For that we would need a place to - // keep a reference to it (perhaps in DAGScheduler's job object). - - def this( - stageId: Int, - rdd: RDD[T], - func: (TaskContext, Iterator[T]) => U, - partitionId: Int, - locs: Seq[TaskLocation], - outputId: Int) = { - this(stageId, rdd.broadcasted, func, rdd.partitions(partitionId), locs, outputId) - } + var outputId: Int) + extends Task[U](stageId, _partitionId) with Externalizable { + + def this() = this(0, null, null, 0, null, 0) + + var split = if (rdd == null) null else rdd.partitions(partitionId) - @transient private[this] val preferredLocs: Seq[TaskLocation] = { + @transient private val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } override def runTask(context: TaskContext): U = { - // Deserialize the RDD using the broadcast variable. - val ser = SparkEnv.get.closureSerializer.newInstance() - val rdd = ser.deserialize[RDD[T]](ByteBuffer.wrap(rddBinary.value), - Thread.currentThread.getContextClassLoader) metrics = Some(context.taskMetrics) try { - func(context, rdd.iterator(partition, context)) + func(context, rdd.iterator(split, context)) } finally { context.executeOnCompleteCallbacks() } } - // This is only callable on the driver side. override def preferredLocations: Seq[TaskLocation] = preferredLocs override def toString = "ResultTask(" + stageId + ", " + partitionId + ")" + + override def writeExternal(out: ObjectOutput) { + RDDCheckpointData.synchronized { + split = rdd.partitions(partitionId) + out.writeInt(stageId) + val bytes = ResultTask.serializeInfo( + stageId, rdd, func.asInstanceOf[(TaskContext, Iterator[_]) => _]) + out.writeInt(bytes.length) + out.write(bytes) + out.writeInt(partitionId) + out.writeInt(outputId) + out.writeLong(epoch) + out.writeObject(split) + } + } + + override def readExternal(in: ObjectInput) { + val stageId = in.readInt() + val numBytes = in.readInt() + val bytes = new Array[Byte](numBytes) + in.readFully(bytes) + val (rdd_, func_) = ResultTask.deserializeInfo(stageId, bytes) + rdd = rdd_.asInstanceOf[RDD[T]] + func = func_.asInstanceOf[(TaskContext, Iterator[T]) => U] + partitionId = in.readInt() + outputId = in.readInt() + epoch = in.readLong() + split = in.readObject().asInstanceOf[Partition] + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 033c6e52861e0623a76d50307e76db39c3d087aa..fdaf1de83f051ce07294c58b1ee329ec25685405 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -17,13 +17,71 @@ package org.apache.spark.scheduler -import java.nio.ByteBuffer +import scala.language.existentials + +import java.io._ +import java.util.zip.{GZIPInputStream, GZIPOutputStream} + +import scala.collection.mutable.HashMap import org.apache.spark._ -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{RDD, RDDCheckpointData} import org.apache.spark.shuffle.ShuffleWriter +private[spark] object ShuffleMapTask { + + // A simple map between the stage id to the serialized byte array of a task. + // Served as a cache for task serialization because serialization can be + // expensive on the master node if it needs to launch thousands of tasks. + private val serializedInfoCache = new HashMap[Int, Array[Byte]] + + def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_, _, _]): Array[Byte] = { + synchronized { + val old = serializedInfoCache.get(stageId).orNull + if (old != null) { + return old + } else { + val out = new ByteArrayOutputStream + val ser = SparkEnv.get.closureSerializer.newInstance() + val objOut = ser.serializeStream(new GZIPOutputStream(out)) + objOut.writeObject(rdd) + objOut.writeObject(dep) + objOut.close() + val bytes = out.toByteArray + serializedInfoCache.put(stageId, bytes) + bytes + } + } + } + + def deserializeInfo(stageId: Int, bytes: Array[Byte]): (RDD[_], ShuffleDependency[_, _, _]) = { + val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) + val ser = SparkEnv.get.closureSerializer.newInstance() + val objIn = ser.deserializeStream(in) + val rdd = objIn.readObject().asInstanceOf[RDD[_]] + val dep = objIn.readObject().asInstanceOf[ShuffleDependency[_, _, _]] + (rdd, dep) + } + + // Since both the JarSet and FileSet have the same format this is used for both. + def deserializeFileSet(bytes: Array[Byte]): HashMap[String, Long] = { + val in = new GZIPInputStream(new ByteArrayInputStream(bytes)) + val objIn = new ObjectInputStream(in) + val set = objIn.readObject().asInstanceOf[Array[(String, Long)]].toMap + HashMap(set.toSeq: _*) + } + + def removeStage(stageId: Int) { + serializedInfoCache.remove(stageId) + } + + def clearCache() { + synchronized { + serializedInfoCache.clear() + } + } +} + /** * A ShuffleMapTask divides the elements of an RDD into multiple buckets (based on a partitioner * specified in the ShuffleDependency). @@ -31,47 +89,62 @@ import org.apache.spark.shuffle.ShuffleWriter * See [[org.apache.spark.scheduler.Task]] for more information. * * @param stageId id of the stage this task belongs to - * @param rddBinary broadcast version of of the serialized RDD + * @param rdd the final RDD in this stage * @param dep the ShuffleDependency - * @param partition partition of the RDD this task is associated with + * @param _partitionId index of the number in the RDD * @param locs preferred task execution locations for locality scheduling */ private[spark] class ShuffleMapTask( stageId: Int, - var rddBinary: Broadcast[Array[Byte]], + var rdd: RDD[_], var dep: ShuffleDependency[_, _, _], - partition: Partition, + _partitionId: Int, @transient private var locs: Seq[TaskLocation]) - extends Task[MapStatus](stageId, partition.index) with Logging { - - // TODO: Should we also broadcast the ShuffleDependency? For that we would need a place to - // keep a reference to it (perhaps in Stage). - - def this( - stageId: Int, - rdd: RDD[_], - dep: ShuffleDependency[_, _, _], - partitionId: Int, - locs: Seq[TaskLocation]) = { - this(stageId, rdd.broadcasted, dep, rdd.partitions(partitionId), locs) - } + extends Task[MapStatus](stageId, _partitionId) + with Externalizable + with Logging { + + protected def this() = this(0, null, null, 0, null) @transient private val preferredLocs: Seq[TaskLocation] = { if (locs == null) Nil else locs.toSet.toSeq } - override def runTask(context: TaskContext): MapStatus = { - // Deserialize the RDD using the broadcast variable. - val ser = SparkEnv.get.closureSerializer.newInstance() - val rdd = ser.deserialize[RDD[_]](ByteBuffer.wrap(rddBinary.value), - Thread.currentThread.getContextClassLoader) + var split = if (rdd == null) null else rdd.partitions(partitionId) + + override def writeExternal(out: ObjectOutput) { + RDDCheckpointData.synchronized { + split = rdd.partitions(partitionId) + out.writeInt(stageId) + val bytes = ShuffleMapTask.serializeInfo(stageId, rdd, dep) + out.writeInt(bytes.length) + out.write(bytes) + out.writeInt(partitionId) + out.writeLong(epoch) + out.writeObject(split) + } + } + override def readExternal(in: ObjectInput) { + val stageId = in.readInt() + val numBytes = in.readInt() + val bytes = new Array[Byte](numBytes) + in.readFully(bytes) + val (rdd_, dep_) = ShuffleMapTask.deserializeInfo(stageId, bytes) + rdd = rdd_ + dep = dep_ + partitionId = in.readInt() + epoch = in.readLong() + split = in.readObject().asInstanceOf[Partition] + } + + override def runTask(context: TaskContext): MapStatus = { metrics = Some(context.taskMetrics) var writer: ShuffleWriter[Any, Any] = null try { val manager = SparkEnv.get.shuffleManager writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, context) - writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) + writer.write(rdd.iterator(split, context).asInstanceOf[Iterator[_ <: Product2[Any, Any]]]) return writer.stop(success = true).get } catch { case e: Exception => diff --git a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala index 871f831531bee9639d3b87a84f981618b6b2eb2a..13b415cccb647f4ca8bf34b5850895bcb4916c07 100644 --- a/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala +++ b/core/src/test/scala/org/apache/spark/ContextCleanerSuite.scala @@ -52,8 +52,9 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } } + test("cleanup RDD") { - val rdd = newRDD().persist() + val rdd = newRDD.persist() val collected = rdd.collect().toList val tester = new CleanerTester(sc, rddIds = Seq(rdd.id)) @@ -66,7 +67,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("cleanup shuffle") { - val (rdd, shuffleDeps) = newRDDWithShuffleDependencies() + val (rdd, shuffleDeps) = newRDDWithShuffleDependencies val collected = rdd.collect().toList val tester = new CleanerTester(sc, shuffleIds = shuffleDeps.map(_.shuffleId)) @@ -79,7 +80,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("cleanup broadcast") { - val broadcast = newBroadcast() + val broadcast = newBroadcast val tester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) // Explicit cleanup @@ -88,7 +89,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("automatically cleanup RDD") { - var rdd = newRDD().persist() + var rdd = newRDD.persist() rdd.count() // Test that GC does not cause RDD cleanup due to a strong reference @@ -106,7 +107,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("automatically cleanup shuffle") { - var rdd = newShuffleRDD() + var rdd = newShuffleRDD rdd.count() // Test that GC does not cause shuffle cleanup due to a strong reference @@ -124,7 +125,7 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo } test("automatically cleanup broadcast") { - var broadcast = newBroadcast() + var broadcast = newBroadcast // Test that GC does not cause broadcast cleanup due to a strong reference val preGCTester = new CleanerTester(sc, broadcastIds = Seq(broadcast.id)) @@ -140,23 +141,11 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo postGCTester.assertCleanup() } - test("automatically cleanup broadcast data for task dispatching") { - var rdd = newRDDWithShuffleDependencies()._1 - rdd.count() // This triggers an action that broadcasts the RDDs. - - // Test that GC causes broadcast task data cleanup after dereferencing the RDD. - val postGCTester = new CleanerTester(sc, - broadcastIds = Seq(rdd.broadcasted.id, rdd.firstParent.broadcasted.id)) - rdd = null - runGC() - postGCTester.assertCleanup() - } - test("automatically cleanup RDD + shuffle + broadcast") { val numRdds = 100 val numBroadcasts = 4 // Broadcasts are more costly - val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer - val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer + val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer val rddIds = sc.persistentRdds.keys.toSeq val shuffleIds = 0 until sc.newShuffleId val broadcastIds = 0L until numBroadcasts @@ -186,8 +175,8 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo val numRdds = 10 val numBroadcasts = 4 // Broadcasts are more costly - val rddBuffer = (1 to numRdds).map(i => randomRdd()).toBuffer - val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast()).toBuffer + val rddBuffer = (1 to numRdds).map(i => randomRdd).toBuffer + val broadcastBuffer = (1 to numBroadcasts).map(i => randomBroadcast).toBuffer val rddIds = sc.persistentRdds.keys.toSeq val shuffleIds = 0 until sc.newShuffleId val broadcastIds = 0L until numBroadcasts @@ -208,18 +197,17 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo //------ Helper functions ------ - private def newRDD() = sc.makeRDD(1 to 10) - private def newPairRDD() = newRDD().map(_ -> 1) - private def newShuffleRDD() = newPairRDD().reduceByKey(_ + _) - private def newBroadcast() = sc.broadcast(1 to 100) - - private def newRDDWithShuffleDependencies(): (RDD[_], Seq[ShuffleDependency[_, _, _]]) = { + def newRDD = sc.makeRDD(1 to 10) + def newPairRDD = newRDD.map(_ -> 1) + def newShuffleRDD = newPairRDD.reduceByKey(_ + _) + def newBroadcast = sc.broadcast(1 to 100) + def newRDDWithShuffleDependencies: (RDD[_], Seq[ShuffleDependency[_, _, _]]) = { def getAllDependencies(rdd: RDD[_]): Seq[Dependency[_]] = { rdd.dependencies ++ rdd.dependencies.flatMap { dep => getAllDependencies(dep.rdd) } } - val rdd = newShuffleRDD() + val rdd = newShuffleRDD // Get all the shuffle dependencies val shuffleDeps = getAllDependencies(rdd) @@ -228,34 +216,34 @@ class ContextCleanerSuite extends FunSuite with BeforeAndAfter with LocalSparkCo (rdd, shuffleDeps) } - private def randomRdd() = { + def randomRdd = { val rdd: RDD[_] = Random.nextInt(3) match { - case 0 => newRDD() - case 1 => newShuffleRDD() - case 2 => newPairRDD.join(newPairRDD()) + case 0 => newRDD + case 1 => newShuffleRDD + case 2 => newPairRDD.join(newPairRDD) } if (Random.nextBoolean()) rdd.persist() rdd.count() rdd } - private def randomBroadcast() = { + def randomBroadcast = { sc.broadcast(Random.nextInt(Int.MaxValue)) } /** Run GC and make sure it actually has run */ - private def runGC() { + def runGC() { val weakRef = new WeakReference(new Object()) val startTime = System.currentTimeMillis System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. // Wait until a weak reference object has been GCed - while (System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { + while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { System.gc() Thread.sleep(200) } } - private def cleaner = sc.cleaner.get + def cleaner = sc.cleaner.get }