diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
index 9dbec41efeada61bbb26f0641ed00834d90a8fff..d6f8b29a43dfdeb775f1d29f5d6a249c05022d98 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDAOptimizer.scala
@@ -144,6 +144,7 @@ final class EMLDAOptimizer extends LDAOptimizer {
     this.checkpointInterval = lda.getCheckpointInterval
     this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount](
       checkpointInterval, graph.vertices.sparkContext)
+    this.graphCheckpointer.update(this.graph)
     this.globalTopicTotals = computeGlobalTopicTotals()
     this
   }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
index a835f96d5d0e3c528a19959e95bc067c2672f370..9ce6faa137c41854fe9a7c8e080655d8e3f508d6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree
 import org.apache.spark.Logging
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.configuration.BoostingStrategy
 import org.apache.spark.mllib.tree.configuration.Algo._
@@ -184,22 +185,28 @@ object GradientBoostedTrees extends Logging {
       false
     }
 
+    // Prepare periodic checkpointers
+    val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
+      treeStrategy.getCheckpointInterval, input.sparkContext)
+    val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
+      treeStrategy.getCheckpointInterval, input.sparkContext)
+
     timer.stop("init")
 
     logDebug("##########")
     logDebug("Building tree 0")
     logDebug("##########")
-    var data = input
 
     // Initialize tree
     timer.start("building tree 0")
-    val firstTreeModel = new DecisionTree(treeStrategy).run(data)
+    val firstTreeModel = new DecisionTree(treeStrategy).run(input)
     val firstTreeWeight = 1.0
     baseLearners(0) = firstTreeModel
     baseLearnerWeights(0) = firstTreeWeight
 
     var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
       computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
+    predErrorCheckpointer.update(predError)
     logDebug("error of gbt = " + predError.values.mean())
 
     // Note: A model of type regression is used since we require raw prediction
