diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index fd0b9556c7d5444ca9af7b7cda6a2f22656dc8db..ba7ccd8ce4b8b06179606c1bd225c79004fedf67 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -25,16 +25,14 @@ import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
 import org.apache.spark.mllib.classification._
 import org.apache.spark.mllib.clustering._
-import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
 import org.apache.spark.mllib.optimization._
 import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors}
 import org.apache.spark.mllib.random.{RandomRDDGenerators => RG}
 import org.apache.spark.mllib.recommendation._
 import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.tree.configuration.Algo._
-import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
 import org.apache.spark.mllib.tree.DecisionTree
-import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Impurity, Variance}
+import org.apache.spark.mllib.tree.impurity._
 import org.apache.spark.mllib.tree.model.DecisionTreeModel
 import org.apache.spark.mllib.stat.Statistics
 import org.apache.spark.mllib.stat.correlation.CorrelationNames
@@ -523,17 +521,8 @@ class PythonMLLibAPI extends Serializable {
 
     val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
 
-    val algo: Algo = algoStr match {
-      case "classification" => Classification
-      case "regression" => Regression
-      case _ => throw new IllegalArgumentException(s"Bad algoStr parameter: $algoStr")
-    }
-    val impurity: Impurity = impurityStr match {
-      case "gini" => Gini
-      case "entropy" => Entropy
-      case "variance" => Variance
-      case _ => throw new IllegalArgumentException(s"Bad impurityStr parameter: $impurityStr")
-    }
+    val algo = Algo.fromString(algoStr)
+    val impurity = Impurities.fromString(impurityStr)
 
     val strategy = new Strategy(
       algo = algo,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 1d03e6e3b36cfbd28cc523caed4b44a26649fe8c..c8a865659682fd800d532a9d1fc8659197f4e994 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -17,14 +17,18 @@
 
 package org.apache.spark.mllib.tree
 
+import org.apache.spark.api.java.JavaRDD
+
+import scala.collection.JavaConverters._
+
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.Logging
 import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
 import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
 import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
-import org.apache.spark.mllib.tree.impurity.Impurity
+import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity}
 import org.apache.spark.mllib.tree.model._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.util.random.XORShiftRandom
@@ -200,6 +204,10 @@ object DecisionTree extends Serializable with Logging {
    * Method to train a decision tree model.
    * The method supports binary and multiclass classification and regression.
    *
+   * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
+   *       and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
+   *       is recommended to clearly separate classification and regression.
+   *
    * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
    *              For classification, labels should take values {0, 1, ..., numClasses-1}.
    *              For regression, labels are real numbers.
@@ -213,10 +221,12 @@ object DecisionTree extends Serializable with Logging {
   }
 
   /**
-   * Method to train a decision tree model where the instances are represented as an RDD of
-   * (label, features) pairs. The method supports binary classification and regression. For the
-   * binary classification, the label for each instance should either be 0 or 1 to denote the two
-   * classes.
+   * Method to train a decision tree model.
+   * The method supports binary and multiclass classification and regression.
+   *
+   * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
+   *       and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
+   *       is recommended to clearly separate classification and regression.
    *
    * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
    *              For classification, labels should take values {0, 1, ..., numClasses-1}.
@@ -237,10 +247,12 @@ object DecisionTree extends Serializable with Logging {
   }
 
   /**
-   * Method to train a decision tree model where the instances are represented as an RDD of
-   * (label, features) pairs. The method supports binary classification and regression. For the
-   * binary classification, the label for each instance should either be 0 or 1 to denote the two
-   * classes.
+   * Method to train a decision tree model.
+   * The method supports binary and multiclass classification and regression.
+   *
+   * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
+   *       and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
+   *       is recommended to clearly separate classification and regression.
    *
    * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
    *              For classification, labels should take values {0, 1, ..., numClasses-1}.
@@ -263,11 +275,12 @@ object DecisionTree extends Serializable with Logging {
   }
 
   /**
-   * Method to train a decision tree model where the instances are represented as an RDD of
-   * (label, features) pairs. The decision tree method supports binary classification and
-   * regression. For the binary classification, the label for each instance should either be 0 or
-   * 1 to denote the two classes. The method also supports categorical features inputs where the
-   * number of categories can specified using the categoricalFeaturesInfo option.
+   * Method to train a decision tree model.
+   * The method supports binary and multiclass classification and regression.
+   *
+   * Note: Using [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
+   *       and [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
+   *       is recommended to clearly separate classification and regression.
    *
    * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
    *              For classification, labels should take values {0, 1, ..., numClasses-1}.
@@ -279,11 +292,9 @@ object DecisionTree extends Serializable with Logging {
    * @param numClassesForClassification number of classes for classification. Default value of 2.
    * @param maxBins maximum number of bins used for splitting features
    * @param quantileCalculationStrategy  algorithm for calculating quantiles
-   * @param categoricalFeaturesInfo A map storing information about the categorical variables and
-   *                                the number of discrete values they take. For example,
-   *                                an entry (n -> k) implies the feature n is categorical with k
-   *                                categories 0, 1, 2, ... , k-1. It's important to note that
-   *                                features are zero-indexed.
+   * @param categoricalFeaturesInfo Map storing arity of categorical features.
+   *                                E.g., an entry (n -> k) indicates that feature n is categorical
+   *                                with k categories indexed from 0: {0, 1, ..., k-1}.
    * @return DecisionTreeModel that can be used for prediction
    */
   def train(
@@ -300,6 +311,93 @@ object DecisionTree extends Serializable with Logging {
     new DecisionTree(strategy).train(input)
   }
 
+  /**
+   * Method to train a decision tree model for binary or multiclass classification.
+   *
+   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+   *              Labels should take values {0, 1, ..., numClasses-1}.
+   * @param numClassesForClassification number of classes for classification.
+   * @param categoricalFeaturesInfo Map storing arity of categorical features.
+   *                                E.g., an entry (n -> k) indicates that feature n is categorical
+   *                                with k categories indexed from 0: {0, 1, ..., k-1}.
+   * @param impurity Criterion used for information gain calculation.
+   *                 Supported values: "gini" (recommended) or "entropy".
+   * @param maxDepth Maximum depth of the tree.
+   *                 E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+   *                  (suggested value: 4)
+   * @param maxBins maximum number of bins used for splitting features
+   *                 (suggested value: 100)
+   * @return DecisionTreeModel that can be used for prediction
+   */
+  def trainClassifier(
+      input: RDD[LabeledPoint],
+      numClassesForClassification: Int,
+      categoricalFeaturesInfo: Map[Int, Int],
+      impurity: String,
+      maxDepth: Int,
+      maxBins: Int): DecisionTreeModel = {
+    val impurityType = Impurities.fromString(impurity)
+    train(input, Classification, impurityType, maxDepth, numClassesForClassification, maxBins, Sort,
+      categoricalFeaturesInfo)
+  }
+
+  /**
+   * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainClassifier]]
+   */
+  def trainClassifier(
+      input: JavaRDD[LabeledPoint],
+      numClassesForClassification: Int,
+      categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
+      impurity: String,
+      maxDepth: Int,
+      maxBins: Int): DecisionTreeModel = {
+    trainClassifier(input.rdd, numClassesForClassification,
+      categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+      impurity, maxDepth, maxBins)
+  }
+
+  /**
+   * Method to train a decision tree model for regression.
+   *
+   * @param input Training dataset: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]].
+   *              Labels are real numbers.
+   * @param categoricalFeaturesInfo Map storing arity of categorical features.
+   *                                E.g., an entry (n -> k) indicates that feature n is categorical
+   *                                with k categories indexed from 0: {0, 1, ..., k-1}.
+   * @param impurity Criterion used for information gain calculation.
+   *                 Supported values: "variance".
+   * @param maxDepth Maximum depth of the tree.
+   *                 E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+   *                  (suggested value: 4)
+   * @param maxBins maximum number of bins used for splitting features
+   *                 (suggested value: 100)
+   * @return DecisionTreeModel that can be used for prediction
+   */
+  def trainRegressor(
+      input: RDD[LabeledPoint],
+      categoricalFeaturesInfo: Map[Int, Int],
+      impurity: String,
+      maxDepth: Int,
+      maxBins: Int): DecisionTreeModel = {
+    val impurityType = Impurities.fromString(impurity)
+    train(input, Regression, impurityType, maxDepth, 0, maxBins, Sort, categoricalFeaturesInfo)
+  }
+
+  /**
+   * Java-friendly API for [[org.apache.spark.mllib.tree.DecisionTree$#trainRegressor]]
+   */
+  def trainRegressor(
+      input: JavaRDD[LabeledPoint],
+      categoricalFeaturesInfo: java.util.Map[java.lang.Integer, java.lang.Integer],
+      impurity: String,
+      maxDepth: Int,
+      maxBins: Int): DecisionTreeModel = {
+    trainRegressor(input.rdd,
+      categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap,
+      impurity, maxDepth, maxBins)
+  }
+
+
   private val InvalidBinIndex = -1
 
   /**
@@ -1331,16 +1429,15 @@ object DecisionTree extends Serializable with Logging {
    * Categorical features:
    *   For each feature, there is 1 bin per split.
    *   Splits and bins are handled in 2 ways:
-   *   (a) For multiclass classification with a low-arity feature
+   *   (a) "unordered features"
+   *       For multiclass classification with a low-arity feature
    *       (i.e., if isMulticlass && isSpaceSufficientForAllCategoricalSplits),
    *       the feature is split based on subsets of categories.
-   *       There are 2^(maxFeatureValue - 1) - 1 splits.
-   *   (b) For regression and binary classification,
+   *       There are math.pow(2, maxFeatureValue - 1) - 1 splits.
+   *   (b) "ordered features"
+   *       For regression and binary classification,
    *       and for multiclass classification with a high-arity feature,
-   *       there is one split per category.
-
-   * Categorical case (a) features are called unordered features.
-   * Other cases are called ordered features.
+   *       there is one bin per category.
    *
    * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
    * @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala
index 79a01f58319e8ffefc55e7ebab4289f399ca59f9..0ef9c6181a0a06ca8ab9c26e5ee431beee82be37 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Algo.scala
@@ -27,4 +27,10 @@ import org.apache.spark.annotation.Experimental
 object Algo extends Enumeration {
   type Algo = Value
   val Classification, Regression = Value
+
+  private[mllib] def fromString(name: String): Algo = name match {
+    case "classification" => Classification
+    case "regression" => Regression
+    case _ => throw new IllegalArgumentException(s"Did not recognize Algo name: $name")
+  }
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala
new file mode 100644
index 0000000000000000000000000000000000000000..9a6452aa13a61f5e5a879adf57b36692d24d98aa
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurities.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impurity
+
+/**
+ * Factory for Impurity instances.
+ */
+private[mllib] object Impurities {
+
+  def fromString(name: String): Impurity = name match {
+    case "gini" => Gini
+    case "entropy" => Entropy
+    case "variance" => Variance
+    case _ => throw new IllegalArgumentException(s"Did not recognize Impurity name: $name")
+  }
+
+}
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index 2518001ea0b93b85be339751c570332ddcf28c1b..e1a4671709b7df29a77710d6ef73df8f2e9f325a 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -131,7 +131,7 @@ class DecisionTree(object):
     """
 
     @staticmethod
-    def trainClassifier(data, numClasses, categoricalFeaturesInfo={},
+    def trainClassifier(data, numClasses, categoricalFeaturesInfo,
                         impurity="gini", maxDepth=4, maxBins=100):
         """
         Train a DecisionTreeModel for classification.
@@ -150,12 +150,20 @@ class DecisionTree(object):
         :param maxBins: Number of bins used for finding splits at each node.
         :return: DecisionTreeModel
         """
-        return DecisionTree.train(data, "classification", numClasses,
-                                  categoricalFeaturesInfo,
-                                  impurity, maxDepth, maxBins)
+        sc = data.context
+        dataBytes = _get_unmangled_labeled_point_rdd(data)
+        categoricalFeaturesInfoJMap = \
+            MapConverter().convert(categoricalFeaturesInfo,
+                                   sc._gateway._gateway_client)
+        model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
+            dataBytes._jrdd, "classification",
+            numClasses, categoricalFeaturesInfoJMap,
+            impurity, maxDepth, maxBins)
+        dataBytes.unpersist()
+        return DecisionTreeModel(sc, model)
 
     @staticmethod
-    def trainRegressor(data, categoricalFeaturesInfo={},
+    def trainRegressor(data, categoricalFeaturesInfo,
                        impurity="variance", maxDepth=4, maxBins=100):
         """
         Train a DecisionTreeModel for regression.
@@ -173,42 +181,14 @@ class DecisionTree(object):
         :param maxBins: Number of bins used for finding splits at each node.
         :return: DecisionTreeModel
         """
-        return DecisionTree.train(data, "regression", 0,
-                                  categoricalFeaturesInfo,
-                                  impurity, maxDepth, maxBins)
-
-    @staticmethod
-    def train(data, algo, numClasses, categoricalFeaturesInfo,
-              impurity, maxDepth, maxBins=100):
-        """
-        Train a DecisionTreeModel for classification or regression.
-
-        :param data: Training data: RDD of LabeledPoint.
-                     For classification, labels are integers
-                      {0,1,...,numClasses}.
-                     For regression, labels are real numbers.
-        :param algo: "classification" or "regression"
-        :param numClasses: Number of classes for classification.
-        :param categoricalFeaturesInfo: Map from categorical feature index
-                                        to number of categories.
-                                        Any feature not in this map
-                                        is treated as continuous.
-        :param impurity: For classification: "entropy" or "gini".
-                         For regression: "variance".
-        :param maxDepth: Max depth of tree.
-                         E.g., depth 0 means 1 leaf node.
-                         Depth 1 means 1 internal node + 2 leaf nodes.
-        :param maxBins: Number of bins used for finding splits at each node.
-        :return: DecisionTreeModel
-        """
         sc = data.context
         dataBytes = _get_unmangled_labeled_point_rdd(data)
         categoricalFeaturesInfoJMap = \
             MapConverter().convert(categoricalFeaturesInfo,
                                    sc._gateway._gateway_client)
         model = sc._jvm.PythonMLLibAPI().trainDecisionTreeModel(
-            dataBytes._jrdd, algo,
-            numClasses, categoricalFeaturesInfoJMap,
+            dataBytes._jrdd, "regression",
+            0, categoricalFeaturesInfoJMap,
             impurity, maxDepth, maxBins)
         dataBytes.unpersist()
         return DecisionTreeModel(sc, model)