diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index d1309b2b20f546173a84ba5370ce1663bb7d5f74..98596569b8c95f2a034ad707939d7f799035cc6f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -130,7 +130,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
 
       // Find best split for all nodes at a level.
       timer.start("findBestSplits")
-      val splitsStatsForLevel: Array[(Split, InformationGainStats)] =
+      val splitsStatsForLevel: Array[(Split, InformationGainStats, Predict)] =
         DecisionTree.findBestSplits(treeInput, parentImpurities,
           metadata, level, nodes, splits, bins, maxLevelForSingleGroup, timer)
       timer.stop("findBestSplits")
@@ -143,8 +143,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
         timer.start("extractNodeInfo")
         val split = nodeSplitStats._1
         val stats = nodeSplitStats._2
+        val predict = nodeSplitStats._3.predict
         val isLeaf = (stats.gain <= 0) || (level == strategy.maxDepth)
-        val node = new Node(nodeIndex, stats.predict, isLeaf, Some(split), None, None, Some(stats))
+        val node = new Node(nodeIndex, predict, isLeaf, Some(split), None, None, Some(stats))
         logDebug("Node = " + node)
         nodes(nodeIndex) = node
         timer.stop("extractNodeInfo")
@@ -425,7 +426,7 @@ object DecisionTree extends Serializable with Logging {
       splits: Array[Array[Split]],
       bins: Array[Array[Bin]],
       maxLevelForSingleGroup: Int,
-      timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = {
+      timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats, Predict)] = {
     // split into groups to avoid memory overflow during aggregation
     if (level > maxLevelForSingleGroup) {
       // When information for all nodes at a given level cannot be stored in memory,
@@ -434,7 +435,7 @@ object DecisionTree extends Serializable with Logging {
       // numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
       val numGroups = 1 << level - maxLevelForSingleGroup
       logDebug("numGroups = " + numGroups)
-      var bestSplits = new Array[(Split, InformationGainStats)](0)
+      var bestSplits = new Array[(Split, InformationGainStats, Predict)](0)
       // Iterate over each group of nodes at a level.
       var groupIndex = 0
       while (groupIndex < numGroups) {
@@ -605,7 +606,7 @@ object DecisionTree extends Serializable with Logging {
       bins: Array[Array[Bin]],
       timer: TimeTracker,
       numGroups: Int = 1,
-      groupIndex: Int = 0): Array[(Split, InformationGainStats)] = {
+      groupIndex: Int = 0): Array[(Split, InformationGainStats, Predict)] = {
 
     /*
      * The high-level descriptions of the best split optimizations are noted here.
@@ -705,7 +706,7 @@ object DecisionTree extends Serializable with Logging {
 
     // Calculate best splits for all nodes at a given level
     timer.start("chooseSplits")
-    val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
+    val bestSplits = new Array[(Split, InformationGainStats, Predict)](numNodes)
     // Iterating over all nodes at this level
     var nodeIndex = 0
     while (nodeIndex < numNodes) {
@@ -734,28 +735,27 @@ object DecisionTree extends Serializable with Logging {
       topImpurity: Double,
       level: Int,
       metadata: DecisionTreeMetadata): InformationGainStats = {
-
     val leftCount = leftImpurityCalculator.count
     val rightCount = rightImpurityCalculator.count
 
-    val totalCount = leftCount + rightCount
-    if (totalCount == 0) {
-      // Return arbitrary prediction.
-      return new InformationGainStats(0, topImpurity, topImpurity, topImpurity, 0)
+    // If left child or right child doesn't satisfy minimum instances per node,
+    // then this split is invalid, return invalid information gain stats.
+    if ((leftCount < metadata.minInstancesPerNode) ||
+        (rightCount < metadata.minInstancesPerNode)) {
+      return InformationGainStats.invalidInformationGainStats
     }
 
-    val parentNodeAgg = leftImpurityCalculator.copy
-    parentNodeAgg.add(rightImpurityCalculator)
+    val totalCount = leftCount + rightCount
+
     // impurity of parent node
     val impurity = if (level > 0) {
       topImpurity
     } else {
+      val parentNodeAgg = leftImpurityCalculator.copy
+      parentNodeAgg.add(rightImpurityCalculator)
       parentNodeAgg.calculate()
     }
 
-    val predict = parentNodeAgg.predict
-    val prob = parentNodeAgg.prob(predict)
-
     val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
     val rightImpurity = rightImpurityCalculator.calculate()
 
@@ -764,7 +764,31 @@ object DecisionTree extends Serializable with Logging {
 
     val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
 
-    new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
+    // if information gain doesn't satisfy minimum information gain,
+    // then this split is invalid, return invalid information gain stats.
+    if (gain < metadata.minInfoGain) {
+      return InformationGainStats.invalidInformationGainStats
+    }
+
+    new InformationGainStats(gain, impurity, leftImpurity, rightImpurity)
+  }
+
+  /**
+   * Calculate predict value for current node, given stats of any split.
+   * Note that this function is called only once for each node.
+   * @param leftImpurityCalculator left node aggregates for a split
+   * @param rightImpurityCalculator right node aggregates for a node
+   * @return predict value for current node
+   */
+  private def calculatePredict(
+      leftImpurityCalculator: ImpurityCalculator,
+      rightImpurityCalculator: ImpurityCalculator): Predict =  {
+    val parentNodeAgg = leftImpurityCalculator.copy
+    parentNodeAgg.add(rightImpurityCalculator)
+    val predict = parentNodeAgg.predict
+    val prob = parentNodeAgg.prob(predict)
+
+    new Predict(predict, prob)
   }
 
   /**
@@ -780,12 +804,15 @@ object DecisionTree extends Serializable with Logging {
       nodeImpurity: Double,
       level: Int,
       metadata: DecisionTreeMetadata,
-      splits: Array[Array[Split]]): (Split, InformationGainStats) = {
+      splits: Array[Array[Split]]): (Split, InformationGainStats, Predict) = {
 
     logDebug("node impurity = " + nodeImpurity)
 
+    // calculate predict only once
+    var predict: Option[Predict] = None
+
     // For each (feature, split), calculate the gain, and select the best (feature, split).
-    Range(0, metadata.numFeatures).map { featureIndex =>
+    val (bestSplit, bestSplitStats) = Range(0, metadata.numFeatures).map { featureIndex =>
       val numSplits = metadata.numSplits(featureIndex)
       if (metadata.isContinuous(featureIndex)) {
         // Cumulative sum (scanLeft) of bin statistics.
@@ -803,6 +830,7 @@ object DecisionTree extends Serializable with Logging {
             val leftChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, splitIdx)
             val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
             rightChildStats.subtract(leftChildStats)
+            predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
             val gainStats =
               calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
             (splitIdx, gainStats)
@@ -816,6 +844,7 @@ object DecisionTree extends Serializable with Logging {
           Range(0, numSplits).map { splitIndex =>
             val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
             val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
+            predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
             val gainStats =
               calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
             (splitIndex, gainStats)
@@ -887,6 +916,7 @@ object DecisionTree extends Serializable with Logging {
             val rightChildStats =
               binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
             rightChildStats.subtract(leftChildStats)
+            predict = Some(predict.getOrElse(calculatePredict(leftChildStats, rightChildStats)))
             val gainStats =
               calculateGainForSplit(leftChildStats, rightChildStats, nodeImpurity, level, metadata)
             (splitIndex, gainStats)
@@ -898,6 +928,10 @@ object DecisionTree extends Serializable with Logging {
         (bestFeatureSplit, bestFeatureGainStats)
       }
     }.maxBy(_._2.gain)
+
+    require(predict.isDefined, "must calculate predict for each node")
+
+    (bestSplit, bestSplitStats, predict.get)
   }
 
   /**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 23f74d5360fe54fa6737234b6b1015a1f97b1ad4..987fe632c91ed0f994295f89f0dd6a746d73daf1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -49,6 +49,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
  *                                k) implies the feature n is categorical with k categories 0,
  *                                1, 2, ... , k-1. It's important to note that features are
  *                                zero-indexed.
+ * @param minInstancesPerNode Minimum number of instances each child must have after split.
+ *                            Default value is 1. If a split cause left or right child
+ *                            to have less than minInstancesPerNode,
+ *                            this split will not be considered as a valid split.
+ * @param minInfoGain Minimum information gain a split must get. Default value is 0.0.
+ *                    If a split has less information gain than minInfoGain,
+ *                    this split will not be considered as a valid split.
  * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
  *                      256 MB.
  */
@@ -61,6 +68,8 @@ class Strategy (
     val maxBins: Int = 32,
     val quantileCalculationStrategy: QuantileStrategy = Sort,
     val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
+    val minInstancesPerNode: Int = 1,
+    val minInfoGain: Double = 0.0,
     val maxMemoryInMB: Int = 256) extends Serializable {
 
   if (algo == Classification) {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index e95add7558bcf026745beed255672e3665155fd1..5ceaa8154d11a77c3c77effb87e85e7d9a8ed560 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -45,7 +45,9 @@ private[tree] class DecisionTreeMetadata(
     val unorderedFeatures: Set[Int],
     val numBins: Array[Int],
     val impurity: Impurity,
-    val quantileStrategy: QuantileStrategy) extends Serializable {
+    val quantileStrategy: QuantileStrategy,
+    val minInstancesPerNode: Int,
+    val minInfoGain: Double) extends Serializable {
 
   def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex)
 
@@ -127,7 +129,8 @@ private[tree] object DecisionTreeMetadata {
 
     new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max,
       strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins,
-      strategy.impurity, strategy.quantileCalculationStrategy)
+      strategy.impurity, strategy.quantileCalculationStrategy,
+      strategy.minInstancesPerNode, strategy.minInfoGain)
   }
 
   /**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index fb12298e0f5d3809e217a0aef0715e0d86f63e8a..f3e2619bd8ba0ee405d7a292912a26dbdf636338 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -26,20 +26,26 @@ import org.apache.spark.annotation.DeveloperApi
  * @param impurity current node impurity
  * @param leftImpurity left node impurity
  * @param rightImpurity right node impurity
- * @param predict predicted value
- * @param prob probability of the label (classification only)
  */
 @DeveloperApi
 class InformationGainStats(
     val gain: Double,
     val impurity: Double,
     val leftImpurity: Double,
-    val rightImpurity: Double,
-    val predict: Double,
-    val prob: Double = 0.0) extends Serializable {
+    val rightImpurity: Double) extends Serializable {
 
   override def toString = {
-    "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f"
-      .format(gain, impurity, leftImpurity, rightImpurity, predict, prob)
+    "gain = %f, impurity = %f, left impurity = %f, right impurity = %f"
+      .format(gain, impurity, leftImpurity, rightImpurity)
   }
 }
+
+
+private[tree] object InformationGainStats {
+  /**
+   * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to
+   * denote that current split doesn't satisfies minimum info gain or
+   * minimum number of instances per node.
+   */
+  val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0)
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
new file mode 100644
index 0000000000000000000000000000000000000000..6fac2be2797bc9cf360afd4f415e5d42f03c9747
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Predict.scala
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.model
+
+import org.apache.spark.annotation.DeveloperApi
+
+/**
+ * :: DeveloperApi ::
+ * Predicted value for a node
+ * @param predict predicted value
+ * @param prob probability of the label (classification only)
+ */
+@DeveloperApi
+private[tree] class Predict(
+    val predict: Double,
+    val prob: Double = 0.0) extends Serializable{
+
+  override def toString = {
+    "predict = %f, prob = %f".format(predict, prob)
+  }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
index 50fb48b40de3d1eba2b21d580afafeee64fd4fd2..b7a85f58544a3bbf426ee2d2094f63836b31c634 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Split.scala
@@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree.model
 
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
+import org.apache.spark.mllib.tree.configuration.FeatureType
+import org.apache.spark.mllib.tree.configuration.FeatureType.FeatureType
 
 /**
  * :: DeveloperApi ::
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 69482f2acbb40b64ba19923bfccd8c523e2de1f2..fd8547c1660fca51dfd7a7687ec5d096c4f91a3d 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
 import org.apache.spark.mllib.tree.configuration.Strategy
 import org.apache.spark.mllib.tree.impl.{DecisionTreeMetadata, TreePoint}
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
-import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Node}
+import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node}
 import org.apache.spark.mllib.util.LocalSparkContext
 
 class DecisionTreeSuite extends FunSuite with LocalSparkContext {
@@ -279,9 +279,10 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(split.threshold === Double.MinValue)
 
     val stats = bestSplits(0)._2
+    val predict = bestSplits(0)._3
     assert(stats.gain > 0)
-    assert(stats.predict === 1)
-    assert(stats.prob === 0.6)
+    assert(predict.predict === 1)
+    assert(predict.prob === 0.6)
     assert(stats.impurity > 0.2)
   }
 
@@ -312,8 +313,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(split.threshold === Double.MinValue)
 
     val stats = bestSplits(0)._2
+    val predict = bestSplits(0)._3.predict
     assert(stats.gain > 0)
-    assert(stats.predict === 0.6)
+    assert(predict === 0.6)
     assert(stats.impurity > 0.2)
   }
 
@@ -387,7 +389,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bestSplits(0)._2.gain === 0)
     assert(bestSplits(0)._2.leftImpurity === 0)
     assert(bestSplits(0)._2.rightImpurity === 0)
-    assert(bestSplits(0)._2.predict === 1)
+    assert(bestSplits(0)._3.predict === 1)
   }
 
   test("Binary classification stump with fixed label 0 for Entropy") {
@@ -414,7 +416,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bestSplits(0)._2.gain === 0)
     assert(bestSplits(0)._2.leftImpurity === 0)
     assert(bestSplits(0)._2.rightImpurity === 0)
-    assert(bestSplits(0)._2.predict === 0)
+    assert(bestSplits(0)._3.predict === 0)
   }
 
   test("Binary classification stump with fixed label 1 for Entropy") {
@@ -441,7 +443,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bestSplits(0)._2.gain === 0)
     assert(bestSplits(0)._2.leftImpurity === 0)
     assert(bestSplits(0)._2.rightImpurity === 0)
-    assert(bestSplits(0)._2.predict === 1)
+    assert(bestSplits(0)._3.predict === 1)
   }
 
   test("Second level node building with vs. without groups") {
@@ -490,7 +492,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       assert(bestSplits(i)._2.impurity === bestSplitsWithGroups(i)._2.impurity)
       assert(bestSplits(i)._2.leftImpurity === bestSplitsWithGroups(i)._2.leftImpurity)
       assert(bestSplits(i)._2.rightImpurity === bestSplitsWithGroups(i)._2.rightImpurity)
-      assert(bestSplits(i)._2.predict === bestSplitsWithGroups(i)._2.predict)
+      assert(bestSplits(i)._3.predict === bestSplitsWithGroups(i)._3.predict)
     }
   }
 
@@ -674,6 +676,91 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     validateClassifier(model, arr, 0.6)
   }
 
+  test("split must satisfy min instances per node requirements") {
+    val arr = new Array[LabeledPoint](3)
+    arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
+    arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
+    arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
+
+    val input = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini,
+      maxDepth = 2, numClassesForClassification = 2, minInstancesPerNode = 2)
+
+    val model = DecisionTree.train(input, strategy)
+    assert(model.topNode.isLeaf)
+    assert(model.topNode.predict == 0.0)
+    val predicts = input.map(p => model.predict(p.features)).collect()
+    predicts.foreach { predict =>
+      assert(predict == 0.0)
+    }
+
+    // test for findBestSplits when no valid split can be found
+    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
+      new Array[Node](0), splits, bins, 10)
+
+    assert(bestSplits.length == 1)
+    val bestInfoStats = bestSplits(0)._2
+    assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
+  }
+
+  test("don't choose split that doesn't satisfy min instance per node requirements") {
+    // if a split doesn't satisfy min instances per node requirements,
+    // this split is invalid, even though the information gain of split is large.
+    val arr = new Array[LabeledPoint](4)
+    arr(0) = new LabeledPoint(0.0, Vectors.dense(0.0, 1.0))
+    arr(1) = new LabeledPoint(1.0, Vectors.dense(1.0, 1.0))
+    arr(2) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
+    arr(3) = new LabeledPoint(0.0, Vectors.dense(0.0, 0.0))
+
+    val input = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini,
+      maxBins = 2, maxDepth = 2, categoricalFeaturesInfo = Map(0 -> 2, 1-> 2),
+      numClassesForClassification = 2, minInstancesPerNode = 2)
+    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
+      new Array[Node](0), splits, bins, 10)
+
+    assert(bestSplits.length == 1)
+    val bestSplit = bestSplits(0)._1
+    val bestSplitStats = bestSplits(0)._1
+    assert(bestSplit.feature == 1)
+    assert(bestSplitStats != InformationGainStats.invalidInformationGainStats)
+  }
+
+  test("split must satisfy min info gain requirements") {
+    val arr = new Array[LabeledPoint](3)
+    arr(0) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 0.0))))
+    arr(1) = new LabeledPoint(1.0, Vectors.sparse(2, Seq((1, 1.0))))
+    arr(2) = new LabeledPoint(0.0, Vectors.sparse(2, Seq((0, 1.0))))
+
+    val input = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+      numClassesForClassification = 2, minInfoGain = 1.0)
+
+    val model = DecisionTree.train(input, strategy)
+    assert(model.topNode.isLeaf)
+    assert(model.topNode.predict == 0.0)
+    val predicts = input.map(p => model.predict(p.features)).collect()
+    predicts.foreach { predict =>
+      assert(predict == 0.0)
+    }
+
+    // test for findBestSplits when no valid split can be found
+    val metadata = DecisionTreeMetadata.buildMetadata(input, strategy)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, metadata)
+    val treeInput = TreePoint.convertToTreeRDD(input, bins, metadata)
+    val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(8), metadata, 0,
+      new Array[Node](0), splits, bins, 10)
+
+    assert(bestSplits.length == 1)
+    val bestInfoStats = bestSplits(0)._2
+    assert(bestInfoStats == InformationGainStats.invalidInformationGainStats)
+  }
 }
 
 object DecisionTreeSuite {