From f9c5b0a6fe8d728e16c60c0cf51ced0054e3a387 Mon Sep 17 00:00:00 2001 From: Tathagata Das <tathagata.das1565@gmail.com> Date: Thu, 20 Dec 2012 11:52:23 -0800 Subject: [PATCH] Changed checkpoint writing and reading process. --- .../main/scala/spark/RDDCheckpointData.scala | 27 +--- .../main/scala/spark/rdd/CheckpointRDD.scala | 117 ++++++++++++++++++ core/src/main/scala/spark/rdd/HadoopRDD.scala | 5 +- 3 files changed, 124 insertions(+), 25 deletions(-) create mode 100644 core/src/main/scala/spark/rdd/CheckpointRDD.scala diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala index e4c0912cdc..1aa9b9aa1e 100644 --- a/core/src/main/scala/spark/RDDCheckpointData.scala +++ b/core/src/main/scala/spark/RDDCheckpointData.scala @@ -1,7 +1,7 @@ package spark import org.apache.hadoop.fs.Path -import rdd.CoalescedRDD +import rdd.{CheckpointRDD, CoalescedRDD} import scheduler.{ResultTask, ShuffleMapTask} /** @@ -55,30 +55,13 @@ extends Logging with Serializable { } // Save to file, and reload it as an RDD - val file = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString - rdd.saveAsObjectFile(file) - - val newRDD = { - val hadoopRDD = rdd.context.objectFile[T](file, rdd.splits.size) - - val oldSplits = rdd.splits.size - val newSplits = hadoopRDD.splits.size - - logDebug("RDD splits = " + oldSplits + " --> " + newSplits) - if (newSplits < oldSplits) { - throw new Exception("# splits after checkpointing is less than before " + - "[" + oldSplits + " --> " + newSplits) - } else if (newSplits > oldSplits) { - new CoalescedRDD(hadoopRDD, rdd.splits.size) - } else { - hadoopRDD - } - } - logDebug("New RDD has " + newRDD.splits.size + " splits") + val path = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString + rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _) + val newRDD = new CheckpointRDD[T](rdd.context, path) // Change the dependencies and splits of the RDD RDDCheckpointData.synchronized { - cpFile = Some(file) + cpFile = Some(path) cpRDD = Some(newRDD) rdd.changeDependencies(newRDD) cpState = Checkpointed diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala new file mode 100644 index 0000000000..c673ab6aaa --- /dev/null +++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala @@ -0,0 +1,117 @@ +package spark.rdd + +import spark._ +import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter} +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.io.{NullWritable, BytesWritable} +import org.apache.hadoop.util.ReflectionUtils +import org.apache.hadoop.fs.Path +import java.io.{File, IOException, EOFException} +import java.text.NumberFormat + +private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends Split { + override val index: Int = idx +} + +class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String) + extends RDD[T](sc, Nil) { + + @transient val path = new Path(checkpointPath) + @transient val fs = path.getFileSystem(new Configuration()) + + @transient val splits_ : Array[Split] = { + val splitFiles = fs.listStatus(path).map(_.getPath.toString).filter(_.contains("part-")).sorted + splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray + } + + override def getSplits = splits_ + + override def getPreferredLocations(split: Split): Seq[String] = { + val status = fs.getFileStatus(path) + val locations = fs.getFileBlockLocations(status, 0, status.getLen) + locations.firstOption.toList.flatMap(_.getHosts).filter(_ != "localhost") + } + + override def compute(split: Split): Iterator[T] = { + CheckpointRDD.readFromFile(split.asInstanceOf[CheckpointRDDSplit].splitFile) + } + + override def checkpoint() { + // Do nothing. Hadoop RDD should not be checkpointed. + } +} + +private[spark] object CheckpointRDD extends Logging { + + def splitIdToFileName(splitId: Int): String = { + val numfmt = NumberFormat.getInstance() + numfmt.setMinimumIntegerDigits(5) + numfmt.setGroupingUsed(false) + "part-" + numfmt.format(splitId) + } + + def writeToFile[T](path: String, blockSize: Int = -1)(context: TaskContext, iterator: Iterator[T]) { + val outputDir = new Path(path) + val fs = outputDir.getFileSystem(new Configuration()) + + val finalOutputName = splitIdToFileName(context.splitId) + val finalOutputPath = new Path(outputDir, finalOutputName) + val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptId) + + if (fs.exists(tempOutputPath)) { + throw new IOException("Checkpoint failed: temporary path " + + tempOutputPath + " already exists") + } + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + + val fileOutputStream = if (blockSize < 0) { + fs.create(tempOutputPath, false, bufferSize) + } else { + // This is mainly for testing purpose + fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize) + } + val serializer = SparkEnv.get.serializer.newInstance() + val serializeStream = serializer.serializeStream(fileOutputStream) + serializeStream.writeAll(iterator) + fileOutputStream.close() + + if (!fs.rename(tempOutputPath, finalOutputPath)) { + if (!fs.delete(finalOutputPath, true)) { + throw new IOException("Checkpoint failed: failed to delete earlier output of task " + + context.attemptId); + } + if (!fs.rename(tempOutputPath, finalOutputPath)) { + throw new IOException("Checkpoint failed: failed to save output of task: " + + context.attemptId) + } + } + } + + def readFromFile[T](path: String): Iterator[T] = { + val inputPath = new Path(path) + val fs = inputPath.getFileSystem(new Configuration()) + val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt + val fileInputStream = fs.open(inputPath, bufferSize) + val serializer = SparkEnv.get.serializer.newInstance() + val deserializeStream = serializer.deserializeStream(fileInputStream) + deserializeStream.asIterator.asInstanceOf[Iterator[T]] + } + + // Test whether CheckpointRDD generate expected number of splits despite + // each split file having multiple blocks. This needs to be run on a + // cluster (mesos or standalone) using HDFS. + def main(args: Array[String]) { + import spark._ + + val Array(cluster, hdfsPath) = args + val sc = new SparkContext(cluster, "CheckpointRDD Test") + val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000) + val path = new Path(hdfsPath, "temp") + val fs = path.getFileSystem(new Configuration()) + sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 10) _) + val cpRDD = new CheckpointRDD[Int](sc, path.toString) + assert(cpRDD.splits.length == rdd.splits.length, "Number of splits is not the same") + assert(cpRDD.collect.toList == rdd.collect.toList, "Data of splits not the same") + fs.delete(path) + } +} diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala index fce190b860..eca51758e4 100644 --- a/core/src/main/scala/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala @@ -25,8 +25,7 @@ import spark.Split * A Spark split class that wraps around a Hadoop InputSplit. */ private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit) - extends Split - with Serializable { + extends Split { val inputSplit = new SerializableWritable[InputSplit](s) @@ -117,6 +116,6 @@ class HadoopRDD[K, V]( } override def checkpoint() { - // Do nothing. Hadoop RDD cannot be checkpointed. + // Do nothing. Hadoop RDD should not be checkpointed. } } -- GitLab