From f9c5b0a6fe8d728e16c60c0cf51ced0054e3a387 Mon Sep 17 00:00:00 2001
From: Tathagata Das <tathagata.das1565@gmail.com>
Date: Thu, 20 Dec 2012 11:52:23 -0800
Subject: [PATCH] Changed checkpoint writing and reading process.

---
 .../main/scala/spark/RDDCheckpointData.scala  |  27 +---
 .../main/scala/spark/rdd/CheckpointRDD.scala  | 117 ++++++++++++++++++
 core/src/main/scala/spark/rdd/HadoopRDD.scala |   5 +-
 3 files changed, 124 insertions(+), 25 deletions(-)
 create mode 100644 core/src/main/scala/spark/rdd/CheckpointRDD.scala

diff --git a/core/src/main/scala/spark/RDDCheckpointData.scala b/core/src/main/scala/spark/RDDCheckpointData.scala
index e4c0912cdc..1aa9b9aa1e 100644
--- a/core/src/main/scala/spark/RDDCheckpointData.scala
+++ b/core/src/main/scala/spark/RDDCheckpointData.scala
@@ -1,7 +1,7 @@
 package spark
 
 import org.apache.hadoop.fs.Path
-import rdd.CoalescedRDD
+import rdd.{CheckpointRDD, CoalescedRDD}
 import scheduler.{ResultTask, ShuffleMapTask}
 
 /**
@@ -55,30 +55,13 @@ extends Logging with Serializable {
     }
 
     // Save to file, and reload it as an RDD
-    val file = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString
-    rdd.saveAsObjectFile(file)
-
-    val newRDD = {
-      val hadoopRDD = rdd.context.objectFile[T](file, rdd.splits.size)
-
-      val oldSplits = rdd.splits.size
-      val newSplits = hadoopRDD.splits.size
-
-      logDebug("RDD splits = " + oldSplits + " --> " + newSplits)
-      if (newSplits < oldSplits) {
-        throw new Exception("# splits after checkpointing is less than before " +
-          "[" + oldSplits + " --> " + newSplits)
-      } else if (newSplits > oldSplits) {
-        new CoalescedRDD(hadoopRDD, rdd.splits.size)
-      } else {
-        hadoopRDD
-      }
-    }
-    logDebug("New RDD has " + newRDD.splits.size + " splits")
+    val path = new Path(rdd.context.checkpointDir, "rdd-" + rdd.id).toString
+    rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path) _)
+    val newRDD = new CheckpointRDD[T](rdd.context, path)
 
     // Change the dependencies and splits of the RDD
     RDDCheckpointData.synchronized {
-      cpFile = Some(file)
+      cpFile = Some(path)
       cpRDD = Some(newRDD)
       rdd.changeDependencies(newRDD)
       cpState = Checkpointed
diff --git a/core/src/main/scala/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
new file mode 100644
index 0000000000..c673ab6aaa
--- /dev/null
+++ b/core/src/main/scala/spark/rdd/CheckpointRDD.scala
@@ -0,0 +1,117 @@
+package spark.rdd
+
+import spark._
+import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter}
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.io.{NullWritable, BytesWritable}
+import org.apache.hadoop.util.ReflectionUtils
+import org.apache.hadoop.fs.Path
+import java.io.{File, IOException, EOFException}
+import java.text.NumberFormat
+
+private[spark] class CheckpointRDDSplit(idx: Int, val splitFile: String) extends Split {
+  override val index: Int = idx
+}
+
+class CheckpointRDD[T: ClassManifest](sc: SparkContext, checkpointPath: String)
+  extends RDD[T](sc, Nil) {
+
+  @transient val path = new Path(checkpointPath)
+  @transient val fs = path.getFileSystem(new Configuration())
+
+  @transient val splits_ : Array[Split] = {
+    val splitFiles = fs.listStatus(path).map(_.getPath.toString).filter(_.contains("part-")).sorted
+    splitFiles.zipWithIndex.map(x => new CheckpointRDDSplit(x._2, x._1)).toArray
+  }
+
+  override def getSplits = splits_
+
+  override def getPreferredLocations(split: Split): Seq[String] = {
+    val status = fs.getFileStatus(path)
+    val locations = fs.getFileBlockLocations(status, 0, status.getLen)
+    locations.firstOption.toList.flatMap(_.getHosts).filter(_ != "localhost")
+  }
+
+  override def compute(split: Split): Iterator[T] = {
+    CheckpointRDD.readFromFile(split.asInstanceOf[CheckpointRDDSplit].splitFile)
+  }
+
+  override def checkpoint() {
+    // Do nothing. Hadoop RDD should not be checkpointed.
+  }
+}
+
+private[spark] object CheckpointRDD extends Logging {
+
+  def splitIdToFileName(splitId: Int): String = {
+    val numfmt = NumberFormat.getInstance()
+    numfmt.setMinimumIntegerDigits(5)
+    numfmt.setGroupingUsed(false)
+    "part-"  + numfmt.format(splitId)
+  }
+
+  def writeToFile[T](path: String, blockSize: Int = -1)(context: TaskContext, iterator: Iterator[T]) {
+    val outputDir = new Path(path)
+    val fs = outputDir.getFileSystem(new Configuration())
+
+    val finalOutputName = splitIdToFileName(context.splitId)
+    val finalOutputPath = new Path(outputDir, finalOutputName)
+    val tempOutputPath = new Path(outputDir, "." + finalOutputName + "-attempt-" + context.attemptId)
+
+    if (fs.exists(tempOutputPath)) {
+      throw new IOException("Checkpoint failed: temporary path " +
+        tempOutputPath + " already exists")
+    }
+    val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+
+    val fileOutputStream = if (blockSize < 0) {
+      fs.create(tempOutputPath, false, bufferSize)
+    } else {
+      // This is mainly for testing purpose
+      fs.create(tempOutputPath, false, bufferSize, fs.getDefaultReplication, blockSize)
+    }
+    val serializer = SparkEnv.get.serializer.newInstance()
+    val serializeStream = serializer.serializeStream(fileOutputStream)
+    serializeStream.writeAll(iterator)
+    fileOutputStream.close()
+
+    if (!fs.rename(tempOutputPath, finalOutputPath)) {
+      if (!fs.delete(finalOutputPath, true)) {
+        throw new IOException("Checkpoint failed: failed to delete earlier output of task "
+          + context.attemptId);
+      }
+      if (!fs.rename(tempOutputPath, finalOutputPath)) {
+        throw new IOException("Checkpoint failed: failed to save output of task: "
+          + context.attemptId)
+      }
+    }
+  }
+
+  def readFromFile[T](path: String): Iterator[T] = {
+    val inputPath = new Path(path)
+    val fs = inputPath.getFileSystem(new Configuration())
+    val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
+    val fileInputStream = fs.open(inputPath, bufferSize)
+    val serializer = SparkEnv.get.serializer.newInstance()
+    val deserializeStream = serializer.deserializeStream(fileInputStream)
+    deserializeStream.asIterator.asInstanceOf[Iterator[T]]
+  }
+
+  // Test whether CheckpointRDD generate expected number of splits despite
+  // each split file having multiple blocks. This needs to be run on a
+  // cluster (mesos or standalone) using HDFS.
+  def main(args: Array[String]) {
+    import spark._
+
+    val Array(cluster, hdfsPath) = args
+    val sc = new SparkContext(cluster, "CheckpointRDD Test")
+    val rdd = sc.makeRDD(1 to 10, 10).flatMap(x => 1 to 10000)
+    val path = new Path(hdfsPath, "temp")
+    val fs = path.getFileSystem(new Configuration())
+    sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, 10) _)
+    val cpRDD = new CheckpointRDD[Int](sc, path.toString)
+    assert(cpRDD.splits.length == rdd.splits.length, "Number of splits is not the same")
+    assert(cpRDD.collect.toList == rdd.collect.toList, "Data of splits not the same")
+    fs.delete(path)
+  }
+}
diff --git a/core/src/main/scala/spark/rdd/HadoopRDD.scala b/core/src/main/scala/spark/rdd/HadoopRDD.scala
index fce190b860..eca51758e4 100644
--- a/core/src/main/scala/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/spark/rdd/HadoopRDD.scala
@@ -25,8 +25,7 @@ import spark.Split
  * A Spark split class that wraps around a Hadoop InputSplit.
  */
 private[spark] class HadoopSplit(rddId: Int, idx: Int, @transient s: InputSplit)
-  extends Split
-  with Serializable {
+  extends Split {
   
   val inputSplit = new SerializableWritable[InputSplit](s)
 
@@ -117,6 +116,6 @@ class HadoopRDD[K, V](
   }
 
   override def checkpoint() {
-    // Do nothing. Hadoop RDD cannot be checkpointed.
+    // Do nothing. Hadoop RDD should not be checkpointed.
   }
 }
-- 
GitLab