@@ -207,35 +214,34 @@ object GradientBoostedTrees extends Logging {
 
     var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel.
       computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
+    if (validate) validatePredErrorCheckpointer.update(validatePredError)
     var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
     var bestM = 1
 
-    // pseudo-residual for second iteration
-    data = predError.zip(input).map { case ((pred, _), point) =>
-      LabeledPoint(-loss.gradient(pred, point.label), point.features)
-    }
-
     var m = 1
-    while (m < numIterations) {
+    var doneLearning = false
+    while (m < numIterations && !doneLearning) {
+      // Update data with pseudo-residuals
+      val data = predError.zip(input).map { case ((pred, _), point) =>
+        LabeledPoint(-loss.gradient(pred, point.label), point.features)
+      }
+
       timer.start(s"building tree $m")
       logDebug("###################################################")
       logDebug("Gradient boosting tree iteration " + m)
       logDebug("###################################################")
       val model = new DecisionTree(treeStrategy).run(data)
       timer.stop(s"building tree $m")
-      // Create partial model
+      // Update partial model
       baseLearners(m) = model
       // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
       //       Technically, the weight should be optimized for the particular loss.
       //       However, the behavior should be reasonable, though not optimal.
       baseLearnerWeights(m) = learningRate
-      // Note: A model of type regression is used since we require raw prediction
-      val partialModel = new GradientBoostedTreesModel(
-        Regression, baseLearners.slice(0, m + 1),
-        baseLearnerWeights.slice(0, m + 1))
 
       predError = GradientBoostedTreesModel.updatePredictionError(
         input, predError, baseLearnerWeights(m), baseLearners(m), loss)
+      predErrorCheckpointer.update(predError)
       logDebug("error of gbt = " + predError.values.mean())
 
       if (validate) {
@@ -246,21 +252,15 @@ object GradientBoostedTrees extends Logging {
 
         validatePredError = GradientBoostedTreesModel.updatePredictionError(
           validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
+        validatePredErrorCheckpointer.update(validatePredError)
         val currentValidateError = validatePredError.values.mean()
         if (bestValidateError - currentValidateError < validationTol) {
-          return new GradientBoostedTreesModel(
-            boostingStrategy.treeStrategy.algo,
-            baseLearners.slice(0, bestM),
-            baseLearnerWeights.slice(0, bestM))
+          doneLearning = true
         } else if (currentValidateError < bestValidateError) {
-            bestValidateError = currentValidateError
-            bestM = m + 1
+          bestValidateError = currentValidateError
+          bestM = m + 1
         }
       }
-      // Update data with pseudo-residuals
-      data = predError.zip(input).map { case ((pred, _), point) =>
-        LabeledPoint(-loss.gradient(pred, point.label), point.features)
-      }
       m += 1
     }
 
@@ -269,6 +269,8 @@ object GradientBoostedTrees extends Logging {
     logInfo("Internal timing for DecisionTree:")
     logInfo(s"$timer")
 
+    predErrorCheckpointer.deleteAllCheckpoints()
+    validatePredErrorCheckpointer.deleteAllCheckpoints()
     if (persistedInput) input.unpersist()
 
     if (validate) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
index 2d6b01524ff3d339caba09454f2098382dd66cc6..9fd30c9b56319621fa82ebb8d8615a9eba29552c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/BoostingStrategy.scala
@@ -36,7 +36,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
  *                     learning rate should be between in the interval (0, 1]
  * @param validationTol Useful when runWithValidation is used. If the error rate on the
  *                      validation input between two iterations is less than the validationTol
- *                      then stop. Ignored when [[run]] is used.
+ *                      then stop.  Ignored when
+ *                      [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used.
  */
 @Experimental
 case class BoostingStrategy(
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index 82c345491bb3cc0d6004076c69ae3d0dcc03e56a..a7bc77965fefd28264ab0f708417b3af8469cc32 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.DataFrame
+import org.apache.spark.util.Utils
 
 
 /**
@@ -76,6 +77,25 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
     }
   }
 
+  test("Checkpointing") {
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+    sc.setCheckpointDir(path)
+
+    val categoricalFeatures = Map.empty[Int, Int]
+    val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
+    val gbt = new GBTClassifier()
+      .setMaxDepth(2)
+      .setLossType("logistic")
+      .setMaxIter(5)
+      .setStepSize(0.1)
+      .setCheckpointInterval(2)
+    val model = gbt.fit(df)
+
+    sc.checkpointDir = None
+    Utils.deleteRecursively(tempDir)
+  }
+
   // TODO: Reinstate test once runWithValidation is implemented   SPARK-7132
   /*
   test("runWithValidation stops early and performs better on a validation dataset") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
index 9682edcd9ba8402fab99b7d09ddfb3dd316ec9bc..dbdce0c9dea54eef60e753d69dc849aeee811fe0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala
@@ -25,7 +25,8 @@ import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees =>
 import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.util.Utils
 
 
 /**
@@ -88,6 +89,23 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(predictions.min() < -1)
   }
 
+  test("Checkpointing") {
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+    sc.setCheckpointDir(path)
+
+    val df = sqlContext.createDataFrame(data)
+    val gbt = new GBTRegressor()
+      .setMaxDepth(2)
+      .setMaxIter(5)
+      .setStepSize(0.1)
+      .setCheckpointInterval(2)
+    val model = gbt.fit(df)
+
+    sc.checkpointDir = None
+    Utils.deleteRecursively(tempDir)
+  }
+
   // TODO: Reinstate test once runWithValidation is implemented  SPARK-7132
   /*
   test("runWithValidation stops early and performs better on a validation dataset") {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
index 2521b3342181a5fe4926e80a40d7ba2e94a85580..6fc9e8df621dfdc02ec32188e833baa0a23d409d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/GradientBoostedTreesSuite.scala
@@ -166,43 +166,58 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
 
     val algos = Array(Regression, Regression, Classification)
     val losses = Array(SquaredError, AbsoluteError, LogLoss)
-    (algos zip losses) map {
-      case (algo, loss) => {
-        val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
-          categoricalFeaturesInfo = Map.empty)
-        val boostingStrategy =
-          new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
-        val gbtValidate = new GradientBoostedTrees(boostingStrategy)
-          .runWithValidation(trainRdd, validateRdd)
-        val numTrees = gbtValidate.numTrees
-        assert(numTrees !== numIterations)
-
-        // Test that it performs better on the validation dataset.
-        val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
-        val (errorWithoutValidation, errorWithValidation) = {
-          if (algo == Classification) {
-            val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
-            (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
-          } else {
-            (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
-          }
-        }
-        assert(errorWithValidation <= errorWithoutValidation)
-
-        // Test that results from evaluateEachIteration comply with runWithValidation.
-        // Note that convergenceTol is set to 0.0
-        val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
-        assert(evaluationArray.length === numIterations)
-        assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
-        var i = 1
-        while (i < numTrees) {
-          assert(evaluationArray(i) <= evaluationArray(i - 1))
-          i += 1
+    algos.zip(losses).foreach { case (algo, loss) =>
+      val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
+        categoricalFeaturesInfo = Map.empty)
+      val boostingStrategy =
+        new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
+      val gbtValidate = new GradientBoostedTrees(boostingStrategy)
+        .runWithValidation(trainRdd, validateRdd)
+      val numTrees = gbtValidate.numTrees
+      assert(numTrees !== numIterations)
+
+      // Test that it performs better on the validation dataset.
+      val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
+      val (errorWithoutValidation, errorWithValidation) = {
+        if (algo == Classification) {
+          val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
+          (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
+        } else {
+          (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
         }
       }
+      assert(errorWithValidation <= errorWithoutValidation)
+
+      // Test that results from evaluateEachIteration comply with runWithValidation.
+      // Note that convergenceTol is set to 0.0
+      val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
+      assert(evaluationArray.length === numIterations)
+      assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
+      var i = 1
+      while (i < numTrees) {
+        assert(evaluationArray(i) <= evaluationArray(i - 1))
+        i += 1
+      }
     }
   }
 
+  test("Checkpointing") {
+    val tempDir = Utils.createTempDir()
+    val path = tempDir.toURI.toString
+    sc.setCheckpointDir(path)
+
+    val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
+
+    val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
+      categoricalFeaturesInfo = Map.empty, checkpointInterval = 2)
+    val boostingStrategy = new BoostingStrategy(treeStrategy, SquaredError, 5, 0.1)
+
+    val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
+
+    sc.checkpointDir = None
+    Utils.deleteRecursively(tempDir)
+  }
+
 }
 
 private object GradientBoostedTreesSuite {