diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 98596569b8c95f2a034ad707939d7f799035cc6f..56bb8812100a78f4acdbabc25c1150d8f857fe4b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -87,17 +87,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
     val maxDepth = strategy.maxDepth
     require(maxDepth <= 30,
       s"DecisionTree currently only supports maxDepth <= 30, but was given maxDepth = $maxDepth.")
-    // Number of nodes to allocate: max number of nodes possible given the depth of the tree, plus 1
-    val maxNumNodesPlus1 = Node.startIndexInLevel(maxDepth + 1)
-    // Initialize an array to hold parent impurity calculations for each node.
-    val parentImpurities = new Array[Double](maxNumNodesPlus1)
-    // dummy value for top node (updated during first split calculation)
-    val nodes = new Array[Node](maxNumNodesPlus1)
 
     // Calculate level for single group construction
 
     // Max memory usage for aggregates
-    val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
+    val maxMemoryUsage = strategy.maxMemoryInMB * 1024L * 1024L
     logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
     // TODO: Calculate memory usage more precisely.
     val numElementsPerNode = DecisionTree.getElementsPerNode(metadata)
@@ -120,81 +114,35 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
      * beforehand and is not used in later levels.
      */
 
+    var topNode: Node = null // set on first iteration
     var level = 0
     var break = false
     while (level <= maxDepth && !break) {
-
       logDebug("#####################################")
       logDebug("level = " + level)
       logDebug("#####################################")
 
       // Find best split for all nodes at a level.
       timer.start("findBestSplits")
-      val splitsStatsForLevel: Array[(Split, InformationGainStats, Predict)] =
-        DecisionTree.findBestSplits(treeInput, parentImpurities,
-          metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
+      val (tmpTopNode: Node, doneTraining: Boolean) = DecisionTree.findBestSplits(treeInput,
+        metadata, level, topNode, splits, bins, maxLevelForSingleGroup, timer)
       timer.stop("findBestSplits")
 
-      val levelNodeIndexOffset = Node.startIndexInLevel(level)
-      for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
-        val nodeIndex = levelNodeIndexOffset + index
-
-        // Extract info for this node (index) at the current level.
-        timer.start("extractNodeInfo")
-        val split = nodeSplitStats._1
-        val stats = nodeSplitStats._2
-        val predict = nodeSplitStats._3.predict
-        val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
-        val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats))
-        logDebug("Node = " + node)
-        nodes(nodeIndex) = node
-        timer.stop("extractNodeInfo")
-
-        if (level != 0) {
-          // Set parent.
-          val parentNodeIndex = Node.parentIndex(nodeIndex)
-          if (Node.isLeftChild(nodeIndex)) {
-            nodes(parentNodeIndex).leftNode = Some(nodes(nodeIndex))
-          } else {
-            nodes(parentNodeIndex).rightNode = Some(nodes(nodeIndex))
-          }
-        }
-        // Extract info for nodes at the next lower level.
-        timer.start("extractInfoForLowerLevels")
-        if (level < maxDepth) {
-          val leftChildIndex = Node.leftChildIndex(nodeIndex)
-          val leftImpurity = stats.leftImpurity
-          logDebug("leftChildIndex = " + leftChildIndex + ", impurity = " + leftImpurity)
-          parentImpurities(leftChildIndex) = leftImpurity
-
-          val rightChildIndex = Node.rightChildIndex(nodeIndex)
-          val rightImpurity = stats.rightImpurity
-          logDebug("rightChildIndex = " + rightChildIndex + ", impurity = " + rightImpurity)
-          parentImpurities(rightChildIndex) = rightImpurity
-        }
-        timer.stop("extractInfoForLowerLevels")
-        logDebug("final best split = " + split)
+      if (level == 0) {
+        topNode = tmpTopNode
       }
-      require(Node.maxNodesInLevel(level) == splitsStatsForLevel.length)
-      // Check whether all the nodes at the current level at leaves.
-      val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
-      logDebug("all leaf = " + allLeaf)
-      if (allLeaf) {
-        break = true // no more tree construction
-      } else {
-        level += 1
+      if (doneTraining) {
+        break = true
+        logDebug("done training")
       }
+
+      level += 1
     }
 
     logDebug("#####################################")
     logDebug("Extracting tree model")
     logDebug("#####################################")
 
