diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 91dc98569a21b12ebbeab56f9776b66cb84e9d9f..dd9a5f261f60f86c0190fe5b063c0b18f0b0518a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -244,8 +244,7 @@ private[ml] object RandomForest extends Logging { if (unorderedFeatures.contains(featureIndex)) { // Unordered feature val featureValue = treePoint.binnedFeatures(featureIndex) - val (leftNodeFeatureOffset, rightNodeFeatureOffset) = - agg.getLeftRightFeatureOffsets(featureIndexIdx) + val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx) // Update the left or right bin for each split. val numSplits = agg.metadata.numSplits(featureIndex) val featureSplits = splits(featureIndex) @@ -253,8 +252,6 @@ private[ml] object RandomForest extends Logging { while (splitIndex < numSplits) { if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) { agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) - } else { - agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) } splitIndex += 1 } @@ -394,6 +391,7 @@ private[ml] object RandomForest extends Logging { mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, metadata.unorderedFeatures, instanceWeight, featuresForNode) } + agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight) } } @@ -658,7 +656,7 @@ private[ml] object RandomForest extends Logging { // Calculate InformationGain and ImpurityStats if current node is top node val level = LearningNode.indexToLevel(node.id) - var gainAndImpurityStats: ImpurityStats = if (level ==0) { + var gainAndImpurityStats: ImpurityStats = if (level == 0) { null } else { node.stats @@ -697,13 +695,12 @@ private[ml] object RandomForest extends Logging { (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (binAggregates.metadata.isUnordered(featureIndex)) { // Unordered categorical feature - val (leftChildOffset, rightChildOffset) = - binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) + val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = - binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) + val rightChildStats = binAggregates.getParentImpurityCalculator() + .subtract(leftChildStats) gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, leftChildStats, rightChildStats, binAggregates.metadata) (splitIndex, gainAndImpurityStats) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 18f66e65f19ca53e2fc4dcb0cb4958f88915b4d1..c0934d241f50a59018706267729d61e0191448db 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -52,6 +52,7 @@ class DecisionTree @Since("1.0.0") (private val strategy: Strategy) /** * Method to train a decision tree model over an RDD + * * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]. * @return DecisionTreeModel that can be used for prediction. */ @@ -368,8 +369,7 @@ object DecisionTree extends Serializable with Logging { if (unorderedFeatures.contains(featureIndex)) { // Unordered feature val featureValue = treePoint.binnedFeatures(featureIndex) - val (leftNodeFeatureOffset, rightNodeFeatureOffset) = - agg.getLeftRightFeatureOffsets(featureIndexIdx) + val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx) // Update the left or right bin for each split. val numSplits = agg.metadata.numSplits(featureIndex) var splitIndex = 0 @@ -377,9 +377,6 @@ object DecisionTree extends Serializable with Logging { if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) { agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight) - } else { - agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, - instanceWeight) } splitIndex += 1 } @@ -521,6 +518,7 @@ object DecisionTree extends Serializable with Logging { mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits, metadata.unorderedFeatures, instanceWeight, featuresForNode) } + agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight) } } @@ -847,13 +845,12 @@ object DecisionTree extends Serializable with Logging { (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (binAggregates.metadata.isUnordered(featureIndex)) { // Unordered categorical feature - val (leftChildOffset, rightChildOffset) = - binAggregates.getLeftRightFeatureOffsets(featureIndexIdx) + val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx) val (bestFeatureSplitIndex, bestFeatureGainStats) = Range(0, numSplits).map { splitIndex => val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) - val rightChildStats = - binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) + val rightChildStats = binAggregates.getParentImpurityCalculator() + .subtract(leftChildStats) predictWithImpurity = Some(predictWithImpurity.getOrElse( calculatePredictImpurity(leftChildStats, rightChildStats))) val gainStats = calculateGainForSplit(leftChildStats, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala index 7985ed4b4c0fa11905e420e3766060b93a6cd4ef..c745e9f8dbed56e5b42d1a1e60c14756cd14832b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala @@ -73,25 +73,33 @@ private[spark] class DTStatsAggregator( * Flat array of elements. * Index for start of stats for a (feature, bin) is: * index = featureOffsets(featureIndex) + binIndex * statsSize - * Note: For unordered features, - * the left child stats have binIndex in [0, numBins(featureIndex) / 2)) - * and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex)) */ private val allStats: Array[Double] = new Array[Double](allStatsSize) + /** + * Array of parent node sufficient stats. + * + * Note: this is necessary because stats for the parent node are not available + * on the first iteration of tree learning. + */ + private val parentStats: Array[Double] = new Array[Double](statsSize) /** * Get an [[ImpurityCalculator]] for a given (node, feature, bin). - * @param featureOffset For ordered features, this is a pre-computed (node, feature) offset + * @param featureOffset This is a pre-computed (node, feature) offset * from [[getFeatureOffset]]. - * For unordered features, this is a pre-computed - * (node, feature, left/right child) offset from - * [[getLeftRightFeatureOffsets]]. */ def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = { impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize) } + /** + * Get an [[ImpurityCalculator]] for the parent node. + */ + def getParentImpurityCalculator(): ImpurityCalculator = { + impurityAggregator.getCalculator(parentStats, 0) + } + /** * Update the stats for a given (feature, bin) for ordered features, using the given label. */ @@ -100,14 +108,18 @@ private[spark] class DTStatsAggregator( impurityAggregator.update(allStats, i, label, instanceWeight) } + /** + * Update the parent node stats using the given label. + */ + def updateParent(label: Double, instanceWeight: Double): Unit = { + impurityAggregator.update(parentStats, 0, label, instanceWeight) + } + /** * Faster version of [[update]]. * Update the stats for a given (feature, bin), using the given label. - * @param featureOffset For ordered features, this is a pre-computed feature offset + * @param featureOffset This is a pre-computed feature offset * from [[getFeatureOffset]]. - * For unordered features, this is a pre-computed - * (feature, left/right child) offset from - * [[getLeftRightFeatureOffsets]]. */ def featureUpdate( featureOffset: Int, @@ -124,22 +136,10 @@ private[spark] class DTStatsAggregator( */ def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex) - /** - * Pre-compute feature offset for use with [[featureUpdate]]. - * For unordered features only. - */ - def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = { - val baseOffset = featureOffsets(featureIndex) - (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize) - } - /** * For a given feature, merge the stats for two bins. - * @param featureOffset For ordered features, this is a pre-computed feature offset + * @param featureOffset This is a pre-computed feature offset * from [[getFeatureOffset]]. - * For unordered features, this is a pre-computed - * (feature, left/right child) offset from - * [[getLeftRightFeatureOffsets]]. * @param binIndex The other bin is merged into this bin. * @param otherBinIndex This bin is not modified. */ @@ -162,6 +162,17 @@ private[spark] class DTStatsAggregator( allStats(i) += other.allStats(i) i += 1 } + + require(statsSize == other.statsSize, + s"DTStatsAggregator.merge requires that both aggregators have the same length parent " + + s"stats vectors. This aggregator's parent stats are length $statsSize, " + + s"but the other is ${other.statsSize}.") + var j = 0 + while (j < statsSize) { + parentStats(j) += other.parentStats(j) + j += 1 + } + this } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala index df13d291ca3962714626b19ea85e8258e0f435ca..4f27dc44eff4d1004d677c2168701e79ce595926 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala @@ -67,11 +67,11 @@ private[spark] class DecisionTreeMetadata( /** * Number of splits for the given feature. - * For unordered features, there are 2 bins per split. + * For unordered features, there is 1 bin per split. * For ordered features, there is 1 more bin than split. */ def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) { - numBins(featureIndex) >> 1 + numBins(featureIndex) } else { numBins(featureIndex) - 1 } @@ -212,6 +212,6 @@ private[spark] object DecisionTreeMetadata extends Logging { * there are math.pow(2, arity - 1) - 1 such splits. * Each split has 2 corresponding bins. */ - def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1) + def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1 } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 73df6b054a8ce003ede8e001f8457a28ac2f1eab..13aff110079ecee225dd17338c3add41e4fcda43 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -113,7 +113,6 @@ private[tree] class EntropyAggregator(numClasses: Int) def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = { new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray) } - } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index f21845b21a8020e7b080169ada28905a652ac819..39c7f9c3be8ab37a59a8b9f6541adc8b8dc255fe 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -109,7 +109,6 @@ private[tree] class GiniAggregator(numClasses: Int) def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = { new GiniCalculator(allStats.view(offset, offset + statsSize).toArray) } - } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index b2c6e2bba43b61345f01bbeb51a034df13dffc74..65f0163ec60590afc4ebef119771c309bced2ca0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -89,7 +89,6 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser * @param offset Start index of stats for this (node, feature, bin). */ def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator - } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 09017d482a73ca32d71a40032cc3ec8abf68cc93..92d74a1b833410dd6d5530c420ffd0a5b8ab2f20 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -93,7 +93,6 @@ private[tree] class VarianceAggregator() def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = { new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray) } - } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index 5518bdf527c8a3fa96bb0b4657da5eefafa8344a..89b64fce96ebf2f653521ed2a4c0f3f5903dffc8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -189,6 +189,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext { assert(bins.length === 2) assert(splits(0).length === 3) assert(bins(0).length === 0) + assert(metadata.numSplits(0) === 3) + assert(metadata.numBins(0) === 3) + assert(metadata.numSplits(1) === 3) + assert(metadata.numBins(1) === 3) // Expecting 2^2 - 1 = 3 bins/splits assert(splits(0)(0).feature === 0)