diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 008dd19c2498d0403274cd6f7633df2996cfb4e8..82e1ed85a0a147b206bccfaf52be2ab5567fbf6b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -996,7 +996,7 @@ private[spark] object RandomForest extends Logging {
     require(metadata.isContinuous(featureIndex),
       "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
 
-    val splits = if (featureSamples.isEmpty) {
+    val splits: Array[Double] = if (featureSamples.isEmpty) {
       Array.empty[Double]
     } else {
       val numSplits = metadata.numSplits(featureIndex)
@@ -1009,10 +1009,15 @@ private[spark] object RandomForest extends Logging {
       // sort distinct values
       val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
 
-      // if possible splits is not enough or just enough, just return all possible splits
       val possibleSplits = valueCounts.length - 1
-      if (possibleSplits <= numSplits) {
-        valueCounts.map(_._1).init
+      if (possibleSplits == 0) {
+        // constant feature
+        Array.empty[Double]
+      } else if (possibleSplits <= numSplits) {
+        // if possible splits is not enough or just enough, just return all possible splits
+        (1 to possibleSplits)
+          .map(index => (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0)
+          .toArray
       } else {
         // stride between splits
         val stride: Double = numSamples.toDouble / (numSplits + 1)
@@ -1037,7 +1042,7 @@ private[spark] object RandomForest extends Logging {
           // makes the gap between currentCount and targetCount smaller,
           // previous value is a split threshold.
           if (previousGap < currentGap) {
-            splitsBuilder += valueCounts(index - 1)._1
+            splitsBuilder += (valueCounts(index - 1)._1 + valueCounts(index)._1) / 2.0
             targetCount += stride
           }
           index += 1
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index e1ab7c2d6520b25a3c65d6d1e6a5860bf7b2637e..df155b464c64b043adf0b430be8bed12b010d18d 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -104,6 +104,31 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
       assert(splits.distinct.length === splits.length)
     }
 
+    // SPARK-16957: Use midpoints for split values.
+    {
+      val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+        Map(), Set(),
+        Array(3), Gini, QuantileStrategy.Sort,
+        0, 0, 0.0, 0, 0
+      )
+
+      // possibleSplits <= numSplits
+      {
+        val featureSamples = Array(0, 1, 0, 0, 1, 0, 1, 1).map(_.toDouble)
+        val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+        val expectedSplits = Array((0.0 + 1.0) / 2)
+        assert(splits === expectedSplits)
+      }
+
+      // possibleSplits > numSplits
+      {
+        val featureSamples = Array(0, 0, 1, 1, 2, 2, 3, 3).map(_.toDouble)
+        val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+        val expectedSplits = Array((0.0 + 1.0) / 2, (2.0 + 3.0) / 2)
+        assert(splits === expectedSplits)
+      }
+    }
+
     // find splits should not return identical splits
     // when there are not enough split candidates, reduce the number of splits in metadata
     {
@@ -112,9 +137,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
         Array(5), Gini, QuantileStrategy.Sort,
         0, 0, 0.0, 0, 0
       )
-      val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
+      val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3).map(_.toDouble)
       val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
-      assert(splits === Array(1.0, 2.0))
+      val expectedSplits = Array((1.0 + 2.0) / 2, (2.0 + 3.0) / 2)
+      assert(splits === expectedSplits)
       // check returned splits are distinct
       assert(splits.distinct.length === splits.length)
     }
@@ -126,9 +152,11 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
         Array(3), Gini, QuantileStrategy.Sort,
         0, 0, 0.0, 0, 0
       )
-      val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
+      val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5)
+        .map(_.toDouble)
       val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
-      assert(splits === Array(2.0, 3.0))
+      val expectedSplits = Array((2.0 + 3.0) / 2, (3.0 + 4.0) / 2)
+      assert(splits === expectedSplits)
     }
 
     // find splits when most samples close to the maximum
@@ -138,9 +166,10 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
         Array(2), Gini, QuantileStrategy.Sort,
         0, 0, 0.0, 0, 0
       )
-      val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
+      val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
       val splits = RandomForest.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
-      assert(splits === Array(1.0))
+      val expectedSplits = Array((1.0 + 2.0) / 2)
+      assert(splits === expectedSplits)
     }
 
     // find splits for constant feature
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index a6089fc8b9d32e46ad9d3346ac513855b08c790f..619fa16d463f5f77c4ffcd92f310af397b4be7c8 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -199,9 +199,9 @@ class DecisionTree(object):
 
         >>> print(model.toDebugString())
         DecisionTreeModel classifier of depth 1 with 3 nodes
-          If (feature 0 <= 0.0)
+          If (feature 0 <= 0.5)
            Predict: 0.0
-          Else (feature 0 > 0.0)
+          Else (feature 0 > 0.5)
            Predict: 1.0
         <BLANKLINE>
         >>> model.predict(array([1.0]))
@@ -383,14 +383,14 @@ class RandomForest(object):
           Tree 0:
             Predict: 1.0
           Tree 1:
-            If (feature 0 <= 1.0)
+            If (feature 0 <= 1.5)
              Predict: 0.0
-            Else (feature 0 > 1.0)
+            Else (feature 0 > 1.5)
              Predict: 1.0
           Tree 2:
-            If (feature 0 <= 1.0)
+            If (feature 0 <= 1.5)
              Predict: 0.0
-            Else (feature 0 > 1.0)
+            Else (feature 0 > 1.5)
              Predict: 1.0
         <BLANKLINE>
         >>> model.predict([2.0])