From 5184141936c18f12c6738caae6fceee4d15800e2 Mon Sep 17 00:00:00 2001 From: Tathagata Das <tathagata.das1565@gmail.com> Date: Tue, 18 Dec 2012 13:30:53 -0800 Subject: [PATCH] Introduced getSpits, getDependencies, and getPreferredLocations in RDD and RDDCheckpointData. --- .../main/scala/spark/PairRDDFunctions.scala | 4 +- .../main/scala/spark/ParallelCollection.scala | 9 +- core/src/main/scala/spark/RDD.scala | 123 +++++++++++------- .../main/scala/spark/RDDCheckpointData.scala | 10 +- core/src/main/scala/spark/rdd/BlockRDD.scala | 9 +- .../main/scala/spark/rdd/CartesianRDD.scala | 12 +- .../main/scala/spark/rdd/CoGroupedRDD.scala | 11 +- .../main/scala/spark/rdd/CoalescedRDD.scala | 10 +- .../main/scala/spark/rdd/FilteredRDD.scala | 2 +- .../main/scala/spark/rdd/FlatMappedRDD.scala | 2 +- .../src/main/scala/spark/rdd/GlommedRDD.scala | 2 +- core/src/main/scala/spark/rdd/HadoopRDD.scala | 4 +- .../scala/spark/rdd/MapPartitionsRDD.scala | 2 +- .../spark/rdd/MapPartitionsWithSplitRDD.scala | 2 +- core/src/main/scala/spark/rdd/MappedRDD.scala | 2 +- .../main/scala/spark/rdd/NewHadoopRDD.scala | 4 +- core/src/main/scala/spark/rdd/PipedRDD.scala | 2 +- .../src/main/scala/spark/rdd/SampledRDD.scala | 9 +- .../main/scala/spark/rdd/ShuffledRDD.scala | 7 +- core/src/main/scala/spark/rdd/UnionRDD.scala | 13 +- .../scala/spark/scheduler/DAGScheduler.scala | 2 +- .../test/scala/spark/CheckpointSuite.scala | 6 +- 22 files changed, 134 insertions(+), 113 deletions(-) diff --git a/core/src/main/scala/spark/PairRDDFunctions.scala b/core/src/main/scala/spark/PairRDDFunctions.scala index 1f82bd3ab8..09ac606cfb 100644 --- a/core/src/main/scala/spark/PairRDDFunctions.scala +++ b/core/src/main/scala/spark/PairRDDFunctions.scala @@ -628,7 +628,7 @@ private[spark] class MappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => U) extends RDD[(K, U)](prev.get) { - override def splits = firstParent[(K, V)].splits + override def getSplits = firstParent[(K, V)].splits override val partitioner = firstParent[(K, V)].partitioner override def compute(split: Split) = firstParent[(K, V)].iterator(split).map{case (k, v) => (k, f(v))} } @@ -637,7 +637,7 @@ private[spark] class FlatMappedValuesRDD[K, V, U](prev: WeakReference[RDD[(K, V)]], f: V => TraversableOnce[U]) extends RDD[(K, U)](prev.get) { - override def splits = firstParent[(K, V)].splits + override def getSplits = firstParent[(K, V)].splits override val partitioner = firstParent[(K, V)].partitioner override def compute(split: Split) = { firstParent[(K, V)].iterator(split).flatMap { case (k, v) => f(v).map(x => (k, x)) } diff --git a/core/src/main/scala/spark/ParallelCollection.scala b/core/src/main/scala/spark/ParallelCollection.scala index 9d12af6912..0bc5b2ff11 100644 --- a/core/src/main/scala/spark/ParallelCollection.scala +++ b/core/src/main/scala/spark/ParallelCollection.scala @@ -37,15 +37,12 @@ private[spark] class ParallelCollection[T: ClassManifest]( slices.indices.map(i => new ParallelCollectionSplit(id, i, slices(i))).toArray } - override def splits = splits_.asInstanceOf[Array[Split]] + override def getSplits = splits_.asInstanceOf[Array[Split]] override def compute(s: Split) = s.asInstanceOf[ParallelCollectionSplit[T]].iterator - - override def preferredLocations(s: Split): Seq[String] = Nil - override def changeDependencies(newRDD: RDD[_]) { - dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + override def clearDependencies() { + splits_ = null } } diff --git a/core/src/main/scala/spark/RDD.scala b/core/src/main/scala/spark/RDD.scala index 6c04769c82..f3e422fa5f 100644 --- a/core/src/main/scala/spark/RDD.scala +++ b/core/src/main/scala/spark/RDD.scala @@ -81,48 +81,33 @@ abstract class RDD[T: ClassManifest]( def this(@transient oneParent: RDD[_]) = this(oneParent.context , List(new OneToOneDependency(oneParent))) - // Methods that must be implemented by subclasses: - - /** Set of partitions in this RDD. */ - def splits: Array[Split] + // ======================================================================= + // Methods that should be implemented by subclasses of RDD + // ======================================================================= /** Function for computing a given partition. */ def compute(split: Split): Iterator[T] - /** How this RDD depends on any parent RDDs. */ - def dependencies: List[Dependency[_]] = dependencies_ + /** Set of partitions in this RDD. */ + protected def getSplits(): Array[Split] - /** Record user function generating this RDD. */ - private[spark] val origin = Utils.getSparkCallSite - - /** Optionally overridden by subclasses to specify how they are partitioned. */ - val partitioner: Option[Partitioner] = None + /** How this RDD depends on any parent RDDs. */ + protected def getDependencies(): List[Dependency[_]] = dependencies_ /** Optionally overridden by subclasses to specify placement preferences. */ - def preferredLocations(split: Split): Seq[String] = Nil - - /** The [[spark.SparkContext]] that this RDD was created on. */ - def context = sc + protected def getPreferredLocations(split: Split): Seq[String] = Nil - private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] - - /** A unique ID for this RDD (within its SparkContext). */ - val id = sc.newRddId() - - // Variables relating to persistence - private var storageLevel: StorageLevel = StorageLevel.NONE + /** Optionally overridden by subclasses to specify how they are partitioned. */ + val partitioner: Option[Partitioner] = None - protected[spark] var checkpointData: Option[RDDCheckpointData[T]] = None - /** Returns the first parent RDD */ - protected[spark] def firstParent[U: ClassManifest] = { - dependencies.head.rdd.asInstanceOf[RDD[U]] - } - /** Returns the `i` th parent RDD */ - protected[spark] def parent[U: ClassManifest](i: Int) = dependencies(i).rdd.asInstanceOf[RDD[U]] + // ======================================================================= + // Methods and fields available on all RDDs + // ======================================================================= - // Methods available on all RDDs: + /** A unique ID for this RDD (within its SparkContext). */ + val id = sc.newRddId() /** * Set this RDD's storage level to persist its values across operations after the first time @@ -147,11 +132,39 @@ abstract class RDD[T: ClassManifest]( /** Get the RDD's current storage level, or StorageLevel.NONE if none is set. */ def getStorageLevel = storageLevel - def getPreferredLocations(split: Split) = { + /** + * Get the preferred location of a split, taking into account whether the + * RDD is checkpointed or not. + */ + final def preferredLocations(split: Split): Seq[String] = { + if (isCheckpointed) { + checkpointData.get.getPreferredLocations(split) + } else { + getPreferredLocations(split) + } + } + + /** + * Get the array of splits of this RDD, taking into account whether the + * RDD is checkpointed or not. + */ + final def splits: Array[Split] = { + if (isCheckpointed) { + checkpointData.get.getSplits + } else { + getSplits + } + } + + /** + * Get the array of splits of this RDD, taking into account whether the + * RDD is checkpointed or not. + */ + final def dependencies: List[Dependency[_]] = { if (isCheckpointed) { - checkpointData.get.preferredLocations(split) + dependencies_ } else { - preferredLocations(split) + getDependencies } } @@ -536,6 +549,27 @@ abstract class RDD[T: ClassManifest]( if (checkpointData.isDefined) checkpointData.get.getCheckpointFile() else None } + // ======================================================================= + // Other internal methods and fields + // ======================================================================= + + private var storageLevel: StorageLevel = StorageLevel.NONE + + /** Record user function generating this RDD. */ + private[spark] val origin = Utils.getSparkCallSite + + private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] + + private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None + + /** Returns the first parent RDD */ + protected[spark] def firstParent[U: ClassManifest] = { + dependencies.head.rdd.asInstanceOf[RDD[U]] + } + + /** The [[spark.SparkContext]] that this RDD was created on. */ + def context = sc + /** * Performs the checkpointing of this RDD by saving this . It is called by the DAGScheduler * after a job using this RDD has completed (therefore the RDD has been materialized and @@ -548,23 +582,18 @@ abstract class RDD[T: ClassManifest]( /** * Changes the dependencies of this RDD from its original parents to the new RDD - * (`newRDD`) created from the checkpoint file. This method must ensure that all references - * to the original parent RDDs must be removed to enable the parent RDDs to be garbage - * collected. Subclasses of RDD may override this method for implementing their own changing - * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. + * (`newRDD`) created from the checkpoint file. */ protected[spark] def changeDependencies(newRDD: RDD[_]) { + clearDependencies() dependencies_ = List(new OneToOneDependency(newRDD)) } - @throws(classOf[IOException]) - private def writeObject(oos: ObjectOutputStream) { - oos.defaultWriteObject() - } - - @throws(classOf[IOException]) - private def readObject(ois: ObjectInputStream) { - ois.defaultReadObject() - } - + /** + * Clears the dependencies of this RDD. This method must ensure that all references + * to the original parent RDDs must be removed to enable the parent RDDs to be garbage + * collected. Subclasses of RDD may override this method for implementing their own changing + * logic. See [[spark.rdd.UnionRDD]] and [[spark.rdd.ShuffledRDD]] to get a better idea. + */ + protected[spark] def clearDependencies() { } } diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index 7613b338e6..e4c0912cdc 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -24,7 +24,6 @@ extends Logging with Serializable { var cpState = Initialized @transient var cpFile: Option[String] = None @transient var cpRDD: Option[RDD[T]] = None - @transient var cpRDDSplits: Seq[Split] = Nil // Mark the RDD for checkpointing def markForCheckpoint() { @@ -81,7 +80,6 @@ extends Logging with Serializable { RDDCheckpointData.synchronized { cpFile = Some(file) cpRDD = Some(newRDD) - cpRDDSplits = newRDD.splits rdd.changeDependencies(newRDD) cpState = Checkpointed RDDCheckpointData.checkpointCompleted() @@ -90,12 +88,18 @@ extends Logging with Serializable { } // Get preferred location of a split after checkpointing - def preferredLocations(split: Split) = { + def getPreferredLocations(split: Split) = { RDDCheckpointData.synchronized { cpRDD.get.preferredLocations(split) } } + def getSplits: Array[Split] = { + RDDCheckpointData.synchronized { + cpRDD.get.splits + } + } + // Get iterator. This is called at the worker nodes. def iterator(split: Split): Iterator[T] = { rdd.firstParent[T].iterator(split) diff --git a/core/src/main/scala/spark/rdd/BlockRDD.scala b/core/src/main/scala/spark/rdd/BlockRDD.scala index 0c8cdd10dd..68e570eb15 100644 --- a/core/src/main/scala/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/spark/rdd/BlockRDD.scala @@ -29,7 +29,7 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St HashMap(blockIds.zip(locations):_*) } - override def splits = splits_ + override def getSplits = splits_ override def compute(split: Split): Iterator[T] = { val blockManager = SparkEnv.get.blockManager @@ -41,12 +41,11 @@ class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[St } } - override def preferredLocations(split: Split) = + override def getPreferredLocations(split: Split) = locations_(split.asInstanceOf[BlockRDDSplit].blockId) - override def changeDependencies(newRDD: RDD[_]) { - dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + override def clearDependencies() { + splits_ = null } } diff --git a/core/src/main/scala/spark/rdd/CartesianRDD.scala b/core/src/main/scala/spark/rdd/CartesianRDD.scala index 9975e79b08..116644bd52 100644 --- a/core/src/main/scala/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/spark/rdd/CartesianRDD.scala @@ -45,9 +45,9 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( array } - override def splits = splits_ + override def getSplits = splits_ - override def preferredLocations(split: Split) = { + override def getPreferredLocations(split: Split) = { val currSplit = split.asInstanceOf[CartesianSplit] rdd1.preferredLocations(currSplit.s1) ++ rdd2.preferredLocations(currSplit.s2) } @@ -66,11 +66,11 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( } ) - override def dependencies = deps_ + override def getDependencies = deps_ - override def changeDependencies(newRDD: RDD[_]) { - deps_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + override def clearDependencies() { + deps_ = Nil + splits_ = null rdd1 = null rdd2 = null } diff --git a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala index bc6d16ee8b..9cc95dc172 100644 --- a/core/src/main/scala/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoGroupedRDD.scala @@ -65,9 +65,7 @@ CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) deps.toList } - // Pre-checkpoint dependencies deps_ should be transient (deps_) - // but post-checkpoint dependencies must not be transient (dependencies_) - override def dependencies = if (isCheckpointed) dependencies_ else deps_ + override def getDependencies = deps_ @transient var splits_ : Array[Split] = { @@ -85,7 +83,7 @@ CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) array } - override def splits = splits_ + override def getSplits = splits_ override val partitioner = Some(part) @@ -117,10 +115,9 @@ CoGroupedRDD[K](@transient var rdds: Seq[RDD[(_, _)]], part: Partitioner) map.iterator } - override def changeDependencies(newRDD: RDD[_]) { + override def clearDependencies() { deps_ = null - dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + splits_ = null rdds = null } } diff --git a/core/src/main/scala/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/spark/rdd/CoalescedRDD.scala index 088958942e..85d0fa9f6a 100644 --- a/core/src/main/scala/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/spark/rdd/CoalescedRDD.scala @@ -44,7 +44,7 @@ class CoalescedRDD[T: ClassManifest]( } } - override def splits = splits_ + override def getSplits = splits_ override def compute(split: Split): Iterator[T] = { split.asInstanceOf[CoalescedRDDSplit].parents.iterator.flatMap { @@ -59,11 +59,11 @@ class CoalescedRDD[T: ClassManifest]( } ) - override def dependencies = deps_ + override def getDependencies() = deps_ - override def changeDependencies(newRDD: RDD[_]) { - deps_ = List(new OneToOneDependency(newRDD)) - splits_ = newRDD.splits + override def clearDependencies() { + deps_ = Nil + splits_ = null prev = null } } diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala index 02f2e7c246..309ed2399d 100644 --- a/core/src/main/scala/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala @@ -9,6 +9,6 @@ class FilteredRDD[T: ClassManifest]( f: T => Boolean) extends RDD[T](prev) { - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).filter(f) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala index cdc8ecdcfe..1160e68bb8 100644 --- a/core/src/main/scala/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/spark/rdd/FlatMappedRDD.scala @@ -9,6 +9,6 @@ class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( f: T => TraversableOnce[U]) extends RDD[U](prev) { - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).flatMap(f) } diff --git a/core/src/main/scala/spark/rdd/GlommedRDD.scala b/core/src/main/scala/spark/rdd/GlommedRDD.scala index df6f61c69d..4fab1a56fa 100644 --- a/core/src/main/scala/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/spark/rdd/GlommedRDD.scala @@ -6,6 +6,6 @@ import spark.Split private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T]) extends RDD[Array[T]](prev) { - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = Array(firstParent[T].iterator(split).toArray).iterator } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index af54f23ebc..fce190b860 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -67,7 +67,7 @@ class HadoopRDD[K, V]( .asInstanceOf[InputFormat[K, V]] } - override def splits = splits_ + override def getSplits = splits_ override def compute(theSplit: Split) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[HadoopSplit] @@ -110,7 +110,7 @@ class HadoopRDD[K, V]( } } - override def preferredLocations(split: Split) = { + override def getPreferredLocations(split: Split) = { // TODO: Filtering out "localhost" in case of file:// URLs val hadoopSplit = split.asInstanceOf[HadoopSplit] hadoopSplit.inputSplit.value.getLocations.filter(_ != "localhost") diff --git a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala index 23b9fb023b..5f4acee041 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsRDD.scala @@ -12,6 +12,6 @@ class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = f(firstParent[T].iterator(split)) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala index 41955c1d7a..f0f3f2c7c7 100644 --- a/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala +++ b/core/src/main/scala/spark/rdd/MapPartitionsWithSplitRDD.scala @@ -14,6 +14,6 @@ class MapPartitionsWithSplitRDD[U: ClassManifest, T: ClassManifest]( f: (Int, Iterator[T]) => Iterator[U]) extends RDD[U](prev) { - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = f(split.index, firstParent[T].iterator(split)) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/MappedRDD.scala b/core/src/main/scala/spark/rdd/MappedRDD.scala index 6f8cb21fd3..44b542db93 100644 --- a/core/src/main/scala/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/spark/rdd/MappedRDD.scala @@ -9,6 +9,6 @@ class MappedRDD[U: ClassManifest, T: ClassManifest]( f: T => U) extends RDD[U](prev) { - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split) = firstParent[T].iterator(split).map(f) } \ No newline at end of file diff --git a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala index c12df5839e..91f89e3c75 100644 --- a/core/src/main/scala/spark/rdd/NewHadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/NewHadoopRDD.scala @@ -55,7 +55,7 @@ class NewHadoopRDD[K, V]( result } - override def splits = splits_ + override def getSplits = splits_ override def compute(theSplit: Split) = new Iterator[(K, V)] { val split = theSplit.asInstanceOf[NewHadoopSplit] @@ -89,7 +89,7 @@ class NewHadoopRDD[K, V]( } } - override def preferredLocations(split: Split) = { + override def getPreferredLocations(split: Split) = { val theSplit = split.asInstanceOf[NewHadoopSplit] theSplit.serializableHadoopSplit.value.getLocations.filter(_ != "localhost") } diff --git a/core/src/main/scala/spark/rdd/PipedRDD.scala b/core/src/main/scala/spark/rdd/PipedRDD.scala index d2047375ea..a88929e55e 100644 --- a/core/src/main/scala/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/spark/rdd/PipedRDD.scala @@ -29,7 +29,7 @@ class PipedRDD[T: ClassManifest]( // using a standard StringTokenizer (i.e. by spaces) def this(prev: RDD[T], command: String) = this(prev, PipedRDD.tokenize(command)) - override def splits = firstParent[T].splits + override def getSplits = firstParent[T].splits override def compute(split: Split): Iterator[String] = { val pb = new ProcessBuilder(command) diff --git a/core/src/main/scala/spark/rdd/SampledRDD.scala b/core/src/main/scala/spark/rdd/SampledRDD.scala index c622e14a66..da6f65765c 100644 --- a/core/src/main/scala/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/spark/rdd/SampledRDD.scala @@ -26,9 +26,9 @@ class SampledRDD[T: ClassManifest]( firstParent[T].splits.map(x => new SampledRDDSplit(x, rg.nextInt)) } - override def splits = splits_.asInstanceOf[Array[Split]] + override def getSplits = splits_.asInstanceOf[Array[Split]] - override def preferredLocations(split: Split) = + override def getPreferredLocations(split: Split) = firstParent[T].preferredLocations(split.asInstanceOf[SampledRDDSplit].prev) override def compute(splitIn: Split) = { @@ -51,8 +51,7 @@ class SampledRDD[T: ClassManifest]( } } - override def changeDependencies(newRDD: RDD[_]) { - dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + override def clearDependencies() { + splits_ = null } } diff --git a/core/src/main/scala/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/spark/rdd/ShuffledRDD.scala index a9dd3f35ed..2caf33c21e 100644 --- a/core/src/main/scala/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/spark/rdd/ShuffledRDD.scala @@ -25,15 +25,14 @@ class ShuffledRDD[K, V]( @transient var splits_ = Array.tabulate[Split](part.numPartitions)(i => new ShuffledRDDSplit(i)) - override def splits = splits_ + override def getSplits = splits_ override def compute(split: Split): Iterator[(K, V)] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId SparkEnv.get.shuffleFetcher.fetch[K, V](shuffledId, split.index) } - override def changeDependencies(newRDD: RDD[_]) { - dependencies_ = List(new OneToOneDependency(newRDD.asInstanceOf[RDD[Any]])) - splits_ = newRDD.splits + override def clearDependencies() { + splits_ = null } } diff --git a/core/src/main/scala/spark/rdd/UnionRDD.scala b/core/src/main/scala/spark/rdd/UnionRDD.scala index a84867492b..05ed6172d1 100644 --- a/core/src/main/scala/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/spark/rdd/UnionRDD.scala @@ -37,7 +37,7 @@ class UnionRDD[T: ClassManifest]( array } - override def splits = splits_ + override def getSplits = splits_ @transient var deps_ = { val deps = new ArrayBuffer[Dependency[_]] @@ -49,19 +49,16 @@ class UnionRDD[T: ClassManifest]( deps.toList } - // Pre-checkpoint dependencies deps_ should be transient (deps_) - // but post-checkpoint dependencies must not be transient (dependencies_) - override def dependencies = if (isCheckpointed) dependencies_ else deps_ + override def getDependencies = deps_ override def compute(s: Split): Iterator[T] = s.asInstanceOf[UnionSplit[T]].iterator() - override def preferredLocations(s: Split): Seq[String] = + override def getPreferredLocations(s: Split): Seq[String] = s.asInstanceOf[UnionSplit[T]].preferredLocations() - override def changeDependencies(newRDD: RDD[_]) { + override def clearDependencies() { deps_ = null - dependencies_ = List(new OneToOneDependency(newRDD)) - splits_ = newRDD.splits + splits_ = null rdds = null } } diff --git a/core/src/main/scala/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/spark/scheduler/DAGScheduler.scala index 33d35b35d1..4b2570fa2b 100644 --- a/core/src/main/scala/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/spark/scheduler/DAGScheduler.scala @@ -575,7 +575,7 @@ class DAGScheduler(taskSched: TaskScheduler) extends TaskSchedulerListener with return cached } // If the RDD has some placement preferences (as is the case for input RDDs), get those - val rddPrefs = rdd.getPreferredLocations(rdd.splits(partition)).toList + val rddPrefs = rdd.preferredLocations(rdd.splits(partition)).toList if (rddPrefs != Nil) { return rddPrefs } diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 0bffedb8db..19626d2450 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -57,7 +57,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(sc.objectFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.splits.length === numSplits) - assert(parCollection.splits.toList === parCollection.checkpointData.get.cpRDDSplits.toList) + assert(parCollection.splits.toList === parCollection.checkpointData.get.getSplits.toList) assert(parCollection.collect() === result) } @@ -72,7 +72,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(sc.objectFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) assert(blockRDD.splits.length === numSplits) - assert(blockRDD.splits.toList === blockRDD.checkpointData.get.cpRDDSplits.toList) + assert(blockRDD.splits.toList === blockRDD.checkpointData.get.getSplits.toList) assert(blockRDD.collect() === result) } @@ -191,7 +191,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { assert(operatedRDD.dependencies.head.rdd != parentRDD) // Test whether the splits have been changed to the new Hadoop splits - assert(operatedRDD.splits.toList === operatedRDD.checkpointData.get.cpRDDSplits.toList) + assert(operatedRDD.splits.toList === operatedRDD.checkpointData.get.getSplits.toList) // Test whether the number of splits is same as before assert(operatedRDD.splits.length === numSplits) -- GitLab