-    // Initialize the top or root node of the tree.
-    val topNode = nodes(1)
-    // Build the full tree using the node info calculated in the level-wise best split calculations.
-    topNode.build(nodes)
-
     timer.stop("total")
 
     logInfo("Internal timing for DecisionTree:")
@@ -409,24 +357,26 @@ object DecisionTree extends Serializable with Logging {
    * multiple groups if the level-wise training task could lead to memory overflow.
    *
    * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
-   * @param parentImpurities Impurities for all parent nodes for the current level
    * @param metadata Learning and dataset metadata
    * @param level Level of the tree
+   * @param topNode Root node of the tree (or invalid node when training first level).
    * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
    * @param bins possible bins for all features, indexed (numFeatures)(numBins)
    * @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
-   * @return array (over nodes) of splits with best split for each node at a given level.
+   * @return  (root, doneTraining) where:
+   *          root = Root node (which is newly created on the first iteration),
+   *          doneTraining = true if no more internal nodes were created.
    */
   private[tree] def findBestSplits(
       input: RDD[TreePoint],
-      parentImpurities: Array[Double],
       metadata: DecisionTreeMetadata,
       level: Int,
-      nodes: Array[Node],
+      topNode: Node,
       splits: Array[Array[Split]],
       bins: Array[Array[Bin]],
       maxLevelForSingleGroup: Int,
-      timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats, Predict)] = {
+      timer: TimeTracker = new TimeTracker): (Node, Boolean) = {
+
     // split into groups to avoid memory overflow during aggregation
     if (level > maxLevelForSingleGroup) {
       // When information for all nodes at a given level cannot be stored in memory,
@@ -435,18 +385,18 @@ object DecisionTree extends Serializable with Logging {
       // numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
       val numGroups = 1 << level - maxLevelForSingleGroup
       logDebug("numGroups = " + numGroups)
-      var bestSplits = new Array[(Split, InformationGainStats, Predict)](0)
       // Iterate over each group of nodes at a level.
       var groupIndex = 0
+      var doneTraining = true
       while (groupIndex < numGroups) {
-        val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, metadata, level,
-          nodes, splits, bins, timer, numGroups, groupIndex)
-        bestSplits = Array.concat(bestSplits, bestSplitsForGroup)
+        val (tmpRoot, doneTrainingGroup) = findBestSplitsPerGroup(input, metadata, level,
+          topNode, splits, bins, timer, numGroups, groupIndex)
+        doneTraining = doneTraining && doneTrainingGroup
         groupIndex += 1
       }
-      bestSplits
+      (topNode, doneTraining) // Not first iteration, so topNode was already set.
     } else {
-      findBestSplitsPerGroup(input, parentImpurities, metadata, level, nodes, splits, bins, timer)
+      findBestSplitsPerGroup(input, metadata, level, topNode, splits, bins, timer)
     }
   }
 
