Skip to content
Snippets Groups Projects
Commit be7be6d4 authored by Joseph K. Bradley's avatar Joseph K. Bradley
Browse files

[SPARK-6684] [MLLIB] [ML] Add checkpointing to GBTs

Add checkpointing to GradientBoostedTrees, GBTClassifier, GBTRegressor

CC: mengxr

Author: Joseph K. Bradley <joseph@databricks.com>

Closes #7804 from jkbradley/gbt-checkpoint3 and squashes the following commits:

3fbd7ba [Joseph K. Bradley] tiny fix
b3e160c [Joseph K. Bradley] unset checkpoint dir after test
9cc3a04 [Joseph K. Bradley] added checkpointing to GBTs
parent 7f7a319c
No related branches found
No related tags found
No related merge requests found
...@@ -144,6 +144,7 @@ final class EMLDAOptimizer extends LDAOptimizer { ...@@ -144,6 +144,7 @@ final class EMLDAOptimizer extends LDAOptimizer {
this.checkpointInterval = lda.getCheckpointInterval this.checkpointInterval = lda.getCheckpointInterval
this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount]( this.graphCheckpointer = new PeriodicGraphCheckpointer[TopicCounts, TokenCount](
checkpointInterval, graph.vertices.sparkContext) checkpointInterval, graph.vertices.sparkContext)
this.graphCheckpointer.update(this.graph)
this.globalTopicTotals = computeGlobalTopicTotals() this.globalTopicTotals = computeGlobalTopicTotals()
this this
} }
......
...@@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree ...@@ -20,6 +20,7 @@ package org.apache.spark.mllib.tree
import org.apache.spark.Logging import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.BoostingStrategy import org.apache.spark.mllib.tree.configuration.BoostingStrategy
import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.Algo._
...@@ -184,22 +185,28 @@ object GradientBoostedTrees extends Logging { ...@@ -184,22 +185,28 @@ object GradientBoostedTrees extends Logging {
false false
} }
// Prepare periodic checkpointers
val predErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
treeStrategy.getCheckpointInterval, input.sparkContext)
val validatePredErrorCheckpointer = new PeriodicRDDCheckpointer[(Double, Double)](
treeStrategy.getCheckpointInterval, input.sparkContext)
timer.stop("init") timer.stop("init")
logDebug("##########") logDebug("##########")
logDebug("Building tree 0") logDebug("Building tree 0")
logDebug("##########") logDebug("##########")
var data = input
// Initialize tree // Initialize tree
timer.start("building tree 0") timer.start("building tree 0")
val firstTreeModel = new DecisionTree(treeStrategy).run(data) val firstTreeModel = new DecisionTree(treeStrategy).run(input)
val firstTreeWeight = 1.0 val firstTreeWeight = 1.0
baseLearners(0) = firstTreeModel baseLearners(0) = firstTreeModel
baseLearnerWeights(0) = firstTreeWeight baseLearnerWeights(0) = firstTreeWeight
var predError: RDD[(Double, Double)] = GradientBoostedTreesModel. var predError: RDD[(Double, Double)] = GradientBoostedTreesModel.
computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss) computeInitialPredictionAndError(input, firstTreeWeight, firstTreeModel, loss)
predErrorCheckpointer.update(predError)
logDebug("error of gbt = " + predError.values.mean()) logDebug("error of gbt = " + predError.values.mean())
// Note: A model of type regression is used since we require raw prediction // Note: A model of type regression is used since we require raw prediction
...@@ -207,35 +214,34 @@ object GradientBoostedTrees extends Logging { ...@@ -207,35 +214,34 @@ object GradientBoostedTrees extends Logging {
var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel. var validatePredError: RDD[(Double, Double)] = GradientBoostedTreesModel.
computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss) computeInitialPredictionAndError(validationInput, firstTreeWeight, firstTreeModel, loss)
if (validate) validatePredErrorCheckpointer.update(validatePredError)
var bestValidateError = if (validate) validatePredError.values.mean() else 0.0 var bestValidateError = if (validate) validatePredError.values.mean() else 0.0
var bestM = 1 var bestM = 1
// pseudo-residual for second iteration
data = predError.zip(input).map { case ((pred, _), point) =>
LabeledPoint(-loss.gradient(pred, point.label), point.features)
}
var m = 1 var m = 1
while (m < numIterations) { var doneLearning = false
while (m < numIterations && !doneLearning) {
// Update data with pseudo-residuals
val data = predError.zip(input).map { case ((pred, _), point) =>
LabeledPoint(-loss.gradient(pred, point.label), point.features)
}
timer.start(s"building tree $m") timer.start(s"building tree $m")
logDebug("###################################################") logDebug("###################################################")
logDebug("Gradient boosting tree iteration " + m) logDebug("Gradient boosting tree iteration " + m)
logDebug("###################################################") logDebug("###################################################")
val model = new DecisionTree(treeStrategy).run(data) val model = new DecisionTree(treeStrategy).run(data)
timer.stop(s"building tree $m") timer.stop(s"building tree $m")
// Create partial model // Update partial model
baseLearners(m) = model baseLearners(m) = model
// Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError. // Note: The setting of baseLearnerWeights is incorrect for losses other than SquaredError.
// Technically, the weight should be optimized for the particular loss. // Technically, the weight should be optimized for the particular loss.
// However, the behavior should be reasonable, though not optimal. // However, the behavior should be reasonable, though not optimal.
baseLearnerWeights(m) = learningRate baseLearnerWeights(m) = learningRate
// Note: A model of type regression is used since we require raw prediction
val partialModel = new GradientBoostedTreesModel(
Regression, baseLearners.slice(0, m + 1),
baseLearnerWeights.slice(0, m + 1))
predError = GradientBoostedTreesModel.updatePredictionError( predError = GradientBoostedTreesModel.updatePredictionError(
input, predError, baseLearnerWeights(m), baseLearners(m), loss) input, predError, baseLearnerWeights(m), baseLearners(m), loss)
predErrorCheckpointer.update(predError)
logDebug("error of gbt = " + predError.values.mean()) logDebug("error of gbt = " + predError.values.mean())
if (validate) { if (validate) {
...@@ -246,21 +252,15 @@ object GradientBoostedTrees extends Logging { ...@@ -246,21 +252,15 @@ object GradientBoostedTrees extends Logging {
validatePredError = GradientBoostedTreesModel.updatePredictionError( validatePredError = GradientBoostedTreesModel.updatePredictionError(
validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss) validationInput, validatePredError, baseLearnerWeights(m), baseLearners(m), loss)
validatePredErrorCheckpointer.update(validatePredError)
val currentValidateError = validatePredError.values.mean() val currentValidateError = validatePredError.values.mean()
if (bestValidateError - currentValidateError < validationTol) { if (bestValidateError - currentValidateError < validationTol) {
return new GradientBoostedTreesModel( doneLearning = true
boostingStrategy.treeStrategy.algo,
baseLearners.slice(0, bestM),
baseLearnerWeights.slice(0, bestM))
} else if (currentValidateError < bestValidateError) { } else if (currentValidateError < bestValidateError) {
bestValidateError = currentValidateError bestValidateError = currentValidateError
bestM = m + 1 bestM = m + 1
} }
} }
// Update data with pseudo-residuals
data = predError.zip(input).map { case ((pred, _), point) =>
LabeledPoint(-loss.gradient(pred, point.label), point.features)
}
m += 1 m += 1
} }
...@@ -269,6 +269,8 @@ object GradientBoostedTrees extends Logging { ...@@ -269,6 +269,8 @@ object GradientBoostedTrees extends Logging {
logInfo("Internal timing for DecisionTree:") logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer") logInfo(s"$timer")
predErrorCheckpointer.deleteAllCheckpoints()
validatePredErrorCheckpointer.deleteAllCheckpoints()
if (persistedInput) input.unpersist() if (persistedInput) input.unpersist()
if (validate) { if (validate) {
......
...@@ -36,7 +36,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss} ...@@ -36,7 +36,8 @@ import org.apache.spark.mllib.tree.loss.{LogLoss, SquaredError, Loss}
* learning rate should be between in the interval (0, 1] * learning rate should be between in the interval (0, 1]
* @param validationTol Useful when runWithValidation is used. If the error rate on the * @param validationTol Useful when runWithValidation is used. If the error rate on the
* validation input between two iterations is less than the validationTol * validation input between two iterations is less than the validationTol
* then stop. Ignored when [[run]] is used. * then stop. Ignored when
* [[org.apache.spark.mllib.tree.GradientBoostedTrees.run()]] is used.
*/ */
@Experimental @Experimental
case class BoostingStrategy( case class BoostingStrategy(
......
...@@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} ...@@ -28,6 +28,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame import org.apache.spark.sql.DataFrame
import org.apache.spark.util.Utils
/** /**
...@@ -76,6 +77,25 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -76,6 +77,25 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
} }
} }
test("Checkpointing") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
sc.setCheckpointDir(path)
val categoricalFeatures = Map.empty[Int, Int]
val df: DataFrame = TreeTests.setMetadata(data, categoricalFeatures, numClasses = 2)
val gbt = new GBTClassifier()
.setMaxDepth(2)
.setLossType("logistic")
.setMaxIter(5)
.setStepSize(0.1)
.setCheckpointInterval(2)
val model = gbt.fit(df)
sc.checkpointDir = None
Utils.deleteRecursively(tempDir)
}
// TODO: Reinstate test once runWithValidation is implemented SPARK-7132 // TODO: Reinstate test once runWithValidation is implemented SPARK-7132
/* /*
test("runWithValidation stops early and performs better on a validation dataset") { test("runWithValidation stops early and performs better on a validation dataset") {
......
...@@ -25,7 +25,8 @@ import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees => ...@@ -25,7 +25,8 @@ import org.apache.spark.mllib.tree.{EnsembleTestHelper, GradientBoostedTrees =>
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row} import org.apache.spark.sql.DataFrame
import org.apache.spark.util.Utils
/** /**
...@@ -88,6 +89,23 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext { ...@@ -88,6 +89,23 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(predictions.min() < -1) assert(predictions.min() < -1)
} }
test("Checkpointing") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
sc.setCheckpointDir(path)
val df = sqlContext.createDataFrame(data)
val gbt = new GBTRegressor()
.setMaxDepth(2)
.setMaxIter(5)
.setStepSize(0.1)
.setCheckpointInterval(2)
val model = gbt.fit(df)
sc.checkpointDir = None
Utils.deleteRecursively(tempDir)
}
// TODO: Reinstate test once runWithValidation is implemented SPARK-7132 // TODO: Reinstate test once runWithValidation is implemented SPARK-7132
/* /*
test("runWithValidation stops early and performs better on a validation dataset") { test("runWithValidation stops early and performs better on a validation dataset") {
......
...@@ -166,43 +166,58 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext ...@@ -166,43 +166,58 @@ class GradientBoostedTreesSuite extends SparkFunSuite with MLlibTestSparkContext
val algos = Array(Regression, Regression, Classification) val algos = Array(Regression, Regression, Classification)
val losses = Array(SquaredError, AbsoluteError, LogLoss) val losses = Array(SquaredError, AbsoluteError, LogLoss)
(algos zip losses) map { algos.zip(losses).foreach { case (algo, loss) =>
case (algo, loss) => { val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2,
val treeStrategy = new Strategy(algo = algo, impurity = Variance, maxDepth = 2, categoricalFeaturesInfo = Map.empty)
categoricalFeaturesInfo = Map.empty) val boostingStrategy =
val boostingStrategy = new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0)
new BoostingStrategy(treeStrategy, loss, numIterations, validationTol = 0.0) val gbtValidate = new GradientBoostedTrees(boostingStrategy)
val gbtValidate = new GradientBoostedTrees(boostingStrategy) .runWithValidation(trainRdd, validateRdd)
.runWithValidation(trainRdd, validateRdd) val numTrees = gbtValidate.numTrees
val numTrees = gbtValidate.numTrees assert(numTrees !== numIterations)
assert(numTrees !== numIterations)
// Test that it performs better on the validation dataset.
// Test that it performs better on the validation dataset. val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd)
val gbt = new GradientBoostedTrees(boostingStrategy).run(trainRdd) val (errorWithoutValidation, errorWithValidation) = {
val (errorWithoutValidation, errorWithValidation) = { if (algo == Classification) {
if (algo == Classification) { val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features))
val remappedRdd = validateRdd.map(x => new LabeledPoint(2 * x.label - 1, x.features)) (loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd))
(loss.computeError(gbt, remappedRdd), loss.computeError(gbtValidate, remappedRdd)) } else {
} else { (loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
(loss.computeError(gbt, validateRdd), loss.computeError(gbtValidate, validateRdd))
}
}
assert(errorWithValidation <= errorWithoutValidation)
// Test that results from evaluateEachIteration comply with runWithValidation.
// Note that convergenceTol is set to 0.0
val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
assert(evaluationArray.length === numIterations)
assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
var i = 1
while (i < numTrees) {
assert(evaluationArray(i) <= evaluationArray(i - 1))
i += 1
} }
} }
assert(errorWithValidation <= errorWithoutValidation)
// Test that results from evaluateEachIteration comply with runWithValidation.
// Note that convergenceTol is set to 0.0
val evaluationArray = gbt.evaluateEachIteration(validateRdd, loss)
assert(evaluationArray.length === numIterations)
assert(evaluationArray(numTrees) > evaluationArray(numTrees - 1))
var i = 1
while (i < numTrees) {
assert(evaluationArray(i) <= evaluationArray(i - 1))
i += 1
}
} }
} }
test("Checkpointing") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString
sc.setCheckpointDir(path)
val rdd = sc.parallelize(GradientBoostedTreesSuite.data, 2)
val treeStrategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
categoricalFeaturesInfo = Map.empty, checkpointInterval = 2)
val boostingStrategy = new BoostingStrategy(treeStrategy, SquaredError, 5, 0.1)
val gbt = GradientBoostedTrees.train(rdd, boostingStrategy)
sc.checkpointDir = None
Utils.deleteRecursively(tempDir)
}
} }
private object GradientBoostedTreesSuite { private object GradientBoostedTreesSuite {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment