diff --git a/core/src/main/scala/spark/SparkContext.scala b/core/src/main/scala/spark/SparkContext.scala index 71ed4ef0582ccab830a6bebedd70c46f92d72ca1..362aa04e66bf87a381f43f05e37d1a1a90662aa4 100644 --- a/core/src/main/scala/spark/SparkContext.scala +++ b/core/src/main/scala/spark/SparkContext.scala @@ -37,9 +37,7 @@ import spark.broadcast._ import spark.deploy.LocalSparkCluster import spark.partial.ApproximateEvaluator import spark.partial.PartialResult -import spark.rdd.HadoopRDD -import spark.rdd.NewHadoopRDD -import spark.rdd.UnionRDD +import rdd.{CheckpointRDD, HadoopRDD, NewHadoopRDD, UnionRDD} import scheduler.{ResultTask, ShuffleMapTask, DAGScheduler, TaskScheduler} import spark.scheduler.local.LocalScheduler import spark.scheduler.cluster.{SparkDeploySchedulerBackend, SchedulerBackend, ClusterScheduler} @@ -368,13 +366,9 @@ class SparkContext( protected[spark] def checkpointFile[T: ClassManifest]( - path: String, - minSplits: Int = defaultMinSplits + path: String ): RDD[T] = { - val rdd = objectFile[T](path, minSplits) - rdd.checkpointData = Some(new RDDCheckpointData(rdd)) - rdd.checkpointData.get.cpFile = Some(path) - rdd + new CheckpointRDD[T](this, path) } /** Build the union of a list of RDDs. */ diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala index c673ab6aaa8404dc61d3358d7a42104193a42f18..fbf8a9ef83b8162c293e2b9bea82aefa179af9d5 100644 --- a/core/src/main/scala/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -24,6 +24,9 @@ class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String) splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray } + checkpointData = Some(new RDDCheckpointData[T](this)) + checkpointData.get.cpFile = Some(checkpointPath) + override def getSplits = splits_ override def getPreferredLocations(split: Split): Seq[String] = { diff --git a/core/src/test/scala/spark/CheckpointSuite.scala b/core/src/test/scala/spark/CheckpointSuite.scala index 19626d24500a06a4984d45902d34d72bdb704585..6bc667bd4c169cefb83b1b5806b3eca52ec962ab 100644 --- a/core/src/test/scala/spark/CheckpointSuite.scala +++ b/core/src/test/scala/spark/CheckpointSuite.scala @@ -54,7 +54,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { parCollection.checkpoint() assert(parCollection.dependencies === Nil) val result = parCollection.collect() - assert(sc.objectFile[Int](parCollection.getCheckpointFile.get).collect() === result) + assert(sc.checkpointFile[Int](parCollection.getCheckpointFile.get).collect() === result) assert(parCollection.dependencies != Nil) assert(parCollection.splits.length === numSplits) assert(parCollection.splits.toList === parCollection.checkpointData.get.getSplits.toList) @@ -69,7 +69,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val numSplits = blockRDD.splits.size blockRDD.checkpoint() val result = blockRDD.collect() - assert(sc.objectFile[String](blockRDD.getCheckpointFile.get).collect() === result) + assert(sc.checkpointFile[String](blockRDD.getCheckpointFile.get).collect() === result) assert(blockRDD.dependencies != Nil) assert(blockRDD.splits.length === numSplits) assert(blockRDD.splits.toList === blockRDD.checkpointData.get.getSplits.toList) @@ -185,7 +185,7 @@ class CheckpointSuite extends FunSuite with BeforeAndAfter with Logging { val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD) // Test whether the checkpoint file has been created - assert(sc.objectFile[U](operatedRDD.getCheckpointFile.get).collect() === result) + assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result) // Test whether dependencies have been changed from its earlier parent RDD assert(operatedRDD.dependencies.head.rdd != parentRDD)