diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 0e5a86f44e410ff7bd00d5a7a3e2985dd680e64b..8eed46759f340c9b7e7b1bb60d7a10a542881e2a 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -1906,6 +1906,16 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
    * be a HDFS path if running on a cluster.
    */
   def setCheckpointDir(directory: String) {
+
+    // If we are running on a cluster, log a warning if the directory is local.
+    // Otherwise, the driver may attempt to reconstruct the checkpointed RDD from
+    // its own local file system, which is incorrect because the checkpoint files
+    // are actually on the executor machines.
+    if (!isLocal && Utils.nonLocalPaths(directory).isEmpty) {
+      logWarning("Checkpoint directory must be non-local " +
+        "if Spark is running on a cluster: " + directory)
+    }
+
     checkpointDir = Option(directory).map { dir =>
       val path = new Path(dir, UUID.randomUUID().toString)
       val fs = path.getFileSystem(hadoopConfiguration)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index 33e6998b2cb1000e65ed5a55c0d4518247f59148..e17bd47905d7a14510886cf35bb6df348f5a2ec8 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -28,7 +28,7 @@ import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.deploy.SparkHadoopUtil
 import org.apache.spark.util.{SerializableConfiguration, Utils}
 
-private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {}
+private[spark] class CheckpointRDDPartition(val index: Int) extends Partition
 
 /**
  * This RDD represents a RDD checkpoint file (similar to HadoopRDD).
@@ -37,9 +37,11 @@ private[spark]
 class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
   extends RDD[T](sc, Nil) {
 
-  val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration))
+  private val broadcastedConf = sc.broadcast(new SerializableConfiguration(sc.hadoopConfiguration))
 
-  @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
+  @transient private val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
+
+  override def getCheckpointFile: Option[String] = Some(checkpointPath)
 
   override def getPartitions: Array[Partition] = {
     val cpath = new Path(checkpointPath)
@@ -59,9 +61,6 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
     Array.tabulate(numPartitions)(i => new CheckpointRDDPartition(i))
   }
 
-  checkpointData = Some(new RDDCheckpointData[T](this))
-  checkpointData.get.cpFile = Some(checkpointPath)
-
   override def getPreferredLocations(split: Partition): Seq[String] = {
     val status = fs.getFileStatus(new Path(checkpointPath,
       CheckpointRDD.splitIdToFile(split.index)))
@@ -74,9 +73,9 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
     CheckpointRDD.readFromFile(file, broadcastedConf, context)
   }
 
-  override def checkpoint() {
-    // Do nothing. CheckpointRDD should not be checkpointed.
-  }
+  // CheckpointRDD should not be checkpointed again
+  override def checkpoint(): Unit = { }
+  override def doCheckpoint(): Unit = { }
 }
 
 private[spark] object CheckpointRDD extends Logging {
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index cac6e3b477e168e7d50cff68a89e0dda3a7a7df1..9f7ebae3e9af3bcf7ff83241163cc49d3a8329c2 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -194,7 +194,7 @@ abstract class RDD[T: ClassTag](
   @transient private var partitions_ : Array[Partition] = null
 
   /** An Option holding our checkpoint RDD, if we are checkpointed */
-  private def checkpointRDD: Option[RDD[T]] = checkpointData.flatMap(_.checkpointRDD)
+  private def checkpointRDD: Option[CheckpointRDD[T]] = checkpointData.flatMap(_.checkpointRDD)
 
   /**
    * Get the list of dependencies of this RDD, taking into account whether the
@@ -1451,12 +1451,16 @@ abstract class RDD[T: ClassTag](
    * executed on this RDD. It is strongly recommended that this RDD is persisted in
    * memory, otherwise saving it on a file will require recomputation.
    */
-  def checkpoint() {
+  def checkpoint(): Unit = {
     if (context.checkpointDir.isEmpty) {
       throw new SparkException("Checkpoint directory has not been set in the SparkContext")
     } else if (checkpointData.isEmpty) {
-      checkpointData = Some(new RDDCheckpointData(this))
-      checkpointData.get.markForCheckpoint()
+      // NOTE: we use a global lock here due to complexities downstream with ensuring
+      // children RDD partitions point to the correct parent partitions. In the future
+      // we should revisit this consideration.
+      RDDCheckpointData.synchronized {
+        checkpointData = Some(new RDDCheckpointData(this))
+      }
     }
   }
 
@@ -1497,7 +1501,7 @@ abstract class RDD[T: ClassTag](
   private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None
 
   /** Returns the first parent RDD */
-  protected[spark] def firstParent[U: ClassTag] = {
+  protected[spark] def firstParent[U: ClassTag]: RDD[U] = {
     dependencies.head.rdd.asInstanceOf[RDD[U]]
   }
 
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
index acbd31aacdf59b027f78d7b271764d77a0d1f923..4f954363bed8e738db9f1ccf30b09ebae6ed2eb1 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
@@ -22,16 +22,15 @@ import scala.reflect.ClassTag
 import org.apache.hadoop.fs.Path
 
 import org.apache.spark._
-import org.apache.spark.scheduler.{ResultTask, ShuffleMapTask}
 import org.apache.spark.util.SerializableConfiguration
 
 /**
  * Enumeration to manage state transitions of an RDD through checkpointing
- * [ Initialized --> marked for checkpointing --> checkpointing in progress --> checkpointed ]
+ * [ Initialized --> checkpointing in progress --> checkpointed ].
  */
 private[spark] object CheckpointState extends Enumeration {
   type CheckpointState = Value
-  val Initialized, MarkedForCheckpoint, CheckpointingInProgress, Checkpointed = Value
+  val Initialized, CheckpointingInProgress, Checkpointed = Value
 }
 
 /**
@@ -46,37 +45,37 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
   import CheckpointState._
 
   // The checkpoint state of the associated RDD.
-  var cpState = Initialized
+  private var cpState = Initialized
 
   // The file to which the associated RDD has been checkpointed to
-  @transient var cpFile: Option[String] = None
+  private var cpFile: Option[String] = None
 
   // The CheckpointRDD created from the checkpoint file, that is, the new parent the associated RDD.
-  var cpRDD: Option[RDD[T]] = None
+  // This is defined if and only if `cpState` is `Checkpointed`.
+  private var cpRDD: Option[CheckpointRDD[T]] = None
 
-  // Mark the RDD for checkpointing
-  def markForCheckpoint() {
-    RDDCheckpointData.synchronized {
-      if (cpState == Initialized) cpState = MarkedForCheckpoint
-    }
-  }
+  // TODO: are we sure we need to use a global lock in the following methods?
 
   // Is the RDD already checkpointed
-  def isCheckpointed: Boolean = {
-    RDDCheckpointData.synchronized { cpState == Checkpointed }
+  def isCheckpointed: Boolean = RDDCheckpointData.synchronized {
+    cpState == Checkpointed
   }
 
   // Get the file to which this RDD was checkpointed to as an Option
-  def getCheckpointFile: Option[String] = {
-    RDDCheckpointData.synchronized { cpFile }
+  def getCheckpointFile: Option[String] = RDDCheckpointData.synchronized {
+    cpFile
   }
 
-  // Do the checkpointing of the RDD. Called after the first job using that RDD is over.
-  def doCheckpoint() {
-    // If it is marked for checkpointing AND checkpointing is not already in progress,
-    // then set it to be in progress, else return
+  /**
+   * Materialize this RDD and write its content to a reliable DFS.
+   * This is called immediately after the first action invoked on this RDD has completed.
+   */
+  def doCheckpoint(): Unit = {
+
+    // Guard against multiple threads checkpointing the same RDD by
+    // atomically flipping the state of this RDDCheckpointData
     RDDCheckpointData.synchronized {
-      if (cpState == MarkedForCheckpoint) {
+      if (cpState == Initialized) {
         cpState = CheckpointingInProgress
       } else {
         return
@@ -87,7 +86,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
     val path = RDDCheckpointData.rddCheckpointDataPath(rdd.context, rdd.id).get
     val fs = path.getFileSystem(rdd.context.hadoopConfiguration)
     if (!fs.mkdirs(path)) {
-      throw new SparkException("Failed to create checkpoint path " + path)
+      throw new SparkException(s"Failed to create checkpoint path $path")
     }
 
     // Save to file, and reload it as an RDD
@@ -99,6 +98,8 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
         cleaner.registerRDDCheckpointDataForCleanup(newRDD, rdd.id)
       }
     }
+
+    // TODO: This is expensive because it computes the RDD again unnecessarily (SPARK-8582)
     rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _)
     if (newRDD.partitions.length != rdd.partitions.length) {
       throw new SparkException(
@@ -113,34 +114,26 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
       rdd.markCheckpointed(newRDD)   // Update the RDD's dependencies and partitions
       cpState = Checkpointed
     }
-    logInfo("Done checkpointing RDD " + rdd.id + " to " + path + ", new parent is RDD " + newRDD.id)
-  }
-
-  // Get preferred location of a split after checkpointing
-  def getPreferredLocations(split: Partition): Seq[String] = {
-    RDDCheckpointData.synchronized {
-      cpRDD.get.preferredLocations(split)
-    }
+    logInfo(s"Done checkpointing RDD ${rdd.id} to $path, new parent is RDD ${newRDD.id}")
   }
 
-  def getPartitions: Array[Partition] = {
-    RDDCheckpointData.synchronized {
-      cpRDD.get.partitions
-    }
+  def getPartitions: Array[Partition] = RDDCheckpointData.synchronized {
+    cpRDD.get.partitions
   }
 
-  def checkpointRDD: Option[RDD[T]] = {
-    RDDCheckpointData.synchronized {
-      cpRDD
-    }
+  def checkpointRDD: Option[CheckpointRDD[T]] = RDDCheckpointData.synchronized {
+    cpRDD
   }
 }
 
 private[spark] object RDDCheckpointData {
+
+  /** Return the path of the directory to which this RDD's checkpoint data is written. */
   def rddCheckpointDataPath(sc: SparkContext, rddId: Int): Option[Path] = {
-    sc.checkpointDir.map { dir => new Path(dir, "rdd-" + rddId) }
+    sc.checkpointDir.map { dir => new Path(dir, s"rdd-$rddId") }
   }
 
+  /** Clean up the files associated with the checkpoint data for this RDD. */
   def clearRDDCheckpointData(sc: SparkContext, rddId: Int): Unit = {
     rddCheckpointDataPath(sc, rddId).foreach { path =>
       val fs = path.getFileSystem(sc.hadoopConfiguration)
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index d1761a48babbc61afe78bab0d8dea180baa24571..cc50e6d79a3e275c5fb0f5cdea52d44ee64edd86 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -46,7 +46,7 @@ class CheckpointSuite extends SparkFunSuite with LocalSparkContext with Logging
     val parCollection = sc.makeRDD(1 to 4)
     val flatMappedRDD = parCollection.flatMap(x => 1 to x)
     flatMappedRDD.checkpoint()
-    assert(flatMappedRDD.dependencies.head.rdd == parCollection)
+    assert(flatMappedRDD.dependencies.head.rdd === parCollection)
     val result = flatMappedRDD.collect()
     assert(flatMappedRDD.dependencies.head.rdd != parCollection)
     assert(flatMappedRDD.collect() === result)