Skip to content
Snippets Groups Projects
Commit d78fbcc3 authored by BenFradet's avatar BenFradet Committed by Joseph K. Bradley
Browse files

[SPARK-14570][ML] Log instrumentation in Random forests

## What changes were proposed in this pull request?

Added Instrumentation logging to DecisionTree{Classifier,Regressor} and RandomForest{Classifier,Regressor}

## How was this patch tested?

No tests involved since it's logging related.

Author: BenFradet <benjamin.fradet@gmail.com>

Closes #12536 from BenFradet/SPARK-14570.
parent af32f4ae
No related branches found
No related tags found
No related merge requests found
......@@ -88,17 +88,30 @@ class DecisionTreeClassifier @Since("1.4.0") (
val numClasses: Int = getNumClasses(dataset)
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
val instr = Instrumentation.create(this, oldDataset)
instr.logParams(params: _*)
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeClassificationModel]
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
val m = trees.head.asInstanceOf[DecisionTreeClassificationModel]
instr.logSuccess(m)
m
}
/** (private[ml]) Train a decision tree on an RDD */
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeClassificationModel = {
val instr = Instrumentation.create(this, data)
instr.logParams(params: _*)
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
seed = 0L, parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeClassificationModel]
seed = 0L, instr = Some(instr), parentUID = Some(uid))
val m = trees.head.asInstanceOf[DecisionTreeClassificationModel]
instr.logSuccess(m)
m
}
/** (private[ml]) Create a Strategy instance to use with the old API. */
......
......@@ -105,11 +105,18 @@ class RandomForestClassifier @Since("1.4.0") (
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset, numClasses)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses, OldAlgo.Classification, getOldImpurity)
val trees =
RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
.map(_.asInstanceOf[DecisionTreeClassificationModel])
val instr = Instrumentation.create(this, oldDataset)
instr.logParams(params: _*)
val trees = RandomForest
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeClassificationModel])
val numFeatures = oldDataset.first().features.size
new RandomForestClassificationModel(trees, numFeatures, numClasses)
val m = new RandomForestClassificationModel(trees, numFeatures, numClasses)
instr.logSuccess(m)
m
}
@Since("1.4.1")
......
......@@ -88,17 +88,30 @@ class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy = getOldStrategy(categoricalFeatures)
val instr = Instrumentation.create(this, oldDataset)
instr.logParams(params: _*)
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeRegressionModel]
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
val m = trees.head.asInstanceOf[DecisionTreeRegressionModel]
instr.logSuccess(m)
m
}
/** (private[ml]) Train a decision tree on an RDD */
private[ml] def train(data: RDD[LabeledPoint],
oldStrategy: OldStrategy): DecisionTreeRegressionModel = {
val instr = Instrumentation.create(this, data)
instr.logParams(params: _*)
val trees = RandomForest.run(data, oldStrategy, numTrees = 1, featureSubsetStrategy = "all",
seed = $(seed), parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeRegressionModel]
seed = $(seed), instr = Some(instr), parentUID = Some(uid))
val m = trees.head.asInstanceOf[DecisionTreeRegressionModel]
instr.logSuccess(m)
m
}
/** (private[ml]) Create a Strategy instance to use with the old API. */
......@@ -167,7 +180,7 @@ class DecisionTreeRegressionModel private[ml] (
override protected def transformImpl(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf { (features: Vector) => predict(features) }
val predictVarianceUDF = udf { (features: Vector) => predictVariance(features) }
var output = dataset.toDF
var output = dataset.toDF()
if ($(predictionCol).nonEmpty) {
output = output.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
......
......@@ -99,11 +99,18 @@ class RandomForestRegressor @Since("1.4.0") (@Since("1.4.0") override val uid: S
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy =
super.getOldStrategy(categoricalFeatures, numClasses = 0, OldAlgo.Regression, getOldImpurity)
val trees =
RandomForest.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed)
.map(_.asInstanceOf[DecisionTreeRegressionModel])
val instr = Instrumentation.create(this, oldDataset)
instr.logParams(params: _*)
val trees = RandomForest
.run(oldDataset, strategy, getNumTrees, getFeatureSubsetStrategy, getSeed, Some(instr))
.map(_.asInstanceOf[DecisionTreeRegressionModel])
val numFeatures = oldDataset.first().features.size
new RandomForestRegressionModel(trees, numFeatures)
val m = new RandomForestRegressionModel(trees, numFeatures)
instr.logSuccess(m)
m
}
@Since("1.4.0")
......
......@@ -26,6 +26,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.regression.DecisionTreeRegressionModel
import org.apache.spark.ml.tree._
import org.apache.spark.ml.util.Instrumentation
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
......@@ -80,6 +81,7 @@ private[spark] object RandomForest extends Logging {
/**
* Train a random forest.
*
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @return an unweighted set of trees
*/
......@@ -89,6 +91,7 @@ private[spark] object RandomForest extends Logging {
numTrees: Int,
featureSubsetStrategy: String,
seed: Long,
instr: Option[Instrumentation[_]],
parentUID: Option[String] = None): Array[DecisionTreeModel] = {
val timer = new TimeTracker()
......@@ -100,13 +103,14 @@ private[spark] object RandomForest extends Logging {
val retaggedInput = input.retag(classOf[LabeledPoint])
val metadata =
DecisionTreeMetadata.buildMetadata(retaggedInput, strategy, numTrees, featureSubsetStrategy)
logDebug("algo = " + strategy.algo)
logDebug("numTrees = " + numTrees)
logDebug("seed = " + seed)
logDebug("maxBins = " + metadata.maxBins)
logDebug("featureSubsetStrategy = " + featureSubsetStrategy)
logDebug("numFeaturesPerNode = " + metadata.numFeaturesPerNode)
logDebug("subsamplingRate = " + strategy.subsamplingRate)
instr match {
case Some(instrumentation) =>
instrumentation.logNumFeatures(metadata.numFeatures)
instrumentation.logNumClasses(metadata.numClasses)
case None =>
logInfo("numFeatures: " + metadata.numFeatures)
logInfo("numClasses: " + metadata.numClasses)
}
// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
......@@ -610,7 +614,9 @@ private[spark] object RandomForest extends Logging {
}
/**
* Calculate the impurity statistics for a give (feature, split) based upon left/right aggregates.
* Calculate the impurity statistics for a given (feature, split) based upon left/right
* aggregates.
*
* @param stats the recycle impurity statistics for this feature's all splits,
* only 'impurity' and 'impurityCalculator' are valid between each iteration
* @param leftImpurityCalculator left node aggregates for this (feature, split)
......@@ -668,6 +674,7 @@ private[spark] object RandomForest extends Logging {
/**
* Find the best split for a node.
*
* @param binAggregates Bin statistics.
* @return tuple for best split: (Split, information gain, prediction at node)
*/
......@@ -940,6 +947,7 @@ private[spark] object RandomForest extends Logging {
* NOTE: Returned number of splits is set based on `featureSamples` and
* could be different from the specified `numSplits`.
* The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
*
* @param featureSamples feature values of each sample
* @param metadata decision tree metadata
* NOTE: `metadata.numbins` will be changed accordingly
......@@ -1083,6 +1091,7 @@ private[spark] object RandomForest extends Logging {
/**
* Get the number of values to be stored for this node in the bin aggregates.
*
* @param featureSubset Indices of features which may be split at this node.
* If None, then use all features.
*/
......
......@@ -62,8 +62,7 @@ class DecisionTree private[spark] (private val strategy: Strategy, private val s
*/
@Since("1.2.0")
def run(input: RDD[LabeledPoint]): DecisionTreeModel = {
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = seed)
val rf = new RandomForest(strategy, numTrees = 1, featureSubsetStrategy = "all", seed = seed)
val rfModel = rf.run(input)
rfModel.trees(0)
}
......@@ -88,7 +87,7 @@ object DecisionTree extends Serializable with Logging {
* categorical), depth of the tree, quantile calculation strategy, etc.
* @return DecisionTreeModel that can be used for prediction.
*/
@Since("1.0.0")
@Since("1.0.0")
def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
new DecisionTree(strategy).run(input)
}
......
......@@ -45,10 +45,10 @@ import org.apache.spark.util.Utils
* - sqrt: recommended by Breiman manual for random forests
* - The defaults of sqrt (classification) and onethird (regression) match the R randomForest
* package.
*
* @see [[http://www.stat.berkeley.edu/~breiman/randomforest2001.pdf Breiman (2001)]]
* @see [[http://www.stat.berkeley.edu/~breiman/Using_random_forests_V3.1.pdf Breiman manual for
* random forests]]
*
* @param strategy The configuration parameters for the random forest algorithm which specify
* the type of random forest (classification or regression), feature type
* (continuous, categorical), depth of the tree, quantile calculation strategy,
......@@ -91,7 +91,7 @@ private class RandomForest (
*/
def run(input: RDD[LabeledPoint]): RandomForestModel = {
val trees: Array[NewDTModel] =
NewRandomForest.run(input, strategy, numTrees, featureSubsetStrategy, seed.toLong)
NewRandomForest.run(input, strategy, numTrees, featureSubsetStrategy, seed.toLong, None)
new RandomForestModel(strategy.algo, trees.map(_.toOld))
}
......
......@@ -322,7 +322,7 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
numClasses = 2, categoricalFeaturesInfo = Map(0 -> 3), maxBins = 3)
val model = RandomForest.run(input, strategy, numTrees = 1, featureSubsetStrategy = "all",
seed = 42).head
seed = 42, instr = None).head
model.rootNode match {
case n: InternalNode => n.split match {
case s: CategoricalSplit =>
......@@ -345,9 +345,9 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
new OldStrategy(OldAlgo.Classification, Entropy, 3, 2, 100, maxMemoryInMB = 0)
val tree1 = RandomForest.run(rdd, strategy1, numTrees = 1, featureSubsetStrategy = "all",
seed = 42).head
seed = 42, instr = None).head
val tree2 = RandomForest.run(rdd, strategy2, numTrees = 1, featureSubsetStrategy = "all",
seed = 42).head
seed = 42, instr = None).head
def getChildren(rootNode: Node): Array[InternalNode] = rootNode match {
case n: InternalNode =>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment