diff --git a/docs/mllib-decision-tree.md b/docs/mllib-decision-tree.md
index 9cd768599e5296e8b574fdbc100e51d90d90372b..9cbd880897578a3b2f1802dc52ba7dbeb205469b 100644
--- a/docs/mllib-decision-tree.md
+++ b/docs/mllib-decision-tree.md
@@ -77,15 +77,17 @@ bins if the condition is not satisfied.
 
 **Categorical features**
 
-For `$M$` categorical features, one could come up with `$2^M-1$` split candidates. However, for
-binary classification, the number of split candidates can be reduced to `$M-1$` by ordering the
+For `$M$` categorical feature values, one could come up with `$2^(M-1)-1$` split candidates. For
+binary classification, we can reduce the number of split candidates to `$M-1$` by ordering the
 categorical feature values by the proportion of labels falling in one of the two classes (see
 Section 9.2.4 in
 [Elements of Statistical Machine Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/) for
 details). For example, for a binary classification problem with one categorical feature with three
 categories A, B and C with corresponding proportion of label 1 as 0.2, 0.6 and 0.4, the categorical
 features are ordered as A followed by C followed B or A, B, C. The two split candidates are A \| C, B
-and A , B \| C where \| denotes the split.
+and A , B \| C where \| denotes the split. A similar heuristic is used for multiclass classification
+when `$2^(M-1)-1$` is greater than the number of bins -- the impurity for each categorical feature value
+is used for ordering.
 
 ### Stopping rule
 
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index b3cc361154198e3442432433a88efe7477ffdcb2..43f13fe24f0d09ba4e5141f327df9e6bd46daf79 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -49,6 +49,7 @@ object DecisionTreeRunner {
   case class Params(
       input: String = null,
       algo: Algo = Classification,
+      numClassesForClassification: Int = 2,
       maxDepth: Int = 5,
       impurity: ImpurityType = Gini,
       maxBins: Int = 100)
@@ -68,6 +69,10 @@ object DecisionTreeRunner {
       opt[Int]("maxDepth")
         .text(s"max depth of the tree, default: ${defaultParams.maxDepth}")
         .action((x, c) => c.copy(maxDepth = x))
+      opt[Int]("numClassesForClassification")
+        .text(s"number of classes for classification, "
+          + s"default: ${defaultParams.numClassesForClassification}")
+        .action((x, c) => c.copy(numClassesForClassification = x))
       opt[Int]("maxBins")
         .text(s"max number of bins, default: ${defaultParams.maxBins}")
         .action((x, c) => c.copy(maxBins = x))
@@ -118,7 +123,13 @@ object DecisionTreeRunner {
       case Variance => impurity.Variance
     }
 
-    val strategy = new Strategy(params.algo, impurityCalculator, params.maxDepth, params.maxBins)
+    val strategy
+      = new Strategy(
+          algo = params.algo,
+          impurity = impurityCalculator,
+          maxDepth = params.maxDepth,
+          maxBins = params.maxBins,
+          numClassesForClassification = params.numClassesForClassification)
     val model = DecisionTree.train(training, strategy)
 
     if (params.algo == Classification) {
@@ -139,12 +150,8 @@ object DecisionTreeRunner {
    */
   private def accuracyScore(
       model: DecisionTreeModel,
-      data: RDD[LabeledPoint],
-      threshold: Double = 0.5): Double = {
-    def predictedValue(features: Vector): Double = {
-      if (model.predict(features) < threshold) 0.0 else 1.0
-    }
-    val correctCount = data.filter(y => predictedValue(y.features) == y.label).count()
+      data: RDD[LabeledPoint]): Double = {
+    val correctCount = data.filter(y => model.predict(y.features) == y.label).count()
     val count = data.count()
     correctCount.toDouble / count
   }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 74d5d7ba1096091e3bb433204e0909c8290ed2fa..ad32e3f4560fe22b32cbf9ace642ae4741d58019 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -77,11 +77,9 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
     // Max memory usage for aggregates
     val maxMemoryUsage = strategy.maxMemoryInMB * 1024 * 1024
     logDebug("max memory usage for aggregates = " + maxMemoryUsage + " bytes.")
-    val numElementsPerNode =
-      strategy.algo match {
-        case Classification => 2 * numBins * numFeatures
-        case Regression => 3 * numBins * numFeatures
-      }
+    val numElementsPerNode = DecisionTree.getElementsPerNode(numFeatures, numBins,
+      strategy.numClassesForClassification, strategy.isMulticlassWithCategoricalFeatures,
+      strategy.algo)
 
     logDebug("numElementsPerNode = " + numElementsPerNode)
     val arraySizePerNode = 8 * numElementsPerNode // approx. memory usage for bin aggregate array
@@ -109,8 +107,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
       logDebug("#####################################")
 
       // Find best split for all nodes at a level.
-      val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
-        level, filters, splits, bins, maxLevelForSingleGroup)
+      val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities,
+        strategy, level, filters, splits, bins, maxLevelForSingleGroup)
 
       for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
         // Extract info for nodes at the current level.
@@ -212,7 +210,7 @@ object DecisionTree extends Serializable with Logging {
    * @return a DecisionTreeModel that can be used for prediction
   */
   def train(input: RDD[LabeledPoint], strategy: Strategy): DecisionTreeModel = {
-    new DecisionTree(strategy).train(input: RDD[LabeledPoint])
+    new DecisionTree(strategy).train(input)
   }
 
   /**
@@ -233,10 +231,33 @@ object DecisionTree extends Serializable with Logging {
       algo: Algo,
       impurity: Impurity,
       maxDepth: Int): DecisionTreeModel = {
-    val strategy = new Strategy(algo,impurity,maxDepth)
-    new DecisionTree(strategy).train(input: RDD[LabeledPoint])
+    val strategy = new Strategy(algo, impurity, maxDepth)
+    new DecisionTree(strategy).train(input)
   }
 
+  /**
+   * Method to train a decision tree model where the instances are represented as an RDD of
+   * (label, features) pairs. The method supports binary classification and regression. For the
+   * binary classification, the label for each instance should either be 0 or 1 to denote the two
+   * classes.
+   *
+   * @param input input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as
+   *              training data
+   * @param algo algorithm, classification or regression
+   * @param impurity impurity criterion used for information gain calculation
+   * @param maxDepth maxDepth maximum depth of the tree
+   * @param numClassesForClassification number of classes for classification. Default value of 2.
+   * @return a DecisionTreeModel that can be used for prediction
+   */
+  def train(
+      input: RDD[LabeledPoint],
+      algo: Algo,
+      impurity: Impurity,
+      maxDepth: Int,
+      numClassesForClassification: Int): DecisionTreeModel = {
+    val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification)
+    new DecisionTree(strategy).train(input)
+  }
 
   /**
    * Method to train a decision tree model where the instances are represented as an RDD of
@@ -250,6 +271,7 @@ object DecisionTree extends Serializable with Logging {
    * @param algo classification or regression
    * @param impurity criterion used for information gain calculation
    * @param maxDepth  maximum depth of the tree
+   * @param numClassesForClassification number of classes for classification. Default value of 2.
    * @param maxBins maximum number of bins used for splitting features
    * @param quantileCalculationStrategy  algorithm for calculating quantiles
    * @param categoricalFeaturesInfo A map storing information about the categorical variables and
@@ -264,12 +286,13 @@ object DecisionTree extends Serializable with Logging {
       algo: Algo,
       impurity: Impurity,
       maxDepth: Int,
+      numClassesForClassification: Int,
       maxBins: Int,
       quantileCalculationStrategy: QuantileStrategy,
       categoricalFeaturesInfo: Map[Int,Int]): DecisionTreeModel = {
-    val strategy = new Strategy(algo, impurity, maxDepth, maxBins, quantileCalculationStrategy,
-      categoricalFeaturesInfo)
-    new DecisionTree(strategy).train(input: RDD[LabeledPoint])
+    val strategy = new Strategy(algo, impurity, maxDepth, numClassesForClassification, maxBins,
+      quantileCalculationStrategy, categoricalFeaturesInfo)
+    new DecisionTree(strategy).train(input)
   }
 
   private val InvalidBinIndex = -1
@@ -381,6 +404,14 @@ object DecisionTree extends Serializable with Logging {
     logDebug("numFeatures = " + numFeatures)
     val numBins = bins(0).length
     logDebug("numBins = " + numBins)
+    val numClasses = strategy.numClassesForClassification
+    logDebug("numClasses = " + numClasses)
+    val isMulticlassClassification = strategy.isMulticlassClassification
+    logDebug("isMulticlassClassification = " + isMulticlassClassification)
+    val isMulticlassClassificationWithCategoricalFeatures
+      = strategy.isMulticlassWithCategoricalFeatures
+    logDebug("isMultiClassWithCategoricalFeatures = " +
+      isMulticlassClassificationWithCategoricalFeatures)
 
     // shift when more than one group is used at deep tree level
     val groupShift = numNodes * groupIndex
@@ -436,10 +467,8 @@ object DecisionTree extends Serializable with Logging {
     /**
      * Find bin for one feature.
      */
-    def findBin(
-        featureIndex: Int,
-        labeledPoint: LabeledPoint,
-        isFeatureContinuous: Boolean): Int = {
+    def findBin(featureIndex: Int, labeledPoint: LabeledPoint,
+        isFeatureContinuous: Boolean, isSpaceSufficientForAllCategoricalSplits: Boolean): Int = {
       val binForFeatures = bins(featureIndex)
       val feature = labeledPoint.features(featureIndex)
 
@@ -467,17 +496,28 @@ object DecisionTree extends Serializable with Logging {
         -1
       }
 
+      /**
+       * Sequential search helper method to find bin for categorical feature in multiclass
+       * classification. The category is returned since each category can belong to multiple
+       * splits. The actual left/right child allocation per split is performed in the
+       * sequential phase of the bin aggregate operation.
+       */
+      def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = {
+        labeledPoint.features(featureIndex).toInt
+      }
+
       /**
        * Sequential search helper method to find bin for categorical feature.
        */
-      def sequentialBinSearchForCategoricalFeature(): Int = {
-        val numCategoricalBins = strategy.categoricalFeaturesInfo(featureIndex)
+      def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = {
+        val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+        val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
         var binIndex = 0
         while (binIndex < numCategoricalBins) {
           val bin = bins(featureIndex)(binIndex)
-          val category = bin.category
+          val categories = bin.highSplit.categories
           val features = labeledPoint.features
-          if (category == features(featureIndex)) {
+          if (categories.contains(features(featureIndex))) {
             return binIndex
           }
           binIndex += 1
@@ -494,7 +534,13 @@ object DecisionTree extends Serializable with Logging {
         binIndex
       } else {
         // Perform sequential search to find bin for categorical features.
-        val binIndex = sequentialBinSearchForCategoricalFeature()
+        val binIndex = {
+          if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
+            sequentialBinSearchForUnorderedCategoricalFeatureInClassification()
+          } else {
+            sequentialBinSearchForOrderedCategoricalFeatureInClassification()
+          }
+        }
         if (binIndex == -1){
           throw new UnknownError("no bin was found for categorical variable.")
         }
@@ -506,13 +552,16 @@ object DecisionTree extends Serializable with Logging {
      * Finds bins for all nodes (and all features) at a given level.
      * For l nodes, k features the storage is as follows:
      * label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk,
-     * where b_ij is an integer between 0 and numBins - 1.
+     * where b_ij is an integer between 0 and numBins - 1 for regressions and binary
+     * classification and the categorical feature value in  multiclass classification.
      * Invalid sample is denoted by noting bin for feature 1 as -1.
      */
     def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = {
       // Calculate bin index and label per feature per node.
       val arr = new Array[Double](1 + (numFeatures * numNodes))
+      // First element of the array is the label of the instance.
       arr(0) = labeledPoint.label
+      // Iterate over nodes.
       var nodeIndex = 0
       while (nodeIndex < numNodes) {
         val parentFilters = findParentFilters(nodeIndex)
@@ -525,8 +574,19 @@ object DecisionTree extends Serializable with Logging {
         } else {
           var featureIndex = 0
           while (featureIndex < numFeatures) {
-            val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
-            arr(shift + featureIndex) = findBin(featureIndex, labeledPoint,isFeatureContinuous)
+            val featureInfo = strategy.categoricalFeaturesInfo.get(featureIndex)
+            val isFeatureContinuous = featureInfo.isEmpty
+            if (isFeatureContinuous) {
+              arr(shift + featureIndex)
+                = findBin(featureIndex, labeledPoint, isFeatureContinuous, false)
+            } else {
+              val featureCategories = featureInfo.get
+              val isSpaceSufficientForAllCategoricalSplits
+                = numBins > math.pow(2, featureCategories.toInt - 1) - 1
+              arr(shift + featureIndex)
+                = findBin(featureIndex, labeledPoint, isFeatureContinuous,
+                isSpaceSufficientForAllCategoricalSplits)
+            }
             featureIndex += 1
           }
         }
@@ -535,18 +595,61 @@ object DecisionTree extends Serializable with Logging {
       arr
     }
 
+     // Find feature bins for all nodes at a level.
+    val binMappedRDD = input.map(x => findBinsForLevel(x))
+
+    def updateBinForOrderedFeature(arr: Array[Double], agg: Array[Double], nodeIndex: Int,
+        label: Double, featureIndex: Int) = {
+
+      // Find the bin index for this feature.
+      val arrShift = 1 + numFeatures * nodeIndex
+      val arrIndex = arrShift + featureIndex
+      // Update the left or right count for one bin.
+      val aggShift = numClasses * numBins * numFeatures * nodeIndex
+      val aggIndex
+        = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
+      val labelInt = label.toInt
+      agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1
+    }
+
+    def updateBinForUnorderedFeature(nodeIndex: Int, featureIndex: Int, arr: Array[Double],
+        label: Double, agg: Array[Double], rightChildShift: Int) = {
+      // Find the bin index for this feature.
+      val arrShift = 1 + numFeatures * nodeIndex
+      val arrIndex = arrShift + featureIndex
+      // Update the left or right count for one bin.
+      val aggShift = numClasses * numBins * numFeatures * nodeIndex
+      val aggIndex
+        = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
+      // Find all matching bins and increment their values
+      val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+      val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
+      var binIndex = 0
+      while (binIndex < numCategoricalBins) {
+        val labelInt = label.toInt
+        if (bins(featureIndex)(binIndex).highSplit.categories.contains(labelInt)) {
+          agg(aggIndex + binIndex)
+            = agg(aggIndex + binIndex) + 1
+        } else {
+          agg(rightChildShift + aggIndex + binIndex)
+            = agg(rightChildShift + aggIndex + binIndex) + 1
+        }
+        binIndex += 1
+      }
+    }
+
     /**
      * Performs a sequential aggregation over a partition for classification. For l nodes,
      * k features, either the left count or the right count of one of the p bins is
      * incremented based upon whether the feature is classified as 0 or 1.
      *
      * @param agg Array[Double] storing aggregate calculation of size
-     *            2 * numSplits * numFeatures*numNodes for classification
+     *            numClasses * numSplits * numFeatures*numNodes for classification
      * @param arr Array[Double] of size 1 + (numFeatures * numNodes)
      * @return Array[Double] storing aggregate calculation of size
      *         2 * numSplits * numFeatures * numNodes for classification
      */
-    def classificationBinSeqOp(arr: Array[Double], agg: Array[Double]) {
+    def orderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = {
       // Iterate over all nodes.
       var nodeIndex = 0
       while (nodeIndex < numNodes) {
@@ -559,15 +662,52 @@ object DecisionTree extends Serializable with Logging {
           // Iterate over all features.
           var featureIndex = 0
           while (featureIndex < numFeatures) {
-            // Find the bin index for this feature.
-            val arrShift = 1 + numFeatures * nodeIndex
-            val arrIndex = arrShift + featureIndex
-            // Update the left or right count for one bin.
-            val aggShift = 2 * numBins * numFeatures * nodeIndex
-            val aggIndex = aggShift + 2 * featureIndex * numBins + arr(arrIndex).toInt * 2
-            label match {
-              case 0.0 => agg(aggIndex) = agg(aggIndex) + 1
-              case 1.0 => agg(aggIndex + 1) = agg(aggIndex + 1) + 1
+            updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
+            featureIndex += 1
+          }
+        }
+        nodeIndex += 1
+      }
+    }
+
+    /**
+     * Performs a sequential aggregation over a partition for classification. For l nodes,
+     * k features, either the left count or the right count of one of the p bins is
+     * incremented based upon whether the feature is classified as 0 or 1.
+     *
+     * @param agg Array[Double] storing aggregate calculation of size
+     *            numClasses * numSplits * numFeatures*numNodes for classification
+     * @param arr Array[Double] of size 1 + (numFeatures * numNodes)
+     * @return Array[Double] storing aggregate calculation of size
+     *         2 * numClasses * numSplits * numFeatures * numNodes for classification
+     */
+    def unorderedClassificationBinSeqOp(arr: Array[Double], agg: Array[Double]) = {
+      // Iterate over all nodes.
+      var nodeIndex = 0
+      while (nodeIndex < numNodes) {
+        // Check whether the instance was valid for this nodeIndex.
+        val validSignalIndex = 1 + numFeatures * nodeIndex
+        val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
+        if (isSampleValidForNode) {
+          val rightChildShift = numClasses * numBins * numFeatures * numNodes
+          // actual class label
+          val label = arr(0)
+          // Iterate over all features.
+          var featureIndex = 0
+          while (featureIndex < numFeatures) {
+            val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+            if (isFeatureContinuous) {
+              updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
+            } else {
+              val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+              val isSpaceSufficientForAllCategoricalSplits
+                = numBins > math.pow(2, featureCategories.toInt - 1) - 1
+              if (isSpaceSufficientForAllCategoricalSplits) {
+                updateBinForUnorderedFeature(nodeIndex, featureIndex, arr, label, agg,
+                  rightChildShift)
+              } else {
+                updateBinForOrderedFeature(arr, agg, nodeIndex, label, featureIndex)
+              }
             }
             featureIndex += 1
           }
@@ -586,7 +726,7 @@ object DecisionTree extends Serializable with Logging {
      * @return Array[Double] storing aggregate calculation of size
      *         3 * numSplits * numFeatures * numNodes for regression
      */
-    def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) {
+    def regressionBinSeqOp(arr: Array[Double], agg: Array[Double]) = {
       // Iterate over all nodes.
       var nodeIndex = 0
       while (nodeIndex < numNodes) {
@@ -620,17 +760,20 @@ object DecisionTree extends Serializable with Logging {
      */
     def binSeqOp(agg: Array[Double], arr: Array[Double]): Array[Double] = {
       strategy.algo match {
-        case Classification => classificationBinSeqOp(arr, agg)
+        case Classification =>
+          if(isMulticlassClassificationWithCategoricalFeatures) {
+            unorderedClassificationBinSeqOp(arr, agg)
+          } else {
+            orderedClassificationBinSeqOp(arr, agg)
+          }
         case Regression => regressionBinSeqOp(arr, agg)
       }
       agg
     }
 
     // Calculate bin aggregate length for classification or regression.
-    val binAggregateLength = strategy.algo match {
-      case Classification => 2 * numBins * numFeatures * numNodes
-      case Regression =>  3 * numBins * numFeatures * numNodes
-    }
+    val binAggregateLength = numNodes * getElementsPerNode(numFeatures, numBins, numClasses,
+        isMulticlassClassificationWithCategoricalFeatures, strategy.algo)
     logDebug("binAggregateLength = " + binAggregateLength)
 
     /**
@@ -649,9 +792,6 @@ object DecisionTree extends Serializable with Logging {
       combinedAggregate
     }
 
-    // Find feature bins for all nodes at a level.
-    val binMappedRDD = input.map(x => findBinsForLevel(x))
-
     // Calculate bin aggregates.
     val binAggregates = {
       binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp)
@@ -668,42 +808,55 @@ object DecisionTree extends Serializable with Logging {
      * @return information gain and statistics for all splits
      */
     def calculateGainForSplit(
-        leftNodeAgg: Array[Array[Double]],
+        leftNodeAgg: Array[Array[Array[Double]]],
         featureIndex: Int,
         splitIndex: Int,
-        rightNodeAgg: Array[Array[Double]],
+        rightNodeAgg: Array[Array[Array[Double]]],
         topImpurity: Double): InformationGainStats = {
       strategy.algo match {
         case Classification =>
-          val left0Count = leftNodeAgg(featureIndex)(2 * splitIndex)
-          val left1Count = leftNodeAgg(featureIndex)(2 * splitIndex + 1)
-          val leftCount = left0Count + left1Count
-
-          val right0Count = rightNodeAgg(featureIndex)(2 * splitIndex)
-          val right1Count = rightNodeAgg(featureIndex)(2 * splitIndex + 1)
-          val rightCount = right0Count + right1Count
+          var classIndex = 0
+          val leftCounts: Array[Double] = new Array[Double](numClasses)
+          val rightCounts: Array[Double] = new Array[Double](numClasses)
+          var leftTotalCount = 0.0
+          var rightTotalCount = 0.0
+          while (classIndex < numClasses) {
+            val leftClassCount = leftNodeAgg(featureIndex)(splitIndex)(classIndex)
+            val rightClassCount = rightNodeAgg(featureIndex)(splitIndex)(classIndex)
+            leftCounts(classIndex) = leftClassCount
+            leftTotalCount += leftClassCount
+            rightCounts(classIndex) = rightClassCount
+            rightTotalCount += rightClassCount
+            classIndex += 1
+          }
 
           val impurity = {
             if (level > 0) {
               topImpurity
             } else {
               // Calculate impurity for root node.
-              strategy.impurity.calculate(left0Count + right0Count, left1Count + right1Count)
+              val rootNodeCounts = new Array[Double](numClasses)
+              var classIndex = 0
+              while (classIndex < numClasses) {
+                rootNodeCounts(classIndex) = leftCounts(classIndex) + rightCounts(classIndex)
+                classIndex += 1
+              }
+              strategy.impurity.calculate(rootNodeCounts, leftTotalCount + rightTotalCount)
             }
           }
 
-          if (leftCount == 0) {
-            return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity,1)
+          if (leftTotalCount == 0) {
+            return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue, 1)
           }
-          if (rightCount == 0) {
-            return new InformationGainStats(0, topImpurity, topImpurity, Double.MinValue,0)
+          if (rightTotalCount == 0) {
+            return new InformationGainStats(0, topImpurity, Double.MinValue, topImpurity, 1)
           }
 
-          val leftImpurity = strategy.impurity.calculate(left0Count, left1Count)
-          val rightImpurity = strategy.impurity.calculate(right0Count, right1Count)
+          val leftImpurity = strategy.impurity.calculate(leftCounts, leftTotalCount)
+          val rightImpurity = strategy.impurity.calculate(rightCounts, rightTotalCount)
 
-          val leftWeight = leftCount.toDouble / (leftCount + rightCount)
-          val rightWeight = rightCount.toDouble / (leftCount + rightCount)
+          val leftWeight = leftTotalCount / (leftTotalCount + rightTotalCount)
+          val rightWeight = rightTotalCount / (leftTotalCount + rightTotalCount)
 
           val gain = {
             if (level > 0) {
@@ -713,17 +866,34 @@ object DecisionTree extends Serializable with Logging {
             }
           }
 
-          val predict = (left1Count + right1Count) / (leftCount + rightCount)
+          val totalCount = leftTotalCount + rightTotalCount
 
-          new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict)
+          // Sum of count for each label
+          val leftRightCounts: Array[Double]
+            = leftCounts.zip(rightCounts)
+              .map{case (leftCount, rightCount) => leftCount + rightCount}
+
+          def indexOfLargestArrayElement(array: Array[Double]): Int = {
+            val result = array.foldLeft(-1, Double.MinValue, 0) {
+              case ((maxIndex, maxValue, currentIndex), currentValue) =>
+                if(currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1)
+                else (maxIndex, maxValue, currentIndex + 1)
+            }
+            if (result._1 < 0) 0 else result._1
+          }
+
+          val predict = indexOfLargestArrayElement(leftRightCounts)
+          val prob = leftRightCounts(predict) / totalCount
+
+          new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
         case Regression =>
-          val leftCount = leftNodeAgg(featureIndex)(3 * splitIndex)
-          val leftSum = leftNodeAgg(featureIndex)(3 * splitIndex + 1)
-          val leftSumSquares = leftNodeAgg(featureIndex)(3 * splitIndex + 2)
+          val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0)
+          val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1)
+          val leftSumSquares = leftNodeAgg(featureIndex)(splitIndex)(2)
 
-          val rightCount = rightNodeAgg(featureIndex)(3 * splitIndex)
-          val rightSum = rightNodeAgg(featureIndex)(3 * splitIndex + 1)
-          val rightSumSquares = rightNodeAgg(featureIndex)(3 * splitIndex + 2)
+          val rightCount = rightNodeAgg(featureIndex)(splitIndex)(0)
+          val rightSum = rightNodeAgg(featureIndex)(splitIndex)(1)
+          val rightSumSquares = rightNodeAgg(featureIndex)(splitIndex)(2)
 
           val impurity = {
             if (level > 0) {
@@ -768,104 +938,149 @@ object DecisionTree extends Serializable with Logging {
     /**
      * Extracts left and right split aggregates.
      * @param binData Array[Double] of size 2*numFeatures*numSplits
-     * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Double],
-     *         Array[Double]) where each array is of size(numFeature,2*(numSplits-1))
+     * @return (leftNodeAgg, rightNodeAgg) tuple of type (Array[Array[Array[Double\]\]\],
+     *         Array[Array[Array[Double\]\]\]) where each array is of size(numFeature,
+     *         (numBins - 1), numClasses)
      */
     def extractLeftRightNodeAggregates(
-        binData: Array[Double]): (Array[Array[Double]], Array[Array[Double]]) = {
+        binData: Array[Double]): (Array[Array[Array[Double]]], Array[Array[Array[Double]]]) = {
+
+
+      def findAggForOrderedFeatureClassification(
+          leftNodeAgg: Array[Array[Array[Double]]],
+          rightNodeAgg: Array[Array[Array[Double]]],
+          featureIndex: Int) {
+
+        // shift for this featureIndex
+        val shift = numClasses * featureIndex * numBins
+
+        var classIndex = 0
+        while (classIndex < numClasses) {
+          // left node aggregate for the lowest split
+          leftNodeAgg(featureIndex)(0)(classIndex) = binData(shift + classIndex)
+          // right node aggregate for the highest split
+          rightNodeAgg(featureIndex)(numBins - 2)(classIndex)
+            = binData(shift + (numClasses * (numBins - 1)) + classIndex)
+          classIndex += 1
+        }
+
+        // Iterate over all splits.
+        var splitIndex = 1
+        while (splitIndex < numBins - 1) {
+          // calculating left node aggregate for a split as a sum of left node aggregate of a
+          // lower split and the left bin aggregate of a bin where the split is a high split
+          var innerClassIndex = 0
+          while (innerClassIndex < numClasses) {
+            leftNodeAgg(featureIndex)(splitIndex)(innerClassIndex)
+              = binData(shift + numClasses * splitIndex + innerClassIndex) +
+                leftNodeAgg(featureIndex)(splitIndex - 1)(innerClassIndex)
+            rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(innerClassIndex) =
+              binData(shift + (numClasses * (numBins - 1 - splitIndex) + innerClassIndex)) +
+                rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(innerClassIndex)
+            innerClassIndex += 1
+          }
+          splitIndex += 1
+        }
+      }
+
+      def findAggForUnorderedFeatureClassification(
+          leftNodeAgg: Array[Array[Array[Double]]],
+          rightNodeAgg: Array[Array[Array[Double]]],
+          featureIndex: Int) {
+
+        val rightChildShift = numClasses * numBins * numFeatures
+        var splitIndex = 0
+        while (splitIndex < numBins - 1) {
+          var classIndex = 0
+          while (classIndex < numClasses) {
+            // shift for this featureIndex
+            val shift = numClasses * featureIndex * numBins + splitIndex * numClasses
+            val leftBinValue = binData(shift + classIndex)
+            val rightBinValue = binData(rightChildShift + shift + classIndex)
+            leftNodeAgg(featureIndex)(splitIndex)(classIndex) = leftBinValue
+            rightNodeAgg(featureIndex)(splitIndex)(classIndex) = rightBinValue
+            classIndex += 1
+          }
+          splitIndex += 1
+        }
+      }
+
+      def findAggForRegression(
+          leftNodeAgg: Array[Array[Array[Double]]],
+          rightNodeAgg: Array[Array[Array[Double]]],
+          featureIndex: Int) {
+
+        // shift for this featureIndex
+        val shift = 3 * featureIndex * numBins
+        // left node aggregate for the lowest split
+        leftNodeAgg(featureIndex)(0)(0) = binData(shift + 0)
+        leftNodeAgg(featureIndex)(0)(1) = binData(shift + 1)
+        leftNodeAgg(featureIndex)(0)(2) = binData(shift + 2)
+
+        // right node aggregate for the highest split
+        rightNodeAgg(featureIndex)(numBins - 2)(0) =
+          binData(shift + (3 * (numBins - 1)))
+        rightNodeAgg(featureIndex)(numBins - 2)(1) =
+          binData(shift + (3 * (numBins - 1)) + 1)
+        rightNodeAgg(featureIndex)(numBins - 2)(2) =
+          binData(shift + (3 * (numBins - 1)) + 2)
+
+        // Iterate over all splits.
+        var splitIndex = 1
+        while (splitIndex < numBins - 1) {
+          var i = 0 // index for regression histograms
+          while (i < 3) { // count, sum, sum^2
+            // calculating left node aggregate for a split as a sum of left node aggregate of a
+            // lower split and the left bin aggregate of a bin where the split is a high split
+            leftNodeAgg(featureIndex)(splitIndex)(i) = binData(shift + 3 * splitIndex + i) +
+              leftNodeAgg(featureIndex)(splitIndex - 1)(i)
+            // calculating right node aggregate for a split as a sum of right node aggregate of a
+            // higher split and the right bin aggregate of a bin where the split is a low split
+            rightNodeAgg(featureIndex)(numBins - 2 - splitIndex)(i) =
+              binData(shift + (3 * (numBins - 1 - splitIndex) + i)) +
+                rightNodeAgg(featureIndex)(numBins - 1 - splitIndex)(i)
+            i += 1
+          }
+          splitIndex += 1
+        }
+      }
+
       strategy.algo match {
         case Classification =>
           // Initialize left and right split aggregates.
-          val leftNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
-          val rightNodeAgg = Array.ofDim[Double](numFeatures, 2 * (numBins - 1))
-          // Iterate over all features.
+          val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
+          val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
           var featureIndex = 0
           while (featureIndex < numFeatures) {
-            // shift for this featureIndex
-            val shift = 2 * featureIndex * numBins
-
-            // left node aggregate for the lowest split
-            leftNodeAgg(featureIndex)(0) = binData(shift + 0)
-            leftNodeAgg(featureIndex)(1) = binData(shift + 1)
-
-            // right node aggregate for the highest split
-            rightNodeAgg(featureIndex)(2 * (numBins - 2))
-              = binData(shift + (2 * (numBins - 1)))
-            rightNodeAgg(featureIndex)(2 * (numBins - 2) + 1)
-              = binData(shift + (2 * (numBins - 1)) + 1)
-
-            // Iterate over all splits.
-            var splitIndex = 1
-            while (splitIndex < numBins - 1) {
-              // calculating left node aggregate for a split as a sum of left node aggregate of a
-              // lower split and the left bin aggregate of a bin where the split is a high split
-              leftNodeAgg(featureIndex)(2 * splitIndex) = binData(shift + 2 * splitIndex) +
-                leftNodeAgg(featureIndex)(2 * splitIndex - 2)
-              leftNodeAgg(featureIndex)(2 * splitIndex + 1) = binData(shift + 2 * splitIndex + 1) +
-                leftNodeAgg(featureIndex)(2 * splitIndex - 2 + 1)
-
-              // calculating right node aggregate for a split as a sum of right node aggregate of a
-              // higher split and the right bin aggregate of a bin where the split is a low split
-              rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex)) =
-                binData(shift + (2 *(numBins - 1 - splitIndex))) +
-                rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex))
-              rightNodeAgg(featureIndex)(2 * (numBins - 2 - splitIndex) + 1) =
-                binData(shift + (2* (numBins - 1 - splitIndex) + 1)) +
-                  rightNodeAgg(featureIndex)(2 * (numBins - 1 - splitIndex) + 1)
-
-              splitIndex += 1
+            if (isMulticlassClassificationWithCategoricalFeatures){
+              val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+              if (isFeatureContinuous) {
+                findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
+              } else {
+                val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+                val isSpaceSufficientForAllCategoricalSplits
+                  = numBins > math.pow(2, featureCategories.toInt - 1) - 1
+                if (isSpaceSufficientForAllCategoricalSplits) {
+                  findAggForUnorderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
+                } else {
+                  findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
+                }
+              }
+            } else {
+              findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
             }
             featureIndex += 1
           }
+
           (leftNodeAgg, rightNodeAgg)
         case Regression =>
           // Initialize left and right split aggregates.
-          val leftNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1))
-          val rightNodeAgg = Array.ofDim[Double](numFeatures, 3 * (numBins - 1))
+          val leftNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
+          val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, 3)
           // Iterate over all features.
           var featureIndex = 0
           while (featureIndex < numFeatures) {
-            // shift for this featureIndex
-            val shift = 3 * featureIndex * numBins
-            // left node aggregate for the lowest split
-            leftNodeAgg(featureIndex)(0) = binData(shift + 0)
-            leftNodeAgg(featureIndex)(1) = binData(shift + 1)
-            leftNodeAgg(featureIndex)(2) = binData(shift + 2)
-
-            // right node aggregate for the highest split
-            rightNodeAgg(featureIndex)(3 * (numBins - 2)) =
-              binData(shift + (3 * (numBins - 1)))
-            rightNodeAgg(featureIndex)(3 * (numBins - 2) + 1) =
-              binData(shift + (3 * (numBins - 1)) + 1)
-            rightNodeAgg(featureIndex)(3 * (numBins - 2) + 2) =
-              binData(shift + (3 * (numBins - 1)) + 2)
-
-            // Iterate over all splits.
-            var splitIndex = 1
-            while (splitIndex < numBins - 1) {
-              // calculating left node aggregate for a split as a sum of left node aggregate of a
-              // lower split and the left bin aggregate of a bin where the split is a high split
-              leftNodeAgg(featureIndex)(3 * splitIndex) = binData(shift + 3 * splitIndex) +
-                leftNodeAgg(featureIndex)(3 * splitIndex - 3)
-              leftNodeAgg(featureIndex)(3 * splitIndex + 1) = binData(shift + 3 * splitIndex + 1) +
-                leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 1)
-              leftNodeAgg(featureIndex)(3 * splitIndex + 2) = binData(shift + 3 * splitIndex + 2) +
-                leftNodeAgg(featureIndex)(3 * splitIndex - 3 + 2)
-
-              // calculating right node aggregate for a split as a sum of right node aggregate of a
-              // higher split and the right bin aggregate of a bin where the split is a low split
-              rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex)) =
-                binData(shift + (3 * (numBins - 1 - splitIndex))) +
-                  rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex))
-              rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 1) =
-                binData(shift + (3 * (numBins - 1 - splitIndex) + 1)) +
-                  rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 1)
-              rightNodeAgg(featureIndex)(3 * (numBins - 2 - splitIndex) + 2) =
-                binData(shift + (3 * (numBins - 1 - splitIndex) + 2)) +
-                  rightNodeAgg(featureIndex)(3 * (numBins - 1 - splitIndex) + 2)
-
-              splitIndex += 1
-            }
+            findAggForRegression(leftNodeAgg, rightNodeAgg, featureIndex)
             featureIndex += 1
           }
           (leftNodeAgg, rightNodeAgg)
@@ -876,8 +1091,8 @@ object DecisionTree extends Serializable with Logging {
      * Calculates information gain for all nodes splits.
      */
     def calculateGainsForAllNodeSplits(
-        leftNodeAgg: Array[Array[Double]],
-        rightNodeAgg: Array[Array[Double]],
+        leftNodeAgg: Array[Array[Array[Double]]],
+        rightNodeAgg: Array[Array[Array[Double]]],
         nodeImpurity: Double): Array[Array[InformationGainStats]] = {
       val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)
 
@@ -918,7 +1133,22 @@ object DecisionTree extends Serializable with Logging {
         while (featureIndex < numFeatures) {
           // Iterate over all splits.
           var splitIndex = 0
-          while (splitIndex < numBins - 1) {
+          val maxSplitIndex : Double = {
+            val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+            if (isFeatureContinuous) {
+              numBins - 1
+            } else { // Categorical feature
+              val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+              val isSpaceSufficientForAllCategoricalSplits
+                = numBins > math.pow(2, featureCategories.toInt - 1) - 1
+              if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
+                math.pow(2.0, featureCategories - 1).toInt - 1
+              } else { // Binary classification
+                featureCategories
+              }
+            }
+          }
+          while (splitIndex < maxSplitIndex) {
             val gainStats = gains(featureIndex)(splitIndex)
             if (gainStats.gain > bestGainStats.gain) {
               bestGainStats = gainStats
@@ -944,9 +1174,23 @@ object DecisionTree extends Serializable with Logging {
     def getBinDataForNode(node: Int): Array[Double] = {
       strategy.algo match {
         case Classification =>
-          val shift = 2 * node * numBins * numFeatures
-          val binsForNode = binAggregates.slice(shift, shift + 2 * numBins * numFeatures)
-          binsForNode
+          if (isMulticlassClassificationWithCategoricalFeatures) {
+            val shift = numClasses * node * numBins * numFeatures
+            val rightChildShift = numClasses * numBins * numFeatures * numNodes
+            val binsForNode = {
+              val leftChildData
+                = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
+              val rightChildData
+              = binAggregates.slice(rightChildShift + shift,
+                rightChildShift + shift + numClasses * numBins * numFeatures)
+              leftChildData ++ rightChildData
+            }
+            binsForNode
+          } else {
+            val shift = numClasses * node * numBins * numFeatures
+            val binsForNode = binAggregates.slice(shift, shift + numClasses * numBins * numFeatures)
+            binsForNode
+          }
         case Regression =>
           val shift = 3 * node * numBins * numFeatures
           val binsForNode = binAggregates.slice(shift, shift + 3 * numBins * numFeatures)
@@ -963,14 +1207,26 @@ object DecisionTree extends Serializable with Logging {
       val binsForNode: Array[Double] = getBinDataForNode(node)
       logDebug("nodeImpurityIndex = " + nodeImpurityIndex)
       val parentNodeImpurity = parentImpurities(nodeImpurityIndex)
-      logDebug("node impurity = " + parentNodeImpurity)
+      logDebug("parent node impurity = " + parentNodeImpurity)
       bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
       node += 1
     }
-
     bestSplits
   }
 
+  private def getElementsPerNode(numFeatures: Int, numBins: Int, numClasses: Int,
+      isMulticlassClassificationWithCategoricalFeatures: Boolean, algo: Algo): Int = {
+    algo match {
+      case Classification =>
+        if (isMulticlassClassificationWithCategoricalFeatures) {
+          2 * numClasses * numBins * numFeatures
+        } else {
+          numClasses * numBins * numFeatures
+        }
+      case Regression => 3 * numBins * numFeatures
+    }
+  }
+
   /**
    * Returns split and bins for decision tree calculation.
    * @param input RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] used as training data
@@ -992,17 +1248,23 @@ object DecisionTree extends Serializable with Logging {
     val maxBins = strategy.maxBins
     val numBins = if (maxBins <= count) maxBins else count.toInt
     logDebug("numBins = " + numBins)
+    val isMulticlassClassification = strategy.isMulticlassClassification
+    logDebug("isMulticlassClassification = " + isMulticlassClassification)
+
 
     /*
-     * TODO: Add a require statement ensuring #bins is always greater than the categories.
+     * Ensure #bins is always greater than the categories. For multiclass classification,
+     * #bins should be greater than 2^(maxCategories - 1) - 1.
      * It's a limitation of the current implementation but a reasonable trade-off since features
      * with large number of categories get favored over continuous features.
      */
     if (strategy.categoricalFeaturesInfo.size > 0) {
       val maxCategoriesForFeatures = strategy.categoricalFeaturesInfo.maxBy(_._2)._2
-      require(numBins >= maxCategoriesForFeatures)
+      require(numBins > maxCategoriesForFeatures, "numBins should be greater than max categories " +
+        "in categorical features")
     }
 
+
     // Calculate the number of sample for approximate quantile calculation.
     val requiredSamples = numBins*numBins
     val fraction = if (requiredSamples < count) requiredSamples.toDouble / count else 1.0
@@ -1036,48 +1298,93 @@ object DecisionTree extends Serializable with Logging {
               val split = new Split(featureIndex, featureSamples(sampleIndex), Continuous, List())
               splits(featureIndex)(index) = split
             }
-          } else {
-            val maxFeatureValue = strategy.categoricalFeaturesInfo(featureIndex)
-            require(maxFeatureValue < numBins, "number of categories should be less than number " +
-              "of bins")
-
-            // For categorical variables, each bin is a category. The bins are sorted and they
-            // are ordered by calculating the centroid of their corresponding labels.
-            val centroidForCategories =
-              sampledInput.map(lp => (lp.features(featureIndex),lp.label))
-                .groupBy(_._1)
-                .mapValues(x => x.map(_._2).sum / x.map(_._1).length)
-
-            // Check for missing categorical variables and putting them last in the sorted list.
-            val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]()
-            for (i <- 0 until maxFeatureValue) {
-              if (centroidForCategories.contains(i)) {
-                fullCentroidForCategories(i) = centroidForCategories(i)
-              } else {
-                fullCentroidForCategories(i) = Double.MaxValue
-              }
-            }
-
-            // bins sorted by centroids
-            val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2)
-
-            logDebug("centriod for categorical variable = " + categoriesSortedByCentroid)
-
-            var categoriesForSplit = List[Double]()
-            categoriesSortedByCentroid.iterator.zipWithIndex.foreach {
-              case ((key, value), index) =>
-                categoriesForSplit = key :: categoriesForSplit
-                splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue, Categorical,
-                  categoriesForSplit)
+          } else { // Categorical feature
+            val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+            val isSpaceSufficientForAllCategoricalSplits
+              = numBins > math.pow(2, featureCategories.toInt - 1) - 1
+
+            // Use different bin/split calculation strategy for categorical features in multiclass
+            // classification that satisfy the space constraint
+            if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
+              // 2^(maxFeatureValue- 1) - 1 combinations
+              var index = 0
+              while (index < math.pow(2.0, featureCategories - 1).toInt - 1) {
+                val categories: List[Double]
+                  = extractMultiClassCategories(index + 1, featureCategories)
+                splits(featureIndex)(index)
+                  = new Split(featureIndex, Double.MinValue, Categorical, categories)
                 bins(featureIndex)(index) = {
                   if (index == 0) {
-                    new Bin(new DummyCategoricalSplit(featureIndex, Categorical),
-                      splits(featureIndex)(0), Categorical, key)
+                    new Bin(
+                      new DummyCategoricalSplit(featureIndex, Categorical),
+                      splits(featureIndex)(0),
+                      Categorical,
+                      Double.MinValue)
                   } else {
-                    new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
-                      Categorical, key)
+                    new Bin(
+                      splits(featureIndex)(index - 1),
+                      splits(featureIndex)(index),
+                      Categorical,
+                      Double.MinValue)
                   }
                 }
+                index += 1
+              }
+            } else {
+
+              val centroidForCategories = {
+                if (isMulticlassClassification) {
+                  // For categorical variables in multiclass classification,
+                  // each bin is a category. The bins are sorted and they
+                  // are ordered by calculating the impurity of their corresponding labels.
+                  sampledInput.map(lp => (lp.features(featureIndex), lp.label))
+                   .groupBy(_._1)
+                   .mapValues(x => x.groupBy(_._2).mapValues(x => x.size.toDouble))
+                   .map(x => (x._1, x._2.values.toArray))
+                   .map(x => (x._1, strategy.impurity.calculate(x._2,x._2.sum)))
+                } else { // regression or binary classification
+                  // For categorical variables in regression and binary classification,
+                  // each bin is a category. The bins are sorted and they
+                  // are ordered by calculating the centroid of their corresponding labels.
+                  sampledInput.map(lp => (lp.features(featureIndex), lp.label))
+                    .groupBy(_._1)
+                    .mapValues(x => x.map(_._2).sum / x.map(_._1).length)
+                }
+              }
+
+              logDebug("centriod for categories = " + centroidForCategories.mkString(","))
+
+              // Check for missing categorical variables and putting them last in the sorted list.
+              val fullCentroidForCategories = scala.collection.mutable.Map[Double,Double]()
+              for (i <- 0 until featureCategories) {
+                if (centroidForCategories.contains(i)) {
+                  fullCentroidForCategories(i) = centroidForCategories(i)
+                } else {
+                  fullCentroidForCategories(i) = Double.MaxValue
+                }
+              }
+
+              // bins sorted by centroids
+              val categoriesSortedByCentroid = fullCentroidForCategories.toList.sortBy(_._2)
+
+              logDebug("centriod for categorical variable = " + categoriesSortedByCentroid)
+
+              var categoriesForSplit = List[Double]()
+              categoriesSortedByCentroid.iterator.zipWithIndex.foreach {
+                case ((key, value), index) =>
+                  categoriesForSplit = key :: categoriesForSplit
+                  splits(featureIndex)(index) = new Split(featureIndex, Double.MinValue,
+                    Categorical, categoriesForSplit)
+                  bins(featureIndex)(index) = {
+                    if (index == 0) {
+                      new Bin(new DummyCategoricalSplit(featureIndex, Categorical),
+                        splits(featureIndex)(0), Categorical, key)
+                    } else {
+                      new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
+                        Categorical, key)
+                    }
+                  }
+              }
             }
           }
           featureIndex += 1
@@ -1107,4 +1414,29 @@ object DecisionTree extends Serializable with Logging {
         throw new UnsupportedOperationException("approximate histogram not supported yet.")
     }
   }
+
+  /**
+   * Nested method to extract list of eligible categories given an index. It extracts the
+   * position of ones in a binary representation of the input. If binary
+   * representation of an number is 01101 (13), the output list should (3.0, 2.0,
+   * 0.0). The maxFeatureValue depict the number of rightmost digits that will be tested for ones.
+   */
+  private[tree] def extractMultiClassCategories(
+      input: Int,
+      maxFeatureValue: Int): List[Double] = {
+    var categories = List[Double]()
+    var j = 0
+    var bitShiftedInput = input
+    while (j < maxFeatureValue) {
+      if (bitShiftedInput % 2 != 0) {
+        // updating the list of categories.
+        categories = j.toDouble :: categories
+      }
+      // Right shift by one
+      bitShiftedInput = bitShiftedInput >> 1
+      j += 1
+    }
+    categories
+  }
+
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 1b505fd76eb751f54338101c587bd3ee74424a17..7c027ac2fda6ba46eab5b90081a25304e980560b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -28,6 +28,8 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
  * @param algo classification or regression
  * @param impurity criterion used for information gain calculation
  * @param maxDepth maximum depth of the tree
+ * @param numClassesForClassification number of classes for classification. Default value is 2
+ *                                    leads to binary classification
  * @param maxBins maximum number of bins used for splitting features
  * @param quantileCalculationStrategy algorithm for calculating quantiles
  * @param categoricalFeaturesInfo A map storing information about the categorical variables and the
@@ -44,7 +46,15 @@ class Strategy (
     val algo: Algo,
     val impurity: Impurity,
     val maxDepth: Int,
+    val numClassesForClassification: Int = 2,
     val maxBins: Int = 100,
     val quantileCalculationStrategy: QuantileStrategy = Sort,
     val categoricalFeaturesInfo: Map[Int, Int] = Map[Int, Int](),
-    val maxMemoryInMB: Int = 128) extends Serializable
+    val maxMemoryInMB: Int = 128) extends Serializable {
+
+  require(numClassesForClassification >= 2)
+  val isMulticlassClassification = numClassesForClassification > 2
+  val isMulticlassWithCategoricalFeatures
+    = isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 60f43e9278d2ae475eca153239d05e7dfe14ddcc..a0e2d91762782804fbe5dc8d099edb59fc5c8c51 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -31,23 +31,35 @@ object Entropy extends Impurity {
 
   /**
    * :: DeveloperApi ::
-   * entropy calculation
-   * @param c0 count of instances with label 0
-   * @param c1 count of instances with label 1
-   * @return entropy value
+   * information calculation for multiclass classification
+   * @param counts Array[Double] with counts for each label
+   * @param totalCount sum of counts for all labels
+   * @return information value
    */
   @DeveloperApi
-  override def calculate(c0: Double, c1: Double): Double = {
-    if (c0 == 0 || c1 == 0) {
-      0
-    } else {
-      val total = c0 + c1
-      val f0 = c0 / total
-      val f1 = c1 / total
-      -(f0 * log2(f0)) - (f1 * log2(f1))
+  override def calculate(counts: Array[Double], totalCount: Double): Double = {
+    val numClasses = counts.length
+    var impurity = 0.0
+    var classIndex = 0
+    while (classIndex < numClasses) {
+      val classCount = counts(classIndex)
+      if (classCount != 0) {
+        val freq = classCount / totalCount
+        impurity -= freq * log2(freq)
+      }
+      classIndex += 1
     }
+    impurity
   }
 
+  /**
+   * :: DeveloperApi ::
+   * variance calculation
+   * @param count number of instances
+   * @param sum sum of labels
+   * @param sumSquares summation of squares of the labels
+   */
+  @DeveloperApi
   override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
     throw new UnsupportedOperationException("Entropy.calculate")
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index c51d76d9b4c5be6ba2ab4ffc48a240f2479c02a0..48144b5e6d1e44951dd6e06cffc3d746d838e47b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -30,23 +30,32 @@ object Gini extends Impurity {
 
   /**
    * :: DeveloperApi ::
-   * Gini coefficient calculation
-   * @param c0 count of instances with label 0
-   * @param c1 count of instances with label 1
-   * @return Gini coefficient value
+   * information calculation for multiclass classification
+   * @param counts Array[Double] with counts for each label
+   * @param totalCount sum of counts for all labels
+   * @return information value
    */
   @DeveloperApi
-  override def calculate(c0: Double, c1: Double): Double = {
-    if (c0 == 0 || c1 == 0) {
-      0
-    } else {
-      val total = c0 + c1
-      val f0 = c0 / total
-      val f1 = c1 / total
-      1 - f0 * f0 - f1 * f1
+  override def calculate(counts: Array[Double], totalCount: Double): Double = {
+    val numClasses = counts.length
+    var impurity = 1.0
+    var classIndex = 0
+    while (classIndex < numClasses) {
+      val freq = counts(classIndex) / totalCount
+      impurity -= freq * freq
+      classIndex += 1
     }
+    impurity
   }
 
+  /**
+   * :: DeveloperApi ::
+   * variance calculation
+   * @param count number of instances
+   * @param sum sum of labels
+   * @param sumSquares summation of squares of the labels
+   */
+  @DeveloperApi
   override def calculate(count: Double, sum: Double, sumSquares: Double): Double =
     throw new UnsupportedOperationException("Gini.calculate")
 }
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index 8eab247cf0932cf723587ca2bd75574942db0da2..7b2a9320cc21dc40b2871d4209e183d61c51bb03 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -28,13 +28,13 @@ trait Impurity extends Serializable {
 
   /**
    * :: DeveloperApi ::
-   * information calculation for binary classification
-   * @param c0 count of instances with label 0
-   * @param c1 count of instances with label 1
+   * information calculation for multiclass classification
+   * @param counts Array[Double] with counts for each label
+   * @param totalCount sum of counts for all labels
    * @return information value
    */
   @DeveloperApi
-  def calculate(c0 : Double, c1 : Double): Double
+  def calculate(counts: Array[Double], totalCount: Double): Double
 
   /**
    * :: DeveloperApi ::
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index 47d07122af30f9a89e402809cdd7150313db4b48..97149a99ead59aa04457a02307b6569c1d41511f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -25,7 +25,16 @@ import org.apache.spark.annotation.{DeveloperApi, Experimental}
  */
 @Experimental
 object Variance extends Impurity {
-   override def calculate(c0: Double, c1: Double): Double =
+
+  /**
+   * :: DeveloperApi ::
+   * information calculation for multiclass classification
+   * @param counts Array[Double] with counts for each label
+   * @param totalCount sum of counts for all labels
+   * @return information value
+   */
+  @DeveloperApi
+  override def calculate(counts: Array[Double], totalCount: Double): Double =
      throw new UnsupportedOperationException("Variance.calculate")
 
   /**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
index 2d71e1e366069afc81da525ff600015b92dbd477..c89c1e371a40e16008e07ffbc8cc7d11a16ca2d6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala
@@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.FeatureType._
  * @param highSplit signifying the upper threshold for the continuous feature to be
  *                 accepted in the bin
  * @param featureType type of feature -- categorical or continuous
- * @param category categorical label value accepted in the bin
+ * @param category categorical label value accepted in the bin for binary classification
  */
 private[tree]
 case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index cc8a24cce961459a021b394a5675043366a5f131..fb12298e0f5d3809e217a0aef0715e0d86f63e8a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -27,6 +27,7 @@ import org.apache.spark.annotation.DeveloperApi
  * @param leftImpurity left node impurity
  * @param rightImpurity right node impurity
  * @param predict predicted value
+ * @param prob probability of the label (classification only)
  */
 @DeveloperApi
 class InformationGainStats(
@@ -34,10 +35,11 @@ class InformationGainStats(
     val impurity: Double,
     val leftImpurity: Double,
     val rightImpurity: Double,
-    val predict: Double) extends Serializable {
+    val predict: Double,
+    val prob: Double = 0.0) extends Serializable {
 
   override def toString = {
-    "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f"
-      .format(gain, impurity, leftImpurity, rightImpurity, predict)
+    "gain = %f, impurity = %f, left impurity = %f, right impurity = %f, predict = %f, prob = %f"
+      .format(gain, impurity, leftImpurity, rightImpurity, predict, prob)
   }
 }
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index bcb11876b8f4fb401ac0834dbb5d307c7f5b97e2..5961a618c59d9a98cedd9b2faaeb35cd9db501e0 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.mllib.tree
 
 import org.scalatest.FunSuite
 
-import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
 import org.apache.spark.mllib.tree.model.Filter
 import org.apache.spark.mllib.tree.model.Split
@@ -28,6 +27,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._
 import org.apache.spark.mllib.tree.configuration.FeatureType._
 import org.apache.spark.mllib.linalg.Vectors
 import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.regression.LabeledPoint
 
 class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
@@ -35,7 +35,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Gini, 3, 100)
+    val strategy = new Strategy(Classification, Gini, 3, 2, 100)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
     assert(splits.length === 2)
     assert(bins.length === 2)
@@ -51,6 +51,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       Classification,
       Gini,
       maxDepth = 3,
+      numClassesForClassification = 2,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
@@ -130,8 +131,9 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
       Classification,
       Gini,
       maxDepth = 3,
+      numClassesForClassification = 2,
       maxBins = 100,
-      categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+      categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
 
     // Check splits.
@@ -231,6 +233,162 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bins(1)(3) === null)
   }
 
+  test("extract categories from a number for multiclass classification") {
+    val l = DecisionTree.extractMultiClassCategories(13, 10)
+    assert(l.length === 3)
+    assert(List(3.0, 2.0, 0.0).toSeq == l.toSeq)
+  }
+
+  test("split and bin calculations for unordered categorical variables with multiclass " +
+    "classification") {
+    val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+    assert(arr.length === 1000)
+    val rdd = sc.parallelize(arr)
+    val strategy = new Strategy(
+      Classification,
+      Gini,
+      maxDepth = 3,
+      numClassesForClassification = 100,
+      maxBins = 100,
+      categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+
+    // Expecting 2^2 - 1 = 3 bins/splits
+    assert(splits(0)(0).feature === 0)
+    assert(splits(0)(0).threshold === Double.MinValue)
+    assert(splits(0)(0).featureType === Categorical)
+    assert(splits(0)(0).categories.length === 1)
+    assert(splits(0)(0).categories.contains(0.0))
+    assert(splits(1)(0).feature === 1)
+    assert(splits(1)(0).threshold === Double.MinValue)
+    assert(splits(1)(0).featureType === Categorical)
+    assert(splits(1)(0).categories.length === 1)
+    assert(splits(1)(0).categories.contains(0.0))
+
+    assert(splits(0)(1).feature === 0)
+    assert(splits(0)(1).threshold === Double.MinValue)
+    assert(splits(0)(1).featureType === Categorical)
+    assert(splits(0)(1).categories.length === 1)
+    assert(splits(0)(1).categories.contains(1.0))
+    assert(splits(1)(1).feature === 1)
+    assert(splits(1)(1).threshold === Double.MinValue)
+    assert(splits(1)(1).featureType === Categorical)
+    assert(splits(1)(1).categories.length === 1)
+    assert(splits(1)(1).categories.contains(1.0))
+
+    assert(splits(0)(2).feature === 0)
+    assert(splits(0)(2).threshold === Double.MinValue)
+    assert(splits(0)(2).featureType === Categorical)
+    assert(splits(0)(2).categories.length === 2)
+    assert(splits(0)(2).categories.contains(0.0))
+    assert(splits(0)(2).categories.contains(1.0))
+    assert(splits(1)(2).feature === 1)
+    assert(splits(1)(2).threshold === Double.MinValue)
+    assert(splits(1)(2).featureType === Categorical)
+    assert(splits(1)(2).categories.length === 2)
+    assert(splits(1)(2).categories.contains(0.0))
+    assert(splits(1)(2).categories.contains(1.0))
+
+    assert(splits(0)(3) === null)
+    assert(splits(1)(3) === null)
+
+
+    // Check bins.
+
+    assert(bins(0)(0).category === Double.MinValue)
+    assert(bins(0)(0).lowSplit.categories.length === 0)
+    assert(bins(0)(0).highSplit.categories.length === 1)
+    assert(bins(0)(0).highSplit.categories.contains(0.0))
+    assert(bins(1)(0).category === Double.MinValue)
+    assert(bins(1)(0).lowSplit.categories.length === 0)
+    assert(bins(1)(0).highSplit.categories.length === 1)
+    assert(bins(1)(0).highSplit.categories.contains(0.0))
+
+    assert(bins(0)(1).category === Double.MinValue)
+    assert(bins(0)(1).lowSplit.categories.length === 1)
+    assert(bins(0)(1).lowSplit.categories.contains(0.0))
+    assert(bins(0)(1).highSplit.categories.length === 1)
+    assert(bins(0)(1).highSplit.categories.contains(1.0))
+    assert(bins(1)(1).category === Double.MinValue)
+    assert(bins(1)(1).lowSplit.categories.length === 1)
+    assert(bins(1)(1).lowSplit.categories.contains(0.0))
+    assert(bins(1)(1).highSplit.categories.length === 1)
+    assert(bins(1)(1).highSplit.categories.contains(1.0))
+
+    assert(bins(0)(2).category === Double.MinValue)
+    assert(bins(0)(2).lowSplit.categories.length === 1)
+    assert(bins(0)(2).lowSplit.categories.contains(1.0))
+    assert(bins(0)(2).highSplit.categories.length === 2)
+    assert(bins(0)(2).highSplit.categories.contains(1.0))
+    assert(bins(0)(2).highSplit.categories.contains(0.0))
+    assert(bins(1)(2).category === Double.MinValue)
+    assert(bins(1)(2).lowSplit.categories.length === 1)
+    assert(bins(1)(2).lowSplit.categories.contains(1.0))
+    assert(bins(1)(2).highSplit.categories.length === 2)
+    assert(bins(1)(2).highSplit.categories.contains(1.0))
+    assert(bins(1)(2).highSplit.categories.contains(0.0))
+
+    assert(bins(0)(3) === null)
+    assert(bins(1)(3) === null)
+
+  }
+
+  test("split and bin calculations for ordered categorical variables with multiclass " +
+    "classification") {
+    val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
+    assert(arr.length === 3000)
+    val rdd = sc.parallelize(arr)
+    val strategy = new Strategy(
+      Classification,
+      Gini,
+      maxDepth = 3,
+      numClassesForClassification = 100,
+      maxBins = 100,
+      categoricalFeaturesInfo = Map(0 -> 10, 1-> 10))
+    val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
+
+    // 2^10 - 1 > 100, so categorical variables will be ordered
+
+    assert(splits(0)(0).feature === 0)
+    assert(splits(0)(0).threshold === Double.MinValue)
+    assert(splits(0)(0).featureType === Categorical)
+    assert(splits(0)(0).categories.length === 1)
+    assert(splits(0)(0).categories.contains(1.0))
+
+    assert(splits(0)(1).feature === 0)
+    assert(splits(0)(1).threshold === Double.MinValue)
+    assert(splits(0)(1).featureType === Categorical)
+    assert(splits(0)(1).categories.length === 2)
+    assert(splits(0)(1).categories.contains(2.0))
+
+    assert(splits(0)(2).feature === 0)
+    assert(splits(0)(2).threshold === Double.MinValue)
+    assert(splits(0)(2).featureType === Categorical)
+    assert(splits(0)(2).categories.length === 3)
+    assert(splits(0)(2).categories.contains(2.0))
+    assert(splits(0)(2).categories.contains(1.0))
+
+    assert(splits(0)(10) === null)
+    assert(splits(1)(10) === null)
+
+
+    // Check bins.
+
+    assert(bins(0)(0).category === 1.0)
+    assert(bins(0)(0).lowSplit.categories.length === 0)
+    assert(bins(0)(0).highSplit.categories.length === 1)
+    assert(bins(0)(0).highSplit.categories.contains(1.0))
+    assert(bins(0)(1).category === 2.0)
+    assert(bins(0)(1).lowSplit.categories.length === 1)
+    assert(bins(0)(1).highSplit.categories.length === 2)
+    assert(bins(0)(1).highSplit.categories.contains(1.0))
+    assert(bins(0)(1).highSplit.categories.contains(2.0))
+
+    assert(bins(0)(10) === null)
+
+  }
+
+
   test("classification stump with all categorical variables") {
     val arr = DecisionTreeSuite.generateCategoricalDataPoints()
     assert(arr.length === 1000)
@@ -238,6 +396,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val strategy = new Strategy(
       Classification,
       Gini,
+      numClassesForClassification = 2,
       maxDepth = 3,
       maxBins = 100,
       categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
@@ -253,8 +412,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val stats = bestSplits(0)._2
     assert(stats.gain > 0)
-    assert(stats.predict > 0.5)
-    assert(stats.predict < 0.7)
+    assert(stats.predict === 1)
+    assert(stats.prob == 0.6)
     assert(stats.impurity > 0.2)
   }
 
@@ -280,8 +439,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
     val stats = bestSplits(0)._2
     assert(stats.gain > 0)
-    assert(stats.predict > 0.5)
-    assert(stats.predict < 0.7)
+    assert(stats.predict == 0.6)
     assert(stats.impurity > 0.2)
   }
 
@@ -289,7 +447,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Gini, 3, 100)
+    val strategy = new Strategy(Classification, Gini, 3, 2, 100)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -312,7 +470,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Gini, 3, 100)
+    val strategy = new Strategy(Classification, Gini, 3, 2, 100)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -336,7 +494,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Entropy, 3, 100)
+    val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -360,7 +518,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Entropy, 3, 100)
+    val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -380,11 +538,11 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
     assert(bestSplits(0)._2.predict === 1)
   }
 
-  test("test second level node building with/without groups") {
+  test("second level node building with/without groups") {
     val arr = DecisionTreeSuite.generateOrderedLabeledPoints()
     assert(arr.length === 1000)
     val rdd = sc.parallelize(arr)
-    val strategy = new Strategy(Classification, Entropy, 3, 100)
+    val strategy = new Strategy(Classification, Entropy, 3, 2, 100)
     val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
     assert(splits.length === 2)
     assert(splits(0).length === 99)
@@ -426,6 +584,82 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
 
   }
 
+  test("stump with categorical variables for multiclass classification") {
+    val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlass()
+    val input = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+      numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
+    assert(strategy.isMulticlassClassification)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+    val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+      Array[List[Filter]](), splits, bins, 10)
+
+    assert(bestSplits.length === 1)
+    val bestSplit = bestSplits(0)._1
+    assert(bestSplit.feature === 0)
+    assert(bestSplit.categories.length === 1)
+    assert(bestSplit.categories.contains(1))
+    assert(bestSplit.featureType === Categorical)
+  }
+
+  test("stump with continuous variables for multiclass classification") {
+    val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
+    val input = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+      numClassesForClassification = 3)
+    assert(strategy.isMulticlassClassification)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+    val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+      Array[List[Filter]](), splits, bins, 10)
+
+    assert(bestSplits.length === 1)
+    val bestSplit = bestSplits(0)._1
+
+    assert(bestSplit.feature === 1)
+    assert(bestSplit.featureType === Continuous)
+    assert(bestSplit.threshold > 1980)
+    assert(bestSplit.threshold < 2020)
+
+  }
+
+  test("stump with continuous + categorical variables for multiclass classification") {
+    val arr = DecisionTreeSuite.generateContinuousDataPointsForMulticlass()
+    val input = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+      numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3))
+    assert(strategy.isMulticlassClassification)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+    val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+      Array[List[Filter]](), splits, bins, 10)
+
+    assert(bestSplits.length === 1)
+    val bestSplit = bestSplits(0)._1
+
+    assert(bestSplit.feature === 1)
+    assert(bestSplit.featureType === Continuous)
+    assert(bestSplit.threshold > 1980)
+    assert(bestSplit.threshold < 2020)
+  }
+
+  test("stump with categorical variables for ordered multiclass classification") {
+    val arr = DecisionTreeSuite.generateCategoricalDataPointsForMulticlassForOrderedFeatures()
+    val input = sc.parallelize(arr)
+    val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 5,
+      numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
+    assert(strategy.isMulticlassClassification)
+    val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
+    val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+      Array[List[Filter]](), splits, bins, 10)
+
+    assert(bestSplits.length === 1)
+    val bestSplit = bestSplits(0)._1
+    assert(bestSplit.feature === 0)
+    assert(bestSplit.categories.length === 1)
+    assert(bestSplit.categories.contains(1.0))
+    assert(bestSplit.featureType === Categorical)
+  }
+
+
 }
 
 object DecisionTreeSuite {
@@ -473,4 +707,47 @@ object DecisionTreeSuite {
     }
     arr
   }
+
+  def generateCategoricalDataPointsForMulticlass(): Array[LabeledPoint] = {
+    val arr = new Array[LabeledPoint](3000)
+    for (i <- 0 until 3000) {
+      if (i < 1000) {
+        arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
+      } else if (i < 2000) {
+        arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))
+      } else {
+        arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
+      }
+    }
+    arr
+  }
+
+  def generateContinuousDataPointsForMulticlass(): Array[LabeledPoint] = {
+    val arr = new Array[LabeledPoint](3000)
+    for (i <- 0 until 3000) {
+      if (i < 2000) {
+        arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, i))
+      } else {
+        arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, i))
+      }
+    }
+    arr
+  }
+
+  def generateCategoricalDataPointsForMulticlassForOrderedFeatures():
+    Array[LabeledPoint] = {
+    val arr = new Array[LabeledPoint](3000)
+    for (i <- 0 until 3000) {
+      if (i < 1000) {
+        arr(i) = new LabeledPoint(2.0, Vectors.dense(2.0, 2.0))
+      } else if (i < 2000) {
+        arr(i) = new LabeledPoint(1.0, Vectors.dense(1.0, 2.0))
+      } else {
+        arr(i) = new LabeledPoint(1.0, Vectors.dense(2.0, 2.0))
+      }
+    }
+    arr
+  }
+
+
 }
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 3487f7c5c12551f680121456b54ddb91e3a2a9c4..e0f433b26f7ffcabc86fd92613619e07f7d1313c 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -82,7 +82,15 @@ object MimaExcludes {
       MimaBuild.excludeSparkClass("util.SerializableHyperLogLog") ++
       MimaBuild.excludeSparkClass("storage.Values") ++
       MimaBuild.excludeSparkClass("storage.Entry") ++
-      MimaBuild.excludeSparkClass("storage.MemoryStore$Entry")
+      MimaBuild.excludeSparkClass("storage.MemoryStore$Entry") ++
+      Seq(
+        ProblemFilters.exclude[IncompatibleMethTypeProblem](
+          "org.apache.spark.mllib.tree.impurity.Gini.calculate"),
+        ProblemFilters.exclude[IncompatibleMethTypeProblem](
+          "org.apache.spark.mllib.tree.impurity.Entropy.calculate"),
+        ProblemFilters.exclude[IncompatibleMethTypeProblem](
+          "org.apache.spark.mllib.tree.impurity.Variance.calculate")
+      )
     case v if v.startsWith("1.0") =>
       Seq(
         MimaBuild.excludeSparkPackage("api.java"),