@@ -586,27 +536,27 @@ object DecisionTree extends Serializable with Logging {
    * Returns an array of optimal splits for a group of nodes at a given level
    *
    * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
-   * @param parentImpurities Impurities for all parent nodes for the current level
    * @param metadata Learning and dataset metadata
    * @param level Level of the tree
-   * @param nodes Array of all nodes in the tree.  Used for matching data points to nodes.
+   * @param topNode Root node of the tree (or invalid node when training first level).
    * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
    * @param bins possible bins for all features, indexed (numFeatures)(numBins)
    * @param numGroups total number of node groups at the current level. Default value is set to 1.
    * @param groupIndex index of the node group being processed. Default value is set to 0.
-   * @return array of splits with best splits for all nodes at a given level.
+   * @return  (root, doneTraining) where:
+   *          root = Root node (which is newly created on the first iteration),
+   *          doneTraining = true if no more internal nodes were created.
    */
   private def findBestSplitsPerGroup(
       input: RDD[TreePoint],
-      parentImpurities: Array[Double],
       metadata: DecisionTreeMetadata,
       level: Int,
-      nodes: Array[Node],
+      topNode: Node,
       splits: Array[Array[Split]],
       bins: Array[Array[Bin]],
       timer: TimeTracker,
       numGroups: Int = 1,
-      groupIndex: Int = 0): Array[(Split, InformationGainStats, Predict)] = {
+      groupIndex: Int = 0): (Node, Boolean) = {
 
     /*
      * The high-level descriptions of the best split optimizations are noted here.
@@ -663,7 +613,7 @@ object DecisionTree extends Serializable with Logging {
         0
       } else {
         val globalNodeIndex =
-          predictNodeIndex(nodes(1), treePoint.binnedFeatures, bins, metadata.unorderedFeatures)
+          predictNodeIndex(topNode, treePoint.binnedFeatures, bins, metadata.unorderedFeatures)
         globalNodeIndex - globalNodeIndexOffset
       }
     }
@@ -706,33 +656,63 @@ object DecisionTree extends Serializable with Logging {
 
     // Calculate best splits for all nodes at a given level
     timer.start("chooseSplits")
-    val bestSplits = new Array[(Split, InformationGainStats, Predict)](numNodes)
-    // Iterating over all nodes at this level
+    // On the first iteration, we need to get and return the newly created root node.
+    var newTopNode: Node = topNode
+
+    // Iterate over all nodes at this level
     var nodeIndex = 0
+    var internalNodeCount = 0
     while (nodeIndex < numNodes) {
-      val nodeImpurity = parentImpurities(globalNodeIndexOffset + nodeIndex)
-      logDebug("node impurity = " + nodeImpurity)
-      bestSplits(nodeIndex) =
-        binsToBestSplit(binAggregates, nodeIndex, nodeImpurity, level, metadata, splits)
-      logDebug("best split = " + bestSplits(nodeIndex)._1)
+      val (split: Split, stats: InformationGainStats, predict: Predict) =
+        binsToBestSplit(binAggregates, nodeIndex, level, metadata, splits)
+      logDebug("best split = " + split)
+
+      val globalNodeIndex = globalNodeIndexOffset + nodeIndex
+
+      // Extract info for this node at the current level.
+      val isLeaf = (stats.gain <= 0) || (level == metadata.maxDepth)
+      val node =
+        new Node(globalNodeIndex, predict.predict, isLeaf, Some(split), None, None, Some(stats))
+      logDebug("Node = " + node)
+
+      if (!isLeaf) {
+        internalNodeCount += 1
+      }
+      if (level == 0) {
+        newTopNode = node
+      } else {
+        // Set parent.
+        val parentNode = Node.getNode(Node.parentIndex(globalNodeIndex), topNode)
+        if (Node.isLeftChild(globalNodeIndex)) {
+          parentNode.leftNode = Some(node)
+        } else {
+          parentNode.rightNode = Some(node)
+        }
+      }
+      if (level < metadata.maxDepth) {
+        logDebug("leftChildIndex = " + Node.leftChildIndex(globalNodeIndex) +
+          ", impurity = " + stats.leftImpurity)
+        logDebug("rightChildIndex = " + Node.rightChildIndex(globalNodeIndex) +
+          ", impurity = " + stats.rightImpurity)
+      }
+
       nodeIndex += 1
     }
     timer.stop("chooseSplits")
 
-    bestSplits
+    val doneTraining = internalNodeCount == 0
+    (newTopNode, doneTraining)
   }
 
   /**
    * Calculate the information gain for a given (feature, split) based upon left/right aggregates.
    * @param leftImpurityCalculator left node aggregates for this (feature, split)
    * @param rightImpurityCalculator right node aggregate for this (feature, split)
-   * @param topImpurity impurity of the parent node
    * @return information gain and statistics for all splits
    */
   private def calculateGainForSplit(
       leftImpurityCalculator: ImpurityCalculator,
       rightImpurityCalculator: ImpurityCalculator,
-      topImpurity: Double,
       level: Int,
       metadata: DecisionTreeMetadata): InformationGainStats = {
     val leftCount = leftImpurityCalculator.count
@@ -747,14 +727,10 @@ object DecisionTree extends Serializable with Logging {
 
     val totalCount = leftCount + rightCount
 
-    // impurity of parent node
-    val impurity = if (level > 0) {
-      topImpurity
-    } else {
-      val parentNodeAgg = leftImpurityCalculator.copy
-      parentNodeAgg.add(rightImpurityCalculator)
-      parentNodeAgg.calculate()
-    }
+    val parentNodeAgg = leftImpurityCalculator.copy
+    parentNodeAgg.add(rightImpurityCalculator)
+
+    val impurity = parentNodeAgg.calculate()
 
     val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
     val rightImpurity = rightImpurityCalculator.calculate()
@@ -795,19 +771,15 @@ object DecisionTree extends Serializable with Logging {
    * Find the best split for a node.
    * @param binAggregates Bin statistics.
    * @param nodeIndex Index for node to split in this (level, group).
-   * @param nodeImpurity Impurity of the node (nodeIndex).
    * @return tuple for best split: (Split, information gain)
    */
   private def binsToBestSplit(
       binAggregates: DTStatsAggregator,
       nodeIndex: Int,
-      nodeImpurity: Double,
       level: Int,
       metadata: DecisionTreeMetadata,
       splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = {
 
-    logDebug("node impurity = " + nodeImpurity)
-
     // calculate predict only once
     var predict: Option[Predict] = None
 
@@ -831,8 +803,7 @@ object DecisionTree extends Serializable with Logging {
             val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
             rightChildStats.subtract(leftChildStats)
             predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
-            val gainStats =
-              calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
+            val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata)
             (splitIdx, gainStats)
           }.maxBy(_._2.gain)
         (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
@@ -845,8 +816,7 @@ object DecisionTree extends Serializable with Logging {
             val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
             val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
             predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
-            val gainStats =
-              calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
+            val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata)
             (splitIndex, gainStats)
           }.maxBy(_._2.gain)
         (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
@@ -917,8 +887,7 @@ object DecisionTree extends Serializable with Logging {
               binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
             rightChildStats.subtract(leftChildStats)
             predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
-            val gainStats =
-              calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
+            val gainStats = calculateGainForSplit(leftChildStats, rightChildStats, level, metadata)
             (splitIndex, gainStats)
           }.maxBy(_._2.gain)
         val categoriesForSplit =
@@ -937,8 +906,8 @@ object DecisionTree extends Serializable with Logging {
   /**
    * Get the number of values to be stored per node in the bin aggregates.
    */
-  private def getElementsPerNode(metadata: DecisionTreeMetadata): Int = {
-    val totalBins = metadata.numBins.sum
+  private def getElementsPerNode(metadata: DecisionTreeMetadata): Long = {
+    val totalBins = metadata.numBins.map(_.toLong).sum
     if (metadata.isClassification) {
       metadata.numClasses * totalBins
     } else {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 987fe632c91ed0f994295f89f0dd6a746d73daf1..31d1e8ac30eea81c65ead384cc024b257a53f996 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -75,6 +75,9 @@ class Strategy (
   if (algo == Classification) {
     require(numClassesForClassification >= 2)
   }
+  require(minInstancesPerNode >= 1,
+    s"DecisionTree Strategy requires minInstancesPerNode >= 1 but was given $minInstancesPerNode")
+
   val isMulticlassClassification =
     algo == Classification && numClassesForClassification > 2
   val isMulticlassWithCategoricalFeatures
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
index 866d85a79bea157236d2acdfebbf31b40a3f1ea4..61a94246711bf054d6b55abe9cf756253a7e2bca 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala
@@ -65,14 +65,7 @@ private[tree] class DTStatsAggregator(
    * Offset for each feature for calculating indices into the [[allStats]] array.
    */
   private val featureOffsets: Array[Int] = {
-    def featureOffsetsCalc(total: Int, featureIndex: Int): Int = {
-      if (isUnordered(featureIndex)) {
-        total + 2 * numBins(featureIndex)
-      } else {
-        total + numBins(featureIndex)
-      }
-    }
-    Range(0, numFeatures).scanLeft(0)(featureOffsetsCalc).map(statsSize * _).toArray
+    numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins)
   }
 
   /**
@@ -149,7 +142,7 @@ private[tree] class DTStatsAggregator(
       s"DTStatsAggregator.getLeftRightNodeFeatureOffsets is for unordered features only," +
       s" but was called for ordered feature $featureIndex.")
     val baseOffset = nodeIndex * nodeStride + featureOffsets(featureIndex)
-    (baseOffset, baseOffset + numBins(featureIndex) * statsSize)
+    (baseOffset, baseOffset + (numBins(featureIndex) >> 1) * statsSize)
   }
 
   /**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index 5ceaa8154d11a77c3c77effb87e85e7d9a8ed560..b6d49e5555b1a18f196fd0f117e18884c61789a2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -46,6 +46,7 @@ private[tree] class DecisionTreeMetadata(
     val numBins: Array[Int],
     val impurity: Impurity,
     val quantileStrategy: QuantileStrategy,
+    val maxDepth: Int,
     val minInstancesPerNode: Int,
     val minInfoGain: Double) extends Serializable {
 
@@ -129,7 +130,7 @@ private[tree] object DecisionTreeMetadata {
 
     new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
       strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
-      strategy.impurity, strategy.quantileCalculationStrategy,
+      strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth,
       strategy.minInstancesPerNode, strategy.minInfoGain)
   }
 
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 0594fd0749d21d1d32ad7a0552341044b6931768..271b2c4ad813e812211a7d1fa0a1e4d4579e24af 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -46,7 +46,7 @@ class DecisionTreeModel(val topNode: Node, val algo: Algo) extends Serializable
    * Predict values for the given data set using the model trained.
    *
    * @param features RDD representing data points to be predicted
-   * @return RDD[Int] where each entry contains the corresponding prediction
+   * @return RDD of predictions for each of the given data points
    */
   def predict(features: RDD[Vector]): RDD[Double] = {
     features.map(x => predict(x))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
index 5b8a4cbed2306d1519622a598ff692081b44de9f..5f0095d23c7ed2c4229b8792d5f30e9e10ff88e6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Node.scala
@@ -55,6 +55,8 @@ class Node (
    * build the left node and right nodes if not leaf
    * @param nodes array of nodes
    */
+  @deprecated("build should no longer be used since trees are constructed on-the-fly in training",
+    "1.2.0")
   def build(nodes: Array[Node]): Unit = {
     logDebug("building node " + id + " at level " + Node.indexToLevel(id))
     logDebug("id = " + id + ", split = " + split)
@@ -93,6 +95,23 @@ class Node (
     }
   }
 
+  /**
+   * Returns a deep copy of the subtree rooted at this node.
+   */
+  private[tree] def deepCopy(): Node = {
+    val leftNodeCopy = if (leftNode.isEmpty) {
+      None
+    } else {
+      Some(leftNode.get.deepCopy())
+    }
+    val rightNodeCopy = if (rightNode.isEmpty) {
+      None
+    } else {
+      Some(rightNode.get.deepCopy())
+    }
+    new Node(id, predict, isLeaf, split, leftNodeCopy, rightNodeCopy, stats)
+  }
+
   /**
    * Get the number of nodes in tree below this node, including leaf nodes.
    * E.g., if this is a leaf, returns 0.  If both children are leaves, returns 2.
@@ -190,4 +209,22 @@ private[tree] object Node {
    */
   def startIndexInLevel(level: Int): Int = 1 << level
 
+  /**
+   * Traces down from a root node to get the node with the given node index.
+   * This assumes the node exists.
+   */
+  def getNode(nodeIndex: Int, rootNode: Node): Node = {
+    var tmpNode: Node = rootNode
+    var levelsToGo = indexToLevel(nodeIndex)
+    while (levelsToGo > 0) {
+      if ((nodeIndex & (1 << levelsToGo - 1)) == 0) {
+        tmpNode = tmpNode.leftNode.get
+      } else {
+        tmpNode = tmpNode.rightNode.get
+      }
+      levelsToGo -= 1
+    }
+    tmpNode
+  }
+
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index fd8547c1660fca51dfd7a7687ec5d096c4f91a3d..1bd7ea05c46c819f2a4356bbb48177ee78b1159b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -270,19 +270,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins(0).length === 0)
 
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
+    val (rootNode: Node, doneTraining: Boolean) =
+      DecisionTree.findBestSplits(treeInput, metadata, 0, null, splits, bins, 10)
 
-    val split = bestSplits(0)._1
+    val split = rootNode.split.get
     assert(split.categories === List(1.0))
     assert(split.featureType === Categorical)
     assert(split.threshold === Double.MinValue)
 
-    val stats = bestSplits(0)._2
-    val predict = bestSplits(0)._3
+    val stats = rootNode.stats.get
     assert(stats.gain > 0)
-    assert(predict.predict === 1)
-    assert(predict.prob === 0.6)
+    assert(rootNode.predict === 1)
     assert(stats.impurity > 0.2)
   }
 
@@ -303,19 +301,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
 
-    val split = bestSplits(0)._1
+    val split = rootNode.split.get
     assert(split.categories.length === 1)
     assert(split.categories.contains(1.0))
     assert(split.featureType === Categorical)
     assert(split.threshold === Double.MinValue)
 
-    val stats = bestSplits(0)._2
-    val predict = bestSplits(0)._3.predict
+    val stats = rootNode.stats.get
     assert(stats.gain > 0)
-    assert(predict === 0.6)
+    assert(rootNode.predict === 0.6)
     assert(stats.impurity > 0.2)
   }
 
@@ -356,13 +353,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins(0).length === 100)
 
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
-    assert(bestSplits.length === 1)
-    assert(bestSplits(0)._1.feature === 0)
-    assert(bestSplits(0)._2.gain === 0)
-    assert(bestSplits(0)._2.leftImpurity === 0)
-    assert(bestSplits(0)._2.rightImpurity === 0)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
+
+    val split = rootNode.split.get
+    assert(split.feature === 0)
+
+    val stats = rootNode.stats.get
+    assert(stats.gain === 0)
+    assert(stats.leftImpurity === 0)
+    assert(stats.rightImpurity === 0)
   }
 
   test("Binary classification stump with fixed label 1 for Gini") {
@@ -382,14 +382,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins(0).length === 100)
 
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
-    assert(bestSplits.length === 1)
-    assert(bestSplits(0)._1.feature === 0)
-    assert(bestSplits(0)._2.gain === 0)
-    assert(bestSplits(0)._2.leftImpurity === 0)
-    assert(bestSplits(0)._2.rightImpurity === 0)
-    assert(bestSplits(0)._3.predict === 1)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
+
+    val split = rootNode.split.get
+    assert(split.feature === 0)
+
+    val stats = rootNode.stats.get
+    assert(stats.gain === 0)
+    assert(stats.leftImpurity === 0)
+    assert(stats.rightImpurity === 0)
+    assert(rootNode.predict === 1)
   }
 
   test("Binary classification stump with fixed label 0 for Entropy") {
@@ -409,14 +412,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins(0).length === 100)
 
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
-    assert(bestSplits.length === 1)
-    assert(bestSplits(0)._1.feature === 0)
-    assert(bestSplits(0)._2.gain === 0)
-    assert(bestSplits(0)._2.leftImpurity === 0)
-    assert(bestSplits(0)._2.rightImpurity === 0)
-    assert(bestSplits(0)._3.predict === 0)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
+
+    val split = rootNode.split.get
+    assert(split.feature === 0)
+
+    val stats = rootNode.stats.get
+    assert(stats.gain === 0)
+    assert(stats.leftImpurity === 0)
+    assert(stats.rightImpurity === 0)
+    assert(rootNode.predict === 0)
   }
 
   test("Binary classification stump with fixed label 1 for Entropy") {
@@ -436,14 +442,17 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins(0).length === 100)
 
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(2), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
-    assert(bestSplits.length === 1)
-    assert(bestSplits(0)._1.feature === 0)
-    assert(bestSplits(0)._2.gain === 0)
-    assert(bestSplits(0)._2.leftImpurity === 0)
-    assert(bestSplits(0)._2.rightImpurity === 0)
-    assert(bestSplits(0)._3.predict === 1)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
+
+    val split = rootNode.split.get
+    assert(split.feature === 0)
+
+    val stats = rootNode.stats.get
+    assert(stats.gain === 0)
+    assert(stats.leftImpurity === 0)
+    assert(stats.rightImpurity === 0)
+    assert(rootNode.predict === 1)
   }
 
   test("Second level node building with vs. without groups") {
@@ -459,40 +468,46 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins(0).length === 100)
 
     // Train a 1-node model
-    val strategyOneNode = new Strategy(Classification, Entropy, 1, 2, 100)
+    val strategyOneNode = new Strategy(Classification, Entropy, maxDepth = 1,
+      numClassesForClassification = 2, maxBins = 100)
     val modelOneNode = DecisionTree.train(rdd, strategyOneNode)
-    val nodes: Array[Node] = new Array[Node](8)
-    nodes(1) = modelOneNode.topNode
-    nodes(1).leftNode = None
-    nodes(1).rightNode = None
-
-    val parentImpurities = Array(0, 0.5, 0.5, 0.5)
+    val rootNodeCopy1 = modelOneNode.topNode.deepCopy()
+    val rootNodeCopy2 = modelOneNode.topNode.deepCopy()
 
     // Single group second level tree construction.
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1, nodes,
-      splits, bins, 10)
-    assert(bestSplits.length === 2)
-    assert(bestSplits(0)._2.gain > 0)
-    assert(bestSplits(1)._2.gain > 0)
+    val (rootNode, _) = DecisionTree.findBestSplits(treeInput, metadata, 1,
+      rootNodeCopy1, splits, bins, 10)
+    assert(rootNode.leftNode.nonEmpty)
+    assert(rootNode.rightNode.nonEmpty)
+    val children1 = new Array[Node](2)
+    children1(0) = rootNode.leftNode.get
+    children1(1) = rootNode.rightNode.get
 
     // maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second
     // level tree construction.
-    val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, metadata, 1,
-      nodes, splits, bins, 0)
-    assert(bestSplitsWithGroups.length === 2)
-    assert(bestSplitsWithGroups(0)._2.gain > 0)
-    assert(bestSplitsWithGroups(1)._2.gain > 0)
+    val (rootNode2, _) = DecisionTree.findBestSplits(treeInput, metadata, 1,
+      rootNodeCopy2, splits, bins, 0)
+    assert(rootNode2.leftNode.nonEmpty)
+    assert(rootNode2.rightNode.nonEmpty)
+    val children2 = new Array[Node](2)
+    children2(0) = rootNode2.leftNode.get
+    children2(1) = rootNode2.rightNode.get
 
     // Verify whether the splits obtained using single group and multiple group level
     // construction strategies are the same.
-    for (i <- 0 until bestSplits.length) {
-      assert(bestSplits(i)._1 === bestSplitsWithGroups(i)._1)
-      assert(bestSplits(i)._2.gain === bestSplitsWithGroups(i)._2.gain)
-      assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity)
-      assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity)
-      assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity)
-      assert(bestSplits(i)._3.predict === bestSplitsWithGroups(i)._3.predict)
+    for (i <- 0 until 2) {
+      assert(children1(i).stats.nonEmpty && children1(i).stats.get.gain > 0)
+      assert(children2(i).stats.nonEmpty && children2(i).stats.get.gain > 0)
+      assert(children1(i).split === children2(i).split)
+      assert(children1(i).stats.nonEmpty && children2(i).stats.nonEmpty)
+      val stats1 = children1(i).stats.get
+      val stats2 = children2(i).stats.get
+      assert(stats1.gain === stats2.gain)
+      assert(stats1.impurity === stats2.impurity)
+      assert(stats1.leftImpurity === stats2.leftImpurity)
+      assert(stats1.rightImpurity === stats2.rightImpurity)
+      assert(children1(i).predict === children2(i).predict)
     }
   }
 
@@ -508,15 +523,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
 
-    assert(bestSplits.length === 1)
-    val bestSplit = bestSplits(0)._1
-    assert(bestSplit.feature === 0)
-    assert(bestSplit.categories.length === 1)
-    assert(bestSplit.categories.contains(1))
-    assert(bestSplit.featureType === Categorical)
+    val split = rootNode.split.get
+    assert(split.feature === 0)
+    assert(split.categories.length === 1)
+    assert(split.categories.contains(1))
+    assert(split.featureType === Categorical)
   }
 
   test("Binary classification stump with 1 continuous feature, to check off-by-1 error") {
@@ -573,16 +587,16 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
-
-    assert(bestSplits.length === 1)
-    val bestSplit = bestSplits(0)._1
-    assert(bestSplit.feature === 0)
-    assert(bestSplit.categories.length === 1)
-    assert(bestSplit.categories.contains(1))
-    assert(bestSplit.featureType === Categorical)
-    val gain = bestSplits(0)._2
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
+
+    val split = rootNode.split.get
+    assert(split.feature === 0)
+    assert(split.categories.length === 1)
+    assert(split.categories.contains(1))
+    assert(split.featureType === Categorical)
+
+    val gain = rootNode.stats.get
     assert(gain.leftImpurity === 0)
     assert(gain.rightImpurity === 0)
   }
