diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 91dc98569a21b12ebbeab56f9776b66cb84e9d9f..dd9a5f261f60f86c0190fe5b063c0b18f0b0518a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -244,8 +244,7 @@ private[ml] object RandomForest extends Logging {
       if (unorderedFeatures.contains(featureIndex)) {
         // Unordered feature
         val featureValue = treePoint.binnedFeatures(featureIndex)
-        val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
-          agg.getLeftRightFeatureOffsets(featureIndexIdx)
+        val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx)
         // Update the left or right bin for each split.
         val numSplits = agg.metadata.numSplits(featureIndex)
         val featureSplits = splits(featureIndex)
@@ -253,8 +252,6 @@ private[ml] object RandomForest extends Logging {
         while (splitIndex < numSplits) {
           if (featureSplits(splitIndex).shouldGoLeft(featureValue, featureSplits)) {
             agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
-          } else {
-            agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label, instanceWeight)
           }
           splitIndex += 1
         }
@@ -394,6 +391,7 @@ private[ml] object RandomForest extends Logging {
           mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
             metadata.unorderedFeatures, instanceWeight, featuresForNode)
         }
+        agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
       }
     }
 
@@ -658,7 +656,7 @@ private[ml] object RandomForest extends Logging {
 
     // Calculate InformationGain and ImpurityStats if current node is top node
     val level = LearningNode.indexToLevel(node.id)
-    var gainAndImpurityStats: ImpurityStats = if (level ==0) {
+    var gainAndImpurityStats: ImpurityStats = if (level == 0) {
       null
     } else {
       node.stats
@@ -697,13 +695,12 @@ private[ml] object RandomForest extends Logging {
           (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
         } else if (binAggregates.metadata.isUnordered(featureIndex)) {
           // Unordered categorical feature
-          val (leftChildOffset, rightChildOffset) =
-            binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
+          val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
           val (bestFeatureSplitIndex, bestFeatureGainStats) =
             Range(0, numSplits).map { splitIndex =>
               val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
-              val rightChildStats =
-                binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
+              val rightChildStats = binAggregates.getParentImpurityCalculator()
+                .subtract(leftChildStats)
               gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
                 leftChildStats, rightChildStats, binAggregates.metadata)
               (splitIndex, gainAndImpurityStats)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 18f66e65f19ca53e2fc4dcb0cb4958f88915b4d1..c0934d241f50a59018706267729d61e0191448db 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -52,6 +52,7 @@ class DecisionTree @Since("1.0.0") (private val strategy: Strategy)
 
   /**
    * Method to train a decision tree model over an RDD
+   *
    * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
    * @return DecisionTreeModel that can be used for prediction.
    */
@@ -368,8 +369,7 @@ object DecisionTree extends Serializable with Logging {
       if (unorderedFeatures.contains(featureIndex)) {
         // Unordered feature
         val featureValue = treePoint.binnedFeatures(featureIndex)
-        val (leftNodeFeatureOffset, rightNodeFeatureOffset) =
-          agg.getLeftRightFeatureOffsets(featureIndexIdx)
+        val leftNodeFeatureOffset = agg.getFeatureOffset(featureIndexIdx)
         // Update the left or right bin for each split.
         val numSplits = agg.metadata.numSplits(featureIndex)
         var splitIndex = 0
@@ -377,9 +377,6 @@ object DecisionTree extends Serializable with Logging {
           if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) {
             agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
               instanceWeight)
-          } else {
-            agg.featureUpdate(rightNodeFeatureOffset, splitIndex, treePoint.label,
-              instanceWeight)
           }
           splitIndex += 1
         }
@@ -521,6 +518,7 @@ object DecisionTree extends Serializable with Logging {
           mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
             metadata.unorderedFeatures, instanceWeight, featuresForNode)
         }
+        agg(aggNodeIndex).updateParent(baggedPoint.datum.label, instanceWeight)
       }
     }
 
@@ -847,13 +845,12 @@ object DecisionTree extends Serializable with Logging {
           (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
         } else if (binAggregates.metadata.isUnordered(featureIndex)) {
           // Unordered categorical feature
-          val (leftChildOffset, rightChildOffset) =
-            binAggregates.getLeftRightFeatureOffsets(featureIndexIdx)
+          val leftChildOffset = binAggregates.getFeatureOffset(featureIndexIdx)
           val (bestFeatureSplitIndex, bestFeatureGainStats) =
             Range(0, numSplits).map { splitIndex =>
               val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
-              val rightChildStats =
-                binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
+              val rightChildStats = binAggregates.getParentImpurityCalculator()
+                .subtract(leftChildStats)
               predictWithImpurity = Some(predictWithImpurity.getOrElse(
                 calculatePredictImpurity(leftChildStats, rightChildStats)))
               val gainStats = calculateGainForSplit(leftChildStats,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
index 7985ed4b4c0fa11905e420e3766060b93a6cd4ef..c745e9f8dbed56e5b42d1a1e60c14756cd14832b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -73,25 +73,33 @@ private[spark] class DTStatsAggregator(
    * Flat array of elements.
    * Index for start of stats for a (feature, bin) is:
    *   index = featureOffsets(featureIndex) + binIndex * statsSize
-   * Note: For unordered features,
-   *       the left child stats have binIndex in [0, numBins(featureIndex) / 2))
-   *       and the right child stats in [numBins(featureIndex) / 2), numBins(featureIndex))
    */
   private val allStats: Array[Double] = new Array[Double](allStatsSize)
 
+  /**
+   * Array of parent node sufficient stats.
+   *
+   * Note: this is necessary because stats for the parent node are not available
+   *       on the first iteration of tree learning.
+   */
+  private val parentStats: Array[Double] = new Array[Double](statsSize)
 
   /**
    * Get an [[ImpurityCalculator]] for a given (node, feature, bin).
-   * @param featureOffset  For ordered features, this is a pre-computed (node, feature) offset
+   * @param featureOffset  This is a pre-computed (node, feature) offset
    *                           from [[getFeatureOffset]].
-   *                           For unordered features, this is a pre-computed
-   *                           (node, feature, left/right child) offset from
-   *                           [[getLeftRightFeatureOffsets]].
    */
   def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = {
     impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize)
   }
 
+  /**
+   * Get an [[ImpurityCalculator]] for the parent node.
+   */
+  def getParentImpurityCalculator(): ImpurityCalculator = {
+    impurityAggregator.getCalculator(parentStats, 0)
+  }
+
   /**
    * Update the stats for a given (feature, bin) for ordered features, using the given label.
    */
@@ -100,14 +108,18 @@ private[spark] class DTStatsAggregator(
     impurityAggregator.update(allStats, i, label, instanceWeight)
   }
 
+  /**
+   * Update the parent node stats using the given label.
+   */
+  def updateParent(label: Double, instanceWeight: Double): Unit = {
+    impurityAggregator.update(parentStats, 0, label, instanceWeight)
+  }
+
   /**
    * Faster version of [[update]].
    * Update the stats for a given (feature, bin), using the given label.
-   * @param featureOffset  For ordered features, this is a pre-computed feature offset
+   * @param featureOffset  This is a pre-computed feature offset
    *                           from [[getFeatureOffset]].
-   *                           For unordered features, this is a pre-computed
-   *                           (feature, left/right child) offset from
-   *                           [[getLeftRightFeatureOffsets]].
    */
   def featureUpdate(
       featureOffset: Int,
@@ -124,22 +136,10 @@ private[spark] class DTStatsAggregator(
    */
   def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex)
 
-  /**
-   * Pre-compute feature offset for use with [[featureUpdate]].
-   * For unordered features only.
-   */
-  def getLeftRightFeatureOffsets(featureIndex: Int): (Int, Int) = {
-    val baseOffset = featureOffsets(featureIndex)
-    (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
-  }
-
   /**
    * For a given feature, merge the stats for two bins.
-   * @param featureOffset  For ordered features, this is a pre-computed feature offset
+   * @param featureOffset  This is a pre-computed feature offset
    *                           from [[getFeatureOffset]].
-   *                           For unordered features, this is a pre-computed
-   *                           (feature, left/right child) offset from
-   *                           [[getLeftRightFeatureOffsets]].
    * @param binIndex  The other bin is merged into this bin.
    * @param otherBinIndex  This bin is not modified.
    */
@@ -162,6 +162,17 @@ private[spark] class DTStatsAggregator(
       allStats(i) += other.allStats(i)
       i += 1
     }
+
+    require(statsSize == other.statsSize,
+      s"DTStatsAggregator.merge requires that both aggregators have the same length parent " +
+        s"stats vectors. This aggregator's parent stats are length $statsSize, " +
+        s"but the other is ${other.statsSize}.")
+    var j = 0
+    while (j < statsSize) {
+      parentStats(j) += other.parentStats(j)
+      j += 1
+    }
+
     this
   }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index df13d291ca3962714626b19ea85e8258e0f435ca..4f27dc44eff4d1004d677c2168701e79ce595926 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -67,11 +67,11 @@ private[spark] class DecisionTreeMetadata(
 
   /**
    * Number of splits for the given feature.
-   * For unordered features, there are 2 bins per split.
+   * For unordered features, there is 1 bin per split.
    * For ordered features, there is 1 more bin than split.
    */
   def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) {
-    numBins(featureIndex) >> 1
+    numBins(featureIndex)
   } else {
     numBins(featureIndex) - 1
   }
@@ -212,6 +212,6 @@ private[spark] object DecisionTreeMetadata extends Logging {
    * there are math.pow(2, arity - 1) - 1 such splits.
    * Each split has 2 corresponding bins.
    */
-  def numUnorderedBins(arity: Int): Int = 2 * ((1 << arity - 1) - 1)
+  def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1
 
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 73df6b054a8ce003ede8e001f8457a28ac2f1eab..13aff110079ecee225dd17338c3add41e4fcda43 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -113,7 +113,6 @@ private[tree] class EntropyAggregator(numClasses: Int)
   def getCalculator(allStats: Array[Double], offset: Int): EntropyCalculator = {
     new EntropyCalculator(allStats.view(offset, offset + statsSize).toArray)
   }
-
 }
 
 /**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index f21845b21a8020e7b080169ada28905a652ac819..39c7f9c3be8ab37a59a8b9f6541adc8b8dc255fe 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -109,7 +109,6 @@ private[tree] class GiniAggregator(numClasses: Int)
   def getCalculator(allStats: Array[Double], offset: Int): GiniCalculator = {
     new GiniCalculator(allStats.view(offset, offset + statsSize).toArray)
   }
-
 }
 
 /**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index b2c6e2bba43b61345f01bbeb51a034df13dffc74..65f0163ec60590afc4ebef119771c309bced2ca0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -89,7 +89,6 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser
    * @param offset    Start index of stats for this (node, feature, bin).
    */
   def getCalculator(allStats: Array[Double], offset: Int): ImpurityCalculator
-
 }
 
 /**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index 09017d482a73ca32d71a40032cc3ec8abf68cc93..92d74a1b833410dd6d5530c420ffd0a5b8ab2f20 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -93,7 +93,6 @@ private[tree] class VarianceAggregator()
   def getCalculator(allStats: Array[Double], offset: Int): VarianceCalculator = {
     new VarianceCalculator(allStats.view(offset, offset + statsSize).toArray)
   }
-
 }
 
 /**
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 5518bdf527c8a3fa96bb0b4657da5eefafa8344a..89b64fce96ebf2f653521ed2a4c0f3f5903dffc8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -189,6 +189,10 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
     assert(bins.length === 2)
     assert(splits(0).length === 3)
     assert(bins(0).length === 0)
+    assert(metadata.numSplits(0) === 3)
+    assert(metadata.numBins(0) === 3)
+    assert(metadata.numSplits(1) === 3)
+    assert(metadata.numBins(1) === 3)
 
     // Expecting 2^2 - 1 = 3 bins/splits
     assert(splits(0)(0).feature === 0)