@@ -600,16 +614,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
-
-    assert(bestSplits.length === 1)
-    val bestSplit = bestSplits(0)._1
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
 
-    assert(bestSplit.feature === 1)
-    assert(bestSplit.featureType === Continuous)
-    assert(bestSplit.threshold > 1980)
-    assert(bestSplit.threshold < 2020)
+    val split = rootNode.split.get
+    assert(split.feature === 1)
+    assert(split.featureType === Continuous)
+    assert(split.threshold > 1980)
+    assert(split.threshold < 2020)
 
   }
 
@@ -627,16 +639,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
 
-    assert(bestSplits.length === 1)
-    val bestSplit = bestSplits(0)._1
-
-    assert(bestSplit.feature === 1)
-    assert(bestSplit.featureType === Continuous)
-    assert(bestSplit.threshold > 1980)
-    assert(bestSplit.threshold < 2020)
+    val split = rootNode.split.get
+    assert(split.feature === 1)
+    assert(split.featureType === Continuous)
+    assert(split.threshold > 1980)
+    assert(split.threshold < 2020)
   }
 
   test("Multiclass classification stump with 10-ary (ordered) categorical features") {
@@ -652,15 +662,14 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, metadata)
     val treeInput = TreePoint.convertToTreeRDD(rdd, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(32), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
 
-    assert(bestSplits.length === 1)
-    val bestSplit = bestSplits(0)._1
-    assert(bestSplit.feature === 0)
-    assert(bestSplit.categories.length === 1)
-    assert(bestSplit.categories.contains(1.0))
-    assert(bestSplit.featureType === Categorical)
+    val split = rootNode.split.get
+    assert(split.feature === 0)
+    assert(split.categories.length === 1)
+    assert(split.categories.contains(1.0))
+    assert(split.featureType === Categorical)
   }
 
   test("Multiclass classification tree with 10-ary (ordered) categorical features," +
@@ -698,12 +707,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
     val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
     val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
 
-    assert(bestSplits.length == 1)
-    val bestInfoStats = bestSplits(0)._2
-    assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
+    val gain = rootNode.stats.get
+    assert(gain == InformationGainStats.invalidInformationGainStats)
   }
 
   test("don't choose split that doesn't satisfy min instance per node requirements") {
@@ -722,14 +730,13 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
     val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
     val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
 
-    assert(bestSplits.length == 1)
-    val bestSplit = bestSplits(0)._1
-    val bestSplitStats = bestSplits(0)._1
-    assert(bestSplit.feature == 1)
-    assert(bestSplitStats != InformationGainStats.invalidInformationGainStats)
+    val split = rootNode.split.get
+    val gain = rootNode.stats.get
+    assert(split.feature == 1)
+    assert(gain != InformationGainStats.invalidInformationGainStats)
   }
 
   test("split must satisfy min info gain requirements") {
@@ -754,12 +761,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
     val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
     val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
-    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
-      new Array[Node](0), splits, bins, 10)
+    val (rootNode, doneTraining) = DecisionTree.findBestSplits(treeInput, metadata, 0,
+      null, splits, bins, 10)
 
-    assert(bestSplits.length == 1)
-    val bestInfoStats = bestSplits(0)._2
-    assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
+    val gain = rootNode.stats.get
+    assert(gain == InformationGainStats.invalidInformationGainStats)
   }
 }
 
@@ -786,13 +792,16 @@ object DecisionTreeSuite {
   def generateOrderedLabeledPoints(): Array[LabeledPoint] = {
     val arr = new Array[LabeledPoint](1000)
     for (i <- 0 until 1000) {
-      if (i < 600) {
-        val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
-        arr(i) = lp
+      val label = if (i < 100) {
+        0.0
+      } else if (i < 500) {
+        1.0
+      } else if (i < 900) {
+        0.0
       } else {
-        val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 1000.0 - i))
-        arr(i) = lp
+        1.0
       }
+      arr(i) = new LabeledPoint(label, Vectors.dense(i.toDouble, 1000.0 - i))
     }
     arr
   }